1use adk_core::{AdkError, Result};
2use std::{future::Future, time::Duration};
3
4#[derive(Clone, Debug)]
5pub struct RetryConfig {
6 pub enabled: bool,
7 pub max_retries: u32,
8 pub initial_delay: Duration,
9 pub max_delay: Duration,
10 pub backoff_multiplier: f32,
11}
12
13impl Default for RetryConfig {
14 fn default() -> Self {
15 Self {
16 enabled: true,
17 max_retries: 3,
18 initial_delay: Duration::from_millis(250),
19 max_delay: Duration::from_secs(5),
20 backoff_multiplier: 2.0,
21 }
22 }
23}
24
25impl RetryConfig {
26 #[must_use]
27 pub fn disabled() -> Self {
28 Self { enabled: false, ..Self::default() }
29 }
30
31 #[must_use]
32 pub fn with_max_retries(mut self, max_retries: u32) -> Self {
33 self.max_retries = max_retries;
34 self
35 }
36
37 #[must_use]
38 pub fn with_initial_delay(mut self, initial_delay: Duration) -> Self {
39 self.initial_delay = initial_delay;
40 self
41 }
42
43 #[must_use]
44 pub fn with_max_delay(mut self, max_delay: Duration) -> Self {
45 self.max_delay = max_delay;
46 self
47 }
48
49 #[must_use]
50 pub fn with_backoff_multiplier(mut self, backoff_multiplier: f32) -> Self {
51 self.backoff_multiplier = backoff_multiplier;
52 self
53 }
54}
55
56#[must_use]
57pub fn is_retryable_status_code(status_code: u16) -> bool {
58 matches!(status_code, 408 | 429 | 500 | 502 | 503 | 504 | 529)
59}
60
61#[must_use]
62pub fn is_retryable_error_message(message: &str) -> bool {
63 let normalized = message.to_ascii_uppercase();
64 normalized.contains("429")
65 || normalized.contains("408")
66 || normalized.contains("500")
67 || normalized.contains("502")
68 || normalized.contains("503")
69 || normalized.contains("504")
70 || normalized.contains("529")
71 || normalized.contains("RATE LIMIT")
72 || normalized.contains("TOO MANY REQUESTS")
73 || normalized.contains("RESOURCE_EXHAUSTED")
74 || normalized.contains("UNAVAILABLE")
75 || normalized.contains("DEADLINE_EXCEEDED")
76 || normalized.contains("TIMEOUT")
77 || normalized.contains("TIMED OUT")
78 || normalized.contains("CONNECTION RESET")
79 || normalized.contains("OVERLOADED")
80}
81
82#[must_use]
83pub fn is_retryable_model_error(error: &AdkError) -> bool {
84 if error.retry.should_retry {
86 return true;
87 }
88 if error.code.ends_with(".legacy") && error.is_model() {
91 return is_retryable_error_message(&error.message);
92 }
93 false
94}
95
96fn next_retry_delay(current: Duration, retry_config: &RetryConfig) -> Duration {
97 if current >= retry_config.max_delay {
98 return retry_config.max_delay;
99 }
100
101 let multiplier = retry_config.backoff_multiplier.max(1.0) as f64;
102 let scaled = Duration::from_secs_f64(current.as_secs_f64() * multiplier);
103 scaled.min(retry_config.max_delay)
104}
105
106#[derive(Debug, Clone, Default)]
121pub struct ServerRetryHint {
122 pub retry_after: Option<Duration>,
124}
125
126pub async fn execute_with_retry<T, Op, Fut, Classify>(
127 retry_config: &RetryConfig,
128 classify_error: Classify,
129 mut operation: Op,
130) -> Result<T>
131where
132 Op: FnMut() -> Fut,
133 Fut: Future<Output = Result<T>>,
134 Classify: Fn(&AdkError) -> bool,
135{
136 execute_with_retry_hint(retry_config, classify_error, None, &mut operation).await
137}
138
139pub async fn execute_with_retry_hint<T, Op, Fut, Classify>(
146 retry_config: &RetryConfig,
147 classify_error: Classify,
148 server_hint: Option<&ServerRetryHint>,
149 operation: &mut Op,
150) -> Result<T>
151where
152 Op: FnMut() -> Fut,
153 Fut: Future<Output = Result<T>>,
154 Classify: Fn(&AdkError) -> bool,
155{
156 if !retry_config.enabled {
157 return operation().await;
158 }
159
160 let mut attempt: u32 = 0;
161 let mut delay = retry_config.initial_delay;
162
163 let server_delay = server_hint.and_then(|h| h.retry_after);
165
166 loop {
167 match operation().await {
168 Ok(value) => return Ok(value),
169 Err(error) if attempt < retry_config.max_retries && classify_error(&error) => {
170 attempt += 1;
171
172 let error_retry_after = error.retry.retry_after();
174 let effective_delay = if let Some(d) = error_retry_after {
175 d
176 } else if attempt == 1 {
177 server_delay.unwrap_or(delay)
178 } else {
179 delay
180 };
181
182 adk_telemetry::warn!(
183 attempt = attempt,
184 max_retries = retry_config.max_retries,
185 delay_ms = effective_delay.as_millis(),
186 error = %error,
187 "Provider request failed with retryable error; retrying"
188 );
189 tokio::time::sleep(effective_delay).await;
190 delay = next_retry_delay(delay, retry_config);
191 }
192 Err(error) => return Err(error),
193 }
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200 use std::sync::{
201 Arc,
202 atomic::{AtomicU32, Ordering},
203 };
204
205 #[tokio::test]
206 async fn execute_with_retry_retries_when_classified_retryable() {
207 let retry_config = RetryConfig::default()
208 .with_max_retries(2)
209 .with_initial_delay(Duration::ZERO)
210 .with_max_delay(Duration::ZERO);
211 let attempts = Arc::new(AtomicU32::new(0));
212
213 let result = execute_with_retry(&retry_config, is_retryable_model_error, || {
214 let attempts = Arc::clone(&attempts);
215 async move {
216 let attempt = attempts.fetch_add(1, Ordering::SeqCst);
217 if attempt < 2 {
218 return Err(AdkError::model("HTTP 429 rate limit"));
219 }
220 Ok("ok")
221 }
222 })
223 .await
224 .expect("operation should succeed after retries");
225
226 assert_eq!(result, "ok");
227 assert_eq!(attempts.load(Ordering::SeqCst), 3);
228 }
229
230 #[tokio::test]
231 async fn execute_with_retry_stops_on_non_retryable_error() {
232 let retry_config = RetryConfig::default()
233 .with_max_retries(3)
234 .with_initial_delay(Duration::ZERO)
235 .with_max_delay(Duration::ZERO);
236 let attempts = Arc::new(AtomicU32::new(0));
237
238 let error = execute_with_retry(&retry_config, is_retryable_model_error, || {
239 let attempts = Arc::clone(&attempts);
240 async move {
241 attempts.fetch_add(1, Ordering::SeqCst);
242 Err::<(), _>(AdkError::model("HTTP 400 bad request"))
243 }
244 })
245 .await
246 .expect_err("operation should fail without retries");
247
248 assert!(error.is_model());
249 assert_eq!(attempts.load(Ordering::SeqCst), 1);
250 }
251
252 #[tokio::test]
253 async fn execute_with_retry_respects_disabled_config() {
254 let retry_config = RetryConfig::disabled().with_max_retries(10);
255 let attempts = Arc::new(AtomicU32::new(0));
256
257 let error = execute_with_retry(&retry_config, is_retryable_model_error, || {
258 let attempts = Arc::clone(&attempts);
259 async move {
260 attempts.fetch_add(1, Ordering::SeqCst);
261 Err::<(), _>(AdkError::model("HTTP 429 too many requests"))
262 }
263 })
264 .await
265 .expect_err("disabled retries should return first error");
266
267 assert!(error.is_model());
268 assert_eq!(attempts.load(Ordering::SeqCst), 1);
269 }
270
271 #[test]
272 fn retryable_status_code_matches_transient_errors() {
273 assert!(is_retryable_status_code(429));
274 assert!(is_retryable_status_code(503));
275 assert!(is_retryable_status_code(529));
276 assert!(!is_retryable_status_code(400));
277 assert!(!is_retryable_status_code(401));
278 }
279
280 #[test]
281 fn retryable_error_message_matches_529_and_overloaded() {
282 assert!(is_retryable_error_message("HTTP 529 overloaded"));
283 assert!(is_retryable_error_message("Server OVERLOADED, try again"));
284 }
285
286 #[tokio::test]
287 async fn execute_with_retry_hint_uses_server_delay() {
288 let retry_config = RetryConfig::default()
289 .with_max_retries(2)
290 .with_initial_delay(Duration::ZERO)
291 .with_max_delay(Duration::ZERO);
292 let attempts = Arc::new(AtomicU32::new(0));
293 let hint = ServerRetryHint { retry_after: Some(Duration::ZERO) };
294
295 let result = execute_with_retry_hint(
296 &retry_config,
297 is_retryable_model_error,
298 Some(&hint),
299 &mut || {
300 let attempts = Arc::clone(&attempts);
301 async move {
302 let attempt = attempts.fetch_add(1, Ordering::SeqCst);
303 if attempt < 1 {
304 return Err(AdkError::model("HTTP 429 rate limit"));
305 }
306 Ok("ok")
307 }
308 },
309 )
310 .await
311 .expect("operation should succeed after retry with hint");
312
313 assert_eq!(result, "ok");
314 assert_eq!(attempts.load(Ordering::SeqCst), 2);
315 }
316
317 #[tokio::test]
319 async fn status_529_is_retried_end_to_end() {
320 let retry_config = RetryConfig::default()
321 .with_max_retries(2)
322 .with_initial_delay(Duration::ZERO)
323 .with_max_delay(Duration::ZERO);
324 let attempts = Arc::new(AtomicU32::new(0));
325
326 let result = execute_with_retry(&retry_config, is_retryable_model_error, || {
327 let attempts = Arc::clone(&attempts);
328 async move {
329 let attempt = attempts.fetch_add(1, Ordering::SeqCst);
330 if attempt == 0 {
331 return Err(AdkError::model("HTTP 529 overloaded"));
332 }
333 Ok("recovered")
334 }
335 })
336 .await
337 .expect("529 should be retried and succeed on second attempt");
338
339 assert_eq!(result, "recovered");
340 assert_eq!(attempts.load(Ordering::SeqCst), 2);
341 }
342
343 #[tokio::test]
348 async fn exponential_backoff_without_retry_after() {
349 let retry_config = RetryConfig::default()
350 .with_max_retries(3)
351 .with_initial_delay(Duration::from_millis(20))
352 .with_max_delay(Duration::from_millis(200))
353 .with_backoff_multiplier(2.0);
354
355 let timestamps: Arc<std::sync::Mutex<Vec<std::time::Instant>>> =
356 Arc::new(std::sync::Mutex::new(Vec::new()));
357
358 let result = execute_with_retry(&retry_config, is_retryable_model_error, || {
359 let timestamps = Arc::clone(×tamps);
360 async move {
361 let now = std::time::Instant::now();
362 let mut ts = timestamps.lock().unwrap();
363 let attempt = ts.len();
364 ts.push(now);
365 if attempt < 3 {
366 return Err(AdkError::model("HTTP 429 rate limit"));
367 }
368 Ok("done")
369 }
370 })
371 .await
372 .expect("should succeed after backoff retries");
373
374 assert_eq!(result, "done");
375
376 let ts = timestamps.lock().unwrap();
377 assert_eq!(ts.len(), 4); let gap1 = ts[1].duration_since(ts[0]);
381 assert!(gap1 >= Duration::from_millis(18), "first backoff gap {gap1:?} should be >= ~20ms");
382
383 let gap2 = ts[2].duration_since(ts[1]);
385 assert!(
386 gap2 >= Duration::from_millis(36),
387 "second backoff gap {gap2:?} should be >= ~40ms"
388 );
389
390 assert!(gap2 >= gap1, "backoff should increase: gap2={gap2:?} should be >= gap1={gap1:?}");
392 }
393}