athena_rs 3.3.0

Database gateway API
Documentation
//! Event batching and broadcast for the CDC WebSocket server.
//!
//! Events are buffered per `organization_id` and flushed once per second,
//! broadcasting to WebSocket subscribers via the server's channel when available.

use once_cell::sync::OnceCell;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
use tokio::fs::{OpenOptions, create_dir_all};
use tokio::io::Lines;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::sync::{Mutex, MutexGuard, broadcast};
use tokio::time::{Duration, Interval, interval};
use tracing::warn;

/// Event shape expected by WebSocket subscribers.
#[derive(Debug, Serialize, Deserialize)]
pub struct EventMessage {
    /// Channel (organization/client) ID; matched to `X-Athena-Client`.
    pub organization_id: String,
    /// Arbitrary JSON payload (e.g. `{ "seq", "ts_ms", "payload" }`).
    pub data: Value,
}

/// Single event record for replay and persistence (seq, timestamp, payload).
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EventRecord {
    /// Monotonic sequence number per stream.
    pub seq: u64,
    /// Event time in milliseconds since Unix epoch.
    pub ts_ms: i64,
    /// Event payload as JSON.
    pub payload: Value,
}

/// Optional broadcast sender, set when the websocket server starts.
static BROADCAST_TX: OnceCell<broadcast::Sender<String>> = OnceCell::new();

/// Register the broadcast sender. Called by `websocket_server` on startup.
pub fn set_broadcast_tx(tx: broadcast::Sender<String>) {
    let _ = BROADCAST_TX.set(tx);
}

/// Returns true if the broadcast channel is configured (websocket server running).
pub fn has_broadcast() -> bool {
    BROADCAST_TX.get().is_some()
}

/// Default directory for on-disk JSONL event logs when [`set_event_log_dir`] is not called.
const DEFAULT_EVENT_LOG_DIR: &str = "cdc/event-log";

/// Maximum number of events kept in the in-memory replay cache per organization.
const MAX_EVENTS_PER_ORG: usize = 5_000;

/// Maximum number of events returned by [`replay_since`] (enforced via clamp).
const REPLAY_LIMIT_MAX: usize = 50_000;

/// Interval at which the batcher flushes events to the broadcast channel.
const BATCH_INTERVAL_SECS: u64 = 1;

static EVENT_LOG_DIR: OnceCell<String> = OnceCell::new();
static NEXT_SEQ: OnceCell<Arc<tokio::sync::Mutex<u64>>> = OnceCell::new();

// Event buffer that collects events per organization_id.
type EventBuckets = HashMap<String, Vec<EventRecord>>;
type SharedEventBuffer = Arc<Mutex<EventBuckets>>;

// In-memory replay cache per organization_id.
type ReplayBuckets = HashMap<String, VecDeque<EventRecord>>;
type SharedReplayCache = Arc<Mutex<ReplayBuckets>>;

static EVENT_BUFFER: OnceCell<SharedEventBuffer> = OnceCell::new();
static REPLAY_CACHE: OnceCell<SharedReplayCache> = OnceCell::new();
static EVENT_BATCHER_INIT: OnceCell<Arc<Mutex<bool>>> = OnceCell::new();

fn get_buffer() -> SharedEventBuffer {
    EVENT_BUFFER
        .get_or_init(|| Arc::new(Mutex::new(HashMap::new())))
        .clone()
}

fn get_replay_cache() -> SharedReplayCache {
    REPLAY_CACHE
        .get_or_init(|| Arc::new(Mutex::new(HashMap::new())))
        .clone()
}

fn get_batcher_init() -> Arc<Mutex<bool>> {
    EVENT_BATCHER_INIT
        .get_or_init(|| Arc::new(Mutex::new(false)))
        .clone()
}

fn now_ts_ms() -> i64 {
    use std::time::{SystemTime, UNIX_EPOCH};
    SystemTime::now()
        .duration_since(UNIX_EPOCH)
        .unwrap_or_else(|_| std::time::Duration::from_secs(0))
        .as_millis() as i64
}

async fn next_seq() -> u64 {
    let counter = NEXT_SEQ
        .get_or_init(|| Arc::new(tokio::sync::Mutex::new(0)))
        .clone();
    let mut guard = counter.lock().await;
    *guard = guard.saturating_add(1);
    *guard
}

fn event_log_dir() -> String {
    EVENT_LOG_DIR
        .get()
        .cloned()
        .unwrap_or_else(|| DEFAULT_EVENT_LOG_DIR.to_string())
}

/// Configure the on-disk event log directory.
///
/// If this is never called, defaults to `cdc/event-log`.
pub fn set_event_log_dir(dir: impl Into<String>) {
    let _ = EVENT_LOG_DIR.set(dir.into());
}

fn org_log_path(organization_id: &str) -> String {
    // organization_id originates from X-Athena-Client (or explicit publish) and is used as
    // a logical channel name; we keep it filesystem-safe by hex-encoding.
    let safe = hex::encode(organization_id.as_bytes());
    format!("{}/{}.jsonl", event_log_dir(), safe)
}

/// Initialize the event batcher background task
async fn ensure_event_batcher_started() {
    let init: Arc<Mutex<bool>> = get_batcher_init();
    let mut guard: tokio::sync::MutexGuard<'_, bool> = init.lock().await;
    if !*guard {
        *guard = true;
        tokio::spawn(async { event_batcher_task().await });
    }
}

/// Background task that flushes events periodically to the WebSocket broadcast.
async fn event_batcher_task() {
    let mut ticker: Interval = interval(Duration::from_secs(BATCH_INTERVAL_SECS));
    let buffer: SharedEventBuffer = get_buffer();

    loop {
        ticker.tick().await;

        let events_to_send: HashMap<String, Vec<EventRecord>> = {
            let mut guard: MutexGuard<'_, HashMap<String, Vec<EventRecord>>> = buffer.lock().await;
            if guard.is_empty() {
                continue;
            }
            std::mem::take(&mut *guard)
        };

        let tx: broadcast::Sender<String> = match BROADCAST_TX.get() {
            Some(t) => t.clone(),
            None => continue,
        };

        for (organization_id, events) in events_to_send {
            for record in events {
                let data = serde_json::json!({
                    "seq": record.seq,
                    "ts_ms": record.ts_ms,
                    "payload": record.payload,
                });
                let msg: EventMessage = EventMessage {
                    organization_id: organization_id.clone(),
                    data,
                };
                let json: String = match serde_json::to_string(&msg) {
                    Ok(s) => s,
                    Err(e) => {
                        warn!("Failed to serialize event: {}", e);
                        continue;
                    }
                };
                if tx.send(json).is_err() {
                    warn!("Failed to broadcast event (no receivers?)");
                }
            }
        }
    }
}

async fn append_to_log(organization_id: &str, record: &EventRecord) {
    let dir: String = event_log_dir();
    if create_dir_all(&dir).await.is_err() {
        return;
    }
    let path: String = org_log_path(organization_id);
    let Ok(mut file) = OpenOptions::new()
        .create(true)
        .append(true)
        .open(&path)
        .await
    else {
        return;
    };
    if let Ok(line) = serde_json::to_string(record) {
        let _ = file.write_all(line.as_bytes()).await;
        let _ = file.write_all(b"\n").await;
    }
}

async fn cache_record(organization_id: &str, record: &EventRecord) {
    let cache: Arc<Mutex<HashMap<String, VecDeque<EventRecord>>>> = get_replay_cache();
    let mut guard: MutexGuard<'_, HashMap<String, VecDeque<EventRecord>>> = cache.lock().await;
    let bucket: &mut VecDeque<EventRecord> = guard.entry(organization_id.to_string()).or_default();
    bucket.push_back(record.clone());
    while bucket.len() > MAX_EVENTS_PER_ORG {
        bucket.pop_front();
    }
}

/// Load events for an organization since a seq or timestamp.
///
/// Prefers the in-memory replay cache when possible and falls back to the on-disk log.
/// `limit` is clamped to 1 and a maximum of 50_000.
#[must_use]
pub async fn replay_since(
    organization_id: &str,
    since_seq: Option<u64>,
    since_ts_ms: Option<i64>,
    limit: usize,
) -> Vec<EventRecord> {
    let limit: usize = limit.clamp(1, REPLAY_LIMIT_MAX);
    let cache: Arc<Mutex<HashMap<String, VecDeque<EventRecord>>>> = get_replay_cache();
    {
        let guard: MutexGuard<'_, HashMap<String, VecDeque<EventRecord>>> = cache.lock().await;
        if let Some(bucket) = guard.get(organization_id) {
            let mut out: Vec<EventRecord> = Vec::new();
            for record in bucket.iter() {
                if let Some(seq) = since_seq
                    && record.seq <= seq
                {
                    continue;
                }
                if let Some(ts) = since_ts_ms
                    && record.ts_ms < ts
                {
                    continue;
                }
                out.push(record.clone());
                if out.len() >= limit {
                    return out;
                }
            }
        }
    }

    // Disk fallback: scan the JSONL file from the start.
    // This is intended for time travel / catch-up, not for high-frequency reads.
    let path: String = org_log_path(organization_id);
    let Ok(file) = tokio::fs::File::open(&path).await else {
        return Vec::new();
    };
    let mut reader: Lines<BufReader<tokio::fs::File>> = BufReader::new(file).lines();
    let mut out: Vec<EventRecord> = Vec::new();
    while let Ok(Some(line)) = reader.next_line().await {
        let Ok(record) = serde_json::from_str::<EventRecord>(&line) else {
            continue;
        };
        if let Some(seq) = since_seq
            && record.seq <= seq
        {
            continue;
        }
        if let Some(ts) = since_ts_ms
            && record.ts_ms < ts
        {
            continue;
        }
        out.push(record);
        if out.len() >= limit {
            break;
        }
    }
    out
}

/// Post an event to the CDC system (batched).
/// Events are collected and broadcast once per second to WebSocket subscribers.
/// No-ops if the websocket server has not been started (no broadcast channel).
pub async fn post_event(organization_id: String, data: Value) {
    if BROADCAST_TX.get().is_none() {
        return;
    }
    ensure_event_batcher_started().await;
    let record: EventRecord = EventRecord {
        seq: next_seq().await,
        ts_ms: now_ts_ms(),
        payload: data,
    };
    cache_record(&organization_id, &record).await;
    append_to_log(&organization_id, &record).await;

    let buffer: SharedEventBuffer = get_buffer();
    let mut guard: MutexGuard<'_, HashMap<String, Vec<EventRecord>>> = buffer.lock().await;
    guard.entry(organization_id).or_default().push(record);
}