ipfrs_interface/
websocket.rs

1//! WebSocket Support for Real-Time IPFRS Communication
2//!
3//! Provides:
4//! - WebSocket upgrade handler
5//! - Message routing and handling
6//! - Pub/sub pattern for subscriptions
7//! - Real-time event notifications (block additions, peer connections, etc.)
8
9use axum::{
10    extract::{
11        ws::{Message, WebSocket, WebSocketUpgrade},
12        State,
13    },
14    response::Response,
15};
16use futures::{sink::SinkExt, stream::StreamExt};
17use serde::{Deserialize, Serialize};
18use std::collections::HashMap;
19use std::sync::Arc;
20use thiserror::Error;
21use tokio::sync::{broadcast, RwLock};
22use tracing::{debug, error, info, warn};
23use uuid::Uuid;
24
25// ============================================================================
26// WebSocket Message Types
27// ============================================================================
28
29/// WebSocket message envelope
30#[derive(Debug, Clone, Serialize, Deserialize)]
31#[serde(tag = "type", rename_all = "lowercase")]
32pub enum WsMessage {
33    /// Subscribe to a topic
34    Subscribe {
35        topic: String,
36        filter: Option<String>,
37    },
38    /// Unsubscribe from a topic
39    Unsubscribe { topic: String },
40    /// Event notification
41    Event {
42        topic: String,
43        data: serde_json::Value,
44    },
45    /// Ping message for keepalive
46    Ping,
47    /// Pong response
48    Pong,
49    /// Error message
50    Error { code: u16, message: String },
51}
52
53// ============================================================================
54// Event Types
55// ============================================================================
56
57/// Real-time event types
58#[derive(Debug, Clone, Serialize, Deserialize)]
59#[serde(tag = "event_type", rename_all = "snake_case")]
60pub enum RealtimeEvent {
61    /// Block was added to the store
62    BlockAdded {
63        cid: String,
64        size: usize,
65        timestamp: u64,
66    },
67    /// Block was removed from the store
68    BlockRemoved { cid: String, timestamp: u64 },
69    /// Peer connected
70    PeerConnected {
71        peer_id: String,
72        address: String,
73        timestamp: u64,
74    },
75    /// Peer disconnected
76    PeerDisconnected { peer_id: String, timestamp: u64 },
77    /// DHT query started
78    DhtQueryStarted { query_id: String, key: String },
79    /// DHT query progress
80    DhtQueryProgress {
81        query_id: String,
82        peers_queried: usize,
83        results_found: usize,
84    },
85    /// DHT query completed
86    DhtQueryCompleted {
87        query_id: String,
88        success: bool,
89        results: usize,
90    },
91}
92
93impl RealtimeEvent {
94    /// Get the topic for this event
95    pub fn topic(&self) -> &str {
96        match self {
97            RealtimeEvent::BlockAdded { .. } | RealtimeEvent::BlockRemoved { .. } => "blocks",
98            RealtimeEvent::PeerConnected { .. } | RealtimeEvent::PeerDisconnected { .. } => "peers",
99            RealtimeEvent::DhtQueryStarted { .. }
100            | RealtimeEvent::DhtQueryProgress { .. }
101            | RealtimeEvent::DhtQueryCompleted { .. } => "dht",
102        }
103    }
104}
105
106// ============================================================================
107// Subscription Manager
108// ============================================================================
109
110/// Manages WebSocket subscriptions and pub/sub
111#[derive(Clone)]
112pub struct SubscriptionManager {
113    /// Topic-based broadcast channels
114    topics: Arc<RwLock<HashMap<String, broadcast::Sender<RealtimeEvent>>>>,
115    /// Active subscriptions per connection
116    subscriptions: Arc<RwLock<HashMap<Uuid, Vec<String>>>>,
117}
118
119impl SubscriptionManager {
120    /// Create a new subscription manager
121    pub fn new() -> Self {
122        Self {
123            topics: Arc::new(RwLock::new(HashMap::new())),
124            subscriptions: Arc::new(RwLock::new(HashMap::new())),
125        }
126    }
127
128    /// Subscribe a connection to a topic
129    pub async fn subscribe(
130        &self,
131        connection_id: Uuid,
132        topic: String,
133    ) -> Result<broadcast::Receiver<RealtimeEvent>, WsError> {
134        let mut topics = self.topics.write().await;
135
136        // Get or create topic channel
137        let sender = topics
138            .entry(topic.clone())
139            .or_insert_with(|| {
140                let (tx, _rx) = broadcast::channel(100);
141                info!("Created new topic channel: {}", topic);
142                tx
143            })
144            .clone();
145
146        // Track subscription
147        let mut subs = self.subscriptions.write().await;
148        subs.entry(connection_id).or_default().push(topic.clone());
149
150        info!(
151            "Connection {} subscribed to topic: {}",
152            connection_id, topic
153        );
154
155        Ok(sender.subscribe())
156    }
157
158    /// Unsubscribe a connection from a topic
159    pub async fn unsubscribe(&self, connection_id: Uuid, topic: &str) {
160        let mut subs = self.subscriptions.write().await;
161        if let Some(topics) = subs.get_mut(&connection_id) {
162            topics.retain(|t| t != topic);
163            info!(
164                "Connection {} unsubscribed from topic: {}",
165                connection_id, topic
166            );
167        }
168    }
169
170    /// Remove all subscriptions for a connection
171    pub async fn remove_connection(&self, connection_id: Uuid) {
172        let mut subs = self.subscriptions.write().await;
173        subs.remove(&connection_id);
174        info!(
175            "Removed all subscriptions for connection: {}",
176            connection_id
177        );
178    }
179
180    /// Publish an event to a topic
181    pub async fn publish(&self, event: RealtimeEvent) -> Result<usize, WsError> {
182        let topic = event.topic().to_string();
183        let topics = self.topics.read().await;
184
185        if let Some(sender) = topics.get(&topic) {
186            match sender.send(event.clone()) {
187                Ok(count) => {
188                    debug!(
189                        "Published event to {} subscribers on topic: {}",
190                        count, topic
191                    );
192                    Ok(count)
193                }
194                Err(_) => {
195                    warn!("No active subscribers for topic: {}", topic);
196                    Ok(0)
197                }
198            }
199        } else {
200            debug!("Topic not found: {}", topic);
201            Ok(0)
202        }
203    }
204
205    /// Get active subscription count
206    pub async fn subscription_count(&self) -> usize {
207        let subs = self.subscriptions.read().await;
208        subs.len()
209    }
210
211    /// Get topic count
212    pub async fn topic_count(&self) -> usize {
213        let topics = self.topics.read().await;
214        topics.len()
215    }
216}
217
218impl Default for SubscriptionManager {
219    fn default() -> Self {
220        Self::new()
221    }
222}
223
224// ============================================================================
225// WebSocket Handler
226// ============================================================================
227
228/// WebSocket handler state
229#[derive(Clone)]
230pub struct WsState {
231    pub subscription_manager: SubscriptionManager,
232}
233
234impl WsState {
235    /// Create new WebSocket state
236    pub fn new() -> Self {
237        Self {
238            subscription_manager: SubscriptionManager::new(),
239        }
240    }
241}
242
243impl Default for WsState {
244    fn default() -> Self {
245        Self::new()
246    }
247}
248
249/// WebSocket upgrade handler
250///
251/// GET /ws
252pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<WsState>) -> Response {
253    ws.on_upgrade(|socket| handle_socket(socket, state))
254}
255
256/// Handle individual WebSocket connection
257#[allow(clippy::too_many_arguments)]
258async fn handle_socket(socket: WebSocket, state: WsState) {
259    let connection_id = Uuid::new_v4();
260    info!("New WebSocket connection: {}", connection_id);
261
262    let (sender, receiver) = socket.split();
263    let sender = Arc::new(tokio::sync::Mutex::new(sender));
264
265    // Subscriptions for this connection
266    let active_subscriptions: Arc<
267        tokio::sync::Mutex<HashMap<String, broadcast::Receiver<RealtimeEvent>>>,
268    > = Arc::new(tokio::sync::Mutex::new(HashMap::new()));
269
270    // Spawn task to handle outgoing events
271    let sender_clone = sender.clone();
272    let subs_clone = active_subscriptions.clone();
273    let event_task = tokio::spawn(async move {
274        loop {
275            // Check all active subscriptions for events
276            let mut subs = subs_clone.lock().await;
277            let topics: Vec<String> = subs.keys().cloned().collect();
278
279            for topic in topics {
280                if let Some(rx) = subs.get_mut(&topic) {
281                    match rx.try_recv() {
282                        Ok(event) => {
283                            let msg = WsMessage::Event {
284                                topic: topic.clone(),
285                                data: serde_json::to_value(&event).unwrap_or_default(),
286                            };
287
288                            if let Ok(json) = serde_json::to_string(&msg) {
289                                let mut tx = sender_clone.lock().await;
290                                if tx.send(Message::Text(json.into())).await.is_err() {
291                                    return;
292                                }
293                            }
294                        }
295                        Err(broadcast::error::TryRecvError::Empty) => {}
296                        Err(_) => {}
297                    }
298                }
299            }
300
301            drop(subs);
302            tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
303        }
304    });
305
306    // Handle incoming messages
307    let mut receiver = receiver;
308    while let Some(msg) = receiver.next().await {
309        match msg {
310            Ok(Message::Text(text)) => match serde_json::from_str::<WsMessage>(&text) {
311                Ok(ws_msg) => match ws_msg {
312                    WsMessage::Subscribe { topic, filter } => {
313                        debug!(
314                            "Connection {} subscribing to topic: {} (filter: {:?})",
315                            connection_id, topic, filter
316                        );
317
318                        match state
319                            .subscription_manager
320                            .subscribe(connection_id, topic.clone())
321                            .await
322                        {
323                            Ok(rx) => {
324                                let mut subs = active_subscriptions.lock().await;
325                                subs.insert(topic, rx);
326                            }
327                            Err(e) => {
328                                error!("Failed to subscribe: {}", e);
329                                let error_msg = WsMessage::Error {
330                                    code: 500,
331                                    message: format!("Subscription failed: {}", e),
332                                };
333                                if let Ok(json) = serde_json::to_string(&error_msg) {
334                                    let mut tx = sender.lock().await;
335                                    let _ = tx.send(Message::Text(json.into())).await;
336                                }
337                            }
338                        }
339                    }
340                    WsMessage::Unsubscribe { topic } => {
341                        debug!(
342                            "Connection {} unsubscribing from topic: {}",
343                            connection_id, topic
344                        );
345                        state
346                            .subscription_manager
347                            .unsubscribe(connection_id, &topic)
348                            .await;
349                        let mut subs = active_subscriptions.lock().await;
350                        subs.remove(&topic);
351                    }
352                    WsMessage::Ping => {
353                        let pong = WsMessage::Pong;
354                        if let Ok(json) = serde_json::to_string(&pong) {
355                            let mut tx = sender.lock().await;
356                            let _ = tx.send(Message::Text(json.into())).await;
357                        }
358                    }
359                    _ => {
360                        warn!("Unexpected message type from client");
361                    }
362                },
363                Err(e) => {
364                    error!("Failed to parse message: {}", e);
365                    let error_msg = WsMessage::Error {
366                        code: 400,
367                        message: format!("Invalid message format: {}", e),
368                    };
369                    if let Ok(json) = serde_json::to_string(&error_msg) {
370                        let mut tx = sender.lock().await;
371                        let _ = tx.send(Message::Text(json.into())).await;
372                    }
373                }
374            },
375            Ok(Message::Close(_)) => {
376                info!("Connection {} closed by client", connection_id);
377                break;
378            }
379            Err(e) => {
380                error!("WebSocket error: {}", e);
381                break;
382            }
383            _ => {}
384        }
385    }
386
387    // Cleanup
388    event_task.abort();
389    state
390        .subscription_manager
391        .remove_connection(connection_id)
392        .await;
393    info!("Connection {} disconnected", connection_id);
394}
395
396// ============================================================================
397// Error Types
398// ============================================================================
399
400/// WebSocket errors
401#[derive(Debug, Error)]
402pub enum WsError {
403    #[error("Subscription error: {0}")]
404    SubscriptionError(String),
405
406    #[error("Invalid topic: {0}")]
407    InvalidTopic(String),
408
409    #[error("Send error: {0}")]
410    SendError(String),
411}
412
413// ============================================================================
414// Tests
415// ============================================================================
416
417#[cfg(test)]
418mod tests {
419    use super::*;
420
421    #[tokio::test]
422    async fn test_subscription_manager_new() {
423        let manager = SubscriptionManager::new();
424        assert_eq!(manager.subscription_count().await, 0);
425        assert_eq!(manager.topic_count().await, 0);
426    }
427
428    #[tokio::test]
429    async fn test_subscribe_and_publish() {
430        let manager = SubscriptionManager::new();
431        let conn_id = Uuid::new_v4();
432
433        // Subscribe to blocks topic
434        let mut rx = manager
435            .subscribe(conn_id, "blocks".to_string())
436            .await
437            .unwrap();
438
439        // Publish an event
440        let event = RealtimeEvent::BlockAdded {
441            cid: "QmTest".to_string(),
442            size: 1024,
443            timestamp: 12345,
444        };
445
446        let count = manager.publish(event.clone()).await.unwrap();
447        assert_eq!(count, 1);
448
449        // Receive the event
450        let received = rx.recv().await.unwrap();
451        match received {
452            RealtimeEvent::BlockAdded { cid, size, .. } => {
453                assert_eq!(cid, "QmTest");
454                assert_eq!(size, 1024);
455            }
456            _ => panic!("Wrong event type"),
457        }
458    }
459
460    #[tokio::test]
461    async fn test_unsubscribe() {
462        let manager = SubscriptionManager::new();
463        let conn_id = Uuid::new_v4();
464
465        // Subscribe
466        let _rx = manager
467            .subscribe(conn_id, "blocks".to_string())
468            .await
469            .unwrap();
470        assert_eq!(manager.subscription_count().await, 1);
471
472        // Unsubscribe
473        manager.unsubscribe(conn_id, "blocks").await;
474        assert_eq!(manager.subscription_count().await, 1); // Connection still tracked
475
476        // Remove connection
477        manager.remove_connection(conn_id).await;
478        assert_eq!(manager.subscription_count().await, 0);
479    }
480
481    #[tokio::test]
482    async fn test_multiple_subscribers() {
483        let manager = SubscriptionManager::new();
484        let conn1 = Uuid::new_v4();
485        let conn2 = Uuid::new_v4();
486
487        // Subscribe both connections
488        let mut rx1 = manager
489            .subscribe(conn1, "blocks".to_string())
490            .await
491            .unwrap();
492        let mut rx2 = manager
493            .subscribe(conn2, "blocks".to_string())
494            .await
495            .unwrap();
496
497        // Publish event
498        let event = RealtimeEvent::BlockAdded {
499            cid: "QmTest".to_string(),
500            size: 2048,
501            timestamp: 12345,
502        };
503
504        let count = manager.publish(event).await.unwrap();
505        assert_eq!(count, 2); // Both subscribers receive it
506
507        // Both should receive
508        assert!(rx1.recv().await.is_ok());
509        assert!(rx2.recv().await.is_ok());
510    }
511
512    #[test]
513    fn test_realtime_event_topic() {
514        let block_event = RealtimeEvent::BlockAdded {
515            cid: "test".to_string(),
516            size: 100,
517            timestamp: 123,
518        };
519        assert_eq!(block_event.topic(), "blocks");
520
521        let peer_event = RealtimeEvent::PeerConnected {
522            peer_id: "peer1".to_string(),
523            address: "addr1".to_string(),
524            timestamp: 123,
525        };
526        assert_eq!(peer_event.topic(), "peers");
527
528        let dht_event = RealtimeEvent::DhtQueryStarted {
529            query_id: "q1".to_string(),
530            key: "key1".to_string(),
531        };
532        assert_eq!(dht_event.topic(), "dht");
533    }
534
535    #[test]
536    fn test_ws_message_serialization() {
537        let subscribe = WsMessage::Subscribe {
538            topic: "blocks".to_string(),
539            filter: Some("cid=Qm*".to_string()),
540        };
541
542        let json = serde_json::to_string(&subscribe).unwrap();
543        assert!(json.contains("subscribe"));
544        assert!(json.contains("blocks"));
545
546        let deserialized: WsMessage = serde_json::from_str(&json).unwrap();
547        match deserialized {
548            WsMessage::Subscribe { topic, .. } => {
549                assert_eq!(topic, "blocks");
550            }
551            _ => panic!("Wrong message type"),
552        }
553    }
554}