use std::collections::HashMap;
use std::sync::Arc;
use nexo_config::types::agents::ModelConfig;
use nexo_config::types::llm::{LlmConfig, LlmProviderConfig, RetryConfig};
use crate::anthropic::AnthropicFactory;
use crate::client::LlmClient;
use crate::deepseek::DeepSeekFactory;
use crate::gemini::GeminiFactory;
use crate::minimax::MiniMaxClient;
use crate::openai_compat::OpenAiClient;
pub trait LlmProviderFactory: Send + Sync {
fn name(&self) -> &str;
fn build(
&self,
provider_cfg: &LlmProviderConfig,
model: &str,
retry: RetryConfig,
) -> anyhow::Result<Arc<dyn LlmClient>>;
}
#[derive(Default)]
pub struct LlmRegistry {
factories: HashMap<String, Box<dyn LlmProviderFactory>>,
}
impl LlmRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn with_builtins() -> Self {
let mut r = Self::new();
r.register(Box::new(MiniMaxFactory))
.expect("builtin minimax factory");
r.register(Box::new(OpenAiFactory))
.expect("builtin openai factory");
r.register(Box::new(AnthropicFactory))
.expect("builtin anthropic factory");
r.register(Box::new(GeminiFactory))
.expect("builtin gemini factory");
r.register(Box::new(DeepSeekFactory))
.expect("builtin deepseek factory");
r
}
pub fn register(&mut self, factory: Box<dyn LlmProviderFactory>) -> anyhow::Result<()> {
let name = factory.name().to_string();
if self.factories.contains_key(&name) {
anyhow::bail!("LLM provider '{name}' already registered");
}
self.factories.insert(name, factory);
Ok(())
}
pub fn names(&self) -> Vec<&str> {
let mut names: Vec<&str> = self.factories.keys().map(String::as_str).collect();
names.sort_unstable();
names
}
pub fn build(
&self,
llm_cfg: &LlmConfig,
agent_model: &ModelConfig,
) -> anyhow::Result<Arc<dyn LlmClient>> {
let factory = self.factories.get(&agent_model.provider).ok_or_else(|| {
anyhow::anyhow!(
"LLM provider '{}' not registered (known: {:?})",
agent_model.provider,
self.names()
)
})?;
let provider_cfg = llm_cfg
.providers
.get(&agent_model.provider)
.ok_or_else(|| {
anyhow::anyhow!(
"LLM provider '{}' not present in config.providers",
agent_model.provider
)
})?;
factory.build(provider_cfg, &agent_model.model, llm_cfg.retry.clone())
}
}
pub struct MiniMaxFactory;
impl LlmProviderFactory for MiniMaxFactory {
fn name(&self) -> &str {
"minimax"
}
fn build(
&self,
provider_cfg: &LlmProviderConfig,
model: &str,
retry: RetryConfig,
) -> anyhow::Result<Arc<dyn LlmClient>> {
Ok(Arc::new(MiniMaxClient::new(provider_cfg, model, retry)))
}
}
pub struct OpenAiFactory;
impl LlmProviderFactory for OpenAiFactory {
fn name(&self) -> &str {
"openai"
}
fn build(
&self,
provider_cfg: &LlmProviderConfig,
model: &str,
retry: RetryConfig,
) -> anyhow::Result<Arc<dyn LlmClient>> {
Ok(Arc::new(OpenAiClient::new(provider_cfg, model, retry)))
}
}
#[cfg(test)]
mod tests {
use super::*;
use nexo_config::types::llm::RateLimitConfig;
use std::collections::HashMap;
fn provider_cfg() -> LlmProviderConfig {
LlmProviderConfig {
api_key: "k".into(),
group_id: None,
base_url: "http://example.invalid".into(),
rate_limit: RateLimitConfig {
requests_per_second: 1.0,
quota_alert_threshold: Some(100),
},
auth: None,
api_flavor: None,
embedding_model: None,
safety_settings: None,
}
}
fn llm_cfg(provider_name: &str) -> LlmConfig {
let mut providers = HashMap::new();
providers.insert(provider_name.to_string(), provider_cfg());
LlmConfig {
providers,
retry: RetryConfig {
max_attempts: 1,
initial_backoff_ms: 1,
max_backoff_ms: 1,
backoff_multiplier: 1.0,
},
context_optimization: Default::default(),
}
}
#[test]
fn builtins_present() {
let r = LlmRegistry::with_builtins();
let names = r.names();
assert!(names.contains(&"minimax"));
assert!(names.contains(&"openai"));
assert!(names.contains(&"anthropic"));
assert!(names.contains(&"gemini"));
assert!(names.contains(&"deepseek"));
}
#[test]
fn duplicate_register_errors() {
let mut r = LlmRegistry::with_builtins();
let err = r
.register(Box::new(MiniMaxFactory))
.expect_err("expected duplicate error");
assert!(err.to_string().contains("already registered"));
}
#[test]
fn build_unknown_provider_errors() {
let r = LlmRegistry::with_builtins();
let cfg = llm_cfg("minimax");
let model = ModelConfig {
provider: "nope".into(),
model: "x".into(),
};
let err = r.build(&cfg, &model).err().expect("expected error");
assert!(err.to_string().contains("not registered"));
}
#[test]
fn build_provider_missing_in_config_errors() {
let r = LlmRegistry::with_builtins();
let cfg = llm_cfg("minimax"); let model = ModelConfig {
provider: "openai".into(),
model: "gpt-x".into(),
};
let err = r.build(&cfg, &model).err().expect("expected error");
assert!(err.to_string().contains("config.providers"));
}
#[test]
fn build_minimax_returns_client() {
let r = LlmRegistry::with_builtins();
let cfg = llm_cfg("minimax");
let model = ModelConfig {
provider: "minimax".into(),
model: "m1".into(),
};
let client = r.build(&cfg, &model).expect("client");
assert_eq!(client.provider(), "minimax");
}
}