use anyhow::Result;
use std::sync::Arc;
use oxi_agent::{ProviderResolver, ToolRegistry};
use oxi_ai::{Model, ModelRegistry, Provider, ProviderRegistry};
use crate::agent_builder::AgentBuilder;
use crate::multi_provider::{MultiProviderBuilder, RoutingConfig};
pub struct Oxi {
providers: Arc<ProviderRegistry>,
models: Arc<ModelRegistry>,
tools: Arc<ToolRegistry>,
include_builtins: bool,
}
impl Oxi {
pub fn agent(&self, config: oxi_agent::AgentConfig) -> AgentBuilder<'_> {
AgentBuilder::new(self, config)
}
pub fn providers(&self) -> &ProviderRegistry {
&self.providers
}
pub fn models(&self) -> &ModelRegistry {
&self.models
}
pub fn tools(&self) -> Arc<ToolRegistry> {
Arc::clone(&self.tools)
}
pub fn resolve_model(&self, model_id: &str) -> Result<Model> {
let parts: Vec<&str> = model_id.splitn(2, '/').collect();
let (provider, model) = if parts.len() == 2 {
(parts[0], parts[1])
} else {
("anthropic", parts[0])
};
self.models
.lookup(provider, model)
.ok_or_else(|| anyhow::anyhow!("Model '{}' not found", model_id))
}
pub fn create_provider(&self, name: &str) -> Result<Arc<dyn Provider>> {
if let Some(p) = self.providers.get_custom(name) {
return Ok(p);
}
if self.include_builtins {
if let Some(p) = oxi_ai::create_builtin_provider(name) {
return Ok(Arc::from(p));
}
}
Err(anyhow::anyhow!("Provider '{}' not found", name))
}
pub fn providers_arc(&self) -> Arc<ProviderRegistry> {
Arc::clone(&self.providers)
}
pub fn models_arc(&self) -> Arc<ModelRegistry> {
Arc::clone(&self.models)
}
pub fn has_builtins(&self) -> bool {
self.include_builtins
}
}
impl ProviderResolver for Oxi {
fn resolve_provider(&self, name: &str) -> Option<Arc<dyn Provider>> {
self.create_provider(name).ok()
}
fn resolve_model(&self, model_id: &str) -> Option<Model> {
self.resolve_model(model_id).ok()
}
}
pub struct OxiBuilder {
providers: ProviderRegistry,
models: ModelRegistry,
tools: ToolRegistry,
include_builtins: bool,
}
impl OxiBuilder {
pub fn new() -> Self {
Self {
providers: ProviderRegistry::new(),
models: ModelRegistry::new(),
tools: ToolRegistry::new(),
include_builtins: false,
}
}
pub fn with_builtins(mut self) -> Self {
self.models = ModelRegistry::from_static();
self.include_builtins = true;
self
}
pub fn provider(self, name: &str, p: impl Provider + 'static) -> Self {
self.providers.register(name, p);
self
}
pub fn tool(self, tool: impl oxi_agent::AgentTool + 'static) -> Self {
self.tools.register(tool);
self
}
pub fn provider_factory(
self,
name: &str,
factory: impl Fn() -> anyhow::Result<Arc<dyn Provider>> + Send + Sync + 'static,
) -> Self {
self.providers.register_factory(name, factory);
self
}
pub fn model(self, model: Model) -> Self {
self.models.register(model);
self
}
pub fn enable_routing(self, config: RoutingConfig) -> Self {
let provider_names: Vec<String> = self.providers.names();
let mut providers_to_add: Vec<(String, Arc<dyn Provider>)> = Vec::new();
for name in &provider_names {
if let Some(provider) = self.providers.get_custom(name) {
providers_to_add.push((name.clone(), provider));
}
}
let mut mp_builder = MultiProviderBuilder::new();
if config.auto_routing {
mp_builder = mp_builder.enable_auto_routing();
}
if config.prefer_cost_efficient {
mp_builder = mp_builder.prefer_cost_efficient();
}
if let Some(router) = config.router {
mp_builder = mp_builder.with_router_boxed(router);
}
for (name, provider) in providers_to_add {
mp_builder = mp_builder.provider(&name, provider);
}
let built = mp_builder.build();
if let Ok(mp) = built {
self.providers.register_arc("multi", mp);
}
self
}
pub fn build(self) -> Oxi {
Oxi {
providers: Arc::new(self.providers),
models: Arc::new(self.models),
tools: Arc::new(self.tools),
include_builtins: self.include_builtins,
}
}
}
impl Default for OxiBuilder {
fn default() -> Self {
Self::new()
}
}