Skip to main content

aptu_core/ai/
registry.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Centralized provider configuration registry.
4//!
5//! This module provides a static registry of all AI providers supported by Aptu,
6//! including their metadata, API endpoints, and available models.
7//!
8//! It also provides runtime model validation infrastructure via the `ModelRegistry` trait
9//! with a simple sync implementation using static model lists.
10//!
11//! # Examples
12//!
13//! ```
14//! use aptu_core::ai::registry::{get_provider, all_providers};
15//!
16//! // Get a specific provider
17//! let provider = get_provider("openrouter");
18//! assert!(provider.is_some());
19//!
20//! // Get all providers
21//! let providers = all_providers();
22//! assert_eq!(providers.len(), 6);
23//! ```
24
25use async_trait::async_trait;
26use secrecy::ExposeSecret;
27use serde::{Deserialize, Serialize};
28use std::path::PathBuf;
29use thiserror::Error;
30
31use crate::auth::TokenProvider;
32use crate::cache::FileCache;
33
34/// Metadata for a single AI model.
35#[derive(Clone, Copy, Debug, PartialEq, Eq)]
36pub struct ModelInfo {
37    /// Human-readable model name for UI display
38    pub display_name: &'static str,
39
40    /// Provider-specific model identifier used in API requests
41    pub identifier: &'static str,
42
43    /// Whether this model is free to use
44    pub is_free: bool,
45
46    /// Maximum context window size in tokens
47    pub context_window: u32,
48}
49
50/// Configuration for an AI provider.
51#[derive(Clone, Copy, Debug, PartialEq, Eq)]
52pub struct ProviderConfig {
53    /// Provider identifier (lowercase, used in config files)
54    pub name: &'static str,
55
56    /// Human-readable provider name for UI display
57    pub display_name: &'static str,
58
59    /// API base URL for this provider
60    pub api_url: &'static str,
61
62    /// Environment variable name for API key
63    pub api_key_env: &'static str,
64}
65
66// ============================================================================
67// Provider Registry
68// ============================================================================
69
70/// Static registry of all supported AI providers
71pub static PROVIDERS: &[ProviderConfig] = &[
72    ProviderConfig {
73        name: "gemini",
74        display_name: "Google Gemini",
75        api_url: "https://generativelanguage.googleapis.com/v1beta/openai/chat/completions",
76        api_key_env: "GEMINI_API_KEY",
77    },
78    ProviderConfig {
79        name: "openrouter",
80        display_name: "OpenRouter",
81        api_url: "https://openrouter.ai/api/v1/chat/completions",
82        api_key_env: "OPENROUTER_API_KEY",
83    },
84    ProviderConfig {
85        name: "groq",
86        display_name: "Groq",
87        api_url: "https://api.groq.com/openai/v1/chat/completions",
88        api_key_env: "GROQ_API_KEY",
89    },
90    ProviderConfig {
91        name: "cerebras",
92        display_name: "Cerebras",
93        api_url: "https://api.cerebras.ai/v1/chat/completions",
94        api_key_env: "CEREBRAS_API_KEY",
95    },
96    ProviderConfig {
97        name: "zenmux",
98        display_name: "Zenmux",
99        api_url: "https://zenmux.ai/api/v1/chat/completions",
100        api_key_env: "ZENMUX_API_KEY",
101    },
102    ProviderConfig {
103        name: "zai",
104        display_name: "Z.AI (Zhipu)",
105        api_url: "https://api.z.ai/api/paas/v4/chat/completions",
106        api_key_env: "ZAI_API_KEY",
107    },
108];
109
110/// Retrieves a provider configuration by name.
111///
112/// # Arguments
113///
114/// * `name` - The provider name (case-sensitive, lowercase)
115///
116/// # Returns
117///
118/// Some(ProviderConfig) if found, None otherwise.
119///
120/// # Examples
121///
122/// ```
123/// use aptu_core::ai::registry::get_provider;
124///
125/// let provider = get_provider("openrouter");
126/// assert!(provider.is_some());
127/// assert_eq!(provider.unwrap().display_name, "OpenRouter");
128/// ```
129#[must_use]
130pub fn get_provider(name: &str) -> Option<&'static ProviderConfig> {
131    PROVIDERS.iter().find(|p| p.name == name)
132}
133
134/// Returns all available providers.
135///
136/// # Returns
137///
138/// A slice of all `ProviderConfig` entries in the registry.
139///
140/// # Examples
141///
142/// ```
143/// use aptu_core::ai::registry::all_providers;
144///
145/// let providers = all_providers();
146/// assert_eq!(providers.len(), 6);
147/// ```
148#[must_use]
149pub fn all_providers() -> &'static [ProviderConfig] {
150    PROVIDERS
151}
152
153// ============================================================================
154// Runtime Model Validation
155// ============================================================================
156
157/// Error type for model registry operations.
158#[derive(Debug, Error)]
159pub enum RegistryError {
160    /// HTTP request failed.
161    #[error("HTTP request failed: {0}")]
162    HttpError(String),
163
164    /// Failed to parse API response.
165    #[error("Failed to parse API response: {0}")]
166    ParseError(String),
167
168    /// Provider not found.
169    #[error("Provider not found: {0}")]
170    ProviderNotFound(String),
171
172    /// Cache error.
173    #[error("Cache error: {0}")]
174    CacheError(String),
175
176    /// IO error.
177    #[error("IO error: {0}")]
178    IoError(#[from] std::io::Error),
179
180    /// Model validation error - invalid model ID.
181    #[error("Invalid model ID: {model_id}")]
182    ModelValidation {
183        /// The invalid model ID provided by the user.
184        model_id: String,
185    },
186}
187
188/// Cached model information from API responses.
189#[derive(Clone, Debug, Serialize, Deserialize)]
190pub struct CachedModel {
191    /// Model identifier from the provider API.
192    pub id: String,
193    /// Human-readable model name.
194    pub name: Option<String>,
195    /// Whether the model is free to use.
196    pub is_free: Option<bool>,
197    /// Maximum context window size in tokens.
198    pub context_window: Option<u32>,
199}
200
201/// Trait for runtime model validation and listing.
202#[async_trait]
203pub trait ModelRegistry: Send + Sync {
204    /// List all available models for a provider.
205    async fn list_models(&self, provider: &str) -> Result<Vec<CachedModel>, RegistryError>;
206
207    /// Check if a model exists for a provider.
208    async fn model_exists(&self, provider: &str, model_id: &str) -> Result<bool, RegistryError>;
209
210    /// Validate that a model ID exists for a provider.
211    async fn validate_model(&self, provider: &str, model_id: &str) -> Result<(), RegistryError>;
212}
213
214/// Cached model registry with HTTP client and TTL support.
215pub struct CachedModelRegistry<'a> {
216    cache: crate::cache::FileCacheImpl<Vec<CachedModel>>,
217    client: reqwest::Client,
218    token_provider: &'a dyn TokenProvider,
219}
220
221impl CachedModelRegistry<'_> {
222    /// Create a new cached model registry.
223    ///
224    /// # Arguments
225    ///
226    /// * `cache_dir` - Directory for storing cached model lists (None to disable caching)
227    /// * `ttl_seconds` - Time-to-live for cache entries (see `DEFAULT_MODEL_TTL_SECS`)
228    /// * `token_provider` - Token provider for API credentials
229    #[must_use]
230    pub fn new(
231        cache_dir: Option<PathBuf>,
232        ttl_seconds: u64,
233        token_provider: &dyn TokenProvider,
234    ) -> CachedModelRegistry<'_> {
235        let ttl = chrono::Duration::seconds(
236            ttl_seconds
237                .try_into()
238                .unwrap_or(crate::cache::DEFAULT_MODEL_TTL_SECS.cast_signed()),
239        );
240        CachedModelRegistry {
241            cache: crate::cache::FileCacheImpl::with_dir(cache_dir, "models", ttl),
242            client: reqwest::Client::builder()
243                .timeout(std::time::Duration::from_secs(10))
244                .build()
245                .unwrap_or_else(|_| reqwest::Client::new()),
246            token_provider,
247        }
248    }
249
250    /// Parse `OpenRouter` API response into models.
251    fn parse_openrouter_models(data: &serde_json::Value) -> Vec<CachedModel> {
252        data.get("data")
253            .and_then(|d| d.as_array())
254            .map(|arr| {
255                arr.iter()
256                    .filter_map(|m| {
257                        Some(CachedModel {
258                            id: m.get("id")?.as_str()?.to_string(),
259                            name: m.get("name").and_then(|n| n.as_str()).map(String::from),
260                            is_free: m
261                                .get("pricing")
262                                .and_then(|p| p.get("prompt"))
263                                .and_then(|p| p.as_str())
264                                .map(|p| p == "0"),
265                            context_window: m
266                                .get("context_length")
267                                .and_then(serde_json::Value::as_u64)
268                                .and_then(|c| u32::try_from(c).ok()),
269                        })
270                    })
271                    .collect()
272            })
273            .unwrap_or_default()
274    }
275
276    /// Parse Gemini API response into models.
277    fn parse_gemini_models(data: &serde_json::Value) -> Vec<CachedModel> {
278        data.get("models")
279            .and_then(|d| d.as_array())
280            .map(|arr| {
281                arr.iter()
282                    .filter_map(|m| {
283                        Some(CachedModel {
284                            id: m.get("name")?.as_str()?.to_string(),
285                            name: m
286                                .get("displayName")
287                                .and_then(|n| n.as_str())
288                                .map(String::from),
289                            is_free: None,
290                            context_window: m
291                                .get("inputTokenLimit")
292                                .and_then(serde_json::Value::as_u64)
293                                .and_then(|c| u32::try_from(c).ok()),
294                        })
295                    })
296                    .collect()
297            })
298            .unwrap_or_default()
299    }
300
301    /// Parse generic OpenAI-compatible API response into models.
302    fn parse_generic_models(data: &serde_json::Value) -> Vec<CachedModel> {
303        data.get("data")
304            .and_then(|d| d.as_array())
305            .map(|arr| {
306                arr.iter()
307                    .filter_map(|m| {
308                        Some(CachedModel {
309                            id: m.get("id")?.as_str()?.to_string(),
310                            name: None,
311                            is_free: None,
312                            context_window: None,
313                        })
314                    })
315                    .collect()
316            })
317            .unwrap_or_default()
318    }
319
320    /// Fetch models from provider API.
321    async fn fetch_from_api(&self, provider: &str) -> Result<Vec<CachedModel>, RegistryError> {
322        let url = match provider {
323            "openrouter" => "https://openrouter.ai/api/v1/models",
324            "gemini" => "https://generativelanguage.googleapis.com/v1beta/models",
325            "groq" => "https://api.groq.com/openai/v1/models",
326            "cerebras" => "https://api.cerebras.ai/v1/models",
327            "zenmux" => "https://zenmux.ai/api/v1/models",
328            "zai" => "https://api.z.ai/api/paas/v4/models",
329            _ => return Err(RegistryError::ProviderNotFound(provider.to_string())),
330        };
331
332        // Get API key from token provider
333        let api_key = self.token_provider.ai_api_key(provider).ok_or_else(|| {
334            RegistryError::HttpError(format!("No API key available for {provider}"))
335        })?;
336
337        // Build request incrementally with provider-specific authentication
338        let request = match provider {
339            "gemini" => {
340                // Gemini uses query parameter authentication
341                self.client
342                    .get(url)
343                    .query(&[("key", api_key.expose_secret())])
344            }
345            "openrouter" | "groq" | "cerebras" | "zenmux" | "zai" => {
346                // These providers use Bearer token authentication
347                self.client.get(url).header(
348                    "Authorization",
349                    format!("Bearer {}", api_key.expose_secret()),
350                )
351            }
352            _ => self.client.get(url),
353        };
354
355        let response = request
356            .send()
357            .await
358            .map_err(|e| RegistryError::HttpError(e.to_string()))?;
359
360        let data = response
361            .json::<serde_json::Value>()
362            .await
363            .map_err(|e| RegistryError::HttpError(e.to_string()))?;
364
365        // Parse based on provider API format
366        let models = match provider {
367            "openrouter" => Self::parse_openrouter_models(&data),
368            "gemini" => Self::parse_gemini_models(&data),
369            "groq" | "cerebras" | "zenmux" | "zai" => Self::parse_generic_models(&data),
370            _ => vec![],
371        };
372
373        Ok(models)
374    }
375}
376
377#[async_trait]
378impl ModelRegistry for CachedModelRegistry<'_> {
379    async fn list_models(&self, provider: &str) -> Result<Vec<CachedModel>, RegistryError> {
380        // Try fresh cache first
381        if let Ok(Some(models)) = self.cache.get(provider) {
382            return Ok(models);
383        }
384
385        // Fetch from API with stale fallback
386        match self.fetch_from_api(provider).await {
387            Ok(models) => {
388                // Save to cache (ignore errors)
389                let _ = self.cache.set(provider, &models);
390                Ok(models)
391            }
392            Err(api_error) => {
393                // Try stale cache as fallback
394                match self.cache.get_stale(provider) {
395                    Ok(Some(models)) => {
396                        tracing::warn!(
397                            provider = provider,
398                            error = %api_error,
399                            "API request failed, returning stale cached models"
400                        );
401                        Ok(models)
402                    }
403                    _ => {
404                        // No stale cache available, return original API error
405                        Err(api_error)
406                    }
407                }
408            }
409        }
410    }
411
412    async fn model_exists(&self, provider: &str, model_id: &str) -> Result<bool, RegistryError> {
413        let models = self.list_models(provider).await?;
414        Ok(models.iter().any(|m| m.id == model_id))
415    }
416
417    async fn validate_model(&self, provider: &str, model_id: &str) -> Result<(), RegistryError> {
418        if self.model_exists(provider, model_id).await? {
419            Ok(())
420        } else {
421            Err(RegistryError::ModelValidation {
422                model_id: model_id.to_string(),
423            })
424        }
425    }
426}
427
428#[cfg(test)]
429mod tests {
430    use super::*;
431
432    #[test]
433    fn test_get_provider_gemini() {
434        let provider = get_provider("gemini");
435        assert!(provider.is_some());
436        let provider = provider.unwrap();
437        assert_eq!(provider.display_name, "Google Gemini");
438        assert_eq!(provider.api_key_env, "GEMINI_API_KEY");
439    }
440
441    #[test]
442    fn test_get_provider_openrouter() {
443        let provider = get_provider("openrouter");
444        assert!(provider.is_some());
445        let provider = provider.unwrap();
446        assert_eq!(provider.display_name, "OpenRouter");
447        assert_eq!(provider.api_key_env, "OPENROUTER_API_KEY");
448    }
449
450    #[test]
451    fn test_get_provider_groq() {
452        let provider = get_provider("groq");
453        assert!(provider.is_some());
454        let provider = provider.unwrap();
455        assert_eq!(provider.display_name, "Groq");
456        assert_eq!(provider.api_key_env, "GROQ_API_KEY");
457    }
458
459    #[test]
460    fn test_get_provider_cerebras() {
461        let provider = get_provider("cerebras");
462        assert!(provider.is_some());
463        let provider = provider.unwrap();
464        assert_eq!(provider.display_name, "Cerebras");
465        assert_eq!(provider.api_key_env, "CEREBRAS_API_KEY");
466    }
467
468    #[test]
469    fn test_get_provider_not_found() {
470        let provider = get_provider("nonexistent");
471        assert!(provider.is_none());
472    }
473
474    #[test]
475    fn test_get_provider_case_sensitive() {
476        let provider = get_provider("OpenRouter");
477        assert!(
478            provider.is_none(),
479            "Provider lookup should be case-sensitive"
480        );
481    }
482
483    #[test]
484    fn test_all_providers_count() {
485        let providers = all_providers();
486        assert_eq!(providers.len(), 6, "Should have exactly 6 providers");
487    }
488
489    #[test]
490    fn test_all_providers_have_unique_names() {
491        let providers = all_providers();
492        let mut names = Vec::new();
493        for provider in providers {
494            assert!(
495                !names.contains(&provider.name),
496                "Duplicate provider name: {}",
497                provider.name
498            );
499            names.push(provider.name);
500        }
501    }
502
503    #[test]
504    fn test_get_provider_zenmux() {
505        let provider = get_provider("zenmux");
506        assert!(provider.is_some());
507        let provider = provider.unwrap();
508        assert_eq!(provider.display_name, "Zenmux");
509        assert_eq!(provider.api_key_env, "ZENMUX_API_KEY");
510    }
511
512    #[test]
513    fn test_get_provider_zai() {
514        let provider = get_provider("zai");
515        assert!(provider.is_some());
516        let provider = provider.unwrap();
517        assert_eq!(provider.display_name, "Z.AI (Zhipu)");
518        assert_eq!(provider.api_key_env, "ZAI_API_KEY");
519    }
520
521    #[test]
522    fn test_provider_api_urls_valid() {
523        let providers = all_providers();
524        for provider in providers {
525            assert!(
526                provider.api_url.starts_with("https://"),
527                "Provider {} API URL should use HTTPS",
528                provider.name
529            );
530        }
531    }
532
533    #[test]
534    fn test_provider_api_key_env_not_empty() {
535        let providers = all_providers();
536        for provider in providers {
537            assert!(
538                !provider.api_key_env.is_empty(),
539                "Provider {} should have API key env var",
540                provider.name
541            );
542        }
543    }
544}