1use async_trait::async_trait;
26use secrecy::ExposeSecret;
27use serde::{Deserialize, Serialize};
28use std::path::PathBuf;
29use thiserror::Error;
30
31use crate::auth::TokenProvider;
32use crate::cache::FileCache;
33
34#[derive(Clone, Copy, Debug, PartialEq, Eq)]
36pub struct ModelInfo {
37 pub display_name: &'static str,
39
40 pub identifier: &'static str,
42
43 pub is_free: bool,
45
46 pub context_window: u32,
48}
49
50#[derive(Clone, Copy, Debug, PartialEq, Eq)]
52pub struct ProviderConfig {
53 pub name: &'static str,
55
56 pub display_name: &'static str,
58
59 pub api_url: &'static str,
61
62 pub api_key_env: &'static str,
64}
65
66pub static PROVIDERS: &[ProviderConfig] = &[
72 ProviderConfig {
73 name: "gemini",
74 display_name: "Google Gemini",
75 api_url: "https://generativelanguage.googleapis.com/v1beta/openai/chat/completions",
76 api_key_env: "GEMINI_API_KEY",
77 },
78 ProviderConfig {
79 name: "openrouter",
80 display_name: "OpenRouter",
81 api_url: "https://openrouter.ai/api/v1/chat/completions",
82 api_key_env: "OPENROUTER_API_KEY",
83 },
84 ProviderConfig {
85 name: "groq",
86 display_name: "Groq",
87 api_url: "https://api.groq.com/openai/v1/chat/completions",
88 api_key_env: "GROQ_API_KEY",
89 },
90 ProviderConfig {
91 name: "cerebras",
92 display_name: "Cerebras",
93 api_url: "https://api.cerebras.ai/v1/chat/completions",
94 api_key_env: "CEREBRAS_API_KEY",
95 },
96 ProviderConfig {
97 name: "zenmux",
98 display_name: "Zenmux",
99 api_url: "https://zenmux.ai/api/v1/chat/completions",
100 api_key_env: "ZENMUX_API_KEY",
101 },
102 ProviderConfig {
103 name: "zai",
104 display_name: "Z.AI (Zhipu)",
105 api_url: "https://api.z.ai/api/paas/v4/chat/completions",
106 api_key_env: "ZAI_API_KEY",
107 },
108];
109
110#[must_use]
130pub fn get_provider(name: &str) -> Option<&'static ProviderConfig> {
131 PROVIDERS.iter().find(|p| p.name == name)
132}
133
134#[must_use]
149pub fn all_providers() -> &'static [ProviderConfig] {
150 PROVIDERS
151}
152
153#[derive(Debug, Error)]
159pub enum RegistryError {
160 #[error("HTTP request failed: {0}")]
162 HttpError(String),
163
164 #[error("Failed to parse API response: {0}")]
166 ParseError(String),
167
168 #[error("Provider not found: {0}")]
170 ProviderNotFound(String),
171
172 #[error("Cache error: {0}")]
174 CacheError(String),
175
176 #[error("IO error: {0}")]
178 IoError(#[from] std::io::Error),
179
180 #[error("Invalid model ID: {model_id}")]
182 ModelValidation {
183 model_id: String,
185 },
186}
187
188#[derive(Clone, Debug, Serialize, Deserialize)]
190pub struct CachedModel {
191 pub id: String,
193 pub name: Option<String>,
195 pub is_free: Option<bool>,
197 pub context_window: Option<u32>,
199}
200
201#[async_trait]
203pub trait ModelRegistry: Send + Sync {
204 async fn list_models(&self, provider: &str) -> Result<Vec<CachedModel>, RegistryError>;
206
207 async fn model_exists(&self, provider: &str, model_id: &str) -> Result<bool, RegistryError>;
209
210 async fn validate_model(&self, provider: &str, model_id: &str) -> Result<(), RegistryError>;
212}
213
214pub struct CachedModelRegistry<'a> {
216 cache: crate::cache::FileCacheImpl<Vec<CachedModel>>,
217 client: reqwest::Client,
218 token_provider: &'a dyn TokenProvider,
219}
220
221impl CachedModelRegistry<'_> {
222 #[must_use]
230 pub fn new(
231 cache_dir: Option<PathBuf>,
232 ttl_seconds: u64,
233 token_provider: &dyn TokenProvider,
234 ) -> CachedModelRegistry<'_> {
235 let ttl = chrono::Duration::seconds(
236 ttl_seconds
237 .try_into()
238 .unwrap_or(crate::cache::DEFAULT_MODEL_TTL_SECS.cast_signed()),
239 );
240 CachedModelRegistry {
241 cache: crate::cache::FileCacheImpl::with_dir(cache_dir, "models", ttl),
242 client: reqwest::Client::builder()
243 .timeout(std::time::Duration::from_secs(10))
244 .build()
245 .unwrap_or_else(|_| reqwest::Client::new()),
246 token_provider,
247 }
248 }
249
250 fn parse_openrouter_models(data: &serde_json::Value) -> Vec<CachedModel> {
252 data.get("data")
253 .and_then(|d| d.as_array())
254 .map(|arr| {
255 arr.iter()
256 .filter_map(|m| {
257 Some(CachedModel {
258 id: m.get("id")?.as_str()?.to_string(),
259 name: m.get("name").and_then(|n| n.as_str()).map(String::from),
260 is_free: m
261 .get("pricing")
262 .and_then(|p| p.get("prompt"))
263 .and_then(|p| p.as_str())
264 .map(|p| p == "0"),
265 context_window: m
266 .get("context_length")
267 .and_then(serde_json::Value::as_u64)
268 .and_then(|c| u32::try_from(c).ok()),
269 })
270 })
271 .collect()
272 })
273 .unwrap_or_default()
274 }
275
276 fn parse_gemini_models(data: &serde_json::Value) -> Vec<CachedModel> {
278 data.get("models")
279 .and_then(|d| d.as_array())
280 .map(|arr| {
281 arr.iter()
282 .filter_map(|m| {
283 Some(CachedModel {
284 id: m.get("name")?.as_str()?.to_string(),
285 name: m
286 .get("displayName")
287 .and_then(|n| n.as_str())
288 .map(String::from),
289 is_free: None,
290 context_window: m
291 .get("inputTokenLimit")
292 .and_then(serde_json::Value::as_u64)
293 .and_then(|c| u32::try_from(c).ok()),
294 })
295 })
296 .collect()
297 })
298 .unwrap_or_default()
299 }
300
301 fn parse_generic_models(data: &serde_json::Value) -> Vec<CachedModel> {
303 data.get("data")
304 .and_then(|d| d.as_array())
305 .map(|arr| {
306 arr.iter()
307 .filter_map(|m| {
308 Some(CachedModel {
309 id: m.get("id")?.as_str()?.to_string(),
310 name: None,
311 is_free: None,
312 context_window: None,
313 })
314 })
315 .collect()
316 })
317 .unwrap_or_default()
318 }
319
320 async fn fetch_from_api(&self, provider: &str) -> Result<Vec<CachedModel>, RegistryError> {
322 let url = match provider {
323 "openrouter" => "https://openrouter.ai/api/v1/models",
324 "gemini" => "https://generativelanguage.googleapis.com/v1beta/models",
325 "groq" => "https://api.groq.com/openai/v1/models",
326 "cerebras" => "https://api.cerebras.ai/v1/models",
327 "zenmux" => "https://zenmux.ai/api/v1/models",
328 "zai" => "https://api.z.ai/api/paas/v4/models",
329 _ => return Err(RegistryError::ProviderNotFound(provider.to_string())),
330 };
331
332 let api_key = self.token_provider.ai_api_key(provider).ok_or_else(|| {
334 RegistryError::HttpError(format!("No API key available for {provider}"))
335 })?;
336
337 let request = match provider {
339 "gemini" => {
340 self.client
342 .get(url)
343 .query(&[("key", api_key.expose_secret())])
344 }
345 "openrouter" | "groq" | "cerebras" | "zenmux" | "zai" => {
346 self.client.get(url).header(
348 "Authorization",
349 format!("Bearer {}", api_key.expose_secret()),
350 )
351 }
352 _ => self.client.get(url),
353 };
354
355 let response = request
356 .send()
357 .await
358 .map_err(|e| RegistryError::HttpError(e.to_string()))?;
359
360 let data = response
361 .json::<serde_json::Value>()
362 .await
363 .map_err(|e| RegistryError::HttpError(e.to_string()))?;
364
365 let models = match provider {
367 "openrouter" => Self::parse_openrouter_models(&data),
368 "gemini" => Self::parse_gemini_models(&data),
369 "groq" | "cerebras" | "zenmux" | "zai" => Self::parse_generic_models(&data),
370 _ => vec![],
371 };
372
373 Ok(models)
374 }
375}
376
377#[async_trait]
378impl ModelRegistry for CachedModelRegistry<'_> {
379 async fn list_models(&self, provider: &str) -> Result<Vec<CachedModel>, RegistryError> {
380 if let Ok(Some(models)) = self.cache.get(provider) {
382 return Ok(models);
383 }
384
385 match self.fetch_from_api(provider).await {
387 Ok(models) => {
388 let _ = self.cache.set(provider, &models);
390 Ok(models)
391 }
392 Err(api_error) => {
393 match self.cache.get_stale(provider) {
395 Ok(Some(models)) => {
396 tracing::warn!(
397 provider = provider,
398 error = %api_error,
399 "API request failed, returning stale cached models"
400 );
401 Ok(models)
402 }
403 _ => {
404 Err(api_error)
406 }
407 }
408 }
409 }
410 }
411
412 async fn model_exists(&self, provider: &str, model_id: &str) -> Result<bool, RegistryError> {
413 let models = self.list_models(provider).await?;
414 Ok(models.iter().any(|m| m.id == model_id))
415 }
416
417 async fn validate_model(&self, provider: &str, model_id: &str) -> Result<(), RegistryError> {
418 if self.model_exists(provider, model_id).await? {
419 Ok(())
420 } else {
421 Err(RegistryError::ModelValidation {
422 model_id: model_id.to_string(),
423 })
424 }
425 }
426}
427
428#[cfg(test)]
429mod tests {
430 use super::*;
431
432 #[test]
433 fn test_get_provider_gemini() {
434 let provider = get_provider("gemini");
435 assert!(provider.is_some());
436 let provider = provider.unwrap();
437 assert_eq!(provider.display_name, "Google Gemini");
438 assert_eq!(provider.api_key_env, "GEMINI_API_KEY");
439 }
440
441 #[test]
442 fn test_get_provider_openrouter() {
443 let provider = get_provider("openrouter");
444 assert!(provider.is_some());
445 let provider = provider.unwrap();
446 assert_eq!(provider.display_name, "OpenRouter");
447 assert_eq!(provider.api_key_env, "OPENROUTER_API_KEY");
448 }
449
450 #[test]
451 fn test_get_provider_groq() {
452 let provider = get_provider("groq");
453 assert!(provider.is_some());
454 let provider = provider.unwrap();
455 assert_eq!(provider.display_name, "Groq");
456 assert_eq!(provider.api_key_env, "GROQ_API_KEY");
457 }
458
459 #[test]
460 fn test_get_provider_cerebras() {
461 let provider = get_provider("cerebras");
462 assert!(provider.is_some());
463 let provider = provider.unwrap();
464 assert_eq!(provider.display_name, "Cerebras");
465 assert_eq!(provider.api_key_env, "CEREBRAS_API_KEY");
466 }
467
468 #[test]
469 fn test_get_provider_not_found() {
470 let provider = get_provider("nonexistent");
471 assert!(provider.is_none());
472 }
473
474 #[test]
475 fn test_get_provider_case_sensitive() {
476 let provider = get_provider("OpenRouter");
477 assert!(
478 provider.is_none(),
479 "Provider lookup should be case-sensitive"
480 );
481 }
482
483 #[test]
484 fn test_all_providers_count() {
485 let providers = all_providers();
486 assert_eq!(providers.len(), 6, "Should have exactly 6 providers");
487 }
488
489 #[test]
490 fn test_all_providers_have_unique_names() {
491 let providers = all_providers();
492 let mut names = Vec::new();
493 for provider in providers {
494 assert!(
495 !names.contains(&provider.name),
496 "Duplicate provider name: {}",
497 provider.name
498 );
499 names.push(provider.name);
500 }
501 }
502
503 #[test]
504 fn test_get_provider_zenmux() {
505 let provider = get_provider("zenmux");
506 assert!(provider.is_some());
507 let provider = provider.unwrap();
508 assert_eq!(provider.display_name, "Zenmux");
509 assert_eq!(provider.api_key_env, "ZENMUX_API_KEY");
510 }
511
512 #[test]
513 fn test_get_provider_zai() {
514 let provider = get_provider("zai");
515 assert!(provider.is_some());
516 let provider = provider.unwrap();
517 assert_eq!(provider.display_name, "Z.AI (Zhipu)");
518 assert_eq!(provider.api_key_env, "ZAI_API_KEY");
519 }
520
521 #[test]
522 fn test_provider_api_urls_valid() {
523 let providers = all_providers();
524 for provider in providers {
525 assert!(
526 provider.api_url.starts_with("https://"),
527 "Provider {} API URL should use HTTPS",
528 provider.name
529 );
530 }
531 }
532
533 #[test]
534 fn test_provider_api_key_env_not_empty() {
535 let providers = all_providers();
536 for provider in providers {
537 assert!(
538 !provider.api_key_env.is_empty(),
539 "Provider {} should have API key env var",
540 provider.name
541 );
542 }
543 }
544}