1use std::time::Duration;
21
22use async_trait::async_trait;
23use futures::stream::BoxStream;
24
25use crate::request::ChatRequest;
26use crate::stream::StreamChunk;
27use crate::traits::CompletionModel;
28
29pub trait Retryable {
34 fn retry_classification(&self) -> RetryClassification;
36}
37
38#[derive(Debug, Clone, PartialEq, Eq)]
41#[non_exhaustive]
42pub enum RetryClassification {
43 Permanent,
45 Transient {
48 retry_after: Option<Duration>,
53 },
54}
55
56#[derive(Debug, Clone)]
62#[non_exhaustive]
63pub struct RetryConfig {
64 pub max_attempts: u32,
68 pub base_delay: Duration,
72 pub max_delay: Duration,
75 pub jitter: bool,
79}
80
81impl Default for RetryConfig {
82 fn default() -> Self {
83 Self {
84 max_attempts: 3,
85 base_delay: Duration::from_secs(1),
86 max_delay: Duration::from_secs(30),
87 jitter: true,
88 }
89 }
90}
91
92pub struct RetryingModel<M> {
96 inner: M,
97 config: RetryConfig,
98}
99
100impl<M> RetryingModel<M> {
101 pub fn new(inner: M) -> Self {
103 Self::with_config(inner, RetryConfig::default())
104 }
105
106 pub fn with_config(inner: M, config: RetryConfig) -> Self {
108 Self { inner, config }
109 }
110
111 pub fn inner(&self) -> &M {
114 &self.inner
115 }
116
117 pub fn into_inner(self) -> M {
119 self.inner
120 }
121}
122
123#[async_trait]
124impl<M> CompletionModel for RetryingModel<M>
125where
126 M: CompletionModel + Send + Sync,
127 M::Error: Retryable,
128{
129 type Error = M::Error;
130
131 fn name(&self) -> &str {
132 self.inner.name()
133 }
134
135 fn model(&self) -> &str {
136 self.inner.model()
137 }
138
139 async fn chat_stream(
140 &self,
141 req: ChatRequest,
142 ) -> Result<BoxStream<'static, Result<StreamChunk, Self::Error>>, Self::Error> {
143 let max = self.config.max_attempts.max(1);
144 let mut attempt: u32 = 0;
145 loop {
146 let try_req = req.clone();
147 let err = match self.inner.chat_stream(try_req).await {
148 Ok(stream) => return Ok(stream),
149 Err(e) => e,
150 };
151 attempt += 1;
152 if attempt >= max {
153 return Err(err);
154 }
155 let delay = match err.retry_classification() {
156 RetryClassification::Permanent => return Err(err),
157 RetryClassification::Transient { retry_after } => {
158 compute_delay(&self.config, attempt, retry_after)
159 }
160 };
161 tokio::time::sleep(delay).await;
162 }
163 }
164}
165
166fn compute_delay(cfg: &RetryConfig, attempt: u32, retry_after: Option<Duration>) -> Duration {
167 let base = match retry_after {
168 Some(d) => d,
169 None => exponential(cfg.base_delay, cfg.max_delay, attempt),
170 };
171 if cfg.jitter { apply_jitter(base) } else { base }
172}
173
174fn exponential(base: Duration, max: Duration, attempt: u32) -> Duration {
175 let shift = attempt.saturating_sub(1).min(20);
177 let factor: u128 = 1u128 << shift;
178 let nanos = base.as_nanos().saturating_mul(factor);
179 let capped = nanos.min(max.as_nanos());
180 Duration::from_nanos(u64::try_from(capped).unwrap_or(u64::MAX))
181}
182
183fn apply_jitter(d: Duration) -> Duration {
186 let nanos = d.as_nanos();
187 if nanos == 0 {
188 return d;
189 }
190 let now_ns = std::time::SystemTime::now()
191 .duration_since(std::time::UNIX_EPOCH)
192 .unwrap_or_default()
193 .subsec_nanos() as u128;
194 let jitter_max = nanos / 10;
195 let offset = if jitter_max == 0 {
196 0
197 } else {
198 now_ns % jitter_max
199 };
200 let total = nanos.saturating_add(offset);
201 Duration::from_nanos(u64::try_from(total).unwrap_or(u64::MAX))
202}
203
204#[cfg(test)]
205mod tests {
206 use std::sync::Mutex;
207 use std::sync::atomic::{AtomicUsize, Ordering};
208 use std::time::Instant;
209
210 use async_trait::async_trait;
211 use futures::StreamExt;
212 use futures::stream::{self, BoxStream};
213
214 use super::*;
215 use crate::stream::{FinishReason, Usage};
216 use crate::testing::{ScriptedError, ScriptedModel, ScriptedTurn};
217
218 fn empty_request() -> ChatRequest {
219 ChatRequest::new(vec![], 0)
220 }
221
222 fn fast_config(max_attempts: u32) -> RetryConfig {
223 RetryConfig {
224 max_attempts,
225 base_delay: Duration::from_millis(1),
226 max_delay: Duration::from_millis(5),
227 jitter: false,
228 }
229 }
230
231 struct CountingModel {
234 inner: ScriptedModel,
235 calls: AtomicUsize,
236 }
237
238 impl CountingModel {
239 fn new(turns: Vec<ScriptedTurn>) -> Self {
240 Self {
241 inner: ScriptedModel::with_turns(turns),
242 calls: AtomicUsize::new(0),
243 }
244 }
245
246 fn calls(&self) -> usize {
247 self.calls.load(Ordering::SeqCst)
248 }
249 }
250
251 #[async_trait]
252 impl CompletionModel for CountingModel {
253 type Error = ScriptedError;
254
255 fn name(&self) -> &str {
256 self.inner.name()
257 }
258
259 fn model(&self) -> &str {
260 self.inner.model()
261 }
262
263 async fn chat_stream(
264 &self,
265 req: ChatRequest,
266 ) -> Result<BoxStream<'static, Result<StreamChunk, Self::Error>>, Self::Error> {
267 self.calls.fetch_add(1, Ordering::SeqCst);
268 self.inner.chat_stream(req).await
269 }
270 }
271
272 fn ok_chunk() -> StreamChunk {
273 StreamChunk::TurnFinished {
274 reason: FinishReason::EndTurn,
275 usage: Usage::default(),
276 service_tier: None,
277 }
278 }
279
280 #[tokio::test]
281 async fn retries_until_success() {
282 let inner = CountingModel::new(vec![
283 Err(ScriptedError("transient:1".into())),
284 Ok(vec![Ok(ok_chunk())]),
285 ]);
286 let model = RetryingModel::with_config(inner, fast_config(3));
287
288 let stream = model
289 .chat_stream(empty_request())
290 .await
291 .expect("retry should succeed on second attempt");
292 let chunks: Vec<_> = stream.collect().await;
293 assert_eq!(chunks.len(), 1);
294 assert_eq!(model.inner().calls(), 2);
295 }
296
297 #[tokio::test]
298 async fn gives_up_after_max_attempts() {
299 let inner = CountingModel::new(vec![
300 Err(ScriptedError("transient:1".into())),
301 Err(ScriptedError("transient:1".into())),
302 Err(ScriptedError("transient:1".into())),
303 Err(ScriptedError("transient:1".into())),
304 ]);
305 let model = RetryingModel::with_config(inner, fast_config(3));
306
307 let result = model.chat_stream(empty_request()).await;
308 assert!(matches!(result, Err(ScriptedError(_))));
309 assert_eq!(
310 model.inner().calls(),
311 3,
312 "max_attempts is total calls including the first"
313 );
314 }
315
316 #[tokio::test]
317 async fn respects_retry_after() {
318 let inner = CountingModel::new(vec![
319 Err(ScriptedError("transient:50".into())),
320 Ok(vec![Ok(ok_chunk())]),
321 ]);
322 let model = RetryingModel::with_config(inner, fast_config(3));
323
324 let started = Instant::now();
325 let stream = model
326 .chat_stream(empty_request())
327 .await
328 .expect("second attempt should succeed");
329 let _: Vec<_> = stream.collect().await;
330 let elapsed = started.elapsed();
331 assert!(
332 elapsed >= Duration::from_millis(50),
333 "expected at least 50ms wait, got {:?}",
334 elapsed
335 );
336 }
337
338 #[tokio::test]
339 async fn permanent_errors_are_not_retried() {
340 let inner = CountingModel::new(vec![
341 Err(ScriptedError("permanent: bad auth".into())),
342 Ok(vec![Ok(ok_chunk())]),
343 ]);
344 let model = RetryingModel::with_config(inner, fast_config(3));
345
346 let result = model.chat_stream(empty_request()).await;
347 assert!(matches!(result, Err(ScriptedError(_))));
348 assert_eq!(
349 model.inner().calls(),
350 1,
351 "permanent errors must not trigger a retry"
352 );
353 }
354
355 #[tokio::test]
356 async fn mid_stream_errors_are_not_retried() {
357 let inner = CountingModel::new(vec![
361 Ok(vec![
362 Ok(StreamChunk::TextDelta {
363 delta: "hello".into(),
364 }),
365 Err(ScriptedError("transient:1".into())),
366 ]),
367 Ok(vec![Ok(ok_chunk())]),
368 ]);
369 let model = RetryingModel::with_config(inner, fast_config(3));
370
371 let stream = model
372 .chat_stream(empty_request())
373 .await
374 .expect("setup should succeed");
375 let chunks: Vec<_> = stream.collect().await;
376 assert_eq!(chunks.len(), 2);
377 assert!(matches!(chunks[0], Ok(StreamChunk::TextDelta { .. })));
378 assert!(matches!(chunks[1], Err(ScriptedError(_))));
379 assert_eq!(
380 model.inner().calls(),
381 1,
382 "mid-stream errors must not trigger a setup-time retry"
383 );
384 }
385
386 #[tokio::test]
387 async fn exponential_backoff_grows_then_caps() {
388 let cfg = RetryConfig {
389 max_attempts: 10,
390 base_delay: Duration::from_millis(10),
391 max_delay: Duration::from_millis(40),
392 jitter: false,
393 };
394 assert_eq!(
395 exponential(cfg.base_delay, cfg.max_delay, 1),
396 Duration::from_millis(10)
397 );
398 assert_eq!(
399 exponential(cfg.base_delay, cfg.max_delay, 2),
400 Duration::from_millis(20)
401 );
402 assert_eq!(
403 exponential(cfg.base_delay, cfg.max_delay, 3),
404 Duration::from_millis(40)
405 );
406 assert_eq!(
407 exponential(cfg.base_delay, cfg.max_delay, 4),
408 Duration::from_millis(40)
409 );
410 }
411
412 #[tokio::test]
416 async fn box_stream_conforms() {
417 let s: BoxStream<'static, Result<StreamChunk, ScriptedError>> =
418 Box::pin(stream::iter(vec![Ok(ok_chunk())]));
419 let chunks: Vec<_> = s.collect().await;
420 assert_eq!(chunks.len(), 1);
421 }
422
423 #[tokio::test]
426 async fn no_lock_held_across_await() {
427 let m = Mutex::new(0u32);
428 {
429 let mut g = m.lock().unwrap();
430 *g += 1;
431 }
432 tokio::task::yield_now().await;
433 assert_eq!(*m.lock().unwrap(), 1);
434 }
435}