ipfrs_network/
dht.rs

1//! Distributed Hash Table (Kademlia DHT) implementation
2//!
3//! Provides DHT operations including:
4//! - Peer discovery and routing
5//! - Content provider management
6//! - Query result caching
7//! - Automatic provider record refresh
8
9use cid::Cid;
10use dashmap::DashMap;
11use ipfrs_core::error::{Error, Result};
12use libp2p::PeerId;
13use parking_lot::RwLock;
14use std::collections::HashMap;
15use std::sync::Arc;
16use std::time::{Duration, Instant};
17use tokio::sync::mpsc;
18use tracing::{debug, info};
19
20/// Default provider record TTL (24 hours)
21const DEFAULT_PROVIDER_TTL: Duration = Duration::from_secs(24 * 60 * 60);
22
23/// Default query cache TTL (5 minutes)
24const DEFAULT_QUERY_CACHE_TTL: Duration = Duration::from_secs(5 * 60);
25
26/// DHT configuration
27#[derive(Debug, Clone)]
28pub struct DhtConfig {
29    /// Provider record TTL
30    pub provider_ttl: Duration,
31    /// Query cache TTL
32    pub query_cache_ttl: Duration,
33    /// Enable automatic provider refresh
34    pub enable_provider_refresh: bool,
35    /// Provider refresh interval (should be < provider_ttl)
36    pub provider_refresh_interval: Duration,
37    /// Maximum cached queries
38    pub max_cached_queries: usize,
39}
40
41impl Default for DhtConfig {
42    fn default() -> Self {
43        Self {
44            provider_ttl: DEFAULT_PROVIDER_TTL,
45            query_cache_ttl: DEFAULT_QUERY_CACHE_TTL,
46            enable_provider_refresh: true,
47            provider_refresh_interval: Duration::from_secs(12 * 60 * 60), // 12 hours
48            max_cached_queries: 10_000,
49        }
50    }
51}
52
53/// Cached query result
54#[derive(Debug, Clone)]
55struct CachedQuery {
56    /// Result peers
57    peers: Vec<PeerId>,
58    /// Timestamp when cached
59    cached_at: Instant,
60    /// Number of times this result was used
61    hit_count: usize,
62}
63
64/// Provider record for refresh tracking
65#[derive(Debug, Clone)]
66struct ProviderRecord {
67    /// Content ID
68    cid: Cid,
69    /// Last announcement time
70    last_announced: Instant,
71}
72
73/// DHT manager for peer and content discovery
74pub struct DhtManager {
75    config: DhtConfig,
76    /// Query result cache (CID -> cached result)
77    query_cache: Arc<DashMap<String, CachedQuery>>,
78    /// Peer routing cache (PeerId -> known addresses count)
79    peer_cache: Arc<DashMap<PeerId, Instant>>,
80    /// Provider records to refresh
81    provider_records: Arc<RwLock<HashMap<String, ProviderRecord>>>,
82    /// Statistics
83    stats: Arc<RwLock<DhtStats>>,
84    /// Refresh task handle
85    refresh_handle: Option<tokio::task::JoinHandle<()>>,
86    /// Command sender for refresh task
87    cmd_tx: Option<mpsc::Sender<DhtCommand>>,
88}
89
90/// DHT statistics
91#[derive(Debug, Clone, Default, serde::Serialize)]
92pub struct DhtStats {
93    /// Total queries performed
94    pub total_queries: u64,
95    /// Cache hits
96    pub cache_hits: u64,
97    /// Cache misses
98    pub cache_misses: u64,
99    /// Total provider refreshes
100    pub provider_refreshes: u64,
101    /// Active provider records
102    pub active_providers: usize,
103    /// Cached queries count
104    pub cached_queries: usize,
105    /// Cached peers count
106    pub cached_peers: usize,
107    /// Successful queries
108    pub successful_queries: u64,
109    /// Failed queries
110    pub failed_queries: u64,
111}
112
113/// DHT health status
114#[derive(Debug, Clone, serde::Serialize)]
115pub struct DhtHealth {
116    /// Overall health score (0.0 - 1.0)
117    pub health_score: f64,
118    /// Query success rate (0.0 - 1.0)
119    pub query_success_rate: f64,
120    /// Cache hit rate (0.0 - 1.0)
121    pub cache_hit_rate: f64,
122    /// Number of cached peers
123    pub peer_count: usize,
124    /// Number of cached queries
125    pub cached_query_count: usize,
126    /// Number of active provider records
127    pub provider_count: usize,
128    /// Health status description
129    pub status: DhtHealthStatus,
130}
131
132/// DHT health status enum
133#[derive(Debug, Clone, PartialEq, serde::Serialize)]
134pub enum DhtHealthStatus {
135    /// DHT is healthy
136    Healthy,
137    /// DHT has degraded performance
138    Degraded,
139    /// DHT is unhealthy
140    Unhealthy,
141    /// Not enough data to determine health
142    Unknown,
143}
144
145/// DHT command for background task
146pub(crate) enum DhtCommand {
147    /// Add a provider record to track
148    TrackProvider { cid: Cid },
149    /// Stop tracking a provider
150    StopTracking { cid: String },
151    /// Refresh all providers (returns sender for refresh requests)
152    #[allow(dead_code)]
153    RefreshProviders { response_tx: mpsc::Sender<Vec<Cid>> },
154    /// Shutdown the refresh task
155    Shutdown,
156}
157
158impl DhtManager {
159    /// Create a new DHT manager
160    pub fn new(config: DhtConfig) -> Self {
161        let manager = Self {
162            config,
163            query_cache: Arc::new(DashMap::new()),
164            peer_cache: Arc::new(DashMap::new()),
165            provider_records: Arc::new(RwLock::new(HashMap::new())),
166            stats: Arc::new(RwLock::new(DhtStats::default())),
167            refresh_handle: None,
168            cmd_tx: None,
169        };
170
171        info!(
172            "DHT Manager initialized (provider_ttl={:?}, query_cache_ttl={:?})",
173            manager.config.provider_ttl, manager.config.query_cache_ttl
174        );
175
176        manager
177    }
178
179    /// Start the provider refresh background task
180    pub fn start_provider_refresh(&mut self) {
181        if !self.config.enable_provider_refresh {
182            info!("Provider refresh disabled");
183            return;
184        }
185
186        let (cmd_tx, mut cmd_rx) = mpsc::channel::<DhtCommand>(100);
187        let provider_records = Arc::clone(&self.provider_records);
188        let stats = Arc::clone(&self.stats);
189        let refresh_interval = self.config.provider_refresh_interval;
190
191        let handle = tokio::spawn(async move {
192            info!(
193                "Starting provider refresh task (interval={:?})",
194                refresh_interval
195            );
196
197            let mut interval = tokio::time::interval(refresh_interval);
198
199            loop {
200                tokio::select! {
201                    _ = interval.tick() => {
202                        // Periodic refresh check
203                        let now = Instant::now();
204                        let mut refresh_needed = Vec::new();
205
206                        {
207                            let records = provider_records.read();
208                            for (key, record) in records.iter() {
209                                if now.duration_since(record.last_announced) >= refresh_interval {
210                                    refresh_needed.push((key.clone(), record.cid));
211                                }
212                            }
213                        }
214
215                        if !refresh_needed.is_empty() {
216                            info!("Refreshing {} provider records", refresh_needed.len());
217                            stats.write().provider_refreshes += refresh_needed.len() as u64;
218
219                            // Update last_announced times
220                            let mut records = provider_records.write();
221                            for (key, _cid) in refresh_needed {
222                                if let Some(record) = records.get_mut(&key) {
223                                    record.last_announced = now;
224                                }
225                            }
226                        }
227                    }
228                    Some(cmd) = cmd_rx.recv() => {
229                        match cmd {
230                            DhtCommand::TrackProvider { cid } => {
231                                let key = cid.to_string();
232                                let mut records = provider_records.write();
233                                records.insert(key.clone(), ProviderRecord {
234                                    cid,
235                                    last_announced: Instant::now(),
236                                });
237                                debug!("Tracking provider record: {}", key);
238                                stats.write().active_providers = records.len();
239                            }
240                            DhtCommand::StopTracking { cid } => {
241                                let mut records = provider_records.write();
242                                records.remove(&cid);
243                                debug!("Stopped tracking provider: {}", cid);
244                                stats.write().active_providers = records.len();
245                            }
246                            DhtCommand::RefreshProviders { response_tx } => {
247                                let cids: Vec<Cid> = {
248                                    let records = provider_records.read();
249                                    records.values().map(|r| r.cid).collect()
250                                };
251                                let _ = response_tx.send(cids).await;
252                            }
253                            DhtCommand::Shutdown => {
254                                info!("Shutting down provider refresh task");
255                                break;
256                            }
257                        }
258                    }
259                }
260            }
261        });
262
263        self.refresh_handle = Some(handle);
264        self.cmd_tx = Some(cmd_tx);
265    }
266
267    /// Track a provider record for automatic refresh
268    pub async fn track_provider(&self, cid: Cid) -> Result<()> {
269        if let Some(cmd_tx) = &self.cmd_tx {
270            cmd_tx
271                .send(DhtCommand::TrackProvider { cid })
272                .await
273                .map_err(|e| Error::Network(format!("Failed to track provider: {}", e)))?;
274        }
275        Ok(())
276    }
277
278    /// Stop tracking a provider record
279    pub async fn stop_tracking(&self, cid: &Cid) -> Result<()> {
280        if let Some(cmd_tx) = &self.cmd_tx {
281            cmd_tx
282                .send(DhtCommand::StopTracking {
283                    cid: cid.to_string(),
284                })
285                .await
286                .map_err(|e| Error::Network(format!("Failed to stop tracking: {}", e)))?;
287        }
288        Ok(())
289    }
290
291    /// Cache a query result
292    pub fn cache_query_result(&self, cid: &Cid, peers: Vec<PeerId>) {
293        let key = cid.to_string();
294
295        // Check cache size limit
296        if self.query_cache.len() >= self.config.max_cached_queries {
297            // Remove oldest entries (LRU-style)
298            let now = Instant::now();
299            let mut to_remove = Vec::new();
300
301            for entry in self.query_cache.iter() {
302                if now.duration_since(entry.value().cached_at) > self.config.query_cache_ttl * 2 {
303                    to_remove.push(entry.key().clone());
304                }
305            }
306
307            for key in to_remove {
308                self.query_cache.remove(&key);
309            }
310        }
311
312        self.query_cache.insert(
313            key.clone(),
314            CachedQuery {
315                peers,
316                cached_at: Instant::now(),
317                hit_count: 0,
318            },
319        );
320
321        debug!("Cached query result for {}", key);
322        self.stats.write().cached_queries = self.query_cache.len();
323    }
324
325    /// Get cached query result
326    pub fn get_cached_query(&self, cid: &Cid) -> Option<Vec<PeerId>> {
327        let key = cid.to_string();
328        let mut stats = self.stats.write();
329        stats.total_queries += 1;
330
331        if let Some(mut cached) = self.query_cache.get_mut(&key) {
332            let age = Instant::now().duration_since(cached.cached_at);
333
334            if age < self.config.query_cache_ttl {
335                cached.hit_count += 1;
336                stats.cache_hits += 1;
337                debug!(
338                    "Cache hit for {} (age={:?}, hits={})",
339                    key, age, cached.hit_count
340                );
341                return Some(cached.peers.clone());
342            } else {
343                debug!("Cache entry expired for {} (age={:?})", key, age);
344                drop(cached);
345                self.query_cache.remove(&key);
346            }
347        }
348
349        stats.cache_misses += 1;
350        None
351    }
352
353    /// Cache a peer
354    pub fn cache_peer(&self, peer_id: PeerId) {
355        self.peer_cache.insert(peer_id, Instant::now());
356        self.stats.write().cached_peers = self.peer_cache.len();
357    }
358
359    /// Check if peer is in cache
360    pub fn is_peer_cached(&self, peer_id: &PeerId) -> bool {
361        self.peer_cache.contains_key(peer_id)
362    }
363
364    /// Get all cached peers
365    pub fn get_cached_peers(&self) -> Vec<PeerId> {
366        self.peer_cache.iter().map(|entry| *entry.key()).collect()
367    }
368
369    /// Clean up expired cache entries
370    pub fn cleanup_cache(&self) {
371        let now = Instant::now();
372        let mut removed_queries = 0;
373        let mut removed_peers = 0;
374
375        // Clean query cache
376        let to_remove: Vec<String> = self
377            .query_cache
378            .iter()
379            .filter(|entry| {
380                now.duration_since(entry.value().cached_at) > self.config.query_cache_ttl
381            })
382            .map(|entry| entry.key().clone())
383            .collect();
384
385        for key in to_remove {
386            self.query_cache.remove(&key);
387            removed_queries += 1;
388        }
389
390        // Clean peer cache (expire after 1 hour)
391        let peer_ttl = Duration::from_secs(3600);
392        let to_remove: Vec<PeerId> = self
393            .peer_cache
394            .iter()
395            .filter(|entry| now.duration_since(*entry.value()) > peer_ttl)
396            .map(|entry| *entry.key())
397            .collect();
398
399        for peer_id in to_remove {
400            self.peer_cache.remove(&peer_id);
401            removed_peers += 1;
402        }
403
404        if removed_queries > 0 || removed_peers > 0 {
405            debug!(
406                "Cache cleanup: removed {} queries, {} peers",
407                removed_queries, removed_peers
408            );
409        }
410
411        let mut stats = self.stats.write();
412        stats.cached_queries = self.query_cache.len();
413        stats.cached_peers = self.peer_cache.len();
414    }
415
416    /// Get DHT statistics
417    pub fn get_stats(&self) -> DhtStats {
418        self.stats.read().clone()
419    }
420
421    /// Record a successful query
422    pub fn record_query_success(&self) {
423        self.stats.write().successful_queries += 1;
424    }
425
426    /// Record a failed query
427    pub fn record_query_failure(&self) {
428        self.stats.write().failed_queries += 1;
429    }
430
431    /// Get DHT health status
432    pub fn get_health(&self) -> DhtHealth {
433        let stats = self.stats.read();
434
435        // Calculate query success rate
436        let total_tracked_queries = stats.successful_queries + stats.failed_queries;
437        let query_success_rate = if total_tracked_queries > 0 {
438            stats.successful_queries as f64 / total_tracked_queries as f64
439        } else {
440            1.0 // No data yet, assume healthy
441        };
442
443        // Calculate cache hit rate
444        let total_cache_queries = stats.cache_hits + stats.cache_misses;
445        let cache_hit_rate = if total_cache_queries > 0 {
446            stats.cache_hits as f64 / total_cache_queries as f64
447        } else {
448            0.0
449        };
450
451        // Calculate overall health score (weighted average)
452        let health_score = if total_tracked_queries > 10 {
453            // Only calculate meaningful health if we have enough data
454            let query_weight = 0.6;
455            let cache_weight = 0.2;
456            let peer_weight = 0.2;
457
458            let peer_score = if stats.cached_peers > 0 { 1.0 } else { 0.0 };
459
460            query_success_rate * query_weight
461                + cache_hit_rate * cache_weight
462                + peer_score * peer_weight
463        } else {
464            1.0 // Not enough data, assume healthy
465        };
466
467        // Determine health status
468        let status = if total_tracked_queries < 10 {
469            DhtHealthStatus::Unknown
470        } else if health_score >= 0.8 {
471            DhtHealthStatus::Healthy
472        } else if health_score >= 0.5 {
473            DhtHealthStatus::Degraded
474        } else {
475            DhtHealthStatus::Unhealthy
476        };
477
478        DhtHealth {
479            health_score,
480            query_success_rate,
481            cache_hit_rate,
482            peer_count: stats.cached_peers,
483            cached_query_count: stats.cached_queries,
484            provider_count: stats.active_providers,
485            status,
486        }
487    }
488
489    /// Check if DHT is healthy
490    pub fn is_healthy(&self) -> bool {
491        let health = self.get_health();
492        matches!(
493            health.status,
494            DhtHealthStatus::Healthy | DhtHealthStatus::Unknown
495        )
496    }
497
498    /// Shutdown the DHT manager
499    pub async fn shutdown(&mut self) {
500        if let Some(tx) = self.cmd_tx.take() {
501            let _ = tx.send(DhtCommand::Shutdown).await;
502        }
503
504        if let Some(handle) = self.refresh_handle.take() {
505            handle.abort();
506        }
507
508        info!("DHT Manager shut down");
509    }
510}
511
512impl Drop for DhtManager {
513    fn drop(&mut self) {
514        if let Some(handle) = self.refresh_handle.take() {
515            handle.abort();
516        }
517    }
518}
519
520#[cfg(test)]
521mod tests {
522    use super::*;
523
524    #[tokio::test]
525    async fn test_dht_manager_creation() {
526        let config = DhtConfig::default();
527        let manager = DhtManager::new(config);
528        let stats = manager.get_stats();
529        assert_eq!(stats.total_queries, 0);
530        assert_eq!(stats.cache_hits, 0);
531    }
532
533    #[tokio::test]
534    async fn test_query_caching() {
535        let manager = DhtManager::new(DhtConfig::default());
536        let cid = Cid::default();
537        let peers = vec![PeerId::random(), PeerId::random()];
538
539        // Cache a result
540        manager.cache_query_result(&cid, peers.clone());
541
542        // Retrieve it
543        let cached = manager.get_cached_query(&cid);
544        assert!(cached.is_some());
545        assert_eq!(cached.unwrap().len(), peers.len());
546
547        let stats = manager.get_stats();
548        assert_eq!(stats.cache_hits, 1);
549        assert_eq!(stats.total_queries, 1);
550    }
551
552    #[tokio::test]
553    async fn test_query_cache_expiration() {
554        let config = DhtConfig {
555            query_cache_ttl: Duration::from_millis(100),
556            ..Default::default()
557        };
558
559        let manager = DhtManager::new(config);
560        let cid = Cid::default();
561        let peers = vec![PeerId::random()];
562
563        manager.cache_query_result(&cid, peers);
564
565        // Should be cached
566        assert!(manager.get_cached_query(&cid).is_some());
567
568        // Wait for expiration
569        tokio::time::sleep(Duration::from_millis(150)).await;
570
571        // Should be expired
572        assert!(manager.get_cached_query(&cid).is_none());
573    }
574
575    #[tokio::test]
576    async fn test_peer_caching() {
577        let manager = DhtManager::new(DhtConfig::default());
578        let peer1 = PeerId::random();
579        let peer2 = PeerId::random();
580
581        manager.cache_peer(peer1);
582        manager.cache_peer(peer2);
583
584        assert!(manager.is_peer_cached(&peer1));
585        assert!(manager.is_peer_cached(&peer2));
586
587        let cached_peers = manager.get_cached_peers();
588        assert_eq!(cached_peers.len(), 2);
589    }
590
591    #[tokio::test]
592    async fn test_provider_tracking() {
593        let mut manager = DhtManager::new(DhtConfig::default());
594        manager.start_provider_refresh();
595
596        let cid = Cid::default();
597        manager.track_provider(cid).await.unwrap();
598
599        // Give it a moment to process
600        tokio::time::sleep(Duration::from_millis(50)).await;
601
602        let stats = manager.get_stats();
603        assert_eq!(stats.active_providers, 1);
604
605        manager.shutdown().await;
606    }
607
608    #[tokio::test]
609    async fn test_cache_cleanup() {
610        let config = DhtConfig {
611            query_cache_ttl: Duration::from_millis(100),
612            ..Default::default()
613        };
614
615        let manager = DhtManager::new(config);
616        let cid = Cid::default();
617        let peers = vec![PeerId::random()];
618
619        manager.cache_query_result(&cid, peers);
620        assert_eq!(manager.get_stats().cached_queries, 1);
621
622        // Wait for expiration
623        tokio::time::sleep(Duration::from_millis(150)).await;
624
625        // Cleanup
626        manager.cleanup_cache();
627        assert_eq!(manager.get_stats().cached_queries, 0);
628    }
629
630    #[tokio::test]
631    async fn test_cache_size_limit() {
632        let config = DhtConfig {
633            max_cached_queries: 5,
634            ..Default::default()
635        };
636
637        let manager = DhtManager::new(config);
638
639        // Add more than the limit
640        for i in 0..10 {
641            let key = format!("test-{}", i);
642            // This is a workaround since we can't easily create different CIDs in test
643            manager.query_cache.insert(
644                key,
645                CachedQuery {
646                    peers: vec![PeerId::random()],
647                    cached_at: Instant::now(),
648                    hit_count: 0,
649                },
650            );
651        }
652
653        // Should not exceed limit significantly (with cleanup)
654        assert!(manager.query_cache.len() <= 15);
655    }
656
657    #[tokio::test]
658    async fn test_health_monitoring_unknown() {
659        let manager = DhtManager::new(DhtConfig::default());
660
661        // With no data, health should be unknown
662        let health = manager.get_health();
663        assert_eq!(health.status, DhtHealthStatus::Unknown);
664        assert!(manager.is_healthy()); // Unknown is considered healthy
665    }
666
667    #[tokio::test]
668    async fn test_health_monitoring_healthy() {
669        let manager = DhtManager::new(DhtConfig::default());
670
671        // Record successful queries
672        for _ in 0..15 {
673            manager.record_query_success();
674        }
675
676        // Add some cache hits
677        let cid = Cid::default();
678        let peers = vec![PeerId::random()];
679        manager.cache_query_result(&cid, peers);
680        let _ = manager.get_cached_query(&cid);
681
682        // Add some peers
683        manager.cache_peer(PeerId::random());
684
685        let health = manager.get_health();
686        assert_eq!(health.status, DhtHealthStatus::Healthy);
687        assert!(health.health_score >= 0.8);
688        assert_eq!(health.query_success_rate, 1.0);
689        assert!(manager.is_healthy());
690    }
691
692    #[tokio::test]
693    async fn test_health_monitoring_degraded() {
694        let manager = DhtManager::new(DhtConfig::default());
695
696        // Record mix of successful and failed queries
697        for _ in 0..7 {
698            manager.record_query_success();
699        }
700        for _ in 0..5 {
701            manager.record_query_failure();
702        }
703
704        let health = manager.get_health();
705        // With 7 success and 5 failures, success rate is ~58%, health score depends on other factors
706        assert!(health.query_success_rate > 0.5);
707        assert!(health.query_success_rate < 1.0);
708    }
709
710    #[tokio::test]
711    async fn test_health_monitoring_unhealthy() {
712        let manager = DhtManager::new(DhtConfig::default());
713
714        // Record mostly failed queries
715        for _ in 0..2 {
716            manager.record_query_success();
717        }
718        for _ in 0..10 {
719            manager.record_query_failure();
720        }
721
722        let health = manager.get_health();
723        assert!(matches!(
724            health.status,
725            DhtHealthStatus::Unhealthy | DhtHealthStatus::Degraded
726        ));
727        assert!(health.query_success_rate < 0.5);
728        assert!(!manager.is_healthy());
729    }
730
731    #[tokio::test]
732    async fn test_health_cache_hit_rate() {
733        let manager = DhtManager::new(DhtConfig::default());
734
735        // Enough queries to be measurable
736        for _ in 0..15 {
737            manager.record_query_success();
738        }
739
740        // Create a CID and cache it
741        let cid1 = Cid::default();
742        let peers = vec![PeerId::random()];
743        manager.cache_query_result(&cid1, peers);
744
745        // Cache hit
746        let _ = manager.get_cached_query(&cid1);
747
748        // Cache miss - use a string key that won't match
749        manager.stats.write().total_queries += 1;
750        manager.stats.write().cache_misses += 1;
751
752        let health = manager.get_health();
753        assert_eq!(health.cache_hit_rate, 0.5);
754    }
755}