athena_rs 3.26.3

Hyper performant polyglot Database driver
//! Actix WebSocket upgrade handler for [`openapi-wss.yaml`](../../../openapi-wss.yaml) `/wss/gateway`.
//!
//! Authenticates via the same [`crate::api::chat::auth::AuthResolver`] as HTTP `/chat/*`, then
//! dispatches [`ClientWsCommand`] frames to hub subscription and ephemeral control paths.

use std::sync::Arc;

use actix_web::{Error, HttpRequest, HttpResponse, get, web};
use athena_chat::ChatError;
use athena_chat::dto::commands::MarkReadUpTo;
use athena_wss::WsGateway;
use athena_wss::commands::typing_event;
use athena_wss::protocol::{ClientWsCommand, ServerWsEvent};
use futures_util::StreamExt;
use tokio::sync::mpsc;
use tracing::{instrument, warn};

use crate::AppState;

/// Upgrades an authenticated HTTP request to a chat WebSocket session.
///
/// Sends `hello.ok` on connect and processes inbound [`ClientWsCommand`] JSON frames until close.
#[get("/wss/gateway")]
#[instrument(name = "wss.upgrade", skip(req, body, state))]
pub async fn gateway_wss_route(
    req: HttpRequest,
    body: web::Payload,
    state: web::Data<AppState>,
) -> Result<HttpResponse, Error> {
    let chat_ctx = match state.auth_resolver.resolve(&req, state.get_ref()).await {
        Ok(ctx) => ctx,
        Err(resp) => return Ok(resp),
    };

    let (response, mut session, mut message_stream) = actix_ws::handle(&req, body)?;
    let (outbound_tx, mut outbound_rx) = mpsc::channel::<ServerWsEvent>(256);
    let connection = state
        .ws_hub
        .register_connection(
            chat_ctx.user_id.clone(),
            chat_ctx.organization_id.clone(),
            chat_ctx.client_name.clone(),
            outbound_tx.clone(),
        )
        .await;

    let hub = state.ws_hub.clone();
    let writer_session = session.clone();
    let writer_connection_id = connection.connection_id;
    actix_web::rt::spawn(async move {
        let mut writer_session = writer_session;
        while let Some(event) = outbound_rx.recv().await {
            let payload = match serde_json::to_string(&event) {
                Ok(payload) => payload,
                Err(err) => {
                    warn!(error = %err, "failed to serialize websocket event");
                    continue;
                }
            };
            if writer_session.text(payload).await.is_err() {
                break;
            }
        }
        hub.unregister_connection(writer_connection_id).await;
    });

    let hello = ServerWsEvent::HelloOk {
        connection_id: connection.connection_id,
        server_time: chrono::Utc::now(),
    };
    let _ = outbound_tx.send(hello).await;

    let state_for_reader = state.clone();
    actix_web::rt::spawn(async move {
        while let Some(message) = message_stream.next().await {
            match message {
                Ok(actix_ws::Message::Text(text)) => {
                    if handle_text_command(
                        text.to_string(),
                        &state_for_reader,
                        &chat_ctx,
                        connection.clone(),
                    )
                    .await
                    .is_err()
                    {
                        break;
                    }
                }
                Ok(actix_ws::Message::Ping(bytes)) => {
                    let _ = session.pong(&bytes).await;
                }
                Ok(actix_ws::Message::Pong(_)) => {}
                Ok(actix_ws::Message::Close(_)) => break,
                Ok(actix_ws::Message::Continuation(_)) => break,
                Ok(actix_ws::Message::Binary(_)) => {}
                Ok(actix_ws::Message::Nop) => {}
                Err(err) => {
                    warn!(error = %err, "websocket reader error");
                    break;
                }
            }
        }
        state_for_reader
            .ws_hub
            .unregister_connection(connection.connection_id)
            .await;
    });

    Ok(response)
}

async fn handle_text_command(
    text: String,
    state: &web::Data<AppState>,
    chat_ctx: &athena_chat::ChatContext,
    connection: Arc<athena_wss::connection::ConnectionHandle>,
) -> Result<(), ()> {
    let command: ClientWsCommand = match serde_json::from_str(&text) {
        Ok(command) => command,
        Err(err) => {
            let _ = connection
                .outbound
                .send(ServerWsEvent::Error {
                    code: "invalid_json".to_string(),
                    message: err.to_string(),
                })
                .await;
            return Ok(());
        }
    };

    match command {
        ClientWsCommand::AuthHello { .. } => {
            let _ = connection
                .outbound
                .send(ServerWsEvent::HelloOk {
                    connection_id: connection.connection_id,
                    server_time: chrono::Utc::now(),
                })
                .await;
        }
        ClientWsCommand::Ping { at } => {
            let _ = connection.outbound.send(ServerWsEvent::Pong { at }).await;
        }
        ClientWsCommand::ChatSubscribe { room_id, from_seq } => {
            let room = match state.chat_app.get_room(chat_ctx.clone(), room_id).await {
                Ok(room) => room,
                Err(_) => {
                    let _ = connection
                        .outbound
                        .send(ServerWsEvent::Error {
                            code: "subscribe_denied".to_string(),
                            message: "room subscription denied".to_string(),
                        })
                        .await;
                    return Ok(());
                }
            };

            let topic = state
                .ws_hub
                .ensure_room_topic(&chat_ctx.client_name, room_id)
                .await;
            let known_room_seq = known_room_seq(room.last_message_seq, topic.current_seq());
            topic.observe_seq(known_room_seq);
            if let Some(from_seq) = from_seq
                && known_room_seq > from_seq
            {
                let _ = connection
                    .outbound
                    .send(ServerWsEvent::ChatSyncRequired {
                        room_id,
                        reason: "resume_gap".to_string(),
                        expected_from_seq: Some(from_seq.saturating_add(1)),
                    })
                    .await;
            }

            state
                .ws_hub
                .add_room_subscription(&connection.client_name, room_id, connection.connection_id)
                .await;
            let mut receiver = topic.sender.subscribe();
            let outbound = connection.outbound.clone();
            let hub = state.ws_hub.clone();
            let app = state.chat_app.clone();
            let chat_ctx = chat_ctx.clone();
            let connection_id = connection.connection_id;
            let client_name = connection.client_name.clone();
            let task = tokio::spawn(async move {
                loop {
                    match receiver.recv().await {
                        Ok(event) => {
                            if matches!(event, ServerWsEvent::ChatMembersUpdated { .. }) {
                                match app.get_room(chat_ctx.clone(), room_id).await {
                                    Err(ChatError::Unauthorized)
                                    | Err(ChatError::Forbidden(_))
                                    | Err(ChatError::NotFound(_)) => {
                                        let _ = outbound
                                            .send(ServerWsEvent::Error {
                                                code: "subscription_revoked".to_string(),
                                                message: "room subscription revoked".to_string(),
                                            })
                                            .await;
                                        break;
                                    }
                                    Err(_) | Ok(_) => {}
                                }
                            }
                            if outbound.send(event).await.is_err() {
                                break;
                            }
                        }
                        Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => {
                            let _ = outbound
                                .send(ServerWsEvent::ChatSyncRequired {
                                    room_id,
                                    reason: "lagged".to_string(),
                                    expected_from_seq: None,
                                })
                                .await;
                        }
                        Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
                    }
                }
                hub.remove_room_subscription(&client_name, room_id, connection_id)
                    .await;
            });
            let mut subscriptions = connection.subscriptions.write().await;
            if let Some(existing) = subscriptions.insert(room_id, task) {
                existing.abort();
            }
            let _ = connection
                .outbound
                .send(ServerWsEvent::ChatSubscribed { room_id, from_seq })
                .await;
        }
        ClientWsCommand::ChatUnsubscribe { room_id } => {
            if let Some(task) = connection.subscriptions.write().await.remove(&room_id) {
                task.abort();
            }
            state
                .ws_hub
                .remove_room_subscription(
                    &connection.client_name,
                    room_id,
                    connection.connection_id,
                )
                .await;
        }
        ClientWsCommand::ChatResume { rooms } => {
            for room in rooms {
                let room_view = match state
                    .chat_app
                    .get_room(chat_ctx.clone(), room.room_id)
                    .await
                {
                    Ok(room_view) => room_view,
                    Err(ChatError::Unauthorized)
                    | Err(ChatError::Forbidden(_))
                    | Err(ChatError::NotFound(_)) => {
                        let _ = connection
                            .outbound
                            .send(ServerWsEvent::Error {
                                code: "resume_denied".to_string(),
                                message: format!("room resume denied for {}", room.room_id),
                            })
                            .await;
                        continue;
                    }
                    Err(err) => {
                        let _ = connection
                            .outbound
                            .send(ServerWsEvent::Error {
                                code: "resume_unavailable".to_string(),
                                message: format!(
                                    "room resume check failed for {}: {err}",
                                    room.room_id
                                ),
                            })
                            .await;
                        continue;
                    }
                };
                let topic = state
                    .ws_hub
                    .ensure_room_topic(&chat_ctx.client_name, room.room_id)
                    .await;
                let known_room_seq =
                    known_room_seq(room_view.last_message_seq, topic.current_seq());
                topic.observe_seq(known_room_seq);
                if known_room_seq > room.last_seq {
                    let _ = connection
                        .outbound
                        .send(ServerWsEvent::ChatSyncRequired {
                            room_id: room.room_id,
                            reason: "resume_gap".to_string(),
                            expected_from_seq: Some(room.last_seq.saturating_add(1)),
                        })
                        .await;
                }
            }
        }
        ClientWsCommand::ChatTypingStart { room_id } => {
            if state
                .chat_app
                .get_room(chat_ctx.clone(), room_id)
                .await
                .is_ok()
            {
                let topic = state
                    .ws_hub
                    .ensure_room_topic(&chat_ctx.client_name, room_id)
                    .await;
                let _ = topic
                    .sender
                    .send(typing_event(room_id, chat_ctx.user_id.clone(), "start"));
            }
        }
        ClientWsCommand::ChatTypingStop { room_id } => {
            if state
                .chat_app
                .get_room(chat_ctx.clone(), room_id)
                .await
                .is_ok()
            {
                let topic = state
                    .ws_hub
                    .ensure_room_topic(&chat_ctx.client_name, room_id)
                    .await;
                let _ = topic
                    .sender
                    .send(typing_event(room_id, chat_ctx.user_id.clone(), "stop"));
            }
        }
        ClientWsCommand::ChatPresenceHeartbeat { active_room_id } => {
            if let Some(room_id) = active_room_id
                && state
                    .chat_app
                    .get_room(chat_ctx.clone(), room_id)
                    .await
                    .is_ok()
            {
                let snapshot = state
                    .ws_hub
                    .current_presence(&chat_ctx.client_name, room_id)
                    .await;
                let topic = state
                    .ws_hub
                    .ensure_room_topic(&chat_ctx.client_name, room_id)
                    .await;
                let _ = topic.sender.send(ServerWsEvent::ChatPresenceUpdated {
                    room_id,
                    users: snapshot.users,
                });
            }
        }
        ClientWsCommand::ChatReadUpTo {
            room_id,
            message_id,
            seq,
        } => {
            let _ = state
                .chat_app
                .mark_read_up_to(
                    chat_ctx.clone(),
                    MarkReadUpTo {
                        room_id,
                        message_id,
                        seq,
                    },
                )
                .await;
        }
    }

    Ok(())
}

fn known_room_seq(durable_room_seq: i64, topic_seq: i64) -> i64 {
    durable_room_seq.max(topic_seq).max(0)
}

#[cfg(test)]
mod tests {
    use super::known_room_seq;

    #[test]
    fn known_room_seq_prefers_durable_state_after_restart() {
        assert_eq!(known_room_seq(14, 0), 14);
    }

    #[test]
    fn known_room_seq_keeps_newer_topic_state_when_broadcasts_advance_first() {
        assert_eq!(known_room_seq(9, 11), 11);
    }
}