athena_rs 1.1.0

Database gateway API
Documentation
//! Websocket server for CDC Athena
//!
//! Clients subscribe by passing the `X-Athena-Client` header when upgrading to WebSocket.
//! Events are filtered by matching the header value to `EventMessage::organization_id`.

use axum::{
    Router,
    extract::{
        Json, State,
        ws::{Message, WebSocket, WebSocketUpgrade},
    },
    http::HeaderMap,
    http::StatusCode,
    response::IntoResponse,
    routing::{get, post},
};

use futures::stream::SplitSink;
use futures::{SinkExt, StreamExt};
use serde::{Deserialize, Serialize};
use serde_json::{Value, from_str, json};
use std::error::Error as stdError;
use std::{collections::HashSet, net::SocketAddr, sync::Arc};
use tokio::net::TcpListener;
use tokio::sync::broadcast::Receiver;
use tokio::sync::{Mutex, MutexGuard, broadcast};
use tokio::task::JoinHandle;

use tokio::time::{Duration, Interval, interval};
use tracing::{info, warn};

pub mod events;

use events::{EventMessage, set_broadcast_tx};

type Tx = broadcast::Sender<String>;

/// JSON subscription request (deprecated). Prefer `X-Athena-Client` header at upgrade.
#[derive(Debug, Serialize, Deserialize)]
#[deprecated(
    since = "0.80.3",
    note = "Use X-Athena-Client header when connecting to WebSocket instead"
)]
struct SubscriptionRequest {
    subscribe_to_organization_id: String,
}

#[derive(Clone)]
struct AppState {
    tx: Tx,
    active_subscribers: Arc<Mutex<HashSet<String>>>,
}

/// Starts the DMS WebSocket server.
///
/// Clients must pass the `X-Athena-Client` header when connecting to subscribe
/// to events for that client. The header value is matched against `EventMessage::organization_id`.
pub async fn websocket_server(port: u16) -> Result<(), Box<dyn stdError>> {
    let capacity: usize = 1000;
    // Create broadcast channel for messages
    let (tx, _) = broadcast::channel::<String>(capacity);
    set_broadcast_tx(tx.clone());
    info!("DMS Broadcast channel created with capacity: {}", capacity);

    // Track active subscribers
    let active_subscribers: Arc<Mutex<HashSet<String>>> = Arc::new(Mutex::new(HashSet::new()));

    // Create shared state
    let state: AppState = AppState {
        tx: tx.clone(),
        active_subscribers: active_subscribers.clone(),
    };

    // Set up routes
    let app: Router = Router::new()
        .route("/ws", get(ws_handler))
        .route("/events", post(publish_event))
        .route("/status", get(status))
        .with_state(state);

    // Start the server
    let addr: SocketAddr = SocketAddr::from(([0, 0, 0, 0], port));
    info!("Server running at http://{}", addr);
    info!("WebSocket endpoint available at ws://{}/ws", addr);

    let listener: TcpListener = TcpListener::bind(&addr).await?;
    axum::serve(listener, app).await?;
    info!("Server started successfully");

    Ok(())
}

/// WebSocket handler that upgrades the connection.
///
/// Clients must pass `X-Athena-Client` to subscribe to events for that client.
/// The header value is matched against `EventMessage::organization_id`.
async fn ws_handler(
    ws: WebSocketUpgrade,
    headers: HeaderMap,
    State(state): State<AppState>,
) -> impl IntoResponse {
    let client_id: String = headers
        .get("x-athena-client")
        .or_else(|| headers.get("X-Athena-Client"))
        .and_then(|h| h.to_str().ok())
        .unwrap_or("")
        .trim()
        .to_string();

    if client_id.is_empty() {
        return (
            StatusCode::BAD_REQUEST,
            Json(json!({
                "error": "Missing X-Athena-Client header",
                "message": "Pass X-Athena-Client when connecting to subscribe to CDC events"
            })),
        )
            .into_response();
    }

    ws.on_upgrade(move |socket| handle_socket(socket, state, client_id))
}

/// Handles the WebSocket connection.
///
/// The client is pre-subscribed to `client_id` from the `X-Athena-Client` header.
/// Additional subscriptions can be sent via JSON messages (deprecated).
async fn handle_socket(socket: WebSocket, state: AppState, client_id: String) {
    let (sender, mut receiver) = socket.split();
    let sender: Arc<Mutex<SplitSink<WebSocket, Message>>> = Arc::new(Mutex::new(sender));

    let mut rx: Receiver<String> = state.tx.subscribe();
    let subscribed_companies: Arc<Mutex<HashSet<String>>> = {
        let mut set: HashSet<String> = HashSet::new();
        set.insert(client_id.clone());
        Arc::new(Mutex::new(set))
    };

    // Register as active subscriber for the client from header
    {
        let mut subscribers: MutexGuard<'_, HashSet<String>> =
            state.active_subscribers.lock().await;
        subscribers.insert(client_id.clone());
    }

    // Send initial subscription confirmation
    {
        let confirmation: Value = json!({
            "status": "subscribed",
            "client_id": client_id,
            "message": "Subscribed via X-Athena-Client header"
        });
        let mut sender_guard: MutexGuard<'_, SplitSink<WebSocket, Message>> = sender.lock().await;
        if sender_guard
            .send(Message::Text(confirmation.to_string().into()))
            .await
            .is_err()
        {
            warn!("Failed to send subscription confirmation to client");
        }
    }

    let subscribed_companies_clone: Arc<Mutex<HashSet<String>>> = subscribed_companies.clone();

    // Buffer to accumulate messages
    let message_buffer: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
    let message_buffer_clone: Arc<Mutex<Vec<String>>> = message_buffer.clone();

    // Task to send messages to the client with throttling
    let mut send_task: JoinHandle<()> = tokio::spawn({
        let sender: Arc<Mutex<SplitSink<WebSocket, Message>>> = Arc::clone(&sender);
        let subscribed_companies: Arc<Mutex<HashSet<String>>> = subscribed_companies.clone();
        async move {
            let mut ticker: Interval = interval(Duration::from_millis(1000));

            loop {
                tokio::select! {
                    msg_result = rx.recv() => {
                        match msg_result {
                            Ok(msg) => {
                                if let Ok(event) = from_str::<EventMessage>(&msg) {
                                    if is_client_subscribed(&subscribed_companies, &event.organization_id).await {
                                        let mut buffer = message_buffer_clone.lock().await;
                                        buffer.push(msg);
                                    }
                                } else {
                                    warn!("Failed to parse EventMessage from broadcast message");
                                }
                            }
                            Err(_) => {
                                break;
                            }
                        }
                    }
                    _ = ticker.tick() => {
                        let mut buffer = message_buffer_clone.lock().await;
                        if !buffer.is_empty() {
                            // Send all accumulated messages
                            for msg in buffer.drain(..) {
                                send_message_to_client(&sender, msg).await;
                            }
                        }
                    }
                }
            }
        }
    });

    // Task to receive messages from the client
    let mut recv_task: JoinHandle<()> = tokio::spawn({
        let sender: Arc<Mutex<SplitSink<WebSocket, Message>>> = Arc::clone(&sender);

        async move {
            while let Some(Ok(Message::Text(text))) = receiver.next().await {
                #[allow(deprecated)]
                if let Ok(subscription) = serde_json::from_str::<SubscriptionRequest>(&text) {
                    handle_subscription_request(
                        &state,
                        &subscribed_companies_clone,
                        &sender,
                        subscription,
                    )
                    .await;
                } else {
                    warn!("Failed to parse SubscriptionRequest from client message");
                }
            }
        }
    });

    // Wait for either task to complete
    tokio::select! {
        _ = &mut send_task => {

        },
        _ = &mut recv_task => {

        },
    }
}

async fn is_client_subscribed(
    subscribed_companies: &Arc<Mutex<HashSet<String>>>,
    organization_id: &str,
) -> bool {
    let companies: MutexGuard<'_, HashSet<String>> = subscribed_companies.lock().await;
    companies.contains(organization_id)
}

async fn send_message_to_client(sender: &Arc<Mutex<SplitSink<WebSocket, Message>>>, msg: String) {
    let mut sender: MutexGuard<'_, SplitSink<WebSocket, Message>> = sender.lock().await;
    if sender.send(Message::Text(msg.into())).await.is_err() {
        warn!("Failed to send message to client");
    }
}

#[allow(deprecated)]
async fn handle_subscription_request(
    state: &AppState,
    subscribed_companies: &Arc<Mutex<HashSet<String>>>,
    sender: &Arc<Mutex<SplitSink<WebSocket, Message>>>,
    subscription: SubscriptionRequest,
) {
    let organization_id: String = subscription.subscribe_to_organization_id;

    // Add to subscribed companies
    {
        let mut companies: MutexGuard<'_, HashSet<String>> = subscribed_companies.lock().await;
        companies.insert(organization_id.clone());
    }

    // Add to active subscribers
    {
        let mut subscribers: MutexGuard<'_, HashSet<String>> =
            state.active_subscribers.lock().await;
        subscribers.insert(organization_id.clone());
    }

    // Send confirmation to client
    let confirmation: Value = json!({
        "status": "subscribed",
        "organization_id": organization_id
    });

    let mut sender: MutexGuard<'_, SplitSink<WebSocket, Message>> = sender.lock().await;

    if sender
        .send(Message::Text(confirmation.to_string().into()))
        .await
        .is_err()
    {
        warn!("Failed to send subscription confirmation to client");
    }
}

/// Endpoint to publish events
async fn publish_event(
    State(state): State<AppState>,
    Json(event): Json<EventMessage>,
) -> impl IntoResponse {
    let organization_id: String = event.organization_id.clone();

    // Check if anyone is subscribed to this organization_id
    let has_subscribers: bool = {
        let subscribers: MutexGuard<'_, HashSet<String>> = state.active_subscribers.lock().await;
        subscribers.contains(&organization_id)
    };

    if !has_subscribers {
        let response: Value = json!({
            "message": "there was no subscriber to this organization_id channel so the message has been voided",
            "organization_id": organization_id,
            "status": "voided",
            "success": true
        });

        return (StatusCode::OK, Json(response)).into_response();
    }

    let event_json: String = serde_json::to_string(&event).unwrap_or_default();
    if let Err(err) = state.tx.send(event_json) {
        warn!("Failed to broadcast event: {}", err);
        return (
            StatusCode::INTERNAL_SERVER_ERROR,
            "Failed to broadcast event",
        )
            .into_response();
    }
    let response: Value = json!({
        "status": "delivered",
        "success": true,
        "organization_id": organization_id
    });

    (StatusCode::OK, Json(response)).into_response()
}

async fn status() -> impl IntoResponse {
    let response: Value = json!({
        "status": "ok",
        "success": true,
        "server": "dms-server-api"
    });

    (StatusCode::OK, Json(response)).into_response()
}