Skip to main content

llmsdk_provider/middleware/
retry.rs

1//! Retry middleware with exponential backoff.
2//!
3//! Retries only errors that report [`ProviderError::is_retryable`] (typically
4//! HTTP 408 / 409 / 429 / 5xx). Non-retryable errors fail fast; the stream
5//! variant only retries before the stream opens, never mid-stream.
6//!
7//! # Runtime requirement
8//!
9//! Uses [`tokio::time::sleep`] for backoff, so the caller must run inside a
10//! tokio runtime (any flavor). No assumption is made about the wider
11//! `Provider` implementation's runtime, but in practice every `llmsdk-*`
12//! provider already uses tokio via `reqwest`.
13// Rust guideline compliant 2026-02-21
14
15use std::sync::Mutex;
16use std::time::{Duration, SystemTime, UNIX_EPOCH};
17
18use async_trait::async_trait;
19
20use crate::error::{ProviderError, Result};
21use crate::language_model::{CallOptions, GenerateResult, LanguageModel, StreamResult};
22
23use super::language_model::LanguageModelMiddleware;
24
25/// Default maximum number of attempts (initial + retries).
26pub const DEFAULT_MAX_ATTEMPTS: u32 = 3;
27/// Default initial backoff before the first retry.
28pub const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(100);
29/// Default multiplicative factor applied to each successive backoff.
30pub const DEFAULT_BACKOFF_MULTIPLIER: f32 = 2.0;
31/// Default cap on a single backoff sleep.
32pub const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(5);
33/// Default jitter ratio (no jitter).
34pub const DEFAULT_JITTER_RATIO: f32 = 0.0;
35
36/// Middleware that retries failed calls with exponential backoff.
37///
38/// Retry policy:
39///
40/// - Only errors with [`ProviderError::is_retryable`] are retried; everything
41///   else propagates immediately.
42/// - [`Self::wrap_generate`]: retries the full request up to
43///   `max_attempts` times.
44/// - [`Self::wrap_stream`]: retries opening the stream only. Once the stream
45///   is open (any item, including [`crate::language_model::StreamPart::Error`],
46///   has been delivered), the retry policy stops; callers decide whether to
47///   re-issue the call.
48///
49/// Backoff is deterministic (no jitter); add jitter at the caller level if
50/// you have hundreds of concurrent retriers hitting the same upstream.
51///
52/// # Examples
53///
54/// ```ignore
55/// use std::sync::Arc;
56/// use std::time::Duration;
57/// use llmsdk_provider::{wrap_language_model, LanguageModel, LanguageModelMiddleware};
58/// use llmsdk_provider::middleware::RetryMiddleware;
59///
60/// fn add_retry(model: Arc<dyn LanguageModel>) -> Arc<dyn LanguageModel> {
61///     let retry = RetryMiddleware::builder()
62///         .max_attempts(5)
63///         .initial_backoff(Duration::from_millis(200))
64///         .build();
65///     wrap_language_model(model, [Arc::new(retry) as Arc<dyn LanguageModelMiddleware>])
66/// }
67/// ```
68#[derive(Debug)]
69pub struct RetryMiddleware {
70    max_attempts: u32,
71    initial_backoff: Duration,
72    backoff_multiplier: f32,
73    max_backoff: Duration,
74    /// Full-jitter ratio in `[0.0, 1.0]`. `0.0` disables jitter. Final backoff
75    /// is `base * (1 - r/2 .. 1 + r/2)` (uniform within bounds).
76    jitter_ratio: f32,
77    /// `SplitMix64` state seeded once from `SystemTime` nanos. `Mutex` so a
78    /// `&self` retry callback can mutate it; uncontended in practice
79    /// (one mutation per backoff).
80    rng: Mutex<u64>,
81}
82
83impl Clone for RetryMiddleware {
84    fn clone(&self) -> Self {
85        Self {
86            max_attempts: self.max_attempts,
87            initial_backoff: self.initial_backoff,
88            backoff_multiplier: self.backoff_multiplier,
89            max_backoff: self.max_backoff,
90            jitter_ratio: self.jitter_ratio,
91            rng: Mutex::new(*self.rng.lock().expect("rng mutex poisoned")),
92        }
93    }
94}
95
96impl Default for RetryMiddleware {
97    fn default() -> Self {
98        Self {
99            max_attempts: DEFAULT_MAX_ATTEMPTS,
100            initial_backoff: DEFAULT_INITIAL_BACKOFF,
101            backoff_multiplier: DEFAULT_BACKOFF_MULTIPLIER,
102            max_backoff: DEFAULT_MAX_BACKOFF,
103            jitter_ratio: DEFAULT_JITTER_RATIO,
104            rng: Mutex::new(seed_from_clock()),
105        }
106    }
107}
108
109/// Mix the current wall-clock nanoseconds into a 64-bit seed.
110#[allow(
111    clippy::cast_possible_truncation,
112    reason = "low 64 bits of clock are intentionally taken as PRNG seed"
113)]
114fn seed_from_clock() -> u64 {
115    let nanos = SystemTime::now()
116        .duration_since(UNIX_EPOCH)
117        .map_or(0xDEAD_BEEF_CAFE_BABE, |d| d.as_nanos() as u64);
118    // One mix step so callers that build many middlewares back-to-back don't
119    // get correlated streams.
120    splitmix64(&mut { nanos })
121}
122
123/// `SplitMix64` PRNG (one step).
124fn splitmix64(state: &mut u64) -> u64 {
125    *state = state.wrapping_add(0x9E37_79B9_7F4A_7C15);
126    let mut z = *state;
127    z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
128    z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
129    z ^ (z >> 31)
130}
131
132impl RetryMiddleware {
133    /// Build with default policy ([`DEFAULT_MAX_ATTEMPTS`], 100ms initial,
134    /// x2 multiplier, 5s cap).
135    #[must_use]
136    pub fn new() -> Self {
137        Self::default()
138    }
139
140    /// Open a builder for non-default tuning.
141    #[must_use]
142    pub fn builder() -> RetryMiddlewareBuilder {
143        RetryMiddlewareBuilder(Self::default())
144    }
145
146    /// Compute the backoff for retry index `attempt` (0-based: 0 = before
147    /// first retry, 1 = before second retry, ...).
148    ///
149    /// When `jitter_ratio > 0`, the result is uniformly perturbed within
150    /// `[base * (1 - r/2), base * (1 + r/2)]` and re-clamped to `max_backoff`.
151    fn backoff_for(&self, attempt: u32) -> Duration {
152        // Use `as_secs_f64` / `from_secs_f64` so the math stays in `f64`
153        // throughout and `Duration` handles bounds for us.
154        let exponent = i32::try_from(attempt).unwrap_or(i32::MAX);
155        let factor = f64::from(self.backoff_multiplier).powi(exponent);
156        let mut scaled = self.initial_backoff.as_secs_f64() * factor;
157        if !scaled.is_finite() || scaled <= 0.0 {
158            return self.initial_backoff;
159        }
160        let cap = self.max_backoff.as_secs_f64();
161        scaled = scaled.min(cap);
162        if self.jitter_ratio > 0.0 {
163            scaled = self.apply_jitter(scaled).min(cap);
164        }
165        Duration::from_secs_f64(scaled.max(0.0))
166    }
167
168    /// Multiply `base` by a uniform factor in `[1 - r/2, 1 + r/2]` where
169    /// `r = jitter_ratio.clamp(0.0, 1.0)`.
170    #[allow(
171        clippy::cast_precision_loss,
172        reason = "f64 mantissa is 52 bits; raw is masked to 53 bits before the cast"
173    )]
174    fn apply_jitter(&self, base: f64) -> f64 {
175        let r = f64::from(self.jitter_ratio.clamp(0.0, 1.0));
176        // Sample uniform u in [0, 1).
177        let raw = {
178            let mut guard = self.rng.lock().expect("rng mutex poisoned");
179            splitmix64(&mut guard)
180        };
181        let u = (raw >> 11) as f64 / (1u64 << 53) as f64;
182        let factor = 1.0 + r * (u - 0.5);
183        base * factor
184    }
185}
186
187/// Builder for [`RetryMiddleware`]; create via [`RetryMiddleware::builder`].
188#[derive(Debug)]
189pub struct RetryMiddlewareBuilder(RetryMiddleware);
190
191impl RetryMiddlewareBuilder {
192    /// Set the maximum number of attempts (must be `>= 1`; `1` disables retry).
193    #[must_use]
194    pub fn max_attempts(mut self, attempts: u32) -> Self {
195        self.0.max_attempts = attempts.max(1);
196        self
197    }
198
199    /// Set the initial backoff applied before the first retry.
200    #[must_use]
201    pub fn initial_backoff(mut self, dur: Duration) -> Self {
202        self.0.initial_backoff = dur;
203        self
204    }
205
206    /// Set the multiplier applied between successive retries.
207    #[must_use]
208    pub fn backoff_multiplier(mut self, factor: f32) -> Self {
209        self.0.backoff_multiplier = factor.max(1.0);
210        self
211    }
212
213    /// Set the upper bound on a single backoff sleep.
214    #[must_use]
215    pub fn max_backoff(mut self, dur: Duration) -> Self {
216        self.0.max_backoff = dur;
217        self
218    }
219
220    /// Set the full-jitter ratio (clamped to `[0.0, 1.0]`).
221    ///
222    /// `0.0` (default) means deterministic backoff. `1.0` spreads each sleep
223    /// uniformly over `[base/2, base*1.5]`. Use a non-zero value when many
224    /// callers retry the same upstream simultaneously to avoid thundering-herd.
225    #[must_use]
226    pub fn jitter_ratio(mut self, ratio: f32) -> Self {
227        self.0.jitter_ratio = ratio.clamp(0.0, 1.0);
228        self
229    }
230
231    /// Finalize the middleware.
232    #[must_use]
233    pub fn build(self) -> RetryMiddleware {
234        self.0
235    }
236}
237
238#[async_trait]
239impl LanguageModelMiddleware for RetryMiddleware {
240    async fn wrap_generate(
241        &self,
242        next: &dyn LanguageModel,
243        params: CallOptions,
244    ) -> Result<GenerateResult> {
245        let mut attempt: u32 = 0;
246        loop {
247            let outcome = next.do_generate(params.clone()).await;
248            match outcome {
249                Ok(result) => return Ok(result),
250                Err(err) => {
251                    if !should_retry(&err, attempt, self.max_attempts) {
252                        return Err(err);
253                    }
254                    tokio::time::sleep(self.backoff_for(attempt)).await;
255                    attempt += 1;
256                }
257            }
258        }
259    }
260
261    async fn wrap_stream(
262        &self,
263        next: &dyn LanguageModel,
264        params: CallOptions,
265    ) -> Result<StreamResult> {
266        let mut attempt: u32 = 0;
267        loop {
268            let outcome = next.do_stream(params.clone()).await;
269            match outcome {
270                Ok(result) => return Ok(result),
271                Err(err) => {
272                    if !should_retry(&err, attempt, self.max_attempts) {
273                        return Err(err);
274                    }
275                    tokio::time::sleep(self.backoff_for(attempt)).await;
276                    attempt += 1;
277                }
278            }
279        }
280    }
281}
282
283/// True when `err` is retryable and we have attempts left.
284///
285/// `attempt` is the zero-based index of the *failed* attempt: 0 = first call
286/// just failed, so we still have `max_attempts - 1` retries available.
287fn should_retry(err: &ProviderError, attempt: u32, max_attempts: u32) -> bool {
288    err.is_retryable() && attempt + 1 < max_attempts
289}
290
291#[cfg(test)]
292mod tests {
293    use std::sync::Arc;
294    use std::sync::atomic::{AtomicUsize, Ordering};
295
296    use crate::language_model::{FinishReason, FinishReasonKind, Usage};
297
298    use super::*;
299
300    /// Mock model that fails the first N attempts with a configurable error,
301    /// then succeeds.
302    #[derive(Debug)]
303    struct FlakyModel {
304        provider: String,
305        model_id: String,
306        fail_until: u32,
307        next_error: Mutex<Option<fn() -> ProviderError>>,
308        call_count: AtomicUsize,
309    }
310
311    impl FlakyModel {
312        fn new(fail_until: u32, err_factory: fn() -> ProviderError) -> Self {
313            Self {
314                provider: "test".to_owned(),
315                model_id: "flaky".to_owned(),
316                fail_until,
317                next_error: Mutex::new(Some(err_factory)),
318                call_count: AtomicUsize::new(0),
319            }
320        }
321
322        fn calls(&self) -> usize {
323            self.call_count.load(Ordering::SeqCst)
324        }
325    }
326
327    fn retryable_503() -> ProviderError {
328        ProviderError::api_call_builder("https://api.test", "service unavailable")
329            .status_code(503)
330            .build()
331    }
332
333    fn non_retryable_400() -> ProviderError {
334        ProviderError::api_call_builder("https://api.test", "bad request")
335            .status_code(400)
336            .build()
337    }
338
339    fn ok_result() -> GenerateResult {
340        GenerateResult {
341            content: vec![],
342            finish_reason: FinishReason::new(FinishReasonKind::Stop),
343            usage: Usage::default(),
344            provider_metadata: None,
345            request: None,
346            response: None,
347            warnings: vec![],
348        }
349    }
350
351    #[async_trait]
352    impl LanguageModel for FlakyModel {
353        fn provider(&self) -> &str {
354            &self.provider
355        }
356
357        fn model_id(&self) -> &str {
358            &self.model_id
359        }
360
361        async fn do_generate(&self, _options: CallOptions) -> Result<GenerateResult> {
362            let n = self.call_count.fetch_add(1, Ordering::SeqCst);
363            if u32::try_from(n).is_ok_and(|n| n < self.fail_until) {
364                let factory = self
365                    .next_error
366                    .lock()
367                    .expect("error factory mutex poisoned")
368                    .expect("error factory missing");
369                return Err(factory());
370            }
371            Ok(ok_result())
372        }
373
374        async fn do_stream(&self, _options: CallOptions) -> Result<StreamResult> {
375            let n = self.call_count.fetch_add(1, Ordering::SeqCst);
376            if u32::try_from(n).is_ok_and(|n| n < self.fail_until) {
377                let factory = self
378                    .next_error
379                    .lock()
380                    .expect("error factory mutex poisoned")
381                    .expect("error factory missing");
382                return Err(factory());
383            }
384            Ok(StreamResult {
385                stream: Box::pin(futures::stream::iter(Vec::new())),
386                request: None,
387                response: None,
388            })
389        }
390    }
391
392    #[tokio::test(start_paused = true)]
393    async fn retries_retryable_then_succeeds() {
394        let model = Arc::new(FlakyModel::new(2, retryable_503));
395        let retry = RetryMiddleware::builder()
396            .max_attempts(3)
397            .initial_backoff(Duration::from_millis(10))
398            .build();
399        retry
400            .wrap_generate(&*model, CallOptions::default())
401            .await
402            .expect("third attempt succeeds");
403        assert_eq!(model.calls(), 3, "two failures + one success");
404    }
405
406    #[tokio::test(start_paused = true)]
407    async fn non_retryable_fails_fast() {
408        let model = Arc::new(FlakyModel::new(5, non_retryable_400));
409        let retry = RetryMiddleware::builder().max_attempts(5).build();
410        let err = retry
411            .wrap_generate(&*model, CallOptions::default())
412            .await
413            .expect_err("non-retryable error propagates");
414        assert!(!err.is_retryable());
415        assert_eq!(model.calls(), 1, "no retries for non-retryable error");
416    }
417
418    #[tokio::test(start_paused = true)]
419    async fn exhausts_attempts_and_returns_last_error() {
420        let model = Arc::new(FlakyModel::new(10, retryable_503));
421        let retry = RetryMiddleware::builder()
422            .max_attempts(3)
423            .initial_backoff(Duration::from_millis(1))
424            .build();
425        let err = retry
426            .wrap_generate(&*model, CallOptions::default())
427            .await
428            .expect_err("attempts exhausted");
429        assert_eq!(err.status_code(), Some(503));
430        assert_eq!(model.calls(), 3, "max_attempts == 3 total calls");
431    }
432
433    #[tokio::test(start_paused = true)]
434    async fn max_attempts_one_disables_retry() {
435        let model = Arc::new(FlakyModel::new(5, retryable_503));
436        let retry = RetryMiddleware::builder().max_attempts(1).build();
437        let err = retry
438            .wrap_generate(&*model, CallOptions::default())
439            .await
440            .expect_err("first failure propagates");
441        assert!(err.is_retryable());
442        assert_eq!(model.calls(), 1);
443    }
444
445    #[tokio::test(start_paused = true)]
446    async fn stream_retries_open_failures() {
447        let model = Arc::new(FlakyModel::new(2, retryable_503));
448        let retry = RetryMiddleware::builder()
449            .max_attempts(3)
450            .initial_backoff(Duration::from_millis(1))
451            .build();
452        retry
453            .wrap_stream(&*model, CallOptions::default())
454            .await
455            .expect("stream opens on third attempt");
456        assert_eq!(model.calls(), 3);
457    }
458
459    #[test]
460    fn backoff_caps_at_max() {
461        let retry = RetryMiddleware::builder()
462            .initial_backoff(Duration::from_millis(100))
463            .backoff_multiplier(10.0)
464            .max_backoff(Duration::from_secs(1))
465            .build();
466        // attempt 0 -> 100ms, attempt 1 -> 1000ms, attempt 2 -> capped at 1s.
467        assert_eq!(retry.backoff_for(0), Duration::from_millis(100));
468        assert_eq!(retry.backoff_for(1), Duration::from_secs(1));
469        assert_eq!(retry.backoff_for(2), Duration::from_secs(1));
470    }
471
472    #[test]
473    fn jitter_perturbs_within_expected_range() {
474        let retry = RetryMiddleware::builder()
475            .initial_backoff(Duration::from_millis(100))
476            .backoff_multiplier(1.0) // keep base constant across attempts
477            .jitter_ratio(0.5) // -25%..+25%
478            .max_backoff(Duration::from_secs(10))
479            .build();
480        let base = 100.0;
481        let lo = base * (1.0 - 0.25);
482        let hi = base * (1.0 + 0.25);
483        for _ in 0..32 {
484            let sample = retry.backoff_for(0).as_secs_f64() * 1000.0;
485            assert!(
486                sample >= lo && sample <= hi,
487                "jitter sample {sample}ms outside [{lo},{hi}]"
488            );
489        }
490    }
491}