use std::collections::BTreeMap;
use std::future::Future;
use std::io::Read;
use std::path::{Path, PathBuf};
use std::pin::Pin;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use tokio::sync::broadcast;
use crate::error::SdkResult;
use crate::ports::catalog::{
CatalogEvent, CatalogModelEntry, CatalogProtocol, CatalogProviderEntry, CatalogSource,
ModelCatalog, RefreshOutcome,
};
const DEFAULT_MTIME_WINDOW: Duration = Duration::from_secs(60 * 60);
const FETCH_TIMEOUT: Duration = Duration::from_secs(10);
const FETCH_RETRIES: u32 = 2;
const RETRY_BACKOFF: Duration = Duration::from_millis(200);
const DEFAULT_URL: &str = "https://models.dev";
const USER_AGENT: &str = concat!("oxi-sdk/", env!("CARGO_PKG_VERSION"));
const BROADCAST_CAPACITY: usize = 16;
#[derive(Debug, Clone)]
pub struct CatalogConfig {
pub cache_path: PathBuf,
pub etag_path: PathBuf,
pub override_path: PathBuf,
pub mtime_window: Duration,
pub fetch_enabled: bool,
pub models_dev_url: String,
pub user_agent: String,
pub local_discovery_urls: Vec<String>,
pub snapshot_path: PathBuf,
}
impl Default for CatalogConfig {
fn default() -> Self {
let home = crate::ports::fs::path::home_dir().unwrap_or_else(|_| PathBuf::from(".oxi"));
let cache = home.join("cache");
let catalog_dir = home.join("catalog");
Self {
cache_path: cache.join("models-dev.json"),
etag_path: cache.join("models-dev.json.etag"),
override_path: catalog_dir.join("overrides.toml"),
mtime_window: DEFAULT_MTIME_WINDOW,
fetch_enabled: true,
models_dev_url: DEFAULT_URL.to_string(),
user_agent: USER_AGENT.to_string(),
local_discovery_urls: Vec::new(),
snapshot_path: home.join("cache").join("models-dev.json"),
}
}
}
#[derive(Debug, Default, Serialize, Deserialize)]
pub(crate) struct MdCatalog(pub BTreeMap<String, MdProvider>);
#[derive(Debug, Serialize, Deserialize)]
pub(crate) struct MdProvider {
pub name: String,
pub env: Vec<String>,
#[serde(default)]
pub npm: Option<String>,
#[serde(default)]
pub api: Option<String>,
#[serde(default)]
pub doc: Option<String>,
pub models: BTreeMap<String, MdModel>,
}
#[derive(Debug, Serialize, Deserialize)]
pub(crate) struct MdModel {
pub name: String,
#[serde(default)]
pub family: Option<String>,
pub reasoning: bool,
#[serde(default)]
pub tool_call: bool,
#[serde(default)]
pub attachment: bool,
#[serde(default)]
pub temperature: Option<bool>,
#[serde(default)]
pub structured_output: Option<bool>,
#[serde(default)]
pub knowledge: Option<String>,
#[serde(default)]
pub release_date: Option<String>,
#[serde(default)]
pub last_updated: Option<String>,
#[serde(default)]
pub open_weights: Option<bool>,
#[serde(default)]
pub interleaved: Option<serde_json::Value>,
#[serde(default)]
pub reasoning_options: Option<Vec<serde_json::Value>>,
pub limit: MdLimit,
#[serde(default)]
pub cost: Option<MdCost>,
#[serde(default)]
pub modalities: Option<MdModalities>,
#[serde(default)]
pub status: Option<String>,
#[serde(default)]
pub provider: Option<MdModelProvider>,
}
#[derive(Debug, Serialize, Deserialize)]
pub(crate) struct MdModelProvider {
#[serde(default)]
pub npm: Option<String>,
#[serde(default)]
pub api: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub(crate) struct MdLimit {
pub context: f64,
#[serde(default)]
pub input: Option<f64>,
pub output: f64,
}
#[derive(Debug, Serialize, Deserialize)]
pub(crate) struct MdCost {
pub input: f64,
pub output: f64,
#[serde(default)]
pub cache_read: Option<f64>,
#[serde(default)]
pub cache_write: Option<f64>,
#[serde(default)]
pub tiers: Option<Vec<serde_json::Value>>,
#[serde(default)]
pub context_over_200k: Option<serde_json::Value>,
#[serde(default)]
pub reasoning: Option<f64>,
#[serde(default)]
pub input_audio: Option<f64>,
#[serde(default)]
pub output_audio: Option<f64>,
}
#[derive(Debug, Serialize, Deserialize)]
pub(crate) struct MdModalities {
#[serde(default)]
pub input: Option<Vec<String>>,
#[serde(default)]
pub output: Option<Vec<String>>,
}
#[derive(Debug, Default, Serialize, Deserialize)]
pub(crate) struct OverrideFile {
#[serde(default)]
pub provider: Vec<OverrideProvider>,
#[serde(default)]
pub model: Vec<OverrideModel>,
}
#[derive(Debug, Serialize, Deserialize)]
pub(crate) struct OverrideProvider {
pub id: String,
#[serde(default)]
pub display_name: Option<String>,
#[serde(default)]
pub base_url: Option<String>,
#[serde(default)]
pub env_key: Option<String>,
#[serde(default)]
pub extra_headers: Vec<(String, String)>,
#[serde(default)]
pub enabled: Option<bool>,
}
#[derive(Debug, Serialize, Deserialize)]
pub(crate) struct OverrideModel {
pub provider: String,
pub id: String,
#[serde(default)]
pub name: Option<String>,
#[serde(default)]
pub cost_input: Option<f64>,
#[serde(default)]
pub cost_output: Option<f64>,
#[serde(default)]
pub context_window: Option<u32>,
#[serde(default)]
pub max_tokens: Option<u32>,
}
fn load_snapshot() -> Option<MdCatalog> {
let compressed: &[u8] = include_bytes!("../../../../oxi-ai/data/catalog/_snapshot.json.gz");
let mut decoder = flate2::read::GzDecoder::new(compressed);
let mut json = String::new();
decoder.read_to_string(&mut json).ok()?;
serde_json::from_str::<MdCatalog>(&json).ok()
}
pub(crate) fn protocol_for(npm: &str) -> CatalogProtocol {
match npm {
"@ai-sdk/anthropic" => CatalogProtocol::AnthropicMessages,
"@ai-sdk/google" => CatalogProtocol::GoogleGenerativeAi,
"@ai-sdk/google-vertex" | "@ai-sdk/google-vertex/anthropic" => {
CatalogProtocol::GoogleVertex
}
"@ai-sdk/mistral" => CatalogProtocol::MistralConversations,
"@ai-sdk/azure" => CatalogProtocol::AzureOpenAiResponses,
"@ai-sdk/amazon-bedrock" => CatalogProtocol::BedrockConverseStream,
"@ai-sdk/openai" | "@ai-sdk/openai-compatible" => CatalogProtocol::OpenAiCompletions,
_ => CatalogProtocol::OpenAiCompatible,
}
}
pub(crate) fn materialize(
catalog: &MdCatalog,
user_overrides: &OverrideFile,
) -> (
Vec<CatalogProviderEntry>,
BTreeMap<String, Vec<CatalogModelEntry>>,
) {
let mut providers = Vec::new();
let mut models: BTreeMap<String, Vec<CatalogModelEntry>> = BTreeMap::new();
for (pid, mdprov) in &catalog.0 {
let provider_protocol = protocol_for(mdprov.npm.as_deref().unwrap_or(""));
providers.push(CatalogProviderEntry {
id: pid.clone(),
display_name: mdprov.name.clone(),
aliases: Vec::new(),
protocol: provider_protocol,
env_key: mdprov.env.first().cloned(),
extra_env_keys: mdprov.env.get(1..).unwrap_or(&[]).to_vec(),
base_url: mdprov.api.clone(),
extra_headers: Vec::new(),
category: String::new(),
description: String::new(),
default_enabled: true,
});
for (mid, mdmodel) in &mdprov.models {
let model_prov = mdmodel.provider.as_ref();
let model_npm = model_prov
.and_then(|p| p.npm.as_deref())
.unwrap_or_else(|| mdprov.npm.as_deref().unwrap_or(""));
let model_protocol = protocol_for(model_npm);
let model_base_url = model_prov
.and_then(|p| p.api.clone())
.filter(|s| !s.is_empty());
models
.entry(pid.clone())
.or_default()
.push(CatalogModelEntry {
provider: pid.clone(),
model_id: mid.clone(),
name: mdmodel.name.clone(),
protocol: model_protocol,
source: CatalogSource::Embedded,
base_url: model_base_url,
reasoning: mdmodel.reasoning,
supports_vision: mdmodel.attachment,
cost_input: mdmodel.cost.as_ref().map(|c| c.input).unwrap_or(0.0),
cost_output: mdmodel.cost.as_ref().map(|c| c.output).unwrap_or(0.0),
cost_cache_read: mdmodel
.cost
.as_ref()
.and_then(|c| c.cache_read)
.unwrap_or(0.0),
cost_cache_write: mdmodel
.cost
.as_ref()
.and_then(|c| c.cache_write)
.unwrap_or(0.0),
context_window: mdmodel.limit.context as u32,
max_tokens: mdmodel.limit.output as u32,
input_modalities: normalize_modalities(&mdmodel.modalities),
release_date: mdmodel.release_date.clone(),
status: mdmodel.status.clone(),
});
}
}
apply_user_overrides(&mut providers, &mut models, user_overrides);
(providers, models)
}
fn normalize_modalities(md: &Option<MdModalities>) -> Vec<String> {
match md {
Some(m) => match &m.input {
Some(input) if !input.is_empty() => input.clone(),
_ => vec!["text".to_string()],
},
None => vec!["text".to_string()],
}
}
fn apply_user_overrides(
providers: &mut Vec<CatalogProviderEntry>,
models: &mut BTreeMap<String, Vec<CatalogModelEntry>>,
overrides: &OverrideFile,
) {
for ovr in &overrides.provider {
if let Some(slot) = providers.iter_mut().find(|p| p.id == ovr.id) {
if let Some(d) = &ovr.display_name {
slot.display_name = d.clone();
}
if let Some(b) = &ovr.base_url {
slot.base_url = Some(b.clone());
}
if let Some(k) = &ovr.env_key {
slot.env_key = Some(k.clone());
}
slot.extra_headers = ovr.extra_headers.clone();
if let Some(en) = ovr.enabled {
slot.default_enabled = en;
}
} else {
providers.push(CatalogProviderEntry {
id: ovr.id.clone(),
display_name: ovr.display_name.clone().unwrap_or_else(|| ovr.id.clone()),
aliases: Vec::new(),
protocol: CatalogProtocol::OpenAiCompatible,
env_key: ovr.env_key.clone(),
extra_env_keys: Vec::new(),
base_url: ovr.base_url.clone(),
extra_headers: ovr.extra_headers.clone(),
category: String::new(),
description: String::new(),
default_enabled: ovr.enabled.unwrap_or(true),
});
}
}
for ovr in &overrides.model {
let entry = CatalogModelEntry {
provider: ovr.provider.clone(),
model_id: ovr.id.clone(),
name: ovr.name.clone().unwrap_or_else(|| ovr.id.clone()),
protocol: CatalogProtocol::OpenAiCompatible,
source: CatalogSource::Override,
base_url: None,
reasoning: false,
supports_vision: false,
cost_input: ovr.cost_input.unwrap_or(0.0),
cost_output: ovr.cost_output.unwrap_or(0.0),
cost_cache_read: 0.0,
cost_cache_write: 0.0,
context_window: ovr.context_window.unwrap_or(0),
max_tokens: ovr.max_tokens.unwrap_or(0),
input_modalities: vec!["text".to_string()],
release_date: None,
status: None,
};
let list = models.entry(ovr.provider.clone()).or_default();
if let Some(slot) = list.iter_mut().find(|m| m.model_id == ovr.id) {
if let Some(n) = ovr.name.clone() {
slot.name = n;
}
if let Some(c) = ovr.cost_input {
slot.cost_input = c;
}
if let Some(c) = ovr.cost_output {
slot.cost_output = c;
}
if let Some(c) = ovr.context_window {
slot.context_window = c;
}
if let Some(m) = ovr.max_tokens {
slot.max_tokens = m;
}
slot.source = CatalogSource::Override;
} else {
list.push(entry);
}
}
}
struct Snapshot {
providers: Vec<CatalogProviderEntry>,
models: BTreeMap<String, BTreeMap<String, CatalogModelEntry>>,
}
impl Snapshot {
fn empty() -> Self {
Self {
providers: Vec::new(),
models: BTreeMap::new(),
}
}
fn stats(&self) -> (usize, usize) {
let model_count = self.models.values().map(|m| m.len()).sum();
(self.providers.len(), model_count)
}
}
pub struct FileModelCatalog {
state: Arc<RwLock<Snapshot>>,
tx: broadcast::Sender<CatalogEvent>,
config: CatalogConfig,
}
impl std::fmt::Debug for FileModelCatalog {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let snap = self.state.read();
let (providers, models) = snap.stats();
f.debug_struct("FileModelCatalog")
.field("providers", &providers)
.field("models", &models)
.field("fetch_enabled", &self.config.fetch_enabled)
.finish_non_exhaustive()
}
}
impl FileModelCatalog {
pub async fn init(config: CatalogConfig) -> std::io::Result<Arc<Self>> {
let (tx, _) = broadcast::channel(BROADCAST_CAPACITY);
let cat = Arc::new(Self {
state: Arc::new(RwLock::new(Snapshot::empty())),
tx,
config,
});
cat.load_snapshot_internal();
if cat.try_load_fresh_cache().await.is_none() {
tracing::debug!("catalog: cache stale or missing");
}
cat.apply_user_overrides_internal();
cat.discover_local_all().await;
if cat.config.fetch_enabled && !cat.is_cache_fresh_internal() {
let _ = cat.refresh().await;
}
Ok(cat)
}
#[allow(dead_code)]
pub(crate) fn tx(&self) -> &broadcast::Sender<CatalogEvent> {
&self.tx
}
fn load_snapshot_internal(&self) {
let Some(md) = load_snapshot() else {
tracing::warn!("catalog: embedded SNAP missing or corrupt");
return;
};
let overrides = OverrideFile::default();
let (providers, models) = materialize(&md, &overrides);
let mut snap = self.state.write();
snap.providers = providers;
snap.models = models
.into_iter()
.map(|(pid, list)| {
let map = list.into_iter().map(|e| (e.model_id.clone(), e)).collect();
(pid, map)
})
.collect();
}
async fn try_load_fresh_cache(&self) -> Option<()> {
let path = self.config.cache_path.clone();
let window = self.config.mtime_window;
let res = tokio::task::spawn_blocking(move || read_cache_if_fresh(&path, window))
.await
.ok()
.flatten();
match res {
Some(catalog) => {
let overrides = OverrideFile::default();
let (providers, models) = materialize(&catalog, &overrides);
let mut snap = self.state.write();
snap.providers = providers;
snap.models = models
.into_iter()
.map(|(pid, list)| {
(
pid,
list.into_iter().map(|e| (e.model_id.clone(), e)).collect(),
)
})
.collect();
Some(())
}
None => None,
}
}
fn is_cache_fresh_internal(&self) -> bool {
let meta = match std::fs::metadata(&self.config.cache_path) {
Ok(m) => m,
Err(_) => return false,
};
let modified = match meta.modified() {
Ok(t) => t,
Err(_) => return false,
};
let age = match SystemTime::now().duration_since(modified) {
Ok(d) => d,
Err(_) => return false,
};
age <= self.config.mtime_window
}
fn apply_user_overrides_internal(&self) {
let Ok(body) = std::fs::read_to_string(&self.config.override_path) else {
return;
};
let Ok(overrides) = toml::from_str::<OverrideFile>(&body) else {
tracing::warn!("catalog: invalid override TOML, ignoring");
return;
};
let mut snap = self.state.write();
let mut providers = snap.providers.clone();
let mut models_map = snap.models.clone();
for ovr in &overrides.provider {
if let Some(slot) = providers.iter_mut().find(|p| p.id == ovr.id) {
if let Some(d) = &ovr.display_name {
slot.display_name = d.clone();
}
if let Some(b) = &ovr.base_url {
slot.base_url = Some(b.clone());
}
if let Some(k) = &ovr.env_key {
slot.env_key = Some(k.clone());
}
slot.extra_headers = ovr.extra_headers.clone();
if let Some(en) = ovr.enabled {
slot.default_enabled = en;
}
}
}
for ovr in &overrides.model {
let entry = CatalogModelEntry {
provider: ovr.provider.clone(),
model_id: ovr.id.clone(),
name: ovr.name.clone().unwrap_or_else(|| ovr.id.clone()),
protocol: CatalogProtocol::OpenAiCompatible,
source: CatalogSource::Override,
base_url: None,
reasoning: false,
supports_vision: false,
cost_input: ovr.cost_input.unwrap_or(0.0),
cost_output: ovr.cost_output.unwrap_or(0.0),
cost_cache_read: 0.0,
cost_cache_write: 0.0,
context_window: ovr.context_window.unwrap_or(0),
max_tokens: ovr.max_tokens.unwrap_or(0),
input_modalities: vec!["text".to_string()],
release_date: None,
status: None,
};
let inner = models_map.entry(ovr.provider.clone()).or_default();
if let Some((_, slot)) = inner.iter_mut().find(|(_, m)| m.model_id == ovr.id) {
if let Some(n) = ovr.name.clone() {
slot.name = n;
}
if let Some(c) = ovr.cost_input {
slot.cost_input = c;
}
if let Some(c) = ovr.cost_output {
slot.cost_output = c;
}
if let Some(c) = ovr.context_window {
slot.context_window = c;
}
if let Some(m) = ovr.max_tokens {
slot.max_tokens = m;
}
slot.source = CatalogSource::Override;
} else {
inner.insert(ovr.id.clone(), entry);
}
}
snap.providers = providers;
snap.models = models_map;
let _ = self.tx.send(CatalogEvent::OverrideApplied {
path: self.config.override_path.clone(),
provider_overrides: overrides.provider.len(),
model_overrides: overrides.model.len(),
});
}
async fn discover_local_all(&self) {
if self.config.local_discovery_urls.is_empty() {
return;
}
let urls = self.config.local_discovery_urls.clone();
for base in urls {
match fetch_local_models(&base).await {
Ok(entries) if !entries.is_empty() => {
let count = entries.len();
let mut snap = self.state.write();
for entry in entries {
let inner = snap.models.entry(entry.provider.clone()).or_default();
inner.insert(entry.model_id.clone(), entry);
}
let _ = self.tx.send(CatalogEvent::LocalDiscovered {
base_url: base,
model_count: count,
});
}
Ok(_) => {}
Err(e) => {
tracing::debug!(error = %e, base = %base, "local discovery failed");
}
}
}
}
}
fn read_cache_if_fresh(path: &Path, window: Duration) -> Option<MdCatalog> {
let meta = std::fs::metadata(path).ok()?;
let modified = meta.modified().ok()?;
let age = SystemTime::now().duration_since(modified).ok()?;
if age > window {
return None;
}
let body = std::fs::read_to_string(path).ok()?;
match serde_json::from_str::<MdCatalog>(&body) {
Ok(c) => Some(c),
Err(e) => {
tracing::warn!(error = %e, "cache corrupt, ignoring");
let _ = std::fs::remove_file(path);
None
}
}
}
enum FetchResult {
Updated(MdCatalog),
NotModified,
}
async fn fetch_conditional(url: &str, etag: Option<&str>, user_agent: &str) -> Option<FetchResult> {
let client = reqwest::Client::builder()
.timeout(FETCH_TIMEOUT)
.build()
.ok()?;
let full = format!("{}/api.json", url.trim_end_matches('/'));
for attempt in 0..FETCH_RETRIES {
let mut req = client.get(&full).header("User-Agent", user_agent);
if let Some(e) = etag {
req = req.header("If-None-Match", e);
}
match req.send().await {
Ok(resp) => {
let status = resp.status();
if status.as_u16() == 304 {
return Some(FetchResult::NotModified);
}
if status.is_success() {
let body = resp.text().await.ok()?;
return serde_json::from_str::<MdCatalog>(&body)
.ok()
.map(FetchResult::Updated);
}
}
Err(e) => {
tracing::warn!(error = %e, attempt, "fetch failed");
}
}
if attempt + 1 < FETCH_RETRIES {
tokio::time::sleep(RETRY_BACKOFF).await;
}
}
None
}
async fn fetch_local_models(base_url: &str) -> std::io::Result<Vec<CatalogModelEntry>> {
let client = reqwest::Client::builder()
.timeout(FETCH_TIMEOUT)
.build()
.map_err(io_err)?;
let url = format!("{}/v1/models", base_url.trim_end_matches('/'));
#[derive(Deserialize)]
struct Resp {
data: Vec<LocalModel>,
}
#[derive(Deserialize)]
struct LocalModel {
id: String,
}
let resp = client
.get(&url)
.send()
.await
.map_err(io_err)?
.json::<Resp>()
.await
.map_err(io_err)?;
let provider_id = derive_local_provider(base_url);
let entries = resp
.data
.into_iter()
.map(|m| CatalogModelEntry {
provider: provider_id.clone(),
model_id: m.id.clone(),
name: m.id,
protocol: CatalogProtocol::OpenAiCompatible,
source: CatalogSource::Local,
base_url: Some(base_url.trim_end_matches('/').to_string()),
reasoning: false,
supports_vision: false,
cost_input: 0.0,
cost_output: 0.0,
cost_cache_read: 0.0,
cost_cache_write: 0.0,
context_window: 0,
max_tokens: 0,
input_modalities: vec!["text".to_string()],
release_date: None,
status: None,
})
.collect();
Ok(entries)
}
fn derive_local_provider(base_url: &str) -> String {
let trimmed = base_url
.trim_start_matches("http://")
.trim_start_matches("https://");
let host = trimmed.split(':').next().unwrap_or("local");
if host.is_empty() {
"local".to_string()
} else {
host.to_string()
}
}
fn io_err<E: std::fmt::Display>(e: E) -> std::io::Error {
std::io::Error::other(e.to_string())
}
impl ModelCatalog for FileModelCatalog {
fn list_providers(&self) -> Pin<Box<dyn Future<Output = SdkResult<Vec<String>>> + Send + '_>> {
let snap = self.state.read();
let mut ids: Vec<String> = snap.providers.iter().map(|p| p.id.clone()).collect();
ids.sort();
Box::pin(async move { Ok(ids) })
}
fn get_provider(
&self,
provider_id: &str,
) -> Pin<Box<dyn Future<Output = SdkResult<Option<CatalogProviderEntry>>> + Send + '_>> {
let snap = self.state.read();
let entry = snap.providers.iter().find(|p| p.id == provider_id).cloned();
Box::pin(async move { Ok(entry) })
}
fn list_models(
&self,
provider_id: &str,
) -> Pin<Box<dyn Future<Output = SdkResult<Vec<CatalogModelEntry>>> + Send + '_>> {
let snap = self.state.read();
let list = snap
.models
.get(provider_id)
.map(|m| m.values().cloned().collect())
.unwrap_or_default();
Box::pin(async move { Ok(list) })
}
fn get_model(
&self,
provider_id: &str,
model_id: &str,
) -> Pin<Box<dyn Future<Output = SdkResult<Option<CatalogModelEntry>>> + Send + '_>> {
let snap = self.state.read();
let entry = snap
.models
.get(provider_id)
.and_then(|m| m.get(model_id))
.cloned();
Box::pin(async move { Ok(entry) })
}
fn search(
&self,
pattern: &str,
) -> Pin<Box<dyn Future<Output = SdkResult<Vec<CatalogModelEntry>>> + Send + '_>> {
let snap = self.state.read();
let lower = pattern.to_lowercase();
let out: Vec<CatalogModelEntry> = snap
.models
.values()
.flat_map(|m| m.values())
.filter(|e| {
e.model_id.to_lowercase().contains(&lower)
|| e.name.to_lowercase().contains(&lower)
|| e.provider.to_lowercase().contains(&lower)
})
.cloned()
.collect();
Box::pin(async move { Ok(out) })
}
fn model_count(&self) -> Pin<Box<dyn Future<Output = SdkResult<usize>> + Send + '_>> {
let snap = self.state.read();
let count: usize = snap.models.values().map(|m| m.len()).sum();
Box::pin(async move { Ok(count) })
}
fn refresh(&self) -> Pin<Box<dyn Future<Output = SdkResult<RefreshOutcome>> + Send + '_>> {
let state = Arc::clone(&self.state);
let tx = self.tx.clone();
let config = self.config.clone();
Box::pin(async move {
if !config.fetch_enabled {
return Ok(RefreshOutcome::Offline {
reason: "fetch_disabled",
});
}
if is_cache_fresh_static(&config.cache_path, config.mtime_window) {
return Ok(RefreshOutcome::Unchanged);
}
let etag = std::fs::read_to_string(&config.etag_path)
.ok()
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty());
match fetch_conditional(&config.models_dev_url, etag.as_deref(), &config.user_agent)
.await
{
Some(FetchResult::Updated(md)) => {
let (providers, models) = materialize(&md, &OverrideFile::default());
let (pcount, mcount) = {
let mut snap = state.write();
snap.providers = providers;
snap.models = models
.into_iter()
.map(|(pid, list)| {
(
pid,
list.into_iter().map(|e| (e.model_id.clone(), e)).collect(),
)
})
.collect();
snap.stats()
};
if let Ok(body) = serde_json::to_string(&md) {
let _ = std::fs::create_dir_all(
config.cache_path.parent().unwrap_or(Path::new(".")),
);
let _ = std::fs::write(&config.cache_path, body);
}
let _ = filetime::set_file_mtime(
&config.cache_path,
filetime::FileTime::from_system_time(SystemTime::now()),
);
let _ = tx.send(CatalogEvent::Updated {
provider_count: pcount,
model_count: mcount,
});
Ok(RefreshOutcome::Updated {
provider_count: pcount,
model_count: mcount,
})
}
Some(FetchResult::NotModified) => {
let _ = filetime::set_file_mtime(
&config.cache_path,
filetime::FileTime::from_system_time(SystemTime::now()),
);
Ok(RefreshOutcome::Unchanged)
}
None => {
let (pcount, mcount) = state.read().stats();
let _ = tx.send(CatalogEvent::RefreshFailed {
reason: "network".into(),
provider_count: pcount,
model_count: mcount,
});
Ok(RefreshOutcome::Failed {
reason: "network".into(),
})
}
}
})
}
fn subscribe(&self) -> broadcast::Receiver<CatalogEvent> {
self.tx.subscribe()
}
fn list_providers_sync(&self) -> Vec<String> {
let snap = self.state.read();
let mut ids: Vec<String> = snap.providers.iter().map(|p| p.id.clone()).collect();
ids.sort();
ids
}
fn get_provider_sync(&self, provider_id: &str) -> Option<CatalogProviderEntry> {
let snap = self.state.read();
snap.providers.iter().find(|p| p.id == provider_id).cloned()
}
fn list_models_sync(&self, provider_id: &str) -> Vec<CatalogModelEntry> {
let snap = self.state.read();
snap.models
.get(provider_id)
.map(|m| m.values().cloned().collect())
.unwrap_or_default()
}
fn get_model_sync(&self, provider_id: &str, model_id: &str) -> Option<CatalogModelEntry> {
let snap = self.state.read();
snap.models
.get(provider_id)
.and_then(|m| m.get(model_id))
.cloned()
}
fn search_sync(&self, pattern: &str) -> Vec<CatalogModelEntry> {
let snap = self.state.read();
let lower = pattern.to_lowercase();
snap.models
.values()
.flat_map(|m| m.values())
.filter(|e| {
e.model_id.to_lowercase().contains(&lower)
|| e.name.to_lowercase().contains(&lower)
|| e.provider.to_lowercase().contains(&lower)
})
.cloned()
.collect()
}
fn model_count_sync(&self) -> usize {
let snap = self.state.read();
snap.models.values().map(|m| m.len()).sum()
}
}
fn is_cache_fresh_static(path: &Path, window: Duration) -> bool {
let meta = match std::fs::metadata(path) {
Ok(m) => m,
Err(_) => return false,
};
let modified = match meta.modified() {
Ok(t) => t,
Err(_) => return false,
};
let age = match SystemTime::now().duration_since(modified) {
Ok(d) => d,
Err(_) => return false,
};
age <= window
}
#[cfg(test)]
mod tests {
use super::*;
use crate::AuthMethod;
#[test]
fn protocol_for_anthropic() {
assert_eq!(
protocol_for("@ai-sdk/anthropic"),
CatalogProtocol::AnthropicMessages
);
}
#[test]
fn protocol_for_google() {
assert_eq!(
protocol_for("@ai-sdk/google"),
CatalogProtocol::GoogleGenerativeAi
);
}
#[test]
fn protocol_for_openai_compat() {
assert_eq!(
protocol_for("@ai-sdk/openai-compatible"),
CatalogProtocol::OpenAiCompletions
);
}
#[test]
fn protocol_for_unknown_is_openai_compatible() {
assert_eq!(
protocol_for("some-new-sdk"),
CatalogProtocol::OpenAiCompatible
);
}
#[test]
fn protocol_for_empty_is_openai_compatible() {
assert_eq!(protocol_for(""), CatalogProtocol::OpenAiCompatible);
}
#[test]
fn default_auth_for_anthropic_is_xapikey() {
assert_eq!(
CatalogProtocol::AnthropicMessages.default_auth(),
AuthMethod::XApiKey
);
}
#[test]
fn default_auth_for_azure_is_apikey() {
assert_eq!(
CatalogProtocol::AzureOpenAiResponses.default_auth(),
AuthMethod::ApiKey
);
}
#[test]
fn default_auth_for_google_is_none() {
assert_eq!(
CatalogProtocol::GoogleVertex.default_auth(),
AuthMethod::None
);
assert_eq!(
CatalogProtocol::GoogleGenerativeAi.default_auth(),
AuthMethod::None
);
assert_eq!(
CatalogProtocol::BedrockConverseStream.default_auth(),
AuthMethod::None
);
}
#[test]
fn default_auth_for_openai_compat_is_bearer() {
assert_eq!(
CatalogProtocol::OpenAiCompletions.default_auth(),
AuthMethod::Bearer
);
assert_eq!(
CatalogProtocol::OpenAiCompatible.default_auth(),
AuthMethod::Bearer
);
assert_eq!(
CatalogProtocol::MistralConversations.default_auth(),
AuthMethod::Bearer
);
assert_eq!(
CatalogProtocol::OpenAiResponses.default_auth(),
AuthMethod::Bearer
);
}
#[test]
fn as_oxi_api_round_trip() {
use oxi_ai::Api;
assert_eq!(
CatalogProtocol::AnthropicMessages.as_oxi_api(),
Api::AnthropicMessages
);
assert_eq!(
CatalogProtocol::OpenAiCompletions.as_oxi_api(),
Api::OpenAiCompletions
);
assert_eq!(
CatalogProtocol::OpenAiCompatible.as_oxi_api(),
Api::OpenAiCompletions
);
assert_eq!(
CatalogProtocol::GoogleGenerativeAi.as_oxi_api(),
Api::GoogleGenerativeAi
);
}
#[test]
fn snapshot_loads_and_has_expected_size() {
let catalog = load_snapshot().expect("SNAP must load");
assert!(!catalog.0.is_empty(), "SNAP should have providers");
let model_count: usize = catalog.0.values().map(|p| p.models.len()).sum();
assert!(
model_count > 1000,
"SNAP should have many models, got {model_count}"
);
}
#[test]
fn materialize_produces_nonzero_entries() {
let catalog = load_snapshot().expect("SNAP");
let (providers, models) = materialize(&catalog, &OverrideFile::default());
assert!(!providers.is_empty());
let count: usize = models.values().map(|v| v.len()).sum();
assert!(count > 0);
}
#[test]
fn override_replaces_existing_model() {
let mut providers = vec![CatalogProviderEntry {
id: "test".into(),
display_name: "Original".into(),
aliases: vec![],
protocol: CatalogProtocol::OpenAiCompletions,
env_key: Some("TEST_KEY".into()),
extra_env_keys: vec![],
base_url: Some("https://api.test.com".into()),
extra_headers: vec![],
category: String::new(),
description: String::new(),
default_enabled: true,
}];
let mut models: BTreeMap<String, Vec<CatalogModelEntry>> = BTreeMap::new();
models.insert(
"test".into(),
vec![CatalogModelEntry {
provider: "test".into(),
model_id: "test-model".into(),
name: "Original".into(),
protocol: CatalogProtocol::OpenAiCompletions,
source: CatalogSource::Embedded,
base_url: None,
reasoning: false,
supports_vision: false,
cost_input: 0.0,
cost_output: 0.0,
cost_cache_read: 0.0,
cost_cache_write: 0.0,
context_window: 1000,
max_tokens: 100,
input_modalities: vec!["text".into()],
release_date: None,
status: None,
}],
);
let overrides = OverrideFile {
model: vec![OverrideModel {
provider: "test".into(),
id: "test-model".into(),
name: Some("Overridden".into()),
cost_input: Some(99.0),
cost_output: None,
context_window: None,
max_tokens: None,
}],
..Default::default()
};
apply_user_overrides(&mut providers, &mut models, &overrides);
let entry = models
.get("test")
.unwrap()
.iter()
.find(|m| m.model_id == "test-model")
.unwrap();
assert_eq!(entry.name, "Overridden");
assert_eq!(entry.source, CatalogSource::Override);
assert!((entry.cost_input - 99.0).abs() < 1e-9);
assert_eq!(entry.context_window, 1000, "untouched field kept");
}
}