1use axum::{
10 extract::{
11 ws::{Message, WebSocket, WebSocketUpgrade},
12 State,
13 },
14 response::Response,
15};
16use futures::{sink::SinkExt, stream::StreamExt};
17use serde::{Deserialize, Serialize};
18use std::collections::HashMap;
19use std::sync::Arc;
20use thiserror::Error;
21use tokio::sync::{broadcast, RwLock};
22use tracing::{debug, error, info, warn};
23use uuid::Uuid;
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
31#[serde(tag = "type", rename_all = "lowercase")]
32pub enum WsMessage {
33 Subscribe {
35 topic: String,
36 filter: Option<String>,
37 },
38 Unsubscribe { topic: String },
40 Event {
42 topic: String,
43 data: serde_json::Value,
44 },
45 Ping,
47 Pong,
49 Error { code: u16, message: String },
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
59#[serde(tag = "event_type", rename_all = "snake_case")]
60pub enum RealtimeEvent {
61 BlockAdded {
63 cid: String,
64 size: usize,
65 timestamp: u64,
66 },
67 BlockRemoved { cid: String, timestamp: u64 },
69 PeerConnected {
71 peer_id: String,
72 address: String,
73 timestamp: u64,
74 },
75 PeerDisconnected { peer_id: String, timestamp: u64 },
77 DhtQueryStarted { query_id: String, key: String },
79 DhtQueryProgress {
81 query_id: String,
82 peers_queried: usize,
83 results_found: usize,
84 },
85 DhtQueryCompleted {
87 query_id: String,
88 success: bool,
89 results: usize,
90 },
91}
92
93impl RealtimeEvent {
94 pub fn topic(&self) -> &str {
96 match self {
97 RealtimeEvent::BlockAdded { .. } | RealtimeEvent::BlockRemoved { .. } => "blocks",
98 RealtimeEvent::PeerConnected { .. } | RealtimeEvent::PeerDisconnected { .. } => "peers",
99 RealtimeEvent::DhtQueryStarted { .. }
100 | RealtimeEvent::DhtQueryProgress { .. }
101 | RealtimeEvent::DhtQueryCompleted { .. } => "dht",
102 }
103 }
104}
105
106#[derive(Clone)]
112pub struct SubscriptionManager {
113 topics: Arc<RwLock<HashMap<String, broadcast::Sender<RealtimeEvent>>>>,
115 subscriptions: Arc<RwLock<HashMap<Uuid, Vec<String>>>>,
117}
118
119impl SubscriptionManager {
120 pub fn new() -> Self {
122 Self {
123 topics: Arc::new(RwLock::new(HashMap::new())),
124 subscriptions: Arc::new(RwLock::new(HashMap::new())),
125 }
126 }
127
128 pub async fn subscribe(
130 &self,
131 connection_id: Uuid,
132 topic: String,
133 ) -> Result<broadcast::Receiver<RealtimeEvent>, WsError> {
134 let mut topics = self.topics.write().await;
135
136 let sender = topics
138 .entry(topic.clone())
139 .or_insert_with(|| {
140 let (tx, _rx) = broadcast::channel(100);
141 info!("Created new topic channel: {}", topic);
142 tx
143 })
144 .clone();
145
146 let mut subs = self.subscriptions.write().await;
148 subs.entry(connection_id).or_default().push(topic.clone());
149
150 info!(
151 "Connection {} subscribed to topic: {}",
152 connection_id, topic
153 );
154
155 Ok(sender.subscribe())
156 }
157
158 pub async fn unsubscribe(&self, connection_id: Uuid, topic: &str) {
160 let mut subs = self.subscriptions.write().await;
161 if let Some(topics) = subs.get_mut(&connection_id) {
162 topics.retain(|t| t != topic);
163 info!(
164 "Connection {} unsubscribed from topic: {}",
165 connection_id, topic
166 );
167 }
168 }
169
170 pub async fn remove_connection(&self, connection_id: Uuid) {
172 let mut subs = self.subscriptions.write().await;
173 subs.remove(&connection_id);
174 info!(
175 "Removed all subscriptions for connection: {}",
176 connection_id
177 );
178 }
179
180 pub async fn publish(&self, event: RealtimeEvent) -> Result<usize, WsError> {
182 let topic = event.topic().to_string();
183 let topics = self.topics.read().await;
184
185 if let Some(sender) = topics.get(&topic) {
186 match sender.send(event.clone()) {
187 Ok(count) => {
188 debug!(
189 "Published event to {} subscribers on topic: {}",
190 count, topic
191 );
192 Ok(count)
193 }
194 Err(_) => {
195 warn!("No active subscribers for topic: {}", topic);
196 Ok(0)
197 }
198 }
199 } else {
200 debug!("Topic not found: {}", topic);
201 Ok(0)
202 }
203 }
204
205 pub async fn subscription_count(&self) -> usize {
207 let subs = self.subscriptions.read().await;
208 subs.len()
209 }
210
211 pub async fn topic_count(&self) -> usize {
213 let topics = self.topics.read().await;
214 topics.len()
215 }
216}
217
218impl Default for SubscriptionManager {
219 fn default() -> Self {
220 Self::new()
221 }
222}
223
224#[derive(Clone)]
230pub struct WsState {
231 pub subscription_manager: SubscriptionManager,
232}
233
234impl WsState {
235 pub fn new() -> Self {
237 Self {
238 subscription_manager: SubscriptionManager::new(),
239 }
240 }
241}
242
243impl Default for WsState {
244 fn default() -> Self {
245 Self::new()
246 }
247}
248
249pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<WsState>) -> Response {
253 ws.on_upgrade(|socket| handle_socket(socket, state))
254}
255
256#[allow(clippy::too_many_arguments)]
258async fn handle_socket(socket: WebSocket, state: WsState) {
259 let connection_id = Uuid::new_v4();
260 info!("New WebSocket connection: {}", connection_id);
261
262 let (sender, receiver) = socket.split();
263 let sender = Arc::new(tokio::sync::Mutex::new(sender));
264
265 let active_subscriptions: Arc<
267 tokio::sync::Mutex<HashMap<String, broadcast::Receiver<RealtimeEvent>>>,
268 > = Arc::new(tokio::sync::Mutex::new(HashMap::new()));
269
270 let sender_clone = sender.clone();
272 let subs_clone = active_subscriptions.clone();
273 let event_task = tokio::spawn(async move {
274 loop {
275 let mut subs = subs_clone.lock().await;
277 let topics: Vec<String> = subs.keys().cloned().collect();
278
279 for topic in topics {
280 if let Some(rx) = subs.get_mut(&topic) {
281 match rx.try_recv() {
282 Ok(event) => {
283 let msg = WsMessage::Event {
284 topic: topic.clone(),
285 data: serde_json::to_value(&event).unwrap_or_default(),
286 };
287
288 if let Ok(json) = serde_json::to_string(&msg) {
289 let mut tx = sender_clone.lock().await;
290 if tx.send(Message::Text(json.into())).await.is_err() {
291 return;
292 }
293 }
294 }
295 Err(broadcast::error::TryRecvError::Empty) => {}
296 Err(_) => {}
297 }
298 }
299 }
300
301 drop(subs);
302 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
303 }
304 });
305
306 let mut receiver = receiver;
308 while let Some(msg) = receiver.next().await {
309 match msg {
310 Ok(Message::Text(text)) => match serde_json::from_str::<WsMessage>(&text) {
311 Ok(ws_msg) => match ws_msg {
312 WsMessage::Subscribe { topic, filter } => {
313 debug!(
314 "Connection {} subscribing to topic: {} (filter: {:?})",
315 connection_id, topic, filter
316 );
317
318 match state
319 .subscription_manager
320 .subscribe(connection_id, topic.clone())
321 .await
322 {
323 Ok(rx) => {
324 let mut subs = active_subscriptions.lock().await;
325 subs.insert(topic, rx);
326 }
327 Err(e) => {
328 error!("Failed to subscribe: {}", e);
329 let error_msg = WsMessage::Error {
330 code: 500,
331 message: format!("Subscription failed: {}", e),
332 };
333 if let Ok(json) = serde_json::to_string(&error_msg) {
334 let mut tx = sender.lock().await;
335 let _ = tx.send(Message::Text(json.into())).await;
336 }
337 }
338 }
339 }
340 WsMessage::Unsubscribe { topic } => {
341 debug!(
342 "Connection {} unsubscribing from topic: {}",
343 connection_id, topic
344 );
345 state
346 .subscription_manager
347 .unsubscribe(connection_id, &topic)
348 .await;
349 let mut subs = active_subscriptions.lock().await;
350 subs.remove(&topic);
351 }
352 WsMessage::Ping => {
353 let pong = WsMessage::Pong;
354 if let Ok(json) = serde_json::to_string(&pong) {
355 let mut tx = sender.lock().await;
356 let _ = tx.send(Message::Text(json.into())).await;
357 }
358 }
359 _ => {
360 warn!("Unexpected message type from client");
361 }
362 },
363 Err(e) => {
364 error!("Failed to parse message: {}", e);
365 let error_msg = WsMessage::Error {
366 code: 400,
367 message: format!("Invalid message format: {}", e),
368 };
369 if let Ok(json) = serde_json::to_string(&error_msg) {
370 let mut tx = sender.lock().await;
371 let _ = tx.send(Message::Text(json.into())).await;
372 }
373 }
374 },
375 Ok(Message::Close(_)) => {
376 info!("Connection {} closed by client", connection_id);
377 break;
378 }
379 Err(e) => {
380 error!("WebSocket error: {}", e);
381 break;
382 }
383 _ => {}
384 }
385 }
386
387 event_task.abort();
389 state
390 .subscription_manager
391 .remove_connection(connection_id)
392 .await;
393 info!("Connection {} disconnected", connection_id);
394}
395
396#[derive(Debug, Error)]
402pub enum WsError {
403 #[error("Subscription error: {0}")]
404 SubscriptionError(String),
405
406 #[error("Invalid topic: {0}")]
407 InvalidTopic(String),
408
409 #[error("Send error: {0}")]
410 SendError(String),
411}
412
413#[cfg(test)]
418mod tests {
419 use super::*;
420
421 #[tokio::test]
422 async fn test_subscription_manager_new() {
423 let manager = SubscriptionManager::new();
424 assert_eq!(manager.subscription_count().await, 0);
425 assert_eq!(manager.topic_count().await, 0);
426 }
427
428 #[tokio::test]
429 async fn test_subscribe_and_publish() {
430 let manager = SubscriptionManager::new();
431 let conn_id = Uuid::new_v4();
432
433 let mut rx = manager
435 .subscribe(conn_id, "blocks".to_string())
436 .await
437 .unwrap();
438
439 let event = RealtimeEvent::BlockAdded {
441 cid: "QmTest".to_string(),
442 size: 1024,
443 timestamp: 12345,
444 };
445
446 let count = manager.publish(event.clone()).await.unwrap();
447 assert_eq!(count, 1);
448
449 let received = rx.recv().await.unwrap();
451 match received {
452 RealtimeEvent::BlockAdded { cid, size, .. } => {
453 assert_eq!(cid, "QmTest");
454 assert_eq!(size, 1024);
455 }
456 _ => panic!("Wrong event type"),
457 }
458 }
459
460 #[tokio::test]
461 async fn test_unsubscribe() {
462 let manager = SubscriptionManager::new();
463 let conn_id = Uuid::new_v4();
464
465 let _rx = manager
467 .subscribe(conn_id, "blocks".to_string())
468 .await
469 .unwrap();
470 assert_eq!(manager.subscription_count().await, 1);
471
472 manager.unsubscribe(conn_id, "blocks").await;
474 assert_eq!(manager.subscription_count().await, 1); manager.remove_connection(conn_id).await;
478 assert_eq!(manager.subscription_count().await, 0);
479 }
480
481 #[tokio::test]
482 async fn test_multiple_subscribers() {
483 let manager = SubscriptionManager::new();
484 let conn1 = Uuid::new_v4();
485 let conn2 = Uuid::new_v4();
486
487 let mut rx1 = manager
489 .subscribe(conn1, "blocks".to_string())
490 .await
491 .unwrap();
492 let mut rx2 = manager
493 .subscribe(conn2, "blocks".to_string())
494 .await
495 .unwrap();
496
497 let event = RealtimeEvent::BlockAdded {
499 cid: "QmTest".to_string(),
500 size: 2048,
501 timestamp: 12345,
502 };
503
504 let count = manager.publish(event).await.unwrap();
505 assert_eq!(count, 2); assert!(rx1.recv().await.is_ok());
509 assert!(rx2.recv().await.is_ok());
510 }
511
512 #[test]
513 fn test_realtime_event_topic() {
514 let block_event = RealtimeEvent::BlockAdded {
515 cid: "test".to_string(),
516 size: 100,
517 timestamp: 123,
518 };
519 assert_eq!(block_event.topic(), "blocks");
520
521 let peer_event = RealtimeEvent::PeerConnected {
522 peer_id: "peer1".to_string(),
523 address: "addr1".to_string(),
524 timestamp: 123,
525 };
526 assert_eq!(peer_event.topic(), "peers");
527
528 let dht_event = RealtimeEvent::DhtQueryStarted {
529 query_id: "q1".to_string(),
530 key: "key1".to_string(),
531 };
532 assert_eq!(dht_event.topic(), "dht");
533 }
534
535 #[test]
536 fn test_ws_message_serialization() {
537 let subscribe = WsMessage::Subscribe {
538 topic: "blocks".to_string(),
539 filter: Some("cid=Qm*".to_string()),
540 };
541
542 let json = serde_json::to_string(&subscribe).unwrap();
543 assert!(json.contains("subscribe"));
544 assert!(json.contains("blocks"));
545
546 let deserialized: WsMessage = serde_json::from_str(&json).unwrap();
547 match deserialized {
548 WsMessage::Subscribe { topic, .. } => {
549 assert_eq!(topic, "blocks");
550 }
551 _ => panic!("Wrong message type"),
552 }
553 }
554}