use std::collections::HashMap;
use serde::Deserialize;
use tracing::warn;
const PROVIDER_NAME_MAP: &[(&str, &[&str])] = &[
("zai-coding", &["zai"]),
("openai", &["openai"]),
("deepseek", &["deepseek"]),
("anthropic", &["anthropic"]),
("google-gemini", &["gemini"]),
("groq", &["groq"]),
("mistral", &["mistral"]),
("together-ai", &["together_ai"]),
("openrouter", &["openrouter"]),
("ollama", &["ollama"]),
("lm-studio", &["ollama"]),
("azure-openai", &["azure", "azure_ai"]),
("bedrock", &["bedrock", "bedrock_converse"]),
];
#[derive(Debug, Clone)]
pub struct RegistryModel {
pub key: String,
pub litellm_provider: String,
pub max_input_tokens: Option<u32>,
pub max_output_tokens: Option<u32>,
pub supports_function_calling: bool,
pub supports_vision: bool,
pub input_cost_per_million: Option<f64>,
pub output_cost_per_million: Option<f64>,
}
#[derive(Deserialize)]
struct RawEntry {
litellm_provider: Option<String>,
#[serde(default, deserialize_with = "lenient_u32::deserialize")]
max_input_tokens: Option<u32>,
#[serde(default, deserialize_with = "lenient_u32::deserialize")]
max_output_tokens: Option<u32>,
#[serde(default, deserialize_with = "lenient_u32::deserialize")]
max_tokens: Option<u32>,
supports_function_calling: Option<bool>,
supports_vision: Option<bool>,
input_cost_per_token: Option<f64>,
output_cost_per_token: Option<f64>,
}
mod lenient_u32 {
use serde::Deserializer;
pub fn deserialize<'de, D>(d: D) -> Result<Option<u32>, D::Error>
where
D: Deserializer<'de>,
{
struct Visitor;
impl serde::de::Visitor<'_> for Visitor {
type Value = Option<u32>;
fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.write_str("u32 or string")
}
fn visit_u64<E: serde::de::Error>(self, v: u64) -> Result<Self::Value, E> {
Ok(Some(v as u32))
}
fn visit_i64<E: serde::de::Error>(self, v: i64) -> Result<Self::Value, E> {
Ok(u32::try_from(v).ok())
}
fn visit_f64<E: serde::de::Error>(self, v: f64) -> Result<Self::Value, E> {
Ok(Some(v as u32))
}
fn visit_str<E: serde::de::Error>(self, _: &str) -> Result<Self::Value, E> {
Ok(None)
}
fn visit_none<E: serde::de::Error>(self) -> Result<Self::Value, E> {
Ok(None)
}
fn visit_unit<E: serde::de::Error>(self) -> Result<Self::Value, E> {
Ok(None)
}
}
d.deserialize_any(Visitor)
}
}
#[derive(Debug, Clone)]
pub struct ModelRegistry {
models: Vec<RegistryModel>,
}
impl ModelRegistry {
pub fn parse(data: &[u8]) -> Self {
let raw: HashMap<String, RawEntry> = match serde_json::from_slice(data) {
Ok(map) => map,
Err(e) => {
warn!("failed to parse model registry JSON: {e}");
return Self::empty();
}
};
let mut models: Vec<RegistryModel> = raw
.into_iter()
.filter(|(key, _)| key != "sample_spec")
.filter_map(|(key, entry)| {
let provider = entry.litellm_provider?;
Some(RegistryModel {
key,
litellm_provider: provider,
max_input_tokens: entry.max_input_tokens,
max_output_tokens: entry.max_output_tokens.or(entry.max_tokens),
supports_function_calling: entry.supports_function_calling.unwrap_or(false),
supports_vision: entry.supports_vision.unwrap_or(false),
input_cost_per_million: entry.input_cost_per_token.map(|c| c * 1_000_000.0),
output_cost_per_million: entry.output_cost_per_token.map(|c| c * 1_000_000.0),
})
})
.collect();
models.sort_by(|a, b| a.key.cmp(&b.key));
Self { models }
}
pub fn empty() -> Self {
Self { models: Vec::new() }
}
pub fn is_empty(&self) -> bool {
self.models.is_empty()
}
pub fn len(&self) -> usize {
self.models.len()
}
pub fn models_for_provider(&self, collet_provider: &str) -> Vec<&RegistryModel> {
let litellm_names = resolve_provider(collet_provider);
let raw: Vec<&RegistryModel> = if litellm_names.is_empty() {
self.models
.iter()
.filter(|m| m.litellm_provider == collet_provider)
.collect()
} else {
self.models
.iter()
.filter(|m| litellm_names.iter().any(|n| *n == m.litellm_provider))
.collect()
};
let mut seen = std::collections::HashSet::new();
raw.into_iter()
.filter(|m| {
let display = Self::model_name(&m.key);
!display.starts_with("ft:") && seen.insert(display.to_string())
})
.collect()
}
pub fn model_name(key: &str) -> &str {
key.rsplit_once('/').map_or(key, |(_, name)| name)
}
pub fn find_model(&self, collet_provider: &str, model_name: &str) -> Option<&RegistryModel> {
let candidates = self.models_for_provider(collet_provider);
if let Some(m) = candidates.iter().find(|m| m.key == model_name) {
return Some(m);
}
let suffix = format!("/{model_name}");
if let Some(m) = candidates.iter().find(|m| m.key.ends_with(&suffix)) {
return Some(m);
}
candidates.into_iter().find(|m| {
Self::model_name(&m.key).contains(model_name)
|| model_name.contains(Self::model_name(&m.key))
})
}
pub fn providers(&self) -> Vec<String> {
self.models
.iter()
.map(|m| m.litellm_provider.clone())
.collect::<std::collections::BTreeSet<_>>()
.into_iter()
.collect()
}
}
fn resolve_provider(collet_name: &str) -> Vec<&'static str> {
for &(collet, litellm_names) in PROVIDER_NAME_MAP {
if collet == collet_name {
return litellm_names.to_vec();
}
}
vec![]
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_json() -> &'static [u8] {
br#"{
"sample_spec": {
"litellm_provider": "sample",
"max_input_tokens": 100
},
"deepseek/deepseek-chat": {
"litellm_provider": "deepseek",
"max_input_tokens": 128000,
"max_output_tokens": 8192,
"supports_function_calling": true,
"supports_vision": false,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000008
},
"openai/gpt-4o": {
"litellm_provider": "openai",
"max_input_tokens": 128000,
"max_tokens": 4096,
"supports_function_calling": true,
"supports_vision": true,
"input_cost_per_token": 0.000005,
"output_cost_per_token": 0.000015
},
"no-provider-model": {
"max_input_tokens": 1000
}
}"#
}
#[test]
fn parse_skips_sample_spec_and_missing_provider() {
let reg = ModelRegistry::parse(sample_json());
assert_eq!(reg.models.len(), 2);
assert!(reg.models.iter().all(|m| m.key != "sample_spec"));
assert!(reg.models.iter().all(|m| m.key != "no-provider-model"));
}
#[test]
fn parse_sorts_by_key() {
let reg = ModelRegistry::parse(sample_json());
assert_eq!(reg.models[0].key, "deepseek/deepseek-chat");
assert_eq!(reg.models[1].key, "openai/gpt-4o");
}
#[test]
fn cost_conversion() {
let reg = ModelRegistry::parse(sample_json());
let ds = reg
.models
.iter()
.find(|m| m.key == "deepseek/deepseek-chat")
.unwrap();
let input = ds.input_cost_per_million.unwrap();
assert!((input - 3.0).abs() < 1e-9, "expected 3.0, got {input}");
let output = ds.output_cost_per_million.unwrap();
assert!((output - 8.0).abs() < 1e-9, "expected 8.0, got {output}");
}
#[test]
fn max_tokens_fallback() {
let reg = ModelRegistry::parse(sample_json());
let gpt = reg
.models
.iter()
.find(|m| m.key == "openai/gpt-4o")
.unwrap();
assert_eq!(gpt.max_output_tokens, Some(4096));
}
#[test]
fn models_for_provider_mapping() {
let reg = ModelRegistry::parse(sample_json());
let ds_models = reg.models_for_provider("deepseek");
assert_eq!(ds_models.len(), 1);
assert_eq!(ds_models[0].key, "deepseek/deepseek-chat");
}
#[test]
fn models_for_provider_direct_fallback() {
let reg = ModelRegistry::parse(sample_json());
let direct = reg.models_for_provider("openai");
assert_eq!(direct.len(), 1);
assert_eq!(direct[0].key, "openai/gpt-4o");
}
#[test]
fn find_model_exact_key() {
let reg = ModelRegistry::parse(sample_json());
let m = reg.find_model("deepseek", "deepseek/deepseek-chat");
assert!(m.is_some());
assert_eq!(m.unwrap().key, "deepseek/deepseek-chat");
}
#[test]
fn find_model_trailing_path() {
let reg = ModelRegistry::parse(sample_json());
let m = reg.find_model("deepseek", "deepseek-chat");
assert!(m.is_some(), "should match via trailing path component");
assert_eq!(m.unwrap().key, "deepseek/deepseek-chat");
}
#[test]
fn find_model_openai() {
let reg = ModelRegistry::parse(sample_json());
let m = reg.find_model("openai", "gpt-4o");
assert!(m.is_some());
assert_eq!(m.unwrap().key, "openai/gpt-4o");
}
#[test]
fn find_model_unknown_returns_none() {
let reg = ModelRegistry::parse(sample_json());
let m = reg.find_model("deepseek", "nonexistent-model");
assert!(m.is_none());
}
#[test]
fn model_name_extraction() {
assert_eq!(
ModelRegistry::model_name("deepseek/deepseek-chat"),
"deepseek-chat"
);
assert_eq!(ModelRegistry::model_name("gpt-4o"), "gpt-4o");
assert_eq!(ModelRegistry::model_name("a/b/c"), "c");
}
#[test]
fn providers_sorted() {
let reg = ModelRegistry::parse(sample_json());
let providers = reg.providers();
assert_eq!(providers, vec!["deepseek", "openai"]);
}
#[test]
fn empty_registry() {
let reg = ModelRegistry::empty();
assert!(reg.is_empty());
assert_eq!(reg.providers().len(), 0);
}
#[test]
fn invalid_json_returns_empty() {
let reg = ModelRegistry::parse(b"not json");
assert!(reg.is_empty());
}
}