Skip to main content

allsource_core/infrastructure/web/
websocket.rs

1use crate::{domain::entities::Event, store::EventStore};
2use axum::extract::ws::{Message, WebSocket};
3use dashmap::DashMap;
4use futures::{sink::SinkExt, stream::StreamExt};
5use std::sync::Arc;
6use tokio::sync::broadcast;
7use uuid::Uuid;
8
9/// Configuration for WebSocket backpressure and batching.
10#[derive(Debug, Clone)]
11pub struct WebSocketConfig {
12    /// Broadcast channel capacity (default 1000).
13    pub capacity: usize,
14    /// Optional batching interval in milliseconds.
15    /// `None` = no batching (current behavior, backward compatible).
16    /// `Some(50)` = buffer events and flush every 50ms as JSON arrays.
17    pub batch_interval_ms: Option<u64>,
18    /// Flush early when the batch reaches this size (default 100).
19    pub max_batch_size: usize,
20}
21
22impl Default for WebSocketConfig {
23    fn default() -> Self {
24        Self {
25            capacity: 1000,
26            batch_interval_ms: None,
27            max_batch_size: 100,
28        }
29    }
30}
31
32/// WebSocket manager for real-time event streaming (v0.2 feature)
33pub struct WebSocketManager {
34    /// Broadcast channel for sending events to all connected clients
35    event_tx: broadcast::Sender<Arc<Event>>,
36
37    /// Connected clients by ID - using DashMap for lock-free concurrent access
38    clients: Arc<DashMap<Uuid, ClientInfo>>,
39
40    /// Backpressure and batching configuration
41    config: WebSocketConfig,
42}
43
44#[derive(Debug, Clone)]
45struct ClientInfo {
46    id: Uuid,
47    filters: EventFilters,
48}
49
50#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)]
51pub struct EventFilters {
52    pub entity_id: Option<String>,
53    pub event_type: Option<String>,
54    /// Prefix-based event type filters (e.g. `["scheduler.*", "index.*"]`).
55    /// If non-empty, only events matching at least one prefix are delivered.
56    #[serde(default)]
57    pub event_type_prefixes: Vec<String>,
58}
59
60/// Client message for setting prefix-based subscription filters.
61/// Sent as: `{"type": "subscribe", "filters": ["scheduler.*", "index.*"]}`
62#[derive(Debug, serde::Deserialize)]
63struct SubscribeMessage {
64    #[serde(rename = "type")]
65    msg_type: String,
66    #[serde(default)]
67    filters: Vec<String>,
68}
69
70impl WebSocketManager {
71    pub fn new() -> Self {
72        Self::with_config(WebSocketConfig::default())
73    }
74
75    /// Create a WebSocket manager with custom backpressure configuration.
76    pub fn with_config(config: WebSocketConfig) -> Self {
77        let (event_tx, _) = broadcast::channel(config.capacity);
78
79        Self {
80            event_tx,
81            clients: Arc::new(DashMap::new()),
82            config,
83        }
84    }
85
86    /// Broadcast an event to all connected WebSocket clients
87    #[cfg_attr(feature = "hotpath", hotpath::measure)]
88    pub fn broadcast_event(&self, event: Arc<Event>) {
89        // Send to broadcast channel (non-blocking)
90        let _ = self.event_tx.send(event);
91    }
92
93    /// Subscribe to the event broadcast channel (used by RESP3 SUBSCRIBE).
94    pub fn subscribe_events(&self) -> broadcast::Receiver<Arc<Event>> {
95        self.event_tx.subscribe()
96    }
97
98    /// Handle a new WebSocket connection (fire-and-forget, no consumer tracking)
99    #[cfg_attr(feature = "hotpath", hotpath::measure)]
100    pub async fn handle_socket(&self, socket: WebSocket) {
101        self.handle_socket_inner(socket, None, None).await;
102    }
103
104    /// Handle a WebSocket connection with a durable consumer for auto-replay.
105    ///
106    /// Replays all events since the consumer's last acked position, then switches
107    /// to real-time delivery. The consumer's event_type_filters are applied during replay.
108    pub async fn handle_socket_with_consumer(
109        &self,
110        socket: WebSocket,
111        consumer_id: String,
112        store: Arc<EventStore>,
113    ) {
114        self.handle_socket_inner(socket, Some(consumer_id), Some(store))
115            .await;
116    }
117
118    async fn handle_socket_inner(
119        &self,
120        socket: WebSocket,
121        consumer_id: Option<String>,
122        store: Option<Arc<EventStore>>,
123    ) {
124        let client_id = Uuid::new_v4();
125        tracing::info!(
126            "🔌 WebSocket client connected: {} (consumer: {:?})",
127            client_id,
128            consumer_id
129        );
130
131        // Subscribe to broadcast channel BEFORE replay so we don't miss events
132        let event_rx = self.event_tx.subscribe();
133
134        // Split socket into sender and receiver
135        let (mut sender, mut receiver) = socket.split();
136
137        // Replay missed events for durable consumers
138        let mut consumer_filters: Vec<String> = Vec::new();
139        if let (Some(cid), Some(store)) = (&consumer_id, &store) {
140            let registry = store.consumer_registry();
141            let consumer = registry.get_or_create(cid);
142            consumer_filters = consumer.event_type_filters.clone();
143            let cursor = consumer.cursor_position.unwrap_or(0);
144
145            let replay_events = store.events_after_offset(cursor, &consumer_filters, usize::MAX);
146
147            tracing::info!(
148                "Replaying {} events for consumer '{}' from offset {}",
149                replay_events.len(),
150                cid,
151                cursor
152            );
153
154            for (position, event) in &replay_events {
155                let dto = serde_json::json!({
156                    "type": "replay",
157                    "position": position,
158                    "event": event,
159                });
160                if let Ok(json) = serde_json::to_string(&dto)
161                    && sender.send(Message::Text(json.into())).await.is_err()
162                {
163                    tracing::warn!("Failed to send replay event to client {}", client_id);
164                    return;
165                }
166            }
167
168            // Send replay-complete sentinel
169            let sentinel =
170                serde_json::json!({"type": "replay_complete", "replayed": replay_events.len()});
171            if let Ok(json) = serde_json::to_string(&sentinel) {
172                let _ = sender.send(Message::Text(json.into())).await;
173            }
174        }
175
176        // Register client with consumer's prefix filters (if any)
177        let initial_filters = if consumer_filters.is_empty() {
178            EventFilters::default()
179        } else {
180            EventFilters {
181                event_type_prefixes: consumer_filters,
182                ..Default::default()
183            }
184        };
185
186        self.clients.insert(
187            client_id,
188            ClientInfo {
189                id: client_id,
190                filters: initial_filters,
191            },
192        );
193
194        // Spawn send task based on config
195        let clients = Arc::clone(&self.clients);
196        let config = self.config.clone();
197        let send_task = tokio::spawn(async move {
198            if let Some(interval_ms) = config.batch_interval_ms {
199                Self::send_batched(
200                    event_rx,
201                    sender,
202                    clients,
203                    client_id,
204                    interval_ms,
205                    config.max_batch_size,
206                )
207                .await;
208            } else {
209                Self::send_unbatched(event_rx, sender, clients, client_id).await;
210            }
211        });
212
213        // Handle incoming messages from client (for setting filters)
214        let clients = Arc::clone(&self.clients);
215        let recv_task = tokio::spawn(async move {
216            while let Some(Ok(msg)) = receiver.next().await {
217                if let Message::Text(text) = msg {
218                    let text_str = text.as_str();
219                    // Try subscribe message first (prefix-based filtering)
220                    if let Ok(sub) = serde_json::from_str::<SubscribeMessage>(text_str)
221                        && sub.msg_type == "subscribe"
222                    {
223                        tracing::info!(
224                            "Setting prefix filters for client {}: {:?}",
225                            client_id,
226                            sub.filters
227                        );
228                        if let Some(mut client) = clients.get_mut(&client_id) {
229                            client.filters.event_type_prefixes = sub.filters;
230                            // Clear exact-match filter when prefix filters are set
231                            client.filters.event_type = None;
232                        }
233                        continue;
234                    }
235                    // Fall back to legacy exact-match filter
236                    if let Ok(filters) = serde_json::from_str::<EventFilters>(text_str) {
237                        tracing::info!("Setting filters for client {}: {:?}", client_id, filters);
238                        if let Some(mut client) = clients.get_mut(&client_id) {
239                            client.filters = filters;
240                        }
241                    }
242                }
243            }
244        });
245
246        // Wait for either task to finish
247        tokio::select! {
248            _ = send_task => {
249                tracing::info!("Send task ended for client {}", client_id);
250            }
251            _ = recv_task => {
252                tracing::info!("Receive task ended for client {}", client_id);
253            }
254        }
255
256        // Clean up client
257        self.clients.remove(&client_id);
258        tracing::info!("🔌 WebSocket client disconnected: {}", client_id);
259    }
260
261    /// Unbatched send loop — original behavior (one message per event).
262    async fn send_unbatched(
263        mut event_rx: broadcast::Receiver<Arc<Event>>,
264        mut sender: futures::stream::SplitSink<WebSocket, Message>,
265        clients: Arc<DashMap<Uuid, ClientInfo>>,
266        client_id: Uuid,
267    ) {
268        loop {
269            match event_rx.recv().await {
270                Ok(event) => {
271                    if !Self::passes_filters(&clients, client_id, &event) {
272                        continue;
273                    }
274
275                    match serde_json::to_string(&*event) {
276                        Ok(json) => {
277                            if sender.send(Message::Text(json.into())).await.is_err() {
278                                tracing::warn!("Failed to send event to client {}", client_id);
279                                break;
280                            }
281                        }
282                        Err(e) => {
283                            tracing::error!("Failed to serialize event: {}", e);
284                        }
285                    }
286                }
287                Err(broadcast::error::RecvError::Lagged(n)) => {
288                    let msg = serde_json::json!({"type": "lagged", "missed": n});
289                    let _ = sender.send(Message::Text(msg.to_string().into())).await;
290                    tracing::warn!("Client {} lagged, missed {} events", client_id, n);
291                }
292                Err(broadcast::error::RecvError::Closed) => break,
293            }
294        }
295    }
296
297    /// Batched send loop — buffers events and flushes periodically or on max batch size.
298    async fn send_batched(
299        mut event_rx: broadcast::Receiver<Arc<Event>>,
300        mut sender: futures::stream::SplitSink<WebSocket, Message>,
301        clients: Arc<DashMap<Uuid, ClientInfo>>,
302        client_id: Uuid,
303        interval_ms: u64,
304        max_batch_size: usize,
305    ) {
306        let mut batch: Vec<serde_json::Value> = Vec::with_capacity(max_batch_size);
307        let mut ticker = tokio::time::interval(std::time::Duration::from_millis(interval_ms));
308        ticker.tick().await; // first tick completes immediately
309
310        loop {
311            tokio::select! {
312                result = event_rx.recv() => {
313                    match result {
314                        Ok(event) => {
315                            if !Self::passes_filters(&clients, client_id, &event) {
316                                continue;
317                            }
318
319                            if let Ok(val) = serde_json::to_value(&*event) {
320                                batch.push(val);
321                            }
322
323                            // Flush early if batch is full
324                            if batch.len() >= max_batch_size
325                                && !Self::flush_batch(&mut sender, &mut batch, client_id).await
326                            {
327                                break;
328                            }
329                        }
330                        Err(broadcast::error::RecvError::Lagged(n)) => {
331                            // Flush any pending batch first
332                            let _ = Self::flush_batch(&mut sender, &mut batch, client_id).await;
333                            let msg = serde_json::json!({"type": "lagged", "missed": n});
334                            let _ = sender
335                                .send(Message::Text(msg.to_string().into()))
336                                .await;
337                            tracing::warn!(
338                                "Client {} lagged, missed {} events",
339                                client_id,
340                                n
341                            );
342                        }
343                        Err(broadcast::error::RecvError::Closed) => {
344                            // Flush remaining batch before exit
345                            let _ = Self::flush_batch(&mut sender, &mut batch, client_id).await;
346                            break;
347                        }
348                    }
349                }
350                _ = ticker.tick() => {
351                    if !batch.is_empty()
352                        && !Self::flush_batch(&mut sender, &mut batch, client_id).await
353                    {
354                        break;
355                    }
356                }
357            }
358        }
359    }
360
361    /// Flush the current batch as a JSON array. Returns false if send failed.
362    async fn flush_batch(
363        sender: &mut futures::stream::SplitSink<WebSocket, Message>,
364        batch: &mut Vec<serde_json::Value>,
365        client_id: Uuid,
366    ) -> bool {
367        if batch.is_empty() {
368            return true;
369        }
370
371        let json_array = serde_json::Value::Array(std::mem::take(batch));
372        match serde_json::to_string(&json_array) {
373            Ok(json) => {
374                if sender.send(Message::Text(json.into())).await.is_err() {
375                    tracing::warn!("Failed to send batch to client {}", client_id);
376                    return false;
377                }
378                true
379            }
380            Err(e) => {
381                tracing::error!("Failed to serialize batch: {}", e);
382                batch.clear();
383                true
384            }
385        }
386    }
387
388    /// Check if an event passes the client's filters.
389    fn passes_filters(clients: &DashMap<Uuid, ClientInfo>, client_id: Uuid, event: &Event) -> bool {
390        let filters = clients
391            .get(&client_id)
392            .map(|entry| entry.value().filters.clone())
393            .unwrap_or_default();
394
395        if let Some(ref entity_id) = filters.entity_id
396            && event.entity_id_str() != entity_id
397        {
398            return false;
399        }
400
401        // Exact match filter (legacy)
402        if let Some(ref event_type) = filters.event_type
403            && event.event_type_str() != event_type
404        {
405            return false;
406        }
407
408        // Prefix-based filters: if set, event must match at least one prefix
409        if !filters.event_type_prefixes.is_empty() {
410            let event_type = event.event_type_str();
411            let matches = filters.event_type_prefixes.iter().any(|pattern| {
412                if let Some(prefix) = pattern.strip_suffix(".*") {
413                    event_type.starts_with(prefix)
414                        && event_type.as_bytes().get(prefix.len()) == Some(&b'.')
415                } else {
416                    event_type == pattern
417                }
418            });
419            if !matches {
420                return false;
421            }
422        }
423
424        true
425    }
426
427    /// Get statistics about connected clients
428    pub fn stats(&self) -> WebSocketStats {
429        WebSocketStats {
430            connected_clients: self.clients.len(),
431            total_capacity: self.event_tx.receiver_count(),
432        }
433    }
434}
435
436impl Default for WebSocketManager {
437    fn default() -> Self {
438        Self::new()
439    }
440}
441
442#[derive(Debug, serde::Serialize)]
443pub struct WebSocketStats {
444    pub connected_clients: usize,
445    pub total_capacity: usize,
446}
447
448#[cfg(test)]
449mod tests {
450    use super::*;
451    use serde_json::json;
452
453    fn create_test_event() -> Event {
454        Event::reconstruct_from_strings(
455            Uuid::new_v4(),
456            "test.event".to_string(),
457            "test-entity".to_string(),
458            "default".to_string(),
459            json!({"test": "data"}),
460            chrono::Utc::now(),
461            None,
462            1,
463        )
464    }
465
466    #[test]
467    fn test_websocket_manager_creation() {
468        let manager = WebSocketManager::new();
469        let stats = manager.stats();
470        assert_eq!(stats.connected_clients, 0);
471    }
472
473    #[test]
474    fn test_event_broadcast() {
475        let manager = WebSocketManager::new();
476        let event = Arc::new(create_test_event());
477
478        // Should not panic
479        manager.broadcast_event(event);
480    }
481
482    #[test]
483    fn test_config_defaults() {
484        let config = WebSocketConfig::default();
485        assert_eq!(config.capacity, 1000);
486        assert_eq!(config.batch_interval_ms, None);
487        assert_eq!(config.max_batch_size, 100);
488    }
489
490    #[test]
491    fn test_lagged_notification() {
492        // Create a tiny channel that will lag quickly
493        let config = WebSocketConfig {
494            capacity: 2,
495            batch_interval_ms: None,
496            max_batch_size: 100,
497        };
498        let manager = WebSocketManager::with_config(config);
499
500        // Subscribe, then overflow the channel
501        let mut rx = manager.subscribe_events();
502        for _ in 0..5 {
503            manager.broadcast_event(Arc::new(create_test_event()));
504        }
505
506        // The receiver should get a Lagged error
507        match rx.try_recv() {
508            Err(broadcast::error::TryRecvError::Lagged(n)) => {
509                assert!(n > 0, "should report missed events");
510            }
511            Ok(_) => {
512                // Got an event — that's fine, lagged may come on next recv
513            }
514            Err(e) => {
515                panic!("unexpected error: {e:?}");
516            }
517        }
518    }
519
520    #[test]
521    fn test_batch_mode_groups_events() {
522        // Verify that with_config creates a manager with batching params
523        let config = WebSocketConfig {
524            capacity: 1000,
525            batch_interval_ms: Some(50),
526            max_batch_size: 10,
527        };
528        let manager = WebSocketManager::with_config(config.clone());
529        assert_eq!(manager.config.batch_interval_ms, Some(50));
530        assert_eq!(manager.config.max_batch_size, 10);
531
532        // The actual batching behavior is tested via the flush_batch helper
533        let rt = tokio::runtime::Builder::new_current_thread()
534            .enable_all()
535            .build()
536            .unwrap();
537
538        rt.block_on(async {
539            // Create a batch of events and serialize as JSON array
540            let events: Vec<serde_json::Value> = (0..3)
541                .map(|_| serde_json::to_value(create_test_event()).unwrap())
542                .collect();
543
544            let json_array = serde_json::Value::Array(events);
545            let serialized = serde_json::to_string(&json_array).unwrap();
546            let parsed: Vec<serde_json::Value> = serde_json::from_str(&serialized).unwrap();
547            assert_eq!(parsed.len(), 3);
548        });
549    }
550
551    #[test]
552    fn test_batch_flush_on_max_size() {
553        // Verify config with small max_batch_size
554        let config = WebSocketConfig {
555            capacity: 1000,
556            batch_interval_ms: Some(1000), // long interval
557            max_batch_size: 5,             // small batch — triggers early flush
558        };
559        let manager = WebSocketManager::with_config(config);
560        assert_eq!(manager.config.max_batch_size, 5);
561    }
562
563    #[test]
564    fn test_prefix_filter_matching() {
565        let manager = WebSocketManager::new();
566        let client_id = Uuid::new_v4();
567
568        // Register client with prefix filters
569        manager.clients.insert(
570            client_id,
571            ClientInfo {
572                id: client_id,
573                filters: EventFilters {
574                    entity_id: None,
575                    event_type: None,
576                    event_type_prefixes: vec!["scheduler.*".to_string()],
577                },
578            },
579        );
580
581        // Matching event
582        let matching = Event::reconstruct_from_strings(
583            Uuid::new_v4(),
584            "scheduler.started".to_string(),
585            "e1".to_string(),
586            "default".to_string(),
587            json!({}),
588            chrono::Utc::now(),
589            None,
590            1,
591        );
592        assert!(WebSocketManager::passes_filters(
593            &manager.clients,
594            client_id,
595            &matching
596        ));
597
598        // Non-matching event
599        let non_matching = Event::reconstruct_from_strings(
600            Uuid::new_v4(),
601            "trade.executed".to_string(),
602            "e2".to_string(),
603            "default".to_string(),
604            json!({}),
605            chrono::Utc::now(),
606            None,
607            1,
608        );
609        assert!(!WebSocketManager::passes_filters(
610            &manager.clients,
611            client_id,
612            &non_matching
613        ));
614    }
615
616    #[test]
617    fn test_prefix_filter_multiple() {
618        let manager = WebSocketManager::new();
619        let client_id = Uuid::new_v4();
620
621        manager.clients.insert(
622            client_id,
623            ClientInfo {
624                id: client_id,
625                filters: EventFilters {
626                    entity_id: None,
627                    event_type: None,
628                    event_type_prefixes: vec!["scheduler.*".to_string(), "index.*".to_string()],
629                },
630            },
631        );
632
633        let scheduler_event = Event::reconstruct_from_strings(
634            Uuid::new_v4(),
635            "scheduler.completed".to_string(),
636            "e1".to_string(),
637            "default".to_string(),
638            json!({}),
639            chrono::Utc::now(),
640            None,
641            1,
642        );
643        assert!(WebSocketManager::passes_filters(
644            &manager.clients,
645            client_id,
646            &scheduler_event
647        ));
648
649        let index_event = Event::reconstruct_from_strings(
650            Uuid::new_v4(),
651            "index.created".to_string(),
652            "e1".to_string(),
653            "default".to_string(),
654            json!({}),
655            chrono::Utc::now(),
656            None,
657            1,
658        );
659        assert!(WebSocketManager::passes_filters(
660            &manager.clients,
661            client_id,
662            &index_event
663        ));
664
665        let trade_event = Event::reconstruct_from_strings(
666            Uuid::new_v4(),
667            "trade.executed".to_string(),
668            "e1".to_string(),
669            "default".to_string(),
670            json!({}),
671            chrono::Utc::now(),
672            None,
673            1,
674        );
675        assert!(!WebSocketManager::passes_filters(
676            &manager.clients,
677            client_id,
678            &trade_event
679        ));
680    }
681
682    #[test]
683    fn test_no_prefix_filters_matches_all() {
684        let manager = WebSocketManager::new();
685        let client_id = Uuid::new_v4();
686
687        manager.clients.insert(
688            client_id,
689            ClientInfo {
690                id: client_id,
691                filters: EventFilters::default(),
692            },
693        );
694
695        let event = create_test_event();
696        assert!(WebSocketManager::passes_filters(
697            &manager.clients,
698            client_id,
699            &event
700        ));
701    }
702
703    #[test]
704    fn test_subscribe_message_parsing() {
705        let json = r#"{"type": "subscribe", "filters": ["scheduler.*", "index.*"]}"#;
706        let msg: SubscribeMessage = serde_json::from_str(json).unwrap();
707        assert_eq!(msg.msg_type, "subscribe");
708        assert_eq!(msg.filters, vec!["scheduler.*", "index.*"]);
709    }
710
711    #[test]
712    fn test_backward_compat_no_config() {
713        // Default constructor should work identically to pre-backpressure behavior
714        let manager = WebSocketManager::new();
715        assert_eq!(manager.config.capacity, 1000);
716        assert!(manager.config.batch_interval_ms.is_none());
717
718        // Broadcast still works
719        let event = Arc::new(create_test_event());
720        manager.broadcast_event(event);
721
722        let stats = manager.stats();
723        assert_eq!(stats.connected_clients, 0);
724    }
725}