Skip to main content

allsource_core/infrastructure/web/
websocket.rs

1use crate::domain::entities::Event;
2use axum::extract::ws::{Message, WebSocket};
3use dashmap::DashMap;
4use futures::{sink::SinkExt, stream::StreamExt};
5use std::sync::Arc;
6use tokio::sync::broadcast;
7use uuid::Uuid;
8
9/// WebSocket manager for real-time event streaming (v0.2 feature)
10pub struct WebSocketManager {
11    /// Broadcast channel for sending events to all connected clients
12    event_tx: broadcast::Sender<Arc<Event>>,
13
14    /// Connected clients by ID - using DashMap for lock-free concurrent access
15    clients: Arc<DashMap<Uuid, ClientInfo>>,
16}
17
18#[derive(Debug, Clone)]
19struct ClientInfo {
20    id: Uuid,
21    filters: EventFilters,
22}
23
24#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)]
25pub struct EventFilters {
26    pub entity_id: Option<String>,
27    pub event_type: Option<String>,
28}
29
30impl WebSocketManager {
31    pub fn new() -> Self {
32        let (event_tx, _) = broadcast::channel(1000);
33
34        Self {
35            event_tx,
36            clients: Arc::new(DashMap::new()),
37        }
38    }
39
40    /// Broadcast an event to all connected WebSocket clients
41    pub fn broadcast_event(&self, event: Arc<Event>) {
42        // Send to broadcast channel (non-blocking)
43        let _ = self.event_tx.send(event);
44    }
45
46    /// Handle a new WebSocket connection
47    pub async fn handle_socket(&self, socket: WebSocket) {
48        let client_id = Uuid::new_v4();
49        tracing::info!("🔌 WebSocket client connected: {}", client_id);
50
51        // Subscribe to broadcast channel
52        let mut event_rx = self.event_tx.subscribe();
53
54        // Split socket into sender and receiver
55        let (mut sender, mut receiver) = socket.split();
56
57        // Register client
58        self.clients.insert(
59            client_id,
60            ClientInfo {
61                id: client_id,
62                filters: EventFilters::default(),
63            },
64        );
65
66        // Spawn task to send events to this client
67        let clients = Arc::clone(&self.clients);
68        let send_task = tokio::spawn(async move {
69            while let Ok(event) = event_rx.recv().await {
70                // Get client filters
71                let filters = clients
72                    .get(&client_id)
73                    .map(|entry| entry.value().filters.clone())
74                    .unwrap_or_default();
75
76                // Apply filters
77                if let Some(ref entity_id) = filters.entity_id {
78                    if event.entity_id_str() != entity_id {
79                        continue;
80                    }
81                }
82
83                if let Some(ref event_type) = filters.event_type {
84                    if event.event_type_str() != event_type {
85                        continue;
86                    }
87                }
88
89                // Serialize event to JSON
90                match serde_json::to_string(&*event) {
91                    Ok(json) => {
92                        if sender.send(Message::Text(json.into())).await.is_err() {
93                            tracing::warn!("Failed to send event to client {}", client_id);
94                            break;
95                        }
96                    }
97                    Err(e) => {
98                        tracing::error!("Failed to serialize event: {}", e);
99                    }
100                }
101            }
102        });
103
104        // Handle incoming messages from client (for setting filters)
105        let clients = Arc::clone(&self.clients);
106        let recv_task = tokio::spawn(async move {
107            while let Some(Ok(msg)) = receiver.next().await {
108                if let Message::Text(text) = msg {
109                    // Parse filter commands (text is Utf8Bytes in axum 0.8+)
110                    if let Ok(filters) = serde_json::from_str::<EventFilters>(text.as_str()) {
111                        tracing::info!("Setting filters for client {}: {:?}", client_id, filters);
112                        if let Some(mut client) = clients.get_mut(&client_id) {
113                            client.filters = filters;
114                        }
115                    }
116                }
117            }
118        });
119
120        // Wait for either task to finish
121        tokio::select! {
122            _ = send_task => {
123                tracing::info!("Send task ended for client {}", client_id);
124            }
125            _ = recv_task => {
126                tracing::info!("Receive task ended for client {}", client_id);
127            }
128        }
129
130        // Clean up client
131        self.clients.remove(&client_id);
132        tracing::info!("🔌 WebSocket client disconnected: {}", client_id);
133    }
134
135    /// Get statistics about connected clients
136    pub fn stats(&self) -> WebSocketStats {
137        WebSocketStats {
138            connected_clients: self.clients.len(),
139            total_capacity: self.event_tx.receiver_count(),
140        }
141    }
142}
143
144impl Default for WebSocketManager {
145    fn default() -> Self {
146        Self::new()
147    }
148}
149
150#[derive(Debug, serde::Serialize)]
151pub struct WebSocketStats {
152    pub connected_clients: usize,
153    pub total_capacity: usize,
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159    use serde_json::json;
160
161    fn create_test_event() -> Event {
162        Event::reconstruct_from_strings(
163            Uuid::new_v4(),
164            "test.event".to_string(),
165            "test-entity".to_string(),
166            "default".to_string(),
167            json!({"test": "data"}),
168            chrono::Utc::now(),
169            None,
170            1,
171        )
172    }
173
174    #[test]
175    fn test_websocket_manager_creation() {
176        let manager = WebSocketManager::new();
177        let stats = manager.stats();
178        assert_eq!(stats.connected_clients, 0);
179    }
180
181    #[test]
182    fn test_event_broadcast() {
183        let manager = WebSocketManager::new();
184        let event = Arc::new(create_test_event());
185
186        // Should not panic
187        manager.broadcast_event(event);
188    }
189}