1use std::sync::Arc;
4use std::time::Duration;
5
6use crate::error::Error;
7use crate::llm::types::{CompletionRequest, CompletionResponse};
8
9use super::LlmProvider;
10
11#[derive(Debug, Clone)]
13pub struct RetryConfig {
14 pub max_retries: u32,
16 pub base_delay: Duration,
18 pub max_delay: Duration,
20}
21
22impl Default for RetryConfig {
23 fn default() -> Self {
24 Self {
25 max_retries: 3,
26 base_delay: Duration::from_millis(500),
27 max_delay: Duration::from_secs(30),
28 }
29 }
30}
31
32pub type OnRetry = dyn Fn(u32, u32, u64, &str) + Send + Sync;
38
39pub struct RetryingProvider<P> {
51 inner: P,
52 config: RetryConfig,
53 on_retry: Option<Arc<OnRetry>>,
54}
55
56impl<P> RetryingProvider<P> {
57 pub fn new(inner: P, config: RetryConfig) -> Self {
59 Self {
60 inner,
61 config,
62 on_retry: None,
63 }
64 }
65
66 pub fn with_defaults(inner: P) -> Self {
68 Self::new(inner, RetryConfig::default())
69 }
70
71 pub fn with_on_retry(mut self, callback: Arc<OnRetry>) -> Self {
75 self.on_retry = Some(callback);
76 self
77 }
78}
79
80fn classify_for_retry(err: &Error) -> &'static str {
82 match err {
83 Error::Api { status: 429, .. } => "rate_limited",
84 Error::Api { status: 500, .. } => "server_error_500",
85 Error::Api { status: 502, .. } => "server_error_502",
86 Error::Api { status: 503, .. } => "server_error_503",
87 Error::Api { status: 529, .. } => "overloaded",
88 Error::Http(_) => "network_error",
89 _ => "unknown",
90 }
91}
92
93fn is_retryable(err: &Error) -> bool {
95 match err {
96 Error::Api { status, .. } => matches!(*status, 429 | 500 | 502 | 503 | 529),
97 Error::Http(_) => true,
98 _ => false,
99 }
100}
101
102fn compute_delay(config: &RetryConfig, attempt: u32) -> Duration {
114 use std::sync::atomic::{AtomicU64, Ordering};
115 let base_ms = config.base_delay.as_millis() as u64;
116 let max_ms = config.max_delay.as_millis() as u64;
117
118 static SEED: AtomicU64 = AtomicU64::new(0x9E3779B97F4A7C15);
121 let prev_max_ms = base_ms.saturating_mul(1u64.checked_shl(attempt).unwrap_or(u32::MAX as u64));
122 let upper = prev_max_ms.saturating_mul(3).min(max_ms.max(base_ms));
123 let lower = base_ms.min(upper);
124 let next = SEED
126 .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |s| {
127 Some(s.wrapping_mul(1664525).wrapping_add(1013904223))
128 })
129 .unwrap_or(0);
130 let span = upper - lower + 1;
131 let pick = lower + (next % span);
132 Duration::from_millis(pick.min(max_ms))
133}
134
135impl<P: LlmProvider> LlmProvider for RetryingProvider<P> {
136 fn model_name(&self) -> Option<&str> {
137 self.inner.model_name()
138 }
139
140 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse, Error> {
141 let mut last_err: Option<Error> = None;
142
143 for attempt in 0..=self.config.max_retries {
144 if attempt > 0 {
145 let delay = compute_delay(&self.config, attempt - 1);
146 let delay_ms = delay.as_millis() as u64;
147 let error_class =
148 classify_for_retry(last_err.as_ref().expect("last_err set before retry"));
149 if let Some(ref cb) = self.on_retry {
150 cb(attempt, self.config.max_retries, delay_ms, error_class);
151 }
152 tracing::warn!(
153 attempt = attempt,
154 max_retries = self.config.max_retries,
155 delay_ms = delay_ms,
156 error = %last_err.as_ref().expect("last_err set before retry"),
157 "retrying LLM call after transient failure"
158 );
159 tokio::time::sleep(delay).await;
160 }
161
162 match self.inner.complete(request.clone()).await {
163 Ok(response) => return Ok(response),
164 Err(e) if is_retryable(&e) => {
165 last_err = Some(e);
166 }
167 Err(e) => return Err(e),
168 }
169 }
170
171 Err(last_err.expect("at least one attempt must have been made"))
173 }
174
175 async fn stream_complete(
176 &self,
177 request: CompletionRequest,
178 on_text: &super::OnText,
179 ) -> Result<CompletionResponse, Error> {
180 let mut last_err: Option<Error> = None;
181 fn noop_text(_: &str) {}
186 let noop: &super::OnText = &noop_text;
187
188 for attempt in 0..=self.config.max_retries {
189 if attempt > 0 {
190 let delay = compute_delay(&self.config, attempt - 1);
191 let delay_ms = delay.as_millis() as u64;
192 let error_class =
193 classify_for_retry(last_err.as_ref().expect("last_err set before retry"));
194 if let Some(ref cb) = self.on_retry {
195 cb(attempt, self.config.max_retries, delay_ms, error_class);
196 }
197 tracing::warn!(
198 attempt = attempt,
199 max_retries = self.config.max_retries,
200 delay_ms = delay_ms,
201 error = %last_err.as_ref().expect("last_err set before retry"),
202 "retrying streaming LLM call after transient failure (streaming suppressed)"
203 );
204 tokio::time::sleep(delay).await;
205 }
206
207 let callback = if attempt == 0 { on_text } else { &noop };
208 match self.inner.stream_complete(request.clone(), callback).await {
209 Ok(response) => return Ok(response),
210 Err(e) if is_retryable(&e) => {
211 last_err = Some(e);
212 }
213 Err(e) => return Err(e),
214 }
215 }
216
217 Err(last_err.expect("at least one attempt must have been made"))
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224 use crate::llm::types::{Message, StopReason, TokenUsage};
225 use std::sync::Arc;
226 use std::sync::atomic::{AtomicU32, Ordering};
227
228 struct FailNTimes {
231 remaining_failures: AtomicU32,
232 error_factory: Box<dyn Fn() -> Error + Send + Sync>,
233 call_count: Arc<AtomicU32>,
234 }
235
236 impl FailNTimes {
237 fn new(
238 failures: u32,
239 error_factory: impl Fn() -> Error + Send + Sync + 'static,
240 ) -> (Self, Arc<AtomicU32>) {
241 let count = Arc::new(AtomicU32::new(0));
242 (
243 Self {
244 remaining_failures: AtomicU32::new(failures),
245 error_factory: Box::new(error_factory),
246 call_count: count.clone(),
247 },
248 count,
249 )
250 }
251 }
252
253 fn success_response() -> CompletionResponse {
254 CompletionResponse {
255 content: vec![crate::llm::types::ContentBlock::Text { text: "ok".into() }],
256 stop_reason: StopReason::EndTurn,
257 usage: TokenUsage {
258 input_tokens: 10,
259 output_tokens: 5,
260 ..Default::default()
261 },
262 model: None,
263 }
264 }
265
266 impl LlmProvider for FailNTimes {
267 async fn complete(&self, _request: CompletionRequest) -> Result<CompletionResponse, Error> {
268 self.call_count.fetch_add(1, Ordering::SeqCst);
269 if self
271 .remaining_failures
272 .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |v| {
273 if v > 0 { Some(v - 1) } else { None }
274 })
275 .is_ok()
276 {
277 return Err((self.error_factory)());
278 }
279 Ok(success_response())
280 }
281 }
282
283 fn test_request() -> CompletionRequest {
284 CompletionRequest {
285 system: String::new(),
286 messages: vec![Message::user("test")],
287 tools: vec![],
288 max_tokens: 100,
289 tool_choice: None,
290 reasoning_effort: None,
291 }
292 }
293
294 fn fast_config(max_retries: u32) -> RetryConfig {
295 RetryConfig {
296 max_retries,
297 base_delay: Duration::from_millis(1), max_delay: Duration::from_millis(10),
299 }
300 }
301
302 #[tokio::test]
303 async fn succeeds_on_first_attempt() {
304 let (mock, count) = FailNTimes::new(0, || Error::Api {
305 status: 429,
306 message: "rate limited".into(),
307 });
308 let provider = RetryingProvider::new(mock, fast_config(3));
309
310 let result = provider.complete(test_request()).await;
311 assert!(result.is_ok());
312 assert_eq!(count.load(Ordering::SeqCst), 1);
313 }
314
315 #[tokio::test]
316 async fn retries_on_429_and_succeeds() {
317 let (mock, count) = FailNTimes::new(2, || Error::Api {
318 status: 429,
319 message: "rate limited".into(),
320 });
321 let provider = RetryingProvider::new(mock, fast_config(3));
322
323 let result = provider.complete(test_request()).await;
324 assert!(result.is_ok());
325 assert_eq!(count.load(Ordering::SeqCst), 3); }
327
328 #[tokio::test]
329 async fn retries_on_500_and_succeeds() {
330 let (mock, count) = FailNTimes::new(1, || Error::Api {
331 status: 500,
332 message: "internal server error".into(),
333 });
334 let provider = RetryingProvider::new(mock, fast_config(3));
335
336 let result = provider.complete(test_request()).await;
337 assert!(result.is_ok());
338 assert_eq!(count.load(Ordering::SeqCst), 2);
339 }
340
341 #[tokio::test]
342 async fn retries_on_502_and_succeeds() {
343 let (mock, count) = FailNTimes::new(1, || Error::Api {
344 status: 502,
345 message: "bad gateway".into(),
346 });
347 let provider = RetryingProvider::new(mock, fast_config(3));
348
349 let result = provider.complete(test_request()).await;
350 assert!(result.is_ok());
351 assert_eq!(count.load(Ordering::SeqCst), 2);
352 }
353
354 #[tokio::test]
355 async fn retries_on_503_and_succeeds() {
356 let (mock, count) = FailNTimes::new(1, || Error::Api {
357 status: 503,
358 message: "service unavailable".into(),
359 });
360 let provider = RetryingProvider::new(mock, fast_config(3));
361
362 let result = provider.complete(test_request()).await;
363 assert!(result.is_ok());
364 assert_eq!(count.load(Ordering::SeqCst), 2);
365 }
366
367 #[tokio::test]
368 async fn retries_on_529_and_succeeds() {
369 let (mock, count) = FailNTimes::new(1, || Error::Api {
370 status: 529,
371 message: "overloaded".into(),
372 });
373 let provider = RetryingProvider::new(mock, fast_config(3));
374
375 let result = provider.complete(test_request()).await;
376 assert!(result.is_ok());
377 assert_eq!(count.load(Ordering::SeqCst), 2);
378 }
379
380 #[tokio::test]
381 async fn exhausts_retries_and_returns_last_error() {
382 let (mock, count) = FailNTimes::new(10, || Error::Api {
383 status: 429,
384 message: "rate limited".into(),
385 });
386 let provider = RetryingProvider::new(mock, fast_config(2));
387
388 let result = provider.complete(test_request()).await;
389 assert!(result.is_err());
390 let err = result.unwrap_err();
391 assert!(matches!(err, Error::Api { status: 429, .. }));
392 assert_eq!(count.load(Ordering::SeqCst), 3); }
394
395 #[tokio::test]
396 async fn does_not_retry_400() {
397 let (mock, count) = FailNTimes::new(5, || Error::Api {
398 status: 400,
399 message: "bad request".into(),
400 });
401 let provider = RetryingProvider::new(mock, fast_config(3));
402
403 let result = provider.complete(test_request()).await;
404 assert!(result.is_err());
405 assert_eq!(count.load(Ordering::SeqCst), 1); }
407
408 #[tokio::test]
409 async fn does_not_retry_401() {
410 let (mock, count) = FailNTimes::new(5, || Error::Api {
411 status: 401,
412 message: "unauthorized".into(),
413 });
414 let provider = RetryingProvider::new(mock, fast_config(3));
415
416 let result = provider.complete(test_request()).await;
417 assert!(result.is_err());
418 assert_eq!(count.load(Ordering::SeqCst), 1);
419 }
420
421 #[tokio::test]
422 async fn does_not_retry_json_parse_error() {
423 let (mock, count) = FailNTimes::new(5, || {
424 Error::Json(serde_json::from_str::<()>("invalid").unwrap_err())
425 });
426 let provider = RetryingProvider::new(mock, fast_config(3));
427
428 let result = provider.complete(test_request()).await;
429 assert!(result.is_err());
430 assert_eq!(count.load(Ordering::SeqCst), 1);
431 }
432
433 #[tokio::test]
434 async fn zero_retries_means_single_attempt() {
435 let (mock, count) = FailNTimes::new(1, || Error::Api {
436 status: 429,
437 message: "rate limited".into(),
438 });
439 let provider = RetryingProvider::new(mock, fast_config(0));
440
441 let result = provider.complete(test_request()).await;
442 assert!(result.is_err());
443 assert_eq!(count.load(Ordering::SeqCst), 1);
444 }
445
446 #[tokio::test]
447 async fn stream_complete_retries_on_transient_failure() {
448 let (mock, count) = FailNTimes::new(2, || Error::Api {
452 status: 429,
453 message: "rate limited".into(),
454 });
455 let provider = RetryingProvider::new(mock, fast_config(3));
456
457 let on_text: &crate::llm::OnText = &|_| {};
458 let result = provider.stream_complete(test_request(), on_text).await;
459 assert!(result.is_ok());
460 assert_eq!(count.load(Ordering::SeqCst), 3); }
462
463 #[tokio::test]
464 async fn stream_complete_does_not_retry_non_retryable() {
465 let (mock, count) = FailNTimes::new(5, || Error::Api {
466 status: 400,
467 message: "bad request".into(),
468 });
469 let provider = RetryingProvider::new(mock, fast_config(3));
470
471 let on_text: &crate::llm::OnText = &|_| {};
472 let result = provider.stream_complete(test_request(), on_text).await;
473 assert!(result.is_err());
474 assert_eq!(count.load(Ordering::SeqCst), 1); }
476
477 #[test]
478 fn default_config_values() {
479 let config = RetryConfig::default();
480 assert_eq!(config.max_retries, 3);
481 assert_eq!(config.base_delay, Duration::from_millis(500));
482 assert_eq!(config.max_delay, Duration::from_secs(30));
483 }
484
485 #[test]
486 fn is_retryable_checks() {
487 assert!(is_retryable(&Error::Api {
489 status: 429,
490 message: "".into()
491 }));
492 assert!(is_retryable(&Error::Api {
493 status: 500,
494 message: "".into()
495 }));
496 assert!(is_retryable(&Error::Api {
497 status: 502,
498 message: "".into()
499 }));
500 assert!(is_retryable(&Error::Api {
501 status: 503,
502 message: "".into()
503 }));
504 assert!(is_retryable(&Error::Api {
505 status: 529,
506 message: "".into()
507 }));
508
509 assert!(!is_retryable(&Error::Api {
511 status: 400,
512 message: "".into()
513 }));
514 assert!(!is_retryable(&Error::Api {
515 status: 401,
516 message: "".into()
517 }));
518 assert!(!is_retryable(&Error::Api {
519 status: 403,
520 message: "".into()
521 }));
522 assert!(!is_retryable(&Error::Api {
523 status: 404,
524 message: "".into()
525 }));
526 assert!(!is_retryable(&Error::Agent("test".into())));
527 assert!(!is_retryable(&Error::Config("test".into())));
528 assert!(!is_retryable(&Error::Memory("test".into())));
529 }
530
531 #[test]
535 fn compute_delay_in_jitter_range() {
536 let config = RetryConfig {
537 max_retries: 5,
538 base_delay: Duration::from_millis(100),
539 max_delay: Duration::from_secs(10),
540 };
541
542 for attempt in 0..4 {
543 let delay = compute_delay(&config, attempt);
544 assert!(
545 delay >= config.base_delay,
546 "attempt {attempt}: delay {delay:?} below base"
547 );
548 assert!(
549 delay <= config.max_delay,
550 "attempt {attempt}: delay {delay:?} above max"
551 );
552 }
553 }
554
555 #[test]
556 fn compute_delay_caps_at_max() {
557 let config = RetryConfig {
558 max_retries: 10,
559 base_delay: Duration::from_millis(1000),
560 max_delay: Duration::from_secs(5),
561 };
562
563 for _ in 0..50 {
566 let d = compute_delay(&config, 3);
567 assert!(d <= config.max_delay, "delay {d:?} exceeds max");
568 let d = compute_delay(&config, 10);
569 assert!(d <= config.max_delay, "delay {d:?} exceeds max");
570 }
571 }
572
573 #[test]
574 fn compute_delay_handles_overflow() {
575 let config = RetryConfig {
576 max_retries: 100,
577 base_delay: Duration::from_secs(1),
578 max_delay: Duration::from_secs(60),
579 };
580
581 for _ in 0..50 {
583 let delay = compute_delay(&config, 50);
584 assert!(delay <= config.max_delay);
585 }
586 }
587
588 #[tokio::test]
589 async fn stream_retry_suppresses_on_text_on_retry() {
590 let text_calls = Arc::new(AtomicU32::new(0));
594 let text_calls_clone = text_calls.clone();
595 let on_text_fn = move |_: &str| {
596 text_calls_clone.fetch_add(1, Ordering::SeqCst);
597 };
598 let on_text: &crate::llm::OnText = &on_text_fn;
599
600 struct StreamFailOnce {
605 failed: AtomicU32,
606 }
607 impl LlmProvider for StreamFailOnce {
608 async fn complete(
609 &self,
610 _request: CompletionRequest,
611 ) -> Result<CompletionResponse, Error> {
612 Ok(success_response())
613 }
614 async fn stream_complete(
615 &self,
616 _request: CompletionRequest,
617 on_text: &crate::llm::OnText,
618 ) -> Result<CompletionResponse, Error> {
619 on_text("hello");
620 if self
621 .failed
622 .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |v| {
623 if v == 0 { Some(1) } else { None }
624 })
625 .is_ok()
626 {
627 return Err(Error::Api {
628 status: 503,
629 message: "transient".into(),
630 });
631 }
632 Ok(success_response())
633 }
634 }
635
636 let provider = RetryingProvider::new(
637 StreamFailOnce {
638 failed: AtomicU32::new(0),
639 },
640 fast_config(3),
641 );
642 let result = provider.stream_complete(test_request(), on_text).await;
643 assert!(result.is_ok());
644 assert_eq!(text_calls.load(Ordering::SeqCst), 1);
647 }
648
649 #[tokio::test]
650 async fn retrying_provider_fires_on_retry() {
651 let (mock, _count) = FailNTimes::new(2, || Error::Api {
652 status: 429,
653 message: "rate limited".into(),
654 });
655 let retries_seen = Arc::new(AtomicU32::new(0));
656 let retries_clone = retries_seen.clone();
657 let provider = RetryingProvider::new(mock, fast_config(3)).with_on_retry(Arc::new(
658 move |attempt, max_retries, _delay_ms, error_class| {
659 assert!(attempt > 0);
660 assert_eq!(max_retries, 3);
661 assert_eq!(error_class, "rate_limited");
662 retries_clone.fetch_add(1, Ordering::SeqCst);
663 },
664 ));
665
666 let result = provider.complete(test_request()).await;
667 assert!(result.is_ok());
668 assert_eq!(retries_seen.load(Ordering::SeqCst), 2); }
670
671 #[tokio::test]
672 async fn retrying_provider_on_retry_none_is_noop() {
673 let (mock, count) = FailNTimes::new(1, || Error::Api {
675 status: 500,
676 message: "server error".into(),
677 });
678 let provider = RetryingProvider::new(mock, fast_config(3));
679 let result = provider.complete(test_request()).await;
682 assert!(result.is_ok());
683 assert_eq!(count.load(Ordering::SeqCst), 2);
684 }
685
686 #[test]
687 fn classify_for_retry_returns_correct_classes() {
688 assert_eq!(
689 classify_for_retry(&Error::Api {
690 status: 429,
691 message: "".into()
692 }),
693 "rate_limited"
694 );
695 assert_eq!(
696 classify_for_retry(&Error::Api {
697 status: 500,
698 message: "".into()
699 }),
700 "server_error_500"
701 );
702 assert_eq!(
703 classify_for_retry(&Error::Api {
704 status: 502,
705 message: "".into()
706 }),
707 "server_error_502"
708 );
709 assert_eq!(
710 classify_for_retry(&Error::Api {
711 status: 503,
712 message: "".into()
713 }),
714 "server_error_503"
715 );
716 assert_eq!(
717 classify_for_retry(&Error::Api {
718 status: 529,
719 message: "".into()
720 }),
721 "overloaded"
722 );
723 assert_eq!(classify_for_retry(&Error::Agent("other".into())), "unknown");
730 }
731
732 #[test]
733 fn model_name_forwards_to_inner() {
734 struct NamedProvider;
735 impl LlmProvider for NamedProvider {
736 async fn complete(
737 &self,
738 _request: CompletionRequest,
739 ) -> Result<CompletionResponse, Error> {
740 unimplemented!()
741 }
742 fn model_name(&self) -> Option<&str> {
743 Some("my-model")
744 }
745 }
746 let provider = RetryingProvider::with_defaults(NamedProvider);
747 assert_eq!(provider.model_name(), Some("my-model"));
748 }
749}