1use dashmap::DashMap;
47use libp2p::PeerId;
48use serde::{Deserialize, Serialize};
49use std::collections::VecDeque;
50use std::sync::Arc;
51use std::time::Instant;
52use thiserror::Error;
53
54#[derive(Debug, Error)]
56pub enum QualityPredictorError {
57 #[error("No historical data available for peer")]
58 NoHistoricalData,
59 #[error("Invalid configuration: {0}")]
60 InvalidConfig(String),
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct QualityPredictorConfig {
66 pub max_samples: usize,
68 pub latency_weight: f64,
70 pub bandwidth_weight: f64,
72 pub reliability_weight: f64,
74 pub uptime_weight: f64,
76 pub smoothing_factor: f64,
78 pub min_acceptable_quality: f64,
80 pub switch_threshold: f64,
82 pub enable_predictions: bool,
84}
85
86impl Default for QualityPredictorConfig {
87 fn default() -> Self {
88 Self {
89 max_samples: 100,
90 latency_weight: 0.3,
91 bandwidth_weight: 0.3,
92 reliability_weight: 0.25,
93 uptime_weight: 0.15,
94 smoothing_factor: 0.2,
95 min_acceptable_quality: 0.5,
96 switch_threshold: 0.6,
97 enable_predictions: true,
98 }
99 }
100}
101
102impl QualityPredictorConfig {
103 pub fn low_latency() -> Self {
105 Self {
106 latency_weight: 0.5,
107 bandwidth_weight: 0.2,
108 reliability_weight: 0.2,
109 uptime_weight: 0.1,
110 ..Default::default()
111 }
112 }
113
114 pub fn high_bandwidth() -> Self {
116 Self {
117 latency_weight: 0.15,
118 bandwidth_weight: 0.5,
119 reliability_weight: 0.25,
120 uptime_weight: 0.1,
121 ..Default::default()
122 }
123 }
124
125 pub fn high_reliability() -> Self {
127 Self {
128 latency_weight: 0.2,
129 bandwidth_weight: 0.2,
130 reliability_weight: 0.4,
131 uptime_weight: 0.2,
132 ..Default::default()
133 }
134 }
135
136 pub fn validate(&self) -> Result<(), QualityPredictorError> {
138 if self.max_samples == 0 {
139 return Err(QualityPredictorError::InvalidConfig(
140 "max_samples must be > 0".to_string(),
141 ));
142 }
143
144 let total_weight = self.latency_weight
145 + self.bandwidth_weight
146 + self.reliability_weight
147 + self.uptime_weight;
148
149 if (total_weight - 1.0).abs() > 0.01 {
150 return Err(QualityPredictorError::InvalidConfig(format!(
151 "weights must sum to 1.0, got {}",
152 total_weight
153 )));
154 }
155
156 if !(0.0..=1.0).contains(&self.smoothing_factor) {
157 return Err(QualityPredictorError::InvalidConfig(
158 "smoothing_factor must be between 0.0 and 1.0".to_string(),
159 ));
160 }
161
162 Ok(())
163 }
164}
165
166#[derive(Debug, Clone)]
168struct ConnectionHistory {
169 latency_samples: VecDeque<u64>,
171 bandwidth_samples: VecDeque<u64>,
173 success_count: u64,
175 failure_count: u64,
177 first_seen: Instant,
179 last_seen: Instant,
181 quality_ema: Option<f64>,
183}
184
185impl ConnectionHistory {
186 fn new() -> Self {
187 let now = Instant::now();
188 Self {
189 latency_samples: VecDeque::new(),
190 bandwidth_samples: VecDeque::new(),
191 success_count: 0,
192 failure_count: 0,
193 first_seen: now,
194 last_seen: now,
195 quality_ema: None,
196 }
197 }
198
199 fn record_latency(&mut self, latency_ms: u64, max_samples: usize) {
200 if self.latency_samples.len() >= max_samples {
201 self.latency_samples.pop_front();
202 }
203 self.latency_samples.push_back(latency_ms);
204 self.last_seen = Instant::now();
205 }
206
207 fn record_bandwidth(&mut self, bytes_per_sec: u64, max_samples: usize) {
208 if self.bandwidth_samples.len() >= max_samples {
209 self.bandwidth_samples.pop_front();
210 }
211 self.bandwidth_samples.push_back(bytes_per_sec);
212 self.last_seen = Instant::now();
213 }
214
215 fn record_success(&mut self) {
216 self.success_count += 1;
217 self.last_seen = Instant::now();
218 }
219
220 fn record_failure(&mut self) {
221 self.failure_count += 1;
222 self.last_seen = Instant::now();
223 }
224
225 fn avg_latency(&self) -> Option<f64> {
226 if self.latency_samples.is_empty() {
227 None
228 } else {
229 let sum: u64 = self.latency_samples.iter().sum();
230 Some(sum as f64 / self.latency_samples.len() as f64)
231 }
232 }
233
234 fn avg_bandwidth(&self) -> Option<f64> {
235 if self.bandwidth_samples.is_empty() {
236 None
237 } else {
238 let sum: u64 = self.bandwidth_samples.iter().sum();
239 Some(sum as f64 / self.bandwidth_samples.len() as f64)
240 }
241 }
242
243 fn reliability_score(&self) -> f64 {
244 let total = self.success_count + self.failure_count;
245 if total == 0 {
246 0.5 } else {
248 self.success_count as f64 / total as f64
249 }
250 }
251
252 fn uptime_score(&self) -> f64 {
253 let total_duration = self.first_seen.elapsed().as_secs_f64();
254 if total_duration < 1.0 {
255 1.0 } else {
257 let active_duration = self.last_seen.duration_since(self.first_seen).as_secs_f64();
258 (active_duration / total_duration).min(1.0)
259 }
260 }
261}
262
263#[derive(Debug, Clone, Serialize, Deserialize)]
265pub struct QualityPrediction {
266 pub overall_score: f64,
268 pub latency_score: f64,
270 pub bandwidth_score: f64,
272 pub reliability_score: f64,
274 pub uptime_score: f64,
276 pub avg_latency_ms: Option<f64>,
278 pub avg_bandwidth_bps: Option<f64>,
280 pub is_acceptable: bool,
282 pub should_switch: bool,
284}
285
286#[derive(Debug, Clone, Serialize, Deserialize)]
288pub struct QualityPredictorStats {
289 pub tracked_peers: usize,
291 pub predictions_made: u64,
293 pub switch_recommendations: u64,
295 pub avg_quality: f64,
297}
298
299pub struct QualityPredictor {
301 config: QualityPredictorConfig,
302 history: Arc<DashMap<PeerId, ConnectionHistory>>,
303 stats: Arc<parking_lot::RwLock<QualityPredictorStats>>,
304}
305
306impl QualityPredictor {
307 pub fn new(config: QualityPredictorConfig) -> Result<Self, QualityPredictorError> {
309 config.validate()?;
310
311 Ok(Self {
312 config,
313 history: Arc::new(DashMap::new()),
314 stats: Arc::new(parking_lot::RwLock::new(QualityPredictorStats {
315 tracked_peers: 0,
316 predictions_made: 0,
317 switch_recommendations: 0,
318 avg_quality: 0.0,
319 })),
320 })
321 }
322
323 pub fn record_latency(&self, peer: PeerId, latency_ms: u64) {
325 let mut entry = self
326 .history
327 .entry(peer)
328 .or_insert_with(ConnectionHistory::new);
329 entry.record_latency(latency_ms, self.config.max_samples);
330 }
331
332 pub fn record_bandwidth(&self, peer: PeerId, bytes_per_sec: u64) {
334 let mut entry = self
335 .history
336 .entry(peer)
337 .or_insert_with(ConnectionHistory::new);
338 entry.record_bandwidth(bytes_per_sec, self.config.max_samples);
339 }
340
341 pub fn record_success(&self, peer: PeerId) {
343 let mut entry = self
344 .history
345 .entry(peer)
346 .or_insert_with(ConnectionHistory::new);
347 entry.record_success();
348 }
349
350 pub fn record_failure(&self, peer: PeerId) {
352 let mut entry = self
353 .history
354 .entry(peer)
355 .or_insert_with(ConnectionHistory::new);
356 entry.record_failure();
357 }
358
359 pub fn predict_quality(&self, peer: &PeerId) -> Option<QualityPrediction> {
361 let history = self.history.get(peer)?;
362
363 let latency_score = self.calculate_latency_score(history.avg_latency());
365 let bandwidth_score = self.calculate_bandwidth_score(history.avg_bandwidth());
366 let reliability_score = history.reliability_score();
367 let uptime_score = history.uptime_score();
368
369 let overall_score = latency_score * self.config.latency_weight
371 + bandwidth_score * self.config.bandwidth_weight
372 + reliability_score * self.config.reliability_weight
373 + uptime_score * self.config.uptime_weight;
374
375 drop(history);
377 if let Some(mut history) = self.history.get_mut(peer) {
378 if let Some(prev_ema) = history.quality_ema {
379 history.quality_ema = Some(
380 self.config.smoothing_factor * overall_score
381 + (1.0 - self.config.smoothing_factor) * prev_ema,
382 );
383 } else {
384 history.quality_ema = Some(overall_score);
385 }
386 }
387
388 let is_acceptable = overall_score >= self.config.min_acceptable_quality;
389 let should_switch =
390 self.config.enable_predictions && overall_score < self.config.switch_threshold;
391
392 let mut stats = self.stats.write();
394 stats.predictions_made += 1;
395 if should_switch {
396 stats.switch_recommendations += 1;
397 }
398
399 Some(QualityPrediction {
400 overall_score,
401 latency_score,
402 bandwidth_score,
403 reliability_score,
404 uptime_score,
405 avg_latency_ms: self.history.get(peer).and_then(|h| h.avg_latency()),
406 avg_bandwidth_bps: self.history.get(peer).and_then(|h| h.avg_bandwidth()),
407 is_acceptable,
408 should_switch,
409 })
410 }
411
412 pub fn should_switch_connection(&self, peer: &PeerId) -> bool {
414 self.predict_quality(peer)
415 .map(|p| p.should_switch)
416 .unwrap_or(false)
417 }
418
419 pub fn get_best_peer(&self, peers: &[PeerId]) -> Option<(PeerId, QualityPrediction)> {
421 peers
422 .iter()
423 .filter_map(|peer| {
424 self.predict_quality(peer)
425 .map(|prediction| (*peer, prediction))
426 })
427 .max_by(|a, b| {
428 a.1.overall_score
429 .partial_cmp(&b.1.overall_score)
430 .unwrap_or(std::cmp::Ordering::Equal)
431 })
432 }
433
434 pub fn rank_peers(&self, peers: &[PeerId]) -> Vec<(PeerId, QualityPrediction)> {
436 let mut ranked: Vec<_> = peers
437 .iter()
438 .filter_map(|peer| {
439 self.predict_quality(peer)
440 .map(|prediction| (*peer, prediction))
441 })
442 .collect();
443
444 ranked.sort_by(|a, b| {
445 b.1.overall_score
446 .partial_cmp(&a.1.overall_score)
447 .unwrap_or(std::cmp::Ordering::Equal)
448 });
449
450 ranked
451 }
452
453 pub fn remove_peer(&self, peer: &PeerId) {
455 self.history.remove(peer);
456 }
457
458 pub fn clear(&self) {
460 self.history.clear();
461 let mut stats = self.stats.write();
462 stats.tracked_peers = 0;
463 stats.predictions_made = 0;
464 stats.switch_recommendations = 0;
465 stats.avg_quality = 0.0;
466 }
467
468 pub fn stats(&self) -> QualityPredictorStats {
470 let mut stats = self.stats.read().clone();
471 stats.tracked_peers = self.history.len();
472
473 if stats.tracked_peers > 0 {
475 let total_quality: f64 = self
476 .history
477 .iter()
478 .filter_map(|entry| entry.quality_ema)
479 .sum();
480 stats.avg_quality = total_quality / stats.tracked_peers as f64;
481 }
482
483 stats
484 }
485
486 fn calculate_latency_score(&self, avg_latency: Option<f64>) -> f64 {
488 match avg_latency {
489 None => 0.5, Some(latency) => {
491 if latency <= 0.0 {
494 1.0
495 } else if latency >= 1000.0 {
496 0.0
497 } else {
498 1.0 - (latency / 1000.0)
499 }
500 }
501 }
502 }
503
504 fn calculate_bandwidth_score(&self, avg_bandwidth: Option<f64>) -> f64 {
506 match avg_bandwidth {
507 None => 0.5, Some(bandwidth) => {
509 let mb_per_sec = bandwidth / 1_000_000.0;
512 if mb_per_sec >= 100.0 {
513 1.0
514 } else if mb_per_sec <= 0.0 {
515 0.0
516 } else {
517 (mb_per_sec / 100.0).min(1.0)
518 }
519 }
520 }
521 }
522}
523
524#[cfg(test)]
525mod tests {
526 use super::*;
527
528 #[test]
529 fn test_config_default() {
530 let config = QualityPredictorConfig::default();
531 assert!(config.validate().is_ok());
532 }
533
534 #[test]
535 fn test_config_validation_weights() {
536 let config = QualityPredictorConfig {
537 latency_weight: 0.5,
538 bandwidth_weight: 0.3,
539 reliability_weight: 0.1,
540 uptime_weight: 0.05, ..Default::default()
542 };
543 assert!(config.validate().is_err());
544 }
545
546 #[test]
547 fn test_config_presets() {
548 assert!(QualityPredictorConfig::low_latency().validate().is_ok());
549 assert!(QualityPredictorConfig::high_bandwidth().validate().is_ok());
550 assert!(QualityPredictorConfig::high_reliability()
551 .validate()
552 .is_ok());
553 }
554
555 #[test]
556 fn test_record_metrics() {
557 let config = QualityPredictorConfig::default();
558 let predictor = QualityPredictor::new(config).unwrap();
559 let peer = PeerId::random();
560
561 predictor.record_latency(peer, 50);
562 predictor.record_bandwidth(peer, 1_000_000);
563 predictor.record_success(peer);
564
565 let prediction = predictor.predict_quality(&peer).unwrap();
566 assert!(prediction.avg_latency_ms.is_some());
567 assert!(prediction.avg_bandwidth_bps.is_some());
568 assert!(prediction.overall_score > 0.0);
569 }
570
571 #[test]
572 fn test_latency_score() {
573 let predictor = QualityPredictor::new(QualityPredictorConfig::default()).unwrap();
574
575 assert_eq!(predictor.calculate_latency_score(Some(0.0)), 1.0);
576 assert!(predictor.calculate_latency_score(Some(100.0)) > 0.7);
577 assert!(predictor.calculate_latency_score(Some(500.0)) < 0.6);
578 assert_eq!(predictor.calculate_latency_score(Some(1000.0)), 0.0);
579 }
580
581 #[test]
582 fn test_bandwidth_score() {
583 let predictor = QualityPredictor::new(QualityPredictorConfig::default()).unwrap();
584
585 assert_eq!(predictor.calculate_bandwidth_score(Some(0.0)), 0.0);
586 assert!(predictor.calculate_bandwidth_score(Some(1_000_000.0)) > 0.0);
587 assert!(predictor.calculate_bandwidth_score(Some(10_000_000.0)) > 0.05);
588 assert_eq!(
589 predictor.calculate_bandwidth_score(Some(100_000_000.0)),
590 1.0
591 );
592 }
593
594 #[test]
595 fn test_reliability_tracking() {
596 let config = QualityPredictorConfig::default();
597 let predictor = QualityPredictor::new(config).unwrap();
598 let peer = PeerId::random();
599
600 predictor.record_success(peer);
601 predictor.record_success(peer);
602 predictor.record_failure(peer);
603
604 let prediction = predictor.predict_quality(&peer).unwrap();
605 assert!((prediction.reliability_score - 0.666).abs() < 0.01);
606 }
607
608 #[test]
609 fn test_get_best_peer() {
610 let config = QualityPredictorConfig::default();
611 let predictor = QualityPredictor::new(config).unwrap();
612
613 let peer1 = PeerId::random();
614 let peer2 = PeerId::random();
615 let peer3 = PeerId::random();
616
617 predictor.record_latency(peer1, 10);
619 predictor.record_bandwidth(peer1, 10_000_000);
620 predictor.record_success(peer1);
621
622 predictor.record_latency(peer2, 500);
624 predictor.record_bandwidth(peer2, 100_000);
625 predictor.record_failure(peer2);
626
627 predictor.record_latency(peer3, 50);
629 predictor.record_bandwidth(peer3, 5_000_000);
630 predictor.record_success(peer3);
631
632 let peers = vec![peer1, peer2, peer3];
633 let (best, _) = predictor.get_best_peer(&peers).unwrap();
634 assert_eq!(best, peer1);
635 }
636
637 #[test]
638 fn test_rank_peers() {
639 let config = QualityPredictorConfig::default();
640 let predictor = QualityPredictor::new(config).unwrap();
641
642 let peer1 = PeerId::random();
643 let peer2 = PeerId::random();
644 let peer3 = PeerId::random();
645
646 predictor.record_latency(peer1, 10);
647 predictor.record_latency(peer2, 100);
648 predictor.record_latency(peer3, 50);
649
650 let peers = vec![peer1, peer2, peer3];
651 let ranked = predictor.rank_peers(&peers);
652
653 assert_eq!(ranked.len(), 3);
654 assert_eq!(ranked[0].0, peer1); assert_eq!(ranked[2].0, peer2); }
657
658 #[test]
659 fn test_should_switch() {
660 let config = QualityPredictorConfig {
661 switch_threshold: 0.7,
662 enable_predictions: true,
663 ..Default::default()
664 };
665 let predictor = QualityPredictor::new(config).unwrap();
666 let peer = PeerId::random();
667
668 predictor.record_latency(peer, 800);
670 predictor.record_bandwidth(peer, 50_000);
671 predictor.record_failure(peer);
672 predictor.record_failure(peer);
673 predictor.record_success(peer);
674
675 assert!(predictor.should_switch_connection(&peer));
676 }
677
678 #[test]
679 fn test_ema_smoothing() {
680 let config = QualityPredictorConfig {
681 smoothing_factor: 0.5,
682 ..Default::default()
683 };
684 let predictor = QualityPredictor::new(config).unwrap();
685 let peer = PeerId::random();
686
687 predictor.record_latency(peer, 100);
688 let pred1 = predictor.predict_quality(&peer).unwrap();
689
690 predictor.record_latency(peer, 50);
691 let pred2 = predictor.predict_quality(&peer).unwrap();
692
693 assert!(pred2.overall_score > pred1.overall_score);
695 }
696
697 #[test]
698 fn test_stats() {
699 let config = QualityPredictorConfig::default();
700 let predictor = QualityPredictor::new(config).unwrap();
701
702 let peer1 = PeerId::random();
703 let peer2 = PeerId::random();
704
705 predictor.record_latency(peer1, 50);
706 predictor.record_latency(peer2, 100);
707
708 predictor.predict_quality(&peer1);
709 predictor.predict_quality(&peer2);
710
711 let stats = predictor.stats();
712 assert_eq!(stats.tracked_peers, 2);
713 assert_eq!(stats.predictions_made, 2);
714 }
715
716 #[test]
717 fn test_remove_peer() {
718 let config = QualityPredictorConfig::default();
719 let predictor = QualityPredictor::new(config).unwrap();
720 let peer = PeerId::random();
721
722 predictor.record_latency(peer, 50);
723 assert!(predictor.predict_quality(&peer).is_some());
724
725 predictor.remove_peer(&peer);
726 assert!(predictor.predict_quality(&peer).is_none());
727 }
728
729 #[test]
730 fn test_clear() {
731 let config = QualityPredictorConfig::default();
732 let predictor = QualityPredictor::new(config).unwrap();
733
734 predictor.record_latency(PeerId::random(), 50);
735 predictor.record_latency(PeerId::random(), 100);
736
737 predictor.clear();
738 let stats = predictor.stats();
739 assert_eq!(stats.tracked_peers, 0);
740 }
741}