1use crate::schema::reasoning_params;
9use std::collections::HashMap;
10use std::path::{Path, PathBuf};
11use std::time::SystemTime;
12
13use serde::{Deserialize, Serialize};
14use tracing::info;
15
16use crate::schema::*;
17use crate::InferenceError;
18
19#[derive(Debug, Clone, Default)]
21pub struct ModelFilter {
22 pub capabilities: Vec<ModelCapability>,
24 pub max_size_mb: Option<u64>,
26 pub max_latency_ms: Option<u64>,
28 pub max_cost_per_mtok: Option<f64>,
30 pub tags: Vec<String>,
32 pub provider: Option<String>,
34 pub local_only: bool,
36 pub available_only: bool,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct ModelUpgrade {
44 pub from_id: String,
45 pub from_name: String,
46 pub to_id: String,
47 pub to_name: String,
48 pub reason: String,
49 pub target_runtime: Option<String>,
50 pub target_runtime_requirement: Option<String>,
51 pub minimum_runtimes: Vec<ModelRuntimeRequirement>,
52 pub target_available: bool,
53 pub target_pullable: bool,
54 pub remove_old_supported: bool,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct ModelRuntimeRequirement {
59 pub name: String,
60 pub minimum_version: String,
61}
62
63pub struct UnifiedRegistry {
65 models_dir: PathBuf,
66 models: HashMap<String, ModelSchema>,
68 user_config_path: PathBuf,
70}
71
72#[derive(Debug, Clone, Deserialize)]
73struct ModelUpgradeRule {
74 from_ids: Vec<String>,
75 to_id: String,
76 reason: String,
77 target_runtime: Option<String>,
78 target_runtime_requirement: Option<String>,
79 #[serde(default)]
80 minimum_runtimes: Vec<ModelRuntimeRequirement>,
81 #[serde(default = "default_remove_old_after_available")]
82 remove_old_after_available: bool,
83}
84
85fn default_remove_old_after_available() -> bool {
86 true
87}
88
89fn model_upgrade_rules() -> Vec<ModelUpgradeRule> {
90 serde_json::from_str(include_str!("../assets/model-upgrades.json"))
91 .expect("built-in model-upgrades.json should parse")
92}
93
94impl UnifiedRegistry {
95 pub fn new(models_dir: PathBuf) -> Self {
96 let user_config_path = models_dir
97 .parent()
98 .unwrap_or(&models_dir)
99 .join("models.json");
100
101 let mut registry = Self {
102 models_dir,
103 models: HashMap::new(),
104 user_config_path,
105 };
106 registry.load_builtin_catalog();
107 registry.refresh_availability();
108 let _ = registry.load_user_config();
110 registry
111 }
112
113 pub fn register(&mut self, mut schema: ModelSchema) {
115 if schema.is_mlx() {
117 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
128 {
129 schema.available = if schema.tags.contains(&"speech".to_string()) {
130 speech_mlx_available()
131 } else if let ModelSource::Mlx { ref hf_repo, .. } = schema.source {
132 let mlx_dir = self.models_dir.join(&schema.name);
138 mlx_dir.join("config.json").exists() || !hf_repo.is_empty()
139 } else {
140 let mlx_dir = self.models_dir.join(&schema.name);
141 mlx_dir.join("config.json").exists()
142 };
143 }
144 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
145 {
146 schema.available = false;
147 }
148 } else if schema.is_vllm_mlx() {
149 schema.available = std::env::var("VLLM_MLX_ENDPOINT").is_ok() || schema.available;
151 } else if schema.is_local() {
152 let local_path = self.models_dir.join(&schema.name).join("model.gguf");
153 schema.available = local_path.exists();
154 } else if schema.is_remote() {
155 if let ModelSource::RemoteApi {
157 ref api_key_env, ..
158 } = schema.source
159 {
160 schema.available = std::env::var(api_key_env).is_ok();
161 }
162 }
163 info!(id = %schema.id, name = %schema.name, available = schema.available, "registered model");
164 self.models.insert(schema.id.clone(), schema);
165 }
166
167 pub fn unregister(&mut self, id: &str) -> Option<ModelSchema> {
169 let removed = self.models.remove(id);
170 if let Some(ref m) = removed {
171 info!(id = %m.id, "unregistered model");
172 }
173 removed
174 }
175
176 pub fn list(&self) -> Vec<&ModelSchema> {
178 let mut models: Vec<&ModelSchema> = self.models.values().collect();
179 models.sort_by(|a, b| a.id.cmp(&b.id));
180 models
181 }
182
183 pub fn query(&self, filter: &ModelFilter) -> Vec<&ModelSchema> {
185 self.models
186 .values()
187 .filter(|m| {
188 if !filter.capabilities.iter().all(|c| m.has_capability(*c)) {
190 return false;
191 }
192 if let Some(max) = filter.max_size_mb {
194 if m.size_mb() > max && m.is_local() {
195 return false;
196 }
197 }
198 if let Some(max) = filter.max_latency_ms {
200 if let Some(p50) = m.performance.latency_p50_ms {
201 if p50 > max {
202 return false;
203 }
204 }
205 }
206 if let Some(max) = filter.max_cost_per_mtok {
208 if let Some(cost) = m.cost.output_per_mtok {
209 if cost > max {
210 return false;
211 }
212 }
213 }
214 if !filter.tags.iter().all(|t| m.tags.contains(t)) {
216 return false;
217 }
218 if let Some(ref p) = filter.provider {
220 if &m.provider != p {
221 return false;
222 }
223 }
224 if filter.local_only && !m.is_local() {
226 return false;
227 }
228 if filter.available_only && !m.available {
230 return false;
231 }
232 true
233 })
234 .collect()
235 }
236
237 pub fn query_by_capability(&self, cap: ModelCapability) -> Vec<&ModelSchema> {
239 self.query(&ModelFilter {
240 capabilities: vec![cap],
241 ..Default::default()
242 })
243 }
244
245 pub fn available_upgrades(&self) -> Vec<ModelUpgrade> {
247 let mut upgrades = Vec::new();
248 for rule in model_upgrade_rules() {
249 let Some(from) = rule
250 .from_ids
251 .iter()
252 .find_map(|id| self.models.get(id.as_str()))
253 .filter(|schema| schema.available)
254 else {
255 continue;
256 };
257 let Some(to) = self.models.get(rule.to_id.as_str()) else {
258 continue;
259 };
260 upgrades.push(ModelUpgrade {
261 from_id: from.id.clone(),
262 from_name: from.name.clone(),
263 to_id: to.id.clone(),
264 to_name: to.name.clone(),
265 reason: rule.reason.clone(),
266 target_runtime: rule.target_runtime.clone(),
267 target_runtime_requirement: rule.target_runtime_requirement.clone(),
268 minimum_runtimes: rule.minimum_runtimes.clone(),
269 target_available: to.available,
270 target_pullable: matches!(
271 to.source,
272 ModelSource::Local { .. } | ModelSource::Mlx { .. }
273 ),
274 remove_old_supported: matches!(
275 from.source,
276 ModelSource::Local { .. } | ModelSource::Mlx { .. }
277 ) && rule.remove_old_after_available,
278 });
279 }
280 upgrades.sort_by(|a, b| a.from_id.cmp(&b.from_id).then(a.to_id.cmp(&b.to_id)));
281 upgrades.dedup_by(|a, b| a.from_id == b.from_id && a.to_id == b.to_id);
282 upgrades
283 }
284
285 pub fn get(&self, id: &str) -> Option<&ModelSchema> {
287 self.models.get(id)
288 }
289
290 pub fn find_by_name(&self, name: &str) -> Option<&ModelSchema> {
293 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
294 if !name.to_ascii_lowercase().ends_with("-mlx") {
295 if let Some(mlx_variant) = self
296 .models
297 .values()
298 .find(|m| m.name.eq_ignore_ascii_case(&format!("{name}-MLX")))
299 {
300 return Some(mlx_variant);
301 }
302 }
303
304 self.models
305 .values()
306 .find(|m| m.name.eq_ignore_ascii_case(name))
307 }
308
309 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
313 pub fn resolve_mlx_equivalent(&self, schema: &ModelSchema) -> Option<&ModelSchema> {
314 if schema.is_mlx() || schema.is_vllm_mlx() {
316 return None;
317 }
318 if !matches!(schema.source, ModelSource::Local { .. }) {
320 return None;
321 }
322 let primary_cap = schema.capabilities.first()?;
324 self.models.values().find(|m| {
325 m.is_mlx() && m.family == schema.family && m.capabilities.contains(primary_cap)
326 })
327 }
328
329 pub async fn ensure_local(&self, id: &str) -> Result<PathBuf, InferenceError> {
331 let schema = self
332 .get(id)
333 .or_else(|| self.find_by_name(id))
334 .ok_or_else(|| InferenceError::ModelNotFound(id.to_string()))?;
335
336 match &schema.source {
337 ModelSource::Local {
338 hf_repo,
339 hf_filename,
340 tokenizer_repo,
341 } => {
342 let model_dir = self.models_dir.join(&schema.name);
343 let model_path = model_dir.join("model.gguf");
344 let tokenizer_path = model_dir.join("tokenizer.json");
345
346 if model_path.exists() && tokenizer_path.exists() {
347 return Ok(model_dir);
348 }
349
350 std::fs::create_dir_all(&model_dir)?;
351
352 if !model_path.exists() {
353 info!(model = %schema.name, repo = %hf_repo, "downloading model weights");
354 download_file(hf_repo, hf_filename, &model_path).await?;
355 }
356 if !tokenizer_path.exists() {
357 info!(model = %schema.name, repo = %tokenizer_repo, "downloading tokenizer");
358 download_file(tokenizer_repo, "tokenizer.json", &tokenizer_path).await?;
359 }
360
361 Ok(model_dir)
362 }
363 ModelSource::Mlx {
364 hf_repo,
365 hf_weight_file,
366 } => {
367 let model_dir = self.models_dir.join(&schema.name);
368 let config_path = model_dir.join("config.json");
369
370 if config_path.exists() {
371 ensure_auxiliary_mlx_files(&schema.name, hf_repo, &model_dir).await?;
372 info!(model = %schema.name, path = %model_dir.display(), "using managed local MLX model");
373 return Ok(model_dir);
374 }
375
376 if let Some(snapshot_dir) = latest_huggingface_repo_snapshot(hf_repo) {
377 ensure_auxiliary_mlx_files(&schema.name, hf_repo, &snapshot_dir).await?;
378 info!(model = %schema.name, path = %snapshot_dir.display(), "using cached MLX snapshot");
379 return Ok(snapshot_dir);
380 }
381
382 if requires_full_mlx_snapshot(&schema) {
383 info!(
384 model = %schema.name,
385 repo = %hf_repo,
386 "downloading full MLX snapshot"
387 );
388 let (snapshot_dir, _files_downloaded) =
389 download_hf_repo_snapshot(hf_repo).await?;
390 ensure_auxiliary_mlx_files(&schema.name, hf_repo, &snapshot_dir).await?;
391 return Ok(snapshot_dir);
392 }
393
394 std::fs::create_dir_all(&model_dir)?;
395
396 info!(model = %schema.name, repo = %hf_repo, "downloading MLX model");
397
398 download_file(hf_repo, "config.json", &config_path).await?;
400 let tok_path = model_dir.join("tokenizer.json");
401 if !tok_path.exists() {
402 download_file(hf_repo, "tokenizer.json", &tok_path).await?;
403 }
404 let tok_config_path = model_dir.join("tokenizer_config.json");
405 if !tok_config_path.exists() {
406 let _ = download_file(hf_repo, "tokenizer_config.json", &tok_config_path).await;
407 }
408
409 if let Some(ref wf) = hf_weight_file {
411 let wf_path = model_dir.join(wf);
412 if !wf_path.exists() {
413 download_file(hf_repo, wf, &wf_path).await?;
414 }
415 } else {
416 let single = model_dir.join("model.safetensors");
418 if !single.exists() {
419 match download_file(hf_repo, "model.safetensors", &single).await {
420 Ok(()) => {}
421 Err(_) => {
422 let index_path = model_dir.join("model.safetensors.index.json");
424 download_file(hf_repo, "model.safetensors.index.json", &index_path)
425 .await?;
426
427 let index_json: serde_json::Value =
428 serde_json::from_str(&std::fs::read_to_string(&index_path)?)
429 .map_err(|e| {
430 InferenceError::InferenceFailed(format!(
431 "parse index: {e}"
432 ))
433 })?;
434
435 if let Some(weight_map) =
436 index_json.get("weight_map").and_then(|m| m.as_object())
437 {
438 let mut files: std::collections::HashSet<String> =
439 std::collections::HashSet::new();
440 for filename in weight_map.values() {
441 if let Some(f) = filename.as_str() {
442 files.insert(f.to_string());
443 }
444 }
445 for file in &files {
446 let dest = model_dir.join(file);
447 if !dest.exists() {
448 info!(file = %file, "downloading weight shard");
449 download_file(hf_repo, file, &dest).await?;
450 }
451 }
452 }
453 }
454 }
455 }
456 }
457
458 ensure_auxiliary_mlx_files(&schema.name, hf_repo, &model_dir).await?;
459 Ok(model_dir)
460 }
461 _ => Err(InferenceError::InferenceFailed(format!(
462 "model {} is not local",
463 id
464 ))),
465 }
466 }
467
468 pub fn remove_local(&mut self, id: &str) -> Result<(), InferenceError> {
470 let schema = self
471 .get(id)
472 .or_else(|| self.find_by_name(id))
473 .ok_or_else(|| InferenceError::ModelNotFound(id.to_string()))?;
474
475 let model_dir = self.models_dir.join(&schema.name);
476 if model_dir.exists() {
477 std::fs::remove_dir_all(&model_dir)?;
478 info!(model = %schema.name, "removed model");
479 }
480
481 match &schema.source {
482 ModelSource::Mlx { hf_repo, .. } => {
483 let repo_dir = huggingface_repo_dir(hf_repo);
484 if repo_dir.exists() {
485 std::fs::remove_dir_all(&repo_dir)?;
486 info!(model = %schema.name, repo = %hf_repo, "removed Hugging Face cache");
487 }
488 }
489 ModelSource::Local {
490 hf_repo,
491 tokenizer_repo,
492 ..
493 } => {
494 for repo in [hf_repo, tokenizer_repo] {
495 let repo_dir = huggingface_repo_dir(repo);
496 if repo_dir.exists() {
497 std::fs::remove_dir_all(&repo_dir)?;
498 info!(model = %schema.name, repo = %repo, "removed Hugging Face cache");
499 }
500 }
501 }
502 _ => {}
503 }
504
505 let id = schema.id.clone();
507 if let Some(m) = self.models.get_mut(&id) {
508 m.available = false;
509 }
510 Ok(())
511 }
512
513 pub fn refresh_availability(&mut self) {
521 let models_dir = self.models_dir.clone();
525 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
530 let mlx_vlm_cli_present = crate::backend::mlx_vlm_cli::is_available();
531 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
532 #[allow(unused_variables)]
533 let mlx_vlm_cli_present = false;
534
535 for m in self.models.values_mut() {
536 match &m.source {
537 ModelSource::Mlx { hf_repo, .. } => {
538 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
547 {
548 let needs_mlx_vlm = m.tags.iter().any(|t| t == "requires-mlx-vlm");
555
556 m.available = if needs_mlx_vlm {
557 mlx_vlm_cli_present
558 } else if m.tags.contains(&"speech".to_string()) {
559 speech_mlx_available()
560 } else {
561 let mlx_dir = models_dir.join(&m.name);
571 mlx_dir.join("config.json").exists() || !hf_repo.is_empty()
572 };
573 }
574 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
575 {
576 let _ = hf_repo; m.available = false;
578 }
579 }
580 ModelSource::Local { .. } => {
581 let local_path = models_dir.join(&m.name).join("model.gguf");
582 m.available = local_path.exists();
583 }
584 ModelSource::RemoteApi { api_key_env, .. } => {
585 m.available = std::env::var(api_key_env).is_ok();
586 }
587 ModelSource::Ollama { .. } => {
588 m.available = true;
590 }
591 ModelSource::VllmMlx { .. } => {
592 m.available = std::env::var("VLLM_MLX_ENDPOINT").is_ok() || m.available;
595 }
597 ModelSource::Proprietary { auth, .. } => {
598 m.available = match auth {
600 crate::schema::ProprietaryAuth::ApiKeyEnv { env_var } => {
601 std::env::var(env_var).is_ok()
602 }
603 crate::schema::ProprietaryAuth::BearerTokenEnv { env_var } => {
604 std::env::var(env_var).is_ok()
605 }
606 crate::schema::ProprietaryAuth::OAuth2Pkce { .. } => {
607 true
609 }
610 };
611 }
612 ModelSource::AppleFoundationModels { .. } => {
613 #[cfg(any(
620 all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)),
621 all(target_os = "ios", target_arch = "aarch64")
622 ))]
623 {
624 m.available = crate::backend::foundation_models::is_available();
625 }
626 #[cfg(not(any(
627 all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)),
628 all(target_os = "ios", target_arch = "aarch64")
629 )))]
630 {
631 m.available = false;
632 }
633 }
634 ModelSource::Delegated { .. } => {
635 m.available = crate::runner::current_inference_runner().is_some();
640 }
641 }
642 }
643 }
644
645 pub fn save_user_config(&self) -> Result<(), InferenceError> {
647 let user_models: Vec<&ModelSchema> = self
648 .models
649 .values()
650 .filter(|m| !m.tags.contains(&"builtin".to_string()))
651 .collect();
652
653 if user_models.is_empty() {
654 return Ok(());
655 }
656
657 let json = serde_json::to_string_pretty(&user_models)
658 .map_err(|e| InferenceError::InferenceFailed(format!("serialize: {e}")))?;
659 std::fs::write(&self.user_config_path, json)?;
660 Ok(())
661 }
662
663 pub fn load_user_config(&mut self) -> Result<(), InferenceError> {
665 if !self.user_config_path.exists() {
666 return Ok(());
667 }
668
669 let json = std::fs::read_to_string(&self.user_config_path)?;
670 let models: Vec<ModelSchema> = serde_json::from_str(&json)
671 .map_err(|e| InferenceError::InferenceFailed(format!("parse models.json: {e}")))?;
672
673 for m in models {
674 self.register(m);
675 }
676 Ok(())
677 }
678
679 pub fn models_dir(&self) -> &Path {
681 &self.models_dir
682 }
683
684 fn load_builtin_catalog(&mut self) {
686 for schema in builtin_catalog() {
687 self.models.insert(schema.id.clone(), schema);
688 }
689 }
690}
691
692#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
693fn speech_mlx_available() -> bool {
694 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
697 {
698 true
699 }
700
701 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
703 {
704 let runtime_root = speech_runtime_root();
705 runtime_root
706 .join("bin")
707 .join("mlx_audio.stt.generate")
708 .exists()
709 || runtime_root
710 .join("bin")
711 .join("mlx_audio.tts.generate")
712 .exists()
713 }
714}
715
716#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
717fn speech_runtime_root() -> PathBuf {
718 if let Ok(path) = std::env::var("CAR_SPEECH_RUNTIME_DIR") {
719 if !path.trim().is_empty() {
720 return PathBuf::from(path);
721 }
722 }
723 std::env::var("HOME")
724 .map(PathBuf::from)
725 .unwrap_or_else(|_| PathBuf::from("."))
726 .join(".car")
727 .join("speech-runtime")
728}
729
730#[derive(Debug, Clone, Serialize, Deserialize)]
732pub struct ModelInfo {
733 pub id: String,
734 pub name: String,
735 pub provider: String,
736 pub capabilities: Vec<ModelCapability>,
737 pub param_count: String,
738 pub size_mb: u64,
739 pub context_length: usize,
740 pub available: bool,
741 pub is_local: bool,
742 #[serde(default)]
746 pub public_benchmarks: Vec<crate::schema::BenchmarkScore>,
747}
748
749impl From<&ModelSchema> for ModelInfo {
750 fn from(s: &ModelSchema) -> Self {
751 ModelInfo {
752 id: s.id.clone(),
753 name: s.name.clone(),
754 provider: s.provider.clone(),
755 capabilities: s.capabilities.clone(),
756 param_count: s.param_count.clone(),
757 size_mb: s.size_mb(),
758 context_length: s.context_length,
759 available: s.available,
760 is_local: s.is_local(),
761 public_benchmarks: s.public_benchmarks.clone(),
762 }
763 }
764}
765
766async fn download_file(repo: &str, filename: &str, dest: &Path) -> Result<(), InferenceError> {
768 let api = hf_hub::api::tokio::Api::new()
769 .map_err(|e| InferenceError::DownloadFailed(e.to_string()))?;
770
771 let repo = api.model(repo.to_string());
772 let path = repo
773 .get(filename)
774 .await
775 .map_err(|e| InferenceError::DownloadFailed(format!("{filename}: {e}")))?;
776
777 if dest.exists() {
778 return Ok(());
779 }
780
781 #[cfg(unix)]
783 {
784 if std::os::unix::fs::symlink(&path, dest).is_ok() {
785 return Ok(());
786 }
787 }
788
789 std::fs::copy(&path, dest)
790 .map_err(|e| InferenceError::DownloadFailed(format!("copy to {}: {e}", dest.display())))?;
791 Ok(())
792}
793
794async fn ensure_auxiliary_mlx_files(
795 model_name: &str,
796 hf_repo: &str,
797 model_dir: &Path,
798) -> Result<(), InferenceError> {
799 if hf_repo == "mlx-community/Flux-1.lite-8B-MLX-Q4" || model_name == "Flux-1.lite-8B-MLX-Q4" {
800 let t5_tokenizer_path = model_dir.join("tokenizer_2").join("tokenizer.json");
801 if !t5_tokenizer_path.exists() {
802 std::fs::create_dir_all(t5_tokenizer_path.parent().ok_or_else(|| {
803 InferenceError::InferenceFailed("invalid tokenizer path".into())
804 })?)?;
805 info!(
806 path = %t5_tokenizer_path.display(),
807 "downloading missing Flux tokenizer_2/tokenizer.json from base model"
808 );
809 download_file(
810 "Freepik/flux.1-lite-8B",
811 "tokenizer_2/tokenizer.json",
812 &t5_tokenizer_path,
813 )
814 .await?;
815 }
816 }
817 Ok(())
818}
819
820fn requires_full_mlx_snapshot(schema: &ModelSchema) -> bool {
821 match &schema.source {
822 ModelSource::Mlx { hf_repo, .. } => {
823 hf_repo == "ckurasek/Yume-1.5-5B-720P-MLX-4bit"
824 || schema.family.starts_with("yume")
825 || schema.tags.iter().any(|tag| {
826 matches!(
827 tag.as_str(),
828 "wan2.2" | "ti2v" | "world-model" | "image-to-video"
829 )
830 })
831 }
832 _ => false,
833 }
834}
835
836fn huggingface_repo_has_snapshot(repo_id: &str) -> bool {
837 latest_huggingface_repo_snapshot(repo_id).is_some()
838}
839
840fn huggingface_cache_root() -> PathBuf {
841 std::env::var("HF_HOME")
842 .map(PathBuf::from)
843 .unwrap_or_else(|_| {
844 std::env::var("HOME")
845 .map(PathBuf::from)
846 .unwrap_or_else(|_| PathBuf::from("."))
847 .join(".cache")
848 .join("huggingface")
849 })
850 .join("hub")
851}
852
853fn huggingface_repo_dir(repo_id: &str) -> PathBuf {
854 huggingface_cache_root().join(format!("models--{}", repo_id.replace('/', "--")))
855}
856
857fn resolve_huggingface_ref_snapshot(repo_dir: &Path, name: &str) -> Option<PathBuf> {
858 let sha = std::fs::read_to_string(repo_dir.join("refs").join(name))
859 .ok()?
860 .trim()
861 .to_string();
862 if sha.is_empty() {
863 return None;
864 }
865
866 let snapshot = repo_dir.join("snapshots").join(sha);
867 if snapshot_looks_ready(&snapshot) {
868 Some(snapshot)
869 } else {
870 None
871 }
872}
873
874fn latest_huggingface_repo_snapshot(repo_id: &str) -> Option<PathBuf> {
875 let repo_dir = huggingface_repo_dir(repo_id);
876 if let Some(snapshot) = resolve_huggingface_ref_snapshot(&repo_dir, "main") {
877 return Some(snapshot);
878 }
879
880 let snapshots = repo_dir.join("snapshots");
881 let mut candidates: Vec<(SystemTime, PathBuf)> = std::fs::read_dir(snapshots)
882 .ok()?
883 .filter_map(Result::ok)
884 .map(|e| e.path())
885 .filter(|p| p.is_dir() && snapshot_looks_ready(p))
886 .map(|path| {
887 let modified = path
888 .metadata()
889 .and_then(|metadata| metadata.modified())
890 .unwrap_or(SystemTime::UNIX_EPOCH);
891 (modified, path)
892 })
893 .collect();
894 candidates.sort();
895 candidates.pop().map(|(_, path)| path)
896}
897
898fn snapshot_looks_ready(path: &Path) -> bool {
899 if path.join("config.json").exists() || path.join("model_index.json").exists() {
900 return true;
901 }
902 snapshot_contains_ext(path, "safetensors")
903}
904
905fn snapshot_contains_ext(root: &Path, ext: &str) -> bool {
906 let Ok(entries) = std::fs::read_dir(root) else {
907 return false;
908 };
909 entries.filter_map(Result::ok).any(|entry| {
910 let path = entry.path();
911 if path.is_dir() {
912 snapshot_contains_ext(&path, ext)
913 } else {
914 path.extension()
915 .and_then(|value| value.to_str())
916 .map(|value| value.eq_ignore_ascii_case(ext))
917 .unwrap_or(false)
918 }
919 })
920}
921
922async fn download_hf_repo_snapshot(repo_id: &str) -> Result<(PathBuf, usize), InferenceError> {
923 let api = hf_hub::api::tokio::ApiBuilder::from_env()
924 .with_progress(false)
925 .build()
926 .map_err(|e| InferenceError::DownloadFailed(format!("init hf api: {e}")))?;
927 let repo = api.model(repo_id.to_string());
928 let info = repo
929 .info()
930 .await
931 .map_err(|e| InferenceError::DownloadFailed(format!("{repo_id}: {e}")))?;
932
933 let snapshot_path = std::env::var("HF_HOME")
934 .map(PathBuf::from)
935 .unwrap_or_else(|_| {
936 std::env::var("HOME")
937 .map(PathBuf::from)
938 .unwrap_or_else(|_| PathBuf::from("."))
939 .join(".cache")
940 .join("huggingface")
941 })
942 .join("hub")
943 .join(format!("models--{}", repo_id.replace('/', "--")))
944 .join("snapshots")
945 .join(&info.sha);
946 let mut downloaded = 0usize;
947 for sibling in &info.siblings {
948 let local_path = snapshot_path.join(&sibling.rfilename);
949 if local_path.exists() {
950 downloaded += 1;
951 continue;
952 }
953 repo.download(&sibling.rfilename).await.map_err(|e| {
954 InferenceError::DownloadFailed(format!("{repo_id}/{}: {e}", sibling.rfilename))
955 })?;
956 downloaded += 1;
957 }
958
959 Ok((snapshot_path, downloaded))
960}
961
962const BUILTIN_CATALOG_JSON: &str = include_str!("builtin_catalog.json");
971
972static BUILTIN_CATALOG: std::sync::LazyLock<Vec<ModelSchema>> = std::sync::LazyLock::new(|| {
973 serde_json::from_str(BUILTIN_CATALOG_JSON)
974 .expect("builtin_catalog.json failed to parse — fix the JSON, not this code")
975});
976
977fn builtin_catalog() -> Vec<ModelSchema> {
978 BUILTIN_CATALOG.clone()
979}
980
981#[cfg(test)]
982mod tests {
983 use super::*;
984 use tempfile::TempDir;
985
986 fn test_registry() -> (UnifiedRegistry, TempDir) {
987 let tmp = TempDir::new().unwrap();
988 let reg = UnifiedRegistry::new(tmp.path().join("models"));
989 (reg, tmp)
990 }
991
992 #[test]
993 fn builtin_catalog_loads() {
994 let (reg, _tmp) = test_registry();
995 let all = reg.list();
996 assert_eq!(all.len(), builtin_catalog().len());
997 }
998
999 #[test]
1012 fn mlx_vlm_models_reflect_runtime_availability() {
1013 let (reg, _tmp) = test_registry();
1014 let mlx_vlm_models: Vec<&ModelSchema> = reg
1015 .list()
1016 .into_iter()
1017 .filter(|m| m.tags.iter().any(|t| t == "requires-mlx-vlm"))
1018 .collect();
1019 assert!(
1020 !mlx_vlm_models.is_empty(),
1021 "catalog should contain at least one model tagged \
1022 `requires-mlx-vlm` — otherwise this regression has \
1023 nothing to guard"
1024 );
1025
1026 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1027 let expected = crate::backend::mlx_vlm_cli::is_available();
1028 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
1029 let expected = false;
1030
1031 for m in mlx_vlm_models {
1032 assert_eq!(
1033 m.available, expected,
1034 "model {} `available` field should reflect \
1035 mlx_vlm CLI presence (expected {expected}, got {})",
1036 m.id, m.available
1037 );
1038 }
1039 }
1040
1041 #[test]
1053 fn mlx_models_unavailable_on_non_mlx_targets() {
1054 let (reg, _tmp) = test_registry();
1055 let mlx_models: Vec<&ModelSchema> = reg
1056 .list()
1057 .into_iter()
1058 .filter(|m| {
1059 m.is_mlx()
1060 && !m.tags.iter().any(|t| t == "requires-mlx-vlm")
1065 && !m.tags.contains(&"speech".to_string())
1066 })
1067 .collect();
1068 assert!(
1069 !mlx_models.is_empty(),
1070 "catalog should contain at least one plain MLX model — \
1071 otherwise this F1 regression guard has nothing to guard"
1072 );
1073
1074 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1075 {
1076 let any_available = mlx_models.iter().any(|m| m.available);
1080 assert!(
1081 any_available,
1082 "on macOS arm64 with MLX enabled, at least one plain MLX \
1083 model with hf_repo should be available — none were"
1084 );
1085 }
1086 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
1087 {
1088 for m in &mlx_models {
1091 assert!(
1092 !m.available,
1093 "MLX model {} is marked available on a non-MLX target — \
1094 the adaptive router will add it to fallback chains \
1095 and dispatch will fail (Parslee-ai/car#231 §7.1)",
1096 m.id
1097 );
1098 }
1099 }
1100 }
1101
1102 #[test]
1105 fn builtin_catalog_json_parses() {
1106 let catalog: Vec<ModelSchema> = serde_json::from_str(BUILTIN_CATALOG_JSON)
1107 .expect("builtin_catalog.json must be valid ModelSchema array");
1108 assert!(
1109 !catalog.is_empty(),
1110 "embedded catalog has no entries — that's almost certainly wrong"
1111 );
1112
1113 let mut seen = std::collections::HashSet::new();
1114 for entry in &catalog {
1115 assert!(
1116 seen.insert(entry.id.clone()),
1117 "duplicate id in builtin_catalog.json: {}",
1118 entry.id
1119 );
1120 }
1121 }
1122
1123 #[test]
1124 fn public_benchmarks_round_trip_through_model_info() {
1125 use crate::schema::BenchmarkScore;
1126 let (mut reg, _tmp) = test_registry();
1127 let mut schema = reg
1128 .find_by_name("Qwen3-4B")
1129 .expect("catalog has Qwen3-4B")
1130 .clone();
1131 schema.id = "test/qwen3-4b-with-bench".into();
1132 schema.public_benchmarks = vec![
1133 BenchmarkScore {
1134 name: "MMLU-Pro".into(),
1135 score: 0.482,
1136 harness: Some("5-shot CoT".into()),
1137 source_url: Some("https://example.invalid/qwen3-4b-card".into()),
1138 measured_at: Some("2025-08-12".into()),
1139 },
1140 BenchmarkScore {
1141 name: "HumanEval".into(),
1142 score: 0.713,
1143 harness: Some("pass@1".into()),
1144 source_url: None,
1145 measured_at: None,
1146 },
1147 ];
1148 reg.register(schema);
1149
1150 let stored = reg
1151 .get("test/qwen3-4b-with-bench")
1152 .expect("registered model is retrievable");
1153 let info = ModelInfo::from(stored);
1154 assert_eq!(info.public_benchmarks.len(), 2);
1155
1156 let json = serde_json::to_string(&info).unwrap();
1158 assert!(json.contains("\"public_benchmarks\""));
1159 assert!(json.contains("\"MMLU-Pro\""));
1160 assert!(json.contains("\"5-shot CoT\""));
1161
1162 let decoded: ModelInfo = serde_json::from_str(&json).unwrap();
1164 assert_eq!(decoded.public_benchmarks.len(), 2);
1165 assert_eq!(decoded.public_benchmarks[0].name, "MMLU-Pro");
1166 assert_eq!(decoded.public_benchmarks[1].name, "HumanEval");
1167 }
1168
1169 #[test]
1170 fn public_benchmarks_default_to_empty_when_absent_in_json() {
1171 let legacy_json = r#"{
1174 "id": "legacy/test:1",
1175 "name": "Legacy Test",
1176 "provider": "test",
1177 "family": "test",
1178 "version": "",
1179 "capabilities": ["generate"],
1180 "context_length": 4096,
1181 "param_count": "1B",
1182 "quantization": null,
1183 "performance": {},
1184 "cost": {},
1185 "source": { "type": "ollama", "model_tag": "legacy:1" },
1186 "tags": [],
1187 "supported_params": []
1188 }"#;
1189 let schema: ModelSchema = serde_json::from_str(legacy_json).unwrap();
1190 assert!(schema.public_benchmarks.is_empty());
1191 }
1192
1193 #[test]
1194 fn find_by_name() {
1195 let (reg, _tmp) = test_registry();
1196 let m = reg.find_by_name("Qwen3-4B").unwrap();
1197 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1198 assert_eq!(m.id, "mlx/qwen3-4b:4bit");
1199 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
1200 assert_eq!(m.id, "qwen/qwen3-4b:q4_k_m");
1201 assert!(m.has_capability(ModelCapability::Code));
1202 }
1203
1204 #[test]
1205 fn query_by_capability() {
1206 let (reg, _tmp) = test_registry();
1207 let embed_models = reg.query_by_capability(ModelCapability::Embed);
1208 assert_eq!(embed_models.len(), 2);
1209 assert!(embed_models
1210 .iter()
1211 .any(|model| model.name == "Qwen3-Embedding-0.6B"));
1212 assert!(embed_models
1213 .iter()
1214 .any(|model| model.name == "Qwen3-Embedding-0.6B-MLX"));
1215 }
1216
1217 #[test]
1218 fn query_with_filter() {
1219 let (reg, _tmp) = test_registry();
1220 let code_small = reg.query(&ModelFilter {
1221 capabilities: vec![ModelCapability::Code],
1222 max_size_mb: Some(3000),
1223 local_only: true,
1224 ..Default::default()
1225 });
1226 assert_eq!(code_small.len(), 4);
1228 }
1229
1230 #[test]
1231 fn register_remote() {
1232 let (mut reg, _tmp) = test_registry();
1233 let initial_len = reg.list().len();
1234 let initial_reasoning_len = reg
1235 .query(&ModelFilter {
1236 capabilities: vec![ModelCapability::Reasoning, ModelCapability::ToolUse],
1237 ..Default::default()
1238 })
1239 .len();
1240 let remote = ModelSchema {
1241 id: "anthropic/claude-sonnet-4-6:latest".into(),
1242 name: "Claude Sonnet 4.6".into(),
1243 provider: "anthropic".into(),
1244 family: "claude-4".into(),
1245 version: "latest".into(),
1246 capabilities: vec![
1247 ModelCapability::Generate,
1248 ModelCapability::Code,
1249 ModelCapability::Reasoning,
1250 ModelCapability::ToolUse,
1251 ],
1252 context_length: 200000,
1253 param_count: String::new(),
1254 quantization: None,
1255 performance: PerformanceEnvelope {
1256 latency_p50_ms: Some(2000),
1257 ..Default::default()
1258 },
1259 cost: CostModel {
1260 input_per_mtok: Some(3.0),
1261 output_per_mtok: Some(15.0),
1262 ..Default::default()
1263 },
1264 source: ModelSource::RemoteApi {
1265 endpoint: "https://api.anthropic.com/v1/messages".into(),
1266 api_key_env: "ANTHROPIC_API_KEY".into(),
1267 api_key_envs: vec![],
1268 api_version: Some("2023-06-01".into()),
1269 protocol: ApiProtocol::Anthropic,
1270 },
1271 tags: vec![],
1272 supported_params: vec![],
1273 public_benchmarks: vec![],
1274 available: false,
1275 };
1276
1277 reg.register(remote);
1278 assert_eq!(reg.list().len(), initial_len);
1280
1281 let reasoning = reg.query(&ModelFilter {
1282 capabilities: vec![ModelCapability::Reasoning, ModelCapability::ToolUse],
1283 ..Default::default()
1284 });
1285 assert_eq!(reasoning.len(), initial_reasoning_len);
1287 }
1288
1289 #[test]
1290 fn unregister() {
1291 let (mut reg, _tmp) = test_registry();
1292 let initial_len = reg.list().len();
1293 let removed = reg.unregister("qwen/qwen3-0.6b:q8_0");
1294 assert!(removed.is_some());
1295 assert_eq!(reg.list().len(), initial_len - 1);
1296 }
1297
1298 #[test]
1299 fn speech_models_are_curated() {
1300 let (reg, _tmp) = test_registry();
1301 let stt = reg.query_by_capability(ModelCapability::SpeechToText);
1302 let tts = reg.query_by_capability(ModelCapability::TextToSpeech);
1303 assert_eq!(stt.len(), 2);
1304 assert_eq!(tts.len(), 4);
1305 }
1306
1307 #[test]
1308 fn qwen_8b_variants_keep_tool_use_consistent() {
1309 let (reg, _tmp) = test_registry();
1310 for name in ["Qwen3-8B", "Qwen3-8B-MLX"] {
1311 let model = reg.find_by_name(name).expect("model should exist");
1312 assert!(model.has_capability(ModelCapability::ToolUse));
1313 assert!(model.has_capability(ModelCapability::MultiToolCall));
1314 }
1315 }
1316
1317 #[test]
1318 fn mac_name_resolution_prefers_mlx_siblings() {
1319 #[allow(unused_variables)]
1322 let (reg, _tmp) = test_registry();
1323 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1324 {
1325 assert_eq!(
1326 reg.find_by_name("Qwen3-0.6B").unwrap().id,
1327 "mlx/qwen3-0.6b:6bit"
1328 );
1329 assert_eq!(
1330 reg.find_by_name("Qwen3-1.7B").unwrap().id,
1331 "mlx/qwen3-1.7b:3bit"
1332 );
1333 assert_eq!(
1334 reg.find_by_name("Qwen3-Embedding-0.6B").unwrap().id,
1335 "mlx/qwen3-embedding-0.6b:mxfp8"
1336 );
1337 }
1338 }
1339
1340 #[test]
1341 fn remote_multimodal_models_are_curated_as_vision_capable() {
1342 let (reg, _tmp) = test_registry();
1343 for name in [
1344 "claude-opus-4-7",
1345 "claude-opus-4-6",
1346 "claude-sonnet-4-6",
1347 "claude-haiku-4-5",
1348 "gpt-5.4",
1349 "gpt-5.4-mini",
1350 "o3",
1351 "o4-mini",
1352 "gpt-4.1-mini",
1353 "gemini-2.5-pro",
1354 "gemini-2.5-flash",
1355 ] {
1356 let model = reg.find_by_name(name).expect("model should exist");
1357 assert!(
1358 model.has_capability(ModelCapability::Vision),
1359 "{name} should be curated as vision-capable"
1360 );
1361 }
1362 }
1363
1364 #[test]
1365 fn qwen25vl_entries_are_replaced_by_qwen3vl_in_builtin_catalog() {
1366 let (reg, _tmp) = test_registry();
1367
1368 let stale_ids = [
1369 "mlx/qwen2.5-vl-3b:4bit",
1371 "mlx/qwen2.5-vl-7b:4bit",
1372 "mlx-vlm/qwen2.5-vl-3b:4bit",
1375 "mlx-vlm/qwen2.5-vl-7b:4bit",
1376 "vllm-mlx/qwen2.5-vl-3b:4bit",
1378 ];
1379 for id in stale_ids {
1380 assert!(
1381 reg.get(id).is_none(),
1382 "{id} is superseded by Qwen3-VL; the catalog must not advertise it"
1383 );
1384 }
1385
1386 let vision_ids: Vec<&str> = reg
1387 .query_by_capability(ModelCapability::Vision)
1388 .into_iter()
1389 .map(|model| model.id.as_str())
1390 .collect();
1391 for stale in stale_ids {
1392 assert!(
1393 !vision_ids.contains(&stale),
1394 "{stale} must not be reachable through the Vision capability index"
1395 );
1396 }
1397 assert!(
1398 vision_ids.contains(&"mlx-vlm/qwen3-vl-2b:bf16"),
1399 "Qwen3-VL is the supported local VL family and must route as Vision"
1400 );
1401 }
1402
1403 #[test]
1404 fn gemini_models_are_curated_for_multimodal_tool_use() {
1405 let (reg, _tmp) = test_registry();
1406 for name in ["gemini-2.5-pro", "gemini-2.5-flash"] {
1407 let model = reg.find_by_name(name).expect("model should exist");
1408 assert!(model.has_capability(ModelCapability::Vision));
1409 assert!(model.has_capability(ModelCapability::ToolUse));
1410 assert!(model.has_capability(ModelCapability::MultiToolCall));
1411 }
1412 }
1413
1414 #[test]
1415 fn visual_generation_models_are_curated() {
1416 let (reg, _tmp) = test_registry();
1417 assert_eq!(
1418 reg.query_by_capability(ModelCapability::ImageGeneration)
1419 .len(),
1420 1
1421 );
1422 assert_eq!(
1423 reg.query_by_capability(ModelCapability::VideoGeneration)
1424 .len(),
1425 2
1426 );
1427 let yume = reg
1428 .get("mlx/yume-1.5-5b-720p:q4")
1429 .expect("Yume MLX should be in the built-in catalog");
1430 assert!(yume.has_capability(ModelCapability::VideoGeneration));
1431 assert!(yume.tags.contains(&"text-to-video".to_string()));
1432 assert!(yume.tags.contains(&"image-to-video".to_string()));
1433 assert!(yume.tags.contains(&"world-model".to_string()));
1434 }
1435}