use serde::Deserialize;
use std::collections::{HashMap, HashSet};
use std::sync::LazyLock;
#[derive(Debug, Clone, Deserialize)]
pub struct ModelCost {
pub input: f64,
pub output: f64,
#[serde(default)]
pub cache_read: Option<f64>,
#[serde(default)]
pub cache_write: Option<f64>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ModelLimit {
pub context: u64,
pub output: u64,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ModelModalities {
pub input: Vec<String>,
pub output: Vec<String>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ModelCapabilities {
#[serde(default)]
pub attachment: bool,
#[serde(default)]
pub reasoning: bool,
#[serde(default)]
pub temperature: bool,
#[serde(default)]
pub tool_call: bool,
#[serde(default = "default_true")]
pub streaming: bool,
}
fn default_true() -> bool {
true
}
#[derive(Debug, Clone, Deserialize)]
pub struct ModelDescriptor {
pub id: String,
pub name: String,
#[serde(default)]
pub family: Option<String>,
#[serde(default)]
pub release_date: Option<String>,
#[serde(default)]
pub cost: Option<ModelCost>,
#[serde(default)]
pub limit: Option<ModelLimit>,
#[serde(default)]
pub modalities: Option<ModelModalities>,
#[serde(default)]
pub capabilities: Option<ModelCapabilities>,
#[serde(default)]
pub knowledge: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ServiceDescriptor {
pub id: String,
#[serde(rename = "display_name")]
pub display_name: String,
pub description: String,
pub category: String,
pub family: String,
#[serde(rename = "auth_mode")]
pub auth_mode: String,
#[serde(rename = "key_var", skip_serializing_if = "String::is_empty", default)]
pub key_var: String,
#[serde(
rename = "literal_auth_token",
skip_serializing_if = "String::is_empty",
default
)]
pub literal_auth_token: String,
#[serde(rename = "base_url")]
pub base_url: String,
#[serde(rename = "default_model")]
pub default_model: String,
#[serde(rename = "model_tiers", default)]
pub model_tiers: HashMap<String, String>,
#[serde(rename = "model_choices", default)]
pub model_choices: Vec<ModelChoice>,
#[serde(rename = "test_url")]
pub test_url: String,
#[serde(default)]
pub setup: Vec<String>,
#[serde(default)]
pub usage: Vec<String>,
#[serde(default)]
pub api_base_url: Option<String>,
#[serde(default)]
pub npm_package: Option<String>,
#[serde(default)]
pub doc_url: Option<String>,
#[serde(default)]
pub models: Vec<ModelDescriptor>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ModelChoice {
pub id: String,
pub description: String,
}
#[derive(Debug, Deserialize)]
struct IndexPayload {
providers: Vec<ServiceDescriptor>,
}
#[derive(Debug, Clone)]
pub struct ProviderIndex {
entries: Vec<ServiceDescriptor>,
index: HashMap<String, usize>,
}
impl ProviderIndex {
fn from_payload(payload: IndexPayload) -> Self {
let index: HashMap<String, usize> = payload
.providers
.iter()
.enumerate()
.map(|(i, p)| (p.id.clone(), i))
.collect();
Self {
entries: payload.providers,
index,
}
}
pub fn embedded() -> &'static ProviderIndex {
&EMBEDDED
}
pub fn all(&self) -> &[ServiceDescriptor] {
&self.entries
}
pub fn ids(&self) -> Vec<String> {
self.entries.iter().map(|p| p.id.clone()).collect()
}
pub fn get(&self, id: &str) -> Option<&ServiceDescriptor> {
self.index.get(id).map(|&i| &self.entries[i])
}
pub fn categories(&self) -> Vec<String> {
self.entries
.iter()
.scan(HashSet::new(), |seen, p| {
Some(if seen.insert(p.category.clone()) {
Some(p.category.clone())
} else {
None
})
})
.flatten()
.collect()
}
pub fn providers_by_category(&self, category: &str) -> Vec<&ServiceDescriptor> {
self.entries
.iter()
.filter(|p| p.category == category)
.collect()
}
pub fn builtin_secret_keys(&self) -> HashSet<String> {
self.entries
.iter()
.filter(|p| !p.key_var.is_empty())
.map(|p| p.key_var.clone())
.collect()
}
pub fn models_for(&self, provider_id: &str) -> &[ModelDescriptor] {
self.get(provider_id)
.map(|p| p.models.as_slice())
.unwrap_or(&[])
}
pub fn find_model(&self, model_id: &str) -> Option<(&ServiceDescriptor, &ModelDescriptor)> {
self.entries
.iter()
.find_map(|p| p.models.iter().find(|m| m.id == model_id).map(|m| (p, m)))
}
pub fn estimate_cost(
&self,
model_id: &str,
prompt_tokens: u32,
completion_tokens: u32,
) -> Option<f64> {
let (_, model) = self.find_model(model_id)?;
let cost = model.cost.as_ref()?;
Some(
cost.input * prompt_tokens as f64 / 1_000_000.0
+ cost.output * completion_tokens as f64 / 1_000_000.0,
)
}
}
static EMBEDDED: LazyLock<ProviderIndex> = LazyLock::new(|| {
let raw = include_str!("catalog.json");
let payload: IndexPayload = serde_json::from_str(raw).expect("catalog.json is valid");
ProviderIndex::from_payload(payload)
});
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_embedded_loads() {
let catalog = ProviderIndex::embedded();
assert!(!catalog.all().is_empty());
}
#[test]
fn test_get_known_provider() {
let catalog = ProviderIndex::embedded();
let p = catalog.get("zai").expect("zai should exist");
assert_eq!(p.id, "zai");
assert!(!p.base_url.is_empty());
assert!(!p.default_model.is_empty());
}
#[test]
fn test_get_unknown_returns_none() {
let catalog = ProviderIndex::embedded();
assert!(catalog.get("nonexistent_provider_xyz").is_none());
}
#[test]
fn test_categories_no_duplicates() {
let catalog = ProviderIndex::embedded();
let cats = catalog.categories();
let mut seen = HashSet::new();
for c in &cats {
assert!(seen.insert(c.clone()), "duplicate category: {}", c);
}
}
#[test]
fn test_builtin_secret_keys() {
let catalog = ProviderIndex::embedded();
let keys = catalog.builtin_secret_keys();
assert!(!keys.is_empty(), "should contain at least one secret key");
assert!(
keys.contains("ZAI_API_KEY"),
"should contain ZAI_API_KEY, got: {:?}",
keys
);
}
#[test]
fn test_providers_by_category() {
let catalog = ProviderIndex::embedded();
let cats = catalog.categories();
if let Some(cat) = cats.first() {
let providers = catalog.providers_by_category(cat);
assert!(!providers.is_empty());
for p in &providers {
assert_eq!(p.category, *cat);
}
}
}
#[test]
fn test_models_for_provider() {
let catalog = ProviderIndex::embedded();
let models = catalog.models_for("zai");
assert!(!models.is_empty(), "zai should have models");
assert!(!models[0].id.is_empty());
}
#[test]
fn test_models_for_unknown_provider() {
let catalog = ProviderIndex::embedded();
let models = catalog.models_for("nonexistent_provider_xyz");
assert!(models.is_empty());
}
#[test]
fn test_find_model() {
let catalog = ProviderIndex::embedded();
let (provider, model) = catalog.find_model("glm-5").expect("glm-5 should be found");
assert_eq!(model.id, "glm-5");
assert!(
provider.id == "zai" || provider.id == "zai-cn",
"glm-5 should belong to a Z.AI provider, got: {}",
provider.id
);
}
#[test]
fn test_find_model_unknown() {
let catalog = ProviderIndex::embedded();
assert!(catalog.find_model("nonexistent-model-xyz").is_none());
}
#[test]
fn test_model_has_pricing() {
let catalog = ProviderIndex::embedded();
let (_, model) = catalog.find_model("glm-5").expect("glm-5 should exist");
let cost = model.cost.as_ref().expect("glm-5 should have cost");
assert!(cost.input > 0.0, "input cost should be positive");
assert!(cost.output > 0.0, "output cost should be positive");
}
}