1use adk_core::{AdkError, Result};
2use std::{future::Future, time::Duration};
3
4#[derive(Clone, Debug)]
5pub struct RetryConfig {
6 pub enabled: bool,
7 pub max_retries: u32,
8 pub initial_delay: Duration,
9 pub max_delay: Duration,
10 pub backoff_multiplier: f32,
11}
12
13impl Default for RetryConfig {
14 fn default() -> Self {
15 Self {
16 enabled: true,
17 max_retries: 3,
18 initial_delay: Duration::from_millis(250),
19 max_delay: Duration::from_secs(5),
20 backoff_multiplier: 2.0,
21 }
22 }
23}
24
25impl RetryConfig {
26 #[must_use]
27 pub fn disabled() -> Self {
28 Self { enabled: false, ..Self::default() }
29 }
30
31 #[must_use]
32 pub fn with_max_retries(mut self, max_retries: u32) -> Self {
33 self.max_retries = max_retries;
34 self
35 }
36
37 #[must_use]
38 pub fn with_initial_delay(mut self, initial_delay: Duration) -> Self {
39 self.initial_delay = initial_delay;
40 self
41 }
42
43 #[must_use]
44 pub fn with_max_delay(mut self, max_delay: Duration) -> Self {
45 self.max_delay = max_delay;
46 self
47 }
48
49 #[must_use]
50 pub fn with_backoff_multiplier(mut self, backoff_multiplier: f32) -> Self {
51 self.backoff_multiplier = backoff_multiplier;
52 self
53 }
54}
55
56#[must_use]
57pub fn is_retryable_status_code(status_code: u16) -> bool {
58 matches!(status_code, 408 | 429 | 500 | 502 | 503 | 504 | 529)
59}
60
61#[must_use]
62pub fn is_retryable_error_message(message: &str) -> bool {
63 let normalized = message.to_ascii_uppercase();
64 normalized.contains("429")
65 || normalized.contains("408")
66 || normalized.contains("500")
67 || normalized.contains("502")
68 || normalized.contains("503")
69 || normalized.contains("504")
70 || normalized.contains("529")
71 || normalized.contains("RATE LIMIT")
72 || normalized.contains("TOO MANY REQUESTS")
73 || normalized.contains("RESOURCE_EXHAUSTED")
74 || normalized.contains("UNAVAILABLE")
75 || normalized.contains("DEADLINE_EXCEEDED")
76 || normalized.contains("TIMEOUT")
77 || normalized.contains("TIMED OUT")
78 || normalized.contains("CONNECTION RESET")
79 || normalized.contains("OVERLOADED")
80}
81
82#[must_use]
83pub fn is_retryable_model_error(error: &AdkError) -> bool {
84 match error {
85 AdkError::Model(message) => is_retryable_error_message(message),
86 _ => false,
87 }
88}
89
90fn next_retry_delay(current: Duration, retry_config: &RetryConfig) -> Duration {
91 if current >= retry_config.max_delay {
92 return retry_config.max_delay;
93 }
94
95 let multiplier = retry_config.backoff_multiplier.max(1.0) as f64;
96 let scaled = Duration::from_secs_f64(current.as_secs_f64() * multiplier);
97 scaled.min(retry_config.max_delay)
98}
99
100#[derive(Debug, Clone, Default)]
115pub struct ServerRetryHint {
116 pub retry_after: Option<Duration>,
118}
119
120pub async fn execute_with_retry<T, Op, Fut, Classify>(
121 retry_config: &RetryConfig,
122 classify_error: Classify,
123 mut operation: Op,
124) -> Result<T>
125where
126 Op: FnMut() -> Fut,
127 Fut: Future<Output = Result<T>>,
128 Classify: Fn(&AdkError) -> bool,
129{
130 execute_with_retry_hint(retry_config, classify_error, None, &mut operation).await
131}
132
133pub async fn execute_with_retry_hint<T, Op, Fut, Classify>(
140 retry_config: &RetryConfig,
141 classify_error: Classify,
142 server_hint: Option<&ServerRetryHint>,
143 operation: &mut Op,
144) -> Result<T>
145where
146 Op: FnMut() -> Fut,
147 Fut: Future<Output = Result<T>>,
148 Classify: Fn(&AdkError) -> bool,
149{
150 if !retry_config.enabled {
151 return operation().await;
152 }
153
154 let mut attempt: u32 = 0;
155 let mut delay = retry_config.initial_delay;
156
157 let server_delay = server_hint.and_then(|h| h.retry_after);
159
160 loop {
161 match operation().await {
162 Ok(value) => return Ok(value),
163 Err(error) if attempt < retry_config.max_retries && classify_error(&error) => {
164 attempt += 1;
165
166 let effective_delay =
169 if attempt == 1 { server_delay.unwrap_or(delay) } else { delay };
170
171 adk_telemetry::warn!(
172 attempt = attempt,
173 max_retries = retry_config.max_retries,
174 delay_ms = effective_delay.as_millis(),
175 error = %error,
176 "Provider request failed with retryable error; retrying"
177 );
178 tokio::time::sleep(effective_delay).await;
179 delay = next_retry_delay(delay, retry_config);
180 }
181 Err(error) => return Err(error),
182 }
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189 use std::sync::{
190 Arc,
191 atomic::{AtomicU32, Ordering},
192 };
193
194 #[tokio::test]
195 async fn execute_with_retry_retries_when_classified_retryable() {
196 let retry_config = RetryConfig::default()
197 .with_max_retries(2)
198 .with_initial_delay(Duration::ZERO)
199 .with_max_delay(Duration::ZERO);
200 let attempts = Arc::new(AtomicU32::new(0));
201
202 let result = execute_with_retry(&retry_config, is_retryable_model_error, || {
203 let attempts = Arc::clone(&attempts);
204 async move {
205 let attempt = attempts.fetch_add(1, Ordering::SeqCst);
206 if attempt < 2 {
207 return Err(AdkError::Model("HTTP 429 rate limit".to_string()));
208 }
209 Ok("ok")
210 }
211 })
212 .await
213 .expect("operation should succeed after retries");
214
215 assert_eq!(result, "ok");
216 assert_eq!(attempts.load(Ordering::SeqCst), 3);
217 }
218
219 #[tokio::test]
220 async fn execute_with_retry_stops_on_non_retryable_error() {
221 let retry_config = RetryConfig::default()
222 .with_max_retries(3)
223 .with_initial_delay(Duration::ZERO)
224 .with_max_delay(Duration::ZERO);
225 let attempts = Arc::new(AtomicU32::new(0));
226
227 let error = execute_with_retry(&retry_config, is_retryable_model_error, || {
228 let attempts = Arc::clone(&attempts);
229 async move {
230 attempts.fetch_add(1, Ordering::SeqCst);
231 Err::<(), _>(AdkError::Model("HTTP 400 bad request".to_string()))
232 }
233 })
234 .await
235 .expect_err("operation should fail without retries");
236
237 assert!(matches!(error, AdkError::Model(_)));
238 assert_eq!(attempts.load(Ordering::SeqCst), 1);
239 }
240
241 #[tokio::test]
242 async fn execute_with_retry_respects_disabled_config() {
243 let retry_config = RetryConfig::disabled().with_max_retries(10);
244 let attempts = Arc::new(AtomicU32::new(0));
245
246 let error = execute_with_retry(&retry_config, is_retryable_model_error, || {
247 let attempts = Arc::clone(&attempts);
248 async move {
249 attempts.fetch_add(1, Ordering::SeqCst);
250 Err::<(), _>(AdkError::Model("HTTP 429 too many requests".to_string()))
251 }
252 })
253 .await
254 .expect_err("disabled retries should return first error");
255
256 assert!(matches!(error, AdkError::Model(_)));
257 assert_eq!(attempts.load(Ordering::SeqCst), 1);
258 }
259
260 #[test]
261 fn retryable_status_code_matches_transient_errors() {
262 assert!(is_retryable_status_code(429));
263 assert!(is_retryable_status_code(503));
264 assert!(is_retryable_status_code(529));
265 assert!(!is_retryable_status_code(400));
266 assert!(!is_retryable_status_code(401));
267 }
268
269 #[test]
270 fn retryable_error_message_matches_529_and_overloaded() {
271 assert!(is_retryable_error_message("HTTP 529 overloaded"));
272 assert!(is_retryable_error_message("Server OVERLOADED, try again"));
273 }
274
275 #[tokio::test]
276 async fn execute_with_retry_hint_uses_server_delay() {
277 let retry_config = RetryConfig::default()
278 .with_max_retries(2)
279 .with_initial_delay(Duration::ZERO)
280 .with_max_delay(Duration::ZERO);
281 let attempts = Arc::new(AtomicU32::new(0));
282 let hint = ServerRetryHint { retry_after: Some(Duration::ZERO) };
283
284 let result = execute_with_retry_hint(
285 &retry_config,
286 is_retryable_model_error,
287 Some(&hint),
288 &mut || {
289 let attempts = Arc::clone(&attempts);
290 async move {
291 let attempt = attempts.fetch_add(1, Ordering::SeqCst);
292 if attempt < 1 {
293 return Err(AdkError::Model("HTTP 429 rate limit".to_string()));
294 }
295 Ok("ok")
296 }
297 },
298 )
299 .await
300 .expect("operation should succeed after retry with hint");
301
302 assert_eq!(result, "ok");
303 assert_eq!(attempts.load(Ordering::SeqCst), 2);
304 }
305
306 #[tokio::test]
308 async fn status_529_is_retried_end_to_end() {
309 let retry_config = RetryConfig::default()
310 .with_max_retries(2)
311 .with_initial_delay(Duration::ZERO)
312 .with_max_delay(Duration::ZERO);
313 let attempts = Arc::new(AtomicU32::new(0));
314
315 let result = execute_with_retry(&retry_config, is_retryable_model_error, || {
316 let attempts = Arc::clone(&attempts);
317 async move {
318 let attempt = attempts.fetch_add(1, Ordering::SeqCst);
319 if attempt == 0 {
320 return Err(AdkError::Model("HTTP 529 overloaded".to_string()));
321 }
322 Ok("recovered")
323 }
324 })
325 .await
326 .expect("529 should be retried and succeed on second attempt");
327
328 assert_eq!(result, "recovered");
329 assert_eq!(attempts.load(Ordering::SeqCst), 2);
330 }
331
332 #[tokio::test]
337 async fn exponential_backoff_without_retry_after() {
338 let retry_config = RetryConfig::default()
339 .with_max_retries(3)
340 .with_initial_delay(Duration::from_millis(20))
341 .with_max_delay(Duration::from_millis(200))
342 .with_backoff_multiplier(2.0);
343
344 let timestamps: Arc<std::sync::Mutex<Vec<std::time::Instant>>> =
345 Arc::new(std::sync::Mutex::new(Vec::new()));
346
347 let result = execute_with_retry(&retry_config, is_retryable_model_error, || {
348 let timestamps = Arc::clone(×tamps);
349 async move {
350 let now = std::time::Instant::now();
351 let mut ts = timestamps.lock().unwrap();
352 let attempt = ts.len();
353 ts.push(now);
354 if attempt < 3 {
355 return Err(AdkError::Model("HTTP 429 rate limit".to_string()));
356 }
357 Ok("done")
358 }
359 })
360 .await
361 .expect("should succeed after backoff retries");
362
363 assert_eq!(result, "done");
364
365 let ts = timestamps.lock().unwrap();
366 assert_eq!(ts.len(), 4); let gap1 = ts[1].duration_since(ts[0]);
370 assert!(gap1 >= Duration::from_millis(18), "first backoff gap {gap1:?} should be >= ~20ms");
371
372 let gap2 = ts[2].duration_since(ts[1]);
374 assert!(
375 gap2 >= Duration::from_millis(36),
376 "second backoff gap {gap2:?} should be >= ~40ms"
377 );
378
379 assert!(gap2 >= gap1, "backoff should increase: gap2={gap2:?} should be >= gap1={gap1:?}");
381 }
382}