Skip to main content

cli_speedtest/
client.rs

1// src/client.rs
2
3use bytes::Bytes;
4use futures_util::StreamExt;
5use indicatif::HumanBytes;
6use rand::{Rng, RngCore};
7use reqwest::Client;
8use std::sync::Arc;
9use std::sync::OnceLock;
10use std::sync::atomic::{AtomicU64, Ordering};
11use std::time::{Duration, Instant};
12use tokio::sync::Barrier;
13use tokio_util::sync::CancellationToken;
14
15use crate::models::{AppConfig, PingStats};
16use crate::theme;
17use crate::utils::{NonRetryableError, WARMUP_SECS, calculate_mbps, create_spinner, with_retry};
18
19// src/client.rs - shared helper used in both test_download and test_upload
20fn check_status(r: &reqwest::Response) -> anyhow::Result<()> {
21    match r.status() {
22        s if s.is_success() => Ok(()),
23
24        reqwest::StatusCode::TOO_MANY_REQUESTS => {
25            let (wait_secs, source) = r
26                .headers()
27                .get("retry-after")
28                .and_then(|v| v.to_str().ok())
29                .and_then(|s| s.parse::<u64>().ok())
30                .map(|s| (s, "server says"))
31                .unwrap_or((900, "estimated - no Retry-After header"));
32
33            Err(anyhow::Error::new(NonRetryableError(anyhow::anyhow!(
34                "You've been rate-limited by Cloudflare. \
35                 Please wait {} minutes ({}).\n\n\
36                 Alternatives:\n  \
37                 - Use a custom server:  cli-speedtest --server <URL>\n  \
38                 - Run ping only:        cli-speedtest --no-download --no-upload\n  \
39                 - Force immediate run:  cli-speedtest --force-run",
40                wait_secs / 60,
41                source
42            ))))
43        }
44
45        reqwest::StatusCode::FORBIDDEN => {
46            Err(anyhow::Error::new(NonRetryableError(anyhow::anyhow!(
47                "Cloudflare returned 403 Forbidden. Your IP may have \
48                 triggered Bot Fight Mode. Wait 15 minutes or switch \
49                 servers with: speedtest --server <URL>"
50            ))))
51        }
52
53        s => anyhow::bail!("Request failed with status: {}", s),
54    }
55}
56
57pub async fn test_ping_stats(
58    client: &Client,
59    base_url: &str,
60    count: u32,
61    config: Arc<AppConfig>,
62) -> anyhow::Result<PingStats> {
63    let pb = create_spinner(
64        "Measuring latency & jitter...",
65        &config,
66        "{spinner:.cyan} {msg}",
67    );
68
69    let url = format!("{}/cdn-cgi/trace", base_url);
70    let mut samples: Vec<u128> = Vec::with_capacity(count as usize);
71    let mut lost: u32 = 0;
72
73    for _ in 0..count {
74        let start = Instant::now();
75        match tokio::time::timeout(Duration::from_secs(2), client.head(&url).send()).await {
76            Ok(Ok(_)) => samples.push(start.elapsed().as_millis()),
77            _ => lost += 1,
78        }
79        tokio::time::sleep(Duration::from_millis(50)).await;
80    }
81
82    pb.finish_and_clear();
83
84    if samples.is_empty() {
85        anyhow::bail!("All ping attempts failed - server unreachable");
86    }
87
88    let min_ms = *samples.iter().min().unwrap();
89    let max_ms = *samples.iter().max().unwrap();
90    let avg_ms = samples.iter().sum::<u128>() as f64 / samples.len() as f64;
91
92    let jitter_ms = if samples.len() > 1 {
93        let diffs: Vec<f64> = samples
94            .windows(2)
95            .map(|w| (w[1] as f64 - w[0] as f64).abs())
96            .collect();
97        diffs.iter().sum::<f64>() / diffs.len() as f64
98    } else {
99        0.0
100    };
101
102    let packet_loss_pct = (lost as f64 / count as f64) * 100.0;
103
104    if !config.quiet {
105        println!(
106            "Ping: {} avg  |  Jitter: {}  |  Loss: {}\n",
107            theme::color_ping(avg_ms, &config),
108            theme::color_jitter(jitter_ms, &config),
109            theme::color_loss(packet_loss_pct, &config)
110        );
111    }
112
113    Ok(PingStats {
114        min_ms,
115        max_ms,
116        avg_ms,
117        jitter_ms,
118        packet_loss_pct,
119    })
120}
121
122pub async fn test_download(
123    client: &Client,
124    base_url: &str,
125    duration_secs: u64,
126    num_connections: usize,
127    config: Arc<AppConfig>,
128) -> anyhow::Result<f64> {
129    let chunk_size_bytes = 50 * 1024 * 1024;
130    let total_downloaded = Arc::new(AtomicU64::new(0));
131
132    let pb = create_spinner(
133        "Downloading...",
134        &config,
135        "{spinner:.green} [{elapsed_precise}] {msg}",
136    );
137
138    let token = CancellationToken::new();
139    let barrier = Arc::new(Barrier::new(num_connections + 1)); // +1 for the display task
140    let shared_start: Arc<OnceLock<Instant>> = Arc::new(OnceLock::new());
141    let mut tasks = vec![];
142
143    // Worker tasks
144    for _ in 0..num_connections {
145        let client = client.clone();
146        let pb = pb.clone();
147        let total_downloaded = total_downloaded.clone();
148        let url = format!("{}/__down?bytes={}", base_url, chunk_size_bytes);
149        let barrier = barrier.clone();
150        let shared_start = shared_start.clone();
151        let token = token.clone();
152
153        let task = tokio::spawn(async move {
154            barrier.wait().await;
155            let start = *shared_start.get_or_init(Instant::now);
156
157            'request: loop {
158                if token.is_cancelled() {
159                    break;
160                }
161
162                let res = match with_retry(3, || async {
163                    let r = client.get(&url).send().await?;
164                    check_status(&r)?;
165                    Ok(r)
166                })
167                .await
168                {
169                    Ok(r) => r,
170                    Err(e) => return Err(e),
171                };
172
173                let mut stream = res.bytes_stream();
174
175                loop {
176                    tokio::select! {
177                        biased;
178                        _ = token.cancelled() => break 'request,
179                        item = stream.next() => {
180                            match item {
181                                Some(Ok(chunk)) => {
182                                    let len = chunk.len() as u64;
183                                    pb.inc(len);
184                                    if start.elapsed().as_secs_f64() >= WARMUP_SECS {
185                                        total_downloaded.fetch_add(len, Ordering::Relaxed);
186                                    }
187                                }
188                                Some(Err(e)) => return Err(e.into()),
189                                None => break,
190                            }
191                        }
192                    }
193                }
194
195                let jitter_ms = rand::rng().random_range(50u64..=150);
196                tokio::time::sleep(std::time::Duration::from_millis(jitter_ms)).await;
197            }
198
199            Ok::<(), anyhow::Error>(())
200        });
201
202        tasks.push(task);
203    }
204
205    // Display task
206    let display_task = {
207        let pb = pb.clone();
208        let total_downloaded = total_downloaded.clone();
209        let token = token.clone();
210        let config = config.clone();
211        let barrier = barrier.clone();
212
213        tokio::spawn(async move {
214            barrier.wait().await;
215            let mut prev_bytes = 0;
216            let mut prev_instant = Instant::now();
217
218            loop {
219                tokio::select! {
220                    _ = token.cancelled() => break,
221                    _ = tokio::time::sleep(Duration::from_millis(250)) => {
222                        let now_bytes = total_downloaded.load(Ordering::Relaxed);
223                        let delta = now_bytes.saturating_sub(prev_bytes);
224                        let elapsed = prev_instant.elapsed().as_secs_f64();
225                        let speed = calculate_mbps(delta, elapsed);
226
227                        let speed_str = if speed == 0.0 && now_bytes == 0 {
228                            "↓  --.- Mbps".to_string()
229                        } else {
230                            format!("↓  {}", theme::color_speed(speed, &config))
231                        };
232
233                        pb.set_message(format!(
234                            "{}    {} total",
235                            speed_str,
236                            HumanBytes(now_bytes)
237                        ));
238
239                        prev_bytes = now_bytes;
240                        prev_instant = Instant::now();
241                    }
242                }
243            }
244        })
245    };
246
247    tokio::time::sleep(Duration::from_secs(duration_secs)).await;
248    token.cancel();
249
250    for task in tasks {
251        task.await??;
252    }
253    display_task.await?;
254
255    pb.finish_and_clear();
256
257    let start = shared_start.get().copied().unwrap_or_else(Instant::now);
258    let effective_duration = (start.elapsed().as_secs_f64() - WARMUP_SECS).max(0.0);
259    Ok(calculate_mbps(
260        total_downloaded.load(Ordering::Relaxed),
261        effective_duration,
262    ))
263}
264
265pub async fn test_upload(
266    client: &Client,
267    base_url: &str,
268    duration_secs: u64,
269    num_connections: usize,
270    config: Arc<AppConfig>,
271) -> anyhow::Result<f64> {
272    let chunk_size = 2 * 1024 * 1024;
273    let total_uploaded = Arc::new(AtomicU64::new(0));
274
275    let pb = create_spinner(
276        "Uploading...",
277        &config,
278        "{spinner:.red} [{elapsed_precise}] {msg}",
279    );
280
281    let token = CancellationToken::new();
282    let barrier = Arc::new(Barrier::new(num_connections + 1));
283    let shared_start: Arc<OnceLock<Instant>> = Arc::new(OnceLock::new());
284    let mut tasks = vec![];
285
286    // Worker tasks
287    for _ in 0..num_connections {
288        let client = client.clone();
289        let pb = pb.clone();
290        let total_uploaded = total_uploaded.clone();
291        let url = format!("{}/__up", base_url);
292        let barrier = barrier.clone();
293        let shared_start = shared_start.clone();
294        let token = token.clone();
295
296        let task = tokio::spawn(async move {
297            barrier.wait().await;
298            let start = *shared_start.get_or_init(Instant::now);
299
300            let mut raw_payload = vec![0u8; chunk_size];
301            rand::rng().fill_bytes(&mut raw_payload);
302            let payload = Bytes::from(raw_payload);
303
304            loop {
305                if token.is_cancelled() {
306                    break;
307                }
308
309                match with_retry(3, || async {
310                    let r = client
311                        .post(url.clone())
312                        .body(payload.clone())
313                        .send()
314                        .await?;
315                    check_status(&r)?;
316                    Ok(r)
317                })
318                .await
319                {
320                    Ok(_) => {
321                        let len = payload.len() as u64;
322                        pb.inc(len);
323                        if start.elapsed().as_secs_f64() >= WARMUP_SECS {
324                            total_uploaded.fetch_add(len, Ordering::Relaxed);
325                        }
326                    }
327                    Err(e) => return Err(e),
328                }
329
330                let jitter_ms = rand::rng().random_range(50u64..=150);
331                tokio::time::sleep(Duration::from_millis(jitter_ms)).await;
332            }
333
334            Ok::<(), anyhow::Error>(())
335        });
336
337        tasks.push(task);
338    }
339
340    // Display task
341    let display_task = {
342        let pb = pb.clone();
343        let total_uploaded = total_uploaded.clone();
344        let token = token.clone();
345        let config = config.clone();
346        let barrier = barrier.clone();
347
348        tokio::spawn(async move {
349            barrier.wait().await;
350            let mut prev_bytes = 0;
351            let mut prev_instant = Instant::now();
352
353            loop {
354                tokio::select! {
355                    _ = token.cancelled() => break,
356                    _ = tokio::time::sleep(Duration::from_millis(250)) => {
357                        let now_bytes = total_uploaded.load(Ordering::Relaxed);
358                        let delta = now_bytes.saturating_sub(prev_bytes);
359                        let elapsed = prev_instant.elapsed().as_secs_f64();
360                        let speed = calculate_mbps(delta, elapsed);
361
362                        let speed_str = if speed == 0.0 && now_bytes == 0 {
363                            "↑  --.- Mbps".to_string()
364                        } else {
365                            format!("↑  {}", theme::color_speed(speed, &config))
366                        };
367
368                        pb.set_message(format!(
369                            "{}    {} total",
370                            speed_str,
371                            HumanBytes(now_bytes)
372                        ));
373
374                        prev_bytes = now_bytes;
375                        prev_instant = Instant::now();
376                    }
377                }
378            }
379        })
380    };
381
382    tokio::time::sleep(Duration::from_secs(duration_secs)).await;
383    token.cancel();
384
385    for task in tasks {
386        task.await??;
387    }
388    display_task.await?;
389
390    pb.finish_and_clear();
391
392    let start = shared_start.get().copied().unwrap_or_else(Instant::now);
393    let effective_duration = (start.elapsed().as_secs_f64() - WARMUP_SECS).max(0.0);
394    Ok(calculate_mbps(
395        total_uploaded.load(Ordering::Relaxed),
396        effective_duration,
397    ))
398}