Skip to main content

allsource_core/infrastructure/web/
websocket.rs

1use crate::domain::entities::Event;
2use axum::extract::ws::{Message, WebSocket};
3use dashmap::DashMap;
4use futures::{sink::SinkExt, stream::StreamExt};
5use std::sync::Arc;
6use tokio::sync::broadcast;
7use uuid::Uuid;
8
9/// WebSocket manager for real-time event streaming (v0.2 feature)
10pub struct WebSocketManager {
11    /// Broadcast channel for sending events to all connected clients
12    event_tx: broadcast::Sender<Arc<Event>>,
13
14    /// Connected clients by ID - using DashMap for lock-free concurrent access
15    clients: Arc<DashMap<Uuid, ClientInfo>>,
16}
17
18#[derive(Debug, Clone)]
19struct ClientInfo {
20    id: Uuid,
21    filters: EventFilters,
22}
23
24#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)]
25pub struct EventFilters {
26    pub entity_id: Option<String>,
27    pub event_type: Option<String>,
28}
29
30impl WebSocketManager {
31    pub fn new() -> Self {
32        let (event_tx, _) = broadcast::channel(1000);
33
34        Self {
35            event_tx,
36            clients: Arc::new(DashMap::new()),
37        }
38    }
39
40    /// Broadcast an event to all connected WebSocket clients
41    pub fn broadcast_event(&self, event: Arc<Event>) {
42        // Send to broadcast channel (non-blocking)
43        let _ = self.event_tx.send(event);
44    }
45
46    /// Subscribe to the event broadcast channel (used by RESP3 SUBSCRIBE).
47    pub fn subscribe_events(&self) -> broadcast::Receiver<Arc<Event>> {
48        self.event_tx.subscribe()
49    }
50
51    /// Handle a new WebSocket connection
52    pub async fn handle_socket(&self, socket: WebSocket) {
53        let client_id = Uuid::new_v4();
54        tracing::info!("🔌 WebSocket client connected: {}", client_id);
55
56        // Subscribe to broadcast channel
57        let mut event_rx = self.event_tx.subscribe();
58
59        // Split socket into sender and receiver
60        let (mut sender, mut receiver) = socket.split();
61
62        // Register client
63        self.clients.insert(
64            client_id,
65            ClientInfo {
66                id: client_id,
67                filters: EventFilters::default(),
68            },
69        );
70
71        // Spawn task to send events to this client
72        let clients = Arc::clone(&self.clients);
73        let send_task = tokio::spawn(async move {
74            while let Ok(event) = event_rx.recv().await {
75                // Get client filters
76                let filters = clients
77                    .get(&client_id)
78                    .map(|entry| entry.value().filters.clone())
79                    .unwrap_or_default();
80
81                // Apply filters
82                if let Some(ref entity_id) = filters.entity_id
83                    && event.entity_id_str() != entity_id
84                {
85                    continue;
86                }
87
88                if let Some(ref event_type) = filters.event_type
89                    && event.event_type_str() != event_type
90                {
91                    continue;
92                }
93
94                // Serialize event to JSON
95                match serde_json::to_string(&*event) {
96                    Ok(json) => {
97                        if sender.send(Message::Text(json.into())).await.is_err() {
98                            tracing::warn!("Failed to send event to client {}", client_id);
99                            break;
100                        }
101                    }
102                    Err(e) => {
103                        tracing::error!("Failed to serialize event: {}", e);
104                    }
105                }
106            }
107        });
108
109        // Handle incoming messages from client (for setting filters)
110        let clients = Arc::clone(&self.clients);
111        let recv_task = tokio::spawn(async move {
112            while let Some(Ok(msg)) = receiver.next().await {
113                if let Message::Text(text) = msg {
114                    // Parse filter commands (text is Utf8Bytes in axum 0.8+)
115                    if let Ok(filters) = serde_json::from_str::<EventFilters>(text.as_str()) {
116                        tracing::info!("Setting filters for client {}: {:?}", client_id, filters);
117                        if let Some(mut client) = clients.get_mut(&client_id) {
118                            client.filters = filters;
119                        }
120                    }
121                }
122            }
123        });
124
125        // Wait for either task to finish
126        tokio::select! {
127            _ = send_task => {
128                tracing::info!("Send task ended for client {}", client_id);
129            }
130            _ = recv_task => {
131                tracing::info!("Receive task ended for client {}", client_id);
132            }
133        }
134
135        // Clean up client
136        self.clients.remove(&client_id);
137        tracing::info!("🔌 WebSocket client disconnected: {}", client_id);
138    }
139
140    /// Get statistics about connected clients
141    pub fn stats(&self) -> WebSocketStats {
142        WebSocketStats {
143            connected_clients: self.clients.len(),
144            total_capacity: self.event_tx.receiver_count(),
145        }
146    }
147}
148
149impl Default for WebSocketManager {
150    fn default() -> Self {
151        Self::new()
152    }
153}
154
155#[derive(Debug, serde::Serialize)]
156pub struct WebSocketStats {
157    pub connected_clients: usize,
158    pub total_capacity: usize,
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164    use serde_json::json;
165
166    fn create_test_event() -> Event {
167        Event::reconstruct_from_strings(
168            Uuid::new_v4(),
169            "test.event".to_string(),
170            "test-entity".to_string(),
171            "default".to_string(),
172            json!({"test": "data"}),
173            chrono::Utc::now(),
174            None,
175            1,
176        )
177    }
178
179    #[test]
180    fn test_websocket_manager_creation() {
181        let manager = WebSocketManager::new();
182        let stats = manager.stats();
183        assert_eq!(stats.connected_clients, 0);
184    }
185
186    #[test]
187    fn test_event_broadcast() {
188        let manager = WebSocketManager::new();
189        let event = Arc::new(create_test_event());
190
191        // Should not panic
192        manager.broadcast_event(event);
193    }
194}