Skip to main content

aptu_core/ai/
registry.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Centralized provider configuration registry.
4//!
5//! This module provides a static registry of all AI providers supported by Aptu,
6//! including their metadata, API endpoints, and available models.
7//!
8//! It also provides runtime model validation infrastructure via the `ModelRegistry` trait
9//! with a simple sync implementation using static model lists.
10//!
11//! # Examples
12//!
13//! ```
14//! use aptu_core::ai::registry::{get_provider, all_providers};
15//!
16//! // Get a specific provider
17//! let provider = get_provider("openrouter");
18//! assert!(provider.is_some());
19//!
20//! // Get all providers
21//! let providers = all_providers();
22//! assert_eq!(providers.len(), 7);
23//! ```
24
25use async_trait::async_trait;
26use secrecy::ExposeSecret;
27use serde::{Deserialize, Serialize};
28use std::path::PathBuf;
29use thiserror::Error;
30
31use crate::auth::TokenProvider;
32use crate::cache::FileCache;
33
34/// Configuration for an AI provider.
35#[derive(Clone, Copy, Debug, PartialEq, Eq)]
36pub struct ProviderConfig {
37    /// Provider identifier (lowercase, used in config files)
38    pub name: &'static str,
39
40    /// Human-readable provider name for UI display
41    pub display_name: &'static str,
42
43    /// API base URL for this provider
44    pub api_url: &'static str,
45
46    /// Environment variable name for API key
47    pub api_key_env: &'static str,
48}
49
50// ============================================================================
51// Provider Registry
52// ============================================================================
53
54/// Provider name constant for Anthropic.
55///
56/// Used throughout the codebase to avoid hardcoding the string literal
57/// in multiple places. Replaces all direct "anthropic" comparisons.
58pub const PROVIDER_ANTHROPIC: &str = "anthropic";
59
60/// Static registry of all supported AI providers
61pub static PROVIDERS: &[ProviderConfig] = &[
62    ProviderConfig {
63        name: "gemini",
64        display_name: "Google Gemini",
65        api_url: "https://generativelanguage.googleapis.com/v1beta/openai/chat/completions",
66        api_key_env: "GEMINI_API_KEY",
67    },
68    ProviderConfig {
69        name: "openrouter",
70        display_name: "OpenRouter",
71        api_url: "https://openrouter.ai/api/v1/chat/completions",
72        api_key_env: "OPENROUTER_API_KEY",
73    },
74    ProviderConfig {
75        name: "groq",
76        display_name: "Groq",
77        api_url: "https://api.groq.com/openai/v1/chat/completions",
78        api_key_env: "GROQ_API_KEY",
79    },
80    ProviderConfig {
81        name: "cerebras",
82        display_name: "Cerebras",
83        api_url: "https://api.cerebras.ai/v1/chat/completions",
84        api_key_env: "CEREBRAS_API_KEY",
85    },
86    ProviderConfig {
87        name: "zenmux",
88        display_name: "Zenmux",
89        api_url: "https://zenmux.ai/api/v1/chat/completions",
90        api_key_env: "ZENMUX_API_KEY",
91    },
92    ProviderConfig {
93        name: "zai",
94        display_name: "Z.AI (Zhipu)",
95        api_url: "https://api.z.ai/api/paas/v4/chat/completions",
96        api_key_env: "ZAI_API_KEY",
97    },
98    ProviderConfig {
99        name: PROVIDER_ANTHROPIC,
100        display_name: "Anthropic",
101        api_url: "https://api.anthropic.com/v1/chat/completions",
102        api_key_env: "ANTHROPIC_API_KEY",
103    },
104];
105
106/// Retrieves a provider configuration by name.
107///
108/// # Arguments
109///
110/// * `name` - The provider name (case-sensitive, lowercase)
111///
112/// # Returns
113///
114/// Some(ProviderConfig) if found, None otherwise.
115///
116/// # Examples
117///
118/// ```
119/// use aptu_core::ai::registry::get_provider;
120///
121/// let provider = get_provider("openrouter");
122/// assert!(provider.is_some());
123/// assert_eq!(provider.unwrap().display_name, "OpenRouter");
124/// ```
125#[must_use]
126pub fn get_provider(name: &str) -> Option<&'static ProviderConfig> {
127    PROVIDERS.iter().find(|p| p.name == name)
128}
129
130/// Returns all available providers.
131///
132/// # Returns
133///
134/// A slice of all `ProviderConfig` entries in the registry.
135///
136/// # Examples
137///
138/// ```
139/// use aptu_core::ai::registry::all_providers;
140///
141/// let providers = all_providers();
142/// assert_eq!(providers.len(), 7);
143/// ```
144#[must_use]
145pub fn all_providers() -> &'static [ProviderConfig] {
146    PROVIDERS
147}
148
149// ============================================================================
150// Runtime Model Validation
151// ============================================================================
152
153/// Error type for model registry operations.
154#[derive(Debug, Error)]
155pub enum RegistryError {
156    /// HTTP request failed.
157    #[error("HTTP request failed: {0}")]
158    HttpError(String),
159
160    /// Failed to parse API response.
161    #[error("Failed to parse API response: {0}")]
162    ParseError(String),
163
164    /// Provider not found.
165    #[error("Provider not found: {0}")]
166    ProviderNotFound(String),
167
168    /// Cache error.
169    #[error("Cache error: {0}")]
170    CacheError(String),
171
172    /// IO error.
173    #[error("IO error: {0}")]
174    IoError(#[from] std::io::Error),
175
176    /// Model validation error - invalid model ID.
177    #[error("Invalid model ID: {model_id}")]
178    ModelValidation {
179        /// The invalid model ID provided by the user.
180        model_id: String,
181    },
182}
183
184/// Model capability indicators.
185#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
186#[serde(rename_all = "snake_case")]
187pub enum Capability {
188    /// Model supports image/vision inputs.
189    Vision,
190    /// Model supports function/tool calling.
191    FunctionCalling,
192    /// Model has extended reasoning capabilities.
193    Reasoning,
194}
195
196/// Raw pricing information for a model (cost per token in USD).
197///
198/// `f64` is used because these values are display-only (never used for
199/// arithmetic or financial calculations). Precision matches what the API
200/// returns in its JSON responses. If cost estimation or budget tracking is
201/// added in the future, migrate to a decimal type such as `rust_decimal`.
202#[derive(Clone, Debug, Serialize, Deserialize)]
203pub struct PricingInfo {
204    /// Cost per prompt token in USD. None if unavailable.
205    pub prompt_per_token: Option<f64>,
206    /// Cost per completion token in USD. None if unavailable.
207    pub completion_per_token: Option<f64>,
208}
209
210/// Cached model information from API responses.
211#[derive(Clone, Debug, Serialize, Deserialize)]
212pub struct CachedModel {
213    /// Model identifier from the provider API.
214    pub id: String,
215    /// Human-readable model name.
216    pub name: Option<String>,
217    /// Whether the model is free to use.
218    pub is_free: Option<bool>,
219    /// Maximum context window size in tokens.
220    pub context_window: Option<u32>,
221    /// Provider name this model belongs to.
222    pub provider: String,
223    /// Model capabilities (e.g., `Vision`, `FunctionCalling`).
224    #[serde(default)]
225    pub capabilities: Vec<Capability>,
226    /// Pricing information for this model.
227    #[serde(default)]
228    pub pricing: Option<PricingInfo>,
229}
230
231/// Trait for runtime model validation and listing.
232#[async_trait]
233pub trait ModelRegistry: Send + Sync {
234    /// List all available models for a provider.
235    async fn list_models(&self, provider: &str) -> Result<Vec<CachedModel>, RegistryError>;
236
237    /// Check if a model exists for a provider.
238    async fn model_exists(&self, provider: &str, model_id: &str) -> Result<bool, RegistryError>;
239
240    /// Validate that a model ID exists for a provider.
241    async fn validate_model(&self, provider: &str, model_id: &str) -> Result<(), RegistryError>;
242}
243
244/// Cached model registry with HTTP client and TTL support.
245#[cfg(not(target_arch = "wasm32"))]
246pub struct CachedModelRegistry<'a> {
247    cache: crate::cache::FileCacheImpl<Vec<CachedModel>>,
248    client: reqwest::Client,
249    token_provider: &'a dyn TokenProvider,
250}
251
252#[cfg(not(target_arch = "wasm32"))]
253impl CachedModelRegistry<'_> {
254    /// Create a new cached model registry.
255    ///
256    /// # Arguments
257    ///
258    /// * `cache_dir` - Directory for storing cached model lists (None to disable caching)
259    /// * `ttl_seconds` - Time-to-live for cache entries (see `DEFAULT_MODEL_TTL_SECS`)
260    /// * `token_provider` - Token provider for API credentials
261    #[must_use]
262    pub fn new(
263        cache_dir: Option<PathBuf>,
264        ttl_seconds: u64,
265        token_provider: &dyn TokenProvider,
266    ) -> CachedModelRegistry<'_> {
267        let ttl = chrono::Duration::seconds(
268            ttl_seconds
269                .try_into()
270                .unwrap_or(crate::cache::DEFAULT_MODEL_TTL_SECS.cast_signed()),
271        );
272        CachedModelRegistry {
273            cache: crate::cache::FileCacheImpl::with_dir(cache_dir, "models", ttl),
274            client: reqwest::Client::builder()
275                .timeout(std::time::Duration::from_secs(10))
276                .build()
277                .unwrap_or_else(|_| reqwest::Client::new()),
278            token_provider,
279        }
280    }
281
282    /// Parse `OpenRouter` API response into models.
283    fn parse_openrouter_models(data: &serde_json::Value, provider: &str) -> Vec<CachedModel> {
284        data.get("data")
285            .and_then(|d| d.as_array())
286            .map(|arr| {
287                arr.iter()
288                    .filter_map(|m| {
289                        let pricing_obj = m.get("pricing");
290                        let prompt_per_token = pricing_obj
291                            .and_then(|p| p.get("prompt"))
292                            .and_then(|p| p.as_str())
293                            .and_then(|s| s.parse::<f64>().ok());
294                        let completion_per_token = pricing_obj
295                            .and_then(|p| p.get("completion"))
296                            .and_then(|p| p.as_str())
297                            .and_then(|s| s.parse::<f64>().ok());
298
299                        let is_free = match (prompt_per_token, completion_per_token) {
300                            (Some(prompt), Some(completion)) => {
301                                Some(prompt == 0.0 && completion == 0.0)
302                            }
303                            (Some(prompt), None) => Some(prompt == 0.0),
304                            _ => pricing_obj
305                                .and_then(|p| p.get("prompt"))
306                                .and_then(|p| p.as_str())
307                                .map(|p| p == "0"),
308                        };
309
310                        let pricing =
311                            if prompt_per_token.is_some() || completion_per_token.is_some() {
312                                Some(PricingInfo {
313                                    prompt_per_token,
314                                    completion_per_token,
315                                })
316                            } else {
317                                None
318                            };
319
320                        // Derive capabilities from architecture field defensively
321                        let arch = m.get("architecture");
322                        let capabilities = {
323                            // Check input_modalities array first
324                            let from_input_modalities = arch
325                                .and_then(|a| a.get("input_modalities"))
326                                .and_then(|im| im.as_array())
327                                .map(|arr| {
328                                    arr.iter().filter_map(|v| v.as_str()).any(|s| s == "image")
329                                });
330                            // Fall back to modalities string
331                            let from_modalities_str = arch
332                                .and_then(|a| a.get("modalities"))
333                                .and_then(|m| m.as_str())
334                                .map(|s| s.contains("image"));
335
336                            let has_vision = from_input_modalities
337                                .or(from_modalities_str)
338                                .unwrap_or(false);
339
340                            if has_vision {
341                                vec![Capability::Vision]
342                            } else {
343                                vec![]
344                            }
345                        };
346
347                        Some(CachedModel {
348                            id: m.get("id")?.as_str()?.to_string(),
349                            name: m.get("name").and_then(|n| n.as_str()).map(String::from),
350                            is_free,
351                            context_window: m
352                                .get("context_length")
353                                .and_then(serde_json::Value::as_u64)
354                                .and_then(|c| u32::try_from(c).ok()),
355                            provider: provider.to_string(),
356                            capabilities,
357                            pricing,
358                        })
359                    })
360                    .collect()
361            })
362            .unwrap_or_default()
363    }
364
365    /// Parse Gemini API response into models.
366    fn parse_gemini_models(data: &serde_json::Value, provider: &str) -> Vec<CachedModel> {
367        data.get("models")
368            .and_then(|d| d.as_array())
369            .map(|arr| {
370                arr.iter()
371                    .filter_map(|m| {
372                        Some(CachedModel {
373                            id: m.get("name")?.as_str()?.to_string(),
374                            name: m
375                                .get("displayName")
376                                .and_then(|n| n.as_str())
377                                .map(String::from),
378                            is_free: None,
379                            context_window: m
380                                .get("inputTokenLimit")
381                                .and_then(serde_json::Value::as_u64)
382                                .and_then(|c| u32::try_from(c).ok()),
383                            provider: provider.to_string(),
384                            capabilities: vec![],
385                            pricing: None,
386                        })
387                    })
388                    .collect()
389            })
390            .unwrap_or_default()
391    }
392
393    /// Parse generic OpenAI-compatible API response into models.
394    fn parse_generic_models(data: &serde_json::Value, provider: &str) -> Vec<CachedModel> {
395        data.get("data")
396            .and_then(|d| d.as_array())
397            .map(|arr| {
398                arr.iter()
399                    .filter_map(|m| {
400                        Some(CachedModel {
401                            id: m.get("id")?.as_str()?.to_string(),
402                            name: None,
403                            is_free: None,
404                            context_window: None,
405                            provider: provider.to_string(),
406                            capabilities: vec![],
407                            pricing: None,
408                        })
409                    })
410                    .collect()
411            })
412            .unwrap_or_default()
413    }
414
415    /// Fetch models from provider API.
416    async fn fetch_from_api(&self, provider: &str) -> Result<Vec<CachedModel>, RegistryError> {
417        let url = match provider {
418            "openrouter" => "https://openrouter.ai/api/v1/models",
419            "gemini" => "https://generativelanguage.googleapis.com/v1beta/models",
420            "groq" => "https://api.groq.com/openai/v1/models",
421            "cerebras" => "https://api.cerebras.ai/v1/models",
422            "zenmux" => "https://zenmux.ai/api/v1/models",
423            "zai" => "https://api.z.ai/api/paas/v4/models",
424            _ => return Err(RegistryError::ProviderNotFound(provider.to_string())),
425        };
426
427        // Get API key from token provider
428        let api_key = self.token_provider.ai_api_key(provider).ok_or_else(|| {
429            RegistryError::HttpError(format!("No API key available for {provider}"))
430        })?;
431
432        // Build request incrementally with provider-specific authentication
433        let request = match provider {
434            "gemini" => {
435                // Gemini uses header authentication
436                self.client
437                    .get(url)
438                    .header("x-goog-api-key", api_key.expose_secret())
439            }
440            "openrouter" | "groq" | "cerebras" | "zenmux" | "zai" => {
441                // These providers use Bearer token authentication
442                self.client.get(url).header(
443                    "Authorization",
444                    format!("Bearer {}", api_key.expose_secret()),
445                )
446            }
447            _ => self.client.get(url),
448        };
449
450        let response = request
451            .send()
452            .await
453            .map_err(|e| RegistryError::HttpError(e.to_string()))?;
454
455        let data = response
456            .json::<serde_json::Value>()
457            .await
458            .map_err(|e| RegistryError::HttpError(e.to_string()))?;
459
460        // Parse based on provider API format
461        let models = match provider {
462            "openrouter" => Self::parse_openrouter_models(&data, provider),
463            "gemini" => Self::parse_gemini_models(&data, provider),
464            "groq" | "cerebras" | "zenmux" | "zai" => Self::parse_generic_models(&data, provider),
465            _ => vec![],
466        };
467
468        Ok(models)
469    }
470}
471
472#[cfg(not(target_arch = "wasm32"))]
473#[async_trait]
474impl ModelRegistry for CachedModelRegistry<'_> {
475    async fn list_models(&self, provider: &str) -> Result<Vec<CachedModel>, RegistryError> {
476        // Try fresh cache first
477        if let Ok(Some(models)) = self.cache.get(provider).await {
478            return Ok(models);
479        }
480
481        // Fetch from API with stale fallback
482        match self.fetch_from_api(provider).await {
483            Ok(models) => {
484                // Save to cache (ignore errors)
485                let _ = self.cache.set(provider, &models).await;
486                Ok(models)
487            }
488            Err(api_error) => {
489                // Try stale cache as fallback
490                match self.cache.get_stale(provider).await {
491                    Ok(Some(models)) => {
492                        tracing::warn!(
493                            provider = provider,
494                            error = %api_error,
495                            "API request failed, returning stale cached models"
496                        );
497                        Ok(models)
498                    }
499                    _ => {
500                        // No stale cache available, return original API error
501                        Err(api_error)
502                    }
503                }
504            }
505        }
506    }
507
508    async fn model_exists(&self, provider: &str, model_id: &str) -> Result<bool, RegistryError> {
509        let models = self.list_models(provider).await?;
510        Ok(models.iter().any(|m| m.id == model_id))
511    }
512
513    async fn validate_model(&self, provider: &str, model_id: &str) -> Result<(), RegistryError> {
514        if self.model_exists(provider, model_id).await? {
515            Ok(())
516        } else {
517            Err(RegistryError::ModelValidation {
518                model_id: model_id.to_string(),
519            })
520        }
521    }
522}
523
524#[cfg(test)]
525mod tests {
526    use super::*;
527
528    #[test]
529    fn test_get_provider_gemini() {
530        let provider = get_provider("gemini");
531        assert!(provider.is_some());
532        let provider = provider.unwrap();
533        assert_eq!(provider.display_name, "Google Gemini");
534        assert_eq!(provider.api_key_env, "GEMINI_API_KEY");
535    }
536
537    #[test]
538    fn test_get_provider_openrouter() {
539        let provider = get_provider("openrouter");
540        assert!(provider.is_some());
541        let provider = provider.unwrap();
542        assert_eq!(provider.display_name, "OpenRouter");
543        assert_eq!(provider.api_key_env, "OPENROUTER_API_KEY");
544    }
545
546    #[test]
547    fn test_get_provider_groq() {
548        let provider = get_provider("groq");
549        assert!(provider.is_some());
550        let provider = provider.unwrap();
551        assert_eq!(provider.display_name, "Groq");
552        assert_eq!(provider.api_key_env, "GROQ_API_KEY");
553    }
554
555    #[test]
556    fn test_get_provider_cerebras() {
557        let provider = get_provider("cerebras");
558        assert!(provider.is_some());
559        let provider = provider.unwrap();
560        assert_eq!(provider.display_name, "Cerebras");
561        assert_eq!(provider.api_key_env, "CEREBRAS_API_KEY");
562    }
563
564    #[test]
565    fn test_get_provider_not_found() {
566        let provider = get_provider("nonexistent");
567        assert!(provider.is_none());
568    }
569
570    #[test]
571    fn test_get_provider_case_sensitive() {
572        let provider = get_provider("OpenRouter");
573        assert!(
574            provider.is_none(),
575            "Provider lookup should be case-sensitive"
576        );
577    }
578
579    #[test]
580    fn test_all_providers_count() {
581        let providers = all_providers();
582        assert_eq!(providers.len(), 7, "Should have exactly 7 providers");
583    }
584
585    #[test]
586    fn test_all_providers_have_unique_names() {
587        let providers = all_providers();
588        let mut names = Vec::new();
589        for provider in providers {
590            assert!(
591                !names.contains(&provider.name),
592                "Duplicate provider name: {}",
593                provider.name
594            );
595            names.push(provider.name);
596        }
597    }
598
599    #[test]
600    fn test_get_provider_zenmux() {
601        let provider = get_provider("zenmux");
602        assert!(provider.is_some());
603        let provider = provider.unwrap();
604        assert_eq!(provider.display_name, "Zenmux");
605        assert_eq!(provider.api_key_env, "ZENMUX_API_KEY");
606    }
607
608    #[test]
609    fn test_get_provider_zai() {
610        let provider = get_provider("zai");
611        assert!(provider.is_some());
612        let provider = provider.unwrap();
613        assert_eq!(provider.display_name, "Z.AI (Zhipu)");
614        assert_eq!(provider.api_key_env, "ZAI_API_KEY");
615    }
616
617    #[test]
618    fn test_provider_api_urls_valid() {
619        let providers = all_providers();
620        for provider in providers {
621            assert!(
622                provider.api_url.starts_with("https://"),
623                "Provider {} API URL should use HTTPS",
624                provider.name
625            );
626        }
627    }
628
629    #[test]
630    fn test_provider_api_key_env_not_empty() {
631        let providers = all_providers();
632        for provider in providers {
633            assert!(
634                !provider.api_key_env.is_empty(),
635                "Provider {} should have API key env var",
636                provider.name
637            );
638        }
639    }
640
641    #[test]
642    fn test_parse_openrouter_models_with_pricing() {
643        let data = serde_json::json!({
644            "data": [
645                {
646                    "id": "openai/gpt-4o",
647                    "name": "GPT-4o",
648                    "context_length": 128_000,
649                    "pricing": {
650                        "prompt": "0.000005",
651                        "completion": "0.000015"
652                    },
653                    "architecture": {
654                        "input_modalities": ["text", "image"],
655                        "output_modalities": ["text"]
656                    }
657                }
658            ]
659        });
660
661        let models = CachedModelRegistry::parse_openrouter_models(&data, "openrouter");
662        assert_eq!(models.len(), 1);
663        let m = &models[0];
664        assert_eq!(m.id, "openai/gpt-4o");
665        assert_eq!(m.is_free, Some(false));
666        let pricing = m.pricing.as_ref().expect("pricing should be present");
667        assert_eq!(pricing.prompt_per_token, Some(0.000_005));
668        assert_eq!(pricing.completion_per_token, Some(0.000_015));
669        assert!(m.capabilities.contains(&Capability::Vision));
670    }
671
672    #[test]
673    fn test_parse_openrouter_models_missing_capabilities() {
674        let data = serde_json::json!({
675            "data": [
676                {
677                    "id": "some/text-only-model",
678                    "name": "Text Only",
679                    "context_length": 32000,
680                    "pricing": {
681                        "prompt": "0",
682                        "completion": "0"
683                    }
684                }
685            ]
686        });
687
688        let models = CachedModelRegistry::parse_openrouter_models(&data, "openrouter");
689        assert_eq!(models.len(), 1);
690        let m = &models[0];
691        assert!(
692            m.capabilities.is_empty(),
693            "no vision if architecture missing"
694        );
695        assert_eq!(m.is_free, Some(true));
696    }
697}