1use crate::models::AppConfig;
4use indicatif::{ProgressBar, ProgressStyle};
5use std::fmt;
6use std::time::Duration;
7use tracing::debug;
8
9pub const WARMUP_SECS: f64 = 2.0;
10
11pub fn create_spinner(msg: &str, config: &AppConfig, style_template: &str) -> ProgressBar {
12 if config.quiet {
13 ProgressBar::hidden()
14 } else {
15 let pb = ProgressBar::new_spinner();
16 if let Ok(style) = ProgressStyle::default_spinner().template(style_template) {
17 pb.set_style(style);
18 }
19 pb.set_message(msg.to_string());
20 pb.enable_steady_tick(Duration::from_millis(100));
21 pb
22 }
23}
24
25pub fn calculate_mbps(bytes: u64, duration_secs: f64) -> f64 {
26 if duration_secs <= 0.0 {
27 return 0.0;
28 }
29 let megabytes = (bytes as f64) / (1024.0 * 1024.0);
30 (megabytes * 8.0) / duration_secs
31}
32
33pub async fn with_retry<F, Fut, T>(max_retries: u32, mut f: F) -> anyhow::Result<T>
34where
35 F: FnMut() -> Fut,
36 Fut: std::future::Future<Output = anyhow::Result<T>>,
37{
38 let mut last_err = anyhow::anyhow!("No attempts made");
39 for attempt in 0..=max_retries {
40 match f().await {
41 Ok(val) => return Ok(val),
42 Err(e) => {
43 if let Some(nre) = e.downcast_ref::<NonRetryableError>() {
45 return Err(anyhow::anyhow!("{}", nre.0));
46 }
47 if attempt < max_retries {
48 let backoff = Duration::from_millis(100 * 2u64.pow(attempt));
49 debug!(
50 "Request failed (attempt {}/{}): {}. Retrying in {:?}...",
51 attempt + 1,
52 max_retries + 1,
53 e,
54 backoff
55 );
56 tokio::time::sleep(backoff).await;
57 }
58 last_err = e;
59 }
60 }
61 }
62 Err(last_err)
63}
64
65#[derive(Debug)]
68pub struct NonRetryableError(pub anyhow::Error);
69
70impl fmt::Display for NonRetryableError {
72 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
73 self.0.fmt(f)
74 }
75}
76
77impl std::error::Error for NonRetryableError {}
79
80#[cfg(test)]
81mod tests {
82 use super::*;
83 use std::sync::Arc;
84 use std::sync::atomic::{AtomicU32, Ordering};
85
86 #[test]
89 fn mbps_correct_for_known_value() {
90 let bytes = 13_107_200u64; let speed = calculate_mbps(bytes, 1.0);
93 assert!(
94 (speed - 100.0).abs() < 0.001,
95 "Expected 100 Mbps, got {}",
96 speed
97 );
98 }
99
100 #[test]
101 fn mbps_zero_for_zero_duration() {
102 assert_eq!(calculate_mbps(1_000_000, 0.0), 0.0);
103 }
104
105 #[test]
106 fn mbps_zero_for_negative_duration() {
107 assert_eq!(calculate_mbps(1_000_000, -5.0), 0.0);
108 }
109
110 #[test]
111 fn mbps_zero_bytes_gives_zero() {
112 assert_eq!(calculate_mbps(0, 10.0), 0.0);
113 }
114
115 #[tokio::test]
118 async fn retry_succeeds_on_first_attempt() {
119 let result = with_retry(3, || async { Ok::<i32, anyhow::Error>(42) }).await;
120 assert!(result.is_ok());
121 assert_eq!(result.unwrap(), 42);
122 }
123
124 #[tokio::test]
125 async fn retry_succeeds_on_second_attempt() {
126 let attempts = Arc::new(AtomicU32::new(0));
127 let attempts_c = attempts.clone();
128
129 let result = with_retry(3, move || {
130 let counter = attempts_c.clone();
131 async move {
132 let n = counter.fetch_add(1, Ordering::SeqCst);
133 if n == 0 {
134 anyhow::bail!("transient error");
135 }
136 Ok::<i32, anyhow::Error>(99)
137 }
138 })
139 .await;
140
141 assert!(result.is_ok());
142 assert_eq!(result.unwrap(), 99);
143 assert_eq!(
144 attempts.load(Ordering::SeqCst),
145 2,
146 "Should have taken exactly 2 attempts"
147 );
148 }
149
150 #[tokio::test]
151 async fn retry_exhausts_all_attempts_and_returns_last_error() {
152 let attempts = Arc::new(AtomicU32::new(0));
153 let attempts_c = attempts.clone();
154
155 let result: anyhow::Result<()> = with_retry(2, move || {
156 let counter = attempts_c.clone();
157 async move {
158 counter.fetch_add(1, Ordering::SeqCst);
159 anyhow::bail!("always fails")
160 }
161 })
162 .await;
163
164 assert!(result.is_err());
165 assert_eq!(
167 attempts.load(Ordering::SeqCst),
168 3,
169 "Should have attempted exactly max_retries + 1 times"
170 );
171 }
172
173 #[tokio::test]
174 async fn retry_with_zero_retries_attempts_exactly_once() {
175 let attempts = Arc::new(AtomicU32::new(0));
176 let attempts_c = attempts.clone();
177
178 let result: anyhow::Result<()> = with_retry(0, move || {
179 let counter = attempts_c.clone();
180 async move {
181 counter.fetch_add(1, Ordering::SeqCst);
182 anyhow::bail!("fail")
183 }
184 })
185 .await;
186
187 assert!(result.is_err());
188 assert_eq!(
189 attempts.load(Ordering::SeqCst),
190 1,
191 "Zero retries = exactly 1 attempt"
192 );
193 }
194
195 #[tokio::test]
196 async fn non_retryable_error_skips_retry() {
197 let attempts = Arc::new(AtomicU32::new(0));
198 let attempts_c = attempts.clone();
199
200 let result: anyhow::Result<()> = with_retry(3, move || {
201 let counter = attempts_c.clone();
202 async move {
203 counter.fetch_add(1, Ordering::SeqCst);
204 anyhow::bail!(NonRetryableError(anyhow::anyhow!("fatal error")))
205 }
206 })
207 .await;
208
209 assert!(result.is_err());
210 assert_eq!(
211 attempts.load(Ordering::SeqCst),
212 1,
213 "Should have attempted exactly once due to NonRetryableError"
214 );
215 }
216
217 #[tokio::test]
218 async fn retryable_error_uses_all_attempts() {
219 let attempts = Arc::new(AtomicU32::new(0));
220 let attempts_c = attempts.clone();
221
222 let result: anyhow::Result<()> = with_retry(2, move || {
223 let counter = attempts_c.clone();
224 async move {
225 counter.fetch_add(1, Ordering::SeqCst);
226 anyhow::bail!("transient error")
227 }
228 })
229 .await;
230
231 assert!(result.is_err());
232 assert_eq!(
233 attempts.load(Ordering::SeqCst),
234 3,
235 "A regular anyhow::bail! still retries max_retries + 1 times"
236 );
237 }
238}