use std::collections::HashMap;
use std::sync::{Mutex, OnceLock};
use std::time::{Duration, Instant};
use serde::Deserialize;
use crate::auth::AuthStorage;
use crate::config::Config;
use crate::error::{Error, Result};
use crate::http::client::Client;
use crate::models::{ModelRegistry, default_models_path};
use crate::provider_metadata::{
ProviderRoutingDefaults, canonical_provider_id, provider_routing_defaults,
};
pub const MODEL_CACHE_TTL: Duration = Duration::from_secs(5 * 60);
pub const DISABLE_CACHE_ENV: &str = "PI_DISABLE_MODEL_CACHE";
#[derive(Debug, Clone)]
struct CacheEntry {
models: Vec<String>,
inserted: Instant,
}
fn cache() -> &'static Mutex<HashMap<String, CacheEntry>> {
static CACHE: OnceLock<Mutex<HashMap<String, CacheEntry>>> = OnceLock::new();
CACHE.get_or_init(|| Mutex::new(HashMap::new()))
}
fn cache_disabled() -> bool {
std::env::var(DISABLE_CACHE_ENV).is_ok_and(|raw| {
matches!(
raw.trim().to_ascii_lowercase().as_str(),
"1" | "true" | "yes" | "on"
)
})
}
fn cache_key(provider: &str) -> String {
canonical_provider_id(provider)
.unwrap_or(provider.trim())
.to_ascii_lowercase()
}
fn cache_lookup(key: &str) -> Option<Vec<String>> {
let guard = cache().lock().ok()?;
let entry = guard.get(key)?;
if entry.inserted.elapsed() < MODEL_CACHE_TTL {
Some(entry.models.clone())
} else {
None
}
}
fn cache_store(key: String, models: Vec<String>) {
if let Ok(mut guard) = cache().lock() {
guard.insert(
key,
CacheEntry {
models,
inserted: Instant::now(),
},
);
}
}
pub fn clear_model_cache() {
if let Ok(mut guard) = cache().lock() {
guard.clear();
}
}
pub async fn fetch_provider_models(provider: &str, api_key: &str) -> Result<Vec<String>> {
let key = cache_key(provider);
if !cache_disabled() {
if let Some(cached) = cache_lookup(&key) {
tracing::debug!(provider = %key, count = cached.len(), "model cache hit");
return Ok(cached);
}
}
fetch_and_cache(provider, &key, api_key).await
}
pub async fn refresh_provider_models(provider: &str, api_key: &str) -> Result<Vec<String>> {
let key = cache_key(provider);
fetch_and_cache(provider, &key, api_key).await
}
async fn fetch_and_cache(provider: &str, key: &str, api_key: &str) -> Result<Vec<String>> {
match fetch_live_models(provider, api_key).await {
Ok(live) if !live.is_empty() => {
if !cache_disabled() {
cache_store(key.to_string(), live.clone());
}
Ok(live)
}
Ok(_) => {
tracing::warn!(
provider = %key,
"live model fetch returned empty list; falling back to static registry (not cached)"
);
Ok(static_registry_models(provider))
}
Err(err) => {
tracing::warn!(
provider = %key,
error = %err,
"live model fetch failed; falling back to static registry (not cached)"
);
Ok(static_registry_models(provider))
}
}
}
pub fn static_registry_models(provider: &str) -> Vec<String> {
let Ok(auth) = AuthStorage::load(Config::auth_path()) else {
return Vec::new();
};
let models_path = Some(default_models_path(&Config::global_dir()));
let registry = ModelRegistry::load_for_listing(&auth, models_path);
let canonical = canonical_provider_id(provider).unwrap_or(provider);
let mut ids: Vec<String> = registry
.models()
.iter()
.filter(|entry| {
let entry_provider = entry.model.provider.as_str();
entry_provider.eq_ignore_ascii_case(provider)
|| entry_provider.eq_ignore_ascii_case(canonical)
|| canonical_provider_id(entry_provider)
.is_some_and(|c| c.eq_ignore_ascii_case(canonical))
})
.map(|entry| entry.model.id.clone())
.collect();
ids.sort();
ids.dedup();
ids
}
#[derive(Debug, Deserialize)]
struct OpenAiModelsResponse {
data: Vec<OpenAiModelRow>,
}
#[derive(Debug, Deserialize)]
struct OpenAiModelRow {
id: String,
}
async fn fetch_live_models(provider: &str, api_key: &str) -> Result<Vec<String>> {
if api_key.trim().is_empty() {
return Err(Error::api(
"no api_key supplied; skipping live provider model fetch",
));
}
let defaults = provider_routing_defaults(provider).ok_or_else(|| {
Error::api(format!(
"provider {provider:?} has no routing defaults; cannot fetch /v1/models"
))
})?;
let url = openai_compat_models_url(&defaults).ok_or_else(|| {
Error::api(format!(
"provider {provider:?} base_url ({}) is not OpenAI-compatible /v1; \
add a custom branch in fetch_live_models to support its catalog endpoint",
defaults.base_url
))
})?;
let client = Client::new();
let request = client
.get(&url)
.header("Authorization", format!("Bearer {}", api_key.trim()))
.header("Accept", "application/json")
.timeout(Duration::from_secs(15));
let response = request.send().await?;
let status = response.status();
if !(200..300).contains(&status) {
let body = response.text().await.unwrap_or_default();
let snippet: String = body.chars().take(200).collect();
return Err(Error::api(format!(
"provider {provider:?} returned HTTP {status} from {url}: {snippet}"
)));
}
let body = response.text().await?;
let parsed: OpenAiModelsResponse = serde_json::from_str(&body).map_err(|err| {
Error::api(format!(
"failed to parse /v1/models response for {provider:?}: {err}"
))
})?;
let mut ids: Vec<String> = parsed
.data
.into_iter()
.map(|row| row.id)
.filter(|id| !id.trim().is_empty())
.collect();
ids.sort();
ids.dedup();
Ok(ids)
}
fn openai_compat_models_url(defaults: &ProviderRoutingDefaults) -> Option<String> {
let base = defaults.base_url.trim_end_matches('/');
if base.is_empty() {
return None;
}
if base.ends_with("/messages") || base.contains("/v1beta") || base.contains("googleapis.com") {
return None;
}
Some(format!("{base}/models"))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cache_key_canonicalizes_aliases() {
assert_eq!(cache_key("OpenAI"), "openai");
assert_eq!(cache_key("openai"), "openai");
}
#[test]
fn openai_compat_url_for_openai() {
let defaults = provider_routing_defaults("openai").expect("openai defaults");
let url = openai_compat_models_url(&defaults).expect("openai is openai-compatible");
assert_eq!(url, "https://api.openai.com/v1/models");
}
#[test]
fn openai_compat_url_for_groq() {
let defaults = provider_routing_defaults("groq").expect("groq defaults");
let url = openai_compat_models_url(&defaults).expect("groq is openai-compatible");
assert_eq!(url, "https://api.groq.com/openai/v1/models");
}
#[test]
fn openai_compat_url_for_openrouter() {
let defaults = provider_routing_defaults("openrouter").expect("openrouter defaults");
let url = openai_compat_models_url(&defaults).expect("openrouter is openai-compatible");
assert_eq!(url, "https://openrouter.ai/api/v1/models");
}
#[test]
fn openai_compat_url_rejects_anthropic_messages_endpoint() {
let defaults = provider_routing_defaults("anthropic").expect("anthropic defaults");
assert!(openai_compat_models_url(&defaults).is_none());
}
#[test]
fn empty_api_key_short_circuits() {
let rt = asupersync::runtime::RuntimeBuilder::current_thread()
.build()
.expect("runtime");
let err = rt.block_on(fetch_live_models("openai", " ")).unwrap_err();
let msg = err.to_string();
assert!(msg.contains("api_key"), "unexpected error: {msg}");
}
#[test]
fn cache_round_trip_respects_ttl() {
clear_model_cache();
let key = cache_key("openai");
assert!(cache_lookup(&key).is_none(), "starts empty");
cache_store(key.clone(), vec!["m-1".to_string(), "m-2".to_string()]);
let hit = cache_lookup(&key).expect("fresh entry");
assert_eq!(hit, vec!["m-1".to_string(), "m-2".to_string()]);
clear_model_cache();
assert!(cache_lookup(&key).is_none(), "cleared");
}
}