1use serde::Deserialize;
2use std::collections::{HashMap, HashSet};
3use std::sync::LazyLock;
4
5#[derive(Debug, Clone, Deserialize)]
11pub struct ModelCost {
12 pub input: f64,
14 pub output: f64,
16 #[serde(default)]
18 pub cache_read: Option<f64>,
19 #[serde(default)]
21 pub cache_write: Option<f64>,
22}
23
24#[derive(Debug, Clone, Deserialize)]
26pub struct ModelLimit {
27 pub context: u64,
29 pub output: u64,
31}
32
33#[derive(Debug, Clone, Deserialize)]
35pub struct ModelModalities {
36 pub input: Vec<String>,
38 pub output: Vec<String>,
40}
41
42#[derive(Debug, Clone, Deserialize)]
44pub struct ModelCapabilities {
45 #[serde(default)]
47 pub attachment: bool,
48 #[serde(default)]
50 pub reasoning: bool,
51 #[serde(default)]
53 pub temperature: bool,
54 #[serde(default)]
56 pub tool_call: bool,
57 #[serde(default = "default_true")]
59 pub streaming: bool,
60}
61
62fn default_true() -> bool {
63 true
64}
65
66#[derive(Debug, Clone, Deserialize)]
68pub struct ModelDescriptor {
69 pub id: String,
71 pub name: String,
73 #[serde(default)]
75 pub family: Option<String>,
76 #[serde(default)]
78 pub release_date: Option<String>,
79 #[serde(default)]
81 pub cost: Option<ModelCost>,
82 #[serde(default)]
84 pub limit: Option<ModelLimit>,
85 #[serde(default)]
87 pub modalities: Option<ModelModalities>,
88 #[serde(default)]
90 pub capabilities: Option<ModelCapabilities>,
91 #[serde(default)]
93 pub knowledge: Option<String>,
94}
95
96#[derive(Debug, Clone, Deserialize)]
102pub struct ServiceDescriptor {
103 pub id: String,
105 #[serde(rename = "display_name")]
107 pub display_name: String,
108 pub description: String,
110 pub category: String,
112 pub family: String,
114 #[serde(rename = "auth_mode")]
116 pub auth_mode: String,
117 #[serde(rename = "key_var", skip_serializing_if = "String::is_empty", default)]
119 pub key_var: String,
120 #[serde(
122 rename = "literal_auth_token",
123 skip_serializing_if = "String::is_empty",
124 default
125 )]
126 pub literal_auth_token: String,
127 #[serde(rename = "base_url")]
129 pub base_url: String,
130 #[serde(rename = "default_model")]
132 pub default_model: String,
133 #[serde(rename = "model_tiers", default)]
135 pub model_tiers: HashMap<String, String>,
136 #[serde(rename = "model_choices", default)]
138 pub model_choices: Vec<ModelChoice>,
139 #[serde(rename = "test_url")]
141 pub test_url: String,
142 #[serde(default)]
144 pub setup: Vec<String>,
145 #[serde(default)]
147 pub usage: Vec<String>,
148
149 #[serde(default)]
152 pub api_base_url: Option<String>,
153 #[serde(default)]
155 pub npm_package: Option<String>,
156 #[serde(default)]
158 pub doc_url: Option<String>,
159 #[serde(default)]
161 pub models: Vec<ModelDescriptor>,
162}
163
164#[derive(Debug, Clone, Deserialize)]
167pub struct ModelChoice {
168 pub id: String,
170 pub description: String,
172}
173
174#[derive(Debug, Deserialize)]
175struct IndexPayload {
176 providers: Vec<ServiceDescriptor>,
177}
178
179#[derive(Debug, Clone)]
188pub struct ProviderIndex {
189 entries: Vec<ServiceDescriptor>,
190 index: HashMap<String, usize>,
191}
192
193impl ProviderIndex {
194 fn from_payload(payload: IndexPayload) -> Self {
195 let index: HashMap<String, usize> = payload
196 .providers
197 .iter()
198 .enumerate()
199 .map(|(i, p)| (p.id.clone(), i))
200 .collect();
201 Self {
202 entries: payload.providers,
203 index,
204 }
205 }
206
207 pub fn embedded() -> &'static ProviderIndex {
209 &EMBEDDED
210 }
211
212 pub fn all(&self) -> &[ServiceDescriptor] {
214 &self.entries
215 }
216
217 pub fn ids(&self) -> Vec<String> {
219 self.entries.iter().map(|p| p.id.clone()).collect()
220 }
221
222 pub fn get(&self, id: &str) -> Option<&ServiceDescriptor> {
224 self.index.get(id).map(|&i| &self.entries[i])
225 }
226
227 pub fn categories(&self) -> Vec<String> {
229 self.entries
230 .iter()
231 .scan(HashSet::new(), |seen, p| {
232 Some(if seen.insert(p.category.clone()) {
233 Some(p.category.clone())
234 } else {
235 None
236 })
237 })
238 .flatten()
239 .collect()
240 }
241
242 pub fn providers_by_category(&self, category: &str) -> Vec<&ServiceDescriptor> {
244 self.entries
245 .iter()
246 .filter(|p| p.category == category)
247 .collect()
248 }
249
250 pub fn builtin_secret_keys(&self) -> HashSet<String> {
252 self.entries
253 .iter()
254 .filter(|p| !p.key_var.is_empty())
255 .map(|p| p.key_var.clone())
256 .collect()
257 }
258
259 pub fn models_for(&self, provider_id: &str) -> &[ModelDescriptor] {
261 self.get(provider_id)
262 .map(|p| p.models.as_slice())
263 .unwrap_or(&[])
264 }
265
266 pub fn find_model(&self, model_id: &str) -> Option<(&ServiceDescriptor, &ModelDescriptor)> {
269 self.entries
270 .iter()
271 .find_map(|p| p.models.iter().find(|m| m.id == model_id).map(|m| (p, m)))
272 }
273
274 pub fn estimate_cost(
281 &self,
282 model_id: &str,
283 prompt_tokens: u32,
284 completion_tokens: u32,
285 ) -> Option<f64> {
286 let (_, model) = self.find_model(model_id)?;
287 let cost = model.cost.as_ref()?;
288 Some(
289 cost.input * prompt_tokens as f64 / 1_000_000.0
290 + cost.output * completion_tokens as f64 / 1_000_000.0,
291 )
292 }
293}
294
295static EMBEDDED: LazyLock<ProviderIndex> = LazyLock::new(|| {
297 let raw = include_str!("catalog.json");
298 let payload: IndexPayload = serde_json::from_str(raw).expect("catalog.json is valid");
299 ProviderIndex::from_payload(payload)
300});
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305
306 #[test]
307 fn test_embedded_loads() {
308 let catalog = ProviderIndex::embedded();
309 assert!(!catalog.all().is_empty());
310 }
311
312 #[test]
313 fn test_get_known_provider() {
314 let catalog = ProviderIndex::embedded();
315 let p = catalog.get("zai").expect("zai should exist");
317 assert_eq!(p.id, "zai");
318 assert!(!p.base_url.is_empty());
319 assert!(!p.default_model.is_empty());
320 }
321
322 #[test]
323 fn test_get_unknown_returns_none() {
324 let catalog = ProviderIndex::embedded();
325 assert!(catalog.get("nonexistent_provider_xyz").is_none());
326 }
327
328 #[test]
329 fn test_categories_no_duplicates() {
330 let catalog = ProviderIndex::embedded();
331 let cats = catalog.categories();
332 let mut seen = HashSet::new();
333 for c in &cats {
334 assert!(seen.insert(c.clone()), "duplicate category: {}", c);
335 }
336 }
337
338 #[test]
339 fn test_builtin_secret_keys() {
340 let catalog = ProviderIndex::embedded();
341 let keys = catalog.builtin_secret_keys();
342 assert!(!keys.is_empty(), "should contain at least one secret key");
343 assert!(
344 keys.contains("ZAI_API_KEY"),
345 "should contain ZAI_API_KEY, got: {:?}",
346 keys
347 );
348 }
349
350 #[test]
351 fn test_providers_by_category() {
352 let catalog = ProviderIndex::embedded();
353 let cats = catalog.categories();
354 if let Some(cat) = cats.first() {
355 let providers = catalog.providers_by_category(cat);
356 assert!(!providers.is_empty());
357 for p in &providers {
358 assert_eq!(p.category, *cat);
359 }
360 }
361 }
362
363 #[test]
364 fn test_models_for_provider() {
365 let catalog = ProviderIndex::embedded();
366 let models = catalog.models_for("zai");
367 assert!(!models.is_empty(), "zai should have models");
368 assert!(!models[0].id.is_empty());
370 }
371
372 #[test]
373 fn test_models_for_unknown_provider() {
374 let catalog = ProviderIndex::embedded();
375 let models = catalog.models_for("nonexistent_provider_xyz");
376 assert!(models.is_empty());
377 }
378
379 #[test]
380 fn test_find_model() {
381 let catalog = ProviderIndex::embedded();
382 let (provider, model) = catalog.find_model("glm-5").expect("glm-5 should be found");
383 assert_eq!(model.id, "glm-5");
384 assert!(
385 provider.id == "zai" || provider.id == "zai-cn",
386 "glm-5 should belong to a Z.AI provider, got: {}",
387 provider.id
388 );
389 }
390
391 #[test]
392 fn test_find_model_unknown() {
393 let catalog = ProviderIndex::embedded();
394 assert!(catalog.find_model("nonexistent-model-xyz").is_none());
395 }
396
397 #[test]
398 fn test_model_has_pricing() {
399 let catalog = ProviderIndex::embedded();
400 let (_, model) = catalog.find_model("glm-5").expect("glm-5 should exist");
401 let cost = model.cost.as_ref().expect("glm-5 should have cost");
402 assert!(cost.input > 0.0, "input cost should be positive");
403 assert!(cost.output > 0.0, "output cost should be positive");
404 }
405}