1use crate::schema::reasoning_params;
9use std::collections::HashMap;
10use std::path::{Path, PathBuf};
11use std::time::SystemTime;
12
13use serde::{Deserialize, Serialize};
14use tracing::info;
15
16use crate::schema::*;
17use crate::InferenceError;
18
19#[derive(Debug, Clone, Default)]
21pub struct ModelFilter {
22 pub capabilities: Vec<ModelCapability>,
24 pub max_size_mb: Option<u64>,
26 pub max_latency_ms: Option<u64>,
28 pub max_cost_per_mtok: Option<f64>,
30 pub tags: Vec<String>,
32 pub provider: Option<String>,
34 pub local_only: bool,
36 pub available_only: bool,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct ModelUpgrade {
44 pub from_id: String,
45 pub from_name: String,
46 pub to_id: String,
47 pub to_name: String,
48 pub reason: String,
49 pub target_runtime: Option<String>,
50 pub target_runtime_requirement: Option<String>,
51 pub minimum_runtimes: Vec<ModelRuntimeRequirement>,
52 pub target_available: bool,
53 pub target_pullable: bool,
54 pub remove_old_supported: bool,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct ModelRuntimeRequirement {
59 pub name: String,
60 pub minimum_version: String,
61}
62
63pub struct UnifiedRegistry {
65 models_dir: PathBuf,
66 models: HashMap<String, ModelSchema>,
68 user_config_path: PathBuf,
70}
71
72#[derive(Debug, Clone, Deserialize)]
73struct ModelUpgradeRule {
74 from_ids: Vec<String>,
75 to_id: String,
76 reason: String,
77 target_runtime: Option<String>,
78 target_runtime_requirement: Option<String>,
79 #[serde(default)]
80 minimum_runtimes: Vec<ModelRuntimeRequirement>,
81 #[serde(default = "default_remove_old_after_available")]
82 remove_old_after_available: bool,
83}
84
85fn default_remove_old_after_available() -> bool {
86 true
87}
88
89fn model_upgrade_rules() -> Vec<ModelUpgradeRule> {
90 serde_json::from_str(include_str!("../assets/model-upgrades.json"))
91 .expect("built-in model-upgrades.json should parse")
92}
93
94impl UnifiedRegistry {
95 pub fn new(models_dir: PathBuf) -> Self {
96 let user_config_path = models_dir
97 .parent()
98 .unwrap_or(&models_dir)
99 .join("models.json");
100
101 let mut registry = Self {
102 models_dir,
103 models: HashMap::new(),
104 user_config_path,
105 };
106 registry.load_builtin_catalog();
107 registry.refresh_availability();
108 let _ = registry.load_user_config();
110 registry
111 }
112
113 pub fn register(&mut self, mut schema: ModelSchema) {
115 if schema.is_mlx() {
117 schema.available = if schema.tags.contains(&"speech".to_string()) {
118 speech_mlx_available()
119 } else if let ModelSource::Mlx { ref hf_repo, .. } = schema.source {
120 let mlx_dir = self.models_dir.join(&schema.name);
126 mlx_dir.join("config.json").exists() || !hf_repo.is_empty()
127 } else {
128 let mlx_dir = self.models_dir.join(&schema.name);
129 mlx_dir.join("config.json").exists()
130 };
131 } else if schema.is_vllm_mlx() {
132 schema.available = std::env::var("VLLM_MLX_ENDPOINT").is_ok() || schema.available;
134 } else if schema.is_local() {
135 let local_path = self.models_dir.join(&schema.name).join("model.gguf");
136 schema.available = local_path.exists();
137 } else if schema.is_remote() {
138 if let ModelSource::RemoteApi {
140 ref api_key_env, ..
141 } = schema.source
142 {
143 schema.available = std::env::var(api_key_env).is_ok();
144 }
145 }
146 info!(id = %schema.id, name = %schema.name, available = schema.available, "registered model");
147 self.models.insert(schema.id.clone(), schema);
148 }
149
150 pub fn unregister(&mut self, id: &str) -> Option<ModelSchema> {
152 let removed = self.models.remove(id);
153 if let Some(ref m) = removed {
154 info!(id = %m.id, "unregistered model");
155 }
156 removed
157 }
158
159 pub fn list(&self) -> Vec<&ModelSchema> {
161 let mut models: Vec<&ModelSchema> = self.models.values().collect();
162 models.sort_by(|a, b| a.id.cmp(&b.id));
163 models
164 }
165
166 pub fn query(&self, filter: &ModelFilter) -> Vec<&ModelSchema> {
168 self.models
169 .values()
170 .filter(|m| {
171 if !filter.capabilities.iter().all(|c| m.has_capability(*c)) {
173 return false;
174 }
175 if let Some(max) = filter.max_size_mb {
177 if m.size_mb() > max && m.is_local() {
178 return false;
179 }
180 }
181 if let Some(max) = filter.max_latency_ms {
183 if let Some(p50) = m.performance.latency_p50_ms {
184 if p50 > max {
185 return false;
186 }
187 }
188 }
189 if let Some(max) = filter.max_cost_per_mtok {
191 if let Some(cost) = m.cost.output_per_mtok {
192 if cost > max {
193 return false;
194 }
195 }
196 }
197 if !filter.tags.iter().all(|t| m.tags.contains(t)) {
199 return false;
200 }
201 if let Some(ref p) = filter.provider {
203 if &m.provider != p {
204 return false;
205 }
206 }
207 if filter.local_only && !m.is_local() {
209 return false;
210 }
211 if filter.available_only && !m.available {
213 return false;
214 }
215 true
216 })
217 .collect()
218 }
219
220 pub fn query_by_capability(&self, cap: ModelCapability) -> Vec<&ModelSchema> {
222 self.query(&ModelFilter {
223 capabilities: vec![cap],
224 ..Default::default()
225 })
226 }
227
228 pub fn available_upgrades(&self) -> Vec<ModelUpgrade> {
230 let mut upgrades = Vec::new();
231 for rule in model_upgrade_rules() {
232 let Some(from) = rule
233 .from_ids
234 .iter()
235 .find_map(|id| self.models.get(id.as_str()))
236 .filter(|schema| schema.available)
237 else {
238 continue;
239 };
240 let Some(to) = self.models.get(rule.to_id.as_str()) else {
241 continue;
242 };
243 upgrades.push(ModelUpgrade {
244 from_id: from.id.clone(),
245 from_name: from.name.clone(),
246 to_id: to.id.clone(),
247 to_name: to.name.clone(),
248 reason: rule.reason.clone(),
249 target_runtime: rule.target_runtime.clone(),
250 target_runtime_requirement: rule.target_runtime_requirement.clone(),
251 minimum_runtimes: rule.minimum_runtimes.clone(),
252 target_available: to.available,
253 target_pullable: matches!(
254 to.source,
255 ModelSource::Local { .. } | ModelSource::Mlx { .. }
256 ),
257 remove_old_supported: matches!(
258 from.source,
259 ModelSource::Local { .. } | ModelSource::Mlx { .. }
260 ) && rule.remove_old_after_available,
261 });
262 }
263 upgrades.sort_by(|a, b| a.from_id.cmp(&b.from_id).then(a.to_id.cmp(&b.to_id)));
264 upgrades.dedup_by(|a, b| a.from_id == b.from_id && a.to_id == b.to_id);
265 upgrades
266 }
267
268 pub fn get(&self, id: &str) -> Option<&ModelSchema> {
270 self.models.get(id)
271 }
272
273 pub fn find_by_name(&self, name: &str) -> Option<&ModelSchema> {
276 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
277 if !name.to_ascii_lowercase().ends_with("-mlx") {
278 if let Some(mlx_variant) = self
279 .models
280 .values()
281 .find(|m| m.name.eq_ignore_ascii_case(&format!("{name}-MLX")))
282 {
283 return Some(mlx_variant);
284 }
285 }
286
287 self.models
288 .values()
289 .find(|m| m.name.eq_ignore_ascii_case(name))
290 }
291
292 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
296 pub fn resolve_mlx_equivalent(&self, schema: &ModelSchema) -> Option<&ModelSchema> {
297 if schema.is_mlx() || schema.is_vllm_mlx() {
299 return None;
300 }
301 if !matches!(schema.source, ModelSource::Local { .. }) {
303 return None;
304 }
305 let primary_cap = schema.capabilities.first()?;
307 self.models.values().find(|m| {
308 m.is_mlx() && m.family == schema.family && m.capabilities.contains(primary_cap)
309 })
310 }
311
312 pub async fn ensure_local(&self, id: &str) -> Result<PathBuf, InferenceError> {
314 let schema = self
315 .get(id)
316 .or_else(|| self.find_by_name(id))
317 .ok_or_else(|| InferenceError::ModelNotFound(id.to_string()))?;
318
319 match &schema.source {
320 ModelSource::Local {
321 hf_repo,
322 hf_filename,
323 tokenizer_repo,
324 } => {
325 let model_dir = self.models_dir.join(&schema.name);
326 let model_path = model_dir.join("model.gguf");
327 let tokenizer_path = model_dir.join("tokenizer.json");
328
329 if model_path.exists() && tokenizer_path.exists() {
330 return Ok(model_dir);
331 }
332
333 std::fs::create_dir_all(&model_dir)?;
334
335 if !model_path.exists() {
336 info!(model = %schema.name, repo = %hf_repo, "downloading model weights");
337 download_file(hf_repo, hf_filename, &model_path).await?;
338 }
339 if !tokenizer_path.exists() {
340 info!(model = %schema.name, repo = %tokenizer_repo, "downloading tokenizer");
341 download_file(tokenizer_repo, "tokenizer.json", &tokenizer_path).await?;
342 }
343
344 Ok(model_dir)
345 }
346 ModelSource::Mlx {
347 hf_repo,
348 hf_weight_file,
349 } => {
350 let model_dir = self.models_dir.join(&schema.name);
351 let config_path = model_dir.join("config.json");
352
353 if config_path.exists() {
354 ensure_auxiliary_mlx_files(&schema.name, hf_repo, &model_dir).await?;
355 info!(model = %schema.name, path = %model_dir.display(), "using managed local MLX model");
356 return Ok(model_dir);
357 }
358
359 if let Some(snapshot_dir) = latest_huggingface_repo_snapshot(hf_repo) {
360 ensure_auxiliary_mlx_files(&schema.name, hf_repo, &snapshot_dir).await?;
361 info!(model = %schema.name, path = %snapshot_dir.display(), "using cached MLX snapshot");
362 return Ok(snapshot_dir);
363 }
364
365 if requires_full_mlx_snapshot(&schema) {
366 info!(
367 model = %schema.name,
368 repo = %hf_repo,
369 "downloading full MLX snapshot"
370 );
371 let (snapshot_dir, _files_downloaded) =
372 download_hf_repo_snapshot(hf_repo).await?;
373 ensure_auxiliary_mlx_files(&schema.name, hf_repo, &snapshot_dir).await?;
374 return Ok(snapshot_dir);
375 }
376
377 std::fs::create_dir_all(&model_dir)?;
378
379 info!(model = %schema.name, repo = %hf_repo, "downloading MLX model");
380
381 download_file(hf_repo, "config.json", &config_path).await?;
383 let tok_path = model_dir.join("tokenizer.json");
384 if !tok_path.exists() {
385 download_file(hf_repo, "tokenizer.json", &tok_path).await?;
386 }
387 let tok_config_path = model_dir.join("tokenizer_config.json");
388 if !tok_config_path.exists() {
389 let _ = download_file(hf_repo, "tokenizer_config.json", &tok_config_path).await;
390 }
391
392 if let Some(ref wf) = hf_weight_file {
394 let wf_path = model_dir.join(wf);
395 if !wf_path.exists() {
396 download_file(hf_repo, wf, &wf_path).await?;
397 }
398 } else {
399 let single = model_dir.join("model.safetensors");
401 if !single.exists() {
402 match download_file(hf_repo, "model.safetensors", &single).await {
403 Ok(()) => {}
404 Err(_) => {
405 let index_path = model_dir.join("model.safetensors.index.json");
407 download_file(hf_repo, "model.safetensors.index.json", &index_path)
408 .await?;
409
410 let index_json: serde_json::Value =
411 serde_json::from_str(&std::fs::read_to_string(&index_path)?)
412 .map_err(|e| {
413 InferenceError::InferenceFailed(format!(
414 "parse index: {e}"
415 ))
416 })?;
417
418 if let Some(weight_map) =
419 index_json.get("weight_map").and_then(|m| m.as_object())
420 {
421 let mut files: std::collections::HashSet<String> =
422 std::collections::HashSet::new();
423 for filename in weight_map.values() {
424 if let Some(f) = filename.as_str() {
425 files.insert(f.to_string());
426 }
427 }
428 for file in &files {
429 let dest = model_dir.join(file);
430 if !dest.exists() {
431 info!(file = %file, "downloading weight shard");
432 download_file(hf_repo, file, &dest).await?;
433 }
434 }
435 }
436 }
437 }
438 }
439 }
440
441 ensure_auxiliary_mlx_files(&schema.name, hf_repo, &model_dir).await?;
442 Ok(model_dir)
443 }
444 _ => Err(InferenceError::InferenceFailed(format!(
445 "model {} is not local",
446 id
447 ))),
448 }
449 }
450
451 pub fn remove_local(&mut self, id: &str) -> Result<(), InferenceError> {
453 let schema = self
454 .get(id)
455 .or_else(|| self.find_by_name(id))
456 .ok_or_else(|| InferenceError::ModelNotFound(id.to_string()))?;
457
458 let model_dir = self.models_dir.join(&schema.name);
459 if model_dir.exists() {
460 std::fs::remove_dir_all(&model_dir)?;
461 info!(model = %schema.name, "removed model");
462 }
463
464 match &schema.source {
465 ModelSource::Mlx { hf_repo, .. } => {
466 let repo_dir = huggingface_repo_dir(hf_repo);
467 if repo_dir.exists() {
468 std::fs::remove_dir_all(&repo_dir)?;
469 info!(model = %schema.name, repo = %hf_repo, "removed Hugging Face cache");
470 }
471 }
472 ModelSource::Local {
473 hf_repo,
474 tokenizer_repo,
475 ..
476 } => {
477 for repo in [hf_repo, tokenizer_repo] {
478 let repo_dir = huggingface_repo_dir(repo);
479 if repo_dir.exists() {
480 std::fs::remove_dir_all(&repo_dir)?;
481 info!(model = %schema.name, repo = %repo, "removed Hugging Face cache");
482 }
483 }
484 }
485 _ => {}
486 }
487
488 let id = schema.id.clone();
490 if let Some(m) = self.models.get_mut(&id) {
491 m.available = false;
492 }
493 Ok(())
494 }
495
496 pub fn refresh_availability(&mut self) {
504 let models_dir = self.models_dir.clone();
505 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
508 let mlx_vlm_cli_present = crate::backend::mlx_vlm_cli::is_available();
509 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
510 let mlx_vlm_cli_present = false;
511
512 for m in self.models.values_mut() {
513 match &m.source {
514 ModelSource::Mlx { hf_repo, .. } => {
515 let needs_mlx_vlm = m.tags.iter().any(|t| t == "requires-mlx-vlm");
522
523 m.available = if needs_mlx_vlm {
524 mlx_vlm_cli_present
525 } else if m.tags.contains(&"speech".to_string()) {
526 speech_mlx_available()
527 } else {
528 let mlx_dir = models_dir.join(&m.name);
538 mlx_dir.join("config.json").exists() || !hf_repo.is_empty()
539 };
540 }
541 ModelSource::Local { .. } => {
542 let local_path = models_dir.join(&m.name).join("model.gguf");
543 m.available = local_path.exists();
544 }
545 ModelSource::RemoteApi { api_key_env, .. } => {
546 m.available = std::env::var(api_key_env).is_ok();
547 }
548 ModelSource::Ollama { .. } => {
549 m.available = true;
551 }
552 ModelSource::VllmMlx { .. } => {
553 m.available = std::env::var("VLLM_MLX_ENDPOINT").is_ok() || m.available;
556 }
558 ModelSource::Proprietary { auth, .. } => {
559 m.available = match auth {
561 crate::schema::ProprietaryAuth::ApiKeyEnv { env_var } => {
562 std::env::var(env_var).is_ok()
563 }
564 crate::schema::ProprietaryAuth::BearerTokenEnv { env_var } => {
565 std::env::var(env_var).is_ok()
566 }
567 crate::schema::ProprietaryAuth::OAuth2Pkce { .. } => {
568 true
570 }
571 };
572 }
573 ModelSource::AppleFoundationModels { .. } => {
574 #[cfg(any(
581 all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)),
582 all(target_os = "ios", target_arch = "aarch64")
583 ))]
584 {
585 m.available = crate::backend::foundation_models::is_available();
586 }
587 #[cfg(not(any(
588 all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)),
589 all(target_os = "ios", target_arch = "aarch64")
590 )))]
591 {
592 m.available = false;
593 }
594 }
595 ModelSource::Delegated { .. } => {
596 m.available = crate::runner::current_inference_runner().is_some();
601 }
602 }
603 }
604 }
605
606 pub fn save_user_config(&self) -> Result<(), InferenceError> {
608 let user_models: Vec<&ModelSchema> = self
609 .models
610 .values()
611 .filter(|m| !m.tags.contains(&"builtin".to_string()))
612 .collect();
613
614 if user_models.is_empty() {
615 return Ok(());
616 }
617
618 let json = serde_json::to_string_pretty(&user_models)
619 .map_err(|e| InferenceError::InferenceFailed(format!("serialize: {e}")))?;
620 std::fs::write(&self.user_config_path, json)?;
621 Ok(())
622 }
623
624 pub fn load_user_config(&mut self) -> Result<(), InferenceError> {
626 if !self.user_config_path.exists() {
627 return Ok(());
628 }
629
630 let json = std::fs::read_to_string(&self.user_config_path)?;
631 let models: Vec<ModelSchema> = serde_json::from_str(&json)
632 .map_err(|e| InferenceError::InferenceFailed(format!("parse models.json: {e}")))?;
633
634 for m in models {
635 self.register(m);
636 }
637 Ok(())
638 }
639
640 pub fn models_dir(&self) -> &Path {
642 &self.models_dir
643 }
644
645 fn load_builtin_catalog(&mut self) {
647 for schema in builtin_catalog() {
648 self.models.insert(schema.id.clone(), schema);
649 }
650 }
651}
652
653fn speech_mlx_available() -> bool {
654 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
657 {
658 true
659 }
660
661 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
663 {
664 let runtime_root = speech_runtime_root();
665 runtime_root
666 .join("bin")
667 .join("mlx_audio.stt.generate")
668 .exists()
669 || runtime_root
670 .join("bin")
671 .join("mlx_audio.tts.generate")
672 .exists()
673 }
674}
675
676fn speech_runtime_root() -> PathBuf {
677 if let Ok(path) = std::env::var("CAR_SPEECH_RUNTIME_DIR") {
678 if !path.trim().is_empty() {
679 return PathBuf::from(path);
680 }
681 }
682 std::env::var("HOME")
683 .map(PathBuf::from)
684 .unwrap_or_else(|_| PathBuf::from("."))
685 .join(".car")
686 .join("speech-runtime")
687}
688
689#[derive(Debug, Clone, Serialize, Deserialize)]
691pub struct ModelInfo {
692 pub id: String,
693 pub name: String,
694 pub provider: String,
695 pub capabilities: Vec<ModelCapability>,
696 pub param_count: String,
697 pub size_mb: u64,
698 pub context_length: usize,
699 pub available: bool,
700 pub is_local: bool,
701 #[serde(default)]
705 pub public_benchmarks: Vec<crate::schema::BenchmarkScore>,
706}
707
708impl From<&ModelSchema> for ModelInfo {
709 fn from(s: &ModelSchema) -> Self {
710 ModelInfo {
711 id: s.id.clone(),
712 name: s.name.clone(),
713 provider: s.provider.clone(),
714 capabilities: s.capabilities.clone(),
715 param_count: s.param_count.clone(),
716 size_mb: s.size_mb(),
717 context_length: s.context_length,
718 available: s.available,
719 is_local: s.is_local(),
720 public_benchmarks: s.public_benchmarks.clone(),
721 }
722 }
723}
724
725async fn download_file(repo: &str, filename: &str, dest: &Path) -> Result<(), InferenceError> {
727 let api = hf_hub::api::tokio::Api::new()
728 .map_err(|e| InferenceError::DownloadFailed(e.to_string()))?;
729
730 let repo = api.model(repo.to_string());
731 let path = repo
732 .get(filename)
733 .await
734 .map_err(|e| InferenceError::DownloadFailed(format!("{filename}: {e}")))?;
735
736 if dest.exists() {
737 return Ok(());
738 }
739
740 #[cfg(unix)]
742 {
743 if std::os::unix::fs::symlink(&path, dest).is_ok() {
744 return Ok(());
745 }
746 }
747
748 std::fs::copy(&path, dest)
749 .map_err(|e| InferenceError::DownloadFailed(format!("copy to {}: {e}", dest.display())))?;
750 Ok(())
751}
752
753async fn ensure_auxiliary_mlx_files(
754 model_name: &str,
755 hf_repo: &str,
756 model_dir: &Path,
757) -> Result<(), InferenceError> {
758 if hf_repo == "mlx-community/Flux-1.lite-8B-MLX-Q4" || model_name == "Flux-1.lite-8B-MLX-Q4" {
759 let t5_tokenizer_path = model_dir.join("tokenizer_2").join("tokenizer.json");
760 if !t5_tokenizer_path.exists() {
761 std::fs::create_dir_all(t5_tokenizer_path.parent().ok_or_else(|| {
762 InferenceError::InferenceFailed("invalid tokenizer path".into())
763 })?)?;
764 info!(
765 path = %t5_tokenizer_path.display(),
766 "downloading missing Flux tokenizer_2/tokenizer.json from base model"
767 );
768 download_file(
769 "Freepik/flux.1-lite-8B",
770 "tokenizer_2/tokenizer.json",
771 &t5_tokenizer_path,
772 )
773 .await?;
774 }
775 }
776 Ok(())
777}
778
779fn requires_full_mlx_snapshot(schema: &ModelSchema) -> bool {
780 match &schema.source {
781 ModelSource::Mlx { hf_repo, .. } => {
782 hf_repo == "ckurasek/Yume-1.5-5B-720P-MLX-4bit"
783 || schema.family.starts_with("yume")
784 || schema.tags.iter().any(|tag| {
785 matches!(
786 tag.as_str(),
787 "wan2.2" | "ti2v" | "world-model" | "image-to-video"
788 )
789 })
790 }
791 _ => false,
792 }
793}
794
795fn huggingface_repo_has_snapshot(repo_id: &str) -> bool {
796 latest_huggingface_repo_snapshot(repo_id).is_some()
797}
798
799fn huggingface_cache_root() -> PathBuf {
800 std::env::var("HF_HOME")
801 .map(PathBuf::from)
802 .unwrap_or_else(|_| {
803 std::env::var("HOME")
804 .map(PathBuf::from)
805 .unwrap_or_else(|_| PathBuf::from("."))
806 .join(".cache")
807 .join("huggingface")
808 })
809 .join("hub")
810}
811
812fn huggingface_repo_dir(repo_id: &str) -> PathBuf {
813 huggingface_cache_root().join(format!("models--{}", repo_id.replace('/', "--")))
814}
815
816fn resolve_huggingface_ref_snapshot(repo_dir: &Path, name: &str) -> Option<PathBuf> {
817 let sha = std::fs::read_to_string(repo_dir.join("refs").join(name))
818 .ok()?
819 .trim()
820 .to_string();
821 if sha.is_empty() {
822 return None;
823 }
824
825 let snapshot = repo_dir.join("snapshots").join(sha);
826 if snapshot_looks_ready(&snapshot) {
827 Some(snapshot)
828 } else {
829 None
830 }
831}
832
833fn latest_huggingface_repo_snapshot(repo_id: &str) -> Option<PathBuf> {
834 let repo_dir = huggingface_repo_dir(repo_id);
835 if let Some(snapshot) = resolve_huggingface_ref_snapshot(&repo_dir, "main") {
836 return Some(snapshot);
837 }
838
839 let snapshots = repo_dir.join("snapshots");
840 let mut candidates: Vec<(SystemTime, PathBuf)> = std::fs::read_dir(snapshots)
841 .ok()?
842 .filter_map(Result::ok)
843 .map(|e| e.path())
844 .filter(|p| p.is_dir() && snapshot_looks_ready(p))
845 .map(|path| {
846 let modified = path
847 .metadata()
848 .and_then(|metadata| metadata.modified())
849 .unwrap_or(SystemTime::UNIX_EPOCH);
850 (modified, path)
851 })
852 .collect();
853 candidates.sort();
854 candidates.pop().map(|(_, path)| path)
855}
856
857fn snapshot_looks_ready(path: &Path) -> bool {
858 if path.join("config.json").exists() || path.join("model_index.json").exists() {
859 return true;
860 }
861 snapshot_contains_ext(path, "safetensors")
862}
863
864fn snapshot_contains_ext(root: &Path, ext: &str) -> bool {
865 let Ok(entries) = std::fs::read_dir(root) else {
866 return false;
867 };
868 entries.filter_map(Result::ok).any(|entry| {
869 let path = entry.path();
870 if path.is_dir() {
871 snapshot_contains_ext(&path, ext)
872 } else {
873 path.extension()
874 .and_then(|value| value.to_str())
875 .map(|value| value.eq_ignore_ascii_case(ext))
876 .unwrap_or(false)
877 }
878 })
879}
880
881async fn download_hf_repo_snapshot(repo_id: &str) -> Result<(PathBuf, usize), InferenceError> {
882 let api = hf_hub::api::tokio::ApiBuilder::from_env()
883 .with_progress(false)
884 .build()
885 .map_err(|e| InferenceError::DownloadFailed(format!("init hf api: {e}")))?;
886 let repo = api.model(repo_id.to_string());
887 let info = repo
888 .info()
889 .await
890 .map_err(|e| InferenceError::DownloadFailed(format!("{repo_id}: {e}")))?;
891
892 let snapshot_path = std::env::var("HF_HOME")
893 .map(PathBuf::from)
894 .unwrap_or_else(|_| {
895 std::env::var("HOME")
896 .map(PathBuf::from)
897 .unwrap_or_else(|_| PathBuf::from("."))
898 .join(".cache")
899 .join("huggingface")
900 })
901 .join("hub")
902 .join(format!("models--{}", repo_id.replace('/', "--")))
903 .join("snapshots")
904 .join(&info.sha);
905 let mut downloaded = 0usize;
906 for sibling in &info.siblings {
907 let local_path = snapshot_path.join(&sibling.rfilename);
908 if local_path.exists() {
909 downloaded += 1;
910 continue;
911 }
912 repo.download(&sibling.rfilename).await.map_err(|e| {
913 InferenceError::DownloadFailed(format!("{repo_id}/{}: {e}", sibling.rfilename))
914 })?;
915 downloaded += 1;
916 }
917
918 Ok((snapshot_path, downloaded))
919}
920
921const BUILTIN_CATALOG_JSON: &str = include_str!("builtin_catalog.json");
930
931static BUILTIN_CATALOG: std::sync::LazyLock<Vec<ModelSchema>> = std::sync::LazyLock::new(|| {
932 serde_json::from_str(BUILTIN_CATALOG_JSON)
933 .expect("builtin_catalog.json failed to parse — fix the JSON, not this code")
934});
935
936fn builtin_catalog() -> Vec<ModelSchema> {
937 BUILTIN_CATALOG.clone()
938}
939
940#[cfg(test)]
941mod tests {
942 use super::*;
943 use tempfile::TempDir;
944
945 fn test_registry() -> (UnifiedRegistry, TempDir) {
946 let tmp = TempDir::new().unwrap();
947 let reg = UnifiedRegistry::new(tmp.path().join("models"));
948 (reg, tmp)
949 }
950
951 #[test]
952 fn builtin_catalog_loads() {
953 let (reg, _tmp) = test_registry();
954 let all = reg.list();
955 assert_eq!(all.len(), builtin_catalog().len());
956 }
957
958 #[test]
971 fn mlx_vlm_models_reflect_runtime_availability() {
972 let (reg, _tmp) = test_registry();
973 let mlx_vlm_models: Vec<&ModelSchema> = reg
974 .list()
975 .into_iter()
976 .filter(|m| m.tags.iter().any(|t| t == "requires-mlx-vlm"))
977 .collect();
978 assert!(
979 !mlx_vlm_models.is_empty(),
980 "catalog should contain at least one model tagged \
981 `requires-mlx-vlm` — otherwise this regression has \
982 nothing to guard"
983 );
984
985 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
986 let expected = crate::backend::mlx_vlm_cli::is_available();
987 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
988 let expected = false;
989
990 for m in mlx_vlm_models {
991 assert_eq!(
992 m.available, expected,
993 "model {} `available` field should reflect \
994 mlx_vlm CLI presence (expected {expected}, got {})",
995 m.id, m.available
996 );
997 }
998 }
999
1000 #[test]
1003 fn builtin_catalog_json_parses() {
1004 let catalog: Vec<ModelSchema> = serde_json::from_str(BUILTIN_CATALOG_JSON)
1005 .expect("builtin_catalog.json must be valid ModelSchema array");
1006 assert!(
1007 !catalog.is_empty(),
1008 "embedded catalog has no entries — that's almost certainly wrong"
1009 );
1010
1011 let mut seen = std::collections::HashSet::new();
1012 for entry in &catalog {
1013 assert!(
1014 seen.insert(entry.id.clone()),
1015 "duplicate id in builtin_catalog.json: {}",
1016 entry.id
1017 );
1018 }
1019 }
1020
1021 #[test]
1022 fn public_benchmarks_round_trip_through_model_info() {
1023 use crate::schema::BenchmarkScore;
1024 let (mut reg, _tmp) = test_registry();
1025 let mut schema = reg
1026 .find_by_name("Qwen3-4B")
1027 .expect("catalog has Qwen3-4B")
1028 .clone();
1029 schema.id = "test/qwen3-4b-with-bench".into();
1030 schema.public_benchmarks = vec![
1031 BenchmarkScore {
1032 name: "MMLU-Pro".into(),
1033 score: 0.482,
1034 harness: Some("5-shot CoT".into()),
1035 source_url: Some("https://example.invalid/qwen3-4b-card".into()),
1036 measured_at: Some("2025-08-12".into()),
1037 },
1038 BenchmarkScore {
1039 name: "HumanEval".into(),
1040 score: 0.713,
1041 harness: Some("pass@1".into()),
1042 source_url: None,
1043 measured_at: None,
1044 },
1045 ];
1046 reg.register(schema);
1047
1048 let stored = reg
1049 .get("test/qwen3-4b-with-bench")
1050 .expect("registered model is retrievable");
1051 let info = ModelInfo::from(stored);
1052 assert_eq!(info.public_benchmarks.len(), 2);
1053
1054 let json = serde_json::to_string(&info).unwrap();
1056 assert!(json.contains("\"public_benchmarks\""));
1057 assert!(json.contains("\"MMLU-Pro\""));
1058 assert!(json.contains("\"5-shot CoT\""));
1059
1060 let decoded: ModelInfo = serde_json::from_str(&json).unwrap();
1062 assert_eq!(decoded.public_benchmarks.len(), 2);
1063 assert_eq!(decoded.public_benchmarks[0].name, "MMLU-Pro");
1064 assert_eq!(decoded.public_benchmarks[1].name, "HumanEval");
1065 }
1066
1067 #[test]
1068 fn public_benchmarks_default_to_empty_when_absent_in_json() {
1069 let legacy_json = r#"{
1072 "id": "legacy/test:1",
1073 "name": "Legacy Test",
1074 "provider": "test",
1075 "family": "test",
1076 "version": "",
1077 "capabilities": ["generate"],
1078 "context_length": 4096,
1079 "param_count": "1B",
1080 "quantization": null,
1081 "performance": {},
1082 "cost": {},
1083 "source": { "type": "ollama", "model_tag": "legacy:1" },
1084 "tags": [],
1085 "supported_params": []
1086 }"#;
1087 let schema: ModelSchema = serde_json::from_str(legacy_json).unwrap();
1088 assert!(schema.public_benchmarks.is_empty());
1089 }
1090
1091 #[test]
1092 fn find_by_name() {
1093 let (reg, _tmp) = test_registry();
1094 let m = reg.find_by_name("Qwen3-4B").unwrap();
1095 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1096 assert_eq!(m.id, "mlx/qwen3-4b:4bit");
1097 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
1098 assert_eq!(m.id, "qwen/qwen3-4b:q4_k_m");
1099 assert!(m.has_capability(ModelCapability::Code));
1100 }
1101
1102 #[test]
1103 fn query_by_capability() {
1104 let (reg, _tmp) = test_registry();
1105 let embed_models = reg.query_by_capability(ModelCapability::Embed);
1106 assert_eq!(embed_models.len(), 2);
1107 assert!(embed_models
1108 .iter()
1109 .any(|model| model.name == "Qwen3-Embedding-0.6B"));
1110 assert!(embed_models
1111 .iter()
1112 .any(|model| model.name == "Qwen3-Embedding-0.6B-MLX"));
1113 }
1114
1115 #[test]
1116 fn query_with_filter() {
1117 let (reg, _tmp) = test_registry();
1118 let code_small = reg.query(&ModelFilter {
1119 capabilities: vec![ModelCapability::Code],
1120 max_size_mb: Some(3000),
1121 local_only: true,
1122 ..Default::default()
1123 });
1124 assert_eq!(code_small.len(), 4);
1126 }
1127
1128 #[test]
1129 fn register_remote() {
1130 let (mut reg, _tmp) = test_registry();
1131 let initial_len = reg.list().len();
1132 let initial_reasoning_len = reg
1133 .query(&ModelFilter {
1134 capabilities: vec![ModelCapability::Reasoning, ModelCapability::ToolUse],
1135 ..Default::default()
1136 })
1137 .len();
1138 let remote = ModelSchema {
1139 id: "anthropic/claude-sonnet-4-6:latest".into(),
1140 name: "Claude Sonnet 4.6".into(),
1141 provider: "anthropic".into(),
1142 family: "claude-4".into(),
1143 version: "latest".into(),
1144 capabilities: vec![
1145 ModelCapability::Generate,
1146 ModelCapability::Code,
1147 ModelCapability::Reasoning,
1148 ModelCapability::ToolUse,
1149 ],
1150 context_length: 200000,
1151 param_count: String::new(),
1152 quantization: None,
1153 performance: PerformanceEnvelope {
1154 latency_p50_ms: Some(2000),
1155 ..Default::default()
1156 },
1157 cost: CostModel {
1158 input_per_mtok: Some(3.0),
1159 output_per_mtok: Some(15.0),
1160 ..Default::default()
1161 },
1162 source: ModelSource::RemoteApi {
1163 endpoint: "https://api.anthropic.com/v1/messages".into(),
1164 api_key_env: "ANTHROPIC_API_KEY".into(),
1165 api_key_envs: vec![],
1166 api_version: Some("2023-06-01".into()),
1167 protocol: ApiProtocol::Anthropic,
1168 },
1169 tags: vec![],
1170 supported_params: vec![],
1171 public_benchmarks: vec![],
1172 available: false,
1173 };
1174
1175 reg.register(remote);
1176 assert_eq!(reg.list().len(), initial_len);
1178
1179 let reasoning = reg.query(&ModelFilter {
1180 capabilities: vec![ModelCapability::Reasoning, ModelCapability::ToolUse],
1181 ..Default::default()
1182 });
1183 assert_eq!(reasoning.len(), initial_reasoning_len);
1185 }
1186
1187 #[test]
1188 fn unregister() {
1189 let (mut reg, _tmp) = test_registry();
1190 let initial_len = reg.list().len();
1191 let removed = reg.unregister("qwen/qwen3-0.6b:q8_0");
1192 assert!(removed.is_some());
1193 assert_eq!(reg.list().len(), initial_len - 1);
1194 }
1195
1196 #[test]
1197 fn speech_models_are_curated() {
1198 let (reg, _tmp) = test_registry();
1199 let stt = reg.query_by_capability(ModelCapability::SpeechToText);
1200 let tts = reg.query_by_capability(ModelCapability::TextToSpeech);
1201 assert_eq!(stt.len(), 2);
1202 assert_eq!(tts.len(), 4);
1203 }
1204
1205 #[test]
1206 fn qwen_8b_variants_keep_tool_use_consistent() {
1207 let (reg, _tmp) = test_registry();
1208 for name in ["Qwen3-8B", "Qwen3-8B-MLX"] {
1209 let model = reg.find_by_name(name).expect("model should exist");
1210 assert!(model.has_capability(ModelCapability::ToolUse));
1211 assert!(model.has_capability(ModelCapability::MultiToolCall));
1212 }
1213 }
1214
1215 #[test]
1216 fn mac_name_resolution_prefers_mlx_siblings() {
1217 #[allow(unused_variables)]
1220 let (reg, _tmp) = test_registry();
1221 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1222 {
1223 assert_eq!(
1224 reg.find_by_name("Qwen3-0.6B").unwrap().id,
1225 "mlx/qwen3-0.6b:6bit"
1226 );
1227 assert_eq!(
1228 reg.find_by_name("Qwen3-1.7B").unwrap().id,
1229 "mlx/qwen3-1.7b:3bit"
1230 );
1231 assert_eq!(
1232 reg.find_by_name("Qwen3-Embedding-0.6B").unwrap().id,
1233 "mlx/qwen3-embedding-0.6b:mxfp8"
1234 );
1235 }
1236 }
1237
1238 #[test]
1239 fn remote_multimodal_models_are_curated_as_vision_capable() {
1240 let (reg, _tmp) = test_registry();
1241 for name in [
1242 "claude-opus-4-7",
1243 "claude-opus-4-6",
1244 "claude-sonnet-4-6",
1245 "claude-haiku-4-5",
1246 "gpt-5.4",
1247 "gpt-5.4-mini",
1248 "o3",
1249 "o4-mini",
1250 "gpt-4.1-mini",
1251 "gemini-2.5-pro",
1252 "gemini-2.5-flash",
1253 ] {
1254 let model = reg.find_by_name(name).expect("model should exist");
1255 assert!(
1256 model.has_capability(ModelCapability::Vision),
1257 "{name} should be curated as vision-capable"
1258 );
1259 }
1260 }
1261
1262 #[test]
1263 fn qwen25vl_entries_are_replaced_by_qwen3vl_in_builtin_catalog() {
1264 let (reg, _tmp) = test_registry();
1265
1266 let stale_ids = [
1267 "mlx/qwen2.5-vl-3b:4bit",
1269 "mlx/qwen2.5-vl-7b:4bit",
1270 "mlx-vlm/qwen2.5-vl-3b:4bit",
1273 "mlx-vlm/qwen2.5-vl-7b:4bit",
1274 "vllm-mlx/qwen2.5-vl-3b:4bit",
1276 ];
1277 for id in stale_ids {
1278 assert!(
1279 reg.get(id).is_none(),
1280 "{id} is superseded by Qwen3-VL; the catalog must not advertise it"
1281 );
1282 }
1283
1284 let vision_ids: Vec<&str> = reg
1285 .query_by_capability(ModelCapability::Vision)
1286 .into_iter()
1287 .map(|model| model.id.as_str())
1288 .collect();
1289 for stale in stale_ids {
1290 assert!(
1291 !vision_ids.contains(&stale),
1292 "{stale} must not be reachable through the Vision capability index"
1293 );
1294 }
1295 assert!(
1296 vision_ids.contains(&"mlx-vlm/qwen3-vl-2b:bf16"),
1297 "Qwen3-VL is the supported local VL family and must route as Vision"
1298 );
1299 }
1300
1301 #[test]
1302 fn gemini_models_are_curated_for_multimodal_tool_use() {
1303 let (reg, _tmp) = test_registry();
1304 for name in ["gemini-2.5-pro", "gemini-2.5-flash"] {
1305 let model = reg.find_by_name(name).expect("model should exist");
1306 assert!(model.has_capability(ModelCapability::Vision));
1307 assert!(model.has_capability(ModelCapability::ToolUse));
1308 assert!(model.has_capability(ModelCapability::MultiToolCall));
1309 }
1310 }
1311
1312 #[test]
1313 fn visual_generation_models_are_curated() {
1314 let (reg, _tmp) = test_registry();
1315 assert_eq!(
1316 reg.query_by_capability(ModelCapability::ImageGeneration)
1317 .len(),
1318 1
1319 );
1320 assert_eq!(
1321 reg.query_by_capability(ModelCapability::VideoGeneration)
1322 .len(),
1323 2
1324 );
1325 let yume = reg
1326 .get("mlx/yume-1.5-5b-720p:q4")
1327 .expect("Yume MLX should be in the built-in catalog");
1328 assert!(yume.has_capability(ModelCapability::VideoGeneration));
1329 assert!(yume.tags.contains(&"text-to-video".to_string()));
1330 assert!(yume.tags.contains(&"image-to-video".to_string()));
1331 assert!(yume.tags.contains(&"world-model".to_string()));
1332 }
1333}