1use async_trait::async_trait;
26use secrecy::ExposeSecret;
27use serde::{Deserialize, Serialize};
28use std::path::PathBuf;
29use thiserror::Error;
30
31use crate::auth::TokenProvider;
32use crate::cache::FileCache;
33
34#[derive(Clone, Copy, Debug, PartialEq, Eq)]
36pub struct ProviderConfig {
37 pub name: &'static str,
39
40 pub display_name: &'static str,
42
43 pub api_url: &'static str,
45
46 pub api_key_env: &'static str,
48}
49
50pub const PROVIDER_ANTHROPIC: &str = "anthropic";
59
60pub static PROVIDERS: &[ProviderConfig] = &[
62 ProviderConfig {
63 name: "gemini",
64 display_name: "Google Gemini",
65 api_url: "https://generativelanguage.googleapis.com/v1beta/openai/chat/completions",
66 api_key_env: "GEMINI_API_KEY",
67 },
68 ProviderConfig {
69 name: "openrouter",
70 display_name: "OpenRouter",
71 api_url: "https://openrouter.ai/api/v1/chat/completions",
72 api_key_env: "OPENROUTER_API_KEY",
73 },
74 ProviderConfig {
75 name: "groq",
76 display_name: "Groq",
77 api_url: "https://api.groq.com/openai/v1/chat/completions",
78 api_key_env: "GROQ_API_KEY",
79 },
80 ProviderConfig {
81 name: "cerebras",
82 display_name: "Cerebras",
83 api_url: "https://api.cerebras.ai/v1/chat/completions",
84 api_key_env: "CEREBRAS_API_KEY",
85 },
86 ProviderConfig {
87 name: "zenmux",
88 display_name: "Zenmux",
89 api_url: "https://zenmux.ai/api/v1/chat/completions",
90 api_key_env: "ZENMUX_API_KEY",
91 },
92 ProviderConfig {
93 name: "zai",
94 display_name: "Z.AI (Zhipu)",
95 api_url: "https://api.z.ai/api/paas/v4/chat/completions",
96 api_key_env: "ZAI_API_KEY",
97 },
98 ProviderConfig {
99 name: PROVIDER_ANTHROPIC,
100 display_name: "Anthropic",
101 api_url: "https://api.anthropic.com/v1/chat/completions",
102 api_key_env: "ANTHROPIC_API_KEY",
103 },
104];
105
106#[must_use]
126pub fn get_provider(name: &str) -> Option<&'static ProviderConfig> {
127 PROVIDERS.iter().find(|p| p.name == name)
128}
129
130#[must_use]
145pub fn all_providers() -> &'static [ProviderConfig] {
146 PROVIDERS
147}
148
149#[derive(Debug, Error)]
155pub enum RegistryError {
156 #[error("HTTP request failed: {0}")]
158 HttpError(String),
159
160 #[error("Failed to parse API response: {0}")]
162 ParseError(String),
163
164 #[error("Provider not found: {0}")]
166 ProviderNotFound(String),
167
168 #[error("Cache error: {0}")]
170 CacheError(String),
171
172 #[error("IO error: {0}")]
174 IoError(#[from] std::io::Error),
175
176 #[error("Invalid model ID: {model_id}")]
178 ModelValidation {
179 model_id: String,
181 },
182}
183
184#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
186#[serde(rename_all = "snake_case")]
187pub enum Capability {
188 Vision,
190 FunctionCalling,
192 Reasoning,
194}
195
196#[derive(Clone, Debug, Serialize, Deserialize)]
203pub struct PricingInfo {
204 pub prompt_per_token: Option<f64>,
206 pub completion_per_token: Option<f64>,
208}
209
210#[derive(Clone, Debug, Serialize, Deserialize)]
212pub struct CachedModel {
213 pub id: String,
215 pub name: Option<String>,
217 pub is_free: Option<bool>,
219 pub context_window: Option<u32>,
221 pub provider: String,
223 #[serde(default)]
225 pub capabilities: Vec<Capability>,
226 #[serde(default)]
228 pub pricing: Option<PricingInfo>,
229}
230
231#[async_trait]
233pub trait ModelRegistry: Send + Sync {
234 async fn list_models(&self, provider: &str) -> Result<Vec<CachedModel>, RegistryError>;
236
237 async fn model_exists(&self, provider: &str, model_id: &str) -> Result<bool, RegistryError>;
239
240 async fn validate_model(&self, provider: &str, model_id: &str) -> Result<(), RegistryError>;
242}
243
244#[cfg(not(target_arch = "wasm32"))]
246pub struct CachedModelRegistry<'a> {
247 cache: crate::cache::FileCacheImpl<Vec<CachedModel>>,
248 client: reqwest::Client,
249 token_provider: &'a dyn TokenProvider,
250}
251
252#[cfg(not(target_arch = "wasm32"))]
253impl CachedModelRegistry<'_> {
254 #[must_use]
262 pub fn new(
263 cache_dir: Option<PathBuf>,
264 ttl_seconds: u64,
265 token_provider: &dyn TokenProvider,
266 ) -> CachedModelRegistry<'_> {
267 let ttl = chrono::Duration::seconds(
268 ttl_seconds
269 .try_into()
270 .unwrap_or(crate::cache::DEFAULT_MODEL_TTL_SECS.cast_signed()),
271 );
272 CachedModelRegistry {
273 cache: crate::cache::FileCacheImpl::with_dir(cache_dir, "models", ttl),
274 client: reqwest::Client::builder()
275 .timeout(std::time::Duration::from_secs(10))
276 .build()
277 .unwrap_or_else(|_| reqwest::Client::new()),
278 token_provider,
279 }
280 }
281
282 fn parse_openrouter_models(data: &serde_json::Value, provider: &str) -> Vec<CachedModel> {
284 data.get("data")
285 .and_then(|d| d.as_array())
286 .map(|arr| {
287 arr.iter()
288 .filter_map(|m| {
289 let pricing_obj = m.get("pricing");
290 let prompt_per_token = pricing_obj
291 .and_then(|p| p.get("prompt"))
292 .and_then(|p| p.as_str())
293 .and_then(|s| s.parse::<f64>().ok());
294 let completion_per_token = pricing_obj
295 .and_then(|p| p.get("completion"))
296 .and_then(|p| p.as_str())
297 .and_then(|s| s.parse::<f64>().ok());
298
299 let is_free = match (prompt_per_token, completion_per_token) {
300 (Some(prompt), Some(completion)) => {
301 Some(prompt == 0.0 && completion == 0.0)
302 }
303 (Some(prompt), None) => Some(prompt == 0.0),
304 _ => pricing_obj
305 .and_then(|p| p.get("prompt"))
306 .and_then(|p| p.as_str())
307 .map(|p| p == "0"),
308 };
309
310 let pricing =
311 if prompt_per_token.is_some() || completion_per_token.is_some() {
312 Some(PricingInfo {
313 prompt_per_token,
314 completion_per_token,
315 })
316 } else {
317 None
318 };
319
320 let arch = m.get("architecture");
322 let capabilities = {
323 let from_input_modalities = arch
325 .and_then(|a| a.get("input_modalities"))
326 .and_then(|im| im.as_array())
327 .map(|arr| {
328 arr.iter().filter_map(|v| v.as_str()).any(|s| s == "image")
329 });
330 let from_modalities_str = arch
332 .and_then(|a| a.get("modalities"))
333 .and_then(|m| m.as_str())
334 .map(|s| s.contains("image"));
335
336 let has_vision = from_input_modalities
337 .or(from_modalities_str)
338 .unwrap_or(false);
339
340 if has_vision {
341 vec![Capability::Vision]
342 } else {
343 vec![]
344 }
345 };
346
347 Some(CachedModel {
348 id: m.get("id")?.as_str()?.to_string(),
349 name: m.get("name").and_then(|n| n.as_str()).map(String::from),
350 is_free,
351 context_window: m
352 .get("context_length")
353 .and_then(serde_json::Value::as_u64)
354 .and_then(|c| u32::try_from(c).ok()),
355 provider: provider.to_string(),
356 capabilities,
357 pricing,
358 })
359 })
360 .collect()
361 })
362 .unwrap_or_default()
363 }
364
365 fn parse_gemini_models(data: &serde_json::Value, provider: &str) -> Vec<CachedModel> {
367 data.get("models")
368 .and_then(|d| d.as_array())
369 .map(|arr| {
370 arr.iter()
371 .filter_map(|m| {
372 Some(CachedModel {
373 id: m.get("name")?.as_str()?.to_string(),
374 name: m
375 .get("displayName")
376 .and_then(|n| n.as_str())
377 .map(String::from),
378 is_free: None,
379 context_window: m
380 .get("inputTokenLimit")
381 .and_then(serde_json::Value::as_u64)
382 .and_then(|c| u32::try_from(c).ok()),
383 provider: provider.to_string(),
384 capabilities: vec![],
385 pricing: None,
386 })
387 })
388 .collect()
389 })
390 .unwrap_or_default()
391 }
392
393 fn parse_generic_models(data: &serde_json::Value, provider: &str) -> Vec<CachedModel> {
395 data.get("data")
396 .and_then(|d| d.as_array())
397 .map(|arr| {
398 arr.iter()
399 .filter_map(|m| {
400 Some(CachedModel {
401 id: m.get("id")?.as_str()?.to_string(),
402 name: None,
403 is_free: None,
404 context_window: None,
405 provider: provider.to_string(),
406 capabilities: vec![],
407 pricing: None,
408 })
409 })
410 .collect()
411 })
412 .unwrap_or_default()
413 }
414
415 async fn fetch_from_api(&self, provider: &str) -> Result<Vec<CachedModel>, RegistryError> {
417 let url = match provider {
418 "openrouter" => "https://openrouter.ai/api/v1/models",
419 "gemini" => "https://generativelanguage.googleapis.com/v1beta/models",
420 "groq" => "https://api.groq.com/openai/v1/models",
421 "cerebras" => "https://api.cerebras.ai/v1/models",
422 "zenmux" => "https://zenmux.ai/api/v1/models",
423 "zai" => "https://api.z.ai/api/paas/v4/models",
424 _ => return Err(RegistryError::ProviderNotFound(provider.to_string())),
425 };
426
427 let api_key = self.token_provider.ai_api_key(provider).ok_or_else(|| {
429 RegistryError::HttpError(format!("No API key available for {provider}"))
430 })?;
431
432 let request = match provider {
434 "gemini" => {
435 self.client
437 .get(url)
438 .header("x-goog-api-key", api_key.expose_secret())
439 }
440 "openrouter" | "groq" | "cerebras" | "zenmux" | "zai" => {
441 self.client.get(url).header(
443 "Authorization",
444 format!("Bearer {}", api_key.expose_secret()),
445 )
446 }
447 _ => self.client.get(url),
448 };
449
450 let response = request
451 .send()
452 .await
453 .map_err(|e| RegistryError::HttpError(e.to_string()))?;
454
455 let data = response
456 .json::<serde_json::Value>()
457 .await
458 .map_err(|e| RegistryError::HttpError(e.to_string()))?;
459
460 let models = match provider {
462 "openrouter" => Self::parse_openrouter_models(&data, provider),
463 "gemini" => Self::parse_gemini_models(&data, provider),
464 "groq" | "cerebras" | "zenmux" | "zai" => Self::parse_generic_models(&data, provider),
465 _ => vec![],
466 };
467
468 Ok(models)
469 }
470}
471
472#[cfg(not(target_arch = "wasm32"))]
473#[async_trait]
474impl ModelRegistry for CachedModelRegistry<'_> {
475 async fn list_models(&self, provider: &str) -> Result<Vec<CachedModel>, RegistryError> {
476 if let Ok(Some(models)) = self.cache.get(provider).await {
478 return Ok(models);
479 }
480
481 match self.fetch_from_api(provider).await {
483 Ok(models) => {
484 let _ = self.cache.set(provider, &models).await;
486 Ok(models)
487 }
488 Err(api_error) => {
489 match self.cache.get_stale(provider).await {
491 Ok(Some(models)) => {
492 tracing::warn!(
493 provider = provider,
494 error = %api_error,
495 "API request failed, returning stale cached models"
496 );
497 Ok(models)
498 }
499 _ => {
500 Err(api_error)
502 }
503 }
504 }
505 }
506 }
507
508 async fn model_exists(&self, provider: &str, model_id: &str) -> Result<bool, RegistryError> {
509 let models = self.list_models(provider).await?;
510 Ok(models.iter().any(|m| m.id == model_id))
511 }
512
513 async fn validate_model(&self, provider: &str, model_id: &str) -> Result<(), RegistryError> {
514 if self.model_exists(provider, model_id).await? {
515 Ok(())
516 } else {
517 Err(RegistryError::ModelValidation {
518 model_id: model_id.to_string(),
519 })
520 }
521 }
522}
523
524#[cfg(test)]
525mod tests {
526 use super::*;
527
528 #[test]
529 fn test_get_provider_gemini() {
530 let provider = get_provider("gemini");
531 assert!(provider.is_some());
532 let provider = provider.unwrap();
533 assert_eq!(provider.display_name, "Google Gemini");
534 assert_eq!(provider.api_key_env, "GEMINI_API_KEY");
535 }
536
537 #[test]
538 fn test_get_provider_openrouter() {
539 let provider = get_provider("openrouter");
540 assert!(provider.is_some());
541 let provider = provider.unwrap();
542 assert_eq!(provider.display_name, "OpenRouter");
543 assert_eq!(provider.api_key_env, "OPENROUTER_API_KEY");
544 }
545
546 #[test]
547 fn test_get_provider_groq() {
548 let provider = get_provider("groq");
549 assert!(provider.is_some());
550 let provider = provider.unwrap();
551 assert_eq!(provider.display_name, "Groq");
552 assert_eq!(provider.api_key_env, "GROQ_API_KEY");
553 }
554
555 #[test]
556 fn test_get_provider_cerebras() {
557 let provider = get_provider("cerebras");
558 assert!(provider.is_some());
559 let provider = provider.unwrap();
560 assert_eq!(provider.display_name, "Cerebras");
561 assert_eq!(provider.api_key_env, "CEREBRAS_API_KEY");
562 }
563
564 #[test]
565 fn test_get_provider_not_found() {
566 let provider = get_provider("nonexistent");
567 assert!(provider.is_none());
568 }
569
570 #[test]
571 fn test_get_provider_case_sensitive() {
572 let provider = get_provider("OpenRouter");
573 assert!(
574 provider.is_none(),
575 "Provider lookup should be case-sensitive"
576 );
577 }
578
579 #[test]
580 fn test_all_providers_count() {
581 let providers = all_providers();
582 assert_eq!(providers.len(), 7, "Should have exactly 7 providers");
583 }
584
585 #[test]
586 fn test_all_providers_have_unique_names() {
587 let providers = all_providers();
588 let mut names = Vec::new();
589 for provider in providers {
590 assert!(
591 !names.contains(&provider.name),
592 "Duplicate provider name: {}",
593 provider.name
594 );
595 names.push(provider.name);
596 }
597 }
598
599 #[test]
600 fn test_get_provider_zenmux() {
601 let provider = get_provider("zenmux");
602 assert!(provider.is_some());
603 let provider = provider.unwrap();
604 assert_eq!(provider.display_name, "Zenmux");
605 assert_eq!(provider.api_key_env, "ZENMUX_API_KEY");
606 }
607
608 #[test]
609 fn test_get_provider_zai() {
610 let provider = get_provider("zai");
611 assert!(provider.is_some());
612 let provider = provider.unwrap();
613 assert_eq!(provider.display_name, "Z.AI (Zhipu)");
614 assert_eq!(provider.api_key_env, "ZAI_API_KEY");
615 }
616
617 #[test]
618 fn test_provider_api_urls_valid() {
619 let providers = all_providers();
620 for provider in providers {
621 assert!(
622 provider.api_url.starts_with("https://"),
623 "Provider {} API URL should use HTTPS",
624 provider.name
625 );
626 }
627 }
628
629 #[test]
630 fn test_provider_api_key_env_not_empty() {
631 let providers = all_providers();
632 for provider in providers {
633 assert!(
634 !provider.api_key_env.is_empty(),
635 "Provider {} should have API key env var",
636 provider.name
637 );
638 }
639 }
640
641 #[test]
642 fn test_parse_openrouter_models_with_pricing() {
643 let data = serde_json::json!({
644 "data": [
645 {
646 "id": "openai/gpt-4o",
647 "name": "GPT-4o",
648 "context_length": 128_000,
649 "pricing": {
650 "prompt": "0.000005",
651 "completion": "0.000015"
652 },
653 "architecture": {
654 "input_modalities": ["text", "image"],
655 "output_modalities": ["text"]
656 }
657 }
658 ]
659 });
660
661 let models = CachedModelRegistry::parse_openrouter_models(&data, "openrouter");
662 assert_eq!(models.len(), 1);
663 let m = &models[0];
664 assert_eq!(m.id, "openai/gpt-4o");
665 assert_eq!(m.is_free, Some(false));
666 let pricing = m.pricing.as_ref().expect("pricing should be present");
667 assert_eq!(pricing.prompt_per_token, Some(0.000_005));
668 assert_eq!(pricing.completion_per_token, Some(0.000_015));
669 assert!(m.capabilities.contains(&Capability::Vision));
670 }
671
672 #[test]
673 fn test_parse_openrouter_models_missing_capabilities() {
674 let data = serde_json::json!({
675 "data": [
676 {
677 "id": "some/text-only-model",
678 "name": "Text Only",
679 "context_length": 32000,
680 "pricing": {
681 "prompt": "0",
682 "completion": "0"
683 }
684 }
685 ]
686 });
687
688 let models = CachedModelRegistry::parse_openrouter_models(&data, "openrouter");
689 assert_eq!(models.len(), 1);
690 let m = &models[0];
691 assert!(
692 m.capabilities.is_empty(),
693 "no vision if architecture missing"
694 );
695 assert_eq!(m.is_free, Some(true));
696 }
697}