Skip to main content

codetether_agent/provider/
models.rs

1//! Model catalog from CodeTether API
2//!
3//! Fetches model information from the "models" endpoint for each provider, including capabilities, costs, and limits. This is used to enrich our internal model information and provide better recommendations and cost estimates.
4//! The catalog is fetched on demand and cached in memory. It also integrates with the secrets manager to check which providers have API keys configured, allowing us to filter available models accordingly.
5
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9/// Model cost information (per million tokens)
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct ModelCost {
12    pub input: f64,
13    pub output: f64,
14    #[serde(default)]
15    pub cache_read: Option<f64>,
16    #[serde(default)]
17    pub cache_write: Option<f64>,
18    #[serde(default)]
19    pub reasoning: Option<f64>,
20}
21
22/// Model limits
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct ModelLimit {
25    #[serde(default)]
26    pub context: u64,
27    #[serde(default)]
28    pub output: u64,
29}
30
31/// Model modalities
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct ModelModalities {
34    #[serde(default)]
35    pub input: Vec<String>,
36    #[serde(default)]
37    pub output: Vec<String>,
38}
39
40/// Model information from the API
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct ApiModelInfo {
43    pub id: String,
44    pub name: String,
45    #[serde(default)]
46    pub family: Option<String>,
47    #[serde(default)]
48    pub attachment: bool,
49    #[serde(default)]
50    pub reasoning: bool,
51    #[serde(default)]
52    pub tool_call: bool,
53    #[serde(default)]
54    pub structured_output: Option<bool>,
55    #[serde(default)]
56    pub temperature: Option<bool>,
57    #[serde(default)]
58    pub knowledge: Option<String>,
59    #[serde(default)]
60    pub release_date: Option<String>,
61    #[serde(default)]
62    pub last_updated: Option<String>,
63    #[serde(default)]
64    pub modalities: Option<ModelModalities>,
65    #[serde(default)]
66    pub open_weights: bool,
67    #[serde(default)]
68    pub cost: Option<ModelCost>,
69    #[serde(default)]
70    pub limit: Option<ModelLimit>,
71}
72
73/// Provider information
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct ProviderInfo {
76    pub id: String,
77    #[serde(default)]
78    pub env: Vec<String>,
79    #[serde(default)]
80    pub npm: Option<String>,
81    #[serde(default)]
82    pub api: Option<String>,
83    pub name: String,
84    #[serde(default)]
85    pub doc: Option<String>,
86    #[serde(default)]
87    pub models: HashMap<String, ApiModelInfo>,
88}
89
90/// The full API response
91pub type ModelsApiResponse = HashMap<String, ProviderInfo>;
92
93/// Model catalog for looking up models
94#[derive(Debug, Clone, Default)]
95pub struct ModelCatalog {
96    providers: HashMap<String, ProviderInfo>,
97}
98
99#[allow(dead_code)]
100impl ModelCatalog {
101    /// Create an empty catalog
102    pub fn new() -> Self {
103        Self::default()
104    }
105
106    /// Fetch models from the CodeTether API
107    pub async fn fetch() -> anyhow::Result<Self> {
108        const MODELS_URL: &str = "https://models.dev/api.json";
109        tracing::info!("Fetching models from {}", MODELS_URL);
110        let response = reqwest::get(MODELS_URL).await?;
111        let providers: ModelsApiResponse = response.json().await?;
112        tracing::info!("Loaded {} providers", providers.len());
113        Ok(Self { providers })
114    }
115
116    /// Fetch models with a custom URL (for testing or alternate sources)
117    #[allow(dead_code)]
118    pub async fn fetch_from(url: &str) -> anyhow::Result<Self> {
119        let response = reqwest::get(url).await?;
120        let providers: ModelsApiResponse = response.json().await?;
121        Ok(Self { providers })
122    }
123
124    /// Check if a provider has an API key configured in HashiCorp Vault
125    ///
126    /// NOTE: This is a sync wrapper that checks the Vault cache.
127    /// For initial population, use `check_provider_api_key_async`.
128    pub fn provider_has_api_key(&self, provider_id: &str) -> bool {
129        // Check if we have it cached in the secrets manager
130        if let Some(manager) = crate::secrets::secrets_manager() {
131            // Sync check - only works if already cached
132            let cache = manager.cache.try_read();
133            if let Ok(cache) = cache {
134                return cache.contains_key(provider_id);
135            }
136        }
137        false
138    }
139
140    /// Async check if a provider has an API key in Vault
141    pub async fn check_provider_api_key_async(&self, provider_id: &str) -> bool {
142        crate::secrets::has_api_key(provider_id).await
143    }
144
145    /// Pre-load API key availability from Vault for all providers
146    pub async fn preload_available_providers(&self) -> Vec<String> {
147        let mut available = Vec::new();
148
149        if let Some(manager) = crate::secrets::secrets_manager() {
150            // List all configured providers from Vault
151            if let Ok(providers) = manager.list_configured_providers().await {
152                for provider_id in providers {
153                    // Verify each one actually has an API key
154                    if manager.has_api_key(&provider_id).await {
155                        available.push(provider_id);
156                    }
157                }
158            }
159        }
160
161        available
162    }
163
164    /// Get list of providers that have API keys configured (sync, uses cache)
165    pub fn available_providers(&self) -> Vec<&str> {
166        self.providers
167            .keys()
168            .filter(|id| self.provider_has_api_key(id))
169            .map(|s| s.as_str())
170            .collect()
171    }
172
173    /// Get list of providers that have API keys configured (async, checks Vault)
174    #[allow(dead_code)]
175    pub async fn available_providers_async(&self) -> Vec<String> {
176        let mut available = Vec::new();
177        for provider_id in self.providers.keys() {
178            if self.check_provider_api_key_async(provider_id).await {
179                available.push(provider_id.clone());
180            }
181        }
182        available
183    }
184
185    /// Get a provider by ID
186    pub fn get_provider(&self, provider_id: &str) -> Option<&ProviderInfo> {
187        self.providers.get(provider_id)
188    }
189
190    /// Get a provider by ID only if it has an API key
191    pub fn get_available_provider(&self, provider_id: &str) -> Option<&ProviderInfo> {
192        if self.provider_has_api_key(provider_id) {
193            self.providers.get(provider_id)
194        } else {
195            None
196        }
197    }
198
199    /// Get a model by provider and model ID
200    pub fn get_model(&self, provider_id: &str, model_id: &str) -> Option<&ApiModelInfo> {
201        self.providers
202            .get(provider_id)
203            .and_then(|p| p.models.get(model_id))
204    }
205
206    /// Get a model only if the provider has an API key
207    pub fn get_available_model(&self, provider_id: &str, model_id: &str) -> Option<&ApiModelInfo> {
208        if self.provider_has_api_key(provider_id) {
209            self.get_model(provider_id, model_id)
210        } else {
211            None
212        }
213    }
214
215    /// Find a model by ID across all providers (only available ones)
216    pub fn find_model(&self, model_id: &str) -> Option<(&str, &ApiModelInfo)> {
217        for (provider_id, provider) in &self.providers {
218            if !self.provider_has_api_key(provider_id) {
219                continue;
220            }
221            if let Some(model) = provider.models.get(model_id) {
222                return Some((provider_id, model));
223            }
224        }
225        None
226    }
227
228    /// Find a model across ALL providers (ignoring API key requirement)
229    pub fn find_model_any(&self, model_id: &str) -> Option<(&str, &ApiModelInfo)> {
230        for (provider_id, provider) in &self.providers {
231            if let Some(model) = provider.models.get(model_id) {
232                return Some((provider_id, model));
233            }
234        }
235        None
236    }
237
238    /// List all provider IDs (all, not filtered)
239    #[allow(dead_code)]
240    pub fn provider_ids(&self) -> Vec<&str> {
241        self.providers.keys().map(|s| s.as_str()).collect()
242    }
243
244    /// Get iterator over all providers and their info (unfiltered, no API key check)
245    pub fn all_providers(&self) -> &HashMap<String, ProviderInfo> {
246        &self.providers
247    }
248
249    /// List models for a provider (only if API key available)
250    pub fn models_for_provider(&self, provider_id: &str) -> Vec<&ApiModelInfo> {
251        if !self.provider_has_api_key(provider_id) {
252            return Vec::new();
253        }
254        self.providers
255            .get(provider_id)
256            .map(|p| p.models.values().collect())
257            .unwrap_or_default()
258    }
259
260    /// Find models with tool calling support (only from available providers)
261    pub fn tool_capable_models(&self) -> Vec<(&str, &ApiModelInfo)> {
262        let mut result = Vec::new();
263        for (provider_id, provider) in &self.providers {
264            if !self.provider_has_api_key(provider_id) {
265                continue;
266            }
267            for model in provider.models.values() {
268                if model.tool_call {
269                    result.push((provider_id.as_str(), model));
270                }
271            }
272        }
273        result
274    }
275
276    /// Find models with reasoning support (only from available providers)
277    pub fn reasoning_models(&self) -> Vec<(&str, &ApiModelInfo)> {
278        let mut result = Vec::new();
279        for (provider_id, provider) in &self.providers {
280            if !self.provider_has_api_key(provider_id) {
281                continue;
282            }
283            for model in provider.models.values() {
284                if model.reasoning {
285                    result.push((provider_id.as_str(), model));
286                }
287            }
288        }
289        result
290    }
291
292    /// Get recommended models for coding tasks
293    pub fn recommended_coding_models(&self) -> Vec<(&str, &ApiModelInfo)> {
294        let preferred_ids = [
295            "claude-sonnet-4-6",
296            "claude-sonnet-4-20250514",
297            "claude-opus-4-20250514",
298            "gpt-5-codex",
299            "gpt-5.1-codex",
300            "gpt-4o",
301            "gemini-3.1-pro-preview",
302            "gemini-2.5-pro",
303            "deepseek-v3.2",
304            "step-3.5-flash",
305            "glm-5",
306            "z-ai/glm-5",
307        ];
308
309        let mut result = Vec::new();
310        for model_id in preferred_ids {
311            if let Some((provider, model)) = self.find_model(model_id) {
312                result.push((provider, model));
313            }
314        }
315        result
316    }
317
318    /// Convert API model info to our internal ModelInfo format
319    #[allow(dead_code)]
320    pub fn to_model_info(&self, model: &ApiModelInfo, provider_id: &str) -> super::ModelInfo {
321        super::ModelInfo {
322            id: model.id.clone(),
323            name: model.name.clone(),
324            provider: provider_id.to_string(),
325            context_window: model
326                .limit
327                .as_ref()
328                .map(|l| l.context as usize)
329                .unwrap_or(128_000),
330            max_output_tokens: model.limit.as_ref().map(|l| l.output as usize),
331            supports_vision: model.attachment,
332            supports_tools: model.tool_call,
333            supports_streaming: true,
334            input_cost_per_million: model.cost.as_ref().map(|c| c.input),
335            output_cost_per_million: model.cost.as_ref().map(|c| c.output),
336        }
337    }
338}