Skip to main content

allsource_core/infrastructure/web/
websocket.rs

1use crate::domain::entities::Event;
2use crate::store::EventStore;
3use axum::extract::ws::{Message, WebSocket};
4use dashmap::DashMap;
5use futures::{sink::SinkExt, stream::StreamExt};
6use std::sync::Arc;
7use tokio::sync::broadcast;
8use uuid::Uuid;
9
10/// Configuration for WebSocket backpressure and batching.
11#[derive(Debug, Clone)]
12pub struct WebSocketConfig {
13    /// Broadcast channel capacity (default 1000).
14    pub capacity: usize,
15    /// Optional batching interval in milliseconds.
16    /// `None` = no batching (current behavior, backward compatible).
17    /// `Some(50)` = buffer events and flush every 50ms as JSON arrays.
18    pub batch_interval_ms: Option<u64>,
19    /// Flush early when the batch reaches this size (default 100).
20    pub max_batch_size: usize,
21}
22
23impl Default for WebSocketConfig {
24    fn default() -> Self {
25        Self {
26            capacity: 1000,
27            batch_interval_ms: None,
28            max_batch_size: 100,
29        }
30    }
31}
32
33/// WebSocket manager for real-time event streaming (v0.2 feature)
34pub struct WebSocketManager {
35    /// Broadcast channel for sending events to all connected clients
36    event_tx: broadcast::Sender<Arc<Event>>,
37
38    /// Connected clients by ID - using DashMap for lock-free concurrent access
39    clients: Arc<DashMap<Uuid, ClientInfo>>,
40
41    /// Backpressure and batching configuration
42    config: WebSocketConfig,
43}
44
45#[derive(Debug, Clone)]
46struct ClientInfo {
47    id: Uuid,
48    filters: EventFilters,
49}
50
51#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)]
52pub struct EventFilters {
53    pub entity_id: Option<String>,
54    pub event_type: Option<String>,
55    /// Prefix-based event type filters (e.g. ["scheduler.*", "index.*"]).
56    /// If non-empty, only events matching at least one prefix are delivered.
57    #[serde(default)]
58    pub event_type_prefixes: Vec<String>,
59}
60
61/// Client message for setting prefix-based subscription filters.
62/// Sent as: `{"type": "subscribe", "filters": ["scheduler.*", "index.*"]}`
63#[derive(Debug, serde::Deserialize)]
64struct SubscribeMessage {
65    #[serde(rename = "type")]
66    msg_type: String,
67    #[serde(default)]
68    filters: Vec<String>,
69}
70
71impl WebSocketManager {
72    pub fn new() -> Self {
73        Self::with_config(WebSocketConfig::default())
74    }
75
76    /// Create a WebSocket manager with custom backpressure configuration.
77    pub fn with_config(config: WebSocketConfig) -> Self {
78        let (event_tx, _) = broadcast::channel(config.capacity);
79
80        Self {
81            event_tx,
82            clients: Arc::new(DashMap::new()),
83            config,
84        }
85    }
86
87    /// Broadcast an event to all connected WebSocket clients
88    pub fn broadcast_event(&self, event: Arc<Event>) {
89        // Send to broadcast channel (non-blocking)
90        let _ = self.event_tx.send(event);
91    }
92
93    /// Subscribe to the event broadcast channel (used by RESP3 SUBSCRIBE).
94    pub fn subscribe_events(&self) -> broadcast::Receiver<Arc<Event>> {
95        self.event_tx.subscribe()
96    }
97
98    /// Handle a new WebSocket connection (fire-and-forget, no consumer tracking)
99    pub async fn handle_socket(&self, socket: WebSocket) {
100        self.handle_socket_inner(socket, None, None).await;
101    }
102
103    /// Handle a WebSocket connection with a durable consumer for auto-replay.
104    ///
105    /// Replays all events since the consumer's last acked position, then switches
106    /// to real-time delivery. The consumer's event_type_filters are applied during replay.
107    pub async fn handle_socket_with_consumer(
108        &self,
109        socket: WebSocket,
110        consumer_id: String,
111        store: Arc<EventStore>,
112    ) {
113        self.handle_socket_inner(socket, Some(consumer_id), Some(store))
114            .await;
115    }
116
117    async fn handle_socket_inner(
118        &self,
119        socket: WebSocket,
120        consumer_id: Option<String>,
121        store: Option<Arc<EventStore>>,
122    ) {
123        let client_id = Uuid::new_v4();
124        tracing::info!(
125            "🔌 WebSocket client connected: {} (consumer: {:?})",
126            client_id,
127            consumer_id
128        );
129
130        // Subscribe to broadcast channel BEFORE replay so we don't miss events
131        let event_rx = self.event_tx.subscribe();
132
133        // Split socket into sender and receiver
134        let (mut sender, mut receiver) = socket.split();
135
136        // Replay missed events for durable consumers
137        let mut consumer_filters: Vec<String> = Vec::new();
138        if let (Some(cid), Some(store)) = (&consumer_id, &store) {
139            let registry = store.consumer_registry();
140            let consumer = registry.get_or_create(cid);
141            consumer_filters = consumer.event_type_filters.clone();
142            let cursor = consumer.cursor_position.unwrap_or(0);
143
144            let replay_events =
145                store.events_after_offset(cursor, &consumer_filters, usize::MAX);
146
147            tracing::info!(
148                "Replaying {} events for consumer '{}' from offset {}",
149                replay_events.len(),
150                cid,
151                cursor
152            );
153
154            for (position, event) in &replay_events {
155                let dto = serde_json::json!({
156                    "type": "replay",
157                    "position": position,
158                    "event": event,
159                });
160                if let Ok(json) = serde_json::to_string(&dto)
161                    && sender.send(Message::Text(json.into())).await.is_err()
162                {
163                    tracing::warn!("Failed to send replay event to client {}", client_id);
164                    return;
165                }
166            }
167
168            // Send replay-complete sentinel
169            let sentinel = serde_json::json!({"type": "replay_complete", "replayed": replay_events.len()});
170            if let Ok(json) = serde_json::to_string(&sentinel) {
171                let _ = sender.send(Message::Text(json.into())).await;
172            }
173        }
174
175        // Register client with consumer's prefix filters (if any)
176        let initial_filters = if !consumer_filters.is_empty() {
177            EventFilters {
178                event_type_prefixes: consumer_filters,
179                ..Default::default()
180            }
181        } else {
182            EventFilters::default()
183        };
184
185        self.clients.insert(
186            client_id,
187            ClientInfo {
188                id: client_id,
189                filters: initial_filters,
190            },
191        );
192
193        // Spawn send task based on config
194        let clients = Arc::clone(&self.clients);
195        let config = self.config.clone();
196        let send_task = tokio::spawn(async move {
197            if let Some(interval_ms) = config.batch_interval_ms {
198                Self::send_batched(
199                    event_rx,
200                    sender,
201                    clients,
202                    client_id,
203                    interval_ms,
204                    config.max_batch_size,
205                )
206                .await;
207            } else {
208                Self::send_unbatched(event_rx, sender, clients, client_id).await;
209            }
210        });
211
212        // Handle incoming messages from client (for setting filters)
213        let clients = Arc::clone(&self.clients);
214        let recv_task = tokio::spawn(async move {
215            while let Some(Ok(msg)) = receiver.next().await {
216                if let Message::Text(text) = msg {
217                    let text_str = text.as_str();
218                    // Try subscribe message first (prefix-based filtering)
219                    if let Ok(sub) = serde_json::from_str::<SubscribeMessage>(text_str)
220                        && sub.msg_type == "subscribe"
221                    {
222                        tracing::info!(
223                            "Setting prefix filters for client {}: {:?}",
224                            client_id,
225                            sub.filters
226                        );
227                        if let Some(mut client) = clients.get_mut(&client_id) {
228                            client.filters.event_type_prefixes = sub.filters;
229                            // Clear exact-match filter when prefix filters are set
230                            client.filters.event_type = None;
231                        }
232                        continue;
233                    }
234                    // Fall back to legacy exact-match filter
235                    if let Ok(filters) = serde_json::from_str::<EventFilters>(text_str) {
236                        tracing::info!("Setting filters for client {}: {:?}", client_id, filters);
237                        if let Some(mut client) = clients.get_mut(&client_id) {
238                            client.filters = filters;
239                        }
240                    }
241                }
242            }
243        });
244
245        // Wait for either task to finish
246        tokio::select! {
247            _ = send_task => {
248                tracing::info!("Send task ended for client {}", client_id);
249            }
250            _ = recv_task => {
251                tracing::info!("Receive task ended for client {}", client_id);
252            }
253        }
254
255        // Clean up client
256        self.clients.remove(&client_id);
257        tracing::info!("🔌 WebSocket client disconnected: {}", client_id);
258    }
259
260    /// Unbatched send loop — original behavior (one message per event).
261    async fn send_unbatched(
262        mut event_rx: broadcast::Receiver<Arc<Event>>,
263        mut sender: futures::stream::SplitSink<WebSocket, Message>,
264        clients: Arc<DashMap<Uuid, ClientInfo>>,
265        client_id: Uuid,
266    ) {
267        loop {
268            match event_rx.recv().await {
269                Ok(event) => {
270                    if !Self::passes_filters(&clients, client_id, &event) {
271                        continue;
272                    }
273
274                    match serde_json::to_string(&*event) {
275                        Ok(json) => {
276                            if sender.send(Message::Text(json.into())).await.is_err() {
277                                tracing::warn!("Failed to send event to client {}", client_id);
278                                break;
279                            }
280                        }
281                        Err(e) => {
282                            tracing::error!("Failed to serialize event: {}", e);
283                        }
284                    }
285                }
286                Err(broadcast::error::RecvError::Lagged(n)) => {
287                    let msg = serde_json::json!({"type": "lagged", "missed": n});
288                    let _ = sender.send(Message::Text(msg.to_string().into())).await;
289                    tracing::warn!("Client {} lagged, missed {} events", client_id, n);
290                }
291                Err(broadcast::error::RecvError::Closed) => break,
292            }
293        }
294    }
295
296    /// Batched send loop — buffers events and flushes periodically or on max batch size.
297    async fn send_batched(
298        mut event_rx: broadcast::Receiver<Arc<Event>>,
299        mut sender: futures::stream::SplitSink<WebSocket, Message>,
300        clients: Arc<DashMap<Uuid, ClientInfo>>,
301        client_id: Uuid,
302        interval_ms: u64,
303        max_batch_size: usize,
304    ) {
305        let mut batch: Vec<serde_json::Value> = Vec::with_capacity(max_batch_size);
306        let mut ticker = tokio::time::interval(std::time::Duration::from_millis(interval_ms));
307        ticker.tick().await; // first tick completes immediately
308
309        loop {
310            tokio::select! {
311                result = event_rx.recv() => {
312                    match result {
313                        Ok(event) => {
314                            if !Self::passes_filters(&clients, client_id, &event) {
315                                continue;
316                            }
317
318                            if let Ok(val) = serde_json::to_value(&*event) {
319                                batch.push(val);
320                            }
321
322                            // Flush early if batch is full
323                            if batch.len() >= max_batch_size
324                                && !Self::flush_batch(&mut sender, &mut batch, client_id).await
325                            {
326                                break;
327                            }
328                        }
329                        Err(broadcast::error::RecvError::Lagged(n)) => {
330                            // Flush any pending batch first
331                            let _ = Self::flush_batch(&mut sender, &mut batch, client_id).await;
332                            let msg = serde_json::json!({"type": "lagged", "missed": n});
333                            let _ = sender
334                                .send(Message::Text(msg.to_string().into()))
335                                .await;
336                            tracing::warn!(
337                                "Client {} lagged, missed {} events",
338                                client_id,
339                                n
340                            );
341                        }
342                        Err(broadcast::error::RecvError::Closed) => {
343                            // Flush remaining batch before exit
344                            let _ = Self::flush_batch(&mut sender, &mut batch, client_id).await;
345                            break;
346                        }
347                    }
348                }
349                _ = ticker.tick() => {
350                    if !batch.is_empty()
351                        && !Self::flush_batch(&mut sender, &mut batch, client_id).await
352                    {
353                        break;
354                    }
355                }
356            }
357        }
358    }
359
360    /// Flush the current batch as a JSON array. Returns false if send failed.
361    async fn flush_batch(
362        sender: &mut futures::stream::SplitSink<WebSocket, Message>,
363        batch: &mut Vec<serde_json::Value>,
364        client_id: Uuid,
365    ) -> bool {
366        if batch.is_empty() {
367            return true;
368        }
369
370        let json_array = serde_json::Value::Array(std::mem::take(batch));
371        match serde_json::to_string(&json_array) {
372            Ok(json) => {
373                if sender.send(Message::Text(json.into())).await.is_err() {
374                    tracing::warn!("Failed to send batch to client {}", client_id);
375                    return false;
376                }
377                true
378            }
379            Err(e) => {
380                tracing::error!("Failed to serialize batch: {}", e);
381                batch.clear();
382                true
383            }
384        }
385    }
386
387    /// Check if an event passes the client's filters.
388    fn passes_filters(clients: &DashMap<Uuid, ClientInfo>, client_id: Uuid, event: &Event) -> bool {
389        let filters = clients
390            .get(&client_id)
391            .map(|entry| entry.value().filters.clone())
392            .unwrap_or_default();
393
394        if let Some(ref entity_id) = filters.entity_id
395            && event.entity_id_str() != entity_id
396        {
397            return false;
398        }
399
400        // Exact match filter (legacy)
401        if let Some(ref event_type) = filters.event_type
402            && event.event_type_str() != event_type
403        {
404            return false;
405        }
406
407        // Prefix-based filters: if set, event must match at least one prefix
408        if !filters.event_type_prefixes.is_empty() {
409            let event_type = event.event_type_str();
410            let matches = filters.event_type_prefixes.iter().any(|pattern| {
411                if let Some(prefix) = pattern.strip_suffix(".*") {
412                    event_type.starts_with(prefix)
413                        && event_type.as_bytes().get(prefix.len()) == Some(&b'.')
414                } else {
415                    event_type == pattern
416                }
417            });
418            if !matches {
419                return false;
420            }
421        }
422
423        true
424    }
425
426    /// Get statistics about connected clients
427    pub fn stats(&self) -> WebSocketStats {
428        WebSocketStats {
429            connected_clients: self.clients.len(),
430            total_capacity: self.event_tx.receiver_count(),
431        }
432    }
433}
434
435impl Default for WebSocketManager {
436    fn default() -> Self {
437        Self::new()
438    }
439}
440
441#[derive(Debug, serde::Serialize)]
442pub struct WebSocketStats {
443    pub connected_clients: usize,
444    pub total_capacity: usize,
445}
446
447#[cfg(test)]
448mod tests {
449    use super::*;
450    use serde_json::json;
451
452    fn create_test_event() -> Event {
453        Event::reconstruct_from_strings(
454            Uuid::new_v4(),
455            "test.event".to_string(),
456            "test-entity".to_string(),
457            "default".to_string(),
458            json!({"test": "data"}),
459            chrono::Utc::now(),
460            None,
461            1,
462        )
463    }
464
465    #[test]
466    fn test_websocket_manager_creation() {
467        let manager = WebSocketManager::new();
468        let stats = manager.stats();
469        assert_eq!(stats.connected_clients, 0);
470    }
471
472    #[test]
473    fn test_event_broadcast() {
474        let manager = WebSocketManager::new();
475        let event = Arc::new(create_test_event());
476
477        // Should not panic
478        manager.broadcast_event(event);
479    }
480
481    #[test]
482    fn test_config_defaults() {
483        let config = WebSocketConfig::default();
484        assert_eq!(config.capacity, 1000);
485        assert_eq!(config.batch_interval_ms, None);
486        assert_eq!(config.max_batch_size, 100);
487    }
488
489    #[test]
490    fn test_lagged_notification() {
491        // Create a tiny channel that will lag quickly
492        let config = WebSocketConfig {
493            capacity: 2,
494            batch_interval_ms: None,
495            max_batch_size: 100,
496        };
497        let manager = WebSocketManager::with_config(config);
498
499        // Subscribe, then overflow the channel
500        let mut rx = manager.subscribe_events();
501        for _ in 0..5 {
502            manager.broadcast_event(Arc::new(create_test_event()));
503        }
504
505        // The receiver should get a Lagged error
506        match rx.try_recv() {
507            Err(broadcast::error::TryRecvError::Lagged(n)) => {
508                assert!(n > 0, "should report missed events");
509            }
510            Ok(_) => {
511                // Got an event — that's fine, lagged may come on next recv
512            }
513            Err(e) => {
514                panic!("unexpected error: {:?}", e);
515            }
516        }
517    }
518
519    #[test]
520    fn test_batch_mode_groups_events() {
521        // Verify that with_config creates a manager with batching params
522        let config = WebSocketConfig {
523            capacity: 1000,
524            batch_interval_ms: Some(50),
525            max_batch_size: 10,
526        };
527        let manager = WebSocketManager::with_config(config.clone());
528        assert_eq!(manager.config.batch_interval_ms, Some(50));
529        assert_eq!(manager.config.max_batch_size, 10);
530
531        // The actual batching behavior is tested via the flush_batch helper
532        let rt = tokio::runtime::Builder::new_current_thread()
533            .enable_all()
534            .build()
535            .unwrap();
536
537        rt.block_on(async {
538            // Create a batch of events and serialize as JSON array
539            let events: Vec<serde_json::Value> = (0..3)
540                .map(|_| serde_json::to_value(create_test_event()).unwrap())
541                .collect();
542
543            let json_array = serde_json::Value::Array(events);
544            let serialized = serde_json::to_string(&json_array).unwrap();
545            let parsed: Vec<serde_json::Value> = serde_json::from_str(&serialized).unwrap();
546            assert_eq!(parsed.len(), 3);
547        });
548    }
549
550    #[test]
551    fn test_batch_flush_on_max_size() {
552        // Verify config with small max_batch_size
553        let config = WebSocketConfig {
554            capacity: 1000,
555            batch_interval_ms: Some(1000), // long interval
556            max_batch_size: 5,             // small batch — triggers early flush
557        };
558        let manager = WebSocketManager::with_config(config);
559        assert_eq!(manager.config.max_batch_size, 5);
560    }
561
562    #[test]
563    fn test_prefix_filter_matching() {
564        let manager = WebSocketManager::new();
565        let client_id = Uuid::new_v4();
566
567        // Register client with prefix filters
568        manager.clients.insert(
569            client_id,
570            ClientInfo {
571                id: client_id,
572                filters: EventFilters {
573                    entity_id: None,
574                    event_type: None,
575                    event_type_prefixes: vec!["scheduler.*".to_string()],
576                },
577            },
578        );
579
580        // Matching event
581        let matching = Event::reconstruct_from_strings(
582            Uuid::new_v4(),
583            "scheduler.started".to_string(),
584            "e1".to_string(),
585            "default".to_string(),
586            json!({}),
587            chrono::Utc::now(),
588            None,
589            1,
590        );
591        assert!(WebSocketManager::passes_filters(
592            &manager.clients,
593            client_id,
594            &matching
595        ));
596
597        // Non-matching event
598        let non_matching = Event::reconstruct_from_strings(
599            Uuid::new_v4(),
600            "trade.executed".to_string(),
601            "e2".to_string(),
602            "default".to_string(),
603            json!({}),
604            chrono::Utc::now(),
605            None,
606            1,
607        );
608        assert!(!WebSocketManager::passes_filters(
609            &manager.clients,
610            client_id,
611            &non_matching
612        ));
613    }
614
615    #[test]
616    fn test_prefix_filter_multiple() {
617        let manager = WebSocketManager::new();
618        let client_id = Uuid::new_v4();
619
620        manager.clients.insert(
621            client_id,
622            ClientInfo {
623                id: client_id,
624                filters: EventFilters {
625                    entity_id: None,
626                    event_type: None,
627                    event_type_prefixes: vec![
628                        "scheduler.*".to_string(),
629                        "index.*".to_string(),
630                    ],
631                },
632            },
633        );
634
635        let scheduler_event = Event::reconstruct_from_strings(
636            Uuid::new_v4(),
637            "scheduler.completed".to_string(),
638            "e1".to_string(),
639            "default".to_string(),
640            json!({}),
641            chrono::Utc::now(),
642            None,
643            1,
644        );
645        assert!(WebSocketManager::passes_filters(
646            &manager.clients,
647            client_id,
648            &scheduler_event
649        ));
650
651        let index_event = Event::reconstruct_from_strings(
652            Uuid::new_v4(),
653            "index.created".to_string(),
654            "e1".to_string(),
655            "default".to_string(),
656            json!({}),
657            chrono::Utc::now(),
658            None,
659            1,
660        );
661        assert!(WebSocketManager::passes_filters(
662            &manager.clients,
663            client_id,
664            &index_event
665        ));
666
667        let trade_event = Event::reconstruct_from_strings(
668            Uuid::new_v4(),
669            "trade.executed".to_string(),
670            "e1".to_string(),
671            "default".to_string(),
672            json!({}),
673            chrono::Utc::now(),
674            None,
675            1,
676        );
677        assert!(!WebSocketManager::passes_filters(
678            &manager.clients,
679            client_id,
680            &trade_event
681        ));
682    }
683
684    #[test]
685    fn test_no_prefix_filters_matches_all() {
686        let manager = WebSocketManager::new();
687        let client_id = Uuid::new_v4();
688
689        manager.clients.insert(
690            client_id,
691            ClientInfo {
692                id: client_id,
693                filters: EventFilters::default(),
694            },
695        );
696
697        let event = create_test_event();
698        assert!(WebSocketManager::passes_filters(
699            &manager.clients,
700            client_id,
701            &event
702        ));
703    }
704
705    #[test]
706    fn test_subscribe_message_parsing() {
707        let json = r#"{"type": "subscribe", "filters": ["scheduler.*", "index.*"]}"#;
708        let msg: SubscribeMessage = serde_json::from_str(json).unwrap();
709        assert_eq!(msg.msg_type, "subscribe");
710        assert_eq!(msg.filters, vec!["scheduler.*", "index.*"]);
711    }
712
713    #[test]
714    fn test_backward_compat_no_config() {
715        // Default constructor should work identically to pre-backpressure behavior
716        let manager = WebSocketManager::new();
717        assert_eq!(manager.config.capacity, 1000);
718        assert!(manager.config.batch_interval_ms.is_none());
719
720        // Broadcast still works
721        let event = Arc::new(create_test_event());
722        manager.broadcast_event(event);
723
724        let stats = manager.stats();
725        assert_eq!(stats.connected_clients, 0);
726    }
727}