Skip to main content

llm_kernel/provider/
catalog.rs

1use serde::Deserialize;
2use std::collections::{HashMap, HashSet};
3use std::sync::LazyLock;
4
5// ---------------------------------------------------------------------------
6// models.dev-compatible model descriptor types
7// ---------------------------------------------------------------------------
8
9/// Per-million-token pricing for a model.
10#[derive(Debug, Clone, Deserialize)]
11pub struct ModelCost {
12    pub input: f64,
13    pub output: f64,
14    #[serde(default)]
15    pub cache_read: Option<f64>,
16    #[serde(default)]
17    pub cache_write: Option<f64>,
18}
19
20/// Token limits for a model.
21#[derive(Debug, Clone, Deserialize)]
22pub struct ModelLimit {
23    pub context: u64,
24    pub output: u64,
25}
26
27/// Input/output modalities a model supports.
28#[derive(Debug, Clone, Deserialize)]
29pub struct ModelModalities {
30    pub input: Vec<String>,
31    pub output: Vec<String>,
32}
33
34/// Capability flags for a model.
35#[derive(Debug, Clone, Deserialize)]
36pub struct ModelCapabilities {
37    #[serde(default)]
38    pub attachment: bool,
39    #[serde(default)]
40    pub reasoning: bool,
41    #[serde(default)]
42    pub temperature: bool,
43    #[serde(default)]
44    pub tool_call: bool,
45    #[serde(default = "default_true")]
46    pub streaming: bool,
47}
48
49fn default_true() -> bool {
50    true
51}
52
53/// A model offered by a provider (models.dev-compatible).
54#[derive(Debug, Clone, Deserialize)]
55pub struct ModelDescriptor {
56    pub id: String,
57    pub name: String,
58    #[serde(default)]
59    pub family: Option<String>,
60    #[serde(default)]
61    pub release_date: Option<String>,
62    #[serde(default)]
63    pub cost: Option<ModelCost>,
64    #[serde(default)]
65    pub limit: Option<ModelLimit>,
66    #[serde(default)]
67    pub modalities: Option<ModelModalities>,
68    #[serde(default)]
69    pub capabilities: Option<ModelCapabilities>,
70    #[serde(default)]
71    pub knowledge: Option<String>,
72}
73
74// ---------------------------------------------------------------------------
75// Provider service descriptor
76// ---------------------------------------------------------------------------
77
78/// Describes an LLM provider service with all metadata needed to connect and use it.
79#[derive(Debug, Clone, Deserialize)]
80pub struct ServiceDescriptor {
81    pub id: String,
82    #[serde(rename = "display_name")]
83    pub display_name: String,
84    pub description: String,
85    pub category: String,
86    pub family: String,
87    #[serde(rename = "auth_mode")]
88    pub auth_mode: String,
89    #[serde(rename = "key_var", skip_serializing_if = "String::is_empty", default)]
90    pub key_var: String,
91    #[serde(
92        rename = "literal_auth_token",
93        skip_serializing_if = "String::is_empty",
94        default
95    )]
96    pub literal_auth_token: String,
97    #[serde(rename = "base_url")]
98    pub base_url: String,
99    #[serde(rename = "default_model")]
100    pub default_model: String,
101    #[serde(rename = "model_tiers", default)]
102    pub model_tiers: HashMap<String, String>,
103    #[serde(rename = "model_choices", default)]
104    pub model_choices: Vec<ModelChoice>,
105    #[serde(rename = "test_url")]
106    pub test_url: String,
107    #[serde(default)]
108    pub setup: Vec<String>,
109    #[serde(default)]
110    pub usage: Vec<String>,
111
112    // models.dev-compatible fields
113    #[serde(default)]
114    pub api_base_url: Option<String>,
115    #[serde(default)]
116    pub npm_package: Option<String>,
117    #[serde(default)]
118    pub doc_url: Option<String>,
119    #[serde(default)]
120    pub models: Vec<ModelDescriptor>,
121}
122
123/// Legacy model choice (claudy-specific: id + description).
124/// Retained for backward compatibility with existing catalog.json entries.
125#[derive(Debug, Clone, Deserialize)]
126pub struct ModelChoice {
127    pub id: String,
128    pub description: String,
129}
130
131#[derive(Debug, Deserialize)]
132struct IndexPayload {
133    providers: Vec<ServiceDescriptor>,
134}
135
136// ---------------------------------------------------------------------------
137// Provider index
138// ---------------------------------------------------------------------------
139
140/// Immutable provider catalog with O(1) lookup by id.
141///
142/// The catalog is compiled into the binary from `catalog.json` via `include_str!`.
143/// Access it through [`ProviderIndex::embedded()`].
144#[derive(Debug, Clone)]
145pub struct ProviderIndex {
146    entries: Vec<ServiceDescriptor>,
147    index: HashMap<String, usize>,
148}
149
150impl ProviderIndex {
151    fn from_payload(payload: IndexPayload) -> Self {
152        let index: HashMap<String, usize> = payload
153            .providers
154            .iter()
155            .enumerate()
156            .map(|(i, p)| (p.id.clone(), i))
157            .collect();
158        Self {
159            entries: payload.providers,
160            index,
161        }
162    }
163
164    /// Access the static catalog embedded at compile time.
165    pub fn embedded() -> &'static ProviderIndex {
166        &EMBEDDED
167    }
168
169    /// Return all providers in catalog order.
170    pub fn all(&self) -> &[ServiceDescriptor] {
171        &self.entries
172    }
173
174    /// Return all provider IDs.
175    pub fn ids(&self) -> Vec<String> {
176        self.entries.iter().map(|p| p.id.clone()).collect()
177    }
178
179    /// Look up a provider by ID. O(1).
180    pub fn get(&self, id: &str) -> Option<&ServiceDescriptor> {
181        self.index.get(id).map(|&i| &self.entries[i])
182    }
183
184    /// Unique categories in catalog order.
185    pub fn categories(&self) -> Vec<String> {
186        self.entries
187            .iter()
188            .scan(HashSet::new(), |seen, p| {
189                Some(if seen.insert(p.category.clone()) {
190                    Some(p.category.clone())
191                } else {
192                    None
193                })
194            })
195            .flatten()
196            .collect()
197    }
198
199    /// Filter providers by category.
200    pub fn providers_by_category(&self, category: &str) -> Vec<&ServiceDescriptor> {
201        self.entries
202            .iter()
203            .filter(|p| p.category == category)
204            .collect()
205    }
206
207    /// Collect all secret key variable names from providers that require one.
208    pub fn builtin_secret_keys(&self) -> HashSet<String> {
209        self.entries
210            .iter()
211            .filter(|p| !p.key_var.is_empty())
212            .map(|p| p.key_var.clone())
213            .collect()
214    }
215
216    /// Get models for a specific provider.
217    pub fn models_for(&self, provider_id: &str) -> &[ModelDescriptor] {
218        self.get(provider_id)
219            .map(|p| p.models.as_slice())
220            .unwrap_or(&[])
221    }
222
223    /// Find a model by ID across all providers.
224    /// Returns the first match (provider, model).
225    pub fn find_model(&self, model_id: &str) -> Option<(&ServiceDescriptor, &ModelDescriptor)> {
226        self.entries
227            .iter()
228            .find_map(|p| p.models.iter().find(|m| m.id == model_id).map(|m| (p, m)))
229    }
230}
231
232/// Static catalog compiled into the binary from `catalog.json`.
233static EMBEDDED: LazyLock<ProviderIndex> = LazyLock::new(|| {
234    let raw = include_str!("catalog.json");
235    let payload: IndexPayload = serde_json::from_str(raw).expect("catalog.json is valid");
236    ProviderIndex::from_payload(payload)
237});
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242
243    #[test]
244    fn test_embedded_loads() {
245        let catalog = ProviderIndex::embedded();
246        assert!(!catalog.all().is_empty());
247    }
248
249    #[test]
250    fn test_get_known_provider() {
251        let catalog = ProviderIndex::embedded();
252        // catalog.json contains "zai" (first provider with key_var)
253        let p = catalog.get("zai").expect("zai should exist");
254        assert_eq!(p.id, "zai");
255        assert!(!p.base_url.is_empty());
256        assert!(!p.default_model.is_empty());
257    }
258
259    #[test]
260    fn test_get_unknown_returns_none() {
261        let catalog = ProviderIndex::embedded();
262        assert!(catalog.get("nonexistent_provider_xyz").is_none());
263    }
264
265    #[test]
266    fn test_categories_no_duplicates() {
267        let catalog = ProviderIndex::embedded();
268        let cats = catalog.categories();
269        let mut seen = HashSet::new();
270        for c in &cats {
271            assert!(seen.insert(c.clone()), "duplicate category: {}", c);
272        }
273    }
274
275    #[test]
276    fn test_builtin_secret_keys() {
277        let catalog = ProviderIndex::embedded();
278        let keys = catalog.builtin_secret_keys();
279        assert!(!keys.is_empty(), "should contain at least one secret key");
280        assert!(
281            keys.contains("ZAI_API_KEY"),
282            "should contain ZAI_API_KEY, got: {:?}",
283            keys
284        );
285    }
286
287    #[test]
288    fn test_providers_by_category() {
289        let catalog = ProviderIndex::embedded();
290        let cats = catalog.categories();
291        if let Some(cat) = cats.first() {
292            let providers = catalog.providers_by_category(cat);
293            assert!(!providers.is_empty());
294            for p in &providers {
295                assert_eq!(p.category, *cat);
296            }
297        }
298    }
299
300    #[test]
301    fn test_models_for_provider() {
302        let catalog = ProviderIndex::embedded();
303        let models = catalog.models_for("zai");
304        assert!(!models.is_empty(), "zai should have models");
305        // First model should have an id
306        assert!(!models[0].id.is_empty());
307    }
308
309    #[test]
310    fn test_models_for_unknown_provider() {
311        let catalog = ProviderIndex::embedded();
312        let models = catalog.models_for("nonexistent_provider_xyz");
313        assert!(models.is_empty());
314    }
315
316    #[test]
317    fn test_find_model() {
318        let catalog = ProviderIndex::embedded();
319        let (provider, model) = catalog.find_model("glm-5").expect("glm-5 should be found");
320        assert_eq!(model.id, "glm-5");
321        assert!(
322            provider.id == "zai" || provider.id == "zai-cn",
323            "glm-5 should belong to a Z.AI provider, got: {}",
324            provider.id
325        );
326    }
327
328    #[test]
329    fn test_find_model_unknown() {
330        let catalog = ProviderIndex::embedded();
331        assert!(catalog.find_model("nonexistent-model-xyz").is_none());
332    }
333
334    #[test]
335    fn test_model_has_pricing() {
336        let catalog = ProviderIndex::embedded();
337        let (_, model) = catalog.find_model("glm-5").expect("glm-5 should exist");
338        let cost = model.cost.as_ref().expect("glm-5 should have cost");
339        assert!(cost.input > 0.0, "input cost should be positive");
340        assert!(cost.output > 0.0, "output cost should be positive");
341    }
342}