1use cid::Cid;
10use dashmap::DashMap;
11use ipfrs_core::error::{Error, Result};
12use libp2p::PeerId;
13use parking_lot::RwLock;
14use std::collections::HashMap;
15use std::sync::Arc;
16use std::time::{Duration, Instant};
17use tokio::sync::mpsc;
18use tracing::{debug, info};
19
20const DEFAULT_PROVIDER_TTL: Duration = Duration::from_secs(24 * 60 * 60);
22
23const DEFAULT_QUERY_CACHE_TTL: Duration = Duration::from_secs(5 * 60);
25
26#[derive(Debug, Clone)]
28pub struct DhtConfig {
29 pub provider_ttl: Duration,
31 pub query_cache_ttl: Duration,
33 pub enable_provider_refresh: bool,
35 pub provider_refresh_interval: Duration,
37 pub max_cached_queries: usize,
39}
40
41impl Default for DhtConfig {
42 fn default() -> Self {
43 Self {
44 provider_ttl: DEFAULT_PROVIDER_TTL,
45 query_cache_ttl: DEFAULT_QUERY_CACHE_TTL,
46 enable_provider_refresh: true,
47 provider_refresh_interval: Duration::from_secs(12 * 60 * 60), max_cached_queries: 10_000,
49 }
50 }
51}
52
53#[derive(Debug, Clone)]
55struct CachedQuery {
56 peers: Vec<PeerId>,
58 cached_at: Instant,
60 hit_count: usize,
62}
63
64#[derive(Debug, Clone)]
66struct ProviderRecord {
67 cid: Cid,
69 last_announced: Instant,
71}
72
73pub struct DhtManager {
75 config: DhtConfig,
76 query_cache: Arc<DashMap<String, CachedQuery>>,
78 peer_cache: Arc<DashMap<PeerId, Instant>>,
80 provider_records: Arc<RwLock<HashMap<String, ProviderRecord>>>,
82 stats: Arc<RwLock<DhtStats>>,
84 refresh_handle: Option<tokio::task::JoinHandle<()>>,
86 cmd_tx: Option<mpsc::Sender<DhtCommand>>,
88}
89
90#[derive(Debug, Clone, Default, serde::Serialize)]
92pub struct DhtStats {
93 pub total_queries: u64,
95 pub cache_hits: u64,
97 pub cache_misses: u64,
99 pub provider_refreshes: u64,
101 pub active_providers: usize,
103 pub cached_queries: usize,
105 pub cached_peers: usize,
107 pub successful_queries: u64,
109 pub failed_queries: u64,
111}
112
113#[derive(Debug, Clone, serde::Serialize)]
115pub struct DhtHealth {
116 pub health_score: f64,
118 pub query_success_rate: f64,
120 pub cache_hit_rate: f64,
122 pub peer_count: usize,
124 pub cached_query_count: usize,
126 pub provider_count: usize,
128 pub status: DhtHealthStatus,
130}
131
132#[derive(Debug, Clone, PartialEq, serde::Serialize)]
134pub enum DhtHealthStatus {
135 Healthy,
137 Degraded,
139 Unhealthy,
141 Unknown,
143}
144
145pub(crate) enum DhtCommand {
147 TrackProvider { cid: Cid },
149 StopTracking { cid: String },
151 #[allow(dead_code)]
153 RefreshProviders { response_tx: mpsc::Sender<Vec<Cid>> },
154 Shutdown,
156}
157
158impl DhtManager {
159 pub fn new(config: DhtConfig) -> Self {
161 let manager = Self {
162 config,
163 query_cache: Arc::new(DashMap::new()),
164 peer_cache: Arc::new(DashMap::new()),
165 provider_records: Arc::new(RwLock::new(HashMap::new())),
166 stats: Arc::new(RwLock::new(DhtStats::default())),
167 refresh_handle: None,
168 cmd_tx: None,
169 };
170
171 info!(
172 "DHT Manager initialized (provider_ttl={:?}, query_cache_ttl={:?})",
173 manager.config.provider_ttl, manager.config.query_cache_ttl
174 );
175
176 manager
177 }
178
179 pub fn start_provider_refresh(&mut self) {
181 if !self.config.enable_provider_refresh {
182 info!("Provider refresh disabled");
183 return;
184 }
185
186 let (cmd_tx, mut cmd_rx) = mpsc::channel::<DhtCommand>(100);
187 let provider_records = Arc::clone(&self.provider_records);
188 let stats = Arc::clone(&self.stats);
189 let refresh_interval = self.config.provider_refresh_interval;
190
191 let handle = tokio::spawn(async move {
192 info!(
193 "Starting provider refresh task (interval={:?})",
194 refresh_interval
195 );
196
197 let mut interval = tokio::time::interval(refresh_interval);
198
199 loop {
200 tokio::select! {
201 _ = interval.tick() => {
202 let now = Instant::now();
204 let mut refresh_needed = Vec::new();
205
206 {
207 let records = provider_records.read();
208 for (key, record) in records.iter() {
209 if now.duration_since(record.last_announced) >= refresh_interval {
210 refresh_needed.push((key.clone(), record.cid));
211 }
212 }
213 }
214
215 if !refresh_needed.is_empty() {
216 info!("Refreshing {} provider records", refresh_needed.len());
217 stats.write().provider_refreshes += refresh_needed.len() as u64;
218
219 let mut records = provider_records.write();
221 for (key, _cid) in refresh_needed {
222 if let Some(record) = records.get_mut(&key) {
223 record.last_announced = now;
224 }
225 }
226 }
227 }
228 Some(cmd) = cmd_rx.recv() => {
229 match cmd {
230 DhtCommand::TrackProvider { cid } => {
231 let key = cid.to_string();
232 let mut records = provider_records.write();
233 records.insert(key.clone(), ProviderRecord {
234 cid,
235 last_announced: Instant::now(),
236 });
237 debug!("Tracking provider record: {}", key);
238 stats.write().active_providers = records.len();
239 }
240 DhtCommand::StopTracking { cid } => {
241 let mut records = provider_records.write();
242 records.remove(&cid);
243 debug!("Stopped tracking provider: {}", cid);
244 stats.write().active_providers = records.len();
245 }
246 DhtCommand::RefreshProviders { response_tx } => {
247 let cids: Vec<Cid> = {
248 let records = provider_records.read();
249 records.values().map(|r| r.cid).collect()
250 };
251 let _ = response_tx.send(cids).await;
252 }
253 DhtCommand::Shutdown => {
254 info!("Shutting down provider refresh task");
255 break;
256 }
257 }
258 }
259 }
260 }
261 });
262
263 self.refresh_handle = Some(handle);
264 self.cmd_tx = Some(cmd_tx);
265 }
266
267 pub async fn track_provider(&self, cid: Cid) -> Result<()> {
269 if let Some(cmd_tx) = &self.cmd_tx {
270 cmd_tx
271 .send(DhtCommand::TrackProvider { cid })
272 .await
273 .map_err(|e| Error::Network(format!("Failed to track provider: {}", e)))?;
274 }
275 Ok(())
276 }
277
278 pub async fn stop_tracking(&self, cid: &Cid) -> Result<()> {
280 if let Some(cmd_tx) = &self.cmd_tx {
281 cmd_tx
282 .send(DhtCommand::StopTracking {
283 cid: cid.to_string(),
284 })
285 .await
286 .map_err(|e| Error::Network(format!("Failed to stop tracking: {}", e)))?;
287 }
288 Ok(())
289 }
290
291 pub fn cache_query_result(&self, cid: &Cid, peers: Vec<PeerId>) {
293 let key = cid.to_string();
294
295 if self.query_cache.len() >= self.config.max_cached_queries {
297 let now = Instant::now();
299 let mut to_remove = Vec::new();
300
301 for entry in self.query_cache.iter() {
302 if now.duration_since(entry.value().cached_at) > self.config.query_cache_ttl * 2 {
303 to_remove.push(entry.key().clone());
304 }
305 }
306
307 for key in to_remove {
308 self.query_cache.remove(&key);
309 }
310 }
311
312 self.query_cache.insert(
313 key.clone(),
314 CachedQuery {
315 peers,
316 cached_at: Instant::now(),
317 hit_count: 0,
318 },
319 );
320
321 debug!("Cached query result for {}", key);
322 self.stats.write().cached_queries = self.query_cache.len();
323 }
324
325 pub fn get_cached_query(&self, cid: &Cid) -> Option<Vec<PeerId>> {
327 let key = cid.to_string();
328 let mut stats = self.stats.write();
329 stats.total_queries += 1;
330
331 if let Some(mut cached) = self.query_cache.get_mut(&key) {
332 let age = Instant::now().duration_since(cached.cached_at);
333
334 if age < self.config.query_cache_ttl {
335 cached.hit_count += 1;
336 stats.cache_hits += 1;
337 debug!(
338 "Cache hit for {} (age={:?}, hits={})",
339 key, age, cached.hit_count
340 );
341 return Some(cached.peers.clone());
342 } else {
343 debug!("Cache entry expired for {} (age={:?})", key, age);
344 drop(cached);
345 self.query_cache.remove(&key);
346 }
347 }
348
349 stats.cache_misses += 1;
350 None
351 }
352
353 pub fn cache_peer(&self, peer_id: PeerId) {
355 self.peer_cache.insert(peer_id, Instant::now());
356 self.stats.write().cached_peers = self.peer_cache.len();
357 }
358
359 pub fn is_peer_cached(&self, peer_id: &PeerId) -> bool {
361 self.peer_cache.contains_key(peer_id)
362 }
363
364 pub fn get_cached_peers(&self) -> Vec<PeerId> {
366 self.peer_cache.iter().map(|entry| *entry.key()).collect()
367 }
368
369 pub fn cleanup_cache(&self) {
371 let now = Instant::now();
372 let mut removed_queries = 0;
373 let mut removed_peers = 0;
374
375 let to_remove: Vec<String> = self
377 .query_cache
378 .iter()
379 .filter(|entry| {
380 now.duration_since(entry.value().cached_at) > self.config.query_cache_ttl
381 })
382 .map(|entry| entry.key().clone())
383 .collect();
384
385 for key in to_remove {
386 self.query_cache.remove(&key);
387 removed_queries += 1;
388 }
389
390 let peer_ttl = Duration::from_secs(3600);
392 let to_remove: Vec<PeerId> = self
393 .peer_cache
394 .iter()
395 .filter(|entry| now.duration_since(*entry.value()) > peer_ttl)
396 .map(|entry| *entry.key())
397 .collect();
398
399 for peer_id in to_remove {
400 self.peer_cache.remove(&peer_id);
401 removed_peers += 1;
402 }
403
404 if removed_queries > 0 || removed_peers > 0 {
405 debug!(
406 "Cache cleanup: removed {} queries, {} peers",
407 removed_queries, removed_peers
408 );
409 }
410
411 let mut stats = self.stats.write();
412 stats.cached_queries = self.query_cache.len();
413 stats.cached_peers = self.peer_cache.len();
414 }
415
416 pub fn get_stats(&self) -> DhtStats {
418 self.stats.read().clone()
419 }
420
421 pub fn record_query_success(&self) {
423 self.stats.write().successful_queries += 1;
424 }
425
426 pub fn record_query_failure(&self) {
428 self.stats.write().failed_queries += 1;
429 }
430
431 pub fn get_health(&self) -> DhtHealth {
433 let stats = self.stats.read();
434
435 let total_tracked_queries = stats.successful_queries + stats.failed_queries;
437 let query_success_rate = if total_tracked_queries > 0 {
438 stats.successful_queries as f64 / total_tracked_queries as f64
439 } else {
440 1.0 };
442
443 let total_cache_queries = stats.cache_hits + stats.cache_misses;
445 let cache_hit_rate = if total_cache_queries > 0 {
446 stats.cache_hits as f64 / total_cache_queries as f64
447 } else {
448 0.0
449 };
450
451 let health_score = if total_tracked_queries > 10 {
453 let query_weight = 0.6;
455 let cache_weight = 0.2;
456 let peer_weight = 0.2;
457
458 let peer_score = if stats.cached_peers > 0 { 1.0 } else { 0.0 };
459
460 query_success_rate * query_weight
461 + cache_hit_rate * cache_weight
462 + peer_score * peer_weight
463 } else {
464 1.0 };
466
467 let status = if total_tracked_queries < 10 {
469 DhtHealthStatus::Unknown
470 } else if health_score >= 0.8 {
471 DhtHealthStatus::Healthy
472 } else if health_score >= 0.5 {
473 DhtHealthStatus::Degraded
474 } else {
475 DhtHealthStatus::Unhealthy
476 };
477
478 DhtHealth {
479 health_score,
480 query_success_rate,
481 cache_hit_rate,
482 peer_count: stats.cached_peers,
483 cached_query_count: stats.cached_queries,
484 provider_count: stats.active_providers,
485 status,
486 }
487 }
488
489 pub fn is_healthy(&self) -> bool {
491 let health = self.get_health();
492 matches!(
493 health.status,
494 DhtHealthStatus::Healthy | DhtHealthStatus::Unknown
495 )
496 }
497
498 pub async fn shutdown(&mut self) {
500 if let Some(tx) = self.cmd_tx.take() {
501 let _ = tx.send(DhtCommand::Shutdown).await;
502 }
503
504 if let Some(handle) = self.refresh_handle.take() {
505 handle.abort();
506 }
507
508 info!("DHT Manager shut down");
509 }
510}
511
512impl Drop for DhtManager {
513 fn drop(&mut self) {
514 if let Some(handle) = self.refresh_handle.take() {
515 handle.abort();
516 }
517 }
518}
519
520#[cfg(test)]
521mod tests {
522 use super::*;
523
524 #[tokio::test]
525 async fn test_dht_manager_creation() {
526 let config = DhtConfig::default();
527 let manager = DhtManager::new(config);
528 let stats = manager.get_stats();
529 assert_eq!(stats.total_queries, 0);
530 assert_eq!(stats.cache_hits, 0);
531 }
532
533 #[tokio::test]
534 async fn test_query_caching() {
535 let manager = DhtManager::new(DhtConfig::default());
536 let cid = Cid::default();
537 let peers = vec![PeerId::random(), PeerId::random()];
538
539 manager.cache_query_result(&cid, peers.clone());
541
542 let cached = manager.get_cached_query(&cid);
544 assert!(cached.is_some());
545 assert_eq!(cached.unwrap().len(), peers.len());
546
547 let stats = manager.get_stats();
548 assert_eq!(stats.cache_hits, 1);
549 assert_eq!(stats.total_queries, 1);
550 }
551
552 #[tokio::test]
553 async fn test_query_cache_expiration() {
554 let config = DhtConfig {
555 query_cache_ttl: Duration::from_millis(100),
556 ..Default::default()
557 };
558
559 let manager = DhtManager::new(config);
560 let cid = Cid::default();
561 let peers = vec![PeerId::random()];
562
563 manager.cache_query_result(&cid, peers);
564
565 assert!(manager.get_cached_query(&cid).is_some());
567
568 tokio::time::sleep(Duration::from_millis(150)).await;
570
571 assert!(manager.get_cached_query(&cid).is_none());
573 }
574
575 #[tokio::test]
576 async fn test_peer_caching() {
577 let manager = DhtManager::new(DhtConfig::default());
578 let peer1 = PeerId::random();
579 let peer2 = PeerId::random();
580
581 manager.cache_peer(peer1);
582 manager.cache_peer(peer2);
583
584 assert!(manager.is_peer_cached(&peer1));
585 assert!(manager.is_peer_cached(&peer2));
586
587 let cached_peers = manager.get_cached_peers();
588 assert_eq!(cached_peers.len(), 2);
589 }
590
591 #[tokio::test]
592 async fn test_provider_tracking() {
593 let mut manager = DhtManager::new(DhtConfig::default());
594 manager.start_provider_refresh();
595
596 let cid = Cid::default();
597 manager.track_provider(cid).await.unwrap();
598
599 tokio::time::sleep(Duration::from_millis(50)).await;
601
602 let stats = manager.get_stats();
603 assert_eq!(stats.active_providers, 1);
604
605 manager.shutdown().await;
606 }
607
608 #[tokio::test]
609 async fn test_cache_cleanup() {
610 let config = DhtConfig {
611 query_cache_ttl: Duration::from_millis(100),
612 ..Default::default()
613 };
614
615 let manager = DhtManager::new(config);
616 let cid = Cid::default();
617 let peers = vec![PeerId::random()];
618
619 manager.cache_query_result(&cid, peers);
620 assert_eq!(manager.get_stats().cached_queries, 1);
621
622 tokio::time::sleep(Duration::from_millis(150)).await;
624
625 manager.cleanup_cache();
627 assert_eq!(manager.get_stats().cached_queries, 0);
628 }
629
630 #[tokio::test]
631 async fn test_cache_size_limit() {
632 let config = DhtConfig {
633 max_cached_queries: 5,
634 ..Default::default()
635 };
636
637 let manager = DhtManager::new(config);
638
639 for i in 0..10 {
641 let key = format!("test-{}", i);
642 manager.query_cache.insert(
644 key,
645 CachedQuery {
646 peers: vec![PeerId::random()],
647 cached_at: Instant::now(),
648 hit_count: 0,
649 },
650 );
651 }
652
653 assert!(manager.query_cache.len() <= 15);
655 }
656
657 #[tokio::test]
658 async fn test_health_monitoring_unknown() {
659 let manager = DhtManager::new(DhtConfig::default());
660
661 let health = manager.get_health();
663 assert_eq!(health.status, DhtHealthStatus::Unknown);
664 assert!(manager.is_healthy()); }
666
667 #[tokio::test]
668 async fn test_health_monitoring_healthy() {
669 let manager = DhtManager::new(DhtConfig::default());
670
671 for _ in 0..15 {
673 manager.record_query_success();
674 }
675
676 let cid = Cid::default();
678 let peers = vec![PeerId::random()];
679 manager.cache_query_result(&cid, peers);
680 let _ = manager.get_cached_query(&cid);
681
682 manager.cache_peer(PeerId::random());
684
685 let health = manager.get_health();
686 assert_eq!(health.status, DhtHealthStatus::Healthy);
687 assert!(health.health_score >= 0.8);
688 assert_eq!(health.query_success_rate, 1.0);
689 assert!(manager.is_healthy());
690 }
691
692 #[tokio::test]
693 async fn test_health_monitoring_degraded() {
694 let manager = DhtManager::new(DhtConfig::default());
695
696 for _ in 0..7 {
698 manager.record_query_success();
699 }
700 for _ in 0..5 {
701 manager.record_query_failure();
702 }
703
704 let health = manager.get_health();
705 assert!(health.query_success_rate > 0.5);
707 assert!(health.query_success_rate < 1.0);
708 }
709
710 #[tokio::test]
711 async fn test_health_monitoring_unhealthy() {
712 let manager = DhtManager::new(DhtConfig::default());
713
714 for _ in 0..2 {
716 manager.record_query_success();
717 }
718 for _ in 0..10 {
719 manager.record_query_failure();
720 }
721
722 let health = manager.get_health();
723 assert!(matches!(
724 health.status,
725 DhtHealthStatus::Unhealthy | DhtHealthStatus::Degraded
726 ));
727 assert!(health.query_success_rate < 0.5);
728 assert!(!manager.is_healthy());
729 }
730
731 #[tokio::test]
732 async fn test_health_cache_hit_rate() {
733 let manager = DhtManager::new(DhtConfig::default());
734
735 for _ in 0..15 {
737 manager.record_query_success();
738 }
739
740 let cid1 = Cid::default();
742 let peers = vec![PeerId::random()];
743 manager.cache_query_result(&cid1, peers);
744
745 let _ = manager.get_cached_query(&cid1);
747
748 manager.stats.write().total_queries += 1;
750 manager.stats.write().cache_misses += 1;
751
752 let health = manager.get_health();
753 assert_eq!(health.cache_hit_rate, 0.5);
754 }
755}