Skip to main content

nd_300/speedtest/
mod.rs

1pub mod cloudflare;
2pub mod display;
3pub mod fastcom;
4pub mod librespeed;
5pub mod ndt7;
6pub mod statistics;
7
8use serde::Serialize;
9use std::sync::Arc;
10use std::time::Instant;
11
12/// Test duration configuration
13#[derive(Debug, Clone)]
14pub enum TestDuration {
15    /// Fixed duration per direction in seconds (e.g., 30 = 30s download + 30s upload)
16    Seconds(u64),
17    /// Let providers use their natural duration
18    Auto,
19}
20
21/// Which providers to run
22#[derive(Debug, Clone, Copy, PartialEq)]
23pub enum ProviderSet {
24    /// All 4 providers: Cloudflare, NDT7, LibreSpeed, fast.com (speedqx default)
25    All,
26    /// Diagnostic subset: Cloudflare + NDT7 only (nd300 default)
27    Diagnostic,
28}
29
30/// Configuration for the speed test orchestrator
31#[derive(Debug, Clone)]
32pub struct SpeedTestConfig {
33    /// Duration per direction for CF, NDT7, LibreSpeed (default: 30s)
34    pub duration: TestDuration,
35    /// Duration per direction for fast.com (default: Auto)
36    pub fastcom_duration: TestDuration,
37    /// Number of latency probes
38    pub latency_probes: u32,
39    /// Which providers to run
40    pub provider_set: ProviderSet,
41    /// Enable colored output
42    pub use_colors: bool,
43}
44
45impl Default for SpeedTestConfig {
46    fn default() -> Self {
47        Self {
48            duration: TestDuration::Seconds(30),
49            fastcom_duration: TestDuration::Auto,
50            latency_probes: 20,
51            provider_set: ProviderSet::All,
52            use_colors: true,
53        }
54    }
55}
56
57/// Phase indicator for progress callbacks
58#[derive(Debug, Clone, Copy, PartialEq)]
59pub enum Phase {
60    CfLatency,
61    CfDownload,
62    CfUpload,
63    Ndt7Discovery,
64    Ndt7Download,
65    Ndt7Upload,
66    LsDiscovery,
67    LsDownload,
68    LsUpload,
69    FcDiscovery,
70    FcDownload,
71    FcUpload,
72    Computing,
73}
74
75/// Raw per-request Mbps samples for statistical post-processing.
76#[derive(Debug, Clone, Default, Serialize)]
77pub struct BandwidthSamples {
78    pub download: Vec<f64>,
79    pub upload: Vec<f64>,
80}
81
82/// Connection stability metrics (coefficient of variation).
83#[derive(Debug, Clone, Serialize)]
84pub struct StabilityMetrics {
85    pub download_cv: f64,
86    pub upload_cv: f64,
87    pub download_stable: bool,
88    pub upload_stable: bool,
89}
90
91/// Provider divergence detection.
92#[derive(Debug, Clone, Serialize)]
93pub struct ProviderDivergence {
94    pub download: f64,
95    pub upload: f64,
96    pub significant: bool,
97}
98
99/// Per-provider speed test result
100#[derive(Debug, Clone, Serialize)]
101pub struct ProviderResult {
102    pub provider: String,
103    pub server: String,
104    #[serde(skip_serializing_if = "Option::is_none")]
105    pub location: Option<String>,
106    #[serde(skip_serializing_if = "Option::is_none")]
107    pub ping_ms: Option<f64>,
108    #[serde(skip_serializing_if = "Option::is_none")]
109    pub jitter_ms: Option<f64>,
110    #[serde(skip_serializing_if = "Option::is_none")]
111    pub download_mbps: Option<f64>,
112    #[serde(skip_serializing_if = "Option::is_none")]
113    pub upload_mbps: Option<f64>,
114    pub download_bytes: u64,
115    pub upload_bytes: u64,
116    pub download_duration_s: f64,
117    pub upload_duration_s: f64,
118    #[serde(skip_serializing_if = "Option::is_none")]
119    pub packet_loss_pct: Option<f64>,
120    #[serde(skip_serializing_if = "Option::is_none")]
121    pub error: Option<String>,
122    #[serde(skip_serializing_if = "Option::is_none")]
123    pub bandwidth_samples: Option<BandwidthSamples>,
124}
125
126/// Aggregated speed test result (used by both speedqx and nd300)
127#[derive(Debug, Clone, Serialize)]
128pub struct SpeedTestResult {
129    #[serde(skip_serializing_if = "Option::is_none")]
130    pub ping_ms: Option<f64>,
131    #[serde(skip_serializing_if = "Option::is_none")]
132    pub jitter_ms: Option<f64>,
133    pub download_mbps: f64,
134    pub upload_mbps: f64,
135    #[serde(skip_serializing_if = "Option::is_none")]
136    pub packet_loss_pct: Option<f64>,
137    pub providers: Vec<ProviderResult>,
138    pub duration_s: f64,
139    #[serde(skip_serializing_if = "Option::is_none")]
140    pub stability: Option<StabilityMetrics>,
141    #[serde(skip_serializing_if = "Option::is_none")]
142    pub provider_divergence: Option<ProviderDivergence>,
143}
144
145/// Latency weight for Cloudflare (NDT7 gets 1 - this).
146/// NDT7's MinRTT from TCP kernel is structurally superior, not just lower-variance.
147const CF_LATENCY_WEIGHT: f64 = 0.4;
148
149/// Divergence threshold: flag when providers differ by more than this fraction.
150const DIVERGENCE_THRESHOLD: f64 = 0.3;
151
152fn divergence_ratio(a: f64, b: f64) -> f64 {
153    if a <= 0.0 || b <= 0.0 {
154        return 0.0;
155    }
156    (a - b).abs() / a.max(b)
157}
158
159fn divergence_spread(values: &[(f64, f64)]) -> f64 {
160    let mut min = f64::INFINITY;
161    let mut max = f64::NEG_INFINITY;
162
163    for (value, _) in values {
164        if *value <= 0.0 {
165            continue;
166        }
167        min = min.min(*value);
168        max = max.max(*value);
169    }
170
171    if !min.is_finite() || !max.is_finite() || max <= 0.0 || min == max {
172        0.0
173    } else {
174        divergence_ratio(min, max)
175    }
176}
177
178fn inverse_variance_merge_many(values: &[(f64, f64)]) -> f64 {
179    let positive: Vec<(f64, f64)> = values
180        .iter()
181        .copied()
182        .filter(|(value, _)| *value > 0.0)
183        .collect();
184
185    if positive.is_empty() {
186        return 0.0;
187    }
188    if positive.len() == 1 {
189        return positive[0].0;
190    }
191
192    let min_positive_variance = positive
193        .iter()
194        .filter_map(|(_, variance)| (*variance > 0.0).then_some(*variance))
195        .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
196
197    if min_positive_variance.is_none() {
198        return positive.iter().map(|(value, _)| value).sum::<f64>() / positive.len() as f64;
199    }
200
201    let variance_floor = min_positive_variance.unwrap().max(0.000_001);
202    let raw_weights: Vec<f64> = positive
203        .iter()
204        .map(|(_, variance)| 1.0 / variance.max(variance_floor))
205        .collect();
206    let raw_total = raw_weights.iter().sum::<f64>();
207    if raw_total <= 0.0 {
208        return positive.iter().map(|(value, _)| value).sum::<f64>() / positive.len() as f64;
209    }
210
211    let weights = capped_inverse_variance_weights(&raw_weights, raw_total, 0.70);
212
213    positive
214        .iter()
215        .zip(weights.iter())
216        .map(|((value, _), weight)| value * weight)
217        .sum()
218}
219
220fn capped_inverse_variance_weights(raw_weights: &[f64], raw_total: f64, cap: f64) -> Vec<f64> {
221    if raw_weights.is_empty() {
222        return Vec::new();
223    }
224    if raw_weights.len() == 1 {
225        return vec![1.0];
226    }
227    if raw_total <= 0.0 {
228        let equal = 1.0 / raw_weights.len() as f64;
229        return vec![equal; raw_weights.len()];
230    }
231
232    let cap = cap.max(1.0 / raw_weights.len() as f64);
233    let mut weights = vec![0.0; raw_weights.len()];
234    let mut remaining: Vec<usize> = (0..raw_weights.len()).collect();
235    let mut remaining_mass = 1.0;
236
237    loop {
238        if remaining.is_empty() {
239            break;
240        }
241
242        let remaining_raw_total = remaining.iter().map(|idx| raw_weights[*idx]).sum::<f64>();
243        if remaining_raw_total <= 0.0 {
244            let equal = remaining_mass / remaining.len() as f64;
245            for idx in remaining {
246                weights[idx] = equal;
247            }
248            break;
249        }
250
251        let mut capped = Vec::new();
252        for idx in &remaining {
253            let candidate = remaining_mass * raw_weights[*idx] / remaining_raw_total;
254            if candidate > cap {
255                weights[*idx] = cap;
256                remaining_mass = (remaining_mass - cap).max(0.0);
257                capped.push(*idx);
258            }
259        }
260
261        if capped.is_empty() {
262            for idx in remaining {
263                weights[idx] = remaining_mass * raw_weights[idx] / remaining_raw_total;
264            }
265            break;
266        }
267
268        remaining.retain(|idx| !capped.contains(idx));
269    }
270
271    weights
272}
273
274/// Aggregation result including new metrics.
275struct AggregateResult {
276    ping: Option<f64>,
277    jitter: Option<f64>,
278    download: f64,
279    upload: f64,
280    packet_loss: Option<f64>,
281    stability: Option<StabilityMetrics>,
282    divergence: Option<ProviderDivergence>,
283}
284
285/// Inverse-variance weighted aggregation across providers.
286/// Uses accurate bandwidth pipeline on raw samples, fixed latency weights,
287/// stability metrics, and divergence detection.
288fn aggregate(providers: &[ProviderResult]) -> AggregateResult {
289    let successful: Vec<&ProviderResult> = providers.iter().filter(|p| p.error.is_none()).collect();
290
291    if successful.is_empty() {
292        return AggregateResult {
293            ping: None,
294            jitter: None,
295            download: 0.0,
296            upload: 0.0,
297            packet_loss: None,
298            stability: None,
299            divergence: None,
300        };
301    }
302
303    // ── Compute accurate bandwidth per provider from raw samples ────
304    let mut provider_dl: Vec<(f64, f64)> = Vec::new(); // (accurate_mbps, variance)
305    let mut provider_ul: Vec<(f64, f64)> = Vec::new();
306    let mut all_dl_samples: Vec<f64> = Vec::new();
307    let mut all_ul_samples: Vec<f64> = Vec::new();
308
309    for p in &successful {
310        if let Some(ref samples) = p.bandwidth_samples {
311            if !samples.download.is_empty() {
312                let acc = statistics::accurate_bandwidth(&samples.download);
313                let var = statistics::variance(&samples.download);
314                if acc > 0.0 {
315                    provider_dl.push((acc, var));
316                }
317                all_dl_samples.extend_from_slice(&samples.download);
318            }
319            if !samples.upload.is_empty() {
320                let acc = statistics::accurate_upload_bandwidth(&samples.upload);
321                let var = statistics::variance(&samples.upload);
322                if acc > 0.0 {
323                    provider_ul.push((acc, var));
324                }
325                all_ul_samples.extend_from_slice(&samples.upload);
326            }
327        }
328        // Fallback: use provider-reported value if no raw samples
329        if p.bandwidth_samples
330            .as_ref()
331            .is_none_or(|s| s.download.is_empty())
332        {
333            if let Some(dl) = p.download_mbps {
334                if dl > 0.0 {
335                    provider_dl.push((dl, 0.0));
336                }
337            }
338        }
339        if p.bandwidth_samples
340            .as_ref()
341            .is_none_or(|s| s.upload.is_empty())
342        {
343            if let Some(ul) = p.upload_mbps {
344                if ul > 0.0 {
345                    provider_ul.push((ul, 0.0));
346                }
347            }
348        }
349    }
350
351    // ── Merge bandwidth via inverse-variance weighting ─────────────
352    let download = inverse_variance_merge_many(&provider_dl);
353
354    let upload = inverse_variance_merge_many(&provider_ul);
355
356    // ── Latency: confidence-weighted merge (CF 0.4 / NDT7 0.6) ─────
357    let cf_ping = successful
358        .iter()
359        .find(|p| p.provider == "Cloudflare")
360        .and_then(|p| p.ping_ms);
361    let ndt_ping = successful
362        .iter()
363        .find(|p| p.provider == "M-Lab NDT7")
364        .and_then(|p| p.ping_ms);
365    let cf_jitter = successful
366        .iter()
367        .find(|p| p.provider == "Cloudflare")
368        .and_then(|p| p.jitter_ms);
369    let ndt_jitter = successful
370        .iter()
371        .find(|p| p.provider == "M-Lab NDT7")
372        .and_then(|p| p.jitter_ms);
373
374    let ping = match (cf_ping, ndt_ping) {
375        (Some(cf), Some(ndt)) => Some(statistics::weighted_merge(cf, ndt, CF_LATENCY_WEIGHT)),
376        (Some(cf), None) => Some(cf),
377        (None, Some(ndt)) => Some(ndt),
378        (None, None) => {
379            // Fallback: minimum across all providers
380            successful
381                .iter()
382                .filter_map(|p| p.ping_ms)
383                .filter(|p| *p > 0.0)
384                .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
385        }
386    };
387
388    let jitter = match (cf_jitter, ndt_jitter) {
389        (Some(cf), Some(ndt)) => Some(statistics::weighted_merge(cf, ndt, CF_LATENCY_WEIGHT)),
390        (Some(cf), None) => Some(cf),
391        (None, Some(ndt)) => Some(ndt),
392        (None, None) => {
393            let jitters: Vec<f64> = successful
394                .iter()
395                .filter_map(|p| p.jitter_ms)
396                .filter(|j| *j > 0.0)
397                .collect();
398            if jitters.is_empty() {
399                None
400            } else {
401                Some(statistics::mean(&jitters))
402            }
403        }
404    };
405
406    // Packet loss from Cloudflare (only provider that measures it)
407    let packet_loss = successful
408        .iter()
409        .find(|p| p.provider == "Cloudflare")
410        .and_then(|p| p.packet_loss_pct);
411
412    // ── Stability metrics ──────────────────────────────────────────
413    let stability = if all_dl_samples.len() > 2 || all_ul_samples.len() > 2 {
414        let dl_cv = statistics::coefficient_of_variation(&all_dl_samples);
415        let ul_cv = statistics::coefficient_of_variation(&all_ul_samples);
416        Some(StabilityMetrics {
417            download_cv: dl_cv,
418            upload_cv: ul_cv,
419            download_stable: dl_cv < 0.15,
420            upload_stable: ul_cv < 0.15,
421        })
422    } else {
423        None
424    };
425
426    // ── Provider divergence ────────────────────────────────────────
427    let divergence = if provider_dl.len() >= 2 || provider_ul.len() >= 2 {
428        let dl_div = divergence_spread(&provider_dl);
429        let ul_div = divergence_spread(&provider_ul);
430        Some(ProviderDivergence {
431            download: dl_div,
432            upload: ul_div,
433            significant: dl_div > DIVERGENCE_THRESHOLD || ul_div > DIVERGENCE_THRESHOLD,
434        })
435    } else {
436        None
437    };
438
439    AggregateResult {
440        ping,
441        jitter,
442        download,
443        upload,
444        packet_loss,
445        stability,
446        divergence,
447    }
448}
449
450/// Callback type for provider completion notifications.
451pub type ProviderCompleteCallback = Arc<dyn Fn(&ProviderResult) + Send + Sync>;
452
453/// Run the speed test with the given configuration and progress callback.
454/// The `on_provider_complete` callback is called after each provider finishes,
455/// allowing the UI to show per-provider summaries.
456pub async fn run<F>(
457    config: SpeedTestConfig,
458    progress: F,
459    on_provider_complete: Option<ProviderCompleteCallback>,
460) -> SpeedTestResult
461where
462    F: Fn(Phase, f64) + Send + Sync + 'static,
463{
464    let start = Instant::now();
465    let mut providers = Vec::new();
466    let progress = Arc::new(progress);
467
468    // Cloudflare (always runs)
469    {
470        let pg = progress.clone();
471        let cf_result = cloudflare::run(&config, move |phase, p| pg(phase, p)).await;
472        if let Some(ref cb) = on_provider_complete {
473            cb(&cf_result);
474        }
475        providers.push(cf_result);
476    }
477
478    // M-Lab NDT7 (always runs)
479    {
480        let pg = progress.clone();
481        let ndt_result = ndt7::run(&config, move |phase, p| pg(phase, p)).await;
482        if let Some(ref cb) = on_provider_complete {
483            cb(&ndt_result);
484        }
485        providers.push(ndt_result);
486    }
487
488    // LibreSpeed + fast.com (only in All mode)
489    if config.provider_set == ProviderSet::All {
490        {
491            let pg = progress.clone();
492            let ls_result = librespeed::run(&config, move |phase, p| pg(phase, p)).await;
493            if let Some(ref cb) = on_provider_complete {
494                cb(&ls_result);
495            }
496            providers.push(ls_result);
497        }
498
499        {
500            let pg = progress.clone();
501            let fc_result = fastcom::run(&config, move |phase, p| pg(phase, p)).await;
502            if let Some(ref cb) = on_provider_complete {
503                cb(&fc_result);
504            }
505            providers.push(fc_result);
506        }
507    }
508
509    progress(Phase::Computing, 1.0);
510
511    let agg = aggregate(&providers);
512    let duration = start.elapsed().as_secs_f64();
513
514    SpeedTestResult {
515        ping_ms: agg.ping,
516        jitter_ms: agg.jitter,
517        download_mbps: agg.download,
518        upload_mbps: agg.upload,
519        packet_loss_pct: agg.packet_loss,
520        providers,
521        duration_s: duration,
522        stability: agg.stability,
523        provider_divergence: agg.divergence,
524    }
525}
526
527/// Format Mbps value for display
528pub fn format_mbps(mbps: f64) -> String {
529    if mbps >= 1000.0 {
530        format!("{:.1} Gbps", mbps / 1000.0)
531    } else if mbps >= 100.0 {
532        format!("{:.0} Mbps", mbps)
533    } else if mbps >= 10.0 {
534        format!("{:.1} Mbps", mbps)
535    } else {
536        format!("{:.2} Mbps", mbps)
537    }
538}
539
540/// Format bytes for display
541pub fn format_bytes(bytes: u64) -> String {
542    const KB: u64 = 1024;
543    const MB: u64 = 1024 * KB;
544    const GB: u64 = 1024 * MB;
545
546    if bytes >= GB {
547        format!("{:.2} GB", bytes as f64 / GB as f64)
548    } else if bytes >= MB {
549        format!("{:.1} MB", bytes as f64 / MB as f64)
550    } else if bytes >= KB {
551        format!("{:.1} KB", bytes as f64 / KB as f64)
552    } else {
553        format!("{} B", bytes)
554    }
555}
556
557#[cfg(test)]
558mod tests {
559    use super::*;
560
561    fn provider(name: &str, download: f64, upload: f64, variance: f64) -> ProviderResult {
562        let delta = variance.sqrt();
563        ProviderResult {
564            provider: name.to_string(),
565            server: "test".to_string(),
566            location: None,
567            ping_ms: None,
568            jitter_ms: None,
569            download_mbps: Some(download),
570            upload_mbps: Some(upload),
571            download_bytes: 1,
572            upload_bytes: 1,
573            download_duration_s: 1.0,
574            upload_duration_s: 1.0,
575            packet_loss_pct: None,
576            error: None,
577            bandwidth_samples: Some(BandwidthSamples {
578                download: vec![download - delta, download, download + delta, download],
579                upload: vec![upload - delta, upload, upload + delta, upload],
580            }),
581        }
582    }
583
584    #[test]
585    fn aggregate_uses_more_than_first_two_providers() {
586        let first_two = vec![
587            provider("Cloudflare", 100.0, 20.0, 4.0),
588            provider("M-Lab NDT7", 100.0, 20.0, 4.0),
589        ];
590        let with_four = vec![
591            provider("Cloudflare", 100.0, 20.0, 4.0),
592            provider("M-Lab NDT7", 100.0, 20.0, 4.0),
593            provider("LibreSpeed", 900.0, 180.0, 4.0),
594            provider("fast.com", 900.0, 180.0, 4.0),
595        ];
596
597        let two = aggregate(&first_two);
598        let four = aggregate(&with_four);
599
600        assert!(
601            four.download > two.download + 100.0,
602            "third/fourth providers should materially influence aggregate: two={}, four={}",
603            two.download,
604            four.download
605        );
606        assert!(
607            four.upload > two.upload + 20.0,
608            "third/fourth providers should materially influence upload aggregate: two={}, four={}",
609            two.upload,
610            four.upload
611        );
612    }
613
614    #[test]
615    fn divergence_uses_full_provider_spread() {
616        let providers = vec![
617            provider("Cloudflare", 100.0, 20.0, 4.0),
618            provider("M-Lab NDT7", 105.0, 22.0, 4.0),
619            provider("LibreSpeed", 450.0, 90.0, 4.0),
620        ];
621
622        let agg = aggregate(&providers);
623        let div = agg.divergence.expect("divergence should be reported");
624
625        assert!(div.significant);
626        assert!(
627            div.download > 0.70,
628            "expected divergence to use 100 vs 450 spread, got {}",
629            div.download
630        );
631        assert!(
632            div.upload > 0.70,
633            "expected divergence to use 20 vs 90 spread, got {}",
634            div.upload
635        );
636    }
637
638    #[test]
639    fn inverse_variance_merge_caps_single_provider_dominance() {
640        let merged = inverse_variance_merge_many(&[(1000.0, 0.000_001), (1.0, 1000.0)]);
641
642        assert!(
643            merged < 701.0,
644            "dominant provider should be capped near 70%, got {}",
645            merged
646        );
647    }
648}