use crate::server::Event;
use axum::{
extract::{ws::WebSocket, State, WebSocketUpgrade},
routing::get,
Router,
};
use server::ServerState;
use stateroom::StateroomServiceFactory;
use std::{
net::{IpAddr, SocketAddr},
sync::Arc,
time::Duration,
};
use tokio::{net::TcpListener, select};
mod server;
const DEFAULT_IP: &str = "0.0.0.0";
#[derive(Debug)]
pub struct Server {
pub heartbeat_interval: Duration,
pub heartbeat_timeout: Duration,
pub port: u16,
pub ip: String,
pub static_path: Option<String>,
pub client_path: Option<String>,
}
impl Default for Server {
fn default() -> Self {
Server {
heartbeat_interval: Duration::from_secs(30),
heartbeat_timeout: Duration::from_secs(300),
port: 8080,
ip: DEFAULT_IP.to_string(),
static_path: None,
client_path: None,
}
}
}
impl Server {
#[must_use]
pub fn new() -> Self {
Server::default()
}
#[must_use]
pub fn with_static_path(mut self, static_path: Option<String>) -> Self {
self.static_path = static_path;
self
}
#[must_use]
pub fn with_client_path(mut self, client_path: Option<String>) -> Self {
self.client_path = client_path;
self
}
#[must_use]
pub fn with_heartbeat_interval(mut self, duration_seconds: u64) -> Self {
self.heartbeat_interval = Duration::from_secs(duration_seconds);
self
}
#[must_use]
pub fn with_heartbeat_timeout(mut self, duration_seconds: u64) -> Self {
self.heartbeat_timeout = Duration::from_secs(duration_seconds);
self
}
#[must_use]
pub fn with_port(mut self, port: u16) -> Self {
self.port = port;
self
}
#[must_use]
pub fn with_ip(mut self, ip: String) -> Self {
self.ip = ip;
self
}
pub async fn serve_async(self, factory: impl StateroomServiceFactory) -> std::io::Result<()> {
let server_state = Arc::new(ServerState::new(factory));
let app = Router::new()
.route("/ws", get(serve_websocket))
.with_state(server_state);
let ip = self.ip.parse::<IpAddr>().unwrap();
let addr = SocketAddr::new(ip, self.port);
let listener = TcpListener::bind(&addr).await?;
axum::serve(listener, app).await?;
Ok(())
}
pub fn serve(self, factory: impl StateroomServiceFactory) -> std::io::Result<()> {
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap()
.block_on(async { self.serve_async(factory).await })
}
}
pub async fn serve_websocket(
ws: WebSocketUpgrade,
State(state): State<Arc<ServerState>>,
) -> axum::response::Response {
ws.on_upgrade(move |socket| handle_socket(socket, state))
}
async fn handle_socket(mut socket: WebSocket, state: Arc<ServerState>) {
let (send, mut recv, client_id) = state.connect();
loop {
select! {
msg = recv.recv() => {
match msg {
Some(msg) => socket.send(msg).await.unwrap(),
None => break,
}
},
msg = socket.recv() => {
match msg {
Some(Ok(msg)) => send.send(Event::Message { client: client_id, message: msg }).await.unwrap(),
Some(Err(_)) => todo!("Error receiving message from client."),
None => break,
}
}
}
}
state.remove(&client_id);
}