1use crate::{domain::entities::Event, store::EventStore};
2use axum::extract::ws::{Message, WebSocket};
3use dashmap::DashMap;
4use futures::{sink::SinkExt, stream::StreamExt};
5use std::sync::Arc;
6use tokio::sync::broadcast;
7use uuid::Uuid;
8
9#[derive(Debug, Clone)]
11pub struct WebSocketConfig {
12 pub capacity: usize,
14 pub batch_interval_ms: Option<u64>,
18 pub max_batch_size: usize,
20}
21
22impl Default for WebSocketConfig {
23 fn default() -> Self {
24 Self {
25 capacity: 1000,
26 batch_interval_ms: None,
27 max_batch_size: 100,
28 }
29 }
30}
31
32pub struct WebSocketManager {
34 event_tx: broadcast::Sender<Arc<Event>>,
36
37 clients: Arc<DashMap<Uuid, ClientInfo>>,
39
40 config: WebSocketConfig,
42}
43
44#[derive(Debug, Clone)]
45struct ClientInfo {
46 id: Uuid,
47 filters: EventFilters,
48}
49
50#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)]
51pub struct EventFilters {
52 pub entity_id: Option<String>,
53 pub event_type: Option<String>,
54 #[serde(default)]
57 pub event_type_prefixes: Vec<String>,
58}
59
60#[derive(Debug, serde::Deserialize)]
63struct SubscribeMessage {
64 #[serde(rename = "type")]
65 msg_type: String,
66 #[serde(default)]
67 filters: Vec<String>,
68}
69
70impl WebSocketManager {
71 pub fn new() -> Self {
72 Self::with_config(WebSocketConfig::default())
73 }
74
75 pub fn with_config(config: WebSocketConfig) -> Self {
77 let (event_tx, _) = broadcast::channel(config.capacity);
78
79 Self {
80 event_tx,
81 clients: Arc::new(DashMap::new()),
82 config,
83 }
84 }
85
86 #[cfg_attr(feature = "hotpath", hotpath::measure)]
88 pub fn broadcast_event(&self, event: Arc<Event>) {
89 let _ = self.event_tx.send(event);
91 }
92
93 pub fn subscribe_events(&self) -> broadcast::Receiver<Arc<Event>> {
95 self.event_tx.subscribe()
96 }
97
98 #[cfg_attr(feature = "hotpath", hotpath::measure)]
100 pub async fn handle_socket(&self, socket: WebSocket) {
101 self.handle_socket_inner(socket, None, None).await;
102 }
103
104 pub async fn handle_socket_with_consumer(
109 &self,
110 socket: WebSocket,
111 consumer_id: String,
112 store: Arc<EventStore>,
113 ) {
114 self.handle_socket_inner(socket, Some(consumer_id), Some(store))
115 .await;
116 }
117
118 async fn handle_socket_inner(
119 &self,
120 socket: WebSocket,
121 consumer_id: Option<String>,
122 store: Option<Arc<EventStore>>,
123 ) {
124 let client_id = Uuid::new_v4();
125 tracing::info!(
126 "🔌 WebSocket client connected: {} (consumer: {:?})",
127 client_id,
128 consumer_id
129 );
130
131 let event_rx = self.event_tx.subscribe();
133
134 let (mut sender, mut receiver) = socket.split();
136
137 let mut consumer_filters: Vec<String> = Vec::new();
139 if let (Some(cid), Some(store)) = (&consumer_id, &store) {
140 let registry = store.consumer_registry();
141 let consumer = registry.get_or_create(cid);
142 consumer_filters = consumer.event_type_filters.clone();
143 let cursor = consumer.cursor_position.unwrap_or(0);
144
145 let replay_events = store.events_after_offset(cursor, &consumer_filters, usize::MAX);
146
147 tracing::info!(
148 "Replaying {} events for consumer '{}' from offset {}",
149 replay_events.len(),
150 cid,
151 cursor
152 );
153
154 for (position, event) in &replay_events {
155 let dto = serde_json::json!({
156 "type": "replay",
157 "position": position,
158 "event": event,
159 });
160 if let Ok(json) = serde_json::to_string(&dto)
161 && sender.send(Message::Text(json.into())).await.is_err()
162 {
163 tracing::warn!("Failed to send replay event to client {}", client_id);
164 return;
165 }
166 }
167
168 let sentinel =
170 serde_json::json!({"type": "replay_complete", "replayed": replay_events.len()});
171 if let Ok(json) = serde_json::to_string(&sentinel) {
172 let _ = sender.send(Message::Text(json.into())).await;
173 }
174 }
175
176 let initial_filters = if consumer_filters.is_empty() {
178 EventFilters::default()
179 } else {
180 EventFilters {
181 event_type_prefixes: consumer_filters,
182 ..Default::default()
183 }
184 };
185
186 self.clients.insert(
187 client_id,
188 ClientInfo {
189 id: client_id,
190 filters: initial_filters,
191 },
192 );
193
194 let clients = Arc::clone(&self.clients);
196 let config = self.config.clone();
197 let send_task = tokio::spawn(async move {
198 if let Some(interval_ms) = config.batch_interval_ms {
199 Self::send_batched(
200 event_rx,
201 sender,
202 clients,
203 client_id,
204 interval_ms,
205 config.max_batch_size,
206 )
207 .await;
208 } else {
209 Self::send_unbatched(event_rx, sender, clients, client_id).await;
210 }
211 });
212
213 let clients = Arc::clone(&self.clients);
215 let recv_task = tokio::spawn(async move {
216 while let Some(Ok(msg)) = receiver.next().await {
217 if let Message::Text(text) = msg {
218 let text_str = text.as_str();
219 if let Ok(sub) = serde_json::from_str::<SubscribeMessage>(text_str)
221 && sub.msg_type == "subscribe"
222 {
223 tracing::info!(
224 "Setting prefix filters for client {}: {:?}",
225 client_id,
226 sub.filters
227 );
228 if let Some(mut client) = clients.get_mut(&client_id) {
229 client.filters.event_type_prefixes = sub.filters;
230 client.filters.event_type = None;
232 }
233 continue;
234 }
235 if let Ok(filters) = serde_json::from_str::<EventFilters>(text_str) {
237 tracing::info!("Setting filters for client {}: {:?}", client_id, filters);
238 if let Some(mut client) = clients.get_mut(&client_id) {
239 client.filters = filters;
240 }
241 }
242 }
243 }
244 });
245
246 tokio::select! {
248 _ = send_task => {
249 tracing::info!("Send task ended for client {}", client_id);
250 }
251 _ = recv_task => {
252 tracing::info!("Receive task ended for client {}", client_id);
253 }
254 }
255
256 self.clients.remove(&client_id);
258 tracing::info!("🔌 WebSocket client disconnected: {}", client_id);
259 }
260
261 async fn send_unbatched(
263 mut event_rx: broadcast::Receiver<Arc<Event>>,
264 mut sender: futures::stream::SplitSink<WebSocket, Message>,
265 clients: Arc<DashMap<Uuid, ClientInfo>>,
266 client_id: Uuid,
267 ) {
268 loop {
269 match event_rx.recv().await {
270 Ok(event) => {
271 if !Self::passes_filters(&clients, client_id, &event) {
272 continue;
273 }
274
275 match serde_json::to_string(&*event) {
276 Ok(json) => {
277 if sender.send(Message::Text(json.into())).await.is_err() {
278 tracing::warn!("Failed to send event to client {}", client_id);
279 break;
280 }
281 }
282 Err(e) => {
283 tracing::error!("Failed to serialize event: {}", e);
284 }
285 }
286 }
287 Err(broadcast::error::RecvError::Lagged(n)) => {
288 let msg = serde_json::json!({"type": "lagged", "missed": n});
289 let _ = sender.send(Message::Text(msg.to_string().into())).await;
290 tracing::warn!("Client {} lagged, missed {} events", client_id, n);
291 }
292 Err(broadcast::error::RecvError::Closed) => break,
293 }
294 }
295 }
296
297 async fn send_batched(
299 mut event_rx: broadcast::Receiver<Arc<Event>>,
300 mut sender: futures::stream::SplitSink<WebSocket, Message>,
301 clients: Arc<DashMap<Uuid, ClientInfo>>,
302 client_id: Uuid,
303 interval_ms: u64,
304 max_batch_size: usize,
305 ) {
306 let mut batch: Vec<serde_json::Value> = Vec::with_capacity(max_batch_size);
307 let mut ticker = tokio::time::interval(std::time::Duration::from_millis(interval_ms));
308 ticker.tick().await; loop {
311 tokio::select! {
312 result = event_rx.recv() => {
313 match result {
314 Ok(event) => {
315 if !Self::passes_filters(&clients, client_id, &event) {
316 continue;
317 }
318
319 if let Ok(val) = serde_json::to_value(&*event) {
320 batch.push(val);
321 }
322
323 if batch.len() >= max_batch_size
325 && !Self::flush_batch(&mut sender, &mut batch, client_id).await
326 {
327 break;
328 }
329 }
330 Err(broadcast::error::RecvError::Lagged(n)) => {
331 let _ = Self::flush_batch(&mut sender, &mut batch, client_id).await;
333 let msg = serde_json::json!({"type": "lagged", "missed": n});
334 let _ = sender
335 .send(Message::Text(msg.to_string().into()))
336 .await;
337 tracing::warn!(
338 "Client {} lagged, missed {} events",
339 client_id,
340 n
341 );
342 }
343 Err(broadcast::error::RecvError::Closed) => {
344 let _ = Self::flush_batch(&mut sender, &mut batch, client_id).await;
346 break;
347 }
348 }
349 }
350 _ = ticker.tick() => {
351 if !batch.is_empty()
352 && !Self::flush_batch(&mut sender, &mut batch, client_id).await
353 {
354 break;
355 }
356 }
357 }
358 }
359 }
360
361 async fn flush_batch(
363 sender: &mut futures::stream::SplitSink<WebSocket, Message>,
364 batch: &mut Vec<serde_json::Value>,
365 client_id: Uuid,
366 ) -> bool {
367 if batch.is_empty() {
368 return true;
369 }
370
371 let json_array = serde_json::Value::Array(std::mem::take(batch));
372 match serde_json::to_string(&json_array) {
373 Ok(json) => {
374 if sender.send(Message::Text(json.into())).await.is_err() {
375 tracing::warn!("Failed to send batch to client {}", client_id);
376 return false;
377 }
378 true
379 }
380 Err(e) => {
381 tracing::error!("Failed to serialize batch: {}", e);
382 batch.clear();
383 true
384 }
385 }
386 }
387
388 fn passes_filters(clients: &DashMap<Uuid, ClientInfo>, client_id: Uuid, event: &Event) -> bool {
390 let filters = clients
391 .get(&client_id)
392 .map(|entry| entry.value().filters.clone())
393 .unwrap_or_default();
394
395 if let Some(ref entity_id) = filters.entity_id
396 && event.entity_id_str() != entity_id
397 {
398 return false;
399 }
400
401 if let Some(ref event_type) = filters.event_type
403 && event.event_type_str() != event_type
404 {
405 return false;
406 }
407
408 if !filters.event_type_prefixes.is_empty() {
410 let event_type = event.event_type_str();
411 let matches = filters.event_type_prefixes.iter().any(|pattern| {
412 if let Some(prefix) = pattern.strip_suffix(".*") {
413 event_type.starts_with(prefix)
414 && event_type.as_bytes().get(prefix.len()) == Some(&b'.')
415 } else {
416 event_type == pattern
417 }
418 });
419 if !matches {
420 return false;
421 }
422 }
423
424 true
425 }
426
427 pub fn stats(&self) -> WebSocketStats {
429 WebSocketStats {
430 connected_clients: self.clients.len(),
431 total_capacity: self.event_tx.receiver_count(),
432 }
433 }
434}
435
436impl Default for WebSocketManager {
437 fn default() -> Self {
438 Self::new()
439 }
440}
441
442#[derive(Debug, serde::Serialize)]
443pub struct WebSocketStats {
444 pub connected_clients: usize,
445 pub total_capacity: usize,
446}
447
448#[cfg(test)]
449mod tests {
450 use super::*;
451 use serde_json::json;
452
453 fn create_test_event() -> Event {
454 Event::reconstruct_from_strings(
455 Uuid::new_v4(),
456 "test.event".to_string(),
457 "test-entity".to_string(),
458 "default".to_string(),
459 json!({"test": "data"}),
460 chrono::Utc::now(),
461 None,
462 1,
463 )
464 }
465
466 #[test]
467 fn test_websocket_manager_creation() {
468 let manager = WebSocketManager::new();
469 let stats = manager.stats();
470 assert_eq!(stats.connected_clients, 0);
471 }
472
473 #[test]
474 fn test_event_broadcast() {
475 let manager = WebSocketManager::new();
476 let event = Arc::new(create_test_event());
477
478 manager.broadcast_event(event);
480 }
481
482 #[test]
483 fn test_config_defaults() {
484 let config = WebSocketConfig::default();
485 assert_eq!(config.capacity, 1000);
486 assert_eq!(config.batch_interval_ms, None);
487 assert_eq!(config.max_batch_size, 100);
488 }
489
490 #[test]
491 fn test_lagged_notification() {
492 let config = WebSocketConfig {
494 capacity: 2,
495 batch_interval_ms: None,
496 max_batch_size: 100,
497 };
498 let manager = WebSocketManager::with_config(config);
499
500 let mut rx = manager.subscribe_events();
502 for _ in 0..5 {
503 manager.broadcast_event(Arc::new(create_test_event()));
504 }
505
506 match rx.try_recv() {
508 Err(broadcast::error::TryRecvError::Lagged(n)) => {
509 assert!(n > 0, "should report missed events");
510 }
511 Ok(_) => {
512 }
514 Err(e) => {
515 panic!("unexpected error: {e:?}");
516 }
517 }
518 }
519
520 #[test]
521 fn test_batch_mode_groups_events() {
522 let config = WebSocketConfig {
524 capacity: 1000,
525 batch_interval_ms: Some(50),
526 max_batch_size: 10,
527 };
528 let manager = WebSocketManager::with_config(config.clone());
529 assert_eq!(manager.config.batch_interval_ms, Some(50));
530 assert_eq!(manager.config.max_batch_size, 10);
531
532 let rt = tokio::runtime::Builder::new_current_thread()
534 .enable_all()
535 .build()
536 .unwrap();
537
538 rt.block_on(async {
539 let events: Vec<serde_json::Value> = (0..3)
541 .map(|_| serde_json::to_value(create_test_event()).unwrap())
542 .collect();
543
544 let json_array = serde_json::Value::Array(events);
545 let serialized = serde_json::to_string(&json_array).unwrap();
546 let parsed: Vec<serde_json::Value> = serde_json::from_str(&serialized).unwrap();
547 assert_eq!(parsed.len(), 3);
548 });
549 }
550
551 #[test]
552 fn test_batch_flush_on_max_size() {
553 let config = WebSocketConfig {
555 capacity: 1000,
556 batch_interval_ms: Some(1000), max_batch_size: 5, };
559 let manager = WebSocketManager::with_config(config);
560 assert_eq!(manager.config.max_batch_size, 5);
561 }
562
563 #[test]
564 fn test_prefix_filter_matching() {
565 let manager = WebSocketManager::new();
566 let client_id = Uuid::new_v4();
567
568 manager.clients.insert(
570 client_id,
571 ClientInfo {
572 id: client_id,
573 filters: EventFilters {
574 entity_id: None,
575 event_type: None,
576 event_type_prefixes: vec!["scheduler.*".to_string()],
577 },
578 },
579 );
580
581 let matching = Event::reconstruct_from_strings(
583 Uuid::new_v4(),
584 "scheduler.started".to_string(),
585 "e1".to_string(),
586 "default".to_string(),
587 json!({}),
588 chrono::Utc::now(),
589 None,
590 1,
591 );
592 assert!(WebSocketManager::passes_filters(
593 &manager.clients,
594 client_id,
595 &matching
596 ));
597
598 let non_matching = Event::reconstruct_from_strings(
600 Uuid::new_v4(),
601 "trade.executed".to_string(),
602 "e2".to_string(),
603 "default".to_string(),
604 json!({}),
605 chrono::Utc::now(),
606 None,
607 1,
608 );
609 assert!(!WebSocketManager::passes_filters(
610 &manager.clients,
611 client_id,
612 &non_matching
613 ));
614 }
615
616 #[test]
617 fn test_prefix_filter_multiple() {
618 let manager = WebSocketManager::new();
619 let client_id = Uuid::new_v4();
620
621 manager.clients.insert(
622 client_id,
623 ClientInfo {
624 id: client_id,
625 filters: EventFilters {
626 entity_id: None,
627 event_type: None,
628 event_type_prefixes: vec!["scheduler.*".to_string(), "index.*".to_string()],
629 },
630 },
631 );
632
633 let scheduler_event = Event::reconstruct_from_strings(
634 Uuid::new_v4(),
635 "scheduler.completed".to_string(),
636 "e1".to_string(),
637 "default".to_string(),
638 json!({}),
639 chrono::Utc::now(),
640 None,
641 1,
642 );
643 assert!(WebSocketManager::passes_filters(
644 &manager.clients,
645 client_id,
646 &scheduler_event
647 ));
648
649 let index_event = Event::reconstruct_from_strings(
650 Uuid::new_v4(),
651 "index.created".to_string(),
652 "e1".to_string(),
653 "default".to_string(),
654 json!({}),
655 chrono::Utc::now(),
656 None,
657 1,
658 );
659 assert!(WebSocketManager::passes_filters(
660 &manager.clients,
661 client_id,
662 &index_event
663 ));
664
665 let trade_event = Event::reconstruct_from_strings(
666 Uuid::new_v4(),
667 "trade.executed".to_string(),
668 "e1".to_string(),
669 "default".to_string(),
670 json!({}),
671 chrono::Utc::now(),
672 None,
673 1,
674 );
675 assert!(!WebSocketManager::passes_filters(
676 &manager.clients,
677 client_id,
678 &trade_event
679 ));
680 }
681
682 #[test]
683 fn test_no_prefix_filters_matches_all() {
684 let manager = WebSocketManager::new();
685 let client_id = Uuid::new_v4();
686
687 manager.clients.insert(
688 client_id,
689 ClientInfo {
690 id: client_id,
691 filters: EventFilters::default(),
692 },
693 );
694
695 let event = create_test_event();
696 assert!(WebSocketManager::passes_filters(
697 &manager.clients,
698 client_id,
699 &event
700 ));
701 }
702
703 #[test]
704 fn test_subscribe_message_parsing() {
705 let json = r#"{"type": "subscribe", "filters": ["scheduler.*", "index.*"]}"#;
706 let msg: SubscribeMessage = serde_json::from_str(json).unwrap();
707 assert_eq!(msg.msg_type, "subscribe");
708 assert_eq!(msg.filters, vec!["scheduler.*", "index.*"]);
709 }
710
711 #[test]
712 fn test_backward_compat_no_config() {
713 let manager = WebSocketManager::new();
715 assert_eq!(manager.config.capacity, 1000);
716 assert!(manager.config.batch_interval_ms.is_none());
717
718 let event = Arc::new(create_test_event());
720 manager.broadcast_event(event);
721
722 let stats = manager.stats();
723 assert_eq!(stats.connected_clients, 0);
724 }
725}