use crate::app::AppContext;
use super::connection::Connection;
use super::manager::{ConnectionManager, ConnectionHandle};
use super::message::Message;
use super::traits::WebSocketHandler;
use axum::{
extract::{ws::WebSocketUpgrade, State},
routing::get,
Router,
};
use futures::StreamExt;
use futures::SinkExt;
use std::sync::Arc;
use tokio::sync::mpsc;
use uuid::Uuid;
pub fn ws<H>(
path: &str,
handler: H,
manager: Arc<ConnectionManager>,
) -> Router<AppContext>
where
H: WebSocketHandler,
{
let handler = Arc::new(handler);
Router::new().route(
path,
get(move |upgrade: WebSocketUpgrade, State(ctx): State<AppContext>| {
let handler = handler.clone();
let manager = manager.clone();
let ctx = Arc::new(ctx);
async move {
upgrade.on_upgrade(move |socket| {
handle_socket(socket, handler, manager, ctx)
})
}
}),
)
}
async fn handle_socket<H: WebSocketHandler>(
socket: axum::extract::ws::WebSocket,
handler: Arc<H>,
manager: Arc<ConnectionManager>,
ctx: Arc<AppContext>,
) {
let conn_id = Uuid::new_v4().to_string();
let channel_capacity = 1000;
let (tx, mut rx) = mpsc::channel::<Message>(channel_capacity);
let (mut ws_sender, mut ws_receiver) = socket.split();
let conn_handle: ConnectionHandle = Arc::new(tokio::sync::RwLock::new(
Connection::new(conn_id.clone(), tx.clone()),
));
if let Err(e) = manager.register(conn_handle.clone()).await {
tracing::warn!(conn_id = %conn_id, error = %e, "Failed to register connection");
let _ = ws_sender.close().await;
return;
}
{
let mut conn = conn_handle.write().await;
if let Err(e) = handler.on_connect(&mut *conn, &ctx).await {
tracing::error!(conn_id = %conn_id, error = %e, "Error in on_connect");
let _ = manager.unregister(&conn_id).await;
return;
}
}
let (cleanup_tx, mut cleanup_rx) = tokio::sync::oneshot::channel::<()>();
let cleanup_tx = Arc::new(tokio::sync::Mutex::new(Some(cleanup_tx)));
let heartbeat_interval = tokio::time::Duration::from_secs(30);
let heartbeat_timeout = tokio::time::Duration::from_secs(60);
let last_pong = Arc::new(tokio::sync::RwLock::new(std::time::Instant::now()));
let heartbeat_task = {
let conn_handle = conn_handle.clone();
let manager = manager.clone();
let conn_id = conn_id.clone();
let last_pong = last_pong.clone();
let cleanup_tx = cleanup_tx.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(heartbeat_interval);
loop {
interval.tick().await;
let last_pong_time = *last_pong.read().await;
if last_pong_time.elapsed() > heartbeat_timeout {
tracing::warn!(conn_id = %conn_id, "Connection heartbeat timeout");
if let Some(tx) = cleanup_tx.lock().await.take() {
let _ = tx.send(());
let _ = manager.unregister(&conn_id).await;
}
break;
}
if let Ok(conn_guard) = conn_handle.try_read() {
let _ = conn_guard.send(Message::Ping(vec![])).await;
}
}
})
};
let send_task = tokio::spawn({
let conn_id = conn_id.clone();
let manager = manager.clone();
let cleanup_tx = cleanup_tx.clone();
async move {
while let Some(msg) = rx.recv().await {
let axum_msg = msg.into_axum();
if ws_sender.send(axum_msg).await.is_err() {
break;
}
}
if let Some(tx) = cleanup_tx.lock().await.take() {
let _ = tx.send(());
let _ = manager.unregister(&conn_id).await;
}
}
});
let recv_task = tokio::spawn({
let handler = handler.clone();
let conn_handle = conn_handle.clone();
let ctx = ctx.clone();
let manager = manager.clone();
let conn_id = conn_id.clone();
let cleanup_tx = cleanup_tx.clone();
async move {
while let Some(result) = ws_receiver.next().await {
match result {
Ok(axum_msg) => {
let msg = Message::from_axum(axum_msg);
if let Message::Ping(data) = msg {
{
let conn = conn_handle.read().await;
let _ = conn.send(Message::Pong(data)).await;
}
continue;
}
if matches!(msg, Message::Pong(_)) {
*last_pong.write().await = std::time::Instant::now();
continue;
}
if matches!(msg, Message::Close(_)) {
break;
}
{
let mut conn = conn_handle.write().await;
if let Err(e) = handler.on_message(&mut *conn, msg, &ctx).await {
tracing::error!(conn_id = %conn_id, error = %e, "Error in on_message");
}
}
}
Err(e) => {
tracing::error!(conn_id = %conn_id, error = %e, "WebSocket receive error");
break;
}
}
}
{
let mut conn = conn_handle.write().await;
handler.on_disconnect(&mut *conn, &ctx).await;
}
if let Some(tx) = cleanup_tx.lock().await.take() {
let _ = tx.send(());
let _ = manager.unregister(&conn_id).await;
}
}
});
tokio::pin!(send_task);
tokio::pin!(recv_task);
tokio::pin!(heartbeat_task);
tokio::select! {
_ = send_task.as_mut() => {
recv_task.abort();
heartbeat_task.abort();
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
}
_ = recv_task.as_mut() => {
send_task.abort();
heartbeat_task.abort();
}
_ = heartbeat_task.as_mut() => {
send_task.abort();
recv_task.abort();
}
}
if cleanup_rx.try_recv().is_err() {
let _ = manager.unregister(&conn_id).await;
}
}