1use crate::domain::entities::Event;
2use crate::store::EventStore;
3use axum::extract::ws::{Message, WebSocket};
4use dashmap::DashMap;
5use futures::{sink::SinkExt, stream::StreamExt};
6use std::sync::Arc;
7use tokio::sync::broadcast;
8use uuid::Uuid;
9
10#[derive(Debug, Clone)]
12pub struct WebSocketConfig {
13 pub capacity: usize,
15 pub batch_interval_ms: Option<u64>,
19 pub max_batch_size: usize,
21}
22
23impl Default for WebSocketConfig {
24 fn default() -> Self {
25 Self {
26 capacity: 1000,
27 batch_interval_ms: None,
28 max_batch_size: 100,
29 }
30 }
31}
32
33pub struct WebSocketManager {
35 event_tx: broadcast::Sender<Arc<Event>>,
37
38 clients: Arc<DashMap<Uuid, ClientInfo>>,
40
41 config: WebSocketConfig,
43}
44
45#[derive(Debug, Clone)]
46struct ClientInfo {
47 id: Uuid,
48 filters: EventFilters,
49}
50
51#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)]
52pub struct EventFilters {
53 pub entity_id: Option<String>,
54 pub event_type: Option<String>,
55 #[serde(default)]
58 pub event_type_prefixes: Vec<String>,
59}
60
61#[derive(Debug, serde::Deserialize)]
64struct SubscribeMessage {
65 #[serde(rename = "type")]
66 msg_type: String,
67 #[serde(default)]
68 filters: Vec<String>,
69}
70
71impl WebSocketManager {
72 pub fn new() -> Self {
73 Self::with_config(WebSocketConfig::default())
74 }
75
76 pub fn with_config(config: WebSocketConfig) -> Self {
78 let (event_tx, _) = broadcast::channel(config.capacity);
79
80 Self {
81 event_tx,
82 clients: Arc::new(DashMap::new()),
83 config,
84 }
85 }
86
87 pub fn broadcast_event(&self, event: Arc<Event>) {
89 let _ = self.event_tx.send(event);
91 }
92
93 pub fn subscribe_events(&self) -> broadcast::Receiver<Arc<Event>> {
95 self.event_tx.subscribe()
96 }
97
98 pub async fn handle_socket(&self, socket: WebSocket) {
100 self.handle_socket_inner(socket, None, None).await;
101 }
102
103 pub async fn handle_socket_with_consumer(
108 &self,
109 socket: WebSocket,
110 consumer_id: String,
111 store: Arc<EventStore>,
112 ) {
113 self.handle_socket_inner(socket, Some(consumer_id), Some(store))
114 .await;
115 }
116
117 async fn handle_socket_inner(
118 &self,
119 socket: WebSocket,
120 consumer_id: Option<String>,
121 store: Option<Arc<EventStore>>,
122 ) {
123 let client_id = Uuid::new_v4();
124 tracing::info!(
125 "🔌 WebSocket client connected: {} (consumer: {:?})",
126 client_id,
127 consumer_id
128 );
129
130 let event_rx = self.event_tx.subscribe();
132
133 let (mut sender, mut receiver) = socket.split();
135
136 let mut consumer_filters: Vec<String> = Vec::new();
138 if let (Some(cid), Some(store)) = (&consumer_id, &store) {
139 let registry = store.consumer_registry();
140 let consumer = registry.get_or_create(cid);
141 consumer_filters = consumer.event_type_filters.clone();
142 let cursor = consumer.cursor_position.unwrap_or(0);
143
144 let replay_events =
145 store.events_after_offset(cursor, &consumer_filters, usize::MAX);
146
147 tracing::info!(
148 "Replaying {} events for consumer '{}' from offset {}",
149 replay_events.len(),
150 cid,
151 cursor
152 );
153
154 for (position, event) in &replay_events {
155 let dto = serde_json::json!({
156 "type": "replay",
157 "position": position,
158 "event": event,
159 });
160 if let Ok(json) = serde_json::to_string(&dto)
161 && sender.send(Message::Text(json.into())).await.is_err()
162 {
163 tracing::warn!("Failed to send replay event to client {}", client_id);
164 return;
165 }
166 }
167
168 let sentinel = serde_json::json!({"type": "replay_complete", "replayed": replay_events.len()});
170 if let Ok(json) = serde_json::to_string(&sentinel) {
171 let _ = sender.send(Message::Text(json.into())).await;
172 }
173 }
174
175 let initial_filters = if !consumer_filters.is_empty() {
177 EventFilters {
178 event_type_prefixes: consumer_filters,
179 ..Default::default()
180 }
181 } else {
182 EventFilters::default()
183 };
184
185 self.clients.insert(
186 client_id,
187 ClientInfo {
188 id: client_id,
189 filters: initial_filters,
190 },
191 );
192
193 let clients = Arc::clone(&self.clients);
195 let config = self.config.clone();
196 let send_task = tokio::spawn(async move {
197 if let Some(interval_ms) = config.batch_interval_ms {
198 Self::send_batched(
199 event_rx,
200 sender,
201 clients,
202 client_id,
203 interval_ms,
204 config.max_batch_size,
205 )
206 .await;
207 } else {
208 Self::send_unbatched(event_rx, sender, clients, client_id).await;
209 }
210 });
211
212 let clients = Arc::clone(&self.clients);
214 let recv_task = tokio::spawn(async move {
215 while let Some(Ok(msg)) = receiver.next().await {
216 if let Message::Text(text) = msg {
217 let text_str = text.as_str();
218 if let Ok(sub) = serde_json::from_str::<SubscribeMessage>(text_str)
220 && sub.msg_type == "subscribe"
221 {
222 tracing::info!(
223 "Setting prefix filters for client {}: {:?}",
224 client_id,
225 sub.filters
226 );
227 if let Some(mut client) = clients.get_mut(&client_id) {
228 client.filters.event_type_prefixes = sub.filters;
229 client.filters.event_type = None;
231 }
232 continue;
233 }
234 if let Ok(filters) = serde_json::from_str::<EventFilters>(text_str) {
236 tracing::info!("Setting filters for client {}: {:?}", client_id, filters);
237 if let Some(mut client) = clients.get_mut(&client_id) {
238 client.filters = filters;
239 }
240 }
241 }
242 }
243 });
244
245 tokio::select! {
247 _ = send_task => {
248 tracing::info!("Send task ended for client {}", client_id);
249 }
250 _ = recv_task => {
251 tracing::info!("Receive task ended for client {}", client_id);
252 }
253 }
254
255 self.clients.remove(&client_id);
257 tracing::info!("🔌 WebSocket client disconnected: {}", client_id);
258 }
259
260 async fn send_unbatched(
262 mut event_rx: broadcast::Receiver<Arc<Event>>,
263 mut sender: futures::stream::SplitSink<WebSocket, Message>,
264 clients: Arc<DashMap<Uuid, ClientInfo>>,
265 client_id: Uuid,
266 ) {
267 loop {
268 match event_rx.recv().await {
269 Ok(event) => {
270 if !Self::passes_filters(&clients, client_id, &event) {
271 continue;
272 }
273
274 match serde_json::to_string(&*event) {
275 Ok(json) => {
276 if sender.send(Message::Text(json.into())).await.is_err() {
277 tracing::warn!("Failed to send event to client {}", client_id);
278 break;
279 }
280 }
281 Err(e) => {
282 tracing::error!("Failed to serialize event: {}", e);
283 }
284 }
285 }
286 Err(broadcast::error::RecvError::Lagged(n)) => {
287 let msg = serde_json::json!({"type": "lagged", "missed": n});
288 let _ = sender.send(Message::Text(msg.to_string().into())).await;
289 tracing::warn!("Client {} lagged, missed {} events", client_id, n);
290 }
291 Err(broadcast::error::RecvError::Closed) => break,
292 }
293 }
294 }
295
296 async fn send_batched(
298 mut event_rx: broadcast::Receiver<Arc<Event>>,
299 mut sender: futures::stream::SplitSink<WebSocket, Message>,
300 clients: Arc<DashMap<Uuid, ClientInfo>>,
301 client_id: Uuid,
302 interval_ms: u64,
303 max_batch_size: usize,
304 ) {
305 let mut batch: Vec<serde_json::Value> = Vec::with_capacity(max_batch_size);
306 let mut ticker = tokio::time::interval(std::time::Duration::from_millis(interval_ms));
307 ticker.tick().await; loop {
310 tokio::select! {
311 result = event_rx.recv() => {
312 match result {
313 Ok(event) => {
314 if !Self::passes_filters(&clients, client_id, &event) {
315 continue;
316 }
317
318 if let Ok(val) = serde_json::to_value(&*event) {
319 batch.push(val);
320 }
321
322 if batch.len() >= max_batch_size
324 && !Self::flush_batch(&mut sender, &mut batch, client_id).await
325 {
326 break;
327 }
328 }
329 Err(broadcast::error::RecvError::Lagged(n)) => {
330 let _ = Self::flush_batch(&mut sender, &mut batch, client_id).await;
332 let msg = serde_json::json!({"type": "lagged", "missed": n});
333 let _ = sender
334 .send(Message::Text(msg.to_string().into()))
335 .await;
336 tracing::warn!(
337 "Client {} lagged, missed {} events",
338 client_id,
339 n
340 );
341 }
342 Err(broadcast::error::RecvError::Closed) => {
343 let _ = Self::flush_batch(&mut sender, &mut batch, client_id).await;
345 break;
346 }
347 }
348 }
349 _ = ticker.tick() => {
350 if !batch.is_empty()
351 && !Self::flush_batch(&mut sender, &mut batch, client_id).await
352 {
353 break;
354 }
355 }
356 }
357 }
358 }
359
360 async fn flush_batch(
362 sender: &mut futures::stream::SplitSink<WebSocket, Message>,
363 batch: &mut Vec<serde_json::Value>,
364 client_id: Uuid,
365 ) -> bool {
366 if batch.is_empty() {
367 return true;
368 }
369
370 let json_array = serde_json::Value::Array(std::mem::take(batch));
371 match serde_json::to_string(&json_array) {
372 Ok(json) => {
373 if sender.send(Message::Text(json.into())).await.is_err() {
374 tracing::warn!("Failed to send batch to client {}", client_id);
375 return false;
376 }
377 true
378 }
379 Err(e) => {
380 tracing::error!("Failed to serialize batch: {}", e);
381 batch.clear();
382 true
383 }
384 }
385 }
386
387 fn passes_filters(clients: &DashMap<Uuid, ClientInfo>, client_id: Uuid, event: &Event) -> bool {
389 let filters = clients
390 .get(&client_id)
391 .map(|entry| entry.value().filters.clone())
392 .unwrap_or_default();
393
394 if let Some(ref entity_id) = filters.entity_id
395 && event.entity_id_str() != entity_id
396 {
397 return false;
398 }
399
400 if let Some(ref event_type) = filters.event_type
402 && event.event_type_str() != event_type
403 {
404 return false;
405 }
406
407 if !filters.event_type_prefixes.is_empty() {
409 let event_type = event.event_type_str();
410 let matches = filters.event_type_prefixes.iter().any(|pattern| {
411 if let Some(prefix) = pattern.strip_suffix(".*") {
412 event_type.starts_with(prefix)
413 && event_type.as_bytes().get(prefix.len()) == Some(&b'.')
414 } else {
415 event_type == pattern
416 }
417 });
418 if !matches {
419 return false;
420 }
421 }
422
423 true
424 }
425
426 pub fn stats(&self) -> WebSocketStats {
428 WebSocketStats {
429 connected_clients: self.clients.len(),
430 total_capacity: self.event_tx.receiver_count(),
431 }
432 }
433}
434
435impl Default for WebSocketManager {
436 fn default() -> Self {
437 Self::new()
438 }
439}
440
441#[derive(Debug, serde::Serialize)]
442pub struct WebSocketStats {
443 pub connected_clients: usize,
444 pub total_capacity: usize,
445}
446
447#[cfg(test)]
448mod tests {
449 use super::*;
450 use serde_json::json;
451
452 fn create_test_event() -> Event {
453 Event::reconstruct_from_strings(
454 Uuid::new_v4(),
455 "test.event".to_string(),
456 "test-entity".to_string(),
457 "default".to_string(),
458 json!({"test": "data"}),
459 chrono::Utc::now(),
460 None,
461 1,
462 )
463 }
464
465 #[test]
466 fn test_websocket_manager_creation() {
467 let manager = WebSocketManager::new();
468 let stats = manager.stats();
469 assert_eq!(stats.connected_clients, 0);
470 }
471
472 #[test]
473 fn test_event_broadcast() {
474 let manager = WebSocketManager::new();
475 let event = Arc::new(create_test_event());
476
477 manager.broadcast_event(event);
479 }
480
481 #[test]
482 fn test_config_defaults() {
483 let config = WebSocketConfig::default();
484 assert_eq!(config.capacity, 1000);
485 assert_eq!(config.batch_interval_ms, None);
486 assert_eq!(config.max_batch_size, 100);
487 }
488
489 #[test]
490 fn test_lagged_notification() {
491 let config = WebSocketConfig {
493 capacity: 2,
494 batch_interval_ms: None,
495 max_batch_size: 100,
496 };
497 let manager = WebSocketManager::with_config(config);
498
499 let mut rx = manager.subscribe_events();
501 for _ in 0..5 {
502 manager.broadcast_event(Arc::new(create_test_event()));
503 }
504
505 match rx.try_recv() {
507 Err(broadcast::error::TryRecvError::Lagged(n)) => {
508 assert!(n > 0, "should report missed events");
509 }
510 Ok(_) => {
511 }
513 Err(e) => {
514 panic!("unexpected error: {:?}", e);
515 }
516 }
517 }
518
519 #[test]
520 fn test_batch_mode_groups_events() {
521 let config = WebSocketConfig {
523 capacity: 1000,
524 batch_interval_ms: Some(50),
525 max_batch_size: 10,
526 };
527 let manager = WebSocketManager::with_config(config.clone());
528 assert_eq!(manager.config.batch_interval_ms, Some(50));
529 assert_eq!(manager.config.max_batch_size, 10);
530
531 let rt = tokio::runtime::Builder::new_current_thread()
533 .enable_all()
534 .build()
535 .unwrap();
536
537 rt.block_on(async {
538 let events: Vec<serde_json::Value> = (0..3)
540 .map(|_| serde_json::to_value(create_test_event()).unwrap())
541 .collect();
542
543 let json_array = serde_json::Value::Array(events);
544 let serialized = serde_json::to_string(&json_array).unwrap();
545 let parsed: Vec<serde_json::Value> = serde_json::from_str(&serialized).unwrap();
546 assert_eq!(parsed.len(), 3);
547 });
548 }
549
550 #[test]
551 fn test_batch_flush_on_max_size() {
552 let config = WebSocketConfig {
554 capacity: 1000,
555 batch_interval_ms: Some(1000), max_batch_size: 5, };
558 let manager = WebSocketManager::with_config(config);
559 assert_eq!(manager.config.max_batch_size, 5);
560 }
561
562 #[test]
563 fn test_prefix_filter_matching() {
564 let manager = WebSocketManager::new();
565 let client_id = Uuid::new_v4();
566
567 manager.clients.insert(
569 client_id,
570 ClientInfo {
571 id: client_id,
572 filters: EventFilters {
573 entity_id: None,
574 event_type: None,
575 event_type_prefixes: vec!["scheduler.*".to_string()],
576 },
577 },
578 );
579
580 let matching = Event::reconstruct_from_strings(
582 Uuid::new_v4(),
583 "scheduler.started".to_string(),
584 "e1".to_string(),
585 "default".to_string(),
586 json!({}),
587 chrono::Utc::now(),
588 None,
589 1,
590 );
591 assert!(WebSocketManager::passes_filters(
592 &manager.clients,
593 client_id,
594 &matching
595 ));
596
597 let non_matching = Event::reconstruct_from_strings(
599 Uuid::new_v4(),
600 "trade.executed".to_string(),
601 "e2".to_string(),
602 "default".to_string(),
603 json!({}),
604 chrono::Utc::now(),
605 None,
606 1,
607 );
608 assert!(!WebSocketManager::passes_filters(
609 &manager.clients,
610 client_id,
611 &non_matching
612 ));
613 }
614
615 #[test]
616 fn test_prefix_filter_multiple() {
617 let manager = WebSocketManager::new();
618 let client_id = Uuid::new_v4();
619
620 manager.clients.insert(
621 client_id,
622 ClientInfo {
623 id: client_id,
624 filters: EventFilters {
625 entity_id: None,
626 event_type: None,
627 event_type_prefixes: vec![
628 "scheduler.*".to_string(),
629 "index.*".to_string(),
630 ],
631 },
632 },
633 );
634
635 let scheduler_event = Event::reconstruct_from_strings(
636 Uuid::new_v4(),
637 "scheduler.completed".to_string(),
638 "e1".to_string(),
639 "default".to_string(),
640 json!({}),
641 chrono::Utc::now(),
642 None,
643 1,
644 );
645 assert!(WebSocketManager::passes_filters(
646 &manager.clients,
647 client_id,
648 &scheduler_event
649 ));
650
651 let index_event = Event::reconstruct_from_strings(
652 Uuid::new_v4(),
653 "index.created".to_string(),
654 "e1".to_string(),
655 "default".to_string(),
656 json!({}),
657 chrono::Utc::now(),
658 None,
659 1,
660 );
661 assert!(WebSocketManager::passes_filters(
662 &manager.clients,
663 client_id,
664 &index_event
665 ));
666
667 let trade_event = Event::reconstruct_from_strings(
668 Uuid::new_v4(),
669 "trade.executed".to_string(),
670 "e1".to_string(),
671 "default".to_string(),
672 json!({}),
673 chrono::Utc::now(),
674 None,
675 1,
676 );
677 assert!(!WebSocketManager::passes_filters(
678 &manager.clients,
679 client_id,
680 &trade_event
681 ));
682 }
683
684 #[test]
685 fn test_no_prefix_filters_matches_all() {
686 let manager = WebSocketManager::new();
687 let client_id = Uuid::new_v4();
688
689 manager.clients.insert(
690 client_id,
691 ClientInfo {
692 id: client_id,
693 filters: EventFilters::default(),
694 },
695 );
696
697 let event = create_test_event();
698 assert!(WebSocketManager::passes_filters(
699 &manager.clients,
700 client_id,
701 &event
702 ));
703 }
704
705 #[test]
706 fn test_subscribe_message_parsing() {
707 let json = r#"{"type": "subscribe", "filters": ["scheduler.*", "index.*"]}"#;
708 let msg: SubscribeMessage = serde_json::from_str(json).unwrap();
709 assert_eq!(msg.msg_type, "subscribe");
710 assert_eq!(msg.filters, vec!["scheduler.*", "index.*"]);
711 }
712
713 #[test]
714 fn test_backward_compat_no_config() {
715 let manager = WebSocketManager::new();
717 assert_eq!(manager.config.capacity, 1000);
718 assert!(manager.config.batch_interval_ms.is_none());
719
720 let event = Arc::new(create_test_event());
722 manager.broadcast_event(event);
723
724 let stats = manager.stats();
725 assert_eq!(stats.connected_clients, 0);
726 }
727}