Skip to main content

car_inference/
registry.rs

1//! Unified model registry — local and remote models under one schema.
2//!
3//! Replaces the hardcoded `ModelRegistry` from `models.rs` with a schema-driven
4//! registry that treats all models as first-class typed resources. Users can
5//! register custom models (fine-tuned endpoints, private APIs) alongside the
6//! built-in catalog.
7
8use crate::schema::reasoning_params;
9use std::collections::HashMap;
10use std::path::{Path, PathBuf};
11use std::time::SystemTime;
12
13use serde::{Deserialize, Serialize};
14use tracing::info;
15
16use crate::schema::*;
17use crate::InferenceError;
18
19/// Filter for querying the registry.
20#[derive(Debug, Clone, Default)]
21pub struct ModelFilter {
22    /// Required capabilities (model must have ALL of these).
23    pub capabilities: Vec<ModelCapability>,
24    /// Maximum on-disk / RAM size in MB.
25    pub max_size_mb: Option<u64>,
26    /// Maximum expected latency in ms (from declared envelope).
27    pub max_latency_ms: Option<u64>,
28    /// Maximum cost per 1M output tokens in USD.
29    pub max_cost_per_mtok: Option<f64>,
30    /// Required tags (model must have ALL of these).
31    pub tags: Vec<String>,
32    /// Filter by provider.
33    pub provider: Option<String>,
34    /// Only local models.
35    pub local_only: bool,
36    /// Only models that are currently available.
37    pub available_only: bool,
38}
39
40/// A curated replacement for a local model that is installed but no longer
41/// the preferred model in its line.
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct ModelUpgrade {
44    pub from_id: String,
45    pub from_name: String,
46    pub to_id: String,
47    pub to_name: String,
48    pub reason: String,
49    pub target_runtime: Option<String>,
50    pub target_runtime_requirement: Option<String>,
51    pub minimum_runtimes: Vec<ModelRuntimeRequirement>,
52    pub target_available: bool,
53    pub target_pullable: bool,
54    pub remove_old_supported: bool,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct ModelRuntimeRequirement {
59    pub name: String,
60    pub minimum_version: String,
61}
62
63/// Unified registry of all known models.
64pub struct UnifiedRegistry {
65    models_dir: PathBuf,
66    /// All registered models, keyed by id.
67    models: HashMap<String, ModelSchema>,
68    /// User-added model config file path (~/.car/models.json).
69    user_config_path: PathBuf,
70}
71
72#[derive(Debug, Clone, Deserialize)]
73struct ModelUpgradeRule {
74    from_ids: Vec<String>,
75    to_id: String,
76    reason: String,
77    target_runtime: Option<String>,
78    target_runtime_requirement: Option<String>,
79    #[serde(default)]
80    minimum_runtimes: Vec<ModelRuntimeRequirement>,
81    #[serde(default = "default_remove_old_after_available")]
82    remove_old_after_available: bool,
83}
84
85fn default_remove_old_after_available() -> bool {
86    true
87}
88
89fn model_upgrade_rules() -> Vec<ModelUpgradeRule> {
90    serde_json::from_str(include_str!("../assets/model-upgrades.json"))
91        .expect("built-in model-upgrades.json should parse")
92}
93
94impl UnifiedRegistry {
95    pub fn new(models_dir: PathBuf) -> Self {
96        let user_config_path = models_dir
97            .parent()
98            .unwrap_or(&models_dir)
99            .join("models.json");
100
101        let mut registry = Self {
102            models_dir,
103            models: HashMap::new(),
104            user_config_path,
105        };
106        registry.load_builtin_catalog();
107        registry.refresh_availability();
108        // Load user config on top (silently ignore if missing)
109        let _ = registry.load_user_config();
110        registry
111    }
112
113    /// Register a model at runtime.
114    pub fn register(&mut self, mut schema: ModelSchema) {
115        // Check availability for local models
116        if schema.is_mlx() {
117            schema.available = if schema.tags.contains(&"speech".to_string()) {
118                speech_mlx_available()
119            } else if let ModelSource::Mlx { ref hf_repo, .. } = schema.source {
120                // Available if cached locally OR has an hf_repo —
121                // ensure_local() lazy-downloads on first use, so a
122                // declared hf_repo is "functionally available" the
123                // same way Ollama/RemoteApi entries are. Mirrors the
124                // refresh_availability() check below; see #164.
125                let mlx_dir = self.models_dir.join(&schema.name);
126                mlx_dir.join("config.json").exists() || !hf_repo.is_empty()
127            } else {
128                let mlx_dir = self.models_dir.join(&schema.name);
129                mlx_dir.join("config.json").exists()
130            };
131        } else if schema.is_vllm_mlx() {
132            // vLLM-MLX: available if endpoint env var set or was manually marked available
133            schema.available = std::env::var("VLLM_MLX_ENDPOINT").is_ok() || schema.available;
134        } else if schema.is_local() {
135            let local_path = self.models_dir.join(&schema.name).join("model.gguf");
136            schema.available = local_path.exists();
137        } else if schema.is_remote() {
138            // Remote models are assumed available if the env var exists
139            if let ModelSource::RemoteApi {
140                ref api_key_env, ..
141            } = schema.source
142            {
143                schema.available = std::env::var(api_key_env).is_ok();
144            }
145        }
146        info!(id = %schema.id, name = %schema.name, available = schema.available, "registered model");
147        self.models.insert(schema.id.clone(), schema);
148    }
149
150    /// Unregister a model by id. Returns the removed schema if found.
151    pub fn unregister(&mut self, id: &str) -> Option<ModelSchema> {
152        let removed = self.models.remove(id);
153        if let Some(ref m) = removed {
154            info!(id = %m.id, "unregistered model");
155        }
156        removed
157    }
158
159    /// List all models.
160    pub fn list(&self) -> Vec<&ModelSchema> {
161        let mut models: Vec<&ModelSchema> = self.models.values().collect();
162        models.sort_by(|a, b| a.id.cmp(&b.id));
163        models
164    }
165
166    /// Query models matching a filter.
167    pub fn query(&self, filter: &ModelFilter) -> Vec<&ModelSchema> {
168        self.models
169            .values()
170            .filter(|m| {
171                // Capability check: model must have ALL required capabilities
172                if !filter.capabilities.iter().all(|c| m.has_capability(*c)) {
173                    return false;
174                }
175                // Size check
176                if let Some(max) = filter.max_size_mb {
177                    if m.size_mb() > max && m.is_local() {
178                        return false;
179                    }
180                }
181                // Latency check (declared envelope)
182                if let Some(max) = filter.max_latency_ms {
183                    if let Some(p50) = m.performance.latency_p50_ms {
184                        if p50 > max {
185                            return false;
186                        }
187                    }
188                }
189                // Cost check
190                if let Some(max) = filter.max_cost_per_mtok {
191                    if let Some(cost) = m.cost.output_per_mtok {
192                        if cost > max {
193                            return false;
194                        }
195                    }
196                }
197                // Tag check
198                if !filter.tags.iter().all(|t| m.tags.contains(t)) {
199                    return false;
200                }
201                // Provider check
202                if let Some(ref p) = filter.provider {
203                    if &m.provider != p {
204                        return false;
205                    }
206                }
207                // Local only
208                if filter.local_only && !m.is_local() {
209                    return false;
210                }
211                // Available only
212                if filter.available_only && !m.available {
213                    return false;
214                }
215                true
216            })
217            .collect()
218    }
219
220    /// Query models by a single capability.
221    pub fn query_by_capability(&self, cap: ModelCapability) -> Vec<&ModelSchema> {
222        self.query(&ModelFilter {
223            capabilities: vec![cap],
224            ..Default::default()
225        })
226    }
227
228    /// Report installed local models with curated newer replacements.
229    pub fn available_upgrades(&self) -> Vec<ModelUpgrade> {
230        let mut upgrades = Vec::new();
231        for rule in model_upgrade_rules() {
232            let Some(from) = rule
233                .from_ids
234                .iter()
235                .find_map(|id| self.models.get(id.as_str()))
236                .filter(|schema| schema.available)
237            else {
238                continue;
239            };
240            let Some(to) = self.models.get(rule.to_id.as_str()) else {
241                continue;
242            };
243            upgrades.push(ModelUpgrade {
244                from_id: from.id.clone(),
245                from_name: from.name.clone(),
246                to_id: to.id.clone(),
247                to_name: to.name.clone(),
248                reason: rule.reason.clone(),
249                target_runtime: rule.target_runtime.clone(),
250                target_runtime_requirement: rule.target_runtime_requirement.clone(),
251                minimum_runtimes: rule.minimum_runtimes.clone(),
252                target_available: to.available,
253                target_pullable: matches!(
254                    to.source,
255                    ModelSource::Local { .. } | ModelSource::Mlx { .. }
256                ),
257                remove_old_supported: matches!(
258                    from.source,
259                    ModelSource::Local { .. } | ModelSource::Mlx { .. }
260                ) && rule.remove_old_after_available,
261            });
262        }
263        upgrades.sort_by(|a, b| a.from_id.cmp(&b.from_id).then(a.to_id.cmp(&b.to_id)));
264        upgrades.dedup_by(|a, b| a.from_id == b.from_id && a.to_id == b.to_id);
265        upgrades
266    }
267
268    /// Get a specific model by id.
269    pub fn get(&self, id: &str) -> Option<&ModelSchema> {
270        self.models.get(id)
271    }
272
273    /// Find a model by name (case-insensitive). For backward compatibility
274    /// with the old registry that used short names like "Qwen3-4B".
275    pub fn find_by_name(&self, name: &str) -> Option<&ModelSchema> {
276        #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
277        if !name.to_ascii_lowercase().ends_with("-mlx") {
278            if let Some(mlx_variant) = self
279                .models
280                .values()
281                .find(|m| m.name.eq_ignore_ascii_case(&format!("{name}-MLX")))
282            {
283                return Some(mlx_variant);
284            }
285        }
286
287        self.models
288            .values()
289            .find(|m| m.name.eq_ignore_ascii_case(name))
290    }
291
292    /// On Apple Silicon, resolve a GGUF/Candle model to its MLX equivalent.
293    /// Returns the MLX model schema if one exists with the same family and
294    /// matching capabilities; otherwise returns None.
295    #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
296    pub fn resolve_mlx_equivalent(&self, schema: &ModelSchema) -> Option<&ModelSchema> {
297        // Already MLX — no redirect needed.
298        if schema.is_mlx() || schema.is_vllm_mlx() {
299            return None;
300        }
301        // Only redirect local GGUF models.
302        if !matches!(schema.source, ModelSource::Local { .. }) {
303            return None;
304        }
305        // Find an MLX model in the same family with at least the same primary capability.
306        let primary_cap = schema.capabilities.first()?;
307        self.models.values().find(|m| {
308            m.is_mlx()
309                && m.family == schema.family
310                && m.capabilities.contains(primary_cap)
311        })
312    }
313
314    /// Ensure a local model is downloaded, returning its local directory path.
315    pub async fn ensure_local(&self, id: &str) -> Result<PathBuf, InferenceError> {
316        let schema = self
317            .get(id)
318            .or_else(|| self.find_by_name(id))
319            .ok_or_else(|| InferenceError::ModelNotFound(id.to_string()))?;
320
321        match &schema.source {
322            ModelSource::Local {
323                hf_repo,
324                hf_filename,
325                tokenizer_repo,
326            } => {
327                let model_dir = self.models_dir.join(&schema.name);
328                let model_path = model_dir.join("model.gguf");
329                let tokenizer_path = model_dir.join("tokenizer.json");
330
331                if model_path.exists() && tokenizer_path.exists() {
332                    return Ok(model_dir);
333                }
334
335                std::fs::create_dir_all(&model_dir)?;
336
337                if !model_path.exists() {
338                    info!(model = %schema.name, repo = %hf_repo, "downloading model weights");
339                    download_file(hf_repo, hf_filename, &model_path).await?;
340                }
341                if !tokenizer_path.exists() {
342                    info!(model = %schema.name, repo = %tokenizer_repo, "downloading tokenizer");
343                    download_file(tokenizer_repo, "tokenizer.json", &tokenizer_path).await?;
344                }
345
346                Ok(model_dir)
347            }
348            ModelSource::Mlx {
349                hf_repo,
350                hf_weight_file,
351            } => {
352                let model_dir = self.models_dir.join(&schema.name);
353                let config_path = model_dir.join("config.json");
354
355                if config_path.exists() {
356                    ensure_auxiliary_mlx_files(&schema.name, hf_repo, &model_dir).await?;
357                    info!(model = %schema.name, path = %model_dir.display(), "using managed local MLX model");
358                    return Ok(model_dir);
359                }
360
361                if let Some(snapshot_dir) = latest_huggingface_repo_snapshot(hf_repo) {
362                    ensure_auxiliary_mlx_files(&schema.name, hf_repo, &snapshot_dir).await?;
363                    info!(model = %schema.name, path = %snapshot_dir.display(), "using cached MLX snapshot");
364                    return Ok(snapshot_dir);
365                }
366
367                if requires_full_mlx_snapshot(&schema) {
368                    info!(
369                        model = %schema.name,
370                        repo = %hf_repo,
371                        "downloading full MLX snapshot"
372                    );
373                    let (snapshot_dir, _files_downloaded) = download_hf_repo_snapshot(hf_repo).await?;
374                    ensure_auxiliary_mlx_files(&schema.name, hf_repo, &snapshot_dir).await?;
375                    return Ok(snapshot_dir);
376                }
377
378                std::fs::create_dir_all(&model_dir)?;
379
380                info!(model = %schema.name, repo = %hf_repo, "downloading MLX model");
381
382                // Download config, tokenizer, and weight files
383                download_file(hf_repo, "config.json", &config_path).await?;
384                let tok_path = model_dir.join("tokenizer.json");
385                if !tok_path.exists() {
386                    download_file(hf_repo, "tokenizer.json", &tok_path).await?;
387                }
388                let tok_config_path = model_dir.join("tokenizer_config.json");
389                if !tok_config_path.exists() {
390                    let _ = download_file(hf_repo, "tokenizer_config.json", &tok_config_path).await;
391                }
392
393                // Download weight files
394                if let Some(ref wf) = hf_weight_file {
395                    let wf_path = model_dir.join(wf);
396                    if !wf_path.exists() {
397                        download_file(hf_repo, wf, &wf_path).await?;
398                    }
399                } else {
400                    // Try single file first, then sharded
401                    let single = model_dir.join("model.safetensors");
402                    if !single.exists() {
403                        match download_file(hf_repo, "model.safetensors", &single).await {
404                            Ok(()) => {}
405                            Err(_) => {
406                                // Sharded: download index and then each shard
407                                let index_path = model_dir.join("model.safetensors.index.json");
408                                download_file(hf_repo, "model.safetensors.index.json", &index_path)
409                                    .await?;
410
411                                let index_json: serde_json::Value =
412                                    serde_json::from_str(&std::fs::read_to_string(&index_path)?)
413                                        .map_err(|e| {
414                                            InferenceError::InferenceFailed(format!(
415                                                "parse index: {e}"
416                                            ))
417                                        })?;
418
419                                if let Some(weight_map) =
420                                    index_json.get("weight_map").and_then(|m| m.as_object())
421                                {
422                                    let mut files: std::collections::HashSet<String> =
423                                        std::collections::HashSet::new();
424                                    for filename in weight_map.values() {
425                                        if let Some(f) = filename.as_str() {
426                                            files.insert(f.to_string());
427                                        }
428                                    }
429                                    for file in &files {
430                                        let dest = model_dir.join(file);
431                                        if !dest.exists() {
432                                            info!(file = %file, "downloading weight shard");
433                                            download_file(hf_repo, file, &dest).await?;
434                                        }
435                                    }
436                                }
437                            }
438                        }
439                    }
440                }
441
442                ensure_auxiliary_mlx_files(&schema.name, hf_repo, &model_dir).await?;
443                Ok(model_dir)
444            }
445            _ => Err(InferenceError::InferenceFailed(format!(
446                "model {} is not local",
447                id
448            ))),
449        }
450    }
451
452    /// Remove a downloaded local model.
453    pub fn remove_local(&mut self, id: &str) -> Result<(), InferenceError> {
454        let schema = self
455            .get(id)
456            .or_else(|| self.find_by_name(id))
457            .ok_or_else(|| InferenceError::ModelNotFound(id.to_string()))?;
458
459        let model_dir = self.models_dir.join(&schema.name);
460        if model_dir.exists() {
461            std::fs::remove_dir_all(&model_dir)?;
462            info!(model = %schema.name, "removed model");
463        }
464
465        match &schema.source {
466            ModelSource::Mlx { hf_repo, .. } => {
467                let repo_dir = huggingface_repo_dir(hf_repo);
468                if repo_dir.exists() {
469                    std::fs::remove_dir_all(&repo_dir)?;
470                    info!(model = %schema.name, repo = %hf_repo, "removed Hugging Face cache");
471                }
472            }
473            ModelSource::Local {
474                hf_repo,
475                tokenizer_repo,
476                ..
477            } => {
478                for repo in [hf_repo, tokenizer_repo] {
479                    let repo_dir = huggingface_repo_dir(repo);
480                    if repo_dir.exists() {
481                        std::fs::remove_dir_all(&repo_dir)?;
482                        info!(model = %schema.name, repo = %repo, "removed Hugging Face cache");
483                    }
484                }
485            }
486            _ => {}
487        }
488
489        // Update availability
490        let id = schema.id.clone();
491        if let Some(m) = self.models.get_mut(&id) {
492            m.available = false;
493        }
494        Ok(())
495    }
496
497    /// Refresh availability flags for all models.
498    ///
499    /// Runtime-true vs catalog-says: this is what closes the gap
500    /// `models.list_unified` callers rely on. If a model is listed
501    /// as `available: true` here, an `infer` call against it should
502    /// reach the backend, not bail with `UnsupportedMode { ...
503    /// "mlx-vlm CLI not found on PATH" }` (the #137 trap).
504    pub fn refresh_availability(&mut self) {
505        let models_dir = self.models_dir.clone();
506        // mlx-vlm CLI is the same probe call no matter which model
507        // requires it; do it once per refresh, not per-model.
508        #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
509        let mlx_vlm_cli_present = crate::backend::mlx_vlm_cli::is_available();
510        #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
511        let mlx_vlm_cli_present = false;
512
513        for m in self.models.values_mut() {
514            match &m.source {
515                ModelSource::Mlx { hf_repo, .. } => {
516                    // Models tagged `requires-mlx-vlm` shell out to the
517                    // mlx_vlm Python CLI for image inference (#115).
518                    // If the CLI isn't on PATH, the runtime reaches it
519                    // anyway and bails — the registry MUST reflect that
520                    // by marking such entries unavailable until the
521                    // user installs `uv tool install mlx-vlm`. #137.
522                    let needs_mlx_vlm =
523                        m.tags.iter().any(|t| t == "requires-mlx-vlm");
524
525                    m.available = if needs_mlx_vlm {
526                        mlx_vlm_cli_present
527                    } else if m.tags.contains(&"speech".to_string()) {
528                        speech_mlx_available()
529                    } else {
530                        // Available if cached locally OR has an hf_repo —
531                        // the native MLX path's ensure_local() lazy-
532                        // downloads on first use, so a declared hf_repo
533                        // is "functionally available" the same way
534                        // Ollama and RemoteApi entries are ("should
535                        // work in principle" not "physically cached").
536                        // Closes #164: mlx/ltx-2.3:q4 reported
537                        // unavailable even though `car video` would
538                        // just download-and-run successfully.
539                        let mlx_dir = models_dir.join(&m.name);
540                        mlx_dir.join("config.json").exists() || !hf_repo.is_empty()
541                    };
542                }
543                ModelSource::Local { .. } => {
544                    let local_path = models_dir.join(&m.name).join("model.gguf");
545                    m.available = local_path.exists();
546                }
547                ModelSource::RemoteApi { api_key_env, .. } => {
548                    m.available = std::env::var(api_key_env).is_ok();
549                }
550                ModelSource::Ollama { .. } => {
551                    // Assume available; health check is async and done lazily
552                    m.available = true;
553                }
554                ModelSource::VllmMlx { .. } => {
555                    // vLLM-MLX availability checked via health endpoint lazily
556                    // Mark as available if VLLM_MLX_ENDPOINT env var is set or default endpoint assumed
557                    m.available = std::env::var("VLLM_MLX_ENDPOINT").is_ok() || m.available;
558                    // preserve manual registration
559                }
560                ModelSource::Proprietary { auth, .. } => {
561                    // Check auth availability
562                    m.available = match auth {
563                        crate::schema::ProprietaryAuth::ApiKeyEnv { env_var } => {
564                            std::env::var(env_var).is_ok()
565                        }
566                        crate::schema::ProprietaryAuth::BearerTokenEnv { env_var } => {
567                            std::env::var(env_var).is_ok()
568                        }
569                        crate::schema::ProprietaryAuth::OAuth2Pkce { .. } => {
570                            // OAuth2 availability determined at runtime by token provider
571                            true
572                        }
573                    };
574                }
575                ModelSource::AppleFoundationModels { .. } => {
576                    // Apple Silicon macOS 26+ AND iOS 26+ both expose
577                    // the FoundationModels framework. The shim's
578                    // runtime probe handles per-device availability
579                    // (Apple Intelligence may be off, the device may
580                    // be pre-A17, etc.); cfg-gating here just hides
581                    // the call on targets where the shim isn't built.
582                    #[cfg(any(
583                        all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)),
584                        all(target_os = "ios", target_arch = "aarch64")
585                    ))]
586                    {
587                        m.available = crate::backend::foundation_models::is_available();
588                    }
589                    #[cfg(not(any(
590                        all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)),
591                        all(target_os = "ios", target_arch = "aarch64")
592                    )))]
593                    {
594                        m.available = false;
595                    }
596                }
597                ModelSource::Delegated { .. } => {
598                    // Availability tracks whether a runner is registered.
599                    // Hosts call `registerInferenceRunner` (or its
600                    // language equivalent) at startup; until then the
601                    // model is unavailable.
602                    m.available = crate::runner::current_inference_runner().is_some();
603                }
604            }
605        }
606    }
607
608    /// Persist user-registered (non-builtin) models to disk.
609    pub fn save_user_config(&self) -> Result<(), InferenceError> {
610        let user_models: Vec<&ModelSchema> = self
611            .models
612            .values()
613            .filter(|m| !m.tags.contains(&"builtin".to_string()))
614            .collect();
615
616        if user_models.is_empty() {
617            return Ok(());
618        }
619
620        let json = serde_json::to_string_pretty(&user_models)
621            .map_err(|e| InferenceError::InferenceFailed(format!("serialize: {e}")))?;
622        std::fs::write(&self.user_config_path, json)?;
623        Ok(())
624    }
625
626    /// Load user-registered models from disk.
627    pub fn load_user_config(&mut self) -> Result<(), InferenceError> {
628        if !self.user_config_path.exists() {
629            return Ok(());
630        }
631
632        let json = std::fs::read_to_string(&self.user_config_path)?;
633        let models: Vec<ModelSchema> = serde_json::from_str(&json)
634            .map_err(|e| InferenceError::InferenceFailed(format!("parse models.json: {e}")))?;
635
636        for m in models {
637            self.register(m);
638        }
639        Ok(())
640    }
641
642    /// Get the models directory path.
643    pub fn models_dir(&self) -> &Path {
644        &self.models_dir
645    }
646
647    /// Load the built-in Qwen3 catalog as ModelSchema objects.
648    fn load_builtin_catalog(&mut self) {
649        for schema in builtin_catalog() {
650            self.models.insert(schema.id.clone(), schema);
651        }
652    }
653}
654
655fn speech_mlx_available() -> bool {
656    // On Apple Silicon, speech uses native MLX backends — no Python CLI needed.
657    // Models are available if we're on the right platform (weights are downloaded on demand).
658    #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
659    { true }
660
661    // On other platforms, check for the Python mlx-audio CLI.
662    #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
663    {
664        let runtime_root = speech_runtime_root();
665        runtime_root.join("bin").join("mlx_audio.stt.generate").exists()
666            || runtime_root.join("bin").join("mlx_audio.tts.generate").exists()
667    }
668}
669
670fn speech_runtime_root() -> PathBuf {
671    if let Ok(path) = std::env::var("CAR_SPEECH_RUNTIME_DIR") {
672        if !path.trim().is_empty() {
673            return PathBuf::from(path);
674        }
675    }
676    std::env::var("HOME")
677        .map(PathBuf::from)
678        .unwrap_or_else(|_| PathBuf::from("."))
679        .join(".car")
680        .join("speech-runtime")
681}
682
683/// Backward-compatible ModelInfo for listing (used by CLI and old callers).
684#[derive(Debug, Clone, Serialize, Deserialize)]
685pub struct ModelInfo {
686    pub id: String,
687    pub name: String,
688    pub provider: String,
689    pub capabilities: Vec<ModelCapability>,
690    pub param_count: String,
691    pub size_mb: u64,
692    pub context_length: usize,
693    pub available: bool,
694    pub is_local: bool,
695    /// Public benchmark scores carried straight through from `ModelSchema`.
696    /// The built-in catalog ships this empty; populating it is a curation
697    /// step (see `BenchmarkScore` in the schema for shape and conventions).
698    #[serde(default)]
699    pub public_benchmarks: Vec<crate::schema::BenchmarkScore>,
700}
701
702impl From<&ModelSchema> for ModelInfo {
703    fn from(s: &ModelSchema) -> Self {
704        ModelInfo {
705            id: s.id.clone(),
706            name: s.name.clone(),
707            provider: s.provider.clone(),
708            capabilities: s.capabilities.clone(),
709            param_count: s.param_count.clone(),
710            size_mb: s.size_mb(),
711            context_length: s.context_length,
712            available: s.available,
713            is_local: s.is_local(),
714            public_benchmarks: s.public_benchmarks.clone(),
715        }
716    }
717}
718
719/// Download a single file from a HuggingFace repo.
720async fn download_file(repo: &str, filename: &str, dest: &Path) -> Result<(), InferenceError> {
721    let api = hf_hub::api::tokio::Api::new()
722        .map_err(|e| InferenceError::DownloadFailed(e.to_string()))?;
723
724    let repo = api.model(repo.to_string());
725    let path = repo
726        .get(filename)
727        .await
728        .map_err(|e| InferenceError::DownloadFailed(format!("{filename}: {e}")))?;
729
730    if dest.exists() {
731        return Ok(());
732    }
733
734    // Try symlink first, fall back to copy
735    #[cfg(unix)]
736    {
737        if std::os::unix::fs::symlink(&path, dest).is_ok() {
738            return Ok(());
739        }
740    }
741
742    std::fs::copy(&path, dest)
743        .map_err(|e| InferenceError::DownloadFailed(format!("copy to {}: {e}", dest.display())))?;
744    Ok(())
745}
746
747async fn ensure_auxiliary_mlx_files(
748    model_name: &str,
749    hf_repo: &str,
750    model_dir: &Path,
751) -> Result<(), InferenceError> {
752    if hf_repo == "mlx-community/Flux-1.lite-8B-MLX-Q4" || model_name == "Flux-1.lite-8B-MLX-Q4" {
753        let t5_tokenizer_path = model_dir.join("tokenizer_2").join("tokenizer.json");
754        if !t5_tokenizer_path.exists() {
755            std::fs::create_dir_all(
756                t5_tokenizer_path
757                    .parent()
758                    .ok_or_else(|| InferenceError::InferenceFailed("invalid tokenizer path".into()))?,
759            )?;
760            info!(
761                path = %t5_tokenizer_path.display(),
762                "downloading missing Flux tokenizer_2/tokenizer.json from base model"
763            );
764            download_file("Freepik/flux.1-lite-8B", "tokenizer_2/tokenizer.json", &t5_tokenizer_path)
765                .await?;
766        }
767    }
768    Ok(())
769}
770
771fn requires_full_mlx_snapshot(schema: &ModelSchema) -> bool {
772    match &schema.source {
773        ModelSource::Mlx { hf_repo, .. } => {
774            hf_repo == "ckurasek/Yume-1.5-5B-720P-MLX-4bit"
775                || schema.family.starts_with("yume")
776                || schema.tags.iter().any(|tag| {
777                    matches!(
778                        tag.as_str(),
779                        "wan2.2" | "ti2v" | "world-model" | "image-to-video"
780                    )
781                })
782        }
783        _ => false,
784    }
785}
786
787fn huggingface_repo_has_snapshot(repo_id: &str) -> bool {
788    latest_huggingface_repo_snapshot(repo_id).is_some()
789}
790
791fn huggingface_cache_root() -> PathBuf {
792    std::env::var("HF_HOME")
793        .map(PathBuf::from)
794        .unwrap_or_else(|_| {
795            std::env::var("HOME")
796                .map(PathBuf::from)
797                .unwrap_or_else(|_| PathBuf::from("."))
798                .join(".cache")
799                .join("huggingface")
800        })
801        .join("hub")
802}
803
804fn huggingface_repo_dir(repo_id: &str) -> PathBuf {
805    huggingface_cache_root().join(format!("models--{}", repo_id.replace('/', "--")))
806}
807
808fn resolve_huggingface_ref_snapshot(repo_dir: &Path, name: &str) -> Option<PathBuf> {
809    let sha = std::fs::read_to_string(repo_dir.join("refs").join(name))
810        .ok()?
811        .trim()
812        .to_string();
813    if sha.is_empty() {
814        return None;
815    }
816
817    let snapshot = repo_dir.join("snapshots").join(sha);
818    if snapshot_looks_ready(&snapshot) {
819        Some(snapshot)
820    } else {
821        None
822    }
823}
824
825fn latest_huggingface_repo_snapshot(repo_id: &str) -> Option<PathBuf> {
826    let repo_dir = huggingface_repo_dir(repo_id);
827    if let Some(snapshot) = resolve_huggingface_ref_snapshot(&repo_dir, "main") {
828        return Some(snapshot);
829    }
830
831    let snapshots = repo_dir.join("snapshots");
832    let mut candidates: Vec<(SystemTime, PathBuf)> = std::fs::read_dir(snapshots)
833        .ok()?
834        .filter_map(Result::ok)
835        .map(|e| e.path())
836        .filter(|p| p.is_dir() && snapshot_looks_ready(p))
837        .map(|path| {
838            let modified = path
839                .metadata()
840                .and_then(|metadata| metadata.modified())
841                .unwrap_or(SystemTime::UNIX_EPOCH);
842            (modified, path)
843        })
844        .collect();
845    candidates.sort();
846    candidates.pop().map(|(_, path)| path)
847}
848
849fn snapshot_looks_ready(path: &Path) -> bool {
850    if path.join("config.json").exists() || path.join("model_index.json").exists() {
851        return true;
852    }
853    snapshot_contains_ext(path, "safetensors")
854}
855
856fn snapshot_contains_ext(root: &Path, ext: &str) -> bool {
857    let Ok(entries) = std::fs::read_dir(root) else {
858        return false;
859    };
860    entries.filter_map(Result::ok).any(|entry| {
861        let path = entry.path();
862        if path.is_dir() {
863            snapshot_contains_ext(&path, ext)
864        } else {
865            path.extension()
866                .and_then(|value| value.to_str())
867                .map(|value| value.eq_ignore_ascii_case(ext))
868                .unwrap_or(false)
869        }
870    })
871}
872
873async fn download_hf_repo_snapshot(repo_id: &str) -> Result<(PathBuf, usize), InferenceError> {
874    let api = hf_hub::api::tokio::ApiBuilder::from_env()
875        .with_progress(false)
876        .build()
877        .map_err(|e| InferenceError::DownloadFailed(format!("init hf api: {e}")))?;
878    let repo = api.model(repo_id.to_string());
879    let info = repo
880        .info()
881        .await
882        .map_err(|e| InferenceError::DownloadFailed(format!("{repo_id}: {e}")))?;
883
884    let snapshot_path = std::env::var("HF_HOME")
885        .map(PathBuf::from)
886        .unwrap_or_else(|_| {
887            std::env::var("HOME")
888                .map(PathBuf::from)
889                .unwrap_or_else(|_| PathBuf::from("."))
890                .join(".cache")
891                .join("huggingface")
892        })
893        .join("hub")
894        .join(format!("models--{}", repo_id.replace('/', "--")))
895        .join("snapshots")
896        .join(&info.sha);
897    let mut downloaded = 0usize;
898    for sibling in &info.siblings {
899        let local_path = snapshot_path.join(&sibling.rfilename);
900        if local_path.exists() {
901            downloaded += 1;
902            continue;
903        }
904        repo.download(&sibling.rfilename).await.map_err(|e| {
905            InferenceError::DownloadFailed(format!("{repo_id}/{}: {e}", sibling.rfilename))
906        })?;
907        downloaded += 1;
908    }
909
910    Ok((snapshot_path, downloaded))
911}
912
913/// Built-in catalog parsed from `builtin_catalog.json`.
914///
915/// Adding, removing, or editing a model is a JSON-only change — Rust
916/// source stays put. The JSON is embedded at compile time via
917/// `include_str!`, parsed once into a `LazyLock`, and cloned on each
918/// call. A malformed JSON file fails the integration test
919/// `builtin_catalog_json_parses` so the binary never ships unable
920/// to load its own catalog.
921const BUILTIN_CATALOG_JSON: &str = include_str!("builtin_catalog.json");
922
923static BUILTIN_CATALOG: std::sync::LazyLock<Vec<ModelSchema>> =
924    std::sync::LazyLock::new(|| {
925        serde_json::from_str(BUILTIN_CATALOG_JSON)
926            .expect("builtin_catalog.json failed to parse — fix the JSON, not this code")
927    });
928
929fn builtin_catalog() -> Vec<ModelSchema> {
930    BUILTIN_CATALOG.clone()
931}
932
933#[cfg(test)]
934mod tests {
935    use super::*;
936    use tempfile::TempDir;
937
938    fn test_registry() -> (UnifiedRegistry, TempDir) {
939        let tmp = TempDir::new().unwrap();
940        let reg = UnifiedRegistry::new(tmp.path().join("models"));
941        (reg, tmp)
942    }
943
944    #[test]
945    fn builtin_catalog_loads() {
946        let (reg, _tmp) = test_registry();
947        let all = reg.list();
948        assert_eq!(all.len(), builtin_catalog().len());
949    }
950
951    /// #137: a model tagged `requires-mlx-vlm` must report
952    /// `available: true` if and only if `mlx_vlm_cli::is_available()`
953    /// returns true. Without this, registry consumers (FFI
954    /// `listModelsUnified`, the tray Models submenu, agent routing)
955    /// see a model as available, the user picks it, and inference
956    /// bails with `mlx-vlm CLI not found on PATH`.
957    ///
958    /// The probe is environmental — runs the same check on the host
959    /// the test executes on. CI usually doesn't have `mlx_vlm`
960    /// installed → expected unavailable; a dev box with it installed
961    /// → expected available. Either way, the registry tracks the
962    /// runtime probe.
963    #[test]
964    fn mlx_vlm_models_reflect_runtime_availability() {
965        let (reg, _tmp) = test_registry();
966        let mlx_vlm_models: Vec<&ModelSchema> = reg
967            .list()
968            .into_iter()
969            .filter(|m| m.tags.iter().any(|t| t == "requires-mlx-vlm"))
970            .collect();
971        assert!(
972            !mlx_vlm_models.is_empty(),
973            "catalog should contain at least one model tagged \
974             `requires-mlx-vlm` — otherwise this regression has \
975             nothing to guard"
976        );
977
978        #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
979        let expected = crate::backend::mlx_vlm_cli::is_available();
980        #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
981        let expected = false;
982
983        for m in mlx_vlm_models {
984            assert_eq!(
985                m.available, expected,
986                "model {} `available` field should reflect \
987                 mlx_vlm CLI presence (expected {expected}, got {})",
988                m.id, m.available
989            );
990        }
991    }
992
993    /// Embedded JSON must parse cleanly — if it doesn't, the runtime
994    /// would panic on first registry load. Catch it in CI instead.
995    #[test]
996    fn builtin_catalog_json_parses() {
997        let catalog: Vec<ModelSchema> = serde_json::from_str(BUILTIN_CATALOG_JSON)
998            .expect("builtin_catalog.json must be valid ModelSchema array");
999        assert!(
1000            !catalog.is_empty(),
1001            "embedded catalog has no entries — that's almost certainly wrong"
1002        );
1003
1004        let mut seen = std::collections::HashSet::new();
1005        for entry in &catalog {
1006            assert!(
1007                seen.insert(entry.id.clone()),
1008                "duplicate id in builtin_catalog.json: {}",
1009                entry.id
1010            );
1011        }
1012    }
1013
1014    #[test]
1015    fn public_benchmarks_round_trip_through_model_info() {
1016        use crate::schema::BenchmarkScore;
1017        let (mut reg, _tmp) = test_registry();
1018        let mut schema = reg
1019            .find_by_name("Qwen3-4B")
1020            .expect("catalog has Qwen3-4B")
1021            .clone();
1022        schema.id = "test/qwen3-4b-with-bench".into();
1023        schema.public_benchmarks = vec![
1024            BenchmarkScore {
1025                name: "MMLU-Pro".into(),
1026                score: 0.482,
1027                harness: Some("5-shot CoT".into()),
1028                source_url: Some("https://example.invalid/qwen3-4b-card".into()),
1029                measured_at: Some("2025-08-12".into()),
1030            },
1031            BenchmarkScore {
1032                name: "HumanEval".into(),
1033                score: 0.713,
1034                harness: Some("pass@1".into()),
1035                source_url: None,
1036                measured_at: None,
1037            },
1038        ];
1039        reg.register(schema);
1040
1041        let stored = reg
1042            .get("test/qwen3-4b-with-bench")
1043            .expect("registered model is retrievable");
1044        let info = ModelInfo::from(stored);
1045        assert_eq!(info.public_benchmarks.len(), 2);
1046
1047        // The serialized JSON shape is what the WS / FFI clients consume.
1048        let json = serde_json::to_string(&info).unwrap();
1049        assert!(json.contains("\"public_benchmarks\""));
1050        assert!(json.contains("\"MMLU-Pro\""));
1051        assert!(json.contains("\"5-shot CoT\""));
1052
1053        // Round-trip back through serde to confirm deserialization works.
1054        let decoded: ModelInfo = serde_json::from_str(&json).unwrap();
1055        assert_eq!(decoded.public_benchmarks.len(), 2);
1056        assert_eq!(decoded.public_benchmarks[0].name, "MMLU-Pro");
1057        assert_eq!(decoded.public_benchmarks[1].name, "HumanEval");
1058    }
1059
1060    #[test]
1061    fn public_benchmarks_default_to_empty_when_absent_in_json() {
1062        // Older user-config JSON written before this field existed must
1063        // still deserialize cleanly into the new ModelSchema shape.
1064        let legacy_json = r#"{
1065            "id": "legacy/test:1",
1066            "name": "Legacy Test",
1067            "provider": "test",
1068            "family": "test",
1069            "version": "",
1070            "capabilities": ["generate"],
1071            "context_length": 4096,
1072            "param_count": "1B",
1073            "quantization": null,
1074            "performance": {},
1075            "cost": {},
1076            "source": { "type": "ollama", "model_tag": "legacy:1" },
1077            "tags": [],
1078            "supported_params": []
1079        }"#;
1080        let schema: ModelSchema = serde_json::from_str(legacy_json).unwrap();
1081        assert!(schema.public_benchmarks.is_empty());
1082    }
1083
1084    #[test]
1085    fn find_by_name() {
1086        let (reg, _tmp) = test_registry();
1087        let m = reg.find_by_name("Qwen3-4B").unwrap();
1088        #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1089        assert_eq!(m.id, "mlx/qwen3-4b:4bit");
1090        #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
1091        assert_eq!(m.id, "qwen/qwen3-4b:q4_k_m");
1092        assert!(m.has_capability(ModelCapability::Code));
1093    }
1094
1095    #[test]
1096    fn query_by_capability() {
1097        let (reg, _tmp) = test_registry();
1098        let embed_models = reg.query_by_capability(ModelCapability::Embed);
1099        assert_eq!(embed_models.len(), 2);
1100        assert!(embed_models
1101            .iter()
1102            .any(|model| model.name == "Qwen3-Embedding-0.6B"));
1103        assert!(embed_models
1104            .iter()
1105            .any(|model| model.name == "Qwen3-Embedding-0.6B-MLX"));
1106    }
1107
1108    #[test]
1109    fn query_with_filter() {
1110        let (reg, _tmp) = test_registry();
1111        let code_small = reg.query(&ModelFilter {
1112            capabilities: vec![ModelCapability::Code],
1113            max_size_mb: Some(3000),
1114            local_only: true,
1115            ..Default::default()
1116        });
1117        // Qwen3-1.7B, Qwen3-1.7B-MLX, Qwen3-4B, and Qwen3-4B-MLX fit and have Code capability.
1118        assert_eq!(code_small.len(), 4);
1119    }
1120
1121    #[test]
1122    fn register_remote() {
1123        let (mut reg, _tmp) = test_registry();
1124        let initial_len = reg.list().len();
1125        let initial_reasoning_len = reg
1126            .query(&ModelFilter {
1127                capabilities: vec![ModelCapability::Reasoning, ModelCapability::ToolUse],
1128                ..Default::default()
1129            })
1130            .len();
1131        let remote = ModelSchema {
1132            id: "anthropic/claude-sonnet-4-6:latest".into(),
1133            name: "Claude Sonnet 4.6".into(),
1134            provider: "anthropic".into(),
1135            family: "claude-4".into(),
1136            version: "latest".into(),
1137            capabilities: vec![
1138                ModelCapability::Generate,
1139                ModelCapability::Code,
1140                ModelCapability::Reasoning,
1141                ModelCapability::ToolUse,
1142            ],
1143            context_length: 200000,
1144            param_count: String::new(),
1145            quantization: None,
1146            performance: PerformanceEnvelope {
1147                latency_p50_ms: Some(2000),
1148                ..Default::default()
1149            },
1150            cost: CostModel {
1151                input_per_mtok: Some(3.0),
1152                output_per_mtok: Some(15.0),
1153                ..Default::default()
1154            },
1155            source: ModelSource::RemoteApi {
1156                endpoint: "https://api.anthropic.com/v1/messages".into(),
1157                api_key_env: "ANTHROPIC_API_KEY".into(),
1158                api_key_envs: vec![],
1159                api_version: Some("2023-06-01".into()),
1160                protocol: ApiProtocol::Anthropic,
1161            },
1162            tags: vec![],
1163            supported_params: vec![],
1164            public_benchmarks: vec![],
1165            available: false,
1166        };
1167
1168        reg.register(remote);
1169        // Same ID as builtin claude-sonnet-4-6 — replaces, count stays same
1170        assert_eq!(reg.list().len(), initial_len);
1171
1172        let reasoning = reg.query(&ModelFilter {
1173            capabilities: vec![ModelCapability::Reasoning, ModelCapability::ToolUse],
1174            ..Default::default()
1175        });
1176        // Replacing an existing remote slot should not change the reasoning/tool-use lineup size.
1177        assert_eq!(reasoning.len(), initial_reasoning_len);
1178    }
1179
1180    #[test]
1181    fn unregister() {
1182        let (mut reg, _tmp) = test_registry();
1183        let initial_len = reg.list().len();
1184        let removed = reg.unregister("qwen/qwen3-0.6b:q8_0");
1185        assert!(removed.is_some());
1186        assert_eq!(reg.list().len(), initial_len - 1);
1187    }
1188
1189    #[test]
1190    fn speech_models_are_curated() {
1191        let (reg, _tmp) = test_registry();
1192        let stt = reg.query_by_capability(ModelCapability::SpeechToText);
1193        let tts = reg.query_by_capability(ModelCapability::TextToSpeech);
1194        assert_eq!(stt.len(), 2);
1195        assert_eq!(tts.len(), 4);
1196    }
1197
1198    #[test]
1199    fn qwen_8b_variants_keep_tool_use_consistent() {
1200        let (reg, _tmp) = test_registry();
1201        for name in ["Qwen3-8B", "Qwen3-8B-MLX"] {
1202            let model = reg.find_by_name(name).expect("model should exist");
1203            assert!(model.has_capability(ModelCapability::ToolUse));
1204            assert!(model.has_capability(ModelCapability::MultiToolCall));
1205        }
1206    }
1207
1208    #[test]
1209    fn mac_name_resolution_prefers_mlx_siblings() {
1210        // Only used inside the aarch64-macos cfg below; non-mac targets
1211        // keep the test as a smoke compile.
1212        #[allow(unused_variables)]
1213        let (reg, _tmp) = test_registry();
1214        #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1215        {
1216            assert_eq!(reg.find_by_name("Qwen3-0.6B").unwrap().id, "mlx/qwen3-0.6b:6bit");
1217            assert_eq!(reg.find_by_name("Qwen3-1.7B").unwrap().id, "mlx/qwen3-1.7b:3bit");
1218            assert_eq!(
1219                reg.find_by_name("Qwen3-Embedding-0.6B").unwrap().id,
1220                "mlx/qwen3-embedding-0.6b:mxfp8"
1221            );
1222        }
1223    }
1224
1225    #[test]
1226    fn remote_multimodal_models_are_curated_as_vision_capable() {
1227        let (reg, _tmp) = test_registry();
1228        for name in [
1229            "claude-opus-4-7",
1230            "claude-opus-4-6",
1231            "claude-sonnet-4-6",
1232            "claude-haiku-4-5",
1233            "gpt-5.4",
1234            "gpt-5.4-mini",
1235            "o3",
1236            "o4-mini",
1237            "gpt-4.1-mini",
1238            "gemini-2.5-pro",
1239            "gemini-2.5-flash",
1240        ] {
1241            let model = reg.find_by_name(name).expect("model should exist");
1242            assert!(
1243                model.has_capability(ModelCapability::Vision),
1244                "{name} should be curated as vision-capable"
1245            );
1246        }
1247    }
1248
1249    #[test]
1250    fn qwen25vl_entries_are_replaced_by_qwen3vl_in_builtin_catalog() {
1251        let (reg, _tmp) = test_registry();
1252
1253        let stale_ids = [
1254            // Native MLX text tower can't tokenize images — never advertise.
1255            "mlx/qwen2.5-vl-3b:4bit",
1256            "mlx/qwen2.5-vl-7b:4bit",
1257            // Qwen2.5-VL is superseded by Qwen3-VL; drop the mlx-vlm CLI
1258            // catalog entries so callers route to the upgraded family.
1259            "mlx-vlm/qwen2.5-vl-3b:4bit",
1260            "mlx-vlm/qwen2.5-vl-7b:4bit",
1261            // Same supersession applies to the vLLM-MLX route.
1262            "vllm-mlx/qwen2.5-vl-3b:4bit",
1263        ];
1264        for id in stale_ids {
1265            assert!(
1266                reg.get(id).is_none(),
1267                "{id} is superseded by Qwen3-VL; the catalog must not advertise it"
1268            );
1269        }
1270
1271        let vision_ids: Vec<&str> = reg
1272            .query_by_capability(ModelCapability::Vision)
1273            .into_iter()
1274            .map(|model| model.id.as_str())
1275            .collect();
1276        for stale in stale_ids {
1277            assert!(
1278                !vision_ids.contains(&stale),
1279                "{stale} must not be reachable through the Vision capability index"
1280            );
1281        }
1282        assert!(
1283            vision_ids.contains(&"mlx-vlm/qwen3-vl-2b:bf16"),
1284            "Qwen3-VL is the supported local VL family and must route as Vision"
1285        );
1286    }
1287
1288    #[test]
1289    fn gemini_models_are_curated_for_multimodal_tool_use() {
1290        let (reg, _tmp) = test_registry();
1291        for name in ["gemini-2.5-pro", "gemini-2.5-flash"] {
1292            let model = reg.find_by_name(name).expect("model should exist");
1293            assert!(model.has_capability(ModelCapability::Vision));
1294            assert!(model.has_capability(ModelCapability::ToolUse));
1295            assert!(model.has_capability(ModelCapability::MultiToolCall));
1296        }
1297    }
1298
1299    #[test]
1300    fn visual_generation_models_are_curated() {
1301        let (reg, _tmp) = test_registry();
1302        assert_eq!(
1303            reg.query_by_capability(ModelCapability::ImageGeneration).len(),
1304            1
1305        );
1306        assert_eq!(
1307            reg.query_by_capability(ModelCapability::VideoGeneration).len(),
1308            2
1309        );
1310        let yume = reg
1311            .get("mlx/yume-1.5-5b-720p:q4")
1312            .expect("Yume MLX should be in the built-in catalog");
1313        assert!(yume.has_capability(ModelCapability::VideoGeneration));
1314        assert!(yume.tags.contains(&"text-to-video".to_string()));
1315        assert!(yume.tags.contains(&"image-to-video".to_string()));
1316        assert!(yume.tags.contains(&"world-model".to_string()));
1317    }
1318}