use std::{path::PathBuf, sync::Arc, time::Duration};
use anyhow::Context as _;
use axum::{
Router,
extract::{
State,
ws::{Message, WebSocket, WebSocketUpgrade},
},
response::Response,
routing::get,
};
use tokio::{net::TcpListener, sync::mpsc};
use crate::{
base::{Constant, Void},
protocol,
store::Store,
};
use super::{hub::Hub, session::run_session};
const REAP_INTERVAL: Duration = Duration::from_secs(15);
const IDLE_TIMEOUT: Duration = Duration::from_secs(60);
const HISTORY_RETENTION: Duration = Duration::from_secs(7 * 24 * 60 * 60);
const HISTORY_PURGE_INTERVAL: Duration = Duration::from_secs(60 * 60);
pub struct ServerConfig {
pub bind: String,
pub data_dir: Option<PathBuf>,
pub admins: super::AdminAllowlist,
}
pub async fn serve(config: ServerConfig) -> Void {
let store = match &config.data_dir {
Some(path) => Store::open(path).await?,
None => Store::open_in_memory().await?,
};
for (name, pin) in &config.admins {
if pin.is_none() {
tracing::warn!(admin = %name, "admin username is unpinned and can be squatted by the first client to register it; pin it as `--admin <user>=<pubkey>`");
}
}
spawn_history_purge(store.clone());
let hub = Hub::new(store, config.admins).await?;
spawn_reaper(Arc::clone(&hub));
let app = router(hub);
let listener = TcpListener::bind(&config.bind).await.with_context(|| format!("failed to bind `{}`", config.bind))?;
let addr = listener.local_addr().context("failed to read the bound address")?;
tracing::info!(%addr, "conclave server listening");
axum::serve(listener, app).with_graceful_shutdown(shutdown_signal()).await.context("server terminated with an error")?;
Ok(())
}
fn spawn_reaper(hub: Arc<Hub>) {
tokio::spawn(async move {
let mut ticker = tokio::time::interval(REAP_INTERVAL);
loop {
ticker.tick().await;
let reaped = hub.reap_idle(IDLE_TIMEOUT);
if reaped > 0 {
tracing::debug!(reaped, "reaped idle sessions");
}
}
});
}
fn spawn_history_purge(store: Store) {
tokio::spawn(async move {
let mut ticker = tokio::time::interval(HISTORY_PURGE_INTERVAL);
loop {
ticker.tick().await;
let cutoff = chrono::Utc::now().timestamp_millis().saturating_sub(i64::try_from(HISTORY_RETENTION.as_millis()).unwrap_or(i64::MAX));
if let Err(err) = store.purge_messages_before(cutoff).await {
tracing::warn!(error = %err, "history retention sweep failed");
}
}
});
}
async fn health() -> &'static str {
"ok"
}
fn router(hub: Arc<Hub>) -> Router {
Router::new().route("/", get(ws_handler)).route("/health", get(health)).with_state(hub)
}
async fn ws_handler(ws: WebSocketUpgrade, State(hub): State<Arc<Hub>>) -> Response {
let instance_id = axum::http::HeaderValue::from_str(hub.instance_id()).ok();
let mut response = ws
.max_message_size(Constant::MAX_FRAME_SIZE)
.max_frame_size(Constant::MAX_FRAME_SIZE)
.on_upgrade(move |socket| handle_ws(hub, socket));
if let Some(id) = instance_id {
response.headers_mut().insert(Constant::SERVER_ID_HEADER, id);
}
response
}
async fn handle_ws(hub: Arc<Hub>, socket: WebSocket) {
use futures_util::{SinkExt as _, StreamExt as _};
let (mut sink, mut stream) = socket.split();
let (inbound_tx, inbound_rx) = mpsc::unbounded_channel();
let (outbound_tx, mut outbound_rx) = mpsc::channel(super::session::OUTBOUND_CAPACITY);
let read_task = tokio::spawn(async move {
while let Some(Ok(message)) = stream.next().await {
match message {
Message::Binary(data) => match protocol::decode(&data) {
Ok(frame) => {
if inbound_tx.send(frame).is_err() {
break;
}
}
Err(_) => break,
},
Message::Close(_) => break,
_ => {}
}
}
});
let write_task = tokio::spawn(async move {
while let Some(frame) = outbound_rx.recv().await {
let Ok(bytes) = protocol::encode(&frame) else { break };
if sink.send(Message::Binary(bytes.into())).await.is_err() {
break;
}
}
let _ = sink.close().await;
});
run_session(hub, inbound_rx, outbound_tx).await;
read_task.abort();
let _ = write_task.await;
}
async fn shutdown_signal() {
#[cfg(unix)]
{
let mut sigterm = match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
Ok(sigterm) => sigterm,
Err(error) => {
tracing::error!(%error, "failed to install the SIGTERM handler; falling back to Ctrl-C only");
let _ = tokio::signal::ctrl_c().await;
tracing::info!("shutdown signal received; draining connections");
return;
}
};
tokio::select! {
_ = tokio::signal::ctrl_c() => {}
_ = sigterm.recv() => {}
}
}
#[cfg(not(unix))]
let _ = tokio::signal::ctrl_c().await;
tracing::info!("shutdown signal received; draining connections");
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use super::*;
#[tokio::test]
async fn wss_upgrade_response_carries_the_server_instance_id() {
let store = Store::open_in_memory().await.unwrap();
let expected = store.instance_id().await.unwrap();
let hub = Hub::new(store, super::super::AdminAllowlist::default()).await.unwrap();
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
axum::serve(listener, router(hub)).await.unwrap();
});
for _ in 0..2 {
let (_ws, response) = tokio_tungstenite::connect_async(format!("ws://{addr}/")).await.unwrap();
let got = response.headers().get(Constant::SERVER_ID_HEADER).expect("the upgrade response must carry the instance-id header");
assert_eq!(got.to_str().unwrap(), expected);
}
}
}