1use crate::error::{AixError, AixResult};
7use rand::Rng;
8use std::future::Future;
9use std::time::Duration;
10
11#[derive(Debug, Clone)]
13pub struct RetryConfig {
14 pub max_attempts: u32,
16 pub initial_backoff: Duration,
18 pub max_backoff: Duration,
20 pub multiplier: f64,
22 pub jitter: bool,
24 pub retry_on_rate_limit: bool,
26 pub retry_on_transport: bool,
28 pub retry_on_timeout: bool,
30}
31
32impl RetryConfig {
33 pub fn new() -> Self {
35 Self::default()
36 }
37
38 pub fn builder() -> RetryConfigBuilder {
40 RetryConfigBuilder::new()
41 }
42
43 pub fn with_max_attempts(mut self, max_attempts: u32) -> Self {
45 self.max_attempts = max_attempts;
46 self
47 }
48
49 pub fn with_initial_backoff(mut self, initial_backoff: Duration) -> Self {
51 self.initial_backoff = initial_backoff;
52 self
53 }
54
55 pub fn with_max_backoff(mut self, max_backoff: Duration) -> Self {
57 self.max_backoff = max_backoff;
58 self
59 }
60
61 pub fn with_multiplier(mut self, multiplier: f64) -> Self {
63 self.multiplier = multiplier;
64 self
65 }
66
67 pub fn with_jitter(mut self, jitter: bool) -> Self {
69 self.jitter = jitter;
70 self
71 }
72
73 pub fn with_retry_on_rate_limit(mut self, retry: bool) -> Self {
75 self.retry_on_rate_limit = retry;
76 self
77 }
78
79 pub fn with_retry_on_transport(mut self, retry: bool) -> Self {
81 self.retry_on_transport = retry;
82 self
83 }
84
85 pub fn with_retry_on_timeout(mut self, retry: bool) -> Self {
87 self.retry_on_timeout = retry;
88 self
89 }
90
91 pub fn should_retry(&self, error: &AixError) -> bool {
93 if !error.is_retryable() {
94 return false;
95 }
96
97 match error {
98 AixError::RateLimit { .. } => self.retry_on_rate_limit,
99 AixError::Transport { .. } => self.retry_on_transport,
100 AixError::Timeout { .. } => self.retry_on_timeout,
101 AixError::Provider { status, .. } => {
102 status.map_or(false, |s| s >= 500)
104 }
105 _ => false,
106 }
107 }
108
109 pub fn calculate_delay(&self, attempt: u32) -> Duration {
117 let base_delay = self.initial_backoff.as_secs_f64() * self.multiplier.powi(attempt as i32);
119 let base_delay = Duration::from_secs_f64(base_delay);
120
121 let delay = std::cmp::min(base_delay, self.max_backoff);
123
124 if self.jitter {
126 let jitter_range = delay.as_secs_f64() * 0.5; let jitter = rand::thread_rng().gen_range(0.0..jitter_range);
128 let actual_delay = delay.as_secs_f64() * (0.5 + jitter / jitter_range);
129 Duration::from_secs_f64(actual_delay)
130 } else {
131 delay
132 }
133 }
134
135 pub fn extract_retry_delay(&self, error: &AixError) -> Option<Duration> {
137 match error {
138 AixError::RateLimit { retry_after, .. } => *retry_after,
139 _ => None,
140 }
141 }
142}
143
144impl Default for RetryConfig {
145 fn default() -> Self {
146 Self {
147 max_attempts: 3,
148 initial_backoff: Duration::from_millis(1000),
149 max_backoff: Duration::from_secs(30),
150 multiplier: 2.0,
151 jitter: true,
152 retry_on_rate_limit: true,
153 retry_on_transport: true,
154 retry_on_timeout: true,
155 }
156 }
157}
158
159pub struct RetryConfigBuilder {
161 config: RetryConfig,
162}
163
164impl RetryConfigBuilder {
165 pub fn new() -> Self {
167 Self {
168 config: RetryConfig::default(),
169 }
170 }
171
172 pub fn max_attempts(mut self, max_attempts: u32) -> Self {
174 self.config.max_attempts = max_attempts;
175 self
176 }
177
178 pub fn initial_backoff(mut self, initial_backoff: Duration) -> Self {
180 self.config.initial_backoff = initial_backoff;
181 self
182 }
183
184 pub fn max_backoff(mut self, max_backoff: Duration) -> Self {
186 self.config.max_backoff = max_backoff;
187 self
188 }
189
190 pub fn multiplier(mut self, multiplier: f64) -> Self {
192 self.config.multiplier = multiplier;
193 self
194 }
195
196 pub fn jitter(mut self, jitter: bool) -> Self {
198 self.config.jitter = jitter;
199 self
200 }
201
202 pub fn retry_on_rate_limit(mut self, retry: bool) -> Self {
204 self.config.retry_on_rate_limit = retry;
205 self
206 }
207
208 pub fn retry_on_transport(mut self, retry: bool) -> Self {
210 self.config.retry_on_transport = retry;
211 self
212 }
213
214 pub fn retry_on_timeout(mut self, retry: bool) -> Self {
216 self.config.retry_on_timeout = retry;
217 self
218 }
219
220 pub fn build(self) -> RetryConfig {
222 self.config
223 }
224}
225
226impl Default for RetryConfigBuilder {
227 fn default() -> Self {
228 Self::new()
229 }
230}
231
232pub struct RetryStrategy {
234 config: RetryConfig,
235}
236
237impl RetryStrategy {
238 pub fn new(config: RetryConfig) -> Self {
240 Self { config }
241 }
242
243 pub async fn execute<F, Fut, T>(&self, mut f: F) -> AixResult<T>
251 where
252 F: FnMut() -> Fut,
253 Fut: Future<Output = AixResult<T>>,
254 {
255 let mut last_error = None;
256
257 for attempt in 0..self.config.max_attempts {
258 match f().await {
260 Ok(result) => return Ok(result),
261 Err(error) => {
262 last_error = Some(error.clone());
263
264 if !self.config.should_retry(&error) {
266 return Err(error);
267 }
268
269 if attempt == self.config.max_attempts - 1 {
271 return Err(error);
272 }
273
274 let delay = self
276 .config
277 .extract_retry_delay(&error)
278 .unwrap_or_else(|| self.config.calculate_delay(attempt));
279
280 tokio::time::sleep(delay).await;
282 }
283 }
284
285 }
287
288 Err(last_error.unwrap_or_else(|| AixError::other("All retry attempts failed")))
290 }
291
292 pub fn config(&self) -> &RetryConfig {
294 &self.config
295 }
296
297 pub fn config_mut(&mut self) -> &mut RetryConfig {
299 &mut self.config
300 }
301}
302
303impl From<RetryConfig> for RetryStrategy {
304 fn from(config: RetryConfig) -> Self {
305 Self::new(config)
306 }
307}
308
309impl RetryConfig {
311 pub fn no_retry() -> Self {
313 Self {
314 max_attempts: 1,
315 initial_backoff: Duration::from_millis(0),
316 max_backoff: Duration::from_millis(0),
317 multiplier: 1.0,
318 jitter: false,
319 retry_on_rate_limit: false,
320 retry_on_transport: false,
321 retry_on_timeout: false,
322 }
323 }
324
325 pub fn conservative() -> Self {
327 Self {
328 max_attempts: 2,
329 initial_backoff: Duration::from_secs(2),
330 max_backoff: Duration::from_secs(10),
331 multiplier: 2.0,
332 jitter: true,
333 retry_on_rate_limit: true,
334 retry_on_transport: false, retry_on_timeout: false,
336 }
337 }
338
339 pub fn aggressive() -> Self {
341 Self {
342 max_attempts: 5,
343 initial_backoff: Duration::from_millis(500),
344 max_backoff: Duration::from_secs(30),
345 multiplier: 1.5,
346 jitter: true,
347 retry_on_rate_limit: true,
348 retry_on_transport: true,
349 retry_on_timeout: true,
350 }
351 }
352
353 pub fn fast() -> Self {
355 Self {
356 max_attempts: 3,
357 initial_backoff: Duration::from_millis(200),
358 max_backoff: Duration::from_secs(5),
359 multiplier: 1.5,
360 jitter: true,
361 retry_on_rate_limit: true,
362 retry_on_transport: true,
363 retry_on_timeout: false, }
365 }
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371 use crate::error::AixError;
372
373 #[test]
374 fn test_retry_config_builder() {
375 let config = RetryConfig::builder()
376 .max_attempts(5)
377 .initial_backoff(Duration::from_millis(500))
378 .max_backoff(Duration::from_secs(10))
379 .multiplier(1.5)
380 .jitter(false)
381 .retry_on_rate_limit(false)
382 .build();
383
384 assert_eq!(config.max_attempts, 5);
385 assert_eq!(config.initial_backoff, Duration::from_millis(500));
386 assert_eq!(config.max_backoff, Duration::from_secs(10));
387 assert_eq!(config.multiplier, 1.5);
388 assert!(!config.jitter);
389 assert!(!config.retry_on_rate_limit);
390 }
391
392 #[test]
393 fn test_backoff_calculation() {
394 let config = RetryConfig {
395 max_attempts: 3,
396 initial_backoff: Duration::from_millis(1000),
397 max_backoff: Duration::from_secs(10),
398 multiplier: 2.0,
399 jitter: false, retry_on_rate_limit: true,
401 retry_on_transport: true,
402 retry_on_timeout: true,
403 };
404
405 assert_eq!(config.calculate_delay(0), Duration::from_millis(1000));
407 assert_eq!(config.calculate_delay(1), Duration::from_millis(2000));
409 assert_eq!(config.calculate_delay(2), Duration::from_millis(4000));
411
412 let long_config = RetryConfig {
414 max_attempts: 10,
415 initial_backoff: Duration::from_millis(1000),
416 max_backoff: Duration::from_millis(3000),
417 multiplier: 2.0,
418 jitter: false,
419 retry_on_rate_limit: true,
420 retry_on_transport: true,
421 retry_on_timeout: true,
422 };
423
424 assert_eq!(long_config.calculate_delay(3), Duration::from_millis(3000));
426 }
427
428 #[test]
429 fn test_jitter() {
430 let config = RetryConfig {
431 max_attempts: 3,
432 initial_backoff: Duration::from_millis(1000),
433 max_backoff: Duration::from_secs(10),
434 multiplier: 2.0,
435 jitter: true,
436 retry_on_rate_limit: true,
437 retry_on_transport: true,
438 retry_on_timeout: true,
439 };
440
441 let delay = config.calculate_delay(0);
443 assert!(delay >= Duration::from_millis(500));
444 assert!(delay <= Duration::from_millis(1500));
445 }
446
447 #[test]
448 fn test_should_retry() {
449 let config = RetryConfig::default();
450
451 assert!(config.should_retry(&AixError::transport("network error", "request")));
453 assert!(config.should_retry(&AixError::rate_limit("openai", "too many requests")));
454 assert!(config.should_retry(&AixError::timeout("chat", Duration::from_secs(30))));
455 assert!(config.should_retry(&AixError::provider_with_details("openai", "server error", 500, "internal_error")));
456
457 assert!(!config.should_retry(&AixError::config("invalid config")));
459 assert!(!config.should_retry(&AixError::auth("openai", "unauthorized")));
460 assert!(!config.should_retry(&AixError::provider_with_details("openai", "bad request", 400, "invalid_request")));
461 }
462
463 #[tokio::test]
464 async fn test_retry_strategy_success() {
465 let strategy = RetryStrategy::new(RetryConfig::default());
466 let mut call_count = 0;
467
468 let result = strategy
469 .execute(|| {
470 call_count += 1;
471 async move { Ok::<_, AixError>("success") }
472 })
473 .await;
474
475 assert_eq!(result.unwrap(), "success");
476 assert_eq!(call_count, 1); }
478
479 #[tokio::test]
480 async fn test_retry_strategy_with_retry() {
481 let strategy = RetryStrategy::new(RetryConfig::builder().max_attempts(3).build());
482 let mut call_count = 0;
483
484 let result = strategy
485 .execute(|| {
486 call_count += 1;
487 async move {
488 if call_count < 3 {
489 Err::<_, AixError>(AixError::transport("network error", "request"))
490 } else {
491 Ok("success")
492 }
493 }
494 })
495 .await;
496
497 assert_eq!(result.unwrap(), "success");
498 assert_eq!(call_count, 3); }
500
501 #[tokio::test]
502 async fn test_retry_strategy_exhausted() {
503 let strategy = RetryStrategy::new(RetryConfig::builder().max_attempts(2).build());
504 let mut call_count = 0;
505
506 let result = strategy
507 .execute(|| {
508 call_count += 1;
509 async move {
510 Err::<_, AixError>(AixError::transport("network error", "request"))
511 }
512 })
513 .await;
514
515 assert!(result.is_err());
516 assert_eq!(call_count, 2); }
518
519 #[test]
520 fn test_preset_configs() {
521 let no_retry = RetryConfig::no_retry();
522 assert_eq!(no_retry.max_attempts, 1);
523 assert!(!no_retry.retry_on_rate_limit);
524
525 let conservative = RetryConfig::conservative();
526 assert_eq!(conservative.max_attempts, 2);
527 assert!(!conservative.retry_on_transport);
528
529 let aggressive = RetryConfig::aggressive();
530 assert_eq!(aggressive.max_attempts, 5);
531 assert!(aggressive.retry_on_transport);
532
533 let fast = RetryConfig::fast();
534 assert_eq!(fast.max_attempts, 3);
535 assert_eq!(fast.initial_backoff, Duration::from_millis(200));
536 }
537}