1use crate::schema::reasoning_params;
9use std::collections::HashMap;
10use std::path::{Path, PathBuf};
11use std::time::SystemTime;
12
13use serde::{Deserialize, Serialize};
14use tracing::info;
15
16use crate::schema::*;
17use crate::InferenceError;
18
19#[derive(Debug, Clone, Default)]
21pub struct ModelFilter {
22 pub capabilities: Vec<ModelCapability>,
24 pub max_size_mb: Option<u64>,
26 pub max_latency_ms: Option<u64>,
28 pub max_cost_per_mtok: Option<f64>,
30 pub tags: Vec<String>,
32 pub provider: Option<String>,
34 pub local_only: bool,
36 pub available_only: bool,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct ModelUpgrade {
44 pub from_id: String,
45 pub from_name: String,
46 pub to_id: String,
47 pub to_name: String,
48 pub reason: String,
49 pub target_runtime: Option<String>,
50 pub target_runtime_requirement: Option<String>,
51 pub minimum_runtimes: Vec<ModelRuntimeRequirement>,
52 pub target_available: bool,
53 pub target_pullable: bool,
54 pub remove_old_supported: bool,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct ModelRuntimeRequirement {
59 pub name: String,
60 pub minimum_version: String,
61}
62
63pub struct UnifiedRegistry {
65 models_dir: PathBuf,
66 models: HashMap<String, ModelSchema>,
68 user_config_path: PathBuf,
70}
71
72#[derive(Debug, Clone, Deserialize)]
73struct ModelUpgradeRule {
74 from_ids: Vec<String>,
75 to_id: String,
76 reason: String,
77 target_runtime: Option<String>,
78 target_runtime_requirement: Option<String>,
79 #[serde(default)]
80 minimum_runtimes: Vec<ModelRuntimeRequirement>,
81 #[serde(default = "default_remove_old_after_available")]
82 remove_old_after_available: bool,
83}
84
85fn default_remove_old_after_available() -> bool {
86 true
87}
88
89fn model_upgrade_rules() -> Vec<ModelUpgradeRule> {
90 serde_json::from_str(include_str!("../assets/model-upgrades.json"))
91 .expect("built-in model-upgrades.json should parse")
92}
93
94impl UnifiedRegistry {
95 pub fn new(models_dir: PathBuf) -> Self {
96 let user_config_path = models_dir
97 .parent()
98 .unwrap_or(&models_dir)
99 .join("models.json");
100
101 let mut registry = Self {
102 models_dir,
103 models: HashMap::new(),
104 user_config_path,
105 };
106 registry.load_builtin_catalog();
107 registry.refresh_availability();
108 let _ = registry.load_user_config();
110 registry
111 }
112
113 pub fn register(&mut self, mut schema: ModelSchema) {
115 if schema.is_mlx() {
117 schema.available = if schema.tags.contains(&"speech".to_string()) {
118 speech_mlx_available()
119 } else if let ModelSource::Mlx { ref hf_repo, .. } = schema.source {
120 let mlx_dir = self.models_dir.join(&schema.name);
121 mlx_dir.join("config.json").exists()
122 || latest_huggingface_repo_snapshot(hf_repo).is_some()
123 } else {
124 let mlx_dir = self.models_dir.join(&schema.name);
125 mlx_dir.join("config.json").exists()
126 };
127 } else if schema.is_vllm_mlx() {
128 schema.available = std::env::var("VLLM_MLX_ENDPOINT").is_ok() || schema.available;
130 } else if schema.is_local() {
131 let local_path = self.models_dir.join(&schema.name).join("model.gguf");
132 schema.available = local_path.exists();
133 } else if schema.is_remote() {
134 if let ModelSource::RemoteApi {
136 ref api_key_env, ..
137 } = schema.source
138 {
139 schema.available = std::env::var(api_key_env).is_ok();
140 }
141 }
142 info!(id = %schema.id, name = %schema.name, available = schema.available, "registered model");
143 self.models.insert(schema.id.clone(), schema);
144 }
145
146 pub fn unregister(&mut self, id: &str) -> Option<ModelSchema> {
148 let removed = self.models.remove(id);
149 if let Some(ref m) = removed {
150 info!(id = %m.id, "unregistered model");
151 }
152 removed
153 }
154
155 pub fn list(&self) -> Vec<&ModelSchema> {
157 let mut models: Vec<&ModelSchema> = self.models.values().collect();
158 models.sort_by(|a, b| a.id.cmp(&b.id));
159 models
160 }
161
162 pub fn query(&self, filter: &ModelFilter) -> Vec<&ModelSchema> {
164 self.models
165 .values()
166 .filter(|m| {
167 if !filter.capabilities.iter().all(|c| m.has_capability(*c)) {
169 return false;
170 }
171 if let Some(max) = filter.max_size_mb {
173 if m.size_mb() > max && m.is_local() {
174 return false;
175 }
176 }
177 if let Some(max) = filter.max_latency_ms {
179 if let Some(p50) = m.performance.latency_p50_ms {
180 if p50 > max {
181 return false;
182 }
183 }
184 }
185 if let Some(max) = filter.max_cost_per_mtok {
187 if let Some(cost) = m.cost.output_per_mtok {
188 if cost > max {
189 return false;
190 }
191 }
192 }
193 if !filter.tags.iter().all(|t| m.tags.contains(t)) {
195 return false;
196 }
197 if let Some(ref p) = filter.provider {
199 if &m.provider != p {
200 return false;
201 }
202 }
203 if filter.local_only && !m.is_local() {
205 return false;
206 }
207 if filter.available_only && !m.available {
209 return false;
210 }
211 true
212 })
213 .collect()
214 }
215
216 pub fn query_by_capability(&self, cap: ModelCapability) -> Vec<&ModelSchema> {
218 self.query(&ModelFilter {
219 capabilities: vec![cap],
220 ..Default::default()
221 })
222 }
223
224 pub fn available_upgrades(&self) -> Vec<ModelUpgrade> {
226 let mut upgrades = Vec::new();
227 for rule in model_upgrade_rules() {
228 let Some(from) = rule
229 .from_ids
230 .iter()
231 .find_map(|id| self.models.get(id.as_str()))
232 .filter(|schema| schema.available)
233 else {
234 continue;
235 };
236 let Some(to) = self.models.get(rule.to_id.as_str()) else {
237 continue;
238 };
239 upgrades.push(ModelUpgrade {
240 from_id: from.id.clone(),
241 from_name: from.name.clone(),
242 to_id: to.id.clone(),
243 to_name: to.name.clone(),
244 reason: rule.reason.clone(),
245 target_runtime: rule.target_runtime.clone(),
246 target_runtime_requirement: rule.target_runtime_requirement.clone(),
247 minimum_runtimes: rule.minimum_runtimes.clone(),
248 target_available: to.available,
249 target_pullable: matches!(
250 to.source,
251 ModelSource::Local { .. } | ModelSource::Mlx { .. }
252 ),
253 remove_old_supported: matches!(
254 from.source,
255 ModelSource::Local { .. } | ModelSource::Mlx { .. }
256 ) && rule.remove_old_after_available,
257 });
258 }
259 upgrades.sort_by(|a, b| a.from_id.cmp(&b.from_id).then(a.to_id.cmp(&b.to_id)));
260 upgrades.dedup_by(|a, b| a.from_id == b.from_id && a.to_id == b.to_id);
261 upgrades
262 }
263
264 pub fn get(&self, id: &str) -> Option<&ModelSchema> {
266 self.models.get(id)
267 }
268
269 pub fn find_by_name(&self, name: &str) -> Option<&ModelSchema> {
272 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
273 if !name.to_ascii_lowercase().ends_with("-mlx") {
274 if let Some(mlx_variant) = self
275 .models
276 .values()
277 .find(|m| m.name.eq_ignore_ascii_case(&format!("{name}-MLX")))
278 {
279 return Some(mlx_variant);
280 }
281 }
282
283 self.models
284 .values()
285 .find(|m| m.name.eq_ignore_ascii_case(name))
286 }
287
288 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
292 pub fn resolve_mlx_equivalent(&self, schema: &ModelSchema) -> Option<&ModelSchema> {
293 if schema.is_mlx() || schema.is_vllm_mlx() {
295 return None;
296 }
297 if !matches!(schema.source, ModelSource::Local { .. }) {
299 return None;
300 }
301 let primary_cap = schema.capabilities.first()?;
303 self.models.values().find(|m| {
304 m.is_mlx()
305 && m.family == schema.family
306 && m.capabilities.contains(primary_cap)
307 })
308 }
309
310 pub async fn ensure_local(&self, id: &str) -> Result<PathBuf, InferenceError> {
312 let schema = self
313 .get(id)
314 .or_else(|| self.find_by_name(id))
315 .ok_or_else(|| InferenceError::ModelNotFound(id.to_string()))?;
316
317 match &schema.source {
318 ModelSource::Local {
319 hf_repo,
320 hf_filename,
321 tokenizer_repo,
322 } => {
323 let model_dir = self.models_dir.join(&schema.name);
324 let model_path = model_dir.join("model.gguf");
325 let tokenizer_path = model_dir.join("tokenizer.json");
326
327 if model_path.exists() && tokenizer_path.exists() {
328 return Ok(model_dir);
329 }
330
331 std::fs::create_dir_all(&model_dir)?;
332
333 if !model_path.exists() {
334 info!(model = %schema.name, repo = %hf_repo, "downloading model weights");
335 download_file(hf_repo, hf_filename, &model_path).await?;
336 }
337 if !tokenizer_path.exists() {
338 info!(model = %schema.name, repo = %tokenizer_repo, "downloading tokenizer");
339 download_file(tokenizer_repo, "tokenizer.json", &tokenizer_path).await?;
340 }
341
342 Ok(model_dir)
343 }
344 ModelSource::Mlx {
345 hf_repo,
346 hf_weight_file,
347 } => {
348 let model_dir = self.models_dir.join(&schema.name);
349 let config_path = model_dir.join("config.json");
350
351 if config_path.exists() {
352 ensure_auxiliary_mlx_files(&schema.name, hf_repo, &model_dir).await?;
353 info!(model = %schema.name, path = %model_dir.display(), "using managed local MLX model");
354 return Ok(model_dir);
355 }
356
357 if let Some(snapshot_dir) = latest_huggingface_repo_snapshot(hf_repo) {
358 ensure_auxiliary_mlx_files(&schema.name, hf_repo, &snapshot_dir).await?;
359 info!(model = %schema.name, path = %snapshot_dir.display(), "using cached MLX snapshot");
360 return Ok(snapshot_dir);
361 }
362
363 if requires_full_mlx_snapshot(&schema) {
364 info!(
365 model = %schema.name,
366 repo = %hf_repo,
367 "downloading full MLX snapshot"
368 );
369 let (snapshot_dir, _files_downloaded) = download_hf_repo_snapshot(hf_repo).await?;
370 ensure_auxiliary_mlx_files(&schema.name, hf_repo, &snapshot_dir).await?;
371 return Ok(snapshot_dir);
372 }
373
374 std::fs::create_dir_all(&model_dir)?;
375
376 info!(model = %schema.name, repo = %hf_repo, "downloading MLX model");
377
378 download_file(hf_repo, "config.json", &config_path).await?;
380 let tok_path = model_dir.join("tokenizer.json");
381 if !tok_path.exists() {
382 download_file(hf_repo, "tokenizer.json", &tok_path).await?;
383 }
384 let tok_config_path = model_dir.join("tokenizer_config.json");
385 if !tok_config_path.exists() {
386 let _ = download_file(hf_repo, "tokenizer_config.json", &tok_config_path).await;
387 }
388
389 if let Some(ref wf) = hf_weight_file {
391 let wf_path = model_dir.join(wf);
392 if !wf_path.exists() {
393 download_file(hf_repo, wf, &wf_path).await?;
394 }
395 } else {
396 let single = model_dir.join("model.safetensors");
398 if !single.exists() {
399 match download_file(hf_repo, "model.safetensors", &single).await {
400 Ok(()) => {}
401 Err(_) => {
402 let index_path = model_dir.join("model.safetensors.index.json");
404 download_file(hf_repo, "model.safetensors.index.json", &index_path)
405 .await?;
406
407 let index_json: serde_json::Value =
408 serde_json::from_str(&std::fs::read_to_string(&index_path)?)
409 .map_err(|e| {
410 InferenceError::InferenceFailed(format!(
411 "parse index: {e}"
412 ))
413 })?;
414
415 if let Some(weight_map) =
416 index_json.get("weight_map").and_then(|m| m.as_object())
417 {
418 let mut files: std::collections::HashSet<String> =
419 std::collections::HashSet::new();
420 for filename in weight_map.values() {
421 if let Some(f) = filename.as_str() {
422 files.insert(f.to_string());
423 }
424 }
425 for file in &files {
426 let dest = model_dir.join(file);
427 if !dest.exists() {
428 info!(file = %file, "downloading weight shard");
429 download_file(hf_repo, file, &dest).await?;
430 }
431 }
432 }
433 }
434 }
435 }
436 }
437
438 ensure_auxiliary_mlx_files(&schema.name, hf_repo, &model_dir).await?;
439 Ok(model_dir)
440 }
441 _ => Err(InferenceError::InferenceFailed(format!(
442 "model {} is not local",
443 id
444 ))),
445 }
446 }
447
448 pub fn remove_local(&mut self, id: &str) -> Result<(), InferenceError> {
450 let schema = self
451 .get(id)
452 .or_else(|| self.find_by_name(id))
453 .ok_or_else(|| InferenceError::ModelNotFound(id.to_string()))?;
454
455 let model_dir = self.models_dir.join(&schema.name);
456 if model_dir.exists() {
457 std::fs::remove_dir_all(&model_dir)?;
458 info!(model = %schema.name, "removed model");
459 }
460
461 match &schema.source {
462 ModelSource::Mlx { hf_repo, .. } => {
463 let repo_dir = huggingface_repo_dir(hf_repo);
464 if repo_dir.exists() {
465 std::fs::remove_dir_all(&repo_dir)?;
466 info!(model = %schema.name, repo = %hf_repo, "removed Hugging Face cache");
467 }
468 }
469 ModelSource::Local {
470 hf_repo,
471 tokenizer_repo,
472 ..
473 } => {
474 for repo in [hf_repo, tokenizer_repo] {
475 let repo_dir = huggingface_repo_dir(repo);
476 if repo_dir.exists() {
477 std::fs::remove_dir_all(&repo_dir)?;
478 info!(model = %schema.name, repo = %repo, "removed Hugging Face cache");
479 }
480 }
481 }
482 _ => {}
483 }
484
485 let id = schema.id.clone();
487 if let Some(m) = self.models.get_mut(&id) {
488 m.available = false;
489 }
490 Ok(())
491 }
492
493 pub fn refresh_availability(&mut self) {
501 let models_dir = self.models_dir.clone();
502 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
505 let mlx_vlm_cli_present = crate::backend::mlx_vlm_cli::is_available();
506 #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
507 let mlx_vlm_cli_present = false;
508
509 for m in self.models.values_mut() {
510 match &m.source {
511 ModelSource::Mlx { hf_repo, .. } => {
512 let needs_mlx_vlm =
519 m.tags.iter().any(|t| t == "requires-mlx-vlm");
520
521 m.available = if needs_mlx_vlm {
522 mlx_vlm_cli_present
523 } else if m.tags.contains(&"speech".to_string()) {
524 speech_mlx_available()
525 } else {
526 let mlx_dir = models_dir.join(&m.name);
527 mlx_dir.join("config.json").exists()
528 || latest_huggingface_repo_snapshot(hf_repo).is_some()
529 };
530 }
531 ModelSource::Local { .. } => {
532 let local_path = models_dir.join(&m.name).join("model.gguf");
533 m.available = local_path.exists();
534 }
535 ModelSource::RemoteApi { api_key_env, .. } => {
536 m.available = std::env::var(api_key_env).is_ok();
537 }
538 ModelSource::Ollama { .. } => {
539 m.available = true;
541 }
542 ModelSource::VllmMlx { .. } => {
543 m.available = std::env::var("VLLM_MLX_ENDPOINT").is_ok() || m.available;
546 }
548 ModelSource::Proprietary { auth, .. } => {
549 m.available = match auth {
551 crate::schema::ProprietaryAuth::ApiKeyEnv { env_var } => {
552 std::env::var(env_var).is_ok()
553 }
554 crate::schema::ProprietaryAuth::BearerTokenEnv { env_var } => {
555 std::env::var(env_var).is_ok()
556 }
557 crate::schema::ProprietaryAuth::OAuth2Pkce { .. } => {
558 true
560 }
561 };
562 }
563 ModelSource::AppleFoundationModels { .. } => {
564 #[cfg(any(
571 all(target_os = "macos", target_arch = "aarch64"),
572 all(target_os = "ios", target_arch = "aarch64")
573 ))]
574 {
575 m.available = crate::backend::foundation_models::is_available();
576 }
577 #[cfg(not(any(
578 all(target_os = "macos", target_arch = "aarch64"),
579 all(target_os = "ios", target_arch = "aarch64")
580 )))]
581 {
582 m.available = false;
583 }
584 }
585 ModelSource::Delegated { .. } => {
586 m.available = crate::runner::current_inference_runner().is_some();
591 }
592 }
593 }
594 }
595
596 pub fn save_user_config(&self) -> Result<(), InferenceError> {
598 let user_models: Vec<&ModelSchema> = self
599 .models
600 .values()
601 .filter(|m| !m.tags.contains(&"builtin".to_string()))
602 .collect();
603
604 if user_models.is_empty() {
605 return Ok(());
606 }
607
608 let json = serde_json::to_string_pretty(&user_models)
609 .map_err(|e| InferenceError::InferenceFailed(format!("serialize: {e}")))?;
610 std::fs::write(&self.user_config_path, json)?;
611 Ok(())
612 }
613
614 pub fn load_user_config(&mut self) -> Result<(), InferenceError> {
616 if !self.user_config_path.exists() {
617 return Ok(());
618 }
619
620 let json = std::fs::read_to_string(&self.user_config_path)?;
621 let models: Vec<ModelSchema> = serde_json::from_str(&json)
622 .map_err(|e| InferenceError::InferenceFailed(format!("parse models.json: {e}")))?;
623
624 for m in models {
625 self.register(m);
626 }
627 Ok(())
628 }
629
630 pub fn models_dir(&self) -> &Path {
632 &self.models_dir
633 }
634
635 fn load_builtin_catalog(&mut self) {
637 for schema in builtin_catalog() {
638 self.models.insert(schema.id.clone(), schema);
639 }
640 }
641}
642
643fn speech_mlx_available() -> bool {
644 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
647 { true }
648
649 #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
651 {
652 let runtime_root = speech_runtime_root();
653 runtime_root.join("bin").join("mlx_audio.stt.generate").exists()
654 || runtime_root.join("bin").join("mlx_audio.tts.generate").exists()
655 }
656}
657
658fn speech_runtime_root() -> PathBuf {
659 if let Ok(path) = std::env::var("CAR_SPEECH_RUNTIME_DIR") {
660 if !path.trim().is_empty() {
661 return PathBuf::from(path);
662 }
663 }
664 std::env::var("HOME")
665 .map(PathBuf::from)
666 .unwrap_or_else(|_| PathBuf::from("."))
667 .join(".car")
668 .join("speech-runtime")
669}
670
671#[derive(Debug, Clone, Serialize, Deserialize)]
673pub struct ModelInfo {
674 pub id: String,
675 pub name: String,
676 pub provider: String,
677 pub capabilities: Vec<ModelCapability>,
678 pub param_count: String,
679 pub size_mb: u64,
680 pub context_length: usize,
681 pub available: bool,
682 pub is_local: bool,
683 #[serde(default)]
687 pub public_benchmarks: Vec<crate::schema::BenchmarkScore>,
688}
689
690impl From<&ModelSchema> for ModelInfo {
691 fn from(s: &ModelSchema) -> Self {
692 ModelInfo {
693 id: s.id.clone(),
694 name: s.name.clone(),
695 provider: s.provider.clone(),
696 capabilities: s.capabilities.clone(),
697 param_count: s.param_count.clone(),
698 size_mb: s.size_mb(),
699 context_length: s.context_length,
700 available: s.available,
701 is_local: s.is_local(),
702 public_benchmarks: s.public_benchmarks.clone(),
703 }
704 }
705}
706
707async fn download_file(repo: &str, filename: &str, dest: &Path) -> Result<(), InferenceError> {
709 let api = hf_hub::api::tokio::Api::new()
710 .map_err(|e| InferenceError::DownloadFailed(e.to_string()))?;
711
712 let repo = api.model(repo.to_string());
713 let path = repo
714 .get(filename)
715 .await
716 .map_err(|e| InferenceError::DownloadFailed(format!("{filename}: {e}")))?;
717
718 if dest.exists() {
719 return Ok(());
720 }
721
722 #[cfg(unix)]
724 {
725 if std::os::unix::fs::symlink(&path, dest).is_ok() {
726 return Ok(());
727 }
728 }
729
730 std::fs::copy(&path, dest)
731 .map_err(|e| InferenceError::DownloadFailed(format!("copy to {}: {e}", dest.display())))?;
732 Ok(())
733}
734
735async fn ensure_auxiliary_mlx_files(
736 model_name: &str,
737 hf_repo: &str,
738 model_dir: &Path,
739) -> Result<(), InferenceError> {
740 if hf_repo == "mlx-community/Flux-1.lite-8B-MLX-Q4" || model_name == "Flux-1.lite-8B-MLX-Q4" {
741 let t5_tokenizer_path = model_dir.join("tokenizer_2").join("tokenizer.json");
742 if !t5_tokenizer_path.exists() {
743 std::fs::create_dir_all(
744 t5_tokenizer_path
745 .parent()
746 .ok_or_else(|| InferenceError::InferenceFailed("invalid tokenizer path".into()))?,
747 )?;
748 info!(
749 path = %t5_tokenizer_path.display(),
750 "downloading missing Flux tokenizer_2/tokenizer.json from base model"
751 );
752 download_file("Freepik/flux.1-lite-8B", "tokenizer_2/tokenizer.json", &t5_tokenizer_path)
753 .await?;
754 }
755 }
756 Ok(())
757}
758
759fn requires_full_mlx_snapshot(schema: &ModelSchema) -> bool {
760 match &schema.source {
761 ModelSource::Mlx { hf_repo, .. } => {
762 hf_repo == "ckurasek/Yume-1.5-5B-720P-MLX-4bit"
763 || schema.family.starts_with("yume")
764 || schema.tags.iter().any(|tag| {
765 matches!(
766 tag.as_str(),
767 "wan2.2" | "ti2v" | "world-model" | "image-to-video"
768 )
769 })
770 }
771 _ => false,
772 }
773}
774
775fn huggingface_repo_has_snapshot(repo_id: &str) -> bool {
776 latest_huggingface_repo_snapshot(repo_id).is_some()
777}
778
779fn huggingface_cache_root() -> PathBuf {
780 std::env::var("HF_HOME")
781 .map(PathBuf::from)
782 .unwrap_or_else(|_| {
783 std::env::var("HOME")
784 .map(PathBuf::from)
785 .unwrap_or_else(|_| PathBuf::from("."))
786 .join(".cache")
787 .join("huggingface")
788 })
789 .join("hub")
790}
791
792fn huggingface_repo_dir(repo_id: &str) -> PathBuf {
793 huggingface_cache_root().join(format!("models--{}", repo_id.replace('/', "--")))
794}
795
796fn resolve_huggingface_ref_snapshot(repo_dir: &Path, name: &str) -> Option<PathBuf> {
797 let sha = std::fs::read_to_string(repo_dir.join("refs").join(name))
798 .ok()?
799 .trim()
800 .to_string();
801 if sha.is_empty() {
802 return None;
803 }
804
805 let snapshot = repo_dir.join("snapshots").join(sha);
806 if snapshot_looks_ready(&snapshot) {
807 Some(snapshot)
808 } else {
809 None
810 }
811}
812
813fn latest_huggingface_repo_snapshot(repo_id: &str) -> Option<PathBuf> {
814 let repo_dir = huggingface_repo_dir(repo_id);
815 if let Some(snapshot) = resolve_huggingface_ref_snapshot(&repo_dir, "main") {
816 return Some(snapshot);
817 }
818
819 let snapshots = repo_dir.join("snapshots");
820 let mut candidates: Vec<(SystemTime, PathBuf)> = std::fs::read_dir(snapshots)
821 .ok()?
822 .filter_map(Result::ok)
823 .map(|e| e.path())
824 .filter(|p| p.is_dir() && snapshot_looks_ready(p))
825 .map(|path| {
826 let modified = path
827 .metadata()
828 .and_then(|metadata| metadata.modified())
829 .unwrap_or(SystemTime::UNIX_EPOCH);
830 (modified, path)
831 })
832 .collect();
833 candidates.sort();
834 candidates.pop().map(|(_, path)| path)
835}
836
837fn snapshot_looks_ready(path: &Path) -> bool {
838 if path.join("config.json").exists() || path.join("model_index.json").exists() {
839 return true;
840 }
841 snapshot_contains_ext(path, "safetensors")
842}
843
844fn snapshot_contains_ext(root: &Path, ext: &str) -> bool {
845 let Ok(entries) = std::fs::read_dir(root) else {
846 return false;
847 };
848 entries.filter_map(Result::ok).any(|entry| {
849 let path = entry.path();
850 if path.is_dir() {
851 snapshot_contains_ext(&path, ext)
852 } else {
853 path.extension()
854 .and_then(|value| value.to_str())
855 .map(|value| value.eq_ignore_ascii_case(ext))
856 .unwrap_or(false)
857 }
858 })
859}
860
861async fn download_hf_repo_snapshot(repo_id: &str) -> Result<(PathBuf, usize), InferenceError> {
862 let api = hf_hub::api::tokio::ApiBuilder::from_env()
863 .with_progress(false)
864 .build()
865 .map_err(|e| InferenceError::DownloadFailed(format!("init hf api: {e}")))?;
866 let repo = api.model(repo_id.to_string());
867 let info = repo
868 .info()
869 .await
870 .map_err(|e| InferenceError::DownloadFailed(format!("{repo_id}: {e}")))?;
871
872 let snapshot_path = std::env::var("HF_HOME")
873 .map(PathBuf::from)
874 .unwrap_or_else(|_| {
875 std::env::var("HOME")
876 .map(PathBuf::from)
877 .unwrap_or_else(|_| PathBuf::from("."))
878 .join(".cache")
879 .join("huggingface")
880 })
881 .join("hub")
882 .join(format!("models--{}", repo_id.replace('/', "--")))
883 .join("snapshots")
884 .join(&info.sha);
885 let mut downloaded = 0usize;
886 for sibling in &info.siblings {
887 let local_path = snapshot_path.join(&sibling.rfilename);
888 if local_path.exists() {
889 downloaded += 1;
890 continue;
891 }
892 repo.download(&sibling.rfilename).await.map_err(|e| {
893 InferenceError::DownloadFailed(format!("{repo_id}/{}: {e}", sibling.rfilename))
894 })?;
895 downloaded += 1;
896 }
897
898 Ok((snapshot_path, downloaded))
899}
900
901const BUILTIN_CATALOG_JSON: &str = include_str!("builtin_catalog.json");
910
911static BUILTIN_CATALOG: std::sync::LazyLock<Vec<ModelSchema>> =
912 std::sync::LazyLock::new(|| {
913 serde_json::from_str(BUILTIN_CATALOG_JSON)
914 .expect("builtin_catalog.json failed to parse — fix the JSON, not this code")
915 });
916
917fn builtin_catalog() -> Vec<ModelSchema> {
918 BUILTIN_CATALOG.clone()
919}
920
921#[cfg(test)]
922mod tests {
923 use super::*;
924 use tempfile::TempDir;
925
926 fn test_registry() -> (UnifiedRegistry, TempDir) {
927 let tmp = TempDir::new().unwrap();
928 let reg = UnifiedRegistry::new(tmp.path().join("models"));
929 (reg, tmp)
930 }
931
932 #[test]
933 fn builtin_catalog_loads() {
934 let (reg, _tmp) = test_registry();
935 let all = reg.list();
936 assert_eq!(all.len(), builtin_catalog().len());
937 }
938
939 #[test]
952 fn mlx_vlm_models_reflect_runtime_availability() {
953 let (reg, _tmp) = test_registry();
954 let mlx_vlm_models: Vec<&ModelSchema> = reg
955 .list()
956 .into_iter()
957 .filter(|m| m.tags.iter().any(|t| t == "requires-mlx-vlm"))
958 .collect();
959 assert!(
960 !mlx_vlm_models.is_empty(),
961 "catalog should contain at least one model tagged \
962 `requires-mlx-vlm` — otherwise this regression has \
963 nothing to guard"
964 );
965
966 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
967 let expected = crate::backend::mlx_vlm_cli::is_available();
968 #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
969 let expected = false;
970
971 for m in mlx_vlm_models {
972 assert_eq!(
973 m.available, expected,
974 "model {} `available` field should reflect \
975 mlx_vlm CLI presence (expected {expected}, got {})",
976 m.id, m.available
977 );
978 }
979 }
980
981 #[test]
984 fn builtin_catalog_json_parses() {
985 let catalog: Vec<ModelSchema> = serde_json::from_str(BUILTIN_CATALOG_JSON)
986 .expect("builtin_catalog.json must be valid ModelSchema array");
987 assert!(
988 !catalog.is_empty(),
989 "embedded catalog has no entries — that's almost certainly wrong"
990 );
991
992 let mut seen = std::collections::HashSet::new();
993 for entry in &catalog {
994 assert!(
995 seen.insert(entry.id.clone()),
996 "duplicate id in builtin_catalog.json: {}",
997 entry.id
998 );
999 }
1000 }
1001
1002 #[test]
1003 fn public_benchmarks_round_trip_through_model_info() {
1004 use crate::schema::BenchmarkScore;
1005 let (mut reg, _tmp) = test_registry();
1006 let mut schema = reg
1007 .find_by_name("Qwen3-4B")
1008 .expect("catalog has Qwen3-4B")
1009 .clone();
1010 schema.id = "test/qwen3-4b-with-bench".into();
1011 schema.public_benchmarks = vec![
1012 BenchmarkScore {
1013 name: "MMLU-Pro".into(),
1014 score: 0.482,
1015 harness: Some("5-shot CoT".into()),
1016 source_url: Some("https://example.invalid/qwen3-4b-card".into()),
1017 measured_at: Some("2025-08-12".into()),
1018 },
1019 BenchmarkScore {
1020 name: "HumanEval".into(),
1021 score: 0.713,
1022 harness: Some("pass@1".into()),
1023 source_url: None,
1024 measured_at: None,
1025 },
1026 ];
1027 reg.register(schema);
1028
1029 let stored = reg
1030 .get("test/qwen3-4b-with-bench")
1031 .expect("registered model is retrievable");
1032 let info = ModelInfo::from(stored);
1033 assert_eq!(info.public_benchmarks.len(), 2);
1034
1035 let json = serde_json::to_string(&info).unwrap();
1037 assert!(json.contains("\"public_benchmarks\""));
1038 assert!(json.contains("\"MMLU-Pro\""));
1039 assert!(json.contains("\"5-shot CoT\""));
1040
1041 let decoded: ModelInfo = serde_json::from_str(&json).unwrap();
1043 assert_eq!(decoded.public_benchmarks.len(), 2);
1044 assert_eq!(decoded.public_benchmarks[0].name, "MMLU-Pro");
1045 assert_eq!(decoded.public_benchmarks[1].name, "HumanEval");
1046 }
1047
1048 #[test]
1049 fn public_benchmarks_default_to_empty_when_absent_in_json() {
1050 let legacy_json = r#"{
1053 "id": "legacy/test:1",
1054 "name": "Legacy Test",
1055 "provider": "test",
1056 "family": "test",
1057 "version": "",
1058 "capabilities": ["generate"],
1059 "context_length": 4096,
1060 "param_count": "1B",
1061 "quantization": null,
1062 "performance": {},
1063 "cost": {},
1064 "source": { "type": "ollama", "model_tag": "legacy:1" },
1065 "tags": [],
1066 "supported_params": []
1067 }"#;
1068 let schema: ModelSchema = serde_json::from_str(legacy_json).unwrap();
1069 assert!(schema.public_benchmarks.is_empty());
1070 }
1071
1072 #[test]
1073 fn find_by_name() {
1074 let (reg, _tmp) = test_registry();
1075 let m = reg.find_by_name("Qwen3-4B").unwrap();
1076 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
1077 assert_eq!(m.id, "mlx/qwen3-4b:4bit");
1078 #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
1079 assert_eq!(m.id, "qwen/qwen3-4b:q4_k_m");
1080 assert!(m.has_capability(ModelCapability::Code));
1081 }
1082
1083 #[test]
1084 fn query_by_capability() {
1085 let (reg, _tmp) = test_registry();
1086 let embed_models = reg.query_by_capability(ModelCapability::Embed);
1087 assert_eq!(embed_models.len(), 2);
1088 assert!(embed_models
1089 .iter()
1090 .any(|model| model.name == "Qwen3-Embedding-0.6B"));
1091 assert!(embed_models
1092 .iter()
1093 .any(|model| model.name == "Qwen3-Embedding-0.6B-MLX"));
1094 }
1095
1096 #[test]
1097 fn query_with_filter() {
1098 let (reg, _tmp) = test_registry();
1099 let code_small = reg.query(&ModelFilter {
1100 capabilities: vec![ModelCapability::Code],
1101 max_size_mb: Some(3000),
1102 local_only: true,
1103 ..Default::default()
1104 });
1105 assert_eq!(code_small.len(), 4);
1107 }
1108
1109 #[test]
1110 fn register_remote() {
1111 let (mut reg, _tmp) = test_registry();
1112 let initial_len = reg.list().len();
1113 let initial_reasoning_len = reg
1114 .query(&ModelFilter {
1115 capabilities: vec![ModelCapability::Reasoning, ModelCapability::ToolUse],
1116 ..Default::default()
1117 })
1118 .len();
1119 let remote = ModelSchema {
1120 id: "anthropic/claude-sonnet-4-6:latest".into(),
1121 name: "Claude Sonnet 4.6".into(),
1122 provider: "anthropic".into(),
1123 family: "claude-4".into(),
1124 version: "latest".into(),
1125 capabilities: vec![
1126 ModelCapability::Generate,
1127 ModelCapability::Code,
1128 ModelCapability::Reasoning,
1129 ModelCapability::ToolUse,
1130 ],
1131 context_length: 200000,
1132 param_count: String::new(),
1133 quantization: None,
1134 performance: PerformanceEnvelope {
1135 latency_p50_ms: Some(2000),
1136 ..Default::default()
1137 },
1138 cost: CostModel {
1139 input_per_mtok: Some(3.0),
1140 output_per_mtok: Some(15.0),
1141 ..Default::default()
1142 },
1143 source: ModelSource::RemoteApi {
1144 endpoint: "https://api.anthropic.com/v1/messages".into(),
1145 api_key_env: "ANTHROPIC_API_KEY".into(),
1146 api_key_envs: vec![],
1147 api_version: Some("2023-06-01".into()),
1148 protocol: ApiProtocol::Anthropic,
1149 },
1150 tags: vec![],
1151 supported_params: vec![],
1152 public_benchmarks: vec![],
1153 available: false,
1154 };
1155
1156 reg.register(remote);
1157 assert_eq!(reg.list().len(), initial_len);
1159
1160 let reasoning = reg.query(&ModelFilter {
1161 capabilities: vec![ModelCapability::Reasoning, ModelCapability::ToolUse],
1162 ..Default::default()
1163 });
1164 assert_eq!(reasoning.len(), initial_reasoning_len);
1166 }
1167
1168 #[test]
1169 fn unregister() {
1170 let (mut reg, _tmp) = test_registry();
1171 let initial_len = reg.list().len();
1172 let removed = reg.unregister("qwen/qwen3-0.6b:q8_0");
1173 assert!(removed.is_some());
1174 assert_eq!(reg.list().len(), initial_len - 1);
1175 }
1176
1177 #[test]
1178 fn speech_models_are_curated() {
1179 let (reg, _tmp) = test_registry();
1180 let stt = reg.query_by_capability(ModelCapability::SpeechToText);
1181 let tts = reg.query_by_capability(ModelCapability::TextToSpeech);
1182 assert_eq!(stt.len(), 2);
1183 assert_eq!(tts.len(), 4);
1184 }
1185
1186 #[test]
1187 fn qwen_8b_variants_keep_tool_use_consistent() {
1188 let (reg, _tmp) = test_registry();
1189 for name in ["Qwen3-8B", "Qwen3-8B-MLX"] {
1190 let model = reg.find_by_name(name).expect("model should exist");
1191 assert!(model.has_capability(ModelCapability::ToolUse));
1192 assert!(model.has_capability(ModelCapability::MultiToolCall));
1193 }
1194 }
1195
1196 #[test]
1197 fn mac_name_resolution_prefers_mlx_siblings() {
1198 #[allow(unused_variables)]
1201 let (reg, _tmp) = test_registry();
1202 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
1203 {
1204 assert_eq!(reg.find_by_name("Qwen3-0.6B").unwrap().id, "mlx/qwen3-0.6b:6bit");
1205 assert_eq!(reg.find_by_name("Qwen3-1.7B").unwrap().id, "mlx/qwen3-1.7b:3bit");
1206 assert_eq!(
1207 reg.find_by_name("Qwen3-Embedding-0.6B").unwrap().id,
1208 "mlx/qwen3-embedding-0.6b:mxfp8"
1209 );
1210 }
1211 }
1212
1213 #[test]
1214 fn remote_multimodal_models_are_curated_as_vision_capable() {
1215 let (reg, _tmp) = test_registry();
1216 for name in [
1217 "claude-opus-4-7",
1218 "claude-opus-4-6",
1219 "claude-sonnet-4-6",
1220 "claude-haiku-4-5",
1221 "gpt-5.4",
1222 "gpt-5.4-mini",
1223 "o3",
1224 "o4-mini",
1225 "gpt-4.1-mini",
1226 "gemini-2.5-pro",
1227 "gemini-2.5-flash",
1228 ] {
1229 let model = reg.find_by_name(name).expect("model should exist");
1230 assert!(
1231 model.has_capability(ModelCapability::Vision),
1232 "{name} should be curated as vision-capable"
1233 );
1234 }
1235 }
1236
1237 #[test]
1238 fn qwen25vl_entries_are_replaced_by_qwen3vl_in_builtin_catalog() {
1239 let (reg, _tmp) = test_registry();
1240
1241 let stale_ids = [
1242 "mlx/qwen2.5-vl-3b:4bit",
1244 "mlx/qwen2.5-vl-7b:4bit",
1245 "mlx-vlm/qwen2.5-vl-3b:4bit",
1248 "mlx-vlm/qwen2.5-vl-7b:4bit",
1249 "vllm-mlx/qwen2.5-vl-3b:4bit",
1251 ];
1252 for id in stale_ids {
1253 assert!(
1254 reg.get(id).is_none(),
1255 "{id} is superseded by Qwen3-VL; the catalog must not advertise it"
1256 );
1257 }
1258
1259 let vision_ids: Vec<&str> = reg
1260 .query_by_capability(ModelCapability::Vision)
1261 .into_iter()
1262 .map(|model| model.id.as_str())
1263 .collect();
1264 for stale in stale_ids {
1265 assert!(
1266 !vision_ids.contains(&stale),
1267 "{stale} must not be reachable through the Vision capability index"
1268 );
1269 }
1270 assert!(
1271 vision_ids.contains(&"mlx-vlm/qwen3-vl-2b:bf16"),
1272 "Qwen3-VL is the supported local VL family and must route as Vision"
1273 );
1274 }
1275
1276 #[test]
1277 fn gemini_models_are_curated_for_multimodal_tool_use() {
1278 let (reg, _tmp) = test_registry();
1279 for name in ["gemini-2.5-pro", "gemini-2.5-flash"] {
1280 let model = reg.find_by_name(name).expect("model should exist");
1281 assert!(model.has_capability(ModelCapability::Vision));
1282 assert!(model.has_capability(ModelCapability::ToolUse));
1283 assert!(model.has_capability(ModelCapability::MultiToolCall));
1284 }
1285 }
1286
1287 #[test]
1288 fn visual_generation_models_are_curated() {
1289 let (reg, _tmp) = test_registry();
1290 assert_eq!(
1291 reg.query_by_capability(ModelCapability::ImageGeneration).len(),
1292 1
1293 );
1294 assert_eq!(
1295 reg.query_by_capability(ModelCapability::VideoGeneration).len(),
1296 2
1297 );
1298 let yume = reg
1299 .get("mlx/yume-1.5-5b-720p:q4")
1300 .expect("Yume MLX should be in the built-in catalog");
1301 assert!(yume.has_capability(ModelCapability::VideoGeneration));
1302 assert!(yume.tags.contains(&"text-to-video".to_string()));
1303 assert!(yume.tags.contains(&"image-to-video".to_string()));
1304 assert!(yume.tags.contains(&"world-model".to_string()));
1305 }
1306}