Skip to main content

oxi_ai/
multi_provider.rs

1//! MultiProvider — intelligent routing with fallback
2//!
3//! This module provides the core routing provider that ties together
4//! ComplexityRouter, CircuitBreaker, and FallbackChain to implement
5//! intelligent model selection with automatic failover.
6//!
7//! # Architecture
8//!
9//! MultiProvider orchestrates multiple components:
10//! - **ComplexityRouter**: Classifies task complexity and selects appropriate models
11//! - **CircuitBreaker**: Prevents cascading failures by tracking provider health
12//! - **FallbackChain**: Provides ordered fallback when primary models fail
13//!
14//! # Priority Order (from design §8.3)
15//!
16//! When `auto_routing=true`:
17//! 1. Router's best model (based on classified complexity)
18//! 2. Incoming model (if registered and healthy)
19//! 3. Fallback chain (if configured)
20//!
21//! When `auto_routing=false`:
22//! 1. Incoming model (if registered and healthy)
23//! 2. Fallback chain (if configured)
24//!
25//! # Error Handling
26//!
27//! - **Retryable errors** (429, 5xx, network, timeout): Record failure, try next model
28//! - **Non-retryable errors** (400, 401, 403, etc.): Return immediately without recording failure
29//!
30//! # Example
31//!
32//! ```ignore
33//! use oxi_ai::multi_provider::{MultiProvider, MultiProviderConfig};
34//! use oxi_ai::fallback_chain::FallbackChain;
35//!
36//! let config = MultiProviderConfig::default();
37//! let mut provider = MultiProvider::new(config);
38//!
39//! // Register providers
40//! provider.register_provider("openai", Arc::new(openai_provider));
41//! provider.register_provider("anthropic", Arc::new(anthropic_provider));
42//!
43//! // Set fallback chain
44//! let fallback = FallbackChain::from_ids(&[
45//!     "anthropic/claude-sonnet-4-20250514",
46//!     "openai/gpt-4o",
47//! ])?;
48//! provider.with_fallback(fallback);
49//!
50//! // Use like any Provider
51//! let stream = provider.stream(&model, &context, None).await?;
52//! ```
53
54use std::collections::HashMap;
55use std::sync::Arc;
56use std::time::Duration;
57
58use async_trait::async_trait;
59use futures::Stream;
60use std::pin::Pin;
61
62use crate::{
63    circuit_breaker::{CircuitBreakerConfig, ProviderCircuitBreaker},
64    complexity_router::{ComplexityRouter, DefaultRouter},
65    context::Context,
66    error::ProviderError,
67    fallback_chain::FallbackChain,
68    model_db::ModelEntry,
69    providers::{FallbackReason, Provider, ProviderEvent, StreamOptions},
70    Model,
71};
72
73// ============================================================================
74// Configuration
75// ============================================================================
76
77/// Configuration for MultiProvider behavior.
78///
79/// Controls auto-routing, cost preference, retry behavior, and circuit breaker settings.
80#[derive(Debug, Clone)]
81pub struct MultiProviderConfig {
82    /// Enable automatic complexity-based routing.
83    ///
84    /// When `true`, the router classifies task complexity and selects
85    /// appropriate models before falling back to the incoming model.
86    ///
87    /// Default: `true`
88    pub auto_routing: bool,
89
90    /// Prefer cost-efficient models when routing.
91    ///
92    /// When `true` and `auto_routing` is enabled, the router selects
93    /// cheaper models that still meet the complexity requirements.
94    ///
95    /// Default: `true`
96    pub prefer_cost_efficient: bool,
97
98    /// Maximum retries per model before giving up.
99    ///
100    /// For each model in the candidate list, we retry this many times
101    /// on retryable errors before moving to the next model.
102    ///
103    /// Default: `1`
104    pub max_retries_per_model: usize,
105
106    /// Per-model timeout for requests.
107    ///
108    /// If set, wraps the request in a timeout. If `None`, uses the
109    /// provider's default timeout.
110    ///
111    /// Default: `None`
112    pub per_model_timeout: Option<Duration>,
113
114    /// Circuit breaker configuration for all providers.
115    ///
116    /// Each registered provider gets its own circuit breaker instance
117    /// with this configuration.
118    ///
119    /// Default: `CircuitBreakerConfig::default()`
120    pub circuit_breaker: CircuitBreakerConfig,
121}
122
123impl Default for MultiProviderConfig {
124    fn default() -> Self {
125        Self {
126            auto_routing: true,
127            prefer_cost_efficient: true,
128            max_retries_per_model: 1,
129            per_model_timeout: None,
130            circuit_breaker: CircuitBreakerConfig::default(),
131        }
132    }
133}
134
135impl MultiProviderConfig {
136    /// Enable or disable automatic routing.
137    #[must_use]
138    pub fn with_auto_routing(mut self, enabled: bool) -> Self {
139        self.auto_routing = enabled;
140        self
141    }
142
143    /// Enable or disable cost-efficient preference.
144    #[must_use]
145    pub fn with_prefer_cost_efficient(mut self, enabled: bool) -> Self {
146        self.prefer_cost_efficient = enabled;
147        self
148    }
149
150    /// Set the maximum retries per model.
151    #[must_use]
152    pub fn with_max_retries(mut self, retries: usize) -> Self {
153        self.max_retries_per_model = retries;
154        self
155    }
156
157    /// Set the per-model timeout.
158    #[must_use]
159    pub fn with_per_model_timeout(mut self, timeout: Duration) -> Self {
160        self.per_model_timeout = Some(timeout);
161        self
162    }
163
164    /// Set the circuit breaker configuration.
165    #[must_use]
166    pub fn with_circuit_breaker(mut self, config: CircuitBreakerConfig) -> Self {
167        self.circuit_breaker = config;
168        self
169    }
170}
171
172// ============================================================================
173// Error Types
174// ============================================================================
175
176/// Errors that can occur in MultiProvider operations.
177#[derive(Debug, thiserror::Error)]
178pub enum MultiProviderError {
179    /// All providers in the candidate list have failed.
180    ///
181    /// Contains the list of errors from each provider for debugging.
182    #[error("All providers exhausted")]
183    AllProvidersExhausted {
184        /// Errors from each provider in order of attempt.
185        errors: Vec<(String, ProviderError)>,
186    },
187
188    /// No provider is registered that can handle the requested model.
189    #[error("No provider available for model: {0}")]
190    NoProviderForModel(String),
191
192    /// Circuit breaker is open for the provider.
193    ///
194    /// The provider should be retried after `retry_after` duration.
195    #[error("Circuit breaker open: {provider} (retry after {retry_after:?})")]
196    CircuitBreakerOpen {
197        /// Name of the provider whose circuit is open.
198        provider: String,
199        /// Duration to wait before retrying.
200        retry_after: Duration,
201    },
202
203    /// No fallback models configured and the primary provider failed.
204    #[error("No fallback models configured and primary provider failed")]
205    NoFallback,
206
207    /// No providers are registered with this MultiProvider.
208    #[error("No provider registered")]
209    NoProviderRegistered,
210}
211
212impl MultiProviderError {
213    /// Check if this is a circuit breaker error.
214    pub fn is_circuit_breaker(&self) -> bool {
215        matches!(self, Self::CircuitBreakerOpen { .. })
216    }
217
218    /// Get the retry duration if this is a circuit breaker error.
219    pub fn retry_after(&self) -> Option<Duration> {
220        match self {
221            Self::CircuitBreakerOpen { retry_after, .. } => Some(*retry_after),
222            _ => None,
223        }
224    }
225}
226
227// ============================================================================
228// MultiProvider
229// ============================================================================
230
231/// Intelligent routing provider with fallback support.
232///
233/// MultiProvider implements the `Provider` trait and provides automatic
234/// model selection based on task complexity, with circuit breaker protection
235/// and ordered fallback for resilience.
236///
237/// # Type Parameters
238///
239/// - `R`: The complexity router type (default: `DefaultRouter`)
240/// - `F`: The fallback chain type (default: `FallbackChain`)
241pub struct MultiProvider {
242    /// Router for complexity-based model selection.
243    router: Arc<dyn ComplexityRouter>,
244
245    /// Registered providers by name.
246    providers: HashMap<String, Arc<dyn Provider>>,
247
248    /// Fallback chain for ordered failover.
249    fallback: FallbackChain,
250
251    /// Circuit breakers for each provider.
252    breakers: HashMap<String, Arc<ProviderCircuitBreaker>>,
253
254    /// Configuration settings.
255    config: MultiProviderConfig,
256}
257
258impl MultiProvider {
259    /// Create a new MultiProvider with the given configuration.
260    ///
261    /// Uses `DefaultRouter` for complexity-based routing.
262    ///
263    /// # Example
264    ///
265    /// ```ignore
266    /// let config = MultiProviderConfig::default();
267    /// let provider = MultiProvider::new(config);
268    /// ```
269    pub fn new(config: MultiProviderConfig) -> Self {
270        Self {
271            router: Arc::new(DefaultRouter::new()),
272            providers: HashMap::new(),
273            fallback: FallbackChain::default(),
274            breakers: HashMap::new(),
275            config,
276        }
277    }
278
279    /// Create a new MultiProvider with a custom router.
280    ///
281    /// Allows using a custom implementation of `ComplexityRouter`.
282    ///
283    /// # Example
284    ///
285    /// ```ignore
286    /// let router = MyCustomRouter::new();
287    /// let provider = MultiProvider::with_router(router);
288    /// ```
289    pub fn with_router(router: impl ComplexityRouter + 'static) -> Self {
290        Self {
291            router: Arc::new(router),
292            providers: HashMap::new(),
293            fallback: FallbackChain::default(),
294            breakers: HashMap::new(),
295            config: MultiProviderConfig::default(),
296        }
297    }
298
299    /// Create a new MultiProvider with custom config and router.
300    ///
301    /// # Example
302    ///
303    /// ```ignore
304    /// let config = MultiProviderConfig::default()
305    ///     .with_auto_routing(false)
306    ///     .with_max_retries(2);
307    /// let router = DefaultRouter::new();
308    /// let provider = MultiProvider::with_config_and_router(config, router);
309    /// ```
310    pub fn with_config_and_router(
311        config: MultiProviderConfig,
312        router: impl ComplexityRouter + 'static,
313    ) -> Self {
314        Self {
315            router: Arc::new(router),
316            providers: HashMap::new(),
317            fallback: FallbackChain::default(),
318            breakers: HashMap::new(),
319            config,
320        }
321    }
322
323    /// Replace the router with a new implementation.
324    ///
325    /// # Example
326    ///
327    /// ```ignore
328    /// let provider = multi_provider.set_router(new_router);
329    /// ```
330    pub fn set_router(mut self, router: impl ComplexityRouter + 'static) -> Self {
331        self.router = Arc::new(router);
332        self
333    }
334
335    /// Set the fallback chain.
336    ///
337    /// The fallback chain is used when the primary model fails or is unavailable.
338    ///
339    /// # Example
340    ///
341    /// ```ignore
342    /// let fallback = FallbackChain::from_ids(&["anthropic/claude-sonnet-4"])?;
343    /// let provider = multi_provider.with_fallback(fallback);
344    /// ```
345    pub fn with_fallback(mut self, fallback: FallbackChain) -> Self {
346        self.fallback = fallback;
347        self
348    }
349
350    /// Set the fallback chain by reference.
351    pub fn set_fallback(&mut self, fallback: FallbackChain) {
352        self.fallback = fallback;
353    }
354
355    /// Register a provider with this MultiProvider.
356    ///
357    /// The provider can be referenced by name when calling `stream()`.
358    /// Each provider gets its own circuit breaker instance.
359    ///
360    /// # Arguments
361    ///
362    /// * `name` - Provider identifier (e.g., "openai", "anthropic")
363    /// * `provider` - The provider implementation
364    ///
365    /// # Example
366    ///
367    /// ```ignore
368    /// let openai_provider = Arc::new(OpenAiProvider::new());
369    /// multi_provider.register_provider("openai", openai_provider);
370    /// ```
371    pub fn register_provider(&mut self, name: &str, provider: Arc<dyn Provider>) {
372        // Create circuit breaker for this provider
373        let breaker = Arc::new(ProviderCircuitBreaker::new(
374            name.to_string(),
375            self.config.circuit_breaker.clone(),
376        ));
377
378        self.providers.insert(name.to_string(), provider);
379        self.breakers.insert(name.to_string(), breaker);
380    }
381
382    /// Unregister a provider.
383    ///
384    /// Removes the provider and its associated circuit breaker.
385    ///
386    /// # Arguments
387    ///
388    /// * `name` - Provider identifier to remove
389    ///
390    /// # Returns
391    ///
392    /// `true` if the provider was found and removed.
393    pub fn unregister_provider(&mut self, name: &str) -> bool {
394        let provider_removed = self.providers.remove(name).is_some();
395        let breaker_removed = self.breakers.remove(name).is_some();
396        provider_removed || breaker_removed
397    }
398
399    /// Get a provider by name.
400    ///
401    /// # Arguments
402    ///
403    /// * `name` - Provider identifier
404    ///
405    /// # Returns
406    ///
407    /// `Option` containing the provider if found.
408    pub fn get_provider(&self, name: &str) -> Option<&Arc<dyn Provider>> {
409        self.providers.get(name)
410    }
411
412    /// Get the circuit breaker for a provider.
413    ///
414    /// # Arguments
415    ///
416    /// * `provider_name` - Provider identifier
417    ///
418    /// # Returns
419    ///
420    /// `Arc<ProviderCircuitBreaker>` if the provider is registered.
421    pub fn get_breaker(&self, provider_name: &str) -> Option<Arc<ProviderCircuitBreaker>> {
422        self.breakers.get(provider_name).cloned()
423    }
424
425    /// Get all registered provider names.
426    pub fn provider_names(&self) -> Vec<&str> {
427        self.providers.keys().map(|s| s.as_str()).collect()
428    }
429
430    /// Get diagnostic information for all circuit breakers.
431    ///
432    /// # Returns
433    ///
434    /// Vector of diagnostics for each registered provider.
435    pub fn circuit_breaker_diagnostics(
436        &self,
437    ) -> Vec<crate::circuit_breaker::CircuitBreakerDiagnostics> {
438        self.breakers.values().map(|b| b.diagnostics()).collect()
439    }
440
441    /// Get the router used for complexity-based routing.
442    pub fn router(&self) -> &Arc<dyn ComplexityRouter> {
443        &self.router
444    }
445
446    /// Get a reference to the fallback chain.
447    pub fn fallback(&self) -> &FallbackChain {
448        &self.fallback
449    }
450
451    /// Get a reference to the configuration.
452    pub fn config(&self) -> &MultiProviderConfig {
453        &self.config
454    }
455
456    /// Get diagnostic summary of the multi-provider state.
457    pub fn diagnostics(&self) -> MultiProviderDiagnostics {
458        MultiProviderDiagnostics {
459            provider_count: self.providers.len(),
460            router_type: "DefaultRouter".to_string(),
461            fallback_len: self.fallback.len(),
462            auto_routing: self.config.auto_routing,
463            prefer_cost_efficient: self.config.prefer_cost_efficient,
464            circuit_breakers: self.circuit_breaker_diagnostics(),
465        }
466    }
467}
468
469// ============================================================================
470// Diagnostic Types
471// ============================================================================
472
473/// Diagnostic information about MultiProvider state.
474#[derive(Debug, Clone)]
475pub struct MultiProviderDiagnostics {
476    /// Number of registered providers.
477    pub provider_count: usize,
478    /// Type of router being used.
479    pub router_type: String,
480    /// Number of models in the fallback chain.
481    pub fallback_len: usize,
482    /// Whether auto-routing is enabled.
483    pub auto_routing: bool,
484    /// Whether cost-efficient models are preferred.
485    pub prefer_cost_efficient: bool,
486    /// Circuit breaker diagnostics for each provider.
487    pub circuit_breakers: Vec<crate::circuit_breaker::CircuitBreakerDiagnostics>,
488}
489
490// ============================================================================
491// Fallback Event Stream Wrapper
492// ============================================================================
493
494use futures::stream::Stream as StreamTrait;
495
496/// A wrapper stream that injects a `FallbackStart` event at the beginning,
497/// then forwards all subsequent events from the underlying stream.
498struct FallbackStream {
499    /// The injected fallback event (always emitted first).
500    fallback_event: ProviderEvent,
501    /// Whether the fallback event has been emitted yet.
502    emitted: bool,
503    /// The inner stream we're wrapping.
504    inner: Pin<Box<dyn Stream<Item = ProviderEvent> + Send>>,
505}
506
507impl FallbackStream {
508    /// Create a new wrapper stream that will emit `FallbackStart` first.
509    fn new(
510        from_model: String,
511        to_model: String,
512        reason: FallbackReason,
513        inner: Pin<Box<dyn Stream<Item = ProviderEvent> + Send>>,
514    ) -> Self {
515        Self {
516            fallback_event: ProviderEvent::FallbackStart {
517                from_model,
518                to_model,
519                reason,
520            },
521            emitted: false,
522            inner,
523        }
524    }
525}
526
527impl StreamTrait for FallbackStream {
528    type Item = ProviderEvent;
529
530    fn poll_next(
531        mut self: std::pin::Pin<&mut Self>,
532        cx: &mut std::task::Context<'_>,
533    ) -> std::task::Poll<Option<Self::Item>> {
534        // Emit the fallback event on the first poll
535        if !self.emitted {
536            self.emitted = true;
537            return std::task::Poll::Ready(Some(self.fallback_event.clone()));
538        }
539
540        // Then delegate to the inner stream
541        Stream::poll_next(self.inner.as_mut(), cx)
542    }
543}
544
545/// A wrapper stream that emits `FallbackExhausted` and then terminates.
546/// Used when all fallback candidates have been exhausted.
547struct FallbackExhaustedStream {
548    /// The exhausted event to emit.
549    exhausted_event: ProviderEvent,
550    /// Whether we've emitted the exhausted event.
551    emitted: bool,
552    /// The inner error stream (may emit additional error events before terminating).
553    inner: Option<Pin<Box<dyn Stream<Item = ProviderEvent> + Send>>>,
554}
555
556impl FallbackExhaustedStream {
557    /// Create a new wrapper stream that will emit `FallbackExhausted` first.
558    fn new(models_tried: Vec<String>, final_error: String) -> Self {
559        Self {
560            exhausted_event: ProviderEvent::FallbackExhausted {
561                models_tried,
562                final_error,
563            },
564            emitted: false,
565            inner: None,
566        }
567    }
568
569    /// Set the inner error stream to forward events from.
570    #[allow(dead_code)]
571    fn with_inner(mut self, inner: Pin<Box<dyn Stream<Item = ProviderEvent> + Send>>) -> Self {
572        self.inner = Some(inner);
573        self
574    }
575}
576
577impl StreamTrait for FallbackExhaustedStream {
578    type Item = ProviderEvent;
579
580    fn poll_next(
581        mut self: std::pin::Pin<&mut Self>,
582        cx: &mut std::task::Context<'_>,
583    ) -> std::task::Poll<Option<Self::Item>> {
584        // Emit the exhausted event on the first poll
585        if !self.emitted {
586            self.emitted = true;
587            return std::task::Poll::Ready(Some(self.exhausted_event.clone()));
588        }
589
590        // Then forward from inner stream if present, otherwise terminate
591        if let Some(ref mut inner) = self.inner {
592            Stream::poll_next(inner.as_mut(), cx)
593        } else {
594            std::task::Poll::Ready(None)
595        }
596    }
597}
598
599/// Determine the fallback reason from a provider error.
600fn error_to_fallback_reason(error: &ProviderError) -> FallbackReason {
601    match error {
602        ProviderError::HttpError(429, _) => FallbackReason::RateLimit,
603        ProviderError::HttpError(code, _) if *code >= 500 => FallbackReason::ServerError,
604        ProviderError::HttpError(code, _) if *code == 401 || *code == 403 => {
605            FallbackReason::AuthError
606        }
607        ProviderError::RequestFailed(_) => FallbackReason::NetworkError,
608        ProviderError::Timeout => FallbackReason::NetworkError,
609        ProviderError::ContextOverflow => FallbackReason::ContextOverflow,
610        _ => FallbackReason::Unknown,
611    }
612}
613
614// ============================================================================
615// Provider Trait Implementation
616// ============================================================================
617
618#[async_trait]
619impl Provider for MultiProvider {
620    /// Stream assistant message events with intelligent routing.
621    ///
622    /// This method implements the priority order logic:
623    ///
624    /// 1. If `auto_routing=true`: classify complexity and select router's best model
625    /// 2. Try the incoming model (if registered and circuit breaker allows)
626    /// 3. Try fallback chain models in order
627    ///
628    /// For each candidate model:
629    /// - Check circuit breaker (skip if open)
630    /// - Call provider.stream()
631    /// - On retryable error: record failure, retry or move to next
632    /// - On non-retryable error: return immediately
633    /// - On success: record success to breaker, return stream
634    async fn stream(
635        &self,
636        model: &Model,
637        context: &Context,
638        options: Option<StreamOptions>,
639    ) -> Result<Pin<Box<dyn Stream<Item = ProviderEvent> + Send>>, ProviderError> {
640        // Build candidate list based on priority order
641        let candidates = self.build_candidate_list(model, context).await?;
642
643        // Try each candidate in order
644        let mut errors: Vec<(String, ProviderError)> = Vec::new();
645        let mut current_candidate_idx: usize = 0;
646
647        while current_candidate_idx < candidates.len() {
648            let candidate = &candidates[current_candidate_idx];
649            let provider_name = &candidate.provider;
650            let candidate_model = candidate.model.clone();
651
652            // Get provider
653            let Some(provider) = self.providers.get(provider_name) else {
654                current_candidate_idx += 1;
655                continue;
656            };
657
658            // Check circuit breaker
659            if let Some(breaker) = self.breakers.get(provider_name) {
660                match breaker.allow_request() {
661                    Ok(()) => {
662                        // Circuit allows request, proceed
663                    }
664                    Err(e) => {
665                        // Circuit is open - skip this provider
666                        tracing::debug!(
667                            provider = %provider_name,
668                            remaining = ?e.remaining,
669                            "Circuit breaker open, skipping provider"
670                        );
671                        current_candidate_idx += 1;
672                        continue;
673                    }
674                }
675            }
676
677            // Try to stream with retries
678            let mut retry_count = 0;
679            let max_retries = self.config.max_retries_per_model;
680
681            loop {
682                match provider
683                    .stream(&candidate_model, context, options.clone())
684                    .await
685                {
686                    Ok(inner_stream) => {
687                        // Success! Record to circuit breaker
688                        if let Some(breaker) = self.breakers.get(provider_name) {
689                            breaker.record_success();
690                        }
691                        tracing::debug!(
692                            provider = %provider_name,
693                            model = %candidate_model.id,
694                            "MultiProvider: stream successful"
695                        );
696
697                        // Wrap stream with fallback event if we attempted a previous candidate
698                        if current_candidate_idx > 0 {
699                            let from_model = format!(
700                                "{}/{}",
701                                candidates[current_candidate_idx - 1].provider,
702                                candidates[current_candidate_idx - 1].model.id
703                            );
704                            let to_model = format!("{}/{}", provider_name, candidate_model.id);
705                            let reason = errors
706                                .last()
707                                .map(|(_, e)| error_to_fallback_reason(e))
708                                .unwrap_or(FallbackReason::Unknown);
709
710                            let wrapped =
711                                FallbackStream::new(from_model, to_model, reason, inner_stream);
712                            return Ok(Box::pin(wrapped) as Pin<Box<_>>);
713                        }
714
715                        return Ok(inner_stream);
716                    }
717                    Err(e) => {
718                        // Check if error is retryable
719                        if e.is_retryable() && retry_count < max_retries {
720                            // Retryable error - record failure and retry
721                            retry_count += 1;
722                            if let Some(breaker) = self.breakers.get(provider_name) {
723                                breaker.record_failure();
724                            }
725                            tracing::debug!(
726                                provider = %provider_name,
727                                model = %candidate_model.id,
728                                error = %e,
729                                retry = retry_count,
730                                "Retryable error, retrying"
731                            );
732                            continue;
733                        }
734
735                        // Non-retryable error or max retries exceeded
736                        if !e.is_retryable() {
737                            // Non-retryable errors (400, 401, 403, etc.) don't record failure
738                            // Return immediately - these won't be fixed by retrying
739                            tracing::warn!(
740                                provider = %provider_name,
741                                model = %candidate_model.id,
742                                error = %e,
743                                "Non-retryable error, returning immediately"
744                            );
745                            return Err(e);
746                        }
747
748                        // Max retries exceeded - try next candidate
749                        tracing::debug!(
750                            provider = %provider_name,
751                            model = %candidate_model.id,
752                            error = %e,
753                            retries = retry_count,
754                            "Max retries exceeded, trying next candidate"
755                        );
756                        errors.push((format!("{}/{}", provider_name, candidate_model.id), e));
757                        break;
758                    }
759                }
760            }
761
762            current_candidate_idx += 1;
763        }
764
765        // All candidates exhausted
766        if errors.is_empty() {
767            if self.providers.is_empty() {
768                Err(ProviderError::UnknownProvider(
769                    "multi-provider: no providers registered".to_string(),
770                ))
771            } else {
772                Err(ProviderError::UnknownProvider(
773                    "multi-provider: no model could be routed".to_string(),
774                ))
775            }
776        } else {
777            // Emit FallbackExhausted event
778            let models_tried: Vec<String> = errors.iter().map(|(m, _)| m.clone()).collect();
779            let final_error = errors
780                .last()
781                .map(|(_, e)| e.to_string())
782                .unwrap_or_else(|| "Unknown error".to_string());
783
784            tracing::warn!(
785                models_tried = ?models_tried,
786                error = %final_error,
787                "All fallback models exhausted"
788            );
789
790            let stream = FallbackExhaustedStream::new(models_tried, final_error);
791            Ok(Box::pin(stream) as Pin<Box<_>>)
792        }
793    }
794
795    /// Returns "multi-provider" as the provider name.
796    fn name(&self) -> &str {
797        "multi-provider"
798    }
799}
800
801// ============================================================================
802// Candidate List Building
803// ============================================================================
804
805/// A candidate model for streaming attempts.
806struct Candidate {
807    /// Provider name for this candidate.
808    provider: String,
809    /// Model to use with this provider.
810    model: Model,
811}
812
813impl MultiProvider {
814    /// Build the candidate list based on configuration and priority order.
815    ///
816    /// Priority order (from design §8.3):
817    /// - auto_routing=true → router's best model → incoming model → fallback chain
818    /// - auto_routing=false → incoming model → fallback chain
819    async fn build_candidate_list(
820        &self,
821        incoming_model: &Model,
822        context: &Context,
823    ) -> Result<Vec<Candidate>, ProviderError> {
824        let mut candidates: Vec<Candidate> = Vec::new();
825        let mut seen_ids: HashMap<String, ()> = HashMap::new();
826
827        // Helper to add candidate if not already added
828        let add_candidate = |candidates: &mut Vec<Candidate>,
829                             seen_ids: &mut HashMap<String, ()>,
830                             provider: String,
831                             model: Model| {
832            let id = format!("{}/{}", provider, model.id);
833            if seen_ids.insert(id, ()).is_none() {
834                candidates.push(Candidate { provider, model });
835            }
836        };
837
838        // 1. Auto-routing: get router's best model
839        if self.config.auto_routing {
840            let complexity = self.router.classify(context);
841            let router_models = self
842                .router
843                .route(complexity, self.config.prefer_cost_efficient);
844
845            tracing::debug!(
846                complexity = ?complexity,
847                model_count = router_models.len(),
848                "MultiProvider: router selected models for complexity"
849            );
850
851            for entry in router_models {
852                // Try to get the model from registry
853                if let Some(registered_model) =
854                    crate::model_registry::get_model(entry.provider, entry.id)
855                {
856                    if self.providers.contains_key(entry.provider) {
857                        add_candidate(
858                            &mut candidates,
859                            &mut seen_ids,
860                            entry.provider.to_string(),
861                            registered_model.clone(),
862                        );
863                    }
864                }
865
866                // Also construct from entry if not found in registry
867                if self.providers.contains_key(entry.provider) {
868                    let model = self.model_from_entry(entry);
869                    let id = format!("{}/{}", entry.provider, entry.id);
870                    if seen_ids.insert(id, ()).is_none() {
871                        candidates.push(Candidate {
872                            provider: entry.provider.to_string(),
873                            model,
874                        });
875                    }
876                }
877            }
878        }
879
880        // 2. Incoming model
881        if self.providers.contains_key(&incoming_model.provider) {
882            add_candidate(
883                &mut candidates,
884                &mut seen_ids,
885                incoming_model.provider.clone(),
886                incoming_model.clone(),
887            );
888        } else {
889            // Try to find a provider for this model
890            // Look through all providers to find one that handles this model type
891            for provider_name in self.providers.keys() {
892                // Check if the incoming model matches this provider's expected models
893                let model_id = &incoming_model.id;
894
895                // Try to get the model from registry
896                if let Some(model) = self.find_model_for_provider(provider_name, model_id) {
897                    add_candidate(&mut candidates, &mut seen_ids, provider_name.clone(), model);
898                    break;
899                }
900            }
901        }
902
903        // 3. Fallback chain
904        for fallback_entry in self.fallback.iter() {
905            // Try registry first
906            if let Some(registered_model) =
907                crate::model_registry::get_model(fallback_entry.provider, fallback_entry.id)
908            {
909                if self.providers.contains_key(fallback_entry.provider) {
910                    add_candidate(
911                        &mut candidates,
912                        &mut seen_ids,
913                        fallback_entry.provider.to_string(),
914                        registered_model.clone(),
915                    );
916                }
917            } else if self.providers.contains_key(fallback_entry.provider) {
918                // Construct from entry
919                let model = self.model_from_entry(fallback_entry);
920                let id = format!("{}/{}", fallback_entry.provider, fallback_entry.id);
921                if seen_ids.insert(id, ()).is_none() {
922                    candidates.push(Candidate {
923                        provider: fallback_entry.provider.to_string(),
924                        model,
925                    });
926                }
927            }
928        }
929
930        // If no candidates found and providers exist, try using the first provider
931        if candidates.is_empty() && !self.providers.is_empty() {
932            // Use first available provider with a default model
933            let (provider_name, _provider) = self
934                .providers
935                .iter()
936                .next()
937                .expect("providers map is non-empty");
938            let model = self.default_model_for_provider(provider_name);
939            add_candidate(&mut candidates, &mut seen_ids, provider_name.clone(), model);
940        }
941
942        tracing::debug!(
943            candidate_count = candidates.len(),
944            "MultiProvider: built candidate list"
945        );
946
947        if candidates.is_empty() && self.providers.is_empty() {
948            return Err(ProviderError::UnknownProvider(
949                "multi-provider: no providers registered".to_string(),
950            ));
951        }
952
953        Ok(candidates)
954    }
955
956    /// Construct a Model from a ModelEntry.
957    fn model_from_entry(&self, entry: &ModelEntry) -> Model {
958        Model {
959            id: entry.id.to_string(),
960            name: entry.name.to_string(),
961            api: entry.api,
962            provider: entry.provider.to_string(),
963            base_url: String::new(), // Will be set by provider
964            reasoning: entry.reasoning,
965            input: entry.input.to_vec(),
966            cost: crate::types::Cost {
967                input: entry.cost_input,
968                output: entry.cost_output,
969                cache_read: entry.cost_cache_read,
970                cache_write: entry.cost_cache_write,
971            },
972            context_window: entry.context_window as usize,
973            max_tokens: entry.max_tokens as usize,
974            headers: HashMap::new(),
975            compat: None,
976        }
977    }
978
979    /// Find a model for a provider given a model ID.
980    fn find_model_for_provider(&self, provider_name: &str, model_id: &str) -> Option<Model> {
981        // Check registry
982        if let Some(model) = crate::model_registry::get_model(provider_name, model_id) {
983            return Some(model.clone());
984        }
985
986        // Check model_db
987        if let Some(entry) = crate::model_db::get_model_entry(provider_name, model_id) {
988            return Some(self.model_from_entry(entry));
989        }
990
991        // Construct from model_id
992        Some(self.construct_model_from_id(provider_name, model_id))
993    }
994
995    /// Construct a Model from just provider and model ID strings.
996    ///
997    /// Uses model_db to get actual metadata (context_window, cost, reasoning support, etc.).
998    /// Falls back to reasonable defaults if the model is not found in model_db.
999    fn construct_model_from_id(&self, provider: &str, model_id: &str) -> Model {
1000        // First, try to look up the model in model_db
1001        if let Some(entry) = crate::model_db::get_model_entry(provider, model_id) {
1002            return self.model_from_entry(entry);
1003        }
1004
1005        // Not in model_db: determine API type from provider name
1006        let api = match provider {
1007            "openai" | "openai-codex" | "opencode" | "opencode-go" => {
1008                crate::types::Api::OpenAiResponses
1009            }
1010            "anthropic" | "cloudflare-ai-gateway" => crate::types::Api::AnthropicMessages,
1011            "google" => crate::types::Api::GoogleGenerativeAi,
1012            "google-vertex" => crate::types::Api::GoogleVertex,
1013            "azure-openai" | "azure-openai-responses" => crate::types::Api::AzureOpenAiResponses,
1014            "amazon-bedrock" | "bedrock" => crate::types::Api::BedrockConverseStream,
1015            _ => crate::types::Api::OpenAiResponses,
1016        };
1017
1018        Model {
1019            id: model_id.to_string(),
1020            name: model_id.to_string(),
1021            api,
1022            provider: provider.to_string(),
1023            base_url: String::new(),
1024            reasoning: false,
1025            input: vec![crate::types::InputModality::Text],
1026            cost: crate::types::Cost::default(),
1027            context_window: 128_000,
1028            max_tokens: 32_000,
1029            headers: HashMap::new(),
1030            compat: None,
1031        }
1032    }
1033
1034    /// Get the default model for a provider.
1035    ///
1036    /// Uses model_db to look up the most capable model for each provider,
1037    /// with fallbacks for providers not in model_db.
1038    fn default_model_for_provider(&self, provider_name: &str) -> Model {
1039        // Define the preferred default model IDs for each major provider
1040        let default_model_id = match provider_name {
1041            "openai" => "gpt-4o-mini",
1042            "anthropic" => "claude-sonnet-4-20250514",
1043            "google" => "gemini-2.0-flash",
1044            _ => return self.construct_model_from_id(provider_name, "default"),
1045        };
1046
1047        // Try to get the model from model_db
1048        if let Some(entry) = crate::model_db::get_model_entry(provider_name, default_model_id) {
1049            return self.model_from_entry(entry);
1050        }
1051
1052        // Fallback: try to get the first/last model from model_db for this provider
1053        let provider_models = crate::model_db::get_provider_models(provider_name);
1054        if !provider_models.is_empty() {
1055            // Use the last model (typically the most capable/latest)
1056            if let Some(entry) = provider_models.last() {
1057                return self.model_from_entry(entry);
1058            }
1059        }
1060
1061        // Ultimate fallback: construct with sensible defaults
1062        self.construct_model_from_id(provider_name, "default")
1063    }
1064}
1065
1066// ============================================================================
1067// Tests
1068// ============================================================================
1069
1070#[cfg(test)]
1071mod tests {
1072    use super::*;
1073    use crate::context::Context;
1074    use crate::Message;
1075
1076    fn create_test_context() -> Context {
1077        let mut ctx = Context::new();
1078        ctx.add_message(Message::User(crate::UserMessage::new(
1079            "Help me write a function to reverse a string".to_string(),
1080        )));
1081        ctx
1082    }
1083
1084    #[test]
1085    fn test_config_defaults() {
1086        let config = MultiProviderConfig::default();
1087        assert!(config.auto_routing);
1088        assert!(config.prefer_cost_efficient);
1089        assert_eq!(config.max_retries_per_model, 1);
1090        assert!(config.per_model_timeout.is_none());
1091        // Circuit breaker config defaults are tested in circuit_breaker module
1092    }
1093
1094    #[test]
1095    fn test_config_builder() {
1096        let config = MultiProviderConfig::default()
1097            .with_auto_routing(false)
1098            .with_prefer_cost_efficient(false)
1099            .with_max_retries(3)
1100            .with_per_model_timeout(Duration::from_secs(30));
1101
1102        assert!(!config.auto_routing);
1103        assert!(!config.prefer_cost_efficient);
1104        assert_eq!(config.max_retries_per_model, 3);
1105        assert_eq!(config.per_model_timeout, Some(Duration::from_secs(30)));
1106    }
1107
1108    #[test]
1109    fn test_multi_provider_creation() {
1110        let config = MultiProviderConfig::default();
1111        let provider = MultiProvider::new(config);
1112
1113        assert_eq!(provider.name(), "multi-provider");
1114        assert!(provider.provider_names().is_empty());
1115    }
1116
1117    #[test]
1118    fn test_register_provider() {
1119        let mut provider = MultiProvider::new(MultiProviderConfig::default());
1120
1121        // Register a mock provider
1122        struct MockProvider;
1123        #[async_trait]
1124        impl Provider for MockProvider {
1125            async fn stream(
1126                &self,
1127                _model: &Model,
1128                _context: &Context,
1129                _options: Option<StreamOptions>,
1130            ) -> Result<Pin<Box<dyn Stream<Item = ProviderEvent> + Send>>, ProviderError>
1131            {
1132                unreachable!("Mock provider - not called in this test")
1133            }
1134
1135            fn name(&self) -> &str {
1136                "mock"
1137            }
1138        }
1139
1140        let mock = Arc::new(MockProvider);
1141        provider.register_provider("test", mock);
1142
1143        assert_eq!(provider.provider_names(), vec!["test"]);
1144        assert!(provider.get_provider("test").is_some());
1145        assert!(provider.get_breaker("test").is_some());
1146    }
1147
1148    #[test]
1149    fn test_unregister_provider() {
1150        let mut provider = MultiProvider::new(MultiProviderConfig::default());
1151
1152        struct MockProvider;
1153        #[async_trait]
1154        impl Provider for MockProvider {
1155            async fn stream(
1156                &self,
1157                _model: &Model,
1158                _context: &Context,
1159                _options: Option<StreamOptions>,
1160            ) -> Result<Pin<Box<dyn Stream<Item = ProviderEvent> + Send>>, ProviderError>
1161            {
1162                unreachable!("Mock provider")
1163            }
1164
1165            fn name(&self) -> &str {
1166                "mock"
1167            }
1168        }
1169
1170        let mock = Arc::new(MockProvider);
1171        provider.register_provider("test", mock.clone());
1172
1173        assert!(provider.unregister_provider("test"));
1174        assert!(provider.provider_names().is_empty());
1175        assert!(provider.get_provider("test").is_none());
1176    }
1177
1178    #[test]
1179    fn test_with_router() {
1180        let router = DefaultRouter::new();
1181        let provider = MultiProvider::with_router(router);
1182
1183        assert_eq!(provider.name(), "multi-provider");
1184    }
1185
1186    #[test]
1187    fn test_with_fallback() {
1188        let fallback = FallbackChain::from_ids(&["openai/gpt-4o"]).unwrap();
1189        let provider = MultiProvider::new(MultiProviderConfig::default()).with_fallback(fallback);
1190
1191        assert_eq!(provider.fallback().len(), 1);
1192    }
1193
1194    #[test]
1195    fn test_circuit_breaker_diagnostics() {
1196        let mut provider = MultiProvider::new(MultiProviderConfig::default());
1197
1198        struct MockProvider;
1199        #[async_trait]
1200        impl Provider for MockProvider {
1201            async fn stream(
1202                &self,
1203                _model: &Model,
1204                _context: &Context,
1205                _options: Option<StreamOptions>,
1206            ) -> Result<Pin<Box<dyn Stream<Item = ProviderEvent> + Send>>, ProviderError>
1207            {
1208                unreachable!("Mock provider")
1209            }
1210
1211            fn name(&self) -> &str {
1212                "mock"
1213            }
1214        }
1215
1216        let mock = Arc::new(MockProvider);
1217        provider.register_provider("test", mock);
1218
1219        let diagnostics = provider.circuit_breaker_diagnostics();
1220        assert_eq!(diagnostics.len(), 1);
1221        assert_eq!(diagnostics[0].provider, "test");
1222    }
1223
1224    #[test]
1225    fn test_multi_provider_error_display() {
1226        let err = MultiProviderError::NoProviderForModel("gpt-4o".to_string());
1227        assert!(err.to_string().contains("gpt-4o"));
1228
1229        let err = MultiProviderError::AllProvidersExhausted { errors: vec![] };
1230        assert!(err.to_string().contains("All providers exhausted"));
1231
1232        let err = MultiProviderError::CircuitBreakerOpen {
1233            provider: "openai".to_string(),
1234            retry_after: Duration::from_secs(10),
1235        };
1236        assert!(err.to_string().contains("openai"));
1237        assert!(err.to_string().contains("10"));
1238    }
1239
1240    #[test]
1241    fn test_multi_provider_error_helpers() {
1242        let err = MultiProviderError::CircuitBreakerOpen {
1243            provider: "openai".to_string(),
1244            retry_after: Duration::from_secs(10),
1245        };
1246        assert!(err.is_circuit_breaker());
1247        assert_eq!(err.retry_after(), Some(Duration::from_secs(10)));
1248
1249        let err = MultiProviderError::AllProvidersExhausted { errors: vec![] };
1250        assert!(!err.is_circuit_breaker());
1251        assert_eq!(err.retry_after(), None);
1252    }
1253
1254    #[test]
1255    fn test_diagnostics() {
1256        let mut provider = MultiProvider::new(MultiProviderConfig::default());
1257
1258        struct MockProvider;
1259        #[async_trait]
1260        impl Provider for MockProvider {
1261            async fn stream(
1262                &self,
1263                _model: &Model,
1264                _context: &Context,
1265                _options: Option<StreamOptions>,
1266            ) -> Result<Pin<Box<dyn Stream<Item = ProviderEvent> + Send>>, ProviderError>
1267            {
1268                unreachable!("Mock provider")
1269            }
1270
1271            fn name(&self) -> &str {
1272                "mock"
1273            }
1274        }
1275
1276        let mock = Arc::new(MockProvider);
1277        provider.register_provider("test", mock);
1278
1279        let diag = provider.diagnostics();
1280        assert_eq!(diag.provider_count, 1);
1281        assert!(diag.auto_routing);
1282        assert!(diag.prefer_cost_efficient);
1283        assert_eq!(diag.circuit_breakers.len(), 1);
1284    }
1285
1286    #[test]
1287    fn test_router_classification() {
1288        use crate::Complexity;
1289        let router = DefaultRouter::new();
1290        let provider = MultiProvider::with_router(router);
1291
1292        let ctx = create_test_context();
1293        let complexity = provider.router().classify(&ctx);
1294
1295        // "Help me write a function to reverse a string" should be Simple complexity
1296        assert!(complexity >= Complexity::Simple);
1297    }
1298}