1use crate::cache::{
19 Qwen35DecodeCache, advance_cache_from_decode_outputs, decode_step_feeds, last_token_indices,
20 pack_input_ids, seed_cache_from_outputs, zero_prompt_padding_kv, zero_recurrent_inputs,
21};
22use crate::capabilities::validate_device;
23use crate::config::Qwen35Config;
24use crate::encode_prompt_auto;
25use crate::lm_head::{
26 greedy_lm_head_argmax, lm_head_logits_row, sample_lm_cap, sample_lm_head_from_hidden,
27};
28use crate::moe_offload::{MoeOffloadState, build_moe_offload};
29use crate::moe_store::{build_moe_expert_store, moe_host_bind_from_store};
30use crate::rope::{mrope_prefill_feeds, mrope_slice_at_pos};
31use crate::vision::{
32 MmProjConfig, MmProjWeights, MultimodalPrefill, MultimodalPrompt, Qwen35VisionEncoder,
33 load_vision_encoder,
34};
35use crate::weights::Qwen35Weights;
36use crate::{
37 PackedParams, build_qwen35_decode_hir_dynamic_ext, build_qwen35_decode_hir_ext,
38 build_qwen35_hir_sized_ext, build_qwen35_prefill_cache_hir_dynamic_ext,
39 build_qwen35_prefill_hidden_cache_hir_dynamic_ext,
40};
41use rlx_runtime::MoeExpertStore;
42
43fn push_moe_residency(compiled: &mut rlx_runtime::CompiledGraph, layers: &[Vec<bool>]) {
44 let refs: Vec<&[bool]> = layers.iter().map(|m| m.as_slice()).collect();
45 compiled.set_moe_resident_experts_per_layer(&refs);
46}
47
48fn refresh_moe_from_capture(
49 mo: &mut MoeOffloadState,
50 store: Option<&MoeExpertStore>,
51 compiled: &mut rlx_runtime::CompiledGraph,
52 layer_indices: &[Vec<u32>],
53 denoise_step: usize,
54 is_prefill_block: bool,
55) -> bool {
56 let refreshed = if let Some(store) = store {
57 mo.refresh_from_capture_with_store(store, layer_indices, denoise_step, is_prefill_block)
58 } else {
59 mo.refresh_from_capture(layer_indices, denoise_step, is_prefill_block)
60 };
61 if refreshed {
62 push_moe_residency(compiled, &mo.per_layer_resident_masks());
63 }
64 refreshed
65}
66use crate::execution::{
67 Qwen35CompileCache, decode_config, get_or_specialize_hir_with_options, hidden_prefill_config,
68 prefill_config,
69};
70use crate::flow::{Qwen35PrefillCacheOpts, build_qwen35_prefill_cache_built};
71use crate::profile::{qwen35_profile_default, qwen35_profile_near_weights};
72use anyhow::{Context, Result, anyhow, bail};
73use rlx_core::flow_bridge::compile_options_from_profile;
74use rlx_core::gguf_support::{GgufModelFamily, assert_gguf_family, resolve_weights_file};
75use rlx_core::weight_loader::GgufLoader;
76use rlx_flow::ModelExecutionConfig;
77use rlx_flow::{CompileProfile, ExecutionPreset};
78use rlx_ir::CompilationMode;
79use rlx_ir::logical_kernel::KernelDispatchConfig;
80use rlx_qwen3::sampling::{SampleOpts, sample_token};
81use rlx_runtime::compile_cache::BucketedCompileCache;
82use rlx_runtime::{AotCache, CompileOptions, Device, Session};
83use std::cell::RefCell;
84use std::collections::HashMap;
85use std::path::PathBuf;
86use std::sync::Arc;
87use std::time::Instant;
88
89#[derive(Debug, Clone, Default)]
98pub enum Qwen35ConfigSource {
99 #[default]
101 Embedded,
102 JsonFile(PathBuf),
104 Explicit(Qwen35Config),
106}
107
108#[derive(Default, Debug)]
109pub struct Qwen35RunnerBuilder {
110 weights: Option<PathBuf>,
111 config: Option<Qwen35ConfigSource>,
112 device: Option<Device>,
113 max_seq: Option<usize>,
114 enable_mtp: bool,
115 last_logits_only: bool,
116 packed_weights: Option<bool>,
121 runtime_mrope: bool,
122 mrope_section_positions: Option<Vec<[usize; 4]>>,
123 batch: Option<usize>,
124 bucketed_decode: Option<bool>,
125 mtp_logits_path: bool,
127 fast_mtp: bool,
128 fast_greedy_lm_head: Option<bool>,
130 aot_cache_dir: Option<PathBuf>,
132 dynamic_prefill: bool,
134 dynamic_decode: bool,
136 inline_weights: Option<(Qwen35Config, Qwen35Weights)>,
137 mmproj: Option<PathBuf>,
139 inline_mmproj: Option<(crate::vision::MmProjConfig, crate::vision::MmProjWeights)>,
141 prefill_profile: Option<CompileProfile>,
143 decode_profile: Option<CompileProfile>,
145 max_gpu_experts_per_layer: Option<usize>,
147 moe_memory_budget_bytes: Option<usize>,
149 expert_refresh_every_decode_steps: Option<usize>,
151 jump_steps: Option<usize>,
153 reserve_vram_gb: Option<f64>,
155 moe_collect_stats: bool,
157}
158
159impl Qwen35RunnerBuilder {
160 pub fn weights(mut self, path: impl Into<PathBuf>) -> Self {
161 self.weights = Some(path.into());
162 self
163 }
164
165 pub fn config(mut self, src: Qwen35ConfigSource) -> Self {
171 self.config = Some(src);
172 self
173 }
174
175 pub fn config_value(self, cfg: Qwen35Config) -> Self {
178 self.config(Qwen35ConfigSource::Explicit(cfg))
179 }
180 pub fn device(mut self, d: Device) -> Self {
181 self.device = Some(d);
182 self
183 }
184 pub fn max_seq(mut self, n: usize) -> Self {
185 self.max_seq = Some(n);
186 self
187 }
188 pub fn enable_mtp(mut self, on: bool) -> Self {
189 self.enable_mtp = on;
190 self
191 }
192 pub fn last_logits_only(mut self, on: bool) -> Self {
193 self.last_logits_only = on;
194 self
195 }
196 pub fn packed_weights(mut self, on: bool) -> Self {
197 self.packed_weights = Some(on);
198 self
199 }
200 pub fn runtime_mrope(mut self, on: bool) -> Self {
203 self.runtime_mrope = on;
204 self
205 }
206 pub fn mrope_section_positions(mut self, positions: Vec<[usize; 4]>) -> Self {
208 self.mrope_section_positions = Some(positions);
209 self
210 }
211 pub fn batch(mut self, n: usize) -> Self {
213 self.batch = Some(n);
214 self
215 }
216 pub fn bucketed_decode(mut self, on: bool) -> Self {
218 self.bucketed_decode = Some(on);
219 self
220 }
221 pub fn mtp_logits_path(mut self, on: bool) -> Self {
223 self.mtp_logits_path = on;
224 self
225 }
226 pub fn fast_mtp(mut self, on: bool) -> Self {
228 self.fast_mtp = on;
229 self
230 }
231 pub fn fast_greedy_lm_head(mut self, on: bool) -> Self {
233 self.fast_greedy_lm_head = Some(on);
234 self
235 }
236 pub fn aot_cache_dir(mut self, path: impl Into<PathBuf>) -> Self {
238 self.aot_cache_dir = Some(path.into());
239 self
240 }
241 pub fn dynamic_prefill(mut self, on: bool) -> Self {
243 self.dynamic_prefill = on;
244 self
245 }
246 pub fn dynamic_decode(mut self, on: bool) -> Self {
248 self.dynamic_decode = on;
249 self
250 }
251 pub fn inline_weights(mut self, cfg: Qwen35Config, weights: Qwen35Weights) -> Self {
253 self.inline_weights = Some((cfg, weights));
254 self
255 }
256
257 pub fn mmproj(mut self, path: impl Into<PathBuf>) -> Self {
259 self.mmproj = Some(path.into());
260 self
261 }
262
263 pub fn inline_mmproj(
265 mut self,
266 cfg: crate::vision::MmProjConfig,
267 weights: crate::vision::MmProjWeights,
268 ) -> Self {
269 self.inline_mmproj = Some((cfg, weights));
270 self
271 }
272
273 pub fn with_compile_profiles(
275 mut self,
276 prefill: CompileProfile,
277 decode: CompileProfile,
278 ) -> Self {
279 self.prefill_profile = Some(prefill);
280 self.decode_profile = Some(decode);
281 self
282 }
283
284 pub fn max_gpu_experts_per_layer(mut self, n: usize) -> Self {
286 self.max_gpu_experts_per_layer = Some(n);
287 self
288 }
289
290 pub fn moe_memory_budget_bytes(mut self, bytes: usize) -> Self {
292 self.moe_memory_budget_bytes = Some(bytes);
293 self
294 }
295
296 pub fn expert_refresh_every_decode_steps(mut self, n: usize) -> Self {
298 self.expert_refresh_every_decode_steps = Some(n);
299 self.jump_steps = Some(n);
300 self
301 }
302
303 pub fn jump_steps(mut self, n: usize) -> Self {
305 self.jump_steps = Some(n);
306 self.expert_refresh_every_decode_steps = Some(n);
307 self
308 }
309
310 pub fn reserve_vram_gb(mut self, gb: f64) -> Self {
312 self.reserve_vram_gb = Some(gb);
313 self
314 }
315
316 pub fn moe_collect_stats(mut self, on: bool) -> Self {
318 self.moe_collect_stats = on;
319 self
320 }
321
322 pub fn enable_predictive_expert_offload(mut self, max_gpu_experts_per_layer: usize) -> Self {
324 self.max_gpu_experts_per_layer = Some(max_gpu_experts_per_layer);
325 self
326 }
327
328 pub fn build(self) -> Result<Qwen35Runner> {
329 let device = self.device.unwrap_or(Device::Cpu);
330 let max_seq = self.max_seq.unwrap_or(128);
331 let batch = self.batch.unwrap_or(1);
332 if batch == 0 {
333 bail!("qwen35: batch must be >= 1");
334 }
335
336 if let Some(src) = self.config.as_ref()
345 && !matches!(src, Qwen35ConfigSource::Embedded)
346 {
347 let weights_path = self
348 .weights
349 .as_ref()
350 .ok_or_else(|| anyhow!("weights path required (call .weights(...))"))?
351 .clone();
352 let resolved = resolve_weights_file(&weights_path)?;
353 let ext = resolved
354 .extension()
355 .and_then(|s| s.to_str())
356 .unwrap_or("")
357 .to_ascii_lowercase();
358 if ext == "gguf" {
359 bail!(
360 "qwen35: Qwen35ConfigSource::{:?} supplied with a GGUF weights file at \
361 {:?} — drop the config source (use the default Embedded) so the GGUF \
362 metadata is the source of truth",
363 src,
364 resolved
365 );
366 }
367 if self.packed_weights == Some(true) {
368 bail!("qwen35: packed_weights requires GGUF; safetensors path is dequant-only");
369 }
370 let cfg = match src {
371 Qwen35ConfigSource::Embedded => unreachable!(),
372 Qwen35ConfigSource::JsonFile(p) => Qwen35Config::from_hf_config_json(p)
373 .with_context(|| format!("qwen35: parse HF config {p:?}"))?,
374 Qwen35ConfigSource::Explicit(cfg) => cfg.clone(),
375 };
376 if self.enable_mtp && cfg.nextn_predict_layers == 0 {
377 bail!(
378 "qwen35: enable_mtp(true) but config has \
379 nextn_predict_layers=0 (no MTP heads to wire)"
380 );
381 }
382 validate_device(&cfg, device, false)?;
383 let path_str = resolved
384 .to_str()
385 .ok_or_else(|| anyhow!("non-utf8 weights path"))?;
386 let inner_map = rlx_core::weight_map::WeightMap::from_file(path_str)
387 .with_context(|| format!("qwen35: load safetensors {resolved:?}"))?;
388 let mut loader = rlx_core::HfTranslatingLoader::new(inner_map);
389 let t = std::time::Instant::now();
390 let weights = Qwen35Weights::from_loader(&mut loader, &cfg)?;
391 eprintln!(
392 "[qwen35] read safetensors weights in {:.2?} \
393 (layers={}, hidden={})",
394 t.elapsed(),
395 cfg.num_hidden_layers,
396 cfg.hidden_size,
397 );
398 return finish_build(
399 cfg,
400 weights,
401 resolved,
402 None,
403 device,
404 max_seq,
405 batch,
406 self.enable_mtp,
407 self.last_logits_only,
408 self.runtime_mrope,
409 self.mrope_section_positions,
410 self.bucketed_decode,
411 self.mtp_logits_path,
412 self.fast_mtp,
413 self.fast_greedy_lm_head.unwrap_or(true),
414 self.aot_cache_dir.clone(),
415 self.dynamic_prefill,
416 self.dynamic_decode,
417 self.mmproj.clone(),
418 self.inline_mmproj,
419 self.prefill_profile,
420 self.decode_profile,
421 self.max_gpu_experts_per_layer,
422 self.moe_memory_budget_bytes,
423 self.jump_steps.or(self.expert_refresh_every_decode_steps),
424 self.reserve_vram_gb.unwrap_or(1.5),
425 self.moe_collect_stats,
426 );
427 }
428
429 if let Some((cfg, weights)) = self.inline_weights {
430 if self.packed_weights == Some(true) {
431 bail!("qwen35: inline_weights and packed_weights are mutually exclusive");
432 }
433 if self.enable_mtp && cfg.nextn_predict_layers == 0 {
434 bail!(
435 "qwen35: enable_mtp(true) but config has \
436 nextn_predict_layers=0 (no MTP heads to wire)"
437 );
438 }
439 if self.mmproj.is_some() && self.inline_mmproj.is_some() {
440 bail!("qwen35: mmproj and inline_mmproj are mutually exclusive");
441 }
442 validate_device(&cfg, device, false)?;
443 return finish_build(
444 cfg,
445 weights,
446 PathBuf::new(),
447 None,
448 device,
449 max_seq,
450 batch,
451 self.enable_mtp,
452 self.last_logits_only,
453 self.runtime_mrope,
454 self.mrope_section_positions,
455 self.bucketed_decode,
456 self.mtp_logits_path,
457 self.fast_mtp,
458 self.fast_greedy_lm_head.unwrap_or(true),
459 self.aot_cache_dir.clone(),
460 self.dynamic_prefill,
461 self.dynamic_decode,
462 self.mmproj.clone(),
463 self.inline_mmproj,
464 self.prefill_profile,
465 self.decode_profile,
466 self.max_gpu_experts_per_layer,
467 self.moe_memory_budget_bytes,
468 self.jump_steps.or(self.expert_refresh_every_decode_steps),
469 self.reserve_vram_gb.unwrap_or(1.5),
470 self.moe_collect_stats,
471 );
472 }
473
474 let weights_path = resolve_weights_file(
475 &self
476 .weights
477 .ok_or_else(|| anyhow!("weights path required (call .weights(...))"))?,
478 )?;
479 let _t_total = Instant::now();
480 let t = Instant::now();
481 let raw = assert_gguf_family(&weights_path, GgufModelFamily::Qwen35)?;
482 let mut loader = GgufLoader::from_file(
483 weights_path
484 .to_str()
485 .ok_or_else(|| anyhow!("non-utf8 weights path"))?,
486 )?;
487 loader.include_mtp(true);
488 let cfg = Qwen35Config::from_gguf(&raw)?;
489 eprintln!(
490 "[qwen35] loaded GGUF metadata in {:.2?} \
491 (layers={}, hidden={}, ssm_state={})",
492 t.elapsed(),
493 cfg.num_hidden_layers,
494 cfg.hidden_size,
495 cfg.ssm_state_size,
496 );
497
498 if self.enable_mtp && cfg.nextn_predict_layers == 0 {
499 bail!(
500 "qwen35: enable_mtp(true) but the file has \
501 nextn_predict_layers=0 (no MTP heads to wire)"
502 );
503 }
504 let packed = self.packed_weights.unwrap_or_else(|| {
512 if raw.tensors.values().any(|t| {
513 matches!(
514 t.dtype,
515 rlx_gguf::GgmlType::Q2K
516 | rlx_gguf::GgmlType::Q3K
517 | rlx_gguf::GgmlType::Q4K
518 | rlx_gguf::GgmlType::Q5K
519 | rlx_gguf::GgmlType::Q6K
520 | rlx_gguf::GgmlType::Q8K
521 )
522 }) {
523 return true;
524 }
525 std::fs::metadata(&weights_path)
526 .ok()
527 .map(|m| m.len() >= 256 * 1024 * 1024)
528 .unwrap_or(false)
529 });
530 validate_device(&cfg, device, packed)?;
531
532 let t = Instant::now();
533 let weights = if packed {
534 Qwen35Weights::from_loader_packed(&mut loader, &cfg)?
535 } else {
536 Qwen35Weights::from_loader(&mut loader, &cfg)?
537 };
538 eprintln!(
539 "[qwen35] read weights ({}) in {:.2?}",
540 if packed { "packed" } else { "F32" },
541 t.elapsed(),
542 );
543
544 finish_build(
545 cfg,
546 weights,
547 weights_path,
548 Some(loader),
549 device,
550 max_seq,
551 batch,
552 self.enable_mtp,
553 self.last_logits_only,
554 self.runtime_mrope,
555 self.mrope_section_positions,
556 self.bucketed_decode,
557 self.mtp_logits_path,
558 self.fast_mtp,
559 self.fast_greedy_lm_head.unwrap_or(true),
560 self.aot_cache_dir.clone(),
561 self.dynamic_prefill,
562 self.dynamic_decode,
563 self.mmproj.clone(),
564 self.inline_mmproj,
565 self.prefill_profile,
566 self.decode_profile,
567 self.max_gpu_experts_per_layer,
568 self.moe_memory_budget_bytes,
569 self.jump_steps.or(self.expert_refresh_every_decode_steps),
570 self.reserve_vram_gb.unwrap_or(1.5),
571 self.moe_collect_stats,
572 )
573 }
574}
575
576fn make_qwen35_dyn_cache(
577 device: Device,
578 capacity: usize,
579 aot_cache_dir: Option<&std::path::Path>,
580) -> Qwen35CompileCache {
581 if let Some(dir) = aot_cache_dir {
582 Qwen35CompileCache::with_aot(device, capacity, dir)
583 } else {
584 Qwen35CompileCache::new(device, capacity)
585 }
586}
587
588fn compile_static_prefill_cache(
590 cfg: &Qwen35Config,
591 weights: Qwen35Weights,
592 batch: usize,
593 max_seq: usize,
594 device: Device,
595 prefill_profile: &CompileProfile,
596 runtime_mrope: bool,
597 enable_mtp_head: bool,
598 fast_mtp: bool,
599 fast_greedy_lm_head: bool,
600 aot_cache_dir: Option<&std::path::Path>,
601) -> Result<(
602 rlx_runtime::CompiledGraph,
603 HashMap<String, Vec<f32>>,
604 PackedParams,
605)> {
606 let mut flow_opts = Qwen35PrefillCacheOpts::static_cache(batch, max_seq);
607 flow_opts.with_lm_head = !fast_greedy_lm_head;
608 flow_opts.runtime_mrope = runtime_mrope;
609 flow_opts.enable_mtp_head = enable_mtp_head;
610 flow_opts.fast_mtp = fast_mtp;
611 flow_opts.fast_greedy_lm_head = fast_greedy_lm_head;
612 flow_opts.profile = Some(prefill_profile.clone());
613
614 let (built, packed) = build_qwen35_prefill_cache_built(cfg, weights, &flow_opts)?;
615 let params = built.params().clone();
616 let config = prefill_config(batch, max_seq);
617 let compile_opts =
618 compile_options_from_profile(prefill_profile, device, KernelDispatchConfig::default());
619
620 let mut cache = match aot_cache_dir {
621 Some(dir) => Qwen35CompileCache::with_aot(device, 1, dir),
622 None => Qwen35CompileCache::new(device, 1),
623 };
624 let mut config = config;
625 if aot_cache_dir.is_some() {
626 config = config.with_compilation_mode(CompilationMode::Aot);
627 }
628 let built = built.with_execution_config(&config);
629 let compiled = cache.compile_built(built, &config, &compile_opts)?;
630 Ok((compiled, params, packed))
631}
632
633#[allow(clippy::too_many_arguments)]
634fn finish_build(
635 cfg: Qwen35Config,
636 weights: Qwen35Weights,
637 weights_path: PathBuf,
638 gguf_loader: Option<GgufLoader>,
639 device: Device,
640 max_seq: usize,
641 batch: usize,
642 enable_mtp: bool,
643 last_logits_only: bool,
644 runtime_mrope: bool,
645 mrope_section_positions: Option<Vec<[usize; 4]>>,
646 bucketed_decode: Option<bool>,
647 mtp_logits_path: bool,
648 fast_mtp: bool,
649 fast_greedy_lm_head: bool,
650 aot_cache_dir: Option<PathBuf>,
651 dynamic_prefill: bool,
652 dynamic_decode: bool,
653 mmproj_path: Option<PathBuf>,
654 inline_mmproj: Option<(MmProjConfig, MmProjWeights)>,
655 prefill_profile_override: Option<CompileProfile>,
656 decode_profile_override: Option<CompileProfile>,
657 max_gpu_experts_per_layer: Option<usize>,
658 moe_memory_budget_bytes: Option<usize>,
659 jump_steps: Option<usize>,
660 reserve_vram_gb: f64,
661 moe_collect_stats: bool,
662) -> Result<Qwen35Runner> {
663 let prefill_profile = prefill_profile_override.unwrap_or_else(|| {
664 if weights_path.as_os_str().is_empty() {
665 qwen35_profile_default(false)
666 } else {
667 qwen35_profile_near_weights(&weights_path, false)
668 }
669 });
670 let decode_profile = decode_profile_override.unwrap_or_else(|| {
671 if weights_path.as_os_str().is_empty() {
672 qwen35_profile_default(true)
673 } else {
674 qwen35_profile_near_weights(&weights_path, true)
675 }
676 });
677
678 if fast_mtp && !mtp_logits_path && !enable_mtp {
679 bail!("qwen35: fast_mtp requires enable_mtp(true) or mtp_logits_path(true)");
680 }
681 if mtp_logits_path && !enable_mtp {
682 bail!("qwen35: mtp_logits_path requires enable_mtp(true)");
683 }
684
685 if dynamic_prefill && batch != 1 {
686 bail!("qwen35: dynamic_prefill requires batch=1");
687 }
688 if dynamic_decode && batch != 1 {
689 bail!("qwen35: dynamic_decode requires batch=1");
690 }
691 if dynamic_decode && bucketed_decode.unwrap_or(true) {
692 eprintln!("[qwen35] dynamic_decode enabled — disabling bucketed decode cache");
693 }
694 let bucketed_decode = if dynamic_decode {
695 false
696 } else {
697 bucketed_decode.unwrap_or(true)
698 };
699
700 let vision_encoder = if let Some(ref path) = mmproj_path {
701 Some(load_vision_encoder(
702 path.to_str()
703 .ok_or_else(|| anyhow!("non-utf8 mmproj path"))?,
704 224,
705 224,
706 )?)
707 } else if let Some((vcfg, vweights)) = inline_mmproj {
708 Some(Qwen35VisionEncoder::from_parts(vcfg, vweights, 4, 4)?)
709 } else {
710 None
711 };
712 let runtime_mrope = runtime_mrope || vision_encoder.is_some();
713 if vision_encoder.is_some() && batch != 1 {
714 bail!("qwen35: VLM (mmproj) requires batch=1");
715 }
716 if vision_encoder.is_some() && !dynamic_prefill {
717 eprintln!("[qwen35] mmproj loaded — enabling dynamic prefill for variable multimodal seq");
718 }
719 let dynamic_prefill = dynamic_prefill || vision_encoder.is_some();
720
721 let t = Instant::now();
722 let aot = aot_cache_dir.as_ref().map(AotCache::new);
723 let (cache_params, cache_packed, mut prefill_cache, prefill_dynamic_cache) = if dynamic_prefill
724 {
725 let (_cache_hir, cache_params, cache_packed) = build_qwen35_prefill_cache_hir_dynamic_ext(
726 &cfg,
727 weights.clone(),
728 batch,
729 max_seq,
730 runtime_mrope,
731 mtp_logits_path,
732 fast_mtp,
733 fast_greedy_lm_head,
734 )?;
735 eprintln!(
736 "[qwen35] built prefill-cache IR in {:.2?} (params={}, packed={})",
737 t.elapsed(),
738 cache_params.len(),
739 cache_packed.len(),
740 );
741 eprintln!("[qwen35] dynamic prefill template ready (compile on first prompt)");
742 (
743 cache_params,
744 cache_packed,
745 None,
746 Some(make_qwen35_dyn_cache(device, 32, aot_cache_dir.as_deref())),
747 )
748 } else {
749 let (compiled, cache_params, cache_packed) = compile_static_prefill_cache(
750 &cfg,
751 weights.clone(),
752 batch,
753 max_seq,
754 device,
755 &prefill_profile,
756 runtime_mrope,
757 mtp_logits_path,
758 fast_mtp,
759 fast_greedy_lm_head,
760 aot_cache_dir.as_deref(),
761 )?;
762 eprintln!(
763 "[qwen35] compiled prefill-cache via BuiltModel in {:.2?} (params={}, packed={})",
764 t.elapsed(),
765 cache_params.len(),
766 cache_packed.len(),
767 );
768 (cache_params, cache_packed, Some(compiled), None)
769 };
770
771 let (prefill_hidden_dynamic_cache, prefill_hidden_cache_params, prefill_hidden_cache_packed) =
772 if vision_encoder.is_some() {
773 let (hidden_hir, hidden_params, hidden_packed) =
774 build_qwen35_prefill_hidden_cache_hir_dynamic_ext(
775 &cfg,
776 weights.clone(),
777 batch,
778 max_seq,
779 runtime_mrope,
780 mtp_logits_path,
781 fast_mtp,
782 fast_greedy_lm_head,
783 )?;
784 let _ = hidden_hir;
785 (
786 Some(make_qwen35_dyn_cache(device, 32, aot_cache_dir.as_deref())),
787 hidden_params,
788 hidden_packed,
789 )
790 } else {
791 (None, HashMap::new(), HashMap::new())
792 };
793
794 let t = Instant::now();
795 if let Some(ref mut compiled) = prefill_cache {
796 for (name, data) in &cache_params {
797 compiled.set_param(name, data);
798 }
799 }
800
801 let decode_compile_cache = if bucketed_decode {
802 Some(BucketedCompileCache::power_of_two_ladder(
803 device,
804 1,
805 max_seq.max(1) as u64,
806 ))
807 } else {
808 None
809 };
810 let decode_dynamic_cache = if dynamic_decode {
811 Some(make_qwen35_dyn_cache(device, 32, aot_cache_dir.as_deref()))
812 } else {
813 None
814 };
815 let (decode_dynamic_params, decode_dynamic_packed) = if dynamic_decode {
816 let (_, p, packed) = build_qwen35_decode_hir_dynamic_ext(
817 &cfg,
818 weights.clone(),
819 batch,
820 max_seq,
821 mtp_logits_path,
822 fast_mtp,
823 fast_greedy_lm_head,
824 )?;
825 (p, packed)
826 } else {
827 (HashMap::new(), HashMap::new())
828 };
829
830 if dynamic_decode {
831 eprintln!("[qwen35] dynamic decode template ready (compile on first step)");
832 }
833
834 let moe_offload = build_moe_offload(
835 &cfg,
836 &weights,
837 max_gpu_experts_per_layer,
838 moe_memory_budget_bytes,
839 jump_steps,
840 reserve_vram_gb,
841 moe_collect_stats,
842 );
843 let moe_store = if moe_offload.is_some() {
844 build_moe_expert_store(&cfg, &weights).ok()
845 } else {
846 None
847 };
848 if let Some(ref mo) = moe_offload {
849 eprintln!(
850 "[qwen35] TIDE MoE offload: layers={} gpu_budget={}/{} jump_steps={} reserve_bytes={}",
851 mo.num_layers(),
852 mo.info.gpu_expert_budget_per_layer,
853 mo.pools[0].num_experts(),
854 mo.jump_steps,
855 mo.info.reserve_bytes,
856 );
857 }
858
859 let mut runner = Qwen35Runner {
860 compiled: None,
861 prefill_cache,
862 prefill_dynamic_cache,
863 prefill_hidden_dynamic_cache,
864 prefill_cache_params: cache_params,
865 prefill_cache_packed: cache_packed,
866 prefill_hidden_cache_params,
867 prefill_hidden_cache_packed,
868 decode_graphs: HashMap::new(),
869 decode_compile_cache,
870 decode_dynamic_cache,
871 predict_hir_cache: None,
872 decode_dynamic_params,
873 decode_dynamic_packed,
874 packed_bytes_cache: HashMap::new(),
875 cfg,
876 device,
877 batch,
878 max_seq,
879 last_logits_only,
880 enable_mtp,
881 mtp_logits_path,
882 fast_mtp,
883 fast_greedy_lm_head,
884 weights,
885 weights_path,
886 gguf_loader,
887 decode_cache: None,
888 runtime_mrope,
889 mrope_section_positions,
890 aot_cache: aot,
891 dynamic_prefill,
892 dynamic_decode,
893 vision_encoder,
894 mmproj_path,
895 prefill_profile,
896 decode_profile,
897 moe_offload,
898 moe_store,
899 moe_refresh_step: 0,
900 };
901
902 if let Some(ref mut compiled) = runner.prefill_cache {
903 upload_packed_opt(
904 compiled,
905 runner.gguf_loader.as_mut(),
906 &runner.prefill_cache_packed,
907 &mut runner.packed_bytes_cache,
908 )?;
909 }
910 eprintln!(
911 "[qwen35] uploaded prefill-cache {} F32 + {} packed params in {:.2?}",
912 runner.prefill_cache_params.len(),
913 runner.prefill_cache_packed.len(),
914 t.elapsed(),
915 );
916
917 runner.warm_decode_graphs()?;
918 runner.warm_predict_graph()?;
919 Ok(runner)
920}
921
922fn ensure_packed_cache(
923 loader: &mut GgufLoader,
924 packed: &PackedParams,
925 cache: &mut HashMap<String, Arc<[u8]>>,
926) -> Result<()> {
927 for (loader_key, _, _) in packed.values() {
928 if cache.contains_key(loader_key) {
929 continue;
930 }
931 let bytes = loader
932 .tensor_bytes_borrowed(loader_key)
933 .ok_or_else(|| anyhow!("packed cache: {loader_key} bytes missing"))?;
934 cache.insert(loader_key.clone(), Arc::from(bytes));
935 }
936 Ok(())
937}
938
939fn upload_packed_opt(
940 compiled: &mut rlx_runtime::CompiledGraph,
941 loader: Option<&mut GgufLoader>,
942 packed: &PackedParams,
943 cache: &mut HashMap<String, Arc<[u8]>>,
944) -> Result<()> {
945 if packed.is_empty() {
946 return Ok(());
947 }
948 let loader = loader
949 .ok_or_else(|| anyhow!("packed params require a GGUF loader (missing weights path)"))?;
950 ensure_packed_cache(loader, packed, cache)?;
951 for (param_name, (loader_key, _scheme, _shape)) in packed {
952 let bytes = cache
953 .get(loader_key)
954 .ok_or_else(|| anyhow!("packed upload: cache miss for {loader_key}"))?;
955 compiled.set_param_typed(param_name, bytes, rlx_ir::DType::U8);
956 }
957 Ok(())
958}
959
960#[allow(dead_code)]
961fn upload_decode_packed(
962 weights_path: &std::path::Path,
963 compiled: &mut rlx_runtime::CompiledGraph,
964 packed: &PackedParams,
965) -> Result<()> {
966 if packed.is_empty() {
967 return Ok(());
968 }
969 let path = weights_path
970 .to_str()
971 .filter(|p| !p.is_empty())
972 .ok_or_else(|| anyhow!("packed decode params require a GGUF weights path"))?;
973 let mut loader = GgufLoader::from_file(path)?;
974 loader.include_mtp(true);
975 upload_packed_opt(compiled, Some(&mut loader), packed, &mut HashMap::new())
976}
977
978pub struct Qwen35Runner {
979 compiled: Option<rlx_runtime::CompiledGraph>,
980 prefill_cache: Option<rlx_runtime::CompiledGraph>,
981 prefill_dynamic_cache: Option<Qwen35CompileCache>,
982 prefill_hidden_dynamic_cache: Option<Qwen35CompileCache>,
983 prefill_cache_params: HashMap<String, Vec<f32>>,
984 prefill_cache_packed: PackedParams,
985 prefill_hidden_cache_params: HashMap<String, Vec<f32>>,
986 prefill_hidden_cache_packed: PackedParams,
987 decode_graphs: HashMap<usize, rlx_runtime::CompiledGraph>,
988 decode_compile_cache: Option<BucketedCompileCache>,
989 decode_dynamic_cache: Option<Qwen35CompileCache>,
990 predict_hir_cache: Option<Qwen35CompileCache>,
992 decode_dynamic_params: HashMap<String, Vec<f32>>,
993 decode_dynamic_packed: PackedParams,
994 packed_bytes_cache: HashMap<String, Arc<[u8]>>,
995 cfg: Qwen35Config,
996 device: Device,
997 batch: usize,
998 max_seq: usize,
999 last_logits_only: bool,
1000 enable_mtp: bool,
1001 mtp_logits_path: bool,
1002 fast_mtp: bool,
1003 fast_greedy_lm_head: bool,
1004 weights: Qwen35Weights,
1005 weights_path: PathBuf,
1006 gguf_loader: Option<GgufLoader>,
1007 decode_cache: Option<Qwen35DecodeCache>,
1008 runtime_mrope: bool,
1009 mrope_section_positions: Option<Vec<[usize; 4]>>,
1010 aot_cache: Option<AotCache>,
1011 dynamic_prefill: bool,
1012 dynamic_decode: bool,
1013 vision_encoder: Option<Qwen35VisionEncoder>,
1014 mmproj_path: Option<PathBuf>,
1015 prefill_profile: CompileProfile,
1016 decode_profile: CompileProfile,
1017 moe_offload: Option<MoeOffloadState>,
1019 moe_store: Option<MoeExpertStore>,
1021 moe_refresh_step: usize,
1023}
1024
1025#[derive(Debug, Clone)]
1026pub struct Qwen35PrefillSeed {
1027 pub trunk_logits: Vec<f32>,
1028 pub mtp_logits: Option<Vec<f32>>,
1029}
1030
1031#[derive(Debug, Clone)]
1032pub struct Qwen35PrefillOutput {
1033 pub logits: Vec<f32>,
1034 pub mtp_logits: Option<Vec<f32>>,
1035 pub vocab_size: usize,
1036}
1037
1038impl Qwen35Runner {
1039 pub fn builder() -> Qwen35RunnerBuilder {
1040 Qwen35RunnerBuilder::default()
1041 }
1042
1043 pub fn has_mmproj(&self) -> bool {
1047 self.mmproj_path.is_some() || self.vision_encoder.is_some()
1048 }
1049
1050 fn execution_config(&self, config: ModelExecutionConfig) -> ModelExecutionConfig {
1052 if self.aot_cache.is_some() {
1053 config.with_compilation_mode(CompilationMode::Aot)
1054 } else {
1055 config
1056 }
1057 }
1058
1059 pub fn prefill_profile(&self) -> &CompileProfile {
1060 &self.prefill_profile
1061 }
1062
1063 pub fn decode_profile(&self) -> &CompileProfile {
1064 &self.decode_profile
1065 }
1066
1067 pub fn moe_offload(&self) -> Option<&MoeOffloadState> {
1069 self.moe_offload.as_ref()
1070 }
1071
1072 pub fn predictive_offload_info(&self) -> Option<&rlx_llada2::tide::PredictiveOffloadInfo> {
1074 self.moe_offload.as_ref().map(|m| &m.info)
1075 }
1076
1077 pub fn get_offload_stats(
1079 &self,
1080 residency: Option<&rlx_runtime::MoeResidencyStats>,
1081 ) -> rlx_llada2::tide::TideOffloadStats {
1082 self.moe_offload
1083 .as_ref()
1084 .map(|m| m.tide_offload_stats(residency))
1085 .unwrap_or_default()
1086 }
1087
1088 pub fn jump_steps(&self) -> usize {
1089 self.moe_offload.as_ref().map(|m| m.jump_steps).unwrap_or(1)
1090 }
1091
1092 pub fn predictive_offload_enabled(&self) -> bool {
1093 self.moe_offload
1094 .as_ref()
1095 .is_some_and(|m| m.predictive_enabled)
1096 }
1097
1098 pub fn moe_offload_mut(&mut self) -> Option<&mut MoeOffloadState> {
1099 self.moe_offload.as_mut()
1100 }
1101
1102 pub fn moe_refresh_step(&self) -> usize {
1104 self.moe_refresh_step
1105 }
1106
1107 pub fn enable_moe_topk_on(&self, compiled: &mut rlx_runtime::CompiledGraph) {
1109 if self.moe_offload.is_some() {
1110 compiled.enable_moe_topk_capture(self.cfg.num_experts);
1111 }
1112 }
1113
1114 pub fn sync_moe_residency(&self, compiled: &mut rlx_runtime::CompiledGraph) {
1116 if let Some(mo) = &self.moe_offload {
1117 push_moe_residency(compiled, &mo.per_layer_resident_masks());
1118 }
1119 }
1120
1121 #[allow(dead_code)]
1122 fn moe_prepare_forward(&self, compiled: &mut rlx_runtime::CompiledGraph) {
1123 self.bind_moe_host_weights();
1124 if self.moe_offload.is_some() {
1125 compiled.enable_moe_topk_capture(self.cfg.num_experts);
1126 self.sync_moe_residency(compiled);
1127 }
1128 }
1129
1130 #[allow(dead_code)]
1131 fn moe_finish_forward(
1132 &mut self,
1133 compiled: &mut rlx_runtime::CompiledGraph,
1134 denoise_step: usize,
1135 is_prefill_block: bool,
1136 ) -> bool {
1137 let Some(layers) = compiled.take_moe_topk_capture() else {
1138 return false;
1139 };
1140 let store = self.moe_store.clone();
1141 let Some(mo) = self.moe_offload.as_mut() else {
1142 return false;
1143 };
1144 let refreshed = if let Some(store) = store.as_ref() {
1145 mo.refresh_from_capture_with_store(store, &layers, denoise_step, is_prefill_block)
1146 } else {
1147 mo.refresh_from_capture(&layers, denoise_step, is_prefill_block)
1148 };
1149 if refreshed {
1150 push_moe_residency(compiled, &mo.per_layer_resident_masks());
1151 }
1152 refreshed
1153 }
1154
1155 fn bind_moe_host_weights(&self) {
1157 if self.moe_offload.is_none() {
1158 rlx_cpu::moe_residency::bind_host_weights(None);
1159 return;
1160 }
1161 if let Some(store) = &self.moe_store {
1162 rlx_cpu::moe_residency::bind_host_weights(Some(moe_host_bind_from_store(store)));
1163 } else {
1164 rlx_cpu::moe_residency::bind_host_weights(None);
1165 }
1166 }
1167
1168 pub fn moe_offload_after_forward(&mut self, compiled: &mut rlx_runtime::CompiledGraph) -> bool {
1170 let Some(mo) = self.moe_offload.as_mut() else {
1171 return false;
1172 };
1173 let Some(layers) = compiled.take_moe_topk_capture() else {
1174 return false;
1175 };
1176 let refreshed = mo.refresh_from_capture(&layers, self.moe_refresh_step, false);
1177 if refreshed {
1178 self.sync_moe_residency(compiled);
1179 }
1180 self.moe_refresh_step = self.moe_refresh_step.saturating_add(1);
1181 refreshed
1182 }
1183
1184 pub fn moe_refresh_after_forward(&mut self, expert_idx: &[u32]) -> bool {
1186 let Some(mo) = self.moe_offload.as_mut() else {
1187 return false;
1188 };
1189 let refresh = mo.pools[0].should_refresh(
1190 rlx_runtime::MoEExecMode::Reuse,
1191 self.moe_refresh_step,
1192 false,
1193 );
1194 if refresh {
1195 for pool in &mut mo.pools {
1196 pool.refresh_from_indices(expert_idx);
1197 }
1198 }
1199 self.moe_refresh_step = self.moe_refresh_step.saturating_add(1);
1200 refresh
1201 }
1202
1203 pub fn with_compile_profiles(
1205 mut self,
1206 prefill: CompileProfile,
1207 decode: CompileProfile,
1208 ) -> Self {
1209 self.prefill_profile = prefill;
1210 self.decode_profile = decode;
1211 self
1212 }
1213
1214 fn profile_compile_options(&self, decode: bool) -> CompileOptions {
1215 let profile = if decode {
1216 &self.decode_profile
1217 } else {
1218 &self.prefill_profile
1219 };
1220 compile_options_from_profile(profile, self.device, KernelDispatchConfig::default())
1221 }
1222
1223 fn dyn_compile_options(&self, config: &ModelExecutionConfig) -> CompileOptions {
1224 let decode = matches!(config.preset, ExecutionPreset::Qwen35Decode);
1225 let mut opts = self.profile_compile_options(decode);
1226 opts.kernel_dispatch = config.component().kernel_dispatch;
1227 opts.dim_binding(config.dim_binding())
1228 }
1229
1230 fn bucketed_decode_compile_options(&self) -> CompileOptions {
1231 self.profile_compile_options(true)
1232 }
1233
1234 pub fn compile_prefill_built(
1236 &self,
1237 cache: &mut Qwen35CompileCache,
1238 built: rlx_flow::BuiltModel,
1239 batch: usize,
1240 seq: usize,
1241 ) -> Result<rlx_runtime::CompiledGraph> {
1242 let config = self.execution_config(prefill_config(batch, seq));
1243 let opts = self.dyn_compile_options(&config);
1244 cache.compile_built(built, &config, &opts)
1245 }
1246
1247 pub fn cfg(&self) -> &Qwen35Config {
1248 &self.cfg
1249 }
1250 pub fn device(&self) -> Device {
1251 self.device
1252 }
1253 pub fn max_seq(&self) -> usize {
1254 self.max_seq
1255 }
1256 pub fn lm_vocab_size(&self) -> usize {
1257 self.weights.lm_vocab_size(&self.cfg)
1258 }
1259
1260 pub fn has_vision(&self) -> bool {
1262 self.vision_encoder.is_some()
1263 }
1264
1265 pub fn mmproj_path(&self) -> Option<&std::path::Path> {
1267 self.mmproj_path.as_deref()
1268 }
1269
1270 fn effective_vocab(&self, graph_vocab: usize) -> usize {
1271 self.lm_vocab_size().min(graph_vocab)
1272 }
1273
1274 fn compile_hir_for_config(
1275 &mut self,
1276 config: ModelExecutionConfig,
1277 aot_disk_key: &str,
1278 hir: rlx_ir::hir::HirModule,
1279 ) -> Result<rlx_runtime::CompiledGraph> {
1280 let config = self.execution_config(config);
1281 let opts = self.dyn_compile_options(&config);
1282 if let Some(aot) = self.aot_cache.as_ref() {
1283 return Ok(aot.compile_hir_cached(aot_disk_key, self.device, hir, &opts)?);
1284 }
1285 if config.preset == ExecutionPreset::Qwen35Decode {
1288 return Ok(Session::new(self.device).compile_hir_with(hir, &opts)?);
1289 }
1290 let cache = self
1291 .predict_hir_cache
1292 .get_or_insert_with(|| make_qwen35_dyn_cache(self.device, 64, None));
1293 let hir = hir;
1294 get_or_specialize_hir_with_options(cache, &config, || hir.clone(), &opts, |_| Ok(()))?;
1295 if self.device == Device::Cpu {
1296 let compiled = get_or_specialize_hir_with_options(
1297 cache,
1298 &config,
1299 || hir.clone(),
1300 &opts,
1301 |_| Ok(()),
1302 )?;
1303 return Ok(compiled.clone());
1304 }
1305 Ok(Session::new(self.device).compile_hir_with(hir, &opts)?)
1306 }
1307
1308 fn lm_loader(&self) -> Option<&GgufLoader> {
1309 self.gguf_loader.as_ref()
1310 }
1311
1312 fn argmax_batch_from_hidden(&self, hidden: &[f32]) -> Result<Vec<u32>> {
1313 let n_embd = self.cfg.hidden_size;
1314 let mut toks = Vec::with_capacity(self.batch);
1315 for b in 0..self.batch {
1316 let h = &hidden[b * n_embd..(b + 1) * n_embd];
1317 let (idx, _) = greedy_lm_head_argmax(&self.weights, &self.cfg, h, self.lm_loader())?;
1318 toks.push(idx);
1319 }
1320 Ok(toks)
1321 }
1322
1323 fn sample_batch_from_hidden(&self, hidden: &[f32], opts: SampleOpts) -> Result<Vec<u32>> {
1324 let n_embd = self.cfg.hidden_size;
1325 let mut toks = Vec::with_capacity(self.batch);
1326 for b in 0..self.batch {
1327 let h = &hidden[b * n_embd..(b + 1) * n_embd];
1328 toks.push(sample_lm_head_from_hidden(
1329 &self.weights,
1330 &self.cfg,
1331 h,
1332 self.lm_loader(),
1333 opts,
1334 )?);
1335 }
1336 Ok(toks)
1337 }
1338
1339 fn decode_step_trunk_raw(
1340 &mut self,
1341 cache: &mut Qwen35DecodeCache,
1342 tokens: &[u32],
1343 generated_per_row: &[usize],
1344 ) -> Result<(Vec<f32>, Option<Vec<f32>>)> {
1345 if self.dynamic_decode {
1346 return self.decode_step_dynamic_raw(cache, tokens, generated_per_row);
1347 }
1348 let past_seq = cache.past_seq;
1349 let head_half = self.cfg.key_length / 2;
1350 let (cos, sin) = mrope_slice_at_pos(&self.cfg, past_seq, head_half);
1351 let use_bucket = self
1352 .decode_compile_cache
1353 .as_ref()
1354 .and_then(|c| c.bucket_for(past_seq as u64))
1355 .is_some();
1356 if use_bucket {
1357 self.decode_step_bucketed_raw(cache, tokens, generated_per_row, &cos, &sin)
1358 } else {
1359 let feeds_owned = decode_step_feeds(
1360 &self.cfg,
1361 cache,
1362 tokens,
1363 &cos,
1364 &sin,
1365 None,
1366 generated_per_row,
1367 )?;
1368 let feeds: Vec<(&str, &[f32])> = feeds_owned
1369 .iter()
1370 .map(|(k, v)| (k.as_str(), v.as_slice()))
1371 .collect();
1372 if !self.decode_graphs.contains_key(&past_seq) {
1373 let (hir, params, packed) = build_qwen35_decode_hir_ext(
1374 &self.cfg,
1375 self.weights.clone(),
1376 self.batch,
1377 past_seq,
1378 false,
1379 self.mtp_logits_path,
1380 self.fast_mtp,
1381 self.fast_greedy_lm_head,
1382 )?;
1383 let mut compiled = self.compile_hir_for_config(
1384 decode_config(self.batch, past_seq),
1385 &format!("decode_{past_seq}"),
1386 hir,
1387 )?;
1388 for (name, data) in ¶ms {
1389 compiled.set_param(name, data);
1390 }
1391 upload_packed_opt(
1392 &mut compiled,
1393 self.gguf_loader.as_mut(),
1394 &packed,
1395 &mut self.packed_bytes_cache,
1396 )?;
1397 self.decode_graphs.insert(past_seq, compiled);
1398 }
1399 let step = self.moe_refresh_step;
1400 let has_moe = self.moe_offload.is_some();
1401 let num_experts = self.cfg.num_experts;
1402 let moe_masks = self
1403 .moe_offload
1404 .as_ref()
1405 .map(|m| m.per_layer_resident_masks());
1406 self.bind_moe_host_weights();
1407 let outs = {
1408 let compiled = self.decode_graphs.get_mut(&past_seq).unwrap();
1409 if has_moe {
1410 compiled.enable_moe_topk_capture(num_experts);
1411 if let Some(layers) = &moe_masks {
1412 push_moe_residency(compiled, layers);
1413 }
1414 }
1415 compiled.run(&feeds)
1416 };
1417 if has_moe {
1418 let layers = {
1419 let compiled = self.decode_graphs.get_mut(&past_seq).unwrap();
1420 compiled.take_moe_topk_capture()
1421 };
1422 if let (Some(mo), Some(layers)) = (self.moe_offload.as_mut(), layers) {
1423 let store = self.moe_store.as_ref();
1424 let compiled = self.decode_graphs.get_mut(&past_seq).unwrap();
1425 if refresh_moe_from_capture(mo, store, compiled, &layers, step, false) {
1426 if let Some(store) = self.moe_store.as_ref() {
1427 store.apply_to_compiled(compiled);
1428 }
1429 }
1430 }
1431 }
1432 self.moe_refresh_step = step.saturating_add(1);
1433 advance_cache_from_decode_outputs(
1434 &self.cfg,
1435 cache,
1436 outs,
1437 None,
1438 self.mtp_logits_path,
1439 false,
1440 self.fast_greedy_lm_head,
1441 )
1442 }
1443 }
1444
1445 fn decode_step_dynamic_raw(
1446 &mut self,
1447 cache: &mut Qwen35DecodeCache,
1448 tokens: &[u32],
1449 generated_per_row: &[usize],
1450 ) -> Result<(Vec<f32>, Option<Vec<f32>>)> {
1451 let past_seq = cache.past_seq;
1452 let head_half = self.cfg.key_length / 2;
1453 let (cos, sin) = mrope_slice_at_pos(&self.cfg, past_seq, head_half);
1454 let feeds_owned = decode_step_feeds(
1455 &self.cfg,
1456 cache,
1457 tokens,
1458 &cos,
1459 &sin,
1460 None,
1461 generated_per_row,
1462 )?;
1463 let feeds: Vec<(&str, &[f32])> = feeds_owned
1464 .iter()
1465 .map(|(k, v)| (k.as_str(), v.as_slice()))
1466 .collect();
1467
1468 let config = self.execution_config(decode_config(self.batch, past_seq));
1469 let compile_opts = self.dyn_compile_options(&config);
1470 let dyn_cache = self
1471 .decode_dynamic_cache
1472 .as_mut()
1473 .ok_or_else(|| anyhow!("dynamic decode without cache"))?;
1474 let cfg = self.cfg.clone();
1475 let weights = self.weights.clone();
1476 let max_seq = self.max_seq;
1477 let mtp_logits_path = self.mtp_logits_path;
1478 let fast_mtp = self.fast_mtp;
1479 let fast_greedy = self.fast_greedy_lm_head;
1480 let batch = self.batch;
1481 let decode_params = &self.decode_dynamic_params;
1482 let decode_packed = &self.decode_dynamic_packed;
1483 let gguf_loader = &mut self.gguf_loader;
1484 let packed_bytes_cache = &mut self.packed_bytes_cache;
1485 let compiled = get_or_specialize_hir_with_options(
1486 dyn_cache,
1487 &config,
1488 || {
1489 build_qwen35_decode_hir_dynamic_ext(
1490 &cfg,
1491 weights,
1492 batch,
1493 max_seq,
1494 mtp_logits_path,
1495 fast_mtp,
1496 fast_greedy,
1497 )
1498 .expect("dynamic decode HIR")
1499 .0
1500 },
1501 &compile_opts,
1502 |c| {
1503 for (name, data) in decode_params {
1504 c.set_param(name, data);
1505 }
1506 upload_packed_opt(c, gguf_loader.as_mut(), decode_packed, packed_bytes_cache)
1507 },
1508 )?;
1509 let outs = compiled.run(&feeds);
1510 advance_cache_from_decode_outputs(
1511 &self.cfg,
1512 cache,
1513 outs,
1514 None,
1515 self.mtp_logits_path,
1516 false,
1517 self.fast_greedy_lm_head,
1518 )
1519 }
1520
1521 fn trunk_to_logits(&self, trunk: Vec<f32>, is_hidden: bool) -> Result<Vec<f32>> {
1522 if !is_hidden {
1523 return Ok(trunk);
1524 }
1525 let n_embd = self.cfg.hidden_size;
1526 let vocab = self.lm_vocab_size();
1527 let mut logits = Vec::with_capacity(self.batch * vocab);
1528 for b in 0..self.batch {
1529 let h = &trunk[b * n_embd..(b + 1) * n_embd];
1530 logits.extend(lm_head_logits_row(
1531 &self.weights,
1532 &self.cfg,
1533 h,
1534 self.lm_loader(),
1535 )?);
1536 }
1537 Ok(logits)
1538 }
1539
1540 fn ensure_decode_bucket_compiled(&mut self, key: u64) -> Result<usize> {
1542 let decode_opts = self.bucketed_decode_compile_options();
1543 let cache_mut = self
1544 .decode_compile_cache
1545 .as_mut()
1546 .ok_or_else(|| anyhow!("bucketed decode without cache"))?;
1547 let cfg = self.cfg.clone();
1548 let weights = self.weights.clone();
1549 let batch = self.batch;
1550 let mtp_logits_path = self.mtp_logits_path;
1551 let fast_mtp = self.fast_mtp;
1552 let fast_greedy = self.fast_greedy_lm_head;
1553 let packed_slot = RefCell::new(None::<PackedParams>);
1554 let (upper, compiled) = cache_mut
1555 .ensure_hir_with_params(
1556 key,
1557 |upper| {
1558 let (hir, params, packed) = build_qwen35_decode_hir_ext(
1559 &cfg,
1560 weights.clone(),
1561 batch,
1562 upper as usize,
1563 true,
1564 mtp_logits_path,
1565 fast_mtp,
1566 fast_greedy,
1567 )
1568 .expect("qwen35 decode HIR");
1569 *packed_slot.borrow_mut() = Some(packed);
1570 (hir, params)
1571 },
1572 &decode_opts,
1573 )
1574 .ok_or_else(|| anyhow!("past_seq {key} outside decode buckets"))?;
1575 if let Some(packed) = packed_slot.take() {
1576 if !packed.is_empty() {
1577 upload_packed_opt(
1578 compiled,
1579 self.gguf_loader.as_mut(),
1580 &packed,
1581 &mut self.packed_bytes_cache,
1582 )?;
1583 }
1584 }
1585 Ok(upper as usize)
1586 }
1587
1588 fn warm_decode_graphs(&mut self) -> Result<()> {
1590 let upper_bounds: Vec<usize> = match self.decode_compile_cache.as_ref() {
1591 Some(cache) => cache.buckets().map(|r| (r.end - 1) as usize).collect(),
1592 None => return Ok(()),
1593 };
1594 let t = Instant::now();
1595 let total = upper_bounds.len();
1596 for upper in upper_bounds {
1597 self.ensure_decode_bucket_compiled(upper as u64)?;
1598 }
1599 if total > 0 {
1600 eprintln!(
1601 "[qwen35] warmed {total} decode bucket(s) in {:.2?}",
1602 t.elapsed()
1603 );
1604 }
1605 Ok(())
1606 }
1607
1608 fn warm_predict_graph(&mut self) -> Result<()> {
1610 if self.compiled.is_some() {
1611 return Ok(());
1612 }
1613 let t = Instant::now();
1614 self.ensure_predict_compiled()?;
1615 eprintln!("[qwen35] warmed predict graph in {:.2?}", t.elapsed());
1616 Ok(())
1617 }
1618
1619 pub fn reset_decode_cache(&mut self) {
1621 self.decode_cache = None;
1622 }
1623
1624 pub fn decode_cache_checkpoint(&self) -> Option<Qwen35DecodeCache> {
1626 self.decode_cache.clone()
1627 }
1628
1629 pub fn restore_decode_cache(&mut self, cache: Option<Qwen35DecodeCache>) {
1631 self.decode_cache = cache;
1632 }
1633
1634 pub fn commit_decode_tokens(&mut self, tokens: &[u32]) -> Result<()> {
1636 for &tok in tokens {
1637 let _ = self.decode_get_logits(tok)?;
1638 }
1639 Ok(())
1640 }
1641
1642 fn ensure_predict_compiled(&mut self) -> Result<()> {
1643 if self.compiled.is_some() {
1644 return Ok(());
1645 }
1646 let debug_layers = std::env::var("RLX_QWEN35_DEBUG_LAYERS")
1654 .map(|v| v == "1")
1655 .unwrap_or(false);
1656 let t = Instant::now();
1657 let (hir, params, packed) = build_qwen35_hir_sized_ext(
1658 &self.cfg,
1659 self.weights.clone(),
1660 self.batch,
1661 self.max_seq,
1662 true,
1663 self.last_logits_only,
1664 self.enable_mtp,
1665 false,
1666 None,
1667 self.runtime_mrope,
1668 self.fast_mtp,
1669 false,
1670 debug_layers,
1671 )?;
1672 eprintln!(
1673 "[qwen35] built predict IR (lazy) in {:.2?} (params={}, packed={})",
1674 t.elapsed(),
1675 params.len(),
1676 packed.len(),
1677 );
1678 let t = Instant::now();
1679 let mut compiled = self.compile_hir_for_config(
1680 prefill_config(self.batch, self.max_seq),
1681 "predict_logits",
1682 hir,
1683 )?;
1684 eprintln!(
1685 "[qwen35] compiled predict graph (lazy) in {:.2?}",
1686 t.elapsed()
1687 );
1688 let t = Instant::now();
1689 for (name, data) in ¶ms {
1690 compiled.set_param(name, data);
1691 }
1692 if !packed.is_empty() {
1693 upload_packed_opt(
1694 &mut compiled,
1695 self.gguf_loader.as_mut(),
1696 &packed,
1697 &mut self.packed_bytes_cache,
1698 )?;
1699 }
1700 eprintln!(
1701 "[qwen35] uploaded predict {} F32 + {} packed params in {:.2?}",
1702 params.len(),
1703 packed.len(),
1704 t.elapsed(),
1705 );
1706 self.compiled = Some(compiled);
1707 Ok(())
1708 }
1709
1710 pub fn predict_logits(&mut self, prompt_ids: &[u32]) -> Result<Qwen35PrefillOutput> {
1711 let out = self
1712 .predict_logits_batch(&[prompt_ids.to_vec()])
1713 .map(|v| v.into_iter().next().unwrap())?;
1714 if !out.logits.is_empty() {
1721 let mut min = f32::INFINITY;
1722 let mut max = f32::NEG_INFINITY;
1723 for &v in &out.logits {
1724 if v.is_finite() {
1725 if v < min {
1726 min = v;
1727 }
1728 if v > max {
1729 max = v;
1730 }
1731 }
1732 }
1733 if (max - min).abs() < 1e-6 {
1734 bail!(
1735 "qwen35: predict_logits returned degenerate output \
1736 (min={min}, max={max}) — the forward pass produced \
1737 all-equal logits, which indicates a broken op or a \
1738 mis-routed weight tensor in the trunk. Re-run with \
1739 RUST_LOG=debug to capture the offending layer."
1740 );
1741 }
1742 }
1743 Ok(out)
1744 }
1745
1746 pub fn predict_logits_batch(
1750 &mut self,
1751 batch_prompts: &[Vec<u32>],
1752 ) -> Result<Vec<Qwen35PrefillOutput>> {
1753 if batch_prompts.len() != self.batch {
1754 bail!(
1755 "qwen35: expected {} prompts (batch={}), got {}",
1756 self.batch,
1757 self.batch,
1758 batch_prompts.len()
1759 );
1760 }
1761 let max_prompt = batch_prompts.iter().map(|p| p.len()).max().unwrap_or(0);
1762 if max_prompt > self.max_seq {
1763 bail!(
1764 "qwen35: prompt length {max_prompt} exceeds compiled max_seq={}",
1765 self.max_seq
1766 );
1767 }
1768 let padded = pack_input_ids(batch_prompts, self.max_seq)?;
1769 let prompt_lens: Vec<usize> = batch_prompts.iter().map(|p| p.len()).collect();
1770 let last_idx = last_token_indices(&prompt_lens);
1771
1772 let mut feeds: Vec<(&str, &[f32])> = vec![("input_ids", padded.as_slice())];
1773 if self.last_logits_only {
1774 feeds.push(("last_token_idx", last_idx.as_slice()));
1775 }
1776 let rope_owned = self.mrope_prefill_rope_feeds(max_prompt);
1777 for (name, data) in &rope_owned {
1778 feeds.push((name.as_str(), data.as_slice()));
1779 }
1780 self.ensure_predict_compiled()?;
1781 let outs = self.compiled.as_mut().unwrap().run(&feeds);
1782 if outs.is_empty() {
1783 bail!("qwen35: forward produced no outputs");
1784 }
1785 if std::env::var("RLX_QWEN35_DEBUG_LAYERS").as_deref() == Ok("1") {
1790 let n_layers = self.cfg.num_hidden_layers - self.cfg.nextn_predict_layers;
1795 for i in 0..outs.len() {
1796 let v = &outs[i];
1797 let mut min = f32::INFINITY;
1798 let mut max = f32::NEG_INFINITY;
1799 let mut sum = 0.0f64;
1800 let mut nan = 0usize;
1801 let mut nnz = 0usize;
1802 for &x in v {
1803 if x.is_nan() {
1804 nan += 1;
1805 continue;
1806 }
1807 sum += x as f64;
1808 if x < min {
1809 min = x;
1810 }
1811 if x > max {
1812 max = x;
1813 }
1814 if x != 0.0 {
1815 nnz += 1;
1816 }
1817 }
1818 let mean = sum / v.len().max(1) as f64;
1819 let label = if i == 0 {
1820 "logits".to_string()
1821 } else if i - 1 < n_layers {
1822 format!("layer_{:02}", i - 1)
1823 } else {
1824 format!("extra_{:02}", i - 1 - n_layers)
1825 };
1826 eprintln!(
1827 "[qwen35][debug-layers] {label}: len={} nnz={} nan={} min={} max={} mean={:.6}",
1828 v.len(),
1829 nnz,
1830 nan,
1831 min,
1832 max,
1833 mean
1834 );
1835 }
1836 }
1837 let vocab_size = if self.last_logits_only {
1838 outs[0].len() / self.batch
1839 } else {
1840 outs[0].len() / (self.batch * self.max_seq)
1841 };
1842 let sample_vocab = self.effective_vocab(vocab_size);
1843 let mtp_logits = if self.enable_mtp && outs.len() >= 2 {
1844 Some(outs[1].clone())
1845 } else {
1846 None
1847 };
1848 let mut per_batch = Vec::with_capacity(self.batch);
1849 for b in 0..self.batch {
1850 let start = b * vocab_size;
1851 let mut row = outs[0][start..start + vocab_size].to_vec();
1852 row.truncate(sample_vocab);
1853 per_batch.push(Qwen35PrefillOutput {
1854 logits: row,
1855 mtp_logits: mtp_logits.as_ref().map(|m| {
1856 let m_vocab = m.len() / self.batch.max(1);
1857 let mut mv = m[b * m_vocab..(b + 1) * m_vocab].to_vec();
1858 mv.truncate(sample_vocab);
1859 mv
1860 }),
1861 vocab_size: sample_vocab,
1862 });
1863 }
1864 Ok(per_batch)
1865 }
1866
1867 pub fn generate<F>(&mut self, prompt_ids: &[u32], n_new: usize, on_token: F) -> Result<Vec<u32>>
1869 where
1870 F: FnMut(u32) -> bool,
1871 {
1872 self.generate_with_opts(prompt_ids, n_new, SampleOpts::greedy(), on_token)
1873 }
1874
1875 pub fn generate_with_opts<F>(
1877 &mut self,
1878 prompt_ids: &[u32],
1879 n_new: usize,
1880 opts: SampleOpts,
1881 mut on_token: F,
1882 ) -> Result<Vec<u32>>
1883 where
1884 F: FnMut(u32) -> bool,
1885 {
1886 if self.batch != 1 {
1887 bail!(
1888 "qwen35::generate: runner batch={} — use generate_batch() instead",
1889 self.batch
1890 );
1891 }
1892 let generated = self
1893 .generate_batch_with_opts(&[prompt_ids.to_vec()], n_new, None, opts, |_, tok| {
1894 on_token(tok)
1895 })?
1896 .into_iter()
1897 .next()
1898 .unwrap_or_default();
1899 Ok(generated)
1900 }
1901
1902 pub fn generate_batch<F>(
1904 &mut self,
1905 prompts: &[Vec<u32>],
1906 n_new: usize,
1907 on_token: F,
1908 ) -> Result<Vec<Vec<u32>>>
1909 where
1910 F: FnMut(usize, u32) -> bool,
1911 {
1912 self.generate_batch_with_opts(prompts, n_new, None, SampleOpts::greedy(), on_token)
1913 }
1914
1915 pub fn generate_batch_with_opts<F>(
1919 &mut self,
1920 prompts: &[Vec<u32>],
1921 n_new: usize,
1922 n_new_per_row: Option<&[usize]>,
1923 opts: SampleOpts,
1924 mut on_token: F,
1925 ) -> Result<Vec<Vec<u32>>>
1926 where
1927 F: FnMut(usize, u32) -> bool,
1928 {
1929 if prompts.is_empty() {
1930 bail!("qwen35::generate_batch: prompts must be non-empty");
1931 }
1932 if prompts.len() != self.batch {
1933 bail!(
1934 "qwen35::generate_batch: expected {} prompts, got {}",
1935 self.batch,
1936 prompts.len()
1937 );
1938 }
1939 if let Some(limits) = n_new_per_row {
1940 if limits.len() != self.batch {
1941 bail!(
1942 "qwen35::generate_batch: n_new_per_row len {} != batch {}",
1943 limits.len(),
1944 self.batch
1945 );
1946 }
1947 }
1948 for (i, p) in prompts.iter().enumerate() {
1949 if p.is_empty() {
1950 bail!("qwen35::generate_batch: prompt row {i} is empty");
1951 }
1952 }
1953
1954 self.decode_cache = None;
1955
1956 let _prompt_lens: Vec<usize> = prompts.iter().map(|p| p.len()).collect();
1957 let row_limits: Vec<usize> = if let Some(limits) = n_new_per_row {
1958 limits.to_vec()
1959 } else {
1960 vec![n_new; self.batch]
1961 };
1962
1963 let (trunk, mut cache, _) = self.prefill_seed_decode_cache(prompts)?;
1964
1965 let mut generated: Vec<Vec<u32>> = vec![Vec::new(); self.batch];
1966 let mut active = vec![true; self.batch];
1967 let mut row_gen_count = vec![0usize; self.batch];
1968
1969 let mut next_tokens = if self.fast_greedy_lm_head && opts.greedy {
1970 self.argmax_batch_from_hidden(&trunk)?
1971 } else if self.fast_greedy_lm_head
1972 && sample_lm_cap(opts, self.lm_vocab_size()) < self.lm_vocab_size()
1973 {
1974 self.sample_batch_from_hidden(&trunk, opts)?
1975 } else {
1976 let logits = self.trunk_to_logits(trunk, self.fast_greedy_lm_head)?;
1977 sample_logits_batch(&logits, self.lm_vocab_size(), self.batch, opts)
1978 };
1979 if n_new > 0 {
1980 for b in 0..self.batch {
1981 if row_gen_count[b] >= row_limits[b] {
1982 active[b] = false;
1983 continue;
1984 }
1985 let tok = next_tokens[b];
1986 generated[b].push(tok);
1987 row_gen_count[b] += 1;
1988 active[b] = on_token(b, tok) && row_gen_count[b] < row_limits[b];
1989 }
1990 }
1991
1992 for _ in 1..n_new {
1993 if !active.iter().any(|&a| a) {
1994 break;
1995 }
1996 if cache.past_seq >= self.max_seq - 1 {
1997 bail!("qwen35: decode cache reached max_seq={}", self.max_seq);
1998 }
1999 next_tokens = self.decode_step(&mut cache, &next_tokens, &row_gen_count, opts)?;
2000 for b in 0..self.batch {
2001 if !active[b] || row_gen_count[b] >= row_limits[b] {
2002 active[b] = false;
2003 continue;
2004 }
2005 let tok = next_tokens[b];
2006 generated[b].push(tok);
2007 row_gen_count[b] += 1;
2008 active[b] = on_token(b, tok) && row_gen_count[b] < row_limits[b];
2009 }
2010 self.decode_cache = Some(cache.clone());
2011 }
2012 Ok(generated)
2013 }
2014
2015 pub fn prefill_get_last_logits(&mut self, prompt_ids: &[u32]) -> Result<Vec<f32>> {
2018 Ok(self.prefill_seed_for_decode(prompt_ids)?.trunk_logits)
2019 }
2020
2021 pub fn prefill_seed_for_decode(&mut self, prompt_ids: &[u32]) -> Result<Qwen35PrefillSeed> {
2023 if self.batch != 1 {
2024 bail!(
2025 "qwen35: prefill_seed_for_decode requires batch=1 (runner batch={})",
2026 self.batch
2027 );
2028 }
2029 let (trunk, _, mtp_logits) = self.prefill_seed_decode_cache(&[prompt_ids.to_vec()])?;
2030 Ok(Qwen35PrefillSeed {
2031 trunk_logits: self.trunk_to_logits(trunk, self.fast_greedy_lm_head)?,
2032 mtp_logits,
2033 })
2034 }
2035
2036 pub fn prefill_multimodal(
2039 &mut self,
2040 prompt: &str,
2041 rgb: &[u8],
2042 img_w: usize,
2043 img_h: usize,
2044 tokenizer: Option<&std::path::Path>,
2045 ) -> Result<Qwen35PrefillSeed> {
2046 let (trunk, mtp_logits) =
2047 self.prefill_multimodal_trunk(prompt, rgb, img_w, img_h, tokenizer)?;
2048 Ok(Qwen35PrefillSeed {
2049 trunk_logits: self.trunk_to_logits(trunk, self.fast_greedy_lm_head)?,
2050 mtp_logits,
2051 })
2052 }
2053
2054 pub fn prefill_from_assembled(
2056 &mut self,
2057 prefill: MultimodalPrefill,
2058 ) -> Result<Qwen35PrefillSeed> {
2059 if self.batch != 1 {
2060 bail!(
2061 "qwen35: prefill_from_assembled requires batch=1 (runner batch={})",
2062 self.batch
2063 );
2064 }
2065 self.mrope_section_positions = Some(prefill.mrope_sections.clone());
2066 let (trunk, _, mtp_logits) = self.prefill_seed_from_hidden(prefill)?;
2067 Ok(Qwen35PrefillSeed {
2068 trunk_logits: self.trunk_to_logits(trunk, self.fast_greedy_lm_head)?,
2069 mtp_logits,
2070 })
2071 }
2072
2073 fn prefill_multimodal_trunk(
2074 &mut self,
2075 prompt: &str,
2076 rgb: &[u8],
2077 img_w: usize,
2078 img_h: usize,
2079 tokenizer: Option<&std::path::Path>,
2080 ) -> Result<(Vec<f32>, Option<Vec<f32>>)> {
2081 if self.batch != 1 {
2082 bail!(
2083 "qwen35: prefill_multimodal requires batch=1 (runner batch={})",
2084 self.batch
2085 );
2086 }
2087 let vision = {
2088 let enc = self
2089 .vision_encoder
2090 .as_mut()
2091 .ok_or_else(|| anyhow!("qwen35: prefill_multimodal requires .mmproj(...)"))?;
2092 enc.encode_rgb(rgb, img_w, img_h)?
2093 };
2094 if self.weights.token_embd.is_empty() {
2095 bail!("qwen35: multimodal prefill requires token_embd weights");
2096 }
2097 let weights_path = self.weights_path.as_path();
2098 if weights_path.as_os_str().is_empty() {
2099 bail!("qwen35: multimodal prefill requires a GGUF weights path (for tokenizer)");
2100 }
2101 let n_embd = self.cfg.hidden_size;
2102 let mm = MultimodalPrompt {
2103 prompt,
2104 vision: &vision,
2105 };
2106 let prefill = mm.assemble(
2107 |text| encode_prompt_auto(weights_path, tokenizer, text),
2108 &self.weights.token_embd,
2109 n_embd,
2110 0,
2111 )?;
2112 self.mrope_section_positions = Some(prefill.mrope_sections.clone());
2113 let (trunk, _, mtp_logits) = self.prefill_seed_from_hidden(prefill)?;
2114 Ok((trunk, mtp_logits))
2115 }
2116
2117 pub fn generate_multimodal_with_opts<F>(
2119 &mut self,
2120 prompt: &str,
2121 rgb: &[u8],
2122 img_w: usize,
2123 img_h: usize,
2124 tokenizer: Option<&std::path::Path>,
2125 n_new: usize,
2126 opts: SampleOpts,
2127 mut on_token: F,
2128 ) -> Result<Vec<u32>>
2129 where
2130 F: FnMut(u32) -> bool,
2131 {
2132 if self.batch != 1 {
2133 bail!(
2134 "qwen35: generate_multimodal requires batch=1 (runner batch={})",
2135 self.batch
2136 );
2137 }
2138 self.decode_cache = None;
2139 let (trunk, _) = self.prefill_multimodal_trunk(prompt, rgb, img_w, img_h, tokenizer)?;
2140 let mut cache = self
2141 .decode_cache
2142 .take()
2143 .ok_or_else(|| anyhow!("qwen35: multimodal prefill did not seed decode cache"))?;
2144 let mut next_tokens = if self.fast_greedy_lm_head && opts.greedy {
2145 self.argmax_batch_from_hidden(&trunk)?
2146 } else if self.fast_greedy_lm_head
2147 && sample_lm_cap(opts, self.lm_vocab_size()) < self.lm_vocab_size()
2148 {
2149 self.sample_batch_from_hidden(&trunk, opts)?
2150 } else {
2151 let logits = self.trunk_to_logits(trunk, self.fast_greedy_lm_head)?;
2152 sample_logits_batch(&logits, self.lm_vocab_size(), 1, opts)
2153 };
2154 let mut generated = Vec::new();
2155 if n_new > 0 {
2156 let tok = next_tokens[0];
2157 generated.push(tok);
2158 if !on_token(tok) {
2159 return Ok(generated);
2160 }
2161 }
2162 let row_gen = vec![0usize];
2163 for _ in 1..n_new {
2164 if cache.past_seq >= self.max_seq - 1 {
2165 bail!("qwen35: decode cache reached max_seq={}", self.max_seq);
2166 }
2167 next_tokens = self.decode_step(&mut cache, &next_tokens, &row_gen, opts)?;
2168 let tok = next_tokens[0];
2169 generated.push(tok);
2170 self.decode_cache = Some(cache.clone());
2171 if !on_token(tok) {
2172 break;
2173 }
2174 }
2175 Ok(generated)
2176 }
2177
2178 pub fn generate_multimodal<F>(
2180 &mut self,
2181 prompt: &str,
2182 rgb: &[u8],
2183 img_w: usize,
2184 img_h: usize,
2185 tokenizer: Option<&std::path::Path>,
2186 n_new: usize,
2187 on_token: F,
2188 ) -> Result<Vec<u32>>
2189 where
2190 F: FnMut(u32) -> bool,
2191 {
2192 self.generate_multimodal_with_opts(
2193 prompt,
2194 rgb,
2195 img_w,
2196 img_h,
2197 tokenizer,
2198 n_new,
2199 SampleOpts::greedy(),
2200 on_token,
2201 )
2202 }
2203
2204 fn prefill_seed_decode_cache(
2206 &mut self,
2207 prompts: &[Vec<u32>],
2208 ) -> Result<(Vec<f32>, Qwen35DecodeCache, Option<Vec<f32>>)> {
2209 if prompts.len() != self.batch {
2210 bail!(
2211 "qwen35: expected {} prompts (batch={}), got {}",
2212 self.batch,
2213 self.batch,
2214 prompts.len()
2215 );
2216 }
2217 for (i, p) in prompts.iter().enumerate() {
2218 if p.is_empty() {
2219 bail!("qwen35: prompt row {i} is empty");
2220 }
2221 }
2222
2223 let prompt_lens: Vec<usize> = prompts.iter().map(|p| p.len()).collect();
2224 let seq = prompt_lens.iter().copied().max().unwrap();
2225 if seq > self.max_seq {
2226 bail!(
2227 "qwen35: prompt length {seq} exceeds compiled max_seq={}",
2228 self.max_seq
2229 );
2230 }
2231
2232 let input_ids = if self.dynamic_prefill {
2233 pack_input_ids(prompts, seq)?
2234 } else {
2235 pack_input_ids(prompts, self.max_seq)?
2236 };
2237 let last_idx = last_token_indices(&prompt_lens);
2238
2239 let mut feeds: Vec<(&str, &[f32])> = vec![("input_ids", input_ids.as_slice())];
2240 feeds.push(("last_token_idx", last_idx.as_slice()));
2241 let zero_in = zero_recurrent_inputs(&self.cfg, self.batch);
2242 for (name, data) in &zero_in {
2243 feeds.push((name, data.as_slice()));
2244 }
2245 let rope_owned = self.mrope_prefill_rope_feeds(seq);
2246 for (name, data) in &rope_owned {
2247 feeds.push((name.as_str(), data.as_slice()));
2248 }
2249
2250 let has_moe = self.moe_offload.is_some();
2251 let num_experts = self.cfg.num_experts;
2252 let moe_masks = self
2253 .moe_offload
2254 .as_ref()
2255 .map(|m| m.per_layer_resident_masks());
2256 self.bind_moe_host_weights();
2257
2258 let outs = if self.dynamic_prefill {
2259 let config = self.execution_config(prefill_config(self.batch, seq));
2260 let compile_opts = self.dyn_compile_options(&config);
2261 let compiled = {
2262 let cache = self
2263 .prefill_dynamic_cache
2264 .as_mut()
2265 .expect("dynamic prefill cache");
2266 let cfg = self.cfg.clone();
2267 let weights = self.weights.clone();
2268 let runtime_mrope = self.runtime_mrope;
2269 let mtp_logits_path = self.mtp_logits_path;
2270 let fast_mtp = self.fast_mtp;
2271 let fast_greedy = self.fast_greedy_lm_head;
2272 let cache_params = &self.prefill_cache_params;
2273 let cache_packed = &self.prefill_cache_packed;
2274 let gguf_loader = &mut self.gguf_loader;
2275 let packed_bytes_cache = &mut self.packed_bytes_cache;
2276 get_or_specialize_hir_with_options(
2277 cache,
2278 &config,
2279 || {
2280 build_qwen35_prefill_cache_hir_dynamic_ext(
2281 &cfg,
2282 weights,
2283 1,
2284 seq,
2285 runtime_mrope,
2286 mtp_logits_path,
2287 fast_mtp,
2288 fast_greedy,
2289 )
2290 .expect("dynamic prefill HIR")
2291 .0
2292 },
2293 &compile_opts,
2294 |c| {
2295 for (name, data) in cache_params {
2296 c.set_param(name, data);
2297 }
2298 upload_packed_opt(c, gguf_loader.as_mut(), cache_packed, packed_bytes_cache)
2299 },
2300 )?
2301 };
2302 if has_moe {
2303 compiled.enable_moe_topk_capture(num_experts);
2304 if let Some(layers) = &moe_masks {
2305 push_moe_residency(compiled, layers);
2306 }
2307 }
2308 let outs = compiled.run(&feeds);
2309 if let Some(layers) = compiled.take_moe_topk_capture() {
2310 if let Some(mo) = self.moe_offload.as_mut() {
2311 let store = self.moe_store.as_ref();
2312 if refresh_moe_from_capture(mo, store, compiled, &layers, 0, true) {
2313 if let Some(store) = self.moe_store.as_ref() {
2314 store.apply_to_compiled(compiled);
2315 }
2316 }
2317 }
2318 }
2319 outs
2320 } else {
2321 let compiled = self.prefill_cache.as_mut().expect("static prefill cache");
2322 if has_moe {
2323 compiled.enable_moe_topk_capture(num_experts);
2324 if let Some(layers) = &moe_masks {
2325 push_moe_residency(compiled, layers);
2326 }
2327 }
2328 let outs = compiled.run(&feeds);
2329 let layers = if has_moe {
2330 compiled.take_moe_topk_capture()
2331 } else {
2332 None
2333 };
2334 if let (Some(mo), Some(layers)) = (self.moe_offload.as_mut(), layers) {
2335 let store = self.moe_store.as_ref();
2336 if refresh_moe_from_capture(mo, store, compiled, &layers, 0, true) {
2337 if let Some(store) = self.moe_store.as_ref() {
2338 store.apply_to_compiled(compiled);
2339 }
2340 }
2341 }
2342 outs
2343 };
2344 let (trunk, mut cache, mtp_logits) = seed_cache_from_outputs(
2345 &self.cfg,
2346 self.batch,
2347 seq,
2348 &prompt_lens,
2349 outs,
2350 self.mtp_logits_path,
2351 self.fast_greedy_lm_head,
2352 )?;
2353 zero_prompt_padding_kv(&self.cfg, &mut cache, seq);
2354 self.decode_cache = Some(cache.clone());
2355 Ok((trunk, cache, mtp_logits))
2356 }
2357
2358 fn prefill_seed_from_hidden(
2360 &mut self,
2361 prefill: MultimodalPrefill,
2362 ) -> Result<(Vec<f32>, Qwen35DecodeCache, Option<Vec<f32>>)> {
2363 let seq = prefill.seq.len();
2364 if seq == 0 {
2365 bail!("qwen35: multimodal prefill seq is empty");
2366 }
2367 if seq > self.max_seq {
2368 bail!(
2369 "qwen35: multimodal seq {seq} exceeds compiled max_seq={}",
2370 self.max_seq
2371 );
2372 }
2373 let n_embd = self.cfg.hidden_size;
2374 if prefill.hidden.len() != seq * n_embd {
2375 bail!(
2376 "qwen35: prefill hidden len {} != seq*n_embd {}*{}",
2377 prefill.hidden.len(),
2378 seq,
2379 n_embd
2380 );
2381 }
2382
2383 let last_idx = vec![prefill.last_token_idx as f32];
2384 let zero_in = zero_recurrent_inputs(&self.cfg, self.batch);
2385 let input_ids = if self.mtp_logits_path || self.enable_mtp {
2386 Some(pack_input_ids(std::slice::from_ref(&prefill.seq), seq)?)
2387 } else {
2388 None
2389 };
2390 let mut feeds: Vec<(&str, &[f32])> = vec![("prefill_hidden", prefill.hidden.as_slice())];
2391 feeds.push(("last_token_idx", last_idx.as_slice()));
2392 for (name, data) in &zero_in {
2393 feeds.push((name, data.as_slice()));
2394 }
2395 if let Some(ref ids) = input_ids {
2396 feeds.push(("input_ids", ids.as_slice()));
2397 }
2398 let rope_owned = self.mrope_prefill_rope_feeds(seq);
2399 for (name, data) in &rope_owned {
2400 feeds.push((name.as_str(), data.as_slice()));
2401 }
2402
2403 let config = self.execution_config(hidden_prefill_config(self.batch, seq));
2404 let compile_opts = self.dyn_compile_options(&config);
2405 let cache = self
2406 .prefill_hidden_dynamic_cache
2407 .as_mut()
2408 .ok_or_else(|| anyhow!("qwen35: hidden prefill cache missing (mmproj not loaded?)"))?;
2409 let cfg = self.cfg.clone();
2410 let weights = self.weights.clone();
2411 let runtime_mrope = self.runtime_mrope;
2412 let mtp_logits_path = self.mtp_logits_path;
2413 let fast_mtp = self.fast_mtp;
2414 let fast_greedy = self.fast_greedy_lm_head;
2415 let hidden_params = &self.prefill_hidden_cache_params;
2416 let hidden_packed = &self.prefill_hidden_cache_packed;
2417 let gguf_loader = &mut self.gguf_loader;
2418 let packed_bytes_cache = &mut self.packed_bytes_cache;
2419 let compiled = get_or_specialize_hir_with_options(
2420 cache,
2421 &config,
2422 || {
2423 build_qwen35_prefill_hidden_cache_hir_dynamic_ext(
2424 &cfg,
2425 weights,
2426 1,
2427 seq,
2428 runtime_mrope,
2429 mtp_logits_path,
2430 fast_mtp,
2431 fast_greedy,
2432 )
2433 .expect("dynamic hidden prefill HIR")
2434 .0
2435 },
2436 &compile_opts,
2437 |c| {
2438 for (name, data) in hidden_params {
2439 c.set_param(name, data);
2440 }
2441 upload_packed_opt(c, gguf_loader.as_mut(), hidden_packed, packed_bytes_cache)
2442 },
2443 )?;
2444 let outs = compiled.run(&feeds);
2445 let prompt_lens = vec![seq];
2446 let (trunk, mut cache, mtp_logits) = seed_cache_from_outputs(
2447 &self.cfg,
2448 self.batch,
2449 seq,
2450 &prompt_lens,
2451 outs,
2452 self.mtp_logits_path,
2453 self.fast_greedy_lm_head,
2454 )?;
2455 zero_prompt_padding_kv(&self.cfg, &mut cache, seq);
2456 self.decode_cache = Some(cache.clone());
2457 Ok((trunk, cache, mtp_logits))
2458 }
2459
2460 pub fn decode_get_logits(&mut self, token: u32) -> Result<Vec<f32>> {
2462 self.decode_forward_logits(token, false)
2463 }
2464
2465 pub fn decode_get_mtp_logits(&mut self, token: u32) -> Result<Vec<f32>> {
2467 if !self.mtp_logits_path {
2468 bail!("qwen35: decode_get_mtp_logits requires mtp_logits_path(true)");
2469 }
2470 self.decode_forward_logits(token, true)
2471 }
2472
2473 pub fn prefill_get_mtp_logits(&mut self, prompt_ids: &[u32]) -> Result<Vec<f32>> {
2475 self.predict_logits(prompt_ids)?
2476 .mtp_logits
2477 .ok_or_else(|| anyhow!("qwen35: MTP logits unavailable (enable_mtp?)"))
2478 }
2479
2480 fn decode_step(
2481 &mut self,
2482 cache: &mut Qwen35DecodeCache,
2483 tokens: &[u32],
2484 generated_per_row: &[usize],
2485 opts: SampleOpts,
2486 ) -> Result<Vec<u32>> {
2487 if self.fast_greedy_lm_head {
2488 let vocab = self.lm_vocab_size();
2489 let (trunk, _mtp) = self.decode_step_trunk_raw(cache, tokens, generated_per_row)?;
2490 if opts.greedy {
2491 return self.argmax_batch_from_hidden(&trunk);
2492 }
2493 if sample_lm_cap(opts, vocab) < vocab {
2494 return self.sample_batch_from_hidden(&trunk, opts);
2495 }
2496 }
2497 let logits = self.decode_forward_logits_batch(cache, tokens, generated_per_row, false)?;
2498 Ok(sample_logits_batch(
2499 &logits,
2500 self.lm_vocab_size(),
2501 self.batch,
2502 opts,
2503 ))
2504 }
2505
2506 fn decode_forward_logits(&mut self, token: u32, want_mtp: bool) -> Result<Vec<f32>> {
2507 let mut cache = self
2508 .decode_cache
2509 .take()
2510 .ok_or_else(|| anyhow!("qwen35: decode requires seeded cache"))?;
2511 let row_gen = vec![0usize; self.batch];
2512 let logits = self.decode_forward_logits_batch(&mut cache, &[token], &row_gen, want_mtp)?;
2513 self.decode_cache = Some(cache);
2514 Ok(logits)
2515 }
2516
2517 fn decode_forward_logits_batch(
2518 &mut self,
2519 cache: &mut Qwen35DecodeCache,
2520 tokens: &[u32],
2521 generated_per_row: &[usize],
2522 want_mtp: bool,
2523 ) -> Result<Vec<f32>> {
2524 let past_seq = cache.past_seq;
2525 let head_half = self.cfg.key_length / 2;
2526 let (cos, sin) = mrope_slice_at_pos(&self.cfg, past_seq, head_half);
2527
2528 let use_bucket = self
2529 .decode_compile_cache
2530 .as_ref()
2531 .and_then(|c| c.bucket_for(past_seq as u64))
2532 .is_some();
2533
2534 if use_bucket {
2535 let (logits, mtp_logits) =
2536 self.decode_step_bucketed(cache, tokens, generated_per_row, &cos, &sin)?;
2537 if want_mtp {
2538 mtp_logits.ok_or_else(|| anyhow!("mtp decode logits missing from bucketed graph"))
2539 } else {
2540 Ok(logits)
2541 }
2542 } else {
2543 let feeds_owned = decode_step_feeds(
2544 &self.cfg,
2545 cache,
2546 tokens,
2547 &cos,
2548 &sin,
2549 None,
2550 generated_per_row,
2551 )?;
2552 let feeds: Vec<(&str, &[f32])> = feeds_owned
2553 .iter()
2554 .map(|(k, v)| (k.as_str(), v.as_slice()))
2555 .collect();
2556 if !self.decode_graphs.contains_key(&past_seq) {
2557 let (hir, params, packed) = build_qwen35_decode_hir_ext(
2558 &self.cfg,
2559 self.weights.clone(),
2560 self.batch,
2561 past_seq,
2562 false,
2563 self.mtp_logits_path,
2564 self.fast_mtp,
2565 self.fast_greedy_lm_head,
2566 )?;
2567 let mut compiled = self.compile_hir_for_config(
2568 decode_config(self.batch, past_seq),
2569 &format!("decode_{past_seq}"),
2570 hir,
2571 )?;
2572 for (name, data) in ¶ms {
2573 compiled.set_param(name, data);
2574 }
2575 upload_packed_opt(
2576 &mut compiled,
2577 self.gguf_loader.as_mut(),
2578 &packed,
2579 &mut self.packed_bytes_cache,
2580 )?;
2581 self.decode_graphs.insert(past_seq, compiled);
2582 }
2583 let step = self.moe_refresh_step;
2584 let has_moe = self.moe_offload.is_some();
2585 let num_experts = self.cfg.num_experts;
2586 let moe_masks = self
2587 .moe_offload
2588 .as_ref()
2589 .map(|m| m.per_layer_resident_masks());
2590 self.bind_moe_host_weights();
2591 let outs = {
2592 let compiled = self.decode_graphs.get_mut(&past_seq).unwrap();
2593 if has_moe {
2594 compiled.enable_moe_topk_capture(num_experts);
2595 if let Some(layers) = &moe_masks {
2596 push_moe_residency(compiled, layers);
2597 }
2598 }
2599 compiled.run(&feeds)
2600 };
2601 if has_moe {
2602 let layers = {
2603 let compiled = self.decode_graphs.get_mut(&past_seq).unwrap();
2604 compiled.take_moe_topk_capture()
2605 };
2606 if let (Some(mo), Some(layers)) = (self.moe_offload.as_mut(), layers) {
2607 let store = self.moe_store.as_ref();
2608 let compiled = self.decode_graphs.get_mut(&past_seq).unwrap();
2609 if refresh_moe_from_capture(mo, store, compiled, &layers, step, false) {
2610 if let Some(store) = self.moe_store.as_ref() {
2611 store.apply_to_compiled(compiled);
2612 }
2613 }
2614 }
2615 }
2616 self.moe_refresh_step = step.saturating_add(1);
2617 let (trunk, mtp_logits) = advance_cache_from_decode_outputs(
2618 &self.cfg,
2619 cache,
2620 outs,
2621 None,
2622 self.mtp_logits_path,
2623 want_mtp,
2624 self.fast_greedy_lm_head,
2625 )?;
2626 let logits = self.trunk_to_logits(trunk, self.fast_greedy_lm_head)?;
2627 if want_mtp {
2628 mtp_logits.ok_or_else(|| anyhow!("mtp decode logits missing from decode graph"))
2629 } else {
2630 Ok(logits)
2631 }
2632 }
2633 }
2634
2635 fn decode_step_bucketed(
2636 &mut self,
2637 cache: &mut Qwen35DecodeCache,
2638 tokens: &[u32],
2639 generated_per_row: &[usize],
2640 cos: &[f32],
2641 sin: &[f32],
2642 ) -> Result<(Vec<f32>, Option<Vec<f32>>)> {
2643 let (trunk, mtp) =
2644 self.decode_step_bucketed_raw(cache, tokens, generated_per_row, cos, sin)?;
2645 let logits = self.trunk_to_logits(trunk, self.fast_greedy_lm_head)?;
2646 Ok((logits, mtp))
2647 }
2648
2649 fn decode_step_bucketed_raw(
2650 &mut self,
2651 cache: &mut Qwen35DecodeCache,
2652 tokens: &[u32],
2653 generated_per_row: &[usize],
2654 cos: &[f32],
2655 sin: &[f32],
2656 ) -> Result<(Vec<f32>, Option<Vec<f32>>)> {
2657 let past_seq = cache.past_seq;
2658 let upper = self.ensure_decode_bucket_compiled(past_seq as u64)?;
2659
2660 let feeds_owned = decode_step_feeds(
2661 &self.cfg,
2662 cache,
2663 tokens,
2664 cos,
2665 sin,
2666 Some(upper),
2667 generated_per_row,
2668 )?;
2669 let feeds: Vec<(&str, &[f32])> = feeds_owned
2670 .iter()
2671 .map(|(k, v)| (k.as_str(), v.as_slice()))
2672 .collect();
2673
2674 let decode_opts = self.bucketed_decode_compile_options();
2675 let cache_mut = self.decode_compile_cache.as_mut().unwrap();
2676 let (_u, compiled) = cache_mut
2677 .ensure_hir_with_params(
2678 past_seq as u64,
2679 |_| panic!("decode bucket must be compiled"),
2680 &decode_opts,
2681 )
2682 .expect("decode bucket missing after ensure");
2683 compiled.set_active_extent(Some((past_seq + 1, upper + 1)));
2685 let outs = compiled.run(&feeds);
2686 compiled.set_active_extent(None);
2687 advance_cache_from_decode_outputs(
2688 &self.cfg,
2689 cache,
2690 outs,
2691 Some(upper),
2692 self.mtp_logits_path,
2693 self.mtp_logits_path,
2694 self.fast_greedy_lm_head,
2695 )
2696 }
2697
2698 fn mrope_prefill_rope_feeds(&self, seq: usize) -> Vec<(String, Vec<f32>)> {
2699 if !self.runtime_mrope {
2700 return Vec::new();
2701 }
2702 let head_half = self.cfg.key_length / 2;
2703 let sections = self.mrope_section_positions.as_deref();
2704 let (cos, sin) = mrope_prefill_feeds(&self.cfg, seq, sections, head_half);
2705 vec![("rope_cos".into(), cos), ("rope_sin".into(), sin)]
2706 }
2707}
2708
2709impl rlx_cli::LmRunner for Qwen35Runner {
2710 fn family(&self) -> &'static str {
2711 "qwen35"
2712 }
2713 fn vocab_size(&self) -> usize {
2714 self.lm_vocab_size()
2715 }
2716 fn predict_logits(&mut self, prompt_ids: &[u32]) -> anyhow::Result<Vec<f32>> {
2717 let out = Qwen35Runner::predict_logits(self, prompt_ids)?;
2718 Ok(out.logits)
2719 }
2720 fn generate(
2721 &mut self,
2722 prompt_ids: &[u32],
2723 n_new: usize,
2724 on_token: &mut dyn FnMut(u32) -> bool,
2725 ) -> anyhow::Result<Vec<u32>> {
2726 Qwen35Runner::generate(self, prompt_ids, n_new, on_token)
2730 }
2731
2732 fn supports_multimodal(&self) -> bool {
2733 self.has_mmproj()
2737 }
2738
2739 fn generate_multimodal(
2740 &mut self,
2741 prompt: &str,
2742 rgb: &[u8],
2743 img_w: usize,
2744 img_h: usize,
2745 tokenizer: Option<&std::path::Path>,
2746 n_new: usize,
2747 on_token: &mut dyn FnMut(u32) -> bool,
2748 ) -> anyhow::Result<Vec<u32>> {
2749 Qwen35Runner::generate_multimodal(
2750 self, prompt, rgb, img_w, img_h, tokenizer, n_new, on_token,
2751 )
2752 }
2753}
2754
2755fn sample_logits_batch(logits: &[f32], vocab: usize, batch: usize, opts: SampleOpts) -> Vec<u32> {
2756 let mut out = Vec::with_capacity(batch);
2757 for b in 0..batch {
2758 let row = &logits[b * vocab..(b + 1) * vocab];
2759 out.push(sample_token(row, opts) as u32);
2760 }
2761 out
2762}