1use async_trait::async_trait;
26use secrecy::ExposeSecret;
27use serde::{Deserialize, Serialize};
28use std::path::PathBuf;
29use std::time::{SystemTime, UNIX_EPOCH};
30use thiserror::Error;
31
32use crate::auth::TokenProvider;
33
34#[derive(Clone, Copy, Debug, PartialEq, Eq)]
36pub struct ModelInfo {
37 pub display_name: &'static str,
39
40 pub identifier: &'static str,
42
43 pub is_free: bool,
45
46 pub context_window: u32,
48}
49
50#[derive(Clone, Copy, Debug, PartialEq, Eq)]
52pub struct ProviderConfig {
53 pub name: &'static str,
55
56 pub display_name: &'static str,
58
59 pub api_url: &'static str,
61
62 pub api_key_env: &'static str,
64}
65
66pub static PROVIDERS: &[ProviderConfig] = &[
72 ProviderConfig {
73 name: "gemini",
74 display_name: "Google Gemini",
75 api_url: "https://generativelanguage.googleapis.com/v1beta/openai/chat/completions",
76 api_key_env: "GEMINI_API_KEY",
77 },
78 ProviderConfig {
79 name: "openrouter",
80 display_name: "OpenRouter",
81 api_url: "https://openrouter.ai/api/v1/chat/completions",
82 api_key_env: "OPENROUTER_API_KEY",
83 },
84 ProviderConfig {
85 name: "groq",
86 display_name: "Groq",
87 api_url: "https://api.groq.com/openai/v1/chat/completions",
88 api_key_env: "GROQ_API_KEY",
89 },
90 ProviderConfig {
91 name: "cerebras",
92 display_name: "Cerebras",
93 api_url: "https://api.cerebras.ai/v1/chat/completions",
94 api_key_env: "CEREBRAS_API_KEY",
95 },
96 ProviderConfig {
97 name: "zenmux",
98 display_name: "Zenmux",
99 api_url: "https://zenmux.ai/api/v1/chat/completions",
100 api_key_env: "ZENMUX_API_KEY",
101 },
102 ProviderConfig {
103 name: "zai",
104 display_name: "Z.AI (Zhipu)",
105 api_url: "https://api.z.ai/api/paas/v4/chat/completions",
106 api_key_env: "ZAI_API_KEY",
107 },
108];
109
110#[must_use]
130pub fn get_provider(name: &str) -> Option<&'static ProviderConfig> {
131 PROVIDERS.iter().find(|p| p.name == name)
132}
133
134#[must_use]
149pub fn all_providers() -> &'static [ProviderConfig] {
150 PROVIDERS
151}
152
153#[derive(Debug, Error)]
159pub enum RegistryError {
160 #[error("HTTP request failed: {0}")]
162 HttpError(String),
163
164 #[error("Failed to parse API response: {0}")]
166 ParseError(String),
167
168 #[error("Provider not found: {0}")]
170 ProviderNotFound(String),
171
172 #[error("Cache error: {0}")]
174 CacheError(String),
175
176 #[error("IO error: {0}")]
178 IoError(#[from] std::io::Error),
179
180 #[error("Invalid model ID: {model_id}. Did you mean one of these?\n{}", .suggestions.join(", "))]
182 ModelValidation {
183 model_id: String,
185 suggestions: Vec<String>,
187 },
188}
189
190#[derive(Clone, Debug, Serialize, Deserialize)]
192pub struct CachedModel {
193 pub id: String,
195 pub name: Option<String>,
197 pub is_free: Option<bool>,
199 pub context_window: Option<u32>,
201}
202
203#[derive(Clone, Debug, Serialize, Deserialize)]
205pub struct CacheMetadata {
206 pub timestamp: u64,
208 pub ttl_seconds: u64,
210}
211
212#[async_trait]
214pub trait ModelRegistry: Send + Sync {
215 async fn list_models(&self, provider: &str) -> Result<Vec<CachedModel>, RegistryError>;
217
218 async fn model_exists(&self, provider: &str, model_id: &str) -> Result<bool, RegistryError>;
220
221 async fn suggest_similar(
223 &self,
224 provider: &str,
225 model_id: &str,
226 ) -> Result<Vec<String>, RegistryError>;
227
228 async fn validate_model(&self, provider: &str, model_id: &str) -> Result<(), RegistryError>;
232}
233
234pub struct CachedModelRegistry<'a> {
236 cache_dir: PathBuf,
237 ttl_seconds: u64,
238 client: reqwest::Client,
239 token_provider: &'a dyn TokenProvider,
240}
241
242impl CachedModelRegistry<'_> {
243 #[must_use]
251 pub fn new(
252 cache_dir: PathBuf,
253 ttl_seconds: u64,
254 token_provider: &dyn TokenProvider,
255 ) -> CachedModelRegistry<'_> {
256 CachedModelRegistry {
257 cache_dir,
258 ttl_seconds,
259 client: reqwest::Client::builder()
260 .timeout(std::time::Duration::from_secs(10))
261 .build()
262 .unwrap_or_else(|_| reqwest::Client::new()),
263 token_provider,
264 }
265 }
266
267 fn cache_path(&self, provider: &str) -> PathBuf {
269 self.cache_dir.join(format!("models_{provider}.json"))
270 }
271
272 fn is_cache_valid(metadata: &CacheMetadata) -> bool {
274 let now = SystemTime::now()
275 .duration_since(UNIX_EPOCH)
276 .unwrap_or_default()
277 .as_secs();
278 now < metadata.timestamp + metadata.ttl_seconds
279 }
280
281 fn load_from_cache(&self, provider: &str) -> Result<Vec<CachedModel>, RegistryError> {
283 let path = self.cache_path(provider);
284 if !path.exists() {
285 return Err(RegistryError::CacheError(
286 "Cache file not found".to_string(),
287 ));
288 }
289
290 let content = std::fs::read_to_string(&path)?;
291 let data: serde_json::Value =
292 serde_json::from_str(&content).map_err(|e| RegistryError::ParseError(e.to_string()))?;
293
294 if let Some(metadata) = data
296 .get("metadata")
297 .and_then(|m| serde_json::from_value::<CacheMetadata>(m.clone()).ok())
298 {
299 if Self::is_cache_valid(&metadata) {
300 return data
301 .get("models")
302 .and_then(|m| serde_json::from_value::<Vec<CachedModel>>(m.clone()).ok())
303 .ok_or_else(|| RegistryError::ParseError("Invalid cache format".to_string()));
304 }
305 return Err(RegistryError::CacheError("Cache expired".to_string()));
306 }
307
308 data.get("models")
310 .and_then(|m| serde_json::from_value::<Vec<CachedModel>>(m.clone()).ok())
311 .ok_or_else(|| RegistryError::ParseError("Invalid cache format".to_string()))
312 }
313
314 fn load_stale_cache(&self, provider: &str) -> Result<Vec<CachedModel>, RegistryError> {
316 let path = self.cache_path(provider);
317 if !path.exists() {
318 return Err(RegistryError::CacheError(
319 "Cache file not found".to_string(),
320 ));
321 }
322
323 let content = std::fs::read_to_string(&path)?;
324 let data: serde_json::Value =
325 serde_json::from_str(&content).map_err(|e| RegistryError::ParseError(e.to_string()))?;
326
327 data.get("models")
329 .and_then(|m| serde_json::from_value::<Vec<CachedModel>>(m.clone()).ok())
330 .ok_or_else(|| RegistryError::ParseError("Invalid cache format".to_string()))
331 }
332
333 fn save_to_cache(&self, provider: &str, models: &[CachedModel]) -> Result<(), RegistryError> {
335 std::fs::create_dir_all(&self.cache_dir)?;
336
337 let now = SystemTime::now()
338 .duration_since(UNIX_EPOCH)
339 .unwrap_or_default()
340 .as_secs();
341
342 let cache_data = serde_json::json!({
343 "metadata": {
344 "timestamp": now,
345 "ttl_seconds": self.ttl_seconds,
346 },
347 "models": models,
348 });
349
350 let path = self.cache_path(provider);
351 std::fs::write(&path, cache_data.to_string())?;
352 Ok(())
353 }
354
355 fn parse_openrouter_models(data: &serde_json::Value) -> Vec<CachedModel> {
357 data.get("data")
358 .and_then(|d| d.as_array())
359 .map(|arr| {
360 arr.iter()
361 .filter_map(|m| {
362 Some(CachedModel {
363 id: m.get("id")?.as_str()?.to_string(),
364 name: m.get("name").and_then(|n| n.as_str()).map(String::from),
365 is_free: m
366 .get("pricing")
367 .and_then(|p| p.get("prompt"))
368 .and_then(|p| p.as_str())
369 .map(|p| p == "0"),
370 context_window: m
371 .get("context_length")
372 .and_then(serde_json::Value::as_u64)
373 .and_then(|c| u32::try_from(c).ok()),
374 })
375 })
376 .collect()
377 })
378 .unwrap_or_default()
379 }
380
381 fn parse_gemini_models(data: &serde_json::Value) -> Vec<CachedModel> {
383 data.get("models")
384 .and_then(|d| d.as_array())
385 .map(|arr| {
386 arr.iter()
387 .filter_map(|m| {
388 Some(CachedModel {
389 id: m.get("name")?.as_str()?.to_string(),
390 name: m
391 .get("displayName")
392 .and_then(|n| n.as_str())
393 .map(String::from),
394 is_free: None,
395 context_window: m
396 .get("inputTokenLimit")
397 .and_then(serde_json::Value::as_u64)
398 .and_then(|c| u32::try_from(c).ok()),
399 })
400 })
401 .collect()
402 })
403 .unwrap_or_default()
404 }
405
406 fn parse_generic_models(data: &serde_json::Value) -> Vec<CachedModel> {
408 data.get("data")
409 .and_then(|d| d.as_array())
410 .map(|arr| {
411 arr.iter()
412 .filter_map(|m| {
413 Some(CachedModel {
414 id: m.get("id")?.as_str()?.to_string(),
415 name: None,
416 is_free: None,
417 context_window: None,
418 })
419 })
420 .collect()
421 })
422 .unwrap_or_default()
423 }
424
425 async fn fetch_from_api(&self, provider: &str) -> Result<Vec<CachedModel>, RegistryError> {
427 let url = match provider {
428 "openrouter" => "https://openrouter.ai/api/v1/models",
429 "gemini" => "https://generativelanguage.googleapis.com/v1beta/models",
430 "groq" => "https://api.groq.com/openai/v1/models",
431 "cerebras" => "https://api.cerebras.ai/v1/models",
432 "zenmux" => "https://zenmux.ai/api/v1/models",
433 "zai" => "https://api.z.ai/api/paas/v4/models",
434 _ => return Err(RegistryError::ProviderNotFound(provider.to_string())),
435 };
436
437 let api_key = self.token_provider.ai_api_key(provider).ok_or_else(|| {
439 RegistryError::HttpError(format!("No API key available for {provider}"))
440 })?;
441
442 let request = match provider {
444 "gemini" => {
445 self.client
447 .get(url)
448 .query(&[("key", api_key.expose_secret())])
449 }
450 "openrouter" | "groq" | "cerebras" | "zenmux" | "zai" => {
451 self.client.get(url).header(
453 "Authorization",
454 format!("Bearer {}", api_key.expose_secret()),
455 )
456 }
457 _ => self.client.get(url),
458 };
459
460 let response = request
461 .send()
462 .await
463 .map_err(|e| RegistryError::HttpError(e.to_string()))?;
464
465 let data = response
466 .json::<serde_json::Value>()
467 .await
468 .map_err(|e| RegistryError::HttpError(e.to_string()))?;
469
470 let models = match provider {
472 "openrouter" => Self::parse_openrouter_models(&data),
473 "gemini" => Self::parse_gemini_models(&data),
474 "groq" | "cerebras" | "zenmux" | "zai" => Self::parse_generic_models(&data),
475 _ => vec![],
476 };
477
478 Ok(models)
479 }
480}
481
482#[async_trait]
483impl ModelRegistry for CachedModelRegistry<'_> {
484 async fn list_models(&self, provider: &str) -> Result<Vec<CachedModel>, RegistryError> {
485 if let Ok(models) = self.load_from_cache(provider) {
487 return Ok(models);
488 }
489
490 match self.fetch_from_api(provider).await {
492 Ok(models) => {
493 let _ = self.save_to_cache(provider, &models);
495 Ok(models)
496 }
497 Err(api_error) => {
498 match self.load_stale_cache(provider) {
500 Ok(models) => {
501 tracing::warn!(
502 provider = provider,
503 error = %api_error,
504 "API request failed, returning stale cached models"
505 );
506 Ok(models)
507 }
508 Err(_) => {
509 Err(api_error)
511 }
512 }
513 }
514 }
515 }
516
517 async fn model_exists(&self, provider: &str, model_id: &str) -> Result<bool, RegistryError> {
518 let models = self.list_models(provider).await?;
519 Ok(models.iter().any(|m| m.id == model_id))
520 }
521
522 async fn suggest_similar(
523 &self,
524 provider: &str,
525 model_id: &str,
526 ) -> Result<Vec<String>, RegistryError> {
527 let models = self.list_models(provider).await?;
528
529 let mut scored_suggestions: Vec<(String, f64)> = models
531 .iter()
532 .map(|m| {
533 let score = strsim::jaro_winkler(&m.id, model_id);
534 (m.id.clone(), score)
535 })
536 .collect();
537
538 scored_suggestions
540 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
541 let suggestions: Vec<String> = scored_suggestions
542 .into_iter()
543 .take(5)
544 .map(|(id, _)| id)
545 .collect();
546
547 Ok(suggestions)
548 }
549
550 async fn validate_model(&self, provider: &str, model_id: &str) -> Result<(), RegistryError> {
551 if self.model_exists(provider, model_id).await? {
552 Ok(())
553 } else {
554 let suggestions = self.suggest_similar(provider, model_id).await?;
555 Err(RegistryError::ModelValidation {
556 model_id: model_id.to_string(),
557 suggestions,
558 })
559 }
560 }
561}
562
563#[cfg(test)]
564mod tests {
565 use super::*;
566
567 #[test]
568 fn test_load_stale_cache_ignores_ttl() {
569 let temp_dir = std::env::temp_dir().join("aptu_test_stale_cache");
571 let _ = std::fs::create_dir_all(&temp_dir);
572
573 #[allow(clippy::items_after_statements)]
575 struct MockTokenProvider;
576 #[allow(clippy::items_after_statements)]
577 impl crate::auth::TokenProvider for MockTokenProvider {
578 fn github_token(&self) -> Option<secrecy::SecretString> {
579 None
580 }
581 fn ai_api_key(&self, _provider: &str) -> Option<secrecy::SecretString> {
582 None
583 }
584 }
585
586 let mock_provider = MockTokenProvider;
587 let registry = CachedModelRegistry::new(temp_dir.clone(), 1, &mock_provider); let models = vec![
590 CachedModel {
591 id: "test-model-1".to_string(),
592 name: Some("Test Model 1".to_string()),
593 is_free: Some(true),
594 context_window: Some(4096),
595 },
596 CachedModel {
597 id: "test-model-2".to_string(),
598 name: Some("Test Model 2".to_string()),
599 is_free: Some(false),
600 context_window: Some(8192),
601 },
602 ];
603
604 let _ = registry.save_to_cache("test_provider", &models);
606
607 std::thread::sleep(std::time::Duration::from_secs(2));
609
610 let result = registry.load_stale_cache("test_provider");
612
613 assert!(result.is_ok(), "load_stale_cache should succeed");
615 let loaded_models = result.unwrap();
616 assert_eq!(loaded_models.len(), 2);
617 assert_eq!(loaded_models[0].id, "test-model-1");
618 assert_eq!(loaded_models[1].id, "test-model-2");
619
620 let _ = std::fs::remove_dir_all(&temp_dir);
622 }
623
624 #[test]
625 fn test_get_provider_gemini() {
626 let provider = get_provider("gemini");
627 assert!(provider.is_some());
628 let provider = provider.unwrap();
629 assert_eq!(provider.display_name, "Google Gemini");
630 assert_eq!(provider.api_key_env, "GEMINI_API_KEY");
631 }
632
633 #[test]
634 fn test_get_provider_openrouter() {
635 let provider = get_provider("openrouter");
636 assert!(provider.is_some());
637 let provider = provider.unwrap();
638 assert_eq!(provider.display_name, "OpenRouter");
639 assert_eq!(provider.api_key_env, "OPENROUTER_API_KEY");
640 }
641
642 #[test]
643 fn test_get_provider_groq() {
644 let provider = get_provider("groq");
645 assert!(provider.is_some());
646 let provider = provider.unwrap();
647 assert_eq!(provider.display_name, "Groq");
648 assert_eq!(provider.api_key_env, "GROQ_API_KEY");
649 }
650
651 #[test]
652 fn test_get_provider_cerebras() {
653 let provider = get_provider("cerebras");
654 assert!(provider.is_some());
655 let provider = provider.unwrap();
656 assert_eq!(provider.display_name, "Cerebras");
657 assert_eq!(provider.api_key_env, "CEREBRAS_API_KEY");
658 }
659
660 #[test]
661 fn test_get_provider_not_found() {
662 let provider = get_provider("nonexistent");
663 assert!(provider.is_none());
664 }
665
666 #[test]
667 fn test_get_provider_case_sensitive() {
668 let provider = get_provider("OpenRouter");
669 assert!(
670 provider.is_none(),
671 "Provider lookup should be case-sensitive"
672 );
673 }
674
675 #[test]
676 fn test_all_providers_count() {
677 let providers = all_providers();
678 assert_eq!(providers.len(), 6, "Should have exactly 6 providers");
679 }
680
681 #[test]
682 fn test_all_providers_have_unique_names() {
683 let providers = all_providers();
684 let mut names = Vec::new();
685 for provider in providers {
686 assert!(
687 !names.contains(&provider.name),
688 "Duplicate provider name: {}",
689 provider.name
690 );
691 names.push(provider.name);
692 }
693 }
694
695 #[test]
696 fn test_get_provider_zenmux() {
697 let provider = get_provider("zenmux");
698 assert!(provider.is_some());
699 let provider = provider.unwrap();
700 assert_eq!(provider.display_name, "Zenmux");
701 assert_eq!(provider.api_key_env, "ZENMUX_API_KEY");
702 }
703
704 #[test]
705 fn test_get_provider_zai() {
706 let provider = get_provider("zai");
707 assert!(provider.is_some());
708 let provider = provider.unwrap();
709 assert_eq!(provider.display_name, "Z.AI (Zhipu)");
710 assert_eq!(provider.api_key_env, "ZAI_API_KEY");
711 }
712
713 #[test]
714 fn test_provider_api_urls_valid() {
715 let providers = all_providers();
716 for provider in providers {
717 assert!(
718 provider.api_url.starts_with("https://"),
719 "Provider {} API URL should use HTTPS",
720 provider.name
721 );
722 }
723 }
724
725 #[test]
726 fn test_provider_api_key_env_not_empty() {
727 let providers = all_providers();
728 for provider in providers {
729 assert!(
730 !provider.api_key_env.is_empty(),
731 "Provider {} should have API key env var",
732 provider.name
733 );
734 }
735 }
736}