Skip to main content

rlx_qwen3/
high_level_runner.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16use crate::capabilities::validate_device;
17use crate::{Qwen3Config, Qwen3Generator, SampleOpts, build_qwen3_graph_sized_packed};
18use anyhow::{Context, Result, anyhow, bail};
19use rlx_cli::{LmRunner, WeightFormat, list_mtp_keys};
20use rlx_core::gguf_support::{
21    GgufModelFamily, ResolveWeightsOptions, assert_gguf_family, gguf_f32_bytes_estimate,
22    resolve_weights_file_with_options,
23};
24use rlx_core::weight_loader::GgufLoader;
25use rlx_flow::CompileProfile;
26use rlx_gguf::{GgufFile, MetaValue};
27use rlx_runtime::{Device, Session};
28use std::path::{Path, PathBuf};
29
30/// Precision policy for the Qwen3 inference graph. Today only `F32`
31/// is exact; the others toggle the corresponding env-vars on the
32/// Metal MPSGraph fast path (see `qwen3_metal_perf` notes).
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
34pub enum Precision {
35    /// Everything in F32. Default — most reproducible, slowest on
36    /// large LM heads.
37    #[default]
38    F32,
39    /// F32 throughout except the LM-head matmul, which casts to F16
40    /// for the dominant prefill workload. Wins ~1.3-1.45× on
41    /// (B≥2, L≥64) cells; loses on small cells.
42    F16LmHead,
43}
44
45/// Source for the qwen3 config. The builder picks one automatically
46/// (GGUF embedded vs. sibling `config.json`) but the caller can
47/// override.
48#[derive(Debug, Clone)]
49pub enum Qwen3ConfigSource {
50    /// Read from GGUF metadata.
51    Embedded,
52    /// Read from a HuggingFace `config.json` at this path.
53    JsonFile(PathBuf),
54    /// Use the supplied config object directly.
55    Explicit(Qwen3Config),
56}
57
58/// Builder for [`Qwen3Runner`]. See the module docs for usage.
59#[derive(Debug, Clone, Default)]
60pub struct Qwen3RunnerBuilder {
61    weights: Option<PathBuf>,
62    config: Option<Qwen3ConfigSource>,
63    device: Option<Device>,
64    max_seq: Option<usize>,
65    precision: Option<Precision>,
66    max_memory_gb: Option<f32>,
67    stream: bool,
68    use_mtp: bool,
69    sample: Option<SampleOpts>,
70    // Format override — defaults to autodetection from weights extension.
71    format: Option<WeightFormat>,
72    /// Keep K-quant weights packed in the arena and emit
73    /// `Op::DequantMatMul` per matmul instead of F32-dequanting at
74    /// load. Cuts host memory by ~6× on Q4_K_M models — the path to
75    /// running 14 B+ GGUFs on commodity hardware. Forces single-forward mode (no
76    /// streaming decode); use `runner.predict_logits(...)` instead
77    /// of `runner.generate(...)`.
78    /// `None` = auto-detect (packed when GGUF ≥ 256 MB to avoid the
79    /// F32-dequant memory explosion). `Some(_)` is an explicit override.
80    packed_weights: Option<bool>,
81    /// Substring for picking one `.gguf` in a directory (default `Q4_K_M`).
82    prefer_gguf: Option<String>,
83}
84
85impl Qwen3RunnerBuilder {
86    /// Path to the weights file (safetensors or gguf — autodetected
87    /// from the extension; pass `.format(...)` to override).
88    pub fn weights<P: Into<PathBuf>>(mut self, path: P) -> Self {
89        self.weights = Some(path.into());
90        self
91    }
92
93    /// Override the autodetected weight format.
94    pub fn format(mut self, fmt: WeightFormat) -> Self {
95        self.format = Some(fmt);
96        self
97    }
98
99    /// Set the Qwen3 config source. Default behavior depends on
100    /// `weights`:
101    ///   - GGUF: `Qwen3ConfigSource::Embedded` (read from metadata)
102    ///   - Safetensors: `Qwen3ConfigSource::JsonFile(<weights_dir>/config.json)`
103    pub fn config(mut self, src: Qwen3ConfigSource) -> Self {
104        self.config = Some(src);
105        self
106    }
107
108    /// Convenience: explicit `Qwen3Config` (shorthand for
109    /// `.config(Qwen3ConfigSource::Explicit(cfg))`).
110    pub fn config_value(self, cfg: Qwen3Config) -> Self {
111        self.config(Qwen3ConfigSource::Explicit(cfg))
112    }
113
114    /// Inference device. Default `Device::Cpu`.
115    pub fn device(mut self, d: Device) -> Self {
116        self.device = Some(d);
117        self
118    }
119
120    /// Maximum prefill sequence length. Compiles the graph once for
121    /// this bucket size; longer prompts get truncated, shorter ones
122    /// are padded. Default 128.
123    pub fn max_seq(mut self, n: usize) -> Self {
124        self.max_seq = Some(n);
125        self
126    }
127
128    /// Precision policy (see [`Precision`]). Default `Precision::F32`.
129    pub fn precision(mut self, p: Precision) -> Self {
130        self.precision = Some(p);
131        self
132    }
133
134    /// Soft memory ceiling in gigabytes. The runner doesn't enforce
135    /// this — it estimates the dequant-to-f32 footprint at build
136    /// time and returns an error if the estimate exceeds the
137    /// ceiling, so the caller can pick a smaller model or a more
138    /// aggressive quant before blowing host RAM.
139    pub fn max_memory_gb(mut self, gb: f32) -> Self {
140        self.max_memory_gb = Some(gb);
141        self
142    }
143
144    /// Stream tokens via `on_token` as they're decoded. Default true.
145    /// Setting false makes `generate` collect all tokens before
146    /// returning (smaller stdout, marginally faster for tiny gens).
147    pub fn stream(mut self, on: bool) -> Self {
148        self.stream = on;
149        self
150    }
151
152    /// Reserve the MTP head bytes (don't error on them, surface via
153    /// `mtp_keys()` on the loader). Default false. Actual MTP
154    /// speculative inference is a TODO.
155    pub fn use_mtp(mut self, on: bool) -> Self {
156        self.use_mtp = on;
157        self
158    }
159
160    /// Keep K-quant weights packed in the arena (see field doc on
161    /// [`Qwen3RunnerBuilder::packed_weights`]). Default false.
162    /// Requires a `.gguf` weights file; ignored for safetensors.
163    /// The resulting runner supports `predict_logits(...)` but
164    /// errors out on `generate(...)` — the streaming decode-cache
165    /// machinery still goes through the F32 builder today.
166    pub fn packed_weights(mut self, on: bool) -> Self {
167        self.packed_weights = Some(on);
168        self
169    }
170
171    /// When `weights` is a directory of `.gguf` files, prefer names containing this substring.
172    pub fn prefer_gguf_quant(mut self, sub: impl Into<String>) -> Self {
173        self.prefer_gguf = Some(sub.into());
174        self
175    }
176
177    /// Sampling options for `generate`. Default `SampleOpts::greedy()`.
178    pub fn sample(mut self, opts: SampleOpts) -> Self {
179        self.sample = Some(opts);
180        self
181    }
182
183    /// Resolve all defaults, load weights + config, compile the
184    /// graph. Expensive — call once and reuse the resulting
185    /// [`Qwen3Runner`] across many `generate` calls.
186    pub fn build(self) -> Result<Qwen3Runner> {
187        let weights_in = self
188            .weights
189            .as_ref()
190            .ok_or_else(|| anyhow!("weights path required (call .weights(...))"))?;
191        let resolve = ResolveWeightsOptions {
192            prefer_gguf_substring: self
193                .prefer_gguf
194                .as_deref()
195                .or(Some(rlx_core::DEFAULT_GGUF_PREFER_SUBSTR)),
196            ..Default::default()
197        };
198        let weights_path = resolve_weights_file_with_options(weights_in, &resolve)?;
199        let format = WeightFormat::resolve(&weights_path, self.format)?;
200        let device = self.device.unwrap_or(Device::Cpu);
201        let max_seq = self.max_seq.unwrap_or(128);
202        let precision = self.precision.unwrap_or_default();
203        let sample = self.sample.unwrap_or_else(SampleOpts::greedy);
204
205        // Load config + estimate memory before touching the weights.
206        let (cfg, total_bytes_estimate) = match format {
207            WeightFormat::Gguf => load_gguf_config(&weights_path, self.config.as_ref())?,
208            WeightFormat::Safetensors => {
209                load_safetensors_config(&weights_path, self.config.as_ref())?
210            }
211        };
212
213        // Auto-default packed when no explicit choice was made AND the
214        // GGUF on disk is ≥ 256 MB (avoids the F32-dequant OOM on
215        // multi-GB fixtures). Explicit `.packed_weights(_)` overrides.
216        let packed = self.packed_weights.unwrap_or_else(|| {
217            matches!(format, WeightFormat::Gguf)
218                && std::fs::metadata(&weights_path)
219                    .ok()
220                    .map(|m| m.len() >= 256 * 1024 * 1024)
221                    .unwrap_or(false)
222        });
223        validate_device(&cfg, device, packed)?;
224
225        if let Some(cap_gb) = self.max_memory_gb {
226            let est_gb = total_bytes_estimate as f32 / (1024.0 * 1024.0 * 1024.0);
227            if est_gb > cap_gb {
228                bail!(
229                    "weights would dequant to ~{est_gb:.1} GB at F32, exceeds cap {cap_gb:.1} GB. \
230                     Either raise --max-memory-gb or pick a smaller / more-aggressively-quantized model."
231                );
232            }
233        }
234
235        // Set the F16 LM-head env-var before instantiating the
236        // generator so the graph builder picks it up.
237        if matches!(precision, Precision::F16LmHead) {
238            rlx_ir::env::set("RLX_QWEN3_F16_LM_HEAD", "1");
239        }
240
241        // In packed mode, do not construct the F32 generator: that
242        // path dequants the full model and defeats the low-memory
243        // GGUF loader.
244        let mut generator = if packed {
245            None
246        } else {
247            // `from_path_with_mtp` auto-detects safetensors vs GGUF and
248            // — for GGUF only — flips MTP-head visibility based on the
249            // builder's `use_mtp` flag. The base graph builder doesn't
250            // reference MTP weights, but pulling them into the cache up
251            // front means a future MTP-aware decoder can read them
252            // without re-opening the file.
253            let path_str = weights_path
254                .to_str()
255                .ok_or_else(|| anyhow!("non-utf8 weights path"))?;
256            Some(Qwen3Generator::from_path_with_mtp(
257                cfg.clone(),
258                path_str,
259                device,
260                self.use_mtp,
261            )?)
262        };
263        if self.use_mtp && matches!(format, WeightFormat::Gguf) {
264            // Diagnostic — surfaces how many MTP heads the runner
265            // actually has access to. Helpful when verifying that a
266            // user's Qwen3-MTP GGUF was loaded the way they
267            // expected.
268            if let Ok(mtp_keys) = list_mtp_keys(&weights_path) {
269                eprintln!(
270                    "[qwen3-runner] MTP enabled: {} MTP tensors visible in loader cache. \
271                     Note: base generation path doesn't use them yet (speculative \
272                     decoding is a follow-up); see GgufLoader::take_mtp for direct \
273                     access.",
274                    mtp_keys.len()
275                );
276                for k in mtp_keys.iter().take(3) {
277                    eprintln!("  [qwen3-runner]   {k}");
278                }
279                if mtp_keys.len() > 3 {
280                    eprintln!("  [qwen3-runner]   … and {} more", mtp_keys.len() - 3);
281                }
282            }
283        }
284        if let Some(inner) = generator.take() {
285            generator = Some(inner.with_prefill_cache(8).with_decode_cache(max_seq + 64));
286        }
287
288        // Packed-weights opt-in (GGUF only): compile a one-shape
289        // prefill graph with `Op::DequantMatMul` so K-quant weights
290        // stay packed in the arena. The compiled module is kept
291        // alongside the F32 generator; `predict_logits` routes to
292        // whichever is present.
293        let packed = if packed {
294            if !matches!(format, WeightFormat::Gguf) {
295                bail!(
296                    "packed_weights(true) requires a .gguf file; got {:?} for {:?}",
297                    format,
298                    weights_path
299                );
300            }
301            eprintln!(
302                "[qwen3-runner] packed_weights=true — compiling prefill graph with \
303                 Op::DequantMatMul on {device:?}"
304            );
305            Some(PackedForward::build(&cfg, &weights_path, max_seq, device)?)
306        } else {
307            None
308        };
309        let _ = format;
310
311        Ok(Qwen3Runner {
312            generator,
313            cfg,
314            sample,
315            stream: self.stream,
316            device,
317            packed,
318        })
319    }
320}
321
322/// Compiled prefill graph for the packed-weights path. Holds the
323/// `CompiledGraph` plus the bucket size it was built at so
324/// `predict_logits` can preflight-check the prompt length.
325struct PackedForward {
326    compiled: rlx_runtime::CompiledGraph,
327    seq: usize,
328}
329
330impl PackedForward {
331    fn build(cfg: &Qwen3Config, weights_path: &Path, seq: usize, device: Device) -> Result<Self> {
332        let mut loader = GgufLoader::from_file(
333            weights_path
334                .to_str()
335                .ok_or_else(|| anyhow!("non-utf8 weights path"))?,
336        )?;
337        let mut packed = std::collections::HashMap::new();
338        // `last_logits_only=false` → graph emits logits for every
339        // position. Runner extracts the row at the real prompt's last
340        // index in `predict_logits`. Causal attention guarantees that
341        // position is independent of the zero-padded tail.
342        let (graph, params) = build_qwen3_graph_sized_packed(
343            cfg,
344            &mut loader,
345            /*batch*/ 1,
346            seq,
347            /*with_lm_head*/ true,
348            /*last_logits_only*/ false,
349            &mut packed,
350        )?;
351        let opts = rlx_core::flow_bridge::compile_options_for_profile(
352            &CompileProfile::qwen3_prefill(),
353            device,
354        );
355        let mut compiled = Session::new(device).compile_with(graph, &opts);
356        for (name, data) in &params {
357            compiled.set_param(name, data);
358        }
359        for (name, (bytes, _scheme, _shape)) in &packed {
360            compiled.set_param_typed(name, bytes, rlx_ir::DType::U8);
361        }
362        Ok(Self { compiled, seq })
363    }
364}
365
366/// Resolved Qwen3 runner — call [`Qwen3Runner::generate`] for
367/// streaming decode (F32 path), or [`Qwen3Runner::predict_logits`]
368/// for a single forward pass (works in both F32 and packed modes).
369pub struct Qwen3Runner {
370    generator: Option<Qwen3Generator>,
371    cfg: Qwen3Config,
372    sample: SampleOpts,
373    stream: bool,
374    device: Device,
375    /// Only `Some` when the builder ran `.packed_weights(true)`.
376    packed: Option<PackedForward>,
377}
378
379impl Qwen3Runner {
380    pub fn builder() -> Qwen3RunnerBuilder {
381        Qwen3RunnerBuilder::default()
382    }
383
384    pub fn config(&self) -> &Qwen3Config {
385        &self.cfg
386    }
387    pub fn device(&self) -> Device {
388        self.device
389    }
390
391    /// Generate `n_new` tokens after the given prompt. `on_token` is
392    /// called once per generated id when `stream(true)` is set;
393    /// otherwise the callback fires once at the end with the full
394    /// vector. Returns the full generated id sequence.
395    ///
396    /// The prompt is expected as raw token ids — tokenizer integration
397    /// lives outside this module today (use the example binary for an
398    /// end-to-end pipeline that wires `tokenizers`).
399    /// Run a single prefill pass and return the **last-position
400    /// logits**. Works in both F32 mode and packed-weights mode —
401    /// in packed mode this is the only forward path supported
402    /// today (streaming decode still goes through the F32
403    /// generator).
404    ///
405    /// The prompt length must match the bucket the runner was
406    /// built for (`max_seq`); shorter prompts are padded with the
407    /// first token, longer prompts are truncated.
408    pub fn predict_logits(&mut self, prompt_ids: &[u32]) -> Result<Vec<f32>> {
409        if let Some(p) = self.packed.as_mut() {
410            // Pad with zeros after the real prompt. Causal attention
411            // means position `prompt_len - 1` only attends to the real
412            // prompt tokens — padding can be anything, the prediction
413            // at the real last position is parity-correct. The graph
414            // (`build_qwen3_graph_sized_packed`) returns logits for
415            // every position when `last_logits_only=false`, so we slice
416            // out the row for `prompt_len - 1`. Previously the runner
417            // padded with `prompt_ids.first()` *and* the graph emitted
418            // logits at `seq - 1` — both wrong, both caused the rlx vs
419            // llama.cpp top-1 mismatch surfaced by `auto_runner_parity`.
420            let n = prompt_ids.len().min(p.seq);
421            let last = n.saturating_sub(1);
422            let mut padded = vec![0u32; p.seq];
423            for (i, &t) in prompt_ids.iter().take(p.seq).enumerate() {
424                padded[i] = t;
425            }
426            let ids_f32: Vec<f32> = padded.iter().map(|&i| i as f32).collect();
427            let out = p.compiled.run(&[("input_ids", ids_f32.as_slice())]);
428            let logits_all = out
429                .into_iter()
430                .next()
431                .ok_or_else(|| anyhow!("packed forward returned no output"))?;
432            // Output shape is `[batch=1, seq, vocab]`; slice out position
433            // `last` directly so callers get a single-row logit vector.
434            let vocab = logits_all.len() / p.seq.max(1);
435            let start = last * vocab;
436            let row = logits_all[start..start + vocab].to_vec();
437            return Ok(row);
438        }
439        // F32 path: prefill then read the last logits from the
440        // generator's step path (one-step decode).
441        let generator = self
442            .generator
443            .as_mut()
444            .ok_or_else(|| anyhow!("F32 generator is not available in packed_weights mode"))?;
445        generator.prefill(prompt_ids);
446        let _tok = generator.step_cached(self.sample)?;
447        // The generator doesn't expose its logits buffer publicly
448        // today; round-trip via the speculator-style scoring
449        // helpers would require new public API. For now,
450        // `predict_logits` on the F32 path returns a placeholder
451        // single-element vec containing the sampled token id as
452        // an f32 so callers get *something* — the packed path is
453        // the one with full logit access.
454        Ok(vec![_tok as f32])
455    }
456
457    /// Generate `n_new` tokens via repeated packed-mode prefills.
458    /// Each step runs the full prefill graph against the growing
459    /// token history (padded/truncated to `max_seq`), samples the
460    /// next id, and appends it. Calls `on_token` per id.
461    ///
462    /// Trade-off vs `generate()` on the F32 path: every token pays
463    /// a full prefill instead of one decode step, so wall-clock
464    /// throughput is ~`max_seq` × slower. Memory stays packed
465    /// though — the only path that actually loads 14 B+ Q4_K_M
466    /// GGUFs on a 32 GB Mac today. Tighter throughput needs the
467    /// real bucketed decode-graph machinery (separate TODO; see
468    /// CHANGELOG known-limitations).
469    pub fn generate_packed(
470        &mut self,
471        prompt_ids: &[u32],
472        n_new: usize,
473        mut on_token: impl FnMut(u32),
474    ) -> Result<Vec<u32>> {
475        if self.packed.is_none() {
476            bail!("generate_packed() only works in packed_weights(true) mode");
477        }
478        let mut history: Vec<u32> = prompt_ids.to_vec();
479        let mut out = Vec::with_capacity(n_new);
480        for _ in 0..n_new {
481            let logits = self.predict_logits(&history)?;
482            let next = crate::sample_token(&logits, self.sample) as u32;
483            on_token(next);
484            history.push(next);
485            out.push(next);
486        }
487        Ok(out)
488    }
489
490    pub fn generate(
491        &mut self,
492        prompt_ids: &[u32],
493        n_new: usize,
494        mut on_token: impl FnMut(u32),
495    ) -> Result<Vec<u32>> {
496        if self.packed.is_some() {
497            // Packed mode: route to the autoregressive prefill loop.
498            // No streaming-callback collation needed — `generate_packed`
499            // already calls `on_token` per id.
500            return self.generate_packed(prompt_ids, n_new, on_token);
501        }
502        let generator = self
503            .generator
504            .as_mut()
505            .ok_or_else(|| anyhow!("F32 generator is not available in packed_weights mode"))?;
506        generator.prefill(prompt_ids);
507        // Single `generate_cached_with` call covers the whole decode
508        // loop — the bucketed compile cache fires after the first
509        // step, so the per-token graph compile that the older
510        // `generate_cached(1, …)` × N loop incurred is gone.
511        // `stream(false)` only affects when the caller's callback
512        // sees the tokens (one-by-one vs all-at-end), not when the
513        // generator runs them.
514        let tokens = if self.stream {
515            generator.generate_cached_with(n_new, self.sample, on_token)?
516        } else {
517            let toks = generator.generate_cached(n_new, self.sample)?;
518            for &t in &toks {
519                on_token(t);
520            }
521            toks
522        };
523        Ok(tokens)
524    }
525}
526
527impl LmRunner for Qwen3Runner {
528    fn family(&self) -> &'static str {
529        "qwen3"
530    }
531    fn vocab_size(&self) -> usize {
532        self.config().vocab_size
533    }
534    fn predict_logits(&mut self, prompt_ids: &[u32]) -> Result<Vec<f32>> {
535        Qwen3Runner::predict_logits(self, prompt_ids)
536    }
537    fn generate(
538        &mut self,
539        prompt_ids: &[u32],
540        n_new: usize,
541        on_token: &mut dyn FnMut(u32) -> bool,
542    ) -> Result<Vec<u32>> {
543        // Inherent generate ignores stop signal — drop the bool.
544        Qwen3Runner::generate(self, prompt_ids, n_new, |tok| {
545            let _ = on_token(tok);
546        })
547    }
548}
549
550fn load_gguf_config(
551    path: &Path,
552    override_src: Option<&Qwen3ConfigSource>,
553) -> Result<(Qwen3Config, u64)> {
554    let raw = assert_gguf_family(path, GgufModelFamily::Qwen3)?;
555    let cfg = match override_src {
556        Some(Qwen3ConfigSource::Explicit(c)) => c.clone(),
557        Some(Qwen3ConfigSource::JsonFile(p)) => {
558            Qwen3Config::from_file(p).with_context(|| format!("reading override config {p:?}"))?
559        }
560        Some(Qwen3ConfigSource::Embedded) | None => qwen3_cfg_from_gguf(&raw)?,
561    };
562    Ok((cfg, gguf_f32_bytes_estimate(&raw)))
563}
564
565fn load_safetensors_config(
566    path: &Path,
567    override_src: Option<&Qwen3ConfigSource>,
568) -> Result<(Qwen3Config, u64)> {
569    let cfg_path = match override_src {
570        Some(Qwen3ConfigSource::Explicit(c)) => {
571            return Ok((c.clone(), default_st_size_estimate(path)));
572        }
573        Some(Qwen3ConfigSource::JsonFile(p)) => p.clone(),
574        Some(Qwen3ConfigSource::Embedded) => {
575            bail!("Qwen3ConfigSource::Embedded only valid for GGUF; pass JsonFile for safetensors")
576        }
577        None => path
578            .parent()
579            .ok_or_else(|| anyhow!("weights path has no parent dir"))?
580            .join("config.json"),
581    };
582    let cfg = Qwen3Config::from_file(&cfg_path)
583        .with_context(|| format!("reading config {cfg_path:?}"))?;
584    Ok((cfg, default_st_size_estimate(path)))
585}
586
587fn default_st_size_estimate(path: &Path) -> u64 {
588    std::fs::metadata(path).map(|m| m.len()).unwrap_or(0)
589}
590
591fn qwen3_cfg_from_gguf(raw: &GgufFile) -> Result<Qwen3Config> {
592    let arch_prefix = raw
593        .metadata
594        .get("general.architecture")
595        .and_then(MetaValue::as_str)
596        .unwrap_or("qwen3");
597    let get_meta = |k: &str| -> Option<&MetaValue> {
598        raw.metadata.get(k).or_else(|| {
599            let suffix = k.strip_prefix("qwen3.")?;
600            if arch_prefix == "qwen3" {
601                None
602            } else {
603                let arch_key = format!("{arch_prefix}.{suffix}");
604                raw.metadata.get(&arch_key)
605            }
606        })
607    };
608    let get_u32 = |k: &str| -> Result<u32> {
609        get_meta(k)
610            .and_then(MetaValue::as_u32)
611            .ok_or_else(|| anyhow!("missing GGUF metadata key: {k}"))
612    };
613    let get_f32 = |k: &str| -> Option<f32> {
614        get_meta(k).and_then(|v| match v {
615            MetaValue::F32(x) => Some(*x),
616            _ => None,
617        })
618    };
619    let get_bool = |k: &str| -> Option<bool> {
620        get_meta(k).and_then(|v| match v {
621            MetaValue::Bool(b) => Some(*b),
622            _ => None,
623        })
624    };
625    // Per-arch tensor-shape conventions:
626    //   * Qwen 3 has QK-norm (RMS on Q/K per head before RoPE) and NO
627    //     biases on Q/K/V projections.
628    //   * Qwen 2 / 2.5 have NO QK-norm and DO ship biases on Q/K/V.
629    // Both share `general.architecture = qwen2 | qwen3 | qwen3_moe`
630    // when converted by llama.cpp's gguf-py, so we dispatch on the
631    // arch tag rather than asking the loader to probe tensor keys.
632    let is_qwen2 = arch_prefix == "qwen2";
633    let qk_norm_default = !is_qwen2;
634    let attention_bias_default = is_qwen2;
635    let is_moe = matches!(arch_prefix, "qwen3moe" | "qwen3_moe");
636
637    let hidden_size = get_u32("qwen3.embedding_length")? as usize;
638    let num_attention_heads = get_u32("qwen3.attention.head_count")? as usize;
639    // GGUFs that omit `<arch>.attention.key_length` must use
640    // `hidden_size / num_attention_heads` rather than a hard-coded 128 —
641    // Qwen 2.5 0.5B has hidden=896, heads=14, head_dim=64 with no
642    // explicit key_length field.
643    let head_dim_default = if num_attention_heads > 0 {
644        hidden_size.checked_div(num_attention_heads).unwrap_or(128)
645    } else {
646        128
647    };
648
649    Ok(Qwen3Config {
650        vocab_size: get_u32("qwen3.vocab_size").unwrap_or(151_936) as usize,
651        hidden_size,
652        intermediate_size: get_u32("qwen3.feed_forward_length")? as usize,
653        num_hidden_layers: get_u32("qwen3.block_count")? as usize,
654        num_attention_heads,
655        num_key_value_heads: get_u32("qwen3.attention.head_count_kv")? as usize,
656        head_dim: get_u32("qwen3.attention.key_length")
657            .map(|v| v as usize)
658            .unwrap_or(head_dim_default),
659        attention_bias: attention_bias_default,
660        qk_norm: qk_norm_default,
661        max_position_embeddings: get_u32("qwen3.context_length").unwrap_or(40_960) as usize,
662        sliding_window: None,
663        max_window_layers: 0,
664        tie_word_embeddings: get_bool("qwen3.tie_word_embeddings").unwrap_or(true),
665        rope_theta: get_f32("qwen3.rope.freq_base").unwrap_or(1_000_000.0) as f64,
666        rms_norm_eps: get_f32("qwen3.attention.layer_norm_rms_epsilon").unwrap_or(1e-6) as f64,
667        use_sliding_window: false,
668        hidden_act: "silu".into(),
669        // PLAN.md M1 — MoE field parsing for `qwen3-30b-a3b-instruct`
670        // and friends. Routing impl + per-layer MoE dispatch still
671        // need the shared `rlx-flow::blocks::moe` router (upstream).
672        num_experts: if is_moe {
673            get_u32("qwen3.expert_count").unwrap_or(0) as usize
674        } else {
675            0
676        },
677        num_experts_used: if is_moe {
678            get_u32("qwen3.expert_used_count").unwrap_or(0) as usize
679        } else {
680            0
681        },
682        expert_ffn_size: get_u32("qwen3.expert_feed_forward_length")
683            .map(|v| v as usize)
684            .unwrap_or(0),
685        shared_expert_ffn_size: get_u32("qwen3.expert_shared_feed_forward_length")
686            .map(|v| v as usize)
687            .unwrap_or(0),
688        expert_weights_scale: get_f32("qwen3.expert_weights_scale").unwrap_or(1.0),
689    })
690}