1use crate::capabilities::validate_device;
17use crate::{Qwen3Config, Qwen3Generator, SampleOpts, build_qwen3_graph_sized_packed};
18use anyhow::{Context, Result, anyhow, bail};
19use rlx_cli::{LmRunner, WeightFormat, list_mtp_keys};
20use rlx_core::gguf_support::{
21 GgufModelFamily, ResolveWeightsOptions, assert_gguf_family, gguf_f32_bytes_estimate,
22 resolve_weights_file_with_options,
23};
24use rlx_core::weight_loader::GgufLoader;
25use rlx_flow::CompileProfile;
26use rlx_gguf::{GgufFile, MetaValue};
27use rlx_runtime::{Device, Session};
28use std::path::{Path, PathBuf};
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
34pub enum Precision {
35 #[default]
38 F32,
39 F16LmHead,
43}
44
45pub type Qwen3ConfigSource = rlx_runtime::ConfigSource<Qwen3Config>;
55
56#[derive(Debug, Clone, Default)]
58pub struct Qwen3RunnerBuilder {
59 weights: Option<PathBuf>,
60 config: Option<Qwen3ConfigSource>,
61 device: Option<Device>,
62 max_seq: Option<usize>,
63 precision: Option<Precision>,
64 max_memory_gb: Option<f32>,
65 stream: bool,
66 use_mtp: bool,
67 sample: Option<SampleOpts>,
68 format: Option<WeightFormat>,
70 packed_weights: Option<bool>,
79 prefer_gguf: Option<String>,
81}
82
83impl Qwen3RunnerBuilder {
84 pub fn weights<P: Into<PathBuf>>(mut self, path: P) -> Self {
87 self.weights = Some(path.into());
88 self
89 }
90
91 pub fn format(mut self, fmt: WeightFormat) -> Self {
93 self.format = Some(fmt);
94 self
95 }
96
97 pub fn config(mut self, src: Qwen3ConfigSource) -> Self {
102 self.config = Some(src);
103 self
104 }
105
106 pub fn config_value(self, cfg: Qwen3Config) -> Self {
109 self.config(Qwen3ConfigSource::Explicit(cfg))
110 }
111
112 pub fn device(mut self, d: Device) -> Self {
114 self.device = Some(d);
115 self
116 }
117
118 pub fn max_seq(mut self, n: usize) -> Self {
122 self.max_seq = Some(n);
123 self
124 }
125
126 pub fn precision(mut self, p: Precision) -> Self {
128 self.precision = Some(p);
129 self
130 }
131
132 pub fn max_memory_gb(mut self, gb: f32) -> Self {
138 self.max_memory_gb = Some(gb);
139 self
140 }
141
142 pub fn stream(mut self, on: bool) -> Self {
146 self.stream = on;
147 self
148 }
149
150 pub fn use_mtp(mut self, on: bool) -> Self {
154 self.use_mtp = on;
155 self
156 }
157
158 pub fn packed_weights(mut self, on: bool) -> Self {
165 self.packed_weights = Some(on);
166 self
167 }
168
169 pub fn prefer_gguf_quant(mut self, sub: impl Into<String>) -> Self {
171 self.prefer_gguf = Some(sub.into());
172 self
173 }
174
175 pub fn sample(mut self, opts: SampleOpts) -> Self {
177 self.sample = Some(opts);
178 self
179 }
180
181 pub fn build(self) -> Result<Qwen3Runner> {
185 let weights_in = self
186 .weights
187 .as_ref()
188 .ok_or_else(|| anyhow!("weights path required (call .weights(...))"))?;
189 let resolve = ResolveWeightsOptions {
190 prefer_gguf_substring: self
191 .prefer_gguf
192 .as_deref()
193 .or(Some(rlx_core::DEFAULT_GGUF_PREFER_SUBSTR)),
194 ..Default::default()
195 };
196 let weights_path = resolve_weights_file_with_options(weights_in, &resolve)?;
197 let format = WeightFormat::resolve(&weights_path, self.format)?;
198 let device = self.device.unwrap_or(Device::Cpu);
199 let max_seq = self.max_seq.unwrap_or(128);
200 let precision = self.precision.unwrap_or_default();
201 let sample = self.sample.unwrap_or_else(SampleOpts::greedy);
202
203 let (cfg, total_bytes_estimate) = match format {
205 WeightFormat::Gguf => load_gguf_config(&weights_path, self.config.as_ref())?,
206 WeightFormat::Safetensors => {
207 load_safetensors_config(&weights_path, self.config.as_ref())?
208 }
209 };
210
211 let packed = self.packed_weights.unwrap_or_else(|| {
215 matches!(format, WeightFormat::Gguf)
216 && std::fs::metadata(&weights_path)
217 .ok()
218 .map(|m| m.len() >= 256 * 1024 * 1024)
219 .unwrap_or(false)
220 });
221 validate_device(&cfg, device, packed)?;
222
223 if let Some(cap_gb) = self.max_memory_gb {
224 let est_gb = total_bytes_estimate as f32 / (1024.0 * 1024.0 * 1024.0);
225 if est_gb > cap_gb {
226 bail!(
227 "weights would dequant to ~{est_gb:.1} GB at F32, exceeds cap {cap_gb:.1} GB. \
228 Either raise --max-memory-gb or pick a smaller / more-aggressively-quantized model."
229 );
230 }
231 }
232
233 if matches!(precision, Precision::F16LmHead) {
236 rlx_ir::env::set("RLX_QWEN3_F16_LM_HEAD", "1");
237 }
238
239 let mut generator = if packed {
243 None
244 } else {
245 let path_str = weights_path
252 .to_str()
253 .ok_or_else(|| anyhow!("non-utf8 weights path"))?;
254 Some(Qwen3Generator::from_path_with_mtp(
255 cfg.clone(),
256 path_str,
257 device,
258 self.use_mtp,
259 )?)
260 };
261 if self.use_mtp && matches!(format, WeightFormat::Gguf) {
262 if let Ok(mtp_keys) = list_mtp_keys(&weights_path) {
267 eprintln!(
268 "[qwen3-runner] MTP enabled: {} MTP tensors visible in loader cache. \
269 Note: base generation path doesn't use them yet (speculative \
270 decoding is a follow-up); see GgufLoader::take_mtp for direct \
271 access.",
272 mtp_keys.len()
273 );
274 for k in mtp_keys.iter().take(3) {
275 eprintln!(" [qwen3-runner] {k}");
276 }
277 if mtp_keys.len() > 3 {
278 eprintln!(" [qwen3-runner] … and {} more", mtp_keys.len() - 3);
279 }
280 }
281 }
282 if let Some(inner) = generator.take() {
283 generator = Some(inner.with_prefill_cache(8).with_decode_cache(max_seq + 64));
284 }
285
286 let packed = if packed {
292 if !matches!(format, WeightFormat::Gguf) {
293 bail!(
294 "packed_weights(true) requires a .gguf file; got {:?} for {:?}",
295 format,
296 weights_path
297 );
298 }
299 eprintln!(
300 "[qwen3-runner] packed_weights=true — compiling prefill graph with \
301 Op::DequantMatMul on {device:?}"
302 );
303 Some(PackedForward::build(&cfg, &weights_path, max_seq, device)?)
304 } else {
305 None
306 };
307 let _ = format;
308
309 Ok(Qwen3Runner {
310 generator,
311 cfg,
312 sample,
313 stream: self.stream,
314 device,
315 packed,
316 })
317 }
318}
319
320struct PackedForward {
324 compiled: rlx_runtime::CompiledGraph,
325 seq: usize,
326 padded_ids: Vec<u32>,
327 ids_f32: Vec<f32>,
332 last_idx: [f32; 1],
333}
334
335impl PackedForward {
336 fn build(cfg: &Qwen3Config, weights_path: &Path, seq: usize, device: Device) -> Result<Self> {
337 let exec_device = rlx_core::flow_bridge::packed_gguf_execution_device(device);
338 if exec_device != device {
339 eprintln!(
340 "[qwen3-runner] packed GGUF on {device:?}: prefill executes on {exec_device:?} \
341 until {device:?} packed parity is fixed upstream"
342 );
343 }
344 let mut loader = GgufLoader::from_file(
345 weights_path
346 .to_str()
347 .ok_or_else(|| anyhow!("non-utf8 weights path"))?,
348 )?;
349 let mut packed = std::collections::HashMap::new();
350 let (graph, params) = build_qwen3_graph_sized_packed(
351 cfg,
352 &mut loader,
353 1,
354 seq,
355 true,
356 true,
357 &mut packed,
358 )?;
359 let opts = rlx_core::flow_bridge::compile_options_for_packed_gguf_prefill_with_profile(
360 &CompileProfile::qwen3_prefill(),
361 exec_device,
362 );
363 let mut compiled = rlx_core::flow_bridge::packed_gguf_compile_guard(exec_device, || {
364 Session::new(exec_device).compile_with(graph, &opts)
365 });
366 for (name, data) in ¶ms {
367 compiled.set_param(name, data);
368 }
369 for (name, (bytes, _scheme, _shape)) in &packed {
370 compiled.set_param_typed(name, bytes, rlx_ir::DType::U8);
371 }
372 Ok(Self {
373 compiled,
374 seq,
375 padded_ids: vec![0u32; seq],
376 ids_f32: vec![0f32; seq],
377 last_idx: [0f32; 1],
378 })
379 }
380}
381
382pub struct Qwen3Runner {
386 generator: Option<Qwen3Generator>,
387 cfg: Qwen3Config,
388 sample: SampleOpts,
389 stream: bool,
390 device: Device,
391 packed: Option<PackedForward>,
393}
394
395impl Qwen3Runner {
396 pub fn builder() -> Qwen3RunnerBuilder {
397 Qwen3RunnerBuilder::default()
398 }
399
400 pub fn config(&self) -> &Qwen3Config {
401 &self.cfg
402 }
403 pub fn device(&self) -> Device {
404 self.device
405 }
406
407 pub fn disable_decode_compile_cache(&mut self) {
411 if let Some(g) = self.generator.as_mut() {
412 g.set_decode_compile_cache(None);
413 }
414 }
415
416 pub fn predict_logits(&mut self, prompt_ids: &[u32]) -> Result<Vec<f32>> {
434 if let Some(p) = self.packed.as_mut() {
435 let n = prompt_ids.len().min(p.seq);
436 p.padded_ids.fill(0);
437 for (i, &t) in prompt_ids.iter().take(n).enumerate() {
438 p.padded_ids[i] = t;
439 }
440 for (dst, &id) in p.ids_f32.iter_mut().zip(p.padded_ids.iter()) {
441 *dst = id as f32;
442 }
443 p.last_idx[0] = n.saturating_sub(1) as f32;
444 let exec_device = p.compiled.device();
445 let out = rlx_core::run_packed_prefill(
446 &mut p.compiled,
447 exec_device,
448 n,
449 p.seq,
450 &[
451 ("input_ids", p.ids_f32.as_slice()),
452 ("last_token_idx", p.last_idx.as_slice()),
453 ],
454 );
455 let logits = out
456 .into_iter()
457 .next()
458 .ok_or_else(|| anyhow!("packed forward returned no output"))?;
459 let vocab = self.cfg.vocab_size;
460 if logits.len() < vocab {
461 bail!("logits short: {} < {vocab}", logits.len());
462 }
463 return Ok(logits[..vocab].to_vec());
464 }
465 let generator = self
468 .generator
469 .as_mut()
470 .ok_or_else(|| anyhow!("F32 generator is not available in packed_weights mode"))?;
471 generator.prefill(prompt_ids);
472 let _tok = generator.step_cached(self.sample)?;
473 Ok(vec![_tok as f32])
481 }
482
483 pub fn generate_packed(
496 &mut self,
497 prompt_ids: &[u32],
498 n_new: usize,
499 mut on_token: impl FnMut(u32),
500 ) -> Result<Vec<u32>> {
501 if self.packed.is_none() {
502 bail!("generate_packed() only works in packed_weights(true) mode");
503 }
504 let mut history: Vec<u32> = prompt_ids.to_vec();
505 let mut out = Vec::with_capacity(n_new);
506 for _ in 0..n_new {
507 let logits = self.predict_logits(&history)?;
508 let next = crate::sample_token(&logits, self.sample) as u32;
509 on_token(next);
510 history.push(next);
511 out.push(next);
512 }
513 Ok(out)
514 }
515
516 pub fn generate(
517 &mut self,
518 prompt_ids: &[u32],
519 n_new: usize,
520 mut on_token: impl FnMut(u32),
521 ) -> Result<Vec<u32>> {
522 self.generate_stoppable(prompt_ids, n_new, |tok| {
523 on_token(tok);
524 true
525 })
526 }
527
528 pub fn generate_stoppable(
531 &mut self,
532 prompt_ids: &[u32],
533 n_new: usize,
534 mut on_token: impl FnMut(u32) -> bool,
535 ) -> Result<Vec<u32>> {
536 if self.packed.is_some() {
537 return self.generate_packed(prompt_ids, n_new, |tok| {
541 let _ = on_token(tok);
542 });
543 }
544 let generator = self
545 .generator
546 .as_mut()
547 .ok_or_else(|| anyhow!("F32 generator is not available in packed_weights mode"))?;
548 generator.prefill(prompt_ids);
549 let stream = self.stream;
557 generator.generate_cached_until(
558 n_new,
559 self.sample,
560 |tok| {
561 if stream {
562 on_token(tok);
563 }
564 true
565 },
566 |_| {},
567 )
568 }
569}
570
571impl LmRunner for Qwen3Runner {
572 fn family(&self) -> &'static str {
573 "qwen3"
574 }
575 fn vocab_size(&self) -> usize {
576 self.config().vocab_size
577 }
578 fn predict_logits(&mut self, prompt_ids: &[u32]) -> Result<Vec<f32>> {
579 Qwen3Runner::predict_logits(self, prompt_ids)
580 }
581 fn generate(
582 &mut self,
583 prompt_ids: &[u32],
584 n_new: usize,
585 on_token: &mut dyn FnMut(u32) -> bool,
586 ) -> Result<Vec<u32>> {
587 Qwen3Runner::generate(self, prompt_ids, n_new, |tok| {
589 let _ = on_token(tok);
590 })
591 }
592}
593
594fn load_gguf_config(
595 path: &Path,
596 override_src: Option<&Qwen3ConfigSource>,
597) -> Result<(Qwen3Config, u64)> {
598 let raw = assert_gguf_family(path, GgufModelFamily::Qwen3)?;
599 let cfg = match override_src {
600 Some(Qwen3ConfigSource::Explicit(c)) => c.clone(),
601 Some(Qwen3ConfigSource::JsonFile(p)) => {
602 Qwen3Config::from_file(p).with_context(|| format!("reading override config {p:?}"))?
603 }
604 Some(Qwen3ConfigSource::Embedded) | None => qwen3_cfg_from_gguf(&raw)?,
605 };
606 Ok((cfg, gguf_f32_bytes_estimate(&raw)))
607}
608
609fn load_safetensors_config(
610 path: &Path,
611 override_src: Option<&Qwen3ConfigSource>,
612) -> Result<(Qwen3Config, u64)> {
613 let cfg_path = match override_src {
614 Some(Qwen3ConfigSource::Explicit(c)) => {
615 return Ok((c.clone(), default_st_size_estimate(path)));
616 }
617 Some(Qwen3ConfigSource::JsonFile(p)) => p.clone(),
618 Some(Qwen3ConfigSource::Embedded) => {
619 bail!("Qwen3ConfigSource::Embedded only valid for GGUF; pass JsonFile for safetensors")
620 }
621 None => path
622 .parent()
623 .ok_or_else(|| anyhow!("weights path has no parent dir"))?
624 .join("config.json"),
625 };
626 let cfg = Qwen3Config::from_file(&cfg_path)
627 .with_context(|| format!("reading config {cfg_path:?}"))?;
628 Ok((cfg, default_st_size_estimate(path)))
629}
630
631fn default_st_size_estimate(path: &Path) -> u64 {
632 std::fs::metadata(path).map(|m| m.len()).unwrap_or(0)
633}
634
635fn qwen3_cfg_from_gguf(raw: &GgufFile) -> Result<Qwen3Config> {
636 let arch_prefix = raw
637 .metadata
638 .get("general.architecture")
639 .and_then(MetaValue::as_str)
640 .unwrap_or("qwen3");
641 let get_meta = |k: &str| -> Option<&MetaValue> {
642 raw.metadata.get(k).or_else(|| {
643 let suffix = k.strip_prefix("qwen3.")?;
644 if arch_prefix == "qwen3" {
645 None
646 } else {
647 let arch_key = format!("{arch_prefix}.{suffix}");
648 raw.metadata.get(&arch_key)
649 }
650 })
651 };
652 let get_u32 = |k: &str| -> Result<u32> {
653 get_meta(k)
654 .and_then(MetaValue::as_u32)
655 .ok_or_else(|| anyhow!("missing GGUF metadata key: {k}"))
656 };
657 let get_f32 = |k: &str| -> Option<f32> {
658 get_meta(k).and_then(|v| match v {
659 MetaValue::F32(x) => Some(*x),
660 _ => None,
661 })
662 };
663 let get_bool = |k: &str| -> Option<bool> {
664 get_meta(k).and_then(|v| match v {
665 MetaValue::Bool(b) => Some(*b),
666 _ => None,
667 })
668 };
669 let is_qwen2 = arch_prefix == "qwen2";
677 let qk_norm_default = !is_qwen2;
678 let attention_bias_default = is_qwen2;
679 let is_moe = matches!(arch_prefix, "qwen3moe" | "qwen3_moe");
680
681 let hidden_size = get_u32("qwen3.embedding_length")? as usize;
682 let num_attention_heads = get_u32("qwen3.attention.head_count")? as usize;
683 let head_dim_default = if num_attention_heads > 0 {
688 hidden_size.checked_div(num_attention_heads).unwrap_or(128)
689 } else {
690 128
691 };
692
693 Ok(Qwen3Config {
694 vocab_size: get_u32("qwen3.vocab_size").unwrap_or(151_936) as usize,
695 hidden_size,
696 intermediate_size: get_u32("qwen3.feed_forward_length")? as usize,
697 num_hidden_layers: get_u32("qwen3.block_count")? as usize,
698 num_attention_heads,
699 num_key_value_heads: get_u32("qwen3.attention.head_count_kv")? as usize,
700 head_dim: get_u32("qwen3.attention.key_length")
701 .map(|v| v as usize)
702 .unwrap_or(head_dim_default),
703 attention_bias: attention_bias_default,
704 qk_norm: qk_norm_default,
705 max_position_embeddings: get_u32("qwen3.context_length").unwrap_or(40_960) as usize,
706 sliding_window: None,
707 max_window_layers: 0,
708 tie_word_embeddings: get_bool("qwen3.tie_word_embeddings").unwrap_or(true),
709 rope_theta: get_f32("qwen3.rope.freq_base").unwrap_or(1_000_000.0) as f64,
710 rms_norm_eps: get_f32("qwen3.attention.layer_norm_rms_epsilon").unwrap_or(1e-6) as f64,
711 use_sliding_window: false,
712 hidden_act: "silu".into(),
713 num_experts: if is_moe {
717 get_u32("qwen3.expert_count").unwrap_or(0) as usize
718 } else {
719 0
720 },
721 num_experts_used: if is_moe {
722 get_u32("qwen3.expert_used_count").unwrap_or(0) as usize
723 } else {
724 0
725 },
726 expert_ffn_size: get_u32("qwen3.expert_feed_forward_length")
727 .map(|v| v as usize)
728 .unwrap_or(0),
729 shared_expert_ffn_size: get_u32("qwen3.expert_shared_feed_forward_length")
730 .map(|v| v as usize)
731 .unwrap_or(0),
732 expert_weights_scale: get_f32("qwen3.expert_weights_scale").unwrap_or(1.0),
733 })
734}