use async_trait::async_trait;
use tracing::warn;
use crate::error::{Result, ZeptoError};
use crate::session::Message;
use super::{ChatOptions, LLMProvider, LLMResponse, StreamEvent, ToolDefinition};
pub struct RetryProvider {
inner: Box<dyn LLMProvider>,
max_retries: u32,
base_delay_ms: u64,
max_delay_ms: u64,
retry_budget_ms: u64,
}
impl std::fmt::Debug for RetryProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RetryProvider")
.field("inner", &self.inner.name())
.field("max_retries", &self.max_retries)
.field("base_delay_ms", &self.base_delay_ms)
.field("max_delay_ms", &self.max_delay_ms)
.field("retry_budget_ms", &self.retry_budget_ms)
.finish()
}
}
impl RetryProvider {
pub fn new(inner: Box<dyn LLMProvider>) -> Self {
Self {
inner,
max_retries: 3,
base_delay_ms: 1000,
max_delay_ms: 30_000,
retry_budget_ms: 45_000,
}
}
pub fn with_max_retries(mut self, max_retries: u32) -> Self {
self.max_retries = max_retries;
self
}
pub fn with_base_delay_ms(mut self, base_delay_ms: u64) -> Self {
self.base_delay_ms = base_delay_ms;
self
}
pub fn with_max_delay_ms(mut self, max_delay_ms: u64) -> Self {
self.max_delay_ms = max_delay_ms;
self
}
pub fn with_retry_budget_ms(mut self, retry_budget_ms: u64) -> Self {
self.retry_budget_ms = retry_budget_ms;
self
}
fn budget_exceeded(&self, start: std::time::Instant) -> bool {
self.retry_budget_ms > 0 && start.elapsed().as_millis() as u64 >= self.retry_budget_ms
}
}
fn is_context_window_exceeded(msg: &str) -> bool {
let lower = msg.to_lowercase();
let hints = [
"exceeds the context window",
"context window of this model",
"maximum context length",
"context length exceeded",
"max_tokens is too large",
];
hints.iter().any(|h| lower.contains(h))
}
fn is_auth_failure(msg: &str) -> bool {
let lower = msg.to_lowercase();
let hints = [
"invalid api key",
"incorrect api key",
"missing api key",
"api key not set",
"authentication failed",
"auth failed",
"unauthorized",
"forbidden",
"permission denied",
"access denied",
"invalid_token",
];
hints.iter().any(|h| lower.contains(h))
}
fn is_model_not_found(msg: &str) -> bool {
let lower = msg.to_lowercase();
lower.contains("model")
&& (lower.contains("not found")
|| lower.contains("does not exist")
|| lower.contains("unknown model")
|| lower.contains("unsupported model"))
}
pub fn is_retryable(err: &ZeptoError) -> bool {
match err {
ZeptoError::ProviderTyped(pe) => pe.is_retryable(),
ZeptoError::Provider(msg) => {
let lower = msg.to_lowercase();
if is_context_window_exceeded(msg) || is_auth_failure(msg) || is_model_not_found(msg) {
return false;
}
let non_retryable_codes = ["http 400", "http 401", "http 403", "http 404"];
if non_retryable_codes.iter().any(|c| lower.contains(c)) {
return false;
}
(lower.contains("rate limit") || lower.contains("rate_limit"))
|| lower.contains("429")
|| lower.contains("500")
|| lower.contains("overload")
|| lower.contains("server error")
|| lower.contains("502")
|| lower.contains("503")
|| lower.contains("504")
|| lower.contains("timeout")
}
_ => false,
}
}
pub async fn delay_with_jitter(attempt: u32, base_delay_ms: u64, max_delay_ms: u64) {
let exponential = base_delay_ms.saturating_mul(1u64 << attempt.min(16));
let jitter_ms = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.subsec_nanos() as u64 % (base_delay_ms.max(1)))
.unwrap_or(0);
let delay = exponential.saturating_add(jitter_ms).min(max_delay_ms);
tokio::time::sleep(std::time::Duration::from_millis(delay)).await;
}
pub fn compute_delay(attempt: u32, base_delay_ms: u64, max_delay_ms: u64, jitter_ms: u64) -> u64 {
let exponential = base_delay_ms.saturating_mul(1u64 << attempt.min(16));
exponential.saturating_add(jitter_ms).min(max_delay_ms)
}
#[async_trait]
impl LLMProvider for RetryProvider {
fn name(&self) -> &str {
self.inner.name()
}
fn default_model(&self) -> &str {
self.inner.default_model()
}
async fn chat(
&self,
messages: Vec<Message>,
tools: Vec<ToolDefinition>,
model: Option<&str>,
options: ChatOptions,
) -> Result<LLMResponse> {
let mut last_err: Option<ZeptoError> = None;
let start = std::time::Instant::now();
for attempt in 0..self.max_retries {
if attempt > 0 {
if self.budget_exceeded(start) {
warn!(
provider = self.inner.name(),
elapsed_ms = start.elapsed().as_millis() as u64,
budget_ms = self.retry_budget_ms,
"Retry budget exhausted for chat request"
);
return Err(last_err.unwrap());
}
if let Some(ref err) = last_err {
warn!(
provider = self.inner.name(),
attempt = attempt,
max_retries = self.max_retries,
error = %err,
"Retrying chat request after transient error"
);
}
delay_with_jitter(attempt - 1, self.base_delay_ms, self.max_delay_ms).await;
}
match self
.inner
.chat(messages.clone(), tools.clone(), model, options.clone())
.await
{
Ok(response) => return Ok(response),
Err(err) => {
if !is_retryable(&err) {
return Err(err);
}
last_err = Some(err);
}
}
}
if self.max_retries > 0 {
if self.budget_exceeded(start) {
warn!(
provider = self.inner.name(),
elapsed_ms = start.elapsed().as_millis() as u64,
budget_ms = self.retry_budget_ms,
"Retry budget exhausted for chat request"
);
return Err(last_err.unwrap());
}
if let Some(ref err) = last_err {
warn!(
provider = self.inner.name(),
attempt = self.max_retries,
max_retries = self.max_retries,
error = %err,
"Retrying chat request after transient error"
);
}
delay_with_jitter(self.max_retries - 1, self.base_delay_ms, self.max_delay_ms).await;
}
self.inner.chat(messages, tools, model, options).await
}
async fn chat_stream(
&self,
messages: Vec<Message>,
tools: Vec<ToolDefinition>,
model: Option<&str>,
options: ChatOptions,
) -> Result<tokio::sync::mpsc::Receiver<StreamEvent>> {
let mut last_err: Option<ZeptoError> = None;
let start = std::time::Instant::now();
for attempt in 0..self.max_retries {
if attempt > 0 {
if self.budget_exceeded(start) {
warn!(
provider = self.inner.name(),
elapsed_ms = start.elapsed().as_millis() as u64,
budget_ms = self.retry_budget_ms,
"Retry budget exhausted for chat_stream request"
);
return Err(last_err.unwrap());
}
if let Some(ref err) = last_err {
warn!(
provider = self.inner.name(),
attempt = attempt,
max_retries = self.max_retries,
error = %err,
"Retrying chat_stream request after transient error"
);
}
delay_with_jitter(attempt - 1, self.base_delay_ms, self.max_delay_ms).await;
}
match self
.inner
.chat_stream(messages.clone(), tools.clone(), model, options.clone())
.await
{
Ok(receiver) => return Ok(receiver),
Err(err) => {
if !is_retryable(&err) {
return Err(err);
}
last_err = Some(err);
}
}
}
if self.max_retries > 0 {
if self.budget_exceeded(start) {
warn!(
provider = self.inner.name(),
elapsed_ms = start.elapsed().as_millis() as u64,
budget_ms = self.retry_budget_ms,
"Retry budget exhausted for chat_stream request"
);
return Err(last_err.unwrap());
}
if let Some(ref err) = last_err {
warn!(
provider = self.inner.name(),
attempt = self.max_retries,
max_retries = self.max_retries,
error = %err,
"Retrying chat_stream request after transient error"
);
}
delay_with_jitter(self.max_retries - 1, self.base_delay_ms, self.max_delay_ms).await;
}
self.inner
.chat_stream(messages, tools, model, options)
.await
}
async fn embed(&self, texts: &[String]) -> crate::error::Result<Vec<Vec<f32>>> {
self.inner.embed(texts).await
}
}
#[cfg(test)]
mod tests {
use super::*;
struct MockProvider {
name: &'static str,
model: &'static str,
}
impl MockProvider {
fn new(name: &'static str, model: &'static str) -> Self {
Self { name, model }
}
}
#[async_trait]
impl LLMProvider for MockProvider {
fn name(&self) -> &str {
self.name
}
fn default_model(&self) -> &str {
self.model
}
async fn chat(
&self,
_messages: Vec<Message>,
_tools: Vec<ToolDefinition>,
_model: Option<&str>,
_options: ChatOptions,
) -> Result<LLMResponse> {
Ok(LLMResponse::text("mock response"))
}
}
#[test]
fn test_retry_provider_creation() {
let mock = MockProvider::new("test-provider", "test-model-v1");
let provider = RetryProvider::new(Box::new(mock));
assert_eq!(provider.name(), "test-provider");
assert_eq!(provider.default_model(), "test-model-v1");
assert_eq!(provider.max_retries, 3);
assert_eq!(provider.base_delay_ms, 1000);
assert_eq!(provider.max_delay_ms, 30_000);
assert_eq!(provider.retry_budget_ms, 45_000);
}
#[test]
fn test_retry_provider_builder() {
let mock = MockProvider::new("test", "model");
let provider = RetryProvider::new(Box::new(mock))
.with_max_retries(5)
.with_base_delay_ms(500)
.with_max_delay_ms(60_000)
.with_retry_budget_ms(10_000);
assert_eq!(provider.max_retries, 5);
assert_eq!(provider.base_delay_ms, 500);
assert_eq!(provider.max_delay_ms, 60_000);
assert_eq!(provider.retry_budget_ms, 10_000);
}
#[test]
fn test_is_retryable_429() {
let err = ZeptoError::Provider("HTTP 429 Too Many Requests".to_string());
assert!(is_retryable(&err));
}
#[test]
fn test_is_retryable_500() {
let err = ZeptoError::Provider("HTTP 500 Internal Server Error".to_string());
assert!(is_retryable(&err));
}
#[test]
fn test_is_retryable_502() {
let err = ZeptoError::Provider("HTTP 502 Bad Gateway".to_string());
assert!(is_retryable(&err));
}
#[test]
fn test_is_retryable_503() {
let err = ZeptoError::Provider("HTTP 503 Service Unavailable".to_string());
assert!(is_retryable(&err));
}
#[test]
fn test_is_retryable_504() {
let err = ZeptoError::Provider("HTTP 504 Gateway Timeout".to_string());
assert!(is_retryable(&err));
}
#[test]
fn test_is_retryable_rate_limit() {
let err = ZeptoError::Provider("Rate limit exceeded, please retry".to_string());
assert!(is_retryable(&err));
}
#[test]
fn test_is_retryable_rate_limit_underscore() {
let err = ZeptoError::Provider("rate_limit_exceeded".to_string());
assert!(is_retryable(&err));
}
#[test]
fn test_is_retryable_overloaded() {
let err = ZeptoError::Provider("Model is overloaded, try again later".to_string());
assert!(is_retryable(&err));
}
#[test]
fn test_is_retryable_400() {
let err = ZeptoError::Provider("HTTP 400 Bad Request: invalid JSON".to_string());
assert!(!is_retryable(&err));
}
#[test]
fn test_is_retryable_401() {
let err = ZeptoError::Provider("HTTP 401 Unauthorized: invalid API key".to_string());
assert!(!is_retryable(&err));
}
#[test]
fn test_is_retryable_403() {
let err = ZeptoError::Provider("HTTP 403 Forbidden".to_string());
assert!(!is_retryable(&err));
}
#[test]
fn test_is_retryable_404() {
let err = ZeptoError::Provider("HTTP 404 Not Found: model not available".to_string());
assert!(!is_retryable(&err));
}
#[test]
fn test_is_retryable_generic_error() {
let err = ZeptoError::Provider("Connection reset by peer".to_string());
assert!(!is_retryable(&err));
}
#[test]
fn test_is_retryable_non_provider_error() {
let err = ZeptoError::Config("Missing API key".to_string());
assert!(!is_retryable(&err));
}
#[test]
fn test_delay_calculation_attempt_0() {
let delay = compute_delay(0, 1000, 30_000, 0);
assert_eq!(delay, 1000);
}
#[test]
fn test_delay_calculation_attempt_1() {
let delay = compute_delay(1, 1000, 30_000, 0);
assert_eq!(delay, 2000);
}
#[test]
fn test_delay_calculation_attempt_2() {
let delay = compute_delay(2, 1000, 30_000, 0);
assert_eq!(delay, 4000);
}
#[test]
fn test_delay_calculation_attempt_3() {
let delay = compute_delay(3, 1000, 30_000, 0);
assert_eq!(delay, 8000);
}
#[test]
fn test_delay_calculation_with_jitter() {
let delay = compute_delay(1, 1000, 30_000, 200);
assert_eq!(delay, 2200);
}
#[test]
fn test_delay_calculation_capped_at_max() {
let delay = compute_delay(10, 1000, 30_000, 0);
assert_eq!(delay, 30_000);
}
#[test]
fn test_delay_calculation_max_with_jitter_still_capped() {
let delay = compute_delay(10, 1000, 30_000, 5000);
assert_eq!(delay, 30_000);
}
#[test]
fn test_delay_calculation_custom_base() {
let delay = compute_delay(0, 500, 30_000, 0);
assert_eq!(delay, 500);
let delay = compute_delay(2, 500, 30_000, 0);
assert_eq!(delay, 2000);
}
#[tokio::test]
async fn test_retry_provider_chat_success() {
let mock = MockProvider::new("test", "model");
let provider = RetryProvider::new(Box::new(mock));
let result = provider
.chat(vec![], vec![], None, ChatOptions::default())
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap().content, "mock response");
}
#[tokio::test]
async fn test_retry_provider_chat_stream_success() {
let mock = MockProvider::new("test", "model");
let provider = RetryProvider::new(Box::new(mock));
let result = provider
.chat_stream(vec![], vec![], None, ChatOptions::default())
.await;
assert!(result.is_ok());
let mut rx = result.unwrap();
let event = rx.recv().await.unwrap();
match event {
StreamEvent::Done { content, .. } => {
assert_eq!(content, "mock response");
}
_ => panic!("Expected Done event"),
}
}
struct FailThenSucceedProvider {
fail_count: std::sync::atomic::AtomicU32,
target_failures: u32,
error_message: String,
}
impl FailThenSucceedProvider {
fn new(target_failures: u32, error_message: &str) -> Self {
Self {
fail_count: std::sync::atomic::AtomicU32::new(0),
target_failures,
error_message: error_message.to_string(),
}
}
}
#[async_trait]
impl LLMProvider for FailThenSucceedProvider {
fn name(&self) -> &str {
"fail-then-succeed"
}
fn default_model(&self) -> &str {
"test-model"
}
async fn chat(
&self,
_messages: Vec<Message>,
_tools: Vec<ToolDefinition>,
_model: Option<&str>,
_options: ChatOptions,
) -> Result<LLMResponse> {
let count = self
.fail_count
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
if count < self.target_failures {
Err(ZeptoError::Provider(self.error_message.clone()))
} else {
Ok(LLMResponse::text("recovered"))
}
}
}
#[tokio::test]
async fn test_retry_provider_retries_on_429() {
let inner = FailThenSucceedProvider::new(2, "HTTP 429 Too Many Requests");
let provider = RetryProvider::new(Box::new(inner))
.with_max_retries(3)
.with_base_delay_ms(1) .with_max_delay_ms(10);
let result = provider
.chat(vec![], vec![], None, ChatOptions::default())
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap().content, "recovered");
}
#[tokio::test]
async fn test_retry_provider_retries_on_500() {
let inner = FailThenSucceedProvider::new(1, "HTTP 500 Internal Server Error");
let provider = RetryProvider::new(Box::new(inner))
.with_max_retries(3)
.with_base_delay_ms(1)
.with_max_delay_ms(10);
let result = provider
.chat(vec![], vec![], None, ChatOptions::default())
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap().content, "recovered");
}
#[tokio::test]
async fn test_retry_provider_no_retry_on_401() {
let inner = FailThenSucceedProvider::new(1, "HTTP 401 Unauthorized");
let provider = RetryProvider::new(Box::new(inner))
.with_max_retries(3)
.with_base_delay_ms(1)
.with_max_delay_ms(10);
let result = provider
.chat(vec![], vec![], None, ChatOptions::default())
.await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("401"));
}
#[tokio::test]
async fn test_retry_provider_exhausts_retries() {
let inner = FailThenSucceedProvider::new(10, "HTTP 429 Too Many Requests");
let provider = RetryProvider::new(Box::new(inner))
.with_max_retries(2)
.with_base_delay_ms(1)
.with_max_delay_ms(10);
let result = provider
.chat(vec![], vec![], None, ChatOptions::default())
.await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("429"));
}
#[test]
fn test_is_retryable_typed_rate_limit() {
use crate::error::ProviderError;
let err = ZeptoError::ProviderTyped(ProviderError::RateLimit("quota exceeded".into()));
assert!(is_retryable(&err));
}
#[test]
fn test_is_retryable_typed_server_error() {
use crate::error::ProviderError;
let err = ZeptoError::ProviderTyped(ProviderError::ServerError("internal error".into()));
assert!(is_retryable(&err));
}
#[test]
fn test_is_retryable_typed_timeout() {
use crate::error::ProviderError;
let err = ZeptoError::ProviderTyped(ProviderError::Timeout("connection timed out".into()));
assert!(is_retryable(&err));
}
#[test]
fn test_is_not_retryable_typed_auth() {
use crate::error::ProviderError;
let err = ZeptoError::ProviderTyped(ProviderError::Auth("invalid api key".into()));
assert!(!is_retryable(&err));
}
#[test]
fn test_is_not_retryable_typed_billing() {
use crate::error::ProviderError;
let err = ZeptoError::ProviderTyped(ProviderError::Billing("payment required".into()));
assert!(!is_retryable(&err));
}
#[test]
fn test_is_not_retryable_typed_invalid_request() {
use crate::error::ProviderError;
let err = ZeptoError::ProviderTyped(ProviderError::InvalidRequest("bad json".into()));
assert!(!is_retryable(&err));
}
#[test]
fn test_is_not_retryable_typed_model_not_found() {
use crate::error::ProviderError;
let err = ZeptoError::ProviderTyped(ProviderError::ModelNotFound("gpt-99".into()));
assert!(!is_retryable(&err));
}
#[test]
fn test_is_not_retryable_typed_unknown() {
use crate::error::ProviderError;
let err = ZeptoError::ProviderTyped(ProviderError::Unknown("something".into()));
assert!(!is_retryable(&err));
}
struct TypedFailThenSucceedProvider {
fail_count: std::sync::atomic::AtomicU32,
target_failures: u32,
}
impl TypedFailThenSucceedProvider {
fn new(target_failures: u32) -> Self {
Self {
fail_count: std::sync::atomic::AtomicU32::new(0),
target_failures,
}
}
}
#[async_trait]
impl LLMProvider for TypedFailThenSucceedProvider {
fn name(&self) -> &str {
"typed-fail-then-succeed"
}
fn default_model(&self) -> &str {
"test-model"
}
async fn chat(
&self,
_messages: Vec<Message>,
_tools: Vec<ToolDefinition>,
_model: Option<&str>,
_options: ChatOptions,
) -> Result<LLMResponse> {
use crate::error::ProviderError;
let count = self
.fail_count
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
if count < self.target_failures {
Err(ZeptoError::ProviderTyped(ProviderError::RateLimit(
"quota exceeded".into(),
)))
} else {
Ok(LLMResponse::text("recovered"))
}
}
}
#[tokio::test]
async fn test_retry_provider_retries_typed_rate_limit() {
let inner = TypedFailThenSucceedProvider::new(2);
let provider = RetryProvider::new(Box::new(inner))
.with_max_retries(3)
.with_base_delay_ms(1)
.with_max_delay_ms(10);
let result = provider
.chat(vec![], vec![], None, ChatOptions::default())
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap().content, "recovered");
}
#[test]
fn test_is_retryable_auth_failure_keywords() {
let cases = [
"invalid api key provided",
"Incorrect API key: sk-xxx",
"authentication failed: bad token",
"401 Unauthorized - access denied",
"permission denied for model",
];
for msg in cases {
let err = ZeptoError::Provider(msg.to_string());
assert!(!is_retryable(&err), "should not retry auth failure: {msg}");
}
}
#[test]
fn test_is_retryable_context_window_exceeded() {
let cases = [
"This model's maximum context length is 128000 tokens",
"exceeds the context window of this model",
"maximum context length exceeded",
"context length exceeded for gpt-4",
];
for msg in cases {
let err = ZeptoError::Provider(msg.to_string());
assert!(
!is_retryable(&err),
"should not retry context window error: {msg}"
);
}
}
#[test]
fn test_is_retryable_model_not_found() {
let cases = [
"model 'gpt-99' not found",
"Unknown model: claude-99",
"model does not exist: llama-999",
"unsupported model specified",
];
for msg in cases {
let err = ZeptoError::Provider(msg.to_string());
assert!(
!is_retryable(&err),
"should not retry model-not-found: {msg}"
);
}
}
#[test]
fn test_is_retryable_429_still_retryable_after_change() {
let err = ZeptoError::Provider("HTTP 429 Too Many Requests".to_string());
assert!(is_retryable(&err));
}
#[test]
fn test_is_retryable_500_still_retryable_after_change() {
let err = ZeptoError::Provider("HTTP 500 Internal Server Error".to_string());
assert!(is_retryable(&err));
}
#[test]
fn test_is_retryable_timeout_regression() {
let err = ZeptoError::Provider("request timeout after 30s".to_string());
assert!(is_retryable(&err));
}
#[test]
fn test_is_retryable_bare_403_forbidden() {
let err = ZeptoError::Provider("403 Forbidden: quota exceeded".to_string());
assert!(!is_retryable(&err));
}
#[test]
fn test_budget_exceeded_zero_means_unlimited() {
let mock = MockProvider::new("test", "model");
let provider = RetryProvider::new(Box::new(mock)).with_retry_budget_ms(0);
let start = std::time::Instant::now();
assert!(!provider.budget_exceeded(start));
}
#[test]
fn test_budget_exceeded_not_yet() {
let mock = MockProvider::new("test", "model");
let provider = RetryProvider::new(Box::new(mock)).with_retry_budget_ms(60_000);
let start = std::time::Instant::now();
assert!(!provider.budget_exceeded(start));
}
#[test]
fn test_budget_exceeded_past_deadline() {
let mock = MockProvider::new("test", "model");
let provider = RetryProvider::new(Box::new(mock)).with_retry_budget_ms(1);
let start = std::time::Instant::now() - std::time::Duration::from_millis(10);
assert!(provider.budget_exceeded(start));
}
struct SlowFailProvider {
delay_ms: u64,
}
#[async_trait]
impl LLMProvider for SlowFailProvider {
fn name(&self) -> &str {
"slow-fail"
}
fn default_model(&self) -> &str {
"test-model"
}
async fn chat(
&self,
_messages: Vec<Message>,
_tools: Vec<ToolDefinition>,
_model: Option<&str>,
_options: ChatOptions,
) -> Result<LLMResponse> {
tokio::time::sleep(std::time::Duration::from_millis(self.delay_ms)).await;
Err(ZeptoError::Provider(
"HTTP 429 Too Many Requests".to_string(),
))
}
}
#[tokio::test]
async fn test_retry_budget_stops_retrying() {
let inner = SlowFailProvider { delay_ms: 20 };
let provider = RetryProvider::new(Box::new(inner))
.with_max_retries(100)
.with_base_delay_ms(1)
.with_max_delay_ms(5)
.with_retry_budget_ms(50);
let start = std::time::Instant::now();
let result = provider
.chat(vec![], vec![], None, ChatOptions::default())
.await;
assert!(result.is_err());
assert!(start.elapsed().as_millis() < 500);
}
#[tokio::test]
async fn test_retry_budget_zero_unlimited() {
let inner = FailThenSucceedProvider::new(3, "HTTP 429 Too Many Requests");
let provider = RetryProvider::new(Box::new(inner))
.with_max_retries(5)
.with_base_delay_ms(1)
.with_max_delay_ms(5)
.with_retry_budget_ms(0);
let result = provider
.chat(vec![], vec![], None, ChatOptions::default())
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap().content, "recovered");
}
#[tokio::test]
async fn test_retry_budget_succeeds_within_budget() {
let inner = FailThenSucceedProvider::new(2, "HTTP 429 Too Many Requests");
let provider = RetryProvider::new(Box::new(inner))
.with_max_retries(5)
.with_base_delay_ms(1)
.with_max_delay_ms(5)
.with_retry_budget_ms(5_000);
let result = provider
.chat(vec![], vec![], None, ChatOptions::default())
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap().content, "recovered");
}
#[test]
fn test_retry_config_default_budget() {
use crate::config::RetryConfig;
let config = RetryConfig::default();
assert_eq!(config.retry_budget_ms, 45_000);
}
}