Skip to main content

car_inference/
registry.rs

1//! Unified model registry — local and remote models under one schema.
2//!
3//! Replaces the hardcoded `ModelRegistry` from `models.rs` with a schema-driven
4//! registry that treats all models as first-class typed resources. Users can
5//! register custom models (fine-tuned endpoints, private APIs) alongside the
6//! built-in catalog.
7
8use crate::schema::reasoning_params;
9use std::collections::HashMap;
10use std::path::{Path, PathBuf};
11use std::time::SystemTime;
12
13use serde::{Deserialize, Serialize};
14use tracing::info;
15
16use crate::schema::*;
17use crate::InferenceError;
18
19/// Filter for querying the registry.
20#[derive(Debug, Clone, Default)]
21pub struct ModelFilter {
22    /// Required capabilities (model must have ALL of these).
23    pub capabilities: Vec<ModelCapability>,
24    /// Maximum on-disk / RAM size in MB.
25    pub max_size_mb: Option<u64>,
26    /// Maximum expected latency in ms (from declared envelope).
27    pub max_latency_ms: Option<u64>,
28    /// Maximum cost per 1M output tokens in USD.
29    pub max_cost_per_mtok: Option<f64>,
30    /// Required tags (model must have ALL of these).
31    pub tags: Vec<String>,
32    /// Filter by provider.
33    pub provider: Option<String>,
34    /// Only local models.
35    pub local_only: bool,
36    /// Only models that are currently available.
37    pub available_only: bool,
38}
39
40/// A curated replacement for a local model that is installed but no longer
41/// the preferred model in its line.
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct ModelUpgrade {
44    pub from_id: String,
45    pub from_name: String,
46    pub to_id: String,
47    pub to_name: String,
48    pub reason: String,
49    pub target_runtime: Option<String>,
50    pub target_runtime_requirement: Option<String>,
51    pub minimum_runtimes: Vec<ModelRuntimeRequirement>,
52    pub target_available: bool,
53    pub target_pullable: bool,
54    pub remove_old_supported: bool,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct ModelRuntimeRequirement {
59    pub name: String,
60    pub minimum_version: String,
61}
62
63/// Unified registry of all known models.
64pub struct UnifiedRegistry {
65    models_dir: PathBuf,
66    /// All registered models, keyed by id.
67    models: HashMap<String, ModelSchema>,
68    /// User-added model config file path (~/.car/models.json).
69    user_config_path: PathBuf,
70}
71
72#[derive(Debug, Clone, Deserialize)]
73struct ModelUpgradeRule {
74    from_ids: Vec<String>,
75    to_id: String,
76    reason: String,
77    target_runtime: Option<String>,
78    target_runtime_requirement: Option<String>,
79    #[serde(default)]
80    minimum_runtimes: Vec<ModelRuntimeRequirement>,
81    #[serde(default = "default_remove_old_after_available")]
82    remove_old_after_available: bool,
83}
84
85fn default_remove_old_after_available() -> bool {
86    true
87}
88
89fn model_upgrade_rules() -> Vec<ModelUpgradeRule> {
90    serde_json::from_str(include_str!("../assets/model-upgrades.json"))
91        .expect("built-in model-upgrades.json should parse")
92}
93
94impl UnifiedRegistry {
95    pub fn new(models_dir: PathBuf) -> Self {
96        let user_config_path = models_dir
97            .parent()
98            .unwrap_or(&models_dir)
99            .join("models.json");
100
101        let mut registry = Self {
102            models_dir,
103            models: HashMap::new(),
104            user_config_path,
105        };
106        registry.load_builtin_catalog();
107        registry.refresh_availability();
108        // Load user config on top (silently ignore if missing)
109        let _ = registry.load_user_config();
110        registry
111    }
112
113    /// Register a model at runtime.
114    pub fn register(&mut self, mut schema: ModelSchema) {
115        // Check availability for local models
116        if schema.is_mlx() {
117            schema.available = if schema.tags.contains(&"speech".to_string()) {
118                speech_mlx_available()
119            } else if let ModelSource::Mlx { ref hf_repo, .. } = schema.source {
120                let mlx_dir = self.models_dir.join(&schema.name);
121                mlx_dir.join("config.json").exists()
122                    || latest_huggingface_repo_snapshot(hf_repo).is_some()
123            } else {
124                let mlx_dir = self.models_dir.join(&schema.name);
125                mlx_dir.join("config.json").exists()
126            };
127        } else if schema.is_vllm_mlx() {
128            // vLLM-MLX: available if endpoint env var set or was manually marked available
129            schema.available = std::env::var("VLLM_MLX_ENDPOINT").is_ok() || schema.available;
130        } else if schema.is_local() {
131            let local_path = self.models_dir.join(&schema.name).join("model.gguf");
132            schema.available = local_path.exists();
133        } else if schema.is_remote() {
134            // Remote models are assumed available if the env var exists
135            if let ModelSource::RemoteApi {
136                ref api_key_env, ..
137            } = schema.source
138            {
139                schema.available = std::env::var(api_key_env).is_ok();
140            }
141        }
142        info!(id = %schema.id, name = %schema.name, available = schema.available, "registered model");
143        self.models.insert(schema.id.clone(), schema);
144    }
145
146    /// Unregister a model by id. Returns the removed schema if found.
147    pub fn unregister(&mut self, id: &str) -> Option<ModelSchema> {
148        let removed = self.models.remove(id);
149        if let Some(ref m) = removed {
150            info!(id = %m.id, "unregistered model");
151        }
152        removed
153    }
154
155    /// List all models.
156    pub fn list(&self) -> Vec<&ModelSchema> {
157        let mut models: Vec<&ModelSchema> = self.models.values().collect();
158        models.sort_by(|a, b| a.id.cmp(&b.id));
159        models
160    }
161
162    /// Query models matching a filter.
163    pub fn query(&self, filter: &ModelFilter) -> Vec<&ModelSchema> {
164        self.models
165            .values()
166            .filter(|m| {
167                // Capability check: model must have ALL required capabilities
168                if !filter.capabilities.iter().all(|c| m.has_capability(*c)) {
169                    return false;
170                }
171                // Size check
172                if let Some(max) = filter.max_size_mb {
173                    if m.size_mb() > max && m.is_local() {
174                        return false;
175                    }
176                }
177                // Latency check (declared envelope)
178                if let Some(max) = filter.max_latency_ms {
179                    if let Some(p50) = m.performance.latency_p50_ms {
180                        if p50 > max {
181                            return false;
182                        }
183                    }
184                }
185                // Cost check
186                if let Some(max) = filter.max_cost_per_mtok {
187                    if let Some(cost) = m.cost.output_per_mtok {
188                        if cost > max {
189                            return false;
190                        }
191                    }
192                }
193                // Tag check
194                if !filter.tags.iter().all(|t| m.tags.contains(t)) {
195                    return false;
196                }
197                // Provider check
198                if let Some(ref p) = filter.provider {
199                    if &m.provider != p {
200                        return false;
201                    }
202                }
203                // Local only
204                if filter.local_only && !m.is_local() {
205                    return false;
206                }
207                // Available only
208                if filter.available_only && !m.available {
209                    return false;
210                }
211                true
212            })
213            .collect()
214    }
215
216    /// Query models by a single capability.
217    pub fn query_by_capability(&self, cap: ModelCapability) -> Vec<&ModelSchema> {
218        self.query(&ModelFilter {
219            capabilities: vec![cap],
220            ..Default::default()
221        })
222    }
223
224    /// Report installed local models with curated newer replacements.
225    pub fn available_upgrades(&self) -> Vec<ModelUpgrade> {
226        let mut upgrades = Vec::new();
227        for rule in model_upgrade_rules() {
228            let Some(from) = rule
229                .from_ids
230                .iter()
231                .find_map(|id| self.models.get(id.as_str()))
232                .filter(|schema| schema.available)
233            else {
234                continue;
235            };
236            let Some(to) = self.models.get(rule.to_id.as_str()) else {
237                continue;
238            };
239            upgrades.push(ModelUpgrade {
240                from_id: from.id.clone(),
241                from_name: from.name.clone(),
242                to_id: to.id.clone(),
243                to_name: to.name.clone(),
244                reason: rule.reason.clone(),
245                target_runtime: rule.target_runtime.clone(),
246                target_runtime_requirement: rule.target_runtime_requirement.clone(),
247                minimum_runtimes: rule.minimum_runtimes.clone(),
248                target_available: to.available,
249                target_pullable: matches!(
250                    to.source,
251                    ModelSource::Local { .. } | ModelSource::Mlx { .. }
252                ),
253                remove_old_supported: matches!(
254                    from.source,
255                    ModelSource::Local { .. } | ModelSource::Mlx { .. }
256                ) && rule.remove_old_after_available,
257            });
258        }
259        upgrades.sort_by(|a, b| a.from_id.cmp(&b.from_id).then(a.to_id.cmp(&b.to_id)));
260        upgrades.dedup_by(|a, b| a.from_id == b.from_id && a.to_id == b.to_id);
261        upgrades
262    }
263
264    /// Get a specific model by id.
265    pub fn get(&self, id: &str) -> Option<&ModelSchema> {
266        self.models.get(id)
267    }
268
269    /// Find a model by name (case-insensitive). For backward compatibility
270    /// with the old registry that used short names like "Qwen3-4B".
271    pub fn find_by_name(&self, name: &str) -> Option<&ModelSchema> {
272        #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
273        if !name.to_ascii_lowercase().ends_with("-mlx") {
274            if let Some(mlx_variant) = self
275                .models
276                .values()
277                .find(|m| m.name.eq_ignore_ascii_case(&format!("{name}-MLX")))
278            {
279                return Some(mlx_variant);
280            }
281        }
282
283        self.models
284            .values()
285            .find(|m| m.name.eq_ignore_ascii_case(name))
286    }
287
288    /// On Apple Silicon, resolve a GGUF/Candle model to its MLX equivalent.
289    /// Returns the MLX model schema if one exists with the same family and
290    /// matching capabilities; otherwise returns None.
291    #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
292    pub fn resolve_mlx_equivalent(&self, schema: &ModelSchema) -> Option<&ModelSchema> {
293        // Already MLX — no redirect needed.
294        if schema.is_mlx() || schema.is_vllm_mlx() {
295            return None;
296        }
297        // Only redirect local GGUF models.
298        if !matches!(schema.source, ModelSource::Local { .. }) {
299            return None;
300        }
301        // Find an MLX model in the same family with at least the same primary capability.
302        let primary_cap = schema.capabilities.first()?;
303        self.models.values().find(|m| {
304            m.is_mlx()
305                && m.family == schema.family
306                && m.capabilities.contains(primary_cap)
307        })
308    }
309
310    /// Ensure a local model is downloaded, returning its local directory path.
311    pub async fn ensure_local(&self, id: &str) -> Result<PathBuf, InferenceError> {
312        let schema = self
313            .get(id)
314            .or_else(|| self.find_by_name(id))
315            .ok_or_else(|| InferenceError::ModelNotFound(id.to_string()))?;
316
317        match &schema.source {
318            ModelSource::Local {
319                hf_repo,
320                hf_filename,
321                tokenizer_repo,
322            } => {
323                let model_dir = self.models_dir.join(&schema.name);
324                let model_path = model_dir.join("model.gguf");
325                let tokenizer_path = model_dir.join("tokenizer.json");
326
327                if model_path.exists() && tokenizer_path.exists() {
328                    return Ok(model_dir);
329                }
330
331                std::fs::create_dir_all(&model_dir)?;
332
333                if !model_path.exists() {
334                    info!(model = %schema.name, repo = %hf_repo, "downloading model weights");
335                    download_file(hf_repo, hf_filename, &model_path).await?;
336                }
337                if !tokenizer_path.exists() {
338                    info!(model = %schema.name, repo = %tokenizer_repo, "downloading tokenizer");
339                    download_file(tokenizer_repo, "tokenizer.json", &tokenizer_path).await?;
340                }
341
342                Ok(model_dir)
343            }
344            ModelSource::Mlx {
345                hf_repo,
346                hf_weight_file,
347            } => {
348                let model_dir = self.models_dir.join(&schema.name);
349                let config_path = model_dir.join("config.json");
350
351                if config_path.exists() {
352                    ensure_auxiliary_mlx_files(&schema.name, hf_repo, &model_dir).await?;
353                    info!(model = %schema.name, path = %model_dir.display(), "using managed local MLX model");
354                    return Ok(model_dir);
355                }
356
357                if let Some(snapshot_dir) = latest_huggingface_repo_snapshot(hf_repo) {
358                    ensure_auxiliary_mlx_files(&schema.name, hf_repo, &snapshot_dir).await?;
359                    info!(model = %schema.name, path = %snapshot_dir.display(), "using cached MLX snapshot");
360                    return Ok(snapshot_dir);
361                }
362
363                if requires_full_mlx_snapshot(&schema) {
364                    info!(
365                        model = %schema.name,
366                        repo = %hf_repo,
367                        "downloading full MLX snapshot"
368                    );
369                    let (snapshot_dir, _files_downloaded) = download_hf_repo_snapshot(hf_repo).await?;
370                    ensure_auxiliary_mlx_files(&schema.name, hf_repo, &snapshot_dir).await?;
371                    return Ok(snapshot_dir);
372                }
373
374                std::fs::create_dir_all(&model_dir)?;
375
376                info!(model = %schema.name, repo = %hf_repo, "downloading MLX model");
377
378                // Download config, tokenizer, and weight files
379                download_file(hf_repo, "config.json", &config_path).await?;
380                let tok_path = model_dir.join("tokenizer.json");
381                if !tok_path.exists() {
382                    download_file(hf_repo, "tokenizer.json", &tok_path).await?;
383                }
384                let tok_config_path = model_dir.join("tokenizer_config.json");
385                if !tok_config_path.exists() {
386                    let _ = download_file(hf_repo, "tokenizer_config.json", &tok_config_path).await;
387                }
388
389                // Download weight files
390                if let Some(ref wf) = hf_weight_file {
391                    let wf_path = model_dir.join(wf);
392                    if !wf_path.exists() {
393                        download_file(hf_repo, wf, &wf_path).await?;
394                    }
395                } else {
396                    // Try single file first, then sharded
397                    let single = model_dir.join("model.safetensors");
398                    if !single.exists() {
399                        match download_file(hf_repo, "model.safetensors", &single).await {
400                            Ok(()) => {}
401                            Err(_) => {
402                                // Sharded: download index and then each shard
403                                let index_path = model_dir.join("model.safetensors.index.json");
404                                download_file(hf_repo, "model.safetensors.index.json", &index_path)
405                                    .await?;
406
407                                let index_json: serde_json::Value =
408                                    serde_json::from_str(&std::fs::read_to_string(&index_path)?)
409                                        .map_err(|e| {
410                                            InferenceError::InferenceFailed(format!(
411                                                "parse index: {e}"
412                                            ))
413                                        })?;
414
415                                if let Some(weight_map) =
416                                    index_json.get("weight_map").and_then(|m| m.as_object())
417                                {
418                                    let mut files: std::collections::HashSet<String> =
419                                        std::collections::HashSet::new();
420                                    for filename in weight_map.values() {
421                                        if let Some(f) = filename.as_str() {
422                                            files.insert(f.to_string());
423                                        }
424                                    }
425                                    for file in &files {
426                                        let dest = model_dir.join(file);
427                                        if !dest.exists() {
428                                            info!(file = %file, "downloading weight shard");
429                                            download_file(hf_repo, file, &dest).await?;
430                                        }
431                                    }
432                                }
433                            }
434                        }
435                    }
436                }
437
438                ensure_auxiliary_mlx_files(&schema.name, hf_repo, &model_dir).await?;
439                Ok(model_dir)
440            }
441            _ => Err(InferenceError::InferenceFailed(format!(
442                "model {} is not local",
443                id
444            ))),
445        }
446    }
447
448    /// Remove a downloaded local model.
449    pub fn remove_local(&mut self, id: &str) -> Result<(), InferenceError> {
450        let schema = self
451            .get(id)
452            .or_else(|| self.find_by_name(id))
453            .ok_or_else(|| InferenceError::ModelNotFound(id.to_string()))?;
454
455        let model_dir = self.models_dir.join(&schema.name);
456        if model_dir.exists() {
457            std::fs::remove_dir_all(&model_dir)?;
458            info!(model = %schema.name, "removed model");
459        }
460
461        match &schema.source {
462            ModelSource::Mlx { hf_repo, .. } => {
463                let repo_dir = huggingface_repo_dir(hf_repo);
464                if repo_dir.exists() {
465                    std::fs::remove_dir_all(&repo_dir)?;
466                    info!(model = %schema.name, repo = %hf_repo, "removed Hugging Face cache");
467                }
468            }
469            ModelSource::Local {
470                hf_repo,
471                tokenizer_repo,
472                ..
473            } => {
474                for repo in [hf_repo, tokenizer_repo] {
475                    let repo_dir = huggingface_repo_dir(repo);
476                    if repo_dir.exists() {
477                        std::fs::remove_dir_all(&repo_dir)?;
478                        info!(model = %schema.name, repo = %repo, "removed Hugging Face cache");
479                    }
480                }
481            }
482            _ => {}
483        }
484
485        // Update availability
486        let id = schema.id.clone();
487        if let Some(m) = self.models.get_mut(&id) {
488            m.available = false;
489        }
490        Ok(())
491    }
492
493    /// Refresh availability flags for all models.
494    ///
495    /// Runtime-true vs catalog-says: this is what closes the gap
496    /// `models.list_unified` callers rely on. If a model is listed
497    /// as `available: true` here, an `infer` call against it should
498    /// reach the backend, not bail with `UnsupportedMode { ...
499    /// "mlx-vlm CLI not found on PATH" }` (the #137 trap).
500    pub fn refresh_availability(&mut self) {
501        let models_dir = self.models_dir.clone();
502        // mlx-vlm CLI is the same probe call no matter which model
503        // requires it; do it once per refresh, not per-model.
504        #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
505        let mlx_vlm_cli_present = crate::backend::mlx_vlm_cli::is_available();
506        #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
507        let mlx_vlm_cli_present = false;
508
509        for m in self.models.values_mut() {
510            match &m.source {
511                ModelSource::Mlx { hf_repo, .. } => {
512                    // Models tagged `requires-mlx-vlm` shell out to the
513                    // mlx_vlm Python CLI for image inference (#115).
514                    // If the CLI isn't on PATH, the runtime reaches it
515                    // anyway and bails — the registry MUST reflect that
516                    // by marking such entries unavailable until the
517                    // user installs `uv tool install mlx-vlm`. #137.
518                    let needs_mlx_vlm =
519                        m.tags.iter().any(|t| t == "requires-mlx-vlm");
520
521                    m.available = if needs_mlx_vlm {
522                        mlx_vlm_cli_present
523                    } else if m.tags.contains(&"speech".to_string()) {
524                        speech_mlx_available()
525                    } else {
526                        let mlx_dir = models_dir.join(&m.name);
527                        mlx_dir.join("config.json").exists()
528                            || latest_huggingface_repo_snapshot(hf_repo).is_some()
529                    };
530                }
531                ModelSource::Local { .. } => {
532                    let local_path = models_dir.join(&m.name).join("model.gguf");
533                    m.available = local_path.exists();
534                }
535                ModelSource::RemoteApi { api_key_env, .. } => {
536                    m.available = std::env::var(api_key_env).is_ok();
537                }
538                ModelSource::Ollama { .. } => {
539                    // Assume available; health check is async and done lazily
540                    m.available = true;
541                }
542                ModelSource::VllmMlx { .. } => {
543                    // vLLM-MLX availability checked via health endpoint lazily
544                    // Mark as available if VLLM_MLX_ENDPOINT env var is set or default endpoint assumed
545                    m.available = std::env::var("VLLM_MLX_ENDPOINT").is_ok() || m.available;
546                    // preserve manual registration
547                }
548                ModelSource::Proprietary { auth, .. } => {
549                    // Check auth availability
550                    m.available = match auth {
551                        crate::schema::ProprietaryAuth::ApiKeyEnv { env_var } => {
552                            std::env::var(env_var).is_ok()
553                        }
554                        crate::schema::ProprietaryAuth::BearerTokenEnv { env_var } => {
555                            std::env::var(env_var).is_ok()
556                        }
557                        crate::schema::ProprietaryAuth::OAuth2Pkce { .. } => {
558                            // OAuth2 availability determined at runtime by token provider
559                            true
560                        }
561                    };
562                }
563                ModelSource::AppleFoundationModels { .. } => {
564                    // Apple Silicon macOS 26+ AND iOS 26+ both expose
565                    // the FoundationModels framework. The shim's
566                    // runtime probe handles per-device availability
567                    // (Apple Intelligence may be off, the device may
568                    // be pre-A17, etc.); cfg-gating here just hides
569                    // the call on targets where the shim isn't built.
570                    #[cfg(any(
571                        all(target_os = "macos", target_arch = "aarch64"),
572                        all(target_os = "ios", target_arch = "aarch64")
573                    ))]
574                    {
575                        m.available = crate::backend::foundation_models::is_available();
576                    }
577                    #[cfg(not(any(
578                        all(target_os = "macos", target_arch = "aarch64"),
579                        all(target_os = "ios", target_arch = "aarch64")
580                    )))]
581                    {
582                        m.available = false;
583                    }
584                }
585                ModelSource::Delegated { .. } => {
586                    // Availability tracks whether a runner is registered.
587                    // Hosts call `registerInferenceRunner` (or its
588                    // language equivalent) at startup; until then the
589                    // model is unavailable.
590                    m.available = crate::runner::current_inference_runner().is_some();
591                }
592            }
593        }
594    }
595
596    /// Persist user-registered (non-builtin) models to disk.
597    pub fn save_user_config(&self) -> Result<(), InferenceError> {
598        let user_models: Vec<&ModelSchema> = self
599            .models
600            .values()
601            .filter(|m| !m.tags.contains(&"builtin".to_string()))
602            .collect();
603
604        if user_models.is_empty() {
605            return Ok(());
606        }
607
608        let json = serde_json::to_string_pretty(&user_models)
609            .map_err(|e| InferenceError::InferenceFailed(format!("serialize: {e}")))?;
610        std::fs::write(&self.user_config_path, json)?;
611        Ok(())
612    }
613
614    /// Load user-registered models from disk.
615    pub fn load_user_config(&mut self) -> Result<(), InferenceError> {
616        if !self.user_config_path.exists() {
617            return Ok(());
618        }
619
620        let json = std::fs::read_to_string(&self.user_config_path)?;
621        let models: Vec<ModelSchema> = serde_json::from_str(&json)
622            .map_err(|e| InferenceError::InferenceFailed(format!("parse models.json: {e}")))?;
623
624        for m in models {
625            self.register(m);
626        }
627        Ok(())
628    }
629
630    /// Get the models directory path.
631    pub fn models_dir(&self) -> &Path {
632        &self.models_dir
633    }
634
635    /// Load the built-in Qwen3 catalog as ModelSchema objects.
636    fn load_builtin_catalog(&mut self) {
637        for schema in builtin_catalog() {
638            self.models.insert(schema.id.clone(), schema);
639        }
640    }
641}
642
643fn speech_mlx_available() -> bool {
644    // On Apple Silicon, speech uses native MLX backends — no Python CLI needed.
645    // Models are available if we're on the right platform (weights are downloaded on demand).
646    #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
647    { true }
648
649    // On other platforms, check for the Python mlx-audio CLI.
650    #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
651    {
652        let runtime_root = speech_runtime_root();
653        runtime_root.join("bin").join("mlx_audio.stt.generate").exists()
654            || runtime_root.join("bin").join("mlx_audio.tts.generate").exists()
655    }
656}
657
658fn speech_runtime_root() -> PathBuf {
659    if let Ok(path) = std::env::var("CAR_SPEECH_RUNTIME_DIR") {
660        if !path.trim().is_empty() {
661            return PathBuf::from(path);
662        }
663    }
664    std::env::var("HOME")
665        .map(PathBuf::from)
666        .unwrap_or_else(|_| PathBuf::from("."))
667        .join(".car")
668        .join("speech-runtime")
669}
670
671/// Backward-compatible ModelInfo for listing (used by CLI and old callers).
672#[derive(Debug, Clone, Serialize, Deserialize)]
673pub struct ModelInfo {
674    pub id: String,
675    pub name: String,
676    pub provider: String,
677    pub capabilities: Vec<ModelCapability>,
678    pub param_count: String,
679    pub size_mb: u64,
680    pub context_length: usize,
681    pub available: bool,
682    pub is_local: bool,
683    /// Public benchmark scores carried straight through from `ModelSchema`.
684    /// The built-in catalog ships this empty; populating it is a curation
685    /// step (see `BenchmarkScore` in the schema for shape and conventions).
686    #[serde(default)]
687    pub public_benchmarks: Vec<crate::schema::BenchmarkScore>,
688}
689
690impl From<&ModelSchema> for ModelInfo {
691    fn from(s: &ModelSchema) -> Self {
692        ModelInfo {
693            id: s.id.clone(),
694            name: s.name.clone(),
695            provider: s.provider.clone(),
696            capabilities: s.capabilities.clone(),
697            param_count: s.param_count.clone(),
698            size_mb: s.size_mb(),
699            context_length: s.context_length,
700            available: s.available,
701            is_local: s.is_local(),
702            public_benchmarks: s.public_benchmarks.clone(),
703        }
704    }
705}
706
707/// Download a single file from a HuggingFace repo.
708async fn download_file(repo: &str, filename: &str, dest: &Path) -> Result<(), InferenceError> {
709    let api = hf_hub::api::tokio::Api::new()
710        .map_err(|e| InferenceError::DownloadFailed(e.to_string()))?;
711
712    let repo = api.model(repo.to_string());
713    let path = repo
714        .get(filename)
715        .await
716        .map_err(|e| InferenceError::DownloadFailed(format!("{filename}: {e}")))?;
717
718    if dest.exists() {
719        return Ok(());
720    }
721
722    // Try symlink first, fall back to copy
723    #[cfg(unix)]
724    {
725        if std::os::unix::fs::symlink(&path, dest).is_ok() {
726            return Ok(());
727        }
728    }
729
730    std::fs::copy(&path, dest)
731        .map_err(|e| InferenceError::DownloadFailed(format!("copy to {}: {e}", dest.display())))?;
732    Ok(())
733}
734
735async fn ensure_auxiliary_mlx_files(
736    model_name: &str,
737    hf_repo: &str,
738    model_dir: &Path,
739) -> Result<(), InferenceError> {
740    if hf_repo == "mlx-community/Flux-1.lite-8B-MLX-Q4" || model_name == "Flux-1.lite-8B-MLX-Q4" {
741        let t5_tokenizer_path = model_dir.join("tokenizer_2").join("tokenizer.json");
742        if !t5_tokenizer_path.exists() {
743            std::fs::create_dir_all(
744                t5_tokenizer_path
745                    .parent()
746                    .ok_or_else(|| InferenceError::InferenceFailed("invalid tokenizer path".into()))?,
747            )?;
748            info!(
749                path = %t5_tokenizer_path.display(),
750                "downloading missing Flux tokenizer_2/tokenizer.json from base model"
751            );
752            download_file("Freepik/flux.1-lite-8B", "tokenizer_2/tokenizer.json", &t5_tokenizer_path)
753                .await?;
754        }
755    }
756    Ok(())
757}
758
759fn requires_full_mlx_snapshot(schema: &ModelSchema) -> bool {
760    match &schema.source {
761        ModelSource::Mlx { hf_repo, .. } => {
762            hf_repo == "ckurasek/Yume-1.5-5B-720P-MLX-4bit"
763                || schema.family.starts_with("yume")
764                || schema.tags.iter().any(|tag| {
765                    matches!(
766                        tag.as_str(),
767                        "wan2.2" | "ti2v" | "world-model" | "image-to-video"
768                    )
769                })
770        }
771        _ => false,
772    }
773}
774
775fn huggingface_repo_has_snapshot(repo_id: &str) -> bool {
776    latest_huggingface_repo_snapshot(repo_id).is_some()
777}
778
779fn huggingface_cache_root() -> PathBuf {
780    std::env::var("HF_HOME")
781        .map(PathBuf::from)
782        .unwrap_or_else(|_| {
783            std::env::var("HOME")
784                .map(PathBuf::from)
785                .unwrap_or_else(|_| PathBuf::from("."))
786                .join(".cache")
787                .join("huggingface")
788        })
789        .join("hub")
790}
791
792fn huggingface_repo_dir(repo_id: &str) -> PathBuf {
793    huggingface_cache_root().join(format!("models--{}", repo_id.replace('/', "--")))
794}
795
796fn resolve_huggingface_ref_snapshot(repo_dir: &Path, name: &str) -> Option<PathBuf> {
797    let sha = std::fs::read_to_string(repo_dir.join("refs").join(name))
798        .ok()?
799        .trim()
800        .to_string();
801    if sha.is_empty() {
802        return None;
803    }
804
805    let snapshot = repo_dir.join("snapshots").join(sha);
806    if snapshot_looks_ready(&snapshot) {
807        Some(snapshot)
808    } else {
809        None
810    }
811}
812
813fn latest_huggingface_repo_snapshot(repo_id: &str) -> Option<PathBuf> {
814    let repo_dir = huggingface_repo_dir(repo_id);
815    if let Some(snapshot) = resolve_huggingface_ref_snapshot(&repo_dir, "main") {
816        return Some(snapshot);
817    }
818
819    let snapshots = repo_dir.join("snapshots");
820    let mut candidates: Vec<(SystemTime, PathBuf)> = std::fs::read_dir(snapshots)
821        .ok()?
822        .filter_map(Result::ok)
823        .map(|e| e.path())
824        .filter(|p| p.is_dir() && snapshot_looks_ready(p))
825        .map(|path| {
826            let modified = path
827                .metadata()
828                .and_then(|metadata| metadata.modified())
829                .unwrap_or(SystemTime::UNIX_EPOCH);
830            (modified, path)
831        })
832        .collect();
833    candidates.sort();
834    candidates.pop().map(|(_, path)| path)
835}
836
837fn snapshot_looks_ready(path: &Path) -> bool {
838    if path.join("config.json").exists() || path.join("model_index.json").exists() {
839        return true;
840    }
841    snapshot_contains_ext(path, "safetensors")
842}
843
844fn snapshot_contains_ext(root: &Path, ext: &str) -> bool {
845    let Ok(entries) = std::fs::read_dir(root) else {
846        return false;
847    };
848    entries.filter_map(Result::ok).any(|entry| {
849        let path = entry.path();
850        if path.is_dir() {
851            snapshot_contains_ext(&path, ext)
852        } else {
853            path.extension()
854                .and_then(|value| value.to_str())
855                .map(|value| value.eq_ignore_ascii_case(ext))
856                .unwrap_or(false)
857        }
858    })
859}
860
861async fn download_hf_repo_snapshot(repo_id: &str) -> Result<(PathBuf, usize), InferenceError> {
862    let api = hf_hub::api::tokio::ApiBuilder::from_env()
863        .with_progress(false)
864        .build()
865        .map_err(|e| InferenceError::DownloadFailed(format!("init hf api: {e}")))?;
866    let repo = api.model(repo_id.to_string());
867    let info = repo
868        .info()
869        .await
870        .map_err(|e| InferenceError::DownloadFailed(format!("{repo_id}: {e}")))?;
871
872    let snapshot_path = std::env::var("HF_HOME")
873        .map(PathBuf::from)
874        .unwrap_or_else(|_| {
875            std::env::var("HOME")
876                .map(PathBuf::from)
877                .unwrap_or_else(|_| PathBuf::from("."))
878                .join(".cache")
879                .join("huggingface")
880        })
881        .join("hub")
882        .join(format!("models--{}", repo_id.replace('/', "--")))
883        .join("snapshots")
884        .join(&info.sha);
885    let mut downloaded = 0usize;
886    for sibling in &info.siblings {
887        let local_path = snapshot_path.join(&sibling.rfilename);
888        if local_path.exists() {
889            downloaded += 1;
890            continue;
891        }
892        repo.download(&sibling.rfilename).await.map_err(|e| {
893            InferenceError::DownloadFailed(format!("{repo_id}/{}: {e}", sibling.rfilename))
894        })?;
895        downloaded += 1;
896    }
897
898    Ok((snapshot_path, downloaded))
899}
900
901/// Built-in catalog parsed from `builtin_catalog.json`.
902///
903/// Adding, removing, or editing a model is a JSON-only change — Rust
904/// source stays put. The JSON is embedded at compile time via
905/// `include_str!`, parsed once into a `LazyLock`, and cloned on each
906/// call. A malformed JSON file fails the integration test
907/// `builtin_catalog_json_parses` so the binary never ships unable
908/// to load its own catalog.
909const BUILTIN_CATALOG_JSON: &str = include_str!("builtin_catalog.json");
910
911static BUILTIN_CATALOG: std::sync::LazyLock<Vec<ModelSchema>> =
912    std::sync::LazyLock::new(|| {
913        serde_json::from_str(BUILTIN_CATALOG_JSON)
914            .expect("builtin_catalog.json failed to parse — fix the JSON, not this code")
915    });
916
917fn builtin_catalog() -> Vec<ModelSchema> {
918    BUILTIN_CATALOG.clone()
919}
920
921#[cfg(test)]
922mod tests {
923    use super::*;
924    use tempfile::TempDir;
925
926    fn test_registry() -> (UnifiedRegistry, TempDir) {
927        let tmp = TempDir::new().unwrap();
928        let reg = UnifiedRegistry::new(tmp.path().join("models"));
929        (reg, tmp)
930    }
931
932    #[test]
933    fn builtin_catalog_loads() {
934        let (reg, _tmp) = test_registry();
935        let all = reg.list();
936        assert_eq!(all.len(), builtin_catalog().len());
937    }
938
939    /// #137: a model tagged `requires-mlx-vlm` must report
940    /// `available: true` if and only if `mlx_vlm_cli::is_available()`
941    /// returns true. Without this, registry consumers (FFI
942    /// `listModelsUnified`, the tray Models submenu, agent routing)
943    /// see a model as available, the user picks it, and inference
944    /// bails with `mlx-vlm CLI not found on PATH`.
945    ///
946    /// The probe is environmental — runs the same check on the host
947    /// the test executes on. CI usually doesn't have `mlx_vlm`
948    /// installed → expected unavailable; a dev box with it installed
949    /// → expected available. Either way, the registry tracks the
950    /// runtime probe.
951    #[test]
952    fn mlx_vlm_models_reflect_runtime_availability() {
953        let (reg, _tmp) = test_registry();
954        let mlx_vlm_models: Vec<&ModelSchema> = reg
955            .list()
956            .into_iter()
957            .filter(|m| m.tags.iter().any(|t| t == "requires-mlx-vlm"))
958            .collect();
959        assert!(
960            !mlx_vlm_models.is_empty(),
961            "catalog should contain at least one model tagged \
962             `requires-mlx-vlm` — otherwise this regression has \
963             nothing to guard"
964        );
965
966        #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
967        let expected = crate::backend::mlx_vlm_cli::is_available();
968        #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
969        let expected = false;
970
971        for m in mlx_vlm_models {
972            assert_eq!(
973                m.available, expected,
974                "model {} `available` field should reflect \
975                 mlx_vlm CLI presence (expected {expected}, got {})",
976                m.id, m.available
977            );
978        }
979    }
980
981    /// Embedded JSON must parse cleanly — if it doesn't, the runtime
982    /// would panic on first registry load. Catch it in CI instead.
983    #[test]
984    fn builtin_catalog_json_parses() {
985        let catalog: Vec<ModelSchema> = serde_json::from_str(BUILTIN_CATALOG_JSON)
986            .expect("builtin_catalog.json must be valid ModelSchema array");
987        assert!(
988            !catalog.is_empty(),
989            "embedded catalog has no entries — that's almost certainly wrong"
990        );
991
992        let mut seen = std::collections::HashSet::new();
993        for entry in &catalog {
994            assert!(
995                seen.insert(entry.id.clone()),
996                "duplicate id in builtin_catalog.json: {}",
997                entry.id
998            );
999        }
1000    }
1001
1002    #[test]
1003    fn public_benchmarks_round_trip_through_model_info() {
1004        use crate::schema::BenchmarkScore;
1005        let (mut reg, _tmp) = test_registry();
1006        let mut schema = reg
1007            .find_by_name("Qwen3-4B")
1008            .expect("catalog has Qwen3-4B")
1009            .clone();
1010        schema.id = "test/qwen3-4b-with-bench".into();
1011        schema.public_benchmarks = vec![
1012            BenchmarkScore {
1013                name: "MMLU-Pro".into(),
1014                score: 0.482,
1015                harness: Some("5-shot CoT".into()),
1016                source_url: Some("https://example.invalid/qwen3-4b-card".into()),
1017                measured_at: Some("2025-08-12".into()),
1018            },
1019            BenchmarkScore {
1020                name: "HumanEval".into(),
1021                score: 0.713,
1022                harness: Some("pass@1".into()),
1023                source_url: None,
1024                measured_at: None,
1025            },
1026        ];
1027        reg.register(schema);
1028
1029        let stored = reg
1030            .get("test/qwen3-4b-with-bench")
1031            .expect("registered model is retrievable");
1032        let info = ModelInfo::from(stored);
1033        assert_eq!(info.public_benchmarks.len(), 2);
1034
1035        // The serialized JSON shape is what the WS / FFI clients consume.
1036        let json = serde_json::to_string(&info).unwrap();
1037        assert!(json.contains("\"public_benchmarks\""));
1038        assert!(json.contains("\"MMLU-Pro\""));
1039        assert!(json.contains("\"5-shot CoT\""));
1040
1041        // Round-trip back through serde to confirm deserialization works.
1042        let decoded: ModelInfo = serde_json::from_str(&json).unwrap();
1043        assert_eq!(decoded.public_benchmarks.len(), 2);
1044        assert_eq!(decoded.public_benchmarks[0].name, "MMLU-Pro");
1045        assert_eq!(decoded.public_benchmarks[1].name, "HumanEval");
1046    }
1047
1048    #[test]
1049    fn public_benchmarks_default_to_empty_when_absent_in_json() {
1050        // Older user-config JSON written before this field existed must
1051        // still deserialize cleanly into the new ModelSchema shape.
1052        let legacy_json = r#"{
1053            "id": "legacy/test:1",
1054            "name": "Legacy Test",
1055            "provider": "test",
1056            "family": "test",
1057            "version": "",
1058            "capabilities": ["generate"],
1059            "context_length": 4096,
1060            "param_count": "1B",
1061            "quantization": null,
1062            "performance": {},
1063            "cost": {},
1064            "source": { "type": "ollama", "model_tag": "legacy:1" },
1065            "tags": [],
1066            "supported_params": []
1067        }"#;
1068        let schema: ModelSchema = serde_json::from_str(legacy_json).unwrap();
1069        assert!(schema.public_benchmarks.is_empty());
1070    }
1071
1072    #[test]
1073    fn find_by_name() {
1074        let (reg, _tmp) = test_registry();
1075        let m = reg.find_by_name("Qwen3-4B").unwrap();
1076        #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
1077        assert_eq!(m.id, "mlx/qwen3-4b:4bit");
1078        #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
1079        assert_eq!(m.id, "qwen/qwen3-4b:q4_k_m");
1080        assert!(m.has_capability(ModelCapability::Code));
1081    }
1082
1083    #[test]
1084    fn query_by_capability() {
1085        let (reg, _tmp) = test_registry();
1086        let embed_models = reg.query_by_capability(ModelCapability::Embed);
1087        assert_eq!(embed_models.len(), 2);
1088        assert!(embed_models
1089            .iter()
1090            .any(|model| model.name == "Qwen3-Embedding-0.6B"));
1091        assert!(embed_models
1092            .iter()
1093            .any(|model| model.name == "Qwen3-Embedding-0.6B-MLX"));
1094    }
1095
1096    #[test]
1097    fn query_with_filter() {
1098        let (reg, _tmp) = test_registry();
1099        let code_small = reg.query(&ModelFilter {
1100            capabilities: vec![ModelCapability::Code],
1101            max_size_mb: Some(3000),
1102            local_only: true,
1103            ..Default::default()
1104        });
1105        // Qwen3-1.7B, Qwen3-1.7B-MLX, Qwen3-4B, and Qwen3-4B-MLX fit and have Code capability.
1106        assert_eq!(code_small.len(), 4);
1107    }
1108
1109    #[test]
1110    fn register_remote() {
1111        let (mut reg, _tmp) = test_registry();
1112        let initial_len = reg.list().len();
1113        let initial_reasoning_len = reg
1114            .query(&ModelFilter {
1115                capabilities: vec![ModelCapability::Reasoning, ModelCapability::ToolUse],
1116                ..Default::default()
1117            })
1118            .len();
1119        let remote = ModelSchema {
1120            id: "anthropic/claude-sonnet-4-6:latest".into(),
1121            name: "Claude Sonnet 4.6".into(),
1122            provider: "anthropic".into(),
1123            family: "claude-4".into(),
1124            version: "latest".into(),
1125            capabilities: vec![
1126                ModelCapability::Generate,
1127                ModelCapability::Code,
1128                ModelCapability::Reasoning,
1129                ModelCapability::ToolUse,
1130            ],
1131            context_length: 200000,
1132            param_count: String::new(),
1133            quantization: None,
1134            performance: PerformanceEnvelope {
1135                latency_p50_ms: Some(2000),
1136                ..Default::default()
1137            },
1138            cost: CostModel {
1139                input_per_mtok: Some(3.0),
1140                output_per_mtok: Some(15.0),
1141                ..Default::default()
1142            },
1143            source: ModelSource::RemoteApi {
1144                endpoint: "https://api.anthropic.com/v1/messages".into(),
1145                api_key_env: "ANTHROPIC_API_KEY".into(),
1146                api_key_envs: vec![],
1147                api_version: Some("2023-06-01".into()),
1148                protocol: ApiProtocol::Anthropic,
1149            },
1150            tags: vec![],
1151            supported_params: vec![],
1152            public_benchmarks: vec![],
1153            available: false,
1154        };
1155
1156        reg.register(remote);
1157        // Same ID as builtin claude-sonnet-4-6 — replaces, count stays same
1158        assert_eq!(reg.list().len(), initial_len);
1159
1160        let reasoning = reg.query(&ModelFilter {
1161            capabilities: vec![ModelCapability::Reasoning, ModelCapability::ToolUse],
1162            ..Default::default()
1163        });
1164        // Replacing an existing remote slot should not change the reasoning/tool-use lineup size.
1165        assert_eq!(reasoning.len(), initial_reasoning_len);
1166    }
1167
1168    #[test]
1169    fn unregister() {
1170        let (mut reg, _tmp) = test_registry();
1171        let initial_len = reg.list().len();
1172        let removed = reg.unregister("qwen/qwen3-0.6b:q8_0");
1173        assert!(removed.is_some());
1174        assert_eq!(reg.list().len(), initial_len - 1);
1175    }
1176
1177    #[test]
1178    fn speech_models_are_curated() {
1179        let (reg, _tmp) = test_registry();
1180        let stt = reg.query_by_capability(ModelCapability::SpeechToText);
1181        let tts = reg.query_by_capability(ModelCapability::TextToSpeech);
1182        assert_eq!(stt.len(), 2);
1183        assert_eq!(tts.len(), 4);
1184    }
1185
1186    #[test]
1187    fn qwen_8b_variants_keep_tool_use_consistent() {
1188        let (reg, _tmp) = test_registry();
1189        for name in ["Qwen3-8B", "Qwen3-8B-MLX"] {
1190            let model = reg.find_by_name(name).expect("model should exist");
1191            assert!(model.has_capability(ModelCapability::ToolUse));
1192            assert!(model.has_capability(ModelCapability::MultiToolCall));
1193        }
1194    }
1195
1196    #[test]
1197    fn mac_name_resolution_prefers_mlx_siblings() {
1198        // Only used inside the aarch64-macos cfg below; non-mac targets
1199        // keep the test as a smoke compile.
1200        #[allow(unused_variables)]
1201        let (reg, _tmp) = test_registry();
1202        #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
1203        {
1204            assert_eq!(reg.find_by_name("Qwen3-0.6B").unwrap().id, "mlx/qwen3-0.6b:6bit");
1205            assert_eq!(reg.find_by_name("Qwen3-1.7B").unwrap().id, "mlx/qwen3-1.7b:3bit");
1206            assert_eq!(
1207                reg.find_by_name("Qwen3-Embedding-0.6B").unwrap().id,
1208                "mlx/qwen3-embedding-0.6b:mxfp8"
1209            );
1210        }
1211    }
1212
1213    #[test]
1214    fn remote_multimodal_models_are_curated_as_vision_capable() {
1215        let (reg, _tmp) = test_registry();
1216        for name in [
1217            "claude-opus-4-7",
1218            "claude-opus-4-6",
1219            "claude-sonnet-4-6",
1220            "claude-haiku-4-5",
1221            "gpt-5.4",
1222            "gpt-5.4-mini",
1223            "o3",
1224            "o4-mini",
1225            "gpt-4.1-mini",
1226            "gemini-2.5-pro",
1227            "gemini-2.5-flash",
1228        ] {
1229            let model = reg.find_by_name(name).expect("model should exist");
1230            assert!(
1231                model.has_capability(ModelCapability::Vision),
1232                "{name} should be curated as vision-capable"
1233            );
1234        }
1235    }
1236
1237    #[test]
1238    fn qwen25vl_entries_are_replaced_by_qwen3vl_in_builtin_catalog() {
1239        let (reg, _tmp) = test_registry();
1240
1241        let stale_ids = [
1242            // Native MLX text tower can't tokenize images — never advertise.
1243            "mlx/qwen2.5-vl-3b:4bit",
1244            "mlx/qwen2.5-vl-7b:4bit",
1245            // Qwen2.5-VL is superseded by Qwen3-VL; drop the mlx-vlm CLI
1246            // catalog entries so callers route to the upgraded family.
1247            "mlx-vlm/qwen2.5-vl-3b:4bit",
1248            "mlx-vlm/qwen2.5-vl-7b:4bit",
1249            // Same supersession applies to the vLLM-MLX route.
1250            "vllm-mlx/qwen2.5-vl-3b:4bit",
1251        ];
1252        for id in stale_ids {
1253            assert!(
1254                reg.get(id).is_none(),
1255                "{id} is superseded by Qwen3-VL; the catalog must not advertise it"
1256            );
1257        }
1258
1259        let vision_ids: Vec<&str> = reg
1260            .query_by_capability(ModelCapability::Vision)
1261            .into_iter()
1262            .map(|model| model.id.as_str())
1263            .collect();
1264        for stale in stale_ids {
1265            assert!(
1266                !vision_ids.contains(&stale),
1267                "{stale} must not be reachable through the Vision capability index"
1268            );
1269        }
1270        assert!(
1271            vision_ids.contains(&"mlx-vlm/qwen3-vl-2b:bf16"),
1272            "Qwen3-VL is the supported local VL family and must route as Vision"
1273        );
1274    }
1275
1276    #[test]
1277    fn gemini_models_are_curated_for_multimodal_tool_use() {
1278        let (reg, _tmp) = test_registry();
1279        for name in ["gemini-2.5-pro", "gemini-2.5-flash"] {
1280            let model = reg.find_by_name(name).expect("model should exist");
1281            assert!(model.has_capability(ModelCapability::Vision));
1282            assert!(model.has_capability(ModelCapability::ToolUse));
1283            assert!(model.has_capability(ModelCapability::MultiToolCall));
1284        }
1285    }
1286
1287    #[test]
1288    fn visual_generation_models_are_curated() {
1289        let (reg, _tmp) = test_registry();
1290        assert_eq!(
1291            reg.query_by_capability(ModelCapability::ImageGeneration).len(),
1292            1
1293        );
1294        assert_eq!(
1295            reg.query_by_capability(ModelCapability::VideoGeneration).len(),
1296            2
1297        );
1298        let yume = reg
1299            .get("mlx/yume-1.5-5b-720p:q4")
1300            .expect("Yume MLX should be in the built-in catalog");
1301        assert!(yume.has_capability(ModelCapability::VideoGeneration));
1302        assert!(yume.tags.contains(&"text-to-video".to_string()));
1303        assert!(yume.tags.contains(&"image-to-video".to_string()));
1304        assert!(yume.tags.contains(&"world-model".to_string()));
1305    }
1306}