ailoop-core 1.0.0-rc.2

Core vocabulary (messages, streams, hooks, middleware) for the ailoop SDK
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
//! Reliability decorator: [`RetryingModel`] wraps any [`CompletionModel`]
//! and retries setup-time failures with exponential backoff that honours
//! `Retry-After` hints surfaced by the adapter.
//!
//! Each adapter classifies its own error type via the [`Retryable`] trait
//! (`AnthropicError` in `ailoop-anthropic`, `AzureOpenAIError` in
//! `ailoop-azure-openai`); the decorator only needs to know whether an
//! error is `Permanent` or `Transient { retry_after }`.
//!
//! ## Scope
//!
//! Only the call to [`CompletionModel::chat_stream`] is retried. Once the
//! stream is open, mid-stream errors propagate to the caller unchanged —
//! the engine has already consumed deltas and built partial assistant
//! state, and replaying from that point is not safe in general. Reissuing
//! the entire request from the engine layer would also discard work
//! the model has already done. If retry-on-mid-stream becomes important
//! later it belongs above this layer (in the engine), not here.

use std::time::Duration;

use async_trait::async_trait;
use futures::stream::BoxStream;

use crate::request::ChatRequest;
use crate::stream::StreamChunk;
use crate::traits::CompletionModel;

/// Adapter-side classification of an error as retryable or not.
///
/// Implemented by each provider's error type (e.g. `AnthropicError`,
/// `AzureOpenAIError`) so [`RetryingModel`] stays provider-agnostic.
pub trait Retryable {
    /// Decide whether this error should retry, and after how long.
    fn retry_classification(&self) -> RetryClassification;
}

/// Outcome of [`Retryable::retry_classification`]: whether the
/// decorator should attempt the call again, and after how long.
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum RetryClassification {
    /// Never retry. Auth, validation, schema errors, etc.
    Permanent,
    /// Worth another attempt after `retry_after` (when the server told us
    /// how long to wait) or after the decorator's own backoff.
    Transient {
        /// Server-supplied delay (parsed from `Retry-After` /
        /// `retry-after-ms` headers). When `Some`, the decorator
        /// honours it as a floor; when `None`, the decorator falls
        /// back to its own exponential schedule.
        retry_after: Option<Duration>,
    },
}

/// Backoff settings for [`RetryingModel`].
///
/// `max_attempts` counts the *total* number of calls to the inner model,
/// including the first one. `max_attempts: 3` means up to two retries
/// after the original failure.
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct RetryConfig {
    /// Total number of calls to the inner model, **including the
    /// first one**. `1` disables retries; `3` allows two retries
    /// after the initial failure.
    pub max_attempts: u32,
    /// First retry delay. Doubles on each subsequent attempt
    /// (`base_delay`, `2 * base_delay`, `4 * base_delay`, ...) up to
    /// [`Self::max_delay`].
    pub base_delay: Duration,
    /// Hard cap on the per-retry delay. The exponential schedule
    /// stops growing once it hits this value.
    pub max_delay: Duration,
    /// When `true`, add 0–10% randomness to each delay so concurrent
    /// clients don't synchronize their retries. A server-supplied
    /// `Retry-After` is treated as a floor — jitter only adds time.
    pub jitter: bool,
}

impl Default for RetryConfig {
    fn default() -> Self {
        Self {
            max_attempts: 3,
            base_delay: Duration::from_secs(1),
            max_delay: Duration::from_secs(30),
            jitter: true,
        }
    }
}

/// Decorator that retries setup-time failures of an inner
/// [`CompletionModel`] using exponential backoff with optional jitter,
/// honouring `Retry-After` when the inner error provides it.
pub struct RetryingModel<M> {
    inner: M,
    config: RetryConfig,
}

impl<M> RetryingModel<M> {
    /// Wrap `inner` with [`RetryConfig::default`].
    pub fn new(inner: M) -> Self {
        Self::with_config(inner, RetryConfig::default())
    }

    /// Wrap `inner` with explicit retry settings.
    pub fn with_config(inner: M, config: RetryConfig) -> Self {
        Self { inner, config }
    }

    /// Borrow the wrapped model. Useful for tests that need to assert
    /// on the inner state without unwrapping.
    pub fn inner(&self) -> &M {
        &self.inner
    }

    /// Unwrap and return the inner model, dropping the retry layer.
    pub fn into_inner(self) -> M {
        self.inner
    }
}

#[async_trait]
impl<M> CompletionModel for RetryingModel<M>
where
    M: CompletionModel + Send + Sync,
    M::Error: Retryable,
{
    type Error = M::Error;

    fn name(&self) -> &str {
        self.inner.name()
    }

    fn model(&self) -> &str {
        self.inner.model()
    }

    async fn chat_stream(
        &self,
        req: ChatRequest,
    ) -> Result<BoxStream<'static, Result<StreamChunk, Self::Error>>, Self::Error> {
        let max = self.config.max_attempts.max(1);
        let mut attempt: u32 = 0;
        loop {
            let try_req = req.clone();
            let err = match self.inner.chat_stream(try_req).await {
                Ok(stream) => return Ok(stream),
                Err(e) => e,
            };
            attempt += 1;
            if attempt >= max {
                return Err(err);
            }
            let delay = match err.retry_classification() {
                RetryClassification::Permanent => return Err(err),
                RetryClassification::Transient { retry_after } => {
                    compute_delay(&self.config, attempt, retry_after)
                }
            };
            tokio::time::sleep(delay).await;
        }
    }
}

fn compute_delay(cfg: &RetryConfig, attempt: u32, retry_after: Option<Duration>) -> Duration {
    let base = match retry_after {
        Some(d) => d,
        None => exponential(cfg.base_delay, cfg.max_delay, attempt),
    };
    if cfg.jitter { apply_jitter(base) } else { base }
}

fn exponential(base: Duration, max: Duration, attempt: u32) -> Duration {
    // attempt is 1-based on the first retry, so 2^(attempt - 1).
    let shift = attempt.saturating_sub(1).min(20);
    let factor: u128 = 1u128 << shift;
    let nanos = base.as_nanos().saturating_mul(factor);
    let capped = nanos.min(max.as_nanos());
    Duration::from_nanos(u64::try_from(capped).unwrap_or(u64::MAX))
}

/// Spreads concurrent clients by adding 0..10% to the delay. Always
/// non-negative so a server-supplied `Retry-After` floor is respected.
fn apply_jitter(d: Duration) -> Duration {
    let nanos = d.as_nanos();
    if nanos == 0 {
        return d;
    }
    let now_ns = std::time::SystemTime::now()
        .duration_since(std::time::UNIX_EPOCH)
        .unwrap_or_default()
        .subsec_nanos() as u128;
    let jitter_max = nanos / 10;
    let offset = if jitter_max == 0 {
        0
    } else {
        now_ns % jitter_max
    };
    let total = nanos.saturating_add(offset);
    Duration::from_nanos(u64::try_from(total).unwrap_or(u64::MAX))
}

#[cfg(test)]
mod tests {
    use std::sync::Mutex;
    use std::sync::atomic::{AtomicUsize, Ordering};
    use std::time::Instant;

    use async_trait::async_trait;
    use futures::StreamExt;
    use futures::stream::{self, BoxStream};

    use super::*;
    use crate::stream::{FinishReason, Usage};
    use crate::testing::{ScriptedError, ScriptedModel, ScriptedTurn};

    fn empty_request() -> ChatRequest {
        ChatRequest::new(vec![], 0)
    }

    fn fast_config(max_attempts: u32) -> RetryConfig {
        RetryConfig {
            max_attempts,
            base_delay: Duration::from_millis(1),
            max_delay: Duration::from_millis(5),
            jitter: false,
        }
    }

    /// Wraps a `ScriptedModel` and counts how many times `chat_stream`
    /// was invoked. Used to assert on retry behaviour from the outside.
    struct CountingModel {
        inner: ScriptedModel,
        calls: AtomicUsize,
    }

    impl CountingModel {
        fn new(turns: Vec<ScriptedTurn>) -> Self {
            Self {
                inner: ScriptedModel::with_turns(turns),
                calls: AtomicUsize::new(0),
            }
        }

        fn calls(&self) -> usize {
            self.calls.load(Ordering::SeqCst)
        }
    }

    #[async_trait]
    impl CompletionModel for CountingModel {
        type Error = ScriptedError;

        fn name(&self) -> &str {
            self.inner.name()
        }

        fn model(&self) -> &str {
            self.inner.model()
        }

        async fn chat_stream(
            &self,
            req: ChatRequest,
        ) -> Result<BoxStream<'static, Result<StreamChunk, Self::Error>>, Self::Error> {
            self.calls.fetch_add(1, Ordering::SeqCst);
            self.inner.chat_stream(req).await
        }
    }

    fn ok_chunk() -> StreamChunk {
        StreamChunk::TurnFinished {
            reason: FinishReason::EndTurn,
            usage: Usage::default(),
            service_tier: None,
        }
    }

    #[tokio::test]
    async fn retries_until_success() {
        let inner = CountingModel::new(vec![
            Err(ScriptedError("transient:1".into())),
            Ok(vec![Ok(ok_chunk())]),
        ]);
        let model = RetryingModel::with_config(inner, fast_config(3));

        let stream = model
            .chat_stream(empty_request())
            .await
            .expect("retry should succeed on second attempt");
        let chunks: Vec<_> = stream.collect().await;
        assert_eq!(chunks.len(), 1);
        assert_eq!(model.inner().calls(), 2);
    }

    #[tokio::test]
    async fn gives_up_after_max_attempts() {
        let inner = CountingModel::new(vec![
            Err(ScriptedError("transient:1".into())),
            Err(ScriptedError("transient:1".into())),
            Err(ScriptedError("transient:1".into())),
            Err(ScriptedError("transient:1".into())),
        ]);
        let model = RetryingModel::with_config(inner, fast_config(3));

        let result = model.chat_stream(empty_request()).await;
        assert!(matches!(result, Err(ScriptedError(_))));
        assert_eq!(
            model.inner().calls(),
            3,
            "max_attempts is total calls including the first"
        );
    }

    #[tokio::test]
    async fn respects_retry_after() {
        let inner = CountingModel::new(vec![
            Err(ScriptedError("transient:50".into())),
            Ok(vec![Ok(ok_chunk())]),
        ]);
        let model = RetryingModel::with_config(inner, fast_config(3));

        let started = Instant::now();
        let stream = model
            .chat_stream(empty_request())
            .await
            .expect("second attempt should succeed");
        let _: Vec<_> = stream.collect().await;
        let elapsed = started.elapsed();
        assert!(
            elapsed >= Duration::from_millis(50),
            "expected at least 50ms wait, got {:?}",
            elapsed
        );
    }

    #[tokio::test]
    async fn permanent_errors_are_not_retried() {
        let inner = CountingModel::new(vec![
            Err(ScriptedError("permanent: bad auth".into())),
            Ok(vec![Ok(ok_chunk())]),
        ]);
        let model = RetryingModel::with_config(inner, fast_config(3));

        let result = model.chat_stream(empty_request()).await;
        assert!(matches!(result, Err(ScriptedError(_))));
        assert_eq!(
            model.inner().calls(),
            1,
            "permanent errors must not trigger a retry"
        );
    }

    #[tokio::test]
    async fn mid_stream_errors_are_not_retried() {
        // The setup call succeeds; the stream then yields one chunk and
        // a transient-looking error. RetryingModel must not reissue the
        // request — the engine has already consumed deltas.
        let inner = CountingModel::new(vec![
            Ok(vec![
                Ok(StreamChunk::TextDelta {
                    delta: "hello".into(),
                }),
                Err(ScriptedError("transient:1".into())),
            ]),
            Ok(vec![Ok(ok_chunk())]),
        ]);
        let model = RetryingModel::with_config(inner, fast_config(3));

        let stream = model
            .chat_stream(empty_request())
            .await
            .expect("setup should succeed");
        let chunks: Vec<_> = stream.collect().await;
        assert_eq!(chunks.len(), 2);
        assert!(matches!(chunks[0], Ok(StreamChunk::TextDelta { .. })));
        assert!(matches!(chunks[1], Err(ScriptedError(_))));
        assert_eq!(
            model.inner().calls(),
            1,
            "mid-stream errors must not trigger a setup-time retry"
        );
    }

    #[tokio::test]
    async fn exponential_backoff_grows_then_caps() {
        let cfg = RetryConfig {
            max_attempts: 10,
            base_delay: Duration::from_millis(10),
            max_delay: Duration::from_millis(40),
            jitter: false,
        };
        assert_eq!(
            exponential(cfg.base_delay, cfg.max_delay, 1),
            Duration::from_millis(10)
        );
        assert_eq!(
            exponential(cfg.base_delay, cfg.max_delay, 2),
            Duration::from_millis(20)
        );
        assert_eq!(
            exponential(cfg.base_delay, cfg.max_delay, 3),
            Duration::from_millis(40)
        );
        assert_eq!(
            exponential(cfg.base_delay, cfg.max_delay, 4),
            Duration::from_millis(40)
        );
    }

    /// Sanity check that `BoxStream` from `stream::iter` still conforms
    /// to what `chat_stream` is supposed to return — guards against the
    /// test scaffolding accidentally lying about the type.
    #[tokio::test]
    async fn box_stream_conforms() {
        let s: BoxStream<'static, Result<StreamChunk, ScriptedError>> =
            Box::pin(stream::iter(vec![Ok(ok_chunk())]));
        let chunks: Vec<_> = s.collect().await;
        assert_eq!(chunks.len(), 1);
    }

    /// Ensures we don't accidentally hold the lock across an `await` —
    /// regression guard for the `Mutex` in test setup.
    #[tokio::test]
    async fn no_lock_held_across_await() {
        let m = Mutex::new(0u32);
        {
            let mut g = m.lock().unwrap();
            *g += 1;
        }
        tokio::task::yield_now().await;
        assert_eq!(*m.lock().unwrap(), 1);
    }
}