use crate::{domain::entities::Event, store::EventStore};
use axum::extract::ws::{Message, WebSocket};
use dashmap::DashMap;
use futures::{sink::SinkExt, stream::StreamExt};
use std::sync::Arc;
use tokio::sync::broadcast;
use uuid::Uuid;
#[derive(Debug, Clone)]
pub struct WebSocketConfig {
pub capacity: usize,
pub batch_interval_ms: Option<u64>,
pub max_batch_size: usize,
}
impl Default for WebSocketConfig {
fn default() -> Self {
Self {
capacity: 1000,
batch_interval_ms: None,
max_batch_size: 100,
}
}
}
pub struct WebSocketManager {
event_tx: broadcast::Sender<Arc<Event>>,
clients: Arc<DashMap<Uuid, ClientInfo>>,
config: WebSocketConfig,
}
#[derive(Debug, Clone)]
struct ClientInfo {
id: Uuid,
filters: EventFilters,
}
#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)]
pub struct EventFilters {
pub entity_id: Option<String>,
pub event_type: Option<String>,
#[serde(default)]
pub event_type_prefixes: Vec<String>,
}
#[derive(Debug, serde::Deserialize)]
struct SubscribeMessage {
#[serde(rename = "type")]
msg_type: String,
#[serde(default)]
filters: Vec<String>,
}
impl WebSocketManager {
pub fn new() -> Self {
Self::with_config(WebSocketConfig::default())
}
pub fn with_config(config: WebSocketConfig) -> Self {
let (event_tx, _) = broadcast::channel(config.capacity);
Self {
event_tx,
clients: Arc::new(DashMap::new()),
config,
}
}
#[cfg_attr(feature = "hotpath", hotpath::measure)]
pub fn broadcast_event(&self, event: Arc<Event>) {
let _ = self.event_tx.send(event);
}
pub fn subscribe_events(&self) -> broadcast::Receiver<Arc<Event>> {
self.event_tx.subscribe()
}
#[cfg_attr(feature = "hotpath", hotpath::measure)]
pub async fn handle_socket(&self, socket: WebSocket) {
self.handle_socket_inner(socket, None, None).await;
}
pub async fn handle_socket_with_consumer(
&self,
socket: WebSocket,
consumer_id: String,
store: Arc<EventStore>,
) {
self.handle_socket_inner(socket, Some(consumer_id), Some(store))
.await;
}
async fn handle_socket_inner(
&self,
socket: WebSocket,
consumer_id: Option<String>,
store: Option<Arc<EventStore>>,
) {
let client_id = Uuid::new_v4();
tracing::info!(
"🔌 WebSocket client connected: {} (consumer: {:?})",
client_id,
consumer_id
);
let event_rx = self.event_tx.subscribe();
let (mut sender, mut receiver) = socket.split();
let mut consumer_filters: Vec<String> = Vec::new();
if let (Some(cid), Some(store)) = (&consumer_id, &store) {
let registry = store.consumer_registry();
let consumer = registry.get_or_create(cid);
consumer_filters = consumer.event_type_filters.clone();
let cursor = consumer.cursor_position.unwrap_or(0);
let replay_events = store.events_after_offset(cursor, &consumer_filters, usize::MAX);
tracing::info!(
"Replaying {} events for consumer '{}' from offset {}",
replay_events.len(),
cid,
cursor
);
for (position, event) in &replay_events {
let dto = serde_json::json!({
"type": "replay",
"position": position,
"event": event,
});
if let Ok(json) = serde_json::to_string(&dto)
&& sender.send(Message::Text(json.into())).await.is_err()
{
tracing::warn!("Failed to send replay event to client {}", client_id);
return;
}
}
let sentinel =
serde_json::json!({"type": "replay_complete", "replayed": replay_events.len()});
if let Ok(json) = serde_json::to_string(&sentinel) {
let _ = sender.send(Message::Text(json.into())).await;
}
}
let initial_filters = if consumer_filters.is_empty() {
EventFilters::default()
} else {
EventFilters {
event_type_prefixes: consumer_filters,
..Default::default()
}
};
self.clients.insert(
client_id,
ClientInfo {
id: client_id,
filters: initial_filters,
},
);
let clients = Arc::clone(&self.clients);
let config = self.config.clone();
let send_task = tokio::spawn(async move {
if let Some(interval_ms) = config.batch_interval_ms {
Self::send_batched(
event_rx,
sender,
clients,
client_id,
interval_ms,
config.max_batch_size,
)
.await;
} else {
Self::send_unbatched(event_rx, sender, clients, client_id).await;
}
});
let clients = Arc::clone(&self.clients);
let recv_task = tokio::spawn(async move {
while let Some(Ok(msg)) = receiver.next().await {
if let Message::Text(text) = msg {
let text_str = text.as_str();
if let Ok(sub) = serde_json::from_str::<SubscribeMessage>(text_str)
&& sub.msg_type == "subscribe"
{
tracing::info!(
"Setting prefix filters for client {}: {:?}",
client_id,
sub.filters
);
if let Some(mut client) = clients.get_mut(&client_id) {
client.filters.event_type_prefixes = sub.filters;
client.filters.event_type = None;
}
continue;
}
if let Ok(filters) = serde_json::from_str::<EventFilters>(text_str) {
tracing::info!("Setting filters for client {}: {:?}", client_id, filters);
if let Some(mut client) = clients.get_mut(&client_id) {
client.filters = filters;
}
}
}
}
});
tokio::select! {
_ = send_task => {
tracing::info!("Send task ended for client {}", client_id);
}
_ = recv_task => {
tracing::info!("Receive task ended for client {}", client_id);
}
}
self.clients.remove(&client_id);
tracing::info!("🔌 WebSocket client disconnected: {}", client_id);
}
async fn send_unbatched(
mut event_rx: broadcast::Receiver<Arc<Event>>,
mut sender: futures::stream::SplitSink<WebSocket, Message>,
clients: Arc<DashMap<Uuid, ClientInfo>>,
client_id: Uuid,
) {
loop {
match event_rx.recv().await {
Ok(event) => {
if !Self::passes_filters(&clients, client_id, &event) {
continue;
}
match serde_json::to_string(&*event) {
Ok(json) => {
if sender.send(Message::Text(json.into())).await.is_err() {
tracing::warn!("Failed to send event to client {}", client_id);
break;
}
}
Err(e) => {
tracing::error!("Failed to serialize event: {}", e);
}
}
}
Err(broadcast::error::RecvError::Lagged(n)) => {
let msg = serde_json::json!({"type": "lagged", "missed": n});
let _ = sender.send(Message::Text(msg.to_string().into())).await;
tracing::warn!("Client {} lagged, missed {} events", client_id, n);
}
Err(broadcast::error::RecvError::Closed) => break,
}
}
}
async fn send_batched(
mut event_rx: broadcast::Receiver<Arc<Event>>,
mut sender: futures::stream::SplitSink<WebSocket, Message>,
clients: Arc<DashMap<Uuid, ClientInfo>>,
client_id: Uuid,
interval_ms: u64,
max_batch_size: usize,
) {
let mut batch: Vec<serde_json::Value> = Vec::with_capacity(max_batch_size);
let mut ticker = tokio::time::interval(std::time::Duration::from_millis(interval_ms));
ticker.tick().await;
loop {
tokio::select! {
result = event_rx.recv() => {
match result {
Ok(event) => {
if !Self::passes_filters(&clients, client_id, &event) {
continue;
}
if let Ok(val) = serde_json::to_value(&*event) {
batch.push(val);
}
if batch.len() >= max_batch_size
&& !Self::flush_batch(&mut sender, &mut batch, client_id).await
{
break;
}
}
Err(broadcast::error::RecvError::Lagged(n)) => {
let _ = Self::flush_batch(&mut sender, &mut batch, client_id).await;
let msg = serde_json::json!({"type": "lagged", "missed": n});
let _ = sender
.send(Message::Text(msg.to_string().into()))
.await;
tracing::warn!(
"Client {} lagged, missed {} events",
client_id,
n
);
}
Err(broadcast::error::RecvError::Closed) => {
let _ = Self::flush_batch(&mut sender, &mut batch, client_id).await;
break;
}
}
}
_ = ticker.tick() => {
if !batch.is_empty()
&& !Self::flush_batch(&mut sender, &mut batch, client_id).await
{
break;
}
}
}
}
}
async fn flush_batch(
sender: &mut futures::stream::SplitSink<WebSocket, Message>,
batch: &mut Vec<serde_json::Value>,
client_id: Uuid,
) -> bool {
if batch.is_empty() {
return true;
}
let json_array = serde_json::Value::Array(std::mem::take(batch));
match serde_json::to_string(&json_array) {
Ok(json) => {
if sender.send(Message::Text(json.into())).await.is_err() {
tracing::warn!("Failed to send batch to client {}", client_id);
return false;
}
true
}
Err(e) => {
tracing::error!("Failed to serialize batch: {}", e);
batch.clear();
true
}
}
}
fn passes_filters(clients: &DashMap<Uuid, ClientInfo>, client_id: Uuid, event: &Event) -> bool {
let filters = clients
.get(&client_id)
.map(|entry| entry.value().filters.clone())
.unwrap_or_default();
if let Some(ref entity_id) = filters.entity_id
&& event.entity_id_str() != entity_id
{
return false;
}
if let Some(ref event_type) = filters.event_type
&& event.event_type_str() != event_type
{
return false;
}
if !filters.event_type_prefixes.is_empty() {
let event_type = event.event_type_str();
let matches = filters.event_type_prefixes.iter().any(|pattern| {
if let Some(prefix) = pattern.strip_suffix(".*") {
event_type.starts_with(prefix)
&& event_type.as_bytes().get(prefix.len()) == Some(&b'.')
} else {
event_type == pattern
}
});
if !matches {
return false;
}
}
true
}
pub fn stats(&self) -> WebSocketStats {
WebSocketStats {
connected_clients: self.clients.len(),
total_capacity: self.event_tx.receiver_count(),
}
}
}
impl Default for WebSocketManager {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, serde::Serialize)]
pub struct WebSocketStats {
pub connected_clients: usize,
pub total_capacity: usize,
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
fn create_test_event() -> Event {
Event::reconstruct_from_strings(
Uuid::new_v4(),
"test.event".to_string(),
"test-entity".to_string(),
"default".to_string(),
json!({"test": "data"}),
chrono::Utc::now(),
None,
1,
)
}
#[test]
fn test_websocket_manager_creation() {
let manager = WebSocketManager::new();
let stats = manager.stats();
assert_eq!(stats.connected_clients, 0);
}
#[test]
fn test_event_broadcast() {
let manager = WebSocketManager::new();
let event = Arc::new(create_test_event());
manager.broadcast_event(event);
}
#[test]
fn test_config_defaults() {
let config = WebSocketConfig::default();
assert_eq!(config.capacity, 1000);
assert_eq!(config.batch_interval_ms, None);
assert_eq!(config.max_batch_size, 100);
}
#[test]
fn test_lagged_notification() {
let config = WebSocketConfig {
capacity: 2,
batch_interval_ms: None,
max_batch_size: 100,
};
let manager = WebSocketManager::with_config(config);
let mut rx = manager.subscribe_events();
for _ in 0..5 {
manager.broadcast_event(Arc::new(create_test_event()));
}
match rx.try_recv() {
Err(broadcast::error::TryRecvError::Lagged(n)) => {
assert!(n > 0, "should report missed events");
}
Ok(_) => {
}
Err(e) => {
panic!("unexpected error: {e:?}");
}
}
}
#[test]
fn test_batch_mode_groups_events() {
let config = WebSocketConfig {
capacity: 1000,
batch_interval_ms: Some(50),
max_batch_size: 10,
};
let manager = WebSocketManager::with_config(config.clone());
assert_eq!(manager.config.batch_interval_ms, Some(50));
assert_eq!(manager.config.max_batch_size, 10);
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(async {
let events: Vec<serde_json::Value> = (0..3)
.map(|_| serde_json::to_value(create_test_event()).unwrap())
.collect();
let json_array = serde_json::Value::Array(events);
let serialized = serde_json::to_string(&json_array).unwrap();
let parsed: Vec<serde_json::Value> = serde_json::from_str(&serialized).unwrap();
assert_eq!(parsed.len(), 3);
});
}
#[test]
fn test_batch_flush_on_max_size() {
let config = WebSocketConfig {
capacity: 1000,
batch_interval_ms: Some(1000), max_batch_size: 5, };
let manager = WebSocketManager::with_config(config);
assert_eq!(manager.config.max_batch_size, 5);
}
#[test]
fn test_prefix_filter_matching() {
let manager = WebSocketManager::new();
let client_id = Uuid::new_v4();
manager.clients.insert(
client_id,
ClientInfo {
id: client_id,
filters: EventFilters {
entity_id: None,
event_type: None,
event_type_prefixes: vec!["scheduler.*".to_string()],
},
},
);
let matching = Event::reconstruct_from_strings(
Uuid::new_v4(),
"scheduler.started".to_string(),
"e1".to_string(),
"default".to_string(),
json!({}),
chrono::Utc::now(),
None,
1,
);
assert!(WebSocketManager::passes_filters(
&manager.clients,
client_id,
&matching
));
let non_matching = Event::reconstruct_from_strings(
Uuid::new_v4(),
"trade.executed".to_string(),
"e2".to_string(),
"default".to_string(),
json!({}),
chrono::Utc::now(),
None,
1,
);
assert!(!WebSocketManager::passes_filters(
&manager.clients,
client_id,
&non_matching
));
}
#[test]
fn test_prefix_filter_multiple() {
let manager = WebSocketManager::new();
let client_id = Uuid::new_v4();
manager.clients.insert(
client_id,
ClientInfo {
id: client_id,
filters: EventFilters {
entity_id: None,
event_type: None,
event_type_prefixes: vec!["scheduler.*".to_string(), "index.*".to_string()],
},
},
);
let scheduler_event = Event::reconstruct_from_strings(
Uuid::new_v4(),
"scheduler.completed".to_string(),
"e1".to_string(),
"default".to_string(),
json!({}),
chrono::Utc::now(),
None,
1,
);
assert!(WebSocketManager::passes_filters(
&manager.clients,
client_id,
&scheduler_event
));
let index_event = Event::reconstruct_from_strings(
Uuid::new_v4(),
"index.created".to_string(),
"e1".to_string(),
"default".to_string(),
json!({}),
chrono::Utc::now(),
None,
1,
);
assert!(WebSocketManager::passes_filters(
&manager.clients,
client_id,
&index_event
));
let trade_event = Event::reconstruct_from_strings(
Uuid::new_v4(),
"trade.executed".to_string(),
"e1".to_string(),
"default".to_string(),
json!({}),
chrono::Utc::now(),
None,
1,
);
assert!(!WebSocketManager::passes_filters(
&manager.clients,
client_id,
&trade_event
));
}
#[test]
fn test_no_prefix_filters_matches_all() {
let manager = WebSocketManager::new();
let client_id = Uuid::new_v4();
manager.clients.insert(
client_id,
ClientInfo {
id: client_id,
filters: EventFilters::default(),
},
);
let event = create_test_event();
assert!(WebSocketManager::passes_filters(
&manager.clients,
client_id,
&event
));
}
#[test]
fn test_subscribe_message_parsing() {
let json = r#"{"type": "subscribe", "filters": ["scheduler.*", "index.*"]}"#;
let msg: SubscribeMessage = serde_json::from_str(json).unwrap();
assert_eq!(msg.msg_type, "subscribe");
assert_eq!(msg.filters, vec!["scheduler.*", "index.*"]);
}
#[test]
fn test_backward_compat_no_config() {
let manager = WebSocketManager::new();
assert_eq!(manager.config.capacity, 1000);
assert!(manager.config.batch_interval_ms.is_none());
let event = Arc::new(create_test_event());
manager.broadcast_event(event);
let stats = manager.stats();
assert_eq!(stats.connected_clients, 0);
}
}