allsource_core/infrastructure/web/
websocket.rs1use crate::domain::entities::Event;
2use axum::extract::ws::{Message, WebSocket};
3use dashmap::DashMap;
4use futures::{sink::SinkExt, stream::StreamExt};
5use std::sync::Arc;
6use tokio::sync::broadcast;
7use uuid::Uuid;
8
9pub struct WebSocketManager {
11 event_tx: broadcast::Sender<Arc<Event>>,
13
14 clients: Arc<DashMap<Uuid, ClientInfo>>,
16}
17
18#[derive(Debug, Clone)]
19struct ClientInfo {
20 id: Uuid,
21 filters: EventFilters,
22}
23
24#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)]
25pub struct EventFilters {
26 pub entity_id: Option<String>,
27 pub event_type: Option<String>,
28}
29
30impl WebSocketManager {
31 pub fn new() -> Self {
32 let (event_tx, _) = broadcast::channel(1000);
33
34 Self {
35 event_tx,
36 clients: Arc::new(DashMap::new()),
37 }
38 }
39
40 pub fn broadcast_event(&self, event: Arc<Event>) {
42 let _ = self.event_tx.send(event);
44 }
45
46 pub async fn handle_socket(&self, socket: WebSocket) {
48 let client_id = Uuid::new_v4();
49 tracing::info!("🔌 WebSocket client connected: {}", client_id);
50
51 let mut event_rx = self.event_tx.subscribe();
53
54 let (mut sender, mut receiver) = socket.split();
56
57 self.clients.insert(
59 client_id,
60 ClientInfo {
61 id: client_id,
62 filters: EventFilters::default(),
63 },
64 );
65
66 let clients = Arc::clone(&self.clients);
68 let send_task = tokio::spawn(async move {
69 while let Ok(event) = event_rx.recv().await {
70 let filters = clients
72 .get(&client_id)
73 .map(|entry| entry.value().filters.clone())
74 .unwrap_or_default();
75
76 if let Some(ref entity_id) = filters.entity_id {
78 if event.entity_id_str() != entity_id {
79 continue;
80 }
81 }
82
83 if let Some(ref event_type) = filters.event_type {
84 if event.event_type_str() != event_type {
85 continue;
86 }
87 }
88
89 match serde_json::to_string(&*event) {
91 Ok(json) => {
92 if sender.send(Message::Text(json.into())).await.is_err() {
93 tracing::warn!("Failed to send event to client {}", client_id);
94 break;
95 }
96 }
97 Err(e) => {
98 tracing::error!("Failed to serialize event: {}", e);
99 }
100 }
101 }
102 });
103
104 let clients = Arc::clone(&self.clients);
106 let recv_task = tokio::spawn(async move {
107 while let Some(Ok(msg)) = receiver.next().await {
108 if let Message::Text(text) = msg {
109 if let Ok(filters) = serde_json::from_str::<EventFilters>(text.as_str()) {
111 tracing::info!("Setting filters for client {}: {:?}", client_id, filters);
112 if let Some(mut client) = clients.get_mut(&client_id) {
113 client.filters = filters;
114 }
115 }
116 }
117 }
118 });
119
120 tokio::select! {
122 _ = send_task => {
123 tracing::info!("Send task ended for client {}", client_id);
124 }
125 _ = recv_task => {
126 tracing::info!("Receive task ended for client {}", client_id);
127 }
128 }
129
130 self.clients.remove(&client_id);
132 tracing::info!("🔌 WebSocket client disconnected: {}", client_id);
133 }
134
135 pub fn stats(&self) -> WebSocketStats {
137 WebSocketStats {
138 connected_clients: self.clients.len(),
139 total_capacity: self.event_tx.receiver_count(),
140 }
141 }
142}
143
144impl Default for WebSocketManager {
145 fn default() -> Self {
146 Self::new()
147 }
148}
149
150#[derive(Debug, serde::Serialize)]
151pub struct WebSocketStats {
152 pub connected_clients: usize,
153 pub total_capacity: usize,
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159 use serde_json::json;
160
161 fn create_test_event() -> Event {
162 Event::reconstruct_from_strings(
163 Uuid::new_v4(),
164 "test.event".to_string(),
165 "test-entity".to_string(),
166 "default".to_string(),
167 json!({"test": "data"}),
168 chrono::Utc::now(),
169 None,
170 1,
171 )
172 }
173
174 #[test]
175 fn test_websocket_manager_creation() {
176 let manager = WebSocketManager::new();
177 let stats = manager.stats();
178 assert_eq!(stats.connected_clients, 0);
179 }
180
181 #[test]
182 fn test_event_broadcast() {
183 let manager = WebSocketManager::new();
184 let event = Arc::new(create_test_event());
185
186 manager.broadcast_event(event);
188 }
189}