1use serde::Deserialize;
2use std::collections::{HashMap, HashSet};
3use std::sync::LazyLock;
4
5#[derive(Debug, Clone, Deserialize)]
11pub struct ModelCost {
12 pub input: f64,
13 pub output: f64,
14 #[serde(default)]
15 pub cache_read: Option<f64>,
16 #[serde(default)]
17 pub cache_write: Option<f64>,
18}
19
20#[derive(Debug, Clone, Deserialize)]
22pub struct ModelLimit {
23 pub context: u64,
24 pub output: u64,
25}
26
27#[derive(Debug, Clone, Deserialize)]
29pub struct ModelModalities {
30 pub input: Vec<String>,
31 pub output: Vec<String>,
32}
33
34#[derive(Debug, Clone, Deserialize)]
36pub struct ModelCapabilities {
37 #[serde(default)]
38 pub attachment: bool,
39 #[serde(default)]
40 pub reasoning: bool,
41 #[serde(default)]
42 pub temperature: bool,
43 #[serde(default)]
44 pub tool_call: bool,
45 #[serde(default = "default_true")]
46 pub streaming: bool,
47}
48
49fn default_true() -> bool {
50 true
51}
52
53#[derive(Debug, Clone, Deserialize)]
55pub struct ModelDescriptor {
56 pub id: String,
57 pub name: String,
58 #[serde(default)]
59 pub family: Option<String>,
60 #[serde(default)]
61 pub release_date: Option<String>,
62 #[serde(default)]
63 pub cost: Option<ModelCost>,
64 #[serde(default)]
65 pub limit: Option<ModelLimit>,
66 #[serde(default)]
67 pub modalities: Option<ModelModalities>,
68 #[serde(default)]
69 pub capabilities: Option<ModelCapabilities>,
70 #[serde(default)]
71 pub knowledge: Option<String>,
72}
73
74#[derive(Debug, Clone, Deserialize)]
80pub struct ServiceDescriptor {
81 pub id: String,
82 #[serde(rename = "display_name")]
83 pub display_name: String,
84 pub description: String,
85 pub category: String,
86 pub family: String,
87 #[serde(rename = "auth_mode")]
88 pub auth_mode: String,
89 #[serde(rename = "key_var", skip_serializing_if = "String::is_empty", default)]
90 pub key_var: String,
91 #[serde(
92 rename = "literal_auth_token",
93 skip_serializing_if = "String::is_empty",
94 default
95 )]
96 pub literal_auth_token: String,
97 #[serde(rename = "base_url")]
98 pub base_url: String,
99 #[serde(rename = "default_model")]
100 pub default_model: String,
101 #[serde(rename = "model_tiers", default)]
102 pub model_tiers: HashMap<String, String>,
103 #[serde(rename = "model_choices", default)]
104 pub model_choices: Vec<ModelChoice>,
105 #[serde(rename = "test_url")]
106 pub test_url: String,
107 #[serde(default)]
108 pub setup: Vec<String>,
109 #[serde(default)]
110 pub usage: Vec<String>,
111
112 #[serde(default)]
114 pub api_base_url: Option<String>,
115 #[serde(default)]
116 pub npm_package: Option<String>,
117 #[serde(default)]
118 pub doc_url: Option<String>,
119 #[serde(default)]
120 pub models: Vec<ModelDescriptor>,
121}
122
123#[derive(Debug, Clone, Deserialize)]
126pub struct ModelChoice {
127 pub id: String,
128 pub description: String,
129}
130
131#[derive(Debug, Deserialize)]
132struct IndexPayload {
133 providers: Vec<ServiceDescriptor>,
134}
135
136#[derive(Debug, Clone)]
145pub struct ProviderIndex {
146 entries: Vec<ServiceDescriptor>,
147 index: HashMap<String, usize>,
148}
149
150impl ProviderIndex {
151 fn from_payload(payload: IndexPayload) -> Self {
152 let index: HashMap<String, usize> = payload
153 .providers
154 .iter()
155 .enumerate()
156 .map(|(i, p)| (p.id.clone(), i))
157 .collect();
158 Self {
159 entries: payload.providers,
160 index,
161 }
162 }
163
164 pub fn embedded() -> &'static ProviderIndex {
166 &EMBEDDED
167 }
168
169 pub fn all(&self) -> &[ServiceDescriptor] {
171 &self.entries
172 }
173
174 pub fn ids(&self) -> Vec<String> {
176 self.entries.iter().map(|p| p.id.clone()).collect()
177 }
178
179 pub fn get(&self, id: &str) -> Option<&ServiceDescriptor> {
181 self.index.get(id).map(|&i| &self.entries[i])
182 }
183
184 pub fn categories(&self) -> Vec<String> {
186 self.entries
187 .iter()
188 .scan(HashSet::new(), |seen, p| {
189 Some(if seen.insert(p.category.clone()) {
190 Some(p.category.clone())
191 } else {
192 None
193 })
194 })
195 .flatten()
196 .collect()
197 }
198
199 pub fn providers_by_category(&self, category: &str) -> Vec<&ServiceDescriptor> {
201 self.entries
202 .iter()
203 .filter(|p| p.category == category)
204 .collect()
205 }
206
207 pub fn builtin_secret_keys(&self) -> HashSet<String> {
209 self.entries
210 .iter()
211 .filter(|p| !p.key_var.is_empty())
212 .map(|p| p.key_var.clone())
213 .collect()
214 }
215
216 pub fn models_for(&self, provider_id: &str) -> &[ModelDescriptor] {
218 self.get(provider_id)
219 .map(|p| p.models.as_slice())
220 .unwrap_or(&[])
221 }
222
223 pub fn find_model(&self, model_id: &str) -> Option<(&ServiceDescriptor, &ModelDescriptor)> {
226 self.entries
227 .iter()
228 .find_map(|p| p.models.iter().find(|m| m.id == model_id).map(|m| (p, m)))
229 }
230}
231
232static EMBEDDED: LazyLock<ProviderIndex> = LazyLock::new(|| {
234 let raw = include_str!("catalog.json");
235 let payload: IndexPayload = serde_json::from_str(raw).expect("catalog.json is valid");
236 ProviderIndex::from_payload(payload)
237});
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242
243 #[test]
244 fn test_embedded_loads() {
245 let catalog = ProviderIndex::embedded();
246 assert!(!catalog.all().is_empty());
247 }
248
249 #[test]
250 fn test_get_known_provider() {
251 let catalog = ProviderIndex::embedded();
252 let p = catalog.get("zai").expect("zai should exist");
254 assert_eq!(p.id, "zai");
255 assert!(!p.base_url.is_empty());
256 assert!(!p.default_model.is_empty());
257 }
258
259 #[test]
260 fn test_get_unknown_returns_none() {
261 let catalog = ProviderIndex::embedded();
262 assert!(catalog.get("nonexistent_provider_xyz").is_none());
263 }
264
265 #[test]
266 fn test_categories_no_duplicates() {
267 let catalog = ProviderIndex::embedded();
268 let cats = catalog.categories();
269 let mut seen = HashSet::new();
270 for c in &cats {
271 assert!(seen.insert(c.clone()), "duplicate category: {}", c);
272 }
273 }
274
275 #[test]
276 fn test_builtin_secret_keys() {
277 let catalog = ProviderIndex::embedded();
278 let keys = catalog.builtin_secret_keys();
279 assert!(!keys.is_empty(), "should contain at least one secret key");
280 assert!(
281 keys.contains("ZAI_API_KEY"),
282 "should contain ZAI_API_KEY, got: {:?}",
283 keys
284 );
285 }
286
287 #[test]
288 fn test_providers_by_category() {
289 let catalog = ProviderIndex::embedded();
290 let cats = catalog.categories();
291 if let Some(cat) = cats.first() {
292 let providers = catalog.providers_by_category(cat);
293 assert!(!providers.is_empty());
294 for p in &providers {
295 assert_eq!(p.category, *cat);
296 }
297 }
298 }
299
300 #[test]
301 fn test_models_for_provider() {
302 let catalog = ProviderIndex::embedded();
303 let models = catalog.models_for("zai");
304 assert!(!models.is_empty(), "zai should have models");
305 assert!(!models[0].id.is_empty());
307 }
308
309 #[test]
310 fn test_models_for_unknown_provider() {
311 let catalog = ProviderIndex::embedded();
312 let models = catalog.models_for("nonexistent_provider_xyz");
313 assert!(models.is_empty());
314 }
315
316 #[test]
317 fn test_find_model() {
318 let catalog = ProviderIndex::embedded();
319 let (provider, model) = catalog.find_model("glm-5").expect("glm-5 should be found");
320 assert_eq!(model.id, "glm-5");
321 assert!(
322 provider.id == "zai" || provider.id == "zai-cn",
323 "glm-5 should belong to a Z.AI provider, got: {}",
324 provider.id
325 );
326 }
327
328 #[test]
329 fn test_find_model_unknown() {
330 let catalog = ProviderIndex::embedded();
331 assert!(catalog.find_model("nonexistent-model-xyz").is_none());
332 }
333
334 #[test]
335 fn test_model_has_pricing() {
336 let catalog = ProviderIndex::embedded();
337 let (_, model) = catalog.find_model("glm-5").expect("glm-5 should exist");
338 let cost = model.cost.as_ref().expect("glm-5 should have cost");
339 assert!(cost.input > 0.0, "input cost should be positive");
340 assert!(cost.output > 0.0, "output cost should be positive");
341 }
342}