use std::{collections::HashMap, sync::Arc};
use anyhow::Result;
use super::LlmProvider;
#[derive(Default)]
pub struct ProviderRegistry {
providers: HashMap<String, Arc<dyn LlmProvider>>,
model_aliases: HashMap<String, String>,
}
impl std::fmt::Debug for ProviderRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ProviderRegistry")
.field("providers", &self.providers.keys().collect::<Vec<_>>())
.finish()
}
}
impl ProviderRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn register(&mut self, name: impl Into<String>, provider: Arc<dyn LlmProvider>) {
self.providers.insert(name.into(), provider);
}
pub fn set_model_aliases(&mut self, aliases: HashMap<String, String>) {
self.model_aliases = aliases;
}
pub fn get(&self, name: &str) -> Result<Arc<dyn LlmProvider>> {
self.providers
.get(name)
.cloned()
.ok_or_else(|| anyhow::anyhow!("provider not registered: {name}"))
}
pub fn names(&self) -> Vec<&str> {
self.providers.keys().map(String::as_str).collect()
}
pub fn parse_model(model: &str) -> (&str, &str) {
if let Some((provider, model_id)) = model.split_once('/') {
(provider, model_id)
} else {
(infer_provider(model), model)
}
}
pub fn resolve_model<'a>(&'a self, model: &'a str) -> (&'a str, &'a str) {
if let Some(alias_provider) = self.model_aliases.get(model) {
if self.providers.contains_key(alias_provider.as_str()) {
let prefix = format!("{}/", alias_provider);
let model_id = model.strip_prefix(&prefix).unwrap_or(model);
return (alias_provider.as_str(), model_id);
}
}
let (provider, model_id) = Self::parse_model(model);
(provider, model_id)
}
}
fn infer_provider(model: &str) -> &str {
let m = model.to_ascii_lowercase();
if m.starts_with("claude") {
"anthropic"
} else if m.starts_with("gemini") {
"gemini"
} else if m.starts_with("deepseek") {
"deepseek"
} else if m.starts_with("qwen") {
"qwen"
} else if m.starts_with("glm") || m.starts_with("chatglm") {
"zhipu"
} else if m.starts_with("moonshot") || m.starts_with("kimi") {
"kimi"
} else if m.starts_with("step") {
"stepfun"
} else if m.starts_with("grok") {
"xai"
} else {
"openai"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_model_explicit_prefix() {
assert_eq!(
ProviderRegistry::parse_model("anthropic/claude-sonnet-4-5"),
("anthropic", "claude-sonnet-4-5")
);
assert_eq!(
ProviderRegistry::parse_model("groq/llama-3-70b"),
("groq", "llama-3-70b")
);
assert_eq!(
ProviderRegistry::parse_model("gemini/gemini-2.0-flash"),
("gemini", "gemini-2.0-flash")
);
}
#[test]
fn parse_model_inferred_anthropic() {
assert_eq!(
ProviderRegistry::parse_model("claude-sonnet-4-5"),
("anthropic", "claude-sonnet-4-5")
);
assert_eq!(
ProviderRegistry::parse_model("claude-3-5-sonnet-20241022"),
("anthropic", "claude-3-5-sonnet-20241022")
);
}
#[test]
fn parse_model_inferred_gemini() {
assert_eq!(
ProviderRegistry::parse_model("gemini-2.0-flash"),
("gemini", "gemini-2.0-flash")
);
}
#[test]
fn parse_model_inferred_openai() {
assert_eq!(
ProviderRegistry::parse_model("gpt-4o"),
("openai", "gpt-4o")
);
assert_eq!(
ProviderRegistry::parse_model("o1-preview"),
("openai", "o1-preview")
);
assert_eq!(
ProviderRegistry::parse_model("o3-mini"),
("openai", "o3-mini")
);
}
#[test]
fn parse_model_unknown_defaults_to_openai() {
assert_eq!(
ProviderRegistry::parse_model("some-unknown-model"),
("openai", "some-unknown-model")
);
}
#[test]
fn parse_model_chinese_providers() {
assert_eq!(
ProviderRegistry::parse_model("deepseek-chat"),
("deepseek", "deepseek-chat")
);
assert_eq!(
ProviderRegistry::parse_model("qwen-turbo"),
("qwen", "qwen-turbo")
);
assert_eq!(ProviderRegistry::parse_model("glm-4"), ("zhipu", "glm-4"));
assert_eq!(
ProviderRegistry::parse_model("moonshot-v1-8k"),
("kimi", "moonshot-v1-8k")
);
assert_eq!(ProviderRegistry::parse_model("grok-2"), ("xai", "grok-2"));
}
}