use std::fmt;
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::time::{SystemTime, UNIX_EPOCH};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use tracing::{info, warn};
use crate::error::Result;
use crate::session::Message;
use super::{ChatOptions, LLMProvider, LLMResponse, StreamEvent, ToolDefinition};
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum RotationStrategy {
#[default]
Priority,
RoundRobin,
CostAware,
}
struct ProviderHealth {
failure_count: AtomicU32,
last_failure_epoch: AtomicU64,
failure_threshold: u32,
cooldown_secs: u64,
}
impl ProviderHealth {
fn new(failure_threshold: u32, cooldown_secs: u64) -> Self {
Self {
failure_count: AtomicU32::new(0),
last_failure_epoch: AtomicU64::new(0),
failure_threshold,
cooldown_secs,
}
}
fn is_healthy(&self) -> bool {
let failures = self.failure_count.load(Ordering::Relaxed);
if failures < self.failure_threshold {
return true;
}
let last_failure = self.last_failure_epoch.load(Ordering::Relaxed);
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
now.saturating_sub(last_failure) >= self.cooldown_secs
}
fn record_success(&self) {
let prev = self.failure_count.swap(0, Ordering::Relaxed);
if prev >= self.failure_threshold {
info!(
previous_failures = prev,
"Rotation: provider recovered, resetting health"
);
}
}
fn record_failure(&self) {
let prev = self.failure_count.fetch_add(1, Ordering::Relaxed);
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
self.last_failure_epoch.store(now, Ordering::Relaxed);
if prev + 1 == self.failure_threshold {
info!(
threshold = self.failure_threshold,
"Rotation: provider marked unhealthy"
);
}
}
}
impl fmt::Debug for ProviderHealth {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ProviderHealth")
.field("failure_count", &self.failure_count.load(Ordering::Relaxed))
.field("failure_threshold", &self.failure_threshold)
.field("cooldown_secs", &self.cooldown_secs)
.field("is_healthy", &self.is_healthy())
.finish()
}
}
pub struct RotationProvider {
providers: Vec<(Box<dyn LLMProvider>, ProviderHealth)>,
strategy: RotationStrategy,
round_robin_index: AtomicU32,
composite_name: String,
}
impl fmt::Debug for RotationProvider {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let names: Vec<&str> = self.providers.iter().map(|(p, _)| p.name()).collect();
f.debug_struct("RotationProvider")
.field("providers", &names)
.field("strategy", &self.strategy)
.finish()
}
}
impl RotationProvider {
pub fn new(
providers: Vec<Box<dyn LLMProvider>>,
strategy: RotationStrategy,
failure_threshold: u32,
cooldown_secs: u64,
) -> Self {
assert!(
!providers.is_empty(),
"RotationProvider requires at least one provider"
);
let names: Vec<&str> = providers.iter().map(|p| p.name()).collect();
let composite_name = format!("rotation({})", names.join(", "));
let providers = providers
.into_iter()
.map(|p| {
let health = ProviderHealth::new(failure_threshold, cooldown_secs);
(p, health)
})
.collect();
Self {
providers,
strategy,
round_robin_index: AtomicU32::new(0),
composite_name,
}
}
fn select_provider_index(&self) -> usize {
let len = self.providers.len();
match self.strategy {
RotationStrategy::Priority => {
for i in 0..len {
if self.providers[i].1.is_healthy() {
return i;
}
}
self.oldest_unhealthy_index()
}
RotationStrategy::RoundRobin => {
let start = self.round_robin_index.fetch_add(1, Ordering::Relaxed) as usize;
for offset in 0..len {
let i = (start + offset) % len;
if self.providers[i].1.is_healthy() {
return i;
}
}
self.oldest_unhealthy_index()
}
RotationStrategy::CostAware => {
let pricing = crate::utils::cost::default_pricing();
let mut best_index: Option<usize> = None;
let mut best_cost = f64::MAX;
for (i, (provider, health)) in self.providers.iter().enumerate() {
if !health.is_healthy() {
continue;
}
let model = provider.default_model();
let cost = pricing
.get(model)
.map(|p| p.input_cost_per_million)
.unwrap_or(f64::MAX);
if cost < best_cost {
best_cost = cost;
best_index = Some(i);
}
}
best_index.unwrap_or_else(|| self.oldest_unhealthy_index())
}
}
}
fn oldest_unhealthy_index(&self) -> usize {
self.providers
.iter()
.enumerate()
.min_by_key(|(_, (_, h))| h.last_failure_epoch.load(Ordering::Relaxed))
.map(|(i, _)| i)
.unwrap_or(0)
}
fn should_rotate(err: &crate::error::ZeptoError) -> bool {
match err {
crate::error::ZeptoError::ProviderTyped(pe) => pe.should_fallback(),
_ => true, }
}
}
#[async_trait]
impl LLMProvider for RotationProvider {
fn name(&self) -> &str {
&self.composite_name
}
fn default_model(&self) -> &str {
self.providers[0].0.default_model()
}
async fn chat(
&self,
messages: Vec<Message>,
tools: Vec<ToolDefinition>,
model: Option<&str>,
options: ChatOptions,
) -> Result<LLMResponse> {
let len = self.providers.len();
let start_index = self.select_provider_index();
let mut last_err = None;
for offset in 0..len {
let i = (start_index + offset) % len;
let (provider, health) = &self.providers[i];
if offset > 0 && !health.is_healthy() {
continue;
}
match provider
.chat(messages.clone(), tools.clone(), model, options.clone())
.await
{
Ok(response) => {
health.record_success();
return Ok(response);
}
Err(err) => {
if Self::should_rotate(&err) {
health.record_failure();
warn!(
provider = provider.name(),
error = %err,
"Rotation: provider failed, trying next"
);
last_err = Some(err);
} else {
warn!(
provider = provider.name(),
error = %err,
"Rotation: non-recoverable error, not rotating"
);
return Err(err);
}
}
}
}
Err(last_err.unwrap_or_else(|| {
crate::error::ZeptoError::Provider("All rotation providers failed".into())
}))
}
async fn chat_stream(
&self,
messages: Vec<Message>,
tools: Vec<ToolDefinition>,
model: Option<&str>,
options: ChatOptions,
) -> Result<tokio::sync::mpsc::Receiver<StreamEvent>> {
let len = self.providers.len();
let start_index = self.select_provider_index();
let mut last_err = None;
for offset in 0..len {
let i = (start_index + offset) % len;
let (provider, health) = &self.providers[i];
if offset > 0 && !health.is_healthy() {
continue;
}
match provider
.chat_stream(messages.clone(), tools.clone(), model, options.clone())
.await
{
Ok(receiver) => {
health.record_success();
return Ok(receiver);
}
Err(err) => {
if Self::should_rotate(&err) {
health.record_failure();
warn!(
provider = provider.name(),
error = %err,
"Rotation: provider streaming failed, trying next"
);
last_err = Some(err);
} else {
warn!(
provider = provider.name(),
error = %err,
"Rotation: non-recoverable streaming error, not rotating"
);
return Err(err);
}
}
}
}
Err(last_err.unwrap_or_else(|| {
crate::error::ZeptoError::Provider("All rotation providers failed (streaming)".into())
}))
}
async fn embed(&self, texts: &[String]) -> crate::error::Result<Vec<Vec<f32>>> {
let index = self.select_provider_index();
self.providers[index].0.embed(texts).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::{ProviderError, ZeptoError};
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
struct SuccessProvider {
name: &'static str,
}
impl fmt::Debug for SuccessProvider {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SuccessProvider")
.field("name", &self.name)
.finish()
}
}
#[async_trait]
impl LLMProvider for SuccessProvider {
fn name(&self) -> &str {
self.name
}
fn default_model(&self) -> &str {
"success-model-v1"
}
async fn chat(
&self,
_messages: Vec<Message>,
_tools: Vec<ToolDefinition>,
_model: Option<&str>,
_options: ChatOptions,
) -> Result<LLMResponse> {
Ok(LLMResponse::text(&format!("success from {}", self.name)))
}
}
struct FailProvider {
name: &'static str,
}
impl fmt::Debug for FailProvider {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FailProvider")
.field("name", &self.name)
.finish()
}
}
#[async_trait]
impl LLMProvider for FailProvider {
fn name(&self) -> &str {
self.name
}
fn default_model(&self) -> &str {
"fail-model-v1"
}
async fn chat(
&self,
_messages: Vec<Message>,
_tools: Vec<ToolDefinition>,
_model: Option<&str>,
_options: ChatOptions,
) -> Result<LLMResponse> {
Err(ZeptoError::Provider("provider failed".into()))
}
}
struct CountingProvider {
name: &'static str,
call_count: Arc<AtomicU32>,
}
impl fmt::Debug for CountingProvider {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("CountingProvider")
.field("name", &self.name)
.field("call_count", &self.call_count.load(Ordering::SeqCst))
.finish()
}
}
#[async_trait]
impl LLMProvider for CountingProvider {
fn name(&self) -> &str {
self.name
}
fn default_model(&self) -> &str {
"counting-model-v1"
}
async fn chat(
&self,
_messages: Vec<Message>,
_tools: Vec<ToolDefinition>,
_model: Option<&str>,
_options: ChatOptions,
) -> Result<LLMResponse> {
self.call_count.fetch_add(1, Ordering::SeqCst);
Ok(LLMResponse::text(&format!("success from {}", self.name)))
}
}
struct TypedFailProvider {
name: &'static str,
error: fn() -> ZeptoError,
}
impl fmt::Debug for TypedFailProvider {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TypedFailProvider")
.field("name", &self.name)
.finish()
}
}
#[async_trait]
impl LLMProvider for TypedFailProvider {
fn name(&self) -> &str {
self.name
}
fn default_model(&self) -> &str {
"typed-fail-model"
}
async fn chat(
&self,
_messages: Vec<Message>,
_tools: Vec<ToolDefinition>,
_model: Option<&str>,
_options: ChatOptions,
) -> Result<LLMResponse> {
Err((self.error)())
}
}
struct CountingFailProvider {
name: &'static str,
call_count: Arc<AtomicU32>,
}
impl fmt::Debug for CountingFailProvider {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("CountingFailProvider")
.field("name", &self.name)
.field("call_count", &self.call_count.load(Ordering::SeqCst))
.finish()
}
}
#[async_trait]
impl LLMProvider for CountingFailProvider {
fn name(&self) -> &str {
self.name
}
fn default_model(&self) -> &str {
"counting-fail-model"
}
async fn chat(
&self,
_messages: Vec<Message>,
_tools: Vec<ToolDefinition>,
_model: Option<&str>,
_options: ChatOptions,
) -> Result<LLMResponse> {
self.call_count.fetch_add(1, Ordering::SeqCst);
Err(ZeptoError::Provider("provider failed".into()))
}
}
#[test]
fn test_rotation_name() {
let provider = RotationProvider::new(
vec![
Box::new(SuccessProvider { name: "claude" }),
Box::new(SuccessProvider { name: "openai" }),
Box::new(SuccessProvider { name: "groq" }),
],
RotationStrategy::Priority,
3,
30,
);
assert_eq!(provider.name(), "rotation(claude, openai, groq)");
}
#[test]
fn test_rotation_default_model() {
let provider = RotationProvider::new(
vec![
Box::new(SuccessProvider { name: "claude" }),
Box::new(SuccessProvider { name: "openai" }),
],
RotationStrategy::Priority,
3,
30,
);
assert_eq!(provider.default_model(), "success-model-v1");
}
#[tokio::test]
async fn test_rotation_priority_uses_first_healthy() {
let calls_a = Arc::new(AtomicU32::new(0));
let calls_b = Arc::new(AtomicU32::new(0));
let calls_c = Arc::new(AtomicU32::new(0));
let provider = RotationProvider::new(
vec![
Box::new(CountingProvider {
name: "alpha",
call_count: Arc::clone(&calls_a),
}),
Box::new(CountingProvider {
name: "beta",
call_count: Arc::clone(&calls_b),
}),
Box::new(CountingProvider {
name: "gamma",
call_count: Arc::clone(&calls_c),
}),
],
RotationStrategy::Priority,
3,
30,
);
let response = provider
.chat(vec![], vec![], None, ChatOptions::default())
.await
.expect("should succeed");
assert_eq!(response.content, "success from alpha");
assert_eq!(calls_a.load(Ordering::SeqCst), 1);
assert_eq!(calls_b.load(Ordering::SeqCst), 0);
assert_eq!(calls_c.load(Ordering::SeqCst), 0);
}
#[tokio::test]
async fn test_rotation_round_robin() {
let calls_a = Arc::new(AtomicU32::new(0));
let calls_b = Arc::new(AtomicU32::new(0));
let calls_c = Arc::new(AtomicU32::new(0));
let provider = RotationProvider::new(
vec![
Box::new(CountingProvider {
name: "alpha",
call_count: Arc::clone(&calls_a),
}),
Box::new(CountingProvider {
name: "beta",
call_count: Arc::clone(&calls_b),
}),
Box::new(CountingProvider {
name: "gamma",
call_count: Arc::clone(&calls_c),
}),
],
RotationStrategy::RoundRobin,
3,
30,
);
for _ in 0..3 {
let _ = provider
.chat(vec![], vec![], None, ChatOptions::default())
.await
.expect("should succeed");
}
assert_eq!(calls_a.load(Ordering::SeqCst), 1);
assert_eq!(calls_b.load(Ordering::SeqCst), 1);
assert_eq!(calls_c.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_rotation_records_failure() {
let provider = RotationProvider::new(
vec![
Box::new(FailProvider { name: "alpha" }),
Box::new(SuccessProvider { name: "beta" }),
],
RotationStrategy::Priority,
3,
30,
);
let response = provider
.chat(vec![], vec![], None, ChatOptions::default())
.await
.expect("beta should succeed");
assert_eq!(response.content, "success from beta");
assert_eq!(
provider.providers[0]
.1
.failure_count
.load(Ordering::Relaxed),
1
);
}
#[tokio::test]
async fn test_rotation_records_success_resets() {
let provider = RotationProvider::new(
vec![
Box::new(SuccessProvider { name: "alpha" }),
Box::new(SuccessProvider { name: "beta" }),
],
RotationStrategy::Priority,
3,
30,
);
provider.providers[0]
.1
.failure_count
.store(2, Ordering::Relaxed);
let _ = provider
.chat(vec![], vec![], None, ChatOptions::default())
.await
.expect("should succeed");
assert_eq!(
provider.providers[0]
.1
.failure_count
.load(Ordering::Relaxed),
0
);
}
#[tokio::test]
async fn test_rotation_skips_unhealthy() {
let calls_a = Arc::new(AtomicU32::new(0));
let calls_b = Arc::new(AtomicU32::new(0));
let provider = RotationProvider::new(
vec![
Box::new(CountingFailProvider {
name: "alpha",
call_count: Arc::clone(&calls_a),
}),
Box::new(CountingProvider {
name: "beta",
call_count: Arc::clone(&calls_b),
}),
],
RotationStrategy::Priority,
3,
30,
);
for _ in 0..3 {
let _ = provider
.chat(vec![], vec![], None, ChatOptions::default())
.await;
}
assert_eq!(calls_a.load(Ordering::SeqCst), 3);
assert_eq!(calls_b.load(Ordering::SeqCst), 3);
let prev_a = calls_a.load(Ordering::SeqCst);
let response = provider
.chat(vec![], vec![], None, ChatOptions::default())
.await
.expect("beta should succeed");
assert_eq!(response.content, "success from beta");
assert_eq!(
calls_a.load(Ordering::SeqCst),
prev_a,
"unhealthy alpha should be skipped"
);
}
#[tokio::test]
async fn test_rotation_all_unhealthy_uses_oldest() {
let provider = RotationProvider::new(
vec![
Box::new(FailProvider { name: "alpha" }),
Box::new(FailProvider { name: "beta" }),
Box::new(FailProvider { name: "gamma" }),
],
RotationStrategy::Priority,
1, 30,
);
for _ in 0..3 {
let _ = provider
.chat(vec![], vec![], None, ChatOptions::default())
.await;
}
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
provider.providers[0]
.1
.last_failure_epoch
.store(now - 100, Ordering::Relaxed);
provider.providers[1]
.1
.last_failure_epoch
.store(now - 50, Ordering::Relaxed);
provider.providers[2]
.1
.last_failure_epoch
.store(now - 10, Ordering::Relaxed);
let selected = provider.select_provider_index();
assert_eq!(selected, 0, "should select provider with oldest failure");
}
#[tokio::test]
async fn test_rotation_auth_error_no_rotation() {
let provider = RotationProvider::new(
vec![
Box::new(TypedFailProvider {
name: "alpha",
error: || ZeptoError::ProviderTyped(ProviderError::Auth("invalid key".into())),
}),
Box::new(SuccessProvider { name: "beta" }),
],
RotationStrategy::Priority,
3,
30,
);
let result = provider
.chat(vec![], vec![], None, ChatOptions::default())
.await;
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Authentication error"));
}
#[tokio::test]
async fn test_rotation_rate_limit_triggers_rotation() {
let provider = RotationProvider::new(
vec![
Box::new(TypedFailProvider {
name: "alpha",
error: || {
ZeptoError::ProviderTyped(ProviderError::RateLimit("quota exceeded".into()))
},
}),
Box::new(SuccessProvider { name: "beta" }),
],
RotationStrategy::Priority,
3,
30,
);
let result = provider
.chat(vec![], vec![], None, ChatOptions::default())
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap().content, "success from beta");
}
#[tokio::test]
async fn test_rotation_single_provider() {
let provider = RotationProvider::new(
vec![Box::new(SuccessProvider { name: "solo" })],
RotationStrategy::Priority,
3,
30,
);
assert_eq!(provider.name(), "rotation(solo)");
let response = provider
.chat(vec![], vec![], None, ChatOptions::default())
.await
.expect("should succeed");
assert_eq!(response.content, "success from solo");
}
#[test]
fn test_rotation_config_defaults() {
use crate::config::RotationConfig;
let config = RotationConfig::default();
assert!(!config.enabled);
assert!(config.order.is_empty());
assert_eq!(config.strategy, RotationStrategy::Priority);
assert_eq!(config.failure_threshold, 3);
assert_eq!(config.cooldown_secs, 30);
}
#[test]
fn test_rotation_strategy_serialize() {
let strategy = RotationStrategy::RoundRobin;
let json = serde_json::to_string(&strategy).unwrap();
assert_eq!(json, "\"round_robin\"");
let parsed: RotationStrategy = serde_json::from_str("\"priority\"").unwrap();
assert_eq!(parsed, RotationStrategy::Priority);
}
struct ModelProvider {
name_str: &'static str,
model_str: &'static str,
call_count: Arc<AtomicU32>,
}
impl fmt::Debug for ModelProvider {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ModelProvider")
.field("name", &self.name_str)
.field("model", &self.model_str)
.finish()
}
}
#[async_trait]
impl LLMProvider for ModelProvider {
fn name(&self) -> &str {
self.name_str
}
fn default_model(&self) -> &str {
self.model_str
}
async fn chat(
&self,
_messages: Vec<Message>,
_tools: Vec<ToolDefinition>,
_model: Option<&str>,
_options: ChatOptions,
) -> Result<LLMResponse> {
self.call_count.fetch_add(1, Ordering::SeqCst);
Ok(LLMResponse::text(&format!("from {}", self.name_str)))
}
}
#[tokio::test]
async fn test_cost_aware_picks_cheapest() {
let calls_haiku = Arc::new(AtomicU32::new(0));
let calls_sonnet = Arc::new(AtomicU32::new(0));
let calls_opus = Arc::new(AtomicU32::new(0));
let provider = RotationProvider::new(
vec![
Box::new(ModelProvider {
name_str: "opus",
model_str: "claude-opus-4-6",
call_count: Arc::clone(&calls_opus),
}),
Box::new(ModelProvider {
name_str: "sonnet",
model_str: "claude-sonnet-4-5-20250929",
call_count: Arc::clone(&calls_sonnet),
}),
Box::new(ModelProvider {
name_str: "haiku",
model_str: "claude-3-haiku-20240307",
call_count: Arc::clone(&calls_haiku),
}),
],
RotationStrategy::CostAware,
3,
30,
);
let response = provider
.chat(vec![], vec![], None, ChatOptions::default())
.await
.expect("should succeed");
assert_eq!(response.content, "from haiku");
assert_eq!(calls_haiku.load(Ordering::SeqCst), 1);
assert_eq!(calls_sonnet.load(Ordering::SeqCst), 0);
assert_eq!(calls_opus.load(Ordering::SeqCst), 0);
}
#[tokio::test]
async fn test_cost_aware_skips_unhealthy_cheap() {
let calls_sonnet = Arc::new(AtomicU32::new(0));
let calls_haiku = Arc::new(AtomicU32::new(0));
let provider = RotationProvider::new(
vec![
Box::new(ModelProvider {
name_str: "haiku",
model_str: "claude-3-haiku-20240307",
call_count: Arc::clone(&calls_haiku),
}),
Box::new(ModelProvider {
name_str: "sonnet",
model_str: "claude-sonnet-4-5-20250929",
call_count: Arc::clone(&calls_sonnet),
}),
],
RotationStrategy::CostAware,
1, 30,
);
provider.providers[0]
.1
.failure_count
.store(1, Ordering::Relaxed);
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
provider.providers[0]
.1
.last_failure_epoch
.store(now, Ordering::Relaxed);
let response = provider
.chat(vec![], vec![], None, ChatOptions::default())
.await
.expect("should succeed");
assert_eq!(response.content, "from sonnet");
assert_eq!(calls_haiku.load(Ordering::SeqCst), 0);
assert_eq!(calls_sonnet.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_cost_aware_unknown_model_treated_expensive() {
let calls_known = Arc::new(AtomicU32::new(0));
let calls_unknown = Arc::new(AtomicU32::new(0));
let provider = RotationProvider::new(
vec![
Box::new(ModelProvider {
name_str: "unknown",
model_str: "some-obscure-model-v99",
call_count: Arc::clone(&calls_unknown),
}),
Box::new(ModelProvider {
name_str: "haiku",
model_str: "claude-3-haiku-20240307",
call_count: Arc::clone(&calls_known),
}),
],
RotationStrategy::CostAware,
3,
30,
);
let response = provider
.chat(vec![], vec![], None, ChatOptions::default())
.await
.expect("should succeed");
assert_eq!(response.content, "from haiku");
assert_eq!(calls_known.load(Ordering::SeqCst), 1);
assert_eq!(calls_unknown.load(Ordering::SeqCst), 0);
}
#[tokio::test]
async fn test_cost_aware_equal_cost_uses_priority() {
let calls_a = Arc::new(AtomicU32::new(0));
let calls_b = Arc::new(AtomicU32::new(0));
let provider = RotationProvider::new(
vec![
Box::new(ModelProvider {
name_str: "alpha",
model_str: "claude-sonnet-4-5-20250929",
call_count: Arc::clone(&calls_a),
}),
Box::new(ModelProvider {
name_str: "beta",
model_str: "claude-sonnet-4-5-20250929",
call_count: Arc::clone(&calls_b),
}),
],
RotationStrategy::CostAware,
3,
30,
);
let response = provider
.chat(vec![], vec![], None, ChatOptions::default())
.await
.expect("should succeed");
assert_eq!(response.content, "from alpha");
assert_eq!(calls_a.load(Ordering::SeqCst), 1);
assert_eq!(calls_b.load(Ordering::SeqCst), 0);
}
#[tokio::test]
async fn test_cost_aware_all_unhealthy_uses_oldest() {
let provider = RotationProvider::new(
vec![
Box::new(ModelProvider {
name_str: "expensive",
model_str: "claude-opus-4-6",
call_count: Arc::new(AtomicU32::new(0)),
}),
Box::new(ModelProvider {
name_str: "cheap",
model_str: "claude-3-haiku-20240307",
call_count: Arc::new(AtomicU32::new(0)),
}),
],
RotationStrategy::CostAware,
1,
30,
);
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
provider.providers[0]
.1
.failure_count
.store(1, Ordering::Relaxed);
provider.providers[0]
.1
.last_failure_epoch
.store(now - 100, Ordering::Relaxed);
provider.providers[1]
.1
.failure_count
.store(1, Ordering::Relaxed);
provider.providers[1]
.1
.last_failure_epoch
.store(now - 10, Ordering::Relaxed);
let idx = provider.select_provider_index();
assert_eq!(idx, 0, "all unhealthy should use oldest-failure fallback");
}
#[test]
fn test_cost_aware_strategy_serialize() {
let strategy = RotationStrategy::CostAware;
let json = serde_json::to_string(&strategy).unwrap();
assert_eq!(json, "\"cost_aware\"");
let parsed: RotationStrategy = serde_json::from_str("\"cost_aware\"").unwrap();
assert_eq!(parsed, RotationStrategy::CostAware);
}
#[tokio::test]
async fn test_rotation_billing_error_no_rotation() {
let provider = RotationProvider::new(
vec![
Box::new(TypedFailProvider {
name: "alpha",
error: || ZeptoError::ProviderTyped(ProviderError::Billing("no funds".into())),
}),
Box::new(SuccessProvider { name: "beta" }),
],
RotationStrategy::Priority,
3,
30,
);
let result = provider
.chat(vec![], vec![], None, ChatOptions::default())
.await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Billing error"));
}
#[tokio::test]
async fn test_rotation_server_error_triggers_rotation() {
let provider = RotationProvider::new(
vec![
Box::new(TypedFailProvider {
name: "alpha",
error: || {
ZeptoError::ProviderTyped(ProviderError::ServerError(
"internal error".into(),
))
},
}),
Box::new(SuccessProvider { name: "beta" }),
Box::new(SuccessProvider { name: "gamma" }),
],
RotationStrategy::Priority,
3,
30,
);
let result = provider
.chat(vec![], vec![], None, ChatOptions::default())
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap().content, "success from beta");
}
#[tokio::test]
async fn test_rotation_all_fail_returns_last_error() {
let provider = RotationProvider::new(
vec![
Box::new(FailProvider { name: "alpha" }),
Box::new(FailProvider { name: "beta" }),
],
RotationStrategy::Priority,
3,
30,
);
let result = provider
.chat(vec![], vec![], None, ChatOptions::default())
.await;
assert!(result.is_err());
}
#[test]
fn test_provider_health_starts_healthy() {
let health = ProviderHealth::new(3, 30);
assert!(health.is_healthy());
assert_eq!(health.failure_count.load(Ordering::Relaxed), 0);
}
#[test]
fn test_provider_health_becomes_unhealthy() {
let health = ProviderHealth::new(3, 30);
health.record_failure();
assert!(health.is_healthy());
health.record_failure();
assert!(health.is_healthy());
health.record_failure();
assert!(!health.is_healthy());
}
#[test]
fn test_provider_health_recovers_after_cooldown() {
let health = ProviderHealth::new(3, 1); health.record_failure();
health.record_failure();
health.record_failure();
assert!(!health.is_healthy());
let past = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
- 2; health.last_failure_epoch.store(past, Ordering::Relaxed);
assert!(health.is_healthy());
}
#[test]
fn test_provider_health_success_resets() {
let health = ProviderHealth::new(3, 30);
health.record_failure();
health.record_failure();
health.record_success();
assert_eq!(health.failure_count.load(Ordering::Relaxed), 0);
assert!(health.is_healthy());
}
}