athena_rs 3.0.1

Database gateway API
Documentation
//! WebSocket server for CDC Athena.
//!
//! - **URL:** Use `ws://host:port/ws` when the server is plain HTTP; use `wss://` only when
//!   the server is behind TLS (otherwise you get SSL WRONG_VERSION_NUMBER).
//! - **Header:** Send `X-Athena-Client: <client_name>` on the upgrade request; without it the
//!   server returns 400. Events are filtered by matching this value to `EventMessage::organization_id`.
//!
//! ## Messages
//!
//! **Client → Server (you send):**
//! - **Travel (replay):** `{"type":"travel","organization_id":"<id>","since_seq":123,"since_ts_ms":0,"limit":1000}`  
//!   All fields except `type` are optional. Omit `organization_id` to replay all subscribed channels.
//!
//! **Server → Client (you receive):**
//! - On connect: `{"status":"subscribed","client_id":"...","message":"..."}`
//! - CDC events: `{"organization_id":"...","data":{"seq":1,"ts_ms":...,"payload":{...}}}`
//! - After travel: same event shape for each replayed record, then `{"status":"travel_complete","organization_id":...,"since_seq":...,"since_ts_ms":...,"limit":...}`

mod connection;
pub mod events;
mod routes;
mod state;

use std::collections::HashSet;
use std::error::Error as StdError;
use std::net::SocketAddr;
use std::sync::Arc;

use axum::Router;
use axum::routing::{get, post};
use tokio::net::TcpListener;
use tokio::sync::{Mutex, broadcast};
use tracing::info;

use events::set_broadcast_tx;
use state::AppState;

/// Default capacity of the broadcast channel (number of in-flight messages).
const BROADCAST_CAPACITY: usize = 1000;

/// Starts the DMS WebSocket server.
///
/// Binds to `0.0.0.0:port`, exposes `/ws` (WebSocket), `/events` (POST), and `/status` (GET).
/// Clients must pass the `X-Athena-Client` header when connecting to subscribe
/// to events for that client. The header value is matched against `EventMessage::organization_id`.
///
/// # Errors
///
/// Returns an error if binding or serving the TCP listener fails.
pub async fn websocket_server(port: u16) -> Result<(), Box<dyn StdError>> {
    let capacity: usize = BROADCAST_CAPACITY;
    let (tx, _) = broadcast::channel::<String>(capacity);
    set_broadcast_tx(tx.clone());
    info!("DMS Broadcast channel created with capacity: {}", capacity);

    let active_subscribers: Arc<Mutex<HashSet<String>>> = Arc::new(Mutex::new(HashSet::new()));
    let app_state: AppState = AppState {
        tx: tx.clone(),
        active_subscribers,
    };

    let app: Router = Router::new()
        .route("/ws", get(routes::ws_handler))
        .route("/events", post(routes::publish_event))
        .route("/status", get(routes::status))
        .with_state(app_state);

    let addr = SocketAddr::from(([0, 0, 0, 0], port));
    info!("Server running at http://{}", addr);
    info!("WebSocket endpoint available at ws://{}/ws", addr);

    let listener = TcpListener::bind(&addr).await?;
    axum::serve(listener, app).await?;
    info!("Server started successfully");

    Ok(())
}