Skip to main content

conclavelib/server/
wss.rs

1//! The production transport: the axum WebSocket endpoint and the `conclave serve` entrypoint.
2//!
3//! TLS terminates at cloudflared and the origin hop is local loopback (DESIGN.md §11/§12), so this
4//! is a plain-HTTP axum server whose single route upgrades to a WebSocket. Each accepted socket is
5//! split into reader / writer pumps that translate between WS binary messages and
6//! [`ProtocolMessage`](crate::protocol::ProtocolMessage) frames, then driven by the shared
7//! [`run_session`]. A background reaper enforces the idle-heartbeat timeout (DESIGN.md §10).
8
9use std::{path::PathBuf, sync::Arc, time::Duration};
10
11use anyhow::Context as _;
12use axum::{
13    Router,
14    extract::{
15        State,
16        ws::{Message, WebSocket, WebSocketUpgrade},
17    },
18    response::Response,
19    routing::get,
20};
21use tokio::{net::TcpListener, sync::mpsc};
22
23use crate::{
24    base::{Constant, Void},
25    protocol,
26    store::Store,
27};
28
29use super::{hub::Hub, session::run_session};
30
31/// How often the heartbeat reaper sweeps for idle sessions.
32const REAP_INTERVAL: Duration = Duration::from_secs(15);
33/// How long a session may go without any inbound frame before it is reaped (DESIGN.md §10).
34const IDLE_TIMEOUT: Duration = Duration::from_secs(60);
35/// Retained-history window: messages older than this are purged (PRD-0013, constant in v1).
36const HISTORY_RETENTION: Duration = Duration::from_secs(7 * 24 * 60 * 60);
37/// How often the retention sweep runs.
38const HISTORY_PURGE_INTERVAL: Duration = Duration::from_secs(60 * 60);
39
40/// The operator-supplied `serve` configuration (DESIGN.md §7, §13).
41pub struct ServerConfig {
42    /// Address to bind the WebSocket endpoint to (e.g. `127.0.0.1:4390`).
43    pub bind: String,
44    /// Data directory for the embedded store; `None` runs a purely in-memory store.
45    pub data_dir: Option<PathBuf>,
46    /// The server-admin allowlist — usernames that may administer server-wide (§7), each
47    /// optionally pinned to the public key permitted to claim it (see [`super::AdminAllowlist`]).
48    pub admins: super::AdminAllowlist,
49}
50
51/// Runs the central server until a shutdown signal (Ctrl-C) is received.
52///
53/// # Errors
54///
55/// Returns an error if the store cannot be opened, the bind address is unavailable, or the
56/// underlying HTTP server fails.
57pub async fn serve(config: ServerConfig) -> Void {
58    let store = match &config.data_dir {
59        Some(path) => Store::open(path).await?,
60        None => Store::open_in_memory().await?,
61    };
62    for (name, pin) in &config.admins {
63        if pin.is_none() {
64            tracing::warn!(admin = %name, "admin username is unpinned and can be squatted by the first client to register it; pin it as `--admin <user>=<pubkey>`");
65        }
66    }
67    spawn_history_purge(store.clone());
68    let hub = Hub::new(store, config.admins).await?;
69
70    spawn_reaper(Arc::clone(&hub));
71
72    let app = router(hub);
73    let listener = TcpListener::bind(&config.bind).await.with_context(|| format!("failed to bind `{}`", config.bind))?;
74    let addr = listener.local_addr().context("failed to read the bound address")?;
75    tracing::info!(%addr, "conclave server listening");
76
77    axum::serve(listener, app).with_graceful_shutdown(shutdown_signal()).await.context("server terminated with an error")?;
78    Ok(())
79}
80
81/// Spawns the background heartbeat reaper (DESIGN.md §10).
82fn spawn_reaper(hub: Arc<Hub>) {
83    tokio::spawn(async move {
84        let mut ticker = tokio::time::interval(REAP_INTERVAL);
85        loop {
86            ticker.tick().await;
87            let reaped = hub.reap_idle(IDLE_TIMEOUT);
88            if reaped > 0 {
89                tracing::debug!(reaped, "reaped idle sessions");
90            }
91        }
92    });
93}
94
95/// Spawns the background history-retention sweep (PRD-0013): hourly, drop rows past the window.
96fn spawn_history_purge(store: Store) {
97    tokio::spawn(async move {
98        let mut ticker = tokio::time::interval(HISTORY_PURGE_INTERVAL);
99        loop {
100            ticker.tick().await;
101            let cutoff = chrono::Utc::now().timestamp_millis().saturating_sub(i64::try_from(HISTORY_RETENTION.as_millis()).unwrap_or(i64::MAX));
102            if let Err(err) = store.purge_messages_before(cutoff).await {
103                tracing::warn!(error = %err, "history retention sweep failed");
104            }
105        }
106    });
107}
108
109/// A liveness endpoint for platform health checks: the origin is otherwise WS-only, so an HTTP GET
110/// to `/` returns 426 Upgrade Required — a real 200 endpoint fits any platform's HTTP check (T-004).
111async fn health() -> &'static str {
112    "ok"
113}
114
115/// Builds the server's router: the WS upgrade route plus the health endpoint.
116fn router(hub: Arc<Hub>) -> Router {
117    Router::new().route("/", get(ws_handler)).route("/health", get(health)).with_state(hub)
118}
119
120/// The WebSocket upgrade handler; every connection is dispatched to [`handle_ws`].
121async fn ws_handler(ws: WebSocketUpgrade, State(hub): State<Arc<Hub>>) -> Response {
122    // The instance ID rides the upgrade response so a bridge can recognize the same server behind
123    // two URLs before it ever authenticates (PRD-0012 T-003).
124    let instance_id = axum::http::HeaderValue::from_str(hub.instance_id()).ok();
125    // Enforce the protocol's frame cap (Constant::MAX_FRAME_SIZE) instead of tungstenite's 64 MiB
126    // default, so a pre-auth peer cannot force a large buffer per message (finding #17/#19).
127    let mut response = ws
128        .max_message_size(Constant::MAX_FRAME_SIZE)
129        .max_frame_size(Constant::MAX_FRAME_SIZE)
130        .on_upgrade(move |socket| handle_ws(hub, socket));
131    if let Some(id) = instance_id {
132        response.headers_mut().insert(Constant::SERVER_ID_HEADER, id);
133    }
134    response
135}
136
137/// Bridges a WebSocket to [`run_session`]: each WS binary message is one protocol frame.
138async fn handle_ws(hub: Arc<Hub>, socket: WebSocket) {
139    use futures_util::{SinkExt as _, StreamExt as _};
140
141    let (mut sink, mut stream) = socket.split();
142    let (inbound_tx, inbound_rx) = mpsc::unbounded_channel();
143    let (outbound_tx, mut outbound_rx) = mpsc::channel(super::session::OUTBOUND_CAPACITY);
144
145    let read_task = tokio::spawn(async move {
146        while let Some(Ok(message)) = stream.next().await {
147            match message {
148                Message::Binary(data) => match protocol::decode(&data) {
149                    Ok(frame) => {
150                        if inbound_tx.send(frame).is_err() {
151                            break;
152                        }
153                    }
154                    Err(_) => break,
155                },
156                Message::Close(_) => break,
157                // Text / ping / pong are ignored: the heartbeat is an app-level Ping/Pong frame.
158                _ => {}
159            }
160        }
161    });
162
163    let write_task = tokio::spawn(async move {
164        while let Some(frame) = outbound_rx.recv().await {
165            let Ok(bytes) = protocol::encode(&frame) else { break };
166            if sink.send(Message::Binary(bytes.into())).await.is_err() {
167                break;
168            }
169        }
170        let _ = sink.close().await;
171    });
172
173    run_session(hub, inbound_rx, outbound_tx).await;
174    // Await the writer so a final handshake-failure / force-drop frame is flushed and the socket
175    // closed cleanly; abort the reader, which may be parked on an idle-but-open socket.
176    read_task.abort();
177    let _ = write_task.await;
178}
179
180/// Resolves when the process receives Ctrl-C or SIGTERM, driving the graceful shutdown. SIGTERM is
181/// what container platforms (Fly.io, `docker stop`) send on deploy/stop — without handling it the
182/// server ignores the signal, waits out the platform's kill timeout, and dies un-drained.
183async fn shutdown_signal() {
184    #[cfg(unix)]
185    {
186        let mut sigterm = match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
187            Ok(sigterm) => sigterm,
188            Err(error) => {
189                tracing::error!(%error, "failed to install the SIGTERM handler; falling back to Ctrl-C only");
190                let _ = tokio::signal::ctrl_c().await;
191                tracing::info!("shutdown signal received; draining connections");
192                return;
193            }
194        };
195        tokio::select! {
196            _ = tokio::signal::ctrl_c() => {}
197            _ = sigterm.recv() => {}
198        }
199    }
200    #[cfg(not(unix))]
201    let _ = tokio::signal::ctrl_c().await;
202
203    tracing::info!("shutdown signal received; draining connections");
204}
205
206#[cfg(test)]
207mod tests {
208    // Tests relax `unwrap_used` (house convention; DESIGN.md §22).
209    #![allow(clippy::unwrap_used)]
210
211    use super::*;
212
213    /// The WS upgrade response carries the persistent instance ID so a bridge can recognize the
214    /// same server reached under two URLs (PRD-0012 T-003) — an HTTP header, out-of-band of the
215    /// wire protocol, so old peers simply never look at it.
216    #[tokio::test]
217    async fn wss_upgrade_response_carries_the_server_instance_id() {
218        let store = Store::open_in_memory().await.unwrap();
219        let expected = store.instance_id().await.unwrap();
220        let hub = Hub::new(store, super::super::AdminAllowlist::default()).await.unwrap();
221
222        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
223        let addr = listener.local_addr().unwrap();
224        tokio::spawn(async move {
225            axum::serve(listener, router(hub)).await.unwrap();
226        });
227
228        // Stable across connections — two dials see the same ID.
229        for _ in 0..2 {
230            let (_ws, response) = tokio_tungstenite::connect_async(format!("ws://{addr}/")).await.unwrap();
231            let got = response.headers().get(Constant::SERVER_ID_HEADER).expect("the upgrade response must carry the instance-id header");
232            assert_eq!(got.to_str().unwrap(), expected);
233        }
234    }
235}