allsource_core/infrastructure/web/
websocket.rs1use crate::domain::entities::Event;
2use axum::extract::ws::{Message, WebSocket};
3use dashmap::DashMap;
4use futures::{sink::SinkExt, stream::StreamExt};
5use std::sync::Arc;
6use tokio::sync::broadcast;
7use uuid::Uuid;
8
9#[derive(Debug, Clone)]
11pub struct WebSocketConfig {
12 pub capacity: usize,
14 pub batch_interval_ms: Option<u64>,
18 pub max_batch_size: usize,
20}
21
22impl Default for WebSocketConfig {
23 fn default() -> Self {
24 Self {
25 capacity: 1000,
26 batch_interval_ms: None,
27 max_batch_size: 100,
28 }
29 }
30}
31
32pub struct WebSocketManager {
34 event_tx: broadcast::Sender<Arc<Event>>,
36
37 clients: Arc<DashMap<Uuid, ClientInfo>>,
39
40 config: WebSocketConfig,
42}
43
44#[derive(Debug, Clone)]
45struct ClientInfo {
46 id: Uuid,
47 filters: EventFilters,
48}
49
50#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)]
51pub struct EventFilters {
52 pub entity_id: Option<String>,
53 pub event_type: Option<String>,
54}
55
56impl WebSocketManager {
57 pub fn new() -> Self {
58 Self::with_config(WebSocketConfig::default())
59 }
60
61 pub fn with_config(config: WebSocketConfig) -> Self {
63 let (event_tx, _) = broadcast::channel(config.capacity);
64
65 Self {
66 event_tx,
67 clients: Arc::new(DashMap::new()),
68 config,
69 }
70 }
71
72 pub fn broadcast_event(&self, event: Arc<Event>) {
74 let _ = self.event_tx.send(event);
76 }
77
78 pub fn subscribe_events(&self) -> broadcast::Receiver<Arc<Event>> {
80 self.event_tx.subscribe()
81 }
82
83 pub async fn handle_socket(&self, socket: WebSocket) {
85 let client_id = Uuid::new_v4();
86 tracing::info!("🔌 WebSocket client connected: {}", client_id);
87
88 let event_rx = self.event_tx.subscribe();
90
91 let (sender, mut receiver) = socket.split();
93
94 self.clients.insert(
96 client_id,
97 ClientInfo {
98 id: client_id,
99 filters: EventFilters::default(),
100 },
101 );
102
103 let clients = Arc::clone(&self.clients);
105 let config = self.config.clone();
106 let send_task = tokio::spawn(async move {
107 if let Some(interval_ms) = config.batch_interval_ms {
108 Self::send_batched(
109 event_rx,
110 sender,
111 clients,
112 client_id,
113 interval_ms,
114 config.max_batch_size,
115 )
116 .await;
117 } else {
118 Self::send_unbatched(event_rx, sender, clients, client_id).await;
119 }
120 });
121
122 let clients = Arc::clone(&self.clients);
124 let recv_task = tokio::spawn(async move {
125 while let Some(Ok(msg)) = receiver.next().await {
126 if let Message::Text(text) = msg {
127 if let Ok(filters) = serde_json::from_str::<EventFilters>(text.as_str()) {
129 tracing::info!("Setting filters for client {}: {:?}", client_id, filters);
130 if let Some(mut client) = clients.get_mut(&client_id) {
131 client.filters = filters;
132 }
133 }
134 }
135 }
136 });
137
138 tokio::select! {
140 _ = send_task => {
141 tracing::info!("Send task ended for client {}", client_id);
142 }
143 _ = recv_task => {
144 tracing::info!("Receive task ended for client {}", client_id);
145 }
146 }
147
148 self.clients.remove(&client_id);
150 tracing::info!("🔌 WebSocket client disconnected: {}", client_id);
151 }
152
153 async fn send_unbatched(
155 mut event_rx: broadcast::Receiver<Arc<Event>>,
156 mut sender: futures::stream::SplitSink<WebSocket, Message>,
157 clients: Arc<DashMap<Uuid, ClientInfo>>,
158 client_id: Uuid,
159 ) {
160 loop {
161 match event_rx.recv().await {
162 Ok(event) => {
163 if !Self::passes_filters(&clients, client_id, &event) {
164 continue;
165 }
166
167 match serde_json::to_string(&*event) {
168 Ok(json) => {
169 if sender.send(Message::Text(json.into())).await.is_err() {
170 tracing::warn!(
171 "Failed to send event to client {}",
172 client_id
173 );
174 break;
175 }
176 }
177 Err(e) => {
178 tracing::error!("Failed to serialize event: {}", e);
179 }
180 }
181 }
182 Err(broadcast::error::RecvError::Lagged(n)) => {
183 let msg = serde_json::json!({"type": "lagged", "missed": n});
184 let _ = sender
185 .send(Message::Text(msg.to_string().into()))
186 .await;
187 tracing::warn!(
188 "Client {} lagged, missed {} events",
189 client_id,
190 n
191 );
192 }
193 Err(broadcast::error::RecvError::Closed) => break,
194 }
195 }
196 }
197
198 async fn send_batched(
200 mut event_rx: broadcast::Receiver<Arc<Event>>,
201 mut sender: futures::stream::SplitSink<WebSocket, Message>,
202 clients: Arc<DashMap<Uuid, ClientInfo>>,
203 client_id: Uuid,
204 interval_ms: u64,
205 max_batch_size: usize,
206 ) {
207 let mut batch: Vec<serde_json::Value> = Vec::with_capacity(max_batch_size);
208 let mut ticker = tokio::time::interval(std::time::Duration::from_millis(interval_ms));
209 ticker.tick().await; loop {
212 tokio::select! {
213 result = event_rx.recv() => {
214 match result {
215 Ok(event) => {
216 if !Self::passes_filters(&clients, client_id, &event) {
217 continue;
218 }
219
220 if let Ok(val) = serde_json::to_value(&*event) {
221 batch.push(val);
222 }
223
224 if batch.len() >= max_batch_size {
226 if !Self::flush_batch(&mut sender, &mut batch, client_id).await {
227 break;
228 }
229 }
230 }
231 Err(broadcast::error::RecvError::Lagged(n)) => {
232 let _ = Self::flush_batch(&mut sender, &mut batch, client_id).await;
234 let msg = serde_json::json!({"type": "lagged", "missed": n});
235 let _ = sender
236 .send(Message::Text(msg.to_string().into()))
237 .await;
238 tracing::warn!(
239 "Client {} lagged, missed {} events",
240 client_id,
241 n
242 );
243 }
244 Err(broadcast::error::RecvError::Closed) => {
245 let _ = Self::flush_batch(&mut sender, &mut batch, client_id).await;
247 break;
248 }
249 }
250 }
251 _ = ticker.tick() => {
252 if !batch.is_empty() {
253 if !Self::flush_batch(&mut sender, &mut batch, client_id).await {
254 break;
255 }
256 }
257 }
258 }
259 }
260 }
261
262 async fn flush_batch(
264 sender: &mut futures::stream::SplitSink<WebSocket, Message>,
265 batch: &mut Vec<serde_json::Value>,
266 client_id: Uuid,
267 ) -> bool {
268 if batch.is_empty() {
269 return true;
270 }
271
272 let json_array = serde_json::Value::Array(std::mem::take(batch));
273 match serde_json::to_string(&json_array) {
274 Ok(json) => {
275 if sender.send(Message::Text(json.into())).await.is_err() {
276 tracing::warn!("Failed to send batch to client {}", client_id);
277 return false;
278 }
279 true
280 }
281 Err(e) => {
282 tracing::error!("Failed to serialize batch: {}", e);
283 batch.clear();
284 true
285 }
286 }
287 }
288
289 fn passes_filters(
291 clients: &DashMap<Uuid, ClientInfo>,
292 client_id: Uuid,
293 event: &Event,
294 ) -> bool {
295 let filters = clients
296 .get(&client_id)
297 .map(|entry| entry.value().filters.clone())
298 .unwrap_or_default();
299
300 if let Some(ref entity_id) = filters.entity_id
301 && event.entity_id_str() != entity_id
302 {
303 return false;
304 }
305
306 if let Some(ref event_type) = filters.event_type
307 && event.event_type_str() != event_type
308 {
309 return false;
310 }
311
312 true
313 }
314
315 pub fn stats(&self) -> WebSocketStats {
317 WebSocketStats {
318 connected_clients: self.clients.len(),
319 total_capacity: self.event_tx.receiver_count(),
320 }
321 }
322}
323
324impl Default for WebSocketManager {
325 fn default() -> Self {
326 Self::new()
327 }
328}
329
330#[derive(Debug, serde::Serialize)]
331pub struct WebSocketStats {
332 pub connected_clients: usize,
333 pub total_capacity: usize,
334}
335
336#[cfg(test)]
337mod tests {
338 use super::*;
339 use serde_json::json;
340
341 fn create_test_event() -> Event {
342 Event::reconstruct_from_strings(
343 Uuid::new_v4(),
344 "test.event".to_string(),
345 "test-entity".to_string(),
346 "default".to_string(),
347 json!({"test": "data"}),
348 chrono::Utc::now(),
349 None,
350 1,
351 )
352 }
353
354 #[test]
355 fn test_websocket_manager_creation() {
356 let manager = WebSocketManager::new();
357 let stats = manager.stats();
358 assert_eq!(stats.connected_clients, 0);
359 }
360
361 #[test]
362 fn test_event_broadcast() {
363 let manager = WebSocketManager::new();
364 let event = Arc::new(create_test_event());
365
366 manager.broadcast_event(event);
368 }
369
370 #[test]
371 fn test_config_defaults() {
372 let config = WebSocketConfig::default();
373 assert_eq!(config.capacity, 1000);
374 assert_eq!(config.batch_interval_ms, None);
375 assert_eq!(config.max_batch_size, 100);
376 }
377
378 #[test]
379 fn test_lagged_notification() {
380 let config = WebSocketConfig {
382 capacity: 2,
383 batch_interval_ms: None,
384 max_batch_size: 100,
385 };
386 let manager = WebSocketManager::with_config(config);
387
388 let mut rx = manager.subscribe_events();
390 for _ in 0..5 {
391 manager.broadcast_event(Arc::new(create_test_event()));
392 }
393
394 match rx.try_recv() {
396 Err(broadcast::error::TryRecvError::Lagged(n)) => {
397 assert!(n > 0, "should report missed events");
398 }
399 Ok(_) => {
400 }
402 Err(e) => {
403 panic!("unexpected error: {:?}", e);
404 }
405 }
406 }
407
408 #[test]
409 fn test_batch_mode_groups_events() {
410 let config = WebSocketConfig {
412 capacity: 1000,
413 batch_interval_ms: Some(50),
414 max_batch_size: 10,
415 };
416 let manager = WebSocketManager::with_config(config.clone());
417 assert_eq!(manager.config.batch_interval_ms, Some(50));
418 assert_eq!(manager.config.max_batch_size, 10);
419
420 let rt = tokio::runtime::Builder::new_current_thread()
422 .enable_all()
423 .build()
424 .unwrap();
425
426 rt.block_on(async {
427 let events: Vec<serde_json::Value> = (0..3)
429 .map(|_| serde_json::to_value(&create_test_event()).unwrap())
430 .collect();
431
432 let json_array = serde_json::Value::Array(events);
433 let serialized = serde_json::to_string(&json_array).unwrap();
434 let parsed: Vec<serde_json::Value> = serde_json::from_str(&serialized).unwrap();
435 assert_eq!(parsed.len(), 3);
436 });
437 }
438
439 #[test]
440 fn test_batch_flush_on_max_size() {
441 let config = WebSocketConfig {
443 capacity: 1000,
444 batch_interval_ms: Some(1000), max_batch_size: 5, };
447 let manager = WebSocketManager::with_config(config);
448 assert_eq!(manager.config.max_batch_size, 5);
449 }
450
451 #[test]
452 fn test_backward_compat_no_config() {
453 let manager = WebSocketManager::new();
455 assert_eq!(manager.config.capacity, 1000);
456 assert!(manager.config.batch_interval_ms.is_none());
457
458 let event = Arc::new(create_test_event());
460 manager.broadcast_event(event);
461
462 let stats = manager.stats();
463 assert_eq!(stats.connected_clients, 0);
464 }
465}