Skip to main content

opendev_http/models/
retry.rs

1//! Retry configuration, backoff logic, and error classification.
2
3use serde::{Deserialize, Serialize};
4
5/// Configuration for retry behavior.
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct RetryConfig {
8    /// Maximum number of retry attempts (not counting the initial request).
9    pub max_retries: u32,
10    /// Base delays in seconds for exponential backoff.
11    pub retry_delays: Vec<f64>,
12    /// HTTP status codes that trigger a retry.
13    pub retryable_status_codes: Vec<u16>,
14    /// Initial delay in milliseconds for exponential backoff.
15    #[serde(default = "default_initial_delay_ms")]
16    pub initial_delay_ms: u64,
17    /// Multiplier for exponential backoff (delay *= factor each attempt).
18    #[serde(default = "default_backoff_factor")]
19    pub backoff_factor: f64,
20    /// Maximum delay in milliseconds (cap for exponential growth).
21    #[serde(default = "default_max_delay_ms")]
22    pub max_delay_ms: u64,
23}
24
25fn default_initial_delay_ms() -> u64 {
26    2000
27}
28fn default_backoff_factor() -> f64 {
29    2.0
30}
31fn default_max_delay_ms() -> u64 {
32    30000
33}
34
35impl Default for RetryConfig {
36    fn default() -> Self {
37        Self {
38            max_retries: 3,
39            retry_delays: vec![1.0, 2.0, 4.0],
40            retryable_status_codes: vec![429, 500, 502, 503, 504],
41            initial_delay_ms: 2000,
42            backoff_factor: 2.0,
43            max_delay_ms: 30000,
44        }
45    }
46}
47
48impl RetryConfig {
49    /// Get the delay for a given attempt index (0-based).
50    ///
51    /// Uses exponential backoff: `initial_delay_ms * backoff_factor^attempt`,
52    /// capped at `max_delay_ms`. Falls back to the legacy `retry_delays` array
53    /// if `initial_delay_ms` is 0.
54    pub fn delay_for_attempt(&self, attempt: u32) -> std::time::Duration {
55        if self.initial_delay_ms > 0 {
56            let delay_ms = self.initial_delay_ms as f64 * self.backoff_factor.powi(attempt as i32);
57            let capped_ms = delay_ms.min(self.max_delay_ms as f64);
58            // Add ±25% random jitter to avoid thundering herd when parallel
59            // agents all retry at the same backoff intervals
60            let jitter_factor = 0.75 + fastrand::f64() * 0.5; // [0.75, 1.25]
61            let final_ms = (capped_ms * jitter_factor) as u64;
62            return std::time::Duration::from_millis(final_ms);
63        }
64        // Legacy fallback: use fixed delay array
65        let idx = (attempt as usize).min(self.retry_delays.len().saturating_sub(1));
66        let secs = self.retry_delays.get(idx).copied().unwrap_or(4.0);
67        std::time::Duration::from_secs_f64(secs)
68    }
69
70    /// Check if a status code is retryable.
71    pub fn is_retryable_status(&self, status: u16) -> bool {
72        self.retryable_status_codes.contains(&status)
73    }
74}
75
76/// Classify an API error response and return a human-readable retry reason.
77///
78/// Returns `Some(message)` if the error is retryable, `None` if it should not be retried.
79pub fn classify_retryable_error(
80    status: Option<u16>,
81    body: Option<&serde_json::Value>,
82) -> Option<String> {
83    // Check status code first
84    match status {
85        Some(429) => {
86            if let Some(body) = body
87                && let Some(msg) = extract_error_message(body)
88            {
89                if msg.contains("rate_limit") || msg.contains("Rate") {
90                    return Some("Rate Limited".to_string());
91                }
92                if msg.contains("too_many_requests") || msg.contains("Too Many") {
93                    return Some("Too Many Requests".to_string());
94                }
95            }
96            return Some("Rate Limited".to_string());
97        }
98        Some(529) => return Some("Provider is overloaded".to_string()),
99        Some(503) => {
100            if let Some(body) = body
101                && let Some(msg) = extract_error_message(body)
102            {
103                if msg.contains("overloaded") || msg.contains("Overloaded") {
104                    return Some("Provider is overloaded".to_string());
105                }
106                if msg.contains("unavailable") || msg.contains("exhausted") {
107                    return Some("Provider is overloaded".to_string());
108                }
109            }
110            return Some("Service Unavailable".to_string());
111        }
112        Some(500) => return Some("Internal Server Error".to_string()),
113        Some(502) => return Some("Bad Gateway".to_string()),
114        Some(504) => return Some("Gateway Timeout".to_string()),
115        _ => {}
116    }
117
118    // Check body for retryable error patterns even without a matching status
119    if let Some(body) = body
120        && let Some(msg) = extract_error_message(body)
121        && (msg.contains("overloaded") || msg.contains("Overloaded"))
122    {
123        return Some("Provider is overloaded".to_string());
124    }
125
126    None
127}
128
129/// Parse a `Retry-After` or `retry-after-ms` header value into a Duration.
130///
131/// Supports:
132/// - `retry-after-ms` header (milliseconds, if provided separately)
133/// - `Retry-After` as seconds (integer or float)
134/// - `Retry-After` as HTTP date (RFC 2822)
135///
136/// Returns `None` if parsing fails.
137pub fn parse_retry_after(
138    retry_after: Option<&str>,
139    retry_after_ms: Option<&str>,
140) -> Option<std::time::Duration> {
141    // Prefer retry-after-ms (more precise)
142    if let Some(ms_str) = retry_after_ms
143        && let Ok(ms) = ms_str.parse::<u64>()
144    {
145        return Some(std::time::Duration::from_millis(ms));
146    }
147
148    let val = retry_after?;
149
150    // Try parsing as seconds (integer or float)
151    if let Ok(secs) = val.parse::<f64>()
152        && secs > 0.0
153    {
154        return Some(std::time::Duration::from_secs_f64(secs));
155    }
156
157    // Try parsing as HTTP date (RFC 2822 / RFC 7231)
158    // Example: "Wed, 21 Oct 2015 07:28:00 GMT"
159    if val.contains(',')
160        && val.contains("GMT")
161        && let Ok(date) = httpdate::parse_http_date(val)
162        && let Ok(duration) = date.duration_since(std::time::SystemTime::now())
163    {
164        return Some(duration);
165    }
166
167    None
168}
169
170/// Extract an error message from an API error response body.
171pub(super) fn extract_error_message(body: &serde_json::Value) -> Option<String> {
172    // OpenAI: {"error": {"message": "...", "type": "...", "code": "..."}}
173    if let Some(err) = body.get("error") {
174        if let Some(msg) = err.get("message").and_then(|v| v.as_str()) {
175            return Some(msg.to_string());
176        }
177        if let Some(code) = err.get("code").and_then(|v| v.as_str()) {
178            return Some(code.to_string());
179        }
180        if let Some(err_type) = err.get("type").and_then(|v| v.as_str()) {
181            return Some(err_type.to_string());
182        }
183        if let Some(msg) = err.as_str() {
184            return Some(msg.to_string());
185        }
186    }
187    // Anthropic: {"type": "error", "error": {"type": "...", "message": "..."}}
188    if body.get("type").and_then(|v| v.as_str()) == Some("error")
189        && let Some(err) = body.get("error")
190        && let Some(msg) = err.get("message").and_then(|v| v.as_str())
191    {
192        return Some(msg.to_string());
193    }
194    // Generic message field
195    body.get("message")
196        .and_then(|v| v.as_str())
197        .map(|s| s.to_string())
198}
199
200#[cfg(test)]
201#[path = "retry_tests.rs"]
202mod tests;