use std::collections::BTreeMap;
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::OnceLock;
use std::time::Duration;
use std::time::SystemTime;
use serde::{Deserialize, Serialize};
use crate::catalog::BuiltinModelEntry;
const DEFAULT_TTL: Duration = Duration::from_secs(5 * 60);
const FETCH_TIMEOUT: Duration = Duration::from_secs(10);
const FETCH_RETRIES: u32 = 2;
const RETRY_BACKOFF: Duration = Duration::from_millis(200);
const DEFAULT_URL: &str = "https://models.dev";
const USER_AGENT: &str = concat!("oxi/", env!("CARGO_PKG_VERSION"));
#[derive(Debug, Default, Serialize, Deserialize)]
pub struct MdCatalog(pub BTreeMap<String, MdProvider>);
#[derive(Debug, Serialize, Deserialize)]
pub struct MdProvider {
#[allow(dead_code)]
pub name: String,
#[allow(dead_code)]
pub env: Vec<String>,
#[serde(default)]
#[allow(dead_code)]
pub npm: Option<String>,
#[serde(default)]
#[allow(dead_code)]
pub api: Option<String>,
pub models: BTreeMap<String, MdModel>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct MdModel {
#[allow(dead_code)]
pub name: String,
pub reasoning: bool,
pub limit: MdLimit,
#[serde(default)]
pub cost: Option<MdCost>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct MdLimit {
pub context: f64,
pub output: f64,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct MdCost {
pub input: f64,
pub output: f64,
#[serde(default)]
pub cache_read: Option<f64>,
#[serde(default)]
pub cache_write: Option<f64>,
}
pub fn provider_map(oxi_pid: &str) -> Option<&'static str> {
Some(match oxi_pid {
"anthropic" | "anthropic-vertex" => "anthropic",
"google" => "google",
"google-vertex" => "google-vertex",
"google-vertex-anthropic" => "google-vertex-anthropic",
"openai" | "openai-responses" | "openai-completions" | "openai-codex" => "openai",
"openrouter" => "openrouter",
"deepseek" => "deepseek",
"groq" => "groq",
"xai" => "xai",
"mistral" => "mistral",
"azure" | "azure-cognitive-services" => "azure-cognitive-services",
"bedrock" | "amazon-bedrock" | "amazon-bedrock-mantle" => "amazon-bedrock",
"fireworks" => "fireworks-ai",
"togetherai" | "together" => "togetherai",
"cerebras" => "cerebras",
"deepinfra" => "deepinfra",
"cloudflare" | "cloudflare-workers-ai" => "cloudflare-workers-ai",
"cloudflare-ai-gateway" => "cloudflare-ai-gateway",
"huggingface" => "huggingface",
"moonshotai" | "moonshot" => "moonshotai",
"moonshotai-cn" => "moonshotai-cn",
"kimi-coding" => "kimi-for-coding",
"xiaomi" => "xiaomi",
"xiaomi-token-plan" => "xiaomi-token-plan",
"minimax" => "minimax",
"minimax-cn" => "minimax-cn",
"zai" | "zai-global" => "zai",
"zai-cn" => "zai-cn",
"zai-coding-global" | "zai-coding-cn" => "zai-coding-plan",
"vercel-ai-gateway" => "vercel",
"copilot" | "codex" | "github-copilot" => "github-copilot",
"opencode" => "opencode",
"opencode-go" => "opencode-go",
"nvidia" => "nvidia",
"novita" => "novita-ai",
"venice" => "venice",
"chutes" => "chutes",
"gmi" => "gmicloud",
"stepfun" => "stepfun-ai",
"qwen-portal" | "alibaba" => "alibaba",
"ollama-cloud" => "ollama-cloud",
"synthetic" => "synthetic",
_ => return None,
})
}
pub fn reasoning_preserve(oxi_pid: &str, id: &str) -> bool {
if id.ends_with("-TEE") {
return matches!(oxi_pid, "chutes" | "together");
}
if id.ends_with("-tput") && oxi_pid == "together" {
return true;
}
matches!(
(oxi_pid, id),
("groq", "groq/compound")
| ("groq", "groq/compound-mini")
| ("together", "Qwen/Qwen3-Coder-Next-FP8")
| ("mistral", "mistral-medium-latest")
| ("together", "Qwen/Qwen3.7-Max")
| ("together", "deepseek-ai/DeepSeek-V3")
)
}
pub fn enrich(entry: &mut BuiltinModelEntry, catalog: &MdCatalog) {
let md_pid = match provider_map(&entry.provider) {
Some(p) => p,
None => {
tracing::trace!(
provider = %entry.provider,
"models.dev: provider unmapped, skipping enrichment"
);
return;
}
};
let Some(mdprov) = catalog.0.get(md_pid) else {
return;
};
let Some(mdm) = mdprov.models.get(&entry.id) else {
return;
};
if let Some(c) = &mdm.cost {
if c.input > 0.0 {
entry.cost_input = c.input;
}
if c.output > 0.0 {
entry.cost_output = c.output;
}
if let Some(cr) = c.cache_read
&& cr > 0.0
{
entry.cost_cache_read = cr;
}
if let Some(cw) = c.cache_write
&& cw > 0.0
{
entry.cost_cache_write = cw;
}
}
if mdm.limit.context > 0.0 {
entry.context_window = mdm.limit.context as u32;
}
if mdm.limit.output > 0.0 {
entry.max_tokens = mdm.limit.output as u32;
}
if !reasoning_preserve(&entry.provider, &entry.id) {
entry.reasoning = mdm.reasoning;
}
}
static MODELS_DEV: OnceLock<Option<Arc<MdCatalog>>> = OnceLock::new();
pub async fn init_models_dev() {
if MODELS_DEV.get().is_some() {
return;
}
let result = fetch_with_fallback().await;
let arc_opt = result.map(Arc::new);
let _ = MODELS_DEV.set(arc_opt);
}
pub fn get() -> Option<&'static MdCatalog> {
MODELS_DEV.get().and_then(|o| o.as_deref())
}
#[cfg(test)]
pub fn reset_for_tests() {
}
fn cache_path() -> Option<PathBuf> {
if let Ok(custom) = std::env::var("OXI_MODELS_DEV_CACHE_PATH")
&& !custom.is_empty()
{
return Some(PathBuf::from(custom));
}
Some(
dirs::home_dir()?
.join(".oxi")
.join("cache")
.join("models-dev.json"),
)
}
fn enabled() -> bool {
!matches!(
std::env::var("OXI_MODELS_DEV").as_deref(),
Ok("off") | Ok("OFF") | Ok("0") | Ok("false") | Ok("FALSE")
)
}
fn fetch_disabled() -> bool {
matches!(
std::env::var("OXI_MODELS_DEV_DISABLE_FETCH").as_deref(),
Ok("1") | Ok("true") | Ok("TRUE")
)
}
fn models_url() -> String {
std::env::var("OXI_MODELS_DEV_URL").unwrap_or_else(|_| DEFAULT_URL.to_string())
}
fn ttl() -> Duration {
std::env::var("OXI_MODELS_DEV_TTL")
.ok()
.and_then(|s| s.parse().ok())
.map(Duration::from_secs)
.unwrap_or(DEFAULT_TTL)
}
async fn fetch_with_fallback() -> Option<MdCatalog> {
if !enabled() {
return None;
}
if let Some(c) = read_cache_if_fresh() {
tracing::debug!("models.dev: using fresh cache");
return Some(c);
}
if !fetch_disabled()
&& let Some(c) = live_fetch().await
{
write_cache_atomic(&c);
return Some(c);
}
if let Some(c) = read_cache_any() {
tracing::debug!("models.dev: using stale cache (live fetch unavailable)");
return Some(c);
}
None
}
fn read_cache_if_fresh() -> Option<MdCatalog> {
let path = cache_path()?;
let meta = std::fs::metadata(&path).ok()?;
let modified = meta.modified().ok()?;
let age = SystemTime::now().duration_since(modified).ok()?;
if age > ttl() {
return None;
}
read_cache(&path)
}
fn read_cache_any() -> Option<MdCatalog> {
let path = cache_path()?;
read_cache(&path)
}
fn read_cache(path: &std::path::Path) -> Option<MdCatalog> {
let body = std::fs::read_to_string(path).ok()?;
match serde_json::from_str::<MdCatalog>(&body) {
Ok(c) => Some(c),
Err(e) => {
tracing::warn!(error = %e, "models.dev: cache corrupt, ignoring");
let _ = std::fs::remove_file(path);
None
}
}
}
fn write_cache_atomic(catalog: &MdCatalog) {
let Some(path) = cache_path() else {
return;
};
let Some(parent) = path.parent() else {
return;
};
if std::fs::create_dir_all(parent).is_err() {
return;
}
let Ok(body) = serde_json::to_string(catalog) else {
return;
};
let tmp = path.with_file_name(format!("models-dev.json.{}.tmp", std::process::id()));
if std::fs::write(&tmp, &body).is_err() {
return;
}
if let Err(e) = std::fs::rename(&tmp, &path) {
tracing::debug!(error = %e, "models.dev: cache rename failed");
let _ = std::fs::remove_file(&tmp);
}
}
async fn live_fetch() -> Option<MdCatalog> {
let client = reqwest::Client::builder()
.timeout(FETCH_TIMEOUT)
.build()
.ok()?;
let url = format!("{}/api.json", models_url().trim_end_matches('/'));
for attempt in 0..FETCH_RETRIES {
match client
.get(&url)
.header("User-Agent", USER_AGENT)
.send()
.await
{
Ok(resp) if resp.status().is_success() => match resp.text().await {
Ok(body) => match serde_json::from_str::<MdCatalog>(&body) {
Ok(c) => {
tracing::debug!(
models = c.0.values().map(|p| p.models.len()).sum::<usize>(),
"models.dev: fetched"
);
return Some(c);
}
Err(e) => {
tracing::warn!(error = %e, "models.dev: parse failed");
return None;
}
},
Err(e) => {
tracing::warn!(error = %e, "models.dev: body read failed");
}
},
Ok(resp) => {
tracing::warn!(status = %resp.status(), "models.dev: non-success status");
}
Err(e) => {
tracing::warn!(error = %e, attempt, "models.dev: fetch failed");
}
}
if attempt + 1 < FETCH_RETRIES {
tokio::time::sleep(RETRY_BACKOFF).await;
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
fn md(
provider: &str,
model_id: &str,
cost: Option<(f64, f64)>,
ctx: f64,
output: f64,
reasoning: bool,
) -> MdCatalog {
let mut cat = MdCatalog::default();
let m = MdModel {
name: model_id.to_string(),
reasoning,
limit: MdLimit {
context: ctx,
output,
},
cost: cost.map(|(i, o)| MdCost {
input: i,
output: o,
cache_read: None,
cache_write: None,
}),
};
let mut models = BTreeMap::new();
models.insert(model_id.to_string(), m);
cat.0.insert(
provider.to_string(),
MdProvider {
name: provider.to_string(),
env: vec![],
npm: None,
api: None,
models,
},
);
cat
}
fn entry(provider: &str, id: &str) -> BuiltinModelEntry {
BuiltinModelEntry {
id: id.to_string(),
name: id.to_string(),
api: "openai-completions".to_string(),
provider: provider.to_string(),
reasoning: false,
input: vec!["text".to_string()],
cost_input: 0.0,
cost_output: 0.0,
cost_cache_read: 0.0,
cost_cache_write: 0.0,
context_window: 0,
max_tokens: 0,
}
}
#[test]
fn schema_parses_snapshot() {
let json = r#"{
"deepseek": {
"id": "deepseek",
"name": "DeepSeek",
"env": ["DEEPSEEK_API_KEY"],
"npm": "@ai-sdk/openai-compatible",
"api": "https://api.deepseek.com",
"models": {
"deepseek-chat": {
"id": "deepseek-chat",
"name": "DeepSeek Chat",
"release_date": "2025-12-01",
"attachment": true,
"reasoning": false,
"tool_call": true,
"temperature": true,
"limit": { "context": 1000000, "output": 384000 },
"cost": { "input": 0.14, "output": 0.28, "cache_read": 0.0028 }
}
}
}
}"#;
let cat: MdCatalog = serde_json::from_str(json).unwrap();
let m = &cat.0["deepseek"].models["deepseek-chat"];
assert!((m.cost.as_ref().unwrap().input - 0.14).abs() < 1e-9);
assert_eq!(m.limit.context, 1000000.0);
assert_eq!(m.limit.output, 384000.0);
}
#[test]
fn enrich_fills_missing_price() {
let cat = md(
"deepseek",
"deepseek-chat",
Some((0.14, 0.28)),
1000000.0,
384000.0,
false,
);
let mut e = entry("deepseek", "deepseek-chat");
enrich(&mut e, &cat);
assert!((e.cost_input - 0.14).abs() < 1e-9);
assert!((e.cost_output - 0.28).abs() < 1e-9);
assert_eq!(e.context_window, 1000000);
assert_eq!(e.max_tokens, 384000);
}
#[test]
fn enrich_preserves_zero_price_when_md_zero() {
let cat = md(
"deepseek",
"free-model",
Some((0.0, 0.0)),
128000.0,
8192.0,
false,
);
let mut e = entry("deepseek", "free-model");
e.cost_input = 0.5; e.context_window = 64000;
enrich(&mut e, &cat);
assert_eq!(e.cost_input, 0.5, "non-zero Layer 1 price must survive");
assert_eq!(e.context_window, 128000);
}
#[test]
fn enrich_noop_for_unmapped_provider() {
let cat = md("deepseek", "x", None, 1000000.0, 384000.0, true);
let mut e = entry("ollama", "llama3");
enrich(&mut e, &cat);
assert_eq!(e.context_window, 0, "unmapped provider must be untouched");
}
#[test]
fn enrich_noop_for_missing_model() {
let cat = md("deepseek", "deepseek-chat", None, 1000000.0, 384000.0, true);
let mut e = entry("deepseek", "deepseek-other");
enrich(&mut e, &cat);
assert_eq!(e.context_window, 0);
}
#[test]
fn enrich_updates_reasoning() {
let cat = md(
"openai",
"gpt-5-chat-latest",
None,
400000.0,
128000.0,
true,
);
let mut e = entry("openai", "gpt-5-chat-latest");
assert!(!e.reasoning);
enrich(&mut e, &cat);
assert!(e.reasoning, "reasoning should be copied from models.dev");
}
#[test]
fn enrich_preserves_reasoning_for_tee_variant() {
let cat = md(
"chutes",
"Qwen/Qwen3-Coder-Next-TEE",
None,
131072.0,
32768.0,
false,
);
let mut e = entry("chutes", "Qwen/Qwen3-Coder-Next-TEE");
e.reasoning = true; enrich(&mut e, &cat);
assert!(e.reasoning, "TEE variant reasoning must be preserved");
}
#[test]
fn enrich_preserves_reasoning_for_compound() {
let cat = md("groq", "groq/compound", None, 131072.0, 32768.0, false);
let mut e = entry("groq", "groq/compound");
e.reasoning = true;
enrich(&mut e, &cat);
assert!(e.reasoning);
}
#[test]
fn provider_map_collapse_rules() {
assert_eq!(provider_map("minimax-cn"), Some("minimax-cn"));
assert_eq!(provider_map("zai-coding-global"), Some("zai-coding-plan"));
assert_eq!(provider_map("moonshot"), Some("moonshotai"));
assert_eq!(provider_map("copilot"), Some("github-copilot"));
assert_eq!(provider_map("ollama"), None);
assert_eq!(provider_map("lmstudio"), None);
assert_eq!(provider_map("unknown-provider"), None);
}
#[test]
fn write_cache_roundtrips() {
let cat = md(
"deepseek",
"deepseek-chat",
Some((0.14, 0.28)),
1000000.0,
384000.0,
false,
);
let tmp = std::env::temp_dir().join(format!("oxi-md-test-{}.json", std::process::id()));
let body = serde_json::to_string(&cat).unwrap();
std::fs::write(&tmp, &body).unwrap();
let back: MdCatalog =
serde_json::from_str(&std::fs::read_to_string(&tmp).unwrap()).unwrap();
let _ = std::fs::remove_file(&tmp);
assert!(back.0.contains_key("deepseek"));
}
}