llmsdk_provider/middleware/
retry.rs1use std::sync::Mutex;
16use std::time::{Duration, SystemTime, UNIX_EPOCH};
17
18use async_trait::async_trait;
19
20use crate::error::{ProviderError, Result};
21use crate::language_model::{CallOptions, GenerateResult, LanguageModel, StreamResult};
22
23use super::language_model::LanguageModelMiddleware;
24
25pub const DEFAULT_MAX_ATTEMPTS: u32 = 3;
27pub const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(100);
29pub const DEFAULT_BACKOFF_MULTIPLIER: f32 = 2.0;
31pub const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(5);
33pub const DEFAULT_JITTER_RATIO: f32 = 0.0;
35
36#[derive(Debug)]
69pub struct RetryMiddleware {
70 max_attempts: u32,
71 initial_backoff: Duration,
72 backoff_multiplier: f32,
73 max_backoff: Duration,
74 jitter_ratio: f32,
77 rng: Mutex<u64>,
81}
82
83impl Clone for RetryMiddleware {
84 fn clone(&self) -> Self {
85 Self {
86 max_attempts: self.max_attempts,
87 initial_backoff: self.initial_backoff,
88 backoff_multiplier: self.backoff_multiplier,
89 max_backoff: self.max_backoff,
90 jitter_ratio: self.jitter_ratio,
91 rng: Mutex::new(*self.rng.lock().expect("rng mutex poisoned")),
92 }
93 }
94}
95
96impl Default for RetryMiddleware {
97 fn default() -> Self {
98 Self {
99 max_attempts: DEFAULT_MAX_ATTEMPTS,
100 initial_backoff: DEFAULT_INITIAL_BACKOFF,
101 backoff_multiplier: DEFAULT_BACKOFF_MULTIPLIER,
102 max_backoff: DEFAULT_MAX_BACKOFF,
103 jitter_ratio: DEFAULT_JITTER_RATIO,
104 rng: Mutex::new(seed_from_clock()),
105 }
106 }
107}
108
109#[allow(
111 clippy::cast_possible_truncation,
112 reason = "low 64 bits of clock are intentionally taken as PRNG seed"
113)]
114fn seed_from_clock() -> u64 {
115 let nanos = SystemTime::now()
116 .duration_since(UNIX_EPOCH)
117 .map_or(0xDEAD_BEEF_CAFE_BABE, |d| d.as_nanos() as u64);
118 splitmix64(&mut { nanos })
121}
122
123fn splitmix64(state: &mut u64) -> u64 {
125 *state = state.wrapping_add(0x9E37_79B9_7F4A_7C15);
126 let mut z = *state;
127 z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
128 z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
129 z ^ (z >> 31)
130}
131
132impl RetryMiddleware {
133 #[must_use]
136 pub fn new() -> Self {
137 Self::default()
138 }
139
140 #[must_use]
142 pub fn builder() -> RetryMiddlewareBuilder {
143 RetryMiddlewareBuilder(Self::default())
144 }
145
146 fn backoff_for(&self, attempt: u32) -> Duration {
152 let exponent = i32::try_from(attempt).unwrap_or(i32::MAX);
155 let factor = f64::from(self.backoff_multiplier).powi(exponent);
156 let mut scaled = self.initial_backoff.as_secs_f64() * factor;
157 if !scaled.is_finite() || scaled <= 0.0 {
158 return self.initial_backoff;
159 }
160 let cap = self.max_backoff.as_secs_f64();
161 scaled = scaled.min(cap);
162 if self.jitter_ratio > 0.0 {
163 scaled = self.apply_jitter(scaled).min(cap);
164 }
165 Duration::from_secs_f64(scaled.max(0.0))
166 }
167
168 #[allow(
171 clippy::cast_precision_loss,
172 reason = "f64 mantissa is 52 bits; raw is masked to 53 bits before the cast"
173 )]
174 fn apply_jitter(&self, base: f64) -> f64 {
175 let r = f64::from(self.jitter_ratio.clamp(0.0, 1.0));
176 let raw = {
178 let mut guard = self.rng.lock().expect("rng mutex poisoned");
179 splitmix64(&mut guard)
180 };
181 let u = (raw >> 11) as f64 / (1u64 << 53) as f64;
182 let factor = 1.0 + r * (u - 0.5);
183 base * factor
184 }
185}
186
187#[derive(Debug)]
189pub struct RetryMiddlewareBuilder(RetryMiddleware);
190
191impl RetryMiddlewareBuilder {
192 #[must_use]
194 pub fn max_attempts(mut self, attempts: u32) -> Self {
195 self.0.max_attempts = attempts.max(1);
196 self
197 }
198
199 #[must_use]
201 pub fn initial_backoff(mut self, dur: Duration) -> Self {
202 self.0.initial_backoff = dur;
203 self
204 }
205
206 #[must_use]
208 pub fn backoff_multiplier(mut self, factor: f32) -> Self {
209 self.0.backoff_multiplier = factor.max(1.0);
210 self
211 }
212
213 #[must_use]
215 pub fn max_backoff(mut self, dur: Duration) -> Self {
216 self.0.max_backoff = dur;
217 self
218 }
219
220 #[must_use]
226 pub fn jitter_ratio(mut self, ratio: f32) -> Self {
227 self.0.jitter_ratio = ratio.clamp(0.0, 1.0);
228 self
229 }
230
231 #[must_use]
233 pub fn build(self) -> RetryMiddleware {
234 self.0
235 }
236}
237
238#[async_trait]
239impl LanguageModelMiddleware for RetryMiddleware {
240 async fn wrap_generate(
241 &self,
242 next: &dyn LanguageModel,
243 params: CallOptions,
244 ) -> Result<GenerateResult> {
245 let mut attempt: u32 = 0;
246 loop {
247 let outcome = next.do_generate(params.clone()).await;
248 match outcome {
249 Ok(result) => return Ok(result),
250 Err(err) => {
251 if !should_retry(&err, attempt, self.max_attempts) {
252 return Err(err);
253 }
254 tokio::time::sleep(self.backoff_for(attempt)).await;
255 attempt += 1;
256 }
257 }
258 }
259 }
260
261 async fn wrap_stream(
262 &self,
263 next: &dyn LanguageModel,
264 params: CallOptions,
265 ) -> Result<StreamResult> {
266 let mut attempt: u32 = 0;
267 loop {
268 let outcome = next.do_stream(params.clone()).await;
269 match outcome {
270 Ok(result) => return Ok(result),
271 Err(err) => {
272 if !should_retry(&err, attempt, self.max_attempts) {
273 return Err(err);
274 }
275 tokio::time::sleep(self.backoff_for(attempt)).await;
276 attempt += 1;
277 }
278 }
279 }
280 }
281}
282
283fn should_retry(err: &ProviderError, attempt: u32, max_attempts: u32) -> bool {
288 err.is_retryable() && attempt + 1 < max_attempts
289}
290
291#[cfg(test)]
292mod tests {
293 use std::sync::Arc;
294 use std::sync::atomic::{AtomicUsize, Ordering};
295
296 use crate::language_model::{FinishReason, FinishReasonKind, Usage};
297
298 use super::*;
299
300 #[derive(Debug)]
303 struct FlakyModel {
304 provider: String,
305 model_id: String,
306 fail_until: u32,
307 next_error: Mutex<Option<fn() -> ProviderError>>,
308 call_count: AtomicUsize,
309 }
310
311 impl FlakyModel {
312 fn new(fail_until: u32, err_factory: fn() -> ProviderError) -> Self {
313 Self {
314 provider: "test".to_owned(),
315 model_id: "flaky".to_owned(),
316 fail_until,
317 next_error: Mutex::new(Some(err_factory)),
318 call_count: AtomicUsize::new(0),
319 }
320 }
321
322 fn calls(&self) -> usize {
323 self.call_count.load(Ordering::SeqCst)
324 }
325 }
326
327 fn retryable_503() -> ProviderError {
328 ProviderError::api_call_builder("https://api.test", "service unavailable")
329 .status_code(503)
330 .build()
331 }
332
333 fn non_retryable_400() -> ProviderError {
334 ProviderError::api_call_builder("https://api.test", "bad request")
335 .status_code(400)
336 .build()
337 }
338
339 fn ok_result() -> GenerateResult {
340 GenerateResult {
341 content: vec![],
342 finish_reason: FinishReason::new(FinishReasonKind::Stop),
343 usage: Usage::default(),
344 provider_metadata: None,
345 request: None,
346 response: None,
347 warnings: vec![],
348 }
349 }
350
351 #[async_trait]
352 impl LanguageModel for FlakyModel {
353 fn provider(&self) -> &str {
354 &self.provider
355 }
356
357 fn model_id(&self) -> &str {
358 &self.model_id
359 }
360
361 async fn do_generate(&self, _options: CallOptions) -> Result<GenerateResult> {
362 let n = self.call_count.fetch_add(1, Ordering::SeqCst);
363 if u32::try_from(n).is_ok_and(|n| n < self.fail_until) {
364 let factory = self
365 .next_error
366 .lock()
367 .expect("error factory mutex poisoned")
368 .expect("error factory missing");
369 return Err(factory());
370 }
371 Ok(ok_result())
372 }
373
374 async fn do_stream(&self, _options: CallOptions) -> Result<StreamResult> {
375 let n = self.call_count.fetch_add(1, Ordering::SeqCst);
376 if u32::try_from(n).is_ok_and(|n| n < self.fail_until) {
377 let factory = self
378 .next_error
379 .lock()
380 .expect("error factory mutex poisoned")
381 .expect("error factory missing");
382 return Err(factory());
383 }
384 Ok(StreamResult {
385 stream: Box::pin(futures::stream::iter(Vec::new())),
386 request: None,
387 response: None,
388 })
389 }
390 }
391
392 #[tokio::test(start_paused = true)]
393 async fn retries_retryable_then_succeeds() {
394 let model = Arc::new(FlakyModel::new(2, retryable_503));
395 let retry = RetryMiddleware::builder()
396 .max_attempts(3)
397 .initial_backoff(Duration::from_millis(10))
398 .build();
399 retry
400 .wrap_generate(&*model, CallOptions::default())
401 .await
402 .expect("third attempt succeeds");
403 assert_eq!(model.calls(), 3, "two failures + one success");
404 }
405
406 #[tokio::test(start_paused = true)]
407 async fn non_retryable_fails_fast() {
408 let model = Arc::new(FlakyModel::new(5, non_retryable_400));
409 let retry = RetryMiddleware::builder().max_attempts(5).build();
410 let err = retry
411 .wrap_generate(&*model, CallOptions::default())
412 .await
413 .expect_err("non-retryable error propagates");
414 assert!(!err.is_retryable());
415 assert_eq!(model.calls(), 1, "no retries for non-retryable error");
416 }
417
418 #[tokio::test(start_paused = true)]
419 async fn exhausts_attempts_and_returns_last_error() {
420 let model = Arc::new(FlakyModel::new(10, retryable_503));
421 let retry = RetryMiddleware::builder()
422 .max_attempts(3)
423 .initial_backoff(Duration::from_millis(1))
424 .build();
425 let err = retry
426 .wrap_generate(&*model, CallOptions::default())
427 .await
428 .expect_err("attempts exhausted");
429 assert_eq!(err.status_code(), Some(503));
430 assert_eq!(model.calls(), 3, "max_attempts == 3 total calls");
431 }
432
433 #[tokio::test(start_paused = true)]
434 async fn max_attempts_one_disables_retry() {
435 let model = Arc::new(FlakyModel::new(5, retryable_503));
436 let retry = RetryMiddleware::builder().max_attempts(1).build();
437 let err = retry
438 .wrap_generate(&*model, CallOptions::default())
439 .await
440 .expect_err("first failure propagates");
441 assert!(err.is_retryable());
442 assert_eq!(model.calls(), 1);
443 }
444
445 #[tokio::test(start_paused = true)]
446 async fn stream_retries_open_failures() {
447 let model = Arc::new(FlakyModel::new(2, retryable_503));
448 let retry = RetryMiddleware::builder()
449 .max_attempts(3)
450 .initial_backoff(Duration::from_millis(1))
451 .build();
452 retry
453 .wrap_stream(&*model, CallOptions::default())
454 .await
455 .expect("stream opens on third attempt");
456 assert_eq!(model.calls(), 3);
457 }
458
459 #[test]
460 fn backoff_caps_at_max() {
461 let retry = RetryMiddleware::builder()
462 .initial_backoff(Duration::from_millis(100))
463 .backoff_multiplier(10.0)
464 .max_backoff(Duration::from_secs(1))
465 .build();
466 assert_eq!(retry.backoff_for(0), Duration::from_millis(100));
468 assert_eq!(retry.backoff_for(1), Duration::from_secs(1));
469 assert_eq!(retry.backoff_for(2), Duration::from_secs(1));
470 }
471
472 #[test]
473 fn jitter_perturbs_within_expected_range() {
474 let retry = RetryMiddleware::builder()
475 .initial_backoff(Duration::from_millis(100))
476 .backoff_multiplier(1.0) .jitter_ratio(0.5) .max_backoff(Duration::from_secs(10))
479 .build();
480 let base = 100.0;
481 let lo = base * (1.0 - 0.25);
482 let hi = base * (1.0 + 0.25);
483 for _ in 0..32 {
484 let sample = retry.backoff_for(0).as_secs_f64() * 1000.0;
485 assert!(
486 sample >= lo && sample <= hi,
487 "jitter sample {sample}ms outside [{lo},{hi}]"
488 );
489 }
490 }
491}