Skip to main content

atomcode_core/provider/
retry.rs

1//! HTTP retry / backoff / Retry-After helpers for LLM providers.
2//!
3//! Retries happen ONLY before the streaming response begins. Once the helper
4//! returns `Ok(Response)`, the caller owns the stream and any error during
5//! SSE iteration must NOT be retried — partial deltas may already have reached
6//! the user.
7
8use std::time::Duration;
9
10/// Retry configuration.
11#[derive(Debug, Clone)]
12pub struct RetryPolicy {
13    pub max_attempts: u32,
14    pub base_delay: Duration,
15    pub max_delay: Duration,
16}
17
18impl RetryPolicy {
19    /// Default policy: 3 attempts, 500ms base, 8s max.
20    pub fn default_policy() -> Self {
21        Self {
22            max_attempts: 3,
23            base_delay: Duration::from_millis(500),
24            max_delay: Duration::from_secs(8),
25        }
26    }
27
28    /// Fast policy for tests: 3 attempts, 1ms base, 10ms max.
29    #[cfg(test)]
30    pub fn testing() -> Self {
31        Self {
32            max_attempts: 3,
33            base_delay: Duration::from_millis(1),
34            max_delay: Duration::from_millis(10),
35        }
36    }
37}
38
39impl Default for RetryPolicy {
40    fn default() -> Self {
41        Self::default_policy()
42    }
43}
44
45/// Status codes that indicate a transient server-side issue worth retrying.
46fn is_retryable_status(status: reqwest::StatusCode) -> bool {
47    matches!(status.as_u16(), 408 | 425 | 429 | 500 | 502 | 503 | 504)
48}
49
50/// Whether a reqwest error is a transient transport issue worth retrying.
51fn is_retryable_error(err: &reqwest::Error) -> bool {
52    err.is_timeout() || err.is_connect()
53}
54
55/// Parse `Retry-After` header as integer seconds. Returns `None` for absent,
56/// malformed, or HTTP-date formats (we currently don't support HTTP-date).
57fn parse_retry_after(headers: &reqwest::header::HeaderMap) -> Option<Duration> {
58    let value = headers.get(reqwest::header::RETRY_AFTER)?.to_str().ok()?;
59    let secs: u64 = value.trim().parse().ok()?;
60    Some(Duration::from_secs(secs))
61}
62
63/// Compute exponential backoff delay with ±25% deterministic jitter.
64fn compute_backoff(attempt: u32, policy: &RetryPolicy) -> Duration {
65    let exp = policy
66        .base_delay
67        .saturating_mul(1u32 << attempt.saturating_sub(1).min(16));
68    let capped = exp.min(policy.max_delay);
69
70    // Deterministic pseudo-jitter from wall-clock nanos: ±25% of capped.
71    let nanos = std::time::SystemTime::now()
72        .duration_since(std::time::UNIX_EPOCH)
73        .map(|d| d.subsec_nanos())
74        .unwrap_or(0);
75    let range = (capped.as_millis() / 2) as u64; // total ±25% = 50% range
76    let jitter_ms = if range > 0 { (nanos as u64) % range } else { 0 };
77    let jitter = Duration::from_millis(jitter_ms);
78    // Center on capped: actual = capped - range/2 + jitter_in_[0, range]
79    let floor = capped.saturating_sub(Duration::from_millis(range / 2));
80    floor + jitter
81}
82
83/// Async retry wrapper for streaming providers.
84///
85/// Uses `RequestBuilder::build_split()` so builder-chain errors
86/// (illegal header value, bad URL, JSON serialization failure)
87/// surface as a real `reqwest::Error` on the return path instead
88/// of crashing the process. A pasted api_key with a trailing
89/// newline is the classic trigger — historically this produced
90/// `panicked at .../retry.rs:94: send_with_retry: request body
91/// must be cloneable (no streams)` with no path to recovery.
92pub async fn send_with_retry(
93    builder: reqwest::RequestBuilder,
94    policy: &RetryPolicy,
95) -> Result<reqwest::Response, reqwest::Error> {
96    // Split the builder into (client, Result<Request>). The Err
97    // variant carries the actual root cause of any failed chain
98    // call; `?` propagates it as a proper reqwest::Error instead
99    // of letting `try_clone` below return None and panic.
100    let (client, built) = builder.build_split();
101    let req = built?;
102    let mut last_err: Option<reqwest::Error> = None;
103    for attempt in 1..=policy.max_attempts {
104        // `Request::try_clone` returns None only for stream bodies
105        // (our callers use `.json(...)` → Bytes, never streams). If
106        // a future caller ever attaches a stream, fall back rather
107        // than panic: surface whatever retryable error we've
108        // accumulated, or single-shot the original request on the
109        // very first attempt.
110        let this_req = match req.try_clone() {
111            Some(c) => c,
112            None => {
113                return match last_err {
114                    Some(e) => Err(e),
115                    None => client.execute(req).await,
116                };
117            }
118        };
119        match client.execute(this_req).await {
120            Ok(resp) => {
121                if is_retryable_status(resp.status()) && attempt < policy.max_attempts {
122                    let wait = parse_retry_after(resp.headers())
123                        .unwrap_or_else(|| compute_backoff(attempt, policy));
124                    tokio::time::sleep(wait).await;
125                    continue;
126                }
127                return Ok(resp);
128            }
129            Err(e) => {
130                if is_retryable_error(&e) && attempt < policy.max_attempts {
131                    let wait = compute_backoff(attempt, policy);
132                    last_err = Some(e);
133                    tokio::time::sleep(wait).await;
134                    continue;
135                }
136                return Err(e);
137            }
138        }
139    }
140    // Unreachable in practice (the loop always returns or continues), but
141    // keeps the type system happy if max_attempts == 0.
142    Err(last_err.expect("send_with_retry: loop terminated without error or response"))
143}
144
145/// Blocking variant for sync code paths (e.g. OAuth token refresh in `create_provider`).
146/// Same contract as `send_with_retry`: builder-chain errors are surfaced
147/// as `reqwest::Error` rather than panics.
148pub fn send_with_retry_blocking(
149    builder: reqwest::blocking::RequestBuilder,
150    policy: &RetryPolicy,
151) -> Result<reqwest::blocking::Response, reqwest::Error> {
152    let (client, built) = builder.build_split();
153    let req = built?;
154    let mut last_err: Option<reqwest::Error> = None;
155    for attempt in 1..=policy.max_attempts {
156        let this_req = match req.try_clone() {
157            Some(c) => c,
158            None => {
159                return match last_err {
160                    Some(e) => Err(e),
161                    None => client.execute(req),
162                };
163            }
164        };
165        match client.execute(this_req) {
166            Ok(resp) => {
167                if is_retryable_status(resp.status()) && attempt < policy.max_attempts {
168                    let wait = parse_retry_after(resp.headers())
169                        .unwrap_or_else(|| compute_backoff(attempt, policy));
170                    std::thread::sleep(wait);
171                    continue;
172                }
173                return Ok(resp);
174            }
175            Err(e) => {
176                if is_retryable_error(&e) && attempt < policy.max_attempts {
177                    let wait = compute_backoff(attempt, policy);
178                    last_err = Some(e);
179                    std::thread::sleep(wait);
180                    continue;
181                }
182                return Err(e);
183            }
184        }
185    }
186    Err(last_err.expect("send_with_retry_blocking: loop terminated without error or response"))
187}
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192    use reqwest::header::{HeaderMap, HeaderValue, RETRY_AFTER};
193
194    #[test]
195    fn parse_retry_after_seconds() {
196        let mut h = HeaderMap::new();
197        h.insert(RETRY_AFTER, HeaderValue::from_static("3"));
198        assert_eq!(parse_retry_after(&h), Some(Duration::from_secs(3)));
199    }
200
201    #[test]
202    fn parse_retry_after_missing_returns_none() {
203        let h = HeaderMap::new();
204        assert_eq!(parse_retry_after(&h), None);
205    }
206
207    #[test]
208    fn parse_retry_after_http_date_returns_none() {
209        let mut h = HeaderMap::new();
210        h.insert(
211            RETRY_AFTER,
212            HeaderValue::from_static("Wed, 21 Oct 2015 07:28:00 GMT"),
213        );
214        assert_eq!(parse_retry_after(&h), None);
215    }
216
217    #[test]
218    fn retryable_status_includes_429_and_5xx() {
219        assert!(is_retryable_status(reqwest::StatusCode::TOO_MANY_REQUESTS));
220        assert!(is_retryable_status(
221            reqwest::StatusCode::INTERNAL_SERVER_ERROR
222        ));
223        assert!(is_retryable_status(reqwest::StatusCode::BAD_GATEWAY));
224        assert!(is_retryable_status(
225            reqwest::StatusCode::SERVICE_UNAVAILABLE
226        ));
227        assert!(is_retryable_status(reqwest::StatusCode::GATEWAY_TIMEOUT));
228        assert!(is_retryable_status(reqwest::StatusCode::REQUEST_TIMEOUT));
229    }
230
231    #[test]
232    fn retryable_status_excludes_auth_and_validation() {
233        assert!(!is_retryable_status(reqwest::StatusCode::UNAUTHORIZED));
234        assert!(!is_retryable_status(reqwest::StatusCode::FORBIDDEN));
235        assert!(!is_retryable_status(reqwest::StatusCode::BAD_REQUEST));
236        assert!(!is_retryable_status(reqwest::StatusCode::NOT_FOUND));
237    }
238
239    #[test]
240    fn backoff_respects_max_delay() {
241        let policy = RetryPolicy {
242            max_attempts: 10,
243            base_delay: Duration::from_millis(500),
244            max_delay: Duration::from_secs(1),
245        };
246        // After enough attempts, should cap at max_delay (+/- jitter).
247        let d = compute_backoff(10, &policy);
248        assert!(d <= Duration::from_millis(1500), "got {:?}", d);
249    }
250
251    use wiremock::matchers::{method, path};
252    use wiremock::{Mock, MockServer, ResponseTemplate};
253
254    fn client() -> reqwest::Client {
255        reqwest::Client::builder()
256            .connect_timeout(Duration::from_secs(2))
257            .timeout(Duration::from_secs(2))
258            .build()
259            .unwrap()
260    }
261
262    #[tokio::test]
263    async fn succeeds_on_first_try() {
264        let server = MockServer::start().await;
265        Mock::given(method("POST"))
266            .and(path("/chat"))
267            .respond_with(ResponseTemplate::new(200).set_body_string("ok"))
268            .expect(1)
269            .mount(&server)
270            .await;
271
272        let builder = client().post(format!("{}/chat", server.uri())).body("req");
273        let resp = send_with_retry(builder, &RetryPolicy::testing())
274            .await
275            .unwrap();
276        assert_eq!(resp.status(), 200);
277    }
278
279    #[tokio::test]
280    async fn retries_on_500_then_succeeds() {
281        let server = MockServer::start().await;
282        // First: 500. Second: 200.
283        Mock::given(method("POST"))
284            .and(path("/chat"))
285            .respond_with(ResponseTemplate::new(500))
286            .up_to_n_times(1)
287            .mount(&server)
288            .await;
289        Mock::given(method("POST"))
290            .and(path("/chat"))
291            .respond_with(ResponseTemplate::new(200).set_body_string("ok"))
292            .mount(&server)
293            .await;
294
295        let builder = client().post(format!("{}/chat", server.uri())).body("req");
296        let resp = send_with_retry(builder, &RetryPolicy::testing())
297            .await
298            .unwrap();
299        assert_eq!(resp.status(), 200);
300    }
301
302    #[tokio::test]
303    async fn exhausts_on_persistent_500() {
304        let server = MockServer::start().await;
305        Mock::given(method("POST"))
306            .and(path("/chat"))
307            .respond_with(ResponseTemplate::new(500))
308            .expect(3) // max_attempts
309            .mount(&server)
310            .await;
311
312        let builder = client().post(format!("{}/chat", server.uri())).body("req");
313        let resp = send_with_retry(builder, &RetryPolicy::testing())
314            .await
315            .unwrap();
316        assert_eq!(resp.status(), 500);
317    }
318
319    #[tokio::test]
320    async fn does_not_retry_on_401() {
321        let server = MockServer::start().await;
322        Mock::given(method("POST"))
323            .and(path("/chat"))
324            .respond_with(ResponseTemplate::new(401))
325            .expect(1) // must NOT retry
326            .mount(&server)
327            .await;
328
329        let builder = client().post(format!("{}/chat", server.uri())).body("req");
330        let resp = send_with_retry(builder, &RetryPolicy::testing())
331            .await
332            .unwrap();
333        assert_eq!(resp.status(), 401);
334    }
335
336    #[tokio::test]
337    async fn honors_retry_after_on_429() {
338        let server = MockServer::start().await;
339        Mock::given(method("POST"))
340            .and(path("/chat"))
341            .respond_with(ResponseTemplate::new(429).insert_header("Retry-After", "1"))
342            .up_to_n_times(1)
343            .mount(&server)
344            .await;
345        Mock::given(method("POST"))
346            .and(path("/chat"))
347            .respond_with(ResponseTemplate::new(200).set_body_string("ok"))
348            .mount(&server)
349            .await;
350
351        let start = std::time::Instant::now();
352        let builder = client().post(format!("{}/chat", server.uri())).body("req");
353        let resp = send_with_retry(builder, &RetryPolicy::testing())
354            .await
355            .unwrap();
356        let elapsed = start.elapsed();
357        assert_eq!(resp.status(), 200);
358        assert!(
359            elapsed >= Duration::from_millis(900),
360            "expected ~1s wait from Retry-After, got {:?}",
361            elapsed
362        );
363    }
364
365    #[tokio::test]
366    async fn retries_on_connect_error() {
367        // Pick a closed port: bind + drop a listener to get an unused port, then target it.
368        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
369        let addr = listener.local_addr().unwrap();
370        drop(listener);
371
372        let builder = client().post(format!("http://{}/chat", addr)).body("req");
373        let err = send_with_retry(builder, &RetryPolicy::testing())
374            .await
375            .unwrap_err();
376        assert!(err.is_connect() || err.is_request(), "got {:?}", err);
377    }
378
379    /// Regression for user-reported crash: `send_with_retry: request
380    /// body must be cloneable (no streams)` panic at runtime.
381    ///
382    /// Real trigger: a `.header(...)` call in the builder chain
383    /// stashes an error (e.g. header value contains `\n` from a
384    /// copy-pasted token, or base URL is malformed). The builder
385    /// sits in `request: Err(...)` state. `try_clone()` returns
386    /// None for builders-in-error-state, and the old code called
387    /// `.expect("... must be cloneable")` — misleading message AND
388    /// a full process crash where the user sees no path to
389    /// recovery.
390    ///
391    /// After the fix `build_split()` pulls out the real error so
392    /// callers receive a proper reqwest::Error with an actionable
393    /// message instead of a panic.
394    #[tokio::test]
395    async fn send_with_retry_returns_builder_error_instead_of_panicking() {
396        let result = std::panic::AssertUnwindSafe(async {
397            let builder = client()
398                .post("http://127.0.0.1:1/")
399                // `\n` is illegal in an HTTP header value (ASCII
400                // control chars are rejected by http::HeaderValue).
401                // reqwest stashes the error in the builder and the
402                // failure only surfaces when we try to clone or
403                // build the request.
404                .header("Authorization", "Bearer token-with\n-newline");
405            send_with_retry(builder, &RetryPolicy::testing()).await
406        });
407        // The test runs the future inside catch_unwind so a panic
408        // would fail cleanly (AssertUnwindSafe bridges the closure).
409        let outcome = futures::FutureExt::catch_unwind(result).await;
410        let inner = match outcome {
411            Ok(r) => r,
412            Err(_) => panic!(
413                "send_with_retry panicked on builder-error input \
414                 (regression of the user's reported crash)"
415            ),
416        };
417        let err = inner.expect_err(
418            "builder with illegal header value must produce Err, \
419             not Ok",
420        );
421        // Sanity: the error must not be our panic masquerading as
422        // something else. reqwest::Error's Display should at least
423        // reference the underlying issue — we don't pin an exact
424        // phrase because reqwest's message text varies, but the
425        // error must be a real `reqwest::Error` and `is_builder()`
426        // must be true for builder-construction failures.
427        assert!(
428            err.is_builder(),
429            "expected is_builder() error, got {:?}",
430            err
431        );
432    }
433}