use super::events::{EventMessage, replay_since};
#[allow(deprecated)]
use super::state::SubscriptionRequest;
use super::state::{AppState, ClientControlMessage};
use axum::extract::ws::{Message, WebSocket};
use futures::stream::SplitSink;
use futures::{SinkExt, StreamExt};
use serde_json::{Value, from_str, json};
use std::collections::HashSet;
use std::sync::Arc;
use tokio::sync::broadcast::Receiver;
use tokio::sync::{Mutex, MutexGuard};
use tokio::task::JoinHandle;
use tokio::time::{Duration, Interval, interval};
use tracing::warn;
const SEND_THROTTLE_MS: u64 = 1000;
const TRAVEL_LIMIT_DEFAULT: usize = 10_000;
pub async fn handle_socket(socket: WebSocket, state: AppState, client_id: String) {
let (sender, mut receiver) = socket.split();
let sender: Arc<Mutex<SplitSink<WebSocket, Message>>> = Arc::new(Mutex::new(sender));
let mut rx: Receiver<String> = state.tx.subscribe();
let subscribed_companies: Arc<Mutex<HashSet<String>>> = {
let mut set: HashSet<String> = HashSet::new();
set.insert(client_id.clone());
Arc::new(Mutex::new(set))
};
{
let mut subscribers: MutexGuard<'_, HashSet<String>> =
state.active_subscribers.lock().await;
subscribers.insert(client_id.clone());
}
{
let confirmation: Value = json!({
"status": "subscribed",
"client_id": client_id,
"message": "Subscribed via X-Athena-Client header"
});
let mut sender_guard: MutexGuard<'_, SplitSink<WebSocket, Message>> = sender.lock().await;
if sender_guard
.send(Message::Text(confirmation.to_string()))
.await
.is_err()
{
warn!("Failed to send subscription confirmation to client");
}
}
let subscribed_companies_clone: Arc<Mutex<HashSet<String>>> = subscribed_companies.clone();
let message_buffer: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
let message_buffer_clone: Arc<Mutex<Vec<String>>> = message_buffer.clone();
let mut send_task: JoinHandle<()> = tokio::spawn({
let sender: Arc<Mutex<SplitSink<WebSocket, Message>>> = Arc::clone(&sender);
let subscribed_companies: Arc<Mutex<HashSet<String>>> = subscribed_companies.clone();
async move {
let mut ticker: Interval = interval(Duration::from_millis(SEND_THROTTLE_MS));
loop {
tokio::select! {
msg_result = rx.recv() => {
match msg_result {
Ok(msg) => {
if let Ok(event) = from_str::<EventMessage>(&msg) {
if is_client_subscribed(&subscribed_companies, &event.organization_id).await {
let mut buffer = message_buffer_clone.lock().await;
buffer.push(msg);
}
} else {
warn!("Failed to parse EventMessage from broadcast message");
}
}
Err(_) => break,
}
}
_ = ticker.tick() => {
let mut buffer = message_buffer_clone.lock().await;
if !buffer.is_empty() {
for msg in buffer.drain(..) {
send_message_to_client(&sender, msg).await;
}
}
}
}
}
}
});
let state_for_recv: AppState = state.clone();
let mut recv_task: JoinHandle<()> = tokio::spawn({
let sender: Arc<Mutex<SplitSink<WebSocket, Message>>> = Arc::clone(&sender);
async move {
while let Some(Ok(Message::Text(text))) = receiver.next().await {
if let Ok(control) = serde_json::from_str::<ClientControlMessage>(&text) {
let ClientControlMessage::Travel {
organization_id,
since_seq,
since_ts_ms,
limit,
} = control;
{
let limit: usize = limit.unwrap_or(TRAVEL_LIMIT_DEFAULT);
let requested_org: Option<String> =
organization_id.as_ref().and_then(|v| {
let t = v.trim();
if t.is_empty() {
None
} else {
Some(t.to_string())
}
});
let orgs: Vec<String> = if let Some(org) = requested_org.clone() {
vec![org]
} else {
let companies: MutexGuard<'_, HashSet<String>> =
subscribed_companies_clone.lock().await;
companies.iter().cloned().collect::<Vec<_>>()
};
for org in orgs {
let records: Vec<crate::cdc::websocket::events::EventRecord> =
replay_since(&org, since_seq, since_ts_ms, limit).await;
for record in records {
let msg = EventMessage {
organization_id: org.clone(),
data: json!({
"seq": record.seq,
"ts_ms": record.ts_ms,
"payload": record.payload,
}),
};
if let Ok(encoded) = serde_json::to_string(&msg) {
send_message_to_client(&sender, encoded).await;
}
}
}
let ack: Value = json!({
"status": "travel_complete",
"organization_id": requested_org,
"since_seq": since_seq,
"since_ts_ms": since_ts_ms,
"limit": limit,
});
send_message_to_client(&sender, ack.to_string()).await;
}
continue;
}
#[allow(deprecated)]
if let Ok(subscription) = serde_json::from_str::<SubscriptionRequest>(&text) {
handle_subscription_request(
&state_for_recv,
&subscribed_companies_clone,
&sender,
subscription,
)
.await;
} else {
warn!("Failed to parse SubscriptionRequest from client message");
}
}
}
});
tokio::select! {
_ = &mut send_task => {}
_ = &mut recv_task => {}
}
send_task.abort();
recv_task.abort();
{
let mut subscribers: MutexGuard<'_, HashSet<String>> =
state.active_subscribers.lock().await;
subscribers.remove(&client_id);
}
{
let mut buffer: MutexGuard<'_, Vec<String>> = message_buffer.lock().await;
buffer.clear();
}
}
async fn is_client_subscribed(
subscribed_companies: &Arc<Mutex<HashSet<String>>>,
organization_id: &str,
) -> bool {
let companies: MutexGuard<'_, HashSet<String>> = subscribed_companies.lock().await;
companies.contains(organization_id)
}
async fn send_message_to_client(sender: &Arc<Mutex<SplitSink<WebSocket, Message>>>, msg: String) {
let mut guard: MutexGuard<'_, SplitSink<WebSocket, Message>> = sender.lock().await;
if guard.send(Message::Text(msg)).await.is_err() {
warn!("Failed to send message to client");
}
}
#[allow(deprecated)]
async fn handle_subscription_request(
state: &AppState,
subscribed_companies: &Arc<Mutex<HashSet<String>>>,
sender: &Arc<Mutex<SplitSink<WebSocket, Message>>>,
subscription: SubscriptionRequest,
) {
let organization_id: String = subscription.subscribe_to_organization_id;
{
let mut companies: MutexGuard<'_, HashSet<String>> = subscribed_companies.lock().await;
companies.insert(organization_id.clone());
}
{
let mut subscribers: MutexGuard<'_, HashSet<String>> =
state.active_subscribers.lock().await;
subscribers.insert(organization_id.clone());
}
let confirmation: Value = json!({
"status": "subscribed",
"organization_id": organization_id
});
let mut guard: MutexGuard<'_, SplitSink<WebSocket, Message>> = sender.lock().await;
if guard
.send(Message::Text(confirmation.to_string()))
.await
.is_err()
{
warn!("Failed to send subscription confirmation to client");
}
}