Skip to main content

aptu_core/ai/
registry.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Centralized provider configuration registry.
4//!
5//! This module provides a static registry of all AI providers supported by Aptu,
6//! including their metadata, API endpoints, and available models.
7//!
8//! It also provides runtime model validation infrastructure via the `ModelRegistry` trait
9//! with a simple sync implementation using static model lists.
10//!
11//! # Examples
12//!
13//! ```
14//! use aptu_core::ai::registry::{get_provider, all_providers};
15//!
16//! // Get a specific provider
17//! let provider = get_provider("openrouter");
18//! assert!(provider.is_some());
19//!
20//! // Get all providers
21//! let providers = all_providers();
22//! assert_eq!(providers.len(), 7);
23//! ```
24
25use async_trait::async_trait;
26use secrecy::ExposeSecret;
27use serde::{Deserialize, Serialize};
28use std::path::PathBuf;
29use thiserror::Error;
30
31use crate::auth::TokenProvider;
32use crate::cache::FileCache;
33
34/// Configuration for an AI provider.
35#[derive(Clone, Copy, Debug, PartialEq, Eq)]
36pub struct ProviderConfig {
37    /// Provider identifier (lowercase, used in config files)
38    pub name: &'static str,
39
40    /// Human-readable provider name for UI display
41    pub display_name: &'static str,
42
43    /// API base URL for this provider
44    pub api_url: &'static str,
45
46    /// Environment variable name for API key
47    pub api_key_env: &'static str,
48}
49
50// ============================================================================
51// Provider Registry
52// ============================================================================
53
54/// Provider name constant for Anthropic.
55///
56/// Used throughout the codebase to avoid hardcoding the string literal
57/// in multiple places. Replaces all direct "anthropic" comparisons.
58pub const PROVIDER_ANTHROPIC: &str = "anthropic";
59
60/// Static registry of all supported AI providers
61pub static PROVIDERS: &[ProviderConfig] = &[
62    ProviderConfig {
63        name: "gemini",
64        display_name: "Google Gemini",
65        api_url: "https://generativelanguage.googleapis.com/v1beta/openai/chat/completions",
66        api_key_env: "GEMINI_API_KEY",
67    },
68    ProviderConfig {
69        name: "openrouter",
70        display_name: "OpenRouter",
71        api_url: "https://openrouter.ai/api/v1/chat/completions",
72        api_key_env: "OPENROUTER_API_KEY",
73    },
74    ProviderConfig {
75        name: "groq",
76        display_name: "Groq",
77        api_url: "https://api.groq.com/openai/v1/chat/completions",
78        api_key_env: "GROQ_API_KEY",
79    },
80    ProviderConfig {
81        name: "cerebras",
82        display_name: "Cerebras",
83        api_url: "https://api.cerebras.ai/v1/chat/completions",
84        api_key_env: "CEREBRAS_API_KEY",
85    },
86    ProviderConfig {
87        name: "zenmux",
88        display_name: "Zenmux",
89        api_url: "https://zenmux.ai/api/v1/chat/completions",
90        api_key_env: "ZENMUX_API_KEY",
91    },
92    ProviderConfig {
93        name: "zai",
94        display_name: "Z.AI (Zhipu)",
95        api_url: "https://api.z.ai/api/paas/v4/chat/completions",
96        api_key_env: "ZAI_API_KEY",
97    },
98    ProviderConfig {
99        name: PROVIDER_ANTHROPIC,
100        display_name: "Anthropic",
101        api_url: "https://api.anthropic.com/v1/chat/completions",
102        api_key_env: "ANTHROPIC_API_KEY",
103    },
104];
105
106/// Retrieves a provider configuration by name.
107///
108/// # Arguments
109///
110/// * `name` - The provider name (case-sensitive, lowercase)
111///
112/// # Returns
113///
114/// Some(ProviderConfig) if found, None otherwise.
115///
116/// # Examples
117///
118/// ```
119/// use aptu_core::ai::registry::get_provider;
120///
121/// let provider = get_provider("openrouter");
122/// assert!(provider.is_some());
123/// assert_eq!(provider.unwrap().display_name, "OpenRouter");
124/// ```
125#[must_use]
126pub fn get_provider(name: &str) -> Option<&'static ProviderConfig> {
127    PROVIDERS.iter().find(|p| p.name == name)
128}
129
130/// Returns all available providers.
131///
132/// # Returns
133///
134/// A slice of all `ProviderConfig` entries in the registry.
135///
136/// # Examples
137///
138/// ```
139/// use aptu_core::ai::registry::all_providers;
140///
141/// let providers = all_providers();
142/// assert_eq!(providers.len(), 7);
143/// ```
144#[must_use]
145pub fn all_providers() -> &'static [ProviderConfig] {
146    PROVIDERS
147}
148
149// ============================================================================
150// Runtime Model Validation
151// ============================================================================
152
153/// Error type for model registry operations.
154#[derive(Debug, Error)]
155pub enum RegistryError {
156    /// HTTP request failed.
157    #[error("HTTP request failed: {0}")]
158    HttpError(String),
159
160    /// Failed to parse API response.
161    #[error("Failed to parse API response: {0}")]
162    ParseError(String),
163
164    /// Provider not found.
165    #[error("Provider not found: {0}")]
166    ProviderNotFound(String),
167
168    /// Cache error.
169    #[error("Cache error: {0}")]
170    CacheError(String),
171
172    /// IO error.
173    #[error("IO error: {0}")]
174    IoError(#[from] std::io::Error),
175
176    /// Model validation error - invalid model ID.
177    #[error("Invalid model ID: {model_id}")]
178    ModelValidation {
179        /// The invalid model ID provided by the user.
180        model_id: String,
181    },
182}
183
184/// Model capability indicators.
185#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
186#[serde(rename_all = "snake_case")]
187pub enum Capability {
188    /// Model supports image/vision inputs.
189    Vision,
190    /// Model supports function/tool calling.
191    FunctionCalling,
192    /// Model has extended reasoning capabilities.
193    Reasoning,
194}
195
196/// Raw pricing information for a model (cost per token in USD).
197///
198/// `f64` is used because these values are display-only (never used for
199/// arithmetic or financial calculations). Precision matches what the API
200/// returns in its JSON responses. If cost estimation or budget tracking is
201/// added in the future, migrate to a decimal type such as `rust_decimal`.
202#[derive(Clone, Debug, Serialize, Deserialize)]
203pub struct PricingInfo {
204    /// Cost per prompt token in USD. None if unavailable.
205    pub prompt_per_token: Option<f64>,
206    /// Cost per completion token in USD. None if unavailable.
207    pub completion_per_token: Option<f64>,
208}
209
210/// Cached model information from API responses.
211#[derive(Clone, Debug, Serialize, Deserialize)]
212pub struct CachedModel {
213    /// Model identifier from the provider API.
214    pub id: String,
215    /// Human-readable model name.
216    pub name: Option<String>,
217    /// Whether the model is free to use.
218    pub is_free: Option<bool>,
219    /// Maximum context window size in tokens.
220    pub context_window: Option<u32>,
221    /// Provider name this model belongs to.
222    pub provider: String,
223    /// Model capabilities (e.g., `Vision`, `FunctionCalling`).
224    #[serde(default)]
225    pub capabilities: Vec<Capability>,
226    /// Pricing information for this model.
227    #[serde(default)]
228    pub pricing: Option<PricingInfo>,
229}
230
231/// Trait for runtime model validation and listing.
232#[async_trait]
233pub trait ModelRegistry: Send + Sync {
234    /// List all available models for a provider.
235    async fn list_models(&self, provider: &str) -> Result<Vec<CachedModel>, RegistryError>;
236
237    /// Check if a model exists for a provider.
238    async fn model_exists(&self, provider: &str, model_id: &str) -> Result<bool, RegistryError>;
239
240    /// Validate that a model ID exists for a provider.
241    async fn validate_model(&self, provider: &str, model_id: &str) -> Result<(), RegistryError>;
242}
243
244/// Cached model registry with HTTP client and TTL support.
245pub struct CachedModelRegistry<'a> {
246    cache: crate::cache::FileCacheImpl<Vec<CachedModel>>,
247    client: reqwest::Client,
248    token_provider: &'a dyn TokenProvider,
249}
250
251impl CachedModelRegistry<'_> {
252    /// Create a new cached model registry.
253    ///
254    /// # Arguments
255    ///
256    /// * `cache_dir` - Directory for storing cached model lists (None to disable caching)
257    /// * `ttl_seconds` - Time-to-live for cache entries (see `DEFAULT_MODEL_TTL_SECS`)
258    /// * `token_provider` - Token provider for API credentials
259    #[must_use]
260    pub fn new(
261        cache_dir: Option<PathBuf>,
262        ttl_seconds: u64,
263        token_provider: &dyn TokenProvider,
264    ) -> CachedModelRegistry<'_> {
265        let ttl = chrono::Duration::seconds(
266            ttl_seconds
267                .try_into()
268                .unwrap_or(crate::cache::DEFAULT_MODEL_TTL_SECS.cast_signed()),
269        );
270        CachedModelRegistry {
271            cache: crate::cache::FileCacheImpl::with_dir(cache_dir, "models", ttl),
272            client: reqwest::Client::builder()
273                .timeout(std::time::Duration::from_secs(10))
274                .build()
275                .unwrap_or_else(|_| reqwest::Client::new()),
276            token_provider,
277        }
278    }
279
280    /// Parse `OpenRouter` API response into models.
281    fn parse_openrouter_models(data: &serde_json::Value, provider: &str) -> Vec<CachedModel> {
282        data.get("data")
283            .and_then(|d| d.as_array())
284            .map(|arr| {
285                arr.iter()
286                    .filter_map(|m| {
287                        let pricing_obj = m.get("pricing");
288                        let prompt_per_token = pricing_obj
289                            .and_then(|p| p.get("prompt"))
290                            .and_then(|p| p.as_str())
291                            .and_then(|s| s.parse::<f64>().ok());
292                        let completion_per_token = pricing_obj
293                            .and_then(|p| p.get("completion"))
294                            .and_then(|p| p.as_str())
295                            .and_then(|s| s.parse::<f64>().ok());
296
297                        let is_free = match (prompt_per_token, completion_per_token) {
298                            (Some(prompt), Some(completion)) => {
299                                Some(prompt == 0.0 && completion == 0.0)
300                            }
301                            (Some(prompt), None) => Some(prompt == 0.0),
302                            _ => pricing_obj
303                                .and_then(|p| p.get("prompt"))
304                                .and_then(|p| p.as_str())
305                                .map(|p| p == "0"),
306                        };
307
308                        let pricing =
309                            if prompt_per_token.is_some() || completion_per_token.is_some() {
310                                Some(PricingInfo {
311                                    prompt_per_token,
312                                    completion_per_token,
313                                })
314                            } else {
315                                None
316                            };
317
318                        // Derive capabilities from architecture field defensively
319                        let arch = m.get("architecture");
320                        let capabilities = {
321                            // Check input_modalities array first
322                            let from_input_modalities = arch
323                                .and_then(|a| a.get("input_modalities"))
324                                .and_then(|im| im.as_array())
325                                .map(|arr| {
326                                    arr.iter().filter_map(|v| v.as_str()).any(|s| s == "image")
327                                });
328                            // Fall back to modalities string
329                            let from_modalities_str = arch
330                                .and_then(|a| a.get("modalities"))
331                                .and_then(|m| m.as_str())
332                                .map(|s| s.contains("image"));
333
334                            let has_vision = from_input_modalities
335                                .or(from_modalities_str)
336                                .unwrap_or(false);
337
338                            if has_vision {
339                                vec![Capability::Vision]
340                            } else {
341                                vec![]
342                            }
343                        };
344
345                        Some(CachedModel {
346                            id: m.get("id")?.as_str()?.to_string(),
347                            name: m.get("name").and_then(|n| n.as_str()).map(String::from),
348                            is_free,
349                            context_window: m
350                                .get("context_length")
351                                .and_then(serde_json::Value::as_u64)
352                                .and_then(|c| u32::try_from(c).ok()),
353                            provider: provider.to_string(),
354                            capabilities,
355                            pricing,
356                        })
357                    })
358                    .collect()
359            })
360            .unwrap_or_default()
361    }
362
363    /// Parse Gemini API response into models.
364    fn parse_gemini_models(data: &serde_json::Value, provider: &str) -> Vec<CachedModel> {
365        data.get("models")
366            .and_then(|d| d.as_array())
367            .map(|arr| {
368                arr.iter()
369                    .filter_map(|m| {
370                        Some(CachedModel {
371                            id: m.get("name")?.as_str()?.to_string(),
372                            name: m
373                                .get("displayName")
374                                .and_then(|n| n.as_str())
375                                .map(String::from),
376                            is_free: None,
377                            context_window: m
378                                .get("inputTokenLimit")
379                                .and_then(serde_json::Value::as_u64)
380                                .and_then(|c| u32::try_from(c).ok()),
381                            provider: provider.to_string(),
382                            capabilities: vec![],
383                            pricing: None,
384                        })
385                    })
386                    .collect()
387            })
388            .unwrap_or_default()
389    }
390
391    /// Parse generic OpenAI-compatible API response into models.
392    fn parse_generic_models(data: &serde_json::Value, provider: &str) -> Vec<CachedModel> {
393        data.get("data")
394            .and_then(|d| d.as_array())
395            .map(|arr| {
396                arr.iter()
397                    .filter_map(|m| {
398                        Some(CachedModel {
399                            id: m.get("id")?.as_str()?.to_string(),
400                            name: None,
401                            is_free: None,
402                            context_window: None,
403                            provider: provider.to_string(),
404                            capabilities: vec![],
405                            pricing: None,
406                        })
407                    })
408                    .collect()
409            })
410            .unwrap_or_default()
411    }
412
413    /// Fetch models from provider API.
414    async fn fetch_from_api(&self, provider: &str) -> Result<Vec<CachedModel>, RegistryError> {
415        let url = match provider {
416            "openrouter" => "https://openrouter.ai/api/v1/models",
417            "gemini" => "https://generativelanguage.googleapis.com/v1beta/models",
418            "groq" => "https://api.groq.com/openai/v1/models",
419            "cerebras" => "https://api.cerebras.ai/v1/models",
420            "zenmux" => "https://zenmux.ai/api/v1/models",
421            "zai" => "https://api.z.ai/api/paas/v4/models",
422            _ => return Err(RegistryError::ProviderNotFound(provider.to_string())),
423        };
424
425        // Get API key from token provider
426        let api_key = self.token_provider.ai_api_key(provider).ok_or_else(|| {
427            RegistryError::HttpError(format!("No API key available for {provider}"))
428        })?;
429
430        // Build request incrementally with provider-specific authentication
431        let request = match provider {
432            "gemini" => {
433                // Gemini uses header authentication
434                self.client
435                    .get(url)
436                    .header("x-goog-api-key", api_key.expose_secret())
437            }
438            "openrouter" | "groq" | "cerebras" | "zenmux" | "zai" => {
439                // These providers use Bearer token authentication
440                self.client.get(url).header(
441                    "Authorization",
442                    format!("Bearer {}", api_key.expose_secret()),
443                )
444            }
445            _ => self.client.get(url),
446        };
447
448        let response = request
449            .send()
450            .await
451            .map_err(|e| RegistryError::HttpError(e.to_string()))?;
452
453        let data = response
454            .json::<serde_json::Value>()
455            .await
456            .map_err(|e| RegistryError::HttpError(e.to_string()))?;
457
458        // Parse based on provider API format
459        let models = match provider {
460            "openrouter" => Self::parse_openrouter_models(&data, provider),
461            "gemini" => Self::parse_gemini_models(&data, provider),
462            "groq" | "cerebras" | "zenmux" | "zai" => Self::parse_generic_models(&data, provider),
463            _ => vec![],
464        };
465
466        Ok(models)
467    }
468}
469
470#[async_trait]
471impl ModelRegistry for CachedModelRegistry<'_> {
472    async fn list_models(&self, provider: &str) -> Result<Vec<CachedModel>, RegistryError> {
473        // Try fresh cache first
474        if let Ok(Some(models)) = self.cache.get(provider).await {
475            return Ok(models);
476        }
477
478        // Fetch from API with stale fallback
479        match self.fetch_from_api(provider).await {
480            Ok(models) => {
481                // Save to cache (ignore errors)
482                let _ = self.cache.set(provider, &models).await;
483                Ok(models)
484            }
485            Err(api_error) => {
486                // Try stale cache as fallback
487                match self.cache.get_stale(provider).await {
488                    Ok(Some(models)) => {
489                        tracing::warn!(
490                            provider = provider,
491                            error = %api_error,
492                            "API request failed, returning stale cached models"
493                        );
494                        Ok(models)
495                    }
496                    _ => {
497                        // No stale cache available, return original API error
498                        Err(api_error)
499                    }
500                }
501            }
502        }
503    }
504
505    async fn model_exists(&self, provider: &str, model_id: &str) -> Result<bool, RegistryError> {
506        let models = self.list_models(provider).await?;
507        Ok(models.iter().any(|m| m.id == model_id))
508    }
509
510    async fn validate_model(&self, provider: &str, model_id: &str) -> Result<(), RegistryError> {
511        if self.model_exists(provider, model_id).await? {
512            Ok(())
513        } else {
514            Err(RegistryError::ModelValidation {
515                model_id: model_id.to_string(),
516            })
517        }
518    }
519}
520
521#[cfg(test)]
522mod tests {
523    use super::*;
524
525    #[test]
526    fn test_get_provider_gemini() {
527        let provider = get_provider("gemini");
528        assert!(provider.is_some());
529        let provider = provider.unwrap();
530        assert_eq!(provider.display_name, "Google Gemini");
531        assert_eq!(provider.api_key_env, "GEMINI_API_KEY");
532    }
533
534    #[test]
535    fn test_get_provider_openrouter() {
536        let provider = get_provider("openrouter");
537        assert!(provider.is_some());
538        let provider = provider.unwrap();
539        assert_eq!(provider.display_name, "OpenRouter");
540        assert_eq!(provider.api_key_env, "OPENROUTER_API_KEY");
541    }
542
543    #[test]
544    fn test_get_provider_groq() {
545        let provider = get_provider("groq");
546        assert!(provider.is_some());
547        let provider = provider.unwrap();
548        assert_eq!(provider.display_name, "Groq");
549        assert_eq!(provider.api_key_env, "GROQ_API_KEY");
550    }
551
552    #[test]
553    fn test_get_provider_cerebras() {
554        let provider = get_provider("cerebras");
555        assert!(provider.is_some());
556        let provider = provider.unwrap();
557        assert_eq!(provider.display_name, "Cerebras");
558        assert_eq!(provider.api_key_env, "CEREBRAS_API_KEY");
559    }
560
561    #[test]
562    fn test_get_provider_not_found() {
563        let provider = get_provider("nonexistent");
564        assert!(provider.is_none());
565    }
566
567    #[test]
568    fn test_get_provider_case_sensitive() {
569        let provider = get_provider("OpenRouter");
570        assert!(
571            provider.is_none(),
572            "Provider lookup should be case-sensitive"
573        );
574    }
575
576    #[test]
577    fn test_all_providers_count() {
578        let providers = all_providers();
579        assert_eq!(providers.len(), 7, "Should have exactly 7 providers");
580    }
581
582    #[test]
583    fn test_all_providers_have_unique_names() {
584        let providers = all_providers();
585        let mut names = Vec::new();
586        for provider in providers {
587            assert!(
588                !names.contains(&provider.name),
589                "Duplicate provider name: {}",
590                provider.name
591            );
592            names.push(provider.name);
593        }
594    }
595
596    #[test]
597    fn test_get_provider_zenmux() {
598        let provider = get_provider("zenmux");
599        assert!(provider.is_some());
600        let provider = provider.unwrap();
601        assert_eq!(provider.display_name, "Zenmux");
602        assert_eq!(provider.api_key_env, "ZENMUX_API_KEY");
603    }
604
605    #[test]
606    fn test_get_provider_zai() {
607        let provider = get_provider("zai");
608        assert!(provider.is_some());
609        let provider = provider.unwrap();
610        assert_eq!(provider.display_name, "Z.AI (Zhipu)");
611        assert_eq!(provider.api_key_env, "ZAI_API_KEY");
612    }
613
614    #[test]
615    fn test_provider_api_urls_valid() {
616        let providers = all_providers();
617        for provider in providers {
618            assert!(
619                provider.api_url.starts_with("https://"),
620                "Provider {} API URL should use HTTPS",
621                provider.name
622            );
623        }
624    }
625
626    #[test]
627    fn test_provider_api_key_env_not_empty() {
628        let providers = all_providers();
629        for provider in providers {
630            assert!(
631                !provider.api_key_env.is_empty(),
632                "Provider {} should have API key env var",
633                provider.name
634            );
635        }
636    }
637
638    #[test]
639    fn test_parse_openrouter_models_with_pricing() {
640        let data = serde_json::json!({
641            "data": [
642                {
643                    "id": "openai/gpt-4o",
644                    "name": "GPT-4o",
645                    "context_length": 128_000,
646                    "pricing": {
647                        "prompt": "0.000005",
648                        "completion": "0.000015"
649                    },
650                    "architecture": {
651                        "input_modalities": ["text", "image"],
652                        "output_modalities": ["text"]
653                    }
654                }
655            ]
656        });
657
658        let models = CachedModelRegistry::parse_openrouter_models(&data, "openrouter");
659        assert_eq!(models.len(), 1);
660        let m = &models[0];
661        assert_eq!(m.id, "openai/gpt-4o");
662        assert_eq!(m.is_free, Some(false));
663        let pricing = m.pricing.as_ref().expect("pricing should be present");
664        assert_eq!(pricing.prompt_per_token, Some(0.000_005));
665        assert_eq!(pricing.completion_per_token, Some(0.000_015));
666        assert!(m.capabilities.contains(&Capability::Vision));
667    }
668
669    #[test]
670    fn test_parse_openrouter_models_missing_capabilities() {
671        let data = serde_json::json!({
672            "data": [
673                {
674                    "id": "some/text-only-model",
675                    "name": "Text Only",
676                    "context_length": 32000,
677                    "pricing": {
678                        "prompt": "0",
679                        "completion": "0"
680                    }
681                }
682            ]
683        });
684
685        let models = CachedModelRegistry::parse_openrouter_models(&data, "openrouter");
686        assert_eq!(models.len(), 1);
687        let m = &models[0];
688        assert!(
689            m.capabilities.is_empty(),
690            "no vision if architecture missing"
691        );
692        assert_eq!(m.is_free, Some(true));
693    }
694}