rlx_qwen3/high_level_runner.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16use crate::capabilities::validate_device;
17use crate::{Qwen3Config, Qwen3Generator, SampleOpts, build_qwen3_graph_sized_packed};
18use anyhow::{Context, Result, anyhow, bail};
19use rlx_cli::{LmRunner, WeightFormat, list_mtp_keys};
20use rlx_core::gguf_support::{
21 GgufModelFamily, ResolveWeightsOptions, assert_gguf_family, gguf_f32_bytes_estimate,
22 resolve_weights_file_with_options,
23};
24use rlx_core::weight_loader::GgufLoader;
25use rlx_flow::CompileProfile;
26use rlx_gguf::{GgufFile, MetaValue};
27use rlx_runtime::{Device, Session};
28use std::path::{Path, PathBuf};
29
30/// Precision policy for the Qwen3 inference graph. Today only `F32`
31/// is exact; the others toggle the corresponding env-vars on the
32/// Metal MPSGraph fast path (see `qwen3_metal_perf` notes).
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
34pub enum Precision {
35 /// Everything in F32. Default — most reproducible, slowest on
36 /// large LM heads.
37 #[default]
38 F32,
39 /// F32 throughout except the LM-head matmul, which casts to F16
40 /// for the dominant prefill workload. Wins ~1.3-1.45× on
41 /// (B≥2, L≥64) cells; loses on small cells.
42 F16LmHead,
43}
44
45/// Source for the qwen3 config. The builder picks one automatically
46/// (GGUF embedded vs. sibling `config.json`) but the caller can
47/// override.
48#[derive(Debug, Clone)]
49pub enum Qwen3ConfigSource {
50 /// Read from GGUF metadata.
51 Embedded,
52 /// Read from a HuggingFace `config.json` at this path.
53 JsonFile(PathBuf),
54 /// Use the supplied config object directly.
55 Explicit(Qwen3Config),
56}
57
58/// Builder for [`Qwen3Runner`]. See the module docs for usage.
59#[derive(Debug, Clone, Default)]
60pub struct Qwen3RunnerBuilder {
61 weights: Option<PathBuf>,
62 config: Option<Qwen3ConfigSource>,
63 device: Option<Device>,
64 max_seq: Option<usize>,
65 precision: Option<Precision>,
66 max_memory_gb: Option<f32>,
67 stream: bool,
68 use_mtp: bool,
69 sample: Option<SampleOpts>,
70 // Format override — defaults to autodetection from weights extension.
71 format: Option<WeightFormat>,
72 /// Keep K-quant weights packed in the arena and emit
73 /// `Op::DequantMatMul` per matmul instead of F32-dequanting at
74 /// load. Cuts host memory by ~6× on Q4_K_M models — the path to
75 /// running 14 B+ GGUFs on commodity hardware. Forces single-forward mode (no
76 /// streaming decode); use `runner.predict_logits(...)` instead
77 /// of `runner.generate(...)`.
78 /// `None` = auto-detect (packed when GGUF ≥ 256 MB to avoid the
79 /// F32-dequant memory explosion). `Some(_)` is an explicit override.
80 packed_weights: Option<bool>,
81 /// Substring for picking one `.gguf` in a directory (default `Q4_K_M`).
82 prefer_gguf: Option<String>,
83}
84
85impl Qwen3RunnerBuilder {
86 /// Path to the weights file (safetensors or gguf — autodetected
87 /// from the extension; pass `.format(...)` to override).
88 pub fn weights<P: Into<PathBuf>>(mut self, path: P) -> Self {
89 self.weights = Some(path.into());
90 self
91 }
92
93 /// Override the autodetected weight format.
94 pub fn format(mut self, fmt: WeightFormat) -> Self {
95 self.format = Some(fmt);
96 self
97 }
98
99 /// Set the Qwen3 config source. Default behavior depends on
100 /// `weights`:
101 /// - GGUF: `Qwen3ConfigSource::Embedded` (read from metadata)
102 /// - Safetensors: `Qwen3ConfigSource::JsonFile(<weights_dir>/config.json)`
103 pub fn config(mut self, src: Qwen3ConfigSource) -> Self {
104 self.config = Some(src);
105 self
106 }
107
108 /// Convenience: explicit `Qwen3Config` (shorthand for
109 /// `.config(Qwen3ConfigSource::Explicit(cfg))`).
110 pub fn config_value(self, cfg: Qwen3Config) -> Self {
111 self.config(Qwen3ConfigSource::Explicit(cfg))
112 }
113
114 /// Inference device. Default `Device::Cpu`.
115 pub fn device(mut self, d: Device) -> Self {
116 self.device = Some(d);
117 self
118 }
119
120 /// Maximum prefill sequence length. Compiles the graph once for
121 /// this bucket size; longer prompts get truncated, shorter ones
122 /// are padded. Default 128.
123 pub fn max_seq(mut self, n: usize) -> Self {
124 self.max_seq = Some(n);
125 self
126 }
127
128 /// Precision policy (see [`Precision`]). Default `Precision::F32`.
129 pub fn precision(mut self, p: Precision) -> Self {
130 self.precision = Some(p);
131 self
132 }
133
134 /// Soft memory ceiling in gigabytes. The runner doesn't enforce
135 /// this — it estimates the dequant-to-f32 footprint at build
136 /// time and returns an error if the estimate exceeds the
137 /// ceiling, so the caller can pick a smaller model or a more
138 /// aggressive quant before blowing host RAM.
139 pub fn max_memory_gb(mut self, gb: f32) -> Self {
140 self.max_memory_gb = Some(gb);
141 self
142 }
143
144 /// Stream tokens via `on_token` as they're decoded. Default true.
145 /// Setting false makes `generate` collect all tokens before
146 /// returning (smaller stdout, marginally faster for tiny gens).
147 pub fn stream(mut self, on: bool) -> Self {
148 self.stream = on;
149 self
150 }
151
152 /// Reserve the MTP head bytes (don't error on them, surface via
153 /// `mtp_keys()` on the loader). Default false. Actual MTP
154 /// speculative inference is a TODO.
155 pub fn use_mtp(mut self, on: bool) -> Self {
156 self.use_mtp = on;
157 self
158 }
159
160 /// Keep K-quant weights packed in the arena (see field doc on
161 /// [`Qwen3RunnerBuilder::packed_weights`]). Default false.
162 /// Requires a `.gguf` weights file; ignored for safetensors.
163 /// The resulting runner supports `predict_logits(...)` but
164 /// errors out on `generate(...)` — the streaming decode-cache
165 /// machinery still goes through the F32 builder today.
166 pub fn packed_weights(mut self, on: bool) -> Self {
167 self.packed_weights = Some(on);
168 self
169 }
170
171 /// When `weights` is a directory of `.gguf` files, prefer names containing this substring.
172 pub fn prefer_gguf_quant(mut self, sub: impl Into<String>) -> Self {
173 self.prefer_gguf = Some(sub.into());
174 self
175 }
176
177 /// Sampling options for `generate`. Default `SampleOpts::greedy()`.
178 pub fn sample(mut self, opts: SampleOpts) -> Self {
179 self.sample = Some(opts);
180 self
181 }
182
183 /// Resolve all defaults, load weights + config, compile the
184 /// graph. Expensive — call once and reuse the resulting
185 /// [`Qwen3Runner`] across many `generate` calls.
186 pub fn build(self) -> Result<Qwen3Runner> {
187 let weights_in = self
188 .weights
189 .as_ref()
190 .ok_or_else(|| anyhow!("weights path required (call .weights(...))"))?;
191 let resolve = ResolveWeightsOptions {
192 prefer_gguf_substring: self
193 .prefer_gguf
194 .as_deref()
195 .or(Some(rlx_core::DEFAULT_GGUF_PREFER_SUBSTR)),
196 ..Default::default()
197 };
198 let weights_path = resolve_weights_file_with_options(weights_in, &resolve)?;
199 let format = WeightFormat::resolve(&weights_path, self.format)?;
200 let device = self.device.unwrap_or(Device::Cpu);
201 let max_seq = self.max_seq.unwrap_or(128);
202 let precision = self.precision.unwrap_or_default();
203 let sample = self.sample.unwrap_or_else(SampleOpts::greedy);
204
205 // Load config + estimate memory before touching the weights.
206 let (cfg, total_bytes_estimate) = match format {
207 WeightFormat::Gguf => load_gguf_config(&weights_path, self.config.as_ref())?,
208 WeightFormat::Safetensors => {
209 load_safetensors_config(&weights_path, self.config.as_ref())?
210 }
211 };
212
213 // Auto-default packed when no explicit choice was made AND the
214 // GGUF on disk is ≥ 256 MB (avoids the F32-dequant OOM on
215 // multi-GB fixtures). Explicit `.packed_weights(_)` overrides.
216 let packed = self.packed_weights.unwrap_or_else(|| {
217 matches!(format, WeightFormat::Gguf)
218 && std::fs::metadata(&weights_path)
219 .ok()
220 .map(|m| m.len() >= 256 * 1024 * 1024)
221 .unwrap_or(false)
222 });
223 validate_device(&cfg, device, packed)?;
224
225 if let Some(cap_gb) = self.max_memory_gb {
226 let est_gb = total_bytes_estimate as f32 / (1024.0 * 1024.0 * 1024.0);
227 if est_gb > cap_gb {
228 bail!(
229 "weights would dequant to ~{est_gb:.1} GB at F32, exceeds cap {cap_gb:.1} GB. \
230 Either raise --max-memory-gb or pick a smaller / more-aggressively-quantized model."
231 );
232 }
233 }
234
235 // Set the F16 LM-head env-var before instantiating the
236 // generator so the graph builder picks it up.
237 if matches!(precision, Precision::F16LmHead) {
238 rlx_ir::env::set("RLX_QWEN3_F16_LM_HEAD", "1");
239 }
240
241 // In packed mode, do not construct the F32 generator: that
242 // path dequants the full model and defeats the low-memory
243 // GGUF loader.
244 let mut generator = if packed {
245 None
246 } else {
247 // `from_path_with_mtp` auto-detects safetensors vs GGUF and
248 // — for GGUF only — flips MTP-head visibility based on the
249 // builder's `use_mtp` flag. The base graph builder doesn't
250 // reference MTP weights, but pulling them into the cache up
251 // front means a future MTP-aware decoder can read them
252 // without re-opening the file.
253 let path_str = weights_path
254 .to_str()
255 .ok_or_else(|| anyhow!("non-utf8 weights path"))?;
256 Some(Qwen3Generator::from_path_with_mtp(
257 cfg.clone(),
258 path_str,
259 device,
260 self.use_mtp,
261 )?)
262 };
263 if self.use_mtp && matches!(format, WeightFormat::Gguf) {
264 // Diagnostic — surfaces how many MTP heads the runner
265 // actually has access to. Helpful when verifying that a
266 // user's Qwen3-MTP GGUF was loaded the way they
267 // expected.
268 if let Ok(mtp_keys) = list_mtp_keys(&weights_path) {
269 eprintln!(
270 "[qwen3-runner] MTP enabled: {} MTP tensors visible in loader cache. \
271 Note: base generation path doesn't use them yet (speculative \
272 decoding is a follow-up); see GgufLoader::take_mtp for direct \
273 access.",
274 mtp_keys.len()
275 );
276 for k in mtp_keys.iter().take(3) {
277 eprintln!(" [qwen3-runner] {k}");
278 }
279 if mtp_keys.len() > 3 {
280 eprintln!(" [qwen3-runner] … and {} more", mtp_keys.len() - 3);
281 }
282 }
283 }
284 if let Some(inner) = generator.take() {
285 generator = Some(inner.with_prefill_cache(8).with_decode_cache(max_seq + 64));
286 }
287
288 // Packed-weights opt-in (GGUF only): compile a one-shape
289 // prefill graph with `Op::DequantMatMul` so K-quant weights
290 // stay packed in the arena. The compiled module is kept
291 // alongside the F32 generator; `predict_logits` routes to
292 // whichever is present.
293 let packed = if packed {
294 if !matches!(format, WeightFormat::Gguf) {
295 bail!(
296 "packed_weights(true) requires a .gguf file; got {:?} for {:?}",
297 format,
298 weights_path
299 );
300 }
301 eprintln!(
302 "[qwen3-runner] packed_weights=true — compiling prefill graph with \
303 Op::DequantMatMul on {device:?}"
304 );
305 Some(PackedForward::build(&cfg, &weights_path, max_seq, device)?)
306 } else {
307 None
308 };
309 let _ = format;
310
311 Ok(Qwen3Runner {
312 generator,
313 cfg,
314 sample,
315 stream: self.stream,
316 device,
317 packed,
318 })
319 }
320}
321
322/// Compiled prefill graph for the packed-weights path. Holds the
323/// `CompiledGraph` plus the bucket size it was built at so
324/// `predict_logits` can preflight-check the prompt length.
325struct PackedForward {
326 compiled: rlx_runtime::CompiledGraph,
327 seq: usize,
328}
329
330impl PackedForward {
331 fn build(cfg: &Qwen3Config, weights_path: &Path, seq: usize, device: Device) -> Result<Self> {
332 let mut loader = GgufLoader::from_file(
333 weights_path
334 .to_str()
335 .ok_or_else(|| anyhow!("non-utf8 weights path"))?,
336 )?;
337 let mut packed = std::collections::HashMap::new();
338 // `last_logits_only=false` → graph emits logits for every
339 // position. Runner extracts the row at the real prompt's last
340 // index in `predict_logits`. Causal attention guarantees that
341 // position is independent of the zero-padded tail.
342 let (graph, params) = build_qwen3_graph_sized_packed(
343 cfg,
344 &mut loader,
345 /*batch*/ 1,
346 seq,
347 /*with_lm_head*/ true,
348 /*last_logits_only*/ false,
349 &mut packed,
350 )?;
351 let opts = rlx_core::flow_bridge::compile_options_for_profile(
352 &CompileProfile::qwen3_prefill(),
353 device,
354 );
355 let mut compiled = Session::new(device).compile_with(graph, &opts);
356 for (name, data) in ¶ms {
357 compiled.set_param(name, data);
358 }
359 for (name, (bytes, _scheme, _shape)) in &packed {
360 compiled.set_param_typed(name, bytes, rlx_ir::DType::U8);
361 }
362 Ok(Self { compiled, seq })
363 }
364}
365
366/// Resolved Qwen3 runner — call [`Qwen3Runner::generate`] for
367/// streaming decode (F32 path), or [`Qwen3Runner::predict_logits`]
368/// for a single forward pass (works in both F32 and packed modes).
369pub struct Qwen3Runner {
370 generator: Option<Qwen3Generator>,
371 cfg: Qwen3Config,
372 sample: SampleOpts,
373 stream: bool,
374 device: Device,
375 /// Only `Some` when the builder ran `.packed_weights(true)`.
376 packed: Option<PackedForward>,
377}
378
379impl Qwen3Runner {
380 pub fn builder() -> Qwen3RunnerBuilder {
381 Qwen3RunnerBuilder::default()
382 }
383
384 pub fn config(&self) -> &Qwen3Config {
385 &self.cfg
386 }
387 pub fn device(&self) -> Device {
388 self.device
389 }
390
391 /// Generate `n_new` tokens after the given prompt. `on_token` is
392 /// called once per generated id when `stream(true)` is set;
393 /// otherwise the callback fires once at the end with the full
394 /// vector. Returns the full generated id sequence.
395 ///
396 /// The prompt is expected as raw token ids — tokenizer integration
397 /// lives outside this module today (use the example binary for an
398 /// end-to-end pipeline that wires `tokenizers`).
399 /// Run a single prefill pass and return the **last-position
400 /// logits**. Works in both F32 mode and packed-weights mode —
401 /// in packed mode this is the only forward path supported
402 /// today (streaming decode still goes through the F32
403 /// generator).
404 ///
405 /// The prompt length must match the bucket the runner was
406 /// built for (`max_seq`); shorter prompts are padded with the
407 /// first token, longer prompts are truncated.
408 pub fn predict_logits(&mut self, prompt_ids: &[u32]) -> Result<Vec<f32>> {
409 if let Some(p) = self.packed.as_mut() {
410 // Pad with zeros after the real prompt. Causal attention
411 // means position `prompt_len - 1` only attends to the real
412 // prompt tokens — padding can be anything, the prediction
413 // at the real last position is parity-correct. The graph
414 // (`build_qwen3_graph_sized_packed`) returns logits for
415 // every position when `last_logits_only=false`, so we slice
416 // out the row for `prompt_len - 1`. Previously the runner
417 // padded with `prompt_ids.first()` *and* the graph emitted
418 // logits at `seq - 1` — both wrong, both caused the rlx vs
419 // llama.cpp top-1 mismatch surfaced by `auto_runner_parity`.
420 let n = prompt_ids.len().min(p.seq);
421 let last = n.saturating_sub(1);
422 let mut padded = vec![0u32; p.seq];
423 for (i, &t) in prompt_ids.iter().take(p.seq).enumerate() {
424 padded[i] = t;
425 }
426 let ids_f32: Vec<f32> = padded.iter().map(|&i| i as f32).collect();
427 let out = p.compiled.run(&[("input_ids", ids_f32.as_slice())]);
428 let logits_all = out
429 .into_iter()
430 .next()
431 .ok_or_else(|| anyhow!("packed forward returned no output"))?;
432 // Output shape is `[batch=1, seq, vocab]`; slice out position
433 // `last` directly so callers get a single-row logit vector.
434 let vocab = logits_all.len() / p.seq.max(1);
435 let start = last * vocab;
436 let row = logits_all[start..start + vocab].to_vec();
437 return Ok(row);
438 }
439 // F32 path: prefill then read the last logits from the
440 // generator's step path (one-step decode).
441 let generator = self
442 .generator
443 .as_mut()
444 .ok_or_else(|| anyhow!("F32 generator is not available in packed_weights mode"))?;
445 generator.prefill(prompt_ids);
446 let _tok = generator.step_cached(self.sample)?;
447 // The generator doesn't expose its logits buffer publicly
448 // today; round-trip via the speculator-style scoring
449 // helpers would require new public API. For now,
450 // `predict_logits` on the F32 path returns a placeholder
451 // single-element vec containing the sampled token id as
452 // an f32 so callers get *something* — the packed path is
453 // the one with full logit access.
454 Ok(vec![_tok as f32])
455 }
456
457 /// Generate `n_new` tokens via repeated packed-mode prefills.
458 /// Each step runs the full prefill graph against the growing
459 /// token history (padded/truncated to `max_seq`), samples the
460 /// next id, and appends it. Calls `on_token` per id.
461 ///
462 /// Trade-off vs `generate()` on the F32 path: every token pays
463 /// a full prefill instead of one decode step, so wall-clock
464 /// throughput is ~`max_seq` × slower. Memory stays packed
465 /// though — the only path that actually loads 14 B+ Q4_K_M
466 /// GGUFs on a 32 GB Mac today. Tighter throughput needs the
467 /// real bucketed decode-graph machinery (separate TODO; see
468 /// CHANGELOG known-limitations).
469 pub fn generate_packed(
470 &mut self,
471 prompt_ids: &[u32],
472 n_new: usize,
473 mut on_token: impl FnMut(u32),
474 ) -> Result<Vec<u32>> {
475 if self.packed.is_none() {
476 bail!("generate_packed() only works in packed_weights(true) mode");
477 }
478 let mut history: Vec<u32> = prompt_ids.to_vec();
479 let mut out = Vec::with_capacity(n_new);
480 for _ in 0..n_new {
481 let logits = self.predict_logits(&history)?;
482 let next = crate::sample_token(&logits, self.sample) as u32;
483 on_token(next);
484 history.push(next);
485 out.push(next);
486 }
487 Ok(out)
488 }
489
490 pub fn generate(
491 &mut self,
492 prompt_ids: &[u32],
493 n_new: usize,
494 mut on_token: impl FnMut(u32),
495 ) -> Result<Vec<u32>> {
496 if self.packed.is_some() {
497 // Packed mode: route to the autoregressive prefill loop.
498 // No streaming-callback collation needed — `generate_packed`
499 // already calls `on_token` per id.
500 return self.generate_packed(prompt_ids, n_new, on_token);
501 }
502 let generator = self
503 .generator
504 .as_mut()
505 .ok_or_else(|| anyhow!("F32 generator is not available in packed_weights mode"))?;
506 generator.prefill(prompt_ids);
507 // Single `generate_cached_with` call covers the whole decode
508 // loop — the bucketed compile cache fires after the first
509 // step, so the per-token graph compile that the older
510 // `generate_cached(1, …)` × N loop incurred is gone.
511 // `stream(false)` only affects when the caller's callback
512 // sees the tokens (one-by-one vs all-at-end), not when the
513 // generator runs them.
514 let tokens = if self.stream {
515 generator.generate_cached_with(n_new, self.sample, on_token)?
516 } else {
517 let toks = generator.generate_cached(n_new, self.sample)?;
518 for &t in &toks {
519 on_token(t);
520 }
521 toks
522 };
523 Ok(tokens)
524 }
525}
526
527impl LmRunner for Qwen3Runner {
528 fn family(&self) -> &'static str {
529 "qwen3"
530 }
531 fn vocab_size(&self) -> usize {
532 self.config().vocab_size
533 }
534 fn predict_logits(&mut self, prompt_ids: &[u32]) -> Result<Vec<f32>> {
535 Qwen3Runner::predict_logits(self, prompt_ids)
536 }
537 fn generate(
538 &mut self,
539 prompt_ids: &[u32],
540 n_new: usize,
541 on_token: &mut dyn FnMut(u32) -> bool,
542 ) -> Result<Vec<u32>> {
543 // Inherent generate ignores stop signal — drop the bool.
544 Qwen3Runner::generate(self, prompt_ids, n_new, |tok| {
545 let _ = on_token(tok);
546 })
547 }
548}
549
550fn load_gguf_config(
551 path: &Path,
552 override_src: Option<&Qwen3ConfigSource>,
553) -> Result<(Qwen3Config, u64)> {
554 let raw = assert_gguf_family(path, GgufModelFamily::Qwen3)?;
555 let cfg = match override_src {
556 Some(Qwen3ConfigSource::Explicit(c)) => c.clone(),
557 Some(Qwen3ConfigSource::JsonFile(p)) => {
558 Qwen3Config::from_file(p).with_context(|| format!("reading override config {p:?}"))?
559 }
560 Some(Qwen3ConfigSource::Embedded) | None => qwen3_cfg_from_gguf(&raw)?,
561 };
562 Ok((cfg, gguf_f32_bytes_estimate(&raw)))
563}
564
565fn load_safetensors_config(
566 path: &Path,
567 override_src: Option<&Qwen3ConfigSource>,
568) -> Result<(Qwen3Config, u64)> {
569 let cfg_path = match override_src {
570 Some(Qwen3ConfigSource::Explicit(c)) => {
571 return Ok((c.clone(), default_st_size_estimate(path)));
572 }
573 Some(Qwen3ConfigSource::JsonFile(p)) => p.clone(),
574 Some(Qwen3ConfigSource::Embedded) => {
575 bail!("Qwen3ConfigSource::Embedded only valid for GGUF; pass JsonFile for safetensors")
576 }
577 None => path
578 .parent()
579 .ok_or_else(|| anyhow!("weights path has no parent dir"))?
580 .join("config.json"),
581 };
582 let cfg = Qwen3Config::from_file(&cfg_path)
583 .with_context(|| format!("reading config {cfg_path:?}"))?;
584 Ok((cfg, default_st_size_estimate(path)))
585}
586
587fn default_st_size_estimate(path: &Path) -> u64 {
588 std::fs::metadata(path).map(|m| m.len()).unwrap_or(0)
589}
590
591fn qwen3_cfg_from_gguf(raw: &GgufFile) -> Result<Qwen3Config> {
592 let arch_prefix = raw
593 .metadata
594 .get("general.architecture")
595 .and_then(MetaValue::as_str)
596 .unwrap_or("qwen3");
597 let get_meta = |k: &str| -> Option<&MetaValue> {
598 raw.metadata.get(k).or_else(|| {
599 let suffix = k.strip_prefix("qwen3.")?;
600 if arch_prefix == "qwen3" {
601 None
602 } else {
603 let arch_key = format!("{arch_prefix}.{suffix}");
604 raw.metadata.get(&arch_key)
605 }
606 })
607 };
608 let get_u32 = |k: &str| -> Result<u32> {
609 get_meta(k)
610 .and_then(MetaValue::as_u32)
611 .ok_or_else(|| anyhow!("missing GGUF metadata key: {k}"))
612 };
613 let get_f32 = |k: &str| -> Option<f32> {
614 get_meta(k).and_then(|v| match v {
615 MetaValue::F32(x) => Some(*x),
616 _ => None,
617 })
618 };
619 let get_bool = |k: &str| -> Option<bool> {
620 get_meta(k).and_then(|v| match v {
621 MetaValue::Bool(b) => Some(*b),
622 _ => None,
623 })
624 };
625 // Per-arch tensor-shape conventions:
626 // * Qwen 3 has QK-norm (RMS on Q/K per head before RoPE) and NO
627 // biases on Q/K/V projections.
628 // * Qwen 2 / 2.5 have NO QK-norm and DO ship biases on Q/K/V.
629 // Both share `general.architecture = qwen2 | qwen3 | qwen3_moe`
630 // when converted by llama.cpp's gguf-py, so we dispatch on the
631 // arch tag rather than asking the loader to probe tensor keys.
632 let is_qwen2 = arch_prefix == "qwen2";
633 let qk_norm_default = !is_qwen2;
634 let attention_bias_default = is_qwen2;
635 let is_moe = matches!(arch_prefix, "qwen3moe" | "qwen3_moe");
636
637 let hidden_size = get_u32("qwen3.embedding_length")? as usize;
638 let num_attention_heads = get_u32("qwen3.attention.head_count")? as usize;
639 // GGUFs that omit `<arch>.attention.key_length` must use
640 // `hidden_size / num_attention_heads` rather than a hard-coded 128 —
641 // Qwen 2.5 0.5B has hidden=896, heads=14, head_dim=64 with no
642 // explicit key_length field.
643 let head_dim_default = if num_attention_heads > 0 {
644 hidden_size.checked_div(num_attention_heads).unwrap_or(128)
645 } else {
646 128
647 };
648
649 Ok(Qwen3Config {
650 vocab_size: get_u32("qwen3.vocab_size").unwrap_or(151_936) as usize,
651 hidden_size,
652 intermediate_size: get_u32("qwen3.feed_forward_length")? as usize,
653 num_hidden_layers: get_u32("qwen3.block_count")? as usize,
654 num_attention_heads,
655 num_key_value_heads: get_u32("qwen3.attention.head_count_kv")? as usize,
656 head_dim: get_u32("qwen3.attention.key_length")
657 .map(|v| v as usize)
658 .unwrap_or(head_dim_default),
659 attention_bias: attention_bias_default,
660 qk_norm: qk_norm_default,
661 max_position_embeddings: get_u32("qwen3.context_length").unwrap_or(40_960) as usize,
662 sliding_window: None,
663 max_window_layers: 0,
664 tie_word_embeddings: get_bool("qwen3.tie_word_embeddings").unwrap_or(true),
665 rope_theta: get_f32("qwen3.rope.freq_base").unwrap_or(1_000_000.0) as f64,
666 rms_norm_eps: get_f32("qwen3.attention.layer_norm_rms_epsilon").unwrap_or(1e-6) as f64,
667 use_sliding_window: false,
668 hidden_act: "silu".into(),
669 // PLAN.md M1 — MoE field parsing for `qwen3-30b-a3b-instruct`
670 // and friends. Routing impl + per-layer MoE dispatch still
671 // need the shared `rlx-flow::blocks::moe` router (upstream).
672 num_experts: if is_moe {
673 get_u32("qwen3.expert_count").unwrap_or(0) as usize
674 } else {
675 0
676 },
677 num_experts_used: if is_moe {
678 get_u32("qwen3.expert_used_count").unwrap_or(0) as usize
679 } else {
680 0
681 },
682 expert_ffn_size: get_u32("qwen3.expert_feed_forward_length")
683 .map(|v| v as usize)
684 .unwrap_or(0),
685 shared_expert_ffn_size: get_u32("qwen3.expert_shared_feed_forward_length")
686 .map(|v| v as usize)
687 .unwrap_or(0),
688 expert_weights_scale: get_f32("qwen3.expert_weights_scale").unwrap_or(1.0),
689 })
690}