Skip to main content

rlx_qwen35/
runner.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! High-level runner for Qwen3.5 / Qwen3.6 (qwen35 architecture).
17
18use crate::cache::{
19    Qwen35DecodeCache, advance_cache_from_decode_outputs, decode_step_feeds, last_token_indices,
20    pack_input_ids, seed_cache_from_outputs, zero_prompt_padding_kv, zero_recurrent_inputs,
21};
22use crate::capabilities::validate_device;
23use crate::config::Qwen35Config;
24use crate::encode_prompt_auto;
25use crate::lm_head::{
26    greedy_lm_head_argmax, lm_head_logits_row, sample_lm_cap, sample_lm_head_from_hidden,
27};
28use crate::moe_offload::{MoeOffloadState, build_moe_offload};
29use crate::moe_store::{build_moe_expert_store, moe_host_bind_from_store};
30use crate::rope::{mrope_prefill_feeds, mrope_slice_at_pos};
31use crate::vision::{
32    MmProjConfig, MmProjWeights, MultimodalPrefill, MultimodalPrompt, Qwen35VisionEncoder,
33    load_vision_encoder,
34};
35use crate::weights::Qwen35Weights;
36use crate::{
37    PackedParams, build_qwen35_decode_hir_dynamic_ext, build_qwen35_decode_hir_ext,
38    build_qwen35_hir_sized_ext, build_qwen35_prefill_cache_hir_dynamic_ext,
39    build_qwen35_prefill_hidden_cache_hir_dynamic_ext,
40};
41use rlx_runtime::MoeExpertStore;
42
43fn push_moe_residency(compiled: &mut rlx_runtime::CompiledGraph, layers: &[Vec<bool>]) {
44    let refs: Vec<&[bool]> = layers.iter().map(|m| m.as_slice()).collect();
45    compiled.set_moe_resident_experts_per_layer(&refs);
46}
47
48fn refresh_moe_from_capture(
49    mo: &mut MoeOffloadState,
50    store: Option<&MoeExpertStore>,
51    compiled: &mut rlx_runtime::CompiledGraph,
52    layer_indices: &[Vec<u32>],
53    denoise_step: usize,
54    is_prefill_block: bool,
55) -> bool {
56    let refreshed = if let Some(store) = store {
57        mo.refresh_from_capture_with_store(store, layer_indices, denoise_step, is_prefill_block)
58    } else {
59        mo.refresh_from_capture(layer_indices, denoise_step, is_prefill_block)
60    };
61    if refreshed {
62        push_moe_residency(compiled, &mo.per_layer_resident_masks());
63    }
64    refreshed
65}
66use crate::execution::{
67    Qwen35CompileCache, decode_config, get_or_specialize_hir_with_options, hidden_prefill_config,
68    prefill_config,
69};
70use crate::flow::{Qwen35PrefillCacheOpts, build_qwen35_prefill_cache_built};
71use crate::profile::{qwen35_profile_default, qwen35_profile_near_weights};
72use anyhow::{Context, Result, anyhow, bail};
73use rlx_core::flow_bridge::compile_options_from_profile;
74use rlx_core::gguf_support::{GgufModelFamily, assert_gguf_family, resolve_weights_file};
75use rlx_core::weight_loader::GgufLoader;
76use rlx_flow::ModelExecutionConfig;
77use rlx_flow::{CompileProfile, ExecutionPreset};
78use rlx_ir::CompilationMode;
79use rlx_ir::logical_kernel::KernelDispatchConfig;
80use rlx_qwen3::sampling::{SampleOpts, sample_token};
81use rlx_runtime::compile_cache::BucketedCompileCache;
82use rlx_runtime::{AotCache, CompileOptions, Device, Session};
83use std::cell::RefCell;
84use std::collections::HashMap;
85use std::path::PathBuf;
86use std::sync::Arc;
87use std::time::Instant;
88
89/// Source for the Qwen3.5 / 3.6 config. Mirrors `Qwen3ConfigSource`
90/// so callers using `Qwen35RunnerBuilder` have the same shape as
91/// `Qwen3RunnerBuilder` (PLAN.md M1).
92///
93/// `JsonFile` and `Explicit` are wired into the future safetensors
94/// load path (catalog rows: `qwen35-{4b,9b,27b}-hauhau-aggressive`).
95/// Until the safetensors load lands, `build()` errors with a clear
96/// M1 follow-up message when these variants are used.
97/// Alias of `rlx_runtime::ConfigSource<Qwen35Config>`. `JsonFile` /
98/// `Explicit` variants are wired into the future safetensors load path
99/// (catalog rows: `qwen35-{4b,9b,27b}-hauhau-aggressive`). Until that
100/// path lands, `build()` errors with a clear M1 follow-up message when
101/// these variants are used.
102pub type Qwen35ConfigSource = rlx_runtime::ConfigSource<Qwen35Config>;
103
104#[derive(Default, Debug)]
105pub struct Qwen35RunnerBuilder {
106    weights: Option<PathBuf>,
107    config: Option<Qwen35ConfigSource>,
108    device: Option<Device>,
109    max_seq: Option<usize>,
110    enable_mtp: bool,
111    last_logits_only: bool,
112    /// `None` = auto-detect (packed when GGUF ≥ 256 MB to avoid the
113    /// F32-dequant memory explosion — a 4B Q3_K_S file is ~2 GB on
114    /// disk but ~16 GB dense-F32). `Some(true)` / `Some(false)` are
115    /// explicit user overrides.
116    packed_weights: Option<bool>,
117    runtime_mrope: bool,
118    mrope_section_positions: Option<Vec<[usize; 4]>>,
119    batch: Option<usize>,
120    bucketed_decode: Option<bool>,
121    /// Emit/consume MTP logits on the prefill-cache + decode path (draft speculator).
122    mtp_logits_path: bool,
123    fast_mtp: bool,
124    /// Skip LM head in decode graphs; argmax on host (default: true).
125    fast_greedy_lm_head: Option<bool>,
126    /// Persist optimized LIR under this directory (warm-start / AOT).
127    aot_cache_dir: Option<PathBuf>,
128    /// Compile prefill once with `sym::SEQ`, specialize per prompt length.
129    dynamic_prefill: bool,
130    /// Compile decode once with `sym::PAST_SEQ`, specialize per prefix length.
131    dynamic_decode: bool,
132    inline_weights: Option<(Qwen35Config, Qwen35Weights)>,
133    /// Optional mmproj GGUF for VLM vision encoding.
134    mmproj: Option<PathBuf>,
135    /// Inline mmproj weights (tests; mutually exclusive with [`Self::mmproj`]).
136    inline_mmproj: Option<(crate::vision::MmProjConfig, crate::vision::MmProjWeights)>,
137    /// Override tier-1 prefill profile (else `qwen35.rlx.toml` or defaults).
138    prefill_profile: Option<CompileProfile>,
139    /// Override tier-1 decode profile.
140    decode_profile: Option<CompileProfile>,
141    /// TIDE-style cap on GPU-resident experts per MoE layer (`max_gpu_experts_per_layer`).
142    max_gpu_experts_per_layer: Option<usize>,
143    /// Unified RAM / VRAM budget for auto expert cap (optional).
144    moe_memory_budget_bytes: Option<usize>,
145    /// Refresh expert placement every N decode/denoise steps (TIDE `jump_steps` τ).
146    expert_refresh_every_decode_steps: Option<usize>,
147    /// TIDE `jump_steps` alias (preferred name).
148    jump_steps: Option<usize>,
149    /// TIDE `reserve_vram_gb` (default 1.5).
150    reserve_vram_gb: Option<f64>,
151    /// TIDE `collect_stats` on MoE forwards.
152    moe_collect_stats: bool,
153}
154
155impl Qwen35RunnerBuilder {
156    pub fn weights(mut self, path: impl Into<PathBuf>) -> Self {
157        self.weights = Some(path.into());
158        self
159    }
160
161    /// Source for the Qwen3.5 / 3.6 config. Default
162    /// `Qwen35ConfigSource::Embedded` (GGUF metadata). PLAN.md M1
163    /// — `JsonFile` / `Explicit` reserve the API shape for the
164    /// safetensors load path; today they error in `build()` with a
165    /// follow-up message.
166    pub fn config(mut self, src: Qwen35ConfigSource) -> Self {
167        self.config = Some(src);
168        self
169    }
170
171    /// Convenience: explicit `Qwen35Config` (shorthand for
172    /// `.config(Qwen35ConfigSource::Explicit(cfg))`).
173    pub fn config_value(self, cfg: Qwen35Config) -> Self {
174        self.config(Qwen35ConfigSource::Explicit(cfg))
175    }
176    pub fn device(mut self, d: Device) -> Self {
177        self.device = Some(d);
178        self
179    }
180    pub fn max_seq(mut self, n: usize) -> Self {
181        self.max_seq = Some(n);
182        self
183    }
184    pub fn enable_mtp(mut self, on: bool) -> Self {
185        self.enable_mtp = on;
186        self
187    }
188    pub fn last_logits_only(mut self, on: bool) -> Self {
189        self.last_logits_only = on;
190        self
191    }
192    pub fn packed_weights(mut self, on: bool) -> Self {
193        self.packed_weights = Some(on);
194        self
195    }
196    /// Use runtime MRoPE cos/sin inputs instead of a baked table. Required
197    /// for multimodal prompts where section positions differ from `[p,p,p,0]`.
198    pub fn runtime_mrope(mut self, on: bool) -> Self {
199        self.runtime_mrope = on;
200        self
201    }
202    /// Per-token MRoPE section positions `[t,h,w,extra]` (length = prompt seq).
203    pub fn mrope_section_positions(mut self, positions: Vec<[usize; 4]>) -> Self {
204        self.mrope_section_positions = Some(positions);
205        self
206    }
207    /// Batch size for compiled graphs (default 1).
208    pub fn batch(mut self, n: usize) -> Self {
209        self.batch = Some(n);
210        self
211    }
212    /// Use power-of-two bucketed decode compile cache (default: true).
213    pub fn bucketed_decode(mut self, on: bool) -> Self {
214        self.bucketed_decode = Some(on);
215        self
216    }
217    /// Use MTP head logits on prefill-cache seeding and decode steps.
218    pub fn mtp_logits_path(mut self, on: bool) -> Self {
219        self.mtp_logits_path = on;
220        self
221    }
222    /// Trim MTP LM head to 32K vocab (llama.cpp FastMTP draft path).
223    pub fn fast_mtp(mut self, on: bool) -> Self {
224        self.fast_mtp = on;
225        self
226    }
227    /// Decode without graph LM head — host argmax over tied embedding (default on).
228    pub fn fast_greedy_lm_head(mut self, on: bool) -> Self {
229        self.fast_greedy_lm_head = Some(on);
230        self
231    }
232    /// Cache optimized LIR on disk for faster subsequent runs.
233    pub fn aot_cache_dir(mut self, path: impl Into<PathBuf>) -> Self {
234        self.aot_cache_dir = Some(path.into());
235        self
236    }
237    /// Use dynamic prefill specialization (symbolic seq; batch=1).
238    pub fn dynamic_prefill(mut self, on: bool) -> Self {
239        self.dynamic_prefill = on;
240        self
241    }
242    /// Use dynamic decode specialization (symbolic past_seq; batch=1).
243    pub fn dynamic_decode(mut self, on: bool) -> Self {
244        self.dynamic_decode = on;
245        self
246    }
247    /// Supply config + weights directly (tests/benches; no GGUF on disk).
248    pub fn inline_weights(mut self, cfg: Qwen35Config, weights: Qwen35Weights) -> Self {
249        self.inline_weights = Some((cfg, weights));
250        self
251    }
252
253    /// Load vision encoder weights from an mmproj GGUF (enables multimodal prefill).
254    pub fn mmproj(mut self, path: impl Into<PathBuf>) -> Self {
255        self.mmproj = Some(path.into());
256        self
257    }
258
259    /// Supply mmproj config + weights directly (tests; no mmproj GGUF on disk).
260    pub fn inline_mmproj(
261        mut self,
262        cfg: crate::vision::MmProjConfig,
263        weights: crate::vision::MmProjWeights,
264    ) -> Self {
265        self.inline_mmproj = Some((cfg, weights));
266        self
267    }
268
269    /// Override tier-1 compile profiles (skips `qwen35.rlx.toml` discovery).
270    pub fn with_compile_profiles(
271        mut self,
272        prefill: CompileProfile,
273        decode: CompileProfile,
274    ) -> Self {
275        self.prefill_profile = Some(prefill);
276        self.decode_profile = Some(decode);
277        self
278    }
279
280    /// Enable MoE expert offload (TIDE). Caps GPU-resident experts per layer; remainder on host.
281    pub fn max_gpu_experts_per_layer(mut self, n: usize) -> Self {
282        self.max_gpu_experts_per_layer = Some(n);
283        self
284    }
285
286    /// Memory budget for automatic expert cap (macOS: unified RAM when unset).
287    pub fn moe_memory_budget_bytes(mut self, bytes: usize) -> Self {
288        self.moe_memory_budget_bytes = Some(bytes);
289        self
290    }
291
292    /// Refresh expert GPU set every N decode steps (TIDE `jump_steps` for AR).
293    pub fn expert_refresh_every_decode_steps(mut self, n: usize) -> Self {
294        self.expert_refresh_every_decode_steps = Some(n);
295        self.jump_steps = Some(n);
296        self
297    }
298
299    /// TIDE `jump_steps` (τ): refresh expert placement every N denoise/decode steps.
300    pub fn jump_steps(mut self, n: usize) -> Self {
301        self.jump_steps = Some(n);
302        self.expert_refresh_every_decode_steps = Some(n);
303        self
304    }
305
306    /// TIDE `reserve_vram_gb` for GPU expert budget sizing (default 1.5).
307    pub fn reserve_vram_gb(mut self, gb: f64) -> Self {
308        self.reserve_vram_gb = Some(gb);
309        self
310    }
311
312    /// TIDE `collect_stats` — aggregate token/compute counters per forward.
313    pub fn moe_collect_stats(mut self, on: bool) -> Self {
314        self.moe_collect_stats = on;
315        self
316    }
317
318    /// TIDE `enable_predictive_expert_offload(max_gpu_experts_per_layer=…)`.
319    pub fn enable_predictive_expert_offload(mut self, max_gpu_experts_per_layer: usize) -> Self {
320        self.max_gpu_experts_per_layer = Some(max_gpu_experts_per_layer);
321        self
322    }
323
324    pub fn build(self) -> Result<Qwen35Runner> {
325        let device = self.device.unwrap_or(Device::Cpu);
326        let max_seq = self.max_seq.unwrap_or(128);
327        let batch = self.batch.unwrap_or(1);
328        if batch == 0 {
329            bail!("qwen35: batch must be >= 1");
330        }
331
332        // PLAN.md M1: safetensors load path for HauhauCS catalog rows
333        // (and any other Qwen3.5 / 3.6 HF safetensors checkpoint).
334        // When the caller picks `JsonFile(p)` or `Explicit(cfg)`, we
335        // detect a safetensors weights path, parse the HF config (or
336        // use the explicit one), open the file via the weight registry,
337        // wrap it in `HfTranslatingLoader` so GGUF-named lookups
338        // succeed against HF tensor names, and run the standard
339        // `Qwen35Weights::from_loader` drain.
340        if let Some(src) = self.config.as_ref()
341            && !matches!(src, Qwen35ConfigSource::Embedded)
342        {
343            let weights_path = self
344                .weights
345                .as_ref()
346                .ok_or_else(|| anyhow!("weights path required (call .weights(...))"))?
347                .clone();
348            let resolved = resolve_weights_file(&weights_path)?;
349            let ext = resolved
350                .extension()
351                .and_then(|s| s.to_str())
352                .unwrap_or("")
353                .to_ascii_lowercase();
354            if ext == "gguf" {
355                bail!(
356                    "qwen35: Qwen35ConfigSource::{:?} supplied with a GGUF weights file at \
357                     {:?} — drop the config source (use the default Embedded) so the GGUF \
358                     metadata is the source of truth",
359                    src,
360                    resolved
361                );
362            }
363            if self.packed_weights == Some(true) {
364                bail!("qwen35: packed_weights requires GGUF; safetensors path is dequant-only");
365            }
366            let cfg = match src {
367                Qwen35ConfigSource::Embedded => unreachable!(),
368                Qwen35ConfigSource::JsonFile(p) => Qwen35Config::from_hf_config_json(p)
369                    .with_context(|| format!("qwen35: parse HF config {p:?}"))?,
370                Qwen35ConfigSource::Explicit(cfg) => cfg.clone(),
371            };
372            if self.enable_mtp && cfg.nextn_predict_layers == 0 {
373                bail!(
374                    "qwen35: enable_mtp(true) but config has \
375                     nextn_predict_layers=0 (no MTP heads to wire)"
376                );
377            }
378            validate_device(&cfg, device, false)?;
379            let path_str = resolved
380                .to_str()
381                .ok_or_else(|| anyhow!("non-utf8 weights path"))?;
382            let inner_map = rlx_core::weight_map::WeightMap::from_file(path_str)
383                .with_context(|| format!("qwen35: load safetensors {resolved:?}"))?;
384            let mut loader = rlx_core::HfTranslatingLoader::new(inner_map);
385            let t = std::time::Instant::now();
386            let weights = Qwen35Weights::from_loader(&mut loader, &cfg)?;
387            eprintln!(
388                "[qwen35] read safetensors weights in {:.2?} \
389                 (layers={}, hidden={})",
390                t.elapsed(),
391                cfg.num_hidden_layers,
392                cfg.hidden_size,
393            );
394            return finish_build(
395                cfg,
396                weights,
397                resolved,
398                None,
399                device,
400                max_seq,
401                batch,
402                self.enable_mtp,
403                self.last_logits_only,
404                self.runtime_mrope,
405                self.mrope_section_positions,
406                self.bucketed_decode,
407                self.mtp_logits_path,
408                self.fast_mtp,
409                self.fast_greedy_lm_head.unwrap_or(true),
410                self.aot_cache_dir.clone(),
411                self.dynamic_prefill,
412                self.dynamic_decode,
413                self.mmproj.clone(),
414                self.inline_mmproj,
415                self.prefill_profile,
416                self.decode_profile,
417                self.max_gpu_experts_per_layer,
418                self.moe_memory_budget_bytes,
419                self.jump_steps.or(self.expert_refresh_every_decode_steps),
420                self.reserve_vram_gb.unwrap_or(1.5),
421                self.moe_collect_stats,
422            );
423        }
424
425        if let Some((cfg, weights)) = self.inline_weights {
426            if self.packed_weights == Some(true) {
427                bail!("qwen35: inline_weights and packed_weights are mutually exclusive");
428            }
429            if self.enable_mtp && cfg.nextn_predict_layers == 0 {
430                bail!(
431                    "qwen35: enable_mtp(true) but config has \
432                     nextn_predict_layers=0 (no MTP heads to wire)"
433                );
434            }
435            if self.mmproj.is_some() && self.inline_mmproj.is_some() {
436                bail!("qwen35: mmproj and inline_mmproj are mutually exclusive");
437            }
438            validate_device(&cfg, device, false)?;
439            return finish_build(
440                cfg,
441                weights,
442                PathBuf::new(),
443                None,
444                device,
445                max_seq,
446                batch,
447                self.enable_mtp,
448                self.last_logits_only,
449                self.runtime_mrope,
450                self.mrope_section_positions,
451                self.bucketed_decode,
452                self.mtp_logits_path,
453                self.fast_mtp,
454                self.fast_greedy_lm_head.unwrap_or(true),
455                self.aot_cache_dir.clone(),
456                self.dynamic_prefill,
457                self.dynamic_decode,
458                self.mmproj.clone(),
459                self.inline_mmproj,
460                self.prefill_profile,
461                self.decode_profile,
462                self.max_gpu_experts_per_layer,
463                self.moe_memory_budget_bytes,
464                self.jump_steps.or(self.expert_refresh_every_decode_steps),
465                self.reserve_vram_gb.unwrap_or(1.5),
466                self.moe_collect_stats,
467            );
468        }
469
470        let weights_path = resolve_weights_file(
471            &self
472                .weights
473                .ok_or_else(|| anyhow!("weights path required (call .weights(...))"))?,
474        )?;
475        let _t_total = Instant::now();
476        let t = Instant::now();
477        let raw = assert_gguf_family(&weights_path, GgufModelFamily::Qwen35)?;
478        let mut loader = GgufLoader::from_file(
479            weights_path
480                .to_str()
481                .ok_or_else(|| anyhow!("non-utf8 weights path"))?,
482        )?;
483        loader.include_mtp(true);
484        let cfg = Qwen35Config::from_gguf(&raw)?;
485        eprintln!(
486            "[qwen35] loaded GGUF metadata in {:.2?} \
487             (layers={}, hidden={}, ssm_state={})",
488            t.elapsed(),
489            cfg.num_hidden_layers,
490            cfg.hidden_size,
491            cfg.ssm_state_size,
492        );
493
494        if self.enable_mtp && cfg.nextn_predict_layers == 0 {
495            bail!(
496                "qwen35: enable_mtp(true) but the file has \
497                 nextn_predict_layers=0 (no MTP heads to wire)"
498            );
499        }
500        // Resolve auto-default. Llama.cpp keeps K-quant tensors packed
501        // in memory and dequantises per block inside the matmul kernel —
502        // it never materialises a dense F32 weight matrix. Mirror that:
503        // when *any* tensor in the GGUF is a K-quant block format
504        // (Q2_K..Q8_K), force the packed path. Otherwise fall back to
505        // the size heuristic (≥ 256 MB → packed) for legacy quant
506        // formats. Explicit `.packed_weights(_)` overrides.
507        let packed = self.packed_weights.unwrap_or_else(|| {
508            if raw.tensors.values().any(|t| {
509                matches!(
510                    t.dtype,
511                    rlx_gguf::GgmlType::Q2K
512                        | rlx_gguf::GgmlType::Q3K
513                        | rlx_gguf::GgmlType::Q4K
514                        | rlx_gguf::GgmlType::Q5K
515                        | rlx_gguf::GgmlType::Q6K
516                        | rlx_gguf::GgmlType::Q8K
517                )
518            }) {
519                return true;
520            }
521            std::fs::metadata(&weights_path)
522                .ok()
523                .map(|m| m.len() >= 256 * 1024 * 1024)
524                .unwrap_or(false)
525        });
526        validate_device(&cfg, device, packed)?;
527
528        let t = Instant::now();
529        let weights = if packed {
530            Qwen35Weights::from_loader_packed(&mut loader, &cfg)?
531        } else {
532            Qwen35Weights::from_loader(&mut loader, &cfg)?
533        };
534        eprintln!(
535            "[qwen35] read weights ({}) in {:.2?}",
536            if packed { "packed" } else { "F32" },
537            t.elapsed(),
538        );
539
540        finish_build(
541            cfg,
542            weights,
543            weights_path,
544            Some(loader),
545            device,
546            max_seq,
547            batch,
548            self.enable_mtp,
549            self.last_logits_only,
550            self.runtime_mrope,
551            self.mrope_section_positions,
552            self.bucketed_decode,
553            self.mtp_logits_path,
554            self.fast_mtp,
555            self.fast_greedy_lm_head.unwrap_or(true),
556            self.aot_cache_dir.clone(),
557            self.dynamic_prefill,
558            self.dynamic_decode,
559            self.mmproj.clone(),
560            self.inline_mmproj,
561            self.prefill_profile,
562            self.decode_profile,
563            self.max_gpu_experts_per_layer,
564            self.moe_memory_budget_bytes,
565            self.jump_steps.or(self.expert_refresh_every_decode_steps),
566            self.reserve_vram_gb.unwrap_or(1.5),
567            self.moe_collect_stats,
568        )
569    }
570}
571
572fn make_qwen35_dyn_cache(
573    device: Device,
574    capacity: usize,
575    aot_cache_dir: Option<&std::path::Path>,
576) -> Qwen35CompileCache {
577    if let Some(dir) = aot_cache_dir {
578        Qwen35CompileCache::with_aot(device, capacity, dir)
579    } else {
580        Qwen35CompileCache::new(device, capacity)
581    }
582}
583
584/// Static prefill-cache compile via tier-0 [`BuiltModel`] + [`Qwen35CompileCache`].
585fn compile_static_prefill_cache(
586    cfg: &Qwen35Config,
587    weights: Qwen35Weights,
588    batch: usize,
589    max_seq: usize,
590    device: Device,
591    prefill_profile: &CompileProfile,
592    runtime_mrope: bool,
593    enable_mtp_head: bool,
594    fast_mtp: bool,
595    fast_greedy_lm_head: bool,
596    aot_cache_dir: Option<&std::path::Path>,
597) -> Result<(
598    rlx_runtime::CompiledGraph,
599    HashMap<String, Vec<f32>>,
600    PackedParams,
601)> {
602    let mut flow_opts = Qwen35PrefillCacheOpts::static_cache(batch, max_seq);
603    flow_opts.with_lm_head = !fast_greedy_lm_head;
604    flow_opts.runtime_mrope = runtime_mrope;
605    flow_opts.enable_mtp_head = enable_mtp_head;
606    flow_opts.fast_mtp = fast_mtp;
607    flow_opts.fast_greedy_lm_head = fast_greedy_lm_head;
608    flow_opts.profile = Some(prefill_profile.clone());
609
610    let (built, packed) = build_qwen35_prefill_cache_built(cfg, weights, &flow_opts)?;
611    let params = built.params().clone();
612    let config = prefill_config(batch, max_seq);
613    let compile_opts =
614        compile_options_from_profile(prefill_profile, device, KernelDispatchConfig::default());
615
616    let mut cache = match aot_cache_dir {
617        Some(dir) => Qwen35CompileCache::with_aot(device, 1, dir),
618        None => Qwen35CompileCache::new(device, 1),
619    };
620    let mut config = config;
621    if aot_cache_dir.is_some() {
622        config = config.with_compilation_mode(CompilationMode::Aot);
623    }
624    let built = built.with_execution_config(&config);
625    let compiled = cache.compile_built(built, &config, &compile_opts)?;
626    Ok((compiled, params, packed))
627}
628
629#[allow(clippy::too_many_arguments)]
630fn finish_build(
631    cfg: Qwen35Config,
632    weights: Qwen35Weights,
633    weights_path: PathBuf,
634    gguf_loader: Option<GgufLoader>,
635    device: Device,
636    max_seq: usize,
637    batch: usize,
638    enable_mtp: bool,
639    last_logits_only: bool,
640    runtime_mrope: bool,
641    mrope_section_positions: Option<Vec<[usize; 4]>>,
642    bucketed_decode: Option<bool>,
643    mtp_logits_path: bool,
644    fast_mtp: bool,
645    fast_greedy_lm_head: bool,
646    aot_cache_dir: Option<PathBuf>,
647    dynamic_prefill: bool,
648    dynamic_decode: bool,
649    mmproj_path: Option<PathBuf>,
650    inline_mmproj: Option<(MmProjConfig, MmProjWeights)>,
651    prefill_profile_override: Option<CompileProfile>,
652    decode_profile_override: Option<CompileProfile>,
653    max_gpu_experts_per_layer: Option<usize>,
654    moe_memory_budget_bytes: Option<usize>,
655    jump_steps: Option<usize>,
656    reserve_vram_gb: f64,
657    moe_collect_stats: bool,
658) -> Result<Qwen35Runner> {
659    let prefill_profile = prefill_profile_override.unwrap_or_else(|| {
660        if weights_path.as_os_str().is_empty() {
661            qwen35_profile_default(false)
662        } else {
663            qwen35_profile_near_weights(&weights_path, false)
664        }
665    });
666    let decode_profile = decode_profile_override.unwrap_or_else(|| {
667        if weights_path.as_os_str().is_empty() {
668            qwen35_profile_default(true)
669        } else {
670            qwen35_profile_near_weights(&weights_path, true)
671        }
672    });
673
674    if fast_mtp && !mtp_logits_path && !enable_mtp {
675        bail!("qwen35: fast_mtp requires enable_mtp(true) or mtp_logits_path(true)");
676    }
677    if mtp_logits_path && !enable_mtp {
678        bail!("qwen35: mtp_logits_path requires enable_mtp(true)");
679    }
680
681    if dynamic_prefill && batch != 1 {
682        bail!("qwen35: dynamic_prefill requires batch=1");
683    }
684    if dynamic_decode && batch != 1 {
685        bail!("qwen35: dynamic_decode requires batch=1");
686    }
687    if dynamic_decode && bucketed_decode.unwrap_or(true) {
688        eprintln!("[qwen35] dynamic_decode enabled — disabling bucketed decode cache");
689    }
690    let bucketed_decode = if dynamic_decode {
691        false
692    } else {
693        bucketed_decode.unwrap_or(true)
694    };
695
696    let vision_encoder = if let Some(ref path) = mmproj_path {
697        Some(load_vision_encoder(
698            path.to_str()
699                .ok_or_else(|| anyhow!("non-utf8 mmproj path"))?,
700            224,
701            224,
702        )?)
703    } else if let Some((vcfg, vweights)) = inline_mmproj {
704        Some(Qwen35VisionEncoder::from_parts(vcfg, vweights, 4, 4)?)
705    } else {
706        None
707    };
708    let runtime_mrope = runtime_mrope || vision_encoder.is_some();
709    if vision_encoder.is_some() && batch != 1 {
710        bail!("qwen35: VLM (mmproj) requires batch=1");
711    }
712    if vision_encoder.is_some() && !dynamic_prefill {
713        eprintln!("[qwen35] mmproj loaded — enabling dynamic prefill for variable multimodal seq");
714    }
715    let dynamic_prefill = dynamic_prefill || vision_encoder.is_some();
716
717    let t = Instant::now();
718    let aot = aot_cache_dir.as_ref().map(AotCache::new);
719    let (cache_params, cache_packed, mut prefill_cache, prefill_dynamic_cache) = if dynamic_prefill
720    {
721        let (_cache_hir, cache_params, cache_packed) = build_qwen35_prefill_cache_hir_dynamic_ext(
722            &cfg,
723            weights.clone(),
724            batch,
725            max_seq,
726            runtime_mrope,
727            mtp_logits_path,
728            fast_mtp,
729            fast_greedy_lm_head,
730        )?;
731        eprintln!(
732            "[qwen35] built prefill-cache IR in {:.2?} (params={}, packed={})",
733            t.elapsed(),
734            cache_params.len(),
735            cache_packed.len(),
736        );
737        eprintln!("[qwen35] dynamic prefill template ready (compile on first prompt)");
738        (
739            cache_params,
740            cache_packed,
741            None,
742            Some(make_qwen35_dyn_cache(device, 32, aot_cache_dir.as_deref())),
743        )
744    } else {
745        let (compiled, cache_params, cache_packed) = compile_static_prefill_cache(
746            &cfg,
747            weights.clone(),
748            batch,
749            max_seq,
750            device,
751            &prefill_profile,
752            runtime_mrope,
753            mtp_logits_path,
754            fast_mtp,
755            fast_greedy_lm_head,
756            aot_cache_dir.as_deref(),
757        )?;
758        eprintln!(
759            "[qwen35] compiled prefill-cache via BuiltModel in {:.2?} (params={}, packed={})",
760            t.elapsed(),
761            cache_params.len(),
762            cache_packed.len(),
763        );
764        (cache_params, cache_packed, Some(compiled), None)
765    };
766
767    let (prefill_hidden_dynamic_cache, prefill_hidden_cache_params, prefill_hidden_cache_packed) =
768        if vision_encoder.is_some() {
769            let (hidden_hir, hidden_params, hidden_packed) =
770                build_qwen35_prefill_hidden_cache_hir_dynamic_ext(
771                    &cfg,
772                    weights.clone(),
773                    batch,
774                    max_seq,
775                    runtime_mrope,
776                    mtp_logits_path,
777                    fast_mtp,
778                    fast_greedy_lm_head,
779                )?;
780            let _ = hidden_hir;
781            (
782                Some(make_qwen35_dyn_cache(device, 32, aot_cache_dir.as_deref())),
783                hidden_params,
784                hidden_packed,
785            )
786        } else {
787            (None, HashMap::new(), HashMap::new())
788        };
789
790    let t = Instant::now();
791    if let Some(ref mut compiled) = prefill_cache {
792        for (name, data) in &cache_params {
793            compiled.set_param(name, data);
794        }
795    }
796
797    let decode_compile_cache = if bucketed_decode {
798        Some(BucketedCompileCache::power_of_two_ladder(
799            device,
800            1,
801            max_seq.max(1) as u64,
802        ))
803    } else {
804        None
805    };
806    let decode_dynamic_cache = if dynamic_decode {
807        Some(make_qwen35_dyn_cache(device, 32, aot_cache_dir.as_deref()))
808    } else {
809        None
810    };
811    let (decode_dynamic_params, decode_dynamic_packed) = if dynamic_decode {
812        let (_, p, packed) = build_qwen35_decode_hir_dynamic_ext(
813            &cfg,
814            weights.clone(),
815            batch,
816            max_seq,
817            mtp_logits_path,
818            fast_mtp,
819            fast_greedy_lm_head,
820        )?;
821        (p, packed)
822    } else {
823        (HashMap::new(), HashMap::new())
824    };
825
826    if dynamic_decode {
827        eprintln!("[qwen35] dynamic decode template ready (compile on first step)");
828    }
829
830    let moe_offload = build_moe_offload(
831        &cfg,
832        &weights,
833        max_gpu_experts_per_layer,
834        moe_memory_budget_bytes,
835        jump_steps,
836        reserve_vram_gb,
837        moe_collect_stats,
838    );
839    let moe_store = if moe_offload.is_some() {
840        build_moe_expert_store(&cfg, &weights).ok()
841    } else {
842        None
843    };
844    if let Some(ref mo) = moe_offload {
845        eprintln!(
846            "[qwen35] TIDE MoE offload: layers={} gpu_budget={}/{} jump_steps={} reserve_bytes={}",
847            mo.num_layers(),
848            mo.info.gpu_expert_budget_per_layer,
849            mo.pools[0].num_experts(),
850            mo.jump_steps,
851            mo.info.reserve_bytes,
852        );
853    }
854
855    let mut runner = Qwen35Runner {
856        compiled: None,
857        prefill_cache,
858        prefill_dynamic_cache,
859        prefill_hidden_dynamic_cache,
860        prefill_cache_params: cache_params,
861        prefill_cache_packed: cache_packed,
862        prefill_hidden_cache_params,
863        prefill_hidden_cache_packed,
864        decode_graphs: HashMap::new(),
865        decode_compile_cache,
866        decode_dynamic_cache,
867        predict_hir_cache: None,
868        decode_dynamic_params,
869        decode_dynamic_packed,
870        packed_bytes_cache: HashMap::new(),
871        cfg,
872        device,
873        batch,
874        max_seq,
875        last_logits_only,
876        enable_mtp,
877        mtp_logits_path,
878        fast_mtp,
879        fast_greedy_lm_head,
880        weights,
881        weights_path,
882        gguf_loader,
883        decode_cache: None,
884        runtime_mrope,
885        mrope_section_positions,
886        aot_cache: aot,
887        dynamic_prefill,
888        dynamic_decode,
889        vision_encoder,
890        mmproj_path,
891        prefill_profile,
892        decode_profile,
893        moe_offload,
894        moe_store,
895        moe_refresh_step: 0,
896    };
897
898    if let Some(ref mut compiled) = runner.prefill_cache {
899        upload_packed_opt(
900            compiled,
901            runner.gguf_loader.as_mut(),
902            &runner.prefill_cache_packed,
903            &mut runner.packed_bytes_cache,
904        )?;
905    }
906    eprintln!(
907        "[qwen35] uploaded prefill-cache {} F32 + {} packed params in {:.2?}",
908        runner.prefill_cache_params.len(),
909        runner.prefill_cache_packed.len(),
910        t.elapsed(),
911    );
912
913    runner.warm_decode_graphs()?;
914    runner.warm_predict_graph()?;
915    Ok(runner)
916}
917
918fn ensure_packed_cache(
919    loader: &mut GgufLoader,
920    packed: &PackedParams,
921    cache: &mut HashMap<String, Arc<[u8]>>,
922) -> Result<()> {
923    for (loader_key, _, _) in packed.values() {
924        if cache.contains_key(loader_key) {
925            continue;
926        }
927        let bytes = loader
928            .tensor_bytes_borrowed(loader_key)
929            .ok_or_else(|| anyhow!("packed cache: {loader_key} bytes missing"))?;
930        cache.insert(loader_key.clone(), Arc::from(bytes));
931    }
932    Ok(())
933}
934
935fn upload_packed_opt(
936    compiled: &mut rlx_runtime::CompiledGraph,
937    loader: Option<&mut GgufLoader>,
938    packed: &PackedParams,
939    cache: &mut HashMap<String, Arc<[u8]>>,
940) -> Result<()> {
941    if packed.is_empty() {
942        return Ok(());
943    }
944    let loader = loader
945        .ok_or_else(|| anyhow!("packed params require a GGUF loader (missing weights path)"))?;
946    ensure_packed_cache(loader, packed, cache)?;
947    for (param_name, (loader_key, _scheme, _shape)) in packed {
948        let bytes = cache
949            .get(loader_key)
950            .ok_or_else(|| anyhow!("packed upload: cache miss for {loader_key}"))?;
951        compiled.set_param_typed(param_name, bytes, rlx_ir::DType::U8);
952    }
953    Ok(())
954}
955
956#[allow(dead_code)]
957fn upload_decode_packed(
958    weights_path: &std::path::Path,
959    compiled: &mut rlx_runtime::CompiledGraph,
960    packed: &PackedParams,
961) -> Result<()> {
962    if packed.is_empty() {
963        return Ok(());
964    }
965    let path = weights_path
966        .to_str()
967        .filter(|p| !p.is_empty())
968        .ok_or_else(|| anyhow!("packed decode params require a GGUF weights path"))?;
969    let mut loader = GgufLoader::from_file(path)?;
970    loader.include_mtp(true);
971    upload_packed_opt(compiled, Some(&mut loader), packed, &mut HashMap::new())
972}
973
974pub struct Qwen35Runner {
975    compiled: Option<rlx_runtime::CompiledGraph>,
976    prefill_cache: Option<rlx_runtime::CompiledGraph>,
977    prefill_dynamic_cache: Option<Qwen35CompileCache>,
978    prefill_hidden_dynamic_cache: Option<Qwen35CompileCache>,
979    prefill_cache_params: HashMap<String, Vec<f32>>,
980    prefill_cache_packed: PackedParams,
981    prefill_hidden_cache_params: HashMap<String, Vec<f32>>,
982    prefill_hidden_cache_packed: PackedParams,
983    decode_graphs: HashMap<usize, rlx_runtime::CompiledGraph>,
984    decode_compile_cache: Option<BucketedCompileCache>,
985    decode_dynamic_cache: Option<Qwen35CompileCache>,
986    /// Predict / reprefill HIR (must not share a template with decode graphs).
987    predict_hir_cache: Option<Qwen35CompileCache>,
988    decode_dynamic_params: HashMap<String, Vec<f32>>,
989    decode_dynamic_packed: PackedParams,
990    packed_bytes_cache: HashMap<String, Arc<[u8]>>,
991    cfg: Qwen35Config,
992    device: Device,
993    batch: usize,
994    max_seq: usize,
995    last_logits_only: bool,
996    enable_mtp: bool,
997    mtp_logits_path: bool,
998    fast_mtp: bool,
999    fast_greedy_lm_head: bool,
1000    weights: Qwen35Weights,
1001    weights_path: PathBuf,
1002    gguf_loader: Option<GgufLoader>,
1003    decode_cache: Option<Qwen35DecodeCache>,
1004    runtime_mrope: bool,
1005    mrope_section_positions: Option<Vec<[usize; 4]>>,
1006    aot_cache: Option<AotCache>,
1007    dynamic_prefill: bool,
1008    dynamic_decode: bool,
1009    vision_encoder: Option<Qwen35VisionEncoder>,
1010    mmproj_path: Option<PathBuf>,
1011    prefill_profile: CompileProfile,
1012    decode_profile: CompileProfile,
1013    /// TIDE-style per-layer expert offload.
1014    moe_offload: Option<MoeOffloadState>,
1015    /// Host expert stacks (migration source; F32 MoE only).
1016    moe_store: Option<MoeExpertStore>,
1017    /// Decode-step counter for MoE refresh scheduling (TIDE τ).
1018    moe_refresh_step: usize,
1019}
1020
1021#[derive(Debug, Clone)]
1022pub struct Qwen35PrefillSeed {
1023    pub trunk_logits: Vec<f32>,
1024    pub mtp_logits: Option<Vec<f32>>,
1025}
1026
1027#[derive(Debug, Clone)]
1028pub struct Qwen35PrefillOutput {
1029    pub logits: Vec<f32>,
1030    pub mtp_logits: Option<Vec<f32>>,
1031    pub vocab_size: usize,
1032}
1033
1034impl Qwen35Runner {
1035    pub fn builder() -> Qwen35RunnerBuilder {
1036        Qwen35RunnerBuilder::default()
1037    }
1038
1039    /// Whether an mmproj vision encoder (or its weights) is wired up,
1040    /// allowing [`Self::generate_multimodal`] to splice image embeddings
1041    /// into the prefill. Backs the `LmRunner::supports_multimodal` hook.
1042    pub fn has_mmproj(&self) -> bool {
1043        self.mmproj_path.is_some() || self.vision_encoder.is_some()
1044    }
1045
1046    /// Apply runner AOT settings and build compile options for a dynamic specialize path.
1047    fn execution_config(&self, config: ModelExecutionConfig) -> ModelExecutionConfig {
1048        if self.aot_cache.is_some() {
1049            config.with_compilation_mode(CompilationMode::Aot)
1050        } else {
1051            config
1052        }
1053    }
1054
1055    pub fn prefill_profile(&self) -> &CompileProfile {
1056        &self.prefill_profile
1057    }
1058
1059    pub fn decode_profile(&self) -> &CompileProfile {
1060        &self.decode_profile
1061    }
1062
1063    /// MoE offload state when enabled at build time.
1064    pub fn moe_offload(&self) -> Option<&MoeOffloadState> {
1065        self.moe_offload.as_ref()
1066    }
1067
1068    /// TIDE `enable_predictive_expert_offload` return payload (when offload is active).
1069    pub fn predictive_offload_info(&self) -> Option<&rlx_llada2::tide::PredictiveOffloadInfo> {
1070        self.moe_offload.as_ref().map(|m| &m.info)
1071    }
1072
1073    /// TIDE `get_offload_stats()` — pool promotions/demotions + last-forward residency (CPU).
1074    pub fn get_offload_stats(
1075        &self,
1076        residency: Option<&rlx_runtime::MoeResidencyStats>,
1077    ) -> rlx_llada2::tide::TideOffloadStats {
1078        self.moe_offload
1079            .as_ref()
1080            .map(|m| m.tide_offload_stats(residency))
1081            .unwrap_or_default()
1082    }
1083
1084    pub fn jump_steps(&self) -> usize {
1085        self.moe_offload.as_ref().map(|m| m.jump_steps).unwrap_or(1)
1086    }
1087
1088    pub fn predictive_offload_enabled(&self) -> bool {
1089        self.moe_offload
1090            .as_ref()
1091            .is_some_and(|m| m.predictive_enabled)
1092    }
1093
1094    pub fn moe_offload_mut(&mut self) -> Option<&mut MoeOffloadState> {
1095        self.moe_offload.as_mut()
1096    }
1097
1098    /// MoE refresh step index (TIDE `step` in block denoise / decode loop).
1099    pub fn moe_refresh_step(&self) -> usize {
1100        self.moe_refresh_step
1101    }
1102
1103    /// Enable TopK capture on a compiled graph (CPU; call once after compile).
1104    pub fn enable_moe_topk_on(&self, compiled: &mut rlx_runtime::CompiledGraph) {
1105        if self.moe_offload.is_some() {
1106            compiled.enable_moe_topk_capture(self.cfg.num_experts);
1107        }
1108    }
1109
1110    /// Push per-layer TIDE residency masks into the compiled graph.
1111    pub fn sync_moe_residency(&self, compiled: &mut rlx_runtime::CompiledGraph) {
1112        if let Some(mo) = &self.moe_offload {
1113            push_moe_residency(compiled, &mo.per_layer_resident_masks());
1114        }
1115    }
1116
1117    #[allow(dead_code)]
1118    fn moe_prepare_forward(&self, compiled: &mut rlx_runtime::CompiledGraph) {
1119        self.bind_moe_host_weights();
1120        if self.moe_offload.is_some() {
1121            compiled.enable_moe_topk_capture(self.cfg.num_experts);
1122            self.sync_moe_residency(compiled);
1123        }
1124    }
1125
1126    #[allow(dead_code)]
1127    fn moe_finish_forward(
1128        &mut self,
1129        compiled: &mut rlx_runtime::CompiledGraph,
1130        denoise_step: usize,
1131        is_prefill_block: bool,
1132    ) -> bool {
1133        let Some(layers) = compiled.take_moe_topk_capture() else {
1134            return false;
1135        };
1136        let store = self.moe_store.clone();
1137        let Some(mo) = self.moe_offload.as_mut() else {
1138            return false;
1139        };
1140        let refreshed = if let Some(store) = store.as_ref() {
1141            mo.refresh_from_capture_with_store(store, &layers, denoise_step, is_prefill_block)
1142        } else {
1143            mo.refresh_from_capture(&layers, denoise_step, is_prefill_block)
1144        };
1145        if refreshed {
1146            push_moe_residency(compiled, &mo.per_layer_resident_masks());
1147        }
1148        refreshed
1149    }
1150
1151    /// Install per-expert host pointers for CPU GroupedMatMul fallback (TIDE).
1152    fn bind_moe_host_weights(&self) {
1153        if self.moe_offload.is_none() {
1154            rlx_cpu::moe_residency::bind_host_weights(None);
1155            return;
1156        }
1157        if let Some(store) = &self.moe_store {
1158            rlx_cpu::moe_residency::bind_host_weights(Some(moe_host_bind_from_store(store)));
1159        } else {
1160            rlx_cpu::moe_residency::bind_host_weights(None);
1161        }
1162    }
1163
1164    /// After forward: refresh pools from captured TopK and update graph mask.
1165    pub fn moe_offload_after_forward(&mut self, compiled: &mut rlx_runtime::CompiledGraph) -> bool {
1166        let Some(mo) = self.moe_offload.as_mut() else {
1167            return false;
1168        };
1169        let Some(layers) = compiled.take_moe_topk_capture() else {
1170            return false;
1171        };
1172        let refreshed = mo.refresh_from_capture(&layers, self.moe_refresh_step, false);
1173        if refreshed {
1174            self.sync_moe_residency(compiled);
1175        }
1176        self.moe_refresh_step = self.moe_refresh_step.saturating_add(1);
1177        refreshed
1178    }
1179
1180    /// Manual refresh from flat expert indices (single shared indices for all layers).
1181    pub fn moe_refresh_after_forward(&mut self, expert_idx: &[u32]) -> bool {
1182        let Some(mo) = self.moe_offload.as_mut() else {
1183            return false;
1184        };
1185        let refresh = mo.pools[0].should_refresh(
1186            rlx_runtime::MoEExecMode::Reuse,
1187            self.moe_refresh_step,
1188            false,
1189        );
1190        if refresh {
1191            for pool in &mut mo.pools {
1192                pool.refresh_from_indices(expert_idx);
1193            }
1194        }
1195        self.moe_refresh_step = self.moe_refresh_step.saturating_add(1);
1196        refresh
1197    }
1198
1199    /// Override tier-1 profiles after build (e.g. tests).
1200    pub fn with_compile_profiles(
1201        mut self,
1202        prefill: CompileProfile,
1203        decode: CompileProfile,
1204    ) -> Self {
1205        self.prefill_profile = prefill;
1206        self.decode_profile = decode;
1207        self
1208    }
1209
1210    fn profile_compile_options(&self, decode: bool) -> CompileOptions {
1211        let profile = if decode {
1212            &self.decode_profile
1213        } else {
1214            &self.prefill_profile
1215        };
1216        compile_options_from_profile(profile, self.device, KernelDispatchConfig::default())
1217    }
1218
1219    fn dyn_compile_options(&self, config: &ModelExecutionConfig) -> CompileOptions {
1220        let decode = matches!(config.preset, ExecutionPreset::Qwen35Decode);
1221        let mut opts = self.profile_compile_options(decode);
1222        opts.kernel_dispatch = config.component().kernel_dispatch;
1223        opts.dim_binding(config.dim_binding())
1224    }
1225
1226    fn bucketed_decode_compile_options(&self) -> CompileOptions {
1227        self.profile_compile_options(true)
1228    }
1229
1230    /// Compile a tier-0 prefill [`rlx_flow::BuiltModel`] through [`Qwen35CompileCache`].
1231    pub fn compile_prefill_built(
1232        &self,
1233        cache: &mut Qwen35CompileCache,
1234        built: rlx_flow::BuiltModel,
1235        batch: usize,
1236        seq: usize,
1237    ) -> Result<rlx_runtime::CompiledGraph> {
1238        let config = self.execution_config(prefill_config(batch, seq));
1239        let opts = self.dyn_compile_options(&config);
1240        cache.compile_built(built, &config, &opts)
1241    }
1242
1243    pub fn cfg(&self) -> &Qwen35Config {
1244        &self.cfg
1245    }
1246    pub fn device(&self) -> Device {
1247        self.device
1248    }
1249    pub fn max_seq(&self) -> usize {
1250        self.max_seq
1251    }
1252    pub fn lm_vocab_size(&self) -> usize {
1253        self.weights.lm_vocab_size(&self.cfg)
1254    }
1255
1256    /// True when an mmproj vision encoder was loaded at build time.
1257    pub fn has_vision(&self) -> bool {
1258        self.vision_encoder.is_some()
1259    }
1260
1261    /// Optional path to the mmproj GGUF (if configured).
1262    pub fn mmproj_path(&self) -> Option<&std::path::Path> {
1263        self.mmproj_path.as_deref()
1264    }
1265
1266    fn effective_vocab(&self, graph_vocab: usize) -> usize {
1267        self.lm_vocab_size().min(graph_vocab)
1268    }
1269
1270    fn compile_hir_for_config(
1271        &mut self,
1272        config: ModelExecutionConfig,
1273        aot_disk_key: &str,
1274        hir: rlx_ir::hir::HirModule,
1275    ) -> Result<rlx_runtime::CompiledGraph> {
1276        let config = self.execution_config(config);
1277        let opts = self.dyn_compile_options(&config);
1278        if let Some(aot) = self.aot_cache.as_ref() {
1279            return Ok(aot.compile_hir_cached(aot_disk_key, self.device, hir, &opts)?);
1280        }
1281        // Per-`past_seq` decode HIR is concrete (not symbolic). Sharing one
1282        // `ModelCompilePipeline` template across variants reuses the wrong graph.
1283        if config.preset == ExecutionPreset::Qwen35Decode {
1284            return Ok(Session::new(self.device).compile_hir_with(hir, &opts)?);
1285        }
1286        let cache = self
1287            .predict_hir_cache
1288            .get_or_insert_with(|| make_qwen35_dyn_cache(self.device, 64, None));
1289        let hir = hir;
1290        get_or_specialize_hir_with_options(cache, &config, || hir.clone(), &opts, |_| Ok(()))?;
1291        if self.device == Device::Cpu {
1292            let compiled = get_or_specialize_hir_with_options(
1293                cache,
1294                &config,
1295                || hir.clone(),
1296                &opts,
1297                |_| Ok(()),
1298            )?;
1299            return Ok(compiled.clone());
1300        }
1301        Ok(Session::new(self.device).compile_hir_with(hir, &opts)?)
1302    }
1303
1304    fn lm_loader(&self) -> Option<&GgufLoader> {
1305        self.gguf_loader.as_ref()
1306    }
1307
1308    fn argmax_batch_from_hidden(&self, hidden: &[f32]) -> Result<Vec<u32>> {
1309        let n_embd = self.cfg.hidden_size;
1310        let mut toks = Vec::with_capacity(self.batch);
1311        for b in 0..self.batch {
1312            let h = &hidden[b * n_embd..(b + 1) * n_embd];
1313            let (idx, _) = greedy_lm_head_argmax(&self.weights, &self.cfg, h, self.lm_loader())?;
1314            toks.push(idx);
1315        }
1316        Ok(toks)
1317    }
1318
1319    fn sample_batch_from_hidden(&self, hidden: &[f32], opts: SampleOpts) -> Result<Vec<u32>> {
1320        let n_embd = self.cfg.hidden_size;
1321        let mut toks = Vec::with_capacity(self.batch);
1322        for b in 0..self.batch {
1323            let h = &hidden[b * n_embd..(b + 1) * n_embd];
1324            toks.push(sample_lm_head_from_hidden(
1325                &self.weights,
1326                &self.cfg,
1327                h,
1328                self.lm_loader(),
1329                opts,
1330            )?);
1331        }
1332        Ok(toks)
1333    }
1334
1335    fn decode_step_trunk_raw(
1336        &mut self,
1337        cache: &mut Qwen35DecodeCache,
1338        tokens: &[u32],
1339        generated_per_row: &[usize],
1340    ) -> Result<(Vec<f32>, Option<Vec<f32>>)> {
1341        if self.dynamic_decode {
1342            return self.decode_step_dynamic_raw(cache, tokens, generated_per_row);
1343        }
1344        let past_seq = cache.past_seq;
1345        let head_half = self.cfg.key_length / 2;
1346        let (cos, sin) = mrope_slice_at_pos(&self.cfg, past_seq, head_half);
1347        let use_bucket = self
1348            .decode_compile_cache
1349            .as_ref()
1350            .and_then(|c| c.bucket_for(past_seq as u64))
1351            .is_some();
1352        if use_bucket {
1353            self.decode_step_bucketed_raw(cache, tokens, generated_per_row, &cos, &sin)
1354        } else {
1355            let feeds_owned = decode_step_feeds(
1356                &self.cfg,
1357                cache,
1358                tokens,
1359                &cos,
1360                &sin,
1361                None,
1362                generated_per_row,
1363            )?;
1364            let feeds: Vec<(&str, &[f32])> = feeds_owned
1365                .iter()
1366                .map(|(k, v)| (k.as_str(), v.as_slice()))
1367                .collect();
1368            if !self.decode_graphs.contains_key(&past_seq) {
1369                let (hir, params, packed) = build_qwen35_decode_hir_ext(
1370                    &self.cfg,
1371                    self.weights.clone(),
1372                    self.batch,
1373                    past_seq,
1374                    false,
1375                    self.mtp_logits_path,
1376                    self.fast_mtp,
1377                    self.fast_greedy_lm_head,
1378                )?;
1379                let mut compiled = self.compile_hir_for_config(
1380                    decode_config(self.batch, past_seq),
1381                    &format!("decode_{past_seq}"),
1382                    hir,
1383                )?;
1384                for (name, data) in &params {
1385                    compiled.set_param(name, data);
1386                }
1387                upload_packed_opt(
1388                    &mut compiled,
1389                    self.gguf_loader.as_mut(),
1390                    &packed,
1391                    &mut self.packed_bytes_cache,
1392                )?;
1393                self.decode_graphs.insert(past_seq, compiled);
1394            }
1395            let step = self.moe_refresh_step;
1396            let has_moe = self.moe_offload.is_some();
1397            let num_experts = self.cfg.num_experts;
1398            let moe_masks = self
1399                .moe_offload
1400                .as_ref()
1401                .map(|m| m.per_layer_resident_masks());
1402            self.bind_moe_host_weights();
1403            let outs = {
1404                let compiled = self.decode_graphs.get_mut(&past_seq).unwrap();
1405                if has_moe {
1406                    compiled.enable_moe_topk_capture(num_experts);
1407                    if let Some(layers) = &moe_masks {
1408                        push_moe_residency(compiled, layers);
1409                    }
1410                }
1411                compiled.run(&feeds)
1412            };
1413            if has_moe {
1414                let layers = {
1415                    let compiled = self.decode_graphs.get_mut(&past_seq).unwrap();
1416                    compiled.take_moe_topk_capture()
1417                };
1418                if let (Some(mo), Some(layers)) = (self.moe_offload.as_mut(), layers) {
1419                    let store = self.moe_store.as_ref();
1420                    let compiled = self.decode_graphs.get_mut(&past_seq).unwrap();
1421                    if refresh_moe_from_capture(mo, store, compiled, &layers, step, false) {
1422                        if let Some(store) = self.moe_store.as_ref() {
1423                            store.apply_to_compiled(compiled);
1424                        }
1425                    }
1426                }
1427            }
1428            self.moe_refresh_step = step.saturating_add(1);
1429            advance_cache_from_decode_outputs(
1430                &self.cfg,
1431                cache,
1432                outs,
1433                None,
1434                self.mtp_logits_path,
1435                false,
1436                self.fast_greedy_lm_head,
1437            )
1438        }
1439    }
1440
1441    fn decode_step_dynamic_raw(
1442        &mut self,
1443        cache: &mut Qwen35DecodeCache,
1444        tokens: &[u32],
1445        generated_per_row: &[usize],
1446    ) -> Result<(Vec<f32>, Option<Vec<f32>>)> {
1447        let past_seq = cache.past_seq;
1448        let head_half = self.cfg.key_length / 2;
1449        let (cos, sin) = mrope_slice_at_pos(&self.cfg, past_seq, head_half);
1450        let feeds_owned = decode_step_feeds(
1451            &self.cfg,
1452            cache,
1453            tokens,
1454            &cos,
1455            &sin,
1456            None,
1457            generated_per_row,
1458        )?;
1459        let feeds: Vec<(&str, &[f32])> = feeds_owned
1460            .iter()
1461            .map(|(k, v)| (k.as_str(), v.as_slice()))
1462            .collect();
1463
1464        let config = self.execution_config(decode_config(self.batch, past_seq));
1465        let compile_opts = self.dyn_compile_options(&config);
1466        let dyn_cache = self
1467            .decode_dynamic_cache
1468            .as_mut()
1469            .ok_or_else(|| anyhow!("dynamic decode without cache"))?;
1470        let cfg = self.cfg.clone();
1471        let weights = self.weights.clone();
1472        let max_seq = self.max_seq;
1473        let mtp_logits_path = self.mtp_logits_path;
1474        let fast_mtp = self.fast_mtp;
1475        let fast_greedy = self.fast_greedy_lm_head;
1476        let batch = self.batch;
1477        let decode_params = &self.decode_dynamic_params;
1478        let decode_packed = &self.decode_dynamic_packed;
1479        let gguf_loader = &mut self.gguf_loader;
1480        let packed_bytes_cache = &mut self.packed_bytes_cache;
1481        let compiled = get_or_specialize_hir_with_options(
1482            dyn_cache,
1483            &config,
1484            || {
1485                build_qwen35_decode_hir_dynamic_ext(
1486                    &cfg,
1487                    weights,
1488                    batch,
1489                    max_seq,
1490                    mtp_logits_path,
1491                    fast_mtp,
1492                    fast_greedy,
1493                )
1494                .expect("dynamic decode HIR")
1495                .0
1496            },
1497            &compile_opts,
1498            |c| {
1499                for (name, data) in decode_params {
1500                    c.set_param(name, data);
1501                }
1502                upload_packed_opt(c, gguf_loader.as_mut(), decode_packed, packed_bytes_cache)
1503            },
1504        )?;
1505        let outs = compiled.run(&feeds);
1506        advance_cache_from_decode_outputs(
1507            &self.cfg,
1508            cache,
1509            outs,
1510            None,
1511            self.mtp_logits_path,
1512            false,
1513            self.fast_greedy_lm_head,
1514        )
1515    }
1516
1517    fn trunk_to_logits(&self, trunk: Vec<f32>, is_hidden: bool) -> Result<Vec<f32>> {
1518        if !is_hidden {
1519            return Ok(trunk);
1520        }
1521        let n_embd = self.cfg.hidden_size;
1522        let vocab = self.lm_vocab_size();
1523        let mut logits = Vec::with_capacity(self.batch * vocab);
1524        for b in 0..self.batch {
1525            let h = &trunk[b * n_embd..(b + 1) * n_embd];
1526            logits.extend(lm_head_logits_row(
1527                &self.weights,
1528                &self.cfg,
1529                h,
1530                self.lm_loader(),
1531            )?);
1532        }
1533        Ok(logits)
1534    }
1535
1536    /// Compile decode HIR for `key`'s bucket (if needed) and upload packed GGUF params once.
1537    fn ensure_decode_bucket_compiled(&mut self, key: u64) -> Result<usize> {
1538        let decode_opts = self.bucketed_decode_compile_options();
1539        let cache_mut = self
1540            .decode_compile_cache
1541            .as_mut()
1542            .ok_or_else(|| anyhow!("bucketed decode without cache"))?;
1543        let cfg = self.cfg.clone();
1544        let weights = self.weights.clone();
1545        let batch = self.batch;
1546        let mtp_logits_path = self.mtp_logits_path;
1547        let fast_mtp = self.fast_mtp;
1548        let fast_greedy = self.fast_greedy_lm_head;
1549        let packed_slot = RefCell::new(None::<PackedParams>);
1550        let (upper, compiled) = cache_mut
1551            .ensure_hir_with_params(
1552                key,
1553                |upper| {
1554                    let (hir, params, packed) = build_qwen35_decode_hir_ext(
1555                        &cfg,
1556                        weights.clone(),
1557                        batch,
1558                        upper as usize,
1559                        true,
1560                        mtp_logits_path,
1561                        fast_mtp,
1562                        fast_greedy,
1563                    )
1564                    .expect("qwen35 decode HIR");
1565                    *packed_slot.borrow_mut() = Some(packed);
1566                    (hir, params)
1567                },
1568                &decode_opts,
1569            )
1570            .ok_or_else(|| anyhow!("past_seq {key} outside decode buckets"))?;
1571        if let Some(packed) = packed_slot.take() {
1572            if !packed.is_empty() {
1573                upload_packed_opt(
1574                    compiled,
1575                    self.gguf_loader.as_mut(),
1576                    &packed,
1577                    &mut self.packed_bytes_cache,
1578                )?;
1579            }
1580        }
1581        Ok(upper as usize)
1582    }
1583
1584    /// Pre-compile every decode bucket and upload packed weights once.
1585    fn warm_decode_graphs(&mut self) -> Result<()> {
1586        let upper_bounds: Vec<usize> = match self.decode_compile_cache.as_ref() {
1587            Some(cache) => cache.buckets().map(|r| (r.end - 1) as usize).collect(),
1588            None => return Ok(()),
1589        };
1590        let t = Instant::now();
1591        let total = upper_bounds.len();
1592        for upper in upper_bounds {
1593            self.ensure_decode_bucket_compiled(upper as u64)?;
1594        }
1595        if total > 0 {
1596            eprintln!(
1597                "[qwen35] warmed {total} decode bucket(s) in {:.2?}",
1598                t.elapsed()
1599            );
1600        }
1601        Ok(())
1602    }
1603
1604    /// Pre-compile the predict (prefill logits) graph at build time.
1605    fn warm_predict_graph(&mut self) -> Result<()> {
1606        if self.compiled.is_some() {
1607            return Ok(());
1608        }
1609        let t = Instant::now();
1610        self.ensure_predict_compiled()?;
1611        eprintln!("[qwen35] warmed predict graph in {:.2?}", t.elapsed());
1612        Ok(())
1613    }
1614
1615    /// Clear the decode KV / recurrent cache (e.g. before a fresh prefill in spec decode).
1616    pub fn reset_decode_cache(&mut self) {
1617        self.decode_cache = None;
1618    }
1619
1620    /// Snapshot the current decode cache (for two-phase MTP draft propose).
1621    pub fn decode_cache_checkpoint(&self) -> Option<Qwen35DecodeCache> {
1622        self.decode_cache.clone()
1623    }
1624
1625    /// Restore a decode cache snapshot (discards uncommitted draft steps).
1626    pub fn restore_decode_cache(&mut self, cache: Option<Qwen35DecodeCache>) {
1627        self.decode_cache = cache;
1628    }
1629
1630    /// Advance the decode cache by `tokens` without returning logits (MTP commit path).
1631    pub fn commit_decode_tokens(&mut self, tokens: &[u32]) -> Result<()> {
1632        for &tok in tokens {
1633            let _ = self.decode_get_logits(tok)?;
1634        }
1635        Ok(())
1636    }
1637
1638    fn ensure_predict_compiled(&mut self) -> Result<()> {
1639        if self.compiled.is_some() {
1640            return Ok(());
1641        }
1642        // RLX_QWEN35_DEBUG_LAYERS=1 makes the predict graph emit every
1643        // trunk layer's hidden state as an extra output (gathered at
1644        // last_token_idx → shape `[batch, n_embd]` per layer). Combined
1645        // with the per-output stats dump in `predict_logits_batch` this
1646        // is the bisection harness for "all-zero logits" symptoms: the
1647        // log shows which layer first emits zero/NaN. Requires
1648        // `last_logits_only=true` (the assertion is in the builder).
1649        let debug_layers = std::env::var("RLX_QWEN35_DEBUG_LAYERS")
1650            .map(|v| v == "1")
1651            .unwrap_or(false);
1652        let t = Instant::now();
1653        let (hir, params, packed) = build_qwen35_hir_sized_ext(
1654            &self.cfg,
1655            self.weights.clone(),
1656            self.batch,
1657            self.max_seq,
1658            true,
1659            self.last_logits_only,
1660            self.enable_mtp,
1661            false,
1662            None,
1663            self.runtime_mrope,
1664            self.fast_mtp,
1665            false,
1666            debug_layers,
1667        )?;
1668        eprintln!(
1669            "[qwen35] built predict IR (lazy) in {:.2?} (params={}, packed={})",
1670            t.elapsed(),
1671            params.len(),
1672            packed.len(),
1673        );
1674        let t = Instant::now();
1675        let mut compiled = self.compile_hir_for_config(
1676            prefill_config(self.batch, self.max_seq),
1677            "predict_logits",
1678            hir,
1679        )?;
1680        eprintln!(
1681            "[qwen35] compiled predict graph (lazy) in {:.2?}",
1682            t.elapsed()
1683        );
1684        let t = Instant::now();
1685        for (name, data) in &params {
1686            compiled.set_param(name, data);
1687        }
1688        if !packed.is_empty() {
1689            upload_packed_opt(
1690                &mut compiled,
1691                self.gguf_loader.as_mut(),
1692                &packed,
1693                &mut self.packed_bytes_cache,
1694            )?;
1695        }
1696        eprintln!(
1697            "[qwen35] uploaded predict {} F32 + {} packed params in {:.2?}",
1698            params.len(),
1699            packed.len(),
1700            t.elapsed(),
1701        );
1702        self.compiled = Some(compiled);
1703        Ok(())
1704    }
1705
1706    pub fn predict_logits(&mut self, prompt_ids: &[u32]) -> Result<Qwen35PrefillOutput> {
1707        let out = self
1708            .predict_logits_batch(&[prompt_ids.to_vec()])
1709            .map(|v| v.into_iter().next().unwrap())?;
1710        // Preflight guard: a degenerate all-zero (or all-equal) logits
1711        // tensor is the signature of a broken forward pass — a buggy
1712        // op writing zeros, a packed-K-quant dispatch mis-routed, an
1713        // arena slot being read from the wrong offset, etc. Fail fast
1714        // with a clear error rather than silently returning argmax =
1715        // vocab_size − 1.
1716        if !out.logits.is_empty() {
1717            let mut min = f32::INFINITY;
1718            let mut max = f32::NEG_INFINITY;
1719            for &v in &out.logits {
1720                if v.is_finite() {
1721                    if v < min {
1722                        min = v;
1723                    }
1724                    if v > max {
1725                        max = v;
1726                    }
1727                }
1728            }
1729            if (max - min).abs() < 1e-6 {
1730                bail!(
1731                    "qwen35: predict_logits returned degenerate output \
1732                     (min={min}, max={max}) — the forward pass produced \
1733                     all-equal logits, which indicates a broken op or a \
1734                     mis-routed weight tensor in the trunk. Re-run with \
1735                     RUST_LOG=debug to capture the offending layer."
1736                );
1737            }
1738        }
1739        Ok(out)
1740    }
1741
1742    /// Prefill forward for `batch` prompts (must equal `self.batch()`).
1743    /// Each row may have a different length; all are zero-padded to
1744    /// `max_seq` in the compiled graph.
1745    pub fn predict_logits_batch(
1746        &mut self,
1747        batch_prompts: &[Vec<u32>],
1748    ) -> Result<Vec<Qwen35PrefillOutput>> {
1749        if batch_prompts.len() != self.batch {
1750            bail!(
1751                "qwen35: expected {} prompts (batch={}), got {}",
1752                self.batch,
1753                self.batch,
1754                batch_prompts.len()
1755            );
1756        }
1757        let max_prompt = batch_prompts.iter().map(|p| p.len()).max().unwrap_or(0);
1758        if max_prompt > self.max_seq {
1759            bail!(
1760                "qwen35: prompt length {max_prompt} exceeds compiled max_seq={}",
1761                self.max_seq
1762            );
1763        }
1764        let padded = pack_input_ids(batch_prompts, self.max_seq)?;
1765        let prompt_lens: Vec<usize> = batch_prompts.iter().map(|p| p.len()).collect();
1766        let last_idx = last_token_indices(&prompt_lens);
1767
1768        let mut feeds: Vec<(&str, &[f32])> = vec![("input_ids", padded.as_slice())];
1769        if self.last_logits_only {
1770            feeds.push(("last_token_idx", last_idx.as_slice()));
1771        }
1772        let rope_owned = self.mrope_prefill_rope_feeds(max_prompt);
1773        for (name, data) in &rope_owned {
1774            feeds.push((name.as_str(), data.as_slice()));
1775        }
1776        self.ensure_predict_compiled()?;
1777        let outs = self.compiled.as_mut().unwrap().run(&feeds);
1778        if outs.is_empty() {
1779            bail!("qwen35: forward produced no outputs");
1780        }
1781        // RLX_QWEN35_DEBUG_LAYERS=1 enabled extra per-layer outputs in
1782        // the predict graph (see `ensure_predict_compiled`). Dump stats
1783        // here so the next debugger can locate which layer first emits
1784        // zero/NaN values without re-running the entire 18-min cycle.
1785        if std::env::var("RLX_QWEN35_DEBUG_LAYERS").as_deref() == Ok("1") {
1786            // Output layout (when debug_layers is on):
1787            //   outs[0] = logits           [batch, vocab]
1788            //   outs[1..] = trunk_layer_hiddens (one per decoder layer)
1789            //               each [batch, n_embd]
1790            let n_layers = self.cfg.num_hidden_layers - self.cfg.nextn_predict_layers;
1791            for i in 0..outs.len() {
1792                let v = &outs[i];
1793                let mut min = f32::INFINITY;
1794                let mut max = f32::NEG_INFINITY;
1795                let mut sum = 0.0f64;
1796                let mut nan = 0usize;
1797                let mut nnz = 0usize;
1798                for &x in v {
1799                    if x.is_nan() {
1800                        nan += 1;
1801                        continue;
1802                    }
1803                    sum += x as f64;
1804                    if x < min {
1805                        min = x;
1806                    }
1807                    if x > max {
1808                        max = x;
1809                    }
1810                    if x != 0.0 {
1811                        nnz += 1;
1812                    }
1813                }
1814                let mean = sum / v.len().max(1) as f64;
1815                let label = if i == 0 {
1816                    "logits".to_string()
1817                } else if i - 1 < n_layers {
1818                    format!("layer_{:02}", i - 1)
1819                } else {
1820                    format!("extra_{:02}", i - 1 - n_layers)
1821                };
1822                eprintln!(
1823                    "[qwen35][debug-layers] {label}: len={} nnz={} nan={} min={} max={} mean={:.6}",
1824                    v.len(),
1825                    nnz,
1826                    nan,
1827                    min,
1828                    max,
1829                    mean
1830                );
1831            }
1832        }
1833        let vocab_size = if self.last_logits_only {
1834            outs[0].len() / self.batch
1835        } else {
1836            outs[0].len() / (self.batch * self.max_seq)
1837        };
1838        let sample_vocab = self.effective_vocab(vocab_size);
1839        let mtp_logits = if self.enable_mtp && outs.len() >= 2 {
1840            Some(outs[1].clone())
1841        } else {
1842            None
1843        };
1844        let mut per_batch = Vec::with_capacity(self.batch);
1845        for b in 0..self.batch {
1846            let start = b * vocab_size;
1847            let mut row = outs[0][start..start + vocab_size].to_vec();
1848            row.truncate(sample_vocab);
1849            per_batch.push(Qwen35PrefillOutput {
1850                logits: row,
1851                mtp_logits: mtp_logits.as_ref().map(|m| {
1852                    let m_vocab = m.len() / self.batch.max(1);
1853                    let mut mv = m[b * m_vocab..(b + 1) * m_vocab].to_vec();
1854                    mv.truncate(sample_vocab);
1855                    mv
1856                }),
1857                vocab_size: sample_vocab,
1858            });
1859        }
1860        Ok(per_batch)
1861    }
1862
1863    /// Greedy autoregressive generation with decode-state caching (batch=1).
1864    pub fn generate<F>(&mut self, prompt_ids: &[u32], n_new: usize, on_token: F) -> Result<Vec<u32>>
1865    where
1866        F: FnMut(u32) -> bool,
1867    {
1868        self.generate_with_opts(prompt_ids, n_new, SampleOpts::greedy(), on_token)
1869    }
1870
1871    /// Autoregressive generation with sampling options (batch=1).
1872    pub fn generate_with_opts<F>(
1873        &mut self,
1874        prompt_ids: &[u32],
1875        n_new: usize,
1876        opts: SampleOpts,
1877        mut on_token: F,
1878    ) -> Result<Vec<u32>>
1879    where
1880        F: FnMut(u32) -> bool,
1881    {
1882        if self.batch != 1 {
1883            bail!(
1884                "qwen35::generate: runner batch={} — use generate_batch() instead",
1885                self.batch
1886            );
1887        }
1888        let generated = self
1889            .generate_batch_with_opts(&[prompt_ids.to_vec()], n_new, None, opts, |_, tok| {
1890                on_token(tok)
1891            })?
1892            .into_iter()
1893            .next()
1894            .unwrap_or_default();
1895        Ok(generated)
1896    }
1897
1898    /// Batched greedy generation. `prompts.len()` must equal `self.batch()`.
1899    pub fn generate_batch<F>(
1900        &mut self,
1901        prompts: &[Vec<u32>],
1902        n_new: usize,
1903        on_token: F,
1904    ) -> Result<Vec<Vec<u32>>>
1905    where
1906        F: FnMut(usize, u32) -> bool,
1907    {
1908        self.generate_batch_with_opts(prompts, n_new, None, SampleOpts::greedy(), on_token)
1909    }
1910
1911    /// Batched generation with per-row token limits and sampling.
1912    ///
1913    /// `n_new_per_row`: optional per-row max new tokens (defaults to `n_new`).
1914    pub fn generate_batch_with_opts<F>(
1915        &mut self,
1916        prompts: &[Vec<u32>],
1917        n_new: usize,
1918        n_new_per_row: Option<&[usize]>,
1919        opts: SampleOpts,
1920        mut on_token: F,
1921    ) -> Result<Vec<Vec<u32>>>
1922    where
1923        F: FnMut(usize, u32) -> bool,
1924    {
1925        if prompts.is_empty() {
1926            bail!("qwen35::generate_batch: prompts must be non-empty");
1927        }
1928        if prompts.len() != self.batch {
1929            bail!(
1930                "qwen35::generate_batch: expected {} prompts, got {}",
1931                self.batch,
1932                prompts.len()
1933            );
1934        }
1935        if let Some(limits) = n_new_per_row {
1936            if limits.len() != self.batch {
1937                bail!(
1938                    "qwen35::generate_batch: n_new_per_row len {} != batch {}",
1939                    limits.len(),
1940                    self.batch
1941                );
1942            }
1943        }
1944        for (i, p) in prompts.iter().enumerate() {
1945            if p.is_empty() {
1946                bail!("qwen35::generate_batch: prompt row {i} is empty");
1947            }
1948        }
1949
1950        self.decode_cache = None;
1951
1952        let _prompt_lens: Vec<usize> = prompts.iter().map(|p| p.len()).collect();
1953        let row_limits: Vec<usize> = if let Some(limits) = n_new_per_row {
1954            limits.to_vec()
1955        } else {
1956            vec![n_new; self.batch]
1957        };
1958
1959        let (trunk, mut cache, _) = self.prefill_seed_decode_cache(prompts)?;
1960
1961        let mut generated: Vec<Vec<u32>> = vec![Vec::new(); self.batch];
1962        let mut active = vec![true; self.batch];
1963        let mut row_gen_count = vec![0usize; self.batch];
1964
1965        let mut next_tokens = if self.fast_greedy_lm_head && opts.greedy {
1966            self.argmax_batch_from_hidden(&trunk)?
1967        } else if self.fast_greedy_lm_head
1968            && sample_lm_cap(opts, self.lm_vocab_size()) < self.lm_vocab_size()
1969        {
1970            self.sample_batch_from_hidden(&trunk, opts)?
1971        } else {
1972            let logits = self.trunk_to_logits(trunk, self.fast_greedy_lm_head)?;
1973            sample_logits_batch(&logits, self.lm_vocab_size(), self.batch, opts)
1974        };
1975        if n_new > 0 {
1976            for b in 0..self.batch {
1977                if row_gen_count[b] >= row_limits[b] {
1978                    active[b] = false;
1979                    continue;
1980                }
1981                let tok = next_tokens[b];
1982                generated[b].push(tok);
1983                row_gen_count[b] += 1;
1984                active[b] = on_token(b, tok) && row_gen_count[b] < row_limits[b];
1985            }
1986        }
1987
1988        for _ in 1..n_new {
1989            if !active.iter().any(|&a| a) {
1990                break;
1991            }
1992            if cache.past_seq >= self.max_seq - 1 {
1993                bail!("qwen35: decode cache reached max_seq={}", self.max_seq);
1994            }
1995            next_tokens = self.decode_step(&mut cache, &next_tokens, &row_gen_count, opts)?;
1996            for b in 0..self.batch {
1997                if !active[b] || row_gen_count[b] >= row_limits[b] {
1998                    active[b] = false;
1999                    continue;
2000                }
2001                let tok = next_tokens[b];
2002                generated[b].push(tok);
2003                row_gen_count[b] += 1;
2004                active[b] = on_token(b, tok) && row_gen_count[b] < row_limits[b];
2005            }
2006            self.decode_cache = Some(cache.clone());
2007        }
2008        Ok(generated)
2009    }
2010
2011    /// Prefill and return trunk logits for the last prompt position.
2012    /// Seeds the decode cache so subsequent [`Self::decode_get_logits`] calls work.
2013    pub fn prefill_get_last_logits(&mut self, prompt_ids: &[u32]) -> Result<Vec<f32>> {
2014        Ok(self.prefill_seed_for_decode(prompt_ids)?.trunk_logits)
2015    }
2016
2017    /// Prefill-cache seed + decode cache for spec decode / MTP draft paths.
2018    pub fn prefill_seed_for_decode(&mut self, prompt_ids: &[u32]) -> Result<Qwen35PrefillSeed> {
2019        if self.batch != 1 {
2020            bail!(
2021                "qwen35: prefill_seed_for_decode requires batch=1 (runner batch={})",
2022                self.batch
2023            );
2024        }
2025        let (trunk, _, mtp_logits) = self.prefill_seed_decode_cache(&[prompt_ids.to_vec()])?;
2026        Ok(Qwen35PrefillSeed {
2027            trunk_logits: self.trunk_to_logits(trunk, self.fast_greedy_lm_head)?,
2028            mtp_logits,
2029        })
2030    }
2031
2032    /// Multimodal prefill: vision-encode `rgb`, splice into `prompt` at
2033    /// [`MEDIA_MARKER`](crate::MEDIA_MARKER), seed decode cache.
2034    pub fn prefill_multimodal(
2035        &mut self,
2036        prompt: &str,
2037        rgb: &[u8],
2038        img_w: usize,
2039        img_h: usize,
2040        tokenizer: Option<&std::path::Path>,
2041    ) -> Result<Qwen35PrefillSeed> {
2042        let (trunk, mtp_logits) =
2043            self.prefill_multimodal_trunk(prompt, rgb, img_w, img_h, tokenizer)?;
2044        Ok(Qwen35PrefillSeed {
2045            trunk_logits: self.trunk_to_logits(trunk, self.fast_greedy_lm_head)?,
2046            mtp_logits,
2047        })
2048    }
2049
2050    /// Prefill from an already-assembled multimodal payload (tests / custom tokenizers).
2051    pub fn prefill_from_assembled(
2052        &mut self,
2053        prefill: MultimodalPrefill,
2054    ) -> Result<Qwen35PrefillSeed> {
2055        if self.batch != 1 {
2056            bail!(
2057                "qwen35: prefill_from_assembled requires batch=1 (runner batch={})",
2058                self.batch
2059            );
2060        }
2061        self.mrope_section_positions = Some(prefill.mrope_sections.clone());
2062        let (trunk, _, mtp_logits) = self.prefill_seed_from_hidden(prefill)?;
2063        Ok(Qwen35PrefillSeed {
2064            trunk_logits: self.trunk_to_logits(trunk, self.fast_greedy_lm_head)?,
2065            mtp_logits,
2066        })
2067    }
2068
2069    fn prefill_multimodal_trunk(
2070        &mut self,
2071        prompt: &str,
2072        rgb: &[u8],
2073        img_w: usize,
2074        img_h: usize,
2075        tokenizer: Option<&std::path::Path>,
2076    ) -> Result<(Vec<f32>, Option<Vec<f32>>)> {
2077        if self.batch != 1 {
2078            bail!(
2079                "qwen35: prefill_multimodal requires batch=1 (runner batch={})",
2080                self.batch
2081            );
2082        }
2083        let vision = {
2084            let enc = self
2085                .vision_encoder
2086                .as_mut()
2087                .ok_or_else(|| anyhow!("qwen35: prefill_multimodal requires .mmproj(...)"))?;
2088            enc.encode_rgb(rgb, img_w, img_h)?
2089        };
2090        if self.weights.token_embd.is_empty() {
2091            bail!("qwen35: multimodal prefill requires token_embd weights");
2092        }
2093        let weights_path = self.weights_path.as_path();
2094        if weights_path.as_os_str().is_empty() {
2095            bail!("qwen35: multimodal prefill requires a GGUF weights path (for tokenizer)");
2096        }
2097        let n_embd = self.cfg.hidden_size;
2098        let mm = MultimodalPrompt {
2099            prompt,
2100            vision: &vision,
2101        };
2102        let prefill = mm.assemble(
2103            |text| encode_prompt_auto(weights_path, tokenizer, text),
2104            &self.weights.token_embd,
2105            n_embd,
2106            0,
2107        )?;
2108        self.mrope_section_positions = Some(prefill.mrope_sections.clone());
2109        let (trunk, _, mtp_logits) = self.prefill_seed_from_hidden(prefill)?;
2110        Ok((trunk, mtp_logits))
2111    }
2112
2113    /// Autoregressive generation from a multimodal prompt (batch=1).
2114    pub fn generate_multimodal_with_opts<F>(
2115        &mut self,
2116        prompt: &str,
2117        rgb: &[u8],
2118        img_w: usize,
2119        img_h: usize,
2120        tokenizer: Option<&std::path::Path>,
2121        n_new: usize,
2122        opts: SampleOpts,
2123        mut on_token: F,
2124    ) -> Result<Vec<u32>>
2125    where
2126        F: FnMut(u32) -> bool,
2127    {
2128        if self.batch != 1 {
2129            bail!(
2130                "qwen35: generate_multimodal requires batch=1 (runner batch={})",
2131                self.batch
2132            );
2133        }
2134        self.decode_cache = None;
2135        let (trunk, _) = self.prefill_multimodal_trunk(prompt, rgb, img_w, img_h, tokenizer)?;
2136        let mut cache = self
2137            .decode_cache
2138            .take()
2139            .ok_or_else(|| anyhow!("qwen35: multimodal prefill did not seed decode cache"))?;
2140        let mut next_tokens = if self.fast_greedy_lm_head && opts.greedy {
2141            self.argmax_batch_from_hidden(&trunk)?
2142        } else if self.fast_greedy_lm_head
2143            && sample_lm_cap(opts, self.lm_vocab_size()) < self.lm_vocab_size()
2144        {
2145            self.sample_batch_from_hidden(&trunk, opts)?
2146        } else {
2147            let logits = self.trunk_to_logits(trunk, self.fast_greedy_lm_head)?;
2148            sample_logits_batch(&logits, self.lm_vocab_size(), 1, opts)
2149        };
2150        let mut generated = Vec::new();
2151        if n_new > 0 {
2152            let tok = next_tokens[0];
2153            generated.push(tok);
2154            if !on_token(tok) {
2155                return Ok(generated);
2156            }
2157        }
2158        let row_gen = vec![0usize];
2159        for _ in 1..n_new {
2160            if cache.past_seq >= self.max_seq - 1 {
2161                bail!("qwen35: decode cache reached max_seq={}", self.max_seq);
2162            }
2163            next_tokens = self.decode_step(&mut cache, &next_tokens, &row_gen, opts)?;
2164            let tok = next_tokens[0];
2165            generated.push(tok);
2166            self.decode_cache = Some(cache.clone());
2167            if !on_token(tok) {
2168                break;
2169            }
2170        }
2171        Ok(generated)
2172    }
2173
2174    /// Greedy multimodal generation (batch=1).
2175    pub fn generate_multimodal<F>(
2176        &mut self,
2177        prompt: &str,
2178        rgb: &[u8],
2179        img_w: usize,
2180        img_h: usize,
2181        tokenizer: Option<&std::path::Path>,
2182        n_new: usize,
2183        on_token: F,
2184    ) -> Result<Vec<u32>>
2185    where
2186        F: FnMut(u32) -> bool,
2187    {
2188        self.generate_multimodal_with_opts(
2189            prompt,
2190            rgb,
2191            img_w,
2192            img_h,
2193            tokenizer,
2194            n_new,
2195            SampleOpts::greedy(),
2196            on_token,
2197        )
2198    }
2199
2200    /// Run the prefill-cache graph, seed decode state, return flattened batch logits.
2201    fn prefill_seed_decode_cache(
2202        &mut self,
2203        prompts: &[Vec<u32>],
2204    ) -> Result<(Vec<f32>, Qwen35DecodeCache, Option<Vec<f32>>)> {
2205        if prompts.len() != self.batch {
2206            bail!(
2207                "qwen35: expected {} prompts (batch={}), got {}",
2208                self.batch,
2209                self.batch,
2210                prompts.len()
2211            );
2212        }
2213        for (i, p) in prompts.iter().enumerate() {
2214            if p.is_empty() {
2215                bail!("qwen35: prompt row {i} is empty");
2216            }
2217        }
2218
2219        let prompt_lens: Vec<usize> = prompts.iter().map(|p| p.len()).collect();
2220        let seq = prompt_lens.iter().copied().max().unwrap();
2221        if seq > self.max_seq {
2222            bail!(
2223                "qwen35: prompt length {seq} exceeds compiled max_seq={}",
2224                self.max_seq
2225            );
2226        }
2227
2228        let input_ids = if self.dynamic_prefill {
2229            pack_input_ids(prompts, seq)?
2230        } else {
2231            pack_input_ids(prompts, self.max_seq)?
2232        };
2233        let last_idx = last_token_indices(&prompt_lens);
2234
2235        let mut feeds: Vec<(&str, &[f32])> = vec![("input_ids", input_ids.as_slice())];
2236        feeds.push(("last_token_idx", last_idx.as_slice()));
2237        let zero_in = zero_recurrent_inputs(&self.cfg, self.batch);
2238        for (name, data) in &zero_in {
2239            feeds.push((name, data.as_slice()));
2240        }
2241        let rope_owned = self.mrope_prefill_rope_feeds(seq);
2242        for (name, data) in &rope_owned {
2243            feeds.push((name.as_str(), data.as_slice()));
2244        }
2245
2246        let has_moe = self.moe_offload.is_some();
2247        let num_experts = self.cfg.num_experts;
2248        let moe_masks = self
2249            .moe_offload
2250            .as_ref()
2251            .map(|m| m.per_layer_resident_masks());
2252        self.bind_moe_host_weights();
2253
2254        let outs = if self.dynamic_prefill {
2255            let config = self.execution_config(prefill_config(self.batch, seq));
2256            let compile_opts = self.dyn_compile_options(&config);
2257            let compiled = {
2258                let cache = self
2259                    .prefill_dynamic_cache
2260                    .as_mut()
2261                    .expect("dynamic prefill cache");
2262                let cfg = self.cfg.clone();
2263                let weights = self.weights.clone();
2264                let runtime_mrope = self.runtime_mrope;
2265                let mtp_logits_path = self.mtp_logits_path;
2266                let fast_mtp = self.fast_mtp;
2267                let fast_greedy = self.fast_greedy_lm_head;
2268                let cache_params = &self.prefill_cache_params;
2269                let cache_packed = &self.prefill_cache_packed;
2270                let gguf_loader = &mut self.gguf_loader;
2271                let packed_bytes_cache = &mut self.packed_bytes_cache;
2272                get_or_specialize_hir_with_options(
2273                    cache,
2274                    &config,
2275                    || {
2276                        build_qwen35_prefill_cache_hir_dynamic_ext(
2277                            &cfg,
2278                            weights,
2279                            1,
2280                            seq,
2281                            runtime_mrope,
2282                            mtp_logits_path,
2283                            fast_mtp,
2284                            fast_greedy,
2285                        )
2286                        .expect("dynamic prefill HIR")
2287                        .0
2288                    },
2289                    &compile_opts,
2290                    |c| {
2291                        for (name, data) in cache_params {
2292                            c.set_param(name, data);
2293                        }
2294                        upload_packed_opt(c, gguf_loader.as_mut(), cache_packed, packed_bytes_cache)
2295                    },
2296                )?
2297            };
2298            if has_moe {
2299                compiled.enable_moe_topk_capture(num_experts);
2300                if let Some(layers) = &moe_masks {
2301                    push_moe_residency(compiled, layers);
2302                }
2303            }
2304            let outs = compiled.run(&feeds);
2305            if let Some(layers) = compiled.take_moe_topk_capture() {
2306                if let Some(mo) = self.moe_offload.as_mut() {
2307                    let store = self.moe_store.as_ref();
2308                    if refresh_moe_from_capture(mo, store, compiled, &layers, 0, true) {
2309                        if let Some(store) = self.moe_store.as_ref() {
2310                            store.apply_to_compiled(compiled);
2311                        }
2312                    }
2313                }
2314            }
2315            outs
2316        } else {
2317            let compiled = self.prefill_cache.as_mut().expect("static prefill cache");
2318            if has_moe {
2319                compiled.enable_moe_topk_capture(num_experts);
2320                if let Some(layers) = &moe_masks {
2321                    push_moe_residency(compiled, layers);
2322                }
2323            }
2324            let outs = compiled.run(&feeds);
2325            let layers = if has_moe {
2326                compiled.take_moe_topk_capture()
2327            } else {
2328                None
2329            };
2330            if let (Some(mo), Some(layers)) = (self.moe_offload.as_mut(), layers) {
2331                let store = self.moe_store.as_ref();
2332                if refresh_moe_from_capture(mo, store, compiled, &layers, 0, true) {
2333                    if let Some(store) = self.moe_store.as_ref() {
2334                        store.apply_to_compiled(compiled);
2335                    }
2336                }
2337            }
2338            outs
2339        };
2340        let (trunk, mut cache, mtp_logits) = seed_cache_from_outputs(
2341            &self.cfg,
2342            self.batch,
2343            seq,
2344            &prompt_lens,
2345            outs,
2346            self.mtp_logits_path,
2347            self.fast_greedy_lm_head,
2348        )?;
2349        zero_prompt_padding_kv(&self.cfg, &mut cache, seq);
2350        self.decode_cache = Some(cache.clone());
2351        Ok((trunk, cache, mtp_logits))
2352    }
2353
2354    /// VLM prefill-cache path: host-spliced hidden states + runtime MRoPE sections.
2355    fn prefill_seed_from_hidden(
2356        &mut self,
2357        prefill: MultimodalPrefill,
2358    ) -> Result<(Vec<f32>, Qwen35DecodeCache, Option<Vec<f32>>)> {
2359        let seq = prefill.seq.len();
2360        if seq == 0 {
2361            bail!("qwen35: multimodal prefill seq is empty");
2362        }
2363        if seq > self.max_seq {
2364            bail!(
2365                "qwen35: multimodal seq {seq} exceeds compiled max_seq={}",
2366                self.max_seq
2367            );
2368        }
2369        let n_embd = self.cfg.hidden_size;
2370        if prefill.hidden.len() != seq * n_embd {
2371            bail!(
2372                "qwen35: prefill hidden len {} != seq*n_embd {}*{}",
2373                prefill.hidden.len(),
2374                seq,
2375                n_embd
2376            );
2377        }
2378
2379        let last_idx = vec![prefill.last_token_idx as f32];
2380        let zero_in = zero_recurrent_inputs(&self.cfg, self.batch);
2381        let input_ids = if self.mtp_logits_path || self.enable_mtp {
2382            Some(pack_input_ids(std::slice::from_ref(&prefill.seq), seq)?)
2383        } else {
2384            None
2385        };
2386        let mut feeds: Vec<(&str, &[f32])> = vec![("prefill_hidden", prefill.hidden.as_slice())];
2387        feeds.push(("last_token_idx", last_idx.as_slice()));
2388        for (name, data) in &zero_in {
2389            feeds.push((name, data.as_slice()));
2390        }
2391        if let Some(ref ids) = input_ids {
2392            feeds.push(("input_ids", ids.as_slice()));
2393        }
2394        let rope_owned = self.mrope_prefill_rope_feeds(seq);
2395        for (name, data) in &rope_owned {
2396            feeds.push((name.as_str(), data.as_slice()));
2397        }
2398
2399        let config = self.execution_config(hidden_prefill_config(self.batch, seq));
2400        let compile_opts = self.dyn_compile_options(&config);
2401        let cache = self
2402            .prefill_hidden_dynamic_cache
2403            .as_mut()
2404            .ok_or_else(|| anyhow!("qwen35: hidden prefill cache missing (mmproj not loaded?)"))?;
2405        let cfg = self.cfg.clone();
2406        let weights = self.weights.clone();
2407        let runtime_mrope = self.runtime_mrope;
2408        let mtp_logits_path = self.mtp_logits_path;
2409        let fast_mtp = self.fast_mtp;
2410        let fast_greedy = self.fast_greedy_lm_head;
2411        let hidden_params = &self.prefill_hidden_cache_params;
2412        let hidden_packed = &self.prefill_hidden_cache_packed;
2413        let gguf_loader = &mut self.gguf_loader;
2414        let packed_bytes_cache = &mut self.packed_bytes_cache;
2415        let compiled = get_or_specialize_hir_with_options(
2416            cache,
2417            &config,
2418            || {
2419                build_qwen35_prefill_hidden_cache_hir_dynamic_ext(
2420                    &cfg,
2421                    weights,
2422                    1,
2423                    seq,
2424                    runtime_mrope,
2425                    mtp_logits_path,
2426                    fast_mtp,
2427                    fast_greedy,
2428                )
2429                .expect("dynamic hidden prefill HIR")
2430                .0
2431            },
2432            &compile_opts,
2433            |c| {
2434                for (name, data) in hidden_params {
2435                    c.set_param(name, data);
2436                }
2437                upload_packed_opt(c, gguf_loader.as_mut(), hidden_packed, packed_bytes_cache)
2438            },
2439        )?;
2440        let outs = compiled.run(&feeds);
2441        let prompt_lens = vec![seq];
2442        let (trunk, mut cache, mtp_logits) = seed_cache_from_outputs(
2443            &self.cfg,
2444            self.batch,
2445            seq,
2446            &prompt_lens,
2447            outs,
2448            self.mtp_logits_path,
2449            self.fast_greedy_lm_head,
2450        )?;
2451        zero_prompt_padding_kv(&self.cfg, &mut cache, seq);
2452        self.decode_cache = Some(cache.clone());
2453        Ok((trunk, cache, mtp_logits))
2454    }
2455
2456    /// Single cached decode step returning trunk logits for `token`.
2457    pub fn decode_get_logits(&mut self, token: u32) -> Result<Vec<f32>> {
2458        self.decode_forward_logits(token, false)
2459    }
2460
2461    /// Single cached decode step returning MTP head logits for `token`.
2462    pub fn decode_get_mtp_logits(&mut self, token: u32) -> Result<Vec<f32>> {
2463        if !self.mtp_logits_path {
2464            bail!("qwen35: decode_get_mtp_logits requires mtp_logits_path(true)");
2465        }
2466        self.decode_forward_logits(token, true)
2467    }
2468
2469    /// Prefill and return optional MTP logits (requires `enable_mtp`).
2470    pub fn prefill_get_mtp_logits(&mut self, prompt_ids: &[u32]) -> Result<Vec<f32>> {
2471        self.predict_logits(prompt_ids)?
2472            .mtp_logits
2473            .ok_or_else(|| anyhow!("qwen35: MTP logits unavailable (enable_mtp?)"))
2474    }
2475
2476    fn decode_step(
2477        &mut self,
2478        cache: &mut Qwen35DecodeCache,
2479        tokens: &[u32],
2480        generated_per_row: &[usize],
2481        opts: SampleOpts,
2482    ) -> Result<Vec<u32>> {
2483        if self.fast_greedy_lm_head {
2484            let vocab = self.lm_vocab_size();
2485            let (trunk, _mtp) = self.decode_step_trunk_raw(cache, tokens, generated_per_row)?;
2486            if opts.greedy {
2487                return self.argmax_batch_from_hidden(&trunk);
2488            }
2489            if sample_lm_cap(opts, vocab) < vocab {
2490                return self.sample_batch_from_hidden(&trunk, opts);
2491            }
2492        }
2493        let logits = self.decode_forward_logits_batch(cache, tokens, generated_per_row, false)?;
2494        Ok(sample_logits_batch(
2495            &logits,
2496            self.lm_vocab_size(),
2497            self.batch,
2498            opts,
2499        ))
2500    }
2501
2502    fn decode_forward_logits(&mut self, token: u32, want_mtp: bool) -> Result<Vec<f32>> {
2503        let mut cache = self
2504            .decode_cache
2505            .take()
2506            .ok_or_else(|| anyhow!("qwen35: decode requires seeded cache"))?;
2507        let row_gen = vec![0usize; self.batch];
2508        let logits = self.decode_forward_logits_batch(&mut cache, &[token], &row_gen, want_mtp)?;
2509        self.decode_cache = Some(cache);
2510        Ok(logits)
2511    }
2512
2513    fn decode_forward_logits_batch(
2514        &mut self,
2515        cache: &mut Qwen35DecodeCache,
2516        tokens: &[u32],
2517        generated_per_row: &[usize],
2518        want_mtp: bool,
2519    ) -> Result<Vec<f32>> {
2520        let past_seq = cache.past_seq;
2521        let head_half = self.cfg.key_length / 2;
2522        let (cos, sin) = mrope_slice_at_pos(&self.cfg, past_seq, head_half);
2523
2524        let use_bucket = self
2525            .decode_compile_cache
2526            .as_ref()
2527            .and_then(|c| c.bucket_for(past_seq as u64))
2528            .is_some();
2529
2530        if use_bucket {
2531            let (logits, mtp_logits) =
2532                self.decode_step_bucketed(cache, tokens, generated_per_row, &cos, &sin)?;
2533            if want_mtp {
2534                mtp_logits.ok_or_else(|| anyhow!("mtp decode logits missing from bucketed graph"))
2535            } else {
2536                Ok(logits)
2537            }
2538        } else {
2539            let feeds_owned = decode_step_feeds(
2540                &self.cfg,
2541                cache,
2542                tokens,
2543                &cos,
2544                &sin,
2545                None,
2546                generated_per_row,
2547            )?;
2548            let feeds: Vec<(&str, &[f32])> = feeds_owned
2549                .iter()
2550                .map(|(k, v)| (k.as_str(), v.as_slice()))
2551                .collect();
2552            if !self.decode_graphs.contains_key(&past_seq) {
2553                let (hir, params, packed) = build_qwen35_decode_hir_ext(
2554                    &self.cfg,
2555                    self.weights.clone(),
2556                    self.batch,
2557                    past_seq,
2558                    false,
2559                    self.mtp_logits_path,
2560                    self.fast_mtp,
2561                    self.fast_greedy_lm_head,
2562                )?;
2563                let mut compiled = self.compile_hir_for_config(
2564                    decode_config(self.batch, past_seq),
2565                    &format!("decode_{past_seq}"),
2566                    hir,
2567                )?;
2568                for (name, data) in &params {
2569                    compiled.set_param(name, data);
2570                }
2571                upload_packed_opt(
2572                    &mut compiled,
2573                    self.gguf_loader.as_mut(),
2574                    &packed,
2575                    &mut self.packed_bytes_cache,
2576                )?;
2577                self.decode_graphs.insert(past_seq, compiled);
2578            }
2579            let step = self.moe_refresh_step;
2580            let has_moe = self.moe_offload.is_some();
2581            let num_experts = self.cfg.num_experts;
2582            let moe_masks = self
2583                .moe_offload
2584                .as_ref()
2585                .map(|m| m.per_layer_resident_masks());
2586            self.bind_moe_host_weights();
2587            let outs = {
2588                let compiled = self.decode_graphs.get_mut(&past_seq).unwrap();
2589                if has_moe {
2590                    compiled.enable_moe_topk_capture(num_experts);
2591                    if let Some(layers) = &moe_masks {
2592                        push_moe_residency(compiled, layers);
2593                    }
2594                }
2595                compiled.run(&feeds)
2596            };
2597            if has_moe {
2598                let layers = {
2599                    let compiled = self.decode_graphs.get_mut(&past_seq).unwrap();
2600                    compiled.take_moe_topk_capture()
2601                };
2602                if let (Some(mo), Some(layers)) = (self.moe_offload.as_mut(), layers) {
2603                    let store = self.moe_store.as_ref();
2604                    let compiled = self.decode_graphs.get_mut(&past_seq).unwrap();
2605                    if refresh_moe_from_capture(mo, store, compiled, &layers, step, false) {
2606                        if let Some(store) = self.moe_store.as_ref() {
2607                            store.apply_to_compiled(compiled);
2608                        }
2609                    }
2610                }
2611            }
2612            self.moe_refresh_step = step.saturating_add(1);
2613            let (trunk, mtp_logits) = advance_cache_from_decode_outputs(
2614                &self.cfg,
2615                cache,
2616                outs,
2617                None,
2618                self.mtp_logits_path,
2619                want_mtp,
2620                self.fast_greedy_lm_head,
2621            )?;
2622            let logits = self.trunk_to_logits(trunk, self.fast_greedy_lm_head)?;
2623            if want_mtp {
2624                mtp_logits.ok_or_else(|| anyhow!("mtp decode logits missing from decode graph"))
2625            } else {
2626                Ok(logits)
2627            }
2628        }
2629    }
2630
2631    fn decode_step_bucketed(
2632        &mut self,
2633        cache: &mut Qwen35DecodeCache,
2634        tokens: &[u32],
2635        generated_per_row: &[usize],
2636        cos: &[f32],
2637        sin: &[f32],
2638    ) -> Result<(Vec<f32>, Option<Vec<f32>>)> {
2639        let (trunk, mtp) =
2640            self.decode_step_bucketed_raw(cache, tokens, generated_per_row, cos, sin)?;
2641        let logits = self.trunk_to_logits(trunk, self.fast_greedy_lm_head)?;
2642        Ok((logits, mtp))
2643    }
2644
2645    fn decode_step_bucketed_raw(
2646        &mut self,
2647        cache: &mut Qwen35DecodeCache,
2648        tokens: &[u32],
2649        generated_per_row: &[usize],
2650        cos: &[f32],
2651        sin: &[f32],
2652    ) -> Result<(Vec<f32>, Option<Vec<f32>>)> {
2653        let past_seq = cache.past_seq;
2654        let upper = self.ensure_decode_bucket_compiled(past_seq as u64)?;
2655
2656        let feeds_owned = decode_step_feeds(
2657            &self.cfg,
2658            cache,
2659            tokens,
2660            cos,
2661            sin,
2662            Some(upper),
2663            generated_per_row,
2664        )?;
2665        let feeds: Vec<(&str, &[f32])> = feeds_owned
2666            .iter()
2667            .map(|(k, v)| (k.as_str(), v.as_slice()))
2668            .collect();
2669
2670        let decode_opts = self.bucketed_decode_compile_options();
2671        let cache_mut = self.decode_compile_cache.as_mut().unwrap();
2672        let (_u, compiled) = cache_mut
2673            .ensure_hir_with_params(
2674                past_seq as u64,
2675                |_| panic!("decode bucket must be compiled"),
2676                &decode_opts,
2677            )
2678            .expect("decode bucket missing after ensure");
2679        // Scale attention KV length to actual cache position, not bucket upper.
2680        compiled.set_active_extent(Some((past_seq + 1, upper + 1)));
2681        let outs = compiled.run(&feeds);
2682        compiled.set_active_extent(None);
2683        advance_cache_from_decode_outputs(
2684            &self.cfg,
2685            cache,
2686            outs,
2687            Some(upper),
2688            self.mtp_logits_path,
2689            self.mtp_logits_path,
2690            self.fast_greedy_lm_head,
2691        )
2692    }
2693
2694    fn mrope_prefill_rope_feeds(&self, seq: usize) -> Vec<(String, Vec<f32>)> {
2695        if !self.runtime_mrope {
2696            return Vec::new();
2697        }
2698        let head_half = self.cfg.key_length / 2;
2699        let sections = self.mrope_section_positions.as_deref();
2700        let (cos, sin) = mrope_prefill_feeds(&self.cfg, seq, sections, head_half);
2701        vec![("rope_cos".into(), cos), ("rope_sin".into(), sin)]
2702    }
2703}
2704
2705impl rlx_cli::LmRunner for Qwen35Runner {
2706    fn family(&self) -> &'static str {
2707        "qwen35"
2708    }
2709    fn vocab_size(&self) -> usize {
2710        self.lm_vocab_size()
2711    }
2712    fn predict_logits(&mut self, prompt_ids: &[u32]) -> anyhow::Result<Vec<f32>> {
2713        let out = Qwen35Runner::predict_logits(self, prompt_ids)?;
2714        Ok(out.logits)
2715    }
2716    fn generate(
2717        &mut self,
2718        prompt_ids: &[u32],
2719        n_new: usize,
2720        on_token: &mut dyn FnMut(u32) -> bool,
2721    ) -> anyhow::Result<Vec<u32>> {
2722        // Qwen35Runner::generate takes a bool-returning callback
2723        // (false = stop). The trait callback now has the same shape,
2724        // so just forward.
2725        Qwen35Runner::generate(self, prompt_ids, n_new, on_token)
2726    }
2727
2728    fn supports_multimodal(&self) -> bool {
2729        // True when an mmproj path or inline mmproj weights were
2730        // attached at builder time. The encoder is lazy-loaded inside
2731        // `generate_multimodal_with_opts`.
2732        self.has_mmproj()
2733    }
2734
2735    fn generate_multimodal(
2736        &mut self,
2737        prompt: &str,
2738        rgb: &[u8],
2739        img_w: usize,
2740        img_h: usize,
2741        tokenizer: Option<&std::path::Path>,
2742        n_new: usize,
2743        on_token: &mut dyn FnMut(u32) -> bool,
2744    ) -> anyhow::Result<Vec<u32>> {
2745        Qwen35Runner::generate_multimodal(
2746            self, prompt, rgb, img_w, img_h, tokenizer, n_new, on_token,
2747        )
2748    }
2749}
2750
2751fn sample_logits_batch(logits: &[f32], vocab: usize, batch: usize, opts: SampleOpts) -> Vec<u32> {
2752    let mut out = Vec::with_capacity(batch);
2753    for b in 0..batch {
2754        let row = &logits[b * vocab..(b + 1) * vocab];
2755        out.push(sample_token(row, opts) as u32);
2756    }
2757    out
2758}