use std::collections::HashMap;
use std::fmt;
use std::sync::Arc;
use oxi_ai::{
circuit_breaker::CircuitBreakerConfig, fallback_chain::FallbackChain,
multi_provider::MultiProviderConfig, ComplexityRouter, MultiProvider, Provider,
};
pub struct RoutingConfig {
pub auto_routing: bool,
pub prefer_cost_efficient: bool,
pub router: Option<Box<dyn ComplexityRouter>>,
}
impl Clone for RoutingConfig {
fn clone(&self) -> Self {
Self {
auto_routing: self.auto_routing,
prefer_cost_efficient: self.prefer_cost_efficient,
router: None,
}
}
}
impl fmt::Debug for RoutingConfig {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RoutingConfig")
.field("auto_routing", &self.auto_routing)
.field("prefer_cost_efficient", &self.prefer_cost_efficient)
.field("router", &"<dyn ComplexityRouter>")
.finish()
}
}
impl Default for RoutingConfig {
fn default() -> Self {
Self {
auto_routing: true,
prefer_cost_efficient: true,
router: None,
}
}
}
impl RoutingConfig {
pub fn new() -> Self {
Self::default()
}
pub fn auto_routing(mut self, enabled: bool) -> Self {
self.auto_routing = enabled;
self
}
pub fn prefer_cost_efficient(mut self, enabled: bool) -> Self {
self.prefer_cost_efficient = enabled;
self
}
pub fn with_router(mut self, router: impl ComplexityRouter + 'static) -> Self {
self.router = Some(Box::new(router));
self
}
}
pub struct MultiProviderBuilder {
router: Option<Box<dyn ComplexityRouter>>,
providers: HashMap<String, Arc<dyn Provider>>,
fallback_chain: FallbackChain,
config: MultiProviderConfig,
}
impl MultiProviderBuilder {
pub fn new() -> Self {
Self {
router: None,
providers: HashMap::new(),
fallback_chain: FallbackChain::default(),
config: MultiProviderConfig::default(),
}
}
pub fn provider(mut self, name: &str, provider: Arc<dyn Provider>) -> Self {
self.providers.insert(name.to_string(), provider);
self
}
pub fn with_fallbacks(self, ids: &[&str]) -> Self {
let fallback = FallbackChain::from_ids(ids).unwrap_or_else(|_| FallbackChain::default());
self.with_fallback_chain(fallback)
}
pub fn with_fallback_chain(mut self, fallback: FallbackChain) -> Self {
self.fallback_chain = fallback;
self
}
pub fn with_router(mut self, router: impl ComplexityRouter + 'static) -> Self {
self.router = Some(Box::new(router));
self
}
pub fn with_router_boxed(mut self, router: Box<dyn ComplexityRouter>) -> Self {
self.router = Some(router);
self
}
pub fn prefer_cost_efficient(mut self) -> Self {
self.config.prefer_cost_efficient = true;
self
}
pub fn enable_auto_routing(mut self) -> Self {
self.config.auto_routing = true;
self
}
pub fn with_circuit_breaker(mut self, config: CircuitBreakerConfig) -> Self {
self.config.circuit_breaker = config;
self
}
pub fn build(self) -> anyhow::Result<Arc<dyn Provider>> {
let mut mp = MultiProvider::new(self.config);
if let Some(router) = self.router {
struct BoxedRouter(Box<dyn ComplexityRouter>);
impl ComplexityRouter for BoxedRouter {
fn classify(&self, context: &oxi_ai::Context) -> oxi_ai::Complexity {
self.0.classify(context)
}
fn route(
&self,
complexity: oxi_ai::Complexity,
prefer_cost_efficient: bool,
) -> Vec<&'static oxi_ai::model_db::ModelEntry> {
self.0.route(complexity, prefer_cost_efficient)
}
}
mp = mp.set_router(BoxedRouter(router));
}
for (name, provider) in self.providers {
mp.register_provider(&name, provider);
}
if !self.fallback_chain.is_empty() {
mp = mp.with_fallback(self.fallback_chain);
}
if mp.provider_names().is_empty() {
anyhow::bail!("MultiProvider requires at least one provider");
}
Ok(Arc::new(mp))
}
}
impl Default for MultiProviderBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_routing_config_default() {
let config = RoutingConfig::default();
assert!(config.auto_routing);
assert!(config.prefer_cost_efficient);
assert!(config.router.is_none());
}
#[test]
fn test_routing_config_builder() {
let config = RoutingConfig::new()
.auto_routing(true)
.prefer_cost_efficient(false);
assert!(config.auto_routing);
assert!(!config.prefer_cost_efficient);
assert!(config.router.is_none());
}
#[test]
fn test_builder_new() {
let builder = MultiProviderBuilder::new();
assert!(builder.config.auto_routing);
}
#[test]
fn test_builder_prefer_cost_efficient() {
let builder = MultiProviderBuilder::new().prefer_cost_efficient();
assert!(builder.config.prefer_cost_efficient);
}
#[test]
fn test_builder_enable_auto_routing() {
let builder = MultiProviderBuilder::new().enable_auto_routing();
assert!(builder.config.auto_routing);
}
#[test]
fn test_builder_with_fallback_chain() {
let fallback = FallbackChain::from_ids(&["openai/gpt-4o"]).unwrap();
let builder = MultiProviderBuilder::new().with_fallback_chain(fallback);
assert!(!builder.fallback_chain.is_empty());
}
}