use axum::{
Router,
extract::{
Json, State,
ws::{Message, WebSocket, WebSocketUpgrade},
},
http::HeaderMap,
http::StatusCode,
response::IntoResponse,
routing::{get, post},
};
use futures::stream::SplitSink;
use futures::{SinkExt, StreamExt};
use serde::{Deserialize, Serialize};
use serde_json::{Value, from_str, json};
use std::error::Error as stdError;
use std::{collections::HashSet, net::SocketAddr, sync::Arc};
use tokio::net::TcpListener;
use tokio::sync::broadcast::Receiver;
use tokio::sync::{Mutex, MutexGuard, broadcast};
use tokio::task::JoinHandle;
use tokio::time::{Duration, Interval, interval};
use tracing::{info, warn};
pub mod events;
use events::{EventMessage, set_broadcast_tx};
type Tx = broadcast::Sender<String>;
#[derive(Debug, Serialize, Deserialize)]
#[deprecated(
since = "0.80.3",
note = "Use X-Athena-Client header when connecting to WebSocket instead"
)]
struct SubscriptionRequest {
subscribe_to_organization_id: String,
}
#[derive(Clone)]
struct AppState {
tx: Tx,
active_subscribers: Arc<Mutex<HashSet<String>>>,
}
pub async fn websocket_server(port: u16) -> Result<(), Box<dyn stdError>> {
let capacity: usize = 1000;
let (tx, _) = broadcast::channel::<String>(capacity);
set_broadcast_tx(tx.clone());
info!("DMS Broadcast channel created with capacity: {}", capacity);
let active_subscribers: Arc<Mutex<HashSet<String>>> = Arc::new(Mutex::new(HashSet::new()));
let state: AppState = AppState {
tx: tx.clone(),
active_subscribers: active_subscribers.clone(),
};
let app: Router = Router::new()
.route("/ws", get(ws_handler))
.route("/events", post(publish_event))
.route("/status", get(status))
.with_state(state);
let addr: SocketAddr = SocketAddr::from(([0, 0, 0, 0], port));
info!("Server running at http://{}", addr);
info!("WebSocket endpoint available at ws://{}/ws", addr);
let listener: TcpListener = TcpListener::bind(&addr).await?;
axum::serve(listener, app).await?;
info!("Server started successfully");
Ok(())
}
async fn ws_handler(
ws: WebSocketUpgrade,
headers: HeaderMap,
State(state): State<AppState>,
) -> impl IntoResponse {
let client_id: String = headers
.get("x-athena-client")
.or_else(|| headers.get("X-Athena-Client"))
.and_then(|h| h.to_str().ok())
.unwrap_or("")
.trim()
.to_string();
if client_id.is_empty() {
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": "Missing X-Athena-Client header",
"message": "Pass X-Athena-Client when connecting to subscribe to CDC events"
})),
)
.into_response();
}
ws.on_upgrade(move |socket| handle_socket(socket, state, client_id))
}
async fn handle_socket(socket: WebSocket, state: AppState, client_id: String) {
let (sender, mut receiver) = socket.split();
let sender: Arc<Mutex<SplitSink<WebSocket, Message>>> = Arc::new(Mutex::new(sender));
let mut rx: Receiver<String> = state.tx.subscribe();
let subscribed_companies: Arc<Mutex<HashSet<String>>> = {
let mut set: HashSet<String> = HashSet::new();
set.insert(client_id.clone());
Arc::new(Mutex::new(set))
};
{
let mut subscribers: MutexGuard<'_, HashSet<String>> =
state.active_subscribers.lock().await;
subscribers.insert(client_id.clone());
}
{
let confirmation: Value = json!({
"status": "subscribed",
"client_id": client_id,
"message": "Subscribed via X-Athena-Client header"
});
let mut sender_guard: MutexGuard<'_, SplitSink<WebSocket, Message>> = sender.lock().await;
if sender_guard
.send(Message::Text(confirmation.to_string().into()))
.await
.is_err()
{
warn!("Failed to send subscription confirmation to client");
}
}
let subscribed_companies_clone: Arc<Mutex<HashSet<String>>> = subscribed_companies.clone();
let message_buffer: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
let message_buffer_clone: Arc<Mutex<Vec<String>>> = message_buffer.clone();
let mut send_task: JoinHandle<()> = tokio::spawn({
let sender: Arc<Mutex<SplitSink<WebSocket, Message>>> = Arc::clone(&sender);
let subscribed_companies: Arc<Mutex<HashSet<String>>> = subscribed_companies.clone();
async move {
let mut ticker: Interval = interval(Duration::from_millis(1000));
loop {
tokio::select! {
msg_result = rx.recv() => {
match msg_result {
Ok(msg) => {
if let Ok(event) = from_str::<EventMessage>(&msg) {
if is_client_subscribed(&subscribed_companies, &event.organization_id).await {
let mut buffer = message_buffer_clone.lock().await;
buffer.push(msg);
}
} else {
warn!("Failed to parse EventMessage from broadcast message");
}
}
Err(_) => {
break;
}
}
}
_ = ticker.tick() => {
let mut buffer = message_buffer_clone.lock().await;
if !buffer.is_empty() {
for msg in buffer.drain(..) {
send_message_to_client(&sender, msg).await;
}
}
}
}
}
}
});
let mut recv_task: JoinHandle<()> = tokio::spawn({
let sender: Arc<Mutex<SplitSink<WebSocket, Message>>> = Arc::clone(&sender);
async move {
while let Some(Ok(Message::Text(text))) = receiver.next().await {
#[allow(deprecated)]
if let Ok(subscription) = serde_json::from_str::<SubscriptionRequest>(&text) {
handle_subscription_request(
&state,
&subscribed_companies_clone,
&sender,
subscription,
)
.await;
} else {
warn!("Failed to parse SubscriptionRequest from client message");
}
}
}
});
tokio::select! {
_ = &mut send_task => {
},
_ = &mut recv_task => {
},
}
}
async fn is_client_subscribed(
subscribed_companies: &Arc<Mutex<HashSet<String>>>,
organization_id: &str,
) -> bool {
let companies: MutexGuard<'_, HashSet<String>> = subscribed_companies.lock().await;
companies.contains(organization_id)
}
async fn send_message_to_client(sender: &Arc<Mutex<SplitSink<WebSocket, Message>>>, msg: String) {
let mut sender: MutexGuard<'_, SplitSink<WebSocket, Message>> = sender.lock().await;
if sender.send(Message::Text(msg.into())).await.is_err() {
warn!("Failed to send message to client");
}
}
#[allow(deprecated)]
async fn handle_subscription_request(
state: &AppState,
subscribed_companies: &Arc<Mutex<HashSet<String>>>,
sender: &Arc<Mutex<SplitSink<WebSocket, Message>>>,
subscription: SubscriptionRequest,
) {
let organization_id: String = subscription.subscribe_to_organization_id;
{
let mut companies: MutexGuard<'_, HashSet<String>> = subscribed_companies.lock().await;
companies.insert(organization_id.clone());
}
{
let mut subscribers: MutexGuard<'_, HashSet<String>> =
state.active_subscribers.lock().await;
subscribers.insert(organization_id.clone());
}
let confirmation: Value = json!({
"status": "subscribed",
"organization_id": organization_id
});
let mut sender: MutexGuard<'_, SplitSink<WebSocket, Message>> = sender.lock().await;
if sender
.send(Message::Text(confirmation.to_string().into()))
.await
.is_err()
{
warn!("Failed to send subscription confirmation to client");
}
}
async fn publish_event(
State(state): State<AppState>,
Json(event): Json<EventMessage>,
) -> impl IntoResponse {
let organization_id: String = event.organization_id.clone();
let has_subscribers: bool = {
let subscribers: MutexGuard<'_, HashSet<String>> = state.active_subscribers.lock().await;
subscribers.contains(&organization_id)
};
if !has_subscribers {
let response: Value = json!({
"message": "there was no subscriber to this organization_id channel so the message has been voided",
"organization_id": organization_id,
"status": "voided",
"success": true
});
return (StatusCode::OK, Json(response)).into_response();
}
let event_json: String = serde_json::to_string(&event).unwrap_or_default();
if let Err(err) = state.tx.send(event_json) {
warn!("Failed to broadcast event: {}", err);
return (
StatusCode::INTERNAL_SERVER_ERROR,
"Failed to broadcast event",
)
.into_response();
}
let response: Value = json!({
"status": "delivered",
"success": true,
"organization_id": organization_id
});
(StatusCode::OK, Json(response)).into_response()
}
async fn status() -> impl IntoResponse {
let response: Value = json!({
"status": "ok",
"success": true,
"server": "dms-server-api"
});
(StatusCode::OK, Json(response)).into_response()
}