allsource_core/
websocket.rs1use crate::domain::entities::Event;
2use axum::extract::ws::{Message, WebSocket};
3use futures::{sink::SinkExt, stream::StreamExt};
4use parking_lot::RwLock;
5use std::collections::HashMap;
6use std::sync::Arc;
7use tokio::sync::broadcast;
8use uuid::Uuid;
9
10pub struct WebSocketManager {
12 event_tx: broadcast::Sender<Arc<Event>>,
14
15 clients: Arc<RwLock<HashMap<Uuid, ClientInfo>>>,
17}
18
19#[derive(Debug, Clone)]
20struct ClientInfo {
21 id: Uuid,
22 filters: EventFilters,
23}
24
25#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)]
26pub struct EventFilters {
27 pub entity_id: Option<String>,
28 pub event_type: Option<String>,
29}
30
31impl WebSocketManager {
32 pub fn new() -> Self {
33 let (event_tx, _) = broadcast::channel(1000);
34
35 Self {
36 event_tx,
37 clients: Arc::new(RwLock::new(HashMap::new())),
38 }
39 }
40
41 pub fn broadcast_event(&self, event: Arc<Event>) {
43 let _ = self.event_tx.send(event);
45 }
46
47 pub async fn handle_socket(&self, socket: WebSocket) {
49 let client_id = Uuid::new_v4();
50 tracing::info!("🔌 WebSocket client connected: {}", client_id);
51
52 let mut event_rx = self.event_tx.subscribe();
54
55 let (mut sender, mut receiver) = socket.split();
57
58 self.clients.write().insert(
60 client_id,
61 ClientInfo {
62 id: client_id,
63 filters: EventFilters::default(),
64 },
65 );
66
67 let clients = Arc::clone(&self.clients);
69 let send_task = tokio::spawn(async move {
70 while let Ok(event) = event_rx.recv().await {
71 let filters = {
73 let clients_lock = clients.read();
74 clients_lock
75 .get(&client_id)
76 .map(|c| c.filters.clone())
77 .unwrap_or_default()
78 };
79
80 if let Some(ref entity_id) = filters.entity_id {
82 if event.entity_id_str() != entity_id {
83 continue;
84 }
85 }
86
87 if let Some(ref event_type) = filters.event_type {
88 if event.event_type_str() != event_type {
89 continue;
90 }
91 }
92
93 match serde_json::to_string(&*event) {
95 Ok(json) => {
96 if sender.send(Message::Text(json)).await.is_err() {
97 tracing::warn!("Failed to send event to client {}", client_id);
98 break;
99 }
100 }
101 Err(e) => {
102 tracing::error!("Failed to serialize event: {}", e);
103 }
104 }
105 }
106 });
107
108 let clients = Arc::clone(&self.clients);
110 let recv_task = tokio::spawn(async move {
111 while let Some(Ok(msg)) = receiver.next().await {
112 if let Message::Text(text) = msg {
113 if let Ok(filters) = serde_json::from_str::<EventFilters>(&text) {
115 tracing::info!("Setting filters for client {}: {:?}", client_id, filters);
116 if let Some(client) = clients.write().get_mut(&client_id) {
117 client.filters = filters;
118 }
119 }
120 }
121 }
122 });
123
124 tokio::select! {
126 _ = send_task => {
127 tracing::info!("Send task ended for client {}", client_id);
128 }
129 _ = recv_task => {
130 tracing::info!("Receive task ended for client {}", client_id);
131 }
132 }
133
134 self.clients.write().remove(&client_id);
136 tracing::info!("🔌 WebSocket client disconnected: {}", client_id);
137 }
138
139 pub fn stats(&self) -> WebSocketStats {
141 let clients = self.clients.read();
142 WebSocketStats {
143 connected_clients: clients.len(),
144 total_capacity: self.event_tx.receiver_count(),
145 }
146 }
147}
148
149impl Default for WebSocketManager {
150 fn default() -> Self {
151 Self::new()
152 }
153}
154
155#[derive(Debug, serde::Serialize)]
156pub struct WebSocketStats {
157 pub connected_clients: usize,
158 pub total_capacity: usize,
159}
160
161#[cfg(test)]
162mod tests {
163 use super::*;
164 use serde_json::json;
165
166 fn create_test_event() -> Event {
167 Event::reconstruct_from_strings(
168 Uuid::new_v4(),
169 "test.event".to_string(),
170 "test-entity".to_string(),
171 "default".to_string(),
172 json!({"test": "data"}),
173 chrono::Utc::now(),
174 None,
175 1,
176 )
177 }
178
179 #[test]
180 fn test_websocket_manager_creation() {
181 let manager = WebSocketManager::new();
182 let stats = manager.stats();
183 assert_eq!(stats.connected_clients, 0);
184 }
185
186 #[test]
187 fn test_event_broadcast() {
188 let manager = WebSocketManager::new();
189 let event = Arc::new(create_test_event());
190
191 manager.broadcast_event(event);
193 }
194}