allsource_core/infrastructure/web/
websocket.rs1use crate::domain::entities::Event;
2use axum::extract::ws::{Message, WebSocket};
3use dashmap::DashMap;
4use futures::{sink::SinkExt, stream::StreamExt};
5use std::sync::Arc;
6use tokio::sync::broadcast;
7use uuid::Uuid;
8
9#[derive(Debug, Clone)]
11pub struct WebSocketConfig {
12 pub capacity: usize,
14 pub batch_interval_ms: Option<u64>,
18 pub max_batch_size: usize,
20}
21
22impl Default for WebSocketConfig {
23 fn default() -> Self {
24 Self {
25 capacity: 1000,
26 batch_interval_ms: None,
27 max_batch_size: 100,
28 }
29 }
30}
31
32pub struct WebSocketManager {
34 event_tx: broadcast::Sender<Arc<Event>>,
36
37 clients: Arc<DashMap<Uuid, ClientInfo>>,
39
40 config: WebSocketConfig,
42}
43
44#[derive(Debug, Clone)]
45struct ClientInfo {
46 id: Uuid,
47 filters: EventFilters,
48}
49
50#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)]
51pub struct EventFilters {
52 pub entity_id: Option<String>,
53 pub event_type: Option<String>,
54}
55
56impl WebSocketManager {
57 pub fn new() -> Self {
58 Self::with_config(WebSocketConfig::default())
59 }
60
61 pub fn with_config(config: WebSocketConfig) -> Self {
63 let (event_tx, _) = broadcast::channel(config.capacity);
64
65 Self {
66 event_tx,
67 clients: Arc::new(DashMap::new()),
68 config,
69 }
70 }
71
72 pub fn broadcast_event(&self, event: Arc<Event>) {
74 let _ = self.event_tx.send(event);
76 }
77
78 pub fn subscribe_events(&self) -> broadcast::Receiver<Arc<Event>> {
80 self.event_tx.subscribe()
81 }
82
83 pub async fn handle_socket(&self, socket: WebSocket) {
85 let client_id = Uuid::new_v4();
86 tracing::info!("🔌 WebSocket client connected: {}", client_id);
87
88 let event_rx = self.event_tx.subscribe();
90
91 let (sender, mut receiver) = socket.split();
93
94 self.clients.insert(
96 client_id,
97 ClientInfo {
98 id: client_id,
99 filters: EventFilters::default(),
100 },
101 );
102
103 let clients = Arc::clone(&self.clients);
105 let config = self.config.clone();
106 let send_task = tokio::spawn(async move {
107 if let Some(interval_ms) = config.batch_interval_ms {
108 Self::send_batched(
109 event_rx,
110 sender,
111 clients,
112 client_id,
113 interval_ms,
114 config.max_batch_size,
115 )
116 .await;
117 } else {
118 Self::send_unbatched(event_rx, sender, clients, client_id).await;
119 }
120 });
121
122 let clients = Arc::clone(&self.clients);
124 let recv_task = tokio::spawn(async move {
125 while let Some(Ok(msg)) = receiver.next().await {
126 if let Message::Text(text) = msg {
127 if let Ok(filters) = serde_json::from_str::<EventFilters>(text.as_str()) {
129 tracing::info!("Setting filters for client {}: {:?}", client_id, filters);
130 if let Some(mut client) = clients.get_mut(&client_id) {
131 client.filters = filters;
132 }
133 }
134 }
135 }
136 });
137
138 tokio::select! {
140 _ = send_task => {
141 tracing::info!("Send task ended for client {}", client_id);
142 }
143 _ = recv_task => {
144 tracing::info!("Receive task ended for client {}", client_id);
145 }
146 }
147
148 self.clients.remove(&client_id);
150 tracing::info!("🔌 WebSocket client disconnected: {}", client_id);
151 }
152
153 async fn send_unbatched(
155 mut event_rx: broadcast::Receiver<Arc<Event>>,
156 mut sender: futures::stream::SplitSink<WebSocket, Message>,
157 clients: Arc<DashMap<Uuid, ClientInfo>>,
158 client_id: Uuid,
159 ) {
160 loop {
161 match event_rx.recv().await {
162 Ok(event) => {
163 if !Self::passes_filters(&clients, client_id, &event) {
164 continue;
165 }
166
167 match serde_json::to_string(&*event) {
168 Ok(json) => {
169 if sender.send(Message::Text(json.into())).await.is_err() {
170 tracing::warn!("Failed to send event to client {}", client_id);
171 break;
172 }
173 }
174 Err(e) => {
175 tracing::error!("Failed to serialize event: {}", e);
176 }
177 }
178 }
179 Err(broadcast::error::RecvError::Lagged(n)) => {
180 let msg = serde_json::json!({"type": "lagged", "missed": n});
181 let _ = sender.send(Message::Text(msg.to_string().into())).await;
182 tracing::warn!("Client {} lagged, missed {} events", client_id, n);
183 }
184 Err(broadcast::error::RecvError::Closed) => break,
185 }
186 }
187 }
188
189 async fn send_batched(
191 mut event_rx: broadcast::Receiver<Arc<Event>>,
192 mut sender: futures::stream::SplitSink<WebSocket, Message>,
193 clients: Arc<DashMap<Uuid, ClientInfo>>,
194 client_id: Uuid,
195 interval_ms: u64,
196 max_batch_size: usize,
197 ) {
198 let mut batch: Vec<serde_json::Value> = Vec::with_capacity(max_batch_size);
199 let mut ticker = tokio::time::interval(std::time::Duration::from_millis(interval_ms));
200 ticker.tick().await; loop {
203 tokio::select! {
204 result = event_rx.recv() => {
205 match result {
206 Ok(event) => {
207 if !Self::passes_filters(&clients, client_id, &event) {
208 continue;
209 }
210
211 if let Ok(val) = serde_json::to_value(&*event) {
212 batch.push(val);
213 }
214
215 if batch.len() >= max_batch_size
217 && !Self::flush_batch(&mut sender, &mut batch, client_id).await
218 {
219 break;
220 }
221 }
222 Err(broadcast::error::RecvError::Lagged(n)) => {
223 let _ = Self::flush_batch(&mut sender, &mut batch, client_id).await;
225 let msg = serde_json::json!({"type": "lagged", "missed": n});
226 let _ = sender
227 .send(Message::Text(msg.to_string().into()))
228 .await;
229 tracing::warn!(
230 "Client {} lagged, missed {} events",
231 client_id,
232 n
233 );
234 }
235 Err(broadcast::error::RecvError::Closed) => {
236 let _ = Self::flush_batch(&mut sender, &mut batch, client_id).await;
238 break;
239 }
240 }
241 }
242 _ = ticker.tick() => {
243 if !batch.is_empty()
244 && !Self::flush_batch(&mut sender, &mut batch, client_id).await
245 {
246 break;
247 }
248 }
249 }
250 }
251 }
252
253 async fn flush_batch(
255 sender: &mut futures::stream::SplitSink<WebSocket, Message>,
256 batch: &mut Vec<serde_json::Value>,
257 client_id: Uuid,
258 ) -> bool {
259 if batch.is_empty() {
260 return true;
261 }
262
263 let json_array = serde_json::Value::Array(std::mem::take(batch));
264 match serde_json::to_string(&json_array) {
265 Ok(json) => {
266 if sender.send(Message::Text(json.into())).await.is_err() {
267 tracing::warn!("Failed to send batch to client {}", client_id);
268 return false;
269 }
270 true
271 }
272 Err(e) => {
273 tracing::error!("Failed to serialize batch: {}", e);
274 batch.clear();
275 true
276 }
277 }
278 }
279
280 fn passes_filters(clients: &DashMap<Uuid, ClientInfo>, client_id: Uuid, event: &Event) -> bool {
282 let filters = clients
283 .get(&client_id)
284 .map(|entry| entry.value().filters.clone())
285 .unwrap_or_default();
286
287 if let Some(ref entity_id) = filters.entity_id
288 && event.entity_id_str() != entity_id
289 {
290 return false;
291 }
292
293 if let Some(ref event_type) = filters.event_type
294 && event.event_type_str() != event_type
295 {
296 return false;
297 }
298
299 true
300 }
301
302 pub fn stats(&self) -> WebSocketStats {
304 WebSocketStats {
305 connected_clients: self.clients.len(),
306 total_capacity: self.event_tx.receiver_count(),
307 }
308 }
309}
310
311impl Default for WebSocketManager {
312 fn default() -> Self {
313 Self::new()
314 }
315}
316
317#[derive(Debug, serde::Serialize)]
318pub struct WebSocketStats {
319 pub connected_clients: usize,
320 pub total_capacity: usize,
321}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326 use serde_json::json;
327
328 fn create_test_event() -> Event {
329 Event::reconstruct_from_strings(
330 Uuid::new_v4(),
331 "test.event".to_string(),
332 "test-entity".to_string(),
333 "default".to_string(),
334 json!({"test": "data"}),
335 chrono::Utc::now(),
336 None,
337 1,
338 )
339 }
340
341 #[test]
342 fn test_websocket_manager_creation() {
343 let manager = WebSocketManager::new();
344 let stats = manager.stats();
345 assert_eq!(stats.connected_clients, 0);
346 }
347
348 #[test]
349 fn test_event_broadcast() {
350 let manager = WebSocketManager::new();
351 let event = Arc::new(create_test_event());
352
353 manager.broadcast_event(event);
355 }
356
357 #[test]
358 fn test_config_defaults() {
359 let config = WebSocketConfig::default();
360 assert_eq!(config.capacity, 1000);
361 assert_eq!(config.batch_interval_ms, None);
362 assert_eq!(config.max_batch_size, 100);
363 }
364
365 #[test]
366 fn test_lagged_notification() {
367 let config = WebSocketConfig {
369 capacity: 2,
370 batch_interval_ms: None,
371 max_batch_size: 100,
372 };
373 let manager = WebSocketManager::with_config(config);
374
375 let mut rx = manager.subscribe_events();
377 for _ in 0..5 {
378 manager.broadcast_event(Arc::new(create_test_event()));
379 }
380
381 match rx.try_recv() {
383 Err(broadcast::error::TryRecvError::Lagged(n)) => {
384 assert!(n > 0, "should report missed events");
385 }
386 Ok(_) => {
387 }
389 Err(e) => {
390 panic!("unexpected error: {:?}", e);
391 }
392 }
393 }
394
395 #[test]
396 fn test_batch_mode_groups_events() {
397 let config = WebSocketConfig {
399 capacity: 1000,
400 batch_interval_ms: Some(50),
401 max_batch_size: 10,
402 };
403 let manager = WebSocketManager::with_config(config.clone());
404 assert_eq!(manager.config.batch_interval_ms, Some(50));
405 assert_eq!(manager.config.max_batch_size, 10);
406
407 let rt = tokio::runtime::Builder::new_current_thread()
409 .enable_all()
410 .build()
411 .unwrap();
412
413 rt.block_on(async {
414 let events: Vec<serde_json::Value> = (0..3)
416 .map(|_| serde_json::to_value(create_test_event()).unwrap())
417 .collect();
418
419 let json_array = serde_json::Value::Array(events);
420 let serialized = serde_json::to_string(&json_array).unwrap();
421 let parsed: Vec<serde_json::Value> = serde_json::from_str(&serialized).unwrap();
422 assert_eq!(parsed.len(), 3);
423 });
424 }
425
426 #[test]
427 fn test_batch_flush_on_max_size() {
428 let config = WebSocketConfig {
430 capacity: 1000,
431 batch_interval_ms: Some(1000), max_batch_size: 5, };
434 let manager = WebSocketManager::with_config(config);
435 assert_eq!(manager.config.max_batch_size, 5);
436 }
437
438 #[test]
439 fn test_backward_compat_no_config() {
440 let manager = WebSocketManager::new();
442 assert_eq!(manager.config.capacity, 1000);
443 assert!(manager.config.batch_interval_ms.is_none());
444
445 let event = Arc::new(create_test_event());
447 manager.broadcast_event(event);
448
449 let stats = manager.stats();
450 assert_eq!(stats.connected_clients, 0);
451 }
452}