1use std::collections::HashMap;
55use std::sync::Arc;
56use std::time::Duration;
57
58use async_trait::async_trait;
59use futures::Stream;
60use std::pin::Pin;
61
62use crate::{
63 circuit_breaker::{CircuitBreakerConfig, ProviderCircuitBreaker},
64 complexity_router::{ComplexityRouter, DefaultRouter},
65 context::Context,
66 error::ProviderError,
67 fallback_chain::FallbackChain,
68 model_db::ModelEntry,
69 providers::{Provider, ProviderEvent, StreamOptions},
70 Model,
71};
72
73#[derive(Debug, Clone)]
81pub struct MultiProviderConfig {
82 pub auto_routing: bool,
89
90 pub prefer_cost_efficient: bool,
97
98 pub max_retries_per_model: usize,
105
106 pub per_model_timeout: Option<Duration>,
113
114 pub circuit_breaker: CircuitBreakerConfig,
121}
122
123impl Default for MultiProviderConfig {
124 fn default() -> Self {
125 Self {
126 auto_routing: true,
127 prefer_cost_efficient: true,
128 max_retries_per_model: 1,
129 per_model_timeout: None,
130 circuit_breaker: CircuitBreakerConfig::default(),
131 }
132 }
133}
134
135impl MultiProviderConfig {
136 #[must_use]
138 pub fn with_auto_routing(mut self, enabled: bool) -> Self {
139 self.auto_routing = enabled;
140 self
141 }
142
143 #[must_use]
145 pub fn with_prefer_cost_efficient(mut self, enabled: bool) -> Self {
146 self.prefer_cost_efficient = enabled;
147 self
148 }
149
150 #[must_use]
152 pub fn with_max_retries(mut self, retries: usize) -> Self {
153 self.max_retries_per_model = retries;
154 self
155 }
156
157 #[must_use]
159 pub fn with_per_model_timeout(mut self, timeout: Duration) -> Self {
160 self.per_model_timeout = Some(timeout);
161 self
162 }
163
164 #[must_use]
166 pub fn with_circuit_breaker(mut self, config: CircuitBreakerConfig) -> Self {
167 self.circuit_breaker = config;
168 self
169 }
170}
171
172#[derive(Debug, thiserror::Error)]
178pub enum MultiProviderError {
179 #[error("All providers exhausted")]
183 AllProvidersExhausted {
184 errors: Vec<(String, ProviderError)>,
186 },
187
188 #[error("No provider available for model: {0}")]
190 NoProviderForModel(String),
191
192 #[error("Circuit breaker open: {provider} (retry after {retry_after:?})")]
196 CircuitBreakerOpen {
197 provider: String,
199 retry_after: Duration,
201 },
202
203 #[error("No fallback models configured and primary provider failed")]
205 NoFallback,
206
207 #[error("No provider registered")]
209 NoProviderRegistered,
210}
211
212impl MultiProviderError {
213 pub fn is_circuit_breaker(&self) -> bool {
215 matches!(self, Self::CircuitBreakerOpen { .. })
216 }
217
218 pub fn retry_after(&self) -> Option<Duration> {
220 match self {
221 Self::CircuitBreakerOpen { retry_after, .. } => Some(*retry_after),
222 _ => None,
223 }
224 }
225}
226
227pub struct MultiProvider {
242 router: Arc<dyn ComplexityRouter>,
244
245 providers: HashMap<String, Arc<dyn Provider>>,
247
248 fallback: FallbackChain,
250
251 breakers: HashMap<String, Arc<ProviderCircuitBreaker>>,
253
254 config: MultiProviderConfig,
256}
257
258impl MultiProvider {
259 pub fn new(config: MultiProviderConfig) -> Self {
270 Self {
271 router: Arc::new(DefaultRouter::new()),
272 providers: HashMap::new(),
273 fallback: FallbackChain::default(),
274 breakers: HashMap::new(),
275 config,
276 }
277 }
278
279 pub fn with_router(router: impl ComplexityRouter + 'static) -> Self {
290 Self {
291 router: Arc::new(router),
292 providers: HashMap::new(),
293 fallback: FallbackChain::default(),
294 breakers: HashMap::new(),
295 config: MultiProviderConfig::default(),
296 }
297 }
298
299 pub fn with_config_and_router(
311 config: MultiProviderConfig,
312 router: impl ComplexityRouter + 'static,
313 ) -> Self {
314 Self {
315 router: Arc::new(router),
316 providers: HashMap::new(),
317 fallback: FallbackChain::default(),
318 breakers: HashMap::new(),
319 config,
320 }
321 }
322
323 pub fn set_router(mut self, router: impl ComplexityRouter + 'static) -> Self {
331 self.router = Arc::new(router);
332 self
333 }
334
335 pub fn with_fallback(mut self, fallback: FallbackChain) -> Self {
346 self.fallback = fallback;
347 self
348 }
349
350 pub fn set_fallback(&mut self, fallback: FallbackChain) {
352 self.fallback = fallback;
353 }
354
355 pub fn register_provider(&mut self, name: &str, provider: Arc<dyn Provider>) {
372 let breaker = Arc::new(ProviderCircuitBreaker::new(
374 name.to_string(),
375 self.config.circuit_breaker.clone(),
376 ));
377
378 self.providers.insert(name.to_string(), provider);
379 self.breakers.insert(name.to_string(), breaker);
380 }
381
382 pub fn unregister_provider(&mut self, name: &str) -> bool {
394 let provider_removed = self.providers.remove(name).is_some();
395 let breaker_removed = self.breakers.remove(name).is_some();
396 provider_removed || breaker_removed
397 }
398
399 pub fn get_provider(&self, name: &str) -> Option<&Arc<dyn Provider>> {
409 self.providers.get(name)
410 }
411
412 pub fn get_breaker(&self, provider_name: &str) -> Option<Arc<ProviderCircuitBreaker>> {
422 self.breakers.get(provider_name).cloned()
423 }
424
425 pub fn provider_names(&self) -> Vec<&str> {
427 self.providers.keys().map(|s| s.as_str()).collect()
428 }
429
430 pub fn circuit_breaker_diagnostics(
436 &self,
437 ) -> Vec<crate::circuit_breaker::CircuitBreakerDiagnostics> {
438 self.breakers.values().map(|b| b.diagnostics()).collect()
439 }
440
441 pub fn router(&self) -> &Arc<dyn ComplexityRouter> {
443 &self.router
444 }
445
446 pub fn fallback(&self) -> &FallbackChain {
448 &self.fallback
449 }
450
451 pub fn config(&self) -> &MultiProviderConfig {
453 &self.config
454 }
455
456 pub fn diagnostics(&self) -> MultiProviderDiagnostics {
458 MultiProviderDiagnostics {
459 provider_count: self.providers.len(),
460 router_type: "DefaultRouter".to_string(),
461 fallback_len: self.fallback.len(),
462 auto_routing: self.config.auto_routing,
463 prefer_cost_efficient: self.config.prefer_cost_efficient,
464 circuit_breakers: self.circuit_breaker_diagnostics(),
465 }
466 }
467}
468
469#[derive(Debug, Clone)]
475pub struct MultiProviderDiagnostics {
476 pub provider_count: usize,
478 pub router_type: String,
480 pub fallback_len: usize,
482 pub auto_routing: bool,
484 pub prefer_cost_efficient: bool,
486 pub circuit_breakers: Vec<crate::circuit_breaker::CircuitBreakerDiagnostics>,
488}
489
490#[async_trait]
495impl Provider for MultiProvider {
496 async fn stream(
511 &self,
512 model: &Model,
513 context: &Context,
514 options: Option<StreamOptions>,
515 ) -> Result<Pin<Box<dyn Stream<Item = ProviderEvent> + Send>>, ProviderError> {
516 let candidates = self.build_candidate_list(model, context).await?;
518
519 let mut errors: Vec<(String, ProviderError)> = Vec::new();
521
522 for candidate in candidates {
523 let provider_name = &candidate.provider;
524 let candidate_model = candidate.model;
525
526 let Some(provider) = self.providers.get(provider_name) else {
528 continue;
530 };
531
532 if let Some(breaker) = self.breakers.get(provider_name) {
534 match breaker.allow_request() {
535 Ok(()) => {
536 }
538 Err(e) => {
539 tracing::debug!(
541 provider = %provider_name,
542 remaining = ?e.remaining,
543 "Circuit breaker open, skipping provider"
544 );
545 continue;
546 }
547 }
548 }
549
550 let mut retry_count = 0;
552 let max_retries = self.config.max_retries_per_model;
553
554 loop {
555 match provider
556 .stream(&candidate_model, context, options.clone())
557 .await
558 {
559 Ok(stream) => {
560 if let Some(breaker) = self.breakers.get(provider_name) {
562 breaker.record_success();
563 }
564 tracing::debug!(
565 provider = %provider_name,
566 model = %candidate_model.id,
567 "MultiProvider: stream successful"
568 );
569 return Ok(stream);
570 }
571 Err(e) => {
572 if e.is_retryable() && retry_count < max_retries {
574 retry_count += 1;
576 if let Some(breaker) = self.breakers.get(provider_name) {
577 breaker.record_failure();
578 }
579 tracing::debug!(
580 provider = %provider_name,
581 model = %candidate_model.id,
582 error = %e,
583 retry = retry_count,
584 "Retryable error, retrying"
585 );
586 continue;
587 }
588
589 if !e.is_retryable() {
591 tracing::warn!(
594 provider = %provider_name,
595 model = %candidate_model.id,
596 error = %e,
597 "Non-retryable error, returning immediately"
598 );
599 return Err(e);
600 }
601
602 tracing::debug!(
604 provider = %provider_name,
605 model = %candidate_model.id,
606 error = %e,
607 retries = retry_count,
608 "Max retries exceeded, trying next candidate"
609 );
610 errors.push((format!("{}/{}", provider_name, candidate_model.id), e));
611 break;
612 }
613 }
614 }
615 }
616
617 if errors.is_empty() {
619 if self.providers.is_empty() {
620 Err(ProviderError::UnknownProvider(
621 "multi-provider: no providers registered".to_string(),
622 ))
623 } else {
624 Err(ProviderError::UnknownProvider(
625 "multi-provider: no model could be routed".to_string(),
626 ))
627 }
628 } else {
629 Err(ProviderError::UnknownProvider(format!(
630 "multi-provider: all {} candidates exhausted",
631 errors.len()
632 )))
633 }
634 }
635
636 fn name(&self) -> &str {
638 "multi-provider"
639 }
640}
641
642struct Candidate {
648 provider: String,
650 model: Model,
652}
653
654impl MultiProvider {
655 async fn build_candidate_list(
661 &self,
662 incoming_model: &Model,
663 context: &Context,
664 ) -> Result<Vec<Candidate>, ProviderError> {
665 let mut candidates: Vec<Candidate> = Vec::new();
666 let mut seen_ids: HashMap<String, ()> = HashMap::new();
667
668 let add_candidate = |candidates: &mut Vec<Candidate>,
670 seen_ids: &mut HashMap<String, ()>,
671 provider: String,
672 model: Model| {
673 let id = format!("{}/{}", provider, model.id);
674 if seen_ids.insert(id, ()).is_none() {
675 candidates.push(Candidate { provider, model });
676 }
677 };
678
679 if self.config.auto_routing {
681 let complexity = self.router.classify(context);
682 let router_models = self
683 .router
684 .route(complexity, self.config.prefer_cost_efficient);
685
686 tracing::debug!(
687 complexity = ?complexity,
688 model_count = router_models.len(),
689 "MultiProvider: router selected models for complexity"
690 );
691
692 for entry in router_models {
693 if let Some(registered_model) =
695 crate::model_registry::get_model(entry.provider, entry.id)
696 {
697 if self.providers.contains_key(entry.provider) {
698 add_candidate(
699 &mut candidates,
700 &mut seen_ids,
701 entry.provider.to_string(),
702 registered_model.clone(),
703 );
704 }
705 }
706
707 if self.providers.contains_key(entry.provider) {
709 let model = self.model_from_entry(entry);
710 let id = format!("{}/{}", entry.provider, entry.id);
711 if seen_ids.insert(id, ()).is_none() {
712 candidates.push(Candidate {
713 provider: entry.provider.to_string(),
714 model,
715 });
716 }
717 }
718 }
719 }
720
721 if self.providers.contains_key(&incoming_model.provider) {
723 add_candidate(
724 &mut candidates,
725 &mut seen_ids,
726 incoming_model.provider.clone(),
727 incoming_model.clone(),
728 );
729 } else {
730 for provider_name in self.providers.keys() {
733 let model_id = &incoming_model.id;
735
736 if let Some(model) = self.find_model_for_provider(provider_name, model_id) {
738 add_candidate(&mut candidates, &mut seen_ids, provider_name.clone(), model);
739 break;
740 }
741 }
742 }
743
744 for fallback_entry in self.fallback.iter() {
746 if let Some(registered_model) =
748 crate::model_registry::get_model(fallback_entry.provider, fallback_entry.id)
749 {
750 if self.providers.contains_key(fallback_entry.provider) {
751 add_candidate(
752 &mut candidates,
753 &mut seen_ids,
754 fallback_entry.provider.to_string(),
755 registered_model.clone(),
756 );
757 }
758 } else if self.providers.contains_key(fallback_entry.provider) {
759 let model = self.model_from_entry(fallback_entry);
761 let id = format!("{}/{}", fallback_entry.provider, fallback_entry.id);
762 if seen_ids.insert(id, ()).is_none() {
763 candidates.push(Candidate {
764 provider: fallback_entry.provider.to_string(),
765 model,
766 });
767 }
768 }
769 }
770
771 if candidates.is_empty() && !self.providers.is_empty() {
773 let (provider_name, _provider) = self
775 .providers
776 .iter()
777 .next()
778 .expect("providers map is non-empty");
779 let model = self.default_model_for_provider(provider_name);
780 add_candidate(&mut candidates, &mut seen_ids, provider_name.clone(), model);
781 }
782
783 tracing::debug!(
784 candidate_count = candidates.len(),
785 "MultiProvider: built candidate list"
786 );
787
788 if candidates.is_empty() && self.providers.is_empty() {
789 return Err(ProviderError::UnknownProvider(
790 "multi-provider: no providers registered".to_string(),
791 ));
792 }
793
794 Ok(candidates)
795 }
796
797 fn model_from_entry(&self, entry: &ModelEntry) -> Model {
799 Model {
800 id: entry.id.to_string(),
801 name: entry.name.to_string(),
802 api: entry.api,
803 provider: entry.provider.to_string(),
804 base_url: String::new(), reasoning: entry.reasoning,
806 input: entry.input.to_vec(),
807 cost: crate::types::Cost {
808 input: entry.cost_input,
809 output: entry.cost_output,
810 cache_read: entry.cost_cache_read,
811 cache_write: entry.cost_cache_write,
812 },
813 context_window: entry.context_window as usize,
814 max_tokens: entry.max_tokens as usize,
815 headers: HashMap::new(),
816 compat: None,
817 }
818 }
819
820 fn find_model_for_provider(&self, provider_name: &str, model_id: &str) -> Option<Model> {
822 if let Some(model) = crate::model_registry::get_model(provider_name, model_id) {
824 return Some(model.clone());
825 }
826
827 if let Some(entry) = crate::model_db::get_model_entry(provider_name, model_id) {
829 return Some(self.model_from_entry(entry));
830 }
831
832 Some(self.construct_model_from_id(provider_name, model_id))
834 }
835
836 fn construct_model_from_id(&self, provider: &str, model_id: &str) -> Model {
841 if let Some(entry) = crate::model_db::get_model_entry(provider, model_id) {
843 return self.model_from_entry(entry);
844 }
845
846 let api = match provider {
848 "openai" | "openai-codex" | "opencode" | "opencode-go" => {
849 crate::types::Api::OpenAiResponses
850 }
851 "anthropic" | "cloudflare-ai-gateway" => crate::types::Api::AnthropicMessages,
852 "google" => crate::types::Api::GoogleGenerativeAi,
853 "google-vertex" => crate::types::Api::GoogleVertex,
854 "azure-openai" | "azure-openai-responses" => crate::types::Api::AzureOpenAiResponses,
855 "amazon-bedrock" | "bedrock" => crate::types::Api::BedrockConverseStream,
856 _ => crate::types::Api::OpenAiResponses,
857 };
858
859 Model {
860 id: model_id.to_string(),
861 name: model_id.to_string(),
862 api,
863 provider: provider.to_string(),
864 base_url: String::new(),
865 reasoning: false,
866 input: vec![crate::types::InputModality::Text],
867 cost: crate::types::Cost::default(),
868 context_window: 128_000,
869 max_tokens: 32_000,
870 headers: HashMap::new(),
871 compat: None,
872 }
873 }
874
875 fn default_model_for_provider(&self, provider_name: &str) -> Model {
880 let default_model_id = match provider_name {
882 "openai" => "gpt-4o-mini",
883 "anthropic" => "claude-sonnet-4-20250514",
884 "google" => "gemini-2.0-flash",
885 _ => return self.construct_model_from_id(provider_name, "default"),
886 };
887
888 if let Some(entry) = crate::model_db::get_model_entry(provider_name, default_model_id) {
890 return self.model_from_entry(entry);
891 }
892
893 let provider_models = crate::model_db::get_provider_models(provider_name);
895 if !provider_models.is_empty() {
896 if let Some(entry) = provider_models.last() {
898 return self.model_from_entry(entry);
899 }
900 }
901
902 self.construct_model_from_id(provider_name, "default")
904 }
905}
906
907#[cfg(test)]
912mod tests {
913 use super::*;
914 use crate::context::Context;
915 use crate::Message;
916
917 fn create_test_context() -> Context {
918 let mut ctx = Context::new();
919 ctx.add_message(Message::User(crate::UserMessage::new(
920 "Help me write a function to reverse a string".to_string(),
921 )));
922 ctx
923 }
924
925 #[test]
926 fn test_config_defaults() {
927 let config = MultiProviderConfig::default();
928 assert!(config.auto_routing);
929 assert!(config.prefer_cost_efficient);
930 assert_eq!(config.max_retries_per_model, 1);
931 assert!(config.per_model_timeout.is_none());
932 }
934
935 #[test]
936 fn test_config_builder() {
937 let config = MultiProviderConfig::default()
938 .with_auto_routing(false)
939 .with_prefer_cost_efficient(false)
940 .with_max_retries(3)
941 .with_per_model_timeout(Duration::from_secs(30));
942
943 assert!(!config.auto_routing);
944 assert!(!config.prefer_cost_efficient);
945 assert_eq!(config.max_retries_per_model, 3);
946 assert_eq!(config.per_model_timeout, Some(Duration::from_secs(30)));
947 }
948
949 #[test]
950 fn test_multi_provider_creation() {
951 let config = MultiProviderConfig::default();
952 let provider = MultiProvider::new(config);
953
954 assert_eq!(provider.name(), "multi-provider");
955 assert!(provider.provider_names().is_empty());
956 }
957
958 #[test]
959 fn test_register_provider() {
960 let mut provider = MultiProvider::new(MultiProviderConfig::default());
961
962 struct MockProvider;
964 #[async_trait]
965 impl Provider for MockProvider {
966 async fn stream(
967 &self,
968 _model: &Model,
969 _context: &Context,
970 _options: Option<StreamOptions>,
971 ) -> Result<Pin<Box<dyn Stream<Item = ProviderEvent> + Send>>, ProviderError>
972 {
973 unreachable!("Mock provider - not called in this test")
974 }
975
976 fn name(&self) -> &str {
977 "mock"
978 }
979 }
980
981 let mock = Arc::new(MockProvider);
982 provider.register_provider("test", mock);
983
984 assert_eq!(provider.provider_names(), vec!["test"]);
985 assert!(provider.get_provider("test").is_some());
986 assert!(provider.get_breaker("test").is_some());
987 }
988
989 #[test]
990 fn test_unregister_provider() {
991 let mut provider = MultiProvider::new(MultiProviderConfig::default());
992
993 struct MockProvider;
994 #[async_trait]
995 impl Provider for MockProvider {
996 async fn stream(
997 &self,
998 _model: &Model,
999 _context: &Context,
1000 _options: Option<StreamOptions>,
1001 ) -> Result<Pin<Box<dyn Stream<Item = ProviderEvent> + Send>>, ProviderError>
1002 {
1003 unreachable!("Mock provider")
1004 }
1005
1006 fn name(&self) -> &str {
1007 "mock"
1008 }
1009 }
1010
1011 let mock = Arc::new(MockProvider);
1012 provider.register_provider("test", mock.clone());
1013
1014 assert!(provider.unregister_provider("test"));
1015 assert!(provider.provider_names().is_empty());
1016 assert!(provider.get_provider("test").is_none());
1017 }
1018
1019 #[test]
1020 fn test_with_router() {
1021 let router = DefaultRouter::new();
1022 let provider = MultiProvider::with_router(router);
1023
1024 assert_eq!(provider.name(), "multi-provider");
1025 }
1026
1027 #[test]
1028 fn test_with_fallback() {
1029 let fallback = FallbackChain::from_ids(&["openai/gpt-4o"]).unwrap();
1030 let provider = MultiProvider::new(MultiProviderConfig::default()).with_fallback(fallback);
1031
1032 assert_eq!(provider.fallback().len(), 1);
1033 }
1034
1035 #[test]
1036 fn test_circuit_breaker_diagnostics() {
1037 let mut provider = MultiProvider::new(MultiProviderConfig::default());
1038
1039 struct MockProvider;
1040 #[async_trait]
1041 impl Provider for MockProvider {
1042 async fn stream(
1043 &self,
1044 _model: &Model,
1045 _context: &Context,
1046 _options: Option<StreamOptions>,
1047 ) -> Result<Pin<Box<dyn Stream<Item = ProviderEvent> + Send>>, ProviderError>
1048 {
1049 unreachable!("Mock provider")
1050 }
1051
1052 fn name(&self) -> &str {
1053 "mock"
1054 }
1055 }
1056
1057 let mock = Arc::new(MockProvider);
1058 provider.register_provider("test", mock);
1059
1060 let diagnostics = provider.circuit_breaker_diagnostics();
1061 assert_eq!(diagnostics.len(), 1);
1062 assert_eq!(diagnostics[0].provider, "test");
1063 }
1064
1065 #[test]
1066 fn test_multi_provider_error_display() {
1067 let err = MultiProviderError::NoProviderForModel("gpt-4o".to_string());
1068 assert!(err.to_string().contains("gpt-4o"));
1069
1070 let err = MultiProviderError::AllProvidersExhausted { errors: vec![] };
1071 assert!(err.to_string().contains("All providers exhausted"));
1072
1073 let err = MultiProviderError::CircuitBreakerOpen {
1074 provider: "openai".to_string(),
1075 retry_after: Duration::from_secs(10),
1076 };
1077 assert!(err.to_string().contains("openai"));
1078 assert!(err.to_string().contains("10"));
1079 }
1080
1081 #[test]
1082 fn test_multi_provider_error_helpers() {
1083 let err = MultiProviderError::CircuitBreakerOpen {
1084 provider: "openai".to_string(),
1085 retry_after: Duration::from_secs(10),
1086 };
1087 assert!(err.is_circuit_breaker());
1088 assert_eq!(err.retry_after(), Some(Duration::from_secs(10)));
1089
1090 let err = MultiProviderError::AllProvidersExhausted { errors: vec![] };
1091 assert!(!err.is_circuit_breaker());
1092 assert_eq!(err.retry_after(), None);
1093 }
1094
1095 #[test]
1096 fn test_diagnostics() {
1097 let mut provider = MultiProvider::new(MultiProviderConfig::default());
1098
1099 struct MockProvider;
1100 #[async_trait]
1101 impl Provider for MockProvider {
1102 async fn stream(
1103 &self,
1104 _model: &Model,
1105 _context: &Context,
1106 _options: Option<StreamOptions>,
1107 ) -> Result<Pin<Box<dyn Stream<Item = ProviderEvent> + Send>>, ProviderError>
1108 {
1109 unreachable!("Mock provider")
1110 }
1111
1112 fn name(&self) -> &str {
1113 "mock"
1114 }
1115 }
1116
1117 let mock = Arc::new(MockProvider);
1118 provider.register_provider("test", mock);
1119
1120 let diag = provider.diagnostics();
1121 assert_eq!(diag.provider_count, 1);
1122 assert!(diag.auto_routing);
1123 assert!(diag.prefer_cost_efficient);
1124 assert_eq!(diag.circuit_breakers.len(), 1);
1125 }
1126
1127 #[test]
1128 fn test_router_classification() {
1129 use crate::Complexity;
1130 let router = DefaultRouter::new();
1131 let provider = MultiProvider::with_router(router);
1132
1133 let ctx = create_test_context();
1134 let complexity = provider.router().classify(&ctx);
1135
1136 assert!(complexity >= Complexity::Simple);
1138 }
1139}