1use super::ProviderError;
8use std::sync::Arc;
9use std::time::{Duration, SystemTime, UNIX_EPOCH};
10
11#[derive(Debug, Clone)]
13pub struct RetryConfig {
14 pub max_attempts: usize,
16 pub base_delay_ms: u64,
18 pub max_delay_ms: u64,
20}
21
22impl Default for RetryConfig {
23 fn default() -> Self {
24 Self {
25 max_attempts: 8,
26 base_delay_ms: 500,
27 max_delay_ms: 30_000,
28 }
29 }
30}
31
32#[derive(Debug, Clone)]
34pub struct RetryInfo {
35 pub attempt: usize,
37 pub max_attempts: usize,
39 pub delay: Duration,
41 pub error: String,
43}
44
45pub type RetryCallback = Arc<dyn Fn(RetryInfo) + Send + Sync>;
47
48pub fn is_retryable_error(err: &ProviderError) -> bool {
50 match err {
51 ProviderError::RateLimited(_) => true,
53 ProviderError::ServiceUnavailable(_) => true,
54 ProviderError::Network(_) => true,
55 ProviderError::Communication(_) => true,
56
57 ProviderError::Authentication(_) => false,
59 ProviderError::Configuration(_) => false,
60 ProviderError::Model(_) => false,
61 ProviderError::Other(_) => false,
62 }
63}
64
65pub fn backoff_delay(attempt: usize, config: &RetryConfig) -> Duration {
67 let shift = (attempt.saturating_sub(1)).min(10) as u32;
68 let exp = 1_u64.checked_shl(shift).unwrap_or(u64::MAX);
69 let base = config.base_delay_ms.saturating_mul(exp);
70 let capped = base.min(config.max_delay_ms);
71 let jittered = jitter_ms(capped);
72 Duration::from_millis(jittered)
73}
74
75fn jitter_ms(base_ms: u64) -> u64 {
77 let nanos = SystemTime::now()
78 .duration_since(UNIX_EPOCH)
79 .unwrap_or_default()
80 .subsec_nanos() as i64;
81 let jitter_pct = (nanos % 41) - 20; let base = base_ms as i64;
83 let jittered = base + (base * jitter_pct / 100);
84 jittered.max(0) as u64
85}
86
87pub async fn retry_with_backoff<F, Fut, T>(
102 mut op: F,
103 config: &RetryConfig,
104 on_retry: &Option<RetryCallback>,
105) -> Result<T, ProviderError>
106where
107 F: FnMut() -> Fut,
108 Fut: std::future::Future<Output = Result<T, ProviderError>>,
109{
110 let mut attempt = 0;
111 loop {
112 attempt += 1;
113 match op().await {
114 Ok(result) => return Ok(result),
115 Err(err) => {
116 if attempt >= config.max_attempts || !is_retryable_error(&err) {
117 return Err(err);
118 }
119 let delay = backoff_delay(attempt, config);
120
121 if let Some(callback) = on_retry {
123 callback(RetryInfo {
124 attempt,
125 max_attempts: config.max_attempts,
126 delay,
127 error: err.to_string(),
128 });
129 }
130
131 tokio::time::sleep(delay).await;
132 }
133 }
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140
141 #[test]
142 fn test_retry_config_default() {
143 let config = RetryConfig::default();
144 assert_eq!(config.max_attempts, 8);
145 assert_eq!(config.base_delay_ms, 500);
146 assert_eq!(config.max_delay_ms, 30_000);
147 }
148
149 #[test]
150 fn test_is_retryable_error_rate_limited() {
151 assert!(is_retryable_error(&ProviderError::RateLimited(
152 "too many requests".into()
153 )));
154 }
155
156 #[test]
157 fn test_is_retryable_error_service_unavailable() {
158 assert!(is_retryable_error(&ProviderError::ServiceUnavailable(
159 "503".into()
160 )));
161 }
162
163 #[test]
164 fn test_is_retryable_error_network() {
165 assert!(is_retryable_error(&ProviderError::Network(
166 "connection refused".into()
167 )));
168 }
169
170 #[test]
171 fn test_is_retryable_error_communication() {
172 assert!(is_retryable_error(&ProviderError::Communication(
173 "timeout".into()
174 )));
175 }
176
177 #[test]
178 fn test_is_retryable_error_not_retryable() {
179 assert!(!is_retryable_error(&ProviderError::Authentication(
181 "bad creds".into()
182 )));
183
184 assert!(!is_retryable_error(&ProviderError::Configuration(
186 "invalid model".into()
187 )));
188
189 assert!(!is_retryable_error(&ProviderError::Model(
191 "content filtered".into()
192 )));
193
194 assert!(!is_retryable_error(&ProviderError::Other("unknown".into())));
196 }
197
198 #[test]
199 fn test_backoff_delay_first_attempt() {
200 let config = RetryConfig::default();
201 let delay = backoff_delay(1, &config);
202
203 assert!(delay.as_millis() >= 400);
206 assert!(delay.as_millis() <= 600);
207 }
208
209 #[test]
210 fn test_backoff_delay_exponential_growth() {
211 let config = RetryConfig {
212 base_delay_ms: 100,
213 max_delay_ms: 10_000,
214 max_attempts: 10,
215 };
216
217 let delay1 = backoff_delay(1, &config);
218 let delay2 = backoff_delay(2, &config);
219 let delay3 = backoff_delay(3, &config);
220
221 assert!(delay2.as_millis() > delay1.as_millis());
224 assert!(delay3.as_millis() > delay2.as_millis());
225 }
226
227 #[test]
228 fn test_backoff_delay_respects_max() {
229 let config = RetryConfig {
230 base_delay_ms: 1000,
231 max_delay_ms: 2000,
232 max_attempts: 10,
233 };
234
235 let delay = backoff_delay(10, &config);
237 assert!(delay.as_millis() <= 2400);
239 }
240
241 #[test]
242 fn test_jitter_ms_produces_variation() {
243 let base = 1000u64;
245
246 let jittered = jitter_ms(base);
249 assert!(jittered >= 800); assert!(jittered <= 1200); }
252
253 #[tokio::test]
254 async fn test_retry_with_backoff_success_first_try() {
255 let config = RetryConfig {
256 max_attempts: 3,
257 base_delay_ms: 10,
258 max_delay_ms: 100,
259 };
260
261 let mut call_count = 0;
262 let result = retry_with_backoff(
263 || {
264 call_count += 1;
265 async { Ok::<_, ProviderError>("success") }
266 },
267 &config,
268 &None,
269 )
270 .await;
271
272 assert!(result.is_ok());
273 assert_eq!(result.unwrap(), "success");
274 assert_eq!(call_count, 1);
275 }
276
277 #[tokio::test]
278 async fn test_retry_with_backoff_retries_on_transient_error() {
279 let config = RetryConfig {
280 max_attempts: 3,
281 base_delay_ms: 1, max_delay_ms: 10,
283 };
284
285 let call_count = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
286 let count_clone = call_count.clone();
287
288 let result = retry_with_backoff(
289 || {
290 let count = count_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
291 async move {
292 if count < 2 {
293 Err(ProviderError::RateLimited("throttled".into()))
294 } else {
295 Ok("success after retry")
296 }
297 }
298 },
299 &config,
300 &None,
301 )
302 .await;
303
304 assert!(result.is_ok());
305 assert_eq!(result.unwrap(), "success after retry");
306 assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 3);
307 }
308
309 #[tokio::test]
310 async fn test_retry_with_backoff_gives_up_after_max_attempts() {
311 let config = RetryConfig {
312 max_attempts: 2,
313 base_delay_ms: 1,
314 max_delay_ms: 10,
315 };
316
317 let call_count = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
318 let count_clone = call_count.clone();
319
320 let result: Result<(), ProviderError> = retry_with_backoff(
321 || {
322 count_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
323 async { Err(ProviderError::RateLimited("always throttled".into())) }
324 },
325 &config,
326 &None,
327 )
328 .await;
329
330 assert!(result.is_err());
331 assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 2);
332 }
333
334 #[tokio::test]
335 async fn test_retry_with_backoff_no_retry_on_permanent_error() {
336 let config = RetryConfig {
337 max_attempts: 5,
338 base_delay_ms: 1,
339 max_delay_ms: 10,
340 };
341
342 let call_count = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
343 let count_clone = call_count.clone();
344
345 let result: Result<(), ProviderError> = retry_with_backoff(
346 || {
347 count_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
348 async { Err(ProviderError::Authentication("bad credentials".into())) }
349 },
350 &config,
351 &None,
352 )
353 .await;
354
355 assert!(result.is_err());
356 assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 1);
358 }
359
360 #[tokio::test]
361 async fn test_retry_with_backoff_callback_invoked() {
362 let config = RetryConfig {
363 max_attempts: 3,
364 base_delay_ms: 1,
365 max_delay_ms: 10,
366 };
367
368 let callback_count = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
369 let callback_count_clone = callback_count.clone();
370
371 let callback: RetryCallback = Arc::new(move |info: RetryInfo| {
372 callback_count_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
373 assert!(info.attempt > 0);
374 assert_eq!(info.max_attempts, 3);
375 });
376
377 let attempt = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
378 let attempt_clone = attempt.clone();
379
380 let _result: Result<(), ProviderError> = retry_with_backoff(
381 || {
382 let count = attempt_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
383 async move {
384 if count < 2 {
385 Err(ProviderError::ServiceUnavailable("503".into()))
386 } else {
387 Ok(())
388 }
389 }
390 },
391 &config,
392 &Some(callback),
393 )
394 .await;
395
396 assert_eq!(callback_count.load(std::sync::atomic::Ordering::SeqCst), 2);
398 }
399}