Skip to main content

adk_model/
retry.rs

1use adk_core::{AdkError, Result};
2use std::{future::Future, time::Duration};
3
4#[derive(Clone, Debug)]
5pub struct RetryConfig {
6    pub enabled: bool,
7    pub max_retries: u32,
8    pub initial_delay: Duration,
9    pub max_delay: Duration,
10    pub backoff_multiplier: f32,
11}
12
13impl Default for RetryConfig {
14    fn default() -> Self {
15        Self {
16            enabled: true,
17            max_retries: 3,
18            initial_delay: Duration::from_millis(250),
19            max_delay: Duration::from_secs(5),
20            backoff_multiplier: 2.0,
21        }
22    }
23}
24
25impl RetryConfig {
26    #[must_use]
27    pub fn disabled() -> Self {
28        Self { enabled: false, ..Self::default() }
29    }
30
31    #[must_use]
32    pub fn with_max_retries(mut self, max_retries: u32) -> Self {
33        self.max_retries = max_retries;
34        self
35    }
36
37    #[must_use]
38    pub fn with_initial_delay(mut self, initial_delay: Duration) -> Self {
39        self.initial_delay = initial_delay;
40        self
41    }
42
43    #[must_use]
44    pub fn with_max_delay(mut self, max_delay: Duration) -> Self {
45        self.max_delay = max_delay;
46        self
47    }
48
49    #[must_use]
50    pub fn with_backoff_multiplier(mut self, backoff_multiplier: f32) -> Self {
51        self.backoff_multiplier = backoff_multiplier;
52        self
53    }
54}
55
56#[must_use]
57pub fn is_retryable_status_code(status_code: u16) -> bool {
58    matches!(status_code, 408 | 429 | 500 | 502 | 503 | 504 | 529)
59}
60
61#[must_use]
62pub fn is_retryable_error_message(message: &str) -> bool {
63    let normalized = message.to_ascii_uppercase();
64    normalized.contains("429")
65        || normalized.contains("408")
66        || normalized.contains("500")
67        || normalized.contains("502")
68        || normalized.contains("503")
69        || normalized.contains("504")
70        || normalized.contains("529")
71        || normalized.contains("RATE LIMIT")
72        || normalized.contains("TOO MANY REQUESTS")
73        || normalized.contains("RESOURCE_EXHAUSTED")
74        || normalized.contains("UNAVAILABLE")
75        || normalized.contains("DEADLINE_EXCEEDED")
76        || normalized.contains("TIMEOUT")
77        || normalized.contains("TIMED OUT")
78        || normalized.contains("CONNECTION RESET")
79        || normalized.contains("OVERLOADED")
80}
81
82#[must_use]
83pub fn is_retryable_model_error(error: &AdkError) -> bool {
84    match error {
85        AdkError::Model(message) => is_retryable_error_message(message),
86        _ => false,
87    }
88}
89
90fn next_retry_delay(current: Duration, retry_config: &RetryConfig) -> Duration {
91    if current >= retry_config.max_delay {
92        return retry_config.max_delay;
93    }
94
95    let multiplier = retry_config.backoff_multiplier.max(1.0) as f64;
96    let scaled = Duration::from_secs_f64(current.as_secs_f64() * multiplier);
97    scaled.min(retry_config.max_delay)
98}
99
100/// Hint from the server about when to retry.
101///
102/// When the server provides a `retry-after` header, this hint overrides the
103/// exponential backoff calculation for the next retry attempt.
104///
105/// # Example
106///
107/// ```rust
108/// use adk_model::retry::ServerRetryHint;
109/// use std::time::Duration;
110///
111/// let hint = ServerRetryHint { retry_after: Some(Duration::from_secs(30)) };
112/// assert_eq!(hint.retry_after, Some(Duration::from_secs(30)));
113/// ```
114#[derive(Debug, Clone, Default)]
115pub struct ServerRetryHint {
116    /// Server-suggested delay before retrying.
117    pub retry_after: Option<Duration>,
118}
119
120pub async fn execute_with_retry<T, Op, Fut, Classify>(
121    retry_config: &RetryConfig,
122    classify_error: Classify,
123    mut operation: Op,
124) -> Result<T>
125where
126    Op: FnMut() -> Fut,
127    Fut: Future<Output = Result<T>>,
128    Classify: Fn(&AdkError) -> bool,
129{
130    execute_with_retry_hint(retry_config, classify_error, None, &mut operation).await
131}
132
133/// Execute an operation with retry logic, optionally using a server-provided
134/// retry hint to override the backoff delay.
135///
136/// When `server_hint` contains a `retry_after` duration, that duration is used
137/// instead of the exponential backoff calculation. This respects server-provided
138/// timing from `retry-after` headers (Requirement 5.1).
139pub async fn execute_with_retry_hint<T, Op, Fut, Classify>(
140    retry_config: &RetryConfig,
141    classify_error: Classify,
142    server_hint: Option<&ServerRetryHint>,
143    operation: &mut Op,
144) -> Result<T>
145where
146    Op: FnMut() -> Fut,
147    Fut: Future<Output = Result<T>>,
148    Classify: Fn(&AdkError) -> bool,
149{
150    if !retry_config.enabled {
151        return operation().await;
152    }
153
154    let mut attempt: u32 = 0;
155    let mut delay = retry_config.initial_delay;
156
157    // If the server provided a retry-after hint, use it for the first retry delay.
158    let server_delay = server_hint.and_then(|h| h.retry_after);
159
160    loop {
161        match operation().await {
162            Ok(value) => return Ok(value),
163            Err(error) if attempt < retry_config.max_retries && classify_error(&error) => {
164                attempt += 1;
165
166                // Requirement 5.1: Use server-provided retry-after when present,
167                // Requirement 5.4: Fall back to exponential backoff otherwise.
168                let effective_delay =
169                    if attempt == 1 { server_delay.unwrap_or(delay) } else { delay };
170
171                adk_telemetry::warn!(
172                    attempt = attempt,
173                    max_retries = retry_config.max_retries,
174                    delay_ms = effective_delay.as_millis(),
175                    error = %error,
176                    "Provider request failed with retryable error; retrying"
177                );
178                tokio::time::sleep(effective_delay).await;
179                delay = next_retry_delay(delay, retry_config);
180            }
181            Err(error) => return Err(error),
182        }
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189    use std::sync::{
190        Arc,
191        atomic::{AtomicU32, Ordering},
192    };
193
194    #[tokio::test]
195    async fn execute_with_retry_retries_when_classified_retryable() {
196        let retry_config = RetryConfig::default()
197            .with_max_retries(2)
198            .with_initial_delay(Duration::ZERO)
199            .with_max_delay(Duration::ZERO);
200        let attempts = Arc::new(AtomicU32::new(0));
201
202        let result = execute_with_retry(&retry_config, is_retryable_model_error, || {
203            let attempts = Arc::clone(&attempts);
204            async move {
205                let attempt = attempts.fetch_add(1, Ordering::SeqCst);
206                if attempt < 2 {
207                    return Err(AdkError::Model("HTTP 429 rate limit".to_string()));
208                }
209                Ok("ok")
210            }
211        })
212        .await
213        .expect("operation should succeed after retries");
214
215        assert_eq!(result, "ok");
216        assert_eq!(attempts.load(Ordering::SeqCst), 3);
217    }
218
219    #[tokio::test]
220    async fn execute_with_retry_stops_on_non_retryable_error() {
221        let retry_config = RetryConfig::default()
222            .with_max_retries(3)
223            .with_initial_delay(Duration::ZERO)
224            .with_max_delay(Duration::ZERO);
225        let attempts = Arc::new(AtomicU32::new(0));
226
227        let error = execute_with_retry(&retry_config, is_retryable_model_error, || {
228            let attempts = Arc::clone(&attempts);
229            async move {
230                attempts.fetch_add(1, Ordering::SeqCst);
231                Err::<(), _>(AdkError::Model("HTTP 400 bad request".to_string()))
232            }
233        })
234        .await
235        .expect_err("operation should fail without retries");
236
237        assert!(matches!(error, AdkError::Model(_)));
238        assert_eq!(attempts.load(Ordering::SeqCst), 1);
239    }
240
241    #[tokio::test]
242    async fn execute_with_retry_respects_disabled_config() {
243        let retry_config = RetryConfig::disabled().with_max_retries(10);
244        let attempts = Arc::new(AtomicU32::new(0));
245
246        let error = execute_with_retry(&retry_config, is_retryable_model_error, || {
247            let attempts = Arc::clone(&attempts);
248            async move {
249                attempts.fetch_add(1, Ordering::SeqCst);
250                Err::<(), _>(AdkError::Model("HTTP 429 too many requests".to_string()))
251            }
252        })
253        .await
254        .expect_err("disabled retries should return first error");
255
256        assert!(matches!(error, AdkError::Model(_)));
257        assert_eq!(attempts.load(Ordering::SeqCst), 1);
258    }
259
260    #[test]
261    fn retryable_status_code_matches_transient_errors() {
262        assert!(is_retryable_status_code(429));
263        assert!(is_retryable_status_code(503));
264        assert!(is_retryable_status_code(529));
265        assert!(!is_retryable_status_code(400));
266        assert!(!is_retryable_status_code(401));
267    }
268
269    #[test]
270    fn retryable_error_message_matches_529_and_overloaded() {
271        assert!(is_retryable_error_message("HTTP 529 overloaded"));
272        assert!(is_retryable_error_message("Server OVERLOADED, try again"));
273    }
274
275    #[tokio::test]
276    async fn execute_with_retry_hint_uses_server_delay() {
277        let retry_config = RetryConfig::default()
278            .with_max_retries(2)
279            .with_initial_delay(Duration::ZERO)
280            .with_max_delay(Duration::ZERO);
281        let attempts = Arc::new(AtomicU32::new(0));
282        let hint = ServerRetryHint { retry_after: Some(Duration::ZERO) };
283
284        let result = execute_with_retry_hint(
285            &retry_config,
286            is_retryable_model_error,
287            Some(&hint),
288            &mut || {
289                let attempts = Arc::clone(&attempts);
290                async move {
291                    let attempt = attempts.fetch_add(1, Ordering::SeqCst);
292                    if attempt < 1 {
293                        return Err(AdkError::Model("HTTP 429 rate limit".to_string()));
294                    }
295                    Ok("ok")
296                }
297            },
298        )
299        .await
300        .expect("operation should succeed after retry with hint");
301
302        assert_eq!(result, "ok");
303        assert_eq!(attempts.load(Ordering::SeqCst), 2);
304    }
305
306    /// Requirement 5.2: HTTP 529 (overloaded) is retried end-to-end.
307    #[tokio::test]
308    async fn status_529_is_retried_end_to_end() {
309        let retry_config = RetryConfig::default()
310            .with_max_retries(2)
311            .with_initial_delay(Duration::ZERO)
312            .with_max_delay(Duration::ZERO);
313        let attempts = Arc::new(AtomicU32::new(0));
314
315        let result = execute_with_retry(&retry_config, is_retryable_model_error, || {
316            let attempts = Arc::clone(&attempts);
317            async move {
318                let attempt = attempts.fetch_add(1, Ordering::SeqCst);
319                if attempt == 0 {
320                    return Err(AdkError::Model("HTTP 529 overloaded".to_string()));
321                }
322                Ok("recovered")
323            }
324        })
325        .await
326        .expect("529 should be retried and succeed on second attempt");
327
328        assert_eq!(result, "recovered");
329        assert_eq!(attempts.load(Ordering::SeqCst), 2);
330    }
331
332    /// Requirement 5.4: Exponential backoff when retry-after is absent.
333    /// With initial_delay=20ms and multiplier=2.0, the delays should be
334    /// ~20ms (attempt 1) then ~40ms (attempt 2). We verify each gap is
335    /// at least the expected delay.
336    #[tokio::test]
337    async fn exponential_backoff_without_retry_after() {
338        let retry_config = RetryConfig::default()
339            .with_max_retries(3)
340            .with_initial_delay(Duration::from_millis(20))
341            .with_max_delay(Duration::from_millis(200))
342            .with_backoff_multiplier(2.0);
343
344        let timestamps: Arc<std::sync::Mutex<Vec<std::time::Instant>>> =
345            Arc::new(std::sync::Mutex::new(Vec::new()));
346
347        let result = execute_with_retry(&retry_config, is_retryable_model_error, || {
348            let timestamps = Arc::clone(&timestamps);
349            async move {
350                let now = std::time::Instant::now();
351                let mut ts = timestamps.lock().unwrap();
352                let attempt = ts.len();
353                ts.push(now);
354                if attempt < 3 {
355                    return Err(AdkError::Model("HTTP 429 rate limit".to_string()));
356                }
357                Ok("done")
358            }
359        })
360        .await
361        .expect("should succeed after backoff retries");
362
363        assert_eq!(result, "done");
364
365        let ts = timestamps.lock().unwrap();
366        assert_eq!(ts.len(), 4); // initial + 3 retries
367
368        // Gap between attempt 0 and 1 should be >= initial_delay (20ms).
369        let gap1 = ts[1].duration_since(ts[0]);
370        assert!(gap1 >= Duration::from_millis(18), "first backoff gap {gap1:?} should be >= ~20ms");
371
372        // Gap between attempt 1 and 2 should be >= 2 * initial_delay (40ms).
373        let gap2 = ts[2].duration_since(ts[1]);
374        assert!(
375            gap2 >= Duration::from_millis(36),
376            "second backoff gap {gap2:?} should be >= ~40ms"
377        );
378
379        // Gap 2 should be roughly double gap 1 (with tolerance for scheduling).
380        assert!(gap2 >= gap1, "backoff should increase: gap2={gap2:?} should be >= gap1={gap1:?}");
381    }
382}