use std::sync::Arc;
use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
use axum::http::HeaderMap;
use axum::routing::{get, MethodRouter};
use futures::{SinkExt, StreamExt};
use tokio::sync::{mpsc, oneshot};
use crate::core::engine::FrozenDiContainer;
use crate::realtime::connection::{ConnectionRegistry, WsClient, WsMessage};
use crate::realtime::gateway::GatewayRuntime;
use crate::web::context::Claims;
#[derive(Clone, Copy, Debug)]
pub struct WsTuning {
pub outbound_buffer: usize,
pub max_connections: usize,
pub ping_interval: std::time::Duration,
}
pub fn ws_route(
runtime: &'static GatewayRuntime,
registry: &'static ConnectionRegistry,
container: &'static FrozenDiContainer,
tuning: WsTuning,
) -> MethodRouter {
let handler = move |ws: WebSocketUpgrade, headers: HeaderMap| async move {
if tuning.max_connections > 0 && registry.connection_count() >= tuning.max_connections {
metrics::counter!("ws_upgrades_refused_total").increment(1);
return axum::http::Response::builder()
.status(503)
.header("retry-after", "5")
.body(axum::body::Body::from("websocket capacity reached"))
.expect("static refusal");
}
let provenance = crate::pipeline::Provenance::from_headers(&headers, container).await;
tracing::debug!(
trace_id = %crate::observability::lean_telemetry::hex_encode(&provenance.trace.trace_id),
tenant = provenance.tenant.as_deref().map(|t| t.id.as_str()).unwrap_or(""),
"WS handshake provenance"
);
ws.on_upgrade(move |socket| {
handle_socket(
socket,
runtime,
registry,
provenance.claims,
provenance.tenant,
tuning,
)
})
};
get(handler)
}
async fn handle_socket(
socket: WebSocket,
runtime: &'static GatewayRuntime,
registry: &'static ConnectionRegistry,
claims: Option<Arc<Claims>>,
tenant: Option<Arc<crate::web::tenant::TenantConfig>>,
tuning: WsTuning,
) {
let (mut sink, mut stream) = socket.split();
let (tx, mut rx) = mpsc::channel::<WsMessage>(tuning.outbound_buffer.max(1));
let id = registry.register(tx, claims.clone());
let client = WsClient::__new(id, registry, claims, tenant);
let (close_tx, mut close_rx) = oneshot::channel::<()>();
let ping_every = tuning.ping_interval;
let writer = tokio::spawn(async move {
let mut ping = (!ping_every.is_zero()).then(|| {
let mut t = tokio::time::interval(ping_every);
t.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
t
});
loop {
let msg = if let Some(t) = ping.as_mut() {
tokio::select! {
m = rx.recv() => m,
_ = t.tick() => {
if sink.send(Message::Ping(Vec::new())).await.is_err() {
break;
}
continue;
}
}
} else {
rx.recv().await
};
let Some(msg) = msg else { break };
let frame = match msg {
WsMessage::Text(arc) => Message::Text(String::from(arc.as_ref())),
WsMessage::Ping => Message::Ping(Vec::new()),
WsMessage::Close => {
let _ = sink.send(Message::Close(None)).await;
break;
}
};
if sink.send(frame).await.is_err() {
break;
}
}
drop(close_tx);
});
(runtime.on_connect)(client.clone()).await;
loop {
tokio::select! {
biased;
_ = &mut close_rx => break,
frame = stream.next() => match frame {
None => break,
Some(Err(_)) => break,
Some(Ok(frame)) => {
registry.touch(id);
match frame {
Message::Text(text) => dispatch_event(runtime, &client, &text).await,
Message::Binary(_) => { }
Message::Close(_) => break,
Message::Ping(_) | Message::Pong(_) => { }
}
}
}
}
}
(runtime.on_disconnect)(client.clone()).await;
registry.unregister(id);
writer.abort();
}
async fn dispatch_event(runtime: &'static GatewayRuntime, client: &WsClient, raw: &str) {
let Ok(value) = serde_json::from_str::<serde_json::Value>(raw) else {
return;
};
let Some(event) = value.get("event").and_then(|e| e.as_str()) else {
return;
};
let Some(handler) = runtime.handler(event) else {
return;
};
let data = value
.get("data")
.cloned()
.unwrap_or(serde_json::Value::Null);
let data_str: Arc<str> = Arc::from(serde_json::to_string(&data).unwrap_or_default());
metrics::counter!("ws_messages_in_total").increment(1);
if let Err(e) = handler(client.clone(), data_str).await {
metrics::counter!("ws_handler_errors_total").increment(1);
tracing::debug!(conn = client.id(), event, error = %e, "gateway handler error");
}
}