1use anyhow::{Context, Result, anyhow, bail};
30use rlx_core::gguf_config::{
31 DINOV2_GGUF_ARCHES, FLUX_GGUF_ARCHES, SAM_GGUF_ARCHES, SAM2_GGUF_ARCHES, SAM3_GGUF_ARCHES,
32 VJEPA2_GGUF_ARCHES, W2V_BERT_GGUF_ARCHES,
33};
34use rlx_core::gguf_support::{
35 gguf_architecture_from_path, gguf_family_for_arch, resolve_weights_file,
36};
37use std::path::{Path, PathBuf};
38
39use crate::registry::run_registered;
40
41pub fn run_auto(args: &[String]) -> Result<()> {
48 let Some(first) = args.first() else {
49 bail!(
50 "auto: expected WEIGHTS path as the first argument\n\
51 usage: rlx-run auto <weights-path> [runner-args...]"
52 );
53 };
54 if matches!(first.as_str(), "-h" | "--help" | "help") {
55 println!(
56 "rlx-run auto — sniff a GGUF / safetensors file and dispatch to the right runner\n\
57 \n\
58 USAGE:\n rlx-run auto <weights-path> [runner-args...]\n\
59 \n\
60 The first argument is forwarded as the runner's --weights value;\n\
61 remaining arguments are passed through unchanged."
62 );
63 return Ok(());
64 }
65 let path = Path::new(first);
66 let sniff = auto_sniff(path)?;
67 eprintln!(
68 "[rlx-run auto] {} → runner `{}` (from {:?})",
69 sniff.path.display(),
70 sniff.runner_name,
71 sniff.from
72 );
73 let rest: Vec<String> = args[1..].to_vec();
76 let has_weights_flag = rest
77 .iter()
78 .any(|a| a == "--weights" || a.starts_with("--weights="));
79 let mut forwarded: Vec<String> = Vec::with_capacity(rest.len() + 2);
80 if !has_weights_flag {
81 forwarded.push("--weights".into());
82 forwarded.push(sniff.path.display().to_string());
83 }
84 forwarded.extend(rest);
85 match run_registered(sniff.runner_name, &forwarded)? {
86 Some(()) => Ok(()),
87 None => bail!(
88 "auto: runner `{}` not registered (sniffed from {:?}); register it via \
89 `register_cli` in your binary's main",
90 sniff.runner_name,
91 sniff.from
92 ),
93 }
94}
95
96#[derive(Debug, Clone)]
98pub enum SniffedFrom {
99 GgufArch(String),
101 SafetensorsConfig(String),
103}
104
105#[derive(Debug, Clone)]
107pub struct SniffedRunner {
108 pub path: PathBuf,
110 pub runner_name: &'static str,
112 pub from: SniffedFrom,
114}
115
116#[derive(Debug, Clone, Copy, PartialEq, Eq)]
120pub struct UnimplementedArch {
121 pub family: &'static str,
123 pub milestone: &'static str,
125 pub note: &'static str,
127}
128
129mod families {
132 use super::UnimplementedArch;
133 pub static MISTRAL: UnimplementedArch = UnimplementedArch {
134 family: "Mistral 3+ / Ministral",
135 milestone: "M4",
136 note: "Llama-shaped with newer RoPE; share `rlx-llama-base` per PLAN.md M4",
137 };
138 pub static PHI: UnimplementedArch = UnimplementedArch {
139 family: "Phi 3 / Phi 4",
140 milestone: "M4",
141 note: "Phi3/4 share llama.cpp arch tag — PLAN.md M4",
142 };
143 pub static PHIMOE: UnimplementedArch = UnimplementedArch {
144 family: "Phi MoE",
145 milestone: "M4 + M5",
146 note: "Phi + MoE routing; depends on shared MoE block — PLAN.md M4/M5",
147 };
148 pub static BONSAI: UnimplementedArch = UnimplementedArch {
149 family: "Bonsai",
150 milestone: "M4",
151 note: "Llama-shaped; HF model_type only — usually ships as llama GGUF — PLAN.md M4",
152 };
153 pub static OMNICODER: UnimplementedArch = UnimplementedArch {
154 family: "OmniCoder",
155 milestone: "M4",
156 note: "Qwen3-coder shaped — PLAN.md M4 (often tagged `qwen3` in GGUF)",
157 };
158 pub static MINIMAX: UnimplementedArch = UnimplementedArch {
159 family: "MiniMax M2",
160 milestone: "M5",
161 note: "Lightning Attention; depends on `rlx-ssm` upstream — PLAN.md M5",
162 };
163 pub static GLM: UnimplementedArch = UnimplementedArch {
164 family: "GLM 4 / 5",
165 milestone: "M5",
166 note: "GLM RoPE + RMSNorm placement — PLAN.md M5",
167 };
168 pub static GLM_MOE: UnimplementedArch = UnimplementedArch {
169 family: "GLM 4 MoE",
170 milestone: "M5",
171 note: "GLM + MoE routing — PLAN.md M5",
172 };
173 pub static GPT_OSS: UnimplementedArch = UnimplementedArch {
174 family: "gpt-oss",
175 milestone: "M5",
176 note: "OpenAI gpt-oss — confirm arch shape — PLAN.md M5",
177 };
178 pub static NEMOTRON: UnimplementedArch = UnimplementedArch {
179 family: "Nemotron",
180 milestone: "M5",
181 note: "Dense Nemotron arch — PLAN.md M5",
182 };
183 pub static NEMOTRON_H: UnimplementedArch = UnimplementedArch {
184 family: "Nemotron-H",
185 milestone: "M5",
186 note: "Mamba+attention hybrid; depends on `rlx-ssm` upstream — PLAN.md M5/M7",
187 };
188 #[allow(dead_code)]
189 pub static LFM: UnimplementedArch = UnimplementedArch {
190 family: "LFM 2 / 2.5",
191 milestone: "M5",
192 note: "Liquid Foundation Models with custom SSM layers — PLAN.md M5",
193 };
194 pub static LFM_MOE: UnimplementedArch = UnimplementedArch {
195 family: "LFM 2 MoE",
196 milestone: "M5",
197 note: "LFM + MoE — PLAN.md M5",
198 };
199 pub static QWEN3_MOE: UnimplementedArch = UnimplementedArch {
200 family: "Qwen3 MoE",
201 milestone: "M5",
202 note: "Qwen3 + MoE routing block — PLAN.md M5 (often loadable via qwen3 runner once MoE lands)",
203 };
204 pub static QWEN3_NEXT: UnimplementedArch = UnimplementedArch {
205 family: "Qwen3-Next",
206 milestone: "M5",
207 note: "Qwen3-Next variant — confirm arch deltas vs qwen3 — PLAN.md M5",
208 };
209 pub static GEMMA3: UnimplementedArch = UnimplementedArch {
210 family: "Gemma 3",
211 milestone: "M2",
212 note: "Gemma 3 (270m / 4b / 12b / 27b) adds per-layer sliding window + new RoPE — \
213 needs rlx-gemma config branch — PLAN.md M2",
214 };
215 pub static GEMMA3N: UnimplementedArch = UnimplementedArch {
216 family: "Gemma 3n",
217 milestone: "M2",
218 note: "Gemma 3n (mobile/edge Matformer variant) — PLAN.md M2",
219 };
220 pub static GEMMA4: UnimplementedArch = UnimplementedArch {
221 family: "Gemma 4",
222 milestone: "M2",
223 note: "Gemma 4 (flagship + edge E2B/E4B + MoE A4B) — PLAN.md M2 flagship",
224 };
225 pub static QWEN3_VL: UnimplementedArch = UnimplementedArch {
226 family: "Qwen3-VL",
227 milestone: "M7",
228 note: "vision tower + projector + LM (dense or MoE) — PLAN.md M7",
229 };
230 pub static QWEN3_MTP: UnimplementedArch = UnimplementedArch {
231 family: "Qwen3 / Qwen3.6 + MTP",
232 milestone: "M6",
233 note: "multi-token-prediction draft heads — PLAN.md M6",
234 };
235 pub static LLADA: UnimplementedArch = UnimplementedArch {
236 family: "LLaDA / LLaDA MoE (text-only)",
237 milestone: "M5",
238 note: "dense LLaDA arch in llama.cpp; rlx-llada2 currently targets the diffusion runner — PLAN.md M5",
239 };
240 pub static GRANITE: UnimplementedArch = UnimplementedArch {
241 family: "Granite (IBM)",
242 milestone: "M4",
243 note: "Llama-shaped — PLAN.md M4",
244 };
245 pub static DEEPSEEK: UnimplementedArch = UnimplementedArch {
246 family: "DeepSeek 2",
247 milestone: "M5",
248 note: "MoE + MLA attention — needs MoE block + MLA primitive — PLAN.md M5",
249 };
250 pub static COHERE: UnimplementedArch = UnimplementedArch {
251 family: "Command-R / Cohere",
252 milestone: "M4",
253 note: "Llama-shaped — PLAN.md M4",
254 };
255}
256
257static KNOWN_UNIMPLEMENTED: phf::Map<&'static str, &'static UnimplementedArch> = phf::phf_map! {
273 "mistral3" => &families::MISTRAL,
275 "mistral4" => &families::MISTRAL,
276 "phi3" => &families::PHI,
280 "phi4" => &families::PHI,
281 "phimoe" => &families::PHIMOE,
282 "bonsai" => &families::BONSAI,
284 "omnicoder" => &families::OMNICODER,
285 "minimax-m2" => &families::MINIMAX,
287 "minimax_m2" => &families::MINIMAX,
288 "minimax" => &families::MINIMAX,
289 "glm4" => &families::GLM,
290 "glm5" => &families::GLM,
291 "chatglm" => &families::GLM,
292 "glm4moe" => &families::GLM_MOE,
293 "gpt-oss" => &families::GPT_OSS,
294 "gpt_oss" => &families::GPT_OSS,
295 "nemotron" => &families::NEMOTRON,
296 "nemotron_h" => &families::NEMOTRON_H,
297 "nemotron_h_moe" => &families::NEMOTRON_H,
298 "lfm2moe" => &families::LFM_MOE,
302 "qwen3moe" => &families::QWEN3_MOE,
304 "qwen3next" => &families::QWEN3_NEXT,
305 "gemma3" => &families::GEMMA3,
307 "gemma3n" => &families::GEMMA3N,
308 "gemma4" => &families::GEMMA4,
309 "gemma4moe" => &families::GEMMA4,
310 "qwen3vl" => &families::QWEN3_VL,
311 "qwen3vlmoe" => &families::QWEN3_VL,
312 "qwen3_vl" => &families::QWEN3_VL,
313 "qwen3-vl" => &families::QWEN3_VL,
314 "qwen3_mtp" => &families::QWEN3_MTP,
315 "qwen3-mtp" => &families::QWEN3_MTP,
316 "qwen36_mtp" => &families::QWEN3_MTP,
317 "llada" => &families::LLADA,
319 "llada-moe" => &families::LLADA,
320 "granite" => &families::GRANITE,
321 "granitemoe" => &families::GRANITE,
322 "granitehybrid" => &families::GRANITE,
323 "deepseek2" => &families::DEEPSEEK,
324 "deepseek2-ocr" => &families::DEEPSEEK,
325 "command-r" => &families::COHERE,
326 "cohere2" => &families::COHERE,
327};
328
329pub fn known_unimplemented_arch(arch_or_model_type: &str) -> Option<UnimplementedArch> {
331 KNOWN_UNIMPLEMENTED.get(arch_or_model_type).map(|p| **p)
332}
333
334pub fn known_unimplemented_keys() -> impl Iterator<Item = (&'static str, &'static UnimplementedArch)>
337{
338 KNOWN_UNIMPLEMENTED.entries().map(|(k, v)| (*k, *v))
339}
340
341pub fn arch_runner_name(arch: &str) -> Option<&'static str> {
348 if let Some(fam) = gguf_family_for_arch(arch) {
349 return Some(fam.runner_name());
350 }
351 if FLUX_GGUF_ARCHES.contains(&arch) {
352 return Some("flux2");
353 }
354 if DINOV2_GGUF_ARCHES.contains(&arch) {
355 return Some("dinov2");
356 }
357 if VJEPA2_GGUF_ARCHES.contains(&arch) {
358 return Some("vjepa2");
359 }
360 if SAM3_GGUF_ARCHES.contains(&arch) {
361 return Some("sam3");
362 }
363 if SAM2_GGUF_ARCHES.contains(&arch) {
364 return Some("sam2");
365 }
366 if SAM_GGUF_ARCHES.contains(&arch) {
367 return Some("sam1");
368 }
369 if W2V_BERT_GGUF_ARCHES.contains(&arch) {
370 return Some("wav2vec2-bert");
371 }
372 None
373}
374
375pub fn model_type_runner_name(model_type: &str) -> Option<&'static str> {
380 match model_type {
381 "qwen3" | "qwen3_moe" | "qwen3moe" | "qwen25" | "qwen2_5" | "qwen2.5" | "qwen251"
385 | "qwen2_5_1" => Some("qwen3"),
386 "qwen35" | "qwen3_5" | "qwen35_moe" | "qwen35moe" => Some("qwen35"),
387 "qwen36" | "qwen3_6" | "qwen36_moe" | "qwen36moe" => Some("qwen35"),
389 "llama" | "llama2" | "llama3" => Some("llama32"),
390 "gemma" | "gemma2" | "gemma3" | "gemma3n" => Some("gemma"),
391 "dinov2" | "dinov2_with_registers" => Some("dinov2"),
392 "vjepa2" | "vjepa" => Some("vjepa2"),
393 "sam" | "sam_vit" | "mobile-sam" | "mobile_sam" => Some("sam1"),
394 "sam2" => Some("sam2"),
395 "sam3" => Some("sam3"),
396 "whisper" => Some("whisper"),
397 "wav2vec2-bert" | "wav2vec2_bert" | "w2v-bert" | "w2v_bert" => Some("wav2vec2-bert"),
398 "flux" | "flux2" => Some("flux2"),
399 _ => None,
400 }
401}
402
403fn read_model_type_from_sidecar(path: &Path) -> Result<Option<String>> {
405 let dir = path
406 .parent()
407 .ok_or_else(|| anyhow!("safetensors path {path:?} has no parent dir"))?;
408 let cfg = dir.join("config.json");
409 if !cfg.is_file() {
410 return Ok(None);
411 }
412 let bytes = std::fs::read(&cfg).with_context(|| format!("reading {cfg:?}"))?;
413 let v: serde_json::Value =
414 serde_json::from_slice(&bytes).with_context(|| format!("parsing {cfg:?}"))?;
415 Ok(v.get("model_type")
416 .and_then(serde_json::Value::as_str)
417 .map(str::to_owned))
418}
419
420pub fn auto_sniff(path: &Path) -> Result<SniffedRunner> {
422 let file = resolve_weights_file(path)?;
423 let ext = file.extension().and_then(|s| s.to_str()).unwrap_or("");
424 match ext {
425 "gguf" => {
426 let arch = gguf_architecture_from_path(&file)?;
427 let runner = arch_runner_name(&arch).ok_or_else(|| {
428 if let Some(u) = known_unimplemented_arch(&arch) {
429 anyhow!(
430 "{file:?}: GGUF architecture `{arch}` is {} ({}) — not yet implemented in rlx-models. {}",
431 u.family, u.milestone, u.note
432 )
433 } else {
434 anyhow!(
435 "{file:?}: GGUF architecture `{arch}` has no registered rlx runner; \
436 see `rlx-run` for supported families"
437 )
438 }
439 })?;
440 Ok(SniffedRunner {
441 path: file,
442 runner_name: runner,
443 from: SniffedFrom::GgufArch(arch),
444 })
445 }
446 "safetensors" => {
447 let model_type = read_model_type_from_sidecar(&file)?.ok_or_else(|| {
448 anyhow!("{file:?}: no `model_type` in sidecar config.json (auto-dispatch needs it)")
449 })?;
450 let runner = model_type_runner_name(&model_type).ok_or_else(|| {
451 if let Some(u) = known_unimplemented_arch(&model_type) {
452 anyhow!(
453 "{file:?}: safetensors model_type `{model_type}` is {} ({}) — not yet implemented in rlx-models. {}",
454 u.family, u.milestone, u.note
455 )
456 } else {
457 anyhow!(
458 "{file:?}: safetensors model_type `{model_type}` has no registered rlx runner"
459 )
460 }
461 })?;
462 Ok(SniffedRunner {
463 path: file,
464 runner_name: runner,
465 from: SniffedFrom::SafetensorsConfig(model_type),
466 })
467 }
468 other => {
469 bail!("{file:?}: unsupported extension `.{other}` (expected .gguf or .safetensors)")
470 }
471 }
472}
473
474pub fn auto_runner_name(path: &Path) -> Result<&'static str> {
476 Ok(auto_sniff(path)?.runner_name)
477}
478
479pub fn auto_dispatch(path: &Path, args: &[String]) -> Result<&'static str> {
484 let sniff = auto_sniff(path)?;
485 match run_registered(sniff.runner_name, args)? {
486 Some(()) => Ok(sniff.runner_name),
487 None => bail!(
488 "runner `{}` not registered (sniffed from {:?}); register it via \
489 `register_cli` before calling auto_dispatch",
490 sniff.runner_name,
491 sniff.from
492 ),
493 }
494}
495
496#[cfg(test)]
497mod tests {
498 use super::*;
499
500 #[test]
501 fn arch_runner_maps_lm_families() {
502 assert_eq!(arch_runner_name("qwen3"), Some("qwen3"));
503 assert_eq!(arch_runner_name("qwen2"), Some("qwen3"));
507 assert_eq!(arch_runner_name("qwen35"), Some("qwen35"));
508 assert_eq!(arch_runner_name("qwen35moe"), Some("qwen35"));
509 assert_eq!(arch_runner_name("qwen36"), Some("qwen35"));
513 assert_eq!(arch_runner_name("qwen36moe"), Some("qwen35"));
514 assert_eq!(arch_runner_name("qwen25"), Some("qwen3"));
517 assert_eq!(arch_runner_name("qwen2_5"), Some("qwen3"));
518 assert_eq!(arch_runner_name("llama"), Some("llama32"));
519 assert_eq!(arch_runner_name("gemma"), Some("gemma"));
520 assert_eq!(arch_runner_name("gemma2"), Some("gemma"));
521 }
522
523 #[test]
524 fn arch_runner_maps_vision_and_diffusion() {
525 assert_eq!(arch_runner_name("dinov2"), Some("dinov2"));
526 assert_eq!(arch_runner_name("sam"), Some("sam1"));
527 assert_eq!(arch_runner_name("mobile-sam"), Some("sam1"));
528 assert_eq!(arch_runner_name("sam2"), Some("sam2"));
529 assert_eq!(arch_runner_name("sam3"), Some("sam3"));
530 assert_eq!(arch_runner_name("flux"), Some("flux2"));
531 assert_eq!(arch_runner_name("vjepa2"), Some("vjepa2"));
532 assert_eq!(arch_runner_name("w2v-bert"), Some("wav2vec2-bert"));
533 }
534
535 #[test]
536 fn arch_runner_returns_none_for_embed_and_unknown() {
537 assert_eq!(arch_runner_name("bert"), None);
539 assert_eq!(arch_runner_name("nomic-bert"), None);
540 assert_eq!(arch_runner_name("totally-fake-arch"), None);
541 }
542
543 #[test]
544 fn known_unimplemented_covers_plan_families() {
545 assert_eq!(
547 known_unimplemented_arch("mistral3").map(|u| u.milestone),
548 Some("M4")
549 );
550 assert_eq!(
551 known_unimplemented_arch("phi3").map(|u| u.milestone),
552 Some("M4")
553 );
554 assert_eq!(
555 known_unimplemented_arch("phi4").map(|u| u.milestone),
556 Some("M4")
557 );
558 assert_eq!(
559 known_unimplemented_arch("bonsai").map(|u| u.milestone),
560 Some("M4")
561 );
562 assert_eq!(
564 known_unimplemented_arch("minimax-m2").map(|u| u.milestone),
565 Some("M5")
566 );
567 assert_eq!(
568 known_unimplemented_arch("glm4").map(|u| u.milestone),
569 Some("M5")
570 );
571 assert_eq!(
572 known_unimplemented_arch("nemotron_h").map(|u| u.milestone),
573 Some("M5")
574 );
575 assert_eq!(
577 known_unimplemented_arch("qwen3_mtp").map(|u| u.milestone),
578 Some("M6")
579 );
580 assert_eq!(
582 known_unimplemented_arch("qwen3vl").map(|u| u.milestone),
583 Some("M7")
584 );
585 assert_eq!(known_unimplemented_arch("qwen3"), None);
588 assert_eq!(known_unimplemented_arch("mistral"), None);
589 assert_eq!(known_unimplemented_arch("totally-fake"), None);
590 }
591
592 #[test]
593 fn auto_sniff_error_points_at_milestone_for_known_unimplemented() {
594 let mut buf: Vec<u8> = Vec::new();
596 buf.extend_from_slice(&rlx_gguf::GGUF_MAGIC.to_le_bytes());
597 buf.extend_from_slice(&3u32.to_le_bytes());
598 buf.extend_from_slice(&1u64.to_le_bytes());
599 buf.extend_from_slice(&1u64.to_le_bytes());
600 let k = "general.architecture";
601 buf.extend_from_slice(&(k.len() as u64).to_le_bytes());
602 buf.extend_from_slice(k.as_bytes());
603 buf.extend_from_slice(&8u32.to_le_bytes());
604 let v = "mistral3";
605 buf.extend_from_slice(&(v.len() as u64).to_le_bytes());
606 buf.extend_from_slice(v.as_bytes());
607 let name = "w";
608 buf.extend_from_slice(&(name.len() as u64).to_le_bytes());
609 buf.extend_from_slice(name.as_bytes());
610 buf.extend_from_slice(&1u32.to_le_bytes());
611 buf.extend_from_slice(&4u64.to_le_bytes());
612 buf.extend_from_slice(&(rlx_gguf::GgmlType::F32 as u32).to_le_bytes());
613 buf.extend_from_slice(&0u64.to_le_bytes());
614 while !buf
615 .len()
616 .is_multiple_of(rlx_gguf::DEFAULT_ALIGNMENT as usize)
617 {
618 buf.push(0);
619 }
620 for _ in 0..4 {
621 buf.extend_from_slice(&1.0f32.to_le_bytes());
622 }
623 let path = std::env::temp_dir().join("rlx_auto_dispatch_mistral3_hint.gguf");
624 std::fs::write(&path, &buf).unwrap();
625 let err = auto_sniff(&path).expect_err("should error");
626 let s = format!("{err:#}");
627 assert!(s.contains("Mistral"), "expected family name in error: {s}");
628 assert!(s.contains("M4"), "expected milestone tag in error: {s}");
629 std::fs::remove_file(&path).ok();
630 }
631
632 #[test]
633 fn model_type_runner_maps_known() {
634 assert_eq!(model_type_runner_name("qwen3"), Some("qwen3"));
635 assert_eq!(model_type_runner_name("qwen3_moe"), Some("qwen3"));
636 assert_eq!(model_type_runner_name("llama"), Some("llama32"));
637 assert_eq!(model_type_runner_name("gemma3"), Some("gemma"));
638 assert_eq!(
639 model_type_runner_name("dinov2_with_registers"),
640 Some("dinov2")
641 );
642 assert_eq!(model_type_runner_name("whisper"), Some("whisper"));
643 assert_eq!(model_type_runner_name("unknown"), None);
644 }
645
646 #[test]
649 fn auto_sniff_reads_gguf_arch() {
650 let mut buf: Vec<u8> = Vec::new();
651 buf.extend_from_slice(&rlx_gguf::GGUF_MAGIC.to_le_bytes());
652 buf.extend_from_slice(&3u32.to_le_bytes());
653 buf.extend_from_slice(&1u64.to_le_bytes()); buf.extend_from_slice(&1u64.to_le_bytes()); let write_string = |buf: &mut Vec<u8>, k: &str, v: &str| {
656 buf.extend_from_slice(&(k.len() as u64).to_le_bytes());
657 buf.extend_from_slice(k.as_bytes());
658 buf.extend_from_slice(&8u32.to_le_bytes());
659 buf.extend_from_slice(&(v.len() as u64).to_le_bytes());
660 buf.extend_from_slice(v.as_bytes());
661 };
662 write_string(&mut buf, "general.architecture", "qwen3");
663 let name = "w";
665 buf.extend_from_slice(&(name.len() as u64).to_le_bytes());
666 buf.extend_from_slice(name.as_bytes());
667 buf.extend_from_slice(&1u32.to_le_bytes());
668 buf.extend_from_slice(&4u64.to_le_bytes());
669 buf.extend_from_slice(&(rlx_gguf::GgmlType::F32 as u32).to_le_bytes());
670 buf.extend_from_slice(&0u64.to_le_bytes());
671 while !buf
672 .len()
673 .is_multiple_of(rlx_gguf::DEFAULT_ALIGNMENT as usize)
674 {
675 buf.push(0);
676 }
677 for _ in 0..4 {
678 buf.extend_from_slice(&1.0f32.to_le_bytes());
679 }
680 let path = std::env::temp_dir().join("rlx_auto_dispatch_sniff.gguf");
681 std::fs::write(&path, &buf).unwrap();
682 let sniff = auto_sniff(&path).expect("sniff");
683 assert_eq!(sniff.runner_name, "qwen3");
684 match sniff.from {
685 SniffedFrom::GgufArch(a) => assert_eq!(a, "qwen3"),
686 other => panic!("wrong sniff source: {other:?}"),
687 }
688 std::fs::remove_file(&path).ok();
689 }
690
691 #[test]
694 fn run_auto_injects_weights_flag_when_missing() {
695 use crate::registry::{ModelRunner, register_runner};
696 use std::sync::{Mutex, OnceLock};
697
698 static CAPTURED: OnceLock<Mutex<Vec<String>>> = OnceLock::new();
699 fn captured() -> &'static Mutex<Vec<String>> {
700 CAPTURED.get_or_init(|| Mutex::new(Vec::new()))
701 }
702
703 struct Capture;
704 impl ModelRunner for Capture {
705 fn name(&self) -> &'static str {
706 "qwen3"
707 }
708 fn description(&self) -> &'static str {
709 "test capture"
710 }
711 fn run(&self, args: &[String]) -> Result<()> {
712 *captured().lock().unwrap() = args.to_vec();
713 Ok(())
714 }
715 }
716 register_runner(Box::new(Capture));
717
718 let dir = std::env::temp_dir().join("rlx_auto_dispatch_run_auto");
720 std::fs::create_dir_all(&dir).unwrap();
721 let path = dir.join("model.gguf");
722 let mut buf: Vec<u8> = Vec::new();
723 buf.extend_from_slice(&rlx_gguf::GGUF_MAGIC.to_le_bytes());
724 buf.extend_from_slice(&3u32.to_le_bytes());
725 buf.extend_from_slice(&1u64.to_le_bytes());
726 buf.extend_from_slice(&1u64.to_le_bytes());
727 let k = "general.architecture";
728 buf.extend_from_slice(&(k.len() as u64).to_le_bytes());
729 buf.extend_from_slice(k.as_bytes());
730 buf.extend_from_slice(&8u32.to_le_bytes());
731 let v = "qwen3";
732 buf.extend_from_slice(&(v.len() as u64).to_le_bytes());
733 buf.extend_from_slice(v.as_bytes());
734 let name = "w";
735 buf.extend_from_slice(&(name.len() as u64).to_le_bytes());
736 buf.extend_from_slice(name.as_bytes());
737 buf.extend_from_slice(&1u32.to_le_bytes());
738 buf.extend_from_slice(&4u64.to_le_bytes());
739 buf.extend_from_slice(&(rlx_gguf::GgmlType::F32 as u32).to_le_bytes());
740 buf.extend_from_slice(&0u64.to_le_bytes());
741 while !buf
742 .len()
743 .is_multiple_of(rlx_gguf::DEFAULT_ALIGNMENT as usize)
744 {
745 buf.push(0);
746 }
747 for _ in 0..4 {
748 buf.extend_from_slice(&1.0f32.to_le_bytes());
749 }
750 std::fs::write(&path, &buf).unwrap();
751
752 run_auto(&[path.display().to_string(), "--prompt".into(), "hi".into()]).unwrap();
754 let got = captured().lock().unwrap().clone();
755 assert_eq!(
756 got,
757 vec![
758 "--weights".to_string(),
759 path.display().to_string(),
760 "--prompt".into(),
761 "hi".into()
762 ]
763 );
764
765 run_auto(&[
767 path.display().to_string(),
768 "--weights".into(),
769 "/other/path".into(),
770 "--prompt".into(),
771 "hi".into(),
772 ])
773 .unwrap();
774 let got = captured().lock().unwrap().clone();
775 assert_eq!(
776 got,
777 vec![
778 "--weights".to_string(),
779 "/other/path".into(),
780 "--prompt".into(),
781 "hi".into(),
782 ]
783 );
784
785 std::fs::remove_dir_all(&dir).ok();
786 }
787
788 #[test]
789 fn auto_sniff_reads_safetensors_sidecar() {
790 let dir = std::env::temp_dir().join("rlx_auto_dispatch_sidecar");
791 std::fs::create_dir_all(&dir).unwrap();
792 let cfg = dir.join("config.json");
793 std::fs::write(&cfg, br#"{"model_type":"llama"}"#).unwrap();
794 let st = dir.join("model.safetensors");
795 std::fs::write(&st, b"").unwrap();
797 let sniff = auto_sniff(&st).expect("sniff");
798 assert_eq!(sniff.runner_name, "llama32");
799 std::fs::remove_dir_all(&dir).ok();
800 }
801}