use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use futures::Stream;
use std::pin::Pin;
use crate::{
circuit_breaker::{CircuitBreakerConfig, ProviderCircuitBreaker},
complexity_router::{ComplexityRouter, DefaultRouter},
context::Context,
error::ProviderError,
fallback_chain::FallbackChain,
model_db::ModelEntry,
providers::{Provider, ProviderEvent, StreamOptions},
Model,
};
#[derive(Debug, Clone)]
pub struct MultiProviderConfig {
pub auto_routing: bool,
pub prefer_cost_efficient: bool,
pub max_retries_per_model: usize,
pub per_model_timeout: Option<Duration>,
pub circuit_breaker: CircuitBreakerConfig,
}
impl Default for MultiProviderConfig {
fn default() -> Self {
Self {
auto_routing: true,
prefer_cost_efficient: true,
max_retries_per_model: 1,
per_model_timeout: None,
circuit_breaker: CircuitBreakerConfig::default(),
}
}
}
impl MultiProviderConfig {
#[must_use]
pub fn with_auto_routing(mut self, enabled: bool) -> Self {
self.auto_routing = enabled;
self
}
#[must_use]
pub fn with_prefer_cost_efficient(mut self, enabled: bool) -> Self {
self.prefer_cost_efficient = enabled;
self
}
#[must_use]
pub fn with_max_retries(mut self, retries: usize) -> Self {
self.max_retries_per_model = retries;
self
}
#[must_use]
pub fn with_per_model_timeout(mut self, timeout: Duration) -> Self {
self.per_model_timeout = Some(timeout);
self
}
#[must_use]
pub fn with_circuit_breaker(mut self, config: CircuitBreakerConfig) -> Self {
self.circuit_breaker = config;
self
}
}
#[derive(Debug, thiserror::Error)]
pub enum MultiProviderError {
#[error("All providers exhausted")]
AllProvidersExhausted {
errors: Vec<(String, ProviderError)>,
},
#[error("No provider available for model: {0}")]
NoProviderForModel(String),
#[error("Circuit breaker open: {provider} (retry after {retry_after:?})")]
CircuitBreakerOpen {
provider: String,
retry_after: Duration,
},
#[error("No fallback models configured and primary provider failed")]
NoFallback,
#[error("No provider registered")]
NoProviderRegistered,
}
impl MultiProviderError {
pub fn is_circuit_breaker(&self) -> bool {
matches!(self, Self::CircuitBreakerOpen { .. })
}
pub fn retry_after(&self) -> Option<Duration> {
match self {
Self::CircuitBreakerOpen { retry_after, .. } => Some(*retry_after),
_ => None,
}
}
}
pub struct MultiProvider {
router: Arc<dyn ComplexityRouter>,
providers: HashMap<String, Arc<dyn Provider>>,
fallback: FallbackChain,
breakers: HashMap<String, Arc<ProviderCircuitBreaker>>,
config: MultiProviderConfig,
}
impl MultiProvider {
pub fn new(config: MultiProviderConfig) -> Self {
Self {
router: Arc::new(DefaultRouter::new()),
providers: HashMap::new(),
fallback: FallbackChain::default(),
breakers: HashMap::new(),
config,
}
}
pub fn with_router(router: impl ComplexityRouter + 'static) -> Self {
Self {
router: Arc::new(router),
providers: HashMap::new(),
fallback: FallbackChain::default(),
breakers: HashMap::new(),
config: MultiProviderConfig::default(),
}
}
pub fn with_config_and_router(
config: MultiProviderConfig,
router: impl ComplexityRouter + 'static,
) -> Self {
Self {
router: Arc::new(router),
providers: HashMap::new(),
fallback: FallbackChain::default(),
breakers: HashMap::new(),
config,
}
}
pub fn set_router(mut self, router: impl ComplexityRouter + 'static) -> Self {
self.router = Arc::new(router);
self
}
pub fn with_fallback(mut self, fallback: FallbackChain) -> Self {
self.fallback = fallback;
self
}
pub fn set_fallback(&mut self, fallback: FallbackChain) {
self.fallback = fallback;
}
pub fn register_provider(&mut self, name: &str, provider: Arc<dyn Provider>) {
let breaker = Arc::new(ProviderCircuitBreaker::new(
name.to_string(),
self.config.circuit_breaker.clone(),
));
self.providers.insert(name.to_string(), provider);
self.breakers.insert(name.to_string(), breaker);
}
pub fn unregister_provider(&mut self, name: &str) -> bool {
let provider_removed = self.providers.remove(name).is_some();
let breaker_removed = self.breakers.remove(name).is_some();
provider_removed || breaker_removed
}
pub fn get_provider(&self, name: &str) -> Option<&Arc<dyn Provider>> {
self.providers.get(name)
}
pub fn get_breaker(&self, provider_name: &str) -> Option<Arc<ProviderCircuitBreaker>> {
self.breakers.get(provider_name).cloned()
}
pub fn provider_names(&self) -> Vec<&str> {
self.providers.keys().map(|s| s.as_str()).collect()
}
pub fn circuit_breaker_diagnostics(
&self,
) -> Vec<crate::circuit_breaker::CircuitBreakerDiagnostics> {
self.breakers.values().map(|b| b.diagnostics()).collect()
}
pub fn router(&self) -> &Arc<dyn ComplexityRouter> {
&self.router
}
pub fn fallback(&self) -> &FallbackChain {
&self.fallback
}
pub fn config(&self) -> &MultiProviderConfig {
&self.config
}
pub fn diagnostics(&self) -> MultiProviderDiagnostics {
MultiProviderDiagnostics {
provider_count: self.providers.len(),
router_type: "DefaultRouter".to_string(),
fallback_len: self.fallback.len(),
auto_routing: self.config.auto_routing,
prefer_cost_efficient: self.config.prefer_cost_efficient,
circuit_breakers: self.circuit_breaker_diagnostics(),
}
}
}
#[derive(Debug, Clone)]
pub struct MultiProviderDiagnostics {
pub provider_count: usize,
pub router_type: String,
pub fallback_len: usize,
pub auto_routing: bool,
pub prefer_cost_efficient: bool,
pub circuit_breakers: Vec<crate::circuit_breaker::CircuitBreakerDiagnostics>,
}
#[async_trait]
impl Provider for MultiProvider {
async fn stream(
&self,
model: &Model,
context: &Context,
options: Option<StreamOptions>,
) -> Result<Pin<Box<dyn Stream<Item = ProviderEvent> + Send>>, ProviderError> {
let candidates = self.build_candidate_list(model, context).await?;
let mut errors: Vec<(String, ProviderError)> = Vec::new();
for candidate in candidates {
let provider_name = &candidate.provider;
let candidate_model = candidate.model;
let Some(provider) = self.providers.get(provider_name) else {
continue;
};
if let Some(breaker) = self.breakers.get(provider_name) {
match breaker.allow_request() {
Ok(()) => {
}
Err(e) => {
tracing::debug!(
provider = %provider_name,
remaining = ?e.remaining,
"Circuit breaker open, skipping provider"
);
continue;
}
}
}
let mut retry_count = 0;
let max_retries = self.config.max_retries_per_model;
loop {
match provider
.stream(&candidate_model, context, options.clone())
.await
{
Ok(stream) => {
if let Some(breaker) = self.breakers.get(provider_name) {
breaker.record_success();
}
tracing::debug!(
provider = %provider_name,
model = %candidate_model.id,
"MultiProvider: stream successful"
);
return Ok(stream);
}
Err(e) => {
if e.is_retryable() && retry_count < max_retries {
retry_count += 1;
if let Some(breaker) = self.breakers.get(provider_name) {
breaker.record_failure();
}
tracing::debug!(
provider = %provider_name,
model = %candidate_model.id,
error = %e,
retry = retry_count,
"Retryable error, retrying"
);
continue;
}
if !e.is_retryable() {
tracing::warn!(
provider = %provider_name,
model = %candidate_model.id,
error = %e,
"Non-retryable error, returning immediately"
);
return Err(e);
}
tracing::debug!(
provider = %provider_name,
model = %candidate_model.id,
error = %e,
retries = retry_count,
"Max retries exceeded, trying next candidate"
);
errors.push((format!("{}/{}", provider_name, candidate_model.id), e));
break;
}
}
}
}
if errors.is_empty() {
if self.providers.is_empty() {
Err(ProviderError::UnknownProvider(
"multi-provider: no providers registered".to_string(),
))
} else {
Err(ProviderError::UnknownProvider(
"multi-provider: no model could be routed".to_string(),
))
}
} else {
Err(ProviderError::UnknownProvider(format!(
"multi-provider: all {} candidates exhausted",
errors.len()
)))
}
}
fn name(&self) -> &str {
"multi-provider"
}
}
struct Candidate {
provider: String,
model: Model,
}
impl MultiProvider {
async fn build_candidate_list(
&self,
incoming_model: &Model,
context: &Context,
) -> Result<Vec<Candidate>, ProviderError> {
let mut candidates: Vec<Candidate> = Vec::new();
let mut seen_ids: HashMap<String, ()> = HashMap::new();
let add_candidate = |candidates: &mut Vec<Candidate>,
seen_ids: &mut HashMap<String, ()>,
provider: String,
model: Model| {
let id = format!("{}/{}", provider, model.id);
if seen_ids.insert(id, ()).is_none() {
candidates.push(Candidate { provider, model });
}
};
if self.config.auto_routing {
let complexity = self.router.classify(context);
let router_models = self
.router
.route(complexity, self.config.prefer_cost_efficient);
tracing::debug!(
complexity = ?complexity,
model_count = router_models.len(),
"MultiProvider: router selected models for complexity"
);
for entry in router_models {
if let Some(registered_model) =
crate::model_registry::get_model(entry.provider, entry.id)
{
if self.providers.contains_key(entry.provider) {
add_candidate(
&mut candidates,
&mut seen_ids,
entry.provider.to_string(),
registered_model.clone(),
);
}
}
if self.providers.contains_key(entry.provider) {
let model = self.model_from_entry(entry);
let id = format!("{}/{}", entry.provider, entry.id);
if seen_ids.insert(id, ()).is_none() {
candidates.push(Candidate {
provider: entry.provider.to_string(),
model,
});
}
}
}
}
if self.providers.contains_key(&incoming_model.provider) {
add_candidate(
&mut candidates,
&mut seen_ids,
incoming_model.provider.clone(),
incoming_model.clone(),
);
} else {
for provider_name in self.providers.keys() {
let model_id = &incoming_model.id;
if let Some(model) = self.find_model_for_provider(provider_name, model_id) {
add_candidate(&mut candidates, &mut seen_ids, provider_name.clone(), model);
break;
}
}
}
for fallback_entry in self.fallback.iter() {
if let Some(registered_model) =
crate::model_registry::get_model(fallback_entry.provider, fallback_entry.id)
{
if self.providers.contains_key(fallback_entry.provider) {
add_candidate(
&mut candidates,
&mut seen_ids,
fallback_entry.provider.to_string(),
registered_model.clone(),
);
}
} else if self.providers.contains_key(fallback_entry.provider) {
let model = self.model_from_entry(fallback_entry);
let id = format!("{}/{}", fallback_entry.provider, fallback_entry.id);
if seen_ids.insert(id, ()).is_none() {
candidates.push(Candidate {
provider: fallback_entry.provider.to_string(),
model,
});
}
}
}
if candidates.is_empty() && !self.providers.is_empty() {
let (provider_name, _provider) = self
.providers
.iter()
.next()
.expect("providers map is non-empty");
let model = self.default_model_for_provider(provider_name);
add_candidate(&mut candidates, &mut seen_ids, provider_name.clone(), model);
}
tracing::debug!(
candidate_count = candidates.len(),
"MultiProvider: built candidate list"
);
if candidates.is_empty() && self.providers.is_empty() {
return Err(ProviderError::UnknownProvider(
"multi-provider: no providers registered".to_string(),
));
}
Ok(candidates)
}
fn model_from_entry(&self, entry: &ModelEntry) -> Model {
Model {
id: entry.id.to_string(),
name: entry.name.to_string(),
api: entry.api,
provider: entry.provider.to_string(),
base_url: String::new(), reasoning: entry.reasoning,
input: entry.input.to_vec(),
cost: crate::types::Cost {
input: entry.cost_input,
output: entry.cost_output,
cache_read: entry.cost_cache_read,
cache_write: entry.cost_cache_write,
},
context_window: entry.context_window as usize,
max_tokens: entry.max_tokens as usize,
headers: HashMap::new(),
compat: None,
}
}
fn find_model_for_provider(&self, provider_name: &str, model_id: &str) -> Option<Model> {
if let Some(model) = crate::model_registry::get_model(provider_name, model_id) {
return Some(model.clone());
}
if let Some(entry) = crate::model_db::get_model_entry(provider_name, model_id) {
return Some(self.model_from_entry(entry));
}
Some(self.construct_model_from_id(provider_name, model_id))
}
fn construct_model_from_id(&self, provider: &str, model_id: &str) -> Model {
if let Some(entry) = crate::model_db::get_model_entry(provider, model_id) {
return self.model_from_entry(entry);
}
let api = match provider {
"openai" | "openai-codex" | "opencode" | "opencode-go" => {
crate::types::Api::OpenAiResponses
}
"anthropic" | "cloudflare-ai-gateway" => crate::types::Api::AnthropicMessages,
"google" => crate::types::Api::GoogleGenerativeAi,
"google-vertex" => crate::types::Api::GoogleVertex,
"azure-openai" | "azure-openai-responses" => crate::types::Api::AzureOpenAiResponses,
"amazon-bedrock" | "bedrock" => crate::types::Api::BedrockConverseStream,
_ => crate::types::Api::OpenAiResponses,
};
Model {
id: model_id.to_string(),
name: model_id.to_string(),
api,
provider: provider.to_string(),
base_url: String::new(),
reasoning: false,
input: vec![crate::types::InputModality::Text],
cost: crate::types::Cost::default(),
context_window: 128_000,
max_tokens: 32_000,
headers: HashMap::new(),
compat: None,
}
}
fn default_model_for_provider(&self, provider_name: &str) -> Model {
let default_model_id = match provider_name {
"openai" => "gpt-4o-mini",
"anthropic" => "claude-sonnet-4-20250514",
"google" => "gemini-2.0-flash",
_ => return self.construct_model_from_id(provider_name, "default"),
};
if let Some(entry) = crate::model_db::get_model_entry(provider_name, default_model_id) {
return self.model_from_entry(entry);
}
let provider_models = crate::model_db::get_provider_models(provider_name);
if !provider_models.is_empty() {
if let Some(entry) = provider_models.last() {
return self.model_from_entry(entry);
}
}
self.construct_model_from_id(provider_name, "default")
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::context::Context;
use crate::Message;
fn create_test_context() -> Context {
let mut ctx = Context::new();
ctx.add_message(Message::User(crate::UserMessage::new(
"Help me write a function to reverse a string".to_string(),
)));
ctx
}
#[test]
fn test_config_defaults() {
let config = MultiProviderConfig::default();
assert!(config.auto_routing);
assert!(config.prefer_cost_efficient);
assert_eq!(config.max_retries_per_model, 1);
assert!(config.per_model_timeout.is_none());
}
#[test]
fn test_config_builder() {
let config = MultiProviderConfig::default()
.with_auto_routing(false)
.with_prefer_cost_efficient(false)
.with_max_retries(3)
.with_per_model_timeout(Duration::from_secs(30));
assert!(!config.auto_routing);
assert!(!config.prefer_cost_efficient);
assert_eq!(config.max_retries_per_model, 3);
assert_eq!(config.per_model_timeout, Some(Duration::from_secs(30)));
}
#[test]
fn test_multi_provider_creation() {
let config = MultiProviderConfig::default();
let provider = MultiProvider::new(config);
assert_eq!(provider.name(), "multi-provider");
assert!(provider.provider_names().is_empty());
}
#[test]
fn test_register_provider() {
let mut provider = MultiProvider::new(MultiProviderConfig::default());
struct MockProvider;
#[async_trait]
impl Provider for MockProvider {
async fn stream(
&self,
_model: &Model,
_context: &Context,
_options: Option<StreamOptions>,
) -> Result<Pin<Box<dyn Stream<Item = ProviderEvent> + Send>>, ProviderError>
{
unreachable!("Mock provider - not called in this test")
}
fn name(&self) -> &str {
"mock"
}
}
let mock = Arc::new(MockProvider);
provider.register_provider("test", mock);
assert_eq!(provider.provider_names(), vec!["test"]);
assert!(provider.get_provider("test").is_some());
assert!(provider.get_breaker("test").is_some());
}
#[test]
fn test_unregister_provider() {
let mut provider = MultiProvider::new(MultiProviderConfig::default());
struct MockProvider;
#[async_trait]
impl Provider for MockProvider {
async fn stream(
&self,
_model: &Model,
_context: &Context,
_options: Option<StreamOptions>,
) -> Result<Pin<Box<dyn Stream<Item = ProviderEvent> + Send>>, ProviderError>
{
unreachable!("Mock provider")
}
fn name(&self) -> &str {
"mock"
}
}
let mock = Arc::new(MockProvider);
provider.register_provider("test", mock.clone());
assert!(provider.unregister_provider("test"));
assert!(provider.provider_names().is_empty());
assert!(provider.get_provider("test").is_none());
}
#[test]
fn test_with_router() {
let router = DefaultRouter::new();
let provider = MultiProvider::with_router(router);
assert_eq!(provider.name(), "multi-provider");
}
#[test]
fn test_with_fallback() {
let fallback = FallbackChain::from_ids(&["openai/gpt-4o"]).unwrap();
let provider = MultiProvider::new(MultiProviderConfig::default()).with_fallback(fallback);
assert_eq!(provider.fallback().len(), 1);
}
#[test]
fn test_circuit_breaker_diagnostics() {
let mut provider = MultiProvider::new(MultiProviderConfig::default());
struct MockProvider;
#[async_trait]
impl Provider for MockProvider {
async fn stream(
&self,
_model: &Model,
_context: &Context,
_options: Option<StreamOptions>,
) -> Result<Pin<Box<dyn Stream<Item = ProviderEvent> + Send>>, ProviderError>
{
unreachable!("Mock provider")
}
fn name(&self) -> &str {
"mock"
}
}
let mock = Arc::new(MockProvider);
provider.register_provider("test", mock);
let diagnostics = provider.circuit_breaker_diagnostics();
assert_eq!(diagnostics.len(), 1);
assert_eq!(diagnostics[0].provider, "test");
}
#[test]
fn test_multi_provider_error_display() {
let err = MultiProviderError::NoProviderForModel("gpt-4o".to_string());
assert!(err.to_string().contains("gpt-4o"));
let err = MultiProviderError::AllProvidersExhausted { errors: vec![] };
assert!(err.to_string().contains("All providers exhausted"));
let err = MultiProviderError::CircuitBreakerOpen {
provider: "openai".to_string(),
retry_after: Duration::from_secs(10),
};
assert!(err.to_string().contains("openai"));
assert!(err.to_string().contains("10"));
}
#[test]
fn test_multi_provider_error_helpers() {
let err = MultiProviderError::CircuitBreakerOpen {
provider: "openai".to_string(),
retry_after: Duration::from_secs(10),
};
assert!(err.is_circuit_breaker());
assert_eq!(err.retry_after(), Some(Duration::from_secs(10)));
let err = MultiProviderError::AllProvidersExhausted { errors: vec![] };
assert!(!err.is_circuit_breaker());
assert_eq!(err.retry_after(), None);
}
#[test]
fn test_diagnostics() {
let mut provider = MultiProvider::new(MultiProviderConfig::default());
struct MockProvider;
#[async_trait]
impl Provider for MockProvider {
async fn stream(
&self,
_model: &Model,
_context: &Context,
_options: Option<StreamOptions>,
) -> Result<Pin<Box<dyn Stream<Item = ProviderEvent> + Send>>, ProviderError>
{
unreachable!("Mock provider")
}
fn name(&self) -> &str {
"mock"
}
}
let mock = Arc::new(MockProvider);
provider.register_provider("test", mock);
let diag = provider.diagnostics();
assert_eq!(diag.provider_count, 1);
assert!(diag.auto_routing);
assert!(diag.prefer_cost_efficient);
assert_eq!(diag.circuit_breakers.len(), 1);
}
#[test]
fn test_router_classification() {
use crate::Complexity;
let router = DefaultRouter::new();
let provider = MultiProvider::with_router(router);
let ctx = create_test_context();
let complexity = provider.router().classify(&ctx);
assert!(complexity >= Complexity::Simple);
}
}