1use crate::cache::{
19 Qwen35DecodeCache, advance_cache_from_decode_outputs, decode_step_feeds, last_token_indices,
20 pack_input_ids, seed_cache_from_outputs, zero_prompt_padding_kv, zero_recurrent_inputs,
21};
22use crate::capabilities::validate_device;
23use crate::config::Qwen35Config;
24use crate::encode_prompt_auto;
25use crate::lm_head::{
26 greedy_lm_head_argmax, lm_head_logits_row, sample_lm_cap, sample_lm_head_from_hidden,
27};
28use crate::moe_offload::{MoeOffloadState, build_moe_offload};
29use crate::moe_store::{build_moe_expert_store, moe_host_bind_from_store};
30use crate::rope::{mrope_prefill_feeds, mrope_slice_at_pos};
31use crate::vision::{
32 MmProjConfig, MmProjWeights, MultimodalPrefill, MultimodalPrompt, Qwen35VisionEncoder,
33 load_vision_encoder,
34};
35use crate::weights::Qwen35Weights;
36use crate::{
37 PackedParams, build_qwen35_decode_hir_dynamic_ext, build_qwen35_decode_hir_ext,
38 build_qwen35_hir_sized_ext, build_qwen35_prefill_cache_hir_dynamic_ext,
39 build_qwen35_prefill_hidden_cache_hir_dynamic_ext,
40};
41use rlx_runtime::MoeExpertStore;
42
43fn push_moe_residency(compiled: &mut rlx_runtime::CompiledGraph, layers: &[Vec<bool>]) {
44 let refs: Vec<&[bool]> = layers.iter().map(|m| m.as_slice()).collect();
45 compiled.set_moe_resident_experts_per_layer(&refs);
46}
47
48fn refresh_moe_from_capture(
49 mo: &mut MoeOffloadState,
50 store: Option<&MoeExpertStore>,
51 compiled: &mut rlx_runtime::CompiledGraph,
52 layer_indices: &[Vec<u32>],
53 denoise_step: usize,
54 is_prefill_block: bool,
55) -> bool {
56 let refreshed = if let Some(store) = store {
57 mo.refresh_from_capture_with_store(store, layer_indices, denoise_step, is_prefill_block)
58 } else {
59 mo.refresh_from_capture(layer_indices, denoise_step, is_prefill_block)
60 };
61 if refreshed {
62 push_moe_residency(compiled, &mo.per_layer_resident_masks());
63 }
64 refreshed
65}
66use crate::execution::{
67 Qwen35CompileCache, decode_config, get_or_specialize_hir_with_options, hidden_prefill_config,
68 prefill_config,
69};
70use crate::flow::{Qwen35PrefillCacheOpts, build_qwen35_prefill_cache_built};
71use crate::profile::{qwen35_profile_default, qwen35_profile_near_weights};
72use anyhow::{Context, Result, anyhow, bail};
73use rlx_core::flow_bridge::compile_options_from_profile;
74use rlx_core::gguf_support::{GgufModelFamily, assert_gguf_family, resolve_weights_file};
75use rlx_core::weight_loader::GgufLoader;
76use rlx_flow::ModelExecutionConfig;
77use rlx_flow::{CompileProfile, ExecutionPreset};
78use rlx_ir::CompilationMode;
79use rlx_ir::logical_kernel::KernelDispatchConfig;
80use rlx_qwen3::sampling::{SampleOpts, sample_token};
81use rlx_runtime::compile_cache::BucketedCompileCache;
82use rlx_runtime::{AotCache, CompileOptions, Device, Session};
83use std::cell::RefCell;
84use std::collections::HashMap;
85use std::path::PathBuf;
86use std::sync::Arc;
87use std::time::Instant;
88
89pub type Qwen35ConfigSource = rlx_runtime::ConfigSource<Qwen35Config>;
103
104#[derive(Default, Debug)]
105pub struct Qwen35RunnerBuilder {
106 weights: Option<PathBuf>,
107 config: Option<Qwen35ConfigSource>,
108 device: Option<Device>,
109 max_seq: Option<usize>,
110 enable_mtp: bool,
111 last_logits_only: bool,
112 packed_weights: Option<bool>,
117 runtime_mrope: bool,
118 mrope_section_positions: Option<Vec<[usize; 4]>>,
119 batch: Option<usize>,
120 bucketed_decode: Option<bool>,
121 mtp_logits_path: bool,
123 fast_mtp: bool,
124 fast_greedy_lm_head: Option<bool>,
126 aot_cache_dir: Option<PathBuf>,
128 dynamic_prefill: bool,
130 dynamic_decode: bool,
132 inline_weights: Option<(Qwen35Config, Qwen35Weights)>,
133 mmproj: Option<PathBuf>,
135 inline_mmproj: Option<(crate::vision::MmProjConfig, crate::vision::MmProjWeights)>,
137 prefill_profile: Option<CompileProfile>,
139 decode_profile: Option<CompileProfile>,
141 max_gpu_experts_per_layer: Option<usize>,
143 moe_memory_budget_bytes: Option<usize>,
145 expert_refresh_every_decode_steps: Option<usize>,
147 jump_steps: Option<usize>,
149 reserve_vram_gb: Option<f64>,
151 moe_collect_stats: bool,
153}
154
155impl Qwen35RunnerBuilder {
156 pub fn weights(mut self, path: impl Into<PathBuf>) -> Self {
157 self.weights = Some(path.into());
158 self
159 }
160
161 pub fn config(mut self, src: Qwen35ConfigSource) -> Self {
167 self.config = Some(src);
168 self
169 }
170
171 pub fn config_value(self, cfg: Qwen35Config) -> Self {
174 self.config(Qwen35ConfigSource::Explicit(cfg))
175 }
176 pub fn device(mut self, d: Device) -> Self {
177 self.device = Some(d);
178 self
179 }
180 pub fn max_seq(mut self, n: usize) -> Self {
181 self.max_seq = Some(n);
182 self
183 }
184 pub fn enable_mtp(mut self, on: bool) -> Self {
185 self.enable_mtp = on;
186 self
187 }
188 pub fn last_logits_only(mut self, on: bool) -> Self {
189 self.last_logits_only = on;
190 self
191 }
192 pub fn packed_weights(mut self, on: bool) -> Self {
193 self.packed_weights = Some(on);
194 self
195 }
196 pub fn runtime_mrope(mut self, on: bool) -> Self {
199 self.runtime_mrope = on;
200 self
201 }
202 pub fn mrope_section_positions(mut self, positions: Vec<[usize; 4]>) -> Self {
204 self.mrope_section_positions = Some(positions);
205 self
206 }
207 pub fn batch(mut self, n: usize) -> Self {
209 self.batch = Some(n);
210 self
211 }
212 pub fn bucketed_decode(mut self, on: bool) -> Self {
214 self.bucketed_decode = Some(on);
215 self
216 }
217 pub fn mtp_logits_path(mut self, on: bool) -> Self {
219 self.mtp_logits_path = on;
220 self
221 }
222 pub fn fast_mtp(mut self, on: bool) -> Self {
224 self.fast_mtp = on;
225 self
226 }
227 pub fn fast_greedy_lm_head(mut self, on: bool) -> Self {
229 self.fast_greedy_lm_head = Some(on);
230 self
231 }
232 pub fn aot_cache_dir(mut self, path: impl Into<PathBuf>) -> Self {
234 self.aot_cache_dir = Some(path.into());
235 self
236 }
237 pub fn dynamic_prefill(mut self, on: bool) -> Self {
239 self.dynamic_prefill = on;
240 self
241 }
242 pub fn dynamic_decode(mut self, on: bool) -> Self {
244 self.dynamic_decode = on;
245 self
246 }
247 pub fn inline_weights(mut self, cfg: Qwen35Config, weights: Qwen35Weights) -> Self {
249 self.inline_weights = Some((cfg, weights));
250 self
251 }
252
253 pub fn mmproj(mut self, path: impl Into<PathBuf>) -> Self {
255 self.mmproj = Some(path.into());
256 self
257 }
258
259 pub fn inline_mmproj(
261 mut self,
262 cfg: crate::vision::MmProjConfig,
263 weights: crate::vision::MmProjWeights,
264 ) -> Self {
265 self.inline_mmproj = Some((cfg, weights));
266 self
267 }
268
269 pub fn with_compile_profiles(
271 mut self,
272 prefill: CompileProfile,
273 decode: CompileProfile,
274 ) -> Self {
275 self.prefill_profile = Some(prefill);
276 self.decode_profile = Some(decode);
277 self
278 }
279
280 pub fn max_gpu_experts_per_layer(mut self, n: usize) -> Self {
282 self.max_gpu_experts_per_layer = Some(n);
283 self
284 }
285
286 pub fn moe_memory_budget_bytes(mut self, bytes: usize) -> Self {
288 self.moe_memory_budget_bytes = Some(bytes);
289 self
290 }
291
292 pub fn expert_refresh_every_decode_steps(mut self, n: usize) -> Self {
294 self.expert_refresh_every_decode_steps = Some(n);
295 self.jump_steps = Some(n);
296 self
297 }
298
299 pub fn jump_steps(mut self, n: usize) -> Self {
301 self.jump_steps = Some(n);
302 self.expert_refresh_every_decode_steps = Some(n);
303 self
304 }
305
306 pub fn reserve_vram_gb(mut self, gb: f64) -> Self {
308 self.reserve_vram_gb = Some(gb);
309 self
310 }
311
312 pub fn moe_collect_stats(mut self, on: bool) -> Self {
314 self.moe_collect_stats = on;
315 self
316 }
317
318 pub fn enable_predictive_expert_offload(mut self, max_gpu_experts_per_layer: usize) -> Self {
320 self.max_gpu_experts_per_layer = Some(max_gpu_experts_per_layer);
321 self
322 }
323
324 pub fn build(self) -> Result<Qwen35Runner> {
325 let device = self.device.unwrap_or(Device::Cpu);
326 let max_seq = self.max_seq.unwrap_or(128);
327 let batch = self.batch.unwrap_or(1);
328 if batch == 0 {
329 bail!("qwen35: batch must be >= 1");
330 }
331
332 if let Some(src) = self.config.as_ref()
341 && !matches!(src, Qwen35ConfigSource::Embedded)
342 {
343 let weights_path = self
344 .weights
345 .as_ref()
346 .ok_or_else(|| anyhow!("weights path required (call .weights(...))"))?
347 .clone();
348 let resolved = resolve_weights_file(&weights_path)?;
349 let ext = resolved
350 .extension()
351 .and_then(|s| s.to_str())
352 .unwrap_or("")
353 .to_ascii_lowercase();
354 if ext == "gguf" {
355 bail!(
356 "qwen35: Qwen35ConfigSource::{:?} supplied with a GGUF weights file at \
357 {:?} — drop the config source (use the default Embedded) so the GGUF \
358 metadata is the source of truth",
359 src,
360 resolved
361 );
362 }
363 if self.packed_weights == Some(true) {
364 bail!("qwen35: packed_weights requires GGUF; safetensors path is dequant-only");
365 }
366 let cfg = match src {
367 Qwen35ConfigSource::Embedded => unreachable!(),
368 Qwen35ConfigSource::JsonFile(p) => Qwen35Config::from_hf_config_json(p)
369 .with_context(|| format!("qwen35: parse HF config {p:?}"))?,
370 Qwen35ConfigSource::Explicit(cfg) => cfg.clone(),
371 };
372 if self.enable_mtp && cfg.nextn_predict_layers == 0 {
373 bail!(
374 "qwen35: enable_mtp(true) but config has \
375 nextn_predict_layers=0 (no MTP heads to wire)"
376 );
377 }
378 validate_device(&cfg, device, false)?;
379 let path_str = resolved
380 .to_str()
381 .ok_or_else(|| anyhow!("non-utf8 weights path"))?;
382 let inner_map = rlx_core::weight_map::WeightMap::from_file(path_str)
383 .with_context(|| format!("qwen35: load safetensors {resolved:?}"))?;
384 let mut loader = rlx_core::HfTranslatingLoader::new(inner_map);
385 let t = std::time::Instant::now();
386 let weights = Qwen35Weights::from_loader(&mut loader, &cfg)?;
387 eprintln!(
388 "[qwen35] read safetensors weights in {:.2?} \
389 (layers={}, hidden={})",
390 t.elapsed(),
391 cfg.num_hidden_layers,
392 cfg.hidden_size,
393 );
394 return finish_build(
395 cfg,
396 weights,
397 resolved,
398 None,
399 device,
400 max_seq,
401 batch,
402 self.enable_mtp,
403 self.last_logits_only,
404 self.runtime_mrope,
405 self.mrope_section_positions,
406 self.bucketed_decode,
407 self.mtp_logits_path,
408 self.fast_mtp,
409 self.fast_greedy_lm_head.unwrap_or(true),
410 self.aot_cache_dir.clone(),
411 self.dynamic_prefill,
412 self.dynamic_decode,
413 self.mmproj.clone(),
414 self.inline_mmproj,
415 self.prefill_profile,
416 self.decode_profile,
417 self.max_gpu_experts_per_layer,
418 self.moe_memory_budget_bytes,
419 self.jump_steps.or(self.expert_refresh_every_decode_steps),
420 self.reserve_vram_gb.unwrap_or(1.5),
421 self.moe_collect_stats,
422 );
423 }
424
425 if let Some((cfg, weights)) = self.inline_weights {
426 if self.packed_weights == Some(true) {
427 bail!("qwen35: inline_weights and packed_weights are mutually exclusive");
428 }
429 if self.enable_mtp && cfg.nextn_predict_layers == 0 {
430 bail!(
431 "qwen35: enable_mtp(true) but config has \
432 nextn_predict_layers=0 (no MTP heads to wire)"
433 );
434 }
435 if self.mmproj.is_some() && self.inline_mmproj.is_some() {
436 bail!("qwen35: mmproj and inline_mmproj are mutually exclusive");
437 }
438 validate_device(&cfg, device, false)?;
439 return finish_build(
440 cfg,
441 weights,
442 PathBuf::new(),
443 None,
444 device,
445 max_seq,
446 batch,
447 self.enable_mtp,
448 self.last_logits_only,
449 self.runtime_mrope,
450 self.mrope_section_positions,
451 self.bucketed_decode,
452 self.mtp_logits_path,
453 self.fast_mtp,
454 self.fast_greedy_lm_head.unwrap_or(true),
455 self.aot_cache_dir.clone(),
456 self.dynamic_prefill,
457 self.dynamic_decode,
458 self.mmproj.clone(),
459 self.inline_mmproj,
460 self.prefill_profile,
461 self.decode_profile,
462 self.max_gpu_experts_per_layer,
463 self.moe_memory_budget_bytes,
464 self.jump_steps.or(self.expert_refresh_every_decode_steps),
465 self.reserve_vram_gb.unwrap_or(1.5),
466 self.moe_collect_stats,
467 );
468 }
469
470 let weights_path = resolve_weights_file(
471 &self
472 .weights
473 .ok_or_else(|| anyhow!("weights path required (call .weights(...))"))?,
474 )?;
475 let _t_total = Instant::now();
476 let t = Instant::now();
477 let raw = assert_gguf_family(&weights_path, GgufModelFamily::Qwen35)?;
478 let mut loader = GgufLoader::from_file(
479 weights_path
480 .to_str()
481 .ok_or_else(|| anyhow!("non-utf8 weights path"))?,
482 )?;
483 loader.include_mtp(true);
484 let cfg = Qwen35Config::from_gguf(&raw)?;
485 eprintln!(
486 "[qwen35] loaded GGUF metadata in {:.2?} \
487 (layers={}, hidden={}, ssm_state={})",
488 t.elapsed(),
489 cfg.num_hidden_layers,
490 cfg.hidden_size,
491 cfg.ssm_state_size,
492 );
493
494 if self.enable_mtp && cfg.nextn_predict_layers == 0 {
495 bail!(
496 "qwen35: enable_mtp(true) but the file has \
497 nextn_predict_layers=0 (no MTP heads to wire)"
498 );
499 }
500 let packed = self.packed_weights.unwrap_or_else(|| {
508 if raw.tensors.values().any(|t| {
509 matches!(
510 t.dtype,
511 rlx_gguf::GgmlType::Q2K
512 | rlx_gguf::GgmlType::Q3K
513 | rlx_gguf::GgmlType::Q4K
514 | rlx_gguf::GgmlType::Q5K
515 | rlx_gguf::GgmlType::Q6K
516 | rlx_gguf::GgmlType::Q8K
517 )
518 }) {
519 return true;
520 }
521 std::fs::metadata(&weights_path)
522 .ok()
523 .map(|m| m.len() >= 256 * 1024 * 1024)
524 .unwrap_or(false)
525 });
526 validate_device(&cfg, device, packed)?;
527
528 let t = Instant::now();
529 let weights = if packed {
530 Qwen35Weights::from_loader_packed(&mut loader, &cfg)?
531 } else {
532 Qwen35Weights::from_loader(&mut loader, &cfg)?
533 };
534 eprintln!(
535 "[qwen35] read weights ({}) in {:.2?}",
536 if packed { "packed" } else { "F32" },
537 t.elapsed(),
538 );
539
540 finish_build(
541 cfg,
542 weights,
543 weights_path,
544 Some(loader),
545 device,
546 max_seq,
547 batch,
548 self.enable_mtp,
549 self.last_logits_only,
550 self.runtime_mrope,
551 self.mrope_section_positions,
552 self.bucketed_decode,
553 self.mtp_logits_path,
554 self.fast_mtp,
555 self.fast_greedy_lm_head.unwrap_or(true),
556 self.aot_cache_dir.clone(),
557 self.dynamic_prefill,
558 self.dynamic_decode,
559 self.mmproj.clone(),
560 self.inline_mmproj,
561 self.prefill_profile,
562 self.decode_profile,
563 self.max_gpu_experts_per_layer,
564 self.moe_memory_budget_bytes,
565 self.jump_steps.or(self.expert_refresh_every_decode_steps),
566 self.reserve_vram_gb.unwrap_or(1.5),
567 self.moe_collect_stats,
568 )
569 }
570}
571
572fn make_qwen35_dyn_cache(
573 device: Device,
574 capacity: usize,
575 aot_cache_dir: Option<&std::path::Path>,
576) -> Qwen35CompileCache {
577 if let Some(dir) = aot_cache_dir {
578 Qwen35CompileCache::with_aot(device, capacity, dir)
579 } else {
580 Qwen35CompileCache::new(device, capacity)
581 }
582}
583
584fn compile_static_prefill_cache(
586 cfg: &Qwen35Config,
587 weights: Qwen35Weights,
588 batch: usize,
589 max_seq: usize,
590 device: Device,
591 prefill_profile: &CompileProfile,
592 runtime_mrope: bool,
593 enable_mtp_head: bool,
594 fast_mtp: bool,
595 fast_greedy_lm_head: bool,
596 aot_cache_dir: Option<&std::path::Path>,
597) -> Result<(
598 rlx_runtime::CompiledGraph,
599 HashMap<String, Vec<f32>>,
600 PackedParams,
601)> {
602 let mut flow_opts = Qwen35PrefillCacheOpts::static_cache(batch, max_seq);
603 flow_opts.with_lm_head = !fast_greedy_lm_head;
604 flow_opts.runtime_mrope = runtime_mrope;
605 flow_opts.enable_mtp_head = enable_mtp_head;
606 flow_opts.fast_mtp = fast_mtp;
607 flow_opts.fast_greedy_lm_head = fast_greedy_lm_head;
608 flow_opts.profile = Some(prefill_profile.clone());
609
610 let (built, packed) = build_qwen35_prefill_cache_built(cfg, weights, &flow_opts)?;
611 let params = built.params().clone();
612 let config = prefill_config(batch, max_seq);
613 let compile_opts =
614 compile_options_from_profile(prefill_profile, device, KernelDispatchConfig::default());
615
616 let mut cache = match aot_cache_dir {
617 Some(dir) => Qwen35CompileCache::with_aot(device, 1, dir),
618 None => Qwen35CompileCache::new(device, 1),
619 };
620 let mut config = config;
621 if aot_cache_dir.is_some() {
622 config = config.with_compilation_mode(CompilationMode::Aot);
623 }
624 let built = built.with_execution_config(&config);
625 let compiled = cache.compile_built(built, &config, &compile_opts)?;
626 Ok((compiled, params, packed))
627}
628
629#[allow(clippy::too_many_arguments)]
630fn finish_build(
631 cfg: Qwen35Config,
632 weights: Qwen35Weights,
633 weights_path: PathBuf,
634 gguf_loader: Option<GgufLoader>,
635 device: Device,
636 max_seq: usize,
637 batch: usize,
638 enable_mtp: bool,
639 last_logits_only: bool,
640 runtime_mrope: bool,
641 mrope_section_positions: Option<Vec<[usize; 4]>>,
642 bucketed_decode: Option<bool>,
643 mtp_logits_path: bool,
644 fast_mtp: bool,
645 fast_greedy_lm_head: bool,
646 aot_cache_dir: Option<PathBuf>,
647 dynamic_prefill: bool,
648 dynamic_decode: bool,
649 mmproj_path: Option<PathBuf>,
650 inline_mmproj: Option<(MmProjConfig, MmProjWeights)>,
651 prefill_profile_override: Option<CompileProfile>,
652 decode_profile_override: Option<CompileProfile>,
653 max_gpu_experts_per_layer: Option<usize>,
654 moe_memory_budget_bytes: Option<usize>,
655 jump_steps: Option<usize>,
656 reserve_vram_gb: f64,
657 moe_collect_stats: bool,
658) -> Result<Qwen35Runner> {
659 let prefill_profile = prefill_profile_override.unwrap_or_else(|| {
660 if weights_path.as_os_str().is_empty() {
661 qwen35_profile_default(false)
662 } else {
663 qwen35_profile_near_weights(&weights_path, false)
664 }
665 });
666 let decode_profile = decode_profile_override.unwrap_or_else(|| {
667 if weights_path.as_os_str().is_empty() {
668 qwen35_profile_default(true)
669 } else {
670 qwen35_profile_near_weights(&weights_path, true)
671 }
672 });
673
674 if fast_mtp && !mtp_logits_path && !enable_mtp {
675 bail!("qwen35: fast_mtp requires enable_mtp(true) or mtp_logits_path(true)");
676 }
677 if mtp_logits_path && !enable_mtp {
678 bail!("qwen35: mtp_logits_path requires enable_mtp(true)");
679 }
680
681 if dynamic_prefill && batch != 1 {
682 bail!("qwen35: dynamic_prefill requires batch=1");
683 }
684 if dynamic_decode && batch != 1 {
685 bail!("qwen35: dynamic_decode requires batch=1");
686 }
687 if dynamic_decode && bucketed_decode.unwrap_or(true) {
688 eprintln!("[qwen35] dynamic_decode enabled — disabling bucketed decode cache");
689 }
690 let bucketed_decode = if dynamic_decode {
691 false
692 } else {
693 bucketed_decode.unwrap_or(true)
694 };
695
696 let vision_encoder = if let Some(ref path) = mmproj_path {
697 Some(load_vision_encoder(
698 path.to_str()
699 .ok_or_else(|| anyhow!("non-utf8 mmproj path"))?,
700 224,
701 224,
702 )?)
703 } else if let Some((vcfg, vweights)) = inline_mmproj {
704 Some(Qwen35VisionEncoder::from_parts(vcfg, vweights, 4, 4)?)
705 } else {
706 None
707 };
708 let runtime_mrope = runtime_mrope || vision_encoder.is_some();
709 if vision_encoder.is_some() && batch != 1 {
710 bail!("qwen35: VLM (mmproj) requires batch=1");
711 }
712 if vision_encoder.is_some() && !dynamic_prefill {
713 eprintln!("[qwen35] mmproj loaded — enabling dynamic prefill for variable multimodal seq");
714 }
715 let dynamic_prefill = dynamic_prefill || vision_encoder.is_some();
716
717 let t = Instant::now();
718 let aot = aot_cache_dir.as_ref().map(AotCache::new);
719 let (cache_params, cache_packed, mut prefill_cache, prefill_dynamic_cache) = if dynamic_prefill
720 {
721 let (_cache_hir, cache_params, cache_packed) = build_qwen35_prefill_cache_hir_dynamic_ext(
722 &cfg,
723 weights.clone(),
724 batch,
725 max_seq,
726 runtime_mrope,
727 mtp_logits_path,
728 fast_mtp,
729 fast_greedy_lm_head,
730 )?;
731 eprintln!(
732 "[qwen35] built prefill-cache IR in {:.2?} (params={}, packed={})",
733 t.elapsed(),
734 cache_params.len(),
735 cache_packed.len(),
736 );
737 eprintln!("[qwen35] dynamic prefill template ready (compile on first prompt)");
738 (
739 cache_params,
740 cache_packed,
741 None,
742 Some(make_qwen35_dyn_cache(device, 32, aot_cache_dir.as_deref())),
743 )
744 } else {
745 let (compiled, cache_params, cache_packed) = compile_static_prefill_cache(
746 &cfg,
747 weights.clone(),
748 batch,
749 max_seq,
750 device,
751 &prefill_profile,
752 runtime_mrope,
753 mtp_logits_path,
754 fast_mtp,
755 fast_greedy_lm_head,
756 aot_cache_dir.as_deref(),
757 )?;
758 eprintln!(
759 "[qwen35] compiled prefill-cache via BuiltModel in {:.2?} (params={}, packed={})",
760 t.elapsed(),
761 cache_params.len(),
762 cache_packed.len(),
763 );
764 (cache_params, cache_packed, Some(compiled), None)
765 };
766
767 let (prefill_hidden_dynamic_cache, prefill_hidden_cache_params, prefill_hidden_cache_packed) =
768 if vision_encoder.is_some() {
769 let (hidden_hir, hidden_params, hidden_packed) =
770 build_qwen35_prefill_hidden_cache_hir_dynamic_ext(
771 &cfg,
772 weights.clone(),
773 batch,
774 max_seq,
775 runtime_mrope,
776 mtp_logits_path,
777 fast_mtp,
778 fast_greedy_lm_head,
779 )?;
780 let _ = hidden_hir;
781 (
782 Some(make_qwen35_dyn_cache(device, 32, aot_cache_dir.as_deref())),
783 hidden_params,
784 hidden_packed,
785 )
786 } else {
787 (None, HashMap::new(), HashMap::new())
788 };
789
790 let t = Instant::now();
791 if let Some(ref mut compiled) = prefill_cache {
792 for (name, data) in &cache_params {
793 compiled.set_param(name, data);
794 }
795 }
796
797 let decode_compile_cache = if bucketed_decode {
798 Some(BucketedCompileCache::power_of_two_ladder(
799 device,
800 1,
801 max_seq.max(1) as u64,
802 ))
803 } else {
804 None
805 };
806 let decode_dynamic_cache = if dynamic_decode {
807 Some(make_qwen35_dyn_cache(device, 32, aot_cache_dir.as_deref()))
808 } else {
809 None
810 };
811 let (decode_dynamic_params, decode_dynamic_packed) = if dynamic_decode {
812 let (_, p, packed) = build_qwen35_decode_hir_dynamic_ext(
813 &cfg,
814 weights.clone(),
815 batch,
816 max_seq,
817 mtp_logits_path,
818 fast_mtp,
819 fast_greedy_lm_head,
820 )?;
821 (p, packed)
822 } else {
823 (HashMap::new(), HashMap::new())
824 };
825
826 if dynamic_decode {
827 eprintln!("[qwen35] dynamic decode template ready (compile on first step)");
828 }
829
830 let moe_offload = build_moe_offload(
831 &cfg,
832 &weights,
833 max_gpu_experts_per_layer,
834 moe_memory_budget_bytes,
835 jump_steps,
836 reserve_vram_gb,
837 moe_collect_stats,
838 );
839 let moe_store = if moe_offload.is_some() {
840 build_moe_expert_store(&cfg, &weights).ok()
841 } else {
842 None
843 };
844 if let Some(ref mo) = moe_offload {
845 eprintln!(
846 "[qwen35] TIDE MoE offload: layers={} gpu_budget={}/{} jump_steps={} reserve_bytes={}",
847 mo.num_layers(),
848 mo.info.gpu_expert_budget_per_layer,
849 mo.pools[0].num_experts(),
850 mo.jump_steps,
851 mo.info.reserve_bytes,
852 );
853 }
854
855 let mut runner = Qwen35Runner {
856 compiled: None,
857 prefill_cache,
858 prefill_dynamic_cache,
859 prefill_hidden_dynamic_cache,
860 prefill_cache_params: cache_params,
861 prefill_cache_packed: cache_packed,
862 prefill_hidden_cache_params,
863 prefill_hidden_cache_packed,
864 decode_graphs: HashMap::new(),
865 decode_compile_cache,
866 decode_dynamic_cache,
867 predict_hir_cache: None,
868 decode_dynamic_params,
869 decode_dynamic_packed,
870 packed_bytes_cache: HashMap::new(),
871 cfg,
872 device,
873 batch,
874 max_seq,
875 last_logits_only,
876 enable_mtp,
877 mtp_logits_path,
878 fast_mtp,
879 fast_greedy_lm_head,
880 weights,
881 weights_path,
882 gguf_loader,
883 decode_cache: None,
884 runtime_mrope,
885 mrope_section_positions,
886 aot_cache: aot,
887 dynamic_prefill,
888 dynamic_decode,
889 vision_encoder,
890 mmproj_path,
891 prefill_profile,
892 decode_profile,
893 moe_offload,
894 moe_store,
895 moe_refresh_step: 0,
896 };
897
898 if let Some(ref mut compiled) = runner.prefill_cache {
899 upload_packed_opt(
900 compiled,
901 runner.gguf_loader.as_mut(),
902 &runner.prefill_cache_packed,
903 &mut runner.packed_bytes_cache,
904 )?;
905 }
906 eprintln!(
907 "[qwen35] uploaded prefill-cache {} F32 + {} packed params in {:.2?}",
908 runner.prefill_cache_params.len(),
909 runner.prefill_cache_packed.len(),
910 t.elapsed(),
911 );
912
913 runner.warm_decode_graphs()?;
914 runner.warm_predict_graph()?;
915 Ok(runner)
916}
917
918fn ensure_packed_cache(
919 loader: &mut GgufLoader,
920 packed: &PackedParams,
921 cache: &mut HashMap<String, Arc<[u8]>>,
922) -> Result<()> {
923 for (loader_key, _, _) in packed.values() {
924 if cache.contains_key(loader_key) {
925 continue;
926 }
927 let bytes = loader
928 .tensor_bytes_borrowed(loader_key)
929 .ok_or_else(|| anyhow!("packed cache: {loader_key} bytes missing"))?;
930 cache.insert(loader_key.clone(), Arc::from(bytes));
931 }
932 Ok(())
933}
934
935fn upload_packed_opt(
936 compiled: &mut rlx_runtime::CompiledGraph,
937 loader: Option<&mut GgufLoader>,
938 packed: &PackedParams,
939 cache: &mut HashMap<String, Arc<[u8]>>,
940) -> Result<()> {
941 if packed.is_empty() {
942 return Ok(());
943 }
944 let loader = loader
945 .ok_or_else(|| anyhow!("packed params require a GGUF loader (missing weights path)"))?;
946 ensure_packed_cache(loader, packed, cache)?;
947 for (param_name, (loader_key, _scheme, _shape)) in packed {
948 let bytes = cache
949 .get(loader_key)
950 .ok_or_else(|| anyhow!("packed upload: cache miss for {loader_key}"))?;
951 compiled.set_param_typed(param_name, bytes, rlx_ir::DType::U8);
952 }
953 Ok(())
954}
955
956#[allow(dead_code)]
957fn upload_decode_packed(
958 weights_path: &std::path::Path,
959 compiled: &mut rlx_runtime::CompiledGraph,
960 packed: &PackedParams,
961) -> Result<()> {
962 if packed.is_empty() {
963 return Ok(());
964 }
965 let path = weights_path
966 .to_str()
967 .filter(|p| !p.is_empty())
968 .ok_or_else(|| anyhow!("packed decode params require a GGUF weights path"))?;
969 let mut loader = GgufLoader::from_file(path)?;
970 loader.include_mtp(true);
971 upload_packed_opt(compiled, Some(&mut loader), packed, &mut HashMap::new())
972}
973
974pub struct Qwen35Runner {
975 compiled: Option<rlx_runtime::CompiledGraph>,
976 prefill_cache: Option<rlx_runtime::CompiledGraph>,
977 prefill_dynamic_cache: Option<Qwen35CompileCache>,
978 prefill_hidden_dynamic_cache: Option<Qwen35CompileCache>,
979 prefill_cache_params: HashMap<String, Vec<f32>>,
980 prefill_cache_packed: PackedParams,
981 prefill_hidden_cache_params: HashMap<String, Vec<f32>>,
982 prefill_hidden_cache_packed: PackedParams,
983 decode_graphs: HashMap<usize, rlx_runtime::CompiledGraph>,
984 decode_compile_cache: Option<BucketedCompileCache>,
985 decode_dynamic_cache: Option<Qwen35CompileCache>,
986 predict_hir_cache: Option<Qwen35CompileCache>,
988 decode_dynamic_params: HashMap<String, Vec<f32>>,
989 decode_dynamic_packed: PackedParams,
990 packed_bytes_cache: HashMap<String, Arc<[u8]>>,
991 cfg: Qwen35Config,
992 device: Device,
993 batch: usize,
994 max_seq: usize,
995 last_logits_only: bool,
996 enable_mtp: bool,
997 mtp_logits_path: bool,
998 fast_mtp: bool,
999 fast_greedy_lm_head: bool,
1000 weights: Qwen35Weights,
1001 weights_path: PathBuf,
1002 gguf_loader: Option<GgufLoader>,
1003 decode_cache: Option<Qwen35DecodeCache>,
1004 runtime_mrope: bool,
1005 mrope_section_positions: Option<Vec<[usize; 4]>>,
1006 aot_cache: Option<AotCache>,
1007 dynamic_prefill: bool,
1008 dynamic_decode: bool,
1009 vision_encoder: Option<Qwen35VisionEncoder>,
1010 mmproj_path: Option<PathBuf>,
1011 prefill_profile: CompileProfile,
1012 decode_profile: CompileProfile,
1013 moe_offload: Option<MoeOffloadState>,
1015 moe_store: Option<MoeExpertStore>,
1017 moe_refresh_step: usize,
1019}
1020
1021#[derive(Debug, Clone)]
1022pub struct Qwen35PrefillSeed {
1023 pub trunk_logits: Vec<f32>,
1024 pub mtp_logits: Option<Vec<f32>>,
1025}
1026
1027#[derive(Debug, Clone)]
1028pub struct Qwen35PrefillOutput {
1029 pub logits: Vec<f32>,
1030 pub mtp_logits: Option<Vec<f32>>,
1031 pub vocab_size: usize,
1032}
1033
1034impl Qwen35Runner {
1035 pub fn builder() -> Qwen35RunnerBuilder {
1036 Qwen35RunnerBuilder::default()
1037 }
1038
1039 pub fn has_mmproj(&self) -> bool {
1043 self.mmproj_path.is_some() || self.vision_encoder.is_some()
1044 }
1045
1046 fn execution_config(&self, config: ModelExecutionConfig) -> ModelExecutionConfig {
1048 if self.aot_cache.is_some() {
1049 config.with_compilation_mode(CompilationMode::Aot)
1050 } else {
1051 config
1052 }
1053 }
1054
1055 pub fn prefill_profile(&self) -> &CompileProfile {
1056 &self.prefill_profile
1057 }
1058
1059 pub fn decode_profile(&self) -> &CompileProfile {
1060 &self.decode_profile
1061 }
1062
1063 pub fn moe_offload(&self) -> Option<&MoeOffloadState> {
1065 self.moe_offload.as_ref()
1066 }
1067
1068 pub fn predictive_offload_info(&self) -> Option<&rlx_llada2::tide::PredictiveOffloadInfo> {
1070 self.moe_offload.as_ref().map(|m| &m.info)
1071 }
1072
1073 pub fn get_offload_stats(
1075 &self,
1076 residency: Option<&rlx_runtime::MoeResidencyStats>,
1077 ) -> rlx_llada2::tide::TideOffloadStats {
1078 self.moe_offload
1079 .as_ref()
1080 .map(|m| m.tide_offload_stats(residency))
1081 .unwrap_or_default()
1082 }
1083
1084 pub fn jump_steps(&self) -> usize {
1085 self.moe_offload.as_ref().map(|m| m.jump_steps).unwrap_or(1)
1086 }
1087
1088 pub fn predictive_offload_enabled(&self) -> bool {
1089 self.moe_offload
1090 .as_ref()
1091 .is_some_and(|m| m.predictive_enabled)
1092 }
1093
1094 pub fn moe_offload_mut(&mut self) -> Option<&mut MoeOffloadState> {
1095 self.moe_offload.as_mut()
1096 }
1097
1098 pub fn moe_refresh_step(&self) -> usize {
1100 self.moe_refresh_step
1101 }
1102
1103 pub fn enable_moe_topk_on(&self, compiled: &mut rlx_runtime::CompiledGraph) {
1105 if self.moe_offload.is_some() {
1106 compiled.enable_moe_topk_capture(self.cfg.num_experts);
1107 }
1108 }
1109
1110 pub fn sync_moe_residency(&self, compiled: &mut rlx_runtime::CompiledGraph) {
1112 if let Some(mo) = &self.moe_offload {
1113 push_moe_residency(compiled, &mo.per_layer_resident_masks());
1114 }
1115 }
1116
1117 #[allow(dead_code)]
1118 fn moe_prepare_forward(&self, compiled: &mut rlx_runtime::CompiledGraph) {
1119 self.bind_moe_host_weights();
1120 if self.moe_offload.is_some() {
1121 compiled.enable_moe_topk_capture(self.cfg.num_experts);
1122 self.sync_moe_residency(compiled);
1123 }
1124 }
1125
1126 #[allow(dead_code)]
1127 fn moe_finish_forward(
1128 &mut self,
1129 compiled: &mut rlx_runtime::CompiledGraph,
1130 denoise_step: usize,
1131 is_prefill_block: bool,
1132 ) -> bool {
1133 let Some(layers) = compiled.take_moe_topk_capture() else {
1134 return false;
1135 };
1136 let store = self.moe_store.clone();
1137 let Some(mo) = self.moe_offload.as_mut() else {
1138 return false;
1139 };
1140 let refreshed = if let Some(store) = store.as_ref() {
1141 mo.refresh_from_capture_with_store(store, &layers, denoise_step, is_prefill_block)
1142 } else {
1143 mo.refresh_from_capture(&layers, denoise_step, is_prefill_block)
1144 };
1145 if refreshed {
1146 push_moe_residency(compiled, &mo.per_layer_resident_masks());
1147 }
1148 refreshed
1149 }
1150
1151 fn bind_moe_host_weights(&self) {
1153 if self.moe_offload.is_none() {
1154 rlx_cpu::moe_residency::bind_host_weights(None);
1155 return;
1156 }
1157 if let Some(store) = &self.moe_store {
1158 rlx_cpu::moe_residency::bind_host_weights(Some(moe_host_bind_from_store(store)));
1159 } else {
1160 rlx_cpu::moe_residency::bind_host_weights(None);
1161 }
1162 }
1163
1164 pub fn moe_offload_after_forward(&mut self, compiled: &mut rlx_runtime::CompiledGraph) -> bool {
1166 let Some(mo) = self.moe_offload.as_mut() else {
1167 return false;
1168 };
1169 let Some(layers) = compiled.take_moe_topk_capture() else {
1170 return false;
1171 };
1172 let refreshed = mo.refresh_from_capture(&layers, self.moe_refresh_step, false);
1173 if refreshed {
1174 self.sync_moe_residency(compiled);
1175 }
1176 self.moe_refresh_step = self.moe_refresh_step.saturating_add(1);
1177 refreshed
1178 }
1179
1180 pub fn moe_refresh_after_forward(&mut self, expert_idx: &[u32]) -> bool {
1182 let Some(mo) = self.moe_offload.as_mut() else {
1183 return false;
1184 };
1185 let refresh = mo.pools[0].should_refresh(
1186 rlx_runtime::MoEExecMode::Reuse,
1187 self.moe_refresh_step,
1188 false,
1189 );
1190 if refresh {
1191 for pool in &mut mo.pools {
1192 pool.refresh_from_indices(expert_idx);
1193 }
1194 }
1195 self.moe_refresh_step = self.moe_refresh_step.saturating_add(1);
1196 refresh
1197 }
1198
1199 pub fn with_compile_profiles(
1201 mut self,
1202 prefill: CompileProfile,
1203 decode: CompileProfile,
1204 ) -> Self {
1205 self.prefill_profile = prefill;
1206 self.decode_profile = decode;
1207 self
1208 }
1209
1210 fn profile_compile_options(&self, decode: bool) -> CompileOptions {
1211 let profile = if decode {
1212 &self.decode_profile
1213 } else {
1214 &self.prefill_profile
1215 };
1216 compile_options_from_profile(profile, self.device, KernelDispatchConfig::default())
1217 }
1218
1219 fn dyn_compile_options(&self, config: &ModelExecutionConfig) -> CompileOptions {
1220 let decode = matches!(config.preset, ExecutionPreset::Qwen35Decode);
1221 let mut opts = self.profile_compile_options(decode);
1222 opts.kernel_dispatch = config.component().kernel_dispatch;
1223 opts.dim_binding(config.dim_binding())
1224 }
1225
1226 fn bucketed_decode_compile_options(&self) -> CompileOptions {
1227 self.profile_compile_options(true)
1228 }
1229
1230 pub fn compile_prefill_built(
1232 &self,
1233 cache: &mut Qwen35CompileCache,
1234 built: rlx_flow::BuiltModel,
1235 batch: usize,
1236 seq: usize,
1237 ) -> Result<rlx_runtime::CompiledGraph> {
1238 let config = self.execution_config(prefill_config(batch, seq));
1239 let opts = self.dyn_compile_options(&config);
1240 cache.compile_built(built, &config, &opts)
1241 }
1242
1243 pub fn cfg(&self) -> &Qwen35Config {
1244 &self.cfg
1245 }
1246 pub fn device(&self) -> Device {
1247 self.device
1248 }
1249 pub fn max_seq(&self) -> usize {
1250 self.max_seq
1251 }
1252 pub fn lm_vocab_size(&self) -> usize {
1253 self.weights.lm_vocab_size(&self.cfg)
1254 }
1255
1256 pub fn has_vision(&self) -> bool {
1258 self.vision_encoder.is_some()
1259 }
1260
1261 pub fn mmproj_path(&self) -> Option<&std::path::Path> {
1263 self.mmproj_path.as_deref()
1264 }
1265
1266 fn effective_vocab(&self, graph_vocab: usize) -> usize {
1267 self.lm_vocab_size().min(graph_vocab)
1268 }
1269
1270 fn compile_hir_for_config(
1271 &mut self,
1272 config: ModelExecutionConfig,
1273 aot_disk_key: &str,
1274 hir: rlx_ir::hir::HirModule,
1275 ) -> Result<rlx_runtime::CompiledGraph> {
1276 let config = self.execution_config(config);
1277 let opts = self.dyn_compile_options(&config);
1278 if let Some(aot) = self.aot_cache.as_ref() {
1279 return Ok(aot.compile_hir_cached(aot_disk_key, self.device, hir, &opts)?);
1280 }
1281 if config.preset == ExecutionPreset::Qwen35Decode {
1284 return Ok(Session::new(self.device).compile_hir_with(hir, &opts)?);
1285 }
1286 let cache = self
1287 .predict_hir_cache
1288 .get_or_insert_with(|| make_qwen35_dyn_cache(self.device, 64, None));
1289 let hir = hir;
1290 get_or_specialize_hir_with_options(cache, &config, || hir.clone(), &opts, |_| Ok(()))?;
1291 if self.device == Device::Cpu {
1292 let compiled = get_or_specialize_hir_with_options(
1293 cache,
1294 &config,
1295 || hir.clone(),
1296 &opts,
1297 |_| Ok(()),
1298 )?;
1299 return Ok(compiled.clone());
1300 }
1301 Ok(Session::new(self.device).compile_hir_with(hir, &opts)?)
1302 }
1303
1304 fn lm_loader(&self) -> Option<&GgufLoader> {
1305 self.gguf_loader.as_ref()
1306 }
1307
1308 fn argmax_batch_from_hidden(&self, hidden: &[f32]) -> Result<Vec<u32>> {
1309 let n_embd = self.cfg.hidden_size;
1310 let mut toks = Vec::with_capacity(self.batch);
1311 for b in 0..self.batch {
1312 let h = &hidden[b * n_embd..(b + 1) * n_embd];
1313 let (idx, _) = greedy_lm_head_argmax(&self.weights, &self.cfg, h, self.lm_loader())?;
1314 toks.push(idx);
1315 }
1316 Ok(toks)
1317 }
1318
1319 fn sample_batch_from_hidden(&self, hidden: &[f32], opts: SampleOpts) -> Result<Vec<u32>> {
1320 let n_embd = self.cfg.hidden_size;
1321 let mut toks = Vec::with_capacity(self.batch);
1322 for b in 0..self.batch {
1323 let h = &hidden[b * n_embd..(b + 1) * n_embd];
1324 toks.push(sample_lm_head_from_hidden(
1325 &self.weights,
1326 &self.cfg,
1327 h,
1328 self.lm_loader(),
1329 opts,
1330 )?);
1331 }
1332 Ok(toks)
1333 }
1334
1335 fn decode_step_trunk_raw(
1336 &mut self,
1337 cache: &mut Qwen35DecodeCache,
1338 tokens: &[u32],
1339 generated_per_row: &[usize],
1340 ) -> Result<(Vec<f32>, Option<Vec<f32>>)> {
1341 if self.dynamic_decode {
1342 return self.decode_step_dynamic_raw(cache, tokens, generated_per_row);
1343 }
1344 let past_seq = cache.past_seq;
1345 let head_half = self.cfg.key_length / 2;
1346 let (cos, sin) = mrope_slice_at_pos(&self.cfg, past_seq, head_half);
1347 let use_bucket = self
1348 .decode_compile_cache
1349 .as_ref()
1350 .and_then(|c| c.bucket_for(past_seq as u64))
1351 .is_some();
1352 if use_bucket {
1353 self.decode_step_bucketed_raw(cache, tokens, generated_per_row, &cos, &sin)
1354 } else {
1355 let feeds_owned = decode_step_feeds(
1356 &self.cfg,
1357 cache,
1358 tokens,
1359 &cos,
1360 &sin,
1361 None,
1362 generated_per_row,
1363 )?;
1364 let feeds: Vec<(&str, &[f32])> = feeds_owned
1365 .iter()
1366 .map(|(k, v)| (k.as_str(), v.as_slice()))
1367 .collect();
1368 if !self.decode_graphs.contains_key(&past_seq) {
1369 let (hir, params, packed) = build_qwen35_decode_hir_ext(
1370 &self.cfg,
1371 self.weights.clone(),
1372 self.batch,
1373 past_seq,
1374 false,
1375 self.mtp_logits_path,
1376 self.fast_mtp,
1377 self.fast_greedy_lm_head,
1378 )?;
1379 let mut compiled = self.compile_hir_for_config(
1380 decode_config(self.batch, past_seq),
1381 &format!("decode_{past_seq}"),
1382 hir,
1383 )?;
1384 for (name, data) in ¶ms {
1385 compiled.set_param(name, data);
1386 }
1387 upload_packed_opt(
1388 &mut compiled,
1389 self.gguf_loader.as_mut(),
1390 &packed,
1391 &mut self.packed_bytes_cache,
1392 )?;
1393 self.decode_graphs.insert(past_seq, compiled);
1394 }
1395 let step = self.moe_refresh_step;
1396 let has_moe = self.moe_offload.is_some();
1397 let num_experts = self.cfg.num_experts;
1398 let moe_masks = self
1399 .moe_offload
1400 .as_ref()
1401 .map(|m| m.per_layer_resident_masks());
1402 self.bind_moe_host_weights();
1403 let outs = {
1404 let compiled = self.decode_graphs.get_mut(&past_seq).unwrap();
1405 if has_moe {
1406 compiled.enable_moe_topk_capture(num_experts);
1407 if let Some(layers) = &moe_masks {
1408 push_moe_residency(compiled, layers);
1409 }
1410 }
1411 compiled.run(&feeds)
1412 };
1413 if has_moe {
1414 let layers = {
1415 let compiled = self.decode_graphs.get_mut(&past_seq).unwrap();
1416 compiled.take_moe_topk_capture()
1417 };
1418 if let (Some(mo), Some(layers)) = (self.moe_offload.as_mut(), layers) {
1419 let store = self.moe_store.as_ref();
1420 let compiled = self.decode_graphs.get_mut(&past_seq).unwrap();
1421 if refresh_moe_from_capture(mo, store, compiled, &layers, step, false) {
1422 if let Some(store) = self.moe_store.as_ref() {
1423 store.apply_to_compiled(compiled);
1424 }
1425 }
1426 }
1427 }
1428 self.moe_refresh_step = step.saturating_add(1);
1429 advance_cache_from_decode_outputs(
1430 &self.cfg,
1431 cache,
1432 outs,
1433 None,
1434 self.mtp_logits_path,
1435 false,
1436 self.fast_greedy_lm_head,
1437 )
1438 }
1439 }
1440
1441 fn decode_step_dynamic_raw(
1442 &mut self,
1443 cache: &mut Qwen35DecodeCache,
1444 tokens: &[u32],
1445 generated_per_row: &[usize],
1446 ) -> Result<(Vec<f32>, Option<Vec<f32>>)> {
1447 let past_seq = cache.past_seq;
1448 let head_half = self.cfg.key_length / 2;
1449 let (cos, sin) = mrope_slice_at_pos(&self.cfg, past_seq, head_half);
1450 let feeds_owned = decode_step_feeds(
1451 &self.cfg,
1452 cache,
1453 tokens,
1454 &cos,
1455 &sin,
1456 None,
1457 generated_per_row,
1458 )?;
1459 let feeds: Vec<(&str, &[f32])> = feeds_owned
1460 .iter()
1461 .map(|(k, v)| (k.as_str(), v.as_slice()))
1462 .collect();
1463
1464 let config = self.execution_config(decode_config(self.batch, past_seq));
1465 let compile_opts = self.dyn_compile_options(&config);
1466 let dyn_cache = self
1467 .decode_dynamic_cache
1468 .as_mut()
1469 .ok_or_else(|| anyhow!("dynamic decode without cache"))?;
1470 let cfg = self.cfg.clone();
1471 let weights = self.weights.clone();
1472 let max_seq = self.max_seq;
1473 let mtp_logits_path = self.mtp_logits_path;
1474 let fast_mtp = self.fast_mtp;
1475 let fast_greedy = self.fast_greedy_lm_head;
1476 let batch = self.batch;
1477 let decode_params = &self.decode_dynamic_params;
1478 let decode_packed = &self.decode_dynamic_packed;
1479 let gguf_loader = &mut self.gguf_loader;
1480 let packed_bytes_cache = &mut self.packed_bytes_cache;
1481 let compiled = get_or_specialize_hir_with_options(
1482 dyn_cache,
1483 &config,
1484 || {
1485 build_qwen35_decode_hir_dynamic_ext(
1486 &cfg,
1487 weights,
1488 batch,
1489 max_seq,
1490 mtp_logits_path,
1491 fast_mtp,
1492 fast_greedy,
1493 )
1494 .expect("dynamic decode HIR")
1495 .0
1496 },
1497 &compile_opts,
1498 |c| {
1499 for (name, data) in decode_params {
1500 c.set_param(name, data);
1501 }
1502 upload_packed_opt(c, gguf_loader.as_mut(), decode_packed, packed_bytes_cache)
1503 },
1504 )?;
1505 let outs = compiled.run(&feeds);
1506 advance_cache_from_decode_outputs(
1507 &self.cfg,
1508 cache,
1509 outs,
1510 None,
1511 self.mtp_logits_path,
1512 false,
1513 self.fast_greedy_lm_head,
1514 )
1515 }
1516
1517 fn trunk_to_logits(&self, trunk: Vec<f32>, is_hidden: bool) -> Result<Vec<f32>> {
1518 if !is_hidden {
1519 return Ok(trunk);
1520 }
1521 let n_embd = self.cfg.hidden_size;
1522 let vocab = self.lm_vocab_size();
1523 let mut logits = Vec::with_capacity(self.batch * vocab);
1524 for b in 0..self.batch {
1525 let h = &trunk[b * n_embd..(b + 1) * n_embd];
1526 logits.extend(lm_head_logits_row(
1527 &self.weights,
1528 &self.cfg,
1529 h,
1530 self.lm_loader(),
1531 )?);
1532 }
1533 Ok(logits)
1534 }
1535
1536 fn ensure_decode_bucket_compiled(&mut self, key: u64) -> Result<usize> {
1538 let decode_opts = self.bucketed_decode_compile_options();
1539 let cache_mut = self
1540 .decode_compile_cache
1541 .as_mut()
1542 .ok_or_else(|| anyhow!("bucketed decode without cache"))?;
1543 let cfg = self.cfg.clone();
1544 let weights = self.weights.clone();
1545 let batch = self.batch;
1546 let mtp_logits_path = self.mtp_logits_path;
1547 let fast_mtp = self.fast_mtp;
1548 let fast_greedy = self.fast_greedy_lm_head;
1549 let packed_slot = RefCell::new(None::<PackedParams>);
1550 let (upper, compiled) = cache_mut
1551 .ensure_hir_with_params(
1552 key,
1553 |upper| {
1554 let (hir, params, packed) = build_qwen35_decode_hir_ext(
1555 &cfg,
1556 weights.clone(),
1557 batch,
1558 upper as usize,
1559 true,
1560 mtp_logits_path,
1561 fast_mtp,
1562 fast_greedy,
1563 )
1564 .expect("qwen35 decode HIR");
1565 *packed_slot.borrow_mut() = Some(packed);
1566 (hir, params)
1567 },
1568 &decode_opts,
1569 )
1570 .ok_or_else(|| anyhow!("past_seq {key} outside decode buckets"))?;
1571 if let Some(packed) = packed_slot.take() {
1572 if !packed.is_empty() {
1573 upload_packed_opt(
1574 compiled,
1575 self.gguf_loader.as_mut(),
1576 &packed,
1577 &mut self.packed_bytes_cache,
1578 )?;
1579 }
1580 }
1581 Ok(upper as usize)
1582 }
1583
1584 fn warm_decode_graphs(&mut self) -> Result<()> {
1586 let upper_bounds: Vec<usize> = match self.decode_compile_cache.as_ref() {
1587 Some(cache) => cache.buckets().map(|r| (r.end - 1) as usize).collect(),
1588 None => return Ok(()),
1589 };
1590 let t = Instant::now();
1591 let total = upper_bounds.len();
1592 for upper in upper_bounds {
1593 self.ensure_decode_bucket_compiled(upper as u64)?;
1594 }
1595 if total > 0 {
1596 eprintln!(
1597 "[qwen35] warmed {total} decode bucket(s) in {:.2?}",
1598 t.elapsed()
1599 );
1600 }
1601 Ok(())
1602 }
1603
1604 fn warm_predict_graph(&mut self) -> Result<()> {
1606 if self.compiled.is_some() {
1607 return Ok(());
1608 }
1609 let t = Instant::now();
1610 self.ensure_predict_compiled()?;
1611 eprintln!("[qwen35] warmed predict graph in {:.2?}", t.elapsed());
1612 Ok(())
1613 }
1614
1615 pub fn reset_decode_cache(&mut self) {
1617 self.decode_cache = None;
1618 }
1619
1620 pub fn decode_cache_checkpoint(&self) -> Option<Qwen35DecodeCache> {
1622 self.decode_cache.clone()
1623 }
1624
1625 pub fn restore_decode_cache(&mut self, cache: Option<Qwen35DecodeCache>) {
1627 self.decode_cache = cache;
1628 }
1629
1630 pub fn commit_decode_tokens(&mut self, tokens: &[u32]) -> Result<()> {
1632 for &tok in tokens {
1633 let _ = self.decode_get_logits(tok)?;
1634 }
1635 Ok(())
1636 }
1637
1638 fn ensure_predict_compiled(&mut self) -> Result<()> {
1639 if self.compiled.is_some() {
1640 return Ok(());
1641 }
1642 let debug_layers = std::env::var("RLX_QWEN35_DEBUG_LAYERS")
1650 .map(|v| v == "1")
1651 .unwrap_or(false);
1652 let t = Instant::now();
1653 let (hir, params, packed) = build_qwen35_hir_sized_ext(
1654 &self.cfg,
1655 self.weights.clone(),
1656 self.batch,
1657 self.max_seq,
1658 true,
1659 self.last_logits_only,
1660 self.enable_mtp,
1661 false,
1662 None,
1663 self.runtime_mrope,
1664 self.fast_mtp,
1665 false,
1666 debug_layers,
1667 )?;
1668 eprintln!(
1669 "[qwen35] built predict IR (lazy) in {:.2?} (params={}, packed={})",
1670 t.elapsed(),
1671 params.len(),
1672 packed.len(),
1673 );
1674 let t = Instant::now();
1675 let mut compiled = self.compile_hir_for_config(
1676 prefill_config(self.batch, self.max_seq),
1677 "predict_logits",
1678 hir,
1679 )?;
1680 eprintln!(
1681 "[qwen35] compiled predict graph (lazy) in {:.2?}",
1682 t.elapsed()
1683 );
1684 let t = Instant::now();
1685 for (name, data) in ¶ms {
1686 compiled.set_param(name, data);
1687 }
1688 if !packed.is_empty() {
1689 upload_packed_opt(
1690 &mut compiled,
1691 self.gguf_loader.as_mut(),
1692 &packed,
1693 &mut self.packed_bytes_cache,
1694 )?;
1695 }
1696 eprintln!(
1697 "[qwen35] uploaded predict {} F32 + {} packed params in {:.2?}",
1698 params.len(),
1699 packed.len(),
1700 t.elapsed(),
1701 );
1702 self.compiled = Some(compiled);
1703 Ok(())
1704 }
1705
1706 pub fn predict_logits(&mut self, prompt_ids: &[u32]) -> Result<Qwen35PrefillOutput> {
1707 let out = self
1708 .predict_logits_batch(&[prompt_ids.to_vec()])
1709 .map(|v| v.into_iter().next().unwrap())?;
1710 if !out.logits.is_empty() {
1717 let mut min = f32::INFINITY;
1718 let mut max = f32::NEG_INFINITY;
1719 for &v in &out.logits {
1720 if v.is_finite() {
1721 if v < min {
1722 min = v;
1723 }
1724 if v > max {
1725 max = v;
1726 }
1727 }
1728 }
1729 if (max - min).abs() < 1e-6 {
1730 bail!(
1731 "qwen35: predict_logits returned degenerate output \
1732 (min={min}, max={max}) — the forward pass produced \
1733 all-equal logits, which indicates a broken op or a \
1734 mis-routed weight tensor in the trunk. Re-run with \
1735 RUST_LOG=debug to capture the offending layer."
1736 );
1737 }
1738 }
1739 Ok(out)
1740 }
1741
1742 pub fn predict_logits_batch(
1746 &mut self,
1747 batch_prompts: &[Vec<u32>],
1748 ) -> Result<Vec<Qwen35PrefillOutput>> {
1749 if batch_prompts.len() != self.batch {
1750 bail!(
1751 "qwen35: expected {} prompts (batch={}), got {}",
1752 self.batch,
1753 self.batch,
1754 batch_prompts.len()
1755 );
1756 }
1757 let max_prompt = batch_prompts.iter().map(|p| p.len()).max().unwrap_or(0);
1758 if max_prompt > self.max_seq {
1759 bail!(
1760 "qwen35: prompt length {max_prompt} exceeds compiled max_seq={}",
1761 self.max_seq
1762 );
1763 }
1764 let padded = pack_input_ids(batch_prompts, self.max_seq)?;
1765 let prompt_lens: Vec<usize> = batch_prompts.iter().map(|p| p.len()).collect();
1766 let last_idx = last_token_indices(&prompt_lens);
1767
1768 let mut feeds: Vec<(&str, &[f32])> = vec![("input_ids", padded.as_slice())];
1769 if self.last_logits_only {
1770 feeds.push(("last_token_idx", last_idx.as_slice()));
1771 }
1772 let rope_owned = self.mrope_prefill_rope_feeds(max_prompt);
1773 for (name, data) in &rope_owned {
1774 feeds.push((name.as_str(), data.as_slice()));
1775 }
1776 self.ensure_predict_compiled()?;
1777 let outs = self.compiled.as_mut().unwrap().run(&feeds);
1778 if outs.is_empty() {
1779 bail!("qwen35: forward produced no outputs");
1780 }
1781 if std::env::var("RLX_QWEN35_DEBUG_LAYERS").as_deref() == Ok("1") {
1786 let n_layers = self.cfg.num_hidden_layers - self.cfg.nextn_predict_layers;
1791 for i in 0..outs.len() {
1792 let v = &outs[i];
1793 let mut min = f32::INFINITY;
1794 let mut max = f32::NEG_INFINITY;
1795 let mut sum = 0.0f64;
1796 let mut nan = 0usize;
1797 let mut nnz = 0usize;
1798 for &x in v {
1799 if x.is_nan() {
1800 nan += 1;
1801 continue;
1802 }
1803 sum += x as f64;
1804 if x < min {
1805 min = x;
1806 }
1807 if x > max {
1808 max = x;
1809 }
1810 if x != 0.0 {
1811 nnz += 1;
1812 }
1813 }
1814 let mean = sum / v.len().max(1) as f64;
1815 let label = if i == 0 {
1816 "logits".to_string()
1817 } else if i - 1 < n_layers {
1818 format!("layer_{:02}", i - 1)
1819 } else {
1820 format!("extra_{:02}", i - 1 - n_layers)
1821 };
1822 eprintln!(
1823 "[qwen35][debug-layers] {label}: len={} nnz={} nan={} min={} max={} mean={:.6}",
1824 v.len(),
1825 nnz,
1826 nan,
1827 min,
1828 max,
1829 mean
1830 );
1831 }
1832 }
1833 let vocab_size = if self.last_logits_only {
1834 outs[0].len() / self.batch
1835 } else {
1836 outs[0].len() / (self.batch * self.max_seq)
1837 };
1838 let sample_vocab = self.effective_vocab(vocab_size);
1839 let mtp_logits = if self.enable_mtp && outs.len() >= 2 {
1840 Some(outs[1].clone())
1841 } else {
1842 None
1843 };
1844 let mut per_batch = Vec::with_capacity(self.batch);
1845 for b in 0..self.batch {
1846 let start = b * vocab_size;
1847 let mut row = outs[0][start..start + vocab_size].to_vec();
1848 row.truncate(sample_vocab);
1849 per_batch.push(Qwen35PrefillOutput {
1850 logits: row,
1851 mtp_logits: mtp_logits.as_ref().map(|m| {
1852 let m_vocab = m.len() / self.batch.max(1);
1853 let mut mv = m[b * m_vocab..(b + 1) * m_vocab].to_vec();
1854 mv.truncate(sample_vocab);
1855 mv
1856 }),
1857 vocab_size: sample_vocab,
1858 });
1859 }
1860 Ok(per_batch)
1861 }
1862
1863 pub fn generate<F>(&mut self, prompt_ids: &[u32], n_new: usize, on_token: F) -> Result<Vec<u32>>
1865 where
1866 F: FnMut(u32) -> bool,
1867 {
1868 self.generate_with_opts(prompt_ids, n_new, SampleOpts::greedy(), on_token)
1869 }
1870
1871 pub fn generate_with_opts<F>(
1873 &mut self,
1874 prompt_ids: &[u32],
1875 n_new: usize,
1876 opts: SampleOpts,
1877 mut on_token: F,
1878 ) -> Result<Vec<u32>>
1879 where
1880 F: FnMut(u32) -> bool,
1881 {
1882 if self.batch != 1 {
1883 bail!(
1884 "qwen35::generate: runner batch={} — use generate_batch() instead",
1885 self.batch
1886 );
1887 }
1888 let generated = self
1889 .generate_batch_with_opts(&[prompt_ids.to_vec()], n_new, None, opts, |_, tok| {
1890 on_token(tok)
1891 })?
1892 .into_iter()
1893 .next()
1894 .unwrap_or_default();
1895 Ok(generated)
1896 }
1897
1898 pub fn generate_batch<F>(
1900 &mut self,
1901 prompts: &[Vec<u32>],
1902 n_new: usize,
1903 on_token: F,
1904 ) -> Result<Vec<Vec<u32>>>
1905 where
1906 F: FnMut(usize, u32) -> bool,
1907 {
1908 self.generate_batch_with_opts(prompts, n_new, None, SampleOpts::greedy(), on_token)
1909 }
1910
1911 pub fn generate_batch_with_opts<F>(
1915 &mut self,
1916 prompts: &[Vec<u32>],
1917 n_new: usize,
1918 n_new_per_row: Option<&[usize]>,
1919 opts: SampleOpts,
1920 mut on_token: F,
1921 ) -> Result<Vec<Vec<u32>>>
1922 where
1923 F: FnMut(usize, u32) -> bool,
1924 {
1925 if prompts.is_empty() {
1926 bail!("qwen35::generate_batch: prompts must be non-empty");
1927 }
1928 if prompts.len() != self.batch {
1929 bail!(
1930 "qwen35::generate_batch: expected {} prompts, got {}",
1931 self.batch,
1932 prompts.len()
1933 );
1934 }
1935 if let Some(limits) = n_new_per_row {
1936 if limits.len() != self.batch {
1937 bail!(
1938 "qwen35::generate_batch: n_new_per_row len {} != batch {}",
1939 limits.len(),
1940 self.batch
1941 );
1942 }
1943 }
1944 for (i, p) in prompts.iter().enumerate() {
1945 if p.is_empty() {
1946 bail!("qwen35::generate_batch: prompt row {i} is empty");
1947 }
1948 }
1949
1950 self.decode_cache = None;
1951
1952 let _prompt_lens: Vec<usize> = prompts.iter().map(|p| p.len()).collect();
1953 let row_limits: Vec<usize> = if let Some(limits) = n_new_per_row {
1954 limits.to_vec()
1955 } else {
1956 vec![n_new; self.batch]
1957 };
1958
1959 let (trunk, mut cache, _) = self.prefill_seed_decode_cache(prompts)?;
1960
1961 let mut generated: Vec<Vec<u32>> = vec![Vec::new(); self.batch];
1962 let mut active = vec![true; self.batch];
1963 let mut row_gen_count = vec![0usize; self.batch];
1964
1965 let mut next_tokens = if self.fast_greedy_lm_head && opts.greedy {
1966 self.argmax_batch_from_hidden(&trunk)?
1967 } else if self.fast_greedy_lm_head
1968 && sample_lm_cap(opts, self.lm_vocab_size()) < self.lm_vocab_size()
1969 {
1970 self.sample_batch_from_hidden(&trunk, opts)?
1971 } else {
1972 let logits = self.trunk_to_logits(trunk, self.fast_greedy_lm_head)?;
1973 sample_logits_batch(&logits, self.lm_vocab_size(), self.batch, opts)
1974 };
1975 if n_new > 0 {
1976 for b in 0..self.batch {
1977 if row_gen_count[b] >= row_limits[b] {
1978 active[b] = false;
1979 continue;
1980 }
1981 let tok = next_tokens[b];
1982 generated[b].push(tok);
1983 row_gen_count[b] += 1;
1984 active[b] = on_token(b, tok) && row_gen_count[b] < row_limits[b];
1985 }
1986 }
1987
1988 for _ in 1..n_new {
1989 if !active.iter().any(|&a| a) {
1990 break;
1991 }
1992 if cache.past_seq >= self.max_seq - 1 {
1993 bail!("qwen35: decode cache reached max_seq={}", self.max_seq);
1994 }
1995 next_tokens = self.decode_step(&mut cache, &next_tokens, &row_gen_count, opts)?;
1996 for b in 0..self.batch {
1997 if !active[b] || row_gen_count[b] >= row_limits[b] {
1998 active[b] = false;
1999 continue;
2000 }
2001 let tok = next_tokens[b];
2002 generated[b].push(tok);
2003 row_gen_count[b] += 1;
2004 active[b] = on_token(b, tok) && row_gen_count[b] < row_limits[b];
2005 }
2006 self.decode_cache = Some(cache.clone());
2007 }
2008 Ok(generated)
2009 }
2010
2011 pub fn prefill_get_last_logits(&mut self, prompt_ids: &[u32]) -> Result<Vec<f32>> {
2014 Ok(self.prefill_seed_for_decode(prompt_ids)?.trunk_logits)
2015 }
2016
2017 pub fn prefill_seed_for_decode(&mut self, prompt_ids: &[u32]) -> Result<Qwen35PrefillSeed> {
2019 if self.batch != 1 {
2020 bail!(
2021 "qwen35: prefill_seed_for_decode requires batch=1 (runner batch={})",
2022 self.batch
2023 );
2024 }
2025 let (trunk, _, mtp_logits) = self.prefill_seed_decode_cache(&[prompt_ids.to_vec()])?;
2026 Ok(Qwen35PrefillSeed {
2027 trunk_logits: self.trunk_to_logits(trunk, self.fast_greedy_lm_head)?,
2028 mtp_logits,
2029 })
2030 }
2031
2032 pub fn prefill_multimodal(
2035 &mut self,
2036 prompt: &str,
2037 rgb: &[u8],
2038 img_w: usize,
2039 img_h: usize,
2040 tokenizer: Option<&std::path::Path>,
2041 ) -> Result<Qwen35PrefillSeed> {
2042 let (trunk, mtp_logits) =
2043 self.prefill_multimodal_trunk(prompt, rgb, img_w, img_h, tokenizer)?;
2044 Ok(Qwen35PrefillSeed {
2045 trunk_logits: self.trunk_to_logits(trunk, self.fast_greedy_lm_head)?,
2046 mtp_logits,
2047 })
2048 }
2049
2050 pub fn prefill_from_assembled(
2052 &mut self,
2053 prefill: MultimodalPrefill,
2054 ) -> Result<Qwen35PrefillSeed> {
2055 if self.batch != 1 {
2056 bail!(
2057 "qwen35: prefill_from_assembled requires batch=1 (runner batch={})",
2058 self.batch
2059 );
2060 }
2061 self.mrope_section_positions = Some(prefill.mrope_sections.clone());
2062 let (trunk, _, mtp_logits) = self.prefill_seed_from_hidden(prefill)?;
2063 Ok(Qwen35PrefillSeed {
2064 trunk_logits: self.trunk_to_logits(trunk, self.fast_greedy_lm_head)?,
2065 mtp_logits,
2066 })
2067 }
2068
2069 fn prefill_multimodal_trunk(
2070 &mut self,
2071 prompt: &str,
2072 rgb: &[u8],
2073 img_w: usize,
2074 img_h: usize,
2075 tokenizer: Option<&std::path::Path>,
2076 ) -> Result<(Vec<f32>, Option<Vec<f32>>)> {
2077 if self.batch != 1 {
2078 bail!(
2079 "qwen35: prefill_multimodal requires batch=1 (runner batch={})",
2080 self.batch
2081 );
2082 }
2083 let vision = {
2084 let enc = self
2085 .vision_encoder
2086 .as_mut()
2087 .ok_or_else(|| anyhow!("qwen35: prefill_multimodal requires .mmproj(...)"))?;
2088 enc.encode_rgb(rgb, img_w, img_h)?
2089 };
2090 if self.weights.token_embd.is_empty() {
2091 bail!("qwen35: multimodal prefill requires token_embd weights");
2092 }
2093 let weights_path = self.weights_path.as_path();
2094 if weights_path.as_os_str().is_empty() {
2095 bail!("qwen35: multimodal prefill requires a GGUF weights path (for tokenizer)");
2096 }
2097 let n_embd = self.cfg.hidden_size;
2098 let mm = MultimodalPrompt {
2099 prompt,
2100 vision: &vision,
2101 };
2102 let prefill = mm.assemble(
2103 |text| encode_prompt_auto(weights_path, tokenizer, text),
2104 &self.weights.token_embd,
2105 n_embd,
2106 0,
2107 )?;
2108 self.mrope_section_positions = Some(prefill.mrope_sections.clone());
2109 let (trunk, _, mtp_logits) = self.prefill_seed_from_hidden(prefill)?;
2110 Ok((trunk, mtp_logits))
2111 }
2112
2113 pub fn generate_multimodal_with_opts<F>(
2115 &mut self,
2116 prompt: &str,
2117 rgb: &[u8],
2118 img_w: usize,
2119 img_h: usize,
2120 tokenizer: Option<&std::path::Path>,
2121 n_new: usize,
2122 opts: SampleOpts,
2123 mut on_token: F,
2124 ) -> Result<Vec<u32>>
2125 where
2126 F: FnMut(u32) -> bool,
2127 {
2128 if self.batch != 1 {
2129 bail!(
2130 "qwen35: generate_multimodal requires batch=1 (runner batch={})",
2131 self.batch
2132 );
2133 }
2134 self.decode_cache = None;
2135 let (trunk, _) = self.prefill_multimodal_trunk(prompt, rgb, img_w, img_h, tokenizer)?;
2136 let mut cache = self
2137 .decode_cache
2138 .take()
2139 .ok_or_else(|| anyhow!("qwen35: multimodal prefill did not seed decode cache"))?;
2140 let mut next_tokens = if self.fast_greedy_lm_head && opts.greedy {
2141 self.argmax_batch_from_hidden(&trunk)?
2142 } else if self.fast_greedy_lm_head
2143 && sample_lm_cap(opts, self.lm_vocab_size()) < self.lm_vocab_size()
2144 {
2145 self.sample_batch_from_hidden(&trunk, opts)?
2146 } else {
2147 let logits = self.trunk_to_logits(trunk, self.fast_greedy_lm_head)?;
2148 sample_logits_batch(&logits, self.lm_vocab_size(), 1, opts)
2149 };
2150 let mut generated = Vec::new();
2151 if n_new > 0 {
2152 let tok = next_tokens[0];
2153 generated.push(tok);
2154 if !on_token(tok) {
2155 return Ok(generated);
2156 }
2157 }
2158 let row_gen = vec![0usize];
2159 for _ in 1..n_new {
2160 if cache.past_seq >= self.max_seq - 1 {
2161 bail!("qwen35: decode cache reached max_seq={}", self.max_seq);
2162 }
2163 next_tokens = self.decode_step(&mut cache, &next_tokens, &row_gen, opts)?;
2164 let tok = next_tokens[0];
2165 generated.push(tok);
2166 self.decode_cache = Some(cache.clone());
2167 if !on_token(tok) {
2168 break;
2169 }
2170 }
2171 Ok(generated)
2172 }
2173
2174 pub fn generate_multimodal<F>(
2176 &mut self,
2177 prompt: &str,
2178 rgb: &[u8],
2179 img_w: usize,
2180 img_h: usize,
2181 tokenizer: Option<&std::path::Path>,
2182 n_new: usize,
2183 on_token: F,
2184 ) -> Result<Vec<u32>>
2185 where
2186 F: FnMut(u32) -> bool,
2187 {
2188 self.generate_multimodal_with_opts(
2189 prompt,
2190 rgb,
2191 img_w,
2192 img_h,
2193 tokenizer,
2194 n_new,
2195 SampleOpts::greedy(),
2196 on_token,
2197 )
2198 }
2199
2200 fn prefill_seed_decode_cache(
2202 &mut self,
2203 prompts: &[Vec<u32>],
2204 ) -> Result<(Vec<f32>, Qwen35DecodeCache, Option<Vec<f32>>)> {
2205 if prompts.len() != self.batch {
2206 bail!(
2207 "qwen35: expected {} prompts (batch={}), got {}",
2208 self.batch,
2209 self.batch,
2210 prompts.len()
2211 );
2212 }
2213 for (i, p) in prompts.iter().enumerate() {
2214 if p.is_empty() {
2215 bail!("qwen35: prompt row {i} is empty");
2216 }
2217 }
2218
2219 let prompt_lens: Vec<usize> = prompts.iter().map(|p| p.len()).collect();
2220 let seq = prompt_lens.iter().copied().max().unwrap();
2221 if seq > self.max_seq {
2222 bail!(
2223 "qwen35: prompt length {seq} exceeds compiled max_seq={}",
2224 self.max_seq
2225 );
2226 }
2227
2228 let input_ids = if self.dynamic_prefill {
2229 pack_input_ids(prompts, seq)?
2230 } else {
2231 pack_input_ids(prompts, self.max_seq)?
2232 };
2233 let last_idx = last_token_indices(&prompt_lens);
2234
2235 let mut feeds: Vec<(&str, &[f32])> = vec![("input_ids", input_ids.as_slice())];
2236 feeds.push(("last_token_idx", last_idx.as_slice()));
2237 let zero_in = zero_recurrent_inputs(&self.cfg, self.batch);
2238 for (name, data) in &zero_in {
2239 feeds.push((name, data.as_slice()));
2240 }
2241 let rope_owned = self.mrope_prefill_rope_feeds(seq);
2242 for (name, data) in &rope_owned {
2243 feeds.push((name.as_str(), data.as_slice()));
2244 }
2245
2246 let has_moe = self.moe_offload.is_some();
2247 let num_experts = self.cfg.num_experts;
2248 let moe_masks = self
2249 .moe_offload
2250 .as_ref()
2251 .map(|m| m.per_layer_resident_masks());
2252 self.bind_moe_host_weights();
2253
2254 let outs = if self.dynamic_prefill {
2255 let config = self.execution_config(prefill_config(self.batch, seq));
2256 let compile_opts = self.dyn_compile_options(&config);
2257 let compiled = {
2258 let cache = self
2259 .prefill_dynamic_cache
2260 .as_mut()
2261 .expect("dynamic prefill cache");
2262 let cfg = self.cfg.clone();
2263 let weights = self.weights.clone();
2264 let runtime_mrope = self.runtime_mrope;
2265 let mtp_logits_path = self.mtp_logits_path;
2266 let fast_mtp = self.fast_mtp;
2267 let fast_greedy = self.fast_greedy_lm_head;
2268 let cache_params = &self.prefill_cache_params;
2269 let cache_packed = &self.prefill_cache_packed;
2270 let gguf_loader = &mut self.gguf_loader;
2271 let packed_bytes_cache = &mut self.packed_bytes_cache;
2272 get_or_specialize_hir_with_options(
2273 cache,
2274 &config,
2275 || {
2276 build_qwen35_prefill_cache_hir_dynamic_ext(
2277 &cfg,
2278 weights,
2279 1,
2280 seq,
2281 runtime_mrope,
2282 mtp_logits_path,
2283 fast_mtp,
2284 fast_greedy,
2285 )
2286 .expect("dynamic prefill HIR")
2287 .0
2288 },
2289 &compile_opts,
2290 |c| {
2291 for (name, data) in cache_params {
2292 c.set_param(name, data);
2293 }
2294 upload_packed_opt(c, gguf_loader.as_mut(), cache_packed, packed_bytes_cache)
2295 },
2296 )?
2297 };
2298 if has_moe {
2299 compiled.enable_moe_topk_capture(num_experts);
2300 if let Some(layers) = &moe_masks {
2301 push_moe_residency(compiled, layers);
2302 }
2303 }
2304 let outs = compiled.run(&feeds);
2305 if let Some(layers) = compiled.take_moe_topk_capture() {
2306 if let Some(mo) = self.moe_offload.as_mut() {
2307 let store = self.moe_store.as_ref();
2308 if refresh_moe_from_capture(mo, store, compiled, &layers, 0, true) {
2309 if let Some(store) = self.moe_store.as_ref() {
2310 store.apply_to_compiled(compiled);
2311 }
2312 }
2313 }
2314 }
2315 outs
2316 } else {
2317 let compiled = self.prefill_cache.as_mut().expect("static prefill cache");
2318 if has_moe {
2319 compiled.enable_moe_topk_capture(num_experts);
2320 if let Some(layers) = &moe_masks {
2321 push_moe_residency(compiled, layers);
2322 }
2323 }
2324 let outs = compiled.run(&feeds);
2325 let layers = if has_moe {
2326 compiled.take_moe_topk_capture()
2327 } else {
2328 None
2329 };
2330 if let (Some(mo), Some(layers)) = (self.moe_offload.as_mut(), layers) {
2331 let store = self.moe_store.as_ref();
2332 if refresh_moe_from_capture(mo, store, compiled, &layers, 0, true) {
2333 if let Some(store) = self.moe_store.as_ref() {
2334 store.apply_to_compiled(compiled);
2335 }
2336 }
2337 }
2338 outs
2339 };
2340 let (trunk, mut cache, mtp_logits) = seed_cache_from_outputs(
2341 &self.cfg,
2342 self.batch,
2343 seq,
2344 &prompt_lens,
2345 outs,
2346 self.mtp_logits_path,
2347 self.fast_greedy_lm_head,
2348 )?;
2349 zero_prompt_padding_kv(&self.cfg, &mut cache, seq);
2350 self.decode_cache = Some(cache.clone());
2351 Ok((trunk, cache, mtp_logits))
2352 }
2353
2354 fn prefill_seed_from_hidden(
2356 &mut self,
2357 prefill: MultimodalPrefill,
2358 ) -> Result<(Vec<f32>, Qwen35DecodeCache, Option<Vec<f32>>)> {
2359 let seq = prefill.seq.len();
2360 if seq == 0 {
2361 bail!("qwen35: multimodal prefill seq is empty");
2362 }
2363 if seq > self.max_seq {
2364 bail!(
2365 "qwen35: multimodal seq {seq} exceeds compiled max_seq={}",
2366 self.max_seq
2367 );
2368 }
2369 let n_embd = self.cfg.hidden_size;
2370 if prefill.hidden.len() != seq * n_embd {
2371 bail!(
2372 "qwen35: prefill hidden len {} != seq*n_embd {}*{}",
2373 prefill.hidden.len(),
2374 seq,
2375 n_embd
2376 );
2377 }
2378
2379 let last_idx = vec![prefill.last_token_idx as f32];
2380 let zero_in = zero_recurrent_inputs(&self.cfg, self.batch);
2381 let input_ids = if self.mtp_logits_path || self.enable_mtp {
2382 Some(pack_input_ids(std::slice::from_ref(&prefill.seq), seq)?)
2383 } else {
2384 None
2385 };
2386 let mut feeds: Vec<(&str, &[f32])> = vec![("prefill_hidden", prefill.hidden.as_slice())];
2387 feeds.push(("last_token_idx", last_idx.as_slice()));
2388 for (name, data) in &zero_in {
2389 feeds.push((name, data.as_slice()));
2390 }
2391 if let Some(ref ids) = input_ids {
2392 feeds.push(("input_ids", ids.as_slice()));
2393 }
2394 let rope_owned = self.mrope_prefill_rope_feeds(seq);
2395 for (name, data) in &rope_owned {
2396 feeds.push((name.as_str(), data.as_slice()));
2397 }
2398
2399 let config = self.execution_config(hidden_prefill_config(self.batch, seq));
2400 let compile_opts = self.dyn_compile_options(&config);
2401 let cache = self
2402 .prefill_hidden_dynamic_cache
2403 .as_mut()
2404 .ok_or_else(|| anyhow!("qwen35: hidden prefill cache missing (mmproj not loaded?)"))?;
2405 let cfg = self.cfg.clone();
2406 let weights = self.weights.clone();
2407 let runtime_mrope = self.runtime_mrope;
2408 let mtp_logits_path = self.mtp_logits_path;
2409 let fast_mtp = self.fast_mtp;
2410 let fast_greedy = self.fast_greedy_lm_head;
2411 let hidden_params = &self.prefill_hidden_cache_params;
2412 let hidden_packed = &self.prefill_hidden_cache_packed;
2413 let gguf_loader = &mut self.gguf_loader;
2414 let packed_bytes_cache = &mut self.packed_bytes_cache;
2415 let compiled = get_or_specialize_hir_with_options(
2416 cache,
2417 &config,
2418 || {
2419 build_qwen35_prefill_hidden_cache_hir_dynamic_ext(
2420 &cfg,
2421 weights,
2422 1,
2423 seq,
2424 runtime_mrope,
2425 mtp_logits_path,
2426 fast_mtp,
2427 fast_greedy,
2428 )
2429 .expect("dynamic hidden prefill HIR")
2430 .0
2431 },
2432 &compile_opts,
2433 |c| {
2434 for (name, data) in hidden_params {
2435 c.set_param(name, data);
2436 }
2437 upload_packed_opt(c, gguf_loader.as_mut(), hidden_packed, packed_bytes_cache)
2438 },
2439 )?;
2440 let outs = compiled.run(&feeds);
2441 let prompt_lens = vec![seq];
2442 let (trunk, mut cache, mtp_logits) = seed_cache_from_outputs(
2443 &self.cfg,
2444 self.batch,
2445 seq,
2446 &prompt_lens,
2447 outs,
2448 self.mtp_logits_path,
2449 self.fast_greedy_lm_head,
2450 )?;
2451 zero_prompt_padding_kv(&self.cfg, &mut cache, seq);
2452 self.decode_cache = Some(cache.clone());
2453 Ok((trunk, cache, mtp_logits))
2454 }
2455
2456 pub fn decode_get_logits(&mut self, token: u32) -> Result<Vec<f32>> {
2458 self.decode_forward_logits(token, false)
2459 }
2460
2461 pub fn decode_get_mtp_logits(&mut self, token: u32) -> Result<Vec<f32>> {
2463 if !self.mtp_logits_path {
2464 bail!("qwen35: decode_get_mtp_logits requires mtp_logits_path(true)");
2465 }
2466 self.decode_forward_logits(token, true)
2467 }
2468
2469 pub fn prefill_get_mtp_logits(&mut self, prompt_ids: &[u32]) -> Result<Vec<f32>> {
2471 self.predict_logits(prompt_ids)?
2472 .mtp_logits
2473 .ok_or_else(|| anyhow!("qwen35: MTP logits unavailable (enable_mtp?)"))
2474 }
2475
2476 fn decode_step(
2477 &mut self,
2478 cache: &mut Qwen35DecodeCache,
2479 tokens: &[u32],
2480 generated_per_row: &[usize],
2481 opts: SampleOpts,
2482 ) -> Result<Vec<u32>> {
2483 if self.fast_greedy_lm_head {
2484 let vocab = self.lm_vocab_size();
2485 let (trunk, _mtp) = self.decode_step_trunk_raw(cache, tokens, generated_per_row)?;
2486 if opts.greedy {
2487 return self.argmax_batch_from_hidden(&trunk);
2488 }
2489 if sample_lm_cap(opts, vocab) < vocab {
2490 return self.sample_batch_from_hidden(&trunk, opts);
2491 }
2492 }
2493 let logits = self.decode_forward_logits_batch(cache, tokens, generated_per_row, false)?;
2494 Ok(sample_logits_batch(
2495 &logits,
2496 self.lm_vocab_size(),
2497 self.batch,
2498 opts,
2499 ))
2500 }
2501
2502 fn decode_forward_logits(&mut self, token: u32, want_mtp: bool) -> Result<Vec<f32>> {
2503 let mut cache = self
2504 .decode_cache
2505 .take()
2506 .ok_or_else(|| anyhow!("qwen35: decode requires seeded cache"))?;
2507 let row_gen = vec![0usize; self.batch];
2508 let logits = self.decode_forward_logits_batch(&mut cache, &[token], &row_gen, want_mtp)?;
2509 self.decode_cache = Some(cache);
2510 Ok(logits)
2511 }
2512
2513 fn decode_forward_logits_batch(
2514 &mut self,
2515 cache: &mut Qwen35DecodeCache,
2516 tokens: &[u32],
2517 generated_per_row: &[usize],
2518 want_mtp: bool,
2519 ) -> Result<Vec<f32>> {
2520 let past_seq = cache.past_seq;
2521 let head_half = self.cfg.key_length / 2;
2522 let (cos, sin) = mrope_slice_at_pos(&self.cfg, past_seq, head_half);
2523
2524 let use_bucket = self
2525 .decode_compile_cache
2526 .as_ref()
2527 .and_then(|c| c.bucket_for(past_seq as u64))
2528 .is_some();
2529
2530 if use_bucket {
2531 let (logits, mtp_logits) =
2532 self.decode_step_bucketed(cache, tokens, generated_per_row, &cos, &sin)?;
2533 if want_mtp {
2534 mtp_logits.ok_or_else(|| anyhow!("mtp decode logits missing from bucketed graph"))
2535 } else {
2536 Ok(logits)
2537 }
2538 } else {
2539 let feeds_owned = decode_step_feeds(
2540 &self.cfg,
2541 cache,
2542 tokens,
2543 &cos,
2544 &sin,
2545 None,
2546 generated_per_row,
2547 )?;
2548 let feeds: Vec<(&str, &[f32])> = feeds_owned
2549 .iter()
2550 .map(|(k, v)| (k.as_str(), v.as_slice()))
2551 .collect();
2552 if !self.decode_graphs.contains_key(&past_seq) {
2553 let (hir, params, packed) = build_qwen35_decode_hir_ext(
2554 &self.cfg,
2555 self.weights.clone(),
2556 self.batch,
2557 past_seq,
2558 false,
2559 self.mtp_logits_path,
2560 self.fast_mtp,
2561 self.fast_greedy_lm_head,
2562 )?;
2563 let mut compiled = self.compile_hir_for_config(
2564 decode_config(self.batch, past_seq),
2565 &format!("decode_{past_seq}"),
2566 hir,
2567 )?;
2568 for (name, data) in ¶ms {
2569 compiled.set_param(name, data);
2570 }
2571 upload_packed_opt(
2572 &mut compiled,
2573 self.gguf_loader.as_mut(),
2574 &packed,
2575 &mut self.packed_bytes_cache,
2576 )?;
2577 self.decode_graphs.insert(past_seq, compiled);
2578 }
2579 let step = self.moe_refresh_step;
2580 let has_moe = self.moe_offload.is_some();
2581 let num_experts = self.cfg.num_experts;
2582 let moe_masks = self
2583 .moe_offload
2584 .as_ref()
2585 .map(|m| m.per_layer_resident_masks());
2586 self.bind_moe_host_weights();
2587 let outs = {
2588 let compiled = self.decode_graphs.get_mut(&past_seq).unwrap();
2589 if has_moe {
2590 compiled.enable_moe_topk_capture(num_experts);
2591 if let Some(layers) = &moe_masks {
2592 push_moe_residency(compiled, layers);
2593 }
2594 }
2595 compiled.run(&feeds)
2596 };
2597 if has_moe {
2598 let layers = {
2599 let compiled = self.decode_graphs.get_mut(&past_seq).unwrap();
2600 compiled.take_moe_topk_capture()
2601 };
2602 if let (Some(mo), Some(layers)) = (self.moe_offload.as_mut(), layers) {
2603 let store = self.moe_store.as_ref();
2604 let compiled = self.decode_graphs.get_mut(&past_seq).unwrap();
2605 if refresh_moe_from_capture(mo, store, compiled, &layers, step, false) {
2606 if let Some(store) = self.moe_store.as_ref() {
2607 store.apply_to_compiled(compiled);
2608 }
2609 }
2610 }
2611 }
2612 self.moe_refresh_step = step.saturating_add(1);
2613 let (trunk, mtp_logits) = advance_cache_from_decode_outputs(
2614 &self.cfg,
2615 cache,
2616 outs,
2617 None,
2618 self.mtp_logits_path,
2619 want_mtp,
2620 self.fast_greedy_lm_head,
2621 )?;
2622 let logits = self.trunk_to_logits(trunk, self.fast_greedy_lm_head)?;
2623 if want_mtp {
2624 mtp_logits.ok_or_else(|| anyhow!("mtp decode logits missing from decode graph"))
2625 } else {
2626 Ok(logits)
2627 }
2628 }
2629 }
2630
2631 fn decode_step_bucketed(
2632 &mut self,
2633 cache: &mut Qwen35DecodeCache,
2634 tokens: &[u32],
2635 generated_per_row: &[usize],
2636 cos: &[f32],
2637 sin: &[f32],
2638 ) -> Result<(Vec<f32>, Option<Vec<f32>>)> {
2639 let (trunk, mtp) =
2640 self.decode_step_bucketed_raw(cache, tokens, generated_per_row, cos, sin)?;
2641 let logits = self.trunk_to_logits(trunk, self.fast_greedy_lm_head)?;
2642 Ok((logits, mtp))
2643 }
2644
2645 fn decode_step_bucketed_raw(
2646 &mut self,
2647 cache: &mut Qwen35DecodeCache,
2648 tokens: &[u32],
2649 generated_per_row: &[usize],
2650 cos: &[f32],
2651 sin: &[f32],
2652 ) -> Result<(Vec<f32>, Option<Vec<f32>>)> {
2653 let past_seq = cache.past_seq;
2654 let upper = self.ensure_decode_bucket_compiled(past_seq as u64)?;
2655
2656 let feeds_owned = decode_step_feeds(
2657 &self.cfg,
2658 cache,
2659 tokens,
2660 cos,
2661 sin,
2662 Some(upper),
2663 generated_per_row,
2664 )?;
2665 let feeds: Vec<(&str, &[f32])> = feeds_owned
2666 .iter()
2667 .map(|(k, v)| (k.as_str(), v.as_slice()))
2668 .collect();
2669
2670 let decode_opts = self.bucketed_decode_compile_options();
2671 let cache_mut = self.decode_compile_cache.as_mut().unwrap();
2672 let (_u, compiled) = cache_mut
2673 .ensure_hir_with_params(
2674 past_seq as u64,
2675 |_| panic!("decode bucket must be compiled"),
2676 &decode_opts,
2677 )
2678 .expect("decode bucket missing after ensure");
2679 compiled.set_active_extent(Some((past_seq + 1, upper + 1)));
2681 let outs = compiled.run(&feeds);
2682 compiled.set_active_extent(None);
2683 advance_cache_from_decode_outputs(
2684 &self.cfg,
2685 cache,
2686 outs,
2687 Some(upper),
2688 self.mtp_logits_path,
2689 self.mtp_logits_path,
2690 self.fast_greedy_lm_head,
2691 )
2692 }
2693
2694 fn mrope_prefill_rope_feeds(&self, seq: usize) -> Vec<(String, Vec<f32>)> {
2695 if !self.runtime_mrope {
2696 return Vec::new();
2697 }
2698 let head_half = self.cfg.key_length / 2;
2699 let sections = self.mrope_section_positions.as_deref();
2700 let (cos, sin) = mrope_prefill_feeds(&self.cfg, seq, sections, head_half);
2701 vec![("rope_cos".into(), cos), ("rope_sin".into(), sin)]
2702 }
2703}
2704
2705impl rlx_cli::LmRunner for Qwen35Runner {
2706 fn family(&self) -> &'static str {
2707 "qwen35"
2708 }
2709 fn vocab_size(&self) -> usize {
2710 self.lm_vocab_size()
2711 }
2712 fn predict_logits(&mut self, prompt_ids: &[u32]) -> anyhow::Result<Vec<f32>> {
2713 let out = Qwen35Runner::predict_logits(self, prompt_ids)?;
2714 Ok(out.logits)
2715 }
2716 fn generate(
2717 &mut self,
2718 prompt_ids: &[u32],
2719 n_new: usize,
2720 on_token: &mut dyn FnMut(u32) -> bool,
2721 ) -> anyhow::Result<Vec<u32>> {
2722 Qwen35Runner::generate(self, prompt_ids, n_new, on_token)
2726 }
2727
2728 fn supports_multimodal(&self) -> bool {
2729 self.has_mmproj()
2733 }
2734
2735 fn generate_multimodal(
2736 &mut self,
2737 prompt: &str,
2738 rgb: &[u8],
2739 img_w: usize,
2740 img_h: usize,
2741 tokenizer: Option<&std::path::Path>,
2742 n_new: usize,
2743 on_token: &mut dyn FnMut(u32) -> bool,
2744 ) -> anyhow::Result<Vec<u32>> {
2745 Qwen35Runner::generate_multimodal(
2746 self, prompt, rgb, img_w, img_h, tokenizer, n_new, on_token,
2747 )
2748 }
2749}
2750
2751fn sample_logits_batch(logits: &[f32], vocab: usize, batch: usize, opts: SampleOpts) -> Vec<u32> {
2752 let mut out = Vec::with_capacity(batch);
2753 for b in 0..batch {
2754 let row = &logits[b * vocab..(b + 1) * vocab];
2755 out.push(sample_token(row, opts) as u32);
2756 }
2757 out
2758}