1use async_trait::async_trait;
26use secrecy::ExposeSecret;
27use serde::{Deserialize, Serialize};
28use std::path::PathBuf;
29use thiserror::Error;
30
31use crate::auth::TokenProvider;
32use crate::cache::FileCache;
33
34#[derive(Clone, Copy, Debug, PartialEq, Eq)]
36pub struct ProviderConfig {
37 pub name: &'static str,
39
40 pub display_name: &'static str,
42
43 pub api_url: &'static str,
45
46 pub api_key_env: &'static str,
48}
49
50pub const PROVIDER_ANTHROPIC: &str = "anthropic";
59
60pub static PROVIDERS: &[ProviderConfig] = &[
62 ProviderConfig {
63 name: "gemini",
64 display_name: "Google Gemini",
65 api_url: "https://generativelanguage.googleapis.com/v1beta/openai/chat/completions",
66 api_key_env: "GEMINI_API_KEY",
67 },
68 ProviderConfig {
69 name: "openrouter",
70 display_name: "OpenRouter",
71 api_url: "https://openrouter.ai/api/v1/chat/completions",
72 api_key_env: "OPENROUTER_API_KEY",
73 },
74 ProviderConfig {
75 name: "groq",
76 display_name: "Groq",
77 api_url: "https://api.groq.com/openai/v1/chat/completions",
78 api_key_env: "GROQ_API_KEY",
79 },
80 ProviderConfig {
81 name: "cerebras",
82 display_name: "Cerebras",
83 api_url: "https://api.cerebras.ai/v1/chat/completions",
84 api_key_env: "CEREBRAS_API_KEY",
85 },
86 ProviderConfig {
87 name: "zenmux",
88 display_name: "Zenmux",
89 api_url: "https://zenmux.ai/api/v1/chat/completions",
90 api_key_env: "ZENMUX_API_KEY",
91 },
92 ProviderConfig {
93 name: "zai",
94 display_name: "Z.AI (Zhipu)",
95 api_url: "https://api.z.ai/api/paas/v4/chat/completions",
96 api_key_env: "ZAI_API_KEY",
97 },
98 ProviderConfig {
99 name: PROVIDER_ANTHROPIC,
100 display_name: "Anthropic",
101 api_url: "https://api.anthropic.com/v1/chat/completions",
102 api_key_env: "ANTHROPIC_API_KEY",
103 },
104];
105
106#[must_use]
126pub fn get_provider(name: &str) -> Option<&'static ProviderConfig> {
127 PROVIDERS.iter().find(|p| p.name == name)
128}
129
130#[must_use]
145pub fn all_providers() -> &'static [ProviderConfig] {
146 PROVIDERS
147}
148
149#[derive(Debug, Error)]
155pub enum RegistryError {
156 #[error("HTTP request failed: {0}")]
158 HttpError(String),
159
160 #[error("Failed to parse API response: {0}")]
162 ParseError(String),
163
164 #[error("Provider not found: {0}")]
166 ProviderNotFound(String),
167
168 #[error("Cache error: {0}")]
170 CacheError(String),
171
172 #[error("IO error: {0}")]
174 IoError(#[from] std::io::Error),
175
176 #[error("Invalid model ID: {model_id}")]
178 ModelValidation {
179 model_id: String,
181 },
182}
183
184#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
186#[serde(rename_all = "snake_case")]
187pub enum Capability {
188 Vision,
190 FunctionCalling,
192 Reasoning,
194}
195
196#[derive(Clone, Debug, Serialize, Deserialize)]
203pub struct PricingInfo {
204 pub prompt_per_token: Option<f64>,
206 pub completion_per_token: Option<f64>,
208}
209
210#[derive(Clone, Debug, Serialize, Deserialize)]
212pub struct CachedModel {
213 pub id: String,
215 pub name: Option<String>,
217 pub is_free: Option<bool>,
219 pub context_window: Option<u32>,
221 pub provider: String,
223 #[serde(default)]
225 pub capabilities: Vec<Capability>,
226 #[serde(default)]
228 pub pricing: Option<PricingInfo>,
229}
230
231#[async_trait]
233pub trait ModelRegistry: Send + Sync {
234 async fn list_models(&self, provider: &str) -> Result<Vec<CachedModel>, RegistryError>;
236
237 async fn model_exists(&self, provider: &str, model_id: &str) -> Result<bool, RegistryError>;
239
240 async fn validate_model(&self, provider: &str, model_id: &str) -> Result<(), RegistryError>;
242}
243
244pub struct CachedModelRegistry<'a> {
246 cache: crate::cache::FileCacheImpl<Vec<CachedModel>>,
247 client: reqwest::Client,
248 token_provider: &'a dyn TokenProvider,
249}
250
251impl CachedModelRegistry<'_> {
252 #[must_use]
260 pub fn new(
261 cache_dir: Option<PathBuf>,
262 ttl_seconds: u64,
263 token_provider: &dyn TokenProvider,
264 ) -> CachedModelRegistry<'_> {
265 let ttl = chrono::Duration::seconds(
266 ttl_seconds
267 .try_into()
268 .unwrap_or(crate::cache::DEFAULT_MODEL_TTL_SECS.cast_signed()),
269 );
270 CachedModelRegistry {
271 cache: crate::cache::FileCacheImpl::with_dir(cache_dir, "models", ttl),
272 client: reqwest::Client::builder()
273 .timeout(std::time::Duration::from_secs(10))
274 .build()
275 .unwrap_or_else(|_| reqwest::Client::new()),
276 token_provider,
277 }
278 }
279
280 fn parse_openrouter_models(data: &serde_json::Value, provider: &str) -> Vec<CachedModel> {
282 data.get("data")
283 .and_then(|d| d.as_array())
284 .map(|arr| {
285 arr.iter()
286 .filter_map(|m| {
287 let pricing_obj = m.get("pricing");
288 let prompt_per_token = pricing_obj
289 .and_then(|p| p.get("prompt"))
290 .and_then(|p| p.as_str())
291 .and_then(|s| s.parse::<f64>().ok());
292 let completion_per_token = pricing_obj
293 .and_then(|p| p.get("completion"))
294 .and_then(|p| p.as_str())
295 .and_then(|s| s.parse::<f64>().ok());
296
297 let is_free = match (prompt_per_token, completion_per_token) {
298 (Some(prompt), Some(completion)) => {
299 Some(prompt == 0.0 && completion == 0.0)
300 }
301 (Some(prompt), None) => Some(prompt == 0.0),
302 _ => pricing_obj
303 .and_then(|p| p.get("prompt"))
304 .and_then(|p| p.as_str())
305 .map(|p| p == "0"),
306 };
307
308 let pricing =
309 if prompt_per_token.is_some() || completion_per_token.is_some() {
310 Some(PricingInfo {
311 prompt_per_token,
312 completion_per_token,
313 })
314 } else {
315 None
316 };
317
318 let arch = m.get("architecture");
320 let capabilities = {
321 let from_input_modalities = arch
323 .and_then(|a| a.get("input_modalities"))
324 .and_then(|im| im.as_array())
325 .map(|arr| {
326 arr.iter().filter_map(|v| v.as_str()).any(|s| s == "image")
327 });
328 let from_modalities_str = arch
330 .and_then(|a| a.get("modalities"))
331 .and_then(|m| m.as_str())
332 .map(|s| s.contains("image"));
333
334 let has_vision = from_input_modalities
335 .or(from_modalities_str)
336 .unwrap_or(false);
337
338 if has_vision {
339 vec![Capability::Vision]
340 } else {
341 vec![]
342 }
343 };
344
345 Some(CachedModel {
346 id: m.get("id")?.as_str()?.to_string(),
347 name: m.get("name").and_then(|n| n.as_str()).map(String::from),
348 is_free,
349 context_window: m
350 .get("context_length")
351 .and_then(serde_json::Value::as_u64)
352 .and_then(|c| u32::try_from(c).ok()),
353 provider: provider.to_string(),
354 capabilities,
355 pricing,
356 })
357 })
358 .collect()
359 })
360 .unwrap_or_default()
361 }
362
363 fn parse_gemini_models(data: &serde_json::Value, provider: &str) -> Vec<CachedModel> {
365 data.get("models")
366 .and_then(|d| d.as_array())
367 .map(|arr| {
368 arr.iter()
369 .filter_map(|m| {
370 Some(CachedModel {
371 id: m.get("name")?.as_str()?.to_string(),
372 name: m
373 .get("displayName")
374 .and_then(|n| n.as_str())
375 .map(String::from),
376 is_free: None,
377 context_window: m
378 .get("inputTokenLimit")
379 .and_then(serde_json::Value::as_u64)
380 .and_then(|c| u32::try_from(c).ok()),
381 provider: provider.to_string(),
382 capabilities: vec![],
383 pricing: None,
384 })
385 })
386 .collect()
387 })
388 .unwrap_or_default()
389 }
390
391 fn parse_generic_models(data: &serde_json::Value, provider: &str) -> Vec<CachedModel> {
393 data.get("data")
394 .and_then(|d| d.as_array())
395 .map(|arr| {
396 arr.iter()
397 .filter_map(|m| {
398 Some(CachedModel {
399 id: m.get("id")?.as_str()?.to_string(),
400 name: None,
401 is_free: None,
402 context_window: None,
403 provider: provider.to_string(),
404 capabilities: vec![],
405 pricing: None,
406 })
407 })
408 .collect()
409 })
410 .unwrap_or_default()
411 }
412
413 async fn fetch_from_api(&self, provider: &str) -> Result<Vec<CachedModel>, RegistryError> {
415 let url = match provider {
416 "openrouter" => "https://openrouter.ai/api/v1/models",
417 "gemini" => "https://generativelanguage.googleapis.com/v1beta/models",
418 "groq" => "https://api.groq.com/openai/v1/models",
419 "cerebras" => "https://api.cerebras.ai/v1/models",
420 "zenmux" => "https://zenmux.ai/api/v1/models",
421 "zai" => "https://api.z.ai/api/paas/v4/models",
422 _ => return Err(RegistryError::ProviderNotFound(provider.to_string())),
423 };
424
425 let api_key = self.token_provider.ai_api_key(provider).ok_or_else(|| {
427 RegistryError::HttpError(format!("No API key available for {provider}"))
428 })?;
429
430 let request = match provider {
432 "gemini" => {
433 self.client
435 .get(url)
436 .header("x-goog-api-key", api_key.expose_secret())
437 }
438 "openrouter" | "groq" | "cerebras" | "zenmux" | "zai" => {
439 self.client.get(url).header(
441 "Authorization",
442 format!("Bearer {}", api_key.expose_secret()),
443 )
444 }
445 _ => self.client.get(url),
446 };
447
448 let response = request
449 .send()
450 .await
451 .map_err(|e| RegistryError::HttpError(e.to_string()))?;
452
453 let data = response
454 .json::<serde_json::Value>()
455 .await
456 .map_err(|e| RegistryError::HttpError(e.to_string()))?;
457
458 let models = match provider {
460 "openrouter" => Self::parse_openrouter_models(&data, provider),
461 "gemini" => Self::parse_gemini_models(&data, provider),
462 "groq" | "cerebras" | "zenmux" | "zai" => Self::parse_generic_models(&data, provider),
463 _ => vec![],
464 };
465
466 Ok(models)
467 }
468}
469
470#[async_trait]
471impl ModelRegistry for CachedModelRegistry<'_> {
472 async fn list_models(&self, provider: &str) -> Result<Vec<CachedModel>, RegistryError> {
473 if let Ok(Some(models)) = self.cache.get(provider).await {
475 return Ok(models);
476 }
477
478 match self.fetch_from_api(provider).await {
480 Ok(models) => {
481 let _ = self.cache.set(provider, &models).await;
483 Ok(models)
484 }
485 Err(api_error) => {
486 match self.cache.get_stale(provider).await {
488 Ok(Some(models)) => {
489 tracing::warn!(
490 provider = provider,
491 error = %api_error,
492 "API request failed, returning stale cached models"
493 );
494 Ok(models)
495 }
496 _ => {
497 Err(api_error)
499 }
500 }
501 }
502 }
503 }
504
505 async fn model_exists(&self, provider: &str, model_id: &str) -> Result<bool, RegistryError> {
506 let models = self.list_models(provider).await?;
507 Ok(models.iter().any(|m| m.id == model_id))
508 }
509
510 async fn validate_model(&self, provider: &str, model_id: &str) -> Result<(), RegistryError> {
511 if self.model_exists(provider, model_id).await? {
512 Ok(())
513 } else {
514 Err(RegistryError::ModelValidation {
515 model_id: model_id.to_string(),
516 })
517 }
518 }
519}
520
521#[cfg(test)]
522mod tests {
523 use super::*;
524
525 #[test]
526 fn test_get_provider_gemini() {
527 let provider = get_provider("gemini");
528 assert!(provider.is_some());
529 let provider = provider.unwrap();
530 assert_eq!(provider.display_name, "Google Gemini");
531 assert_eq!(provider.api_key_env, "GEMINI_API_KEY");
532 }
533
534 #[test]
535 fn test_get_provider_openrouter() {
536 let provider = get_provider("openrouter");
537 assert!(provider.is_some());
538 let provider = provider.unwrap();
539 assert_eq!(provider.display_name, "OpenRouter");
540 assert_eq!(provider.api_key_env, "OPENROUTER_API_KEY");
541 }
542
543 #[test]
544 fn test_get_provider_groq() {
545 let provider = get_provider("groq");
546 assert!(provider.is_some());
547 let provider = provider.unwrap();
548 assert_eq!(provider.display_name, "Groq");
549 assert_eq!(provider.api_key_env, "GROQ_API_KEY");
550 }
551
552 #[test]
553 fn test_get_provider_cerebras() {
554 let provider = get_provider("cerebras");
555 assert!(provider.is_some());
556 let provider = provider.unwrap();
557 assert_eq!(provider.display_name, "Cerebras");
558 assert_eq!(provider.api_key_env, "CEREBRAS_API_KEY");
559 }
560
561 #[test]
562 fn test_get_provider_not_found() {
563 let provider = get_provider("nonexistent");
564 assert!(provider.is_none());
565 }
566
567 #[test]
568 fn test_get_provider_case_sensitive() {
569 let provider = get_provider("OpenRouter");
570 assert!(
571 provider.is_none(),
572 "Provider lookup should be case-sensitive"
573 );
574 }
575
576 #[test]
577 fn test_all_providers_count() {
578 let providers = all_providers();
579 assert_eq!(providers.len(), 7, "Should have exactly 7 providers");
580 }
581
582 #[test]
583 fn test_all_providers_have_unique_names() {
584 let providers = all_providers();
585 let mut names = Vec::new();
586 for provider in providers {
587 assert!(
588 !names.contains(&provider.name),
589 "Duplicate provider name: {}",
590 provider.name
591 );
592 names.push(provider.name);
593 }
594 }
595
596 #[test]
597 fn test_get_provider_zenmux() {
598 let provider = get_provider("zenmux");
599 assert!(provider.is_some());
600 let provider = provider.unwrap();
601 assert_eq!(provider.display_name, "Zenmux");
602 assert_eq!(provider.api_key_env, "ZENMUX_API_KEY");
603 }
604
605 #[test]
606 fn test_get_provider_zai() {
607 let provider = get_provider("zai");
608 assert!(provider.is_some());
609 let provider = provider.unwrap();
610 assert_eq!(provider.display_name, "Z.AI (Zhipu)");
611 assert_eq!(provider.api_key_env, "ZAI_API_KEY");
612 }
613
614 #[test]
615 fn test_provider_api_urls_valid() {
616 let providers = all_providers();
617 for provider in providers {
618 assert!(
619 provider.api_url.starts_with("https://"),
620 "Provider {} API URL should use HTTPS",
621 provider.name
622 );
623 }
624 }
625
626 #[test]
627 fn test_provider_api_key_env_not_empty() {
628 let providers = all_providers();
629 for provider in providers {
630 assert!(
631 !provider.api_key_env.is_empty(),
632 "Provider {} should have API key env var",
633 provider.name
634 );
635 }
636 }
637
638 #[test]
639 fn test_parse_openrouter_models_with_pricing() {
640 let data = serde_json::json!({
641 "data": [
642 {
643 "id": "openai/gpt-4o",
644 "name": "GPT-4o",
645 "context_length": 128_000,
646 "pricing": {
647 "prompt": "0.000005",
648 "completion": "0.000015"
649 },
650 "architecture": {
651 "input_modalities": ["text", "image"],
652 "output_modalities": ["text"]
653 }
654 }
655 ]
656 });
657
658 let models = CachedModelRegistry::parse_openrouter_models(&data, "openrouter");
659 assert_eq!(models.len(), 1);
660 let m = &models[0];
661 assert_eq!(m.id, "openai/gpt-4o");
662 assert_eq!(m.is_free, Some(false));
663 let pricing = m.pricing.as_ref().expect("pricing should be present");
664 assert_eq!(pricing.prompt_per_token, Some(0.000_005));
665 assert_eq!(pricing.completion_per_token, Some(0.000_015));
666 assert!(m.capabilities.contains(&Capability::Vision));
667 }
668
669 #[test]
670 fn test_parse_openrouter_models_missing_capabilities() {
671 let data = serde_json::json!({
672 "data": [
673 {
674 "id": "some/text-only-model",
675 "name": "Text Only",
676 "context_length": 32000,
677 "pricing": {
678 "prompt": "0",
679 "completion": "0"
680 }
681 }
682 ]
683 });
684
685 let models = CachedModelRegistry::parse_openrouter_models(&data, "openrouter");
686 assert_eq!(models.len(), 1);
687 let m = &models[0];
688 assert!(
689 m.capabilities.is_empty(),
690 "no vision if architecture missing"
691 );
692 assert_eq!(m.is_free, Some(true));
693 }
694}