use crate::schema::reasoning_params;
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::time::SystemTime;
use serde::{Deserialize, Serialize};
use tracing::info;
use crate::schema::*;
use crate::InferenceError;
#[derive(Debug, Clone, Default)]
pub struct ModelFilter {
pub capabilities: Vec<ModelCapability>,
pub max_size_mb: Option<u64>,
pub max_latency_ms: Option<u64>,
pub max_cost_per_mtok: Option<f64>,
pub tags: Vec<String>,
pub provider: Option<String>,
pub local_only: bool,
pub available_only: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelUpgrade {
pub from_id: String,
pub from_name: String,
pub to_id: String,
pub to_name: String,
pub reason: String,
pub target_runtime: Option<String>,
pub target_runtime_requirement: Option<String>,
pub minimum_runtimes: Vec<ModelRuntimeRequirement>,
pub target_available: bool,
pub target_pullable: bool,
pub remove_old_supported: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelRuntimeRequirement {
pub name: String,
pub minimum_version: String,
}
pub struct UnifiedRegistry {
models_dir: PathBuf,
models: HashMap<String, ModelSchema>,
user_config_path: PathBuf,
}
#[derive(Debug, Clone, Deserialize)]
struct ModelUpgradeRule {
from_ids: Vec<String>,
to_id: String,
reason: String,
target_runtime: Option<String>,
target_runtime_requirement: Option<String>,
#[serde(default)]
minimum_runtimes: Vec<ModelRuntimeRequirement>,
#[serde(default = "default_remove_old_after_available")]
remove_old_after_available: bool,
}
fn default_remove_old_after_available() -> bool {
true
}
fn model_upgrade_rules() -> Vec<ModelUpgradeRule> {
serde_json::from_str(include_str!("../assets/model-upgrades.json"))
.expect("built-in model-upgrades.json should parse")
}
impl UnifiedRegistry {
pub fn new(models_dir: PathBuf) -> Self {
let user_config_path = models_dir
.parent()
.unwrap_or(&models_dir)
.join("models.json");
let mut registry = Self {
models_dir,
models: HashMap::new(),
user_config_path,
};
registry.load_builtin_catalog();
registry.refresh_availability();
let _ = registry.load_user_config();
registry
}
pub fn register(&mut self, mut schema: ModelSchema) {
if schema.is_mlx() {
schema.available = if schema.tags.contains(&"speech".to_string()) {
speech_mlx_available()
} else if let ModelSource::Mlx { ref hf_repo, .. } = schema.source {
let mlx_dir = self.models_dir.join(&schema.name);
mlx_dir.join("config.json").exists()
|| latest_huggingface_repo_snapshot(hf_repo).is_some()
} else {
let mlx_dir = self.models_dir.join(&schema.name);
mlx_dir.join("config.json").exists()
};
} else if schema.is_vllm_mlx() {
schema.available = std::env::var("VLLM_MLX_ENDPOINT").is_ok() || schema.available;
} else if schema.is_local() {
let local_path = self.models_dir.join(&schema.name).join("model.gguf");
schema.available = local_path.exists();
} else if schema.is_remote() {
if let ModelSource::RemoteApi {
ref api_key_env, ..
} = schema.source
{
schema.available = std::env::var(api_key_env).is_ok();
}
}
info!(id = %schema.id, name = %schema.name, available = schema.available, "registered model");
self.models.insert(schema.id.clone(), schema);
}
pub fn unregister(&mut self, id: &str) -> Option<ModelSchema> {
let removed = self.models.remove(id);
if let Some(ref m) = removed {
info!(id = %m.id, "unregistered model");
}
removed
}
pub fn list(&self) -> Vec<&ModelSchema> {
let mut models: Vec<&ModelSchema> = self.models.values().collect();
models.sort_by(|a, b| a.id.cmp(&b.id));
models
}
pub fn query(&self, filter: &ModelFilter) -> Vec<&ModelSchema> {
self.models
.values()
.filter(|m| {
if !filter.capabilities.iter().all(|c| m.has_capability(*c)) {
return false;
}
if let Some(max) = filter.max_size_mb {
if m.size_mb() > max && m.is_local() {
return false;
}
}
if let Some(max) = filter.max_latency_ms {
if let Some(p50) = m.performance.latency_p50_ms {
if p50 > max {
return false;
}
}
}
if let Some(max) = filter.max_cost_per_mtok {
if let Some(cost) = m.cost.output_per_mtok {
if cost > max {
return false;
}
}
}
if !filter.tags.iter().all(|t| m.tags.contains(t)) {
return false;
}
if let Some(ref p) = filter.provider {
if &m.provider != p {
return false;
}
}
if filter.local_only && !m.is_local() {
return false;
}
if filter.available_only && !m.available {
return false;
}
true
})
.collect()
}
pub fn query_by_capability(&self, cap: ModelCapability) -> Vec<&ModelSchema> {
self.query(&ModelFilter {
capabilities: vec![cap],
..Default::default()
})
}
pub fn available_upgrades(&self) -> Vec<ModelUpgrade> {
let mut upgrades = Vec::new();
for rule in model_upgrade_rules() {
let Some(from) = rule
.from_ids
.iter()
.find_map(|id| self.models.get(id.as_str()))
.filter(|schema| schema.available)
else {
continue;
};
let Some(to) = self.models.get(rule.to_id.as_str()) else {
continue;
};
upgrades.push(ModelUpgrade {
from_id: from.id.clone(),
from_name: from.name.clone(),
to_id: to.id.clone(),
to_name: to.name.clone(),
reason: rule.reason.clone(),
target_runtime: rule.target_runtime.clone(),
target_runtime_requirement: rule.target_runtime_requirement.clone(),
minimum_runtimes: rule.minimum_runtimes.clone(),
target_available: to.available,
target_pullable: matches!(
to.source,
ModelSource::Local { .. } | ModelSource::Mlx { .. }
),
remove_old_supported: matches!(
from.source,
ModelSource::Local { .. } | ModelSource::Mlx { .. }
) && rule.remove_old_after_available,
});
}
upgrades.sort_by(|a, b| a.from_id.cmp(&b.from_id).then(a.to_id.cmp(&b.to_id)));
upgrades.dedup_by(|a, b| a.from_id == b.from_id && a.to_id == b.to_id);
upgrades
}
pub fn get(&self, id: &str) -> Option<&ModelSchema> {
self.models.get(id)
}
pub fn find_by_name(&self, name: &str) -> Option<&ModelSchema> {
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
if !name.to_ascii_lowercase().ends_with("-mlx") {
if let Some(mlx_variant) = self
.models
.values()
.find(|m| m.name.eq_ignore_ascii_case(&format!("{name}-MLX")))
{
return Some(mlx_variant);
}
}
self.models
.values()
.find(|m| m.name.eq_ignore_ascii_case(name))
}
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
pub fn resolve_mlx_equivalent(&self, schema: &ModelSchema) -> Option<&ModelSchema> {
if schema.is_mlx() || schema.is_vllm_mlx() {
return None;
}
if !matches!(schema.source, ModelSource::Local { .. }) {
return None;
}
let primary_cap = schema.capabilities.first()?;
self.models.values().find(|m| {
m.is_mlx()
&& m.family == schema.family
&& m.capabilities.contains(primary_cap)
})
}
pub async fn ensure_local(&self, id: &str) -> Result<PathBuf, InferenceError> {
let schema = self
.get(id)
.or_else(|| self.find_by_name(id))
.ok_or_else(|| InferenceError::ModelNotFound(id.to_string()))?;
match &schema.source {
ModelSource::Local {
hf_repo,
hf_filename,
tokenizer_repo,
} => {
let model_dir = self.models_dir.join(&schema.name);
let model_path = model_dir.join("model.gguf");
let tokenizer_path = model_dir.join("tokenizer.json");
if model_path.exists() && tokenizer_path.exists() {
return Ok(model_dir);
}
std::fs::create_dir_all(&model_dir)?;
if !model_path.exists() {
info!(model = %schema.name, repo = %hf_repo, "downloading model weights");
download_file(hf_repo, hf_filename, &model_path).await?;
}
if !tokenizer_path.exists() {
info!(model = %schema.name, repo = %tokenizer_repo, "downloading tokenizer");
download_file(tokenizer_repo, "tokenizer.json", &tokenizer_path).await?;
}
Ok(model_dir)
}
ModelSource::Mlx {
hf_repo,
hf_weight_file,
} => {
let model_dir = self.models_dir.join(&schema.name);
let config_path = model_dir.join("config.json");
if config_path.exists() {
ensure_auxiliary_mlx_files(&schema.name, hf_repo, &model_dir).await?;
info!(model = %schema.name, path = %model_dir.display(), "using managed local MLX model");
return Ok(model_dir);
}
if let Some(snapshot_dir) = latest_huggingface_repo_snapshot(hf_repo) {
ensure_auxiliary_mlx_files(&schema.name, hf_repo, &snapshot_dir).await?;
info!(model = %schema.name, path = %snapshot_dir.display(), "using cached MLX snapshot");
return Ok(snapshot_dir);
}
std::fs::create_dir_all(&model_dir)?;
info!(model = %schema.name, repo = %hf_repo, "downloading MLX model");
download_file(hf_repo, "config.json", &config_path).await?;
let tok_path = model_dir.join("tokenizer.json");
if !tok_path.exists() {
download_file(hf_repo, "tokenizer.json", &tok_path).await?;
}
let tok_config_path = model_dir.join("tokenizer_config.json");
if !tok_config_path.exists() {
let _ = download_file(hf_repo, "tokenizer_config.json", &tok_config_path).await;
}
if let Some(ref wf) = hf_weight_file {
let wf_path = model_dir.join(wf);
if !wf_path.exists() {
download_file(hf_repo, wf, &wf_path).await?;
}
} else {
let single = model_dir.join("model.safetensors");
if !single.exists() {
match download_file(hf_repo, "model.safetensors", &single).await {
Ok(()) => {}
Err(_) => {
let index_path = model_dir.join("model.safetensors.index.json");
download_file(hf_repo, "model.safetensors.index.json", &index_path)
.await?;
let index_json: serde_json::Value =
serde_json::from_str(&std::fs::read_to_string(&index_path)?)
.map_err(|e| {
InferenceError::InferenceFailed(format!(
"parse index: {e}"
))
})?;
if let Some(weight_map) =
index_json.get("weight_map").and_then(|m| m.as_object())
{
let mut files: std::collections::HashSet<String> =
std::collections::HashSet::new();
for filename in weight_map.values() {
if let Some(f) = filename.as_str() {
files.insert(f.to_string());
}
}
for file in &files {
let dest = model_dir.join(file);
if !dest.exists() {
info!(file = %file, "downloading weight shard");
download_file(hf_repo, file, &dest).await?;
}
}
}
}
}
}
}
ensure_auxiliary_mlx_files(&schema.name, hf_repo, &model_dir).await?;
Ok(model_dir)
}
_ => Err(InferenceError::InferenceFailed(format!(
"model {} is not local",
id
))),
}
}
pub fn remove_local(&mut self, id: &str) -> Result<(), InferenceError> {
let schema = self
.get(id)
.or_else(|| self.find_by_name(id))
.ok_or_else(|| InferenceError::ModelNotFound(id.to_string()))?;
let model_dir = self.models_dir.join(&schema.name);
if model_dir.exists() {
std::fs::remove_dir_all(&model_dir)?;
info!(model = %schema.name, "removed model");
}
match &schema.source {
ModelSource::Mlx { hf_repo, .. } => {
let repo_dir = huggingface_repo_dir(hf_repo);
if repo_dir.exists() {
std::fs::remove_dir_all(&repo_dir)?;
info!(model = %schema.name, repo = %hf_repo, "removed Hugging Face cache");
}
}
ModelSource::Local {
hf_repo,
tokenizer_repo,
..
} => {
for repo in [hf_repo, tokenizer_repo] {
let repo_dir = huggingface_repo_dir(repo);
if repo_dir.exists() {
std::fs::remove_dir_all(&repo_dir)?;
info!(model = %schema.name, repo = %repo, "removed Hugging Face cache");
}
}
}
_ => {}
}
let id = schema.id.clone();
if let Some(m) = self.models.get_mut(&id) {
m.available = false;
}
Ok(())
}
pub fn refresh_availability(&mut self) {
let models_dir = self.models_dir.clone();
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
let mlx_vlm_cli_present = crate::backend::mlx_vlm_cli::is_available();
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
let mlx_vlm_cli_present = false;
for m in self.models.values_mut() {
match &m.source {
ModelSource::Mlx { .. } => {
let needs_mlx_vlm =
m.tags.iter().any(|t| t == "requires-mlx-vlm");
m.available = if needs_mlx_vlm {
mlx_vlm_cli_present
} else if m.tags.contains(&"speech".to_string()) {
speech_mlx_available()
} else {
let mlx_dir = models_dir.join(&m.name);
mlx_dir.join("config.json").exists()
};
}
ModelSource::Local { .. } => {
let local_path = models_dir.join(&m.name).join("model.gguf");
m.available = local_path.exists();
}
ModelSource::RemoteApi { api_key_env, .. } => {
m.available = std::env::var(api_key_env).is_ok();
}
ModelSource::Ollama { .. } => {
m.available = true;
}
ModelSource::VllmMlx { .. } => {
m.available = std::env::var("VLLM_MLX_ENDPOINT").is_ok() || m.available;
}
ModelSource::Proprietary { auth, .. } => {
m.available = match auth {
crate::schema::ProprietaryAuth::ApiKeyEnv { env_var } => {
std::env::var(env_var).is_ok()
}
crate::schema::ProprietaryAuth::BearerTokenEnv { env_var } => {
std::env::var(env_var).is_ok()
}
crate::schema::ProprietaryAuth::OAuth2Pkce { .. } => {
true
}
};
}
ModelSource::AppleFoundationModels { .. } => {
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
{
m.available = crate::backend::foundation_models::is_available();
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
{
m.available = false;
}
}
}
}
}
pub fn save_user_config(&self) -> Result<(), InferenceError> {
let user_models: Vec<&ModelSchema> = self
.models
.values()
.filter(|m| !m.tags.contains(&"builtin".to_string()))
.collect();
if user_models.is_empty() {
return Ok(());
}
let json = serde_json::to_string_pretty(&user_models)
.map_err(|e| InferenceError::InferenceFailed(format!("serialize: {e}")))?;
std::fs::write(&self.user_config_path, json)?;
Ok(())
}
pub fn load_user_config(&mut self) -> Result<(), InferenceError> {
if !self.user_config_path.exists() {
return Ok(());
}
let json = std::fs::read_to_string(&self.user_config_path)?;
let models: Vec<ModelSchema> = serde_json::from_str(&json)
.map_err(|e| InferenceError::InferenceFailed(format!("parse models.json: {e}")))?;
for m in models {
self.register(m);
}
Ok(())
}
pub fn models_dir(&self) -> &Path {
&self.models_dir
}
fn load_builtin_catalog(&mut self) {
for schema in builtin_catalog() {
self.models.insert(schema.id.clone(), schema);
}
}
}
fn speech_mlx_available() -> bool {
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
{ true }
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
{
let runtime_root = speech_runtime_root();
runtime_root.join("bin").join("mlx_audio.stt.generate").exists()
|| runtime_root.join("bin").join("mlx_audio.tts.generate").exists()
}
}
fn speech_runtime_root() -> PathBuf {
if let Ok(path) = std::env::var("CAR_SPEECH_RUNTIME_DIR") {
if !path.trim().is_empty() {
return PathBuf::from(path);
}
}
std::env::var("HOME")
.map(PathBuf::from)
.unwrap_or_else(|_| PathBuf::from("."))
.join(".car")
.join("speech-runtime")
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelInfo {
pub id: String,
pub name: String,
pub provider: String,
pub capabilities: Vec<ModelCapability>,
pub param_count: String,
pub size_mb: u64,
pub context_length: usize,
pub available: bool,
pub is_local: bool,
#[serde(default)]
pub public_benchmarks: Vec<crate::schema::BenchmarkScore>,
}
impl From<&ModelSchema> for ModelInfo {
fn from(s: &ModelSchema) -> Self {
ModelInfo {
id: s.id.clone(),
name: s.name.clone(),
provider: s.provider.clone(),
capabilities: s.capabilities.clone(),
param_count: s.param_count.clone(),
size_mb: s.size_mb(),
context_length: s.context_length,
available: s.available,
is_local: s.is_local(),
public_benchmarks: s.public_benchmarks.clone(),
}
}
}
async fn download_file(repo: &str, filename: &str, dest: &Path) -> Result<(), InferenceError> {
let api = hf_hub::api::tokio::Api::new()
.map_err(|e| InferenceError::DownloadFailed(e.to_string()))?;
let repo = api.model(repo.to_string());
let path = repo
.get(filename)
.await
.map_err(|e| InferenceError::DownloadFailed(format!("{filename}: {e}")))?;
if dest.exists() {
return Ok(());
}
#[cfg(unix)]
{
if std::os::unix::fs::symlink(&path, dest).is_ok() {
return Ok(());
}
}
std::fs::copy(&path, dest)
.map_err(|e| InferenceError::DownloadFailed(format!("copy to {}: {e}", dest.display())))?;
Ok(())
}
async fn ensure_auxiliary_mlx_files(
model_name: &str,
hf_repo: &str,
model_dir: &Path,
) -> Result<(), InferenceError> {
if hf_repo == "mlx-community/Flux-1.lite-8B-MLX-Q4" || model_name == "Flux-1.lite-8B-MLX-Q4" {
let t5_tokenizer_path = model_dir.join("tokenizer_2").join("tokenizer.json");
if !t5_tokenizer_path.exists() {
std::fs::create_dir_all(
t5_tokenizer_path
.parent()
.ok_or_else(|| InferenceError::InferenceFailed("invalid tokenizer path".into()))?,
)?;
info!(
path = %t5_tokenizer_path.display(),
"downloading missing Flux tokenizer_2/tokenizer.json from base model"
);
download_file("Freepik/flux.1-lite-8B", "tokenizer_2/tokenizer.json", &t5_tokenizer_path)
.await?;
}
}
Ok(())
}
fn huggingface_repo_has_snapshot(repo_id: &str) -> bool {
latest_huggingface_repo_snapshot(repo_id).is_some()
}
fn huggingface_cache_root() -> PathBuf {
std::env::var("HF_HOME")
.map(PathBuf::from)
.unwrap_or_else(|_| {
std::env::var("HOME")
.map(PathBuf::from)
.unwrap_or_else(|_| PathBuf::from("."))
.join(".cache")
.join("huggingface")
})
.join("hub")
}
fn huggingface_repo_dir(repo_id: &str) -> PathBuf {
huggingface_cache_root().join(format!("models--{}", repo_id.replace('/', "--")))
}
fn resolve_huggingface_ref_snapshot(repo_dir: &Path, name: &str) -> Option<PathBuf> {
let sha = std::fs::read_to_string(repo_dir.join("refs").join(name))
.ok()?
.trim()
.to_string();
if sha.is_empty() {
return None;
}
let snapshot = repo_dir.join("snapshots").join(sha);
if snapshot_looks_ready(&snapshot) {
Some(snapshot)
} else {
None
}
}
fn latest_huggingface_repo_snapshot(repo_id: &str) -> Option<PathBuf> {
let repo_dir = huggingface_repo_dir(repo_id);
if let Some(snapshot) = resolve_huggingface_ref_snapshot(&repo_dir, "main") {
return Some(snapshot);
}
let snapshots = repo_dir.join("snapshots");
let mut candidates: Vec<(SystemTime, PathBuf)> = std::fs::read_dir(snapshots)
.ok()?
.filter_map(Result::ok)
.map(|e| e.path())
.filter(|p| p.is_dir() && snapshot_looks_ready(p))
.map(|path| {
let modified = path
.metadata()
.and_then(|metadata| metadata.modified())
.unwrap_or(SystemTime::UNIX_EPOCH);
(modified, path)
})
.collect();
candidates.sort();
candidates.pop().map(|(_, path)| path)
}
fn snapshot_looks_ready(path: &Path) -> bool {
if path.join("config.json").exists() || path.join("model_index.json").exists() {
return true;
}
snapshot_contains_ext(path, "safetensors")
}
fn snapshot_contains_ext(root: &Path, ext: &str) -> bool {
let Ok(entries) = std::fs::read_dir(root) else {
return false;
};
entries.filter_map(Result::ok).any(|entry| {
let path = entry.path();
if path.is_dir() {
snapshot_contains_ext(&path, ext)
} else {
path.extension()
.and_then(|value| value.to_str())
.map(|value| value.eq_ignore_ascii_case(ext))
.unwrap_or(false)
}
})
}
async fn download_hf_repo_snapshot(repo_id: &str) -> Result<(PathBuf, usize), InferenceError> {
let api = hf_hub::api::tokio::ApiBuilder::from_env()
.with_progress(false)
.build()
.map_err(|e| InferenceError::DownloadFailed(format!("init hf api: {e}")))?;
let repo = api.model(repo_id.to_string());
let info = repo
.info()
.await
.map_err(|e| InferenceError::DownloadFailed(format!("{repo_id}: {e}")))?;
let snapshot_path = std::env::var("HF_HOME")
.map(PathBuf::from)
.unwrap_or_else(|_| {
std::env::var("HOME")
.map(PathBuf::from)
.unwrap_or_else(|_| PathBuf::from("."))
.join(".cache")
.join("huggingface")
})
.join("hub")
.join(format!("models--{}", repo_id.replace('/', "--")))
.join("snapshots")
.join(&info.sha);
let mut downloaded = 0usize;
for sibling in &info.siblings {
let local_path = snapshot_path.join(&sibling.rfilename);
if local_path.exists() {
downloaded += 1;
continue;
}
repo.download(&sibling.rfilename).await.map_err(|e| {
InferenceError::DownloadFailed(format!("{repo_id}/{}: {e}", sibling.rfilename))
})?;
downloaded += 1;
}
Ok((snapshot_path, downloaded))
}
const BUILTIN_CATALOG_JSON: &str = include_str!("builtin_catalog.json");
static BUILTIN_CATALOG: std::sync::LazyLock<Vec<ModelSchema>> =
std::sync::LazyLock::new(|| {
serde_json::from_str(BUILTIN_CATALOG_JSON)
.expect("builtin_catalog.json failed to parse — fix the JSON, not this code")
});
fn builtin_catalog() -> Vec<ModelSchema> {
BUILTIN_CATALOG.clone()
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn test_registry() -> (UnifiedRegistry, TempDir) {
let tmp = TempDir::new().unwrap();
let reg = UnifiedRegistry::new(tmp.path().join("models"));
(reg, tmp)
}
#[test]
fn builtin_catalog_loads() {
let (reg, _tmp) = test_registry();
let all = reg.list();
assert_eq!(all.len(), builtin_catalog().len());
}
#[test]
fn mlx_vlm_models_reflect_runtime_availability() {
let (reg, _tmp) = test_registry();
let mlx_vlm_models: Vec<&ModelSchema> = reg
.list()
.into_iter()
.filter(|m| m.tags.iter().any(|t| t == "requires-mlx-vlm"))
.collect();
assert!(
!mlx_vlm_models.is_empty(),
"catalog should contain at least one model tagged \
`requires-mlx-vlm` — otherwise this regression has \
nothing to guard"
);
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
let expected = crate::backend::mlx_vlm_cli::is_available();
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
let expected = false;
for m in mlx_vlm_models {
assert_eq!(
m.available, expected,
"model {} `available` field should reflect \
mlx_vlm CLI presence (expected {expected}, got {})",
m.id, m.available
);
}
}
#[test]
fn builtin_catalog_json_parses() {
let catalog: Vec<ModelSchema> = serde_json::from_str(BUILTIN_CATALOG_JSON)
.expect("builtin_catalog.json must be valid ModelSchema array");
assert!(
!catalog.is_empty(),
"embedded catalog has no entries — that's almost certainly wrong"
);
let mut seen = std::collections::HashSet::new();
for entry in &catalog {
assert!(
seen.insert(entry.id.clone()),
"duplicate id in builtin_catalog.json: {}",
entry.id
);
}
}
#[test]
fn public_benchmarks_round_trip_through_model_info() {
use crate::schema::BenchmarkScore;
let (mut reg, _tmp) = test_registry();
let mut schema = reg
.find_by_name("Qwen3-4B")
.expect("catalog has Qwen3-4B")
.clone();
schema.id = "test/qwen3-4b-with-bench".into();
schema.public_benchmarks = vec![
BenchmarkScore {
name: "MMLU-Pro".into(),
score: 0.482,
harness: Some("5-shot CoT".into()),
source_url: Some("https://example.invalid/qwen3-4b-card".into()),
measured_at: Some("2025-08-12".into()),
},
BenchmarkScore {
name: "HumanEval".into(),
score: 0.713,
harness: Some("pass@1".into()),
source_url: None,
measured_at: None,
},
];
reg.register(schema);
let stored = reg
.get("test/qwen3-4b-with-bench")
.expect("registered model is retrievable");
let info = ModelInfo::from(stored);
assert_eq!(info.public_benchmarks.len(), 2);
let json = serde_json::to_string(&info).unwrap();
assert!(json.contains("\"public_benchmarks\""));
assert!(json.contains("\"MMLU-Pro\""));
assert!(json.contains("\"5-shot CoT\""));
let decoded: ModelInfo = serde_json::from_str(&json).unwrap();
assert_eq!(decoded.public_benchmarks.len(), 2);
assert_eq!(decoded.public_benchmarks[0].name, "MMLU-Pro");
assert_eq!(decoded.public_benchmarks[1].name, "HumanEval");
}
#[test]
fn public_benchmarks_default_to_empty_when_absent_in_json() {
let legacy_json = r#"{
"id": "legacy/test:1",
"name": "Legacy Test",
"provider": "test",
"family": "test",
"version": "",
"capabilities": ["generate"],
"context_length": 4096,
"param_count": "1B",
"quantization": null,
"performance": {},
"cost": {},
"source": { "type": "ollama", "model_tag": "legacy:1" },
"tags": [],
"supported_params": []
}"#;
let schema: ModelSchema = serde_json::from_str(legacy_json).unwrap();
assert!(schema.public_benchmarks.is_empty());
}
#[test]
fn find_by_name() {
let (reg, _tmp) = test_registry();
let m = reg.find_by_name("Qwen3-4B").unwrap();
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
assert_eq!(m.id, "mlx/qwen3-4b:4bit");
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
assert_eq!(m.id, "qwen/qwen3-4b:q4_k_m");
assert!(m.has_capability(ModelCapability::Code));
}
#[test]
fn query_by_capability() {
let (reg, _tmp) = test_registry();
let embed_models = reg.query_by_capability(ModelCapability::Embed);
assert_eq!(embed_models.len(), 2);
assert!(embed_models
.iter()
.any(|model| model.name == "Qwen3-Embedding-0.6B"));
assert!(embed_models
.iter()
.any(|model| model.name == "Qwen3-Embedding-0.6B-MLX"));
}
#[test]
fn query_with_filter() {
let (reg, _tmp) = test_registry();
let code_small = reg.query(&ModelFilter {
capabilities: vec![ModelCapability::Code],
max_size_mb: Some(3000),
local_only: true,
..Default::default()
});
assert_eq!(code_small.len(), 4);
}
#[test]
fn register_remote() {
let (mut reg, _tmp) = test_registry();
let initial_len = reg.list().len();
let initial_reasoning_len = reg
.query(&ModelFilter {
capabilities: vec![ModelCapability::Reasoning, ModelCapability::ToolUse],
..Default::default()
})
.len();
let remote = ModelSchema {
id: "anthropic/claude-sonnet-4-6:latest".into(),
name: "Claude Sonnet 4.6".into(),
provider: "anthropic".into(),
family: "claude-4".into(),
version: "latest".into(),
capabilities: vec![
ModelCapability::Generate,
ModelCapability::Code,
ModelCapability::Reasoning,
ModelCapability::ToolUse,
],
context_length: 200000,
param_count: String::new(),
quantization: None,
performance: PerformanceEnvelope {
latency_p50_ms: Some(2000),
..Default::default()
},
cost: CostModel {
input_per_mtok: Some(3.0),
output_per_mtok: Some(15.0),
..Default::default()
},
source: ModelSource::RemoteApi {
endpoint: "https://api.anthropic.com/v1/messages".into(),
api_key_env: "ANTHROPIC_API_KEY".into(),
api_key_envs: vec![],
api_version: Some("2023-06-01".into()),
protocol: ApiProtocol::Anthropic,
},
tags: vec![],
supported_params: vec![],
public_benchmarks: vec![],
available: false,
};
reg.register(remote);
assert_eq!(reg.list().len(), initial_len);
let reasoning = reg.query(&ModelFilter {
capabilities: vec![ModelCapability::Reasoning, ModelCapability::ToolUse],
..Default::default()
});
assert_eq!(reasoning.len(), initial_reasoning_len);
}
#[test]
fn unregister() {
let (mut reg, _tmp) = test_registry();
let initial_len = reg.list().len();
let removed = reg.unregister("qwen/qwen3-0.6b:q8_0");
assert!(removed.is_some());
assert_eq!(reg.list().len(), initial_len - 1);
}
#[test]
fn speech_models_are_curated() {
let (reg, _tmp) = test_registry();
let stt = reg.query_by_capability(ModelCapability::SpeechToText);
let tts = reg.query_by_capability(ModelCapability::TextToSpeech);
assert_eq!(stt.len(), 2);
assert_eq!(tts.len(), 4);
}
#[test]
fn qwen_8b_variants_keep_tool_use_consistent() {
let (reg, _tmp) = test_registry();
for name in ["Qwen3-8B", "Qwen3-8B-MLX"] {
let model = reg.find_by_name(name).expect("model should exist");
assert!(model.has_capability(ModelCapability::ToolUse));
assert!(model.has_capability(ModelCapability::MultiToolCall));
}
}
#[test]
fn mac_name_resolution_prefers_mlx_siblings() {
#[allow(unused_variables)]
let (reg, _tmp) = test_registry();
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
{
assert_eq!(reg.find_by_name("Qwen3-0.6B").unwrap().id, "mlx/qwen3-0.6b:6bit");
assert_eq!(reg.find_by_name("Qwen3-1.7B").unwrap().id, "mlx/qwen3-1.7b:3bit");
assert_eq!(
reg.find_by_name("Qwen3-Embedding-0.6B").unwrap().id,
"mlx/qwen3-embedding-0.6b:mxfp8"
);
}
}
#[test]
fn remote_multimodal_models_are_curated_as_vision_capable() {
let (reg, _tmp) = test_registry();
for name in [
"claude-opus-4-7",
"claude-opus-4-6",
"claude-sonnet-4-6",
"claude-haiku-4-5",
"gpt-5.4",
"gpt-5.4-mini",
"o3",
"o4-mini",
"gpt-4.1-mini",
"gemini-2.5-pro",
"gemini-2.5-flash",
] {
let model = reg.find_by_name(name).expect("model should exist");
assert!(
model.has_capability(ModelCapability::Vision),
"{name} should be curated as vision-capable"
);
}
}
#[test]
fn qwen25vl_entries_are_replaced_by_qwen3vl_in_builtin_catalog() {
let (reg, _tmp) = test_registry();
let stale_ids = [
"mlx/qwen2.5-vl-3b:4bit",
"mlx/qwen2.5-vl-7b:4bit",
"mlx-vlm/qwen2.5-vl-3b:4bit",
"mlx-vlm/qwen2.5-vl-7b:4bit",
"vllm-mlx/qwen2.5-vl-3b:4bit",
];
for id in stale_ids {
assert!(
reg.get(id).is_none(),
"{id} is superseded by Qwen3-VL; the catalog must not advertise it"
);
}
let vision_ids: Vec<&str> = reg
.query_by_capability(ModelCapability::Vision)
.into_iter()
.map(|model| model.id.as_str())
.collect();
for stale in stale_ids {
assert!(
!vision_ids.contains(&stale),
"{stale} must not be reachable through the Vision capability index"
);
}
assert!(
vision_ids.contains(&"mlx-vlm/qwen3-vl-2b:bf16"),
"Qwen3-VL is the supported local VL family and must route as Vision"
);
}
#[test]
fn gemini_models_are_curated_for_multimodal_tool_use() {
let (reg, _tmp) = test_registry();
for name in ["gemini-2.5-pro", "gemini-2.5-flash"] {
let model = reg.find_by_name(name).expect("model should exist");
assert!(model.has_capability(ModelCapability::Vision));
assert!(model.has_capability(ModelCapability::ToolUse));
assert!(model.has_capability(ModelCapability::MultiToolCall));
}
}
#[test]
fn visual_generation_models_are_curated() {
let (reg, _tmp) = test_registry();
assert_eq!(
reg.query_by_capability(ModelCapability::ImageGeneration).len(),
1
);
assert_eq!(
reg.query_by_capability(ModelCapability::VideoGeneration).len(),
1
);
}
}