ai_agent/services/api/
with_retry.rs1use std::future::Future;
4use std::time::Duration;
5
6#[derive(Debug, Clone)]
8pub struct RetryConfig {
9 pub max_retries: u32,
11 pub initial_delay_ms: u64,
13 pub max_delay_ms: u64,
15 pub backoff_multiplier: f64,
17 pub retryable_status_codes: Vec<u16>,
19}
20
21impl Default for RetryConfig {
22 fn default() -> Self {
23 Self {
24 max_retries: 3,
25 initial_delay_ms: 1000,
26 max_delay_ms: 10000,
27 backoff_multiplier: 2.0,
28 retryable_status_codes: vec![429, 500, 502, 503, 504],
29 }
30 }
31}
32
33pub enum RetryResult<T> {
35 Success(T),
36 RetriesExhausted(T),
37 Error(String),
38}
39
40pub fn is_retryable_status(code: u16, config: &RetryConfig) -> bool {
42 config.retryable_status_codes.contains(&code)
43}
44
45pub fn calculate_delay(attempt: u32, config: &RetryConfig) -> Duration {
47 let delay = config.initial_delay_ms as f64 * config.backoff_multiplier.powi(attempt as i32);
48 let delay = delay.min(config.max_delay_ms as f64);
49 Duration::from_millis(delay as u64)
50}
51
52pub async fn with_retry<T, E, F, Fut>(operation: F, config: RetryConfig) -> Result<T, E>
54where
55 F: Fn() -> Fut,
56 Fut: Future<Output = Result<T, E>>,
57 E: std::fmt::Debug,
58{
59 let mut last_error: Option<E> = None;
60
61 for attempt in 0..=config.max_retries {
62 match operation().await {
63 Ok(result) => {
64 if attempt > 0 {
65 }
67 return Ok(result);
68 }
69 Err(e) => {
70 last_error = Some(e);
71
72 if attempt < config.max_retries {
73 let delay = calculate_delay(attempt, &config);
74 tokio::time::sleep(delay).await;
75 }
76 }
77 }
78 }
79
80 Err(last_error.unwrap())
81}
82
83pub async fn with_retry_after<T, E, F, Fut>(operation: F, config: RetryConfig) -> Result<T, E>
85where
86 F: Fn(Option<u64>) -> Fut,
87 Fut: Future<Output = Result<T, E>>,
88 E: std::fmt::Debug,
89{
90 let mut last_error: Option<E> = None;
91 let mut retry_after: Option<u64> = None;
92
93 for attempt in 0..=config.max_retries {
94 let delay = retry_after.or_else(|| {
95 if attempt > 0 {
96 Some(calculate_delay(attempt - 1, &config).as_millis() as u64)
97 } else {
98 None
99 }
100 });
101
102 match operation(delay).await {
103 Ok(result) => return Ok(result),
104 Err(e) => {
105 last_error = Some(e);
106
107 if attempt < config.max_retries {
108 let delay = calculate_delay(attempt, &config);
111 tokio::time::sleep(delay).await;
112 }
113 }
114 }
115 }
116
117 Err(last_error.unwrap())
118}
119
120#[cfg(test)]
121mod tests {
122 use super::*;
123
124 #[test]
125 fn test_retry_config_default() {
126 let config = RetryConfig::default();
127 assert_eq!(config.max_retries, 3);
128 assert_eq!(config.initial_delay_ms, 1000);
129 }
130
131 #[test]
132 fn test_is_retryable_status() {
133 let config = RetryConfig::default();
134 assert!(is_retryable_status(429, &config));
135 assert!(is_retryable_status(503, &config));
136 assert!(!is_retryable_status(400, &config));
137 assert!(!is_retryable_status(200, &config));
138 }
139
140 #[test]
141 fn test_calculate_delay() {
142 let config = RetryConfig::default();
143
144 let delay0 = calculate_delay(0, &config);
145 assert_eq!(delay0, Duration::from_millis(1000));
146
147 let delay1 = calculate_delay(1, &config);
148 assert_eq!(delay1, Duration::from_millis(2000));
149
150 let delay2 = calculate_delay(2, &config);
151 assert_eq!(delay2, Duration::from_millis(4000));
152 }
153}