Skip to main content

aptu_core/ai/
registry.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Centralized provider configuration registry.
4//!
5//! This module provides a static registry of all AI providers supported by Aptu,
6//! including their metadata, API endpoints, and available models.
7//!
8//! It also provides runtime model validation infrastructure via the `ModelRegistry` trait
9//! with a simple sync implementation using static model lists.
10//!
11//! # Examples
12//!
13//! ```
14//! use aptu_core::ai::registry::{get_provider, all_providers};
15//!
16//! // Get a specific provider
17//! let provider = get_provider("openrouter");
18//! assert!(provider.is_some());
19//!
20//! // Get all providers
21//! let providers = all_providers();
22//! assert_eq!(providers.len(), 6);
23//! ```
24
25use async_trait::async_trait;
26use secrecy::ExposeSecret;
27use serde::{Deserialize, Serialize};
28use std::path::PathBuf;
29use thiserror::Error;
30
31use crate::auth::TokenProvider;
32use crate::cache::FileCache;
33
34/// Configuration for an AI provider.
35#[derive(Clone, Copy, Debug, PartialEq, Eq)]
36pub struct ProviderConfig {
37    /// Provider identifier (lowercase, used in config files)
38    pub name: &'static str,
39
40    /// Human-readable provider name for UI display
41    pub display_name: &'static str,
42
43    /// API base URL for this provider
44    pub api_url: &'static str,
45
46    /// Environment variable name for API key
47    pub api_key_env: &'static str,
48}
49
50// ============================================================================
51// Provider Registry
52// ============================================================================
53
54/// Static registry of all supported AI providers
55pub static PROVIDERS: &[ProviderConfig] = &[
56    ProviderConfig {
57        name: "gemini",
58        display_name: "Google Gemini",
59        api_url: "https://generativelanguage.googleapis.com/v1beta/openai/chat/completions",
60        api_key_env: "GEMINI_API_KEY",
61    },
62    ProviderConfig {
63        name: "openrouter",
64        display_name: "OpenRouter",
65        api_url: "https://openrouter.ai/api/v1/chat/completions",
66        api_key_env: "OPENROUTER_API_KEY",
67    },
68    ProviderConfig {
69        name: "groq",
70        display_name: "Groq",
71        api_url: "https://api.groq.com/openai/v1/chat/completions",
72        api_key_env: "GROQ_API_KEY",
73    },
74    ProviderConfig {
75        name: "cerebras",
76        display_name: "Cerebras",
77        api_url: "https://api.cerebras.ai/v1/chat/completions",
78        api_key_env: "CEREBRAS_API_KEY",
79    },
80    ProviderConfig {
81        name: "zenmux",
82        display_name: "Zenmux",
83        api_url: "https://zenmux.ai/api/v1/chat/completions",
84        api_key_env: "ZENMUX_API_KEY",
85    },
86    ProviderConfig {
87        name: "zai",
88        display_name: "Z.AI (Zhipu)",
89        api_url: "https://api.z.ai/api/paas/v4/chat/completions",
90        api_key_env: "ZAI_API_KEY",
91    },
92];
93
94/// Retrieves a provider configuration by name.
95///
96/// # Arguments
97///
98/// * `name` - The provider name (case-sensitive, lowercase)
99///
100/// # Returns
101///
102/// Some(ProviderConfig) if found, None otherwise.
103///
104/// # Examples
105///
106/// ```
107/// use aptu_core::ai::registry::get_provider;
108///
109/// let provider = get_provider("openrouter");
110/// assert!(provider.is_some());
111/// assert_eq!(provider.unwrap().display_name, "OpenRouter");
112/// ```
113#[must_use]
114pub fn get_provider(name: &str) -> Option<&'static ProviderConfig> {
115    PROVIDERS.iter().find(|p| p.name == name)
116}
117
118/// Returns all available providers.
119///
120/// # Returns
121///
122/// A slice of all `ProviderConfig` entries in the registry.
123///
124/// # Examples
125///
126/// ```
127/// use aptu_core::ai::registry::all_providers;
128///
129/// let providers = all_providers();
130/// assert_eq!(providers.len(), 6);
131/// ```
132#[must_use]
133pub fn all_providers() -> &'static [ProviderConfig] {
134    PROVIDERS
135}
136
137// ============================================================================
138// Runtime Model Validation
139// ============================================================================
140
141/// Error type for model registry operations.
142#[derive(Debug, Error)]
143pub enum RegistryError {
144    /// HTTP request failed.
145    #[error("HTTP request failed: {0}")]
146    HttpError(String),
147
148    /// Failed to parse API response.
149    #[error("Failed to parse API response: {0}")]
150    ParseError(String),
151
152    /// Provider not found.
153    #[error("Provider not found: {0}")]
154    ProviderNotFound(String),
155
156    /// Cache error.
157    #[error("Cache error: {0}")]
158    CacheError(String),
159
160    /// IO error.
161    #[error("IO error: {0}")]
162    IoError(#[from] std::io::Error),
163
164    /// Model validation error - invalid model ID.
165    #[error("Invalid model ID: {model_id}")]
166    ModelValidation {
167        /// The invalid model ID provided by the user.
168        model_id: String,
169    },
170}
171
172/// Model capability indicators.
173#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
174#[serde(rename_all = "snake_case")]
175pub enum Capability {
176    /// Model supports image/vision inputs.
177    Vision,
178    /// Model supports function/tool calling.
179    FunctionCalling,
180    /// Model has extended reasoning capabilities.
181    Reasoning,
182}
183
184/// Raw pricing information for a model (cost per token in USD).
185///
186/// `f64` is used because these values are display-only (never used for
187/// arithmetic or financial calculations). Precision matches what the API
188/// returns in its JSON responses. If cost estimation or budget tracking is
189/// added in the future, migrate to a decimal type such as `rust_decimal`.
190#[derive(Clone, Debug, Serialize, Deserialize)]
191pub struct PricingInfo {
192    /// Cost per prompt token in USD. None if unavailable.
193    pub prompt_per_token: Option<f64>,
194    /// Cost per completion token in USD. None if unavailable.
195    pub completion_per_token: Option<f64>,
196}
197
198/// Cached model information from API responses.
199#[derive(Clone, Debug, Serialize, Deserialize)]
200pub struct CachedModel {
201    /// Model identifier from the provider API.
202    pub id: String,
203    /// Human-readable model name.
204    pub name: Option<String>,
205    /// Whether the model is free to use.
206    pub is_free: Option<bool>,
207    /// Maximum context window size in tokens.
208    pub context_window: Option<u32>,
209    /// Provider name this model belongs to.
210    pub provider: String,
211    /// Model capabilities (e.g., `Vision`, `FunctionCalling`).
212    #[serde(default)]
213    pub capabilities: Vec<Capability>,
214    /// Pricing information for this model.
215    #[serde(default)]
216    pub pricing: Option<PricingInfo>,
217}
218
219/// Trait for runtime model validation and listing.
220#[async_trait]
221pub trait ModelRegistry: Send + Sync {
222    /// List all available models for a provider.
223    async fn list_models(&self, provider: &str) -> Result<Vec<CachedModel>, RegistryError>;
224
225    /// Check if a model exists for a provider.
226    async fn model_exists(&self, provider: &str, model_id: &str) -> Result<bool, RegistryError>;
227
228    /// Validate that a model ID exists for a provider.
229    async fn validate_model(&self, provider: &str, model_id: &str) -> Result<(), RegistryError>;
230}
231
232/// Cached model registry with HTTP client and TTL support.
233pub struct CachedModelRegistry<'a> {
234    cache: crate::cache::FileCacheImpl<Vec<CachedModel>>,
235    client: reqwest::Client,
236    token_provider: &'a dyn TokenProvider,
237}
238
239impl CachedModelRegistry<'_> {
240    /// Create a new cached model registry.
241    ///
242    /// # Arguments
243    ///
244    /// * `cache_dir` - Directory for storing cached model lists (None to disable caching)
245    /// * `ttl_seconds` - Time-to-live for cache entries (see `DEFAULT_MODEL_TTL_SECS`)
246    /// * `token_provider` - Token provider for API credentials
247    #[must_use]
248    pub fn new(
249        cache_dir: Option<PathBuf>,
250        ttl_seconds: u64,
251        token_provider: &dyn TokenProvider,
252    ) -> CachedModelRegistry<'_> {
253        let ttl = chrono::Duration::seconds(
254            ttl_seconds
255                .try_into()
256                .unwrap_or(crate::cache::DEFAULT_MODEL_TTL_SECS.cast_signed()),
257        );
258        CachedModelRegistry {
259            cache: crate::cache::FileCacheImpl::with_dir(cache_dir, "models", ttl),
260            client: reqwest::Client::builder()
261                .timeout(std::time::Duration::from_secs(10))
262                .build()
263                .unwrap_or_else(|_| reqwest::Client::new()),
264            token_provider,
265        }
266    }
267
268    /// Parse `OpenRouter` API response into models.
269    fn parse_openrouter_models(data: &serde_json::Value, provider: &str) -> Vec<CachedModel> {
270        data.get("data")
271            .and_then(|d| d.as_array())
272            .map(|arr| {
273                arr.iter()
274                    .filter_map(|m| {
275                        let pricing_obj = m.get("pricing");
276                        let prompt_per_token = pricing_obj
277                            .and_then(|p| p.get("prompt"))
278                            .and_then(|p| p.as_str())
279                            .and_then(|s| s.parse::<f64>().ok());
280                        let completion_per_token = pricing_obj
281                            .and_then(|p| p.get("completion"))
282                            .and_then(|p| p.as_str())
283                            .and_then(|s| s.parse::<f64>().ok());
284
285                        let is_free = match (prompt_per_token, completion_per_token) {
286                            (Some(prompt), Some(completion)) => {
287                                Some(prompt == 0.0 && completion == 0.0)
288                            }
289                            (Some(prompt), None) => Some(prompt == 0.0),
290                            _ => pricing_obj
291                                .and_then(|p| p.get("prompt"))
292                                .and_then(|p| p.as_str())
293                                .map(|p| p == "0"),
294                        };
295
296                        let pricing =
297                            if prompt_per_token.is_some() || completion_per_token.is_some() {
298                                Some(PricingInfo {
299                                    prompt_per_token,
300                                    completion_per_token,
301                                })
302                            } else {
303                                None
304                            };
305
306                        // Derive capabilities from architecture field defensively
307                        let arch = m.get("architecture");
308                        let capabilities = {
309                            // Check input_modalities array first
310                            let from_input_modalities = arch
311                                .and_then(|a| a.get("input_modalities"))
312                                .and_then(|im| im.as_array())
313                                .map(|arr| {
314                                    arr.iter().filter_map(|v| v.as_str()).any(|s| s == "image")
315                                });
316                            // Fall back to modalities string
317                            let from_modalities_str = arch
318                                .and_then(|a| a.get("modalities"))
319                                .and_then(|m| m.as_str())
320                                .map(|s| s.contains("image"));
321
322                            let has_vision = from_input_modalities
323                                .or(from_modalities_str)
324                                .unwrap_or(false);
325
326                            if has_vision {
327                                vec![Capability::Vision]
328                            } else {
329                                vec![]
330                            }
331                        };
332
333                        Some(CachedModel {
334                            id: m.get("id")?.as_str()?.to_string(),
335                            name: m.get("name").and_then(|n| n.as_str()).map(String::from),
336                            is_free,
337                            context_window: m
338                                .get("context_length")
339                                .and_then(serde_json::Value::as_u64)
340                                .and_then(|c| u32::try_from(c).ok()),
341                            provider: provider.to_string(),
342                            capabilities,
343                            pricing,
344                        })
345                    })
346                    .collect()
347            })
348            .unwrap_or_default()
349    }
350
351    /// Parse Gemini API response into models.
352    fn parse_gemini_models(data: &serde_json::Value, provider: &str) -> Vec<CachedModel> {
353        data.get("models")
354            .and_then(|d| d.as_array())
355            .map(|arr| {
356                arr.iter()
357                    .filter_map(|m| {
358                        Some(CachedModel {
359                            id: m.get("name")?.as_str()?.to_string(),
360                            name: m
361                                .get("displayName")
362                                .and_then(|n| n.as_str())
363                                .map(String::from),
364                            is_free: None,
365                            context_window: m
366                                .get("inputTokenLimit")
367                                .and_then(serde_json::Value::as_u64)
368                                .and_then(|c| u32::try_from(c).ok()),
369                            provider: provider.to_string(),
370                            capabilities: vec![],
371                            pricing: None,
372                        })
373                    })
374                    .collect()
375            })
376            .unwrap_or_default()
377    }
378
379    /// Parse generic OpenAI-compatible API response into models.
380    fn parse_generic_models(data: &serde_json::Value, provider: &str) -> Vec<CachedModel> {
381        data.get("data")
382            .and_then(|d| d.as_array())
383            .map(|arr| {
384                arr.iter()
385                    .filter_map(|m| {
386                        Some(CachedModel {
387                            id: m.get("id")?.as_str()?.to_string(),
388                            name: None,
389                            is_free: None,
390                            context_window: None,
391                            provider: provider.to_string(),
392                            capabilities: vec![],
393                            pricing: None,
394                        })
395                    })
396                    .collect()
397            })
398            .unwrap_or_default()
399    }
400
401    /// Fetch models from provider API.
402    async fn fetch_from_api(&self, provider: &str) -> Result<Vec<CachedModel>, RegistryError> {
403        let url = match provider {
404            "openrouter" => "https://openrouter.ai/api/v1/models",
405            "gemini" => "https://generativelanguage.googleapis.com/v1beta/models",
406            "groq" => "https://api.groq.com/openai/v1/models",
407            "cerebras" => "https://api.cerebras.ai/v1/models",
408            "zenmux" => "https://zenmux.ai/api/v1/models",
409            "zai" => "https://api.z.ai/api/paas/v4/models",
410            _ => return Err(RegistryError::ProviderNotFound(provider.to_string())),
411        };
412
413        // Get API key from token provider
414        let api_key = self.token_provider.ai_api_key(provider).ok_or_else(|| {
415            RegistryError::HttpError(format!("No API key available for {provider}"))
416        })?;
417
418        // Build request incrementally with provider-specific authentication
419        let request = match provider {
420            "gemini" => {
421                // Gemini uses query parameter authentication
422                self.client
423                    .get(url)
424                    .query(&[("key", api_key.expose_secret())])
425            }
426            "openrouter" | "groq" | "cerebras" | "zenmux" | "zai" => {
427                // These providers use Bearer token authentication
428                self.client.get(url).header(
429                    "Authorization",
430                    format!("Bearer {}", api_key.expose_secret()),
431                )
432            }
433            _ => self.client.get(url),
434        };
435
436        let response = request
437            .send()
438            .await
439            .map_err(|e| RegistryError::HttpError(e.to_string()))?;
440
441        let data = response
442            .json::<serde_json::Value>()
443            .await
444            .map_err(|e| RegistryError::HttpError(e.to_string()))?;
445
446        // Parse based on provider API format
447        let models = match provider {
448            "openrouter" => Self::parse_openrouter_models(&data, provider),
449            "gemini" => Self::parse_gemini_models(&data, provider),
450            "groq" | "cerebras" | "zenmux" | "zai" => Self::parse_generic_models(&data, provider),
451            _ => vec![],
452        };
453
454        Ok(models)
455    }
456}
457
458#[async_trait]
459impl ModelRegistry for CachedModelRegistry<'_> {
460    async fn list_models(&self, provider: &str) -> Result<Vec<CachedModel>, RegistryError> {
461        // Try fresh cache first
462        if let Ok(Some(models)) = self.cache.get(provider) {
463            return Ok(models);
464        }
465
466        // Fetch from API with stale fallback
467        match self.fetch_from_api(provider).await {
468            Ok(models) => {
469                // Save to cache (ignore errors)
470                let _ = self.cache.set(provider, &models);
471                Ok(models)
472            }
473            Err(api_error) => {
474                // Try stale cache as fallback
475                match self.cache.get_stale(provider) {
476                    Ok(Some(models)) => {
477                        tracing::warn!(
478                            provider = provider,
479                            error = %api_error,
480                            "API request failed, returning stale cached models"
481                        );
482                        Ok(models)
483                    }
484                    _ => {
485                        // No stale cache available, return original API error
486                        Err(api_error)
487                    }
488                }
489            }
490        }
491    }
492
493    async fn model_exists(&self, provider: &str, model_id: &str) -> Result<bool, RegistryError> {
494        let models = self.list_models(provider).await?;
495        Ok(models.iter().any(|m| m.id == model_id))
496    }
497
498    async fn validate_model(&self, provider: &str, model_id: &str) -> Result<(), RegistryError> {
499        if self.model_exists(provider, model_id).await? {
500            Ok(())
501        } else {
502            Err(RegistryError::ModelValidation {
503                model_id: model_id.to_string(),
504            })
505        }
506    }
507}
508
509#[cfg(test)]
510mod tests {
511    use super::*;
512
513    #[test]
514    fn test_get_provider_gemini() {
515        let provider = get_provider("gemini");
516        assert!(provider.is_some());
517        let provider = provider.unwrap();
518        assert_eq!(provider.display_name, "Google Gemini");
519        assert_eq!(provider.api_key_env, "GEMINI_API_KEY");
520    }
521
522    #[test]
523    fn test_get_provider_openrouter() {
524        let provider = get_provider("openrouter");
525        assert!(provider.is_some());
526        let provider = provider.unwrap();
527        assert_eq!(provider.display_name, "OpenRouter");
528        assert_eq!(provider.api_key_env, "OPENROUTER_API_KEY");
529    }
530
531    #[test]
532    fn test_get_provider_groq() {
533        let provider = get_provider("groq");
534        assert!(provider.is_some());
535        let provider = provider.unwrap();
536        assert_eq!(provider.display_name, "Groq");
537        assert_eq!(provider.api_key_env, "GROQ_API_KEY");
538    }
539
540    #[test]
541    fn test_get_provider_cerebras() {
542        let provider = get_provider("cerebras");
543        assert!(provider.is_some());
544        let provider = provider.unwrap();
545        assert_eq!(provider.display_name, "Cerebras");
546        assert_eq!(provider.api_key_env, "CEREBRAS_API_KEY");
547    }
548
549    #[test]
550    fn test_get_provider_not_found() {
551        let provider = get_provider("nonexistent");
552        assert!(provider.is_none());
553    }
554
555    #[test]
556    fn test_get_provider_case_sensitive() {
557        let provider = get_provider("OpenRouter");
558        assert!(
559            provider.is_none(),
560            "Provider lookup should be case-sensitive"
561        );
562    }
563
564    #[test]
565    fn test_all_providers_count() {
566        let providers = all_providers();
567        assert_eq!(providers.len(), 6, "Should have exactly 6 providers");
568    }
569
570    #[test]
571    fn test_all_providers_have_unique_names() {
572        let providers = all_providers();
573        let mut names = Vec::new();
574        for provider in providers {
575            assert!(
576                !names.contains(&provider.name),
577                "Duplicate provider name: {}",
578                provider.name
579            );
580            names.push(provider.name);
581        }
582    }
583
584    #[test]
585    fn test_get_provider_zenmux() {
586        let provider = get_provider("zenmux");
587        assert!(provider.is_some());
588        let provider = provider.unwrap();
589        assert_eq!(provider.display_name, "Zenmux");
590        assert_eq!(provider.api_key_env, "ZENMUX_API_KEY");
591    }
592
593    #[test]
594    fn test_get_provider_zai() {
595        let provider = get_provider("zai");
596        assert!(provider.is_some());
597        let provider = provider.unwrap();
598        assert_eq!(provider.display_name, "Z.AI (Zhipu)");
599        assert_eq!(provider.api_key_env, "ZAI_API_KEY");
600    }
601
602    #[test]
603    fn test_provider_api_urls_valid() {
604        let providers = all_providers();
605        for provider in providers {
606            assert!(
607                provider.api_url.starts_with("https://"),
608                "Provider {} API URL should use HTTPS",
609                provider.name
610            );
611        }
612    }
613
614    #[test]
615    fn test_provider_api_key_env_not_empty() {
616        let providers = all_providers();
617        for provider in providers {
618            assert!(
619                !provider.api_key_env.is_empty(),
620                "Provider {} should have API key env var",
621                provider.name
622            );
623        }
624    }
625
626    #[test]
627    fn test_parse_openrouter_models_with_pricing() {
628        let data = serde_json::json!({
629            "data": [
630                {
631                    "id": "openai/gpt-4o",
632                    "name": "GPT-4o",
633                    "context_length": 128000,
634                    "pricing": {
635                        "prompt": "0.000005",
636                        "completion": "0.000015"
637                    },
638                    "architecture": {
639                        "input_modalities": ["text", "image"],
640                        "output_modalities": ["text"]
641                    }
642                }
643            ]
644        });
645
646        let models = CachedModelRegistry::parse_openrouter_models(&data, "openrouter");
647        assert_eq!(models.len(), 1);
648        let m = &models[0];
649        assert_eq!(m.id, "openai/gpt-4o");
650        assert_eq!(m.is_free, Some(false));
651        let pricing = m.pricing.as_ref().expect("pricing should be present");
652        assert_eq!(pricing.prompt_per_token, Some(0.000_005));
653        assert_eq!(pricing.completion_per_token, Some(0.000_015));
654        assert!(m.capabilities.contains(&Capability::Vision));
655    }
656
657    #[test]
658    fn test_parse_openrouter_models_missing_capabilities() {
659        let data = serde_json::json!({
660            "data": [
661                {
662                    "id": "some/text-only-model",
663                    "name": "Text Only",
664                    "context_length": 32000,
665                    "pricing": {
666                        "prompt": "0",
667                        "completion": "0"
668                    }
669                }
670            ]
671        });
672
673        let models = CachedModelRegistry::parse_openrouter_models(&data, "openrouter");
674        assert_eq!(models.len(), 1);
675        let m = &models[0];
676        assert!(
677            m.capabilities.is_empty(),
678            "no vision if architecture missing"
679        );
680        assert_eq!(m.is_free, Some(true));
681    }
682}