1use std::collections::HashMap;
55use std::sync::Arc;
56use std::time::Duration;
57
58use async_trait::async_trait;
59use futures::Stream;
60use std::pin::Pin;
61
62use crate::{
63 circuit_breaker::{CircuitBreakerConfig, ProviderCircuitBreaker},
64 complexity_router::{ComplexityRouter, DefaultRouter},
65 context::Context,
66 error::ProviderError,
67 fallback_chain::FallbackChain,
68 model_db::ModelEntry,
69 providers::{FallbackReason, Provider, ProviderEvent, StreamOptions},
70 Model,
71};
72
73#[derive(Debug, Clone)]
81pub struct MultiProviderConfig {
82 pub auto_routing: bool,
89
90 pub prefer_cost_efficient: bool,
97
98 pub max_retries_per_model: usize,
105
106 pub per_model_timeout: Option<Duration>,
113
114 pub circuit_breaker: CircuitBreakerConfig,
121}
122
123impl Default for MultiProviderConfig {
124 fn default() -> Self {
125 Self {
126 auto_routing: true,
127 prefer_cost_efficient: true,
128 max_retries_per_model: 1,
129 per_model_timeout: None,
130 circuit_breaker: CircuitBreakerConfig::default(),
131 }
132 }
133}
134
135impl MultiProviderConfig {
136 #[must_use]
138 pub fn with_auto_routing(mut self, enabled: bool) -> Self {
139 self.auto_routing = enabled;
140 self
141 }
142
143 #[must_use]
145 pub fn with_prefer_cost_efficient(mut self, enabled: bool) -> Self {
146 self.prefer_cost_efficient = enabled;
147 self
148 }
149
150 #[must_use]
152 pub fn with_max_retries(mut self, retries: usize) -> Self {
153 self.max_retries_per_model = retries;
154 self
155 }
156
157 #[must_use]
159 pub fn with_per_model_timeout(mut self, timeout: Duration) -> Self {
160 self.per_model_timeout = Some(timeout);
161 self
162 }
163
164 #[must_use]
166 pub fn with_circuit_breaker(mut self, config: CircuitBreakerConfig) -> Self {
167 self.circuit_breaker = config;
168 self
169 }
170}
171
172#[derive(Debug, thiserror::Error)]
178pub enum MultiProviderError {
179 #[error("All providers exhausted")]
183 AllProvidersExhausted {
184 errors: Vec<(String, ProviderError)>,
186 },
187
188 #[error("No provider available for model: {0}")]
190 NoProviderForModel(String),
191
192 #[error("Circuit breaker open: {provider} (retry after {retry_after:?})")]
196 CircuitBreakerOpen {
197 provider: String,
199 retry_after: Duration,
201 },
202
203 #[error("No fallback models configured and primary provider failed")]
205 NoFallback,
206
207 #[error("No provider registered")]
209 NoProviderRegistered,
210}
211
212impl MultiProviderError {
213 pub fn is_circuit_breaker(&self) -> bool {
215 matches!(self, Self::CircuitBreakerOpen { .. })
216 }
217
218 pub fn retry_after(&self) -> Option<Duration> {
220 match self {
221 Self::CircuitBreakerOpen { retry_after, .. } => Some(*retry_after),
222 _ => None,
223 }
224 }
225}
226
227pub struct MultiProvider {
242 router: Arc<dyn ComplexityRouter>,
244
245 providers: HashMap<String, Arc<dyn Provider>>,
247
248 fallback: FallbackChain,
250
251 breakers: HashMap<String, Arc<ProviderCircuitBreaker>>,
253
254 config: MultiProviderConfig,
256}
257
258impl MultiProvider {
259 pub fn new(config: MultiProviderConfig) -> Self {
270 Self {
271 router: Arc::new(DefaultRouter::new()),
272 providers: HashMap::new(),
273 fallback: FallbackChain::default(),
274 breakers: HashMap::new(),
275 config,
276 }
277 }
278
279 pub fn with_router(router: impl ComplexityRouter + 'static) -> Self {
290 Self {
291 router: Arc::new(router),
292 providers: HashMap::new(),
293 fallback: FallbackChain::default(),
294 breakers: HashMap::new(),
295 config: MultiProviderConfig::default(),
296 }
297 }
298
299 pub fn with_config_and_router(
311 config: MultiProviderConfig,
312 router: impl ComplexityRouter + 'static,
313 ) -> Self {
314 Self {
315 router: Arc::new(router),
316 providers: HashMap::new(),
317 fallback: FallbackChain::default(),
318 breakers: HashMap::new(),
319 config,
320 }
321 }
322
323 pub fn set_router(mut self, router: impl ComplexityRouter + 'static) -> Self {
331 self.router = Arc::new(router);
332 self
333 }
334
335 pub fn with_fallback(mut self, fallback: FallbackChain) -> Self {
346 self.fallback = fallback;
347 self
348 }
349
350 pub fn set_fallback(&mut self, fallback: FallbackChain) {
352 self.fallback = fallback;
353 }
354
355 pub fn register_provider(&mut self, name: &str, provider: Arc<dyn Provider>) {
372 let breaker = Arc::new(ProviderCircuitBreaker::new(
374 name.to_string(),
375 self.config.circuit_breaker.clone(),
376 ));
377
378 self.providers.insert(name.to_string(), provider);
379 self.breakers.insert(name.to_string(), breaker);
380 }
381
382 pub fn unregister_provider(&mut self, name: &str) -> bool {
394 let provider_removed = self.providers.remove(name).is_some();
395 let breaker_removed = self.breakers.remove(name).is_some();
396 provider_removed || breaker_removed
397 }
398
399 pub fn get_provider(&self, name: &str) -> Option<&Arc<dyn Provider>> {
409 self.providers.get(name)
410 }
411
412 pub fn get_breaker(&self, provider_name: &str) -> Option<Arc<ProviderCircuitBreaker>> {
422 self.breakers.get(provider_name).cloned()
423 }
424
425 pub fn provider_names(&self) -> Vec<&str> {
427 self.providers.keys().map(|s| s.as_str()).collect()
428 }
429
430 pub fn circuit_breaker_diagnostics(
436 &self,
437 ) -> Vec<crate::circuit_breaker::CircuitBreakerDiagnostics> {
438 self.breakers.values().map(|b| b.diagnostics()).collect()
439 }
440
441 pub fn router(&self) -> &Arc<dyn ComplexityRouter> {
443 &self.router
444 }
445
446 pub fn fallback(&self) -> &FallbackChain {
448 &self.fallback
449 }
450
451 pub fn config(&self) -> &MultiProviderConfig {
453 &self.config
454 }
455
456 pub fn diagnostics(&self) -> MultiProviderDiagnostics {
458 MultiProviderDiagnostics {
459 provider_count: self.providers.len(),
460 router_type: "DefaultRouter".to_string(),
461 fallback_len: self.fallback.len(),
462 auto_routing: self.config.auto_routing,
463 prefer_cost_efficient: self.config.prefer_cost_efficient,
464 circuit_breakers: self.circuit_breaker_diagnostics(),
465 }
466 }
467}
468
469#[derive(Debug, Clone)]
475pub struct MultiProviderDiagnostics {
476 pub provider_count: usize,
478 pub router_type: String,
480 pub fallback_len: usize,
482 pub auto_routing: bool,
484 pub prefer_cost_efficient: bool,
486 pub circuit_breakers: Vec<crate::circuit_breaker::CircuitBreakerDiagnostics>,
488}
489
490use futures::stream::Stream as StreamTrait;
495
496struct FallbackStream {
499 fallback_event: ProviderEvent,
501 emitted: bool,
503 inner: Pin<Box<dyn Stream<Item = ProviderEvent> + Send>>,
505}
506
507impl FallbackStream {
508 fn new(
510 from_model: String,
511 to_model: String,
512 reason: FallbackReason,
513 inner: Pin<Box<dyn Stream<Item = ProviderEvent> + Send>>,
514 ) -> Self {
515 Self {
516 fallback_event: ProviderEvent::FallbackStart {
517 from_model,
518 to_model,
519 reason,
520 },
521 emitted: false,
522 inner,
523 }
524 }
525}
526
527impl StreamTrait for FallbackStream {
528 type Item = ProviderEvent;
529
530 fn poll_next(
531 mut self: std::pin::Pin<&mut Self>,
532 cx: &mut std::task::Context<'_>,
533 ) -> std::task::Poll<Option<Self::Item>> {
534 if !self.emitted {
536 self.emitted = true;
537 return std::task::Poll::Ready(Some(self.fallback_event.clone()));
538 }
539
540 Stream::poll_next(self.inner.as_mut(), cx)
542 }
543}
544
545struct FallbackExhaustedStream {
548 exhausted_event: ProviderEvent,
550 emitted: bool,
552 inner: Option<Pin<Box<dyn Stream<Item = ProviderEvent> + Send>>>,
554}
555
556impl FallbackExhaustedStream {
557 fn new(models_tried: Vec<String>, final_error: String) -> Self {
559 Self {
560 exhausted_event: ProviderEvent::FallbackExhausted {
561 models_tried,
562 final_error,
563 },
564 emitted: false,
565 inner: None,
566 }
567 }
568
569 #[allow(dead_code)]
571 fn with_inner(mut self, inner: Pin<Box<dyn Stream<Item = ProviderEvent> + Send>>) -> Self {
572 self.inner = Some(inner);
573 self
574 }
575}
576
577impl StreamTrait for FallbackExhaustedStream {
578 type Item = ProviderEvent;
579
580 fn poll_next(
581 mut self: std::pin::Pin<&mut Self>,
582 cx: &mut std::task::Context<'_>,
583 ) -> std::task::Poll<Option<Self::Item>> {
584 if !self.emitted {
586 self.emitted = true;
587 return std::task::Poll::Ready(Some(self.exhausted_event.clone()));
588 }
589
590 if let Some(ref mut inner) = self.inner {
592 Stream::poll_next(inner.as_mut(), cx)
593 } else {
594 std::task::Poll::Ready(None)
595 }
596 }
597}
598
599fn error_to_fallback_reason(error: &ProviderError) -> FallbackReason {
601 match error {
602 ProviderError::HttpError(429, _) => FallbackReason::RateLimit,
603 ProviderError::HttpError(code, _) if *code >= 500 => FallbackReason::ServerError,
604 ProviderError::HttpError(code, _) if *code == 401 || *code == 403 => {
605 FallbackReason::AuthError
606 }
607 ProviderError::RequestFailed(_) => FallbackReason::NetworkError,
608 ProviderError::Timeout => FallbackReason::NetworkError,
609 ProviderError::ContextOverflow => FallbackReason::ContextOverflow,
610 _ => FallbackReason::Unknown,
611 }
612}
613
614#[async_trait]
619impl Provider for MultiProvider {
620 async fn stream(
635 &self,
636 model: &Model,
637 context: &Context,
638 options: Option<StreamOptions>,
639 ) -> Result<Pin<Box<dyn Stream<Item = ProviderEvent> + Send>>, ProviderError> {
640 let candidates = self.build_candidate_list(model, context).await?;
642
643 let mut errors: Vec<(String, ProviderError)> = Vec::new();
645 let mut current_candidate_idx: usize = 0;
646
647 while current_candidate_idx < candidates.len() {
648 let candidate = &candidates[current_candidate_idx];
649 let provider_name = &candidate.provider;
650 let candidate_model = candidate.model.clone();
651
652 let Some(provider) = self.providers.get(provider_name) else {
654 current_candidate_idx += 1;
655 continue;
656 };
657
658 if let Some(breaker) = self.breakers.get(provider_name) {
660 match breaker.allow_request() {
661 Ok(()) => {
662 }
664 Err(e) => {
665 tracing::debug!(
667 provider = %provider_name,
668 remaining = ?e.remaining,
669 "Circuit breaker open, skipping provider"
670 );
671 current_candidate_idx += 1;
672 continue;
673 }
674 }
675 }
676
677 let mut retry_count = 0;
679 let max_retries = self.config.max_retries_per_model;
680
681 loop {
682 match provider
683 .stream(&candidate_model, context, options.clone())
684 .await
685 {
686 Ok(inner_stream) => {
687 if let Some(breaker) = self.breakers.get(provider_name) {
689 breaker.record_success();
690 }
691 tracing::debug!(
692 provider = %provider_name,
693 model = %candidate_model.id,
694 "MultiProvider: stream successful"
695 );
696
697 if current_candidate_idx > 0 {
699 let from_model = format!(
700 "{}/{}",
701 candidates[current_candidate_idx - 1].provider,
702 candidates[current_candidate_idx - 1].model.id
703 );
704 let to_model = format!("{}/{}", provider_name, candidate_model.id);
705 let reason = errors
706 .last()
707 .map(|(_, e)| error_to_fallback_reason(e))
708 .unwrap_or(FallbackReason::Unknown);
709
710 let wrapped =
711 FallbackStream::new(from_model, to_model, reason, inner_stream);
712 return Ok(Box::pin(wrapped) as Pin<Box<_>>);
713 }
714
715 return Ok(inner_stream);
716 }
717 Err(e) => {
718 if e.is_retryable() && retry_count < max_retries {
720 retry_count += 1;
722 if let Some(breaker) = self.breakers.get(provider_name) {
723 breaker.record_failure();
724 }
725 tracing::debug!(
726 provider = %provider_name,
727 model = %candidate_model.id,
728 error = %e,
729 retry = retry_count,
730 "Retryable error, retrying"
731 );
732 continue;
733 }
734
735 if !e.is_retryable() {
737 tracing::warn!(
740 provider = %provider_name,
741 model = %candidate_model.id,
742 error = %e,
743 "Non-retryable error, returning immediately"
744 );
745 return Err(e);
746 }
747
748 tracing::debug!(
750 provider = %provider_name,
751 model = %candidate_model.id,
752 error = %e,
753 retries = retry_count,
754 "Max retries exceeded, trying next candidate"
755 );
756 errors.push((format!("{}/{}", provider_name, candidate_model.id), e));
757 break;
758 }
759 }
760 }
761
762 current_candidate_idx += 1;
763 }
764
765 if errors.is_empty() {
767 if self.providers.is_empty() {
768 Err(ProviderError::UnknownProvider(
769 "multi-provider: no providers registered".to_string(),
770 ))
771 } else {
772 Err(ProviderError::UnknownProvider(
773 "multi-provider: no model could be routed".to_string(),
774 ))
775 }
776 } else {
777 let models_tried: Vec<String> = errors.iter().map(|(m, _)| m.clone()).collect();
779 let final_error = errors
780 .last()
781 .map(|(_, e)| e.to_string())
782 .unwrap_or_else(|| "Unknown error".to_string());
783
784 tracing::warn!(
785 models_tried = ?models_tried,
786 error = %final_error,
787 "All fallback models exhausted"
788 );
789
790 let stream = FallbackExhaustedStream::new(models_tried, final_error);
791 Ok(Box::pin(stream) as Pin<Box<_>>)
792 }
793 }
794
795 fn name(&self) -> &str {
797 "multi-provider"
798 }
799}
800
801struct Candidate {
807 provider: String,
809 model: Model,
811}
812
813impl MultiProvider {
814 async fn build_candidate_list(
820 &self,
821 incoming_model: &Model,
822 context: &Context,
823 ) -> Result<Vec<Candidate>, ProviderError> {
824 let mut candidates: Vec<Candidate> = Vec::new();
825 let mut seen_ids: HashMap<String, ()> = HashMap::new();
826
827 let add_candidate = |candidates: &mut Vec<Candidate>,
829 seen_ids: &mut HashMap<String, ()>,
830 provider: String,
831 model: Model| {
832 let id = format!("{}/{}", provider, model.id);
833 if seen_ids.insert(id, ()).is_none() {
834 candidates.push(Candidate { provider, model });
835 }
836 };
837
838 if self.config.auto_routing {
840 let complexity = self.router.classify(context);
841 let router_models = self
842 .router
843 .route(complexity, self.config.prefer_cost_efficient);
844
845 tracing::debug!(
846 complexity = ?complexity,
847 model_count = router_models.len(),
848 "MultiProvider: router selected models for complexity"
849 );
850
851 for entry in router_models {
852 if let Some(registered_model) =
854 crate::model_registry::get_model(entry.provider, entry.id)
855 {
856 if self.providers.contains_key(entry.provider) {
857 add_candidate(
858 &mut candidates,
859 &mut seen_ids,
860 entry.provider.to_string(),
861 registered_model.clone(),
862 );
863 }
864 }
865
866 if self.providers.contains_key(entry.provider) {
868 let model = self.model_from_entry(entry);
869 let id = format!("{}/{}", entry.provider, entry.id);
870 if seen_ids.insert(id, ()).is_none() {
871 candidates.push(Candidate {
872 provider: entry.provider.to_string(),
873 model,
874 });
875 }
876 }
877 }
878 }
879
880 if self.providers.contains_key(&incoming_model.provider) {
882 add_candidate(
883 &mut candidates,
884 &mut seen_ids,
885 incoming_model.provider.clone(),
886 incoming_model.clone(),
887 );
888 } else {
889 for provider_name in self.providers.keys() {
892 let model_id = &incoming_model.id;
894
895 if let Some(model) = self.find_model_for_provider(provider_name, model_id) {
897 add_candidate(&mut candidates, &mut seen_ids, provider_name.clone(), model);
898 break;
899 }
900 }
901 }
902
903 for fallback_entry in self.fallback.iter() {
905 if let Some(registered_model) =
907 crate::model_registry::get_model(fallback_entry.provider, fallback_entry.id)
908 {
909 if self.providers.contains_key(fallback_entry.provider) {
910 add_candidate(
911 &mut candidates,
912 &mut seen_ids,
913 fallback_entry.provider.to_string(),
914 registered_model.clone(),
915 );
916 }
917 } else if self.providers.contains_key(fallback_entry.provider) {
918 let model = self.model_from_entry(fallback_entry);
920 let id = format!("{}/{}", fallback_entry.provider, fallback_entry.id);
921 if seen_ids.insert(id, ()).is_none() {
922 candidates.push(Candidate {
923 provider: fallback_entry.provider.to_string(),
924 model,
925 });
926 }
927 }
928 }
929
930 if candidates.is_empty() && !self.providers.is_empty() {
932 let (provider_name, _provider) = self
934 .providers
935 .iter()
936 .next()
937 .expect("providers map is non-empty");
938 let model = self.default_model_for_provider(provider_name);
939 add_candidate(&mut candidates, &mut seen_ids, provider_name.clone(), model);
940 }
941
942 tracing::debug!(
943 candidate_count = candidates.len(),
944 "MultiProvider: built candidate list"
945 );
946
947 if candidates.is_empty() && self.providers.is_empty() {
948 return Err(ProviderError::UnknownProvider(
949 "multi-provider: no providers registered".to_string(),
950 ));
951 }
952
953 Ok(candidates)
954 }
955
956 fn model_from_entry(&self, entry: &ModelEntry) -> Model {
958 Model {
959 id: entry.id.to_string(),
960 name: entry.name.to_string(),
961 api: entry.api,
962 provider: entry.provider.to_string(),
963 base_url: String::new(), reasoning: entry.reasoning,
965 input: entry.input.to_vec(),
966 cost: crate::types::Cost {
967 input: entry.cost_input,
968 output: entry.cost_output,
969 cache_read: entry.cost_cache_read,
970 cache_write: entry.cost_cache_write,
971 },
972 context_window: entry.context_window as usize,
973 max_tokens: entry.max_tokens as usize,
974 headers: HashMap::new(),
975 compat: None,
976 }
977 }
978
979 fn find_model_for_provider(&self, provider_name: &str, model_id: &str) -> Option<Model> {
981 if let Some(model) = crate::model_registry::get_model(provider_name, model_id) {
983 return Some(model.clone());
984 }
985
986 if let Some(entry) = crate::model_db::get_model_entry(provider_name, model_id) {
988 return Some(self.model_from_entry(entry));
989 }
990
991 Some(self.construct_model_from_id(provider_name, model_id))
993 }
994
995 fn construct_model_from_id(&self, provider: &str, model_id: &str) -> Model {
1000 if let Some(entry) = crate::model_db::get_model_entry(provider, model_id) {
1002 return self.model_from_entry(entry);
1003 }
1004
1005 let api = match provider {
1007 "openai" | "openai-codex" | "opencode" | "opencode-go" => {
1008 crate::types::Api::OpenAiResponses
1009 }
1010 "anthropic" | "cloudflare-ai-gateway" => crate::types::Api::AnthropicMessages,
1011 "google" => crate::types::Api::GoogleGenerativeAi,
1012 "google-vertex" => crate::types::Api::GoogleVertex,
1013 "azure-openai" | "azure-openai-responses" => crate::types::Api::AzureOpenAiResponses,
1014 "amazon-bedrock" | "bedrock" => crate::types::Api::BedrockConverseStream,
1015 _ => crate::types::Api::OpenAiResponses,
1016 };
1017
1018 Model {
1019 id: model_id.to_string(),
1020 name: model_id.to_string(),
1021 api,
1022 provider: provider.to_string(),
1023 base_url: String::new(),
1024 reasoning: false,
1025 input: vec![crate::types::InputModality::Text],
1026 cost: crate::types::Cost::default(),
1027 context_window: 128_000,
1028 max_tokens: 32_000,
1029 headers: HashMap::new(),
1030 compat: None,
1031 }
1032 }
1033
1034 fn default_model_for_provider(&self, provider_name: &str) -> Model {
1039 let default_model_id = match provider_name {
1041 "openai" => "gpt-4o-mini",
1042 "anthropic" => "claude-sonnet-4-20250514",
1043 "google" => "gemini-2.0-flash",
1044 _ => return self.construct_model_from_id(provider_name, "default"),
1045 };
1046
1047 if let Some(entry) = crate::model_db::get_model_entry(provider_name, default_model_id) {
1049 return self.model_from_entry(entry);
1050 }
1051
1052 let provider_models = crate::model_db::get_provider_models(provider_name);
1054 if !provider_models.is_empty() {
1055 if let Some(entry) = provider_models.last() {
1057 return self.model_from_entry(entry);
1058 }
1059 }
1060
1061 self.construct_model_from_id(provider_name, "default")
1063 }
1064}
1065
1066#[cfg(test)]
1071mod tests {
1072 use super::*;
1073 use crate::context::Context;
1074 use crate::Message;
1075
1076 fn create_test_context() -> Context {
1077 let mut ctx = Context::new();
1078 ctx.add_message(Message::User(crate::UserMessage::new(
1079 "Help me write a function to reverse a string".to_string(),
1080 )));
1081 ctx
1082 }
1083
1084 #[test]
1085 fn test_config_defaults() {
1086 let config = MultiProviderConfig::default();
1087 assert!(config.auto_routing);
1088 assert!(config.prefer_cost_efficient);
1089 assert_eq!(config.max_retries_per_model, 1);
1090 assert!(config.per_model_timeout.is_none());
1091 }
1093
1094 #[test]
1095 fn test_config_builder() {
1096 let config = MultiProviderConfig::default()
1097 .with_auto_routing(false)
1098 .with_prefer_cost_efficient(false)
1099 .with_max_retries(3)
1100 .with_per_model_timeout(Duration::from_secs(30));
1101
1102 assert!(!config.auto_routing);
1103 assert!(!config.prefer_cost_efficient);
1104 assert_eq!(config.max_retries_per_model, 3);
1105 assert_eq!(config.per_model_timeout, Some(Duration::from_secs(30)));
1106 }
1107
1108 #[test]
1109 fn test_multi_provider_creation() {
1110 let config = MultiProviderConfig::default();
1111 let provider = MultiProvider::new(config);
1112
1113 assert_eq!(provider.name(), "multi-provider");
1114 assert!(provider.provider_names().is_empty());
1115 }
1116
1117 #[test]
1118 fn test_register_provider() {
1119 let mut provider = MultiProvider::new(MultiProviderConfig::default());
1120
1121 struct MockProvider;
1123 #[async_trait]
1124 impl Provider for MockProvider {
1125 async fn stream(
1126 &self,
1127 _model: &Model,
1128 _context: &Context,
1129 _options: Option<StreamOptions>,
1130 ) -> Result<Pin<Box<dyn Stream<Item = ProviderEvent> + Send>>, ProviderError>
1131 {
1132 unreachable!("Mock provider - not called in this test")
1133 }
1134
1135 fn name(&self) -> &str {
1136 "mock"
1137 }
1138 }
1139
1140 let mock = Arc::new(MockProvider);
1141 provider.register_provider("test", mock);
1142
1143 assert_eq!(provider.provider_names(), vec!["test"]);
1144 assert!(provider.get_provider("test").is_some());
1145 assert!(provider.get_breaker("test").is_some());
1146 }
1147
1148 #[test]
1149 fn test_unregister_provider() {
1150 let mut provider = MultiProvider::new(MultiProviderConfig::default());
1151
1152 struct MockProvider;
1153 #[async_trait]
1154 impl Provider for MockProvider {
1155 async fn stream(
1156 &self,
1157 _model: &Model,
1158 _context: &Context,
1159 _options: Option<StreamOptions>,
1160 ) -> Result<Pin<Box<dyn Stream<Item = ProviderEvent> + Send>>, ProviderError>
1161 {
1162 unreachable!("Mock provider")
1163 }
1164
1165 fn name(&self) -> &str {
1166 "mock"
1167 }
1168 }
1169
1170 let mock = Arc::new(MockProvider);
1171 provider.register_provider("test", mock.clone());
1172
1173 assert!(provider.unregister_provider("test"));
1174 assert!(provider.provider_names().is_empty());
1175 assert!(provider.get_provider("test").is_none());
1176 }
1177
1178 #[test]
1179 fn test_with_router() {
1180 let router = DefaultRouter::new();
1181 let provider = MultiProvider::with_router(router);
1182
1183 assert_eq!(provider.name(), "multi-provider");
1184 }
1185
1186 #[test]
1187 fn test_with_fallback() {
1188 let fallback = FallbackChain::from_ids(&["openai/gpt-4o"]).unwrap();
1189 let provider = MultiProvider::new(MultiProviderConfig::default()).with_fallback(fallback);
1190
1191 assert_eq!(provider.fallback().len(), 1);
1192 }
1193
1194 #[test]
1195 fn test_circuit_breaker_diagnostics() {
1196 let mut provider = MultiProvider::new(MultiProviderConfig::default());
1197
1198 struct MockProvider;
1199 #[async_trait]
1200 impl Provider for MockProvider {
1201 async fn stream(
1202 &self,
1203 _model: &Model,
1204 _context: &Context,
1205 _options: Option<StreamOptions>,
1206 ) -> Result<Pin<Box<dyn Stream<Item = ProviderEvent> + Send>>, ProviderError>
1207 {
1208 unreachable!("Mock provider")
1209 }
1210
1211 fn name(&self) -> &str {
1212 "mock"
1213 }
1214 }
1215
1216 let mock = Arc::new(MockProvider);
1217 provider.register_provider("test", mock);
1218
1219 let diagnostics = provider.circuit_breaker_diagnostics();
1220 assert_eq!(diagnostics.len(), 1);
1221 assert_eq!(diagnostics[0].provider, "test");
1222 }
1223
1224 #[test]
1225 fn test_multi_provider_error_display() {
1226 let err = MultiProviderError::NoProviderForModel("gpt-4o".to_string());
1227 assert!(err.to_string().contains("gpt-4o"));
1228
1229 let err = MultiProviderError::AllProvidersExhausted { errors: vec![] };
1230 assert!(err.to_string().contains("All providers exhausted"));
1231
1232 let err = MultiProviderError::CircuitBreakerOpen {
1233 provider: "openai".to_string(),
1234 retry_after: Duration::from_secs(10),
1235 };
1236 assert!(err.to_string().contains("openai"));
1237 assert!(err.to_string().contains("10"));
1238 }
1239
1240 #[test]
1241 fn test_multi_provider_error_helpers() {
1242 let err = MultiProviderError::CircuitBreakerOpen {
1243 provider: "openai".to_string(),
1244 retry_after: Duration::from_secs(10),
1245 };
1246 assert!(err.is_circuit_breaker());
1247 assert_eq!(err.retry_after(), Some(Duration::from_secs(10)));
1248
1249 let err = MultiProviderError::AllProvidersExhausted { errors: vec![] };
1250 assert!(!err.is_circuit_breaker());
1251 assert_eq!(err.retry_after(), None);
1252 }
1253
1254 #[test]
1255 fn test_diagnostics() {
1256 let mut provider = MultiProvider::new(MultiProviderConfig::default());
1257
1258 struct MockProvider;
1259 #[async_trait]
1260 impl Provider for MockProvider {
1261 async fn stream(
1262 &self,
1263 _model: &Model,
1264 _context: &Context,
1265 _options: Option<StreamOptions>,
1266 ) -> Result<Pin<Box<dyn Stream<Item = ProviderEvent> + Send>>, ProviderError>
1267 {
1268 unreachable!("Mock provider")
1269 }
1270
1271 fn name(&self) -> &str {
1272 "mock"
1273 }
1274 }
1275
1276 let mock = Arc::new(MockProvider);
1277 provider.register_provider("test", mock);
1278
1279 let diag = provider.diagnostics();
1280 assert_eq!(diag.provider_count, 1);
1281 assert!(diag.auto_routing);
1282 assert!(diag.prefer_cost_efficient);
1283 assert_eq!(diag.circuit_breakers.len(), 1);
1284 }
1285
1286 #[test]
1287 fn test_router_classification() {
1288 use crate::Complexity;
1289 let router = DefaultRouter::new();
1290 let provider = MultiProvider::with_router(router);
1291
1292 let ctx = create_test_context();
1293 let complexity = provider.router().classify(&ctx);
1294
1295 assert!(complexity >= Complexity::Simple);
1297 }
1298}