Skip to main content

ailoop_core/
retry.rs

1//! Reliability decorator: [`RetryingModel`] wraps any [`CompletionModel`]
2//! and retries setup-time failures with exponential backoff that honours
3//! `Retry-After` hints surfaced by the adapter.
4//!
5//! Each adapter classifies its own error type via the [`Retryable`] trait
6//! (`AnthropicError` in `ailoop-anthropic`, `AzureOpenAIError` in
7//! `ailoop-azure-openai`); the decorator only needs to know whether an
8//! error is `Permanent` or `Transient { retry_after }`.
9//!
10//! ## Scope
11//!
12//! Only the call to [`CompletionModel::chat_stream`] is retried. Once the
13//! stream is open, mid-stream errors propagate to the caller unchanged —
14//! the engine has already consumed deltas and built partial assistant
15//! state, and replaying from that point is not safe in general. Reissuing
16//! the entire request from the engine layer would also discard work
17//! the model has already done. If retry-on-mid-stream becomes important
18//! later it belongs above this layer (in the engine), not here.
19
20use std::time::Duration;
21
22use async_trait::async_trait;
23use futures::stream::BoxStream;
24
25use crate::request::ChatRequest;
26use crate::stream::StreamChunk;
27use crate::traits::CompletionModel;
28
29/// Adapter-side classification of an error as retryable or not.
30///
31/// Implemented by each provider's error type (e.g. `AnthropicError`,
32/// `AzureOpenAIError`) so [`RetryingModel`] stays provider-agnostic.
33pub trait Retryable {
34    /// Decide whether this error should retry, and after how long.
35    fn retry_classification(&self) -> RetryClassification;
36}
37
38/// Outcome of [`Retryable::retry_classification`]: whether the
39/// decorator should attempt the call again, and after how long.
40#[derive(Debug, Clone, PartialEq, Eq)]
41#[non_exhaustive]
42pub enum RetryClassification {
43    /// Never retry. Auth, validation, schema errors, etc.
44    Permanent,
45    /// Worth another attempt after `retry_after` (when the server told us
46    /// how long to wait) or after the decorator's own backoff.
47    Transient {
48        /// Server-supplied delay (parsed from `Retry-After` /
49        /// `retry-after-ms` headers). When `Some`, the decorator
50        /// honours it as a floor; when `None`, the decorator falls
51        /// back to its own exponential schedule.
52        retry_after: Option<Duration>,
53    },
54}
55
56/// Backoff settings for [`RetryingModel`].
57///
58/// `max_attempts` counts the *total* number of calls to the inner model,
59/// including the first one. `max_attempts: 3` means up to two retries
60/// after the original failure.
61#[derive(Debug, Clone)]
62#[non_exhaustive]
63pub struct RetryConfig {
64    /// Total number of calls to the inner model, **including the
65    /// first one**. `1` disables retries; `3` allows two retries
66    /// after the initial failure.
67    pub max_attempts: u32,
68    /// First retry delay. Doubles on each subsequent attempt
69    /// (`base_delay`, `2 * base_delay`, `4 * base_delay`, ...) up to
70    /// [`Self::max_delay`].
71    pub base_delay: Duration,
72    /// Hard cap on the per-retry delay. The exponential schedule
73    /// stops growing once it hits this value.
74    pub max_delay: Duration,
75    /// When `true`, add 0–10% randomness to each delay so concurrent
76    /// clients don't synchronize their retries. A server-supplied
77    /// `Retry-After` is treated as a floor — jitter only adds time.
78    pub jitter: bool,
79}
80
81impl Default for RetryConfig {
82    fn default() -> Self {
83        Self {
84            max_attempts: 3,
85            base_delay: Duration::from_secs(1),
86            max_delay: Duration::from_secs(30),
87            jitter: true,
88        }
89    }
90}
91
92/// Decorator that retries setup-time failures of an inner
93/// [`CompletionModel`] using exponential backoff with optional jitter,
94/// honouring `Retry-After` when the inner error provides it.
95pub struct RetryingModel<M> {
96    inner: M,
97    config: RetryConfig,
98}
99
100impl<M> RetryingModel<M> {
101    /// Wrap `inner` with [`RetryConfig::default`].
102    pub fn new(inner: M) -> Self {
103        Self::with_config(inner, RetryConfig::default())
104    }
105
106    /// Wrap `inner` with explicit retry settings.
107    pub fn with_config(inner: M, config: RetryConfig) -> Self {
108        Self { inner, config }
109    }
110
111    /// Borrow the wrapped model. Useful for tests that need to assert
112    /// on the inner state without unwrapping.
113    pub fn inner(&self) -> &M {
114        &self.inner
115    }
116
117    /// Unwrap and return the inner model, dropping the retry layer.
118    pub fn into_inner(self) -> M {
119        self.inner
120    }
121}
122
123#[async_trait]
124impl<M> CompletionModel for RetryingModel<M>
125where
126    M: CompletionModel + Send + Sync,
127    M::Error: Retryable,
128{
129    type Error = M::Error;
130
131    fn name(&self) -> &str {
132        self.inner.name()
133    }
134
135    fn model(&self) -> &str {
136        self.inner.model()
137    }
138
139    async fn chat_stream(
140        &self,
141        req: ChatRequest,
142    ) -> Result<BoxStream<'static, Result<StreamChunk, Self::Error>>, Self::Error> {
143        let max = self.config.max_attempts.max(1);
144        let mut attempt: u32 = 0;
145        loop {
146            let try_req = req.clone();
147            let err = match self.inner.chat_stream(try_req).await {
148                Ok(stream) => return Ok(stream),
149                Err(e) => e,
150            };
151            attempt += 1;
152            if attempt >= max {
153                return Err(err);
154            }
155            let delay = match err.retry_classification() {
156                RetryClassification::Permanent => return Err(err),
157                RetryClassification::Transient { retry_after } => {
158                    compute_delay(&self.config, attempt, retry_after)
159                }
160            };
161            tokio::time::sleep(delay).await;
162        }
163    }
164}
165
166fn compute_delay(cfg: &RetryConfig, attempt: u32, retry_after: Option<Duration>) -> Duration {
167    let base = match retry_after {
168        Some(d) => d,
169        None => exponential(cfg.base_delay, cfg.max_delay, attempt),
170    };
171    if cfg.jitter { apply_jitter(base) } else { base }
172}
173
174fn exponential(base: Duration, max: Duration, attempt: u32) -> Duration {
175    // attempt is 1-based on the first retry, so 2^(attempt - 1).
176    let shift = attempt.saturating_sub(1).min(20);
177    let factor: u128 = 1u128 << shift;
178    let nanos = base.as_nanos().saturating_mul(factor);
179    let capped = nanos.min(max.as_nanos());
180    Duration::from_nanos(u64::try_from(capped).unwrap_or(u64::MAX))
181}
182
183/// Spreads concurrent clients by adding 0..10% to the delay. Always
184/// non-negative so a server-supplied `Retry-After` floor is respected.
185fn apply_jitter(d: Duration) -> Duration {
186    let nanos = d.as_nanos();
187    if nanos == 0 {
188        return d;
189    }
190    let now_ns = std::time::SystemTime::now()
191        .duration_since(std::time::UNIX_EPOCH)
192        .unwrap_or_default()
193        .subsec_nanos() as u128;
194    let jitter_max = nanos / 10;
195    let offset = if jitter_max == 0 {
196        0
197    } else {
198        now_ns % jitter_max
199    };
200    let total = nanos.saturating_add(offset);
201    Duration::from_nanos(u64::try_from(total).unwrap_or(u64::MAX))
202}
203
204#[cfg(test)]
205mod tests {
206    use std::sync::Mutex;
207    use std::sync::atomic::{AtomicUsize, Ordering};
208    use std::time::Instant;
209
210    use async_trait::async_trait;
211    use futures::StreamExt;
212    use futures::stream::{self, BoxStream};
213
214    use super::*;
215    use crate::stream::{FinishReason, Usage};
216    use crate::testing::{ScriptedError, ScriptedModel, ScriptedTurn};
217
218    fn empty_request() -> ChatRequest {
219        ChatRequest::new(vec![], 0)
220    }
221
222    fn fast_config(max_attempts: u32) -> RetryConfig {
223        RetryConfig {
224            max_attempts,
225            base_delay: Duration::from_millis(1),
226            max_delay: Duration::from_millis(5),
227            jitter: false,
228        }
229    }
230
231    /// Wraps a `ScriptedModel` and counts how many times `chat_stream`
232    /// was invoked. Used to assert on retry behaviour from the outside.
233    struct CountingModel {
234        inner: ScriptedModel,
235        calls: AtomicUsize,
236    }
237
238    impl CountingModel {
239        fn new(turns: Vec<ScriptedTurn>) -> Self {
240            Self {
241                inner: ScriptedModel::with_turns(turns),
242                calls: AtomicUsize::new(0),
243            }
244        }
245
246        fn calls(&self) -> usize {
247            self.calls.load(Ordering::SeqCst)
248        }
249    }
250
251    #[async_trait]
252    impl CompletionModel for CountingModel {
253        type Error = ScriptedError;
254
255        fn name(&self) -> &str {
256            self.inner.name()
257        }
258
259        fn model(&self) -> &str {
260            self.inner.model()
261        }
262
263        async fn chat_stream(
264            &self,
265            req: ChatRequest,
266        ) -> Result<BoxStream<'static, Result<StreamChunk, Self::Error>>, Self::Error> {
267            self.calls.fetch_add(1, Ordering::SeqCst);
268            self.inner.chat_stream(req).await
269        }
270    }
271
272    fn ok_chunk() -> StreamChunk {
273        StreamChunk::TurnFinished {
274            reason: FinishReason::EndTurn,
275            usage: Usage::default(),
276            service_tier: None,
277        }
278    }
279
280    #[tokio::test]
281    async fn retries_until_success() {
282        let inner = CountingModel::new(vec![
283            Err(ScriptedError("transient:1".into())),
284            Ok(vec![Ok(ok_chunk())]),
285        ]);
286        let model = RetryingModel::with_config(inner, fast_config(3));
287
288        let stream = model
289            .chat_stream(empty_request())
290            .await
291            .expect("retry should succeed on second attempt");
292        let chunks: Vec<_> = stream.collect().await;
293        assert_eq!(chunks.len(), 1);
294        assert_eq!(model.inner().calls(), 2);
295    }
296
297    #[tokio::test]
298    async fn gives_up_after_max_attempts() {
299        let inner = CountingModel::new(vec![
300            Err(ScriptedError("transient:1".into())),
301            Err(ScriptedError("transient:1".into())),
302            Err(ScriptedError("transient:1".into())),
303            Err(ScriptedError("transient:1".into())),
304        ]);
305        let model = RetryingModel::with_config(inner, fast_config(3));
306
307        let result = model.chat_stream(empty_request()).await;
308        assert!(matches!(result, Err(ScriptedError(_))));
309        assert_eq!(
310            model.inner().calls(),
311            3,
312            "max_attempts is total calls including the first"
313        );
314    }
315
316    #[tokio::test]
317    async fn respects_retry_after() {
318        let inner = CountingModel::new(vec![
319            Err(ScriptedError("transient:50".into())),
320            Ok(vec![Ok(ok_chunk())]),
321        ]);
322        let model = RetryingModel::with_config(inner, fast_config(3));
323
324        let started = Instant::now();
325        let stream = model
326            .chat_stream(empty_request())
327            .await
328            .expect("second attempt should succeed");
329        let _: Vec<_> = stream.collect().await;
330        let elapsed = started.elapsed();
331        assert!(
332            elapsed >= Duration::from_millis(50),
333            "expected at least 50ms wait, got {:?}",
334            elapsed
335        );
336    }
337
338    #[tokio::test]
339    async fn permanent_errors_are_not_retried() {
340        let inner = CountingModel::new(vec![
341            Err(ScriptedError("permanent: bad auth".into())),
342            Ok(vec![Ok(ok_chunk())]),
343        ]);
344        let model = RetryingModel::with_config(inner, fast_config(3));
345
346        let result = model.chat_stream(empty_request()).await;
347        assert!(matches!(result, Err(ScriptedError(_))));
348        assert_eq!(
349            model.inner().calls(),
350            1,
351            "permanent errors must not trigger a retry"
352        );
353    }
354
355    #[tokio::test]
356    async fn mid_stream_errors_are_not_retried() {
357        // The setup call succeeds; the stream then yields one chunk and
358        // a transient-looking error. RetryingModel must not reissue the
359        // request — the engine has already consumed deltas.
360        let inner = CountingModel::new(vec![
361            Ok(vec![
362                Ok(StreamChunk::TextDelta {
363                    delta: "hello".into(),
364                }),
365                Err(ScriptedError("transient:1".into())),
366            ]),
367            Ok(vec![Ok(ok_chunk())]),
368        ]);
369        let model = RetryingModel::with_config(inner, fast_config(3));
370
371        let stream = model
372            .chat_stream(empty_request())
373            .await
374            .expect("setup should succeed");
375        let chunks: Vec<_> = stream.collect().await;
376        assert_eq!(chunks.len(), 2);
377        assert!(matches!(chunks[0], Ok(StreamChunk::TextDelta { .. })));
378        assert!(matches!(chunks[1], Err(ScriptedError(_))));
379        assert_eq!(
380            model.inner().calls(),
381            1,
382            "mid-stream errors must not trigger a setup-time retry"
383        );
384    }
385
386    #[tokio::test]
387    async fn exponential_backoff_grows_then_caps() {
388        let cfg = RetryConfig {
389            max_attempts: 10,
390            base_delay: Duration::from_millis(10),
391            max_delay: Duration::from_millis(40),
392            jitter: false,
393        };
394        assert_eq!(
395            exponential(cfg.base_delay, cfg.max_delay, 1),
396            Duration::from_millis(10)
397        );
398        assert_eq!(
399            exponential(cfg.base_delay, cfg.max_delay, 2),
400            Duration::from_millis(20)
401        );
402        assert_eq!(
403            exponential(cfg.base_delay, cfg.max_delay, 3),
404            Duration::from_millis(40)
405        );
406        assert_eq!(
407            exponential(cfg.base_delay, cfg.max_delay, 4),
408            Duration::from_millis(40)
409        );
410    }
411
412    /// Sanity check that `BoxStream` from `stream::iter` still conforms
413    /// to what `chat_stream` is supposed to return — guards against the
414    /// test scaffolding accidentally lying about the type.
415    #[tokio::test]
416    async fn box_stream_conforms() {
417        let s: BoxStream<'static, Result<StreamChunk, ScriptedError>> =
418            Box::pin(stream::iter(vec![Ok(ok_chunk())]));
419        let chunks: Vec<_> = s.collect().await;
420        assert_eq!(chunks.len(), 1);
421    }
422
423    /// Ensures we don't accidentally hold the lock across an `await` —
424    /// regression guard for the `Mutex` in test setup.
425    #[tokio::test]
426    async fn no_lock_held_across_await() {
427        let m = Mutex::new(0u32);
428        {
429            let mut g = m.lock().unwrap();
430            *g += 1;
431        }
432        tokio::task::yield_now().await;
433        assert_eq!(*m.lock().unwrap(), 1);
434    }
435}