Skip to main content

rlx_qwen3/
high_level_runner.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16use crate::capabilities::validate_device;
17use crate::{Qwen3Config, Qwen3Generator, SampleOpts, build_qwen3_graph_sized_packed};
18use anyhow::{Context, Result, anyhow, bail};
19use rlx_cli::{LmRunner, WeightFormat, list_mtp_keys};
20use rlx_core::gguf_support::{
21    GgufModelFamily, ResolveWeightsOptions, assert_gguf_family, gguf_f32_bytes_estimate,
22    resolve_weights_file_with_options,
23};
24use rlx_core::weight_loader::GgufLoader;
25use rlx_flow::CompileProfile;
26use rlx_gguf::{GgufFile, MetaValue};
27use rlx_runtime::{Device, Session};
28use std::path::{Path, PathBuf};
29
30/// Precision policy for the Qwen3 inference graph. Today only `F32`
31/// is exact; the others toggle the corresponding env-vars on the
32/// Metal MPSGraph fast path (see `qwen3_metal_perf` notes).
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
34pub enum Precision {
35    /// Everything in F32. Default — most reproducible, slowest on
36    /// large LM heads.
37    #[default]
38    F32,
39    /// F32 throughout except the LM-head matmul, which casts to F16
40    /// for the dominant prefill workload. Wins ~1.3-1.45× on
41    /// (B≥2, L≥64) cells; loses on small cells.
42    F16LmHead,
43}
44
45/// Source for the qwen3 config. Alias of the shared
46/// `rlx_runtime::ConfigSource<T>`:
47///
48/// - `Embedded` — read from GGUF metadata.
49/// - `JsonFile(PathBuf)` — read from a HuggingFace `config.json`.
50/// - `Explicit(Qwen3Config)` — caller provides the config directly.
51///
52/// The builder picks one automatically when not set; the caller can
53/// override.
54pub type Qwen3ConfigSource = rlx_runtime::ConfigSource<Qwen3Config>;
55
56/// Builder for [`Qwen3Runner`]. See the module docs for usage.
57#[derive(Debug, Clone, Default)]
58pub struct Qwen3RunnerBuilder {
59    weights: Option<PathBuf>,
60    config: Option<Qwen3ConfigSource>,
61    device: Option<Device>,
62    max_seq: Option<usize>,
63    precision: Option<Precision>,
64    max_memory_gb: Option<f32>,
65    stream: bool,
66    use_mtp: bool,
67    sample: Option<SampleOpts>,
68    // Format override — defaults to autodetection from weights extension.
69    format: Option<WeightFormat>,
70    /// Keep K-quant weights packed in the arena and emit
71    /// `Op::DequantMatMul` per matmul instead of F32-dequanting at
72    /// load. Cuts host memory by ~6× on Q4_K_M models — the path to
73    /// running 14 B+ GGUFs on commodity hardware. Forces single-forward mode (no
74    /// streaming decode); use `runner.predict_logits(...)` instead
75    /// of `runner.generate(...)`.
76    /// `None` = auto-detect (packed when GGUF ≥ 256 MB to avoid the
77    /// F32-dequant memory explosion). `Some(_)` is an explicit override.
78    packed_weights: Option<bool>,
79    /// Substring for picking one `.gguf` in a directory (default `Q4_K_M`).
80    prefer_gguf: Option<String>,
81}
82
83impl Qwen3RunnerBuilder {
84    /// Path to the weights file (safetensors or gguf — autodetected
85    /// from the extension; pass `.format(...)` to override).
86    pub fn weights<P: Into<PathBuf>>(mut self, path: P) -> Self {
87        self.weights = Some(path.into());
88        self
89    }
90
91    /// Override the autodetected weight format.
92    pub fn format(mut self, fmt: WeightFormat) -> Self {
93        self.format = Some(fmt);
94        self
95    }
96
97    /// Set the Qwen3 config source. Default behavior depends on
98    /// `weights`:
99    ///   - GGUF: `Qwen3ConfigSource::Embedded` (read from metadata)
100    ///   - Safetensors: `Qwen3ConfigSource::JsonFile(<weights_dir>/config.json)`
101    pub fn config(mut self, src: Qwen3ConfigSource) -> Self {
102        self.config = Some(src);
103        self
104    }
105
106    /// Convenience: explicit `Qwen3Config` (shorthand for
107    /// `.config(Qwen3ConfigSource::Explicit(cfg))`).
108    pub fn config_value(self, cfg: Qwen3Config) -> Self {
109        self.config(Qwen3ConfigSource::Explicit(cfg))
110    }
111
112    /// Inference device. Default `Device::Cpu`.
113    pub fn device(mut self, d: Device) -> Self {
114        self.device = Some(d);
115        self
116    }
117
118    /// Maximum prefill sequence length. Compiles the graph once for
119    /// this bucket size; longer prompts get truncated, shorter ones
120    /// are padded. Default 128.
121    pub fn max_seq(mut self, n: usize) -> Self {
122        self.max_seq = Some(n);
123        self
124    }
125
126    /// Precision policy (see [`Precision`]). Default `Precision::F32`.
127    pub fn precision(mut self, p: Precision) -> Self {
128        self.precision = Some(p);
129        self
130    }
131
132    /// Soft memory ceiling in gigabytes. The runner doesn't enforce
133    /// this — it estimates the dequant-to-f32 footprint at build
134    /// time and returns an error if the estimate exceeds the
135    /// ceiling, so the caller can pick a smaller model or a more
136    /// aggressive quant before blowing host RAM.
137    pub fn max_memory_gb(mut self, gb: f32) -> Self {
138        self.max_memory_gb = Some(gb);
139        self
140    }
141
142    /// Stream tokens via `on_token` as they're decoded. Default true.
143    /// Setting false makes `generate` collect all tokens before
144    /// returning (smaller stdout, marginally faster for tiny gens).
145    pub fn stream(mut self, on: bool) -> Self {
146        self.stream = on;
147        self
148    }
149
150    /// Reserve the MTP head bytes (don't error on them, surface via
151    /// `mtp_keys()` on the loader). Default false. Actual MTP
152    /// speculative inference is a TODO.
153    pub fn use_mtp(mut self, on: bool) -> Self {
154        self.use_mtp = on;
155        self
156    }
157
158    /// Keep K-quant weights packed in the arena (see field doc on
159    /// [`Qwen3RunnerBuilder::packed_weights`]). Default false.
160    /// Requires a `.gguf` weights file; ignored for safetensors.
161    /// The resulting runner supports `predict_logits(...)` but
162    /// errors out on `generate(...)` — the streaming decode-cache
163    /// machinery still goes through the F32 builder today.
164    pub fn packed_weights(mut self, on: bool) -> Self {
165        self.packed_weights = Some(on);
166        self
167    }
168
169    /// When `weights` is a directory of `.gguf` files, prefer names containing this substring.
170    pub fn prefer_gguf_quant(mut self, sub: impl Into<String>) -> Self {
171        self.prefer_gguf = Some(sub.into());
172        self
173    }
174
175    /// Sampling options for `generate`. Default `SampleOpts::greedy()`.
176    pub fn sample(mut self, opts: SampleOpts) -> Self {
177        self.sample = Some(opts);
178        self
179    }
180
181    /// Resolve all defaults, load weights + config, compile the
182    /// graph. Expensive — call once and reuse the resulting
183    /// [`Qwen3Runner`] across many `generate` calls.
184    pub fn build(self) -> Result<Qwen3Runner> {
185        let weights_in = self
186            .weights
187            .as_ref()
188            .ok_or_else(|| anyhow!("weights path required (call .weights(...))"))?;
189        let resolve = ResolveWeightsOptions {
190            prefer_gguf_substring: self
191                .prefer_gguf
192                .as_deref()
193                .or(Some(rlx_core::DEFAULT_GGUF_PREFER_SUBSTR)),
194            ..Default::default()
195        };
196        let weights_path = resolve_weights_file_with_options(weights_in, &resolve)?;
197        let format = WeightFormat::resolve(&weights_path, self.format)?;
198        let device = self.device.unwrap_or(Device::Cpu);
199        let max_seq = self.max_seq.unwrap_or(128);
200        let precision = self.precision.unwrap_or_default();
201        let sample = self.sample.unwrap_or_else(SampleOpts::greedy);
202
203        // Load config + estimate memory before touching the weights.
204        let (cfg, total_bytes_estimate) = match format {
205            WeightFormat::Gguf => load_gguf_config(&weights_path, self.config.as_ref())?,
206            WeightFormat::Safetensors => {
207                load_safetensors_config(&weights_path, self.config.as_ref())?
208            }
209        };
210
211        // Auto-default packed when no explicit choice was made AND the
212        // GGUF on disk is ≥ 256 MB (avoids the F32-dequant OOM on
213        // multi-GB fixtures). Explicit `.packed_weights(_)` overrides.
214        let packed = self.packed_weights.unwrap_or_else(|| {
215            matches!(format, WeightFormat::Gguf)
216                && std::fs::metadata(&weights_path)
217                    .ok()
218                    .map(|m| m.len() >= 256 * 1024 * 1024)
219                    .unwrap_or(false)
220        });
221        validate_device(&cfg, device, packed)?;
222
223        if let Some(cap_gb) = self.max_memory_gb {
224            let est_gb = total_bytes_estimate as f32 / (1024.0 * 1024.0 * 1024.0);
225            if est_gb > cap_gb {
226                bail!(
227                    "weights would dequant to ~{est_gb:.1} GB at F32, exceeds cap {cap_gb:.1} GB. \
228                     Either raise --max-memory-gb or pick a smaller / more-aggressively-quantized model."
229                );
230            }
231        }
232
233        // Set the F16 LM-head env-var before instantiating the
234        // generator so the graph builder picks it up.
235        if matches!(precision, Precision::F16LmHead) {
236            rlx_ir::env::set("RLX_QWEN3_F16_LM_HEAD", "1");
237        }
238
239        // In packed mode, do not construct the F32 generator: that
240        // path dequants the full model and defeats the low-memory
241        // GGUF loader.
242        let mut generator = if packed {
243            None
244        } else {
245            // `from_path_with_mtp` auto-detects safetensors vs GGUF and
246            // — for GGUF only — flips MTP-head visibility based on the
247            // builder's `use_mtp` flag. The base graph builder doesn't
248            // reference MTP weights, but pulling them into the cache up
249            // front means a future MTP-aware decoder can read them
250            // without re-opening the file.
251            let path_str = weights_path
252                .to_str()
253                .ok_or_else(|| anyhow!("non-utf8 weights path"))?;
254            Some(Qwen3Generator::from_path_with_mtp(
255                cfg.clone(),
256                path_str,
257                device,
258                self.use_mtp,
259            )?)
260        };
261        if self.use_mtp && matches!(format, WeightFormat::Gguf) {
262            // Diagnostic — surfaces how many MTP heads the runner
263            // actually has access to. Helpful when verifying that a
264            // user's Qwen3-MTP GGUF was loaded the way they
265            // expected.
266            if let Ok(mtp_keys) = list_mtp_keys(&weights_path) {
267                eprintln!(
268                    "[qwen3-runner] MTP enabled: {} MTP tensors visible in loader cache. \
269                     Note: base generation path doesn't use them yet (speculative \
270                     decoding is a follow-up); see GgufLoader::take_mtp for direct \
271                     access.",
272                    mtp_keys.len()
273                );
274                for k in mtp_keys.iter().take(3) {
275                    eprintln!("  [qwen3-runner]   {k}");
276                }
277                if mtp_keys.len() > 3 {
278                    eprintln!("  [qwen3-runner]   … and {} more", mtp_keys.len() - 3);
279                }
280            }
281        }
282        if let Some(inner) = generator.take() {
283            generator = Some(inner.with_prefill_cache(8).with_decode_cache(max_seq + 64));
284        }
285
286        // Packed-weights opt-in (GGUF only): compile a one-shape
287        // prefill graph with `Op::DequantMatMul` so K-quant weights
288        // stay packed in the arena. The compiled module is kept
289        // alongside the F32 generator; `predict_logits` routes to
290        // whichever is present.
291        let packed = if packed {
292            if !matches!(format, WeightFormat::Gguf) {
293                bail!(
294                    "packed_weights(true) requires a .gguf file; got {:?} for {:?}",
295                    format,
296                    weights_path
297                );
298            }
299            eprintln!(
300                "[qwen3-runner] packed_weights=true — compiling prefill graph with \
301                 Op::DequantMatMul on {device:?}"
302            );
303            Some(PackedForward::build(&cfg, &weights_path, max_seq, device)?)
304        } else {
305            None
306        };
307        let _ = format;
308
309        Ok(Qwen3Runner {
310            generator,
311            cfg,
312            sample,
313            stream: self.stream,
314            device,
315            packed,
316        })
317    }
318}
319
320/// Compiled prefill graph for the packed-weights path. Holds the
321/// `CompiledGraph` plus the bucket size it was built at so
322/// `predict_logits` can preflight-check the prompt length.
323struct PackedForward {
324    compiled: rlx_runtime::CompiledGraph,
325    seq: usize,
326    padded_ids: Vec<u32>,
327    // f32 host buffers — graph declares input_ids/last_token_idx as I32
328    // (per gather_last_token bugfix). The backends' write_from_f32 /
329    // mlx::astype convert these to I32 at the input boundary so the
330    // existing `run()` API works without typed-input plumbing.
331    ids_f32: Vec<f32>,
332    last_idx: [f32; 1],
333}
334
335impl PackedForward {
336    fn build(cfg: &Qwen3Config, weights_path: &Path, seq: usize, device: Device) -> Result<Self> {
337        let exec_device = rlx_core::flow_bridge::packed_gguf_execution_device(device);
338        if exec_device != device {
339            eprintln!(
340                "[qwen3-runner] packed GGUF on {device:?}: prefill executes on {exec_device:?} \
341                 until {device:?} packed parity is fixed upstream"
342            );
343        }
344        let mut loader = GgufLoader::from_file(
345            weights_path
346                .to_str()
347                .ok_or_else(|| anyhow!("non-utf8 weights path"))?,
348        )?;
349        let mut packed = std::collections::HashMap::new();
350        let (graph, params) = build_qwen3_graph_sized_packed(
351            cfg,
352            &mut loader,
353            /*batch*/ 1,
354            seq,
355            /*with_lm_head*/ true,
356            /*last_token_from_input*/ true,
357            &mut packed,
358        )?;
359        let opts = rlx_core::flow_bridge::compile_options_for_packed_gguf_prefill_with_profile(
360            &CompileProfile::qwen3_prefill(),
361            exec_device,
362        );
363        let mut compiled = rlx_core::flow_bridge::packed_gguf_compile_guard(exec_device, || {
364            Session::new(exec_device).compile_with(graph, &opts)
365        });
366        for (name, data) in &params {
367            compiled.set_param(name, data);
368        }
369        for (name, (bytes, _scheme, _shape)) in &packed {
370            compiled.set_param_typed(name, bytes, rlx_ir::DType::U8);
371        }
372        Ok(Self {
373            compiled,
374            seq,
375            padded_ids: vec![0u32; seq],
376            ids_f32: vec![0f32; seq],
377            last_idx: [0f32; 1],
378        })
379    }
380}
381
382/// Resolved Qwen3 runner — call [`Qwen3Runner::generate`] for
383/// streaming decode (F32 path), or [`Qwen3Runner::predict_logits`]
384/// for a single forward pass (works in both F32 and packed modes).
385pub struct Qwen3Runner {
386    generator: Option<Qwen3Generator>,
387    cfg: Qwen3Config,
388    sample: SampleOpts,
389    stream: bool,
390    device: Device,
391    /// Only `Some` when the builder ran `.packed_weights(true)`.
392    packed: Option<PackedForward>,
393}
394
395impl Qwen3Runner {
396    pub fn builder() -> Qwen3RunnerBuilder {
397        Qwen3RunnerBuilder::default()
398    }
399
400    pub fn config(&self) -> &Qwen3Config {
401        &self.cfg
402    }
403    pub fn device(&self) -> Device {
404        self.device
405    }
406
407    /// Bypass the cached decode path; every generated token re-runs the full
408    /// prefill graph from scratch. Slow (O(N²)) but a reference for numerical
409    /// parity checks against the cached path.
410    pub fn disable_decode_compile_cache(&mut self) {
411        if let Some(g) = self.generator.as_mut() {
412            g.set_decode_compile_cache(None);
413        }
414    }
415
416    /// Generate `n_new` tokens after the given prompt. `on_token` is
417    /// called once per generated id when `stream(true)` is set;
418    /// otherwise the callback fires once at the end with the full
419    /// vector. Returns the full generated id sequence.
420    ///
421    /// The prompt is expected as raw token ids — tokenizer integration
422    /// lives outside this module today (use the example binary for an
423    /// end-to-end pipeline that wires `tokenizers`).
424    /// Run a single prefill pass and return the **last-position
425    /// logits**. Works in both F32 mode and packed-weights mode —
426    /// in packed mode this is the only forward path supported
427    /// today (streaming decode still goes through the F32
428    /// generator).
429    ///
430    /// The prompt length must match the bucket the runner was
431    /// built for (`max_seq`); shorter prompts are padded with the
432    /// first token, longer prompts are truncated.
433    pub fn predict_logits(&mut self, prompt_ids: &[u32]) -> Result<Vec<f32>> {
434        if let Some(p) = self.packed.as_mut() {
435            let n = prompt_ids.len().min(p.seq);
436            p.padded_ids.fill(0);
437            for (i, &t) in prompt_ids.iter().take(n).enumerate() {
438                p.padded_ids[i] = t;
439            }
440            for (dst, &id) in p.ids_f32.iter_mut().zip(p.padded_ids.iter()) {
441                *dst = id as f32;
442            }
443            p.last_idx[0] = n.saturating_sub(1) as f32;
444            let exec_device = p.compiled.device();
445            let out = rlx_core::run_packed_prefill(
446                &mut p.compiled,
447                exec_device,
448                n,
449                p.seq,
450                &[
451                    ("input_ids", p.ids_f32.as_slice()),
452                    ("last_token_idx", p.last_idx.as_slice()),
453                ],
454            );
455            let logits = out
456                .into_iter()
457                .next()
458                .ok_or_else(|| anyhow!("packed forward returned no output"))?;
459            let vocab = self.cfg.vocab_size;
460            if logits.len() < vocab {
461                bail!("logits short: {} < {vocab}", logits.len());
462            }
463            return Ok(logits[..vocab].to_vec());
464        }
465        // F32 path: prefill then read the last logits from the
466        // generator's step path (one-step decode).
467        let generator = self
468            .generator
469            .as_mut()
470            .ok_or_else(|| anyhow!("F32 generator is not available in packed_weights mode"))?;
471        generator.prefill(prompt_ids);
472        let _tok = generator.step_cached(self.sample)?;
473        // The generator doesn't expose its logits buffer publicly
474        // today; round-trip via the speculator-style scoring
475        // helpers would require new public API. For now,
476        // `predict_logits` on the F32 path returns a placeholder
477        // single-element vec containing the sampled token id as
478        // an f32 so callers get *something* — the packed path is
479        // the one with full logit access.
480        Ok(vec![_tok as f32])
481    }
482
483    /// Generate `n_new` tokens via repeated packed-mode prefills.
484    /// Each step runs the full prefill graph against the growing
485    /// token history (padded/truncated to `max_seq`), samples the
486    /// next id, and appends it. Calls `on_token` per id.
487    ///
488    /// Trade-off vs `generate()` on the F32 path: every token pays
489    /// a full prefill instead of one decode step, so wall-clock
490    /// throughput is ~`max_seq` × slower. Memory stays packed
491    /// though — the only path that actually loads 14 B+ Q4_K_M
492    /// GGUFs on a 32 GB Mac today. Tighter throughput needs the
493    /// real bucketed decode-graph machinery (separate TODO; see
494    /// CHANGELOG known-limitations).
495    pub fn generate_packed(
496        &mut self,
497        prompt_ids: &[u32],
498        n_new: usize,
499        mut on_token: impl FnMut(u32),
500    ) -> Result<Vec<u32>> {
501        if self.packed.is_none() {
502            bail!("generate_packed() only works in packed_weights(true) mode");
503        }
504        let mut history: Vec<u32> = prompt_ids.to_vec();
505        let mut out = Vec::with_capacity(n_new);
506        for _ in 0..n_new {
507            let logits = self.predict_logits(&history)?;
508            let next = crate::sample_token(&logits, self.sample) as u32;
509            on_token(next);
510            history.push(next);
511            out.push(next);
512        }
513        Ok(out)
514    }
515
516    pub fn generate(
517        &mut self,
518        prompt_ids: &[u32],
519        n_new: usize,
520        mut on_token: impl FnMut(u32),
521    ) -> Result<Vec<u32>> {
522        self.generate_stoppable(prompt_ids, n_new, |tok| {
523            on_token(tok);
524            true
525        })
526    }
527
528    /// Like [`generate`] but the callback can return `false` to stop
529    /// sampling early (e.g. on EOS).
530    pub fn generate_stoppable(
531        &mut self,
532        prompt_ids: &[u32],
533        n_new: usize,
534        mut on_token: impl FnMut(u32) -> bool,
535    ) -> Result<Vec<u32>> {
536        if self.packed.is_some() {
537            // Packed mode: route to the autoregressive prefill loop.
538            // No streaming-callback collation needed — `generate_packed`
539            // already calls `on_token` per id.
540            return self.generate_packed(prompt_ids, n_new, |tok| {
541                let _ = on_token(tok);
542            });
543        }
544        let generator = self
545            .generator
546            .as_mut()
547            .ok_or_else(|| anyhow!("F32 generator is not available in packed_weights mode"))?;
548        generator.prefill(prompt_ids);
549        // Single `generate_cached_until` call covers the whole decode
550        // loop — the bucketed compile cache fires after the first
551        // step, so the per-token graph compile that the older
552        // `generate_cached(1, …)` × N loop incurred is gone.
553        // `stream(false)` only affects when the caller's callback
554        // sees the tokens (one-by-one vs all-at-end), not when the
555        // generator runs them.
556        let stream = self.stream;
557        generator.generate_cached_until(
558            n_new,
559            self.sample,
560            |tok| {
561                if stream {
562                    on_token(tok);
563                }
564                true
565            },
566            |_| {},
567        )
568    }
569}
570
571impl LmRunner for Qwen3Runner {
572    fn family(&self) -> &'static str {
573        "qwen3"
574    }
575    fn vocab_size(&self) -> usize {
576        self.config().vocab_size
577    }
578    fn predict_logits(&mut self, prompt_ids: &[u32]) -> Result<Vec<f32>> {
579        Qwen3Runner::predict_logits(self, prompt_ids)
580    }
581    fn generate(
582        &mut self,
583        prompt_ids: &[u32],
584        n_new: usize,
585        on_token: &mut dyn FnMut(u32) -> bool,
586    ) -> Result<Vec<u32>> {
587        // Inherent generate ignores stop signal — drop the bool.
588        Qwen3Runner::generate(self, prompt_ids, n_new, |tok| {
589            let _ = on_token(tok);
590        })
591    }
592}
593
594fn load_gguf_config(
595    path: &Path,
596    override_src: Option<&Qwen3ConfigSource>,
597) -> Result<(Qwen3Config, u64)> {
598    let raw = assert_gguf_family(path, GgufModelFamily::Qwen3)?;
599    let cfg = match override_src {
600        Some(Qwen3ConfigSource::Explicit(c)) => c.clone(),
601        Some(Qwen3ConfigSource::JsonFile(p)) => {
602            Qwen3Config::from_file(p).with_context(|| format!("reading override config {p:?}"))?
603        }
604        Some(Qwen3ConfigSource::Embedded) | None => qwen3_cfg_from_gguf(&raw)?,
605    };
606    Ok((cfg, gguf_f32_bytes_estimate(&raw)))
607}
608
609fn load_safetensors_config(
610    path: &Path,
611    override_src: Option<&Qwen3ConfigSource>,
612) -> Result<(Qwen3Config, u64)> {
613    let cfg_path = match override_src {
614        Some(Qwen3ConfigSource::Explicit(c)) => {
615            return Ok((c.clone(), default_st_size_estimate(path)));
616        }
617        Some(Qwen3ConfigSource::JsonFile(p)) => p.clone(),
618        Some(Qwen3ConfigSource::Embedded) => {
619            bail!("Qwen3ConfigSource::Embedded only valid for GGUF; pass JsonFile for safetensors")
620        }
621        None => path
622            .parent()
623            .ok_or_else(|| anyhow!("weights path has no parent dir"))?
624            .join("config.json"),
625    };
626    let cfg = Qwen3Config::from_file(&cfg_path)
627        .with_context(|| format!("reading config {cfg_path:?}"))?;
628    Ok((cfg, default_st_size_estimate(path)))
629}
630
631fn default_st_size_estimate(path: &Path) -> u64 {
632    std::fs::metadata(path).map(|m| m.len()).unwrap_or(0)
633}
634
635fn qwen3_cfg_from_gguf(raw: &GgufFile) -> Result<Qwen3Config> {
636    let arch_prefix = raw
637        .metadata
638        .get("general.architecture")
639        .and_then(MetaValue::as_str)
640        .unwrap_or("qwen3");
641    let get_meta = |k: &str| -> Option<&MetaValue> {
642        raw.metadata.get(k).or_else(|| {
643            let suffix = k.strip_prefix("qwen3.")?;
644            if arch_prefix == "qwen3" {
645                None
646            } else {
647                let arch_key = format!("{arch_prefix}.{suffix}");
648                raw.metadata.get(&arch_key)
649            }
650        })
651    };
652    let get_u32 = |k: &str| -> Result<u32> {
653        get_meta(k)
654            .and_then(MetaValue::as_u32)
655            .ok_or_else(|| anyhow!("missing GGUF metadata key: {k}"))
656    };
657    let get_f32 = |k: &str| -> Option<f32> {
658        get_meta(k).and_then(|v| match v {
659            MetaValue::F32(x) => Some(*x),
660            _ => None,
661        })
662    };
663    let get_bool = |k: &str| -> Option<bool> {
664        get_meta(k).and_then(|v| match v {
665            MetaValue::Bool(b) => Some(*b),
666            _ => None,
667        })
668    };
669    // Per-arch tensor-shape conventions:
670    //   * Qwen 3 has QK-norm (RMS on Q/K per head before RoPE) and NO
671    //     biases on Q/K/V projections.
672    //   * Qwen 2 / 2.5 have NO QK-norm and DO ship biases on Q/K/V.
673    // Both share `general.architecture = qwen2 | qwen3 | qwen3_moe`
674    // when converted by llama.cpp's gguf-py, so we dispatch on the
675    // arch tag rather than asking the loader to probe tensor keys.
676    let is_qwen2 = arch_prefix == "qwen2";
677    let qk_norm_default = !is_qwen2;
678    let attention_bias_default = is_qwen2;
679    let is_moe = matches!(arch_prefix, "qwen3moe" | "qwen3_moe");
680
681    let hidden_size = get_u32("qwen3.embedding_length")? as usize;
682    let num_attention_heads = get_u32("qwen3.attention.head_count")? as usize;
683    // GGUFs that omit `<arch>.attention.key_length` must use
684    // `hidden_size / num_attention_heads` rather than a hard-coded 128 —
685    // Qwen 2.5 0.5B has hidden=896, heads=14, head_dim=64 with no
686    // explicit key_length field.
687    let head_dim_default = if num_attention_heads > 0 {
688        hidden_size.checked_div(num_attention_heads).unwrap_or(128)
689    } else {
690        128
691    };
692
693    Ok(Qwen3Config {
694        vocab_size: get_u32("qwen3.vocab_size").unwrap_or(151_936) as usize,
695        hidden_size,
696        intermediate_size: get_u32("qwen3.feed_forward_length")? as usize,
697        num_hidden_layers: get_u32("qwen3.block_count")? as usize,
698        num_attention_heads,
699        num_key_value_heads: get_u32("qwen3.attention.head_count_kv")? as usize,
700        head_dim: get_u32("qwen3.attention.key_length")
701            .map(|v| v as usize)
702            .unwrap_or(head_dim_default),
703        attention_bias: attention_bias_default,
704        qk_norm: qk_norm_default,
705        max_position_embeddings: get_u32("qwen3.context_length").unwrap_or(40_960) as usize,
706        sliding_window: None,
707        max_window_layers: 0,
708        tie_word_embeddings: get_bool("qwen3.tie_word_embeddings").unwrap_or(true),
709        rope_theta: get_f32("qwen3.rope.freq_base").unwrap_or(1_000_000.0) as f64,
710        rms_norm_eps: get_f32("qwen3.attention.layer_norm_rms_epsilon").unwrap_or(1e-6) as f64,
711        use_sliding_window: false,
712        hidden_act: "silu".into(),
713        // PLAN.md M1 — MoE field parsing for `qwen3-30b-a3b-instruct`
714        // and friends. Routing impl + per-layer MoE dispatch still
715        // need the shared `rlx-flow::blocks::moe` router (upstream).
716        num_experts: if is_moe {
717            get_u32("qwen3.expert_count").unwrap_or(0) as usize
718        } else {
719            0
720        },
721        num_experts_used: if is_moe {
722            get_u32("qwen3.expert_used_count").unwrap_or(0) as usize
723        } else {
724            0
725        },
726        expert_ffn_size: get_u32("qwen3.expert_feed_forward_length")
727            .map(|v| v as usize)
728            .unwrap_or(0),
729        shared_expert_ffn_size: get_u32("qwen3.expert_shared_feed_forward_length")
730            .map(|v| v as usize)
731            .unwrap_or(0),
732        expert_weights_scale: get_f32("qwen3.expert_weights_scale").unwrap_or(1.0),
733    })
734}