use std::sync::Arc;
use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
use axum::http::HeaderMap;
use axum::routing::{get, MethodRouter};
use futures::{SinkExt, StreamExt};
use tokio::sync::{mpsc, oneshot};
use crate::core::engine::FrozenDiContainer;
use crate::realtime::connection::{ConnectionRegistry, WsClient, WsMessage};
use crate::realtime::gateway::GatewayRuntime;
use crate::web::context::Claims;
pub fn ws_route(
runtime: &'static GatewayRuntime,
registry: &'static ConnectionRegistry,
container: &'static FrozenDiContainer,
) -> MethodRouter {
let handler = move |ws: WebSocketUpgrade, headers: HeaderMap| async move {
let claims = crate::auth::extract::extract_claims(&headers, container);
ws.on_upgrade(move |socket| handle_socket(socket, runtime, registry, claims))
};
get(handler)
}
async fn handle_socket(
socket: WebSocket,
runtime: &'static GatewayRuntime,
registry: &'static ConnectionRegistry,
claims: Option<Arc<Claims>>,
) {
let (mut sink, mut stream) = socket.split();
let (tx, mut rx) = mpsc::unbounded_channel::<WsMessage>();
let id = registry.register(tx, claims.clone());
let client = WsClient::__new(id, registry, claims);
let (close_tx, mut close_rx) = oneshot::channel::<()>();
let writer = tokio::spawn(async move {
while let Some(msg) = rx.recv().await {
let frame = match msg {
WsMessage::Text(arc) => Message::Text(String::from(arc.as_ref())),
WsMessage::Close => {
let _ = sink.send(Message::Close(None)).await;
break;
}
};
if sink.send(frame).await.is_err() {
break;
}
}
drop(close_tx);
});
(runtime.on_connect)(client.clone()).await;
loop {
tokio::select! {
biased;
_ = &mut close_rx => break,
frame = stream.next() => match frame {
None => break,
Some(Err(_)) => break,
Some(Ok(frame)) => match frame {
Message::Text(text) => dispatch_event(runtime, &client, &text).await,
Message::Binary(_) => { }
Message::Close(_) => break,
Message::Ping(_) | Message::Pong(_) => { }
},
}
}
}
(runtime.on_disconnect)(client.clone()).await;
registry.unregister(id);
writer.abort();
}
async fn dispatch_event(runtime: &'static GatewayRuntime, client: &WsClient, raw: &str) {
let Ok(value) = serde_json::from_str::<serde_json::Value>(raw) else {
return;
};
let Some(event) = value.get("event").and_then(|e| e.as_str()) else {
return;
};
let Some(handler) = runtime.handler(event) else {
return;
};
let data = value
.get("data")
.cloned()
.unwrap_or(serde_json::Value::Null);
let data_str: Arc<str> = Arc::from(serde_json::to_string(&data).unwrap_or_default());
let _ = handler(client.clone(), data_str).await;
}