conclavelib/server/
wss.rs1use std::{path::PathBuf, sync::Arc, time::Duration};
10
11use anyhow::Context as _;
12use axum::{
13 Router,
14 extract::{
15 State,
16 ws::{Message, WebSocket, WebSocketUpgrade},
17 },
18 response::Response,
19 routing::get,
20};
21use tokio::{net::TcpListener, sync::mpsc};
22
23use crate::{
24 base::{Constant, Void},
25 protocol,
26 store::Store,
27};
28
29use super::{hub::Hub, session::run_session};
30
31const REAP_INTERVAL: Duration = Duration::from_secs(15);
33const IDLE_TIMEOUT: Duration = Duration::from_secs(60);
35
36pub struct ServerConfig {
38 pub bind: String,
40 pub data_dir: Option<PathBuf>,
42 pub admins: super::AdminAllowlist,
45}
46
47pub async fn serve(config: ServerConfig) -> Void {
54 let store = match &config.data_dir {
55 Some(path) => Store::open(path).await?,
56 None => Store::open_in_memory().await?,
57 };
58 for (name, pin) in &config.admins {
59 if pin.is_none() {
60 tracing::warn!(admin = %name, "admin username is unpinned and can be squatted by the first client to register it; pin it as `--admin <user>=<pubkey>`");
61 }
62 }
63 let hub = Hub::new(store, config.admins).await?;
64
65 spawn_reaper(Arc::clone(&hub));
66
67 let app = Router::new().route("/", get(ws_handler)).route("/health", get(health)).with_state(hub);
68 let listener = TcpListener::bind(&config.bind).await.with_context(|| format!("failed to bind `{}`", config.bind))?;
69 let addr = listener.local_addr().context("failed to read the bound address")?;
70 tracing::info!(%addr, "conclave server listening");
71
72 axum::serve(listener, app).with_graceful_shutdown(shutdown_signal()).await.context("server terminated with an error")?;
73 Ok(())
74}
75
76fn spawn_reaper(hub: Arc<Hub>) {
78 tokio::spawn(async move {
79 let mut ticker = tokio::time::interval(REAP_INTERVAL);
80 loop {
81 ticker.tick().await;
82 let reaped = hub.reap_idle(IDLE_TIMEOUT);
83 if reaped > 0 {
84 tracing::debug!(reaped, "reaped idle sessions");
85 }
86 }
87 });
88}
89
90async fn health() -> &'static str {
93 "ok"
94}
95
96async fn ws_handler(ws: WebSocketUpgrade, State(hub): State<Arc<Hub>>) -> Response {
98 ws.max_message_size(Constant::MAX_FRAME_SIZE)
101 .max_frame_size(Constant::MAX_FRAME_SIZE)
102 .on_upgrade(move |socket| handle_ws(hub, socket))
103}
104
105async fn handle_ws(hub: Arc<Hub>, socket: WebSocket) {
107 use futures_util::{SinkExt as _, StreamExt as _};
108
109 let (mut sink, mut stream) = socket.split();
110 let (inbound_tx, inbound_rx) = mpsc::unbounded_channel();
111 let (outbound_tx, mut outbound_rx) = mpsc::channel(super::session::OUTBOUND_CAPACITY);
112
113 let read_task = tokio::spawn(async move {
114 while let Some(Ok(message)) = stream.next().await {
115 match message {
116 Message::Binary(data) => match protocol::decode(&data) {
117 Ok(frame) => {
118 if inbound_tx.send(frame).is_err() {
119 break;
120 }
121 }
122 Err(_) => break,
123 },
124 Message::Close(_) => break,
125 _ => {}
127 }
128 }
129 });
130
131 let write_task = tokio::spawn(async move {
132 while let Some(frame) = outbound_rx.recv().await {
133 let Ok(bytes) = protocol::encode(&frame) else { break };
134 if sink.send(Message::Binary(bytes.into())).await.is_err() {
135 break;
136 }
137 }
138 let _ = sink.close().await;
139 });
140
141 run_session(hub, inbound_rx, outbound_tx).await;
142 read_task.abort();
145 let _ = write_task.await;
146}
147
148async fn shutdown_signal() {
152 #[cfg(unix)]
153 {
154 let mut sigterm = match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
155 Ok(sigterm) => sigterm,
156 Err(error) => {
157 tracing::error!(%error, "failed to install the SIGTERM handler; falling back to Ctrl-C only");
158 let _ = tokio::signal::ctrl_c().await;
159 tracing::info!("shutdown signal received; draining connections");
160 return;
161 }
162 };
163 tokio::select! {
164 _ = tokio::signal::ctrl_c() => {}
165 _ = sigterm.recv() => {}
166 }
167 }
168 #[cfg(not(unix))]
169 let _ = tokio::signal::ctrl_c().await;
170
171 tracing::info!("shutdown signal received; draining connections");
172}