use std::collections::BTreeMap;
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::OnceLock;
use std::time::Duration;
use std::time::SystemTime;
use serde::{Deserialize, Serialize};
use crate::Api;
use crate::catalog::provider::AuthMethod;
const DEFAULT_MTIME_WINDOW: Duration = Duration::from_secs(60 * 60);
const FETCH_TIMEOUT: Duration = Duration::from_secs(10);
const FETCH_RETRIES: u32 = 2;
const RETRY_BACKOFF: Duration = Duration::from_millis(200);
const DEFAULT_URL: &str = "https://models.dev";
const USER_AGENT: &str = concat!("oxi/", env!("CARGO_PKG_VERSION"));
#[derive(Debug, Default, Serialize, Deserialize)]
pub struct MdCatalog(pub BTreeMap<String, MdProvider>);
#[derive(Debug, Serialize, Deserialize)]
pub struct MdProvider {
#[allow(dead_code)]
pub name: String,
#[allow(dead_code)]
pub env: Vec<String>,
#[serde(default)]
#[allow(dead_code)]
pub npm: Option<String>,
#[serde(default)]
#[allow(dead_code)]
pub api: Option<String>,
#[serde(default)]
#[allow(dead_code)]
pub doc: Option<String>,
pub models: BTreeMap<String, MdModel>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct MdModel {
#[allow(dead_code)]
pub name: String,
#[serde(default)]
#[allow(dead_code)]
pub family: Option<String>,
pub reasoning: bool,
#[serde(default)]
pub tool_call: bool,
#[serde(default)]
pub attachment: bool,
#[serde(default)]
#[allow(dead_code)]
pub temperature: Option<bool>,
#[serde(default)]
#[allow(dead_code)]
pub structured_output: Option<bool>,
#[serde(default)]
#[allow(dead_code)]
pub knowledge: Option<String>,
#[serde(default)]
#[allow(dead_code)]
pub release_date: Option<String>,
#[serde(default)]
#[allow(dead_code)]
pub last_updated: Option<String>,
#[serde(default)]
#[allow(dead_code)]
pub open_weights: Option<bool>,
#[serde(default)]
#[allow(dead_code)]
pub interleaved: Option<serde_json::Value>,
#[serde(default)]
#[allow(dead_code)]
pub reasoning_options: Option<Vec<MdReasoningOption>>,
pub limit: MdLimit,
#[serde(default)]
pub cost: Option<MdCost>,
#[serde(default)]
#[allow(dead_code)]
pub modalities: Option<MdModalities>,
#[serde(default)]
#[allow(dead_code)]
pub status: Option<String>,
#[serde(default)]
pub provider: Option<MdModelProvider>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct MdModelProvider {
#[serde(default)]
pub npm: Option<String>,
#[serde(default)]
pub api: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct MdLimit {
pub context: f64,
#[serde(default)]
pub input: Option<f64>,
pub output: f64,
}
#[derive(Debug, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct MdCost {
pub input: f64,
pub output: f64,
#[serde(default)]
pub cache_read: Option<f64>,
#[serde(default)]
pub cache_write: Option<f64>,
#[serde(default)]
pub tiers: Option<Vec<MdCostTier>>,
#[serde(default)]
pub context_over_200k: Option<MdCostTierData>,
#[serde(default)]
pub reasoning: Option<f64>,
#[serde(default)]
pub input_audio: Option<f64>,
#[serde(default)]
pub output_audio: Option<f64>,
}
#[derive(Debug, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct MdCostTier {
pub input: f64,
pub output: f64,
#[serde(default)]
pub cache_read: Option<f64>,
#[serde(default)]
pub cache_write: Option<f64>,
pub tier: MdTierSpec,
}
#[derive(Debug, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct MdTierSpec {
#[serde(rename = "type")]
pub kind: String,
pub size: f64,
}
#[derive(Debug, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct MdCostTierData {
pub input: f64,
pub output: f64,
#[serde(default)]
pub cache_read: Option<f64>,
#[serde(default)]
pub cache_write: Option<f64>,
}
#[derive(Debug, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct MdModalities {
#[serde(default)]
#[allow(dead_code)]
pub input: Option<Vec<String>>,
#[serde(default)]
#[allow(dead_code)]
pub output: Option<Vec<String>>,
}
#[derive(Debug, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct MdReasoningOption {
#[serde(rename = "type")]
pub kind: String,
#[serde(default)]
#[allow(dead_code)]
pub values: Option<Vec<Option<String>>>,
#[serde(default)]
#[allow(dead_code)]
pub min: Option<f64>,
}
pub fn protocol_for(npm: &str) -> (Api, AuthMethod) {
match npm {
"@ai-sdk/anthropic" => (Api::AnthropicMessages, AuthMethod::XApiKey),
"@ai-sdk/google" => (Api::GoogleGenerativeAi, AuthMethod::None),
"@ai-sdk/google-vertex" | "@ai-sdk/google-vertex/anthropic" => {
(Api::GoogleVertex, AuthMethod::None)
}
"@ai-sdk/mistral" => (Api::MistralConversations, AuthMethod::Bearer),
"@ai-sdk/azure" => (Api::AzureOpenAiResponses, AuthMethod::ApiKey),
"@ai-sdk/amazon-bedrock" => (Api::BedrockConverseStream, AuthMethod::None),
_ => (Api::OpenAiCompletions, AuthMethod::Bearer),
}
}
static MODELS_DEV: OnceLock<Option<Arc<MdCatalog>>> = OnceLock::new();
pub async fn init_models_dev() {
if MODELS_DEV.get().is_some() {
return;
}
let result = fetch_with_fallback().await;
let arc_opt = result.map(Arc::new);
let _ = MODELS_DEV.set(arc_opt);
}
pub fn get() -> Option<&'static MdCatalog> {
MODELS_DEV.get().and_then(|o| o.as_deref())
}
pub async fn refresh() -> bool {
if !enabled() || fetch_disabled() {
return false;
}
let etag = read_etag();
match live_fetch_conditional(etag.as_deref()).await {
Some(ConditionalResult::NotModified) => {
tracing::info!("models.dev: already up to date (304)");
touch_cache_mtime();
false
}
Some(ConditionalResult::Updated(c, new_etag)) => {
write_cache_atomic(&c);
if let Some(e) = new_etag {
write_etag(&e);
}
tracing::info!("models.dev: cache refreshed");
true
}
None => {
tracing::warn!("models.dev: refresh failed");
false
}
}
}
#[cfg(test)]
pub fn reset_for_tests() {
}
fn cache_path() -> Option<PathBuf> {
if let Ok(custom) = std::env::var("OXI_MODELS_DEV_CACHE_PATH")
&& !custom.is_empty()
{
return Some(PathBuf::from(custom));
}
Some(
dirs::home_dir()?
.join(".oxi")
.join("cache")
.join("models-dev.json"),
)
}
fn enabled() -> bool {
!matches!(
std::env::var("OXI_MODELS_DEV").as_deref(),
Ok("off") | Ok("OFF") | Ok("0") | Ok("false") | Ok("FALSE")
)
}
fn fetch_disabled() -> bool {
matches!(
std::env::var("OXI_MODELS_DEV_DISABLE_FETCH").as_deref(),
Ok("1") | Ok("true") | Ok("TRUE")
)
}
fn models_url() -> String {
std::env::var("OXI_MODELS_DEV_URL").unwrap_or_else(|_| DEFAULT_URL.to_string())
}
fn mtime_window() -> Duration {
std::env::var("OXI_MODELS_DEV_MTIME_WINDOW")
.ok()
.and_then(|s| s.parse().ok())
.map(Duration::from_secs)
.unwrap_or(DEFAULT_MTIME_WINDOW)
}
fn force_refresh() -> bool {
matches!(
std::env::var("OXI_MODELS_DEV_FORCE_REFRESH").as_deref(),
Ok("1") | Ok("true") | Ok("TRUE")
)
}
async fn fetch_with_fallback() -> Option<MdCatalog> {
if !enabled() {
return None;
}
if !force_refresh()
&& let Some(c) = read_cache_if_fresh()
{
tracing::debug!("models.dev: using cache within mtime window");
return Some(c);
}
if !fetch_disabled() {
let etag = read_etag();
match live_fetch_conditional(etag.as_deref()).await {
Some(ConditionalResult::NotModified) => {
if let Some(c) = read_cache_any() {
tracing::debug!("models.dev: 304 Not Modified, touching cache mtime");
touch_cache_mtime();
return Some(c);
}
tracing::warn!("models.dev: 304 received but cache missing — refetching");
clear_etag();
if let Some(ConditionalResult::Updated(c, new_etag)) =
live_fetch_conditional(None).await
{
write_cache_atomic(&c);
if let Some(e) = new_etag {
write_etag(&e);
}
return Some(c);
}
}
Some(ConditionalResult::Updated(c, new_etag)) => {
write_cache_atomic(&c);
if let Some(e) = new_etag {
write_etag(&e);
}
return Some(c);
}
None => { }
}
}
if let Some(c) = read_cache_any() {
tracing::debug!("models.dev: using stale cache (live fetch unavailable)");
return Some(c);
}
None
}
enum ConditionalResult {
NotModified,
Updated(MdCatalog, Option<String>),
}
fn read_cache_if_fresh() -> Option<MdCatalog> {
let path = cache_path()?;
let meta = std::fs::metadata(&path).ok()?;
let modified = meta.modified().ok()?;
let age = SystemTime::now().duration_since(modified).ok()?;
if age > mtime_window() {
return None;
}
read_cache(&path)
}
fn read_cache_any() -> Option<MdCatalog> {
let path = cache_path()?;
read_cache(&path)
}
fn read_cache(path: &std::path::Path) -> Option<MdCatalog> {
let body = std::fs::read_to_string(path).ok()?;
match serde_json::from_str::<MdCatalog>(&body) {
Ok(c) => Some(c),
Err(e) => {
tracing::warn!(error = %e, "models.dev: cache corrupt, ignoring");
let _ = std::fs::remove_file(path);
None
}
}
}
fn touch_cache_mtime() {
let Some(path) = cache_path() else { return };
let now = std::time::SystemTime::now();
let _ = filetime::set_file_mtime(&path, filetime::FileTime::from_system_time(now));
}
fn etag_path() -> Option<PathBuf> {
let base = cache_path()?;
Some(base.with_extension("json.etag"))
}
fn read_etag() -> Option<String> {
let path = etag_path()?;
let body = std::fs::read_to_string(&path).ok()?;
let trimmed = body.trim();
if trimmed.is_empty() {
None
} else {
Some(trimmed.to_string())
}
}
fn write_etag(etag: &str) {
let Some(path) = etag_path() else { return };
let tmp = path.with_extension("json.etag.tmp");
if std::fs::write(&tmp, etag).is_ok() {
let _ = std::fs::rename(&tmp, &path);
}
}
fn clear_etag() {
let Some(path) = etag_path() else { return };
let _ = std::fs::remove_file(&path);
}
fn write_cache_atomic(catalog: &MdCatalog) {
let Some(path) = cache_path() else {
return;
};
let Some(parent) = path.parent() else {
return;
};
if std::fs::create_dir_all(parent).is_err() {
return;
}
let Ok(body) = serde_json::to_string(catalog) else {
return;
};
let tmp = path.with_file_name(format!("models-dev.json.{}.tmp", std::process::id()));
if std::fs::write(&tmp, &body).is_err() {
return;
}
if let Err(e) = std::fs::rename(&tmp, &path) {
tracing::debug!(error = %e, "models.dev: cache rename failed");
let _ = std::fs::remove_file(&tmp);
}
}
async fn live_fetch_conditional(etag: Option<&str>) -> Option<ConditionalResult> {
let client = reqwest::Client::builder()
.timeout(FETCH_TIMEOUT)
.build()
.ok()?;
let url = format!("{}/api.json", models_url().trim_end_matches('/'));
for attempt in 0..FETCH_RETRIES {
let mut req = client.get(&url).header("User-Agent", USER_AGENT);
if let Some(e) = etag {
req = req.header("If-None-Match", e);
}
match req.send().await {
Ok(resp) => {
let status = resp.status();
if status.as_u16() == 304 {
tracing::debug!("models.dev: 304 Not Modified");
return Some(ConditionalResult::NotModified);
}
if status.is_success() {
let new_etag = resp
.headers()
.get(reqwest::header::ETAG)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
match resp.text().await {
Ok(body) => match serde_json::from_str::<MdCatalog>(&body) {
Ok(c) => {
tracing::debug!(
models = c.0.values().map(|p| p.models.len()).sum::<usize>(),
"models.dev: fetched"
);
return Some(ConditionalResult::Updated(c, new_etag));
}
Err(e) => {
tracing::warn!(error = %e, "models.dev: parse failed");
return None;
}
},
Err(e) => {
tracing::warn!(error = %e, "models.dev: body read failed");
}
}
} else {
tracing::warn!(status = %status, "models.dev: non-success status");
}
}
Err(e) => {
tracing::warn!(error = %e, attempt, "models.dev: fetch failed");
}
}
if attempt + 1 < FETCH_RETRIES {
tokio::time::sleep(RETRY_BACKOFF).await;
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
fn md(
provider: &str,
model_id: &str,
cost: Option<(f64, f64)>,
ctx: f64,
output: f64,
reasoning: bool,
) -> MdCatalog {
let mut cat = MdCatalog::default();
let m = MdModel {
name: model_id.to_string(),
family: None,
reasoning,
tool_call: false,
attachment: false,
temperature: None,
structured_output: None,
knowledge: None,
release_date: None,
last_updated: None,
open_weights: None,
interleaved: None,
reasoning_options: None,
limit: MdLimit {
context: ctx,
input: None,
output,
},
cost: cost.map(|(i, o)| MdCost {
input: i,
output: o,
cache_read: None,
cache_write: None,
tiers: None,
context_over_200k: None,
reasoning: None,
input_audio: None,
output_audio: None,
}),
modalities: None,
status: None,
provider: None,
};
let mut models = BTreeMap::new();
models.insert(model_id.to_string(), m);
cat.0.insert(
provider.to_string(),
MdProvider {
name: provider.to_string(),
env: vec![],
npm: None,
api: None,
doc: None,
models,
},
);
cat
}
#[test]
fn schema_parses_snapshot() {
let json = r#"{
"deepseek": {
"id": "deepseek",
"name": "DeepSeek",
"env": ["DEEPSEEK_API_KEY"],
"npm": "@ai-sdk/openai-compatible",
"api": "https://api.deepseek.com",
"models": {
"deepseek-chat": {
"id": "deepseek-chat",
"name": "DeepSeek Chat",
"release_date": "2025-12-01",
"attachment": true,
"reasoning": false,
"tool_call": true,
"temperature": true,
"limit": { "context": 1000000, "output": 384000 },
"cost": { "input": 0.14, "output": 0.28, "cache_read": 0.0028 }
}
}
}
}"#;
let cat: MdCatalog = serde_json::from_str(json).unwrap();
let m = &cat.0["deepseek"].models["deepseek-chat"];
assert!((m.cost.as_ref().unwrap().input - 0.14).abs() < 1e-9);
assert_eq!(m.limit.context, 1000000.0);
assert_eq!(m.limit.output, 384000.0);
}
#[test]
fn write_cache_roundtrips() {
let cat = md(
"deepseek",
"deepseek-chat",
Some((0.14, 0.28)),
1000000.0,
384000.0,
false,
);
let tmp = std::env::temp_dir().join(format!("oxi-md-test-{}.json", std::process::id()));
let body = serde_json::to_string(&cat).unwrap();
std::fs::write(&tmp, &body).unwrap();
let back: MdCatalog =
serde_json::from_str(&std::fs::read_to_string(&tmp).unwrap()).unwrap();
let _ = std::fs::remove_file(&tmp);
assert!(back.0.contains_key("deepseek"));
}
}