use std::sync::Arc;
use actix_web::{Error, HttpRequest, HttpResponse, get, web};
use athena_chat::ChatError;
use athena_chat::dto::commands::MarkReadUpTo;
use athena_wss::WsGateway;
use athena_wss::commands::typing_event;
use athena_wss::protocol::{ClientWsCommand, ServerWsEvent};
use futures_util::StreamExt;
use tokio::sync::mpsc;
use tracing::{instrument, warn};
use crate::AppState;
#[get("/wss/gateway")]
#[instrument(name = "wss.upgrade", skip(req, body, state))]
pub async fn gateway_wss_route(
req: HttpRequest,
body: web::Payload,
state: web::Data<AppState>,
) -> Result<HttpResponse, Error> {
let chat_ctx = match state.auth_resolver.resolve(&req, state.get_ref()).await {
Ok(ctx) => ctx,
Err(resp) => return Ok(resp),
};
let (response, mut session, mut message_stream) = actix_ws::handle(&req, body)?;
let (outbound_tx, mut outbound_rx) = mpsc::channel::<ServerWsEvent>(256);
let connection = state
.ws_hub
.register_connection(
chat_ctx.user_id.clone(),
chat_ctx.organization_id.clone(),
chat_ctx.client_name.clone(),
outbound_tx.clone(),
)
.await;
let hub = state.ws_hub.clone();
let writer_session = session.clone();
let writer_connection_id = connection.connection_id;
actix_web::rt::spawn(async move {
let mut writer_session = writer_session;
while let Some(event) = outbound_rx.recv().await {
let payload = match serde_json::to_string(&event) {
Ok(payload) => payload,
Err(err) => {
warn!(error = %err, "failed to serialize websocket event");
continue;
}
};
if writer_session.text(payload).await.is_err() {
break;
}
}
hub.unregister_connection(writer_connection_id).await;
});
let hello = ServerWsEvent::HelloOk {
connection_id: connection.connection_id,
server_time: chrono::Utc::now(),
};
let _ = outbound_tx.send(hello).await;
let state_for_reader = state.clone();
actix_web::rt::spawn(async move {
while let Some(message) = message_stream.next().await {
match message {
Ok(actix_ws::Message::Text(text)) => {
if handle_text_command(
text.to_string(),
&state_for_reader,
&chat_ctx,
connection.clone(),
)
.await
.is_err()
{
break;
}
}
Ok(actix_ws::Message::Ping(bytes)) => {
let _ = session.pong(&bytes).await;
}
Ok(actix_ws::Message::Pong(_)) => {}
Ok(actix_ws::Message::Close(_)) => break,
Ok(actix_ws::Message::Continuation(_)) => break,
Ok(actix_ws::Message::Binary(_)) => {}
Ok(actix_ws::Message::Nop) => {}
Err(err) => {
warn!(error = %err, "websocket reader error");
break;
}
}
}
state_for_reader
.ws_hub
.unregister_connection(connection.connection_id)
.await;
});
Ok(response)
}
async fn handle_text_command(
text: String,
state: &web::Data<AppState>,
chat_ctx: &athena_chat::ChatContext,
connection: Arc<athena_wss::connection::ConnectionHandle>,
) -> Result<(), ()> {
let command: ClientWsCommand = match serde_json::from_str(&text) {
Ok(command) => command,
Err(err) => {
let _ = connection
.outbound
.send(ServerWsEvent::Error {
code: "invalid_json".to_string(),
message: err.to_string(),
})
.await;
return Ok(());
}
};
match command {
ClientWsCommand::AuthHello { .. } => {
let _ = connection
.outbound
.send(ServerWsEvent::HelloOk {
connection_id: connection.connection_id,
server_time: chrono::Utc::now(),
})
.await;
}
ClientWsCommand::Ping { at } => {
let _ = connection.outbound.send(ServerWsEvent::Pong { at }).await;
}
ClientWsCommand::ChatSubscribe { room_id, from_seq } => {
let room = match state.chat_app.get_room(chat_ctx.clone(), room_id).await {
Ok(room) => room,
Err(_) => {
let _ = connection
.outbound
.send(ServerWsEvent::Error {
code: "subscribe_denied".to_string(),
message: "room subscription denied".to_string(),
})
.await;
return Ok(());
}
};
let topic = state
.ws_hub
.ensure_room_topic(&chat_ctx.client_name, room_id)
.await;
let known_room_seq = known_room_seq(room.last_message_seq, topic.current_seq());
topic.observe_seq(known_room_seq);
if let Some(from_seq) = from_seq
&& known_room_seq > from_seq
{
let _ = connection
.outbound
.send(ServerWsEvent::ChatSyncRequired {
room_id,
reason: "resume_gap".to_string(),
expected_from_seq: Some(from_seq.saturating_add(1)),
})
.await;
}
state
.ws_hub
.add_room_subscription(&connection.client_name, room_id, connection.connection_id)
.await;
let mut receiver = topic.sender.subscribe();
let outbound = connection.outbound.clone();
let hub = state.ws_hub.clone();
let app = state.chat_app.clone();
let chat_ctx = chat_ctx.clone();
let connection_id = connection.connection_id;
let client_name = connection.client_name.clone();
let task = tokio::spawn(async move {
loop {
match receiver.recv().await {
Ok(event) => {
if matches!(event, ServerWsEvent::ChatMembersUpdated { .. }) {
match app.get_room(chat_ctx.clone(), room_id).await {
Err(ChatError::Unauthorized)
| Err(ChatError::Forbidden(_))
| Err(ChatError::NotFound(_)) => {
let _ = outbound
.send(ServerWsEvent::Error {
code: "subscription_revoked".to_string(),
message: "room subscription revoked".to_string(),
})
.await;
break;
}
Err(_) | Ok(_) => {}
}
}
if outbound.send(event).await.is_err() {
break;
}
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => {
let _ = outbound
.send(ServerWsEvent::ChatSyncRequired {
room_id,
reason: "lagged".to_string(),
expected_from_seq: None,
})
.await;
}
Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
}
}
hub.remove_room_subscription(&client_name, room_id, connection_id)
.await;
});
let mut subscriptions = connection.subscriptions.write().await;
if let Some(existing) = subscriptions.insert(room_id, task) {
existing.abort();
}
let _ = connection
.outbound
.send(ServerWsEvent::ChatSubscribed { room_id, from_seq })
.await;
}
ClientWsCommand::ChatUnsubscribe { room_id } => {
if let Some(task) = connection.subscriptions.write().await.remove(&room_id) {
task.abort();
}
state
.ws_hub
.remove_room_subscription(
&connection.client_name,
room_id,
connection.connection_id,
)
.await;
}
ClientWsCommand::ChatResume { rooms } => {
for room in rooms {
let room_view = match state
.chat_app
.get_room(chat_ctx.clone(), room.room_id)
.await
{
Ok(room_view) => room_view,
Err(ChatError::Unauthorized)
| Err(ChatError::Forbidden(_))
| Err(ChatError::NotFound(_)) => {
let _ = connection
.outbound
.send(ServerWsEvent::Error {
code: "resume_denied".to_string(),
message: format!("room resume denied for {}", room.room_id),
})
.await;
continue;
}
Err(err) => {
let _ = connection
.outbound
.send(ServerWsEvent::Error {
code: "resume_unavailable".to_string(),
message: format!(
"room resume check failed for {}: {err}",
room.room_id
),
})
.await;
continue;
}
};
let topic = state
.ws_hub
.ensure_room_topic(&chat_ctx.client_name, room.room_id)
.await;
let known_room_seq =
known_room_seq(room_view.last_message_seq, topic.current_seq());
topic.observe_seq(known_room_seq);
if known_room_seq > room.last_seq {
let _ = connection
.outbound
.send(ServerWsEvent::ChatSyncRequired {
room_id: room.room_id,
reason: "resume_gap".to_string(),
expected_from_seq: Some(room.last_seq.saturating_add(1)),
})
.await;
}
}
}
ClientWsCommand::ChatTypingStart { room_id } => {
if state
.chat_app
.get_room(chat_ctx.clone(), room_id)
.await
.is_ok()
{
let topic = state
.ws_hub
.ensure_room_topic(&chat_ctx.client_name, room_id)
.await;
let _ = topic
.sender
.send(typing_event(room_id, chat_ctx.user_id.clone(), "start"));
}
}
ClientWsCommand::ChatTypingStop { room_id } => {
if state
.chat_app
.get_room(chat_ctx.clone(), room_id)
.await
.is_ok()
{
let topic = state
.ws_hub
.ensure_room_topic(&chat_ctx.client_name, room_id)
.await;
let _ = topic
.sender
.send(typing_event(room_id, chat_ctx.user_id.clone(), "stop"));
}
}
ClientWsCommand::ChatPresenceHeartbeat { active_room_id } => {
if let Some(room_id) = active_room_id
&& state
.chat_app
.get_room(chat_ctx.clone(), room_id)
.await
.is_ok()
{
let snapshot = state
.ws_hub
.current_presence(&chat_ctx.client_name, room_id)
.await;
let topic = state
.ws_hub
.ensure_room_topic(&chat_ctx.client_name, room_id)
.await;
let _ = topic.sender.send(ServerWsEvent::ChatPresenceUpdated {
room_id,
users: snapshot.users,
});
}
}
ClientWsCommand::ChatReadUpTo {
room_id,
message_id,
seq,
} => {
let _ = state
.chat_app
.mark_read_up_to(
chat_ctx.clone(),
MarkReadUpTo {
room_id,
message_id,
seq,
},
)
.await;
}
}
Ok(())
}
fn known_room_seq(durable_room_seq: i64, topic_seq: i64) -> i64 {
durable_room_seq.max(topic_seq).max(0)
}
#[cfg(test)]
mod tests {
use super::known_room_seq;
#[test]
fn known_room_seq_prefers_durable_state_after_restart() {
assert_eq!(known_room_seq(14, 0), 14);
}
#[test]
fn known_room_seq_keeps_newer_topic_state_when_broadcasts_advance_first() {
assert_eq!(known_room_seq(9, 11), 11);
}
}