use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use awaken_contract::contract::executor::{
InferenceExecutionError, InferenceRequest, InferenceStream, LlmExecutor,
};
use awaken_contract::contract::inference::StreamResult;
use super::circuit_breaker::CircuitBreaker;
const MAX_BACKOFF_MS: u64 = 8_000;
#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
#[serde(deny_unknown_fields)]
pub struct LlmRetryPolicy {
pub max_retries: u32,
pub fallback_upstream_models: Vec<String>,
#[serde(default = "default_backoff_base_ms")]
pub backoff_base_ms: u64,
#[serde(default = "default_overloaded_backoff_base_ms")]
pub overloaded_backoff_base_ms: u64,
#[serde(default = "default_max_stream_retries")]
pub max_stream_retries: u32,
#[serde(default = "default_stream_idle_timeout_secs")]
pub stream_idle_timeout_secs: u64,
}
fn default_backoff_base_ms() -> u64 {
500
}
fn default_overloaded_backoff_base_ms() -> u64 {
2_000
}
fn default_max_stream_retries() -> u32 {
2
}
fn default_stream_idle_timeout_secs() -> u64 {
60
}
impl Default for LlmRetryPolicy {
fn default() -> Self {
Self {
max_retries: 2,
fallback_upstream_models: Vec::new(),
backoff_base_ms: default_backoff_base_ms(),
overloaded_backoff_base_ms: default_overloaded_backoff_base_ms(),
max_stream_retries: default_max_stream_retries(),
stream_idle_timeout_secs: default_stream_idle_timeout_secs(),
}
}
}
impl LlmRetryPolicy {
pub fn no_retry() -> Self {
Self {
max_retries: 0,
..Default::default()
}
}
pub fn with_max_retries(mut self, n: u32) -> Self {
self.max_retries = n;
self
}
pub fn with_fallback_upstream_model(mut self, upstream_model: impl Into<String>) -> Self {
self.fallback_upstream_models.push(upstream_model.into());
self
}
pub fn with_backoff_base_ms(mut self, ms: u64) -> Self {
self.backoff_base_ms = ms;
self
}
pub fn with_overloaded_backoff_base_ms(mut self, ms: u64) -> Self {
self.overloaded_backoff_base_ms = ms;
self
}
pub fn with_max_stream_retries(mut self, n: u32) -> Self {
self.max_stream_retries = n;
self
}
pub fn with_stream_idle_timeout_secs(mut self, secs: u64) -> Self {
self.stream_idle_timeout_secs = secs;
self
}
fn backoff_delay(&self, attempt: u32) -> Duration {
Self::backoff_delay_with_base(self.backoff_base_ms, attempt)
}
fn overloaded_backoff_delay(&self, attempt: u32) -> Duration {
Self::backoff_delay_with_base(self.overloaded_backoff_base_ms, attempt)
}
fn backoff_delay_with_base(base_ms: u64, attempt: u32) -> Duration {
if base_ms == 0 {
return Duration::ZERO;
}
let delay_ms = base_ms
.saturating_mul(1u64 << attempt.min(16))
.min(MAX_BACKOFF_MS);
Duration::from_millis(delay_ms)
}
pub fn delay_before_retry(&self, err: &InferenceExecutionError, attempt: u32) -> Duration {
let base = match err {
InferenceExecutionError::Overloaded { .. } => self.overloaded_backoff_delay(attempt),
_ => self.backoff_delay(attempt),
};
match err.retry_after() {
Some(hint) if hint > base => hint,
_ => base,
}
}
}
fn is_retryable(err: &InferenceExecutionError) -> bool {
err.is_retryable()
}
pub struct RetryingExecutor {
inner: Arc<dyn LlmExecutor>,
policy: LlmRetryPolicy,
circuit_breaker: Option<Arc<CircuitBreaker>>,
}
impl RetryingExecutor {
pub fn new(inner: Arc<dyn LlmExecutor>, policy: LlmRetryPolicy) -> Self {
Self {
inner,
policy,
circuit_breaker: None,
}
}
pub fn with_circuit_breaker(mut self, cb: Arc<CircuitBreaker>) -> Self {
self.circuit_breaker = Some(cb);
self
}
async fn try_with_retries(
&self,
request: &InferenceRequest,
) -> Result<StreamResult, InferenceExecutionError> {
let mut last_error = None;
for attempt in 0..=self.policy.max_retries {
if let Some(ref cb) = self.circuit_breaker {
cb.check(&request.upstream_model)?;
}
match self.inner.execute(request.clone()).await {
Ok(result) => {
if let Some(ref cb) = self.circuit_breaker {
cb.record_success(&request.upstream_model);
}
return Ok(result);
}
Err(err) => {
if err.counts_toward_circuit_breaker() {
if let Some(ref cb) = self.circuit_breaker {
cb.record_failure(&request.upstream_model);
}
}
if !is_retryable(&err) {
return Err(err);
}
if attempt == self.policy.max_retries {
last_error = Some(err);
break;
}
let delay = self.policy.delay_before_retry(&err, attempt);
last_error = Some(err);
if !delay.is_zero() {
tokio::time::sleep(delay).await;
}
}
}
}
Err(last_error.expect("at least one attempt was made"))
}
fn fallback_upstream_models_for_request(&self, request: &InferenceRequest) -> Vec<String> {
request
.overrides
.as_ref()
.and_then(|overrides| overrides.fallback_upstream_models.clone())
.unwrap_or_else(|| self.policy.fallback_upstream_models.clone())
}
async fn try_stream_with_retries(
&self,
request: &InferenceRequest,
) -> Result<InferenceStream, InferenceExecutionError> {
let mut last_error = None;
for attempt in 0..=self.policy.max_retries {
if let Some(ref cb) = self.circuit_breaker {
cb.check(&request.upstream_model)?;
}
match self.inner.execute_stream(request.clone()).await {
Ok(stream) => {
if let Some(ref cb) = self.circuit_breaker {
cb.record_success(&request.upstream_model);
}
return Ok(stream);
}
Err(err) => {
if err.counts_toward_circuit_breaker() {
if let Some(ref cb) = self.circuit_breaker {
cb.record_failure(&request.upstream_model);
}
}
if !is_retryable(&err) {
return Err(err);
}
if attempt == self.policy.max_retries {
last_error = Some(err);
break;
}
let delay = self.policy.delay_before_retry(&err, attempt);
last_error = Some(err);
if !delay.is_zero() {
tokio::time::sleep(delay).await;
}
}
}
}
Err(last_error.expect("at least one stream attempt was made"))
}
fn all_models_blocked(
&self,
request: &InferenceRequest,
fallback_upstream_models: &[String],
) -> bool {
let Some(ref cb) = self.circuit_breaker else {
return false;
};
if cb.check(&request.upstream_model).is_ok() {
return false;
}
fallback_upstream_models
.iter()
.all(|m| cb.check(m).is_err())
}
}
#[async_trait]
impl LlmExecutor for RetryingExecutor {
async fn execute(
&self,
request: InferenceRequest,
) -> Result<StreamResult, InferenceExecutionError> {
let fallback_upstream_models = self.fallback_upstream_models_for_request(&request);
if self.all_models_blocked(&request, &fallback_upstream_models) {
return Err(InferenceExecutionError::AllModelsUnavailable);
}
match self.try_with_retries(&request).await {
Ok(result) => return Ok(result),
Err(err) if !is_retryable(&err) || fallback_upstream_models.is_empty() => {
return Err(err);
}
Err(_) => {}
}
let mut last_error = None;
for (i, fallback_upstream_model) in fallback_upstream_models.iter().enumerate() {
let mut fallback_request = request.clone();
fallback_request.upstream_model = fallback_upstream_model.clone();
match self.try_with_retries(&fallback_request).await {
Ok(result) => return Ok(result),
Err(err) => {
let is_last = i == fallback_upstream_models.len() - 1;
if !is_retryable(&err) || is_last {
last_error = Some(err);
break;
}
last_error = Some(err);
}
}
}
Err(last_error.expect("at least one fallback was attempted"))
}
fn execute_stream(
&self,
request: InferenceRequest,
) -> std::pin::Pin<
Box<
dyn std::future::Future<Output = Result<InferenceStream, InferenceExecutionError>>
+ Send
+ '_,
>,
> {
Box::pin(async move {
let fallback_upstream_models = self.fallback_upstream_models_for_request(&request);
if self.all_models_blocked(&request, &fallback_upstream_models) {
return Err(InferenceExecutionError::AllModelsUnavailable);
}
match self.try_stream_with_retries(&request).await {
Ok(stream) => return Ok(stream),
Err(err) if !is_retryable(&err) || fallback_upstream_models.is_empty() => {
return Err(err);
}
Err(_) => {}
}
let mut last_error = None;
for (i, fallback_upstream_model) in fallback_upstream_models.iter().enumerate() {
let mut fallback_request = request.clone();
fallback_request.upstream_model = fallback_upstream_model.clone();
match self.try_stream_with_retries(&fallback_request).await {
Ok(stream) => return Ok(stream),
Err(err) => {
let is_last = i == fallback_upstream_models.len() - 1;
if !is_retryable(&err) || is_last {
last_error = Some(err);
break;
}
last_error = Some(err);
}
}
}
Err(last_error.expect("at least one stream fallback was attempted"))
})
}
fn name(&self) -> &str {
self.inner.name()
}
}
pub struct RetryConfigKey;
impl awaken_contract::registry_spec::PluginConfigKey for RetryConfigKey {
const KEY: &'static str = "retry";
type Config = LlmRetryPolicy;
}
#[cfg(test)]
mod tests {
use super::*;
use awaken_contract::contract::content::ContentBlock;
use awaken_contract::contract::inference::{InferenceOverride, StopReason, TokenUsage};
use awaken_contract::contract::message::Message;
use std::sync::atomic::{AtomicU32, Ordering};
fn test_policy() -> LlmRetryPolicy {
LlmRetryPolicy::default().with_backoff_base_ms(0)
}
struct FailNThenSucceed {
fail_count: u32,
error_kind: fn(u32) -> InferenceExecutionError,
calls: AtomicU32,
}
impl FailNThenSucceed {
fn new(fail_count: u32) -> Self {
Self {
fail_count,
error_kind: |_| InferenceExecutionError::Provider("transient".into()),
calls: AtomicU32::new(0),
}
}
fn with_error(mut self, f: fn(u32) -> InferenceExecutionError) -> Self {
self.error_kind = f;
self
}
fn call_count(&self) -> u32 {
self.calls.load(Ordering::SeqCst)
}
}
fn ok_result() -> StreamResult {
StreamResult {
content: vec![ContentBlock::text("ok")],
tool_calls: vec![],
usage: Some(TokenUsage {
prompt_tokens: Some(10),
completion_tokens: Some(5),
total_tokens: Some(15),
..Default::default()
}),
stop_reason: Some(StopReason::EndTurn),
has_incomplete_tool_calls: false,
}
}
fn test_request() -> InferenceRequest {
InferenceRequest {
upstream_model: "primary-model".into(),
messages: vec![Message::user("hello")],
tools: vec![],
system: vec![],
overrides: None,
enable_prompt_cache: false,
}
}
#[async_trait]
impl LlmExecutor for FailNThenSucceed {
async fn execute(
&self,
_request: InferenceRequest,
) -> Result<StreamResult, InferenceExecutionError> {
let call = self.calls.fetch_add(1, Ordering::SeqCst);
if call < self.fail_count {
Err((self.error_kind)(call))
} else {
Ok(ok_result())
}
}
fn name(&self) -> &str {
"mock"
}
}
struct ModelRecorder {
models: std::sync::Mutex<Vec<String>>,
error: InferenceExecutionError,
}
impl ModelRecorder {
fn always_fail_with(err: InferenceExecutionError) -> Self {
Self {
models: std::sync::Mutex::new(Vec::new()),
error: err,
}
}
fn recorded_models(&self) -> Vec<String> {
self.models.lock().unwrap().clone()
}
}
#[async_trait]
impl LlmExecutor for ModelRecorder {
async fn execute(
&self,
request: InferenceRequest,
) -> Result<StreamResult, InferenceExecutionError> {
self.models
.lock()
.unwrap()
.push(request.upstream_model.clone());
Err(self.error.clone())
}
fn name(&self) -> &str {
"model-recorder"
}
}
#[tokio::test]
async fn no_retry_policy_first_failure_is_terminal() {
let inner = Arc::new(FailNThenSucceed::new(1));
let executor = RetryingExecutor::new(
inner.clone(),
LlmRetryPolicy::no_retry().with_backoff_base_ms(0),
);
let result = executor.execute(test_request()).await;
assert!(result.is_err());
assert_eq!(inner.call_count(), 1);
}
#[tokio::test]
async fn retry_succeeds_on_second_attempt() {
let inner = Arc::new(FailNThenSucceed::new(1));
let policy = test_policy().with_max_retries(2);
let executor = RetryingExecutor::new(inner.clone(), policy);
let result = executor.execute(test_request()).await;
assert!(result.is_ok());
assert_eq!(inner.call_count(), 2);
}
#[tokio::test]
async fn retry_exhausts_all_attempts_returns_last_error() {
let inner = Arc::new(FailNThenSucceed::new(100)); let policy = test_policy().with_max_retries(3);
let executor = RetryingExecutor::new(inner.clone(), policy);
let result = executor.execute(test_request()).await;
assert!(result.is_err());
assert_eq!(inner.call_count(), 4);
}
#[tokio::test]
async fn non_retryable_error_is_not_retried() {
let inner =
Arc::new(FailNThenSucceed::new(1).with_error(|_| InferenceExecutionError::Cancelled));
let policy = test_policy().with_max_retries(5);
let executor = RetryingExecutor::new(inner.clone(), policy);
let result = executor.execute(test_request()).await;
assert!(result.is_err());
assert_eq!(inner.call_count(), 1);
}
#[tokio::test]
async fn fallback_upstream_model_used_after_primary_exhausts_retries() {
let inner = Arc::new(ModelRecorder::always_fail_with(
InferenceExecutionError::rate_limited("overloaded"),
));
let policy = test_policy()
.with_max_retries(1)
.with_fallback_upstream_model("fallback-a")
.with_fallback_upstream_model("fallback-b");
let executor = RetryingExecutor::new(inner.clone(), policy);
let result = executor.execute(test_request()).await;
assert!(result.is_err());
let models = inner.recorded_models();
assert_eq!(models.len(), 6);
assert_eq!(models[0], "primary-model");
assert_eq!(models[1], "primary-model");
assert_eq!(models[2], "fallback-a");
assert_eq!(models[3], "fallback-a");
assert_eq!(models[4], "fallback-b");
assert_eq!(models[5], "fallback-b");
}
#[tokio::test]
async fn request_override_fallback_upstream_models_replace_policy_fallbacks() {
let inner = Arc::new(ModelRecorder::always_fail_with(
InferenceExecutionError::rate_limited("overloaded"),
));
let policy = test_policy()
.with_max_retries(0)
.with_fallback_upstream_model("policy-fallback");
let executor = RetryingExecutor::new(inner.clone(), policy);
let mut request = test_request();
request.overrides = Some(InferenceOverride {
fallback_upstream_models: Some(vec!["override-fallback".into()]),
..Default::default()
});
let result = executor.execute(request).await;
assert!(result.is_err());
assert_eq!(
inner.recorded_models(),
vec!["primary-model", "override-fallback"]
);
}
#[tokio::test]
async fn execute_stream_retries_stream_start_until_success() {
let inner = Arc::new(FailNThenSucceed::new(1));
let policy = test_policy().with_max_retries(2);
let executor = RetryingExecutor::new(inner.clone(), policy);
let result = executor.execute_stream(test_request()).await;
assert!(result.is_ok());
assert_eq!(inner.call_count(), 2);
}
#[tokio::test]
async fn execute_stream_uses_request_override_fallback_upstream_models() {
let inner = Arc::new(ModelRecorder::always_fail_with(
InferenceExecutionError::rate_limited("overloaded"),
));
let policy = test_policy()
.with_max_retries(0)
.with_fallback_upstream_model("policy-fallback");
let executor = RetryingExecutor::new(inner.clone(), policy);
let mut request = test_request();
request.overrides = Some(InferenceOverride {
fallback_upstream_models: Some(vec!["override-fallback".into()]),
..Default::default()
});
let result = executor.execute_stream(request).await;
assert!(result.is_err());
assert_eq!(
inner.recorded_models(),
vec!["primary-model", "override-fallback"]
);
}
#[tokio::test]
async fn fallback_succeeds_after_primary_fails() {
let inner = Arc::new(FailNThenSucceed::new(3));
let policy = test_policy()
.with_max_retries(1)
.with_fallback_upstream_model("fallback-model");
let executor = RetryingExecutor::new(inner.clone(), policy);
let result = executor.execute(test_request()).await;
assert!(result.is_ok());
assert_eq!(inner.call_count(), 4);
}
#[tokio::test]
async fn succeeds_on_first_try_no_retry_needed() {
let inner = Arc::new(FailNThenSucceed::new(0)); let policy = test_policy().with_max_retries(3);
let executor = RetryingExecutor::new(inner.clone(), policy);
let result = executor.execute(test_request()).await;
assert!(result.is_ok());
assert_eq!(inner.call_count(), 1, "should call executor exactly once");
}
#[tokio::test]
async fn retrying_executor_delegates_name() {
let inner = Arc::new(FailNThenSucceed::new(0));
let executor = RetryingExecutor::new(inner, test_policy());
assert_eq!(executor.name(), "mock");
}
#[tokio::test]
async fn non_retryable_error_during_fallback_stops_immediately() {
let call_count = Arc::new(AtomicU32::new(0));
let cc = call_count.clone();
struct PrimaryRetryableFallbackFatal {
calls: Arc<AtomicU32>,
}
#[async_trait]
impl LlmExecutor for PrimaryRetryableFallbackFatal {
async fn execute(
&self,
request: InferenceRequest,
) -> Result<StreamResult, InferenceExecutionError> {
let n = self.calls.fetch_add(1, Ordering::SeqCst);
if request.upstream_model.starts_with("primary") {
Err(InferenceExecutionError::Provider("down".into()))
} else {
let _ = n;
Err(InferenceExecutionError::Cancelled)
}
}
fn name(&self) -> &str {
"primary-retryable-fallback-fatal"
}
}
let inner = Arc::new(PrimaryRetryableFallbackFatal { calls: cc });
let policy = test_policy()
.with_max_retries(0)
.with_fallback_upstream_model("fallback-a")
.with_fallback_upstream_model("fallback-b");
let executor = RetryingExecutor::new(inner, policy);
let result = executor.execute(test_request()).await;
assert!(result.is_err());
assert_eq!(call_count.load(Ordering::SeqCst), 2);
}
#[test]
fn default_policy_values() {
let policy = LlmRetryPolicy::default();
assert_eq!(policy.max_retries, 2);
assert!(policy.fallback_upstream_models.is_empty());
assert_eq!(policy.backoff_base_ms, 500);
assert_eq!(policy.overloaded_backoff_base_ms, 2_000);
assert_eq!(policy.max_stream_retries, 2);
assert_eq!(policy.stream_idle_timeout_secs, 60);
}
#[test]
fn no_retry_policy_values() {
let policy = LlmRetryPolicy::no_retry();
assert_eq!(policy.max_retries, 0);
assert!(policy.fallback_upstream_models.is_empty());
}
#[test]
fn rate_limit_error_is_retryable() {
assert!(is_retryable(&InferenceExecutionError::rate_limited("429")));
}
#[test]
fn overloaded_error_is_retryable() {
assert!(is_retryable(&InferenceExecutionError::overloaded("529")));
}
#[test]
fn context_overflow_is_not_retryable() {
assert!(!is_retryable(&InferenceExecutionError::ContextOverflow(
"too long".into()
)));
}
#[test]
fn context_overflow_does_not_count_toward_breaker() {
let err = InferenceExecutionError::ContextOverflow("too long".into());
assert!(!err.counts_toward_circuit_breaker());
}
#[test]
fn invalid_request_does_not_count_toward_breaker() {
assert!(
!InferenceExecutionError::InvalidRequest("schema".into())
.counts_toward_circuit_breaker()
);
}
#[test]
fn unauthorized_does_not_count_toward_breaker() {
assert!(
!InferenceExecutionError::Unauthorized("key".into()).counts_toward_circuit_breaker()
);
}
#[test]
fn all_models_unavailable_is_fail_fast() {
let err = InferenceExecutionError::AllModelsUnavailable;
assert!(!err.is_retryable());
assert!(!err.counts_toward_circuit_breaker());
}
#[test]
fn server_error_is_retryable() {
assert!(is_retryable(&InferenceExecutionError::Provider(
"500 internal".into()
)));
}
#[test]
fn timeout_error_is_retryable() {
assert!(is_retryable(&InferenceExecutionError::Timeout(
"timed out".into()
)));
}
#[test]
fn cancelled_error_is_not_retryable() {
assert!(!is_retryable(&InferenceExecutionError::Cancelled));
}
#[test]
fn builder_methods_chain() {
let policy = LlmRetryPolicy::default()
.with_max_retries(5)
.with_fallback_upstream_model("model-a")
.with_fallback_upstream_model("model-b")
.with_backoff_base_ms(100);
assert_eq!(policy.max_retries, 5);
assert_eq!(policy.fallback_upstream_models, vec!["model-a", "model-b"]);
assert_eq!(policy.backoff_base_ms, 100);
}
#[test]
fn backoff_delay_zero_base() {
let policy = LlmRetryPolicy::default().with_backoff_base_ms(0);
assert_eq!(policy.backoff_delay(0), Duration::ZERO);
assert_eq!(policy.backoff_delay(5), Duration::ZERO);
}
#[test]
fn backoff_delay_exponential() {
let policy = LlmRetryPolicy::default().with_backoff_base_ms(500);
assert_eq!(policy.backoff_delay(0), Duration::from_millis(500)); assert_eq!(policy.backoff_delay(1), Duration::from_millis(1000)); assert_eq!(policy.backoff_delay(2), Duration::from_millis(2000)); assert_eq!(policy.backoff_delay(3), Duration::from_millis(4000)); }
#[test]
fn backoff_delay_caps_at_max() {
let policy = LlmRetryPolicy::default().with_backoff_base_ms(500);
assert_eq!(policy.backoff_delay(4), Duration::from_millis(8000));
assert_eq!(policy.backoff_delay(5), Duration::from_millis(8000));
}
#[tokio::test]
async fn circuit_breaker_blocks_when_open() {
use crate::engine::circuit_breaker::CircuitBreakerConfig;
let inner = Arc::new(FailNThenSucceed::new(100));
let cb = Arc::new(CircuitBreaker::new(CircuitBreakerConfig {
failure_threshold: 2,
cooldown: std::time::Duration::from_secs(60),
half_open_max: 1,
}));
cb.record_failure("primary-model");
cb.record_failure("primary-model");
let policy = test_policy().with_max_retries(3);
let executor = RetryingExecutor::new(inner.clone(), policy).with_circuit_breaker(cb);
let result = executor.execute(test_request()).await;
assert!(result.is_err());
assert_eq!(inner.call_count(), 0);
}
#[tokio::test]
async fn circuit_breaker_records_success() {
use crate::engine::circuit_breaker::CircuitBreakerConfig;
let inner = Arc::new(FailNThenSucceed::new(0));
let cb = Arc::new(CircuitBreaker::new(CircuitBreakerConfig {
failure_threshold: 2,
cooldown: std::time::Duration::from_secs(60),
half_open_max: 1,
}));
cb.record_failure("primary-model");
let policy = test_policy().with_max_retries(1);
let executor =
RetryingExecutor::new(inner.clone(), policy).with_circuit_breaker(cb.clone());
let result = executor.execute(test_request()).await;
assert!(result.is_ok());
cb.record_failure("primary-model");
assert!(cb.check("primary-model").is_ok());
}
#[tokio::test]
async fn retry_on_rate_limit_then_succeed() {
let inner = Arc::new(
FailNThenSucceed::new(2)
.with_error(|_| InferenceExecutionError::rate_limited("rate limited")),
);
let policy = test_policy().with_max_retries(3);
let executor = RetryingExecutor::new(inner.clone(), policy);
let result = executor.execute(test_request()).await;
assert!(result.is_ok());
assert_eq!(inner.call_count(), 3); }
#[tokio::test]
async fn retry_on_timeout_then_succeed() {
let inner = Arc::new(
FailNThenSucceed::new(1)
.with_error(|_| InferenceExecutionError::Timeout("timed out".into())),
);
let policy = test_policy().with_max_retries(2);
let executor = RetryingExecutor::new(inner.clone(), policy);
let result = executor.execute(test_request()).await;
assert!(result.is_ok());
assert_eq!(inner.call_count(), 2);
}
#[tokio::test]
async fn zero_retries_with_fallback_tries_fallback_once() {
let inner = Arc::new(FailNThenSucceed::new(1)); let policy = test_policy()
.with_max_retries(0)
.with_fallback_upstream_model("fallback");
let executor = RetryingExecutor::new(inner.clone(), policy);
let result = executor.execute(test_request()).await;
assert!(result.is_ok());
assert_eq!(inner.call_count(), 2); }
#[tokio::test]
async fn no_fallback_upstream_models_configured_returns_primary_error() {
let inner = Arc::new(FailNThenSucceed::new(100));
let policy = test_policy().with_max_retries(1);
let executor = RetryingExecutor::new(inner.clone(), policy);
let result = executor.execute(test_request()).await;
assert!(result.is_err());
assert_eq!(inner.call_count(), 2); }
#[tokio::test]
async fn all_error_types_handled() {
for error_fn in [
(|_: u32| InferenceExecutionError::Provider("down".into())) as fn(u32) -> _,
|_| InferenceExecutionError::rate_limited("429"),
|_| InferenceExecutionError::Timeout("timeout".into()),
] {
let inner = Arc::new(FailNThenSucceed::new(1).with_error(error_fn));
let policy = test_policy().with_max_retries(2);
let executor = RetryingExecutor::new(inner.clone(), policy);
let result = executor.execute(test_request()).await;
assert!(result.is_ok(), "should recover from retryable error");
}
}
#[tokio::test]
async fn max_retries_zero_and_no_fallback_just_one_attempt() {
let inner = Arc::new(FailNThenSucceed::new(100));
let policy = LlmRetryPolicy::no_retry().with_backoff_base_ms(0);
let executor = RetryingExecutor::new(inner.clone(), policy);
let result = executor.execute(test_request()).await;
assert!(result.is_err());
assert_eq!(inner.call_count(), 1);
}
#[tokio::test]
async fn success_on_first_try_no_fallback_attempted() {
let recorder = Arc::new(ModelRecorder::always_fail_with(
InferenceExecutionError::Provider("down".into()),
));
let inner = Arc::new(FailNThenSucceed::new(0)); let policy = test_policy()
.with_max_retries(3)
.with_fallback_upstream_model("fallback-a");
let executor = RetryingExecutor::new(inner.clone(), policy);
let result = executor.execute(test_request()).await;
assert!(result.is_ok());
assert_eq!(inner.call_count(), 1, "should not attempt fallback");
let _ = recorder; }
#[test]
fn retry_policy_serde_roundtrip() {
let policy = LlmRetryPolicy::default()
.with_max_retries(5)
.with_fallback_upstream_model("fallback-a")
.with_fallback_upstream_model("fallback-b")
.with_backoff_base_ms(200)
.with_overloaded_backoff_base_ms(4_000)
.with_max_stream_retries(3)
.with_stream_idle_timeout_secs(90);
let json = serde_json::to_string(&policy).unwrap();
let parsed: LlmRetryPolicy = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.max_retries, 5);
assert_eq!(
parsed.fallback_upstream_models,
vec!["fallback-a", "fallback-b"]
);
assert_eq!(parsed.backoff_base_ms, 200);
assert_eq!(parsed.overloaded_backoff_base_ms, 4_000);
assert_eq!(parsed.max_stream_retries, 3);
assert_eq!(parsed.stream_idle_timeout_secs, 90);
}
#[test]
fn retry_policy_serde_default_backoff() {
let json = r#"{"max_retries":2,"fallback_upstream_models":[]}"#;
let parsed: LlmRetryPolicy = serde_json::from_str(json).unwrap();
assert_eq!(parsed.backoff_base_ms, 500);
assert_eq!(parsed.overloaded_backoff_base_ms, 2_000);
assert_eq!(parsed.max_stream_retries, 2);
assert_eq!(parsed.stream_idle_timeout_secs, 60);
}
#[test]
fn retry_policy_rejects_legacy_fallback_field() {
let json = r#"{"max_retries":2,"fallback_models":[]}"#;
let parsed = serde_json::from_str::<LlmRetryPolicy>(json);
assert!(parsed.is_err());
}
#[tokio::test]
async fn retry_budget_exact_boundary() {
let inner = Arc::new(FailNThenSucceed::new(2));
let policy = test_policy().with_max_retries(2);
let executor = RetryingExecutor::new(inner.clone(), policy);
let result = executor.execute(test_request()).await;
assert!(result.is_ok());
assert_eq!(inner.call_count(), 3);
}
#[tokio::test]
async fn retry_budget_one_over_boundary() {
let inner = Arc::new(FailNThenSucceed::new(3));
let policy = test_policy().with_max_retries(2);
let executor = RetryingExecutor::new(inner.clone(), policy);
let result = executor.execute(test_request()).await;
assert!(result.is_err());
assert_eq!(inner.call_count(), 3, "1 initial + 2 retries = 3 calls");
}
#[tokio::test]
async fn circuit_breaker_opens_during_retry_sequence() {
use crate::engine::circuit_breaker::CircuitBreakerConfig;
let cb = Arc::new(CircuitBreaker::new(CircuitBreakerConfig {
failure_threshold: 2,
cooldown: Duration::from_secs(60),
half_open_max: 1,
}));
let inner = Arc::new(FailNThenSucceed::new(100)); let policy = test_policy().with_max_retries(5);
let executor = RetryingExecutor::new(inner.clone(), policy).with_circuit_breaker(cb);
let result = executor.execute(test_request()).await;
assert!(result.is_err());
assert_eq!(inner.call_count(), 2);
}
#[tokio::test]
async fn circuit_breaker_independent_per_model_in_fallback() {
use crate::engine::circuit_breaker::CircuitBreakerConfig;
let cb = Arc::new(CircuitBreaker::new(CircuitBreakerConfig {
failure_threshold: 2,
cooldown: Duration::from_secs(60),
half_open_max: 1,
}));
cb.record_failure("primary-model");
cb.record_failure("primary-model");
let inner = Arc::new(FailNThenSucceed::new(0));
let policy = test_policy()
.with_max_retries(0)
.with_fallback_upstream_model("fallback-model");
let executor = RetryingExecutor::new(inner.clone(), policy).with_circuit_breaker(cb);
let result = executor.execute(test_request()).await;
assert!(result.is_ok());
assert_eq!(inner.call_count(), 1);
}
#[test]
fn delay_before_retry_respects_retry_after_when_longer() {
let policy = LlmRetryPolicy::default().with_backoff_base_ms(100);
let err = InferenceExecutionError::RateLimited {
message: "slow".into(),
retry_after: Some(Duration::from_secs(5)),
};
assert_eq!(policy.delay_before_retry(&err, 0), Duration::from_secs(5));
}
#[test]
fn delay_before_retry_uses_exponential_when_longer_than_retry_after() {
let policy = LlmRetryPolicy::default().with_backoff_base_ms(10_000);
let err = InferenceExecutionError::RateLimited {
message: "fast hint".into(),
retry_after: Some(Duration::from_millis(100)),
};
assert_eq!(
policy.delay_before_retry(&err, 0),
Duration::from_millis(MAX_BACKOFF_MS)
);
}
#[test]
fn delay_before_retry_uses_overloaded_base_for_overloaded_errors() {
let policy = LlmRetryPolicy::default()
.with_backoff_base_ms(500)
.with_overloaded_backoff_base_ms(2_000);
let overloaded = InferenceExecutionError::overloaded("surge");
assert_eq!(
policy.delay_before_retry(&overloaded, 0),
Duration::from_millis(2_000)
);
}
#[tokio::test(start_paused = true)]
async fn rate_limited_retry_after_waits_hint_duration() {
let inner = Arc::new(FailNThenSucceed::new(1).with_error(|_| {
InferenceExecutionError::RateLimited {
message: "slow down".into(),
retry_after: Some(Duration::from_secs(3)),
}
}));
let policy = LlmRetryPolicy::default()
.with_max_retries(2)
.with_backoff_base_ms(10); let executor = RetryingExecutor::new(inner.clone(), policy);
let start = tokio::time::Instant::now();
let result = executor.execute(test_request()).await;
assert!(result.is_ok());
let elapsed = start.elapsed();
assert!(
elapsed >= Duration::from_secs(3),
"expected >=3s retry-after wait, got {elapsed:?}"
);
assert_eq!(inner.call_count(), 2);
}
#[tokio::test]
async fn context_overflow_error_is_not_retried() {
let inner =
Arc::new(FailNThenSucceed::new(5).with_error(|_| {
InferenceExecutionError::ContextOverflow("prompt too long".into())
}));
let policy = test_policy().with_max_retries(3);
let executor = RetryingExecutor::new(inner.clone(), policy);
let result = executor.execute(test_request()).await;
assert!(matches!(
result,
Err(InferenceExecutionError::ContextOverflow(_))
));
assert_eq!(inner.call_count(), 1, "permanent error must not retry");
}
#[tokio::test]
async fn context_overflow_does_not_trip_circuit_breaker() {
use crate::engine::circuit_breaker::CircuitBreakerConfig;
let inner = Arc::new(
FailNThenSucceed::new(100)
.with_error(|_| InferenceExecutionError::ContextOverflow("too long".into())),
);
let cb = Arc::new(CircuitBreaker::new(CircuitBreakerConfig {
failure_threshold: 2,
cooldown: Duration::from_secs(60),
half_open_max: 1,
}));
let policy = test_policy().with_max_retries(0);
let executor =
RetryingExecutor::new(inner.clone(), policy).with_circuit_breaker(cb.clone());
for _ in 0..5 {
let _ = executor.execute(test_request()).await;
}
assert!(
cb.check("primary-model").is_ok(),
"ContextOverflow must not increment the breaker"
);
}
#[tokio::test]
async fn all_models_blocked_short_circuits_with_all_models_unavailable() {
use crate::engine::circuit_breaker::CircuitBreakerConfig;
let cb = Arc::new(CircuitBreaker::new(CircuitBreakerConfig {
failure_threshold: 1,
cooldown: Duration::from_secs(60),
half_open_max: 1,
}));
cb.record_failure("primary-model");
cb.record_failure("fallback-a");
cb.record_failure("fallback-b");
let inner = Arc::new(FailNThenSucceed::new(0)); let policy = test_policy()
.with_max_retries(2)
.with_fallback_upstream_model("fallback-a")
.with_fallback_upstream_model("fallback-b");
let executor =
RetryingExecutor::new(inner.clone(), policy).with_circuit_breaker(cb.clone());
let result = executor.execute(test_request()).await;
assert!(
matches!(result, Err(InferenceExecutionError::AllModelsUnavailable)),
"expected AllModelsUnavailable, got {result:?}"
);
assert_eq!(inner.call_count(), 0, "no inner call should be made");
}
#[tokio::test(start_paused = true)]
async fn backoff_actually_sleeps() {
let inner = Arc::new(FailNThenSucceed::new(2));
let policy = LlmRetryPolicy::default()
.with_max_retries(3)
.with_backoff_base_ms(1000); let executor = RetryingExecutor::new(inner.clone(), policy);
let start = tokio::time::Instant::now();
let result = executor.execute(test_request()).await;
assert!(result.is_ok());
let elapsed = start.elapsed();
assert!(
elapsed >= Duration::from_secs(3),
"expected >= 3s backoff, got {elapsed:?}"
);
}
mod proptest_retry {
use super::*;
use proptest::prelude::*;
proptest! {
#[test]
fn llm_retry_policy_serde_roundtrip(
max_retries in 0u32..10,
backoff_base_ms in 0u64..10000,
overloaded_backoff_base_ms in 0u64..10000,
max_stream_retries in 0u32..10,
stream_idle_timeout_secs in 1u64..300,
num_fallbacks in 0usize..5,
) {
let policy = LlmRetryPolicy {
max_retries,
fallback_upstream_models: (0..num_fallbacks).map(|i| format!("model-{i}")).collect(),
backoff_base_ms,
overloaded_backoff_base_ms,
max_stream_retries,
stream_idle_timeout_secs,
};
let json = serde_json::to_string(&policy).unwrap();
let parsed: LlmRetryPolicy = serde_json::from_str(&json).unwrap();
prop_assert_eq!(parsed.max_retries, max_retries);
prop_assert_eq!(parsed.backoff_base_ms, backoff_base_ms);
prop_assert_eq!(parsed.overloaded_backoff_base_ms, overloaded_backoff_base_ms);
prop_assert_eq!(parsed.max_stream_retries, max_stream_retries);
prop_assert_eq!(parsed.stream_idle_timeout_secs, stream_idle_timeout_secs);
prop_assert_eq!(parsed.fallback_upstream_models.len(), num_fallbacks);
}
#[test]
fn backoff_delay_is_monotonically_non_decreasing(
base_ms in 1u64..1000,
) {
let policy = LlmRetryPolicy::default().with_backoff_base_ms(base_ms);
let mut prev = Duration::ZERO;
for attempt in 0..10u32 {
let delay = policy.backoff_delay(attempt);
prop_assert!(
delay >= prev,
"delay should be monotonically non-decreasing: attempt={attempt}, delay={delay:?}, prev={prev:?}"
);
prev = delay;
}
}
#[test]
fn backoff_delay_never_exceeds_cap(
base_ms in 0u64..10000,
attempt in 0u32..100,
) {
let policy = LlmRetryPolicy::default().with_backoff_base_ms(base_ms);
let delay = policy.backoff_delay(attempt);
prop_assert!(
delay <= Duration::from_millis(MAX_BACKOFF_MS),
"delay {delay:?} exceeds {MAX_BACKOFF_MS}ms cap"
);
}
#[test]
fn backoff_delay_zero_base_always_zero(
attempt in 0u32..100,
) {
let policy = LlmRetryPolicy::default().with_backoff_base_ms(0);
let delay = policy.backoff_delay(attempt);
prop_assert_eq!(delay, Duration::ZERO);
}
}
}
}