allsource_core/infrastructure/web/
websocket.rs1use crate::domain::entities::Event;
2use axum::extract::ws::{Message, WebSocket};
3use dashmap::DashMap;
4use futures::{sink::SinkExt, stream::StreamExt};
5use std::sync::Arc;
6use tokio::sync::broadcast;
7use uuid::Uuid;
8
9pub struct WebSocketManager {
11 event_tx: broadcast::Sender<Arc<Event>>,
13
14 clients: Arc<DashMap<Uuid, ClientInfo>>,
16}
17
18#[derive(Debug, Clone)]
19struct ClientInfo {
20 id: Uuid,
21 filters: EventFilters,
22}
23
24#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)]
25pub struct EventFilters {
26 pub entity_id: Option<String>,
27 pub event_type: Option<String>,
28}
29
30impl WebSocketManager {
31 pub fn new() -> Self {
32 let (event_tx, _) = broadcast::channel(1000);
33
34 Self {
35 event_tx,
36 clients: Arc::new(DashMap::new()),
37 }
38 }
39
40 pub fn broadcast_event(&self, event: Arc<Event>) {
42 let _ = self.event_tx.send(event);
44 }
45
46 pub fn subscribe_events(&self) -> broadcast::Receiver<Arc<Event>> {
48 self.event_tx.subscribe()
49 }
50
51 pub async fn handle_socket(&self, socket: WebSocket) {
53 let client_id = Uuid::new_v4();
54 tracing::info!("🔌 WebSocket client connected: {}", client_id);
55
56 let mut event_rx = self.event_tx.subscribe();
58
59 let (mut sender, mut receiver) = socket.split();
61
62 self.clients.insert(
64 client_id,
65 ClientInfo {
66 id: client_id,
67 filters: EventFilters::default(),
68 },
69 );
70
71 let clients = Arc::clone(&self.clients);
73 let send_task = tokio::spawn(async move {
74 while let Ok(event) = event_rx.recv().await {
75 let filters = clients
77 .get(&client_id)
78 .map(|entry| entry.value().filters.clone())
79 .unwrap_or_default();
80
81 if let Some(ref entity_id) = filters.entity_id
83 && event.entity_id_str() != entity_id
84 {
85 continue;
86 }
87
88 if let Some(ref event_type) = filters.event_type
89 && event.event_type_str() != event_type
90 {
91 continue;
92 }
93
94 match serde_json::to_string(&*event) {
96 Ok(json) => {
97 if sender.send(Message::Text(json.into())).await.is_err() {
98 tracing::warn!("Failed to send event to client {}", client_id);
99 break;
100 }
101 }
102 Err(e) => {
103 tracing::error!("Failed to serialize event: {}", e);
104 }
105 }
106 }
107 });
108
109 let clients = Arc::clone(&self.clients);
111 let recv_task = tokio::spawn(async move {
112 while let Some(Ok(msg)) = receiver.next().await {
113 if let Message::Text(text) = msg {
114 if let Ok(filters) = serde_json::from_str::<EventFilters>(text.as_str()) {
116 tracing::info!("Setting filters for client {}: {:?}", client_id, filters);
117 if let Some(mut client) = clients.get_mut(&client_id) {
118 client.filters = filters;
119 }
120 }
121 }
122 }
123 });
124
125 tokio::select! {
127 _ = send_task => {
128 tracing::info!("Send task ended for client {}", client_id);
129 }
130 _ = recv_task => {
131 tracing::info!("Receive task ended for client {}", client_id);
132 }
133 }
134
135 self.clients.remove(&client_id);
137 tracing::info!("🔌 WebSocket client disconnected: {}", client_id);
138 }
139
140 pub fn stats(&self) -> WebSocketStats {
142 WebSocketStats {
143 connected_clients: self.clients.len(),
144 total_capacity: self.event_tx.receiver_count(),
145 }
146 }
147}
148
149impl Default for WebSocketManager {
150 fn default() -> Self {
151 Self::new()
152 }
153}
154
155#[derive(Debug, serde::Serialize)]
156pub struct WebSocketStats {
157 pub connected_clients: usize,
158 pub total_capacity: usize,
159}
160
161#[cfg(test)]
162mod tests {
163 use super::*;
164 use serde_json::json;
165
166 fn create_test_event() -> Event {
167 Event::reconstruct_from_strings(
168 Uuid::new_v4(),
169 "test.event".to_string(),
170 "test-entity".to_string(),
171 "default".to_string(),
172 json!({"test": "data"}),
173 chrono::Utc::now(),
174 None,
175 1,
176 )
177 }
178
179 #[test]
180 fn test_websocket_manager_creation() {
181 let manager = WebSocketManager::new();
182 let stats = manager.stats();
183 assert_eq!(stats.connected_clients, 0);
184 }
185
186 #[test]
187 fn test_event_broadcast() {
188 let manager = WebSocketManager::new();
189 let event = Arc::new(create_test_event());
190
191 manager.broadcast_event(event);
193 }
194}