athena_rs 2.12.1

Database gateway API
Documentation
//! WebSocket connection lifecycle: per-connection state, send/recv tasks, and cleanup.

use super::events::{EventMessage, replay_since};
#[allow(deprecated)]
use super::state::SubscriptionRequest;
use super::state::{AppState, ClientControlMessage};
use axum::extract::ws::{Message, WebSocket};
use futures::stream::SplitSink;
use futures::{SinkExt, StreamExt};
use serde_json::{Value, from_str, json};
use std::collections::HashSet;
use std::sync::Arc;
use tokio::sync::broadcast::Receiver;
use tokio::sync::{Mutex, MutexGuard};
use tokio::task::JoinHandle;
use tokio::time::{Duration, Interval, interval};
use tracing::warn;

/// Throttle interval for flushing buffered events to the client (milliseconds).
const SEND_THROTTLE_MS: u64 = 1000;

/// Default maximum number of replayed events per travel request.
const TRAVEL_LIMIT_DEFAULT: usize = 10_000;

/// Handles a single WebSocket connection: subscribe to broadcast, run send/recv tasks, cleanup on drop.
pub async fn handle_socket(socket: WebSocket, state: AppState, client_id: String) {
    let (sender, mut receiver) = socket.split();
    let sender: Arc<Mutex<SplitSink<WebSocket, Message>>> = Arc::new(Mutex::new(sender));

    let mut rx: Receiver<String> = state.tx.subscribe();
    let subscribed_companies: Arc<Mutex<HashSet<String>>> = {
        let mut set = HashSet::new();
        set.insert(client_id.clone());
        Arc::new(Mutex::new(set))
    };

    {
        let mut subscribers = state.active_subscribers.lock().await;
        subscribers.insert(client_id.clone());
    }

    {
        let confirmation: Value = json!({
            "status": "subscribed",
            "client_id": client_id,
            "message": "Subscribed via X-Athena-Client header"
        });
        let mut sender_guard = sender.lock().await;
        if sender_guard
            .send(Message::Text(confirmation.to_string()))
            .await
            .is_err()
        {
            warn!("Failed to send subscription confirmation to client");
        }
    }

    let subscribed_companies_clone = subscribed_companies.clone();
    let message_buffer: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
    let message_buffer_clone = message_buffer.clone();

    let mut send_task: JoinHandle<()> = tokio::spawn({
        let sender = Arc::clone(&sender);
        let subscribed_companies = subscribed_companies.clone();
        async move {
            let mut ticker: Interval = interval(Duration::from_millis(SEND_THROTTLE_MS));
            loop {
                tokio::select! {
                    msg_result = rx.recv() => {
                        match msg_result {
                            Ok(msg) => {
                                if let Ok(event) = from_str::<EventMessage>(&msg) {
                                    if is_client_subscribed(&subscribed_companies, &event.organization_id).await {
                                        let mut buffer = message_buffer_clone.lock().await;
                                        buffer.push(msg);
                                    }
                                } else {
                                    warn!("Failed to parse EventMessage from broadcast message");
                                }
                            }
                            Err(_) => break,
                        }
                    }
                    _ = ticker.tick() => {
                        let mut buffer = message_buffer_clone.lock().await;
                        if !buffer.is_empty() {
                            for msg in buffer.drain(..) {
                                send_message_to_client(&sender, msg).await;
                            }
                        }
                    }
                }
            }
        }
    });

    let state_for_recv = state.clone();
    let mut recv_task: JoinHandle<()> = tokio::spawn({
        let sender = Arc::clone(&sender);
        async move {
            while let Some(Ok(Message::Text(text))) = receiver.next().await {
                if let Ok(control) = serde_json::from_str::<ClientControlMessage>(&text) {
                    let ClientControlMessage::Travel {
                        organization_id,
                        since_seq,
                        since_ts_ms,
                        limit,
                    } = control;
                    {
                        let limit = limit.unwrap_or(TRAVEL_LIMIT_DEFAULT);
                        let requested_org = organization_id.as_ref().and_then(|v| {
                            let t = v.trim();
                            if t.is_empty() {
                                None
                            } else {
                                Some(t.to_string())
                            }
                        });
                        let orgs = if let Some(org) = requested_org.clone() {
                            vec![org]
                        } else {
                            let companies = subscribed_companies_clone.lock().await;
                            companies.iter().cloned().collect::<Vec<_>>()
                        };

                        for org in orgs {
                            let records = replay_since(&org, since_seq, since_ts_ms, limit).await;
                            for record in records {
                                let msg = EventMessage {
                                    organization_id: org.clone(),
                                    data: json!({
                                        "seq": record.seq,
                                        "ts_ms": record.ts_ms,
                                        "payload": record.payload,
                                    }),
                                };
                                if let Ok(encoded) = serde_json::to_string(&msg) {
                                    send_message_to_client(&sender, encoded).await;
                                }
                            }
                        }

                        let ack: Value = json!({
                            "status": "travel_complete",
                            "organization_id": requested_org,
                            "since_seq": since_seq,
                            "since_ts_ms": since_ts_ms,
                            "limit": limit,
                        });
                        send_message_to_client(&sender, ack.to_string()).await;
                    }
                    continue;
                }

                #[allow(deprecated)]
                if let Ok(subscription) = serde_json::from_str::<SubscriptionRequest>(&text) {
                    handle_subscription_request(
                        &state_for_recv,
                        &subscribed_companies_clone,
                        &sender,
                        subscription,
                    )
                    .await;
                } else {
                    warn!("Failed to parse SubscriptionRequest from client message");
                }
            }
        }
    });

    tokio::select! {
        _ = &mut send_task => {}
        _ = &mut recv_task => {}
    }

    send_task.abort();
    recv_task.abort();

    {
        let mut subscribers = state.active_subscribers.lock().await;
        subscribers.remove(&client_id);
    }

    {
        let mut buffer = message_buffer.lock().await;
        buffer.clear();
    }
}

async fn is_client_subscribed(
    subscribed_companies: &Arc<Mutex<HashSet<String>>>,
    organization_id: &str,
) -> bool {
    let companies: MutexGuard<'_, HashSet<String>> = subscribed_companies.lock().await;
    companies.contains(organization_id)
}

async fn send_message_to_client(sender: &Arc<Mutex<SplitSink<WebSocket, Message>>>, msg: String) {
    let mut guard = sender.lock().await;
    if guard.send(Message::Text(msg)).await.is_err() {
        warn!("Failed to send message to client");
    }
}

#[allow(deprecated)]
async fn handle_subscription_request(
    state: &AppState,
    subscribed_companies: &Arc<Mutex<HashSet<String>>>,
    sender: &Arc<Mutex<SplitSink<WebSocket, Message>>>,
    subscription: SubscriptionRequest,
) {
    let organization_id = subscription.subscribe_to_organization_id;

    {
        let mut companies = subscribed_companies.lock().await;
        companies.insert(organization_id.clone());
    }

    {
        let mut subscribers = state.active_subscribers.lock().await;
        subscribers.insert(organization_id.clone());
    }

    let confirmation: Value = json!({
        "status": "subscribed",
        "organization_id": organization_id
    });

    let mut guard = sender.lock().await;
    if guard
        .send(Message::Text(confirmation.to_string()))
        .await
        .is_err()
    {
        warn!("Failed to send subscription confirmation to client");
    }
}