conclavelib/server/
wss.rs1use std::{path::PathBuf, sync::Arc, time::Duration};
10
11use anyhow::Context as _;
12use axum::{
13 Router,
14 extract::{
15 State,
16 ws::{Message, WebSocket, WebSocketUpgrade},
17 },
18 response::Response,
19 routing::get,
20};
21use tokio::{net::TcpListener, sync::mpsc};
22
23use crate::{
24 base::{Constant, Void},
25 protocol,
26 store::Store,
27};
28
29use super::{hub::Hub, session::run_session};
30
31const REAP_INTERVAL: Duration = Duration::from_secs(15);
33const IDLE_TIMEOUT: Duration = Duration::from_secs(60);
35const HISTORY_RETENTION: Duration = Duration::from_secs(7 * 24 * 60 * 60);
37const HISTORY_PURGE_INTERVAL: Duration = Duration::from_secs(60 * 60);
39
40pub struct ServerConfig {
42 pub bind: String,
44 pub data_dir: Option<PathBuf>,
46 pub admins: super::AdminAllowlist,
49}
50
51pub async fn serve(config: ServerConfig) -> Void {
58 let store = match &config.data_dir {
59 Some(path) => Store::open(path).await?,
60 None => Store::open_in_memory().await?,
61 };
62 for (name, pin) in &config.admins {
63 if pin.is_none() {
64 tracing::warn!(admin = %name, "admin username is unpinned and can be squatted by the first client to register it; pin it as `--admin <user>=<pubkey>`");
65 }
66 }
67 spawn_history_purge(store.clone());
68 let hub = Hub::new(store, config.admins).await?;
69
70 spawn_reaper(Arc::clone(&hub));
71
72 let app = router(hub);
73 let listener = TcpListener::bind(&config.bind).await.with_context(|| format!("failed to bind `{}`", config.bind))?;
74 let addr = listener.local_addr().context("failed to read the bound address")?;
75 tracing::info!(%addr, "conclave server listening");
76
77 axum::serve(listener, app).with_graceful_shutdown(shutdown_signal()).await.context("server terminated with an error")?;
78 Ok(())
79}
80
81fn spawn_reaper(hub: Arc<Hub>) {
83 tokio::spawn(async move {
84 let mut ticker = tokio::time::interval(REAP_INTERVAL);
85 loop {
86 ticker.tick().await;
87 let reaped = hub.reap_idle(IDLE_TIMEOUT);
88 if reaped > 0 {
89 tracing::debug!(reaped, "reaped idle sessions");
90 }
91 }
92 });
93}
94
95fn spawn_history_purge(store: Store) {
97 tokio::spawn(async move {
98 let mut ticker = tokio::time::interval(HISTORY_PURGE_INTERVAL);
99 loop {
100 ticker.tick().await;
101 let cutoff = chrono::Utc::now().timestamp_millis().saturating_sub(i64::try_from(HISTORY_RETENTION.as_millis()).unwrap_or(i64::MAX));
102 if let Err(err) = store.purge_messages_before(cutoff).await {
103 tracing::warn!(error = %err, "history retention sweep failed");
104 }
105 }
106 });
107}
108
109async fn health() -> &'static str {
112 "ok"
113}
114
115fn router(hub: Arc<Hub>) -> Router {
117 Router::new().route("/", get(ws_handler)).route("/health", get(health)).with_state(hub)
118}
119
120async fn ws_handler(ws: WebSocketUpgrade, State(hub): State<Arc<Hub>>) -> Response {
122 let instance_id = axum::http::HeaderValue::from_str(hub.instance_id()).ok();
125 let mut response = ws
128 .max_message_size(Constant::MAX_FRAME_SIZE)
129 .max_frame_size(Constant::MAX_FRAME_SIZE)
130 .on_upgrade(move |socket| handle_ws(hub, socket));
131 if let Some(id) = instance_id {
132 response.headers_mut().insert(Constant::SERVER_ID_HEADER, id);
133 }
134 response
135}
136
137async fn handle_ws(hub: Arc<Hub>, socket: WebSocket) {
139 use futures_util::{SinkExt as _, StreamExt as _};
140
141 let (mut sink, mut stream) = socket.split();
142 let (inbound_tx, inbound_rx) = mpsc::unbounded_channel();
143 let (outbound_tx, mut outbound_rx) = mpsc::channel(super::session::OUTBOUND_CAPACITY);
144
145 let read_task = tokio::spawn(async move {
146 while let Some(Ok(message)) = stream.next().await {
147 match message {
148 Message::Binary(data) => match protocol::decode(&data) {
149 Ok(frame) => {
150 if inbound_tx.send(frame).is_err() {
151 break;
152 }
153 }
154 Err(_) => break,
155 },
156 Message::Close(_) => break,
157 _ => {}
159 }
160 }
161 });
162
163 let write_task = tokio::spawn(async move {
164 while let Some(frame) = outbound_rx.recv().await {
165 let Ok(bytes) = protocol::encode(&frame) else { break };
166 if sink.send(Message::Binary(bytes.into())).await.is_err() {
167 break;
168 }
169 }
170 let _ = sink.close().await;
171 });
172
173 run_session(hub, inbound_rx, outbound_tx).await;
174 read_task.abort();
177 let _ = write_task.await;
178}
179
180async fn shutdown_signal() {
184 #[cfg(unix)]
185 {
186 let mut sigterm = match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
187 Ok(sigterm) => sigterm,
188 Err(error) => {
189 tracing::error!(%error, "failed to install the SIGTERM handler; falling back to Ctrl-C only");
190 let _ = tokio::signal::ctrl_c().await;
191 tracing::info!("shutdown signal received; draining connections");
192 return;
193 }
194 };
195 tokio::select! {
196 _ = tokio::signal::ctrl_c() => {}
197 _ = sigterm.recv() => {}
198 }
199 }
200 #[cfg(not(unix))]
201 let _ = tokio::signal::ctrl_c().await;
202
203 tracing::info!("shutdown signal received; draining connections");
204}
205
206#[cfg(test)]
207mod tests {
208 #![allow(clippy::unwrap_used)]
210
211 use super::*;
212
213 #[tokio::test]
217 async fn wss_upgrade_response_carries_the_server_instance_id() {
218 let store = Store::open_in_memory().await.unwrap();
219 let expected = store.instance_id().await.unwrap();
220 let hub = Hub::new(store, super::super::AdminAllowlist::default()).await.unwrap();
221
222 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
223 let addr = listener.local_addr().unwrap();
224 tokio::spawn(async move {
225 axum::serve(listener, router(hub)).await.unwrap();
226 });
227
228 for _ in 0..2 {
230 let (_ws, response) = tokio_tungstenite::connect_async(format!("ws://{addr}/")).await.unwrap();
231 let got = response.headers().get(Constant::SERVER_ID_HEADER).expect("the upgrade response must carry the instance-id header");
232 assert_eq!(got.to_str().unwrap(), expected);
233 }
234 }
235}