Skip to main content

netspeed_cli/
bandwidth_loop.rs

1//! Shared bandwidth measurement loop for download/upload tests.
2//!
3//! Eliminates duplication between `download.rs` and `upload.rs` by providing:
4//! - [`LoopState`] — unified state for throttled speed sampling,
5//!   peak tracking, progress bar updates, and atomic byte counting
6//! - [`run_concurrent_streams`] — shared spawn/collect/report pattern
7//!   that both download and upload tests delegate to
8//!
9//! Each I/O operation (download chunk, upload round) calls `record_bytes()`
10//! to update shared state. Call `finish()` at the end to compute final results.
11
12use crate::common;
13use crate::error::Error;
14use crate::progress::Tracker;
15use crate::terminal;
16use crate::test_config::TestConfig;
17use owo_colors::OwoColorize;
18use std::sync::Arc;
19use std::sync::Mutex;
20use std::sync::atomic::{AtomicU64, Ordering};
21use std::time::Instant;
22
23/// Throttle interval for speed sampling (20 Hz max).
24pub const SAMPLE_INTERVAL_MS: u64 = 50;
25
26/// Shared state for a bandwidth test (download or upload).
27///
28/// All fields are thread-safe for use across multiple concurrent streams.
29pub struct LoopState {
30    pub total_bytes: Arc<AtomicU64>,
31    pub peak_bps: Arc<AtomicU64>,
32    pub speed_samples: Arc<Mutex<Vec<f64>>>,
33    pub start: Instant,
34    pub last_sample_ms: Arc<AtomicU64>,
35    pub estimated_total: u64,
36    pub progress: Arc<Tracker>,
37}
38
39/// Final result from a bandwidth test.
40#[derive(Debug)]
41pub struct BandwidthResult {
42    pub avg_bps: f64,
43    pub peak_bps: f64,
44    pub total_bytes: u64,
45    pub duration_secs: f64,
46    pub speed_samples: Vec<f64>,
47}
48
49impl LoopState {
50    /// Create a new bandwidth measurement state.
51    #[must_use]
52    pub fn new(estimated_total: u64, progress: Arc<Tracker>) -> Self {
53        Self {
54            total_bytes: Arc::new(AtomicU64::new(0)),
55            peak_bps: Arc::new(AtomicU64::new(0)),
56            speed_samples: Arc::new(Mutex::new(Vec::new())),
57            start: Instant::now(),
58            last_sample_ms: Arc::new(AtomicU64::new(0)),
59            estimated_total,
60            progress,
61        }
62    }
63
64    /// Record transferred bytes and update progress (throttled to 20 Hz).
65    ///
66    /// This is the single point where all expensive operations (bandwidth calc,
67    /// peak tracking, sample recording, progress update) are throttled.
68    ///
69    /// Note: Uses cached sample_interval to avoid repeated TestConfig::default() calls.
70    pub fn record_bytes(&self, len: u64, sample_interval_ms: u64) {
71        // Release ensures writes are visible to the final Acquire load in finish()
72        self.total_bytes.fetch_add(len, Ordering::Release);
73
74        let elapsed_ms = u64::try_from(self.start.elapsed().as_millis()).unwrap_or(u64::MAX);
75        let last_ms = self.last_sample_ms.load(Ordering::Relaxed);
76        let should_sample =
77            last_ms == 0 || elapsed_ms.saturating_sub(last_ms) >= sample_interval_ms;
78
79        if should_sample {
80            self.last_sample_ms.store(elapsed_ms, Ordering::Relaxed);
81            self.sample_now();
82        }
83    }
84
85    /// Take a speed sample and update progress (no throttle check — caller must gate).
86    fn sample_now(&self) {
87        let total = self.total_bytes.load(Ordering::Acquire);
88        let elapsed = self.start.elapsed().as_secs_f64();
89        let speed = common::calculate_bandwidth(total, elapsed);
90
91        // Safe: peak_bps stores bits-per-second; even 100 Gbps = 1e11, well under 2^53.
92        let current_peak = self.peak_bps.load(Ordering::Relaxed) as f64;
93        if speed > current_peak {
94            let peak_u64 = speed.clamp(0.0, u64::MAX as f64) as u64;
95            // Release pairs with the Acquire load in finish()
96            self.peak_bps.store(peak_u64, Ordering::Release);
97        }
98
99        if let Ok(mut samples) = self.speed_samples.lock() {
100            samples.push(speed);
101        }
102
103        // Safe: total and estimated_total are byte counts from a test lasting seconds;
104        // they cannot approach 2^53 (~9 PB) where f64 loses precision.
105        let pct = (total as f64 / self.estimated_total as f64).min(1.0);
106        self.progress.update(speed / 1_000_000.0, pct, total);
107    }
108
109    /// Compute final results from accumulated state.
110    #[must_use]
111    pub fn finish(&self) -> BandwidthResult {
112        // Acquire pairs with the Release fetch_add/stores to see all writes
113        let total = self.total_bytes.load(Ordering::Acquire);
114        // Safe: peak_bps is bits/sec; even 100 Gbps = 1e11, well under 2^53.
115        let peak = self.peak_bps.load(Ordering::Acquire) as f64;
116        let duration = self.start.elapsed().as_secs_f64();
117        // Graceful fallback: if lock is poisoned (thread panicked), return empty samples
118        let samples = self
119            .speed_samples
120            .lock()
121            .map(|g| g.to_vec())
122            .unwrap_or_default();
123        let avg = common::calculate_bandwidth(total, duration);
124
125        BandwidthResult {
126            avg_bps: avg,
127            peak_bps: peak,
128            total_bytes: total,
129            duration_secs: duration,
130            speed_samples: samples,
131        }
132    }
133}
134
135/// Run a bandwidth test using multiple concurrent streams.
136///
137/// This is the shared spawn/collect/report pattern used by both download
138/// and upload tests. It:
139/// 1. Creates a [`LoopState`] for the test
140/// 2. Spawns `stream_count` tasks via `spawn_fn`
141/// 3. Collects results, logging any task panics
142/// 4. Returns a [`BandwidthResult`] (zeroed if all tasks failed)
143///
144/// The `spawn_fn` closure receives the stream index and a shared reference
145/// to the loop state. Each call should create and return a `JoinHandle<()>`
146/// that performs I/O and calls [`LoopState::record_bytes`] for each
147/// transferred chunk.
148///
149/// # Arguments
150/// * `estimated_total` — Estimated total bytes for progress bar initialization
151/// * `stream_count` — Number of concurrent streams to spawn
152/// * `progress` — Shared progress bar for the test phase
153/// * `spawn_fn` — Closure that creates one stream's async task
154///
155/// # Panics
156///
157/// Individual task panics are caught and logged; they do not propagate.
158#[must_use = "the BandwidthResult should be used to report test outcomes"]
159pub async fn run_concurrent_streams(
160    estimated_total: u64,
161    stream_count: usize,
162    progress: Arc<Tracker>,
163    label: &str,
164    mut spawn_fn: impl FnMut(usize, Arc<LoopState>, u64) -> tokio::task::JoinHandle<Result<(), Error>>,
165) -> Result<BandwidthResult, Error> {
166    let config = TestConfig::default();
167    let sample_interval_ms = config.sample_interval_ms;
168    let state = Arc::new(LoopState::new(estimated_total, progress));
169
170    let mut handles = Vec::with_capacity(stream_count);
171    for i in 0..stream_count {
172        handles.push(spawn_fn(i, Arc::clone(&state), sample_interval_ms));
173    }
174
175    // Collect results — log any task panics so failures aren't silently swallowed.
176    let mut any_succeeded = false;
177    let mut first_error: Option<Error> = None;
178    for (i, handle) in handles.into_iter().enumerate() {
179        match handle.await {
180            Ok(Ok(())) => any_succeeded = true,
181            Ok(Err(err)) => {
182                let msg = format!("Warning: {label} stream {i} failed: {err}");
183                if terminal::no_color() {
184                    eprintln!("\n{msg}");
185                } else {
186                    eprintln!("\n{}", msg.yellow().bold());
187                }
188                if first_error.is_none() {
189                    first_error = Some(err);
190                }
191            }
192            Err(e) => {
193                let msg = format!("Warning: {label} stream {i} failed: {e}");
194                if terminal::no_color() {
195                    eprintln!("\n{msg}");
196                } else {
197                    eprintln!("\n{}", msg.yellow().bold());
198                }
199                if first_error.is_none() {
200                    first_error = Some(Error::context(format!("{label} stream {i} panicked: {e}")));
201                }
202            }
203        }
204    }
205
206    if !any_succeeded {
207        return Err(
208            first_error.unwrap_or_else(|| Error::context(format!("all {label} streams failed")))
209        );
210    }
211
212    let result = state.finish();
213    if result.total_bytes == 0 {
214        return Err(first_error.unwrap_or_else(|| match label {
215            "download" => {
216                Error::DownloadFailure("test completed without transferring data".to_string())
217            }
218            "upload" => {
219                Error::UploadFailure("test completed without transferring data".to_string())
220            }
221            _ => Error::context(format!("{label} test completed without transferring data")),
222        }));
223    }
224
225    Ok(result)
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231    use std::sync::atomic::Ordering;
232    use std::thread;
233    use std::time::Duration;
234
235    // ── LoopState Tests ──────────────────────────────────────────────────────
236
237    fn make_tracker() -> Arc<Tracker> {
238        Arc::new(Tracker::new("test"))
239    }
240
241    #[test]
242    fn test_loop_state_new_fields() {
243        let tracker = make_tracker();
244        let state = LoopState::new(100_000_000, tracker);
245        assert_eq!(state.total_bytes.load(Ordering::SeqCst), 0);
246        assert_eq!(state.peak_bps.load(Ordering::SeqCst), 0);
247        assert_eq!(state.estimated_total, 100_000_000);
248        assert!(state.speed_samples.lock().unwrap().is_empty());
249    }
250
251    #[test]
252    fn test_loop_state_concurrent_atomic_updates() {
253        let tracker = make_tracker();
254        let state = Arc::new(LoopState::new(100_000_000, tracker));
255
256        let handles: Vec<_> = (0..4)
257            .map(|_| {
258                let s = Arc::clone(&state);
259                thread::spawn(move || {
260                    for _ in 0..1000 {
261                        s.record_bytes(100, SAMPLE_INTERVAL_MS);
262                    }
263                })
264            })
265            .collect();
266
267        for h in handles {
268            h.join().unwrap();
269        }
270
271        // 4 threads * 1000 * 100 = 400,000
272        assert_eq!(state.total_bytes.load(Ordering::SeqCst), 400_000);
273    }
274
275    #[test]
276    fn test_record_bytes_zero_value() {
277        let tracker = make_tracker();
278        let state = LoopState::new(100_000_000, tracker);
279        state.record_bytes(0, SAMPLE_INTERVAL_MS);
280        assert_eq!(state.total_bytes.load(Ordering::SeqCst), 0);
281    }
282
283    #[test]
284    fn test_record_bytes_accumulates() {
285        let tracker = make_tracker();
286        let state = LoopState::new(100_000_000, tracker);
287        state.record_bytes(1000, SAMPLE_INTERVAL_MS);
288        state.record_bytes(2000, SAMPLE_INTERVAL_MS);
289        state.record_bytes(3000, SAMPLE_INTERVAL_MS);
290        assert_eq!(state.total_bytes.load(Ordering::SeqCst), 6000);
291    }
292
293    #[test]
294    fn test_record_bytes_large_values() {
295        let tracker = make_tracker();
296        let state = LoopState::new(u64::MAX, tracker);
297        state.record_bytes(1_000_000_000, SAMPLE_INTERVAL_MS);
298        assert_eq!(state.total_bytes.load(Ordering::SeqCst), 1_000_000_000);
299    }
300
301    #[test]
302    fn test_record_bytes_throttle_mechanism() {
303        let tracker = make_tracker();
304        let state = LoopState::new(100_000_000, tracker);
305
306        // Test throttle by verifying that samples are recorded
307        // The throttle mechanism limits sampling to once per interval
308        let interval_ms = 50u64;
309
310        // First call always triggers
311        state.record_bytes(1000, interval_ms);
312        assert_eq!(state.speed_samples.lock().unwrap().len(), 1);
313
314        // Rapid second call - may or may not trigger depending on elapsed time
315        state.record_bytes(1000, interval_ms);
316
317        // Wait enough time for throttle to reset
318        thread::sleep(Duration::from_millis(100));
319        state.record_bytes(1000, interval_ms);
320
321        // Should have at least 2 samples (first + after wait)
322        // The exact count depends on timing, but throttle is working
323        let samples = state.speed_samples.lock().unwrap();
324        assert!(
325            samples.len() >= 2,
326            "Expected at least 2 samples, got {}",
327            samples.len()
328        );
329    }
330
331    #[test]
332    fn test_record_bytes_short_interval_samples_more() {
333        let tracker = make_tracker();
334        let state = LoopState::new(100_000_000, tracker);
335
336        // Short interval with explicit waits allows more frequent sampling
337        for _ in 0..3 {
338            state.record_bytes(1_000_000, 5); // 5ms interval
339            thread::sleep(Duration::from_millis(10));
340        }
341
342        let samples = state.speed_samples.lock().unwrap();
343        // With short interval and time between calls, should get multiple samples
344        assert!(
345            samples.len() >= 2,
346            "Expected >= 2 samples with short interval, got {}",
347            samples.len()
348        );
349    }
350
351    #[test]
352    fn test_record_bytes_updates_peak() {
353        let tracker = make_tracker();
354        let state = LoopState::new(100_000_000, tracker);
355
356        state.record_bytes(10_000_000, SAMPLE_INTERVAL_MS);
357        thread::sleep(Duration::from_millis(60));
358        state.record_bytes(10_000_000, SAMPLE_INTERVAL_MS);
359
360        let peak = state.peak_bps.load(Ordering::SeqCst);
361        assert!(peak > 0);
362    }
363
364    #[test]
365    fn test_finish_empty_state() {
366        let tracker = make_tracker();
367        let state = LoopState::new(100_000_000, tracker);
368        thread::sleep(Duration::from_millis(10));
369        let result = state.finish();
370
371        assert_eq!(result.total_bytes, 0);
372        assert_eq!(result.avg_bps, 0.0);
373        assert_eq!(result.peak_bps, 0.0);
374        assert!(result.duration_secs > 0.0);
375        assert!(result.speed_samples.is_empty());
376    }
377
378    #[test]
379    fn test_finish_with_transfer() {
380        let tracker = make_tracker();
381        let state = LoopState::new(100_000_000, tracker);
382
383        state.record_bytes(20_000_000, SAMPLE_INTERVAL_MS);
384        thread::sleep(Duration::from_millis(100));
385
386        let result = state.finish();
387        assert_eq!(result.total_bytes, 20_000_000);
388        assert!(result.avg_bps > 0.0);
389    }
390
391    #[test]
392    fn test_finish_peak_gte_avg() {
393        let tracker = make_tracker();
394        let state = LoopState::new(100_000_000, tracker);
395
396        for _ in 0..5 {
397            state.record_bytes(5_000_000, SAMPLE_INTERVAL_MS);
398            thread::sleep(Duration::from_millis(60));
399        }
400
401        let result = state.finish();
402        assert!(result.peak_bps >= result.avg_bps);
403    }
404
405    #[test]
406    fn test_finish_various_estimated_totals() {
407        for estimated in [1u64, 1000, 1_000_000, u64::MAX / 2] {
408            let tracker = make_tracker();
409            let state = LoopState::new(estimated, tracker);
410            state.record_bytes(100, SAMPLE_INTERVAL_MS);
411            thread::sleep(Duration::from_millis(10));
412            let result = state.finish();
413            assert_eq!(result.total_bytes, 100);
414        }
415    }
416
417    #[test]
418    fn test_finish_returns_speed_samples() {
419        let tracker = make_tracker();
420        let state = LoopState::new(10_000_000, tracker);
421
422        for _ in 0..3 {
423            state.record_bytes(1_000_000, 10);
424            thread::sleep(Duration::from_millis(20));
425        }
426
427        let result = state.finish();
428        assert!(!result.speed_samples.is_empty());
429        for sample in &result.speed_samples {
430            assert!(*sample >= 0.0);
431        }
432    }
433
434    #[test]
435    fn test_sample_interval_constant() {
436        assert_eq!(SAMPLE_INTERVAL_MS, 50);
437    }
438
439    #[test]
440    fn test_bandwidth_result_struct() {
441        let tracker = make_tracker();
442        let state = LoopState::new(100_000_000, tracker);
443        state.record_bytes(50_000_000, SAMPLE_INTERVAL_MS);
444        thread::sleep(Duration::from_millis(100));
445
446        let result = state.finish();
447
448        // Verify all fields are correctly populated
449        assert!(result.avg_bps >= 0.0);
450        assert!(result.peak_bps >= 0.0);
451        assert!(result.total_bytes > 0);
452        assert!(result.duration_secs > 0.0);
453    }
454
455    // ── run_concurrent_streams Tests ─────────────────────────────────────────
456
457    #[tokio::test]
458    async fn test_run_concurrent_streams_zero_streams() {
459        let tracker = make_tracker();
460        let result = run_concurrent_streams(100_000_000, 0, tracker, "test", |_, _, _| {
461            tokio::spawn(async { Ok(()) })
462        })
463        .await;
464        assert!(result.is_err());
465    }
466
467    #[tokio::test]
468    async fn test_run_concurrent_streams_single_stream_success() {
469        let tracker = make_tracker();
470        let result =
471            run_concurrent_streams(100_000_000, 1, tracker, "download", |_, state, interval| {
472                let s = Arc::clone(&state);
473                tokio::spawn(async move {
474                    s.record_bytes(10_000_000, interval);
475                    Ok(())
476                })
477            })
478            .await;
479
480        assert!(result.is_ok());
481        assert_eq!(result.unwrap().total_bytes, 10_000_000);
482    }
483
484    #[tokio::test]
485    async fn test_run_concurrent_streams_four_streams() {
486        let tracker = make_tracker();
487        let result =
488            run_concurrent_streams(100_000_000, 4, tracker, "upload", |_, state, interval| {
489                let s = Arc::clone(&state);
490                tokio::spawn(async move {
491                    s.record_bytes(1_000_000, interval);
492                    Ok(())
493                })
494            })
495            .await;
496
497        assert!(result.is_ok());
498        assert_eq!(result.unwrap().total_bytes, 4_000_000);
499    }
500
501    #[tokio::test]
502    async fn test_run_concurrent_streams_all_fail() {
503        let tracker = make_tracker();
504        let result = run_concurrent_streams(100_000_000, 3, tracker, "download", |_, _, _| {
505            tokio::spawn(async { Err(Error::DownloadFailure("failed".into())) })
506        })
507        .await;
508
509        assert!(result.is_err());
510    }
511
512    #[tokio::test]
513    async fn test_run_concurrent_streams_partial_failure() {
514        let tracker = make_tracker();
515        let result =
516            run_concurrent_streams(100_000_000, 4, tracker, "upload", |i, state, interval| {
517                let s = Arc::clone(&state);
518                tokio::spawn(async move {
519                    if i < 2 {
520                        s.record_bytes(1_000_000, interval);
521                        Ok(())
522                    } else {
523                        Err(Error::UploadFailure("failed".into()))
524                    }
525                })
526            })
527            .await;
528
529        assert!(result.is_ok());
530        assert_eq!(result.unwrap().total_bytes, 2_000_000);
531    }
532
533    #[tokio::test]
534    async fn test_run_concurrent_streams_stream_panic() {
535        let tracker = make_tracker();
536        let result =
537            run_concurrent_streams(100_000_000, 2, tracker, "download", |i, state, interval| {
538                let s = Arc::clone(&state);
539                tokio::spawn(async move {
540                    if i == 0 {
541                        s.record_bytes(1_000_000, interval);
542                        Ok(())
543                    } else {
544                        panic!("stream panicked");
545                    }
546                })
547            })
548            .await;
549
550        assert!(result.is_ok());
551        assert_eq!(result.unwrap().total_bytes, 1_000_000);
552    }
553
554    #[tokio::test]
555    async fn test_run_concurrent_streams_zero_bytes_returns_error() {
556        let tracker = make_tracker();
557        let result = run_concurrent_streams(100_000_000, 2, tracker, "download", |_, _, _| {
558            tokio::spawn(async { Ok(()) })
559        })
560        .await;
561
562        assert!(result.is_err());
563    }
564
565    #[tokio::test]
566    async fn test_run_concurrent_streams_label_different_errors() {
567        for label in ["download", "upload", "custom"] {
568            let tracker = make_tracker();
569            let result = run_concurrent_streams(100_000_000, 0, tracker, label, |_, _, _| {
570                tokio::spawn(async { Ok(()) })
571            })
572            .await;
573
574            assert!(result.is_err());
575            let err_str = format!("{:?}", result.unwrap_err());
576            assert!(err_str.contains(label));
577        }
578    }
579
580    #[tokio::test]
581    async fn test_run_concurrent_streams_estimated_total_param() {
582        for estimated in [1_000u64, 10_000_000, 1_000_000_000] {
583            let tracker = make_tracker();
584            let result =
585                run_concurrent_streams(estimated, 1, tracker, "test", |_, state, interval| {
586                    let s = Arc::clone(&state);
587                    tokio::spawn(async move {
588                        s.record_bytes(1000, interval);
589                        Ok(())
590                    })
591                })
592                .await;
593            assert!(result.is_ok());
594        }
595    }
596}