Skip to main content

cli_speedtest/
utils.rs

1// src/utils.rs
2
3use crate::models::AppConfig;
4use indicatif::{ProgressBar, ProgressStyle};
5use std::fmt;
6use std::time::Duration;
7use tracing::debug;
8
9pub const WARMUP_SECS: f64 = 2.0;
10
11pub fn create_spinner(msg: &str, config: &AppConfig, style_template: &str) -> ProgressBar {
12    if config.quiet {
13        ProgressBar::hidden()
14    } else {
15        let pb = ProgressBar::new_spinner();
16        if let Ok(style) = ProgressStyle::default_spinner().template(style_template) {
17            pb.set_style(style);
18        }
19        pb.set_message(msg.to_string());
20        pb.enable_steady_tick(Duration::from_millis(100));
21        pb
22    }
23}
24
25pub fn calculate_mbps(bytes: u64, duration_secs: f64) -> f64 {
26    if duration_secs <= 0.0 {
27        return 0.0;
28    }
29    let megabytes = (bytes as f64) / (1024.0 * 1024.0);
30    (megabytes * 8.0) / duration_secs
31}
32
33pub async fn with_retry<F, Fut, T>(max_retries: u32, mut f: F) -> anyhow::Result<T>
34where
35    F: FnMut() -> Fut,
36    Fut: std::future::Future<Output = anyhow::Result<T>>,
37{
38    let mut last_err = anyhow::anyhow!("No attempts made");
39    for attempt in 0..=max_retries {
40        match f().await {
41            Ok(val) => return Ok(val),
42            Err(e) => {
43                // If the closure wrapped the error as NonRetryable, bail instantly
44                if let Some(nre) = e.downcast_ref::<NonRetryableError>() {
45                    return Err(anyhow::anyhow!("{}", nre.0));
46                }
47                if attempt < max_retries {
48                    let backoff = Duration::from_millis(100 * 2u64.pow(attempt));
49                    debug!(
50                        "Request failed (attempt {}/{}): {}. Retrying in {:?}...",
51                        attempt + 1,
52                        max_retries + 1,
53                        e,
54                        backoff
55                    );
56                    tokio::time::sleep(backoff).await;
57                }
58                last_err = e;
59            }
60        }
61    }
62    Err(last_err)
63}
64
65/// Marker error that tells with_retry to bail immediately without retrying.
66/// Used for HTTP 429 / 403 where retrying is actively harmful.
67#[derive(Debug)]
68pub struct NonRetryableError(pub anyhow::Error);
69
70// anyhow::Error::new() requires StdError, which requires Display
71impl fmt::Display for NonRetryableError {
72    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
73        self.0.fmt(f)
74    }
75}
76
77// This is the missing trait that causes the compiler error
78impl std::error::Error for NonRetryableError {}
79
80#[cfg(test)]
81mod tests {
82    use super::*;
83    use std::sync::Arc;
84    use std::sync::atomic::{AtomicU32, Ordering};
85
86    // --- calculate_mbps ---
87
88    #[test]
89    fn mbps_correct_for_known_value() {
90        // 12.5 MiB in 1 second = exactly 100 Mbps
91        let bytes = 13_107_200u64; // 12.5 * 1024 * 1024
92        let speed = calculate_mbps(bytes, 1.0);
93        assert!(
94            (speed - 100.0).abs() < 0.001,
95            "Expected 100 Mbps, got {}",
96            speed
97        );
98    }
99
100    #[test]
101    fn mbps_zero_for_zero_duration() {
102        assert_eq!(calculate_mbps(1_000_000, 0.0), 0.0);
103    }
104
105    #[test]
106    fn mbps_zero_for_negative_duration() {
107        assert_eq!(calculate_mbps(1_000_000, -5.0), 0.0);
108    }
109
110    #[test]
111    fn mbps_zero_bytes_gives_zero() {
112        assert_eq!(calculate_mbps(0, 10.0), 0.0);
113    }
114
115    // --- with_retry ---
116
117    #[tokio::test]
118    async fn retry_succeeds_on_first_attempt() {
119        let result = with_retry(3, || async { Ok::<i32, anyhow::Error>(42) }).await;
120        assert!(result.is_ok());
121        assert_eq!(result.unwrap(), 42);
122    }
123
124    #[tokio::test]
125    async fn retry_succeeds_on_second_attempt() {
126        let attempts = Arc::new(AtomicU32::new(0));
127        let attempts_c = attempts.clone();
128
129        let result = with_retry(3, move || {
130            let counter = attempts_c.clone();
131            async move {
132                let n = counter.fetch_add(1, Ordering::SeqCst);
133                if n == 0 {
134                    anyhow::bail!("transient error");
135                }
136                Ok::<i32, anyhow::Error>(99)
137            }
138        })
139        .await;
140
141        assert!(result.is_ok());
142        assert_eq!(result.unwrap(), 99);
143        assert_eq!(
144            attempts.load(Ordering::SeqCst),
145            2,
146            "Should have taken exactly 2 attempts"
147        );
148    }
149
150    #[tokio::test]
151    async fn retry_exhausts_all_attempts_and_returns_last_error() {
152        let attempts = Arc::new(AtomicU32::new(0));
153        let attempts_c = attempts.clone();
154
155        let result: anyhow::Result<()> = with_retry(2, move || {
156            let counter = attempts_c.clone();
157            async move {
158                counter.fetch_add(1, Ordering::SeqCst);
159                anyhow::bail!("always fails")
160            }
161        })
162        .await;
163
164        assert!(result.is_err());
165        // max_retries = 2 means 3 total attempts: attempt 0, 1, 2
166        assert_eq!(
167            attempts.load(Ordering::SeqCst),
168            3,
169            "Should have attempted exactly max_retries + 1 times"
170        );
171    }
172
173    #[tokio::test]
174    async fn retry_with_zero_retries_attempts_exactly_once() {
175        let attempts = Arc::new(AtomicU32::new(0));
176        let attempts_c = attempts.clone();
177
178        let result: anyhow::Result<()> = with_retry(0, move || {
179            let counter = attempts_c.clone();
180            async move {
181                counter.fetch_add(1, Ordering::SeqCst);
182                anyhow::bail!("fail")
183            }
184        })
185        .await;
186
187        assert!(result.is_err());
188        assert_eq!(
189            attempts.load(Ordering::SeqCst),
190            1,
191            "Zero retries = exactly 1 attempt"
192        );
193    }
194
195    #[tokio::test]
196    async fn non_retryable_error_skips_retry() {
197        let attempts = Arc::new(AtomicU32::new(0));
198        let attempts_c = attempts.clone();
199
200        let result: anyhow::Result<()> = with_retry(3, move || {
201            let counter = attempts_c.clone();
202            async move {
203                counter.fetch_add(1, Ordering::SeqCst);
204                anyhow::bail!(NonRetryableError(anyhow::anyhow!("fatal error")))
205            }
206        })
207        .await;
208
209        assert!(result.is_err());
210        assert_eq!(
211            attempts.load(Ordering::SeqCst),
212            1,
213            "Should have attempted exactly once due to NonRetryableError"
214        );
215    }
216
217    #[tokio::test]
218    async fn retryable_error_uses_all_attempts() {
219        let attempts = Arc::new(AtomicU32::new(0));
220        let attempts_c = attempts.clone();
221
222        let result: anyhow::Result<()> = with_retry(2, move || {
223            let counter = attempts_c.clone();
224            async move {
225                counter.fetch_add(1, Ordering::SeqCst);
226                anyhow::bail!("transient error")
227            }
228        })
229        .await;
230
231        assert!(result.is_err());
232        assert_eq!(
233            attempts.load(Ordering::SeqCst),
234            3,
235            "A regular anyhow::bail! still retries max_retries + 1 times"
236        );
237    }
238}