1use dashmap::DashMap;
10use ipfrs_core::Cid;
11use serde::{Deserialize, Deserializer, Serialize, Serializer};
12use std::collections::{HashMap, HashSet};
13use std::sync::Arc;
14use thiserror::Error;
15use tokio::sync::RwLock;
16use tracing::{debug, trace};
17
18fn serialize_cid<S>(cid: &Cid, serializer: S) -> Result<S::Ok, S::Error>
20where
21 S: Serializer,
22{
23 serializer.serialize_str(&cid.to_string())
24}
25
26fn deserialize_cid<'de, D>(deserializer: D) -> Result<Cid, D::Error>
28where
29 D: Deserializer<'de>,
30{
31 let s = String::deserialize(deserializer)?;
32 s.parse().map_err(serde::de::Error::custom)
33}
34
35#[derive(Error, Debug)]
37pub enum MulticastError {
38 #[error("Topic not found: {0}")]
39 TopicNotFound(String),
40 #[error("Subscription failed: {0}")]
41 SubscriptionFailed(String),
42 #[error("Already subscribed")]
43 AlreadySubscribed,
44 #[error("Not subscribed")]
45 NotSubscribed,
46}
47
48pub type PeerId = String;
50
51#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
53pub struct Topic(pub String);
54
55impl Topic {
56 pub fn new(name: impl Into<String>) -> Self {
58 Topic(name.into())
59 }
60
61 pub fn all_blocks() -> Self {
63 Topic("blocks:all".to_string())
64 }
65
66 pub fn content_type(content_type: &str) -> Self {
68 Topic(format!("blocks:{}", content_type))
69 }
70
71 pub fn tensors() -> Self {
73 Topic("blocks:tensors".to_string())
74 }
75
76 pub fn gradients() -> Self {
78 Topic("blocks:gradients".to_string())
79 }
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct BlockAnnouncement {
85 #[serde(serialize_with = "serialize_cid", deserialize_with = "deserialize_cid")]
87 pub cid: Cid,
88 pub size: u64,
90 pub topic: Option<Topic>,
92 pub metadata: HashMap<String, String>,
94}
95
96impl BlockAnnouncement {
97 pub fn new(cid: Cid, size: u64) -> Self {
99 Self {
100 cid,
101 size,
102 topic: None,
103 metadata: HashMap::new(),
104 }
105 }
106
107 pub fn with_topic(mut self, topic: Topic) -> Self {
109 self.topic = Some(topic);
110 self
111 }
112
113 pub fn with_metadata(mut self, key: String, value: String) -> Self {
115 self.metadata.insert(key, value);
116 self
117 }
118}
119
120#[derive(Clone)]
122pub enum SubscriptionFilter {
123 All,
125 Topic(Topic),
127 Topics(Vec<Topic>),
129 Custom(Arc<dyn Fn(&BlockAnnouncement) -> bool + Send + Sync>),
131}
132
133impl std::fmt::Debug for SubscriptionFilter {
134 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135 match self {
136 Self::All => write!(f, "SubscriptionFilter::All"),
137 Self::Topic(topic) => write!(f, "SubscriptionFilter::Topic({:?})", topic),
138 Self::Topics(topics) => write!(f, "SubscriptionFilter::Topics({:?})", topics),
139 Self::Custom(_) => write!(f, "SubscriptionFilter::Custom(<function>)"),
140 }
141 }
142}
143
144impl SubscriptionFilter {
145 pub fn matches(&self, announcement: &BlockAnnouncement) -> bool {
147 match self {
148 SubscriptionFilter::All => true,
149 SubscriptionFilter::Topic(topic) => announcement.topic.as_ref() == Some(topic),
150 SubscriptionFilter::Topics(topics) => {
151 if let Some(ref ann_topic) = announcement.topic {
152 topics.contains(ann_topic)
153 } else {
154 false
155 }
156 }
157 SubscriptionFilter::Custom(predicate) => predicate(announcement),
158 }
159 }
160}
161
162#[derive(Debug)]
164pub struct Subscription {
165 peer_id: PeerId,
166 filter: SubscriptionFilter,
167 created_at: std::time::Instant,
168}
169
170impl Subscription {
171 pub fn new(peer_id: PeerId, filter: SubscriptionFilter) -> Self {
173 Self {
174 peer_id,
175 filter,
176 created_at: std::time::Instant::now(),
177 }
178 }
179
180 pub fn matches(&self, announcement: &BlockAnnouncement) -> bool {
182 self.filter.matches(announcement)
183 }
184
185 pub fn peer_id(&self) -> &str {
187 &self.peer_id
188 }
189
190 pub fn age(&self) -> std::time::Duration {
192 self.created_at.elapsed()
193 }
194}
195
196#[derive(Debug, Clone)]
198pub struct MulticastConfig {
199 pub max_subscriptions_per_peer: usize,
201 pub max_total_subscriptions: usize,
203 pub enable_topic_routing: bool,
205 pub dedup_window_secs: u64,
207}
208
209impl Default for MulticastConfig {
210 fn default() -> Self {
211 Self {
212 max_subscriptions_per_peer: 100,
213 max_total_subscriptions: 10000,
214 enable_topic_routing: true,
215 dedup_window_secs: 60,
216 }
217 }
218}
219
220#[derive(Debug, Clone, Default)]
222pub struct MulticastStats {
223 pub announcements_sent: u64,
225 pub announcements_received: u64,
227 pub active_subscriptions: usize,
229 pub unique_topics: usize,
231 pub filtered_announcements: u64,
233}
234
235pub struct MulticastManager {
237 config: MulticastConfig,
239 subscriptions: Arc<DashMap<PeerId, Vec<Subscription>>>,
241 topic_index: Arc<RwLock<HashMap<Topic, HashSet<PeerId>>>>,
243 recent_announcements: Arc<RwLock<HashMap<Cid, std::time::Instant>>>,
245 stats: Arc<RwLock<MulticastStats>>,
247}
248
249impl MulticastManager {
250 pub fn new(config: MulticastConfig) -> Self {
252 Self {
253 config,
254 subscriptions: Arc::new(DashMap::new()),
255 topic_index: Arc::new(RwLock::new(HashMap::new())),
256 recent_announcements: Arc::new(RwLock::new(HashMap::new())),
257 stats: Arc::new(RwLock::new(MulticastStats::default())),
258 }
259 }
260
261 pub async fn subscribe(
263 &self,
264 peer_id: PeerId,
265 filter: SubscriptionFilter,
266 ) -> Result<(), MulticastError> {
267 let total_subs = self
269 .subscriptions
270 .iter()
271 .map(|r| r.value().len())
272 .sum::<usize>();
273 if total_subs >= self.config.max_total_subscriptions {
274 return Err(MulticastError::SubscriptionFailed(
275 "Max total subscriptions reached".to_string(),
276 ));
277 }
278
279 let subscription = Subscription::new(peer_id.clone(), filter.clone());
280
281 self.subscriptions
283 .entry(peer_id.clone())
284 .or_default()
285 .push(subscription);
286
287 if self.config.enable_topic_routing {
289 if let SubscriptionFilter::Topic(topic) = &filter {
290 let mut index = self.topic_index.write().await;
291 index
292 .entry(topic.clone())
293 .or_insert_with(HashSet::new)
294 .insert(peer_id.clone());
295 } else if let SubscriptionFilter::Topics(topics) = &filter {
296 let mut index = self.topic_index.write().await;
297 for topic in topics {
298 index
299 .entry(topic.clone())
300 .or_insert_with(HashSet::new)
301 .insert(peer_id.clone());
302 }
303 }
304 }
305
306 let mut stats = self.stats.write().await;
308 stats.active_subscriptions = self.subscriptions.iter().map(|r| r.value().len()).sum();
309
310 debug!("Peer {} subscribed with filter: {:?}", peer_id, filter);
311 Ok(())
312 }
313
314 pub async fn unsubscribe(&self, peer_id: &str) -> Result<(), MulticastError> {
316 if self.subscriptions.remove(peer_id).is_none() {
318 return Err(MulticastError::NotSubscribed);
319 }
320
321 if self.config.enable_topic_routing {
323 let mut index = self.topic_index.write().await;
324 for peers in index.values_mut() {
325 peers.remove(peer_id);
326 }
327 }
328
329 let mut stats = self.stats.write().await;
331 stats.active_subscriptions = self.subscriptions.iter().map(|r| r.value().len()).sum();
332
333 debug!("Peer {} unsubscribed", peer_id);
334 Ok(())
335 }
336
337 pub async fn announce(&self, announcement: BlockAnnouncement) -> Vec<PeerId> {
339 {
341 let mut recent = self.recent_announcements.write().await;
342 let now = std::time::Instant::now();
343
344 recent.retain(|_, timestamp| {
346 now.duration_since(*timestamp).as_secs() < self.config.dedup_window_secs
347 });
348
349 if recent.contains_key(&announcement.cid) {
351 trace!(
352 "Skipping duplicate announcement for CID: {}",
353 announcement.cid
354 );
355 return Vec::new();
356 }
357
358 recent.insert(announcement.cid, now);
359 }
360
361 let mut interested_peers = HashSet::new();
362
363 if self.config.enable_topic_routing {
365 if let Some(ref topic) = announcement.topic {
366 let index = self.topic_index.read().await;
367 if let Some(peers) = index.get(topic) {
368 interested_peers.extend(peers.iter().cloned());
369 }
370 }
371 }
372
373 for entry in self.subscriptions.iter() {
375 let peer_id = entry.key();
376 let subscriptions = entry.value();
377
378 for subscription in subscriptions {
379 if subscription.matches(&announcement) {
380 interested_peers.insert(peer_id.clone());
381 }
382 }
383 }
384
385 let mut stats = self.stats.write().await;
387 stats.announcements_sent += interested_peers.len() as u64;
388
389 trace!(
390 "Announced CID {} to {} peers",
391 announcement.cid,
392 interested_peers.len()
393 );
394
395 interested_peers.into_iter().collect()
396 }
397
398 pub fn get_subscriptions(&self, peer_id: &str) -> Option<Vec<PeerId>> {
400 self.subscriptions
401 .get(peer_id)
402 .map(|subs| vec![peer_id.to_string(); subs.len()])
403 }
404
405 pub async fn stats(&self) -> MulticastStats {
407 let stats = self.stats.read().await;
408 let mut result = stats.clone();
409 result.unique_topics = self.topic_index.read().await.len();
410 result
411 }
412
413 pub async fn clear(&self) {
415 self.subscriptions.clear();
416 self.topic_index.write().await.clear();
417 self.recent_announcements.write().await.clear();
418
419 let mut stats = self.stats.write().await;
420 stats.active_subscriptions = 0;
421 stats.unique_topics = 0;
422 }
423}
424
425#[cfg(test)]
426mod tests {
427 use super::*;
428
429 fn test_cid() -> Cid {
430 "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi"
431 .parse()
432 .unwrap()
433 }
434
435 fn test_cid2() -> Cid {
436 "bafybeihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"
437 .parse()
438 .unwrap()
439 }
440
441 #[test]
442 fn test_topic_creation() {
443 let topic = Topic::new("test");
444 assert_eq!(topic.0, "test");
445
446 let all_blocks = Topic::all_blocks();
447 assert_eq!(all_blocks.0, "blocks:all");
448
449 let tensors = Topic::tensors();
450 assert_eq!(tensors.0, "blocks:tensors");
451 }
452
453 #[test]
454 fn test_block_announcement() {
455 let cid = test_cid();
456 let announcement = BlockAnnouncement::new(cid, 1024)
457 .with_topic(Topic::tensors())
458 .with_metadata("dtype".to_string(), "float32".to_string());
459
460 assert_eq!(announcement.size, 1024);
461 assert_eq!(announcement.topic, Some(Topic::tensors()));
462 assert_eq!(announcement.metadata.get("dtype").unwrap(), "float32");
463 }
464
465 #[tokio::test]
466 async fn test_subscribe_unsubscribe() {
467 let manager = MulticastManager::new(MulticastConfig::default());
468
469 manager
470 .subscribe("peer1".to_string(), SubscriptionFilter::All)
471 .await
472 .unwrap();
473
474 let stats = manager.stats().await;
475 assert_eq!(stats.active_subscriptions, 1);
476
477 manager.unsubscribe("peer1").await.unwrap();
478
479 let stats = manager.stats().await;
480 assert_eq!(stats.active_subscriptions, 0);
481 }
482
483 #[tokio::test]
484 async fn test_topic_based_announcement() {
485 let manager = MulticastManager::new(MulticastConfig::default());
486
487 manager
488 .subscribe(
489 "peer1".to_string(),
490 SubscriptionFilter::Topic(Topic::tensors()),
491 )
492 .await
493 .unwrap();
494
495 manager
496 .subscribe(
497 "peer2".to_string(),
498 SubscriptionFilter::Topic(Topic::gradients()),
499 )
500 .await
501 .unwrap();
502
503 let cid = test_cid();
504 let announcement = BlockAnnouncement::new(cid, 1024).with_topic(Topic::tensors());
505
506 let peers = manager.announce(announcement).await;
507 assert_eq!(peers.len(), 1);
508 assert!(peers.contains(&"peer1".to_string()));
509 }
510
511 #[tokio::test]
512 async fn test_all_announcements_subscription() {
513 let manager = MulticastManager::new(MulticastConfig::default());
514
515 manager
516 .subscribe("peer1".to_string(), SubscriptionFilter::All)
517 .await
518 .unwrap();
519
520 let cid = test_cid();
521 let announcement = BlockAnnouncement::new(cid, 1024).with_topic(Topic::tensors());
522
523 let peers = manager.announce(announcement).await;
524 assert_eq!(peers.len(), 1);
525 assert!(peers.contains(&"peer1".to_string()));
526 }
527
528 #[tokio::test]
529 async fn test_multiple_topics_subscription() {
530 let manager = MulticastManager::new(MulticastConfig::default());
531
532 manager
533 .subscribe(
534 "peer1".to_string(),
535 SubscriptionFilter::Topics(vec![Topic::tensors(), Topic::gradients()]),
536 )
537 .await
538 .unwrap();
539
540 let cid1 = test_cid();
541 let announcement1 = BlockAnnouncement::new(cid1, 1024).with_topic(Topic::tensors());
542 let peers1 = manager.announce(announcement1).await;
543 assert_eq!(peers1.len(), 1);
544
545 let cid2 = test_cid2();
546 let announcement2 = BlockAnnouncement::new(cid2, 2048).with_topic(Topic::gradients());
547 let peers2 = manager.announce(announcement2).await;
548 assert_eq!(peers2.len(), 1);
549 }
550
551 #[tokio::test]
552 async fn test_deduplication() {
553 let manager = MulticastManager::new(MulticastConfig::default());
554
555 manager
556 .subscribe("peer1".to_string(), SubscriptionFilter::All)
557 .await
558 .unwrap();
559
560 let cid = test_cid();
561 let announcement = BlockAnnouncement::new(cid, 1024);
562
563 let peers1 = manager.announce(announcement.clone()).await;
564 assert_eq!(peers1.len(), 1);
565
566 let peers2 = manager.announce(announcement).await;
568 assert_eq!(peers2.len(), 0);
569 }
570
571 #[tokio::test]
572 async fn test_subscription_limits() {
573 let config = MulticastConfig {
574 max_total_subscriptions: 2,
575 ..Default::default()
576 };
577 let manager = MulticastManager::new(config);
578
579 manager
580 .subscribe("peer1".to_string(), SubscriptionFilter::All)
581 .await
582 .unwrap();
583
584 manager
585 .subscribe("peer2".to_string(), SubscriptionFilter::All)
586 .await
587 .unwrap();
588
589 let result = manager
590 .subscribe("peer3".to_string(), SubscriptionFilter::All)
591 .await;
592 assert!(result.is_err());
593 }
594
595 #[tokio::test]
596 async fn test_clear_subscriptions() {
597 let manager = MulticastManager::new(MulticastConfig::default());
598
599 manager
600 .subscribe("peer1".to_string(), SubscriptionFilter::All)
601 .await
602 .unwrap();
603
604 manager
605 .subscribe("peer2".to_string(), SubscriptionFilter::All)
606 .await
607 .unwrap();
608
609 let stats = manager.stats().await;
610 assert_eq!(stats.active_subscriptions, 2);
611
612 manager.clear().await;
613
614 let stats = manager.stats().await;
615 assert_eq!(stats.active_subscriptions, 0);
616 }
617}