Skip to main content

netspeed_cli/
download.rs

1//! Multi-stream download bandwidth measurement.
2//!
3//! This module handles downloading test files from speedtest.net servers
4//! to measure download bandwidth. It supports:
5//! - Multi-stream concurrent downloads (4 streams by default, 1 with `--single`)
6//! - Dynamic test URL construction from server base URL
7//! - Real-time progress tracking with speed calculation
8//! - Peak speed detection through periodic sampling
9
10#![allow(
11    clippy::cast_precision_loss,
12    clippy::cast_possible_truncation,
13    clippy::cast_sign_loss
14)]
15
16use crate::common;
17use crate::error::SpeedtestError;
18use crate::progress::SpeedProgress;
19use crate::types::Server;
20use reqwest::Client;
21use std::sync::Arc;
22use std::sync::Mutex;
23use std::sync::atomic::{AtomicU64, Ordering};
24use std::time::Instant;
25
26/// Estimated total bytes for progress bar initialization.
27/// This is a rough estimate; the bar will adjust as actual data is downloaded.
28const ESTIMATED_DOWNLOAD_BYTES: u64 = 15_000_000; // 15 MB estimate
29
30/// Interval between speed samples in milliseconds.
31/// Throttling prevents excessive sampling overhead on all hot-path operations.
32/// Uses 0 as initial value so the first chunk always triggers a sample.
33const SAMPLE_INTERVAL_MS: u64 = 50; // 50ms between samples (20 Hz max)
34
35/// Number of download rounds per stream (each round fetches a different test file).
36const DOWNLOAD_TEST_ROUNDS: usize = 4;
37
38/// Extract base URL from server URL (strip /upload.php suffix)
39#[must_use]
40pub fn extract_base_url(url: &str) -> &str {
41    url.strip_suffix("/upload.php").unwrap_or(url)
42}
43
44/// Build test file URL using Speedtest.net standard naming
45#[must_use]
46pub fn build_test_url(server_url: &str, file_index: usize) -> String {
47    let base = extract_base_url(server_url);
48    let sizes = ["2000x2000", "3000x3000", "3500x3500", "4000x4000"];
49    let size = sizes[file_index % sizes.len()];
50    format!("{base}/random{size}.jpg")
51}
52
53/// Result from a single download stream
54struct StreamResult {
55    bytes: u64,
56    elapsed_secs: f64,
57}
58
59use futures_util::StreamExt;
60
61/// Run download bandwidth test against the given server.
62///
63/// Returns `(avg_speed_bps, peak_speed_bps, total_bytes_downloaded, speed_samples)`.
64///
65/// # Errors
66///
67/// Returns [`SpeedtestError::NetworkError`] if all download streams fail.
68/// Returns [`SpeedtestError::Context`] if the server URL is invalid.
69pub async fn download_test(
70    client: &Client,
71    server: &Server,
72    single: bool,
73    progress: Arc<SpeedProgress>,
74) -> Result<(f64, f64, u64, Vec<f64>), SpeedtestError> {
75    let concurrent_streams = common::determine_stream_count(single);
76    let total_bytes = Arc::new(AtomicU64::new(0));
77    let peak_bps = Arc::new(AtomicU64::new(0));
78    let speed_samples = Arc::new(Mutex::new(Vec::new()));
79    let start = Instant::now();
80
81    // Estimated total: progress bar will update dynamically as data is downloaded
82    let estimated_total: u64 = ESTIMATED_DOWNLOAD_BYTES;
83
84    // Throttle gate: tracks last sample time in millis to limit all expensive ops to 20 Hz.
85    // Initialized to 0 so the first chunk always triggers a sample (any elapsed > 0 fires).
86    let last_sample_ms = Arc::new(AtomicU64::new(0));
87
88    // Spawn streams that report progress
89    let mut handles = Vec::new();
90    for _ in 0..concurrent_streams {
91        let client = client.clone();
92        let server_url = server.url.clone();
93        let total_ref = Arc::clone(&total_bytes);
94        let peak_ref = Arc::clone(&peak_bps);
95        let samples_ref = Arc::clone(&speed_samples);
96        let start_ref = start;
97        let prog = Arc::clone(&progress);
98        let throttle_ref = Arc::clone(&last_sample_ms);
99
100        let handle = tokio::spawn(async move {
101            let mut stream_bytes = 0u64;
102
103            for j in 0..DOWNLOAD_TEST_ROUNDS {
104                let test_url = build_test_url(&server_url, j);
105
106                if let Ok(response) = client.get(&test_url).send().await {
107                    let mut stream = response.bytes_stream();
108                    while let Some(item) = stream.next().await {
109                        if let Ok(chunk) = item {
110                            let len = chunk.len() as u64;
111                            stream_bytes += len;
112                            // Cheap atomic add — runs on every chunk
113                            total_ref.fetch_add(len, Ordering::Relaxed);
114
115                            // Throttle gate: only run expensive ops every 50ms.
116                            // First sample always fires (last_sample_ms == 0 means "never sampled").
117                            let elapsed_ms = start_ref.elapsed().as_millis() as u64;
118                            let last_ms = throttle_ref.load(Ordering::Relaxed);
119                            let should_sample = last_ms == 0
120                                || elapsed_ms.saturating_sub(last_ms) >= SAMPLE_INTERVAL_MS;
121                            if should_sample {
122                                // Update throttle timestamp
123                                throttle_ref.store(elapsed_ms, Ordering::Relaxed);
124
125                                // All expensive ops now run at most every 50ms:
126                                // Acquire ensures we see the latest fetch_add results on ARM64.
127                                let total_so_far = total_ref.load(Ordering::Acquire);
128                                let elapsed = start_ref.elapsed().as_secs_f64();
129                                let speed = common::calculate_bandwidth(total_so_far, elapsed);
130
131                                // Peak tracking (cheap compare-and-swap)
132                                let current_peak = peak_ref.load(Ordering::Relaxed);
133                                if speed > current_peak as f64 {
134                                    peak_ref.store(speed as u64, Ordering::Relaxed);
135                                }
136
137                                // Record speed sample (throttled, no need for additional check)
138                                if let Ok(mut samples) = samples_ref.lock() {
139                                    samples.push(speed);
140                                }
141
142                                let pct = (total_so_far as f64 / estimated_total as f64).min(1.0);
143                                prog.update(speed / 1_000_000.0, pct, total_so_far);
144                            }
145                        }
146                    }
147                }
148            }
149
150            StreamResult {
151                bytes: stream_bytes,
152                elapsed_secs: start_ref.elapsed().as_secs_f64(),
153            }
154        });
155
156        handles.push(handle);
157    }
158
159    // Collect results
160    let mut results = Vec::new();
161    for handle in handles {
162        if let Ok(result) = handle.await {
163            results.push(result);
164        }
165    }
166
167    if results.is_empty() {
168        return Ok((0.0, 0.0, 0, Vec::new()));
169    }
170
171    let total_bandwidth: f64 = results
172        .iter()
173        .map(|r| common::calculate_bandwidth(r.bytes, r.elapsed_secs))
174        .sum();
175
176    let final_total_bytes = total_bytes.load(Ordering::Relaxed);
177    let final_peak_speed = peak_bps.load(Ordering::Relaxed) as f64;
178    let avg_bandwidth = total_bandwidth / results.len() as f64;
179    let samples = speed_samples.lock().unwrap().to_vec();
180    Ok((avg_bandwidth, final_peak_speed, final_total_bytes, samples))
181}
182
183#[cfg(test)]
184mod tests {
185    use crate::common;
186
187    use super::*;
188
189    #[test]
190    fn test_download_bandwidth_calculation() {
191        let result = common::calculate_bandwidth(10_000_000, 2.0);
192        assert_eq!(result, 40_000_000.0);
193    }
194
195    #[test]
196    fn test_download_bandwidth_zero_elapsed() {
197        let result = common::calculate_bandwidth(10_000_000, 0.0);
198        assert_eq!(result, 0.0);
199    }
200
201    #[test]
202    fn test_download_concurrent_streams_single() {
203        assert_eq!(common::determine_stream_count(true), 1);
204    }
205
206    #[test]
207    fn test_download_concurrent_streams_multiple() {
208        assert_eq!(common::determine_stream_count(false), 4);
209    }
210
211    #[test]
212    fn test_download_url_generation() {
213        let server_url = "http://server.example.com/speedtest/upload.php";
214        let test_url = build_test_url(server_url, 0);
215        assert_eq!(
216            test_url,
217            "http://server.example.com/speedtest/random2000x2000.jpg"
218        );
219    }
220
221    #[test]
222    fn test_download_url_generation_cycles() {
223        let server_url = "http://server.example.com/speedtest/upload.php";
224        let url_0 = build_test_url(server_url, 0);
225        let url_4 = build_test_url(server_url, 4);
226        assert_eq!(url_0, url_4);
227    }
228
229    #[test]
230    fn test_download_url_generation_all_sizes() {
231        let server_url = "http://server.example.com/speedtest/upload.php";
232        let expected = [
233            "http://server.example.com/speedtest/random2000x2000.jpg",
234            "http://server.example.com/speedtest/random3000x3000.jpg",
235            "http://server.example.com/speedtest/random3500x3500.jpg",
236            "http://server.example.com/speedtest/random4000x4000.jpg",
237        ];
238
239        for (i, expected_url) in expected.iter().enumerate() {
240            assert_eq!(build_test_url(server_url, i), *expected_url);
241        }
242    }
243
244    #[test]
245    fn test_extract_base_url() {
246        let url = "http://server.example.com:8080/speedtest/upload.php";
247        assert_eq!(
248            extract_base_url(url),
249            "http://server.example.com:8080/speedtest"
250        );
251    }
252
253    #[test]
254    fn test_extract_base_url_no_suffix() {
255        let url = "http://server.example.com/speedtest";
256        assert_eq!(extract_base_url(url), "http://server.example.com/speedtest");
257    }
258
259    #[test]
260    fn test_extract_base_url_different_path() {
261        let url = "https://cdn.speedtest.net/upload.php";
262        assert_eq!(extract_base_url(url), "https://cdn.speedtest.net");
263    }
264
265    #[test]
266    fn test_estimated_download_bytes_constant() {
267        // Verify the constant is reasonable (around 15 MB)
268        const _: () = assert!(ESTIMATED_DOWNLOAD_BYTES > 10_000_000);
269        const _: () = assert!(ESTIMATED_DOWNLOAD_BYTES < 20_000_000);
270    }
271
272    #[test]
273    fn test_sample_interval_constant() {
274        // Verify sample interval is 50ms (20 Hz)
275        const _: () = assert!(SAMPLE_INTERVAL_MS == 50);
276    }
277}