use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
#[serde(rename_all = "snake_case")]
pub enum ApiType {
#[default]
Openai,
Anthropic,
Google,
Other,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderSpec {
pub name: String,
#[serde(default)]
pub api_type: ApiType,
pub keywords: Vec<String>,
pub env_key: String,
pub display_name: String,
#[serde(default)]
pub default_model: Option<String>,
pub litellm_prefix: String,
pub skip_prefixes: Vec<String>,
pub env_extras: Vec<(String, String)>,
pub default_api_base: String,
#[serde(default)]
pub supports_prompt_caching: bool,
#[serde(default)]
pub models: Vec<String>,
pub model_overrides: Vec<(String, HashMap<String, serde_json::Value>)>,
}
impl ProviderSpec {
pub fn label(&self) -> String {
if !self.display_name.is_empty() {
self.display_name.clone()
} else {
let mut name = self.name.clone();
if let Some(first_char) = name.chars().next() {
name = first_char.to_uppercase().to_string() + &name[first_char.len_utf8()..];
}
name
}
}
pub fn default_model(&self) -> Option<&str> {
self.default_model
.as_deref()
.map(str::trim)
.filter(|value| !value.is_empty())
}
}
pub struct ProviderRegistry {
providers: Vec<ProviderSpec>,
}
impl ProviderRegistry {
pub fn new() -> Self {
Self {
providers: Self::default_providers(),
}
}
pub fn all(&self) -> &[ProviderSpec] {
&self.providers
}
pub fn find_by_model(&self, model: &str) -> Option<&ProviderSpec> {
let model_lower = model.to_lowercase();
self.providers
.iter()
.find(|spec| spec.keywords.iter().any(|kw| model_lower.contains(kw)))
}
pub fn find_by_name(&self, name: &str) -> Option<&ProviderSpec> {
self.providers.iter().find(|spec| spec.name == name)
}
fn default_providers() -> Vec<ProviderSpec> {
let yaml = include_str!("providers.yaml");
serde_yaml::from_str(yaml).expect("Failed to parse default providers configuration")
}
}
impl Default for ProviderRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_find_by_model() {
let registry = ProviderRegistry::new();
let spec = registry.find_by_model("claude-3-opus");
assert!(spec.is_some());
assert_eq!(spec.unwrap().name, "anthropic");
let spec = registry.find_by_model("deepseek-chat");
assert!(spec.is_some());
assert_eq!(spec.unwrap().name, "deepseek");
let spec = registry.find_by_model("qwen-max");
assert!(spec.is_some());
assert_eq!(spec.unwrap().name, "dashscope");
let spec = registry.find_by_model("MiniMax-M2.1");
assert!(spec.is_some());
assert_eq!(spec.unwrap().name, "minimax");
}
#[test]
fn test_find_by_name() {
let registry = ProviderRegistry::new();
let spec = registry.find_by_name("anthropic");
assert!(spec.is_some());
assert_eq!(spec.unwrap().display_name, "Anthropic");
}
}