1use async_trait::async_trait;
26use secrecy::ExposeSecret;
27use serde::{Deserialize, Serialize};
28use std::path::PathBuf;
29use thiserror::Error;
30
31use crate::auth::TokenProvider;
32use crate::cache::FileCache;
33
34#[derive(Clone, Copy, Debug, PartialEq, Eq)]
36pub struct ProviderConfig {
37 pub name: &'static str,
39
40 pub display_name: &'static str,
42
43 pub api_url: &'static str,
45
46 pub api_key_env: &'static str,
48}
49
50pub static PROVIDERS: &[ProviderConfig] = &[
56 ProviderConfig {
57 name: "gemini",
58 display_name: "Google Gemini",
59 api_url: "https://generativelanguage.googleapis.com/v1beta/openai/chat/completions",
60 api_key_env: "GEMINI_API_KEY",
61 },
62 ProviderConfig {
63 name: "openrouter",
64 display_name: "OpenRouter",
65 api_url: "https://openrouter.ai/api/v1/chat/completions",
66 api_key_env: "OPENROUTER_API_KEY",
67 },
68 ProviderConfig {
69 name: "groq",
70 display_name: "Groq",
71 api_url: "https://api.groq.com/openai/v1/chat/completions",
72 api_key_env: "GROQ_API_KEY",
73 },
74 ProviderConfig {
75 name: "cerebras",
76 display_name: "Cerebras",
77 api_url: "https://api.cerebras.ai/v1/chat/completions",
78 api_key_env: "CEREBRAS_API_KEY",
79 },
80 ProviderConfig {
81 name: "zenmux",
82 display_name: "Zenmux",
83 api_url: "https://zenmux.ai/api/v1/chat/completions",
84 api_key_env: "ZENMUX_API_KEY",
85 },
86 ProviderConfig {
87 name: "zai",
88 display_name: "Z.AI (Zhipu)",
89 api_url: "https://api.z.ai/api/paas/v4/chat/completions",
90 api_key_env: "ZAI_API_KEY",
91 },
92];
93
94#[must_use]
114pub fn get_provider(name: &str) -> Option<&'static ProviderConfig> {
115 PROVIDERS.iter().find(|p| p.name == name)
116}
117
118#[must_use]
133pub fn all_providers() -> &'static [ProviderConfig] {
134 PROVIDERS
135}
136
137#[derive(Debug, Error)]
143pub enum RegistryError {
144 #[error("HTTP request failed: {0}")]
146 HttpError(String),
147
148 #[error("Failed to parse API response: {0}")]
150 ParseError(String),
151
152 #[error("Provider not found: {0}")]
154 ProviderNotFound(String),
155
156 #[error("Cache error: {0}")]
158 CacheError(String),
159
160 #[error("IO error: {0}")]
162 IoError(#[from] std::io::Error),
163
164 #[error("Invalid model ID: {model_id}")]
166 ModelValidation {
167 model_id: String,
169 },
170}
171
172#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
174#[serde(rename_all = "snake_case")]
175pub enum Capability {
176 Vision,
178 FunctionCalling,
180 Reasoning,
182}
183
184#[derive(Clone, Debug, Serialize, Deserialize)]
191pub struct PricingInfo {
192 pub prompt_per_token: Option<f64>,
194 pub completion_per_token: Option<f64>,
196}
197
198#[derive(Clone, Debug, Serialize, Deserialize)]
200pub struct CachedModel {
201 pub id: String,
203 pub name: Option<String>,
205 pub is_free: Option<bool>,
207 pub context_window: Option<u32>,
209 pub provider: String,
211 #[serde(default)]
213 pub capabilities: Vec<Capability>,
214 #[serde(default)]
216 pub pricing: Option<PricingInfo>,
217}
218
219#[async_trait]
221pub trait ModelRegistry: Send + Sync {
222 async fn list_models(&self, provider: &str) -> Result<Vec<CachedModel>, RegistryError>;
224
225 async fn model_exists(&self, provider: &str, model_id: &str) -> Result<bool, RegistryError>;
227
228 async fn validate_model(&self, provider: &str, model_id: &str) -> Result<(), RegistryError>;
230}
231
232pub struct CachedModelRegistry<'a> {
234 cache: crate::cache::FileCacheImpl<Vec<CachedModel>>,
235 client: reqwest::Client,
236 token_provider: &'a dyn TokenProvider,
237}
238
239impl CachedModelRegistry<'_> {
240 #[must_use]
248 pub fn new(
249 cache_dir: Option<PathBuf>,
250 ttl_seconds: u64,
251 token_provider: &dyn TokenProvider,
252 ) -> CachedModelRegistry<'_> {
253 let ttl = chrono::Duration::seconds(
254 ttl_seconds
255 .try_into()
256 .unwrap_or(crate::cache::DEFAULT_MODEL_TTL_SECS.cast_signed()),
257 );
258 CachedModelRegistry {
259 cache: crate::cache::FileCacheImpl::with_dir(cache_dir, "models", ttl),
260 client: reqwest::Client::builder()
261 .timeout(std::time::Duration::from_secs(10))
262 .build()
263 .unwrap_or_else(|_| reqwest::Client::new()),
264 token_provider,
265 }
266 }
267
268 fn parse_openrouter_models(data: &serde_json::Value, provider: &str) -> Vec<CachedModel> {
270 data.get("data")
271 .and_then(|d| d.as_array())
272 .map(|arr| {
273 arr.iter()
274 .filter_map(|m| {
275 let pricing_obj = m.get("pricing");
276 let prompt_per_token = pricing_obj
277 .and_then(|p| p.get("prompt"))
278 .and_then(|p| p.as_str())
279 .and_then(|s| s.parse::<f64>().ok());
280 let completion_per_token = pricing_obj
281 .and_then(|p| p.get("completion"))
282 .and_then(|p| p.as_str())
283 .and_then(|s| s.parse::<f64>().ok());
284
285 let is_free = match (prompt_per_token, completion_per_token) {
286 (Some(prompt), Some(completion)) => {
287 Some(prompt == 0.0 && completion == 0.0)
288 }
289 (Some(prompt), None) => Some(prompt == 0.0),
290 _ => pricing_obj
291 .and_then(|p| p.get("prompt"))
292 .and_then(|p| p.as_str())
293 .map(|p| p == "0"),
294 };
295
296 let pricing =
297 if prompt_per_token.is_some() || completion_per_token.is_some() {
298 Some(PricingInfo {
299 prompt_per_token,
300 completion_per_token,
301 })
302 } else {
303 None
304 };
305
306 let arch = m.get("architecture");
308 let capabilities = {
309 let from_input_modalities = arch
311 .and_then(|a| a.get("input_modalities"))
312 .and_then(|im| im.as_array())
313 .map(|arr| {
314 arr.iter().filter_map(|v| v.as_str()).any(|s| s == "image")
315 });
316 let from_modalities_str = arch
318 .and_then(|a| a.get("modalities"))
319 .and_then(|m| m.as_str())
320 .map(|s| s.contains("image"));
321
322 let has_vision = from_input_modalities
323 .or(from_modalities_str)
324 .unwrap_or(false);
325
326 if has_vision {
327 vec![Capability::Vision]
328 } else {
329 vec![]
330 }
331 };
332
333 Some(CachedModel {
334 id: m.get("id")?.as_str()?.to_string(),
335 name: m.get("name").and_then(|n| n.as_str()).map(String::from),
336 is_free,
337 context_window: m
338 .get("context_length")
339 .and_then(serde_json::Value::as_u64)
340 .and_then(|c| u32::try_from(c).ok()),
341 provider: provider.to_string(),
342 capabilities,
343 pricing,
344 })
345 })
346 .collect()
347 })
348 .unwrap_or_default()
349 }
350
351 fn parse_gemini_models(data: &serde_json::Value, provider: &str) -> Vec<CachedModel> {
353 data.get("models")
354 .and_then(|d| d.as_array())
355 .map(|arr| {
356 arr.iter()
357 .filter_map(|m| {
358 Some(CachedModel {
359 id: m.get("name")?.as_str()?.to_string(),
360 name: m
361 .get("displayName")
362 .and_then(|n| n.as_str())
363 .map(String::from),
364 is_free: None,
365 context_window: m
366 .get("inputTokenLimit")
367 .and_then(serde_json::Value::as_u64)
368 .and_then(|c| u32::try_from(c).ok()),
369 provider: provider.to_string(),
370 capabilities: vec![],
371 pricing: None,
372 })
373 })
374 .collect()
375 })
376 .unwrap_or_default()
377 }
378
379 fn parse_generic_models(data: &serde_json::Value, provider: &str) -> Vec<CachedModel> {
381 data.get("data")
382 .and_then(|d| d.as_array())
383 .map(|arr| {
384 arr.iter()
385 .filter_map(|m| {
386 Some(CachedModel {
387 id: m.get("id")?.as_str()?.to_string(),
388 name: None,
389 is_free: None,
390 context_window: None,
391 provider: provider.to_string(),
392 capabilities: vec![],
393 pricing: None,
394 })
395 })
396 .collect()
397 })
398 .unwrap_or_default()
399 }
400
401 async fn fetch_from_api(&self, provider: &str) -> Result<Vec<CachedModel>, RegistryError> {
403 let url = match provider {
404 "openrouter" => "https://openrouter.ai/api/v1/models",
405 "gemini" => "https://generativelanguage.googleapis.com/v1beta/models",
406 "groq" => "https://api.groq.com/openai/v1/models",
407 "cerebras" => "https://api.cerebras.ai/v1/models",
408 "zenmux" => "https://zenmux.ai/api/v1/models",
409 "zai" => "https://api.z.ai/api/paas/v4/models",
410 _ => return Err(RegistryError::ProviderNotFound(provider.to_string())),
411 };
412
413 let api_key = self.token_provider.ai_api_key(provider).ok_or_else(|| {
415 RegistryError::HttpError(format!("No API key available for {provider}"))
416 })?;
417
418 let request = match provider {
420 "gemini" => {
421 self.client
423 .get(url)
424 .query(&[("key", api_key.expose_secret())])
425 }
426 "openrouter" | "groq" | "cerebras" | "zenmux" | "zai" => {
427 self.client.get(url).header(
429 "Authorization",
430 format!("Bearer {}", api_key.expose_secret()),
431 )
432 }
433 _ => self.client.get(url),
434 };
435
436 let response = request
437 .send()
438 .await
439 .map_err(|e| RegistryError::HttpError(e.to_string()))?;
440
441 let data = response
442 .json::<serde_json::Value>()
443 .await
444 .map_err(|e| RegistryError::HttpError(e.to_string()))?;
445
446 let models = match provider {
448 "openrouter" => Self::parse_openrouter_models(&data, provider),
449 "gemini" => Self::parse_gemini_models(&data, provider),
450 "groq" | "cerebras" | "zenmux" | "zai" => Self::parse_generic_models(&data, provider),
451 _ => vec![],
452 };
453
454 Ok(models)
455 }
456}
457
458#[async_trait]
459impl ModelRegistry for CachedModelRegistry<'_> {
460 async fn list_models(&self, provider: &str) -> Result<Vec<CachedModel>, RegistryError> {
461 if let Ok(Some(models)) = self.cache.get(provider) {
463 return Ok(models);
464 }
465
466 match self.fetch_from_api(provider).await {
468 Ok(models) => {
469 let _ = self.cache.set(provider, &models);
471 Ok(models)
472 }
473 Err(api_error) => {
474 match self.cache.get_stale(provider) {
476 Ok(Some(models)) => {
477 tracing::warn!(
478 provider = provider,
479 error = %api_error,
480 "API request failed, returning stale cached models"
481 );
482 Ok(models)
483 }
484 _ => {
485 Err(api_error)
487 }
488 }
489 }
490 }
491 }
492
493 async fn model_exists(&self, provider: &str, model_id: &str) -> Result<bool, RegistryError> {
494 let models = self.list_models(provider).await?;
495 Ok(models.iter().any(|m| m.id == model_id))
496 }
497
498 async fn validate_model(&self, provider: &str, model_id: &str) -> Result<(), RegistryError> {
499 if self.model_exists(provider, model_id).await? {
500 Ok(())
501 } else {
502 Err(RegistryError::ModelValidation {
503 model_id: model_id.to_string(),
504 })
505 }
506 }
507}
508
509#[cfg(test)]
510mod tests {
511 use super::*;
512
513 #[test]
514 fn test_get_provider_gemini() {
515 let provider = get_provider("gemini");
516 assert!(provider.is_some());
517 let provider = provider.unwrap();
518 assert_eq!(provider.display_name, "Google Gemini");
519 assert_eq!(provider.api_key_env, "GEMINI_API_KEY");
520 }
521
522 #[test]
523 fn test_get_provider_openrouter() {
524 let provider = get_provider("openrouter");
525 assert!(provider.is_some());
526 let provider = provider.unwrap();
527 assert_eq!(provider.display_name, "OpenRouter");
528 assert_eq!(provider.api_key_env, "OPENROUTER_API_KEY");
529 }
530
531 #[test]
532 fn test_get_provider_groq() {
533 let provider = get_provider("groq");
534 assert!(provider.is_some());
535 let provider = provider.unwrap();
536 assert_eq!(provider.display_name, "Groq");
537 assert_eq!(provider.api_key_env, "GROQ_API_KEY");
538 }
539
540 #[test]
541 fn test_get_provider_cerebras() {
542 let provider = get_provider("cerebras");
543 assert!(provider.is_some());
544 let provider = provider.unwrap();
545 assert_eq!(provider.display_name, "Cerebras");
546 assert_eq!(provider.api_key_env, "CEREBRAS_API_KEY");
547 }
548
549 #[test]
550 fn test_get_provider_not_found() {
551 let provider = get_provider("nonexistent");
552 assert!(provider.is_none());
553 }
554
555 #[test]
556 fn test_get_provider_case_sensitive() {
557 let provider = get_provider("OpenRouter");
558 assert!(
559 provider.is_none(),
560 "Provider lookup should be case-sensitive"
561 );
562 }
563
564 #[test]
565 fn test_all_providers_count() {
566 let providers = all_providers();
567 assert_eq!(providers.len(), 6, "Should have exactly 6 providers");
568 }
569
570 #[test]
571 fn test_all_providers_have_unique_names() {
572 let providers = all_providers();
573 let mut names = Vec::new();
574 for provider in providers {
575 assert!(
576 !names.contains(&provider.name),
577 "Duplicate provider name: {}",
578 provider.name
579 );
580 names.push(provider.name);
581 }
582 }
583
584 #[test]
585 fn test_get_provider_zenmux() {
586 let provider = get_provider("zenmux");
587 assert!(provider.is_some());
588 let provider = provider.unwrap();
589 assert_eq!(provider.display_name, "Zenmux");
590 assert_eq!(provider.api_key_env, "ZENMUX_API_KEY");
591 }
592
593 #[test]
594 fn test_get_provider_zai() {
595 let provider = get_provider("zai");
596 assert!(provider.is_some());
597 let provider = provider.unwrap();
598 assert_eq!(provider.display_name, "Z.AI (Zhipu)");
599 assert_eq!(provider.api_key_env, "ZAI_API_KEY");
600 }
601
602 #[test]
603 fn test_provider_api_urls_valid() {
604 let providers = all_providers();
605 for provider in providers {
606 assert!(
607 provider.api_url.starts_with("https://"),
608 "Provider {} API URL should use HTTPS",
609 provider.name
610 );
611 }
612 }
613
614 #[test]
615 fn test_provider_api_key_env_not_empty() {
616 let providers = all_providers();
617 for provider in providers {
618 assert!(
619 !provider.api_key_env.is_empty(),
620 "Provider {} should have API key env var",
621 provider.name
622 );
623 }
624 }
625
626 #[test]
627 fn test_parse_openrouter_models_with_pricing() {
628 let data = serde_json::json!({
629 "data": [
630 {
631 "id": "openai/gpt-4o",
632 "name": "GPT-4o",
633 "context_length": 128000,
634 "pricing": {
635 "prompt": "0.000005",
636 "completion": "0.000015"
637 },
638 "architecture": {
639 "input_modalities": ["text", "image"],
640 "output_modalities": ["text"]
641 }
642 }
643 ]
644 });
645
646 let models = CachedModelRegistry::parse_openrouter_models(&data, "openrouter");
647 assert_eq!(models.len(), 1);
648 let m = &models[0];
649 assert_eq!(m.id, "openai/gpt-4o");
650 assert_eq!(m.is_free, Some(false));
651 let pricing = m.pricing.as_ref().expect("pricing should be present");
652 assert_eq!(pricing.prompt_per_token, Some(0.000_005));
653 assert_eq!(pricing.completion_per_token, Some(0.000_015));
654 assert!(m.capabilities.contains(&Capability::Vision));
655 }
656
657 #[test]
658 fn test_parse_openrouter_models_missing_capabilities() {
659 let data = serde_json::json!({
660 "data": [
661 {
662 "id": "some/text-only-model",
663 "name": "Text Only",
664 "context_length": 32000,
665 "pricing": {
666 "prompt": "0",
667 "completion": "0"
668 }
669 }
670 ]
671 });
672
673 let models = CachedModelRegistry::parse_openrouter_models(&data, "openrouter");
674 assert_eq!(models.len(), 1);
675 let m = &models[0];
676 assert!(
677 m.capabilities.is_empty(),
678 "no vision if architecture missing"
679 );
680 assert_eq!(m.is_free, Some(true));
681 }
682}