use std::{collections::HashSet, path::PathBuf, sync::Arc, time::Duration};
use anyhow::Context as _;
use axum::{
Router,
extract::{
State,
ws::{Message, WebSocket, WebSocketUpgrade},
},
response::Response,
routing::get,
};
use tokio::{net::TcpListener, sync::mpsc};
use crate::{base::Void, protocol, store::Store};
use super::{hub::Hub, session::run_session};
const REAP_INTERVAL: Duration = Duration::from_secs(15);
const IDLE_TIMEOUT: Duration = Duration::from_secs(60);
pub struct ServerConfig {
pub bind: String,
pub data_dir: Option<PathBuf>,
pub admins: HashSet<String>,
}
pub async fn serve(config: ServerConfig) -> Void {
let store = match &config.data_dir {
Some(path) => Store::open(path).await?,
None => Store::open_in_memory().await?,
};
let hub = Hub::new(store, config.admins);
spawn_reaper(Arc::clone(&hub));
let app = Router::new().route("/", get(ws_handler)).with_state(hub);
let listener = TcpListener::bind(&config.bind).await.with_context(|| format!("failed to bind `{}`", config.bind))?;
let addr = listener.local_addr().context("failed to read the bound address")?;
tracing::info!(%addr, "conclave server listening");
axum::serve(listener, app).with_graceful_shutdown(shutdown_signal()).await.context("server terminated with an error")?;
Ok(())
}
fn spawn_reaper(hub: Arc<Hub>) {
tokio::spawn(async move {
let mut ticker = tokio::time::interval(REAP_INTERVAL);
loop {
ticker.tick().await;
let reaped = hub.reap_idle(IDLE_TIMEOUT);
if reaped > 0 {
tracing::debug!(reaped, "reaped idle sessions");
}
}
});
}
async fn ws_handler(ws: WebSocketUpgrade, State(hub): State<Arc<Hub>>) -> Response {
ws.on_upgrade(move |socket| handle_ws(hub, socket))
}
async fn handle_ws(hub: Arc<Hub>, socket: WebSocket) {
use futures_util::{SinkExt as _, StreamExt as _};
let (mut sink, mut stream) = socket.split();
let (inbound_tx, inbound_rx) = mpsc::unbounded_channel();
let (outbound_tx, mut outbound_rx) = mpsc::unbounded_channel();
let read_task = tokio::spawn(async move {
while let Some(Ok(message)) = stream.next().await {
match message {
Message::Binary(data) => match protocol::decode(&data) {
Ok(frame) => {
if inbound_tx.send(frame).is_err() {
break;
}
}
Err(_) => break,
},
Message::Close(_) => break,
_ => {}
}
}
});
let write_task = tokio::spawn(async move {
while let Some(frame) = outbound_rx.recv().await {
let Ok(bytes) = protocol::encode(&frame) else { break };
if sink.send(Message::Binary(bytes.into())).await.is_err() {
break;
}
}
let _ = sink.close().await;
});
run_session(hub, inbound_rx, outbound_tx).await;
read_task.abort();
let _ = write_task.await;
}
async fn shutdown_signal() {
let _ = tokio::signal::ctrl_c().await;
tracing::info!("shutdown signal received; draining connections");
}