Skip to main content

rlx_cli/
auto_dispatch.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Auto-dispatch: pick a registered model runner from a weights path.
17//!
18//! `auto_runner_name(path)` resolves the path (file or directory), sniffs
19//! the model family (GGUF `general.architecture` for `.gguf`, sidecar
20//! `config.json` `model_type` for safetensors), and maps it to the short
21//! runner name a callsite registered with [`register_cli`](crate::register_cli)
22//! (e.g. `"qwen3"`, `"gemma"`).
23//!
24//! `auto_dispatch(path, args)` is a one-shot: sniff, look up, run.
25//!
26//! Used by `skill` so callers don't need to hardcode `Qwen3Runner` vs
27//! `GemmaRunner` per family.
28
29use anyhow::{Context, Result, anyhow, bail};
30use rlx_core::gguf_config::{
31    DINOV2_GGUF_ARCHES, FLUX_GGUF_ARCHES, SAM_GGUF_ARCHES, SAM2_GGUF_ARCHES, SAM3_GGUF_ARCHES,
32    VJEPA2_GGUF_ARCHES, W2V_BERT_GGUF_ARCHES,
33};
34use rlx_core::gguf_support::{
35    gguf_architecture_from_path, gguf_family_for_arch, resolve_weights_file,
36};
37use std::path::{Path, PathBuf};
38
39use crate::registry::run_registered;
40
41/// Entry point for an `rlx-run auto WEIGHTS [args...]` subcommand.
42///
43/// Treats the first positional as the weights path (file or directory),
44/// sniffs the runner, and forwards the remaining args to it. The
45/// canonical wiring is `register_cli("auto", "...", rlx_cli::run_auto)`
46/// in the multiplexer.
47pub fn run_auto(args: &[String]) -> Result<()> {
48    let Some(first) = args.first() else {
49        bail!(
50            "auto: expected WEIGHTS path as the first argument\n\
51             usage: rlx-run auto <weights-path> [runner-args...]"
52        );
53    };
54    if matches!(first.as_str(), "-h" | "--help" | "help") {
55        println!(
56            "rlx-run auto — sniff a GGUF / safetensors file and dispatch to the right runner\n\
57             \n\
58             USAGE:\n  rlx-run auto <weights-path> [runner-args...]\n\
59             \n\
60             The first argument is forwarded as the runner's --weights value;\n\
61             remaining arguments are passed through unchanged."
62        );
63        return Ok(());
64    }
65    let path = Path::new(first);
66    let sniff = auto_sniff(path)?;
67    eprintln!(
68        "[rlx-run auto] {} → runner `{}` (from {:?})",
69        sniff.path.display(),
70        sniff.runner_name,
71        sniff.from
72    );
73    // Re-build argv: most per-family runners take `--weights PATH`. If the
74    // caller already passed --weights, don't double it; otherwise inject.
75    let rest: Vec<String> = args[1..].to_vec();
76    let has_weights_flag = rest
77        .iter()
78        .any(|a| a == "--weights" || a.starts_with("--weights="));
79    let mut forwarded: Vec<String> = Vec::with_capacity(rest.len() + 2);
80    if !has_weights_flag {
81        forwarded.push("--weights".into());
82        forwarded.push(sniff.path.display().to_string());
83    }
84    forwarded.extend(rest);
85    match run_registered(sniff.runner_name, &forwarded)? {
86        Some(()) => Ok(()),
87        None => bail!(
88            "auto: runner `{}` not registered (sniffed from {:?}); register it via \
89             `register_cli` in your binary's main",
90            sniff.runner_name,
91            sniff.from
92        ),
93    }
94}
95
96/// Source the sniffer used to identify the model family.
97#[derive(Debug, Clone)]
98pub enum SniffedFrom {
99    /// `general.architecture` value read from a `.gguf` file.
100    GgufArch(String),
101    /// `model_type` value read from a sidecar `config.json`.
102    SafetensorsConfig(String),
103}
104
105/// Result of sniffing a weights path.
106#[derive(Debug, Clone)]
107pub struct SniffedRunner {
108    /// Concrete file we sniffed (after resolving a directory).
109    pub path: PathBuf,
110    /// Short runner name as registered with `register_cli`.
111    pub runner_name: &'static str,
112    /// Where the sniff came from — useful for diagnostics.
113    pub from: SniffedFrom,
114}
115
116/// A catalog arch that RLX recognizes but has not yet implemented a runner
117/// for. Returned by [`known_unimplemented_arch`] so error messages can point
118/// at the PLAN.md milestone that unblocks the family.
119#[derive(Debug, Clone, Copy, PartialEq, Eq)]
120pub struct UnimplementedArch {
121    /// Display name (e.g. `"Mistral 3.5"`).
122    pub family: &'static str,
123    /// PLAN.md milestone tag (e.g. `"M4"`).
124    pub milestone: &'static str,
125    /// One-line note for the user.
126    pub note: &'static str,
127}
128
129/// Family-level metadata referenced by [`KNOWN_UNIMPLEMENTED`]. Static so
130/// the phf map can hold `&'static UnimplementedArch`.
131mod families {
132    use super::UnimplementedArch;
133    pub static MISTRAL: UnimplementedArch = UnimplementedArch {
134        family: "Mistral 3+ / Ministral",
135        milestone: "M4",
136        note: "Llama-shaped with newer RoPE; share `rlx-llama-base` per PLAN.md M4",
137    };
138    pub static PHI: UnimplementedArch = UnimplementedArch {
139        family: "Phi 3 / Phi 4",
140        milestone: "M4",
141        note: "Phi3/4 share llama.cpp arch tag — PLAN.md M4",
142    };
143    pub static PHIMOE: UnimplementedArch = UnimplementedArch {
144        family: "Phi MoE",
145        milestone: "M4 + M5",
146        note: "Phi + MoE routing; depends on shared MoE block — PLAN.md M4/M5",
147    };
148    pub static BONSAI: UnimplementedArch = UnimplementedArch {
149        family: "Bonsai",
150        milestone: "M4",
151        note: "Llama-shaped; HF model_type only — usually ships as llama GGUF — PLAN.md M4",
152    };
153    pub static OMNICODER: UnimplementedArch = UnimplementedArch {
154        family: "OmniCoder",
155        milestone: "M4",
156        note: "Qwen3-coder shaped — PLAN.md M4 (often tagged `qwen3` in GGUF)",
157    };
158    pub static MINIMAX: UnimplementedArch = UnimplementedArch {
159        family: "MiniMax M2",
160        milestone: "M5",
161        note: "Lightning Attention; depends on `rlx-ssm` upstream — PLAN.md M5",
162    };
163    pub static GLM: UnimplementedArch = UnimplementedArch {
164        family: "GLM 4 / 5",
165        milestone: "M5",
166        note: "GLM RoPE + RMSNorm placement — PLAN.md M5",
167    };
168    pub static GLM_MOE: UnimplementedArch = UnimplementedArch {
169        family: "GLM 4 MoE",
170        milestone: "M5",
171        note: "GLM + MoE routing — PLAN.md M5",
172    };
173    pub static GPT_OSS: UnimplementedArch = UnimplementedArch {
174        family: "gpt-oss",
175        milestone: "M5",
176        note: "OpenAI gpt-oss — confirm arch shape — PLAN.md M5",
177    };
178    pub static NEMOTRON: UnimplementedArch = UnimplementedArch {
179        family: "Nemotron",
180        milestone: "M5",
181        note: "Dense Nemotron arch — PLAN.md M5",
182    };
183    pub static NEMOTRON_H: UnimplementedArch = UnimplementedArch {
184        family: "Nemotron-H",
185        milestone: "M5",
186        note: "Mamba+attention hybrid; depends on `rlx-ssm` upstream — PLAN.md M5/M7",
187    };
188    #[allow(dead_code)]
189    pub static LFM: UnimplementedArch = UnimplementedArch {
190        family: "LFM 2 / 2.5",
191        milestone: "M5",
192        note: "Liquid Foundation Models with custom SSM layers — PLAN.md M5",
193    };
194    pub static LFM_MOE: UnimplementedArch = UnimplementedArch {
195        family: "LFM 2 MoE",
196        milestone: "M5",
197        note: "LFM + MoE — PLAN.md M5",
198    };
199    pub static QWEN3_MOE: UnimplementedArch = UnimplementedArch {
200        family: "Qwen3 MoE",
201        milestone: "M5",
202        note: "Qwen3 + MoE routing block — PLAN.md M5 (often loadable via qwen3 runner once MoE lands)",
203    };
204    pub static QWEN3_NEXT: UnimplementedArch = UnimplementedArch {
205        family: "Qwen3-Next",
206        milestone: "M5",
207        note: "Qwen3-Next variant — confirm arch deltas vs qwen3 — PLAN.md M5",
208    };
209    pub static GEMMA3: UnimplementedArch = UnimplementedArch {
210        family: "Gemma 3",
211        milestone: "M2",
212        note: "Gemma 3 (270m / 4b / 12b / 27b) adds per-layer sliding window + new RoPE — \
213               needs rlx-gemma config branch — PLAN.md M2",
214    };
215    pub static GEMMA3N: UnimplementedArch = UnimplementedArch {
216        family: "Gemma 3n",
217        milestone: "M2",
218        note: "Gemma 3n (mobile/edge Matformer variant) — PLAN.md M2",
219    };
220    pub static GEMMA4: UnimplementedArch = UnimplementedArch {
221        family: "Gemma 4",
222        milestone: "M2",
223        note: "Gemma 4 (flagship + edge E2B/E4B + MoE A4B) — PLAN.md M2 flagship",
224    };
225    pub static QWEN3_VL: UnimplementedArch = UnimplementedArch {
226        family: "Qwen3-VL",
227        milestone: "M7",
228        note: "vision tower + projector + LM (dense or MoE) — PLAN.md M7",
229    };
230    pub static QWEN3_MTP: UnimplementedArch = UnimplementedArch {
231        family: "Qwen3 / Qwen3.6 + MTP",
232        milestone: "M6",
233        note: "multi-token-prediction draft heads — PLAN.md M6",
234    };
235    pub static LLADA: UnimplementedArch = UnimplementedArch {
236        family: "LLaDA / LLaDA MoE (text-only)",
237        milestone: "M5",
238        note: "dense LLaDA arch in llama.cpp; rlx-llada2 currently targets the diffusion runner — PLAN.md M5",
239    };
240    pub static GRANITE: UnimplementedArch = UnimplementedArch {
241        family: "Granite (IBM)",
242        milestone: "M4",
243        note: "Llama-shaped — PLAN.md M4",
244    };
245    pub static DEEPSEEK: UnimplementedArch = UnimplementedArch {
246        family: "DeepSeek 2",
247        milestone: "M5",
248        note: "MoE + MLA attention — needs MoE block + MLA primitive — PLAN.md M5",
249    };
250    pub static COHERE: UnimplementedArch = UnimplementedArch {
251        family: "Command-R / Cohere",
252        milestone: "M4",
253        note: "Llama-shaped — PLAN.md M4",
254    };
255}
256
257/// Catalog families we know about but haven't implemented yet.
258///
259/// The keys are the **actual** GGUF `general.architecture` strings llama.cpp
260/// uses (`src/llama-arch.cpp::LLM_ARCH_NAMES`) plus their HF `model_type`
261/// aliases when those differ. Notably:
262///
263/// * Mistral 1/2 and Qwen 2.5 ship as `general.architecture = llama` /
264///   `qwen2` respectively — they don't have their own llama.cpp arch tag.
265///   Those tags route to the existing `llama32` / `qwen3` runners and are
266///   *not* listed here.
267/// * Mistral 3+ ships as `mistral3` / `mistral4` (real tags).
268/// * Phi-4 ships as `phi3` (Phi-4 reuses the Phi-3 arch in llama.cpp).
269///
270/// Both GGUF arch tags and HF `model_type` values are accepted so
271/// downstream callers don't keep two parallel lists.
272static KNOWN_UNIMPLEMENTED: phf::Map<&'static str, &'static UnimplementedArch> = phf::phf_map! {
273    // Mistral / Ministral (real llama.cpp tags)
274    "mistral3" => &families::MISTRAL,
275    "mistral4" => &families::MISTRAL,
276    // Phi family — Llama32Family accepts the arch tag, but the GGUF
277    // tensor-name remap for `phi3`/`phi4` (e.g. `blk.*.attn_q.weight`
278    // → `model.layers.*.self_attn.q_proj.weight`) is M4 follow-up.
279    "phi3" => &families::PHI,
280    "phi4" => &families::PHI,
281    "phimoe" => &families::PHIMOE,
282    // Catalog HF model_type aliases — same remap gap as phi3.
283    "bonsai" => &families::BONSAI,
284    "omnicoder" => &families::OMNICODER,
285    // Hybrid / SSM families
286    "minimax-m2" => &families::MINIMAX,
287    "minimax_m2" => &families::MINIMAX,
288    "minimax" => &families::MINIMAX,
289    "glm4" => &families::GLM,
290    "glm5" => &families::GLM,
291    "chatglm" => &families::GLM,
292    "glm4moe" => &families::GLM_MOE,
293    "gpt-oss" => &families::GPT_OSS,
294    "gpt_oss" => &families::GPT_OSS,
295    "nemotron" => &families::NEMOTRON,
296    "nemotron_h" => &families::NEMOTRON_H,
297    "nemotron_h_moe" => &families::NEMOTRON_H,
298    // lfm2 / lfm / lfm25 / lfm2_5 are now routed through `rlx-lfm`'s
299    // `LfmRunner` via `gguf_family_for_arch` → `GgufModelFamily::Lfm`.
300    // Only the MoE variant remains unimplemented.
301    "lfm2moe" => &families::LFM_MOE,
302    // Qwen variants we don't run yet
303    "qwen3moe" => &families::QWEN3_MOE,
304    "qwen3next" => &families::QWEN3_NEXT,
305    // Gemma 3+ — rlx-gemma currently targets gemma/gemma2 only.
306    "gemma3" => &families::GEMMA3,
307    "gemma3n" => &families::GEMMA3N,
308    "gemma4" => &families::GEMMA4,
309    "gemma4moe" => &families::GEMMA4,
310    "qwen3vl" => &families::QWEN3_VL,
311    "qwen3vlmoe" => &families::QWEN3_VL,
312    "qwen3_vl" => &families::QWEN3_VL,
313    "qwen3-vl" => &families::QWEN3_VL,
314    "qwen3_mtp" => &families::QWEN3_MTP,
315    "qwen3-mtp" => &families::QWEN3_MTP,
316    "qwen36_mtp" => &families::QWEN3_MTP,
317    // Other catalog-adjacent families
318    "llada" => &families::LLADA,
319    "llada-moe" => &families::LLADA,
320    "granite" => &families::GRANITE,
321    "granitemoe" => &families::GRANITE,
322    "granitehybrid" => &families::GRANITE,
323    "deepseek2" => &families::DEEPSEEK,
324    "deepseek2-ocr" => &families::DEEPSEEK,
325    "command-r" => &families::COHERE,
326    "cohere2" => &families::COHERE,
327};
328
329/// Look up an arch / model_type in the unimplemented-families table.
330pub fn known_unimplemented_arch(arch_or_model_type: &str) -> Option<UnimplementedArch> {
331    KNOWN_UNIMPLEMENTED.get(arch_or_model_type).map(|p| **p)
332}
333
334/// Snapshot of every (key, family) pair currently in the unimplemented
335/// table — useful for `rlx-run check --list-unimplemented` style tooling.
336pub fn known_unimplemented_keys() -> impl Iterator<Item = (&'static str, &'static UnimplementedArch)>
337{
338    KNOWN_UNIMPLEMENTED.entries().map(|(k, v)| (*k, *v))
339}
340
341/// Map a GGUF `general.architecture` tag to the short runner name.
342///
343/// Returns `None` for embed-only families (`bert`, `nomic-bert`, …) which
344/// aren't currently exposed through the `rlx-run` dispatch table, and for
345/// catalog families that aren't implemented yet — those get a richer error
346/// via [`known_unimplemented_arch`] when sniffed.
347pub fn arch_runner_name(arch: &str) -> Option<&'static str> {
348    if let Some(fam) = gguf_family_for_arch(arch) {
349        return Some(fam.runner_name());
350    }
351    if FLUX_GGUF_ARCHES.contains(&arch) {
352        return Some("flux2");
353    }
354    if DINOV2_GGUF_ARCHES.contains(&arch) {
355        return Some("dinov2");
356    }
357    if VJEPA2_GGUF_ARCHES.contains(&arch) {
358        return Some("vjepa2");
359    }
360    if SAM3_GGUF_ARCHES.contains(&arch) {
361        return Some("sam3");
362    }
363    if SAM2_GGUF_ARCHES.contains(&arch) {
364        return Some("sam2");
365    }
366    if SAM_GGUF_ARCHES.contains(&arch) {
367        return Some("sam1");
368    }
369    if W2V_BERT_GGUF_ARCHES.contains(&arch) {
370        return Some("wav2vec2-bert");
371    }
372    None
373}
374
375/// Map an HF `config.json` `model_type` value to a short runner name.
376///
377/// HF naming differs from GGUF tags — `model_type: "llama"` covers Llama
378/// 2 / 3 / 3.x, `qwen3` covers Qwen3 and Qwen3 MoE, etc.
379pub fn model_type_runner_name(model_type: &str) -> Option<&'static str> {
380    match model_type {
381        // qwen2 deliberately omitted — rlx-qwen3 doesn't support
382        // Qwen 2 tensor layout (needs q/k/v bias + no QK-norm).
383        // qwen2 GGUFs fall through to known_unimplemented_arch.
384        "qwen3" | "qwen3_moe" | "qwen3moe" | "qwen25" | "qwen2_5" | "qwen2.5" | "qwen251"
385        | "qwen2_5_1" => Some("qwen3"),
386        "qwen35" | "qwen3_5" | "qwen35_moe" | "qwen35moe" => Some("qwen35"),
387        // Qwen3.6 runs through the qwen35 trunk (PLAN.md M1).
388        "qwen36" | "qwen3_6" | "qwen36_moe" | "qwen36moe" => Some("qwen35"),
389        "llama" | "llama2" | "llama3" => Some("llama32"),
390        "gemma" | "gemma2" | "gemma3" | "gemma3n" => Some("gemma"),
391        "dinov2" | "dinov2_with_registers" => Some("dinov2"),
392        "vjepa2" | "vjepa" => Some("vjepa2"),
393        "sam" | "sam_vit" | "mobile-sam" | "mobile_sam" => Some("sam1"),
394        "sam2" => Some("sam2"),
395        "sam3" => Some("sam3"),
396        "whisper" => Some("whisper"),
397        "wav2vec2-bert" | "wav2vec2_bert" | "w2v-bert" | "w2v_bert" => Some("wav2vec2-bert"),
398        "flux" | "flux2" => Some("flux2"),
399        _ => None,
400    }
401}
402
403/// Sniff `model_type` from the `config.json` next to a safetensors file.
404fn read_model_type_from_sidecar(path: &Path) -> Result<Option<String>> {
405    let dir = path
406        .parent()
407        .ok_or_else(|| anyhow!("safetensors path {path:?} has no parent dir"))?;
408    let cfg = dir.join("config.json");
409    if !cfg.is_file() {
410        return Ok(None);
411    }
412    let bytes = std::fs::read(&cfg).with_context(|| format!("reading {cfg:?}"))?;
413    let v: serde_json::Value =
414        serde_json::from_slice(&bytes).with_context(|| format!("parsing {cfg:?}"))?;
415    Ok(v.get("model_type")
416        .and_then(serde_json::Value::as_str)
417        .map(str::to_owned))
418}
419
420/// Resolve `path` to a single weight file, then sniff the runner.
421pub fn auto_sniff(path: &Path) -> Result<SniffedRunner> {
422    let file = resolve_weights_file(path)?;
423    let ext = file.extension().and_then(|s| s.to_str()).unwrap_or("");
424    match ext {
425        "gguf" => {
426            let arch = gguf_architecture_from_path(&file)?;
427            let runner = arch_runner_name(&arch).ok_or_else(|| {
428                if let Some(u) = known_unimplemented_arch(&arch) {
429                    anyhow!(
430                        "{file:?}: GGUF architecture `{arch}` is {} ({}) — not yet implemented in rlx-models. {}",
431                        u.family, u.milestone, u.note
432                    )
433                } else {
434                    anyhow!(
435                        "{file:?}: GGUF architecture `{arch}` has no registered rlx runner; \
436                         see `rlx-run` for supported families"
437                    )
438                }
439            })?;
440            Ok(SniffedRunner {
441                path: file,
442                runner_name: runner,
443                from: SniffedFrom::GgufArch(arch),
444            })
445        }
446        "safetensors" => {
447            let model_type = read_model_type_from_sidecar(&file)?.ok_or_else(|| {
448                anyhow!("{file:?}: no `model_type` in sidecar config.json (auto-dispatch needs it)")
449            })?;
450            let runner = model_type_runner_name(&model_type).ok_or_else(|| {
451                if let Some(u) = known_unimplemented_arch(&model_type) {
452                    anyhow!(
453                        "{file:?}: safetensors model_type `{model_type}` is {} ({}) — not yet implemented in rlx-models. {}",
454                        u.family, u.milestone, u.note
455                    )
456                } else {
457                    anyhow!(
458                        "{file:?}: safetensors model_type `{model_type}` has no registered rlx runner"
459                    )
460                }
461            })?;
462            Ok(SniffedRunner {
463                path: file,
464                runner_name: runner,
465                from: SniffedFrom::SafetensorsConfig(model_type),
466            })
467        }
468        other => {
469            bail!("{file:?}: unsupported extension `.{other}` (expected .gguf or .safetensors)")
470        }
471    }
472}
473
474/// Sniff `path` and return only the runner short name.
475pub fn auto_runner_name(path: &Path) -> Result<&'static str> {
476    Ok(auto_sniff(path)?.runner_name)
477}
478
479/// Sniff `path`, look up its runner in the registry, and run it with `args`.
480///
481/// `args` should be the per-runner argv *without* the leading subcommand.
482/// Returns the runner name that was dispatched to.
483pub fn auto_dispatch(path: &Path, args: &[String]) -> Result<&'static str> {
484    let sniff = auto_sniff(path)?;
485    match run_registered(sniff.runner_name, args)? {
486        Some(()) => Ok(sniff.runner_name),
487        None => bail!(
488            "runner `{}` not registered (sniffed from {:?}); register it via \
489             `register_cli` before calling auto_dispatch",
490            sniff.runner_name,
491            sniff.from
492        ),
493    }
494}
495
496#[cfg(test)]
497mod tests {
498    use super::*;
499
500    #[test]
501    fn arch_runner_maps_lm_families() {
502        assert_eq!(arch_runner_name("qwen3"), Some("qwen3"));
503        // qwen2 now routes to the qwen3 runner — the runner reads
504        // attention_bias + qk_norm from the GGUF arch tag and emits
505        // the right per-layer math.
506        assert_eq!(arch_runner_name("qwen2"), Some("qwen3"));
507        assert_eq!(arch_runner_name("qwen35"), Some("qwen35"));
508        assert_eq!(arch_runner_name("qwen35moe"), Some("qwen35"));
509        // Qwen3.6 reuses the qwen35 trunk (PLAN.md M1). qwen36_mtp still
510        // routes through known_unimplemented_arch — base qwen36 routes
511        // here so unsloth/Qwen3.6-27B-GGUF (no MTP) just works.
512        assert_eq!(arch_runner_name("qwen36"), Some("qwen35"));
513        assert_eq!(arch_runner_name("qwen36moe"), Some("qwen35"));
514        // Qwen 2.5 / 2.5.1 ship as `qwen2` arch tag; explicit short
515        // tags also route to the qwen3 runner (PLAN.md M4).
516        assert_eq!(arch_runner_name("qwen25"), Some("qwen3"));
517        assert_eq!(arch_runner_name("qwen2_5"), Some("qwen3"));
518        assert_eq!(arch_runner_name("llama"), Some("llama32"));
519        assert_eq!(arch_runner_name("gemma"), Some("gemma"));
520        assert_eq!(arch_runner_name("gemma2"), Some("gemma"));
521    }
522
523    #[test]
524    fn arch_runner_maps_vision_and_diffusion() {
525        assert_eq!(arch_runner_name("dinov2"), Some("dinov2"));
526        assert_eq!(arch_runner_name("sam"), Some("sam1"));
527        assert_eq!(arch_runner_name("mobile-sam"), Some("sam1"));
528        assert_eq!(arch_runner_name("sam2"), Some("sam2"));
529        assert_eq!(arch_runner_name("sam3"), Some("sam3"));
530        assert_eq!(arch_runner_name("flux"), Some("flux2"));
531        assert_eq!(arch_runner_name("vjepa2"), Some("vjepa2"));
532        assert_eq!(arch_runner_name("w2v-bert"), Some("wav2vec2-bert"));
533    }
534
535    #[test]
536    fn arch_runner_returns_none_for_embed_and_unknown() {
537        // Embed families aren't in the rlx-run dispatch table today.
538        assert_eq!(arch_runner_name("bert"), None);
539        assert_eq!(arch_runner_name("nomic-bert"), None);
540        assert_eq!(arch_runner_name("totally-fake-arch"), None);
541    }
542
543    #[test]
544    fn known_unimplemented_covers_plan_families() {
545        // M4 — Llama-shaped (real llama.cpp tags)
546        assert_eq!(
547            known_unimplemented_arch("mistral3").map(|u| u.milestone),
548            Some("M4")
549        );
550        assert_eq!(
551            known_unimplemented_arch("phi3").map(|u| u.milestone),
552            Some("M4")
553        );
554        assert_eq!(
555            known_unimplemented_arch("phi4").map(|u| u.milestone),
556            Some("M4")
557        );
558        assert_eq!(
559            known_unimplemented_arch("bonsai").map(|u| u.milestone),
560            Some("M4")
561        );
562        // M5 — MoE / SSM
563        assert_eq!(
564            known_unimplemented_arch("minimax-m2").map(|u| u.milestone),
565            Some("M5")
566        );
567        assert_eq!(
568            known_unimplemented_arch("glm4").map(|u| u.milestone),
569            Some("M5")
570        );
571        assert_eq!(
572            known_unimplemented_arch("nemotron_h").map(|u| u.milestone),
573            Some("M5")
574        );
575        // M6 — MTP
576        assert_eq!(
577            known_unimplemented_arch("qwen3_mtp").map(|u| u.milestone),
578            Some("M6")
579        );
580        // M7 — VL
581        assert_eq!(
582            known_unimplemented_arch("qwen3vl").map(|u| u.milestone),
583            Some("M7")
584        );
585        // Implemented or unknown — plain `mistral` is NOT a llama.cpp arch
586        // tag (Mistral 1/2 use `llama`), so it should not be flagged.
587        assert_eq!(known_unimplemented_arch("qwen3"), None);
588        assert_eq!(known_unimplemented_arch("mistral"), None);
589        assert_eq!(known_unimplemented_arch("totally-fake"), None);
590    }
591
592    #[test]
593    fn auto_sniff_error_points_at_milestone_for_known_unimplemented() {
594        // Build a tiny mistral.gguf and check the error message.
595        let mut buf: Vec<u8> = Vec::new();
596        buf.extend_from_slice(&rlx_gguf::GGUF_MAGIC.to_le_bytes());
597        buf.extend_from_slice(&3u32.to_le_bytes());
598        buf.extend_from_slice(&1u64.to_le_bytes());
599        buf.extend_from_slice(&1u64.to_le_bytes());
600        let k = "general.architecture";
601        buf.extend_from_slice(&(k.len() as u64).to_le_bytes());
602        buf.extend_from_slice(k.as_bytes());
603        buf.extend_from_slice(&8u32.to_le_bytes());
604        let v = "mistral3";
605        buf.extend_from_slice(&(v.len() as u64).to_le_bytes());
606        buf.extend_from_slice(v.as_bytes());
607        let name = "w";
608        buf.extend_from_slice(&(name.len() as u64).to_le_bytes());
609        buf.extend_from_slice(name.as_bytes());
610        buf.extend_from_slice(&1u32.to_le_bytes());
611        buf.extend_from_slice(&4u64.to_le_bytes());
612        buf.extend_from_slice(&(rlx_gguf::GgmlType::F32 as u32).to_le_bytes());
613        buf.extend_from_slice(&0u64.to_le_bytes());
614        while !buf
615            .len()
616            .is_multiple_of(rlx_gguf::DEFAULT_ALIGNMENT as usize)
617        {
618            buf.push(0);
619        }
620        for _ in 0..4 {
621            buf.extend_from_slice(&1.0f32.to_le_bytes());
622        }
623        let path = std::env::temp_dir().join("rlx_auto_dispatch_mistral3_hint.gguf");
624        std::fs::write(&path, &buf).unwrap();
625        let err = auto_sniff(&path).expect_err("should error");
626        let s = format!("{err:#}");
627        assert!(s.contains("Mistral"), "expected family name in error: {s}");
628        assert!(s.contains("M4"), "expected milestone tag in error: {s}");
629        std::fs::remove_file(&path).ok();
630    }
631
632    #[test]
633    fn model_type_runner_maps_known() {
634        assert_eq!(model_type_runner_name("qwen3"), Some("qwen3"));
635        assert_eq!(model_type_runner_name("qwen3_moe"), Some("qwen3"));
636        assert_eq!(model_type_runner_name("llama"), Some("llama32"));
637        assert_eq!(model_type_runner_name("gemma3"), Some("gemma"));
638        assert_eq!(
639            model_type_runner_name("dinov2_with_registers"),
640            Some("dinov2")
641        );
642        assert_eq!(model_type_runner_name("whisper"), Some("whisper"));
643        assert_eq!(model_type_runner_name("unknown"), None);
644    }
645
646    /// Builds a minimal GGUF file in a temp dir, then verifies auto_sniff
647    /// picks the right runner name from `general.architecture`.
648    #[test]
649    fn auto_sniff_reads_gguf_arch() {
650        let mut buf: Vec<u8> = Vec::new();
651        buf.extend_from_slice(&rlx_gguf::GGUF_MAGIC.to_le_bytes());
652        buf.extend_from_slice(&3u32.to_le_bytes());
653        buf.extend_from_slice(&1u64.to_le_bytes()); // tensor count
654        buf.extend_from_slice(&1u64.to_le_bytes()); // kv count
655        let write_string = |buf: &mut Vec<u8>, k: &str, v: &str| {
656            buf.extend_from_slice(&(k.len() as u64).to_le_bytes());
657            buf.extend_from_slice(k.as_bytes());
658            buf.extend_from_slice(&8u32.to_le_bytes());
659            buf.extend_from_slice(&(v.len() as u64).to_le_bytes());
660            buf.extend_from_slice(v.as_bytes());
661        };
662        write_string(&mut buf, "general.architecture", "qwen3");
663        // one f32 tensor with 4 elements
664        let name = "w";
665        buf.extend_from_slice(&(name.len() as u64).to_le_bytes());
666        buf.extend_from_slice(name.as_bytes());
667        buf.extend_from_slice(&1u32.to_le_bytes());
668        buf.extend_from_slice(&4u64.to_le_bytes());
669        buf.extend_from_slice(&(rlx_gguf::GgmlType::F32 as u32).to_le_bytes());
670        buf.extend_from_slice(&0u64.to_le_bytes());
671        while !buf
672            .len()
673            .is_multiple_of(rlx_gguf::DEFAULT_ALIGNMENT as usize)
674        {
675            buf.push(0);
676        }
677        for _ in 0..4 {
678            buf.extend_from_slice(&1.0f32.to_le_bytes());
679        }
680        let path = std::env::temp_dir().join("rlx_auto_dispatch_sniff.gguf");
681        std::fs::write(&path, &buf).unwrap();
682        let sniff = auto_sniff(&path).expect("sniff");
683        assert_eq!(sniff.runner_name, "qwen3");
684        match sniff.from {
685            SniffedFrom::GgufArch(a) => assert_eq!(a, "qwen3"),
686            other => panic!("wrong sniff source: {other:?}"),
687        }
688        std::fs::remove_file(&path).ok();
689    }
690
691    /// Register a fake runner under a known name, ask `run_auto` to dispatch
692    /// to it, and capture what argv it received.
693    #[test]
694    fn run_auto_injects_weights_flag_when_missing() {
695        use crate::registry::{ModelRunner, register_runner};
696        use std::sync::{Mutex, OnceLock};
697
698        static CAPTURED: OnceLock<Mutex<Vec<String>>> = OnceLock::new();
699        fn captured() -> &'static Mutex<Vec<String>> {
700            CAPTURED.get_or_init(|| Mutex::new(Vec::new()))
701        }
702
703        struct Capture;
704        impl ModelRunner for Capture {
705            fn name(&self) -> &'static str {
706                "qwen3"
707            }
708            fn description(&self) -> &'static str {
709                "test capture"
710            }
711            fn run(&self, args: &[String]) -> Result<()> {
712                *captured().lock().unwrap() = args.to_vec();
713                Ok(())
714            }
715        }
716        register_runner(Box::new(Capture));
717
718        // Build a minimal qwen3 GGUF in a temp dir.
719        let dir = std::env::temp_dir().join("rlx_auto_dispatch_run_auto");
720        std::fs::create_dir_all(&dir).unwrap();
721        let path = dir.join("model.gguf");
722        let mut buf: Vec<u8> = Vec::new();
723        buf.extend_from_slice(&rlx_gguf::GGUF_MAGIC.to_le_bytes());
724        buf.extend_from_slice(&3u32.to_le_bytes());
725        buf.extend_from_slice(&1u64.to_le_bytes());
726        buf.extend_from_slice(&1u64.to_le_bytes());
727        let k = "general.architecture";
728        buf.extend_from_slice(&(k.len() as u64).to_le_bytes());
729        buf.extend_from_slice(k.as_bytes());
730        buf.extend_from_slice(&8u32.to_le_bytes());
731        let v = "qwen3";
732        buf.extend_from_slice(&(v.len() as u64).to_le_bytes());
733        buf.extend_from_slice(v.as_bytes());
734        let name = "w";
735        buf.extend_from_slice(&(name.len() as u64).to_le_bytes());
736        buf.extend_from_slice(name.as_bytes());
737        buf.extend_from_slice(&1u32.to_le_bytes());
738        buf.extend_from_slice(&4u64.to_le_bytes());
739        buf.extend_from_slice(&(rlx_gguf::GgmlType::F32 as u32).to_le_bytes());
740        buf.extend_from_slice(&0u64.to_le_bytes());
741        while !buf
742            .len()
743            .is_multiple_of(rlx_gguf::DEFAULT_ALIGNMENT as usize)
744        {
745            buf.push(0);
746        }
747        for _ in 0..4 {
748            buf.extend_from_slice(&1.0f32.to_le_bytes());
749        }
750        std::fs::write(&path, &buf).unwrap();
751
752        // Caller passed no --weights → run_auto must inject it.
753        run_auto(&[path.display().to_string(), "--prompt".into(), "hi".into()]).unwrap();
754        let got = captured().lock().unwrap().clone();
755        assert_eq!(
756            got,
757            vec![
758                "--weights".to_string(),
759                path.display().to_string(),
760                "--prompt".into(),
761                "hi".into()
762            ]
763        );
764
765        // Caller already passed --weights → don't inject again.
766        run_auto(&[
767            path.display().to_string(),
768            "--weights".into(),
769            "/other/path".into(),
770            "--prompt".into(),
771            "hi".into(),
772        ])
773        .unwrap();
774        let got = captured().lock().unwrap().clone();
775        assert_eq!(
776            got,
777            vec![
778                "--weights".to_string(),
779                "/other/path".into(),
780                "--prompt".into(),
781                "hi".into(),
782            ]
783        );
784
785        std::fs::remove_dir_all(&dir).ok();
786    }
787
788    #[test]
789    fn auto_sniff_reads_safetensors_sidecar() {
790        let dir = std::env::temp_dir().join("rlx_auto_dispatch_sidecar");
791        std::fs::create_dir_all(&dir).unwrap();
792        let cfg = dir.join("config.json");
793        std::fs::write(&cfg, br#"{"model_type":"llama"}"#).unwrap();
794        let st = dir.join("model.safetensors");
795        // Empty file is fine — sniffer never opens the safetensors payload.
796        std::fs::write(&st, b"").unwrap();
797        let sniff = auto_sniff(&st).expect("sniff");
798        assert_eq!(sniff.runner_name, "llama32");
799        std::fs::remove_dir_all(&dir).ok();
800    }
801}