1use adk_core::{AdkError, Result};
2use std::{future::Future, time::Duration};
3
4#[derive(Clone, Debug)]
6pub struct RetryConfig {
7 pub enabled: bool,
9 pub max_retries: u32,
11 pub initial_delay: Duration,
13 pub max_delay: Duration,
15 pub backoff_multiplier: f32,
17}
18
19impl Default for RetryConfig {
20 fn default() -> Self {
21 Self {
22 enabled: true,
23 max_retries: 3,
24 initial_delay: Duration::from_millis(250),
25 max_delay: Duration::from_secs(5),
26 backoff_multiplier: 2.0,
27 }
28 }
29}
30
31impl RetryConfig {
32 #[must_use]
34 pub fn disabled() -> Self {
35 Self { enabled: false, ..Self::default() }
36 }
37
38 #[must_use]
40 pub fn with_max_retries(mut self, max_retries: u32) -> Self {
41 self.max_retries = max_retries;
42 self
43 }
44
45 #[must_use]
47 pub fn with_initial_delay(mut self, initial_delay: Duration) -> Self {
48 self.initial_delay = initial_delay;
49 self
50 }
51
52 #[must_use]
54 pub fn with_max_delay(mut self, max_delay: Duration) -> Self {
55 self.max_delay = max_delay;
56 self
57 }
58
59 #[must_use]
61 pub fn with_backoff_multiplier(mut self, backoff_multiplier: f32) -> Self {
62 self.backoff_multiplier = backoff_multiplier;
63 self
64 }
65}
66
67#[must_use]
69pub fn is_retryable_status_code(status_code: u16) -> bool {
70 matches!(status_code, 408 | 429 | 500 | 502 | 503 | 504 | 529)
71}
72
73#[must_use]
75pub fn is_retryable_error_message(message: &str) -> bool {
76 let normalized = message.to_ascii_uppercase();
77 normalized.contains("429")
78 || normalized.contains("408")
79 || normalized.contains("500")
80 || normalized.contains("502")
81 || normalized.contains("503")
82 || normalized.contains("504")
83 || normalized.contains("529")
84 || normalized.contains("RATE LIMIT")
85 || normalized.contains("TOO MANY REQUESTS")
86 || normalized.contains("RESOURCE_EXHAUSTED")
87 || normalized.contains("UNAVAILABLE")
88 || normalized.contains("DEADLINE_EXCEEDED")
89 || normalized.contains("TIMEOUT")
90 || normalized.contains("TIMED OUT")
91 || normalized.contains("CONNECTION RESET")
92 || normalized.contains("OVERLOADED")
93}
94
95#[must_use]
97pub fn is_retryable_model_error(error: &AdkError) -> bool {
98 if error.retry.should_retry {
100 return true;
101 }
102 if error.code.ends_with(".legacy") && error.is_model() {
105 return is_retryable_error_message(&error.message);
106 }
107 false
108}
109
110fn next_retry_delay(current: Duration, retry_config: &RetryConfig) -> Duration {
111 if current >= retry_config.max_delay {
112 return retry_config.max_delay;
113 }
114
115 let multiplier = retry_config.backoff_multiplier.max(1.0) as f64;
116 let scaled = Duration::from_secs_f64(current.as_secs_f64() * multiplier);
117 scaled.min(retry_config.max_delay)
118}
119
120#[derive(Debug, Clone, Default)]
135pub struct ServerRetryHint {
136 pub retry_after: Option<Duration>,
138}
139
140pub async fn execute_with_retry<T, Op, Fut, Classify>(
142 retry_config: &RetryConfig,
143 classify_error: Classify,
144 mut operation: Op,
145) -> Result<T>
146where
147 Op: FnMut() -> Fut,
148 Fut: Future<Output = Result<T>>,
149 Classify: Fn(&AdkError) -> bool,
150{
151 execute_with_retry_hint(retry_config, classify_error, None, &mut operation).await
152}
153
154pub async fn execute_with_retry_hint<T, Op, Fut, Classify>(
161 retry_config: &RetryConfig,
162 classify_error: Classify,
163 server_hint: Option<&ServerRetryHint>,
164 operation: &mut Op,
165) -> Result<T>
166where
167 Op: FnMut() -> Fut,
168 Fut: Future<Output = Result<T>>,
169 Classify: Fn(&AdkError) -> bool,
170{
171 if !retry_config.enabled {
172 return operation().await;
173 }
174
175 let mut attempt: u32 = 0;
176 let mut delay = retry_config.initial_delay;
177
178 let server_delay = server_hint.and_then(|h| h.retry_after);
180
181 loop {
182 match operation().await {
183 Ok(value) => return Ok(value),
184 Err(error) if attempt < retry_config.max_retries && classify_error(&error) => {
185 attempt += 1;
186
187 let error_retry_after = error.retry.retry_after();
189 let effective_delay = if let Some(d) = error_retry_after {
190 d
191 } else if attempt == 1 {
192 server_delay.unwrap_or(delay)
193 } else {
194 delay
195 };
196
197 adk_telemetry::warn!(
198 attempt = attempt,
199 max_retries = retry_config.max_retries,
200 delay_ms = effective_delay.as_millis(),
201 error = %error,
202 "Provider request failed with retryable error; retrying"
203 );
204 tokio::time::sleep(effective_delay).await;
205 delay = next_retry_delay(delay, retry_config);
206 }
207 Err(error) => return Err(error),
208 }
209 }
210}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215 use std::sync::{
216 Arc,
217 atomic::{AtomicU32, Ordering},
218 };
219
220 #[tokio::test]
221 async fn execute_with_retry_retries_when_classified_retryable() {
222 let retry_config = RetryConfig::default()
223 .with_max_retries(2)
224 .with_initial_delay(Duration::ZERO)
225 .with_max_delay(Duration::ZERO);
226 let attempts = Arc::new(AtomicU32::new(0));
227
228 let result = execute_with_retry(&retry_config, is_retryable_model_error, || {
229 let attempts = Arc::clone(&attempts);
230 async move {
231 let attempt = attempts.fetch_add(1, Ordering::SeqCst);
232 if attempt < 2 {
233 return Err(AdkError::model("HTTP 429 rate limit"));
234 }
235 Ok("ok")
236 }
237 })
238 .await
239 .expect("operation should succeed after retries");
240
241 assert_eq!(result, "ok");
242 assert_eq!(attempts.load(Ordering::SeqCst), 3);
243 }
244
245 #[tokio::test]
246 async fn execute_with_retry_stops_on_non_retryable_error() {
247 let retry_config = RetryConfig::default()
248 .with_max_retries(3)
249 .with_initial_delay(Duration::ZERO)
250 .with_max_delay(Duration::ZERO);
251 let attempts = Arc::new(AtomicU32::new(0));
252
253 let error = execute_with_retry(&retry_config, is_retryable_model_error, || {
254 let attempts = Arc::clone(&attempts);
255 async move {
256 attempts.fetch_add(1, Ordering::SeqCst);
257 Err::<(), _>(AdkError::model("HTTP 400 bad request"))
258 }
259 })
260 .await
261 .expect_err("operation should fail without retries");
262
263 assert!(error.is_model());
264 assert_eq!(attempts.load(Ordering::SeqCst), 1);
265 }
266
267 #[tokio::test]
268 async fn execute_with_retry_respects_disabled_config() {
269 let retry_config = RetryConfig::disabled().with_max_retries(10);
270 let attempts = Arc::new(AtomicU32::new(0));
271
272 let error = execute_with_retry(&retry_config, is_retryable_model_error, || {
273 let attempts = Arc::clone(&attempts);
274 async move {
275 attempts.fetch_add(1, Ordering::SeqCst);
276 Err::<(), _>(AdkError::model("HTTP 429 too many requests"))
277 }
278 })
279 .await
280 .expect_err("disabled retries should return first error");
281
282 assert!(error.is_model());
283 assert_eq!(attempts.load(Ordering::SeqCst), 1);
284 }
285
286 #[test]
287 fn retryable_status_code_matches_transient_errors() {
288 assert!(is_retryable_status_code(429));
289 assert!(is_retryable_status_code(503));
290 assert!(is_retryable_status_code(529));
291 assert!(!is_retryable_status_code(400));
292 assert!(!is_retryable_status_code(401));
293 }
294
295 #[test]
296 fn retryable_error_message_matches_529_and_overloaded() {
297 assert!(is_retryable_error_message("HTTP 529 overloaded"));
298 assert!(is_retryable_error_message("Server OVERLOADED, try again"));
299 }
300
301 #[tokio::test]
302 async fn execute_with_retry_hint_uses_server_delay() {
303 let retry_config = RetryConfig::default()
304 .with_max_retries(2)
305 .with_initial_delay(Duration::ZERO)
306 .with_max_delay(Duration::ZERO);
307 let attempts = Arc::new(AtomicU32::new(0));
308 let hint = ServerRetryHint { retry_after: Some(Duration::ZERO) };
309
310 let result = execute_with_retry_hint(
311 &retry_config,
312 is_retryable_model_error,
313 Some(&hint),
314 &mut || {
315 let attempts = Arc::clone(&attempts);
316 async move {
317 let attempt = attempts.fetch_add(1, Ordering::SeqCst);
318 if attempt < 1 {
319 return Err(AdkError::model("HTTP 429 rate limit"));
320 }
321 Ok("ok")
322 }
323 },
324 )
325 .await
326 .expect("operation should succeed after retry with hint");
327
328 assert_eq!(result, "ok");
329 assert_eq!(attempts.load(Ordering::SeqCst), 2);
330 }
331
332 #[tokio::test]
334 async fn status_529_is_retried_end_to_end() {
335 let retry_config = RetryConfig::default()
336 .with_max_retries(2)
337 .with_initial_delay(Duration::ZERO)
338 .with_max_delay(Duration::ZERO);
339 let attempts = Arc::new(AtomicU32::new(0));
340
341 let result = execute_with_retry(&retry_config, is_retryable_model_error, || {
342 let attempts = Arc::clone(&attempts);
343 async move {
344 let attempt = attempts.fetch_add(1, Ordering::SeqCst);
345 if attempt == 0 {
346 return Err(AdkError::model("HTTP 529 overloaded"));
347 }
348 Ok("recovered")
349 }
350 })
351 .await
352 .expect("529 should be retried and succeed on second attempt");
353
354 assert_eq!(result, "recovered");
355 assert_eq!(attempts.load(Ordering::SeqCst), 2);
356 }
357
358 #[tokio::test]
363 async fn exponential_backoff_without_retry_after() {
364 let retry_config = RetryConfig::default()
365 .with_max_retries(3)
366 .with_initial_delay(Duration::from_millis(20))
367 .with_max_delay(Duration::from_millis(200))
368 .with_backoff_multiplier(2.0);
369
370 let timestamps: Arc<std::sync::Mutex<Vec<std::time::Instant>>> =
371 Arc::new(std::sync::Mutex::new(Vec::new()));
372
373 let result = execute_with_retry(&retry_config, is_retryable_model_error, || {
374 let timestamps = Arc::clone(×tamps);
375 async move {
376 let now = std::time::Instant::now();
377 let mut ts = timestamps.lock().unwrap();
378 let attempt = ts.len();
379 ts.push(now);
380 if attempt < 3 {
381 return Err(AdkError::model("HTTP 429 rate limit"));
382 }
383 Ok("done")
384 }
385 })
386 .await
387 .expect("should succeed after backoff retries");
388
389 assert_eq!(result, "done");
390
391 let ts = timestamps.lock().unwrap();
392 assert_eq!(ts.len(), 4); let gap1 = ts[1].duration_since(ts[0]);
396 assert!(gap1 >= Duration::from_millis(18), "first backoff gap {gap1:?} should be >= ~20ms");
397
398 let gap2 = ts[2].duration_since(ts[1]);
400 assert!(
401 gap2 >= Duration::from_millis(36),
402 "second backoff gap {gap2:?} should be >= ~40ms"
403 );
404
405 assert!(gap2 >= gap1, "backoff should increase: gap2={gap2:?} should be >= gap1={gap1:?}");
407 }
408}