Skip to main content

codetether_agent/provider/
models.rs

1//! Model catalog from CodeTether API
2//!
3//! Fetches model information from https://api.codetether.run/static/models/api.json
4
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8const MODELS_API_URL: &str = "https://api.codetether.run/static/models/api.json";
9
10/// Model cost information (per million tokens)
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ModelCost {
13    pub input: f64,
14    pub output: f64,
15    #[serde(default)]
16    pub cache_read: Option<f64>,
17    #[serde(default)]
18    pub cache_write: Option<f64>,
19    #[serde(default)]
20    pub reasoning: Option<f64>,
21}
22
23/// Model limits
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct ModelLimit {
26    #[serde(default)]
27    pub context: u64,
28    #[serde(default)]
29    pub output: u64,
30}
31
32/// Model modalities
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct ModelModalities {
35    #[serde(default)]
36    pub input: Vec<String>,
37    #[serde(default)]
38    pub output: Vec<String>,
39}
40
41/// Model information from the API
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct ApiModelInfo {
44    pub id: String,
45    pub name: String,
46    #[serde(default)]
47    pub family: Option<String>,
48    #[serde(default)]
49    pub attachment: bool,
50    #[serde(default)]
51    pub reasoning: bool,
52    #[serde(default)]
53    pub tool_call: bool,
54    #[serde(default)]
55    pub structured_output: Option<bool>,
56    #[serde(default)]
57    pub temperature: Option<bool>,
58    #[serde(default)]
59    pub knowledge: Option<String>,
60    #[serde(default)]
61    pub release_date: Option<String>,
62    #[serde(default)]
63    pub last_updated: Option<String>,
64    #[serde(default)]
65    pub modalities: Option<ModelModalities>,
66    #[serde(default)]
67    pub open_weights: bool,
68    #[serde(default)]
69    pub cost: Option<ModelCost>,
70    #[serde(default)]
71    pub limit: Option<ModelLimit>,
72}
73
74/// Provider information
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct ProviderInfo {
77    pub id: String,
78    #[serde(default)]
79    pub env: Vec<String>,
80    #[serde(default)]
81    pub npm: Option<String>,
82    #[serde(default)]
83    pub api: Option<String>,
84    pub name: String,
85    #[serde(default)]
86    pub doc: Option<String>,
87    #[serde(default)]
88    pub models: HashMap<String, ApiModelInfo>,
89}
90
91/// The full API response
92pub type ModelsApiResponse = HashMap<String, ProviderInfo>;
93
94/// Model catalog for looking up models
95#[derive(Debug, Clone, Default)]
96pub struct ModelCatalog {
97    providers: HashMap<String, ProviderInfo>,
98}
99
100#[allow(dead_code)]
101impl ModelCatalog {
102    /// Create an empty catalog
103    pub fn new() -> Self {
104        Self::default()
105    }
106
107    /// Fetch models from the CodeTether API
108    pub async fn fetch() -> anyhow::Result<Self> {
109        tracing::info!("Fetching models from {}", MODELS_API_URL);
110        let response = reqwest::get(MODELS_API_URL).await?;
111        let providers: ModelsApiResponse = response.json().await?;
112        tracing::info!("Loaded {} providers", providers.len());
113        Ok(Self { providers })
114    }
115
116    /// Fetch models with a custom URL (for testing or alternate sources)
117    #[allow(dead_code)]
118    pub async fn fetch_from(url: &str) -> anyhow::Result<Self> {
119        let response = reqwest::get(url).await?;
120        let providers: ModelsApiResponse = response.json().await?;
121        Ok(Self { providers })
122    }
123
124    /// Check if a provider has an API key configured in HashiCorp Vault
125    ///
126    /// NOTE: This is a sync wrapper that checks the Vault cache.
127    /// For initial population, use `check_provider_api_key_async`.
128    pub fn provider_has_api_key(&self, provider_id: &str) -> bool {
129        // Check if we have it cached in the secrets manager
130        if let Some(manager) = crate::secrets::secrets_manager() {
131            // Sync check - only works if already cached
132            let cache = manager.cache.try_read();
133            if let Ok(cache) = cache {
134                return cache.contains_key(provider_id);
135            }
136        }
137        false
138    }
139
140    /// Async check if a provider has an API key in Vault
141    pub async fn check_provider_api_key_async(&self, provider_id: &str) -> bool {
142        crate::secrets::has_api_key(provider_id).await
143    }
144
145    /// Pre-load API key availability from Vault for all providers
146    pub async fn preload_available_providers(&self) -> Vec<String> {
147        let mut available = Vec::new();
148
149        if let Some(manager) = crate::secrets::secrets_manager() {
150            // List all configured providers from Vault
151            if let Ok(providers) = manager.list_configured_providers().await {
152                for provider_id in providers {
153                    // Verify each one actually has an API key
154                    if manager.has_api_key(&provider_id).await {
155                        available.push(provider_id);
156                    }
157                }
158            }
159        }
160
161        available
162    }
163
164    /// Get list of providers that have API keys configured (sync, uses cache)
165    pub fn available_providers(&self) -> Vec<&str> {
166        self.providers
167            .keys()
168            .filter(|id| self.provider_has_api_key(id))
169            .map(|s| s.as_str())
170            .collect()
171    }
172
173    /// Get list of providers that have API keys configured (async, checks Vault)
174    #[allow(dead_code)]
175    pub async fn available_providers_async(&self) -> Vec<String> {
176        let mut available = Vec::new();
177        for provider_id in self.providers.keys() {
178            if self.check_provider_api_key_async(provider_id).await {
179                available.push(provider_id.clone());
180            }
181        }
182        available
183    }
184
185    /// Get a provider by ID
186    pub fn get_provider(&self, provider_id: &str) -> Option<&ProviderInfo> {
187        self.providers.get(provider_id)
188    }
189
190    /// Get a provider by ID only if it has an API key
191    pub fn get_available_provider(&self, provider_id: &str) -> Option<&ProviderInfo> {
192        if self.provider_has_api_key(provider_id) {
193            self.providers.get(provider_id)
194        } else {
195            None
196        }
197    }
198
199    /// Get a model by provider and model ID
200    pub fn get_model(&self, provider_id: &str, model_id: &str) -> Option<&ApiModelInfo> {
201        self.providers
202            .get(provider_id)
203            .and_then(|p| p.models.get(model_id))
204    }
205
206    /// Get a model only if the provider has an API key
207    pub fn get_available_model(&self, provider_id: &str, model_id: &str) -> Option<&ApiModelInfo> {
208        if self.provider_has_api_key(provider_id) {
209            self.get_model(provider_id, model_id)
210        } else {
211            None
212        }
213    }
214
215    /// Find a model by ID across all providers (only available ones)
216    pub fn find_model(&self, model_id: &str) -> Option<(&str, &ApiModelInfo)> {
217        for (provider_id, provider) in &self.providers {
218            if !self.provider_has_api_key(provider_id) {
219                continue;
220            }
221            if let Some(model) = provider.models.get(model_id) {
222                return Some((provider_id, model));
223            }
224        }
225        None
226    }
227
228    /// Find a model across ALL providers (ignoring API key requirement)
229    pub fn find_model_any(&self, model_id: &str) -> Option<(&str, &ApiModelInfo)> {
230        for (provider_id, provider) in &self.providers {
231            if let Some(model) = provider.models.get(model_id) {
232                return Some((provider_id, model));
233            }
234        }
235        None
236    }
237
238    /// List all provider IDs (all, not filtered)
239    #[allow(dead_code)]
240    pub fn provider_ids(&self) -> Vec<&str> {
241        self.providers.keys().map(|s| s.as_str()).collect()
242    }
243
244    /// Get iterator over all providers and their info (unfiltered, no API key check)
245    pub fn all_providers(&self) -> &HashMap<String, ProviderInfo> {
246        &self.providers
247    }
248
249    /// List models for a provider (only if API key available)
250    pub fn models_for_provider(&self, provider_id: &str) -> Vec<&ApiModelInfo> {
251        if !self.provider_has_api_key(provider_id) {
252            return Vec::new();
253        }
254        self.providers
255            .get(provider_id)
256            .map(|p| p.models.values().collect())
257            .unwrap_or_default()
258    }
259
260    /// Find models with tool calling support (only from available providers)
261    pub fn tool_capable_models(&self) -> Vec<(&str, &ApiModelInfo)> {
262        let mut result = Vec::new();
263        for (provider_id, provider) in &self.providers {
264            if !self.provider_has_api_key(provider_id) {
265                continue;
266            }
267            for model in provider.models.values() {
268                if model.tool_call {
269                    result.push((provider_id.as_str(), model));
270                }
271            }
272        }
273        result
274    }
275
276    /// Find models with reasoning support (only from available providers)
277    pub fn reasoning_models(&self) -> Vec<(&str, &ApiModelInfo)> {
278        let mut result = Vec::new();
279        for (provider_id, provider) in &self.providers {
280            if !self.provider_has_api_key(provider_id) {
281                continue;
282            }
283            for model in provider.models.values() {
284                if model.reasoning {
285                    result.push((provider_id.as_str(), model));
286                }
287            }
288        }
289        result
290    }
291
292    /// Get recommended models for coding tasks
293    pub fn recommended_coding_models(&self) -> Vec<(&str, &ApiModelInfo)> {
294        let preferred_ids = [
295            "claude-sonnet-4-20250514",
296            "claude-opus-4-20250514",
297            "gpt-5-codex",
298            "gpt-5.1-codex",
299            "gpt-4o",
300            "gemini-2.5-pro",
301            "deepseek-v3.2",
302            "step-3.5-flash",
303            "glm-5",
304            "z-ai/glm-5",
305        ];
306
307        let mut result = Vec::new();
308        for model_id in preferred_ids {
309            if let Some((provider, model)) = self.find_model(model_id) {
310                result.push((provider, model));
311            }
312        }
313        result
314    }
315
316    /// Convert API model info to our internal ModelInfo format
317    #[allow(dead_code)]
318    pub fn to_model_info(&self, model: &ApiModelInfo, provider_id: &str) -> super::ModelInfo {
319        super::ModelInfo {
320            id: model.id.clone(),
321            name: model.name.clone(),
322            provider: provider_id.to_string(),
323            context_window: model
324                .limit
325                .as_ref()
326                .map(|l| l.context as usize)
327                .unwrap_or(128_000),
328            max_output_tokens: model.limit.as_ref().map(|l| l.output as usize),
329            supports_vision: model.attachment,
330            supports_tools: model.tool_call,
331            supports_streaming: true,
332            input_cost_per_million: model.cost.as_ref().map(|c| c.input),
333            output_cost_per_million: model.cost.as_ref().map(|c| c.output),
334        }
335    }
336}