Skip to main content

adk_model/
retry.rs

1use adk_core::{AdkError, Result};
2use std::{future::Future, time::Duration};
3
4/// Configuration for automatic retry with exponential backoff.
5#[derive(Clone, Debug)]
6pub struct RetryConfig {
7    /// Whether retries are enabled.
8    pub enabled: bool,
9    /// Maximum number of retry attempts before giving up.
10    pub max_retries: u32,
11    /// Initial delay before the first retry.
12    pub initial_delay: Duration,
13    /// Maximum delay between retries (caps exponential growth).
14    pub max_delay: Duration,
15    /// Multiplier applied to the delay after each retry.
16    pub backoff_multiplier: f32,
17}
18
19impl Default for RetryConfig {
20    fn default() -> Self {
21        Self {
22            enabled: true,
23            max_retries: 3,
24            initial_delay: Duration::from_millis(250),
25            max_delay: Duration::from_secs(5),
26            backoff_multiplier: 2.0,
27        }
28    }
29}
30
31impl RetryConfig {
32    /// Create a disabled retry configuration (no retries).
33    #[must_use]
34    pub fn disabled() -> Self {
35        Self { enabled: false, ..Self::default() }
36    }
37
38    /// Set the maximum number of retry attempts.
39    #[must_use]
40    pub fn with_max_retries(mut self, max_retries: u32) -> Self {
41        self.max_retries = max_retries;
42        self
43    }
44
45    /// Set the initial delay before the first retry.
46    #[must_use]
47    pub fn with_initial_delay(mut self, initial_delay: Duration) -> Self {
48        self.initial_delay = initial_delay;
49        self
50    }
51
52    /// Set the maximum delay between retries.
53    #[must_use]
54    pub fn with_max_delay(mut self, max_delay: Duration) -> Self {
55        self.max_delay = max_delay;
56        self
57    }
58
59    /// Set the backoff multiplier applied after each retry.
60    #[must_use]
61    pub fn with_backoff_multiplier(mut self, backoff_multiplier: f32) -> Self {
62        self.backoff_multiplier = backoff_multiplier;
63        self
64    }
65}
66
67/// Returns `true` if the HTTP status code indicates a transient error worth retrying.
68#[must_use]
69pub fn is_retryable_status_code(status_code: u16) -> bool {
70    matches!(status_code, 408 | 429 | 500 | 502 | 503 | 504 | 529)
71}
72
73/// Returns `true` if the error message contains patterns indicating a transient failure.
74#[must_use]
75pub fn is_retryable_error_message(message: &str) -> bool {
76    let normalized = message.to_ascii_uppercase();
77    normalized.contains("429")
78        || normalized.contains("408")
79        || normalized.contains("500")
80        || normalized.contains("502")
81        || normalized.contains("503")
82        || normalized.contains("504")
83        || normalized.contains("529")
84        || normalized.contains("RATE LIMIT")
85        || normalized.contains("TOO MANY REQUESTS")
86        || normalized.contains("RESOURCE_EXHAUSTED")
87        || normalized.contains("UNAVAILABLE")
88        || normalized.contains("DEADLINE_EXCEEDED")
89        || normalized.contains("TIMEOUT")
90        || normalized.contains("TIMED OUT")
91        || normalized.contains("CONNECTION RESET")
92        || normalized.contains("OVERLOADED")
93}
94
95/// Returns `true` if the `AdkError` should be retried based on its retry hint or message.
96#[must_use]
97pub fn is_retryable_model_error(error: &AdkError) -> bool {
98    // Primary: use structured retry hint (single source of truth)
99    if error.retry.should_retry {
100        return true;
101    }
102    // Fallback: for backward-compat `.legacy` errors during transition,
103    // check the error message for retryable patterns
104    if error.code.ends_with(".legacy") && error.is_model() {
105        return is_retryable_error_message(&error.message);
106    }
107    false
108}
109
110fn next_retry_delay(current: Duration, retry_config: &RetryConfig) -> Duration {
111    if current >= retry_config.max_delay {
112        return retry_config.max_delay;
113    }
114
115    let multiplier = retry_config.backoff_multiplier.max(1.0) as f64;
116    let scaled = Duration::from_secs_f64(current.as_secs_f64() * multiplier);
117    scaled.min(retry_config.max_delay)
118}
119
120/// Hint from the server about when to retry.
121///
122/// When the server provides a `retry-after` header, this hint overrides the
123/// exponential backoff calculation for the next retry attempt.
124///
125/// # Example
126///
127/// ```rust
128/// use adk_model::retry::ServerRetryHint;
129/// use std::time::Duration;
130///
131/// let hint = ServerRetryHint { retry_after: Some(Duration::from_secs(30)) };
132/// assert_eq!(hint.retry_after, Some(Duration::from_secs(30)));
133/// ```
134#[derive(Debug, Clone, Default)]
135pub struct ServerRetryHint {
136    /// Server-suggested delay before retrying.
137    pub retry_after: Option<Duration>,
138}
139
140/// Execute an operation with automatic retry on transient errors.
141pub async fn execute_with_retry<T, Op, Fut, Classify>(
142    retry_config: &RetryConfig,
143    classify_error: Classify,
144    mut operation: Op,
145) -> Result<T>
146where
147    Op: FnMut() -> Fut,
148    Fut: Future<Output = Result<T>>,
149    Classify: Fn(&AdkError) -> bool,
150{
151    execute_with_retry_hint(retry_config, classify_error, None, &mut operation).await
152}
153
154/// Execute an operation with retry logic, optionally using a server-provided
155/// retry hint to override the backoff delay.
156///
157/// When `server_hint` contains a `retry_after` duration, that duration is used
158/// instead of the exponential backoff calculation. This respects server-provided
159/// timing from `retry-after` headers (Requirement 5.1).
160pub async fn execute_with_retry_hint<T, Op, Fut, Classify>(
161    retry_config: &RetryConfig,
162    classify_error: Classify,
163    server_hint: Option<&ServerRetryHint>,
164    operation: &mut Op,
165) -> Result<T>
166where
167    Op: FnMut() -> Fut,
168    Fut: Future<Output = Result<T>>,
169    Classify: Fn(&AdkError) -> bool,
170{
171    if !retry_config.enabled {
172        return operation().await;
173    }
174
175    let mut attempt: u32 = 0;
176    let mut delay = retry_config.initial_delay;
177
178    // If the server provided a retry-after hint, use it for the first retry delay.
179    let server_delay = server_hint.and_then(|h| h.retry_after);
180
181    loop {
182        match operation().await {
183            Ok(value) => return Ok(value),
184            Err(error) if attempt < retry_config.max_retries && classify_error(&error) => {
185                attempt += 1;
186
187                // Priority: 1) structured retry_after from AdkError, 2) server hint, 3) backoff
188                let error_retry_after = error.retry.retry_after();
189                let effective_delay = if let Some(d) = error_retry_after {
190                    d
191                } else if attempt == 1 {
192                    server_delay.unwrap_or(delay)
193                } else {
194                    delay
195                };
196
197                adk_telemetry::warn!(
198                    attempt = attempt,
199                    max_retries = retry_config.max_retries,
200                    delay_ms = effective_delay.as_millis(),
201                    error = %error,
202                    "Provider request failed with retryable error; retrying"
203                );
204                tokio::time::sleep(effective_delay).await;
205                delay = next_retry_delay(delay, retry_config);
206            }
207            Err(error) => return Err(error),
208        }
209    }
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215    use std::sync::{
216        Arc,
217        atomic::{AtomicU32, Ordering},
218    };
219
220    #[tokio::test]
221    async fn execute_with_retry_retries_when_classified_retryable() {
222        let retry_config = RetryConfig::default()
223            .with_max_retries(2)
224            .with_initial_delay(Duration::ZERO)
225            .with_max_delay(Duration::ZERO);
226        let attempts = Arc::new(AtomicU32::new(0));
227
228        let result = execute_with_retry(&retry_config, is_retryable_model_error, || {
229            let attempts = Arc::clone(&attempts);
230            async move {
231                let attempt = attempts.fetch_add(1, Ordering::SeqCst);
232                if attempt < 2 {
233                    return Err(AdkError::model("HTTP 429 rate limit"));
234                }
235                Ok("ok")
236            }
237        })
238        .await
239        .expect("operation should succeed after retries");
240
241        assert_eq!(result, "ok");
242        assert_eq!(attempts.load(Ordering::SeqCst), 3);
243    }
244
245    #[tokio::test]
246    async fn execute_with_retry_stops_on_non_retryable_error() {
247        let retry_config = RetryConfig::default()
248            .with_max_retries(3)
249            .with_initial_delay(Duration::ZERO)
250            .with_max_delay(Duration::ZERO);
251        let attempts = Arc::new(AtomicU32::new(0));
252
253        let error = execute_with_retry(&retry_config, is_retryable_model_error, || {
254            let attempts = Arc::clone(&attempts);
255            async move {
256                attempts.fetch_add(1, Ordering::SeqCst);
257                Err::<(), _>(AdkError::model("HTTP 400 bad request"))
258            }
259        })
260        .await
261        .expect_err("operation should fail without retries");
262
263        assert!(error.is_model());
264        assert_eq!(attempts.load(Ordering::SeqCst), 1);
265    }
266
267    #[tokio::test]
268    async fn execute_with_retry_respects_disabled_config() {
269        let retry_config = RetryConfig::disabled().with_max_retries(10);
270        let attempts = Arc::new(AtomicU32::new(0));
271
272        let error = execute_with_retry(&retry_config, is_retryable_model_error, || {
273            let attempts = Arc::clone(&attempts);
274            async move {
275                attempts.fetch_add(1, Ordering::SeqCst);
276                Err::<(), _>(AdkError::model("HTTP 429 too many requests"))
277            }
278        })
279        .await
280        .expect_err("disabled retries should return first error");
281
282        assert!(error.is_model());
283        assert_eq!(attempts.load(Ordering::SeqCst), 1);
284    }
285
286    #[test]
287    fn retryable_status_code_matches_transient_errors() {
288        assert!(is_retryable_status_code(429));
289        assert!(is_retryable_status_code(503));
290        assert!(is_retryable_status_code(529));
291        assert!(!is_retryable_status_code(400));
292        assert!(!is_retryable_status_code(401));
293    }
294
295    #[test]
296    fn retryable_error_message_matches_529_and_overloaded() {
297        assert!(is_retryable_error_message("HTTP 529 overloaded"));
298        assert!(is_retryable_error_message("Server OVERLOADED, try again"));
299    }
300
301    #[tokio::test]
302    async fn execute_with_retry_hint_uses_server_delay() {
303        let retry_config = RetryConfig::default()
304            .with_max_retries(2)
305            .with_initial_delay(Duration::ZERO)
306            .with_max_delay(Duration::ZERO);
307        let attempts = Arc::new(AtomicU32::new(0));
308        let hint = ServerRetryHint { retry_after: Some(Duration::ZERO) };
309
310        let result = execute_with_retry_hint(
311            &retry_config,
312            is_retryable_model_error,
313            Some(&hint),
314            &mut || {
315                let attempts = Arc::clone(&attempts);
316                async move {
317                    let attempt = attempts.fetch_add(1, Ordering::SeqCst);
318                    if attempt < 1 {
319                        return Err(AdkError::model("HTTP 429 rate limit"));
320                    }
321                    Ok("ok")
322                }
323            },
324        )
325        .await
326        .expect("operation should succeed after retry with hint");
327
328        assert_eq!(result, "ok");
329        assert_eq!(attempts.load(Ordering::SeqCst), 2);
330    }
331
332    /// Requirement 5.2: HTTP 529 (overloaded) is retried end-to-end.
333    #[tokio::test]
334    async fn status_529_is_retried_end_to_end() {
335        let retry_config = RetryConfig::default()
336            .with_max_retries(2)
337            .with_initial_delay(Duration::ZERO)
338            .with_max_delay(Duration::ZERO);
339        let attempts = Arc::new(AtomicU32::new(0));
340
341        let result = execute_with_retry(&retry_config, is_retryable_model_error, || {
342            let attempts = Arc::clone(&attempts);
343            async move {
344                let attempt = attempts.fetch_add(1, Ordering::SeqCst);
345                if attempt == 0 {
346                    return Err(AdkError::model("HTTP 529 overloaded"));
347                }
348                Ok("recovered")
349            }
350        })
351        .await
352        .expect("529 should be retried and succeed on second attempt");
353
354        assert_eq!(result, "recovered");
355        assert_eq!(attempts.load(Ordering::SeqCst), 2);
356    }
357
358    /// Requirement 5.4: Exponential backoff when retry-after is absent.
359    /// With initial_delay=20ms and multiplier=2.0, the delays should be
360    /// ~20ms (attempt 1) then ~40ms (attempt 2). We verify each gap is
361    /// at least the expected delay.
362    #[tokio::test]
363    async fn exponential_backoff_without_retry_after() {
364        let retry_config = RetryConfig::default()
365            .with_max_retries(3)
366            .with_initial_delay(Duration::from_millis(20))
367            .with_max_delay(Duration::from_millis(200))
368            .with_backoff_multiplier(2.0);
369
370        let timestamps: Arc<std::sync::Mutex<Vec<std::time::Instant>>> =
371            Arc::new(std::sync::Mutex::new(Vec::new()));
372
373        let result = execute_with_retry(&retry_config, is_retryable_model_error, || {
374            let timestamps = Arc::clone(&timestamps);
375            async move {
376                let now = std::time::Instant::now();
377                let mut ts = timestamps.lock().unwrap();
378                let attempt = ts.len();
379                ts.push(now);
380                if attempt < 3 {
381                    return Err(AdkError::model("HTTP 429 rate limit"));
382                }
383                Ok("done")
384            }
385        })
386        .await
387        .expect("should succeed after backoff retries");
388
389        assert_eq!(result, "done");
390
391        let ts = timestamps.lock().unwrap();
392        assert_eq!(ts.len(), 4); // initial + 3 retries
393
394        // Gap between attempt 0 and 1 should be >= initial_delay (20ms).
395        let gap1 = ts[1].duration_since(ts[0]);
396        assert!(gap1 >= Duration::from_millis(18), "first backoff gap {gap1:?} should be >= ~20ms");
397
398        // Gap between attempt 1 and 2 should be >= 2 * initial_delay (40ms).
399        let gap2 = ts[2].duration_since(ts[1]);
400        assert!(
401            gap2 >= Duration::from_millis(36),
402            "second backoff gap {gap2:?} should be >= ~40ms"
403        );
404
405        // Gap 2 should be roughly double gap 1 (with tolerance for scheduling).
406        assert!(gap2 >= gap1, "backoff should increase: gap2={gap2:?} should be >= gap1={gap1:?}");
407    }
408}