Skip to main content

adk_model/
retry.rs

1use adk_core::{AdkError, Result};
2use std::{future::Future, time::Duration};
3
4#[derive(Clone, Debug)]
5pub struct RetryConfig {
6    pub enabled: bool,
7    pub max_retries: u32,
8    pub initial_delay: Duration,
9    pub max_delay: Duration,
10    pub backoff_multiplier: f32,
11}
12
13impl Default for RetryConfig {
14    fn default() -> Self {
15        Self {
16            enabled: true,
17            max_retries: 3,
18            initial_delay: Duration::from_millis(250),
19            max_delay: Duration::from_secs(5),
20            backoff_multiplier: 2.0,
21        }
22    }
23}
24
25impl RetryConfig {
26    #[must_use]
27    pub fn disabled() -> Self {
28        Self { enabled: false, ..Self::default() }
29    }
30
31    #[must_use]
32    pub fn with_max_retries(mut self, max_retries: u32) -> Self {
33        self.max_retries = max_retries;
34        self
35    }
36
37    #[must_use]
38    pub fn with_initial_delay(mut self, initial_delay: Duration) -> Self {
39        self.initial_delay = initial_delay;
40        self
41    }
42
43    #[must_use]
44    pub fn with_max_delay(mut self, max_delay: Duration) -> Self {
45        self.max_delay = max_delay;
46        self
47    }
48
49    #[must_use]
50    pub fn with_backoff_multiplier(mut self, backoff_multiplier: f32) -> Self {
51        self.backoff_multiplier = backoff_multiplier;
52        self
53    }
54}
55
56#[must_use]
57pub fn is_retryable_status_code(status_code: u16) -> bool {
58    matches!(status_code, 408 | 429 | 500 | 502 | 503 | 504 | 529)
59}
60
61#[must_use]
62pub fn is_retryable_error_message(message: &str) -> bool {
63    let normalized = message.to_ascii_uppercase();
64    normalized.contains("429")
65        || normalized.contains("408")
66        || normalized.contains("500")
67        || normalized.contains("502")
68        || normalized.contains("503")
69        || normalized.contains("504")
70        || normalized.contains("529")
71        || normalized.contains("RATE LIMIT")
72        || normalized.contains("TOO MANY REQUESTS")
73        || normalized.contains("RESOURCE_EXHAUSTED")
74        || normalized.contains("UNAVAILABLE")
75        || normalized.contains("DEADLINE_EXCEEDED")
76        || normalized.contains("TIMEOUT")
77        || normalized.contains("TIMED OUT")
78        || normalized.contains("CONNECTION RESET")
79        || normalized.contains("OVERLOADED")
80}
81
82#[must_use]
83pub fn is_retryable_model_error(error: &AdkError) -> bool {
84    // Primary: use structured retry hint (single source of truth)
85    if error.retry.should_retry {
86        return true;
87    }
88    // Fallback: for backward-compat `.legacy` errors during transition,
89    // check the error message for retryable patterns
90    if error.code.ends_with(".legacy") && error.is_model() {
91        return is_retryable_error_message(&error.message);
92    }
93    false
94}
95
96fn next_retry_delay(current: Duration, retry_config: &RetryConfig) -> Duration {
97    if current >= retry_config.max_delay {
98        return retry_config.max_delay;
99    }
100
101    let multiplier = retry_config.backoff_multiplier.max(1.0) as f64;
102    let scaled = Duration::from_secs_f64(current.as_secs_f64() * multiplier);
103    scaled.min(retry_config.max_delay)
104}
105
106/// Hint from the server about when to retry.
107///
108/// When the server provides a `retry-after` header, this hint overrides the
109/// exponential backoff calculation for the next retry attempt.
110///
111/// # Example
112///
113/// ```rust
114/// use adk_model::retry::ServerRetryHint;
115/// use std::time::Duration;
116///
117/// let hint = ServerRetryHint { retry_after: Some(Duration::from_secs(30)) };
118/// assert_eq!(hint.retry_after, Some(Duration::from_secs(30)));
119/// ```
120#[derive(Debug, Clone, Default)]
121pub struct ServerRetryHint {
122    /// Server-suggested delay before retrying.
123    pub retry_after: Option<Duration>,
124}
125
126pub async fn execute_with_retry<T, Op, Fut, Classify>(
127    retry_config: &RetryConfig,
128    classify_error: Classify,
129    mut operation: Op,
130) -> Result<T>
131where
132    Op: FnMut() -> Fut,
133    Fut: Future<Output = Result<T>>,
134    Classify: Fn(&AdkError) -> bool,
135{
136    execute_with_retry_hint(retry_config, classify_error, None, &mut operation).await
137}
138
139/// Execute an operation with retry logic, optionally using a server-provided
140/// retry hint to override the backoff delay.
141///
142/// When `server_hint` contains a `retry_after` duration, that duration is used
143/// instead of the exponential backoff calculation. This respects server-provided
144/// timing from `retry-after` headers (Requirement 5.1).
145pub async fn execute_with_retry_hint<T, Op, Fut, Classify>(
146    retry_config: &RetryConfig,
147    classify_error: Classify,
148    server_hint: Option<&ServerRetryHint>,
149    operation: &mut Op,
150) -> Result<T>
151where
152    Op: FnMut() -> Fut,
153    Fut: Future<Output = Result<T>>,
154    Classify: Fn(&AdkError) -> bool,
155{
156    if !retry_config.enabled {
157        return operation().await;
158    }
159
160    let mut attempt: u32 = 0;
161    let mut delay = retry_config.initial_delay;
162
163    // If the server provided a retry-after hint, use it for the first retry delay.
164    let server_delay = server_hint.and_then(|h| h.retry_after);
165
166    loop {
167        match operation().await {
168            Ok(value) => return Ok(value),
169            Err(error) if attempt < retry_config.max_retries && classify_error(&error) => {
170                attempt += 1;
171
172                // Priority: 1) structured retry_after from AdkError, 2) server hint, 3) backoff
173                let error_retry_after = error.retry.retry_after();
174                let effective_delay = if let Some(d) = error_retry_after {
175                    d
176                } else if attempt == 1 {
177                    server_delay.unwrap_or(delay)
178                } else {
179                    delay
180                };
181
182                adk_telemetry::warn!(
183                    attempt = attempt,
184                    max_retries = retry_config.max_retries,
185                    delay_ms = effective_delay.as_millis(),
186                    error = %error,
187                    "Provider request failed with retryable error; retrying"
188                );
189                tokio::time::sleep(effective_delay).await;
190                delay = next_retry_delay(delay, retry_config);
191            }
192            Err(error) => return Err(error),
193        }
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200    use std::sync::{
201        Arc,
202        atomic::{AtomicU32, Ordering},
203    };
204
205    #[tokio::test]
206    async fn execute_with_retry_retries_when_classified_retryable() {
207        let retry_config = RetryConfig::default()
208            .with_max_retries(2)
209            .with_initial_delay(Duration::ZERO)
210            .with_max_delay(Duration::ZERO);
211        let attempts = Arc::new(AtomicU32::new(0));
212
213        let result = execute_with_retry(&retry_config, is_retryable_model_error, || {
214            let attempts = Arc::clone(&attempts);
215            async move {
216                let attempt = attempts.fetch_add(1, Ordering::SeqCst);
217                if attempt < 2 {
218                    return Err(AdkError::model("HTTP 429 rate limit"));
219                }
220                Ok("ok")
221            }
222        })
223        .await
224        .expect("operation should succeed after retries");
225
226        assert_eq!(result, "ok");
227        assert_eq!(attempts.load(Ordering::SeqCst), 3);
228    }
229
230    #[tokio::test]
231    async fn execute_with_retry_stops_on_non_retryable_error() {
232        let retry_config = RetryConfig::default()
233            .with_max_retries(3)
234            .with_initial_delay(Duration::ZERO)
235            .with_max_delay(Duration::ZERO);
236        let attempts = Arc::new(AtomicU32::new(0));
237
238        let error = execute_with_retry(&retry_config, is_retryable_model_error, || {
239            let attempts = Arc::clone(&attempts);
240            async move {
241                attempts.fetch_add(1, Ordering::SeqCst);
242                Err::<(), _>(AdkError::model("HTTP 400 bad request"))
243            }
244        })
245        .await
246        .expect_err("operation should fail without retries");
247
248        assert!(error.is_model());
249        assert_eq!(attempts.load(Ordering::SeqCst), 1);
250    }
251
252    #[tokio::test]
253    async fn execute_with_retry_respects_disabled_config() {
254        let retry_config = RetryConfig::disabled().with_max_retries(10);
255        let attempts = Arc::new(AtomicU32::new(0));
256
257        let error = execute_with_retry(&retry_config, is_retryable_model_error, || {
258            let attempts = Arc::clone(&attempts);
259            async move {
260                attempts.fetch_add(1, Ordering::SeqCst);
261                Err::<(), _>(AdkError::model("HTTP 429 too many requests"))
262            }
263        })
264        .await
265        .expect_err("disabled retries should return first error");
266
267        assert!(error.is_model());
268        assert_eq!(attempts.load(Ordering::SeqCst), 1);
269    }
270
271    #[test]
272    fn retryable_status_code_matches_transient_errors() {
273        assert!(is_retryable_status_code(429));
274        assert!(is_retryable_status_code(503));
275        assert!(is_retryable_status_code(529));
276        assert!(!is_retryable_status_code(400));
277        assert!(!is_retryable_status_code(401));
278    }
279
280    #[test]
281    fn retryable_error_message_matches_529_and_overloaded() {
282        assert!(is_retryable_error_message("HTTP 529 overloaded"));
283        assert!(is_retryable_error_message("Server OVERLOADED, try again"));
284    }
285
286    #[tokio::test]
287    async fn execute_with_retry_hint_uses_server_delay() {
288        let retry_config = RetryConfig::default()
289            .with_max_retries(2)
290            .with_initial_delay(Duration::ZERO)
291            .with_max_delay(Duration::ZERO);
292        let attempts = Arc::new(AtomicU32::new(0));
293        let hint = ServerRetryHint { retry_after: Some(Duration::ZERO) };
294
295        let result = execute_with_retry_hint(
296            &retry_config,
297            is_retryable_model_error,
298            Some(&hint),
299            &mut || {
300                let attempts = Arc::clone(&attempts);
301                async move {
302                    let attempt = attempts.fetch_add(1, Ordering::SeqCst);
303                    if attempt < 1 {
304                        return Err(AdkError::model("HTTP 429 rate limit"));
305                    }
306                    Ok("ok")
307                }
308            },
309        )
310        .await
311        .expect("operation should succeed after retry with hint");
312
313        assert_eq!(result, "ok");
314        assert_eq!(attempts.load(Ordering::SeqCst), 2);
315    }
316
317    /// Requirement 5.2: HTTP 529 (overloaded) is retried end-to-end.
318    #[tokio::test]
319    async fn status_529_is_retried_end_to_end() {
320        let retry_config = RetryConfig::default()
321            .with_max_retries(2)
322            .with_initial_delay(Duration::ZERO)
323            .with_max_delay(Duration::ZERO);
324        let attempts = Arc::new(AtomicU32::new(0));
325
326        let result = execute_with_retry(&retry_config, is_retryable_model_error, || {
327            let attempts = Arc::clone(&attempts);
328            async move {
329                let attempt = attempts.fetch_add(1, Ordering::SeqCst);
330                if attempt == 0 {
331                    return Err(AdkError::model("HTTP 529 overloaded"));
332                }
333                Ok("recovered")
334            }
335        })
336        .await
337        .expect("529 should be retried and succeed on second attempt");
338
339        assert_eq!(result, "recovered");
340        assert_eq!(attempts.load(Ordering::SeqCst), 2);
341    }
342
343    /// Requirement 5.4: Exponential backoff when retry-after is absent.
344    /// With initial_delay=20ms and multiplier=2.0, the delays should be
345    /// ~20ms (attempt 1) then ~40ms (attempt 2). We verify each gap is
346    /// at least the expected delay.
347    #[tokio::test]
348    async fn exponential_backoff_without_retry_after() {
349        let retry_config = RetryConfig::default()
350            .with_max_retries(3)
351            .with_initial_delay(Duration::from_millis(20))
352            .with_max_delay(Duration::from_millis(200))
353            .with_backoff_multiplier(2.0);
354
355        let timestamps: Arc<std::sync::Mutex<Vec<std::time::Instant>>> =
356            Arc::new(std::sync::Mutex::new(Vec::new()));
357
358        let result = execute_with_retry(&retry_config, is_retryable_model_error, || {
359            let timestamps = Arc::clone(&timestamps);
360            async move {
361                let now = std::time::Instant::now();
362                let mut ts = timestamps.lock().unwrap();
363                let attempt = ts.len();
364                ts.push(now);
365                if attempt < 3 {
366                    return Err(AdkError::model("HTTP 429 rate limit"));
367                }
368                Ok("done")
369            }
370        })
371        .await
372        .expect("should succeed after backoff retries");
373
374        assert_eq!(result, "done");
375
376        let ts = timestamps.lock().unwrap();
377        assert_eq!(ts.len(), 4); // initial + 3 retries
378
379        // Gap between attempt 0 and 1 should be >= initial_delay (20ms).
380        let gap1 = ts[1].duration_since(ts[0]);
381        assert!(gap1 >= Duration::from_millis(18), "first backoff gap {gap1:?} should be >= ~20ms");
382
383        // Gap between attempt 1 and 2 should be >= 2 * initial_delay (40ms).
384        let gap2 = ts[2].duration_since(ts[1]);
385        assert!(
386            gap2 >= Duration::from_millis(36),
387            "second backoff gap {gap2:?} should be >= ~40ms"
388        );
389
390        // Gap 2 should be roughly double gap 1 (with tolerance for scheduling).
391        assert!(gap2 >= gap1, "backoff should increase: gap2={gap2:?} should be >= gap1={gap1:?}");
392    }
393}