1use std::sync::Arc;
7use std::time::Duration;
8
9use async_trait::async_trait;
10use serde::{Deserialize, Serialize};
11
12use awaken_contract::contract::executor::{
13 InferenceExecutionError, InferenceRequest, InferenceStream, LlmExecutor,
14};
15use awaken_contract::contract::inference::StreamResult;
16
17use super::circuit_breaker::CircuitBreaker;
18
19const MAX_BACKOFF_MS: u64 = 8_000;
21
22#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
24#[serde(deny_unknown_fields)]
25pub struct LlmRetryPolicy {
26 pub max_retries: u32,
28 pub fallback_upstream_models: Vec<String>,
30 #[serde(default = "default_backoff_base_ms")]
33 pub backoff_base_ms: u64,
34 #[serde(default = "default_overloaded_backoff_base_ms")]
38 pub overloaded_backoff_base_ms: u64,
39 #[serde(default = "default_max_stream_retries")]
44 pub max_stream_retries: u32,
45 #[serde(default = "default_stream_idle_timeout_secs")]
49 pub stream_idle_timeout_secs: u64,
50}
51
52fn default_backoff_base_ms() -> u64 {
53 500
54}
55
56fn default_overloaded_backoff_base_ms() -> u64 {
57 2_000
58}
59
60fn default_max_stream_retries() -> u32 {
61 2
62}
63
64fn default_stream_idle_timeout_secs() -> u64 {
65 60
66}
67
68impl Default for LlmRetryPolicy {
69 fn default() -> Self {
70 Self {
71 max_retries: 2,
72 fallback_upstream_models: Vec::new(),
73 backoff_base_ms: default_backoff_base_ms(),
74 overloaded_backoff_base_ms: default_overloaded_backoff_base_ms(),
75 max_stream_retries: default_max_stream_retries(),
76 stream_idle_timeout_secs: default_stream_idle_timeout_secs(),
77 }
78 }
79}
80
81impl LlmRetryPolicy {
82 pub fn no_retry() -> Self {
84 Self {
85 max_retries: 0,
86 ..Default::default()
87 }
88 }
89
90 pub fn with_max_retries(mut self, n: u32) -> Self {
92 self.max_retries = n;
93 self
94 }
95
96 pub fn with_fallback_upstream_model(mut self, upstream_model: impl Into<String>) -> Self {
98 self.fallback_upstream_models.push(upstream_model.into());
99 self
100 }
101
102 pub fn with_backoff_base_ms(mut self, ms: u64) -> Self {
104 self.backoff_base_ms = ms;
105 self
106 }
107
108 pub fn with_overloaded_backoff_base_ms(mut self, ms: u64) -> Self {
110 self.overloaded_backoff_base_ms = ms;
111 self
112 }
113
114 pub fn with_max_stream_retries(mut self, n: u32) -> Self {
116 self.max_stream_retries = n;
117 self
118 }
119
120 pub fn with_stream_idle_timeout_secs(mut self, secs: u64) -> Self {
122 self.stream_idle_timeout_secs = secs;
123 self
124 }
125
126 fn backoff_delay(&self, attempt: u32) -> Duration {
128 Self::backoff_delay_with_base(self.backoff_base_ms, attempt)
129 }
130
131 fn overloaded_backoff_delay(&self, attempt: u32) -> Duration {
133 Self::backoff_delay_with_base(self.overloaded_backoff_base_ms, attempt)
134 }
135
136 fn backoff_delay_with_base(base_ms: u64, attempt: u32) -> Duration {
137 if base_ms == 0 {
138 return Duration::ZERO;
139 }
140 let delay_ms = base_ms
141 .saturating_mul(1u64 << attempt.min(16))
142 .min(MAX_BACKOFF_MS);
143 Duration::from_millis(delay_ms)
144 }
145
146 pub fn delay_before_retry(&self, err: &InferenceExecutionError, attempt: u32) -> Duration {
150 let base = match err {
151 InferenceExecutionError::Overloaded { .. } => self.overloaded_backoff_delay(attempt),
152 _ => self.backoff_delay(attempt),
153 };
154 match err.retry_after() {
155 Some(hint) if hint > base => hint,
156 _ => base,
157 }
158 }
159}
160
161fn is_retryable(err: &InferenceExecutionError) -> bool {
163 err.is_retryable()
164}
165
166pub struct RetryingExecutor {
172 inner: Arc<dyn LlmExecutor>,
173 policy: LlmRetryPolicy,
174 circuit_breaker: Option<Arc<CircuitBreaker>>,
175}
176
177impl RetryingExecutor {
178 pub fn new(inner: Arc<dyn LlmExecutor>, policy: LlmRetryPolicy) -> Self {
180 Self {
181 inner,
182 policy,
183 circuit_breaker: None,
184 }
185 }
186
187 pub fn with_circuit_breaker(mut self, cb: Arc<CircuitBreaker>) -> Self {
189 self.circuit_breaker = Some(cb);
190 self
191 }
192
193 async fn try_with_retries(
195 &self,
196 request: &InferenceRequest,
197 ) -> Result<StreamResult, InferenceExecutionError> {
198 let mut last_error = None;
199
200 for attempt in 0..=self.policy.max_retries {
201 if let Some(ref cb) = self.circuit_breaker {
203 cb.check(&request.upstream_model)?;
204 }
205
206 match self.inner.execute(request.clone()).await {
207 Ok(result) => {
208 if let Some(ref cb) = self.circuit_breaker {
209 cb.record_success(&request.upstream_model);
210 }
211 return Ok(result);
212 }
213 Err(err) => {
214 if err.counts_toward_circuit_breaker() {
215 if let Some(ref cb) = self.circuit_breaker {
216 cb.record_failure(&request.upstream_model);
217 }
218 }
219 if !is_retryable(&err) {
220 return Err(err);
221 }
222 if attempt == self.policy.max_retries {
223 last_error = Some(err);
224 break;
225 }
226 let delay = self.policy.delay_before_retry(&err, attempt);
228 last_error = Some(err);
229 if !delay.is_zero() {
230 tokio::time::sleep(delay).await;
231 }
232 }
233 }
234 }
235
236 Err(last_error.expect("at least one attempt was made"))
237 }
238
239 fn fallback_upstream_models_for_request(&self, request: &InferenceRequest) -> Vec<String> {
240 request
241 .overrides
242 .as_ref()
243 .and_then(|overrides| overrides.fallback_upstream_models.clone())
244 .unwrap_or_else(|| self.policy.fallback_upstream_models.clone())
245 }
246
247 async fn try_stream_with_retries(
253 &self,
254 request: &InferenceRequest,
255 ) -> Result<InferenceStream, InferenceExecutionError> {
256 let mut last_error = None;
257
258 for attempt in 0..=self.policy.max_retries {
259 if let Some(ref cb) = self.circuit_breaker {
260 cb.check(&request.upstream_model)?;
261 }
262
263 match self.inner.execute_stream(request.clone()).await {
264 Ok(stream) => {
265 if let Some(ref cb) = self.circuit_breaker {
266 cb.record_success(&request.upstream_model);
267 }
268 return Ok(stream);
269 }
270 Err(err) => {
271 if err.counts_toward_circuit_breaker() {
272 if let Some(ref cb) = self.circuit_breaker {
273 cb.record_failure(&request.upstream_model);
274 }
275 }
276 if !is_retryable(&err) {
277 return Err(err);
278 }
279 if attempt == self.policy.max_retries {
280 last_error = Some(err);
281 break;
282 }
283 let delay = self.policy.delay_before_retry(&err, attempt);
284 last_error = Some(err);
285 if !delay.is_zero() {
286 tokio::time::sleep(delay).await;
287 }
288 }
289 }
290 }
291
292 Err(last_error.expect("at least one stream attempt was made"))
293 }
294
295 fn all_models_blocked(
298 &self,
299 request: &InferenceRequest,
300 fallback_upstream_models: &[String],
301 ) -> bool {
302 let Some(ref cb) = self.circuit_breaker else {
303 return false;
304 };
305 if cb.check(&request.upstream_model).is_ok() {
306 return false;
307 }
308 fallback_upstream_models
309 .iter()
310 .all(|m| cb.check(m).is_err())
311 }
312}
313
314#[async_trait]
315impl LlmExecutor for RetryingExecutor {
316 async fn execute(
317 &self,
318 request: InferenceRequest,
319 ) -> Result<StreamResult, InferenceExecutionError> {
320 let fallback_upstream_models = self.fallback_upstream_models_for_request(&request);
321
322 if self.all_models_blocked(&request, &fallback_upstream_models) {
323 return Err(InferenceExecutionError::AllModelsUnavailable);
324 }
325
326 match self.try_with_retries(&request).await {
328 Ok(result) => return Ok(result),
329 Err(err) if !is_retryable(&err) || fallback_upstream_models.is_empty() => {
330 return Err(err);
331 }
332 Err(_) => {}
333 }
334
335 let mut last_error = None;
337 for (i, fallback_upstream_model) in fallback_upstream_models.iter().enumerate() {
338 let mut fallback_request = request.clone();
339 fallback_request.upstream_model = fallback_upstream_model.clone();
340
341 match self.try_with_retries(&fallback_request).await {
342 Ok(result) => return Ok(result),
343 Err(err) => {
344 let is_last = i == fallback_upstream_models.len() - 1;
345 if !is_retryable(&err) || is_last {
346 last_error = Some(err);
347 break;
348 }
349 last_error = Some(err);
350 }
351 }
352 }
353
354 Err(last_error.expect("at least one fallback was attempted"))
355 }
356
357 fn execute_stream(
358 &self,
359 request: InferenceRequest,
360 ) -> std::pin::Pin<
361 Box<
362 dyn std::future::Future<Output = Result<InferenceStream, InferenceExecutionError>>
363 + Send
364 + '_,
365 >,
366 > {
367 Box::pin(async move {
368 let fallback_upstream_models = self.fallback_upstream_models_for_request(&request);
369
370 if self.all_models_blocked(&request, &fallback_upstream_models) {
371 return Err(InferenceExecutionError::AllModelsUnavailable);
372 }
373
374 match self.try_stream_with_retries(&request).await {
375 Ok(stream) => return Ok(stream),
376 Err(err) if !is_retryable(&err) || fallback_upstream_models.is_empty() => {
377 return Err(err);
378 }
379 Err(_) => {}
380 }
381
382 let mut last_error = None;
383 for (i, fallback_upstream_model) in fallback_upstream_models.iter().enumerate() {
384 let mut fallback_request = request.clone();
385 fallback_request.upstream_model = fallback_upstream_model.clone();
386
387 match self.try_stream_with_retries(&fallback_request).await {
388 Ok(stream) => return Ok(stream),
389 Err(err) => {
390 let is_last = i == fallback_upstream_models.len() - 1;
391 if !is_retryable(&err) || is_last {
392 last_error = Some(err);
393 break;
394 }
395 last_error = Some(err);
396 }
397 }
398 }
399
400 Err(last_error.expect("at least one stream fallback was attempted"))
401 })
402 }
403
404 fn name(&self) -> &str {
405 self.inner.name()
406 }
407}
408
409pub struct RetryConfigKey;
411
412impl awaken_contract::registry_spec::PluginConfigKey for RetryConfigKey {
413 const KEY: &'static str = "retry";
414 type Config = LlmRetryPolicy;
415}
416
417#[cfg(test)]
418mod tests {
419 use super::*;
420 use awaken_contract::contract::content::ContentBlock;
421 use awaken_contract::contract::inference::{InferenceOverride, StopReason, TokenUsage};
422 use awaken_contract::contract::message::Message;
423 use std::sync::atomic::{AtomicU32, Ordering};
424
425 fn test_policy() -> LlmRetryPolicy {
427 LlmRetryPolicy::default().with_backoff_base_ms(0)
428 }
429
430 struct FailNThenSucceed {
432 fail_count: u32,
433 error_kind: fn(u32) -> InferenceExecutionError,
434 calls: AtomicU32,
435 }
436
437 impl FailNThenSucceed {
438 fn new(fail_count: u32) -> Self {
439 Self {
440 fail_count,
441 error_kind: |_| InferenceExecutionError::Provider("transient".into()),
442 calls: AtomicU32::new(0),
443 }
444 }
445
446 fn with_error(mut self, f: fn(u32) -> InferenceExecutionError) -> Self {
447 self.error_kind = f;
448 self
449 }
450
451 fn call_count(&self) -> u32 {
452 self.calls.load(Ordering::SeqCst)
453 }
454 }
455
456 fn ok_result() -> StreamResult {
457 StreamResult {
458 content: vec![ContentBlock::text("ok")],
459 tool_calls: vec![],
460 usage: Some(TokenUsage {
461 prompt_tokens: Some(10),
462 completion_tokens: Some(5),
463 total_tokens: Some(15),
464 ..Default::default()
465 }),
466 stop_reason: Some(StopReason::EndTurn),
467 has_incomplete_tool_calls: false,
468 }
469 }
470
471 fn test_request() -> InferenceRequest {
472 InferenceRequest {
473 upstream_model: "primary-model".into(),
474 messages: vec![Message::user("hello")],
475 tools: vec![],
476 system: vec![],
477 overrides: None,
478 enable_prompt_cache: false,
479 }
480 }
481
482 #[async_trait]
483 impl LlmExecutor for FailNThenSucceed {
484 async fn execute(
485 &self,
486 _request: InferenceRequest,
487 ) -> Result<StreamResult, InferenceExecutionError> {
488 let call = self.calls.fetch_add(1, Ordering::SeqCst);
489 if call < self.fail_count {
490 Err((self.error_kind)(call))
491 } else {
492 Ok(ok_result())
493 }
494 }
495
496 fn name(&self) -> &str {
497 "mock"
498 }
499 }
500
501 struct ModelRecorder {
503 models: std::sync::Mutex<Vec<String>>,
504 error: InferenceExecutionError,
505 }
506
507 impl ModelRecorder {
508 fn always_fail_with(err: InferenceExecutionError) -> Self {
509 Self {
510 models: std::sync::Mutex::new(Vec::new()),
511 error: err,
512 }
513 }
514
515 fn recorded_models(&self) -> Vec<String> {
516 self.models.lock().unwrap().clone()
517 }
518 }
519
520 #[async_trait]
521 impl LlmExecutor for ModelRecorder {
522 async fn execute(
523 &self,
524 request: InferenceRequest,
525 ) -> Result<StreamResult, InferenceExecutionError> {
526 self.models
527 .lock()
528 .unwrap()
529 .push(request.upstream_model.clone());
530 Err(self.error.clone())
531 }
532
533 fn name(&self) -> &str {
534 "model-recorder"
535 }
536 }
537
538 #[tokio::test]
539 async fn no_retry_policy_first_failure_is_terminal() {
540 let inner = Arc::new(FailNThenSucceed::new(1));
541 let executor = RetryingExecutor::new(
542 inner.clone(),
543 LlmRetryPolicy::no_retry().with_backoff_base_ms(0),
544 );
545
546 let result = executor.execute(test_request()).await;
547 assert!(result.is_err());
548 assert_eq!(inner.call_count(), 1);
549 }
550
551 #[tokio::test]
552 async fn retry_succeeds_on_second_attempt() {
553 let inner = Arc::new(FailNThenSucceed::new(1));
554 let policy = test_policy().with_max_retries(2);
555 let executor = RetryingExecutor::new(inner.clone(), policy);
556
557 let result = executor.execute(test_request()).await;
558 assert!(result.is_ok());
559 assert_eq!(inner.call_count(), 2);
560 }
561
562 #[tokio::test]
563 async fn retry_exhausts_all_attempts_returns_last_error() {
564 let inner = Arc::new(FailNThenSucceed::new(100)); let policy = test_policy().with_max_retries(3);
566 let executor = RetryingExecutor::new(inner.clone(), policy);
567
568 let result = executor.execute(test_request()).await;
569 assert!(result.is_err());
570 assert_eq!(inner.call_count(), 4);
572 }
573
574 #[tokio::test]
575 async fn non_retryable_error_is_not_retried() {
576 let inner =
577 Arc::new(FailNThenSucceed::new(1).with_error(|_| InferenceExecutionError::Cancelled));
578 let policy = test_policy().with_max_retries(5);
579 let executor = RetryingExecutor::new(inner.clone(), policy);
580
581 let result = executor.execute(test_request()).await;
582 assert!(result.is_err());
583 assert_eq!(inner.call_count(), 1);
584 }
585
586 #[tokio::test]
587 async fn fallback_upstream_model_used_after_primary_exhausts_retries() {
588 let inner = Arc::new(ModelRecorder::always_fail_with(
589 InferenceExecutionError::rate_limited("overloaded"),
590 ));
591 let policy = test_policy()
592 .with_max_retries(1)
593 .with_fallback_upstream_model("fallback-a")
594 .with_fallback_upstream_model("fallback-b");
595 let executor = RetryingExecutor::new(inner.clone(), policy);
596
597 let result = executor.execute(test_request()).await;
598 assert!(result.is_err());
599
600 let models = inner.recorded_models();
601 assert_eq!(models.len(), 6);
605 assert_eq!(models[0], "primary-model");
606 assert_eq!(models[1], "primary-model");
607 assert_eq!(models[2], "fallback-a");
608 assert_eq!(models[3], "fallback-a");
609 assert_eq!(models[4], "fallback-b");
610 assert_eq!(models[5], "fallback-b");
611 }
612
613 #[tokio::test]
614 async fn request_override_fallback_upstream_models_replace_policy_fallbacks() {
615 let inner = Arc::new(ModelRecorder::always_fail_with(
616 InferenceExecutionError::rate_limited("overloaded"),
617 ));
618 let policy = test_policy()
619 .with_max_retries(0)
620 .with_fallback_upstream_model("policy-fallback");
621 let executor = RetryingExecutor::new(inner.clone(), policy);
622
623 let mut request = test_request();
624 request.overrides = Some(InferenceOverride {
625 fallback_upstream_models: Some(vec!["override-fallback".into()]),
626 ..Default::default()
627 });
628
629 let result = executor.execute(request).await;
630 assert!(result.is_err());
631
632 assert_eq!(
633 inner.recorded_models(),
634 vec!["primary-model", "override-fallback"]
635 );
636 }
637
638 #[tokio::test]
639 async fn execute_stream_retries_stream_start_until_success() {
640 let inner = Arc::new(FailNThenSucceed::new(1));
641 let policy = test_policy().with_max_retries(2);
642 let executor = RetryingExecutor::new(inner.clone(), policy);
643
644 let result = executor.execute_stream(test_request()).await;
645 assert!(result.is_ok());
646 assert_eq!(inner.call_count(), 2);
647 }
648
649 #[tokio::test]
650 async fn execute_stream_uses_request_override_fallback_upstream_models() {
651 let inner = Arc::new(ModelRecorder::always_fail_with(
652 InferenceExecutionError::rate_limited("overloaded"),
653 ));
654 let policy = test_policy()
655 .with_max_retries(0)
656 .with_fallback_upstream_model("policy-fallback");
657 let executor = RetryingExecutor::new(inner.clone(), policy);
658
659 let mut request = test_request();
660 request.overrides = Some(InferenceOverride {
661 fallback_upstream_models: Some(vec!["override-fallback".into()]),
662 ..Default::default()
663 });
664
665 let result = executor.execute_stream(request).await;
666 assert!(result.is_err());
667
668 assert_eq!(
669 inner.recorded_models(),
670 vec!["primary-model", "override-fallback"]
671 );
672 }
673
674 #[tokio::test]
675 async fn fallback_succeeds_after_primary_fails() {
676 let inner = Arc::new(FailNThenSucceed::new(3));
677 let policy = test_policy()
678 .with_max_retries(1)
679 .with_fallback_upstream_model("fallback-model");
680 let executor = RetryingExecutor::new(inner.clone(), policy);
681
682 let result = executor.execute(test_request()).await;
683 assert!(result.is_ok());
684 assert_eq!(inner.call_count(), 4);
686 }
687
688 #[tokio::test]
689 async fn succeeds_on_first_try_no_retry_needed() {
690 let inner = Arc::new(FailNThenSucceed::new(0)); let policy = test_policy().with_max_retries(3);
692 let executor = RetryingExecutor::new(inner.clone(), policy);
693
694 let result = executor.execute(test_request()).await;
695 assert!(result.is_ok());
696 assert_eq!(inner.call_count(), 1, "should call executor exactly once");
697 }
698
699 #[tokio::test]
700 async fn retrying_executor_delegates_name() {
701 let inner = Arc::new(FailNThenSucceed::new(0));
702 let executor = RetryingExecutor::new(inner, test_policy());
703 assert_eq!(executor.name(), "mock");
704 }
705
706 #[tokio::test]
707 async fn non_retryable_error_during_fallback_stops_immediately() {
708 let call_count = Arc::new(AtomicU32::new(0));
709 let cc = call_count.clone();
710
711 struct PrimaryRetryableFallbackFatal {
712 calls: Arc<AtomicU32>,
713 }
714
715 #[async_trait]
716 impl LlmExecutor for PrimaryRetryableFallbackFatal {
717 async fn execute(
718 &self,
719 request: InferenceRequest,
720 ) -> Result<StreamResult, InferenceExecutionError> {
721 let n = self.calls.fetch_add(1, Ordering::SeqCst);
722 if request.upstream_model.starts_with("primary") {
723 Err(InferenceExecutionError::Provider("down".into()))
724 } else {
725 let _ = n;
726 Err(InferenceExecutionError::Cancelled)
727 }
728 }
729
730 fn name(&self) -> &str {
731 "primary-retryable-fallback-fatal"
732 }
733 }
734
735 let inner = Arc::new(PrimaryRetryableFallbackFatal { calls: cc });
736 let policy = test_policy()
737 .with_max_retries(0)
738 .with_fallback_upstream_model("fallback-a")
739 .with_fallback_upstream_model("fallback-b");
740 let executor = RetryingExecutor::new(inner, policy);
741
742 let result = executor.execute(test_request()).await;
743 assert!(result.is_err());
744 assert_eq!(call_count.load(Ordering::SeqCst), 2);
746 }
747
748 #[test]
749 fn default_policy_values() {
750 let policy = LlmRetryPolicy::default();
751 assert_eq!(policy.max_retries, 2);
752 assert!(policy.fallback_upstream_models.is_empty());
753 assert_eq!(policy.backoff_base_ms, 500);
754 assert_eq!(policy.overloaded_backoff_base_ms, 2_000);
755 assert_eq!(policy.max_stream_retries, 2);
756 assert_eq!(policy.stream_idle_timeout_secs, 60);
757 }
758
759 #[test]
760 fn no_retry_policy_values() {
761 let policy = LlmRetryPolicy::no_retry();
762 assert_eq!(policy.max_retries, 0);
763 assert!(policy.fallback_upstream_models.is_empty());
764 }
765
766 #[test]
767 fn rate_limit_error_is_retryable() {
768 assert!(is_retryable(&InferenceExecutionError::rate_limited("429")));
769 }
770
771 #[test]
772 fn overloaded_error_is_retryable() {
773 assert!(is_retryable(&InferenceExecutionError::overloaded("529")));
774 }
775
776 #[test]
777 fn context_overflow_is_not_retryable() {
778 assert!(!is_retryable(&InferenceExecutionError::ContextOverflow(
779 "too long".into()
780 )));
781 }
782
783 #[test]
784 fn context_overflow_does_not_count_toward_breaker() {
785 let err = InferenceExecutionError::ContextOverflow("too long".into());
786 assert!(!err.counts_toward_circuit_breaker());
787 }
788
789 #[test]
790 fn invalid_request_does_not_count_toward_breaker() {
791 assert!(
792 !InferenceExecutionError::InvalidRequest("schema".into())
793 .counts_toward_circuit_breaker()
794 );
795 }
796
797 #[test]
798 fn unauthorized_does_not_count_toward_breaker() {
799 assert!(
800 !InferenceExecutionError::Unauthorized("key".into()).counts_toward_circuit_breaker()
801 );
802 }
803
804 #[test]
805 fn all_models_unavailable_is_fail_fast() {
806 let err = InferenceExecutionError::AllModelsUnavailable;
807 assert!(!err.is_retryable());
808 assert!(!err.counts_toward_circuit_breaker());
809 }
810
811 #[test]
812 fn server_error_is_retryable() {
813 assert!(is_retryable(&InferenceExecutionError::Provider(
814 "500 internal".into()
815 )));
816 }
817
818 #[test]
819 fn timeout_error_is_retryable() {
820 assert!(is_retryable(&InferenceExecutionError::Timeout(
821 "timed out".into()
822 )));
823 }
824
825 #[test]
826 fn cancelled_error_is_not_retryable() {
827 assert!(!is_retryable(&InferenceExecutionError::Cancelled));
828 }
829
830 #[test]
831 fn builder_methods_chain() {
832 let policy = LlmRetryPolicy::default()
833 .with_max_retries(5)
834 .with_fallback_upstream_model("model-a")
835 .with_fallback_upstream_model("model-b")
836 .with_backoff_base_ms(100);
837 assert_eq!(policy.max_retries, 5);
838 assert_eq!(policy.fallback_upstream_models, vec!["model-a", "model-b"]);
839 assert_eq!(policy.backoff_base_ms, 100);
840 }
841
842 #[test]
847 fn backoff_delay_zero_base() {
848 let policy = LlmRetryPolicy::default().with_backoff_base_ms(0);
849 assert_eq!(policy.backoff_delay(0), Duration::ZERO);
850 assert_eq!(policy.backoff_delay(5), Duration::ZERO);
851 }
852
853 #[test]
854 fn backoff_delay_exponential() {
855 let policy = LlmRetryPolicy::default().with_backoff_base_ms(500);
856 assert_eq!(policy.backoff_delay(0), Duration::from_millis(500)); assert_eq!(policy.backoff_delay(1), Duration::from_millis(1000)); assert_eq!(policy.backoff_delay(2), Duration::from_millis(2000)); assert_eq!(policy.backoff_delay(3), Duration::from_millis(4000)); }
861
862 #[test]
863 fn backoff_delay_caps_at_max() {
864 let policy = LlmRetryPolicy::default().with_backoff_base_ms(500);
865 assert_eq!(policy.backoff_delay(4), Duration::from_millis(8000));
867 assert_eq!(policy.backoff_delay(5), Duration::from_millis(8000));
869 }
870
871 #[tokio::test]
876 async fn circuit_breaker_blocks_when_open() {
877 use crate::engine::circuit_breaker::CircuitBreakerConfig;
878
879 let inner = Arc::new(FailNThenSucceed::new(100));
880 let cb = Arc::new(CircuitBreaker::new(CircuitBreakerConfig {
881 failure_threshold: 2,
882 cooldown: std::time::Duration::from_secs(60),
883 half_open_max: 1,
884 }));
885
886 cb.record_failure("primary-model");
888 cb.record_failure("primary-model");
889
890 let policy = test_policy().with_max_retries(3);
891 let executor = RetryingExecutor::new(inner.clone(), policy).with_circuit_breaker(cb);
892
893 let result = executor.execute(test_request()).await;
894 assert!(result.is_err());
895 assert_eq!(inner.call_count(), 0);
897 }
898
899 #[tokio::test]
900 async fn circuit_breaker_records_success() {
901 use crate::engine::circuit_breaker::CircuitBreakerConfig;
902
903 let inner = Arc::new(FailNThenSucceed::new(0));
904 let cb = Arc::new(CircuitBreaker::new(CircuitBreakerConfig {
905 failure_threshold: 2,
906 cooldown: std::time::Duration::from_secs(60),
907 half_open_max: 1,
908 }));
909
910 cb.record_failure("primary-model");
912
913 let policy = test_policy().with_max_retries(1);
914 let executor =
915 RetryingExecutor::new(inner.clone(), policy).with_circuit_breaker(cb.clone());
916
917 let result = executor.execute(test_request()).await;
918 assert!(result.is_ok());
919
920 cb.record_failure("primary-model");
922 assert!(cb.check("primary-model").is_ok());
923 }
924
925 #[tokio::test]
930 async fn retry_on_rate_limit_then_succeed() {
931 let inner = Arc::new(
932 FailNThenSucceed::new(2)
933 .with_error(|_| InferenceExecutionError::rate_limited("rate limited")),
934 );
935 let policy = test_policy().with_max_retries(3);
936 let executor = RetryingExecutor::new(inner.clone(), policy);
937
938 let result = executor.execute(test_request()).await;
939 assert!(result.is_ok());
940 assert_eq!(inner.call_count(), 3); }
942
943 #[tokio::test]
944 async fn retry_on_timeout_then_succeed() {
945 let inner = Arc::new(
946 FailNThenSucceed::new(1)
947 .with_error(|_| InferenceExecutionError::Timeout("timed out".into())),
948 );
949 let policy = test_policy().with_max_retries(2);
950 let executor = RetryingExecutor::new(inner.clone(), policy);
951
952 let result = executor.execute(test_request()).await;
953 assert!(result.is_ok());
954 assert_eq!(inner.call_count(), 2);
955 }
956
957 #[tokio::test]
958 async fn zero_retries_with_fallback_tries_fallback_once() {
959 let inner = Arc::new(FailNThenSucceed::new(1)); let policy = test_policy()
961 .with_max_retries(0)
962 .with_fallback_upstream_model("fallback");
963 let executor = RetryingExecutor::new(inner.clone(), policy);
964
965 let result = executor.execute(test_request()).await;
966 assert!(result.is_ok());
967 assert_eq!(inner.call_count(), 2); }
969
970 #[tokio::test]
971 async fn no_fallback_upstream_models_configured_returns_primary_error() {
972 let inner = Arc::new(FailNThenSucceed::new(100));
973 let policy = test_policy().with_max_retries(1);
974 let executor = RetryingExecutor::new(inner.clone(), policy);
976
977 let result = executor.execute(test_request()).await;
978 assert!(result.is_err());
979 assert_eq!(inner.call_count(), 2); }
981
982 #[tokio::test]
983 async fn all_error_types_handled() {
984 for error_fn in [
985 (|_: u32| InferenceExecutionError::Provider("down".into())) as fn(u32) -> _,
986 |_| InferenceExecutionError::rate_limited("429"),
987 |_| InferenceExecutionError::Timeout("timeout".into()),
988 ] {
989 let inner = Arc::new(FailNThenSucceed::new(1).with_error(error_fn));
990 let policy = test_policy().with_max_retries(2);
991 let executor = RetryingExecutor::new(inner.clone(), policy);
992
993 let result = executor.execute(test_request()).await;
994 assert!(result.is_ok(), "should recover from retryable error");
995 }
996 }
997
998 #[tokio::test]
999 async fn max_retries_zero_and_no_fallback_just_one_attempt() {
1000 let inner = Arc::new(FailNThenSucceed::new(100));
1001 let policy = LlmRetryPolicy::no_retry().with_backoff_base_ms(0);
1002 let executor = RetryingExecutor::new(inner.clone(), policy);
1003
1004 let result = executor.execute(test_request()).await;
1005 assert!(result.is_err());
1006 assert_eq!(inner.call_count(), 1);
1007 }
1008
1009 #[tokio::test]
1010 async fn success_on_first_try_no_fallback_attempted() {
1011 let recorder = Arc::new(ModelRecorder::always_fail_with(
1012 InferenceExecutionError::Provider("down".into()),
1013 ));
1014 let inner = Arc::new(FailNThenSucceed::new(0)); let policy = test_policy()
1016 .with_max_retries(3)
1017 .with_fallback_upstream_model("fallback-a");
1018 let executor = RetryingExecutor::new(inner.clone(), policy);
1019
1020 let result = executor.execute(test_request()).await;
1021 assert!(result.is_ok());
1022 assert_eq!(inner.call_count(), 1, "should not attempt fallback");
1023 let _ = recorder; }
1025
1026 #[test]
1031 fn retry_policy_serde_roundtrip() {
1032 let policy = LlmRetryPolicy::default()
1033 .with_max_retries(5)
1034 .with_fallback_upstream_model("fallback-a")
1035 .with_fallback_upstream_model("fallback-b")
1036 .with_backoff_base_ms(200)
1037 .with_overloaded_backoff_base_ms(4_000)
1038 .with_max_stream_retries(3)
1039 .with_stream_idle_timeout_secs(90);
1040 let json = serde_json::to_string(&policy).unwrap();
1041 let parsed: LlmRetryPolicy = serde_json::from_str(&json).unwrap();
1042 assert_eq!(parsed.max_retries, 5);
1043 assert_eq!(
1044 parsed.fallback_upstream_models,
1045 vec!["fallback-a", "fallback-b"]
1046 );
1047 assert_eq!(parsed.backoff_base_ms, 200);
1048 assert_eq!(parsed.overloaded_backoff_base_ms, 4_000);
1049 assert_eq!(parsed.max_stream_retries, 3);
1050 assert_eq!(parsed.stream_idle_timeout_secs, 90);
1051 }
1052
1053 #[test]
1054 fn retry_policy_serde_default_backoff() {
1055 let json = r#"{"max_retries":2,"fallback_upstream_models":[]}"#;
1057 let parsed: LlmRetryPolicy = serde_json::from_str(json).unwrap();
1058 assert_eq!(parsed.backoff_base_ms, 500);
1059 assert_eq!(parsed.overloaded_backoff_base_ms, 2_000);
1060 assert_eq!(parsed.max_stream_retries, 2);
1061 assert_eq!(parsed.stream_idle_timeout_secs, 60);
1062 }
1063
1064 #[test]
1065 fn retry_policy_rejects_legacy_fallback_field() {
1066 let json = r#"{"max_retries":2,"fallback_models":[]}"#;
1067 let parsed = serde_json::from_str::<LlmRetryPolicy>(json);
1068 assert!(parsed.is_err());
1069 }
1070
1071 #[tokio::test]
1072 async fn retry_budget_exact_boundary() {
1073 let inner = Arc::new(FailNThenSucceed::new(2));
1074 let policy = test_policy().with_max_retries(2);
1075 let executor = RetryingExecutor::new(inner.clone(), policy);
1076
1077 let result = executor.execute(test_request()).await;
1078 assert!(result.is_ok());
1079 assert_eq!(inner.call_count(), 3);
1080 }
1081
1082 #[tokio::test]
1083 async fn retry_budget_one_over_boundary() {
1084 let inner = Arc::new(FailNThenSucceed::new(3));
1085 let policy = test_policy().with_max_retries(2);
1086 let executor = RetryingExecutor::new(inner.clone(), policy);
1087
1088 let result = executor.execute(test_request()).await;
1089 assert!(result.is_err());
1090 assert_eq!(inner.call_count(), 3, "1 initial + 2 retries = 3 calls");
1091 }
1092
1093 #[tokio::test]
1098 async fn circuit_breaker_opens_during_retry_sequence() {
1099 use crate::engine::circuit_breaker::CircuitBreakerConfig;
1100
1101 let cb = Arc::new(CircuitBreaker::new(CircuitBreakerConfig {
1102 failure_threshold: 2,
1103 cooldown: Duration::from_secs(60),
1104 half_open_max: 1,
1105 }));
1106 let inner = Arc::new(FailNThenSucceed::new(100)); let policy = test_policy().with_max_retries(5);
1108 let executor = RetryingExecutor::new(inner.clone(), policy).with_circuit_breaker(cb);
1109
1110 let result = executor.execute(test_request()).await;
1111 assert!(result.is_err());
1112 assert_eq!(inner.call_count(), 2);
1114 }
1115
1116 #[tokio::test]
1121 async fn circuit_breaker_independent_per_model_in_fallback() {
1122 use crate::engine::circuit_breaker::CircuitBreakerConfig;
1123
1124 let cb = Arc::new(CircuitBreaker::new(CircuitBreakerConfig {
1125 failure_threshold: 2,
1126 cooldown: Duration::from_secs(60),
1127 half_open_max: 1,
1128 }));
1129 cb.record_failure("primary-model");
1131 cb.record_failure("primary-model");
1132
1133 let inner = Arc::new(FailNThenSucceed::new(0));
1135 let policy = test_policy()
1136 .with_max_retries(0)
1137 .with_fallback_upstream_model("fallback-model");
1138 let executor = RetryingExecutor::new(inner.clone(), policy).with_circuit_breaker(cb);
1139
1140 let result = executor.execute(test_request()).await;
1141 assert!(result.is_ok());
1143 assert_eq!(inner.call_count(), 1);
1145 }
1146
1147 #[test]
1153 fn delay_before_retry_respects_retry_after_when_longer() {
1154 let policy = LlmRetryPolicy::default().with_backoff_base_ms(100);
1155 let err = InferenceExecutionError::RateLimited {
1156 message: "slow".into(),
1157 retry_after: Some(Duration::from_secs(5)),
1158 };
1159 assert_eq!(policy.delay_before_retry(&err, 0), Duration::from_secs(5));
1161 }
1162
1163 #[test]
1164 fn delay_before_retry_uses_exponential_when_longer_than_retry_after() {
1165 let policy = LlmRetryPolicy::default().with_backoff_base_ms(10_000);
1166 let err = InferenceExecutionError::RateLimited {
1167 message: "fast hint".into(),
1168 retry_after: Some(Duration::from_millis(100)),
1169 };
1170 assert_eq!(
1172 policy.delay_before_retry(&err, 0),
1173 Duration::from_millis(MAX_BACKOFF_MS)
1174 );
1175 }
1176
1177 #[test]
1178 fn delay_before_retry_uses_overloaded_base_for_overloaded_errors() {
1179 let policy = LlmRetryPolicy::default()
1180 .with_backoff_base_ms(500)
1181 .with_overloaded_backoff_base_ms(2_000);
1182 let overloaded = InferenceExecutionError::overloaded("surge");
1183 assert_eq!(
1185 policy.delay_before_retry(&overloaded, 0),
1186 Duration::from_millis(2_000)
1187 );
1188 }
1189
1190 #[tokio::test(start_paused = true)]
1191 async fn rate_limited_retry_after_waits_hint_duration() {
1192 let inner = Arc::new(FailNThenSucceed::new(1).with_error(|_| {
1193 InferenceExecutionError::RateLimited {
1194 message: "slow down".into(),
1195 retry_after: Some(Duration::from_secs(3)),
1196 }
1197 }));
1198 let policy = LlmRetryPolicy::default()
1199 .with_max_retries(2)
1200 .with_backoff_base_ms(10); let executor = RetryingExecutor::new(inner.clone(), policy);
1202
1203 let start = tokio::time::Instant::now();
1204 let result = executor.execute(test_request()).await;
1205 assert!(result.is_ok());
1206 let elapsed = start.elapsed();
1207 assert!(
1208 elapsed >= Duration::from_secs(3),
1209 "expected >=3s retry-after wait, got {elapsed:?}"
1210 );
1211 assert_eq!(inner.call_count(), 2);
1212 }
1213
1214 #[tokio::test]
1215 async fn context_overflow_error_is_not_retried() {
1216 let inner =
1217 Arc::new(FailNThenSucceed::new(5).with_error(|_| {
1218 InferenceExecutionError::ContextOverflow("prompt too long".into())
1219 }));
1220 let policy = test_policy().with_max_retries(3);
1221 let executor = RetryingExecutor::new(inner.clone(), policy);
1222
1223 let result = executor.execute(test_request()).await;
1224 assert!(matches!(
1225 result,
1226 Err(InferenceExecutionError::ContextOverflow(_))
1227 ));
1228 assert_eq!(inner.call_count(), 1, "permanent error must not retry");
1229 }
1230
1231 #[tokio::test]
1232 async fn context_overflow_does_not_trip_circuit_breaker() {
1233 use crate::engine::circuit_breaker::CircuitBreakerConfig;
1234
1235 let inner = Arc::new(
1236 FailNThenSucceed::new(100)
1237 .with_error(|_| InferenceExecutionError::ContextOverflow("too long".into())),
1238 );
1239 let cb = Arc::new(CircuitBreaker::new(CircuitBreakerConfig {
1240 failure_threshold: 2,
1241 cooldown: Duration::from_secs(60),
1242 half_open_max: 1,
1243 }));
1244
1245 let policy = test_policy().with_max_retries(0);
1246 let executor =
1247 RetryingExecutor::new(inner.clone(), policy).with_circuit_breaker(cb.clone());
1248
1249 for _ in 0..5 {
1251 let _ = executor.execute(test_request()).await;
1252 }
1253 assert!(
1254 cb.check("primary-model").is_ok(),
1255 "ContextOverflow must not increment the breaker"
1256 );
1257 }
1258
1259 #[tokio::test]
1260 async fn all_models_blocked_short_circuits_with_all_models_unavailable() {
1261 use crate::engine::circuit_breaker::CircuitBreakerConfig;
1262
1263 let cb = Arc::new(CircuitBreaker::new(CircuitBreakerConfig {
1264 failure_threshold: 1,
1265 cooldown: Duration::from_secs(60),
1266 half_open_max: 1,
1267 }));
1268 cb.record_failure("primary-model");
1269 cb.record_failure("fallback-a");
1270 cb.record_failure("fallback-b");
1271
1272 let inner = Arc::new(FailNThenSucceed::new(0)); let policy = test_policy()
1274 .with_max_retries(2)
1275 .with_fallback_upstream_model("fallback-a")
1276 .with_fallback_upstream_model("fallback-b");
1277 let executor =
1278 RetryingExecutor::new(inner.clone(), policy).with_circuit_breaker(cb.clone());
1279
1280 let result = executor.execute(test_request()).await;
1281 assert!(
1282 matches!(result, Err(InferenceExecutionError::AllModelsUnavailable)),
1283 "expected AllModelsUnavailable, got {result:?}"
1284 );
1285 assert_eq!(inner.call_count(), 0, "no inner call should be made");
1286 }
1287
1288 #[tokio::test(start_paused = true)]
1293 async fn backoff_actually_sleeps() {
1294 let inner = Arc::new(FailNThenSucceed::new(2));
1295 let policy = LlmRetryPolicy::default()
1296 .with_max_retries(3)
1297 .with_backoff_base_ms(1000); let executor = RetryingExecutor::new(inner.clone(), policy);
1299
1300 let start = tokio::time::Instant::now();
1301 let result = executor.execute(test_request()).await;
1302 assert!(result.is_ok());
1303
1304 let elapsed = start.elapsed();
1309 assert!(
1310 elapsed >= Duration::from_secs(3),
1311 "expected >= 3s backoff, got {elapsed:?}"
1312 );
1313 }
1314
1315 mod proptest_retry {
1318 use super::*;
1319 use proptest::prelude::*;
1320
1321 proptest! {
1322 #[test]
1323 fn llm_retry_policy_serde_roundtrip(
1324 max_retries in 0u32..10,
1325 backoff_base_ms in 0u64..10000,
1326 overloaded_backoff_base_ms in 0u64..10000,
1327 max_stream_retries in 0u32..10,
1328 stream_idle_timeout_secs in 1u64..300,
1329 num_fallbacks in 0usize..5,
1330 ) {
1331 let policy = LlmRetryPolicy {
1332 max_retries,
1333 fallback_upstream_models: (0..num_fallbacks).map(|i| format!("model-{i}")).collect(),
1334 backoff_base_ms,
1335 overloaded_backoff_base_ms,
1336 max_stream_retries,
1337 stream_idle_timeout_secs,
1338 };
1339 let json = serde_json::to_string(&policy).unwrap();
1340 let parsed: LlmRetryPolicy = serde_json::from_str(&json).unwrap();
1341 prop_assert_eq!(parsed.max_retries, max_retries);
1342 prop_assert_eq!(parsed.backoff_base_ms, backoff_base_ms);
1343 prop_assert_eq!(parsed.overloaded_backoff_base_ms, overloaded_backoff_base_ms);
1344 prop_assert_eq!(parsed.max_stream_retries, max_stream_retries);
1345 prop_assert_eq!(parsed.stream_idle_timeout_secs, stream_idle_timeout_secs);
1346 prop_assert_eq!(parsed.fallback_upstream_models.len(), num_fallbacks);
1347 }
1348
1349 #[test]
1350 fn backoff_delay_is_monotonically_non_decreasing(
1351 base_ms in 1u64..1000,
1352 ) {
1353 let policy = LlmRetryPolicy::default().with_backoff_base_ms(base_ms);
1354 let mut prev = Duration::ZERO;
1355 for attempt in 0..10u32 {
1356 let delay = policy.backoff_delay(attempt);
1357 prop_assert!(
1358 delay >= prev,
1359 "delay should be monotonically non-decreasing: attempt={attempt}, delay={delay:?}, prev={prev:?}"
1360 );
1361 prev = delay;
1362 }
1363 }
1364
1365 #[test]
1366 fn backoff_delay_never_exceeds_cap(
1367 base_ms in 0u64..10000,
1368 attempt in 0u32..100,
1369 ) {
1370 let policy = LlmRetryPolicy::default().with_backoff_base_ms(base_ms);
1371 let delay = policy.backoff_delay(attempt);
1372 prop_assert!(
1373 delay <= Duration::from_millis(MAX_BACKOFF_MS),
1374 "delay {delay:?} exceeds {MAX_BACKOFF_MS}ms cap"
1375 );
1376 }
1377
1378 #[test]
1379 fn backoff_delay_zero_base_always_zero(
1380 attempt in 0u32..100,
1381 ) {
1382 let policy = LlmRetryPolicy::default().with_backoff_base_ms(0);
1383 let delay = policy.backoff_delay(attempt);
1384 prop_assert_eq!(delay, Duration::ZERO);
1385 }
1386 }
1387 }
1388}