1use crate::schema::reasoning_params;
9use std::collections::HashMap;
10use std::path::{Path, PathBuf};
11use std::time::SystemTime;
12
13use serde::{Deserialize, Serialize};
14use tracing::info;
15
16use crate::download::{DownloadEvent, ProgressSink};
17use crate::schema::*;
18use crate::InferenceError;
19
20#[derive(Debug, Clone, Default)]
22pub struct ModelFilter {
23 pub capabilities: Vec<ModelCapability>,
25 pub max_size_mb: Option<u64>,
27 pub max_latency_ms: Option<u64>,
29 pub max_cost_per_mtok: Option<f64>,
31 pub tags: Vec<String>,
33 pub provider: Option<String>,
35 pub local_only: bool,
37 pub available_only: bool,
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct ModelUpgrade {
45 pub from_id: String,
46 pub from_name: String,
47 pub to_id: String,
48 pub to_name: String,
49 pub reason: String,
50 pub target_runtime: Option<String>,
51 pub target_runtime_requirement: Option<String>,
52 pub minimum_runtimes: Vec<ModelRuntimeRequirement>,
53 pub target_available: bool,
54 pub target_pullable: bool,
55 pub remove_old_supported: bool,
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct ModelRuntimeRequirement {
60 pub name: String,
61 pub minimum_version: String,
62}
63
64pub struct UnifiedRegistry {
66 models_dir: PathBuf,
67 models: HashMap<String, ModelSchema>,
69 user_config_path: PathBuf,
71}
72
73#[derive(Debug, Clone, Deserialize)]
74struct ModelUpgradeRule {
75 from_ids: Vec<String>,
76 to_id: String,
77 reason: String,
78 target_runtime: Option<String>,
79 target_runtime_requirement: Option<String>,
80 #[serde(default)]
81 minimum_runtimes: Vec<ModelRuntimeRequirement>,
82 #[serde(default = "default_remove_old_after_available")]
83 remove_old_after_available: bool,
84}
85
86fn default_remove_old_after_available() -> bool {
87 true
88}
89
90fn model_upgrade_rules() -> Vec<ModelUpgradeRule> {
91 serde_json::from_str(include_str!("../assets/model-upgrades.json"))
92 .expect("built-in model-upgrades.json should parse")
93}
94
95impl UnifiedRegistry {
96 pub fn new(models_dir: PathBuf) -> Self {
97 let user_config_path = models_dir
98 .parent()
99 .unwrap_or(&models_dir)
100 .join("models.json");
101
102 let mut registry = Self {
103 models_dir,
104 models: HashMap::new(),
105 user_config_path,
106 };
107 registry.load_builtin_catalog();
108 registry.refresh_availability();
109 let _ = registry.load_user_config();
111 registry
112 }
113
114 pub fn register(&mut self, mut schema: ModelSchema) {
116 if schema.is_mlx() {
118 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
129 {
130 schema.available = if schema.tags.contains(&"speech".to_string()) {
131 speech_mlx_available()
132 } else if let ModelSource::Mlx { ref hf_repo, .. } = schema.source {
133 let mlx_dir = self.models_dir.join(&schema.name);
139 mlx_dir.join("config.json").exists() || !hf_repo.is_empty()
140 } else {
141 let mlx_dir = self.models_dir.join(&schema.name);
142 mlx_dir.join("config.json").exists()
143 };
144 }
145 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
146 {
147 schema.available = false;
148 }
149 } else if schema.is_vllm_mlx() {
150 schema.available = std::env::var("VLLM_MLX_ENDPOINT").is_ok() || schema.available;
152 } else if schema.is_local() {
153 let local_path = self.models_dir.join(&schema.name).join("model.gguf");
154 schema.available = local_path.exists();
155 } else if schema.is_remote() {
156 if let ModelSource::RemoteApi {
158 ref api_key_env, ..
159 } = schema.source
160 {
161 schema.available = std::env::var(api_key_env).is_ok();
162 }
163 }
164 info!(id = %schema.id, name = %schema.name, available = schema.available, "registered model");
165 self.models.insert(schema.id.clone(), schema);
166 }
167
168 pub fn unregister(&mut self, id: &str) -> Option<ModelSchema> {
170 let removed = self.models.remove(id);
171 if let Some(ref m) = removed {
172 info!(id = %m.id, "unregistered model");
173 }
174 removed
175 }
176
177 pub fn list(&self) -> Vec<&ModelSchema> {
179 let mut models: Vec<&ModelSchema> = self.models.values().collect();
180 models.sort_by(|a, b| a.id.cmp(&b.id));
181 models
182 }
183
184 pub fn query(&self, filter: &ModelFilter) -> Vec<&ModelSchema> {
186 self.models
187 .values()
188 .filter(|m| {
189 if !filter.capabilities.iter().all(|c| m.has_capability(*c)) {
191 return false;
192 }
193 if let Some(max) = filter.max_size_mb {
195 if m.size_mb() > max && m.is_local() {
196 return false;
197 }
198 }
199 if let Some(max) = filter.max_latency_ms {
201 if let Some(p50) = m.performance.latency_p50_ms {
202 if p50 > max {
203 return false;
204 }
205 }
206 }
207 if let Some(max) = filter.max_cost_per_mtok {
209 if let Some(cost) = m.cost.output_per_mtok {
210 if cost > max {
211 return false;
212 }
213 }
214 }
215 if !filter.tags.iter().all(|t| m.tags.contains(t)) {
217 return false;
218 }
219 if let Some(ref p) = filter.provider {
221 if &m.provider != p {
222 return false;
223 }
224 }
225 if filter.local_only && !m.is_local() {
227 return false;
228 }
229 if filter.available_only && !m.available {
231 return false;
232 }
233 true
234 })
235 .collect()
236 }
237
238 pub fn query_by_capability(&self, cap: ModelCapability) -> Vec<&ModelSchema> {
240 self.query(&ModelFilter {
241 capabilities: vec![cap],
242 ..Default::default()
243 })
244 }
245
246 pub fn available_upgrades(&self) -> Vec<ModelUpgrade> {
248 let mut upgrades = Vec::new();
249 for rule in model_upgrade_rules() {
250 let Some(from) = rule
251 .from_ids
252 .iter()
253 .find_map(|id| self.models.get(id.as_str()))
254 .filter(|schema| schema.available)
255 else {
256 continue;
257 };
258 let Some(to) = self.models.get(rule.to_id.as_str()) else {
259 continue;
260 };
261 upgrades.push(ModelUpgrade {
262 from_id: from.id.clone(),
263 from_name: from.name.clone(),
264 to_id: to.id.clone(),
265 to_name: to.name.clone(),
266 reason: rule.reason.clone(),
267 target_runtime: rule.target_runtime.clone(),
268 target_runtime_requirement: rule.target_runtime_requirement.clone(),
269 minimum_runtimes: rule.minimum_runtimes.clone(),
270 target_available: to.available,
271 target_pullable: matches!(
272 to.source,
273 ModelSource::Local { .. } | ModelSource::Mlx { .. }
274 ),
275 remove_old_supported: matches!(
276 from.source,
277 ModelSource::Local { .. } | ModelSource::Mlx { .. }
278 ) && rule.remove_old_after_available,
279 });
280 }
281 upgrades.sort_by(|a, b| a.from_id.cmp(&b.from_id).then(a.to_id.cmp(&b.to_id)));
282 upgrades.dedup_by(|a, b| a.from_id == b.from_id && a.to_id == b.to_id);
283 upgrades
284 }
285
286 pub fn get(&self, id: &str) -> Option<&ModelSchema> {
288 self.models.get(id)
289 }
290
291 pub fn find_by_name(&self, name: &str) -> Option<&ModelSchema> {
294 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
295 if !name.to_ascii_lowercase().ends_with("-mlx") {
296 if let Some(mlx_variant) = self
297 .models
298 .values()
299 .find(|m| m.name.eq_ignore_ascii_case(&format!("{name}-MLX")))
300 {
301 return Some(mlx_variant);
302 }
303 }
304
305 self.models
306 .values()
307 .find(|m| m.name.eq_ignore_ascii_case(name))
308 }
309
310 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
314 pub fn resolve_mlx_equivalent(&self, schema: &ModelSchema) -> Option<&ModelSchema> {
315 if schema.is_mlx() || schema.is_vllm_mlx() {
317 return None;
318 }
319 if !matches!(schema.source, ModelSource::Local { .. }) {
321 return None;
322 }
323 let primary_cap = schema.capabilities.first()?;
325 self.models.values().find(|m| {
326 m.is_mlx() && m.family == schema.family && m.capabilities.contains(primary_cap)
327 })
328 }
329
330 pub async fn ensure_local(&self, id: &str) -> Result<PathBuf, InferenceError> {
332 self.ensure_local_with_progress(id, &ProgressSink::none())
333 .await
334 }
335
336 pub async fn ensure_local_with_progress(
341 &self,
342 id: &str,
343 sink: &ProgressSink,
344 ) -> Result<PathBuf, InferenceError> {
345 let schema = self
346 .get(id)
347 .or_else(|| self.find_by_name(id))
348 .ok_or_else(|| InferenceError::ModelNotFound(id.to_string()))?;
349 let model_name = schema.name.clone();
350 let model_id = schema.id.clone();
351 let needed_mb = schema.size_mb();
352 let model_dir = self.models_dir.join(&schema.name);
353
354 let _guard = crate::download::acquire_model_lock(&model_id).await;
356
357 if let Err(e) = crate::download::check_disk_space(&model_dir, needed_mb) {
359 sink.emit(DownloadEvent::Failed { error: e.clone() });
360 return Err(InferenceError::DownloadFailed(e));
361 }
362
363 sink.emit(DownloadEvent::Started {
364 model: model_name.clone(),
365 total_files: 0,
366 total_mb: needed_mb,
367 });
368 let result = self.ensure_local_inner(id, sink).await;
369 match &result {
370 Ok(_) => sink.emit(DownloadEvent::Completed { model: model_name }),
371 Err(e) => sink.emit(DownloadEvent::Failed {
372 error: e.to_string(),
373 }),
374 }
375 result
376 }
377
378 async fn ensure_local_inner(
379 &self,
380 id: &str,
381 sink: &ProgressSink,
382 ) -> Result<PathBuf, InferenceError> {
383 let schema = self
384 .get(id)
385 .or_else(|| self.find_by_name(id))
386 .ok_or_else(|| InferenceError::ModelNotFound(id.to_string()))?;
387
388 match &schema.source {
389 ModelSource::Local {
390 hf_repo,
391 hf_filename,
392 tokenizer_repo,
393 } => {
394 let model_dir = self.models_dir.join(&schema.name);
395 let model_path = model_dir.join("model.gguf");
396 let tokenizer_path = model_dir.join("tokenizer.json");
397
398 if model_path.exists() && tokenizer_path.exists() {
399 return Ok(model_dir);
400 }
401
402 std::fs::create_dir_all(&model_dir)?;
403
404 if !model_path.exists() {
405 info!(model = %schema.name, repo = %hf_repo, "downloading model weights");
406 sink.emit(DownloadEvent::FileStarted {
407 filename: "model weights".into(),
408 index: 1,
409 total_files: 2,
410 size_mb: schema.size_mb(),
411 });
412 download_file(hf_repo, hf_filename, &model_path).await?;
413 sink.emit(DownloadEvent::FileCompleted {
414 filename: "model weights".into(),
415 });
416 }
417 if !tokenizer_path.exists() {
418 info!(model = %schema.name, repo = %tokenizer_repo, "downloading tokenizer");
419 sink.emit(DownloadEvent::FileStarted {
420 filename: "tokenizer".into(),
421 index: 2,
422 total_files: 2,
423 size_mb: 0,
424 });
425 download_file(tokenizer_repo, "tokenizer.json", &tokenizer_path).await?;
426 sink.emit(DownloadEvent::FileCompleted {
427 filename: "tokenizer".into(),
428 });
429 }
430
431 Ok(model_dir)
432 }
433 ModelSource::Mlx {
434 hf_repo,
435 hf_weight_file,
436 } => {
437 let model_dir = self.models_dir.join(&schema.name);
438 let config_path = model_dir.join("config.json");
439
440 if config_path.exists() {
441 ensure_auxiliary_mlx_files(&schema.name, hf_repo, &model_dir).await?;
442 info!(model = %schema.name, path = %model_dir.display(), "using managed local MLX model");
443 return Ok(model_dir);
444 }
445
446 if let Some(snapshot_dir) = latest_huggingface_repo_snapshot(hf_repo) {
447 ensure_auxiliary_mlx_files(&schema.name, hf_repo, &snapshot_dir).await?;
448 info!(model = %schema.name, path = %snapshot_dir.display(), "using cached MLX snapshot");
449 return Ok(snapshot_dir);
450 }
451
452 if requires_full_mlx_snapshot(&schema) {
453 info!(
454 model = %schema.name,
455 repo = %hf_repo,
456 "downloading full MLX snapshot"
457 );
458 let (snapshot_dir, _files_downloaded) =
459 download_hf_repo_snapshot(hf_repo).await?;
460 ensure_auxiliary_mlx_files(&schema.name, hf_repo, &snapshot_dir).await?;
461 return Ok(snapshot_dir);
462 }
463
464 std::fs::create_dir_all(&model_dir)?;
465
466 info!(model = %schema.name, repo = %hf_repo, "downloading MLX model");
467
468 emit_file(sink, "config", 0, schema.size_mb());
472 download_file(hf_repo, "config.json", &config_path).await?;
473 let tok_path = model_dir.join("tokenizer.json");
474 if !tok_path.exists() {
475 emit_file(sink, "tokenizer", 0, 0);
476 download_file(hf_repo, "tokenizer.json", &tok_path).await?;
477 }
478 let tok_config_path = model_dir.join("tokenizer_config.json");
479 if !tok_config_path.exists() {
480 let _ = download_file(hf_repo, "tokenizer_config.json", &tok_config_path).await;
481 }
482
483 if let Some(ref wf) = hf_weight_file {
485 let wf_path = model_dir.join(wf);
486 if !wf_path.exists() {
487 emit_file(sink, "model weights", 0, schema.size_mb());
488 download_file(hf_repo, wf, &wf_path).await?;
489 }
490 } else {
491 let single = model_dir.join("model.safetensors");
493 if !single.exists() {
494 emit_file(sink, "model weights", 0, schema.size_mb());
495 match download_file(hf_repo, "model.safetensors", &single).await {
496 Ok(()) => {}
497 Err(_) => {
498 let index_path = model_dir.join("model.safetensors.index.json");
500 download_file(hf_repo, "model.safetensors.index.json", &index_path)
501 .await?;
502
503 let index_json: serde_json::Value =
504 serde_json::from_str(&std::fs::read_to_string(&index_path)?)
505 .map_err(|e| {
506 InferenceError::InferenceFailed(format!(
507 "parse index: {e}"
508 ))
509 })?;
510
511 if let Some(weight_map) =
512 index_json.get("weight_map").and_then(|m| m.as_object())
513 {
514 let mut files: std::collections::HashSet<String> =
515 std::collections::HashSet::new();
516 for filename in weight_map.values() {
517 if let Some(f) = filename.as_str() {
518 files.insert(f.to_string());
519 }
520 }
521 let shard_total = files.len() as u32;
522 for (i, file) in files.iter().enumerate() {
523 let dest = model_dir.join(file);
524 if !dest.exists() {
525 info!(file = %file, "downloading weight shard");
526 sink.emit(DownloadEvent::FileStarted {
527 filename: format!("weights part {}", i + 1),
528 index: (i + 1) as u32,
529 total_files: shard_total,
530 size_mb: 0,
531 });
532 download_file(hf_repo, file, &dest).await?;
533 sink.emit(DownloadEvent::FileCompleted {
534 filename: format!("weights part {}", i + 1),
535 });
536 }
537 }
538 }
539 }
540 }
541 }
542 }
543
544 ensure_auxiliary_mlx_files(&schema.name, hf_repo, &model_dir).await?;
545 Ok(model_dir)
546 }
547 _ => Err(InferenceError::InferenceFailed(format!(
548 "model {} is not local",
549 id
550 ))),
551 }
552 }
553
554 pub fn remove_local(&mut self, id: &str) -> Result<(), InferenceError> {
556 let schema = self
557 .get(id)
558 .or_else(|| self.find_by_name(id))
559 .ok_or_else(|| InferenceError::ModelNotFound(id.to_string()))?;
560
561 let model_dir = self.models_dir.join(&schema.name);
562 if model_dir.exists() {
563 std::fs::remove_dir_all(&model_dir)?;
564 info!(model = %schema.name, "removed model");
565 }
566
567 match &schema.source {
568 ModelSource::Mlx { hf_repo, .. } => {
569 let repo_dir = huggingface_repo_dir(hf_repo);
570 if repo_dir.exists() {
571 std::fs::remove_dir_all(&repo_dir)?;
572 info!(model = %schema.name, repo = %hf_repo, "removed Hugging Face cache");
573 }
574 }
575 ModelSource::Local {
576 hf_repo,
577 tokenizer_repo,
578 ..
579 } => {
580 for repo in [hf_repo, tokenizer_repo] {
581 let repo_dir = huggingface_repo_dir(repo);
582 if repo_dir.exists() {
583 std::fs::remove_dir_all(&repo_dir)?;
584 info!(model = %schema.name, repo = %repo, "removed Hugging Face cache");
585 }
586 }
587 }
588 _ => {}
589 }
590
591 let id = schema.id.clone();
593 if let Some(m) = self.models.get_mut(&id) {
594 m.available = false;
595 }
596 Ok(())
597 }
598
599 pub fn refresh_availability(&mut self) {
607 let models_dir = self.models_dir.clone();
611 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
616 let mlx_vlm_cli_present = crate::backend::mlx_vlm_cli::is_available();
617 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
618 #[allow(unused_variables)]
619 let mlx_vlm_cli_present = false;
620
621 for m in self.models.values_mut() {
622 match &m.source {
623 ModelSource::Mlx { hf_repo, .. } => {
624 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
633 {
634 let needs_mlx_vlm = m.tags.iter().any(|t| t == "requires-mlx-vlm");
641
642 m.available = if needs_mlx_vlm {
643 mlx_vlm_cli_present
644 } else if m.tags.contains(&"speech".to_string()) {
645 speech_mlx_available()
646 } else {
647 let mlx_dir = models_dir.join(&m.name);
657 mlx_dir.join("config.json").exists() || !hf_repo.is_empty()
658 };
659 }
660 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
661 {
662 let _ = hf_repo; m.available = false;
664 }
665 }
666 ModelSource::Local { .. } => {
667 let local_path = models_dir.join(&m.name).join("model.gguf");
668 m.available = local_path.exists();
669 }
670 ModelSource::RemoteApi { api_key_env, .. } => {
671 m.available = std::env::var(api_key_env).is_ok();
672 }
673 ModelSource::Ollama { .. } => {
674 m.available = true;
676 }
677 ModelSource::VllmMlx { .. } => {
678 m.available = std::env::var("VLLM_MLX_ENDPOINT").is_ok() || m.available;
681 }
683 ModelSource::Proprietary { auth, .. } => {
684 m.available = match auth {
686 crate::schema::ProprietaryAuth::ApiKeyEnv { env_var } => {
687 std::env::var(env_var).is_ok()
688 }
689 crate::schema::ProprietaryAuth::BearerTokenEnv { env_var } => {
690 std::env::var(env_var).is_ok()
691 }
692 crate::schema::ProprietaryAuth::OAuth2Pkce { .. } => {
693 true
695 }
696 };
697 }
698 ModelSource::AppleFoundationModels { .. } => {
699 #[cfg(any(
706 all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)),
707 all(target_os = "ios", target_arch = "aarch64")
708 ))]
709 {
710 m.available = crate::backend::foundation_models::is_available();
711 }
712 #[cfg(not(any(
713 all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)),
714 all(target_os = "ios", target_arch = "aarch64")
715 )))]
716 {
717 m.available = false;
718 }
719 }
720 ModelSource::Delegated { .. } => {
721 m.available = crate::runner::current_inference_runner().is_some();
726 }
727 }
728 }
729 }
730
731 pub fn save_user_config(&self) -> Result<(), InferenceError> {
733 let user_models: Vec<&ModelSchema> = self
734 .models
735 .values()
736 .filter(|m| !m.tags.contains(&"builtin".to_string()))
737 .collect();
738
739 if user_models.is_empty() {
740 return Ok(());
741 }
742
743 let json = serde_json::to_string_pretty(&user_models)
744 .map_err(|e| InferenceError::InferenceFailed(format!("serialize: {e}")))?;
745 std::fs::write(&self.user_config_path, json)?;
746 Ok(())
747 }
748
749 pub fn load_user_config(&mut self) -> Result<(), InferenceError> {
751 if !self.user_config_path.exists() {
752 return Ok(());
753 }
754
755 let json = std::fs::read_to_string(&self.user_config_path)?;
756 let models: Vec<ModelSchema> = serde_json::from_str(&json)
757 .map_err(|e| InferenceError::InferenceFailed(format!("parse models.json: {e}")))?;
758
759 for m in models {
760 self.register(m);
761 }
762 Ok(())
763 }
764
765 pub fn models_dir(&self) -> &Path {
767 &self.models_dir
768 }
769
770 fn load_builtin_catalog(&mut self) {
772 for schema in builtin_catalog() {
773 self.models.insert(schema.id.clone(), schema);
774 }
775 }
776}
777
778#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
779fn speech_mlx_available() -> bool {
780 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
783 {
784 true
785 }
786
787 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
789 {
790 let runtime_root = speech_runtime_root();
791 runtime_root
792 .join("bin")
793 .join("mlx_audio.stt.generate")
794 .exists()
795 || runtime_root
796 .join("bin")
797 .join("mlx_audio.tts.generate")
798 .exists()
799 }
800}
801
802#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
803#[allow(dead_code)] fn speech_runtime_root() -> PathBuf {
805 if let Ok(path) = std::env::var("CAR_SPEECH_RUNTIME_DIR") {
806 if !path.trim().is_empty() {
807 return PathBuf::from(path);
808 }
809 }
810 std::env::var("HOME")
811 .map(PathBuf::from)
812 .unwrap_or_else(|_| PathBuf::from("."))
813 .join(".car")
814 .join("speech-runtime")
815}
816
817#[derive(Debug, Clone, Serialize, Deserialize)]
819pub struct ModelInfo {
820 pub id: String,
821 pub name: String,
822 pub provider: String,
823 pub capabilities: Vec<ModelCapability>,
824 pub param_count: String,
825 pub size_mb: u64,
826 pub context_length: usize,
827 pub available: bool,
828 pub is_local: bool,
829 #[serde(default)]
833 pub public_benchmarks: Vec<crate::schema::BenchmarkScore>,
834}
835
836impl From<&ModelSchema> for ModelInfo {
837 fn from(s: &ModelSchema) -> Self {
838 ModelInfo {
839 id: s.id.clone(),
840 name: s.name.clone(),
841 provider: s.provider.clone(),
842 capabilities: s.capabilities.clone(),
843 param_count: s.param_count.clone(),
844 size_mb: s.size_mb(),
845 context_length: s.context_length,
846 available: s.available,
847 is_local: s.is_local(),
848 public_benchmarks: s.public_benchmarks.clone(),
849 }
850 }
851}
852
853fn emit_file(sink: &ProgressSink, name: &str, index: u32, size_mb: u64) {
858 sink.emit(DownloadEvent::FileStarted {
859 filename: name.to_string(),
860 index,
861 total_files: 0,
862 size_mb,
863 });
864}
865
866async fn download_file(repo: &str, filename: &str, dest: &Path) -> Result<(), InferenceError> {
868 let api = hf_hub::api::tokio::Api::new()
869 .map_err(|e| InferenceError::DownloadFailed(e.to_string()))?;
870
871 let repo = api.model(repo.to_string());
872 let path = repo
873 .get(filename)
874 .await
875 .map_err(|e| InferenceError::DownloadFailed(format!("{filename}: {e}")))?;
876
877 if dest.exists() {
878 return Ok(());
879 }
880
881 #[cfg(unix)]
883 {
884 if std::os::unix::fs::symlink(&path, dest).is_ok() {
885 return Ok(());
886 }
887 }
888
889 std::fs::copy(&path, dest)
890 .map_err(|e| InferenceError::DownloadFailed(format!("copy to {}: {e}", dest.display())))?;
891 Ok(())
892}
893
894async fn ensure_auxiliary_mlx_files(
895 model_name: &str,
896 hf_repo: &str,
897 model_dir: &Path,
898) -> Result<(), InferenceError> {
899 if hf_repo == "mlx-community/Flux-1.lite-8B-MLX-Q4" || model_name == "Flux-1.lite-8B-MLX-Q4" {
900 let t5_tokenizer_path = model_dir.join("tokenizer_2").join("tokenizer.json");
901 if !t5_tokenizer_path.exists() {
902 std::fs::create_dir_all(t5_tokenizer_path.parent().ok_or_else(|| {
903 InferenceError::InferenceFailed("invalid tokenizer path".into())
904 })?)?;
905 info!(
906 path = %t5_tokenizer_path.display(),
907 "downloading missing Flux tokenizer_2/tokenizer.json from base model"
908 );
909 download_file(
910 "Freepik/flux.1-lite-8B",
911 "tokenizer_2/tokenizer.json",
912 &t5_tokenizer_path,
913 )
914 .await?;
915 }
916 }
917 Ok(())
918}
919
920fn requires_full_mlx_snapshot(schema: &ModelSchema) -> bool {
921 match &schema.source {
922 ModelSource::Mlx { hf_repo, .. } => {
923 hf_repo == "ckurasek/Yume-1.5-5B-720P-MLX-4bit"
924 || schema.family.starts_with("yume")
925 || schema.tags.iter().any(|tag| {
926 matches!(
927 tag.as_str(),
928 "wan2.2" | "ti2v" | "world-model" | "image-to-video"
929 )
930 })
931 }
932 _ => false,
933 }
934}
935
936#[allow(dead_code)] fn huggingface_repo_has_snapshot(repo_id: &str) -> bool {
938 latest_huggingface_repo_snapshot(repo_id).is_some()
939}
940
941fn huggingface_cache_root() -> PathBuf {
942 std::env::var("HF_HOME")
943 .map(PathBuf::from)
944 .unwrap_or_else(|_| {
945 std::env::var("HOME")
946 .map(PathBuf::from)
947 .unwrap_or_else(|_| PathBuf::from("."))
948 .join(".cache")
949 .join("huggingface")
950 })
951 .join("hub")
952}
953
954fn huggingface_repo_dir(repo_id: &str) -> PathBuf {
955 huggingface_cache_root().join(format!("models--{}", repo_id.replace('/', "--")))
956}
957
958fn resolve_huggingface_ref_snapshot(repo_dir: &Path, name: &str) -> Option<PathBuf> {
959 let sha = std::fs::read_to_string(repo_dir.join("refs").join(name))
960 .ok()?
961 .trim()
962 .to_string();
963 if sha.is_empty() {
964 return None;
965 }
966
967 let snapshot = repo_dir.join("snapshots").join(sha);
968 if snapshot_looks_ready(&snapshot) {
969 Some(snapshot)
970 } else {
971 None
972 }
973}
974
975fn latest_huggingface_repo_snapshot(repo_id: &str) -> Option<PathBuf> {
976 let repo_dir = huggingface_repo_dir(repo_id);
977 if let Some(snapshot) = resolve_huggingface_ref_snapshot(&repo_dir, "main") {
978 return Some(snapshot);
979 }
980
981 let snapshots = repo_dir.join("snapshots");
982 let mut candidates: Vec<(SystemTime, PathBuf)> = std::fs::read_dir(snapshots)
983 .ok()?
984 .filter_map(Result::ok)
985 .map(|e| e.path())
986 .filter(|p| p.is_dir() && snapshot_looks_ready(p))
987 .map(|path| {
988 let modified = path
989 .metadata()
990 .and_then(|metadata| metadata.modified())
991 .unwrap_or(SystemTime::UNIX_EPOCH);
992 (modified, path)
993 })
994 .collect();
995 candidates.sort();
996 candidates.pop().map(|(_, path)| path)
997}
998
999fn snapshot_looks_ready(path: &Path) -> bool {
1000 if path.join("config.json").exists() || path.join("model_index.json").exists() {
1001 return true;
1002 }
1003 snapshot_contains_ext(path, "safetensors")
1004}
1005
1006fn snapshot_contains_ext(root: &Path, ext: &str) -> bool {
1007 let Ok(entries) = std::fs::read_dir(root) else {
1008 return false;
1009 };
1010 entries.filter_map(Result::ok).any(|entry| {
1011 let path = entry.path();
1012 if path.is_dir() {
1013 snapshot_contains_ext(&path, ext)
1014 } else {
1015 path.extension()
1016 .and_then(|value| value.to_str())
1017 .map(|value| value.eq_ignore_ascii_case(ext))
1018 .unwrap_or(false)
1019 }
1020 })
1021}
1022
1023async fn download_hf_repo_snapshot(repo_id: &str) -> Result<(PathBuf, usize), InferenceError> {
1024 let api = hf_hub::api::tokio::ApiBuilder::from_env()
1025 .with_progress(false)
1026 .build()
1027 .map_err(|e| InferenceError::DownloadFailed(format!("init hf api: {e}")))?;
1028 let repo = api.model(repo_id.to_string());
1029 let info = repo
1030 .info()
1031 .await
1032 .map_err(|e| InferenceError::DownloadFailed(format!("{repo_id}: {e}")))?;
1033
1034 let snapshot_path = std::env::var("HF_HOME")
1035 .map(PathBuf::from)
1036 .unwrap_or_else(|_| {
1037 std::env::var("HOME")
1038 .map(PathBuf::from)
1039 .unwrap_or_else(|_| PathBuf::from("."))
1040 .join(".cache")
1041 .join("huggingface")
1042 })
1043 .join("hub")
1044 .join(format!("models--{}", repo_id.replace('/', "--")))
1045 .join("snapshots")
1046 .join(&info.sha);
1047 let mut downloaded = 0usize;
1048 for sibling in &info.siblings {
1049 let local_path = snapshot_path.join(&sibling.rfilename);
1050 if local_path.exists() {
1051 downloaded += 1;
1052 continue;
1053 }
1054 repo.download(&sibling.rfilename).await.map_err(|e| {
1055 InferenceError::DownloadFailed(format!("{repo_id}/{}: {e}", sibling.rfilename))
1056 })?;
1057 downloaded += 1;
1058 }
1059
1060 Ok((snapshot_path, downloaded))
1061}
1062
1063const BUILTIN_CATALOG_JSON: &str = include_str!("builtin_catalog.json");
1072
1073static BUILTIN_CATALOG: std::sync::LazyLock<Vec<ModelSchema>> = std::sync::LazyLock::new(|| {
1074 serde_json::from_str(BUILTIN_CATALOG_JSON)
1075 .expect("builtin_catalog.json failed to parse — fix the JSON, not this code")
1076});
1077
1078fn builtin_catalog() -> Vec<ModelSchema> {
1079 BUILTIN_CATALOG.clone()
1080}
1081
1082#[cfg(test)]
1083mod tests {
1084 use super::*;
1085 use tempfile::TempDir;
1086
1087 fn test_registry() -> (UnifiedRegistry, TempDir) {
1088 let tmp = TempDir::new().unwrap();
1089 let reg = UnifiedRegistry::new(tmp.path().join("models"));
1090 (reg, tmp)
1091 }
1092
1093 #[test]
1094 fn builtin_catalog_loads() {
1095 let (reg, _tmp) = test_registry();
1096 let all = reg.list();
1097 assert_eq!(all.len(), builtin_catalog().len());
1098 }
1099
1100 #[test]
1113 fn mlx_vlm_models_reflect_runtime_availability() {
1114 let (reg, _tmp) = test_registry();
1115 let mlx_vlm_models: Vec<&ModelSchema> = reg
1116 .list()
1117 .into_iter()
1118 .filter(|m| m.tags.iter().any(|t| t == "requires-mlx-vlm"))
1119 .collect();
1120 assert!(
1121 !mlx_vlm_models.is_empty(),
1122 "catalog should contain at least one model tagged \
1123 `requires-mlx-vlm` — otherwise this regression has \
1124 nothing to guard"
1125 );
1126
1127 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1128 let expected = crate::backend::mlx_vlm_cli::is_available();
1129 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
1130 let expected = false;
1131
1132 for m in mlx_vlm_models {
1133 assert_eq!(
1134 m.available, expected,
1135 "model {} `available` field should reflect \
1136 mlx_vlm CLI presence (expected {expected}, got {})",
1137 m.id, m.available
1138 );
1139 }
1140 }
1141
1142 #[test]
1154 fn mlx_models_unavailable_on_non_mlx_targets() {
1155 let (reg, _tmp) = test_registry();
1156 let mlx_models: Vec<&ModelSchema> = reg
1157 .list()
1158 .into_iter()
1159 .filter(|m| {
1160 m.is_mlx()
1161 && !m.tags.iter().any(|t| t == "requires-mlx-vlm")
1166 && !m.tags.contains(&"speech".to_string())
1167 })
1168 .collect();
1169 assert!(
1170 !mlx_models.is_empty(),
1171 "catalog should contain at least one plain MLX model — \
1172 otherwise this F1 regression guard has nothing to guard"
1173 );
1174
1175 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1176 {
1177 let any_available = mlx_models.iter().any(|m| m.available);
1181 assert!(
1182 any_available,
1183 "on macOS arm64 with MLX enabled, at least one plain MLX \
1184 model with hf_repo should be available — none were"
1185 );
1186 }
1187 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
1188 {
1189 for m in &mlx_models {
1192 assert!(
1193 !m.available,
1194 "MLX model {} is marked available on a non-MLX target — \
1195 the adaptive router will add it to fallback chains \
1196 and dispatch will fail (Parslee-ai/car#231 §7.1)",
1197 m.id
1198 );
1199 }
1200 }
1201 }
1202
1203 #[test]
1206 fn builtin_catalog_json_parses() {
1207 let catalog: Vec<ModelSchema> = serde_json::from_str(BUILTIN_CATALOG_JSON)
1208 .expect("builtin_catalog.json must be valid ModelSchema array");
1209 assert!(
1210 !catalog.is_empty(),
1211 "embedded catalog has no entries — that's almost certainly wrong"
1212 );
1213
1214 let mut seen = std::collections::HashSet::new();
1215 for entry in &catalog {
1216 assert!(
1217 seen.insert(entry.id.clone()),
1218 "duplicate id in builtin_catalog.json: {}",
1219 entry.id
1220 );
1221 }
1222 }
1223
1224 #[test]
1225 fn public_benchmarks_round_trip_through_model_info() {
1226 use crate::schema::BenchmarkScore;
1227 let (mut reg, _tmp) = test_registry();
1228 let mut schema = reg
1229 .find_by_name("Qwen3-4B")
1230 .expect("catalog has Qwen3-4B")
1231 .clone();
1232 schema.id = "test/qwen3-4b-with-bench".into();
1233 schema.public_benchmarks = vec![
1234 BenchmarkScore {
1235 name: "MMLU-Pro".into(),
1236 score: 0.482,
1237 harness: Some("5-shot CoT".into()),
1238 source_url: Some("https://example.invalid/qwen3-4b-card".into()),
1239 measured_at: Some("2025-08-12".into()),
1240 },
1241 BenchmarkScore {
1242 name: "HumanEval".into(),
1243 score: 0.713,
1244 harness: Some("pass@1".into()),
1245 source_url: None,
1246 measured_at: None,
1247 },
1248 ];
1249 reg.register(schema);
1250
1251 let stored = reg
1252 .get("test/qwen3-4b-with-bench")
1253 .expect("registered model is retrievable");
1254 let info = ModelInfo::from(stored);
1255 assert_eq!(info.public_benchmarks.len(), 2);
1256
1257 let json = serde_json::to_string(&info).unwrap();
1259 assert!(json.contains("\"public_benchmarks\""));
1260 assert!(json.contains("\"MMLU-Pro\""));
1261 assert!(json.contains("\"5-shot CoT\""));
1262
1263 let decoded: ModelInfo = serde_json::from_str(&json).unwrap();
1265 assert_eq!(decoded.public_benchmarks.len(), 2);
1266 assert_eq!(decoded.public_benchmarks[0].name, "MMLU-Pro");
1267 assert_eq!(decoded.public_benchmarks[1].name, "HumanEval");
1268 }
1269
1270 #[test]
1271 fn public_benchmarks_default_to_empty_when_absent_in_json() {
1272 let legacy_json = r#"{
1275 "id": "legacy/test:1",
1276 "name": "Legacy Test",
1277 "provider": "test",
1278 "family": "test",
1279 "version": "",
1280 "capabilities": ["generate"],
1281 "context_length": 4096,
1282 "param_count": "1B",
1283 "quantization": null,
1284 "performance": {},
1285 "cost": {},
1286 "source": { "type": "ollama", "model_tag": "legacy:1" },
1287 "tags": [],
1288 "supported_params": []
1289 }"#;
1290 let schema: ModelSchema = serde_json::from_str(legacy_json).unwrap();
1291 assert!(schema.public_benchmarks.is_empty());
1292 }
1293
1294 #[test]
1295 fn find_by_name() {
1296 let (reg, _tmp) = test_registry();
1297 let m = reg.find_by_name("Qwen3-4B").unwrap();
1298 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1299 assert_eq!(m.id, "mlx/qwen3-4b:4bit");
1300 #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
1301 assert_eq!(m.id, "qwen/qwen3-4b:q4_k_m");
1302 assert!(m.has_capability(ModelCapability::Code));
1303 }
1304
1305 #[test]
1306 fn query_by_capability() {
1307 let (reg, _tmp) = test_registry();
1308 let embed_models = reg.query_by_capability(ModelCapability::Embed);
1309 assert_eq!(embed_models.len(), 2);
1310 assert!(embed_models
1311 .iter()
1312 .any(|model| model.name == "Qwen3-Embedding-0.6B"));
1313 assert!(embed_models
1314 .iter()
1315 .any(|model| model.name == "Qwen3-Embedding-0.6B-MLX"));
1316 }
1317
1318 #[test]
1319 fn query_with_filter() {
1320 let (reg, _tmp) = test_registry();
1321 let code_small = reg.query(&ModelFilter {
1322 capabilities: vec![ModelCapability::Code],
1323 max_size_mb: Some(3000),
1324 local_only: true,
1325 ..Default::default()
1326 });
1327 assert_eq!(code_small.len(), 4);
1329 }
1330
1331 #[test]
1332 fn register_remote() {
1333 let (mut reg, _tmp) = test_registry();
1334 let initial_len = reg.list().len();
1335 let initial_reasoning_len = reg
1336 .query(&ModelFilter {
1337 capabilities: vec![ModelCapability::Reasoning, ModelCapability::ToolUse],
1338 ..Default::default()
1339 })
1340 .len();
1341 let remote = ModelSchema {
1342 id: "anthropic/claude-sonnet-4-6:latest".into(),
1343 name: "Claude Sonnet 4.6".into(),
1344 provider: "anthropic".into(),
1345 family: "claude-4".into(),
1346 version: "latest".into(),
1347 capabilities: vec![
1348 ModelCapability::Generate,
1349 ModelCapability::Code,
1350 ModelCapability::Reasoning,
1351 ModelCapability::ToolUse,
1352 ],
1353 context_length: 200000,
1354 param_count: String::new(),
1355 quantization: None,
1356 performance: PerformanceEnvelope {
1357 latency_p50_ms: Some(2000),
1358 ..Default::default()
1359 },
1360 cost: CostModel {
1361 input_per_mtok: Some(3.0),
1362 output_per_mtok: Some(15.0),
1363 ..Default::default()
1364 },
1365 source: ModelSource::RemoteApi {
1366 endpoint: "https://api.anthropic.com/v1/messages".into(),
1367 api_key_env: "ANTHROPIC_API_KEY".into(),
1368 api_key_envs: vec![],
1369 api_version: Some("2023-06-01".into()),
1370 protocol: ApiProtocol::Anthropic,
1371 },
1372 tags: vec![],
1373 supported_params: vec![],
1374 public_benchmarks: vec![],
1375 trust_tier: crate::schema::TrustTier::Curated,
1376 deprecated: false,
1377 available: false,
1378 };
1379
1380 reg.register(remote);
1381 assert_eq!(reg.list().len(), initial_len);
1383
1384 let reasoning = reg.query(&ModelFilter {
1385 capabilities: vec![ModelCapability::Reasoning, ModelCapability::ToolUse],
1386 ..Default::default()
1387 });
1388 assert_eq!(reasoning.len(), initial_reasoning_len);
1390 }
1391
1392 #[test]
1393 fn unregister() {
1394 let (mut reg, _tmp) = test_registry();
1395 let initial_len = reg.list().len();
1396 let removed = reg.unregister("qwen/qwen3-0.6b:q8_0");
1397 assert!(removed.is_some());
1398 assert_eq!(reg.list().len(), initial_len - 1);
1399 }
1400
1401 #[test]
1402 fn speech_models_are_curated() {
1403 let (reg, _tmp) = test_registry();
1404 let stt = reg.query_by_capability(ModelCapability::SpeechToText);
1405 let tts = reg.query_by_capability(ModelCapability::TextToSpeech);
1406 assert_eq!(stt.len(), 2);
1407 assert_eq!(tts.len(), 4);
1408 }
1409
1410 #[test]
1411 fn qwen_8b_variants_keep_tool_use_consistent() {
1412 let (reg, _tmp) = test_registry();
1413 for name in ["Qwen3-8B", "Qwen3-8B-MLX"] {
1414 let model = reg.find_by_name(name).expect("model should exist");
1415 assert!(model.has_capability(ModelCapability::ToolUse));
1416 assert!(model.has_capability(ModelCapability::MultiToolCall));
1417 }
1418 }
1419
1420 #[test]
1421 fn mac_name_resolution_prefers_mlx_siblings() {
1422 #[allow(unused_variables)]
1425 let (reg, _tmp) = test_registry();
1426 #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1427 {
1428 assert_eq!(
1429 reg.find_by_name("Qwen3-0.6B").unwrap().id,
1430 "mlx/qwen3-0.6b:6bit"
1431 );
1432 assert_eq!(
1433 reg.find_by_name("Qwen3-1.7B").unwrap().id,
1434 "mlx/qwen3-1.7b:3bit"
1435 );
1436 assert_eq!(
1437 reg.find_by_name("Qwen3-Embedding-0.6B").unwrap().id,
1438 "mlx/qwen3-embedding-0.6b:mxfp8"
1439 );
1440 }
1441 }
1442
1443 #[test]
1444 fn remote_multimodal_models_are_curated_as_vision_capable() {
1445 let (reg, _tmp) = test_registry();
1446 for name in [
1447 "claude-opus-4-7",
1448 "claude-opus-4-6",
1449 "claude-sonnet-4-6",
1450 "claude-haiku-4-5",
1451 "gpt-5.4",
1452 "gpt-5.4-mini",
1453 "o3",
1454 "o4-mini",
1455 "gpt-4.1-mini",
1456 "gemini-2.5-pro",
1457 "gemini-2.5-flash",
1458 ] {
1459 let model = reg.find_by_name(name).expect("model should exist");
1460 assert!(
1461 model.has_capability(ModelCapability::Vision),
1462 "{name} should be curated as vision-capable"
1463 );
1464 }
1465 }
1466
1467 #[test]
1468 fn qwen25vl_entries_are_replaced_by_qwen3vl_in_builtin_catalog() {
1469 let (reg, _tmp) = test_registry();
1470
1471 let stale_ids = [
1472 "mlx/qwen2.5-vl-3b:4bit",
1474 "mlx/qwen2.5-vl-7b:4bit",
1475 "mlx-vlm/qwen2.5-vl-3b:4bit",
1478 "mlx-vlm/qwen2.5-vl-7b:4bit",
1479 "vllm-mlx/qwen2.5-vl-3b:4bit",
1481 ];
1482 for id in stale_ids {
1483 assert!(
1484 reg.get(id).is_none(),
1485 "{id} is superseded by Qwen3-VL; the catalog must not advertise it"
1486 );
1487 }
1488
1489 let vision_ids: Vec<&str> = reg
1490 .query_by_capability(ModelCapability::Vision)
1491 .into_iter()
1492 .map(|model| model.id.as_str())
1493 .collect();
1494 for stale in stale_ids {
1495 assert!(
1496 !vision_ids.contains(&stale),
1497 "{stale} must not be reachable through the Vision capability index"
1498 );
1499 }
1500 assert!(
1501 vision_ids.contains(&"mlx-vlm/qwen3-vl-2b:bf16"),
1502 "Qwen3-VL is the supported local VL family and must route as Vision"
1503 );
1504 }
1505
1506 #[test]
1507 fn gemini_models_are_curated_for_multimodal_tool_use() {
1508 let (reg, _tmp) = test_registry();
1509 for name in ["gemini-2.5-pro", "gemini-2.5-flash"] {
1510 let model = reg.find_by_name(name).expect("model should exist");
1511 assert!(model.has_capability(ModelCapability::Vision));
1512 assert!(model.has_capability(ModelCapability::ToolUse));
1513 assert!(model.has_capability(ModelCapability::MultiToolCall));
1514 }
1515 }
1516
1517 #[test]
1518 fn visual_generation_models_are_curated() {
1519 let (reg, _tmp) = test_registry();
1520 assert_eq!(
1521 reg.query_by_capability(ModelCapability::ImageGeneration)
1522 .len(),
1523 1
1524 );
1525 assert_eq!(
1526 reg.query_by_capability(ModelCapability::VideoGeneration)
1527 .len(),
1528 2
1529 );
1530 let yume = reg
1531 .get("mlx/yume-1.5-5b-720p:q4")
1532 .expect("Yume MLX should be in the built-in catalog");
1533 assert!(yume.has_capability(ModelCapability::VideoGeneration));
1534 assert!(yume.tags.contains(&"text-to-video".to_string()));
1535 assert!(yume.tags.contains(&"image-to-video".to_string()));
1536 assert!(yume.tags.contains(&"world-model".to_string()));
1537 }
1538}