Skip to main content

allsource_core/infrastructure/web/
websocket.rs

1use crate::domain::entities::Event;
2use axum::extract::ws::{Message, WebSocket};
3use dashmap::DashMap;
4use futures::{sink::SinkExt, stream::StreamExt};
5use std::sync::Arc;
6use tokio::sync::broadcast;
7use uuid::Uuid;
8
9/// Configuration for WebSocket backpressure and batching.
10#[derive(Debug, Clone)]
11pub struct WebSocketConfig {
12    /// Broadcast channel capacity (default 1000).
13    pub capacity: usize,
14    /// Optional batching interval in milliseconds.
15    /// `None` = no batching (current behavior, backward compatible).
16    /// `Some(50)` = buffer events and flush every 50ms as JSON arrays.
17    pub batch_interval_ms: Option<u64>,
18    /// Flush early when the batch reaches this size (default 100).
19    pub max_batch_size: usize,
20}
21
22impl Default for WebSocketConfig {
23    fn default() -> Self {
24        Self {
25            capacity: 1000,
26            batch_interval_ms: None,
27            max_batch_size: 100,
28        }
29    }
30}
31
32/// WebSocket manager for real-time event streaming (v0.2 feature)
33pub struct WebSocketManager {
34    /// Broadcast channel for sending events to all connected clients
35    event_tx: broadcast::Sender<Arc<Event>>,
36
37    /// Connected clients by ID - using DashMap for lock-free concurrent access
38    clients: Arc<DashMap<Uuid, ClientInfo>>,
39
40    /// Backpressure and batching configuration
41    config: WebSocketConfig,
42}
43
44#[derive(Debug, Clone)]
45struct ClientInfo {
46    id: Uuid,
47    filters: EventFilters,
48}
49
50#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)]
51pub struct EventFilters {
52    pub entity_id: Option<String>,
53    pub event_type: Option<String>,
54}
55
56impl WebSocketManager {
57    pub fn new() -> Self {
58        Self::with_config(WebSocketConfig::default())
59    }
60
61    /// Create a WebSocket manager with custom backpressure configuration.
62    pub fn with_config(config: WebSocketConfig) -> Self {
63        let (event_tx, _) = broadcast::channel(config.capacity);
64
65        Self {
66            event_tx,
67            clients: Arc::new(DashMap::new()),
68            config,
69        }
70    }
71
72    /// Broadcast an event to all connected WebSocket clients
73    pub fn broadcast_event(&self, event: Arc<Event>) {
74        // Send to broadcast channel (non-blocking)
75        let _ = self.event_tx.send(event);
76    }
77
78    /// Subscribe to the event broadcast channel (used by RESP3 SUBSCRIBE).
79    pub fn subscribe_events(&self) -> broadcast::Receiver<Arc<Event>> {
80        self.event_tx.subscribe()
81    }
82
83    /// Handle a new WebSocket connection
84    pub async fn handle_socket(&self, socket: WebSocket) {
85        let client_id = Uuid::new_v4();
86        tracing::info!("🔌 WebSocket client connected: {}", client_id);
87
88        // Subscribe to broadcast channel
89        let event_rx = self.event_tx.subscribe();
90
91        // Split socket into sender and receiver
92        let (sender, mut receiver) = socket.split();
93
94        // Register client
95        self.clients.insert(
96            client_id,
97            ClientInfo {
98                id: client_id,
99                filters: EventFilters::default(),
100            },
101        );
102
103        // Spawn send task based on config
104        let clients = Arc::clone(&self.clients);
105        let config = self.config.clone();
106        let send_task = tokio::spawn(async move {
107            if let Some(interval_ms) = config.batch_interval_ms {
108                Self::send_batched(
109                    event_rx,
110                    sender,
111                    clients,
112                    client_id,
113                    interval_ms,
114                    config.max_batch_size,
115                )
116                .await;
117            } else {
118                Self::send_unbatched(event_rx, sender, clients, client_id).await;
119            }
120        });
121
122        // Handle incoming messages from client (for setting filters)
123        let clients = Arc::clone(&self.clients);
124        let recv_task = tokio::spawn(async move {
125            while let Some(Ok(msg)) = receiver.next().await {
126                if let Message::Text(text) = msg {
127                    // Parse filter commands (text is Utf8Bytes in axum 0.8+)
128                    if let Ok(filters) = serde_json::from_str::<EventFilters>(text.as_str()) {
129                        tracing::info!("Setting filters for client {}: {:?}", client_id, filters);
130                        if let Some(mut client) = clients.get_mut(&client_id) {
131                            client.filters = filters;
132                        }
133                    }
134                }
135            }
136        });
137
138        // Wait for either task to finish
139        tokio::select! {
140            _ = send_task => {
141                tracing::info!("Send task ended for client {}", client_id);
142            }
143            _ = recv_task => {
144                tracing::info!("Receive task ended for client {}", client_id);
145            }
146        }
147
148        // Clean up client
149        self.clients.remove(&client_id);
150        tracing::info!("🔌 WebSocket client disconnected: {}", client_id);
151    }
152
153    /// Unbatched send loop — original behavior (one message per event).
154    async fn send_unbatched(
155        mut event_rx: broadcast::Receiver<Arc<Event>>,
156        mut sender: futures::stream::SplitSink<WebSocket, Message>,
157        clients: Arc<DashMap<Uuid, ClientInfo>>,
158        client_id: Uuid,
159    ) {
160        loop {
161            match event_rx.recv().await {
162                Ok(event) => {
163                    if !Self::passes_filters(&clients, client_id, &event) {
164                        continue;
165                    }
166
167                    match serde_json::to_string(&*event) {
168                        Ok(json) => {
169                            if sender.send(Message::Text(json.into())).await.is_err() {
170                                tracing::warn!(
171                                    "Failed to send event to client {}",
172                                    client_id
173                                );
174                                break;
175                            }
176                        }
177                        Err(e) => {
178                            tracing::error!("Failed to serialize event: {}", e);
179                        }
180                    }
181                }
182                Err(broadcast::error::RecvError::Lagged(n)) => {
183                    let msg = serde_json::json!({"type": "lagged", "missed": n});
184                    let _ = sender
185                        .send(Message::Text(msg.to_string().into()))
186                        .await;
187                    tracing::warn!(
188                        "Client {} lagged, missed {} events",
189                        client_id,
190                        n
191                    );
192                }
193                Err(broadcast::error::RecvError::Closed) => break,
194            }
195        }
196    }
197
198    /// Batched send loop — buffers events and flushes periodically or on max batch size.
199    async fn send_batched(
200        mut event_rx: broadcast::Receiver<Arc<Event>>,
201        mut sender: futures::stream::SplitSink<WebSocket, Message>,
202        clients: Arc<DashMap<Uuid, ClientInfo>>,
203        client_id: Uuid,
204        interval_ms: u64,
205        max_batch_size: usize,
206    ) {
207        let mut batch: Vec<serde_json::Value> = Vec::with_capacity(max_batch_size);
208        let mut ticker = tokio::time::interval(std::time::Duration::from_millis(interval_ms));
209        ticker.tick().await; // first tick completes immediately
210
211        loop {
212            tokio::select! {
213                result = event_rx.recv() => {
214                    match result {
215                        Ok(event) => {
216                            if !Self::passes_filters(&clients, client_id, &event) {
217                                continue;
218                            }
219
220                            if let Ok(val) = serde_json::to_value(&*event) {
221                                batch.push(val);
222                            }
223
224                            // Flush early if batch is full
225                            if batch.len() >= max_batch_size {
226                                if !Self::flush_batch(&mut sender, &mut batch, client_id).await {
227                                    break;
228                                }
229                            }
230                        }
231                        Err(broadcast::error::RecvError::Lagged(n)) => {
232                            // Flush any pending batch first
233                            let _ = Self::flush_batch(&mut sender, &mut batch, client_id).await;
234                            let msg = serde_json::json!({"type": "lagged", "missed": n});
235                            let _ = sender
236                                .send(Message::Text(msg.to_string().into()))
237                                .await;
238                            tracing::warn!(
239                                "Client {} lagged, missed {} events",
240                                client_id,
241                                n
242                            );
243                        }
244                        Err(broadcast::error::RecvError::Closed) => {
245                            // Flush remaining batch before exit
246                            let _ = Self::flush_batch(&mut sender, &mut batch, client_id).await;
247                            break;
248                        }
249                    }
250                }
251                _ = ticker.tick() => {
252                    if !batch.is_empty() {
253                        if !Self::flush_batch(&mut sender, &mut batch, client_id).await {
254                            break;
255                        }
256                    }
257                }
258            }
259        }
260    }
261
262    /// Flush the current batch as a JSON array. Returns false if send failed.
263    async fn flush_batch(
264        sender: &mut futures::stream::SplitSink<WebSocket, Message>,
265        batch: &mut Vec<serde_json::Value>,
266        client_id: Uuid,
267    ) -> bool {
268        if batch.is_empty() {
269            return true;
270        }
271
272        let json_array = serde_json::Value::Array(std::mem::take(batch));
273        match serde_json::to_string(&json_array) {
274            Ok(json) => {
275                if sender.send(Message::Text(json.into())).await.is_err() {
276                    tracing::warn!("Failed to send batch to client {}", client_id);
277                    return false;
278                }
279                true
280            }
281            Err(e) => {
282                tracing::error!("Failed to serialize batch: {}", e);
283                batch.clear();
284                true
285            }
286        }
287    }
288
289    /// Check if an event passes the client's filters.
290    fn passes_filters(
291        clients: &DashMap<Uuid, ClientInfo>,
292        client_id: Uuid,
293        event: &Event,
294    ) -> bool {
295        let filters = clients
296            .get(&client_id)
297            .map(|entry| entry.value().filters.clone())
298            .unwrap_or_default();
299
300        if let Some(ref entity_id) = filters.entity_id
301            && event.entity_id_str() != entity_id
302        {
303            return false;
304        }
305
306        if let Some(ref event_type) = filters.event_type
307            && event.event_type_str() != event_type
308        {
309            return false;
310        }
311
312        true
313    }
314
315    /// Get statistics about connected clients
316    pub fn stats(&self) -> WebSocketStats {
317        WebSocketStats {
318            connected_clients: self.clients.len(),
319            total_capacity: self.event_tx.receiver_count(),
320        }
321    }
322}
323
324impl Default for WebSocketManager {
325    fn default() -> Self {
326        Self::new()
327    }
328}
329
330#[derive(Debug, serde::Serialize)]
331pub struct WebSocketStats {
332    pub connected_clients: usize,
333    pub total_capacity: usize,
334}
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339    use serde_json::json;
340
341    fn create_test_event() -> Event {
342        Event::reconstruct_from_strings(
343            Uuid::new_v4(),
344            "test.event".to_string(),
345            "test-entity".to_string(),
346            "default".to_string(),
347            json!({"test": "data"}),
348            chrono::Utc::now(),
349            None,
350            1,
351        )
352    }
353
354    #[test]
355    fn test_websocket_manager_creation() {
356        let manager = WebSocketManager::new();
357        let stats = manager.stats();
358        assert_eq!(stats.connected_clients, 0);
359    }
360
361    #[test]
362    fn test_event_broadcast() {
363        let manager = WebSocketManager::new();
364        let event = Arc::new(create_test_event());
365
366        // Should not panic
367        manager.broadcast_event(event);
368    }
369
370    #[test]
371    fn test_config_defaults() {
372        let config = WebSocketConfig::default();
373        assert_eq!(config.capacity, 1000);
374        assert_eq!(config.batch_interval_ms, None);
375        assert_eq!(config.max_batch_size, 100);
376    }
377
378    #[test]
379    fn test_lagged_notification() {
380        // Create a tiny channel that will lag quickly
381        let config = WebSocketConfig {
382            capacity: 2,
383            batch_interval_ms: None,
384            max_batch_size: 100,
385        };
386        let manager = WebSocketManager::with_config(config);
387
388        // Subscribe, then overflow the channel
389        let mut rx = manager.subscribe_events();
390        for _ in 0..5 {
391            manager.broadcast_event(Arc::new(create_test_event()));
392        }
393
394        // The receiver should get a Lagged error
395        match rx.try_recv() {
396            Err(broadcast::error::TryRecvError::Lagged(n)) => {
397                assert!(n > 0, "should report missed events");
398            }
399            Ok(_) => {
400                // Got an event — that's fine, lagged may come on next recv
401            }
402            Err(e) => {
403                panic!("unexpected error: {:?}", e);
404            }
405        }
406    }
407
408    #[test]
409    fn test_batch_mode_groups_events() {
410        // Verify that with_config creates a manager with batching params
411        let config = WebSocketConfig {
412            capacity: 1000,
413            batch_interval_ms: Some(50),
414            max_batch_size: 10,
415        };
416        let manager = WebSocketManager::with_config(config.clone());
417        assert_eq!(manager.config.batch_interval_ms, Some(50));
418        assert_eq!(manager.config.max_batch_size, 10);
419
420        // The actual batching behavior is tested via the flush_batch helper
421        let rt = tokio::runtime::Builder::new_current_thread()
422            .enable_all()
423            .build()
424            .unwrap();
425
426        rt.block_on(async {
427            // Create a batch of events and serialize as JSON array
428            let events: Vec<serde_json::Value> = (0..3)
429                .map(|_| serde_json::to_value(&create_test_event()).unwrap())
430                .collect();
431
432            let json_array = serde_json::Value::Array(events);
433            let serialized = serde_json::to_string(&json_array).unwrap();
434            let parsed: Vec<serde_json::Value> = serde_json::from_str(&serialized).unwrap();
435            assert_eq!(parsed.len(), 3);
436        });
437    }
438
439    #[test]
440    fn test_batch_flush_on_max_size() {
441        // Verify config with small max_batch_size
442        let config = WebSocketConfig {
443            capacity: 1000,
444            batch_interval_ms: Some(1000), // long interval
445            max_batch_size: 5,             // small batch — triggers early flush
446        };
447        let manager = WebSocketManager::with_config(config);
448        assert_eq!(manager.config.max_batch_size, 5);
449    }
450
451    #[test]
452    fn test_backward_compat_no_config() {
453        // Default constructor should work identically to pre-backpressure behavior
454        let manager = WebSocketManager::new();
455        assert_eq!(manager.config.capacity, 1000);
456        assert!(manager.config.batch_interval_ms.is_none());
457
458        // Broadcast still works
459        let event = Arc::new(create_test_event());
460        manager.broadcast_event(event);
461
462        let stats = manager.stats();
463        assert_eq!(stats.connected_clients, 0);
464    }
465}