conclavelib/server/
wss.rs1use std::{path::PathBuf, sync::Arc, time::Duration};
10
11use anyhow::Context as _;
12use axum::{
13 Router,
14 extract::{
15 State,
16 ws::{Message, WebSocket, WebSocketUpgrade},
17 },
18 response::Response,
19 routing::get,
20};
21use tokio::{net::TcpListener, sync::mpsc};
22
23use crate::{
24 base::{Constant, Void},
25 protocol,
26 store::Store,
27};
28
29use super::{hub::Hub, session::run_session};
30
31const REAP_INTERVAL: Duration = Duration::from_secs(15);
33const IDLE_TIMEOUT: Duration = Duration::from_secs(60);
35
36pub struct ServerConfig {
38 pub bind: String,
40 pub data_dir: Option<PathBuf>,
42 pub admins: super::AdminAllowlist,
45}
46
47pub async fn serve(config: ServerConfig) -> Void {
54 let store = match &config.data_dir {
55 Some(path) => Store::open(path).await?,
56 None => Store::open_in_memory().await?,
57 };
58 for (name, pin) in &config.admins {
59 if pin.is_none() {
60 tracing::warn!(admin = %name, "admin username is unpinned and can be squatted by the first client to register it; pin it as `--admin <user>=<pubkey>`");
61 }
62 }
63 let hub = Hub::new(store, config.admins).await?;
64
65 spawn_reaper(Arc::clone(&hub));
66
67 let app = router(hub);
68 let listener = TcpListener::bind(&config.bind).await.with_context(|| format!("failed to bind `{}`", config.bind))?;
69 let addr = listener.local_addr().context("failed to read the bound address")?;
70 tracing::info!(%addr, "conclave server listening");
71
72 axum::serve(listener, app).with_graceful_shutdown(shutdown_signal()).await.context("server terminated with an error")?;
73 Ok(())
74}
75
76fn spawn_reaper(hub: Arc<Hub>) {
78 tokio::spawn(async move {
79 let mut ticker = tokio::time::interval(REAP_INTERVAL);
80 loop {
81 ticker.tick().await;
82 let reaped = hub.reap_idle(IDLE_TIMEOUT);
83 if reaped > 0 {
84 tracing::debug!(reaped, "reaped idle sessions");
85 }
86 }
87 });
88}
89
90async fn health() -> &'static str {
93 "ok"
94}
95
96fn router(hub: Arc<Hub>) -> Router {
98 Router::new().route("/", get(ws_handler)).route("/health", get(health)).with_state(hub)
99}
100
101async fn ws_handler(ws: WebSocketUpgrade, State(hub): State<Arc<Hub>>) -> Response {
103 let instance_id = axum::http::HeaderValue::from_str(hub.instance_id()).ok();
106 let mut response = ws
109 .max_message_size(Constant::MAX_FRAME_SIZE)
110 .max_frame_size(Constant::MAX_FRAME_SIZE)
111 .on_upgrade(move |socket| handle_ws(hub, socket));
112 if let Some(id) = instance_id {
113 response.headers_mut().insert(Constant::SERVER_ID_HEADER, id);
114 }
115 response
116}
117
118async fn handle_ws(hub: Arc<Hub>, socket: WebSocket) {
120 use futures_util::{SinkExt as _, StreamExt as _};
121
122 let (mut sink, mut stream) = socket.split();
123 let (inbound_tx, inbound_rx) = mpsc::unbounded_channel();
124 let (outbound_tx, mut outbound_rx) = mpsc::channel(super::session::OUTBOUND_CAPACITY);
125
126 let read_task = tokio::spawn(async move {
127 while let Some(Ok(message)) = stream.next().await {
128 match message {
129 Message::Binary(data) => match protocol::decode(&data) {
130 Ok(frame) => {
131 if inbound_tx.send(frame).is_err() {
132 break;
133 }
134 }
135 Err(_) => break,
136 },
137 Message::Close(_) => break,
138 _ => {}
140 }
141 }
142 });
143
144 let write_task = tokio::spawn(async move {
145 while let Some(frame) = outbound_rx.recv().await {
146 let Ok(bytes) = protocol::encode(&frame) else { break };
147 if sink.send(Message::Binary(bytes.into())).await.is_err() {
148 break;
149 }
150 }
151 let _ = sink.close().await;
152 });
153
154 run_session(hub, inbound_rx, outbound_tx).await;
155 read_task.abort();
158 let _ = write_task.await;
159}
160
161async fn shutdown_signal() {
165 #[cfg(unix)]
166 {
167 let mut sigterm = match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
168 Ok(sigterm) => sigterm,
169 Err(error) => {
170 tracing::error!(%error, "failed to install the SIGTERM handler; falling back to Ctrl-C only");
171 let _ = tokio::signal::ctrl_c().await;
172 tracing::info!("shutdown signal received; draining connections");
173 return;
174 }
175 };
176 tokio::select! {
177 _ = tokio::signal::ctrl_c() => {}
178 _ = sigterm.recv() => {}
179 }
180 }
181 #[cfg(not(unix))]
182 let _ = tokio::signal::ctrl_c().await;
183
184 tracing::info!("shutdown signal received; draining connections");
185}
186
187#[cfg(test)]
188mod tests {
189 #![allow(clippy::unwrap_used)]
191
192 use super::*;
193
194 #[tokio::test]
198 async fn wss_upgrade_response_carries_the_server_instance_id() {
199 let store = Store::open_in_memory().await.unwrap();
200 let expected = store.instance_id().await.unwrap();
201 let hub = Hub::new(store, super::super::AdminAllowlist::default()).await.unwrap();
202
203 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
204 let addr = listener.local_addr().unwrap();
205 tokio::spawn(async move {
206 axum::serve(listener, router(hub)).await.unwrap();
207 });
208
209 for _ in 0..2 {
211 let (_ws, response) = tokio_tungstenite::connect_async(format!("ws://{addr}/")).await.unwrap();
212 let got = response.headers().get(Constant::SERVER_ID_HEADER).expect("the upgrade response must carry the instance-id header");
213 assert_eq!(got.to_str().unwrap(), expected);
214 }
215 }
216}