Skip to main content

allsource_core/infrastructure/web/
websocket.rs

1use crate::domain::entities::Event;
2use axum::extract::ws::{Message, WebSocket};
3use dashmap::DashMap;
4use futures::{sink::SinkExt, stream::StreamExt};
5use std::sync::Arc;
6use tokio::sync::broadcast;
7use uuid::Uuid;
8
9/// Configuration for WebSocket backpressure and batching.
10#[derive(Debug, Clone)]
11pub struct WebSocketConfig {
12    /// Broadcast channel capacity (default 1000).
13    pub capacity: usize,
14    /// Optional batching interval in milliseconds.
15    /// `None` = no batching (current behavior, backward compatible).
16    /// `Some(50)` = buffer events and flush every 50ms as JSON arrays.
17    pub batch_interval_ms: Option<u64>,
18    /// Flush early when the batch reaches this size (default 100).
19    pub max_batch_size: usize,
20}
21
22impl Default for WebSocketConfig {
23    fn default() -> Self {
24        Self {
25            capacity: 1000,
26            batch_interval_ms: None,
27            max_batch_size: 100,
28        }
29    }
30}
31
32/// WebSocket manager for real-time event streaming (v0.2 feature)
33pub struct WebSocketManager {
34    /// Broadcast channel for sending events to all connected clients
35    event_tx: broadcast::Sender<Arc<Event>>,
36
37    /// Connected clients by ID - using DashMap for lock-free concurrent access
38    clients: Arc<DashMap<Uuid, ClientInfo>>,
39
40    /// Backpressure and batching configuration
41    config: WebSocketConfig,
42}
43
44#[derive(Debug, Clone)]
45struct ClientInfo {
46    id: Uuid,
47    filters: EventFilters,
48}
49
50#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)]
51pub struct EventFilters {
52    pub entity_id: Option<String>,
53    pub event_type: Option<String>,
54}
55
56impl WebSocketManager {
57    pub fn new() -> Self {
58        Self::with_config(WebSocketConfig::default())
59    }
60
61    /// Create a WebSocket manager with custom backpressure configuration.
62    pub fn with_config(config: WebSocketConfig) -> Self {
63        let (event_tx, _) = broadcast::channel(config.capacity);
64
65        Self {
66            event_tx,
67            clients: Arc::new(DashMap::new()),
68            config,
69        }
70    }
71
72    /// Broadcast an event to all connected WebSocket clients
73    pub fn broadcast_event(&self, event: Arc<Event>) {
74        // Send to broadcast channel (non-blocking)
75        let _ = self.event_tx.send(event);
76    }
77
78    /// Subscribe to the event broadcast channel (used by RESP3 SUBSCRIBE).
79    pub fn subscribe_events(&self) -> broadcast::Receiver<Arc<Event>> {
80        self.event_tx.subscribe()
81    }
82
83    /// Handle a new WebSocket connection
84    pub async fn handle_socket(&self, socket: WebSocket) {
85        let client_id = Uuid::new_v4();
86        tracing::info!("🔌 WebSocket client connected: {}", client_id);
87
88        // Subscribe to broadcast channel
89        let event_rx = self.event_tx.subscribe();
90
91        // Split socket into sender and receiver
92        let (sender, mut receiver) = socket.split();
93
94        // Register client
95        self.clients.insert(
96            client_id,
97            ClientInfo {
98                id: client_id,
99                filters: EventFilters::default(),
100            },
101        );
102
103        // Spawn send task based on config
104        let clients = Arc::clone(&self.clients);
105        let config = self.config.clone();
106        let send_task = tokio::spawn(async move {
107            if let Some(interval_ms) = config.batch_interval_ms {
108                Self::send_batched(
109                    event_rx,
110                    sender,
111                    clients,
112                    client_id,
113                    interval_ms,
114                    config.max_batch_size,
115                )
116                .await;
117            } else {
118                Self::send_unbatched(event_rx, sender, clients, client_id).await;
119            }
120        });
121
122        // Handle incoming messages from client (for setting filters)
123        let clients = Arc::clone(&self.clients);
124        let recv_task = tokio::spawn(async move {
125            while let Some(Ok(msg)) = receiver.next().await {
126                if let Message::Text(text) = msg {
127                    // Parse filter commands (text is Utf8Bytes in axum 0.8+)
128                    if let Ok(filters) = serde_json::from_str::<EventFilters>(text.as_str()) {
129                        tracing::info!("Setting filters for client {}: {:?}", client_id, filters);
130                        if let Some(mut client) = clients.get_mut(&client_id) {
131                            client.filters = filters;
132                        }
133                    }
134                }
135            }
136        });
137
138        // Wait for either task to finish
139        tokio::select! {
140            _ = send_task => {
141                tracing::info!("Send task ended for client {}", client_id);
142            }
143            _ = recv_task => {
144                tracing::info!("Receive task ended for client {}", client_id);
145            }
146        }
147
148        // Clean up client
149        self.clients.remove(&client_id);
150        tracing::info!("🔌 WebSocket client disconnected: {}", client_id);
151    }
152
153    /// Unbatched send loop — original behavior (one message per event).
154    async fn send_unbatched(
155        mut event_rx: broadcast::Receiver<Arc<Event>>,
156        mut sender: futures::stream::SplitSink<WebSocket, Message>,
157        clients: Arc<DashMap<Uuid, ClientInfo>>,
158        client_id: Uuid,
159    ) {
160        loop {
161            match event_rx.recv().await {
162                Ok(event) => {
163                    if !Self::passes_filters(&clients, client_id, &event) {
164                        continue;
165                    }
166
167                    match serde_json::to_string(&*event) {
168                        Ok(json) => {
169                            if sender.send(Message::Text(json.into())).await.is_err() {
170                                tracing::warn!("Failed to send event to client {}", client_id);
171                                break;
172                            }
173                        }
174                        Err(e) => {
175                            tracing::error!("Failed to serialize event: {}", e);
176                        }
177                    }
178                }
179                Err(broadcast::error::RecvError::Lagged(n)) => {
180                    let msg = serde_json::json!({"type": "lagged", "missed": n});
181                    let _ = sender.send(Message::Text(msg.to_string().into())).await;
182                    tracing::warn!("Client {} lagged, missed {} events", client_id, n);
183                }
184                Err(broadcast::error::RecvError::Closed) => break,
185            }
186        }
187    }
188
189    /// Batched send loop — buffers events and flushes periodically or on max batch size.
190    async fn send_batched(
191        mut event_rx: broadcast::Receiver<Arc<Event>>,
192        mut sender: futures::stream::SplitSink<WebSocket, Message>,
193        clients: Arc<DashMap<Uuid, ClientInfo>>,
194        client_id: Uuid,
195        interval_ms: u64,
196        max_batch_size: usize,
197    ) {
198        let mut batch: Vec<serde_json::Value> = Vec::with_capacity(max_batch_size);
199        let mut ticker = tokio::time::interval(std::time::Duration::from_millis(interval_ms));
200        ticker.tick().await; // first tick completes immediately
201
202        loop {
203            tokio::select! {
204                result = event_rx.recv() => {
205                    match result {
206                        Ok(event) => {
207                            if !Self::passes_filters(&clients, client_id, &event) {
208                                continue;
209                            }
210
211                            if let Ok(val) = serde_json::to_value(&*event) {
212                                batch.push(val);
213                            }
214
215                            // Flush early if batch is full
216                            if batch.len() >= max_batch_size
217                                && !Self::flush_batch(&mut sender, &mut batch, client_id).await
218                            {
219                                break;
220                            }
221                        }
222                        Err(broadcast::error::RecvError::Lagged(n)) => {
223                            // Flush any pending batch first
224                            let _ = Self::flush_batch(&mut sender, &mut batch, client_id).await;
225                            let msg = serde_json::json!({"type": "lagged", "missed": n});
226                            let _ = sender
227                                .send(Message::Text(msg.to_string().into()))
228                                .await;
229                            tracing::warn!(
230                                "Client {} lagged, missed {} events",
231                                client_id,
232                                n
233                            );
234                        }
235                        Err(broadcast::error::RecvError::Closed) => {
236                            // Flush remaining batch before exit
237                            let _ = Self::flush_batch(&mut sender, &mut batch, client_id).await;
238                            break;
239                        }
240                    }
241                }
242                _ = ticker.tick() => {
243                    if !batch.is_empty()
244                        && !Self::flush_batch(&mut sender, &mut batch, client_id).await
245                    {
246                        break;
247                    }
248                }
249            }
250        }
251    }
252
253    /// Flush the current batch as a JSON array. Returns false if send failed.
254    async fn flush_batch(
255        sender: &mut futures::stream::SplitSink<WebSocket, Message>,
256        batch: &mut Vec<serde_json::Value>,
257        client_id: Uuid,
258    ) -> bool {
259        if batch.is_empty() {
260            return true;
261        }
262
263        let json_array = serde_json::Value::Array(std::mem::take(batch));
264        match serde_json::to_string(&json_array) {
265            Ok(json) => {
266                if sender.send(Message::Text(json.into())).await.is_err() {
267                    tracing::warn!("Failed to send batch to client {}", client_id);
268                    return false;
269                }
270                true
271            }
272            Err(e) => {
273                tracing::error!("Failed to serialize batch: {}", e);
274                batch.clear();
275                true
276            }
277        }
278    }
279
280    /// Check if an event passes the client's filters.
281    fn passes_filters(clients: &DashMap<Uuid, ClientInfo>, client_id: Uuid, event: &Event) -> bool {
282        let filters = clients
283            .get(&client_id)
284            .map(|entry| entry.value().filters.clone())
285            .unwrap_or_default();
286
287        if let Some(ref entity_id) = filters.entity_id
288            && event.entity_id_str() != entity_id
289        {
290            return false;
291        }
292
293        if let Some(ref event_type) = filters.event_type
294            && event.event_type_str() != event_type
295        {
296            return false;
297        }
298
299        true
300    }
301
302    /// Get statistics about connected clients
303    pub fn stats(&self) -> WebSocketStats {
304        WebSocketStats {
305            connected_clients: self.clients.len(),
306            total_capacity: self.event_tx.receiver_count(),
307        }
308    }
309}
310
311impl Default for WebSocketManager {
312    fn default() -> Self {
313        Self::new()
314    }
315}
316
317#[derive(Debug, serde::Serialize)]
318pub struct WebSocketStats {
319    pub connected_clients: usize,
320    pub total_capacity: usize,
321}
322
323#[cfg(test)]
324mod tests {
325    use super::*;
326    use serde_json::json;
327
328    fn create_test_event() -> Event {
329        Event::reconstruct_from_strings(
330            Uuid::new_v4(),
331            "test.event".to_string(),
332            "test-entity".to_string(),
333            "default".to_string(),
334            json!({"test": "data"}),
335            chrono::Utc::now(),
336            None,
337            1,
338        )
339    }
340
341    #[test]
342    fn test_websocket_manager_creation() {
343        let manager = WebSocketManager::new();
344        let stats = manager.stats();
345        assert_eq!(stats.connected_clients, 0);
346    }
347
348    #[test]
349    fn test_event_broadcast() {
350        let manager = WebSocketManager::new();
351        let event = Arc::new(create_test_event());
352
353        // Should not panic
354        manager.broadcast_event(event);
355    }
356
357    #[test]
358    fn test_config_defaults() {
359        let config = WebSocketConfig::default();
360        assert_eq!(config.capacity, 1000);
361        assert_eq!(config.batch_interval_ms, None);
362        assert_eq!(config.max_batch_size, 100);
363    }
364
365    #[test]
366    fn test_lagged_notification() {
367        // Create a tiny channel that will lag quickly
368        let config = WebSocketConfig {
369            capacity: 2,
370            batch_interval_ms: None,
371            max_batch_size: 100,
372        };
373        let manager = WebSocketManager::with_config(config);
374
375        // Subscribe, then overflow the channel
376        let mut rx = manager.subscribe_events();
377        for _ in 0..5 {
378            manager.broadcast_event(Arc::new(create_test_event()));
379        }
380
381        // The receiver should get a Lagged error
382        match rx.try_recv() {
383            Err(broadcast::error::TryRecvError::Lagged(n)) => {
384                assert!(n > 0, "should report missed events");
385            }
386            Ok(_) => {
387                // Got an event — that's fine, lagged may come on next recv
388            }
389            Err(e) => {
390                panic!("unexpected error: {:?}", e);
391            }
392        }
393    }
394
395    #[test]
396    fn test_batch_mode_groups_events() {
397        // Verify that with_config creates a manager with batching params
398        let config = WebSocketConfig {
399            capacity: 1000,
400            batch_interval_ms: Some(50),
401            max_batch_size: 10,
402        };
403        let manager = WebSocketManager::with_config(config.clone());
404        assert_eq!(manager.config.batch_interval_ms, Some(50));
405        assert_eq!(manager.config.max_batch_size, 10);
406
407        // The actual batching behavior is tested via the flush_batch helper
408        let rt = tokio::runtime::Builder::new_current_thread()
409            .enable_all()
410            .build()
411            .unwrap();
412
413        rt.block_on(async {
414            // Create a batch of events and serialize as JSON array
415            let events: Vec<serde_json::Value> = (0..3)
416                .map(|_| serde_json::to_value(create_test_event()).unwrap())
417                .collect();
418
419            let json_array = serde_json::Value::Array(events);
420            let serialized = serde_json::to_string(&json_array).unwrap();
421            let parsed: Vec<serde_json::Value> = serde_json::from_str(&serialized).unwrap();
422            assert_eq!(parsed.len(), 3);
423        });
424    }
425
426    #[test]
427    fn test_batch_flush_on_max_size() {
428        // Verify config with small max_batch_size
429        let config = WebSocketConfig {
430            capacity: 1000,
431            batch_interval_ms: Some(1000), // long interval
432            max_batch_size: 5,             // small batch — triggers early flush
433        };
434        let manager = WebSocketManager::with_config(config);
435        assert_eq!(manager.config.max_batch_size, 5);
436    }
437
438    #[test]
439    fn test_backward_compat_no_config() {
440        // Default constructor should work identically to pre-backpressure behavior
441        let manager = WebSocketManager::new();
442        assert_eq!(manager.config.capacity, 1000);
443        assert!(manager.config.batch_interval_ms.is_none());
444
445        // Broadcast still works
446        let event = Arc::new(create_test_event());
447        manager.broadcast_event(event);
448
449        let stats = manager.stats();
450        assert_eq!(stats.connected_clients, 0);
451    }
452}