Skip to main content

llm_kernel/provider/
catalog.rs

1use serde::Deserialize;
2use std::collections::{HashMap, HashSet};
3use std::sync::LazyLock;
4
5// ---------------------------------------------------------------------------
6// models.dev-compatible model descriptor types
7// ---------------------------------------------------------------------------
8
9/// Per-million-token pricing for a model.
10#[derive(Debug, Clone, Deserialize)]
11pub struct ModelCost {
12    /// Price per million input (prompt) tokens in USD.
13    pub input: f64,
14    /// Price per million output (completion) tokens in USD.
15    pub output: f64,
16    /// Price per million cache-read tokens, if the provider supports prompt caching.
17    #[serde(default)]
18    pub cache_read: Option<f64>,
19    /// Price per million cache-write tokens, if the provider supports prompt caching.
20    #[serde(default)]
21    pub cache_write: Option<f64>,
22}
23
24/// Token limits for a model.
25#[derive(Debug, Clone, Deserialize)]
26pub struct ModelLimit {
27    /// Maximum context window in tokens (prompt + completion).
28    pub context: u64,
29    /// Maximum output (completion) tokens per request.
30    pub output: u64,
31}
32
33/// Input/output modalities a model supports.
34#[derive(Debug, Clone, Deserialize)]
35pub struct ModelModalities {
36    /// Accepted input modalities (e.g. `["text", "image"]`).
37    pub input: Vec<String>,
38    /// Produced output modalities (e.g. `["text"]`).
39    pub output: Vec<String>,
40}
41
42/// Capability flags for a model.
43#[derive(Debug, Clone, Deserialize)]
44pub struct ModelCapabilities {
45    /// Whether the model accepts file/image attachments.
46    #[serde(default)]
47    pub attachment: bool,
48    /// Whether the model supports extended reasoning / chain-of-thought.
49    #[serde(default)]
50    pub reasoning: bool,
51    /// Whether the model accepts a `temperature` parameter.
52    #[serde(default)]
53    pub temperature: bool,
54    /// Whether the model supports tool/function calling.
55    #[serde(default)]
56    pub tool_call: bool,
57    /// Whether the model supports streaming responses (SSE).
58    #[serde(default = "default_true")]
59    pub streaming: bool,
60}
61
62fn default_true() -> bool {
63    true
64}
65
66/// A model offered by a provider (models.dev-compatible).
67#[derive(Debug, Clone, Deserialize)]
68pub struct ModelDescriptor {
69    /// Unique model identifier (e.g. `"gpt-4o"`, `"claude-sonnet-4-6"`).
70    pub id: String,
71    /// Human-readable model name.
72    pub name: String,
73    /// Model family grouping (e.g. `"gpt-4"`, `"claude-3"`).
74    #[serde(default)]
75    pub family: Option<String>,
76    /// ISO 8601 date the model was released.
77    #[serde(default)]
78    pub release_date: Option<String>,
79    /// Pricing information per million tokens.
80    #[serde(default)]
81    pub cost: Option<ModelCost>,
82    /// Token limits for context and output.
83    #[serde(default)]
84    pub limit: Option<ModelLimit>,
85    /// Input and output modalities.
86    #[serde(default)]
87    pub modalities: Option<ModelModalities>,
88    /// Capability flags (tool calling, streaming, etc.).
89    #[serde(default)]
90    pub capabilities: Option<ModelCapabilities>,
91    /// Knowledge cutoff date (ISO 8601).
92    #[serde(default)]
93    pub knowledge: Option<String>,
94}
95
96// ---------------------------------------------------------------------------
97// Provider service descriptor
98// ---------------------------------------------------------------------------
99
100/// Describes an LLM provider service with all metadata needed to connect and use it.
101#[derive(Debug, Clone, Deserialize)]
102pub struct ServiceDescriptor {
103    /// Unique provider identifier (e.g. `"openai"`, `"anthropic"`).
104    pub id: String,
105    /// Human-readable display name.
106    #[serde(rename = "display_name")]
107    pub display_name: String,
108    /// Short description of the provider.
109    pub description: String,
110    /// Provider category (e.g. `"cloud"`, `"local"`).
111    pub category: String,
112    /// Provider family used to group related providers.
113    pub family: String,
114    /// Authentication mode: `"none"`, `"literal"`, or `"secret"`.
115    #[serde(rename = "auth_mode")]
116    pub auth_mode: String,
117    /// Environment variable name that holds the API key (empty if not required).
118    #[serde(rename = "key_var", skip_serializing_if = "String::is_empty", default)]
119    pub key_var: String,
120    /// Literal auth token embedded in the catalog (only set when `auth_mode = "literal"`).
121    #[serde(
122        rename = "literal_auth_token",
123        skip_serializing_if = "String::is_empty",
124        default
125    )]
126    pub literal_auth_token: String,
127    /// Base URL for the provider's web interface.
128    #[serde(rename = "base_url")]
129    pub base_url: String,
130    /// Default model ID used when no model override is specified.
131    #[serde(rename = "default_model")]
132    pub default_model: String,
133    /// Named model tiers mapping tier name → model ID (e.g. `"fast"` → `"gpt-4o-mini"`).
134    #[serde(rename = "model_tiers", default)]
135    pub model_tiers: HashMap<String, String>,
136    /// Legacy list of available model choices (claudy-specific).
137    #[serde(rename = "model_choices", default)]
138    pub model_choices: Vec<ModelChoice>,
139    /// URL used to test connectivity to the provider.
140    #[serde(rename = "test_url")]
141    pub test_url: String,
142    /// Setup instructions shown to the user during first-time configuration.
143    #[serde(default)]
144    pub setup: Vec<String>,
145    /// Usage examples shown to the user in the install wizard.
146    #[serde(default)]
147    pub usage: Vec<String>,
148
149    // models.dev-compatible fields
150    /// API base URL override (models.dev-compatible field).
151    #[serde(default)]
152    pub api_base_url: Option<String>,
153    /// npm package name (models.dev-compatible field, for AI coding tools).
154    #[serde(default)]
155    pub npm_package: Option<String>,
156    /// Link to provider documentation.
157    #[serde(default)]
158    pub doc_url: Option<String>,
159    /// Full list of models offered by this provider.
160    #[serde(default)]
161    pub models: Vec<ModelDescriptor>,
162}
163
164/// Legacy model choice (claudy-specific: id + description).
165/// Retained for backward compatibility with existing catalog.json entries.
166#[derive(Debug, Clone, Deserialize)]
167pub struct ModelChoice {
168    /// Model identifier.
169    pub id: String,
170    /// Short description of the model.
171    pub description: String,
172}
173
174#[derive(Debug, Deserialize)]
175struct IndexPayload {
176    providers: Vec<ServiceDescriptor>,
177}
178
179// ---------------------------------------------------------------------------
180// Provider index
181// ---------------------------------------------------------------------------
182
183/// Immutable provider catalog with O(1) lookup by id.
184///
185/// The catalog is compiled into the binary from `catalog.json` via `include_str!`.
186/// Access it through [`ProviderIndex::embedded()`].
187#[derive(Debug, Clone)]
188pub struct ProviderIndex {
189    entries: Vec<ServiceDescriptor>,
190    index: HashMap<String, usize>,
191}
192
193impl ProviderIndex {
194    fn from_payload(payload: IndexPayload) -> Self {
195        let index: HashMap<String, usize> = payload
196            .providers
197            .iter()
198            .enumerate()
199            .map(|(i, p)| (p.id.clone(), i))
200            .collect();
201        Self {
202            entries: payload.providers,
203            index,
204        }
205    }
206
207    /// Access the static catalog embedded at compile time.
208    pub fn embedded() -> &'static ProviderIndex {
209        &EMBEDDED
210    }
211
212    /// Return all providers in catalog order.
213    pub fn all(&self) -> &[ServiceDescriptor] {
214        &self.entries
215    }
216
217    /// Return all provider IDs.
218    pub fn ids(&self) -> Vec<String> {
219        self.entries.iter().map(|p| p.id.clone()).collect()
220    }
221
222    /// Look up a provider by ID. O(1).
223    pub fn get(&self, id: &str) -> Option<&ServiceDescriptor> {
224        self.index.get(id).map(|&i| &self.entries[i])
225    }
226
227    /// Unique categories in catalog order.
228    pub fn categories(&self) -> Vec<String> {
229        self.entries
230            .iter()
231            .scan(HashSet::new(), |seen, p| {
232                Some(if seen.insert(p.category.clone()) {
233                    Some(p.category.clone())
234                } else {
235                    None
236                })
237            })
238            .flatten()
239            .collect()
240    }
241
242    /// Filter providers by category.
243    pub fn providers_by_category(&self, category: &str) -> Vec<&ServiceDescriptor> {
244        self.entries
245            .iter()
246            .filter(|p| p.category == category)
247            .collect()
248    }
249
250    /// Collect all secret key variable names from providers that require one.
251    pub fn builtin_secret_keys(&self) -> HashSet<String> {
252        self.entries
253            .iter()
254            .filter(|p| !p.key_var.is_empty())
255            .map(|p| p.key_var.clone())
256            .collect()
257    }
258
259    /// Get models for a specific provider.
260    pub fn models_for(&self, provider_id: &str) -> &[ModelDescriptor] {
261        self.get(provider_id)
262            .map(|p| p.models.as_slice())
263            .unwrap_or(&[])
264    }
265
266    /// Find a model by ID across all providers.
267    /// Returns the first match (provider, model).
268    pub fn find_model(&self, model_id: &str) -> Option<(&ServiceDescriptor, &ModelDescriptor)> {
269        self.entries
270            .iter()
271            .find_map(|p| p.models.iter().find(|m| m.id == model_id).map(|m| (p, m)))
272    }
273
274    /// Estimate the USD cost of an LLM call given token counts.
275    ///
276    /// Looks up `model_id` across all providers and computes:
277    /// `(input_price * prompt_tokens + output_price * completion_tokens) / 1_000_000`
278    ///
279    /// Returns `None` if the model is not found or has no pricing data.
280    pub fn estimate_cost(
281        &self,
282        model_id: &str,
283        prompt_tokens: u32,
284        completion_tokens: u32,
285    ) -> Option<f64> {
286        let (_, model) = self.find_model(model_id)?;
287        let cost = model.cost.as_ref()?;
288        Some(
289            cost.input * prompt_tokens as f64 / 1_000_000.0
290                + cost.output * completion_tokens as f64 / 1_000_000.0,
291        )
292    }
293}
294
295/// Static catalog compiled into the binary from `catalog.json`.
296static EMBEDDED: LazyLock<ProviderIndex> = LazyLock::new(|| {
297    let raw = include_str!("catalog.json");
298    let payload: IndexPayload = serde_json::from_str(raw).expect("catalog.json is valid");
299    ProviderIndex::from_payload(payload)
300});
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305
306    #[test]
307    fn test_embedded_loads() {
308        let catalog = ProviderIndex::embedded();
309        assert!(!catalog.all().is_empty());
310    }
311
312    #[test]
313    fn test_get_known_provider() {
314        let catalog = ProviderIndex::embedded();
315        // catalog.json contains "zai" (first provider with key_var)
316        let p = catalog.get("zai").expect("zai should exist");
317        assert_eq!(p.id, "zai");
318        assert!(!p.base_url.is_empty());
319        assert!(!p.default_model.is_empty());
320    }
321
322    #[test]
323    fn test_get_unknown_returns_none() {
324        let catalog = ProviderIndex::embedded();
325        assert!(catalog.get("nonexistent_provider_xyz").is_none());
326    }
327
328    #[test]
329    fn test_categories_no_duplicates() {
330        let catalog = ProviderIndex::embedded();
331        let cats = catalog.categories();
332        let mut seen = HashSet::new();
333        for c in &cats {
334            assert!(seen.insert(c.clone()), "duplicate category: {}", c);
335        }
336    }
337
338    #[test]
339    fn test_builtin_secret_keys() {
340        let catalog = ProviderIndex::embedded();
341        let keys = catalog.builtin_secret_keys();
342        assert!(!keys.is_empty(), "should contain at least one secret key");
343        assert!(
344            keys.contains("ZAI_API_KEY"),
345            "should contain ZAI_API_KEY, got: {:?}",
346            keys
347        );
348    }
349
350    #[test]
351    fn test_providers_by_category() {
352        let catalog = ProviderIndex::embedded();
353        let cats = catalog.categories();
354        if let Some(cat) = cats.first() {
355            let providers = catalog.providers_by_category(cat);
356            assert!(!providers.is_empty());
357            for p in &providers {
358                assert_eq!(p.category, *cat);
359            }
360        }
361    }
362
363    #[test]
364    fn test_models_for_provider() {
365        let catalog = ProviderIndex::embedded();
366        let models = catalog.models_for("zai");
367        assert!(!models.is_empty(), "zai should have models");
368        // First model should have an id
369        assert!(!models[0].id.is_empty());
370    }
371
372    #[test]
373    fn test_models_for_unknown_provider() {
374        let catalog = ProviderIndex::embedded();
375        let models = catalog.models_for("nonexistent_provider_xyz");
376        assert!(models.is_empty());
377    }
378
379    #[test]
380    fn test_find_model() {
381        let catalog = ProviderIndex::embedded();
382        let (provider, model) = catalog.find_model("glm-5").expect("glm-5 should be found");
383        assert_eq!(model.id, "glm-5");
384        assert!(
385            provider.id == "zai" || provider.id == "zai-cn",
386            "glm-5 should belong to a Z.AI provider, got: {}",
387            provider.id
388        );
389    }
390
391    #[test]
392    fn test_find_model_unknown() {
393        let catalog = ProviderIndex::embedded();
394        assert!(catalog.find_model("nonexistent-model-xyz").is_none());
395    }
396
397    #[test]
398    fn test_model_has_pricing() {
399        let catalog = ProviderIndex::embedded();
400        let (_, model) = catalog.find_model("glm-5").expect("glm-5 should exist");
401        let cost = model.cost.as_ref().expect("glm-5 should have cost");
402        assert!(cost.input > 0.0, "input cost should be positive");
403        assert!(cost.output > 0.0, "output cost should be positive");
404    }
405}