Skip to main content

car_inference/
registry.rs

1//! Unified model registry — local and remote models under one schema.
2//!
3//! Replaces the hardcoded `ModelRegistry` from `models.rs` with a schema-driven
4//! registry that treats all models as first-class typed resources. Users can
5//! register custom models (fine-tuned endpoints, private APIs) alongside the
6//! built-in catalog.
7
8use crate::schema::reasoning_params;
9use std::collections::HashMap;
10use std::path::{Path, PathBuf};
11use std::time::SystemTime;
12
13use serde::{Deserialize, Serialize};
14use tracing::info;
15
16use crate::schema::*;
17use crate::InferenceError;
18
19/// Filter for querying the registry.
20#[derive(Debug, Clone, Default)]
21pub struct ModelFilter {
22    /// Required capabilities (model must have ALL of these).
23    pub capabilities: Vec<ModelCapability>,
24    /// Maximum on-disk / RAM size in MB.
25    pub max_size_mb: Option<u64>,
26    /// Maximum expected latency in ms (from declared envelope).
27    pub max_latency_ms: Option<u64>,
28    /// Maximum cost per 1M output tokens in USD.
29    pub max_cost_per_mtok: Option<f64>,
30    /// Required tags (model must have ALL of these).
31    pub tags: Vec<String>,
32    /// Filter by provider.
33    pub provider: Option<String>,
34    /// Only local models.
35    pub local_only: bool,
36    /// Only models that are currently available.
37    pub available_only: bool,
38}
39
40/// A curated replacement for a local model that is installed but no longer
41/// the preferred model in its line.
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct ModelUpgrade {
44    pub from_id: String,
45    pub from_name: String,
46    pub to_id: String,
47    pub to_name: String,
48    pub reason: String,
49    pub target_runtime: Option<String>,
50    pub target_runtime_requirement: Option<String>,
51    pub minimum_runtimes: Vec<ModelRuntimeRequirement>,
52    pub target_available: bool,
53    pub target_pullable: bool,
54    pub remove_old_supported: bool,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct ModelRuntimeRequirement {
59    pub name: String,
60    pub minimum_version: String,
61}
62
63/// Unified registry of all known models.
64pub struct UnifiedRegistry {
65    models_dir: PathBuf,
66    /// All registered models, keyed by id.
67    models: HashMap<String, ModelSchema>,
68    /// User-added model config file path (~/.car/models.json).
69    user_config_path: PathBuf,
70}
71
72#[derive(Debug, Clone, Deserialize)]
73struct ModelUpgradeRule {
74    from_ids: Vec<String>,
75    to_id: String,
76    reason: String,
77    target_runtime: Option<String>,
78    target_runtime_requirement: Option<String>,
79    #[serde(default)]
80    minimum_runtimes: Vec<ModelRuntimeRequirement>,
81    #[serde(default = "default_remove_old_after_available")]
82    remove_old_after_available: bool,
83}
84
85fn default_remove_old_after_available() -> bool {
86    true
87}
88
89fn model_upgrade_rules() -> Vec<ModelUpgradeRule> {
90    serde_json::from_str(include_str!("../assets/model-upgrades.json"))
91        .expect("built-in model-upgrades.json should parse")
92}
93
94impl UnifiedRegistry {
95    pub fn new(models_dir: PathBuf) -> Self {
96        let user_config_path = models_dir
97            .parent()
98            .unwrap_or(&models_dir)
99            .join("models.json");
100
101        let mut registry = Self {
102            models_dir,
103            models: HashMap::new(),
104            user_config_path,
105        };
106        registry.load_builtin_catalog();
107        registry.refresh_availability();
108        // Load user config on top (silently ignore if missing)
109        let _ = registry.load_user_config();
110        registry
111    }
112
113    /// Register a model at runtime.
114    pub fn register(&mut self, mut schema: ModelSchema) {
115        // Check availability for local models
116        if schema.is_mlx() {
117            schema.available = if schema.tags.contains(&"speech".to_string()) {
118                speech_mlx_available()
119            } else if let ModelSource::Mlx { ref hf_repo, .. } = schema.source {
120                // Available if cached locally OR has an hf_repo —
121                // ensure_local() lazy-downloads on first use, so a
122                // declared hf_repo is "functionally available" the
123                // same way Ollama/RemoteApi entries are. Mirrors the
124                // refresh_availability() check below; see #164.
125                let mlx_dir = self.models_dir.join(&schema.name);
126                mlx_dir.join("config.json").exists() || !hf_repo.is_empty()
127            } else {
128                let mlx_dir = self.models_dir.join(&schema.name);
129                mlx_dir.join("config.json").exists()
130            };
131        } else if schema.is_vllm_mlx() {
132            // vLLM-MLX: available if endpoint env var set or was manually marked available
133            schema.available = std::env::var("VLLM_MLX_ENDPOINT").is_ok() || schema.available;
134        } else if schema.is_local() {
135            let local_path = self.models_dir.join(&schema.name).join("model.gguf");
136            schema.available = local_path.exists();
137        } else if schema.is_remote() {
138            // Remote models are assumed available if the env var exists
139            if let ModelSource::RemoteApi {
140                ref api_key_env, ..
141            } = schema.source
142            {
143                schema.available = std::env::var(api_key_env).is_ok();
144            }
145        }
146        info!(id = %schema.id, name = %schema.name, available = schema.available, "registered model");
147        self.models.insert(schema.id.clone(), schema);
148    }
149
150    /// Unregister a model by id. Returns the removed schema if found.
151    pub fn unregister(&mut self, id: &str) -> Option<ModelSchema> {
152        let removed = self.models.remove(id);
153        if let Some(ref m) = removed {
154            info!(id = %m.id, "unregistered model");
155        }
156        removed
157    }
158
159    /// List all models.
160    pub fn list(&self) -> Vec<&ModelSchema> {
161        let mut models: Vec<&ModelSchema> = self.models.values().collect();
162        models.sort_by(|a, b| a.id.cmp(&b.id));
163        models
164    }
165
166    /// Query models matching a filter.
167    pub fn query(&self, filter: &ModelFilter) -> Vec<&ModelSchema> {
168        self.models
169            .values()
170            .filter(|m| {
171                // Capability check: model must have ALL required capabilities
172                if !filter.capabilities.iter().all(|c| m.has_capability(*c)) {
173                    return false;
174                }
175                // Size check
176                if let Some(max) = filter.max_size_mb {
177                    if m.size_mb() > max && m.is_local() {
178                        return false;
179                    }
180                }
181                // Latency check (declared envelope)
182                if let Some(max) = filter.max_latency_ms {
183                    if let Some(p50) = m.performance.latency_p50_ms {
184                        if p50 > max {
185                            return false;
186                        }
187                    }
188                }
189                // Cost check
190                if let Some(max) = filter.max_cost_per_mtok {
191                    if let Some(cost) = m.cost.output_per_mtok {
192                        if cost > max {
193                            return false;
194                        }
195                    }
196                }
197                // Tag check
198                if !filter.tags.iter().all(|t| m.tags.contains(t)) {
199                    return false;
200                }
201                // Provider check
202                if let Some(ref p) = filter.provider {
203                    if &m.provider != p {
204                        return false;
205                    }
206                }
207                // Local only
208                if filter.local_only && !m.is_local() {
209                    return false;
210                }
211                // Available only
212                if filter.available_only && !m.available {
213                    return false;
214                }
215                true
216            })
217            .collect()
218    }
219
220    /// Query models by a single capability.
221    pub fn query_by_capability(&self, cap: ModelCapability) -> Vec<&ModelSchema> {
222        self.query(&ModelFilter {
223            capabilities: vec![cap],
224            ..Default::default()
225        })
226    }
227
228    /// Report installed local models with curated newer replacements.
229    pub fn available_upgrades(&self) -> Vec<ModelUpgrade> {
230        let mut upgrades = Vec::new();
231        for rule in model_upgrade_rules() {
232            let Some(from) = rule
233                .from_ids
234                .iter()
235                .find_map(|id| self.models.get(id.as_str()))
236                .filter(|schema| schema.available)
237            else {
238                continue;
239            };
240            let Some(to) = self.models.get(rule.to_id.as_str()) else {
241                continue;
242            };
243            upgrades.push(ModelUpgrade {
244                from_id: from.id.clone(),
245                from_name: from.name.clone(),
246                to_id: to.id.clone(),
247                to_name: to.name.clone(),
248                reason: rule.reason.clone(),
249                target_runtime: rule.target_runtime.clone(),
250                target_runtime_requirement: rule.target_runtime_requirement.clone(),
251                minimum_runtimes: rule.minimum_runtimes.clone(),
252                target_available: to.available,
253                target_pullable: matches!(
254                    to.source,
255                    ModelSource::Local { .. } | ModelSource::Mlx { .. }
256                ),
257                remove_old_supported: matches!(
258                    from.source,
259                    ModelSource::Local { .. } | ModelSource::Mlx { .. }
260                ) && rule.remove_old_after_available,
261            });
262        }
263        upgrades.sort_by(|a, b| a.from_id.cmp(&b.from_id).then(a.to_id.cmp(&b.to_id)));
264        upgrades.dedup_by(|a, b| a.from_id == b.from_id && a.to_id == b.to_id);
265        upgrades
266    }
267
268    /// Get a specific model by id.
269    pub fn get(&self, id: &str) -> Option<&ModelSchema> {
270        self.models.get(id)
271    }
272
273    /// Find a model by name (case-insensitive). For backward compatibility
274    /// with the old registry that used short names like "Qwen3-4B".
275    pub fn find_by_name(&self, name: &str) -> Option<&ModelSchema> {
276        #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
277        if !name.to_ascii_lowercase().ends_with("-mlx") {
278            if let Some(mlx_variant) = self
279                .models
280                .values()
281                .find(|m| m.name.eq_ignore_ascii_case(&format!("{name}-MLX")))
282            {
283                return Some(mlx_variant);
284            }
285        }
286
287        self.models
288            .values()
289            .find(|m| m.name.eq_ignore_ascii_case(name))
290    }
291
292    /// On Apple Silicon, resolve a GGUF/Candle model to its MLX equivalent.
293    /// Returns the MLX model schema if one exists with the same family and
294    /// matching capabilities; otherwise returns None.
295    #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
296    pub fn resolve_mlx_equivalent(&self, schema: &ModelSchema) -> Option<&ModelSchema> {
297        // Already MLX — no redirect needed.
298        if schema.is_mlx() || schema.is_vllm_mlx() {
299            return None;
300        }
301        // Only redirect local GGUF models.
302        if !matches!(schema.source, ModelSource::Local { .. }) {
303            return None;
304        }
305        // Find an MLX model in the same family with at least the same primary capability.
306        let primary_cap = schema.capabilities.first()?;
307        self.models.values().find(|m| {
308            m.is_mlx() && m.family == schema.family && m.capabilities.contains(primary_cap)
309        })
310    }
311
312    /// Ensure a local model is downloaded, returning its local directory path.
313    pub async fn ensure_local(&self, id: &str) -> Result<PathBuf, InferenceError> {
314        let schema = self
315            .get(id)
316            .or_else(|| self.find_by_name(id))
317            .ok_or_else(|| InferenceError::ModelNotFound(id.to_string()))?;
318
319        match &schema.source {
320            ModelSource::Local {
321                hf_repo,
322                hf_filename,
323                tokenizer_repo,
324            } => {
325                let model_dir = self.models_dir.join(&schema.name);
326                let model_path = model_dir.join("model.gguf");
327                let tokenizer_path = model_dir.join("tokenizer.json");
328
329                if model_path.exists() && tokenizer_path.exists() {
330                    return Ok(model_dir);
331                }
332
333                std::fs::create_dir_all(&model_dir)?;
334
335                if !model_path.exists() {
336                    info!(model = %schema.name, repo = %hf_repo, "downloading model weights");
337                    download_file(hf_repo, hf_filename, &model_path).await?;
338                }
339                if !tokenizer_path.exists() {
340                    info!(model = %schema.name, repo = %tokenizer_repo, "downloading tokenizer");
341                    download_file(tokenizer_repo, "tokenizer.json", &tokenizer_path).await?;
342                }
343
344                Ok(model_dir)
345            }
346            ModelSource::Mlx {
347                hf_repo,
348                hf_weight_file,
349            } => {
350                let model_dir = self.models_dir.join(&schema.name);
351                let config_path = model_dir.join("config.json");
352
353                if config_path.exists() {
354                    ensure_auxiliary_mlx_files(&schema.name, hf_repo, &model_dir).await?;
355                    info!(model = %schema.name, path = %model_dir.display(), "using managed local MLX model");
356                    return Ok(model_dir);
357                }
358
359                if let Some(snapshot_dir) = latest_huggingface_repo_snapshot(hf_repo) {
360                    ensure_auxiliary_mlx_files(&schema.name, hf_repo, &snapshot_dir).await?;
361                    info!(model = %schema.name, path = %snapshot_dir.display(), "using cached MLX snapshot");
362                    return Ok(snapshot_dir);
363                }
364
365                if requires_full_mlx_snapshot(&schema) {
366                    info!(
367                        model = %schema.name,
368                        repo = %hf_repo,
369                        "downloading full MLX snapshot"
370                    );
371                    let (snapshot_dir, _files_downloaded) =
372                        download_hf_repo_snapshot(hf_repo).await?;
373                    ensure_auxiliary_mlx_files(&schema.name, hf_repo, &snapshot_dir).await?;
374                    return Ok(snapshot_dir);
375                }
376
377                std::fs::create_dir_all(&model_dir)?;
378
379                info!(model = %schema.name, repo = %hf_repo, "downloading MLX model");
380
381                // Download config, tokenizer, and weight files
382                download_file(hf_repo, "config.json", &config_path).await?;
383                let tok_path = model_dir.join("tokenizer.json");
384                if !tok_path.exists() {
385                    download_file(hf_repo, "tokenizer.json", &tok_path).await?;
386                }
387                let tok_config_path = model_dir.join("tokenizer_config.json");
388                if !tok_config_path.exists() {
389                    let _ = download_file(hf_repo, "tokenizer_config.json", &tok_config_path).await;
390                }
391
392                // Download weight files
393                if let Some(ref wf) = hf_weight_file {
394                    let wf_path = model_dir.join(wf);
395                    if !wf_path.exists() {
396                        download_file(hf_repo, wf, &wf_path).await?;
397                    }
398                } else {
399                    // Try single file first, then sharded
400                    let single = model_dir.join("model.safetensors");
401                    if !single.exists() {
402                        match download_file(hf_repo, "model.safetensors", &single).await {
403                            Ok(()) => {}
404                            Err(_) => {
405                                // Sharded: download index and then each shard
406                                let index_path = model_dir.join("model.safetensors.index.json");
407                                download_file(hf_repo, "model.safetensors.index.json", &index_path)
408                                    .await?;
409
410                                let index_json: serde_json::Value =
411                                    serde_json::from_str(&std::fs::read_to_string(&index_path)?)
412                                        .map_err(|e| {
413                                            InferenceError::InferenceFailed(format!(
414                                                "parse index: {e}"
415                                            ))
416                                        })?;
417
418                                if let Some(weight_map) =
419                                    index_json.get("weight_map").and_then(|m| m.as_object())
420                                {
421                                    let mut files: std::collections::HashSet<String> =
422                                        std::collections::HashSet::new();
423                                    for filename in weight_map.values() {
424                                        if let Some(f) = filename.as_str() {
425                                            files.insert(f.to_string());
426                                        }
427                                    }
428                                    for file in &files {
429                                        let dest = model_dir.join(file);
430                                        if !dest.exists() {
431                                            info!(file = %file, "downloading weight shard");
432                                            download_file(hf_repo, file, &dest).await?;
433                                        }
434                                    }
435                                }
436                            }
437                        }
438                    }
439                }
440
441                ensure_auxiliary_mlx_files(&schema.name, hf_repo, &model_dir).await?;
442                Ok(model_dir)
443            }
444            _ => Err(InferenceError::InferenceFailed(format!(
445                "model {} is not local",
446                id
447            ))),
448        }
449    }
450
451    /// Remove a downloaded local model.
452    pub fn remove_local(&mut self, id: &str) -> Result<(), InferenceError> {
453        let schema = self
454            .get(id)
455            .or_else(|| self.find_by_name(id))
456            .ok_or_else(|| InferenceError::ModelNotFound(id.to_string()))?;
457
458        let model_dir = self.models_dir.join(&schema.name);
459        if model_dir.exists() {
460            std::fs::remove_dir_all(&model_dir)?;
461            info!(model = %schema.name, "removed model");
462        }
463
464        match &schema.source {
465            ModelSource::Mlx { hf_repo, .. } => {
466                let repo_dir = huggingface_repo_dir(hf_repo);
467                if repo_dir.exists() {
468                    std::fs::remove_dir_all(&repo_dir)?;
469                    info!(model = %schema.name, repo = %hf_repo, "removed Hugging Face cache");
470                }
471            }
472            ModelSource::Local {
473                hf_repo,
474                tokenizer_repo,
475                ..
476            } => {
477                for repo in [hf_repo, tokenizer_repo] {
478                    let repo_dir = huggingface_repo_dir(repo);
479                    if repo_dir.exists() {
480                        std::fs::remove_dir_all(&repo_dir)?;
481                        info!(model = %schema.name, repo = %repo, "removed Hugging Face cache");
482                    }
483                }
484            }
485            _ => {}
486        }
487
488        // Update availability
489        let id = schema.id.clone();
490        if let Some(m) = self.models.get_mut(&id) {
491            m.available = false;
492        }
493        Ok(())
494    }
495
496    /// Refresh availability flags for all models.
497    ///
498    /// Runtime-true vs catalog-says: this is what closes the gap
499    /// `models.list_unified` callers rely on. If a model is listed
500    /// as `available: true` here, an `infer` call against it should
501    /// reach the backend, not bail with `UnsupportedMode { ...
502    /// "mlx-vlm CLI not found on PATH" }` (the #137 trap).
503    pub fn refresh_availability(&mut self) {
504        let models_dir = self.models_dir.clone();
505        // mlx-vlm CLI is the same probe call no matter which model
506        // requires it; do it once per refresh, not per-model.
507        #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
508        let mlx_vlm_cli_present = crate::backend::mlx_vlm_cli::is_available();
509        #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
510        let mlx_vlm_cli_present = false;
511
512        for m in self.models.values_mut() {
513            match &m.source {
514                ModelSource::Mlx { hf_repo, .. } => {
515                    // Models tagged `requires-mlx-vlm` shell out to the
516                    // mlx_vlm Python CLI for image inference (#115).
517                    // If the CLI isn't on PATH, the runtime reaches it
518                    // anyway and bails — the registry MUST reflect that
519                    // by marking such entries unavailable until the
520                    // user installs `uv tool install mlx-vlm`. #137.
521                    let needs_mlx_vlm = m.tags.iter().any(|t| t == "requires-mlx-vlm");
522
523                    m.available = if needs_mlx_vlm {
524                        mlx_vlm_cli_present
525                    } else if m.tags.contains(&"speech".to_string()) {
526                        speech_mlx_available()
527                    } else {
528                        // Available if cached locally OR has an hf_repo —
529                        // the native MLX path's ensure_local() lazy-
530                        // downloads on first use, so a declared hf_repo
531                        // is "functionally available" the same way
532                        // Ollama and RemoteApi entries are ("should
533                        // work in principle" not "physically cached").
534                        // Closes #164: mlx/ltx-2.3:q4 reported
535                        // unavailable even though `car video` would
536                        // just download-and-run successfully.
537                        let mlx_dir = models_dir.join(&m.name);
538                        mlx_dir.join("config.json").exists() || !hf_repo.is_empty()
539                    };
540                }
541                ModelSource::Local { .. } => {
542                    let local_path = models_dir.join(&m.name).join("model.gguf");
543                    m.available = local_path.exists();
544                }
545                ModelSource::RemoteApi { api_key_env, .. } => {
546                    m.available = std::env::var(api_key_env).is_ok();
547                }
548                ModelSource::Ollama { .. } => {
549                    // Assume available; health check is async and done lazily
550                    m.available = true;
551                }
552                ModelSource::VllmMlx { .. } => {
553                    // vLLM-MLX availability checked via health endpoint lazily
554                    // Mark as available if VLLM_MLX_ENDPOINT env var is set or default endpoint assumed
555                    m.available = std::env::var("VLLM_MLX_ENDPOINT").is_ok() || m.available;
556                    // preserve manual registration
557                }
558                ModelSource::Proprietary { auth, .. } => {
559                    // Check auth availability
560                    m.available = match auth {
561                        crate::schema::ProprietaryAuth::ApiKeyEnv { env_var } => {
562                            std::env::var(env_var).is_ok()
563                        }
564                        crate::schema::ProprietaryAuth::BearerTokenEnv { env_var } => {
565                            std::env::var(env_var).is_ok()
566                        }
567                        crate::schema::ProprietaryAuth::OAuth2Pkce { .. } => {
568                            // OAuth2 availability determined at runtime by token provider
569                            true
570                        }
571                    };
572                }
573                ModelSource::AppleFoundationModels { .. } => {
574                    // Apple Silicon macOS 26+ AND iOS 26+ both expose
575                    // the FoundationModels framework. The shim's
576                    // runtime probe handles per-device availability
577                    // (Apple Intelligence may be off, the device may
578                    // be pre-A17, etc.); cfg-gating here just hides
579                    // the call on targets where the shim isn't built.
580                    #[cfg(any(
581                        all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)),
582                        all(target_os = "ios", target_arch = "aarch64")
583                    ))]
584                    {
585                        m.available = crate::backend::foundation_models::is_available();
586                    }
587                    #[cfg(not(any(
588                        all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)),
589                        all(target_os = "ios", target_arch = "aarch64")
590                    )))]
591                    {
592                        m.available = false;
593                    }
594                }
595                ModelSource::Delegated { .. } => {
596                    // Availability tracks whether a runner is registered.
597                    // Hosts call `registerInferenceRunner` (or its
598                    // language equivalent) at startup; until then the
599                    // model is unavailable.
600                    m.available = crate::runner::current_inference_runner().is_some();
601                }
602            }
603        }
604    }
605
606    /// Persist user-registered (non-builtin) models to disk.
607    pub fn save_user_config(&self) -> Result<(), InferenceError> {
608        let user_models: Vec<&ModelSchema> = self
609            .models
610            .values()
611            .filter(|m| !m.tags.contains(&"builtin".to_string()))
612            .collect();
613
614        if user_models.is_empty() {
615            return Ok(());
616        }
617
618        let json = serde_json::to_string_pretty(&user_models)
619            .map_err(|e| InferenceError::InferenceFailed(format!("serialize: {e}")))?;
620        std::fs::write(&self.user_config_path, json)?;
621        Ok(())
622    }
623
624    /// Load user-registered models from disk.
625    pub fn load_user_config(&mut self) -> Result<(), InferenceError> {
626        if !self.user_config_path.exists() {
627            return Ok(());
628        }
629
630        let json = std::fs::read_to_string(&self.user_config_path)?;
631        let models: Vec<ModelSchema> = serde_json::from_str(&json)
632            .map_err(|e| InferenceError::InferenceFailed(format!("parse models.json: {e}")))?;
633
634        for m in models {
635            self.register(m);
636        }
637        Ok(())
638    }
639
640    /// Get the models directory path.
641    pub fn models_dir(&self) -> &Path {
642        &self.models_dir
643    }
644
645    /// Load the built-in Qwen3 catalog as ModelSchema objects.
646    fn load_builtin_catalog(&mut self) {
647        for schema in builtin_catalog() {
648            self.models.insert(schema.id.clone(), schema);
649        }
650    }
651}
652
653fn speech_mlx_available() -> bool {
654    // On Apple Silicon, speech uses native MLX backends — no Python CLI needed.
655    // Models are available if we're on the right platform (weights are downloaded on demand).
656    #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
657    {
658        true
659    }
660
661    // On other platforms, check for the Python mlx-audio CLI.
662    #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
663    {
664        let runtime_root = speech_runtime_root();
665        runtime_root
666            .join("bin")
667            .join("mlx_audio.stt.generate")
668            .exists()
669            || runtime_root
670                .join("bin")
671                .join("mlx_audio.tts.generate")
672                .exists()
673    }
674}
675
676fn speech_runtime_root() -> PathBuf {
677    if let Ok(path) = std::env::var("CAR_SPEECH_RUNTIME_DIR") {
678        if !path.trim().is_empty() {
679            return PathBuf::from(path);
680        }
681    }
682    std::env::var("HOME")
683        .map(PathBuf::from)
684        .unwrap_or_else(|_| PathBuf::from("."))
685        .join(".car")
686        .join("speech-runtime")
687}
688
689/// Backward-compatible ModelInfo for listing (used by CLI and old callers).
690#[derive(Debug, Clone, Serialize, Deserialize)]
691pub struct ModelInfo {
692    pub id: String,
693    pub name: String,
694    pub provider: String,
695    pub capabilities: Vec<ModelCapability>,
696    pub param_count: String,
697    pub size_mb: u64,
698    pub context_length: usize,
699    pub available: bool,
700    pub is_local: bool,
701    /// Public benchmark scores carried straight through from `ModelSchema`.
702    /// The built-in catalog ships this empty; populating it is a curation
703    /// step (see `BenchmarkScore` in the schema for shape and conventions).
704    #[serde(default)]
705    pub public_benchmarks: Vec<crate::schema::BenchmarkScore>,
706}
707
708impl From<&ModelSchema> for ModelInfo {
709    fn from(s: &ModelSchema) -> Self {
710        ModelInfo {
711            id: s.id.clone(),
712            name: s.name.clone(),
713            provider: s.provider.clone(),
714            capabilities: s.capabilities.clone(),
715            param_count: s.param_count.clone(),
716            size_mb: s.size_mb(),
717            context_length: s.context_length,
718            available: s.available,
719            is_local: s.is_local(),
720            public_benchmarks: s.public_benchmarks.clone(),
721        }
722    }
723}
724
725/// Download a single file from a HuggingFace repo.
726async fn download_file(repo: &str, filename: &str, dest: &Path) -> Result<(), InferenceError> {
727    let api = hf_hub::api::tokio::Api::new()
728        .map_err(|e| InferenceError::DownloadFailed(e.to_string()))?;
729
730    let repo = api.model(repo.to_string());
731    let path = repo
732        .get(filename)
733        .await
734        .map_err(|e| InferenceError::DownloadFailed(format!("{filename}: {e}")))?;
735
736    if dest.exists() {
737        return Ok(());
738    }
739
740    // Try symlink first, fall back to copy
741    #[cfg(unix)]
742    {
743        if std::os::unix::fs::symlink(&path, dest).is_ok() {
744            return Ok(());
745        }
746    }
747
748    std::fs::copy(&path, dest)
749        .map_err(|e| InferenceError::DownloadFailed(format!("copy to {}: {e}", dest.display())))?;
750    Ok(())
751}
752
753async fn ensure_auxiliary_mlx_files(
754    model_name: &str,
755    hf_repo: &str,
756    model_dir: &Path,
757) -> Result<(), InferenceError> {
758    if hf_repo == "mlx-community/Flux-1.lite-8B-MLX-Q4" || model_name == "Flux-1.lite-8B-MLX-Q4" {
759        let t5_tokenizer_path = model_dir.join("tokenizer_2").join("tokenizer.json");
760        if !t5_tokenizer_path.exists() {
761            std::fs::create_dir_all(t5_tokenizer_path.parent().ok_or_else(|| {
762                InferenceError::InferenceFailed("invalid tokenizer path".into())
763            })?)?;
764            info!(
765                path = %t5_tokenizer_path.display(),
766                "downloading missing Flux tokenizer_2/tokenizer.json from base model"
767            );
768            download_file(
769                "Freepik/flux.1-lite-8B",
770                "tokenizer_2/tokenizer.json",
771                &t5_tokenizer_path,
772            )
773            .await?;
774        }
775    }
776    Ok(())
777}
778
779fn requires_full_mlx_snapshot(schema: &ModelSchema) -> bool {
780    match &schema.source {
781        ModelSource::Mlx { hf_repo, .. } => {
782            hf_repo == "ckurasek/Yume-1.5-5B-720P-MLX-4bit"
783                || schema.family.starts_with("yume")
784                || schema.tags.iter().any(|tag| {
785                    matches!(
786                        tag.as_str(),
787                        "wan2.2" | "ti2v" | "world-model" | "image-to-video"
788                    )
789                })
790        }
791        _ => false,
792    }
793}
794
795fn huggingface_repo_has_snapshot(repo_id: &str) -> bool {
796    latest_huggingface_repo_snapshot(repo_id).is_some()
797}
798
799fn huggingface_cache_root() -> PathBuf {
800    std::env::var("HF_HOME")
801        .map(PathBuf::from)
802        .unwrap_or_else(|_| {
803            std::env::var("HOME")
804                .map(PathBuf::from)
805                .unwrap_or_else(|_| PathBuf::from("."))
806                .join(".cache")
807                .join("huggingface")
808        })
809        .join("hub")
810}
811
812fn huggingface_repo_dir(repo_id: &str) -> PathBuf {
813    huggingface_cache_root().join(format!("models--{}", repo_id.replace('/', "--")))
814}
815
816fn resolve_huggingface_ref_snapshot(repo_dir: &Path, name: &str) -> Option<PathBuf> {
817    let sha = std::fs::read_to_string(repo_dir.join("refs").join(name))
818        .ok()?
819        .trim()
820        .to_string();
821    if sha.is_empty() {
822        return None;
823    }
824
825    let snapshot = repo_dir.join("snapshots").join(sha);
826    if snapshot_looks_ready(&snapshot) {
827        Some(snapshot)
828    } else {
829        None
830    }
831}
832
833fn latest_huggingface_repo_snapshot(repo_id: &str) -> Option<PathBuf> {
834    let repo_dir = huggingface_repo_dir(repo_id);
835    if let Some(snapshot) = resolve_huggingface_ref_snapshot(&repo_dir, "main") {
836        return Some(snapshot);
837    }
838
839    let snapshots = repo_dir.join("snapshots");
840    let mut candidates: Vec<(SystemTime, PathBuf)> = std::fs::read_dir(snapshots)
841        .ok()?
842        .filter_map(Result::ok)
843        .map(|e| e.path())
844        .filter(|p| p.is_dir() && snapshot_looks_ready(p))
845        .map(|path| {
846            let modified = path
847                .metadata()
848                .and_then(|metadata| metadata.modified())
849                .unwrap_or(SystemTime::UNIX_EPOCH);
850            (modified, path)
851        })
852        .collect();
853    candidates.sort();
854    candidates.pop().map(|(_, path)| path)
855}
856
857fn snapshot_looks_ready(path: &Path) -> bool {
858    if path.join("config.json").exists() || path.join("model_index.json").exists() {
859        return true;
860    }
861    snapshot_contains_ext(path, "safetensors")
862}
863
864fn snapshot_contains_ext(root: &Path, ext: &str) -> bool {
865    let Ok(entries) = std::fs::read_dir(root) else {
866        return false;
867    };
868    entries.filter_map(Result::ok).any(|entry| {
869        let path = entry.path();
870        if path.is_dir() {
871            snapshot_contains_ext(&path, ext)
872        } else {
873            path.extension()
874                .and_then(|value| value.to_str())
875                .map(|value| value.eq_ignore_ascii_case(ext))
876                .unwrap_or(false)
877        }
878    })
879}
880
881async fn download_hf_repo_snapshot(repo_id: &str) -> Result<(PathBuf, usize), InferenceError> {
882    let api = hf_hub::api::tokio::ApiBuilder::from_env()
883        .with_progress(false)
884        .build()
885        .map_err(|e| InferenceError::DownloadFailed(format!("init hf api: {e}")))?;
886    let repo = api.model(repo_id.to_string());
887    let info = repo
888        .info()
889        .await
890        .map_err(|e| InferenceError::DownloadFailed(format!("{repo_id}: {e}")))?;
891
892    let snapshot_path = std::env::var("HF_HOME")
893        .map(PathBuf::from)
894        .unwrap_or_else(|_| {
895            std::env::var("HOME")
896                .map(PathBuf::from)
897                .unwrap_or_else(|_| PathBuf::from("."))
898                .join(".cache")
899                .join("huggingface")
900        })
901        .join("hub")
902        .join(format!("models--{}", repo_id.replace('/', "--")))
903        .join("snapshots")
904        .join(&info.sha);
905    let mut downloaded = 0usize;
906    for sibling in &info.siblings {
907        let local_path = snapshot_path.join(&sibling.rfilename);
908        if local_path.exists() {
909            downloaded += 1;
910            continue;
911        }
912        repo.download(&sibling.rfilename).await.map_err(|e| {
913            InferenceError::DownloadFailed(format!("{repo_id}/{}: {e}", sibling.rfilename))
914        })?;
915        downloaded += 1;
916    }
917
918    Ok((snapshot_path, downloaded))
919}
920
921/// Built-in catalog parsed from `builtin_catalog.json`.
922///
923/// Adding, removing, or editing a model is a JSON-only change — Rust
924/// source stays put. The JSON is embedded at compile time via
925/// `include_str!`, parsed once into a `LazyLock`, and cloned on each
926/// call. A malformed JSON file fails the integration test
927/// `builtin_catalog_json_parses` so the binary never ships unable
928/// to load its own catalog.
929const BUILTIN_CATALOG_JSON: &str = include_str!("builtin_catalog.json");
930
931static BUILTIN_CATALOG: std::sync::LazyLock<Vec<ModelSchema>> = std::sync::LazyLock::new(|| {
932    serde_json::from_str(BUILTIN_CATALOG_JSON)
933        .expect("builtin_catalog.json failed to parse — fix the JSON, not this code")
934});
935
936fn builtin_catalog() -> Vec<ModelSchema> {
937    BUILTIN_CATALOG.clone()
938}
939
940#[cfg(test)]
941mod tests {
942    use super::*;
943    use tempfile::TempDir;
944
945    fn test_registry() -> (UnifiedRegistry, TempDir) {
946        let tmp = TempDir::new().unwrap();
947        let reg = UnifiedRegistry::new(tmp.path().join("models"));
948        (reg, tmp)
949    }
950
951    #[test]
952    fn builtin_catalog_loads() {
953        let (reg, _tmp) = test_registry();
954        let all = reg.list();
955        assert_eq!(all.len(), builtin_catalog().len());
956    }
957
958    /// #137: a model tagged `requires-mlx-vlm` must report
959    /// `available: true` if and only if `mlx_vlm_cli::is_available()`
960    /// returns true. Without this, registry consumers (FFI
961    /// `listModelsUnified`, the tray Models submenu, agent routing)
962    /// see a model as available, the user picks it, and inference
963    /// bails with `mlx-vlm CLI not found on PATH`.
964    ///
965    /// The probe is environmental — runs the same check on the host
966    /// the test executes on. CI usually doesn't have `mlx_vlm`
967    /// installed → expected unavailable; a dev box with it installed
968    /// → expected available. Either way, the registry tracks the
969    /// runtime probe.
970    #[test]
971    fn mlx_vlm_models_reflect_runtime_availability() {
972        let (reg, _tmp) = test_registry();
973        let mlx_vlm_models: Vec<&ModelSchema> = reg
974            .list()
975            .into_iter()
976            .filter(|m| m.tags.iter().any(|t| t == "requires-mlx-vlm"))
977            .collect();
978        assert!(
979            !mlx_vlm_models.is_empty(),
980            "catalog should contain at least one model tagged \
981             `requires-mlx-vlm` — otherwise this regression has \
982             nothing to guard"
983        );
984
985        #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
986        let expected = crate::backend::mlx_vlm_cli::is_available();
987        #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
988        let expected = false;
989
990        for m in mlx_vlm_models {
991            assert_eq!(
992                m.available, expected,
993                "model {} `available` field should reflect \
994                 mlx_vlm CLI presence (expected {expected}, got {})",
995                m.id, m.available
996            );
997        }
998    }
999
1000    /// Embedded JSON must parse cleanly — if it doesn't, the runtime
1001    /// would panic on first registry load. Catch it in CI instead.
1002    #[test]
1003    fn builtin_catalog_json_parses() {
1004        let catalog: Vec<ModelSchema> = serde_json::from_str(BUILTIN_CATALOG_JSON)
1005            .expect("builtin_catalog.json must be valid ModelSchema array");
1006        assert!(
1007            !catalog.is_empty(),
1008            "embedded catalog has no entries — that's almost certainly wrong"
1009        );
1010
1011        let mut seen = std::collections::HashSet::new();
1012        for entry in &catalog {
1013            assert!(
1014                seen.insert(entry.id.clone()),
1015                "duplicate id in builtin_catalog.json: {}",
1016                entry.id
1017            );
1018        }
1019    }
1020
1021    #[test]
1022    fn public_benchmarks_round_trip_through_model_info() {
1023        use crate::schema::BenchmarkScore;
1024        let (mut reg, _tmp) = test_registry();
1025        let mut schema = reg
1026            .find_by_name("Qwen3-4B")
1027            .expect("catalog has Qwen3-4B")
1028            .clone();
1029        schema.id = "test/qwen3-4b-with-bench".into();
1030        schema.public_benchmarks = vec![
1031            BenchmarkScore {
1032                name: "MMLU-Pro".into(),
1033                score: 0.482,
1034                harness: Some("5-shot CoT".into()),
1035                source_url: Some("https://example.invalid/qwen3-4b-card".into()),
1036                measured_at: Some("2025-08-12".into()),
1037            },
1038            BenchmarkScore {
1039                name: "HumanEval".into(),
1040                score: 0.713,
1041                harness: Some("pass@1".into()),
1042                source_url: None,
1043                measured_at: None,
1044            },
1045        ];
1046        reg.register(schema);
1047
1048        let stored = reg
1049            .get("test/qwen3-4b-with-bench")
1050            .expect("registered model is retrievable");
1051        let info = ModelInfo::from(stored);
1052        assert_eq!(info.public_benchmarks.len(), 2);
1053
1054        // The serialized JSON shape is what the WS / FFI clients consume.
1055        let json = serde_json::to_string(&info).unwrap();
1056        assert!(json.contains("\"public_benchmarks\""));
1057        assert!(json.contains("\"MMLU-Pro\""));
1058        assert!(json.contains("\"5-shot CoT\""));
1059
1060        // Round-trip back through serde to confirm deserialization works.
1061        let decoded: ModelInfo = serde_json::from_str(&json).unwrap();
1062        assert_eq!(decoded.public_benchmarks.len(), 2);
1063        assert_eq!(decoded.public_benchmarks[0].name, "MMLU-Pro");
1064        assert_eq!(decoded.public_benchmarks[1].name, "HumanEval");
1065    }
1066
1067    #[test]
1068    fn public_benchmarks_default_to_empty_when_absent_in_json() {
1069        // Older user-config JSON written before this field existed must
1070        // still deserialize cleanly into the new ModelSchema shape.
1071        let legacy_json = r#"{
1072            "id": "legacy/test:1",
1073            "name": "Legacy Test",
1074            "provider": "test",
1075            "family": "test",
1076            "version": "",
1077            "capabilities": ["generate"],
1078            "context_length": 4096,
1079            "param_count": "1B",
1080            "quantization": null,
1081            "performance": {},
1082            "cost": {},
1083            "source": { "type": "ollama", "model_tag": "legacy:1" },
1084            "tags": [],
1085            "supported_params": []
1086        }"#;
1087        let schema: ModelSchema = serde_json::from_str(legacy_json).unwrap();
1088        assert!(schema.public_benchmarks.is_empty());
1089    }
1090
1091    #[test]
1092    fn find_by_name() {
1093        let (reg, _tmp) = test_registry();
1094        let m = reg.find_by_name("Qwen3-4B").unwrap();
1095        #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1096        assert_eq!(m.id, "mlx/qwen3-4b:4bit");
1097        #[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
1098        assert_eq!(m.id, "qwen/qwen3-4b:q4_k_m");
1099        assert!(m.has_capability(ModelCapability::Code));
1100    }
1101
1102    #[test]
1103    fn query_by_capability() {
1104        let (reg, _tmp) = test_registry();
1105        let embed_models = reg.query_by_capability(ModelCapability::Embed);
1106        assert_eq!(embed_models.len(), 2);
1107        assert!(embed_models
1108            .iter()
1109            .any(|model| model.name == "Qwen3-Embedding-0.6B"));
1110        assert!(embed_models
1111            .iter()
1112            .any(|model| model.name == "Qwen3-Embedding-0.6B-MLX"));
1113    }
1114
1115    #[test]
1116    fn query_with_filter() {
1117        let (reg, _tmp) = test_registry();
1118        let code_small = reg.query(&ModelFilter {
1119            capabilities: vec![ModelCapability::Code],
1120            max_size_mb: Some(3000),
1121            local_only: true,
1122            ..Default::default()
1123        });
1124        // Qwen3-1.7B, Qwen3-1.7B-MLX, Qwen3-4B, and Qwen3-4B-MLX fit and have Code capability.
1125        assert_eq!(code_small.len(), 4);
1126    }
1127
1128    #[test]
1129    fn register_remote() {
1130        let (mut reg, _tmp) = test_registry();
1131        let initial_len = reg.list().len();
1132        let initial_reasoning_len = reg
1133            .query(&ModelFilter {
1134                capabilities: vec![ModelCapability::Reasoning, ModelCapability::ToolUse],
1135                ..Default::default()
1136            })
1137            .len();
1138        let remote = ModelSchema {
1139            id: "anthropic/claude-sonnet-4-6:latest".into(),
1140            name: "Claude Sonnet 4.6".into(),
1141            provider: "anthropic".into(),
1142            family: "claude-4".into(),
1143            version: "latest".into(),
1144            capabilities: vec![
1145                ModelCapability::Generate,
1146                ModelCapability::Code,
1147                ModelCapability::Reasoning,
1148                ModelCapability::ToolUse,
1149            ],
1150            context_length: 200000,
1151            param_count: String::new(),
1152            quantization: None,
1153            performance: PerformanceEnvelope {
1154                latency_p50_ms: Some(2000),
1155                ..Default::default()
1156            },
1157            cost: CostModel {
1158                input_per_mtok: Some(3.0),
1159                output_per_mtok: Some(15.0),
1160                ..Default::default()
1161            },
1162            source: ModelSource::RemoteApi {
1163                endpoint: "https://api.anthropic.com/v1/messages".into(),
1164                api_key_env: "ANTHROPIC_API_KEY".into(),
1165                api_key_envs: vec![],
1166                api_version: Some("2023-06-01".into()),
1167                protocol: ApiProtocol::Anthropic,
1168            },
1169            tags: vec![],
1170            supported_params: vec![],
1171            public_benchmarks: vec![],
1172            available: false,
1173        };
1174
1175        reg.register(remote);
1176        // Same ID as builtin claude-sonnet-4-6 — replaces, count stays same
1177        assert_eq!(reg.list().len(), initial_len);
1178
1179        let reasoning = reg.query(&ModelFilter {
1180            capabilities: vec![ModelCapability::Reasoning, ModelCapability::ToolUse],
1181            ..Default::default()
1182        });
1183        // Replacing an existing remote slot should not change the reasoning/tool-use lineup size.
1184        assert_eq!(reasoning.len(), initial_reasoning_len);
1185    }
1186
1187    #[test]
1188    fn unregister() {
1189        let (mut reg, _tmp) = test_registry();
1190        let initial_len = reg.list().len();
1191        let removed = reg.unregister("qwen/qwen3-0.6b:q8_0");
1192        assert!(removed.is_some());
1193        assert_eq!(reg.list().len(), initial_len - 1);
1194    }
1195
1196    #[test]
1197    fn speech_models_are_curated() {
1198        let (reg, _tmp) = test_registry();
1199        let stt = reg.query_by_capability(ModelCapability::SpeechToText);
1200        let tts = reg.query_by_capability(ModelCapability::TextToSpeech);
1201        assert_eq!(stt.len(), 2);
1202        assert_eq!(tts.len(), 4);
1203    }
1204
1205    #[test]
1206    fn qwen_8b_variants_keep_tool_use_consistent() {
1207        let (reg, _tmp) = test_registry();
1208        for name in ["Qwen3-8B", "Qwen3-8B-MLX"] {
1209            let model = reg.find_by_name(name).expect("model should exist");
1210            assert!(model.has_capability(ModelCapability::ToolUse));
1211            assert!(model.has_capability(ModelCapability::MultiToolCall));
1212        }
1213    }
1214
1215    #[test]
1216    fn mac_name_resolution_prefers_mlx_siblings() {
1217        // Only used inside the aarch64-macos cfg below; non-mac targets
1218        // keep the test as a smoke compile.
1219        #[allow(unused_variables)]
1220        let (reg, _tmp) = test_registry();
1221        #[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
1222        {
1223            assert_eq!(
1224                reg.find_by_name("Qwen3-0.6B").unwrap().id,
1225                "mlx/qwen3-0.6b:6bit"
1226            );
1227            assert_eq!(
1228                reg.find_by_name("Qwen3-1.7B").unwrap().id,
1229                "mlx/qwen3-1.7b:3bit"
1230            );
1231            assert_eq!(
1232                reg.find_by_name("Qwen3-Embedding-0.6B").unwrap().id,
1233                "mlx/qwen3-embedding-0.6b:mxfp8"
1234            );
1235        }
1236    }
1237
1238    #[test]
1239    fn remote_multimodal_models_are_curated_as_vision_capable() {
1240        let (reg, _tmp) = test_registry();
1241        for name in [
1242            "claude-opus-4-7",
1243            "claude-opus-4-6",
1244            "claude-sonnet-4-6",
1245            "claude-haiku-4-5",
1246            "gpt-5.4",
1247            "gpt-5.4-mini",
1248            "o3",
1249            "o4-mini",
1250            "gpt-4.1-mini",
1251            "gemini-2.5-pro",
1252            "gemini-2.5-flash",
1253        ] {
1254            let model = reg.find_by_name(name).expect("model should exist");
1255            assert!(
1256                model.has_capability(ModelCapability::Vision),
1257                "{name} should be curated as vision-capable"
1258            );
1259        }
1260    }
1261
1262    #[test]
1263    fn qwen25vl_entries_are_replaced_by_qwen3vl_in_builtin_catalog() {
1264        let (reg, _tmp) = test_registry();
1265
1266        let stale_ids = [
1267            // Native MLX text tower can't tokenize images — never advertise.
1268            "mlx/qwen2.5-vl-3b:4bit",
1269            "mlx/qwen2.5-vl-7b:4bit",
1270            // Qwen2.5-VL is superseded by Qwen3-VL; drop the mlx-vlm CLI
1271            // catalog entries so callers route to the upgraded family.
1272            "mlx-vlm/qwen2.5-vl-3b:4bit",
1273            "mlx-vlm/qwen2.5-vl-7b:4bit",
1274            // Same supersession applies to the vLLM-MLX route.
1275            "vllm-mlx/qwen2.5-vl-3b:4bit",
1276        ];
1277        for id in stale_ids {
1278            assert!(
1279                reg.get(id).is_none(),
1280                "{id} is superseded by Qwen3-VL; the catalog must not advertise it"
1281            );
1282        }
1283
1284        let vision_ids: Vec<&str> = reg
1285            .query_by_capability(ModelCapability::Vision)
1286            .into_iter()
1287            .map(|model| model.id.as_str())
1288            .collect();
1289        for stale in stale_ids {
1290            assert!(
1291                !vision_ids.contains(&stale),
1292                "{stale} must not be reachable through the Vision capability index"
1293            );
1294        }
1295        assert!(
1296            vision_ids.contains(&"mlx-vlm/qwen3-vl-2b:bf16"),
1297            "Qwen3-VL is the supported local VL family and must route as Vision"
1298        );
1299    }
1300
1301    #[test]
1302    fn gemini_models_are_curated_for_multimodal_tool_use() {
1303        let (reg, _tmp) = test_registry();
1304        for name in ["gemini-2.5-pro", "gemini-2.5-flash"] {
1305            let model = reg.find_by_name(name).expect("model should exist");
1306            assert!(model.has_capability(ModelCapability::Vision));
1307            assert!(model.has_capability(ModelCapability::ToolUse));
1308            assert!(model.has_capability(ModelCapability::MultiToolCall));
1309        }
1310    }
1311
1312    #[test]
1313    fn visual_generation_models_are_curated() {
1314        let (reg, _tmp) = test_registry();
1315        assert_eq!(
1316            reg.query_by_capability(ModelCapability::ImageGeneration)
1317                .len(),
1318            1
1319        );
1320        assert_eq!(
1321            reg.query_by_capability(ModelCapability::VideoGeneration)
1322                .len(),
1323            2
1324        );
1325        let yume = reg
1326            .get("mlx/yume-1.5-5b-720p:q4")
1327            .expect("Yume MLX should be in the built-in catalog");
1328        assert!(yume.has_capability(ModelCapability::VideoGeneration));
1329        assert!(yume.tags.contains(&"text-to-video".to_string()));
1330        assert!(yume.tags.contains(&"image-to-video".to_string()));
1331        assert!(yume.tags.contains(&"world-model".to_string()));
1332    }
1333}