allsource_core/infrastructure/web/
websocket.rs

1use crate::domain::entities::Event;
2use axum::extract::ws::{Message, WebSocket};
3use futures::{sink::SinkExt, stream::StreamExt};
4use parking_lot::RwLock;
5use std::collections::HashMap;
6use std::sync::Arc;
7use tokio::sync::broadcast;
8use uuid::Uuid;
9
10/// WebSocket manager for real-time event streaming (v0.2 feature)
11pub struct WebSocketManager {
12    /// Broadcast channel for sending events to all connected clients
13    event_tx: broadcast::Sender<Arc<Event>>,
14
15    /// Connected clients by ID
16    clients: Arc<RwLock<HashMap<Uuid, ClientInfo>>>,
17}
18
19#[derive(Debug, Clone)]
20struct ClientInfo {
21    id: Uuid,
22    filters: EventFilters,
23}
24
25#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)]
26pub struct EventFilters {
27    pub entity_id: Option<String>,
28    pub event_type: Option<String>,
29}
30
31impl WebSocketManager {
32    pub fn new() -> Self {
33        let (event_tx, _) = broadcast::channel(1000);
34
35        Self {
36            event_tx,
37            clients: Arc::new(RwLock::new(HashMap::new())),
38        }
39    }
40
41    /// Broadcast an event to all connected WebSocket clients
42    pub fn broadcast_event(&self, event: Arc<Event>) {
43        // Send to broadcast channel (non-blocking)
44        let _ = self.event_tx.send(event);
45    }
46
47    /// Handle a new WebSocket connection
48    pub async fn handle_socket(&self, socket: WebSocket) {
49        let client_id = Uuid::new_v4();
50        tracing::info!("🔌 WebSocket client connected: {}", client_id);
51
52        // Subscribe to broadcast channel
53        let mut event_rx = self.event_tx.subscribe();
54
55        // Split socket into sender and receiver
56        let (mut sender, mut receiver) = socket.split();
57
58        // Register client
59        self.clients.write().insert(
60            client_id,
61            ClientInfo {
62                id: client_id,
63                filters: EventFilters::default(),
64            },
65        );
66
67        // Spawn task to send events to this client
68        let clients = Arc::clone(&self.clients);
69        let send_task = tokio::spawn(async move {
70            while let Ok(event) = event_rx.recv().await {
71                // Get client filters
72                let filters = {
73                    let clients_lock = clients.read();
74                    clients_lock
75                        .get(&client_id)
76                        .map(|c| c.filters.clone())
77                        .unwrap_or_default()
78                };
79
80                // Apply filters
81                if let Some(ref entity_id) = filters.entity_id {
82                    if event.entity_id_str() != entity_id {
83                        continue;
84                    }
85                }
86
87                if let Some(ref event_type) = filters.event_type {
88                    if event.event_type_str() != event_type {
89                        continue;
90                    }
91                }
92
93                // Serialize event to JSON
94                match serde_json::to_string(&*event) {
95                    Ok(json) => {
96                        if sender.send(Message::Text(json.into())).await.is_err() {
97                            tracing::warn!("Failed to send event to client {}", client_id);
98                            break;
99                        }
100                    }
101                    Err(e) => {
102                        tracing::error!("Failed to serialize event: {}", e);
103                    }
104                }
105            }
106        });
107
108        // Handle incoming messages from client (for setting filters)
109        let clients = Arc::clone(&self.clients);
110        let recv_task = tokio::spawn(async move {
111            while let Some(Ok(msg)) = receiver.next().await {
112                if let Message::Text(text) = msg {
113                    // Parse filter commands (text is Utf8Bytes in axum 0.8+)
114                    if let Ok(filters) = serde_json::from_str::<EventFilters>(text.as_str()) {
115                        tracing::info!("Setting filters for client {}: {:?}", client_id, filters);
116                        if let Some(client) = clients.write().get_mut(&client_id) {
117                            client.filters = filters;
118                        }
119                    }
120                }
121            }
122        });
123
124        // Wait for either task to finish
125        tokio::select! {
126            _ = send_task => {
127                tracing::info!("Send task ended for client {}", client_id);
128            }
129            _ = recv_task => {
130                tracing::info!("Receive task ended for client {}", client_id);
131            }
132        }
133
134        // Clean up client
135        self.clients.write().remove(&client_id);
136        tracing::info!("🔌 WebSocket client disconnected: {}", client_id);
137    }
138
139    /// Get statistics about connected clients
140    pub fn stats(&self) -> WebSocketStats {
141        let clients = self.clients.read();
142        WebSocketStats {
143            connected_clients: clients.len(),
144            total_capacity: self.event_tx.receiver_count(),
145        }
146    }
147}
148
149impl Default for WebSocketManager {
150    fn default() -> Self {
151        Self::new()
152    }
153}
154
155#[derive(Debug, serde::Serialize)]
156pub struct WebSocketStats {
157    pub connected_clients: usize,
158    pub total_capacity: usize,
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164    use serde_json::json;
165
166    fn create_test_event() -> Event {
167        Event::reconstruct_from_strings(
168            Uuid::new_v4(),
169            "test.event".to_string(),
170            "test-entity".to_string(),
171            "default".to_string(),
172            json!({"test": "data"}),
173            chrono::Utc::now(),
174            None,
175            1,
176        )
177    }
178
179    #[test]
180    fn test_websocket_manager_creation() {
181        let manager = WebSocketManager::new();
182        let stats = manager.stats();
183        assert_eq!(stats.connected_clients, 0);
184    }
185
186    #[test]
187    fn test_event_broadcast() {
188        let manager = WebSocketManager::new();
189        let event = Arc::new(create_test_event());
190
191        // Should not panic
192        manager.broadcast_event(event);
193    }
194}