ipfrs_transport/
multicast.rs

1//! Multicast block announcements
2//!
3//! Efficient fan-out for block availability notifications:
4//! - Subscription management for interested peers
5//! - Topic-based filtering
6//! - Reduce announcement overhead
7//! - Scalable notifications
8
9use dashmap::DashMap;
10use ipfrs_core::Cid;
11use serde::{Deserialize, Deserializer, Serialize, Serializer};
12use std::collections::{HashMap, HashSet};
13use std::sync::Arc;
14use thiserror::Error;
15use tokio::sync::RwLock;
16use tracing::{debug, trace};
17
18/// Serialize CID as string
19fn serialize_cid<S>(cid: &Cid, serializer: S) -> Result<S::Ok, S::Error>
20where
21    S: Serializer,
22{
23    serializer.serialize_str(&cid.to_string())
24}
25
26/// Deserialize CID from string
27fn deserialize_cid<'de, D>(deserializer: D) -> Result<Cid, D::Error>
28where
29    D: Deserializer<'de>,
30{
31    let s = String::deserialize(deserializer)?;
32    s.parse().map_err(serde::de::Error::custom)
33}
34
35/// Error types for multicast
36#[derive(Error, Debug)]
37pub enum MulticastError {
38    #[error("Topic not found: {0}")]
39    TopicNotFound(String),
40    #[error("Subscription failed: {0}")]
41    SubscriptionFailed(String),
42    #[error("Already subscribed")]
43    AlreadySubscribed,
44    #[error("Not subscribed")]
45    NotSubscribed,
46}
47
48/// Peer identifier type
49pub type PeerId = String;
50
51/// Topic for grouping related announcements
52#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
53pub struct Topic(pub String);
54
55impl Topic {
56    /// Create a new topic
57    pub fn new(name: impl Into<String>) -> Self {
58        Topic(name.into())
59    }
60
61    /// Topic for all block announcements
62    pub fn all_blocks() -> Self {
63        Topic("blocks:all".to_string())
64    }
65
66    /// Topic for specific content type
67    pub fn content_type(content_type: &str) -> Self {
68        Topic(format!("blocks:{}", content_type))
69    }
70
71    /// Topic for tensor blocks
72    pub fn tensors() -> Self {
73        Topic("blocks:tensors".to_string())
74    }
75
76    /// Topic for gradients
77    pub fn gradients() -> Self {
78        Topic("blocks:gradients".to_string())
79    }
80}
81
82/// Block announcement message
83#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct BlockAnnouncement {
85    /// CID of the announced block
86    #[serde(serialize_with = "serialize_cid", deserialize_with = "deserialize_cid")]
87    pub cid: Cid,
88    /// Block size in bytes
89    pub size: u64,
90    /// Optional topic for filtering
91    pub topic: Option<Topic>,
92    /// Optional metadata
93    pub metadata: HashMap<String, String>,
94}
95
96impl BlockAnnouncement {
97    /// Create a new block announcement
98    pub fn new(cid: Cid, size: u64) -> Self {
99        Self {
100            cid,
101            size,
102            topic: None,
103            metadata: HashMap::new(),
104        }
105    }
106
107    /// Set the topic for this announcement
108    pub fn with_topic(mut self, topic: Topic) -> Self {
109        self.topic = Some(topic);
110        self
111    }
112
113    /// Add metadata to the announcement
114    pub fn with_metadata(mut self, key: String, value: String) -> Self {
115        self.metadata.insert(key, value);
116        self
117    }
118}
119
120/// Subscription filter
121#[derive(Clone)]
122pub enum SubscriptionFilter {
123    /// Subscribe to all announcements
124    All,
125    /// Subscribe to specific topic
126    Topic(Topic),
127    /// Subscribe to multiple topics
128    Topics(Vec<Topic>),
129    /// Subscribe to announcements matching a predicate
130    Custom(Arc<dyn Fn(&BlockAnnouncement) -> bool + Send + Sync>),
131}
132
133impl std::fmt::Debug for SubscriptionFilter {
134    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135        match self {
136            Self::All => write!(f, "SubscriptionFilter::All"),
137            Self::Topic(topic) => write!(f, "SubscriptionFilter::Topic({:?})", topic),
138            Self::Topics(topics) => write!(f, "SubscriptionFilter::Topics({:?})", topics),
139            Self::Custom(_) => write!(f, "SubscriptionFilter::Custom(<function>)"),
140        }
141    }
142}
143
144impl SubscriptionFilter {
145    /// Check if an announcement matches this filter
146    pub fn matches(&self, announcement: &BlockAnnouncement) -> bool {
147        match self {
148            SubscriptionFilter::All => true,
149            SubscriptionFilter::Topic(topic) => announcement.topic.as_ref() == Some(topic),
150            SubscriptionFilter::Topics(topics) => {
151                if let Some(ref ann_topic) = announcement.topic {
152                    topics.contains(ann_topic)
153                } else {
154                    false
155                }
156            }
157            SubscriptionFilter::Custom(predicate) => predicate(announcement),
158        }
159    }
160}
161
162/// Subscription handle
163#[derive(Debug)]
164pub struct Subscription {
165    peer_id: PeerId,
166    filter: SubscriptionFilter,
167    created_at: std::time::Instant,
168}
169
170impl Subscription {
171    /// Create a new subscription
172    pub fn new(peer_id: PeerId, filter: SubscriptionFilter) -> Self {
173        Self {
174            peer_id,
175            filter,
176            created_at: std::time::Instant::now(),
177        }
178    }
179
180    /// Check if this subscription matches an announcement
181    pub fn matches(&self, announcement: &BlockAnnouncement) -> bool {
182        self.filter.matches(announcement)
183    }
184
185    /// Get the peer ID for this subscription
186    pub fn peer_id(&self) -> &str {
187        &self.peer_id
188    }
189
190    /// Get subscription age
191    pub fn age(&self) -> std::time::Duration {
192        self.created_at.elapsed()
193    }
194}
195
196/// Multicast manager configuration
197#[derive(Debug, Clone)]
198pub struct MulticastConfig {
199    /// Maximum subscriptions per peer
200    pub max_subscriptions_per_peer: usize,
201    /// Maximum total subscriptions
202    pub max_total_subscriptions: usize,
203    /// Enable topic-based routing
204    pub enable_topic_routing: bool,
205    /// Announcement deduplication window (seconds)
206    pub dedup_window_secs: u64,
207}
208
209impl Default for MulticastConfig {
210    fn default() -> Self {
211        Self {
212            max_subscriptions_per_peer: 100,
213            max_total_subscriptions: 10000,
214            enable_topic_routing: true,
215            dedup_window_secs: 60,
216        }
217    }
218}
219
220/// Multicast statistics
221#[derive(Debug, Clone, Default)]
222pub struct MulticastStats {
223    /// Total announcements sent
224    pub announcements_sent: u64,
225    /// Total announcements received
226    pub announcements_received: u64,
227    /// Active subscriptions
228    pub active_subscriptions: usize,
229    /// Unique topics
230    pub unique_topics: usize,
231    /// Announcements filtered out
232    pub filtered_announcements: u64,
233}
234
235/// Multicast manager for efficient block announcements
236pub struct MulticastManager {
237    /// Configuration
238    config: MulticastConfig,
239    /// Active subscriptions by peer
240    subscriptions: Arc<DashMap<PeerId, Vec<Subscription>>>,
241    /// Topic index for efficient routing
242    topic_index: Arc<RwLock<HashMap<Topic, HashSet<PeerId>>>>,
243    /// Recent announcements for deduplication
244    recent_announcements: Arc<RwLock<HashMap<Cid, std::time::Instant>>>,
245    /// Statistics
246    stats: Arc<RwLock<MulticastStats>>,
247}
248
249impl MulticastManager {
250    /// Create a new multicast manager
251    pub fn new(config: MulticastConfig) -> Self {
252        Self {
253            config,
254            subscriptions: Arc::new(DashMap::new()),
255            topic_index: Arc::new(RwLock::new(HashMap::new())),
256            recent_announcements: Arc::new(RwLock::new(HashMap::new())),
257            stats: Arc::new(RwLock::new(MulticastStats::default())),
258        }
259    }
260
261    /// Subscribe a peer to announcements
262    pub async fn subscribe(
263        &self,
264        peer_id: PeerId,
265        filter: SubscriptionFilter,
266    ) -> Result<(), MulticastError> {
267        // Check subscription limits
268        let total_subs = self
269            .subscriptions
270            .iter()
271            .map(|r| r.value().len())
272            .sum::<usize>();
273        if total_subs >= self.config.max_total_subscriptions {
274            return Err(MulticastError::SubscriptionFailed(
275                "Max total subscriptions reached".to_string(),
276            ));
277        }
278
279        let subscription = Subscription::new(peer_id.clone(), filter.clone());
280
281        // Add subscription
282        self.subscriptions
283            .entry(peer_id.clone())
284            .or_default()
285            .push(subscription);
286
287        // Update topic index if topic-based routing is enabled
288        if self.config.enable_topic_routing {
289            if let SubscriptionFilter::Topic(topic) = &filter {
290                let mut index = self.topic_index.write().await;
291                index
292                    .entry(topic.clone())
293                    .or_insert_with(HashSet::new)
294                    .insert(peer_id.clone());
295            } else if let SubscriptionFilter::Topics(topics) = &filter {
296                let mut index = self.topic_index.write().await;
297                for topic in topics {
298                    index
299                        .entry(topic.clone())
300                        .or_insert_with(HashSet::new)
301                        .insert(peer_id.clone());
302                }
303            }
304        }
305
306        // Update stats
307        let mut stats = self.stats.write().await;
308        stats.active_subscriptions = self.subscriptions.iter().map(|r| r.value().len()).sum();
309
310        debug!("Peer {} subscribed with filter: {:?}", peer_id, filter);
311        Ok(())
312    }
313
314    /// Unsubscribe a peer from all announcements
315    pub async fn unsubscribe(&self, peer_id: &str) -> Result<(), MulticastError> {
316        // Remove from subscriptions
317        if self.subscriptions.remove(peer_id).is_none() {
318            return Err(MulticastError::NotSubscribed);
319        }
320
321        // Remove from topic index
322        if self.config.enable_topic_routing {
323            let mut index = self.topic_index.write().await;
324            for peers in index.values_mut() {
325                peers.remove(peer_id);
326            }
327        }
328
329        // Update stats
330        let mut stats = self.stats.write().await;
331        stats.active_subscriptions = self.subscriptions.iter().map(|r| r.value().len()).sum();
332
333        debug!("Peer {} unsubscribed", peer_id);
334        Ok(())
335    }
336
337    /// Announce a new block to subscribed peers
338    pub async fn announce(&self, announcement: BlockAnnouncement) -> Vec<PeerId> {
339        // Check for duplicate announcement
340        {
341            let mut recent = self.recent_announcements.write().await;
342            let now = std::time::Instant::now();
343
344            // Clean up old announcements
345            recent.retain(|_, timestamp| {
346                now.duration_since(*timestamp).as_secs() < self.config.dedup_window_secs
347            });
348
349            // Check if already announced recently
350            if recent.contains_key(&announcement.cid) {
351                trace!(
352                    "Skipping duplicate announcement for CID: {}",
353                    announcement.cid
354                );
355                return Vec::new();
356            }
357
358            recent.insert(announcement.cid, now);
359        }
360
361        let mut interested_peers = HashSet::new();
362
363        // Topic-based routing for efficiency
364        if self.config.enable_topic_routing {
365            if let Some(ref topic) = announcement.topic {
366                let index = self.topic_index.read().await;
367                if let Some(peers) = index.get(topic) {
368                    interested_peers.extend(peers.iter().cloned());
369                }
370            }
371        }
372
373        // Also check subscriptions that don't use topic routing
374        for entry in self.subscriptions.iter() {
375            let peer_id = entry.key();
376            let subscriptions = entry.value();
377
378            for subscription in subscriptions {
379                if subscription.matches(&announcement) {
380                    interested_peers.insert(peer_id.clone());
381                }
382            }
383        }
384
385        // Update stats
386        let mut stats = self.stats.write().await;
387        stats.announcements_sent += interested_peers.len() as u64;
388
389        trace!(
390            "Announced CID {} to {} peers",
391            announcement.cid,
392            interested_peers.len()
393        );
394
395        interested_peers.into_iter().collect()
396    }
397
398    /// Get subscriptions for a peer
399    pub fn get_subscriptions(&self, peer_id: &str) -> Option<Vec<PeerId>> {
400        self.subscriptions
401            .get(peer_id)
402            .map(|subs| vec![peer_id.to_string(); subs.len()])
403    }
404
405    /// Get statistics
406    pub async fn stats(&self) -> MulticastStats {
407        let stats = self.stats.read().await;
408        let mut result = stats.clone();
409        result.unique_topics = self.topic_index.read().await.len();
410        result
411    }
412
413    /// Clear all subscriptions
414    pub async fn clear(&self) {
415        self.subscriptions.clear();
416        self.topic_index.write().await.clear();
417        self.recent_announcements.write().await.clear();
418
419        let mut stats = self.stats.write().await;
420        stats.active_subscriptions = 0;
421        stats.unique_topics = 0;
422    }
423}
424
425#[cfg(test)]
426mod tests {
427    use super::*;
428
429    fn test_cid() -> Cid {
430        "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi"
431            .parse()
432            .unwrap()
433    }
434
435    fn test_cid2() -> Cid {
436        "bafybeihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"
437            .parse()
438            .unwrap()
439    }
440
441    #[test]
442    fn test_topic_creation() {
443        let topic = Topic::new("test");
444        assert_eq!(topic.0, "test");
445
446        let all_blocks = Topic::all_blocks();
447        assert_eq!(all_blocks.0, "blocks:all");
448
449        let tensors = Topic::tensors();
450        assert_eq!(tensors.0, "blocks:tensors");
451    }
452
453    #[test]
454    fn test_block_announcement() {
455        let cid = test_cid();
456        let announcement = BlockAnnouncement::new(cid, 1024)
457            .with_topic(Topic::tensors())
458            .with_metadata("dtype".to_string(), "float32".to_string());
459
460        assert_eq!(announcement.size, 1024);
461        assert_eq!(announcement.topic, Some(Topic::tensors()));
462        assert_eq!(announcement.metadata.get("dtype").unwrap(), "float32");
463    }
464
465    #[tokio::test]
466    async fn test_subscribe_unsubscribe() {
467        let manager = MulticastManager::new(MulticastConfig::default());
468
469        manager
470            .subscribe("peer1".to_string(), SubscriptionFilter::All)
471            .await
472            .unwrap();
473
474        let stats = manager.stats().await;
475        assert_eq!(stats.active_subscriptions, 1);
476
477        manager.unsubscribe("peer1").await.unwrap();
478
479        let stats = manager.stats().await;
480        assert_eq!(stats.active_subscriptions, 0);
481    }
482
483    #[tokio::test]
484    async fn test_topic_based_announcement() {
485        let manager = MulticastManager::new(MulticastConfig::default());
486
487        manager
488            .subscribe(
489                "peer1".to_string(),
490                SubscriptionFilter::Topic(Topic::tensors()),
491            )
492            .await
493            .unwrap();
494
495        manager
496            .subscribe(
497                "peer2".to_string(),
498                SubscriptionFilter::Topic(Topic::gradients()),
499            )
500            .await
501            .unwrap();
502
503        let cid = test_cid();
504        let announcement = BlockAnnouncement::new(cid, 1024).with_topic(Topic::tensors());
505
506        let peers = manager.announce(announcement).await;
507        assert_eq!(peers.len(), 1);
508        assert!(peers.contains(&"peer1".to_string()));
509    }
510
511    #[tokio::test]
512    async fn test_all_announcements_subscription() {
513        let manager = MulticastManager::new(MulticastConfig::default());
514
515        manager
516            .subscribe("peer1".to_string(), SubscriptionFilter::All)
517            .await
518            .unwrap();
519
520        let cid = test_cid();
521        let announcement = BlockAnnouncement::new(cid, 1024).with_topic(Topic::tensors());
522
523        let peers = manager.announce(announcement).await;
524        assert_eq!(peers.len(), 1);
525        assert!(peers.contains(&"peer1".to_string()));
526    }
527
528    #[tokio::test]
529    async fn test_multiple_topics_subscription() {
530        let manager = MulticastManager::new(MulticastConfig::default());
531
532        manager
533            .subscribe(
534                "peer1".to_string(),
535                SubscriptionFilter::Topics(vec![Topic::tensors(), Topic::gradients()]),
536            )
537            .await
538            .unwrap();
539
540        let cid1 = test_cid();
541        let announcement1 = BlockAnnouncement::new(cid1, 1024).with_topic(Topic::tensors());
542        let peers1 = manager.announce(announcement1).await;
543        assert_eq!(peers1.len(), 1);
544
545        let cid2 = test_cid2();
546        let announcement2 = BlockAnnouncement::new(cid2, 2048).with_topic(Topic::gradients());
547        let peers2 = manager.announce(announcement2).await;
548        assert_eq!(peers2.len(), 1);
549    }
550
551    #[tokio::test]
552    async fn test_deduplication() {
553        let manager = MulticastManager::new(MulticastConfig::default());
554
555        manager
556            .subscribe("peer1".to_string(), SubscriptionFilter::All)
557            .await
558            .unwrap();
559
560        let cid = test_cid();
561        let announcement = BlockAnnouncement::new(cid, 1024);
562
563        let peers1 = manager.announce(announcement.clone()).await;
564        assert_eq!(peers1.len(), 1);
565
566        // Duplicate announcement should be filtered
567        let peers2 = manager.announce(announcement).await;
568        assert_eq!(peers2.len(), 0);
569    }
570
571    #[tokio::test]
572    async fn test_subscription_limits() {
573        let config = MulticastConfig {
574            max_total_subscriptions: 2,
575            ..Default::default()
576        };
577        let manager = MulticastManager::new(config);
578
579        manager
580            .subscribe("peer1".to_string(), SubscriptionFilter::All)
581            .await
582            .unwrap();
583
584        manager
585            .subscribe("peer2".to_string(), SubscriptionFilter::All)
586            .await
587            .unwrap();
588
589        let result = manager
590            .subscribe("peer3".to_string(), SubscriptionFilter::All)
591            .await;
592        assert!(result.is_err());
593    }
594
595    #[tokio::test]
596    async fn test_clear_subscriptions() {
597        let manager = MulticastManager::new(MulticastConfig::default());
598
599        manager
600            .subscribe("peer1".to_string(), SubscriptionFilter::All)
601            .await
602            .unwrap();
603
604        manager
605            .subscribe("peer2".to_string(), SubscriptionFilter::All)
606            .await
607            .unwrap();
608
609        let stats = manager.stats().await;
610        assert_eq!(stats.active_subscriptions, 2);
611
612        manager.clear().await;
613
614        let stats = manager.stats().await;
615        assert_eq!(stats.active_subscriptions, 0);
616    }
617}