aptu_core/ai/
registry.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Centralized provider configuration registry.
4//!
5//! This module provides a static registry of all AI providers supported by Aptu,
6//! including their metadata, API endpoints, and available models.
7//!
8//! It also provides runtime model validation infrastructure via the `ModelRegistry` trait
9//! with a simple sync implementation using static model lists.
10//!
11//! # Examples
12//!
13//! ```
14//! use aptu_core::ai::registry::{get_provider, all_providers};
15//!
16//! // Get a specific provider
17//! let provider = get_provider("openrouter");
18//! assert!(provider.is_some());
19//!
20//! // Get all providers
21//! let providers = all_providers();
22//! assert_eq!(providers.len(), 6);
23//! ```
24
25use async_trait::async_trait;
26use secrecy::ExposeSecret;
27use serde::{Deserialize, Serialize};
28use std::path::PathBuf;
29use std::time::{SystemTime, UNIX_EPOCH};
30use thiserror::Error;
31
32use crate::auth::TokenProvider;
33
34/// Metadata for a single AI model.
35#[derive(Clone, Copy, Debug, PartialEq, Eq)]
36pub struct ModelInfo {
37    /// Human-readable model name for UI display
38    pub display_name: &'static str,
39
40    /// Provider-specific model identifier used in API requests
41    pub identifier: &'static str,
42
43    /// Whether this model is free to use
44    pub is_free: bool,
45
46    /// Maximum context window size in tokens
47    pub context_window: u32,
48}
49
50/// Configuration for an AI provider.
51#[derive(Clone, Copy, Debug, PartialEq, Eq)]
52pub struct ProviderConfig {
53    /// Provider identifier (lowercase, used in config files)
54    pub name: &'static str,
55
56    /// Human-readable provider name for UI display
57    pub display_name: &'static str,
58
59    /// API base URL for this provider
60    pub api_url: &'static str,
61
62    /// Environment variable name for API key
63    pub api_key_env: &'static str,
64}
65
66// ============================================================================
67// Provider Registry
68// ============================================================================
69
70/// Static registry of all supported AI providers
71pub static PROVIDERS: &[ProviderConfig] = &[
72    ProviderConfig {
73        name: "gemini",
74        display_name: "Google Gemini",
75        api_url: "https://generativelanguage.googleapis.com/v1beta/openai/chat/completions",
76        api_key_env: "GEMINI_API_KEY",
77    },
78    ProviderConfig {
79        name: "openrouter",
80        display_name: "OpenRouter",
81        api_url: "https://openrouter.ai/api/v1/chat/completions",
82        api_key_env: "OPENROUTER_API_KEY",
83    },
84    ProviderConfig {
85        name: "groq",
86        display_name: "Groq",
87        api_url: "https://api.groq.com/openai/v1/chat/completions",
88        api_key_env: "GROQ_API_KEY",
89    },
90    ProviderConfig {
91        name: "cerebras",
92        display_name: "Cerebras",
93        api_url: "https://api.cerebras.ai/v1/chat/completions",
94        api_key_env: "CEREBRAS_API_KEY",
95    },
96    ProviderConfig {
97        name: "zenmux",
98        display_name: "Zenmux",
99        api_url: "https://zenmux.ai/api/v1/chat/completions",
100        api_key_env: "ZENMUX_API_KEY",
101    },
102    ProviderConfig {
103        name: "zai",
104        display_name: "Z.AI (Zhipu)",
105        api_url: "https://api.z.ai/api/paas/v4/chat/completions",
106        api_key_env: "ZAI_API_KEY",
107    },
108];
109
110/// Retrieves a provider configuration by name.
111///
112/// # Arguments
113///
114/// * `name` - The provider name (case-sensitive, lowercase)
115///
116/// # Returns
117///
118/// Some(ProviderConfig) if found, None otherwise.
119///
120/// # Examples
121///
122/// ```
123/// use aptu_core::ai::registry::get_provider;
124///
125/// let provider = get_provider("openrouter");
126/// assert!(provider.is_some());
127/// assert_eq!(provider.unwrap().display_name, "OpenRouter");
128/// ```
129#[must_use]
130pub fn get_provider(name: &str) -> Option<&'static ProviderConfig> {
131    PROVIDERS.iter().find(|p| p.name == name)
132}
133
134/// Returns all available providers.
135///
136/// # Returns
137///
138/// A slice of all `ProviderConfig` entries in the registry.
139///
140/// # Examples
141///
142/// ```
143/// use aptu_core::ai::registry::all_providers;
144///
145/// let providers = all_providers();
146/// assert_eq!(providers.len(), 6);
147/// ```
148#[must_use]
149pub fn all_providers() -> &'static [ProviderConfig] {
150    PROVIDERS
151}
152
153// ============================================================================
154// Runtime Model Validation
155// ============================================================================
156
157/// Error type for model registry operations.
158#[derive(Debug, Error)]
159pub enum RegistryError {
160    /// HTTP request failed.
161    #[error("HTTP request failed: {0}")]
162    HttpError(String),
163
164    /// Failed to parse API response.
165    #[error("Failed to parse API response: {0}")]
166    ParseError(String),
167
168    /// Provider not found.
169    #[error("Provider not found: {0}")]
170    ProviderNotFound(String),
171
172    /// Cache error.
173    #[error("Cache error: {0}")]
174    CacheError(String),
175
176    /// IO error.
177    #[error("IO error: {0}")]
178    IoError(#[from] std::io::Error),
179
180    /// Model validation error - invalid model ID with suggestions.
181    #[error("Invalid model ID: {model_id}. Did you mean one of these?\n{}", .suggestions.join(", "))]
182    ModelValidation {
183        /// The invalid model ID provided by the user.
184        model_id: String,
185        /// Suggested valid model IDs based on fuzzy matching.
186        suggestions: Vec<String>,
187    },
188}
189
190/// Cached model information from API responses.
191#[derive(Clone, Debug, Serialize, Deserialize)]
192pub struct CachedModel {
193    /// Model identifier from the provider API.
194    pub id: String,
195    /// Human-readable model name.
196    pub name: Option<String>,
197    /// Whether the model is free to use.
198    pub is_free: Option<bool>,
199    /// Maximum context window size in tokens.
200    pub context_window: Option<u32>,
201}
202
203/// Cache metadata for TTL validation.
204#[derive(Clone, Debug, Serialize, Deserialize)]
205pub struct CacheMetadata {
206    /// Unix timestamp when the cache entry was created.
207    pub timestamp: u64,
208    /// Time-to-live in seconds.
209    pub ttl_seconds: u64,
210}
211
212/// Trait for runtime model validation and listing.
213#[async_trait]
214pub trait ModelRegistry: Send + Sync {
215    /// List all available models for a provider.
216    async fn list_models(&self, provider: &str) -> Result<Vec<CachedModel>, RegistryError>;
217
218    /// Check if a model exists for a provider.
219    async fn model_exists(&self, provider: &str, model_id: &str) -> Result<bool, RegistryError>;
220
221    /// Suggest similar models when a model is not found.
222    async fn suggest_similar(
223        &self,
224        provider: &str,
225        model_id: &str,
226    ) -> Result<Vec<String>, RegistryError>;
227
228    /// Validate that a model ID exists for a provider.
229    ///
230    /// Returns an error with fuzzy-matched suggestions if the model is not found.
231    async fn validate_model(&self, provider: &str, model_id: &str) -> Result<(), RegistryError>;
232}
233
234/// Cached model registry with HTTP client and TTL support.
235pub struct CachedModelRegistry<'a> {
236    cache_dir: PathBuf,
237    ttl_seconds: u64,
238    client: reqwest::Client,
239    token_provider: &'a dyn TokenProvider,
240}
241
242impl CachedModelRegistry<'_> {
243    /// Create a new cached model registry.
244    ///
245    /// # Arguments
246    ///
247    /// * `cache_dir` - Directory for storing cached model lists
248    /// * `ttl_seconds` - Time-to-live for cache entries (default: 86400 = 24 hours)
249    /// * `token_provider` - Token provider for API credentials
250    #[must_use]
251    pub fn new(
252        cache_dir: PathBuf,
253        ttl_seconds: u64,
254        token_provider: &dyn TokenProvider,
255    ) -> CachedModelRegistry<'_> {
256        CachedModelRegistry {
257            cache_dir,
258            ttl_seconds,
259            client: reqwest::Client::builder()
260                .timeout(std::time::Duration::from_secs(10))
261                .build()
262                .unwrap_or_else(|_| reqwest::Client::new()),
263            token_provider,
264        }
265    }
266
267    /// Get the cache file path for a provider.
268    fn cache_path(&self, provider: &str) -> PathBuf {
269        self.cache_dir.join(format!("models_{provider}.json"))
270    }
271
272    /// Check if cache is still valid.
273    fn is_cache_valid(metadata: &CacheMetadata) -> bool {
274        let now = SystemTime::now()
275            .duration_since(UNIX_EPOCH)
276            .unwrap_or_default()
277            .as_secs();
278        now < metadata.timestamp + metadata.ttl_seconds
279    }
280
281    /// Load models from cache, respecting TTL.
282    fn load_from_cache(&self, provider: &str) -> Result<Vec<CachedModel>, RegistryError> {
283        let path = self.cache_path(provider);
284        if !path.exists() {
285            return Err(RegistryError::CacheError(
286                "Cache file not found".to_string(),
287            ));
288        }
289
290        let content = std::fs::read_to_string(&path)?;
291        let data: serde_json::Value =
292            serde_json::from_str(&content).map_err(|e| RegistryError::ParseError(e.to_string()))?;
293
294        // Check TTL and extract models in one step
295        if let Some(metadata) = data
296            .get("metadata")
297            .and_then(|m| serde_json::from_value::<CacheMetadata>(m.clone()).ok())
298        {
299            if Self::is_cache_valid(&metadata) {
300                return data
301                    .get("models")
302                    .and_then(|m| serde_json::from_value::<Vec<CachedModel>>(m.clone()).ok())
303                    .ok_or_else(|| RegistryError::ParseError("Invalid cache format".to_string()));
304            }
305            return Err(RegistryError::CacheError("Cache expired".to_string()));
306        }
307
308        // Extract models if no metadata
309        data.get("models")
310            .and_then(|m| serde_json::from_value::<Vec<CachedModel>>(m.clone()).ok())
311            .ok_or_else(|| RegistryError::ParseError("Invalid cache format".to_string()))
312    }
313
314    /// Load models from cache regardless of TTL (stale fallback).
315    fn load_stale_cache(&self, provider: &str) -> Result<Vec<CachedModel>, RegistryError> {
316        let path = self.cache_path(provider);
317        if !path.exists() {
318            return Err(RegistryError::CacheError(
319                "Cache file not found".to_string(),
320            ));
321        }
322
323        let content = std::fs::read_to_string(&path)?;
324        let data: serde_json::Value =
325            serde_json::from_str(&content).map_err(|e| RegistryError::ParseError(e.to_string()))?;
326
327        // Extract models regardless of TTL
328        data.get("models")
329            .and_then(|m| serde_json::from_value::<Vec<CachedModel>>(m.clone()).ok())
330            .ok_or_else(|| RegistryError::ParseError("Invalid cache format".to_string()))
331    }
332
333    /// Save models to cache.
334    fn save_to_cache(&self, provider: &str, models: &[CachedModel]) -> Result<(), RegistryError> {
335        std::fs::create_dir_all(&self.cache_dir)?;
336
337        let now = SystemTime::now()
338            .duration_since(UNIX_EPOCH)
339            .unwrap_or_default()
340            .as_secs();
341
342        let cache_data = serde_json::json!({
343            "metadata": {
344                "timestamp": now,
345                "ttl_seconds": self.ttl_seconds,
346            },
347            "models": models,
348        });
349
350        let path = self.cache_path(provider);
351        std::fs::write(&path, cache_data.to_string())?;
352        Ok(())
353    }
354
355    /// Parse `OpenRouter` API response into models.
356    fn parse_openrouter_models(data: &serde_json::Value) -> Vec<CachedModel> {
357        data.get("data")
358            .and_then(|d| d.as_array())
359            .map(|arr| {
360                arr.iter()
361                    .filter_map(|m| {
362                        Some(CachedModel {
363                            id: m.get("id")?.as_str()?.to_string(),
364                            name: m.get("name").and_then(|n| n.as_str()).map(String::from),
365                            is_free: m
366                                .get("pricing")
367                                .and_then(|p| p.get("prompt"))
368                                .and_then(|p| p.as_str())
369                                .map(|p| p == "0"),
370                            context_window: m
371                                .get("context_length")
372                                .and_then(serde_json::Value::as_u64)
373                                .and_then(|c| u32::try_from(c).ok()),
374                        })
375                    })
376                    .collect()
377            })
378            .unwrap_or_default()
379    }
380
381    /// Parse Gemini API response into models.
382    fn parse_gemini_models(data: &serde_json::Value) -> Vec<CachedModel> {
383        data.get("models")
384            .and_then(|d| d.as_array())
385            .map(|arr| {
386                arr.iter()
387                    .filter_map(|m| {
388                        Some(CachedModel {
389                            id: m.get("name")?.as_str()?.to_string(),
390                            name: m
391                                .get("displayName")
392                                .and_then(|n| n.as_str())
393                                .map(String::from),
394                            is_free: None,
395                            context_window: m
396                                .get("inputTokenLimit")
397                                .and_then(serde_json::Value::as_u64)
398                                .and_then(|c| u32::try_from(c).ok()),
399                        })
400                    })
401                    .collect()
402            })
403            .unwrap_or_default()
404    }
405
406    /// Parse generic OpenAI-compatible API response into models.
407    fn parse_generic_models(data: &serde_json::Value) -> Vec<CachedModel> {
408        data.get("data")
409            .and_then(|d| d.as_array())
410            .map(|arr| {
411                arr.iter()
412                    .filter_map(|m| {
413                        Some(CachedModel {
414                            id: m.get("id")?.as_str()?.to_string(),
415                            name: None,
416                            is_free: None,
417                            context_window: None,
418                        })
419                    })
420                    .collect()
421            })
422            .unwrap_or_default()
423    }
424
425    /// Fetch models from provider API.
426    async fn fetch_from_api(&self, provider: &str) -> Result<Vec<CachedModel>, RegistryError> {
427        let url = match provider {
428            "openrouter" => "https://openrouter.ai/api/v1/models",
429            "gemini" => "https://generativelanguage.googleapis.com/v1beta/models",
430            "groq" => "https://api.groq.com/openai/v1/models",
431            "cerebras" => "https://api.cerebras.ai/v1/models",
432            "zenmux" => "https://zenmux.ai/api/v1/models",
433            "zai" => "https://api.z.ai/api/paas/v4/models",
434            _ => return Err(RegistryError::ProviderNotFound(provider.to_string())),
435        };
436
437        // Get API key from token provider
438        let api_key = self.token_provider.ai_api_key(provider).ok_or_else(|| {
439            RegistryError::HttpError(format!("No API key available for {provider}"))
440        })?;
441
442        // Build request incrementally with provider-specific authentication
443        let request = match provider {
444            "gemini" => {
445                // Gemini uses query parameter authentication
446                self.client
447                    .get(url)
448                    .query(&[("key", api_key.expose_secret())])
449            }
450            "openrouter" | "groq" | "cerebras" | "zenmux" | "zai" => {
451                // These providers use Bearer token authentication
452                self.client.get(url).header(
453                    "Authorization",
454                    format!("Bearer {}", api_key.expose_secret()),
455                )
456            }
457            _ => self.client.get(url),
458        };
459
460        let response = request
461            .send()
462            .await
463            .map_err(|e| RegistryError::HttpError(e.to_string()))?;
464
465        let data = response
466            .json::<serde_json::Value>()
467            .await
468            .map_err(|e| RegistryError::HttpError(e.to_string()))?;
469
470        // Parse based on provider API format
471        let models = match provider {
472            "openrouter" => Self::parse_openrouter_models(&data),
473            "gemini" => Self::parse_gemini_models(&data),
474            "groq" | "cerebras" | "zenmux" | "zai" => Self::parse_generic_models(&data),
475            _ => vec![],
476        };
477
478        Ok(models)
479    }
480}
481
482#[async_trait]
483impl ModelRegistry for CachedModelRegistry<'_> {
484    async fn list_models(&self, provider: &str) -> Result<Vec<CachedModel>, RegistryError> {
485        // Try fresh cache first
486        if let Ok(models) = self.load_from_cache(provider) {
487            return Ok(models);
488        }
489
490        // Fetch from API with stale fallback
491        match self.fetch_from_api(provider).await {
492            Ok(models) => {
493                // Save to cache (ignore errors)
494                let _ = self.save_to_cache(provider, &models);
495                Ok(models)
496            }
497            Err(api_error) => {
498                // Try stale cache as fallback
499                match self.load_stale_cache(provider) {
500                    Ok(models) => {
501                        tracing::warn!(
502                            provider = provider,
503                            error = %api_error,
504                            "API request failed, returning stale cached models"
505                        );
506                        Ok(models)
507                    }
508                    Err(_) => {
509                        // No stale cache available, return original API error
510                        Err(api_error)
511                    }
512                }
513            }
514        }
515    }
516
517    async fn model_exists(&self, provider: &str, model_id: &str) -> Result<bool, RegistryError> {
518        let models = self.list_models(provider).await?;
519        Ok(models.iter().any(|m| m.id == model_id))
520    }
521
522    async fn suggest_similar(
523        &self,
524        provider: &str,
525        model_id: &str,
526    ) -> Result<Vec<String>, RegistryError> {
527        let models = self.list_models(provider).await?;
528
529        // Use jaro_winkler fuzzy matching to score similarity
530        let mut scored_suggestions: Vec<(String, f64)> = models
531            .iter()
532            .map(|m| {
533                let score = strsim::jaro_winkler(&m.id, model_id);
534                (m.id.clone(), score)
535            })
536            .collect();
537
538        // Sort by similarity score (descending) and take top 5
539        scored_suggestions
540            .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
541        let suggestions: Vec<String> = scored_suggestions
542            .into_iter()
543            .take(5)
544            .map(|(id, _)| id)
545            .collect();
546
547        Ok(suggestions)
548    }
549
550    async fn validate_model(&self, provider: &str, model_id: &str) -> Result<(), RegistryError> {
551        if self.model_exists(provider, model_id).await? {
552            Ok(())
553        } else {
554            let suggestions = self.suggest_similar(provider, model_id).await?;
555            Err(RegistryError::ModelValidation {
556                model_id: model_id.to_string(),
557                suggestions,
558            })
559        }
560    }
561}
562
563#[cfg(test)]
564mod tests {
565    use super::*;
566
567    #[test]
568    fn test_load_stale_cache_ignores_ttl() {
569        // Arrange: Create a temporary cache file with expired TTL
570        let temp_dir = std::env::temp_dir().join("aptu_test_stale_cache");
571        let _ = std::fs::create_dir_all(&temp_dir);
572
573        // Create a mock token provider
574        #[allow(clippy::items_after_statements)]
575        struct MockTokenProvider;
576        #[allow(clippy::items_after_statements)]
577        impl crate::auth::TokenProvider for MockTokenProvider {
578            fn github_token(&self) -> Option<secrecy::SecretString> {
579                None
580            }
581            fn ai_api_key(&self, _provider: &str) -> Option<secrecy::SecretString> {
582                None
583            }
584        }
585
586        let mock_provider = MockTokenProvider;
587        let registry = CachedModelRegistry::new(temp_dir.clone(), 1, &mock_provider); // 1 second TTL
588
589        let models = vec![
590            CachedModel {
591                id: "test-model-1".to_string(),
592                name: Some("Test Model 1".to_string()),
593                is_free: Some(true),
594                context_window: Some(4096),
595            },
596            CachedModel {
597                id: "test-model-2".to_string(),
598                name: Some("Test Model 2".to_string()),
599                is_free: Some(false),
600                context_window: Some(8192),
601            },
602        ];
603
604        // Save to cache
605        let _ = registry.save_to_cache("test_provider", &models);
606
607        // Wait for TTL to expire
608        std::thread::sleep(std::time::Duration::from_secs(2));
609
610        // Act: Load stale cache (should succeed despite expired TTL)
611        let result = registry.load_stale_cache("test_provider");
612
613        // Assert
614        assert!(result.is_ok(), "load_stale_cache should succeed");
615        let loaded_models = result.unwrap();
616        assert_eq!(loaded_models.len(), 2);
617        assert_eq!(loaded_models[0].id, "test-model-1");
618        assert_eq!(loaded_models[1].id, "test-model-2");
619
620        // Cleanup
621        let _ = std::fs::remove_dir_all(&temp_dir);
622    }
623
624    #[test]
625    fn test_get_provider_gemini() {
626        let provider = get_provider("gemini");
627        assert!(provider.is_some());
628        let provider = provider.unwrap();
629        assert_eq!(provider.display_name, "Google Gemini");
630        assert_eq!(provider.api_key_env, "GEMINI_API_KEY");
631    }
632
633    #[test]
634    fn test_get_provider_openrouter() {
635        let provider = get_provider("openrouter");
636        assert!(provider.is_some());
637        let provider = provider.unwrap();
638        assert_eq!(provider.display_name, "OpenRouter");
639        assert_eq!(provider.api_key_env, "OPENROUTER_API_KEY");
640    }
641
642    #[test]
643    fn test_get_provider_groq() {
644        let provider = get_provider("groq");
645        assert!(provider.is_some());
646        let provider = provider.unwrap();
647        assert_eq!(provider.display_name, "Groq");
648        assert_eq!(provider.api_key_env, "GROQ_API_KEY");
649    }
650
651    #[test]
652    fn test_get_provider_cerebras() {
653        let provider = get_provider("cerebras");
654        assert!(provider.is_some());
655        let provider = provider.unwrap();
656        assert_eq!(provider.display_name, "Cerebras");
657        assert_eq!(provider.api_key_env, "CEREBRAS_API_KEY");
658    }
659
660    #[test]
661    fn test_get_provider_not_found() {
662        let provider = get_provider("nonexistent");
663        assert!(provider.is_none());
664    }
665
666    #[test]
667    fn test_get_provider_case_sensitive() {
668        let provider = get_provider("OpenRouter");
669        assert!(
670            provider.is_none(),
671            "Provider lookup should be case-sensitive"
672        );
673    }
674
675    #[test]
676    fn test_all_providers_count() {
677        let providers = all_providers();
678        assert_eq!(providers.len(), 6, "Should have exactly 6 providers");
679    }
680
681    #[test]
682    fn test_all_providers_have_unique_names() {
683        let providers = all_providers();
684        let mut names = Vec::new();
685        for provider in providers {
686            assert!(
687                !names.contains(&provider.name),
688                "Duplicate provider name: {}",
689                provider.name
690            );
691            names.push(provider.name);
692        }
693    }
694
695    #[test]
696    fn test_get_provider_zenmux() {
697        let provider = get_provider("zenmux");
698        assert!(provider.is_some());
699        let provider = provider.unwrap();
700        assert_eq!(provider.display_name, "Zenmux");
701        assert_eq!(provider.api_key_env, "ZENMUX_API_KEY");
702    }
703
704    #[test]
705    fn test_get_provider_zai() {
706        let provider = get_provider("zai");
707        assert!(provider.is_some());
708        let provider = provider.unwrap();
709        assert_eq!(provider.display_name, "Z.AI (Zhipu)");
710        assert_eq!(provider.api_key_env, "ZAI_API_KEY");
711    }
712
713    #[test]
714    fn test_provider_api_urls_valid() {
715        let providers = all_providers();
716        for provider in providers {
717            assert!(
718                provider.api_url.starts_with("https://"),
719                "Provider {} API URL should use HTTPS",
720                provider.name
721            );
722        }
723    }
724
725    #[test]
726    fn test_provider_api_key_env_not_empty() {
727        let providers = all_providers();
728        for provider in providers {
729            assert!(
730                !provider.api_key_env.is_empty(),
731                "Provider {} should have API key env var",
732                provider.name
733            );
734        }
735    }
736}