Skip to main content

rlx_qwen35/
runner.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! High-level runner for Qwen3.5 / Qwen3.6 (qwen35 architecture).
17
18use crate::cache::{
19    Qwen35DecodeCache, advance_cache_from_decode_outputs, decode_step_feeds, last_token_indices,
20    pack_input_ids, seed_cache_from_outputs, zero_prompt_padding_kv, zero_recurrent_inputs,
21};
22use crate::capabilities::validate_device;
23use crate::config::Qwen35Config;
24use crate::encode_prompt_auto;
25use crate::lm_head::{
26    greedy_lm_head_argmax, lm_head_logits_row, sample_lm_cap, sample_lm_head_from_hidden,
27};
28use crate::moe_offload::{MoeOffloadState, build_moe_offload};
29use crate::moe_store::{build_moe_expert_store, moe_host_bind_from_store};
30use crate::rope::{mrope_prefill_feeds, mrope_slice_at_pos};
31use crate::vision::{
32    MmProjConfig, MmProjWeights, MultimodalPrefill, MultimodalPrompt, Qwen35VisionEncoder,
33    load_vision_encoder,
34};
35use crate::weights::Qwen35Weights;
36use crate::{
37    PackedParams, build_qwen35_decode_hir_dynamic_ext, build_qwen35_decode_hir_ext,
38    build_qwen35_hir_sized_ext, build_qwen35_prefill_cache_hir_dynamic_ext,
39    build_qwen35_prefill_hidden_cache_hir_dynamic_ext,
40};
41use rlx_runtime::MoeExpertStore;
42
43fn push_moe_residency(compiled: &mut rlx_runtime::CompiledGraph, layers: &[Vec<bool>]) {
44    let refs: Vec<&[bool]> = layers.iter().map(|m| m.as_slice()).collect();
45    compiled.set_moe_resident_experts_per_layer(&refs);
46}
47
48fn refresh_moe_from_capture(
49    mo: &mut MoeOffloadState,
50    store: Option<&MoeExpertStore>,
51    compiled: &mut rlx_runtime::CompiledGraph,
52    layer_indices: &[Vec<u32>],
53    denoise_step: usize,
54    is_prefill_block: bool,
55) -> bool {
56    let refreshed = if let Some(store) = store {
57        mo.refresh_from_capture_with_store(store, layer_indices, denoise_step, is_prefill_block)
58    } else {
59        mo.refresh_from_capture(layer_indices, denoise_step, is_prefill_block)
60    };
61    if refreshed {
62        push_moe_residency(compiled, &mo.per_layer_resident_masks());
63    }
64    refreshed
65}
66use crate::execution::{
67    Qwen35CompileCache, decode_config, get_or_specialize_hir_with_options, hidden_prefill_config,
68    prefill_config,
69};
70use crate::flow::{Qwen35PrefillCacheOpts, build_qwen35_prefill_cache_built};
71use crate::profile::{qwen35_profile_default, qwen35_profile_near_weights};
72use anyhow::{Context, Result, anyhow, bail};
73use rlx_core::flow_bridge::compile_options_from_profile;
74use rlx_core::gguf_support::{GgufModelFamily, assert_gguf_family, resolve_weights_file};
75use rlx_core::weight_loader::GgufLoader;
76use rlx_flow::ModelExecutionConfig;
77use rlx_flow::{CompileProfile, ExecutionPreset};
78use rlx_ir::CompilationMode;
79use rlx_ir::logical_kernel::KernelDispatchConfig;
80use rlx_qwen3::sampling::{SampleOpts, sample_token};
81use rlx_runtime::compile_cache::BucketedCompileCache;
82use rlx_runtime::{AotCache, CompileOptions, Device, Session};
83use std::cell::RefCell;
84use std::collections::HashMap;
85use std::path::PathBuf;
86use std::sync::Arc;
87use std::time::Instant;
88
89/// Source for the Qwen3.5 / 3.6 config. Mirrors `Qwen3ConfigSource`
90/// so callers using `Qwen35RunnerBuilder` have the same shape as
91/// `Qwen3RunnerBuilder` (PLAN.md M1).
92///
93/// `JsonFile` and `Explicit` are wired into the future safetensors
94/// load path (catalog rows: `qwen35-{4b,9b,27b}-hauhau-aggressive`).
95/// Until the safetensors load lands, `build()` errors with a clear
96/// M1 follow-up message when these variants are used.
97#[derive(Debug, Clone, Default)]
98pub enum Qwen35ConfigSource {
99    /// Read from GGUF metadata (today's only working path).
100    #[default]
101    Embedded,
102    /// Read from a HuggingFace `config.json` at this path.
103    JsonFile(PathBuf),
104    /// Use the supplied config object directly.
105    Explicit(Qwen35Config),
106}
107
108#[derive(Default, Debug)]
109pub struct Qwen35RunnerBuilder {
110    weights: Option<PathBuf>,
111    config: Option<Qwen35ConfigSource>,
112    device: Option<Device>,
113    max_seq: Option<usize>,
114    enable_mtp: bool,
115    last_logits_only: bool,
116    /// `None` = auto-detect (packed when GGUF ≥ 256 MB to avoid the
117    /// F32-dequant memory explosion — a 4B Q3_K_S file is ~2 GB on
118    /// disk but ~16 GB dense-F32). `Some(true)` / `Some(false)` are
119    /// explicit user overrides.
120    packed_weights: Option<bool>,
121    runtime_mrope: bool,
122    mrope_section_positions: Option<Vec<[usize; 4]>>,
123    batch: Option<usize>,
124    bucketed_decode: Option<bool>,
125    /// Emit/consume MTP logits on the prefill-cache + decode path (draft speculator).
126    mtp_logits_path: bool,
127    fast_mtp: bool,
128    /// Skip LM head in decode graphs; argmax on host (default: true).
129    fast_greedy_lm_head: Option<bool>,
130    /// Persist optimized LIR under this directory (warm-start / AOT).
131    aot_cache_dir: Option<PathBuf>,
132    /// Compile prefill once with `sym::SEQ`, specialize per prompt length.
133    dynamic_prefill: bool,
134    /// Compile decode once with `sym::PAST_SEQ`, specialize per prefix length.
135    dynamic_decode: bool,
136    inline_weights: Option<(Qwen35Config, Qwen35Weights)>,
137    /// Optional mmproj GGUF for VLM vision encoding.
138    mmproj: Option<PathBuf>,
139    /// Inline mmproj weights (tests; mutually exclusive with [`Self::mmproj`]).
140    inline_mmproj: Option<(crate::vision::MmProjConfig, crate::vision::MmProjWeights)>,
141    /// Override tier-1 prefill profile (else `qwen35.rlx.toml` or defaults).
142    prefill_profile: Option<CompileProfile>,
143    /// Override tier-1 decode profile.
144    decode_profile: Option<CompileProfile>,
145    /// TIDE-style cap on GPU-resident experts per MoE layer (`max_gpu_experts_per_layer`).
146    max_gpu_experts_per_layer: Option<usize>,
147    /// Unified RAM / VRAM budget for auto expert cap (optional).
148    moe_memory_budget_bytes: Option<usize>,
149    /// Refresh expert placement every N decode/denoise steps (TIDE `jump_steps` τ).
150    expert_refresh_every_decode_steps: Option<usize>,
151    /// TIDE `jump_steps` alias (preferred name).
152    jump_steps: Option<usize>,
153    /// TIDE `reserve_vram_gb` (default 1.5).
154    reserve_vram_gb: Option<f64>,
155    /// TIDE `collect_stats` on MoE forwards.
156    moe_collect_stats: bool,
157}
158
159impl Qwen35RunnerBuilder {
160    pub fn weights(mut self, path: impl Into<PathBuf>) -> Self {
161        self.weights = Some(path.into());
162        self
163    }
164
165    /// Source for the Qwen3.5 / 3.6 config. Default
166    /// `Qwen35ConfigSource::Embedded` (GGUF metadata). PLAN.md M1
167    /// — `JsonFile` / `Explicit` reserve the API shape for the
168    /// safetensors load path; today they error in `build()` with a
169    /// follow-up message.
170    pub fn config(mut self, src: Qwen35ConfigSource) -> Self {
171        self.config = Some(src);
172        self
173    }
174
175    /// Convenience: explicit `Qwen35Config` (shorthand for
176    /// `.config(Qwen35ConfigSource::Explicit(cfg))`).
177    pub fn config_value(self, cfg: Qwen35Config) -> Self {
178        self.config(Qwen35ConfigSource::Explicit(cfg))
179    }
180    pub fn device(mut self, d: Device) -> Self {
181        self.device = Some(d);
182        self
183    }
184    pub fn max_seq(mut self, n: usize) -> Self {
185        self.max_seq = Some(n);
186        self
187    }
188    pub fn enable_mtp(mut self, on: bool) -> Self {
189        self.enable_mtp = on;
190        self
191    }
192    pub fn last_logits_only(mut self, on: bool) -> Self {
193        self.last_logits_only = on;
194        self
195    }
196    pub fn packed_weights(mut self, on: bool) -> Self {
197        self.packed_weights = Some(on);
198        self
199    }
200    /// Use runtime MRoPE cos/sin inputs instead of a baked table. Required
201    /// for multimodal prompts where section positions differ from `[p,p,p,0]`.
202    pub fn runtime_mrope(mut self, on: bool) -> Self {
203        self.runtime_mrope = on;
204        self
205    }
206    /// Per-token MRoPE section positions `[t,h,w,extra]` (length = prompt seq).
207    pub fn mrope_section_positions(mut self, positions: Vec<[usize; 4]>) -> Self {
208        self.mrope_section_positions = Some(positions);
209        self
210    }
211    /// Batch size for compiled graphs (default 1).
212    pub fn batch(mut self, n: usize) -> Self {
213        self.batch = Some(n);
214        self
215    }
216    /// Use power-of-two bucketed decode compile cache (default: true).
217    pub fn bucketed_decode(mut self, on: bool) -> Self {
218        self.bucketed_decode = Some(on);
219        self
220    }
221    /// Use MTP head logits on prefill-cache seeding and decode steps.
222    pub fn mtp_logits_path(mut self, on: bool) -> Self {
223        self.mtp_logits_path = on;
224        self
225    }
226    /// Trim MTP LM head to 32K vocab (llama.cpp FastMTP draft path).
227    pub fn fast_mtp(mut self, on: bool) -> Self {
228        self.fast_mtp = on;
229        self
230    }
231    /// Decode without graph LM head — host argmax over tied embedding (default on).
232    pub fn fast_greedy_lm_head(mut self, on: bool) -> Self {
233        self.fast_greedy_lm_head = Some(on);
234        self
235    }
236    /// Cache optimized LIR on disk for faster subsequent runs.
237    pub fn aot_cache_dir(mut self, path: impl Into<PathBuf>) -> Self {
238        self.aot_cache_dir = Some(path.into());
239        self
240    }
241    /// Use dynamic prefill specialization (symbolic seq; batch=1).
242    pub fn dynamic_prefill(mut self, on: bool) -> Self {
243        self.dynamic_prefill = on;
244        self
245    }
246    /// Use dynamic decode specialization (symbolic past_seq; batch=1).
247    pub fn dynamic_decode(mut self, on: bool) -> Self {
248        self.dynamic_decode = on;
249        self
250    }
251    /// Supply config + weights directly (tests/benches; no GGUF on disk).
252    pub fn inline_weights(mut self, cfg: Qwen35Config, weights: Qwen35Weights) -> Self {
253        self.inline_weights = Some((cfg, weights));
254        self
255    }
256
257    /// Load vision encoder weights from an mmproj GGUF (enables multimodal prefill).
258    pub fn mmproj(mut self, path: impl Into<PathBuf>) -> Self {
259        self.mmproj = Some(path.into());
260        self
261    }
262
263    /// Supply mmproj config + weights directly (tests; no mmproj GGUF on disk).
264    pub fn inline_mmproj(
265        mut self,
266        cfg: crate::vision::MmProjConfig,
267        weights: crate::vision::MmProjWeights,
268    ) -> Self {
269        self.inline_mmproj = Some((cfg, weights));
270        self
271    }
272
273    /// Override tier-1 compile profiles (skips `qwen35.rlx.toml` discovery).
274    pub fn with_compile_profiles(
275        mut self,
276        prefill: CompileProfile,
277        decode: CompileProfile,
278    ) -> Self {
279        self.prefill_profile = Some(prefill);
280        self.decode_profile = Some(decode);
281        self
282    }
283
284    /// Enable MoE expert offload (TIDE). Caps GPU-resident experts per layer; remainder on host.
285    pub fn max_gpu_experts_per_layer(mut self, n: usize) -> Self {
286        self.max_gpu_experts_per_layer = Some(n);
287        self
288    }
289
290    /// Memory budget for automatic expert cap (macOS: unified RAM when unset).
291    pub fn moe_memory_budget_bytes(mut self, bytes: usize) -> Self {
292        self.moe_memory_budget_bytes = Some(bytes);
293        self
294    }
295
296    /// Refresh expert GPU set every N decode steps (TIDE `jump_steps` for AR).
297    pub fn expert_refresh_every_decode_steps(mut self, n: usize) -> Self {
298        self.expert_refresh_every_decode_steps = Some(n);
299        self.jump_steps = Some(n);
300        self
301    }
302
303    /// TIDE `jump_steps` (τ): refresh expert placement every N denoise/decode steps.
304    pub fn jump_steps(mut self, n: usize) -> Self {
305        self.jump_steps = Some(n);
306        self.expert_refresh_every_decode_steps = Some(n);
307        self
308    }
309
310    /// TIDE `reserve_vram_gb` for GPU expert budget sizing (default 1.5).
311    pub fn reserve_vram_gb(mut self, gb: f64) -> Self {
312        self.reserve_vram_gb = Some(gb);
313        self
314    }
315
316    /// TIDE `collect_stats` — aggregate token/compute counters per forward.
317    pub fn moe_collect_stats(mut self, on: bool) -> Self {
318        self.moe_collect_stats = on;
319        self
320    }
321
322    /// TIDE `enable_predictive_expert_offload(max_gpu_experts_per_layer=…)`.
323    pub fn enable_predictive_expert_offload(mut self, max_gpu_experts_per_layer: usize) -> Self {
324        self.max_gpu_experts_per_layer = Some(max_gpu_experts_per_layer);
325        self
326    }
327
328    pub fn build(self) -> Result<Qwen35Runner> {
329        let device = self.device.unwrap_or(Device::Cpu);
330        let max_seq = self.max_seq.unwrap_or(128);
331        let batch = self.batch.unwrap_or(1);
332        if batch == 0 {
333            bail!("qwen35: batch must be >= 1");
334        }
335
336        // PLAN.md M1: safetensors load path for HauhauCS catalog rows
337        // (and any other Qwen3.5 / 3.6 HF safetensors checkpoint).
338        // When the caller picks `JsonFile(p)` or `Explicit(cfg)`, we
339        // detect a safetensors weights path, parse the HF config (or
340        // use the explicit one), open the file via the weight registry,
341        // wrap it in `HfTranslatingLoader` so GGUF-named lookups
342        // succeed against HF tensor names, and run the standard
343        // `Qwen35Weights::from_loader` drain.
344        if let Some(src) = self.config.as_ref()
345            && !matches!(src, Qwen35ConfigSource::Embedded)
346        {
347            let weights_path = self
348                .weights
349                .as_ref()
350                .ok_or_else(|| anyhow!("weights path required (call .weights(...))"))?
351                .clone();
352            let resolved = resolve_weights_file(&weights_path)?;
353            let ext = resolved
354                .extension()
355                .and_then(|s| s.to_str())
356                .unwrap_or("")
357                .to_ascii_lowercase();
358            if ext == "gguf" {
359                bail!(
360                    "qwen35: Qwen35ConfigSource::{:?} supplied with a GGUF weights file at \
361                     {:?} — drop the config source (use the default Embedded) so the GGUF \
362                     metadata is the source of truth",
363                    src,
364                    resolved
365                );
366            }
367            if self.packed_weights == Some(true) {
368                bail!("qwen35: packed_weights requires GGUF; safetensors path is dequant-only");
369            }
370            let cfg = match src {
371                Qwen35ConfigSource::Embedded => unreachable!(),
372                Qwen35ConfigSource::JsonFile(p) => Qwen35Config::from_hf_config_json(p)
373                    .with_context(|| format!("qwen35: parse HF config {p:?}"))?,
374                Qwen35ConfigSource::Explicit(cfg) => cfg.clone(),
375            };
376            if self.enable_mtp && cfg.nextn_predict_layers == 0 {
377                bail!(
378                    "qwen35: enable_mtp(true) but config has \
379                     nextn_predict_layers=0 (no MTP heads to wire)"
380                );
381            }
382            validate_device(&cfg, device, false)?;
383            let path_str = resolved
384                .to_str()
385                .ok_or_else(|| anyhow!("non-utf8 weights path"))?;
386            let inner_map = rlx_core::weight_map::WeightMap::from_file(path_str)
387                .with_context(|| format!("qwen35: load safetensors {resolved:?}"))?;
388            let mut loader = rlx_core::HfTranslatingLoader::new(inner_map);
389            let t = std::time::Instant::now();
390            let weights = Qwen35Weights::from_loader(&mut loader, &cfg)?;
391            eprintln!(
392                "[qwen35] read safetensors weights in {:.2?} \
393                 (layers={}, hidden={})",
394                t.elapsed(),
395                cfg.num_hidden_layers,
396                cfg.hidden_size,
397            );
398            return finish_build(
399                cfg,
400                weights,
401                resolved,
402                None,
403                device,
404                max_seq,
405                batch,
406                self.enable_mtp,
407                self.last_logits_only,
408                self.runtime_mrope,
409                self.mrope_section_positions,
410                self.bucketed_decode,
411                self.mtp_logits_path,
412                self.fast_mtp,
413                self.fast_greedy_lm_head.unwrap_or(true),
414                self.aot_cache_dir.clone(),
415                self.dynamic_prefill,
416                self.dynamic_decode,
417                self.mmproj.clone(),
418                self.inline_mmproj,
419                self.prefill_profile,
420                self.decode_profile,
421                self.max_gpu_experts_per_layer,
422                self.moe_memory_budget_bytes,
423                self.jump_steps.or(self.expert_refresh_every_decode_steps),
424                self.reserve_vram_gb.unwrap_or(1.5),
425                self.moe_collect_stats,
426            );
427        }
428
429        if let Some((cfg, weights)) = self.inline_weights {
430            if self.packed_weights == Some(true) {
431                bail!("qwen35: inline_weights and packed_weights are mutually exclusive");
432            }
433            if self.enable_mtp && cfg.nextn_predict_layers == 0 {
434                bail!(
435                    "qwen35: enable_mtp(true) but config has \
436                     nextn_predict_layers=0 (no MTP heads to wire)"
437                );
438            }
439            if self.mmproj.is_some() && self.inline_mmproj.is_some() {
440                bail!("qwen35: mmproj and inline_mmproj are mutually exclusive");
441            }
442            validate_device(&cfg, device, false)?;
443            return finish_build(
444                cfg,
445                weights,
446                PathBuf::new(),
447                None,
448                device,
449                max_seq,
450                batch,
451                self.enable_mtp,
452                self.last_logits_only,
453                self.runtime_mrope,
454                self.mrope_section_positions,
455                self.bucketed_decode,
456                self.mtp_logits_path,
457                self.fast_mtp,
458                self.fast_greedy_lm_head.unwrap_or(true),
459                self.aot_cache_dir.clone(),
460                self.dynamic_prefill,
461                self.dynamic_decode,
462                self.mmproj.clone(),
463                self.inline_mmproj,
464                self.prefill_profile,
465                self.decode_profile,
466                self.max_gpu_experts_per_layer,
467                self.moe_memory_budget_bytes,
468                self.jump_steps.or(self.expert_refresh_every_decode_steps),
469                self.reserve_vram_gb.unwrap_or(1.5),
470                self.moe_collect_stats,
471            );
472        }
473
474        let weights_path = resolve_weights_file(
475            &self
476                .weights
477                .ok_or_else(|| anyhow!("weights path required (call .weights(...))"))?,
478        )?;
479        let _t_total = Instant::now();
480        let t = Instant::now();
481        let raw = assert_gguf_family(&weights_path, GgufModelFamily::Qwen35)?;
482        let mut loader = GgufLoader::from_file(
483            weights_path
484                .to_str()
485                .ok_or_else(|| anyhow!("non-utf8 weights path"))?,
486        )?;
487        loader.include_mtp(true);
488        let cfg = Qwen35Config::from_gguf(&raw)?;
489        eprintln!(
490            "[qwen35] loaded GGUF metadata in {:.2?} \
491             (layers={}, hidden={}, ssm_state={})",
492            t.elapsed(),
493            cfg.num_hidden_layers,
494            cfg.hidden_size,
495            cfg.ssm_state_size,
496        );
497
498        if self.enable_mtp && cfg.nextn_predict_layers == 0 {
499            bail!(
500                "qwen35: enable_mtp(true) but the file has \
501                 nextn_predict_layers=0 (no MTP heads to wire)"
502            );
503        }
504        // Resolve auto-default. Llama.cpp keeps K-quant tensors packed
505        // in memory and dequantises per block inside the matmul kernel —
506        // it never materialises a dense F32 weight matrix. Mirror that:
507        // when *any* tensor in the GGUF is a K-quant block format
508        // (Q2_K..Q8_K), force the packed path. Otherwise fall back to
509        // the size heuristic (≥ 256 MB → packed) for legacy quant
510        // formats. Explicit `.packed_weights(_)` overrides.
511        let packed = self.packed_weights.unwrap_or_else(|| {
512            if raw.tensors.values().any(|t| {
513                matches!(
514                    t.dtype,
515                    rlx_gguf::GgmlType::Q2K
516                        | rlx_gguf::GgmlType::Q3K
517                        | rlx_gguf::GgmlType::Q4K
518                        | rlx_gguf::GgmlType::Q5K
519                        | rlx_gguf::GgmlType::Q6K
520                        | rlx_gguf::GgmlType::Q8K
521                )
522            }) {
523                return true;
524            }
525            std::fs::metadata(&weights_path)
526                .ok()
527                .map(|m| m.len() >= 256 * 1024 * 1024)
528                .unwrap_or(false)
529        });
530        validate_device(&cfg, device, packed)?;
531
532        let t = Instant::now();
533        let weights = if packed {
534            Qwen35Weights::from_loader_packed(&mut loader, &cfg)?
535        } else {
536            Qwen35Weights::from_loader(&mut loader, &cfg)?
537        };
538        eprintln!(
539            "[qwen35] read weights ({}) in {:.2?}",
540            if packed { "packed" } else { "F32" },
541            t.elapsed(),
542        );
543
544        finish_build(
545            cfg,
546            weights,
547            weights_path,
548            Some(loader),
549            device,
550            max_seq,
551            batch,
552            self.enable_mtp,
553            self.last_logits_only,
554            self.runtime_mrope,
555            self.mrope_section_positions,
556            self.bucketed_decode,
557            self.mtp_logits_path,
558            self.fast_mtp,
559            self.fast_greedy_lm_head.unwrap_or(true),
560            self.aot_cache_dir.clone(),
561            self.dynamic_prefill,
562            self.dynamic_decode,
563            self.mmproj.clone(),
564            self.inline_mmproj,
565            self.prefill_profile,
566            self.decode_profile,
567            self.max_gpu_experts_per_layer,
568            self.moe_memory_budget_bytes,
569            self.jump_steps.or(self.expert_refresh_every_decode_steps),
570            self.reserve_vram_gb.unwrap_or(1.5),
571            self.moe_collect_stats,
572        )
573    }
574}
575
576fn make_qwen35_dyn_cache(
577    device: Device,
578    capacity: usize,
579    aot_cache_dir: Option<&std::path::Path>,
580) -> Qwen35CompileCache {
581    if let Some(dir) = aot_cache_dir {
582        Qwen35CompileCache::with_aot(device, capacity, dir)
583    } else {
584        Qwen35CompileCache::new(device, capacity)
585    }
586}
587
588/// Static prefill-cache compile via tier-0 [`BuiltModel`] + [`Qwen35CompileCache`].
589fn compile_static_prefill_cache(
590    cfg: &Qwen35Config,
591    weights: Qwen35Weights,
592    batch: usize,
593    max_seq: usize,
594    device: Device,
595    prefill_profile: &CompileProfile,
596    runtime_mrope: bool,
597    enable_mtp_head: bool,
598    fast_mtp: bool,
599    fast_greedy_lm_head: bool,
600    aot_cache_dir: Option<&std::path::Path>,
601) -> Result<(
602    rlx_runtime::CompiledGraph,
603    HashMap<String, Vec<f32>>,
604    PackedParams,
605)> {
606    let mut flow_opts = Qwen35PrefillCacheOpts::static_cache(batch, max_seq);
607    flow_opts.with_lm_head = !fast_greedy_lm_head;
608    flow_opts.runtime_mrope = runtime_mrope;
609    flow_opts.enable_mtp_head = enable_mtp_head;
610    flow_opts.fast_mtp = fast_mtp;
611    flow_opts.fast_greedy_lm_head = fast_greedy_lm_head;
612    flow_opts.profile = Some(prefill_profile.clone());
613
614    let (built, packed) = build_qwen35_prefill_cache_built(cfg, weights, &flow_opts)?;
615    let params = built.params().clone();
616    let config = prefill_config(batch, max_seq);
617    let compile_opts =
618        compile_options_from_profile(prefill_profile, device, KernelDispatchConfig::default());
619
620    let mut cache = match aot_cache_dir {
621        Some(dir) => Qwen35CompileCache::with_aot(device, 1, dir),
622        None => Qwen35CompileCache::new(device, 1),
623    };
624    let mut config = config;
625    if aot_cache_dir.is_some() {
626        config = config.with_compilation_mode(CompilationMode::Aot);
627    }
628    let built = built.with_execution_config(&config);
629    let compiled = cache.compile_built(built, &config, &compile_opts)?;
630    Ok((compiled, params, packed))
631}
632
633#[allow(clippy::too_many_arguments)]
634fn finish_build(
635    cfg: Qwen35Config,
636    weights: Qwen35Weights,
637    weights_path: PathBuf,
638    gguf_loader: Option<GgufLoader>,
639    device: Device,
640    max_seq: usize,
641    batch: usize,
642    enable_mtp: bool,
643    last_logits_only: bool,
644    runtime_mrope: bool,
645    mrope_section_positions: Option<Vec<[usize; 4]>>,
646    bucketed_decode: Option<bool>,
647    mtp_logits_path: bool,
648    fast_mtp: bool,
649    fast_greedy_lm_head: bool,
650    aot_cache_dir: Option<PathBuf>,
651    dynamic_prefill: bool,
652    dynamic_decode: bool,
653    mmproj_path: Option<PathBuf>,
654    inline_mmproj: Option<(MmProjConfig, MmProjWeights)>,
655    prefill_profile_override: Option<CompileProfile>,
656    decode_profile_override: Option<CompileProfile>,
657    max_gpu_experts_per_layer: Option<usize>,
658    moe_memory_budget_bytes: Option<usize>,
659    jump_steps: Option<usize>,
660    reserve_vram_gb: f64,
661    moe_collect_stats: bool,
662) -> Result<Qwen35Runner> {
663    let prefill_profile = prefill_profile_override.unwrap_or_else(|| {
664        if weights_path.as_os_str().is_empty() {
665            qwen35_profile_default(false)
666        } else {
667            qwen35_profile_near_weights(&weights_path, false)
668        }
669    });
670    let decode_profile = decode_profile_override.unwrap_or_else(|| {
671        if weights_path.as_os_str().is_empty() {
672            qwen35_profile_default(true)
673        } else {
674            qwen35_profile_near_weights(&weights_path, true)
675        }
676    });
677
678    if fast_mtp && !mtp_logits_path && !enable_mtp {
679        bail!("qwen35: fast_mtp requires enable_mtp(true) or mtp_logits_path(true)");
680    }
681    if mtp_logits_path && !enable_mtp {
682        bail!("qwen35: mtp_logits_path requires enable_mtp(true)");
683    }
684
685    if dynamic_prefill && batch != 1 {
686        bail!("qwen35: dynamic_prefill requires batch=1");
687    }
688    if dynamic_decode && batch != 1 {
689        bail!("qwen35: dynamic_decode requires batch=1");
690    }
691    if dynamic_decode && bucketed_decode.unwrap_or(true) {
692        eprintln!("[qwen35] dynamic_decode enabled — disabling bucketed decode cache");
693    }
694    let bucketed_decode = if dynamic_decode {
695        false
696    } else {
697        bucketed_decode.unwrap_or(true)
698    };
699
700    let vision_encoder = if let Some(ref path) = mmproj_path {
701        Some(load_vision_encoder(
702            path.to_str()
703                .ok_or_else(|| anyhow!("non-utf8 mmproj path"))?,
704            224,
705            224,
706        )?)
707    } else if let Some((vcfg, vweights)) = inline_mmproj {
708        Some(Qwen35VisionEncoder::from_parts(vcfg, vweights, 4, 4)?)
709    } else {
710        None
711    };
712    let runtime_mrope = runtime_mrope || vision_encoder.is_some();
713    if vision_encoder.is_some() && batch != 1 {
714        bail!("qwen35: VLM (mmproj) requires batch=1");
715    }
716    if vision_encoder.is_some() && !dynamic_prefill {
717        eprintln!("[qwen35] mmproj loaded — enabling dynamic prefill for variable multimodal seq");
718    }
719    let dynamic_prefill = dynamic_prefill || vision_encoder.is_some();
720
721    let t = Instant::now();
722    let aot = aot_cache_dir.as_ref().map(AotCache::new);
723    let (cache_params, cache_packed, mut prefill_cache, prefill_dynamic_cache) = if dynamic_prefill
724    {
725        let (_cache_hir, cache_params, cache_packed) = build_qwen35_prefill_cache_hir_dynamic_ext(
726            &cfg,
727            weights.clone(),
728            batch,
729            max_seq,
730            runtime_mrope,
731            mtp_logits_path,
732            fast_mtp,
733            fast_greedy_lm_head,
734        )?;
735        eprintln!(
736            "[qwen35] built prefill-cache IR in {:.2?} (params={}, packed={})",
737            t.elapsed(),
738            cache_params.len(),
739            cache_packed.len(),
740        );
741        eprintln!("[qwen35] dynamic prefill template ready (compile on first prompt)");
742        (
743            cache_params,
744            cache_packed,
745            None,
746            Some(make_qwen35_dyn_cache(device, 32, aot_cache_dir.as_deref())),
747        )
748    } else {
749        let (compiled, cache_params, cache_packed) = compile_static_prefill_cache(
750            &cfg,
751            weights.clone(),
752            batch,
753            max_seq,
754            device,
755            &prefill_profile,
756            runtime_mrope,
757            mtp_logits_path,
758            fast_mtp,
759            fast_greedy_lm_head,
760            aot_cache_dir.as_deref(),
761        )?;
762        eprintln!(
763            "[qwen35] compiled prefill-cache via BuiltModel in {:.2?} (params={}, packed={})",
764            t.elapsed(),
765            cache_params.len(),
766            cache_packed.len(),
767        );
768        (cache_params, cache_packed, Some(compiled), None)
769    };
770
771    let (prefill_hidden_dynamic_cache, prefill_hidden_cache_params, prefill_hidden_cache_packed) =
772        if vision_encoder.is_some() {
773            let (hidden_hir, hidden_params, hidden_packed) =
774                build_qwen35_prefill_hidden_cache_hir_dynamic_ext(
775                    &cfg,
776                    weights.clone(),
777                    batch,
778                    max_seq,
779                    runtime_mrope,
780                    mtp_logits_path,
781                    fast_mtp,
782                    fast_greedy_lm_head,
783                )?;
784            let _ = hidden_hir;
785            (
786                Some(make_qwen35_dyn_cache(device, 32, aot_cache_dir.as_deref())),
787                hidden_params,
788                hidden_packed,
789            )
790        } else {
791            (None, HashMap::new(), HashMap::new())
792        };
793
794    let t = Instant::now();
795    if let Some(ref mut compiled) = prefill_cache {
796        for (name, data) in &cache_params {
797            compiled.set_param(name, data);
798        }
799    }
800
801    let decode_compile_cache = if bucketed_decode {
802        Some(BucketedCompileCache::power_of_two_ladder(
803            device,
804            1,
805            max_seq.max(1) as u64,
806        ))
807    } else {
808        None
809    };
810    let decode_dynamic_cache = if dynamic_decode {
811        Some(make_qwen35_dyn_cache(device, 32, aot_cache_dir.as_deref()))
812    } else {
813        None
814    };
815    let (decode_dynamic_params, decode_dynamic_packed) = if dynamic_decode {
816        let (_, p, packed) = build_qwen35_decode_hir_dynamic_ext(
817            &cfg,
818            weights.clone(),
819            batch,
820            max_seq,
821            mtp_logits_path,
822            fast_mtp,
823            fast_greedy_lm_head,
824        )?;
825        (p, packed)
826    } else {
827        (HashMap::new(), HashMap::new())
828    };
829
830    if dynamic_decode {
831        eprintln!("[qwen35] dynamic decode template ready (compile on first step)");
832    }
833
834    let moe_offload = build_moe_offload(
835        &cfg,
836        &weights,
837        max_gpu_experts_per_layer,
838        moe_memory_budget_bytes,
839        jump_steps,
840        reserve_vram_gb,
841        moe_collect_stats,
842    );
843    let moe_store = if moe_offload.is_some() {
844        build_moe_expert_store(&cfg, &weights).ok()
845    } else {
846        None
847    };
848    if let Some(ref mo) = moe_offload {
849        eprintln!(
850            "[qwen35] TIDE MoE offload: layers={} gpu_budget={}/{} jump_steps={} reserve_bytes={}",
851            mo.num_layers(),
852            mo.info.gpu_expert_budget_per_layer,
853            mo.pools[0].num_experts(),
854            mo.jump_steps,
855            mo.info.reserve_bytes,
856        );
857    }
858
859    let mut runner = Qwen35Runner {
860        compiled: None,
861        prefill_cache,
862        prefill_dynamic_cache,
863        prefill_hidden_dynamic_cache,
864        prefill_cache_params: cache_params,
865        prefill_cache_packed: cache_packed,
866        prefill_hidden_cache_params,
867        prefill_hidden_cache_packed,
868        decode_graphs: HashMap::new(),
869        decode_compile_cache,
870        decode_dynamic_cache,
871        predict_hir_cache: None,
872        decode_dynamic_params,
873        decode_dynamic_packed,
874        packed_bytes_cache: HashMap::new(),
875        cfg,
876        device,
877        batch,
878        max_seq,
879        last_logits_only,
880        enable_mtp,
881        mtp_logits_path,
882        fast_mtp,
883        fast_greedy_lm_head,
884        weights,
885        weights_path,
886        gguf_loader,
887        decode_cache: None,
888        runtime_mrope,
889        mrope_section_positions,
890        aot_cache: aot,
891        dynamic_prefill,
892        dynamic_decode,
893        vision_encoder,
894        mmproj_path,
895        prefill_profile,
896        decode_profile,
897        moe_offload,
898        moe_store,
899        moe_refresh_step: 0,
900    };
901
902    if let Some(ref mut compiled) = runner.prefill_cache {
903        upload_packed_opt(
904            compiled,
905            runner.gguf_loader.as_mut(),
906            &runner.prefill_cache_packed,
907            &mut runner.packed_bytes_cache,
908        )?;
909    }
910    eprintln!(
911        "[qwen35] uploaded prefill-cache {} F32 + {} packed params in {:.2?}",
912        runner.prefill_cache_params.len(),
913        runner.prefill_cache_packed.len(),
914        t.elapsed(),
915    );
916
917    runner.warm_decode_graphs()?;
918    runner.warm_predict_graph()?;
919    Ok(runner)
920}
921
922fn ensure_packed_cache(
923    loader: &mut GgufLoader,
924    packed: &PackedParams,
925    cache: &mut HashMap<String, Arc<[u8]>>,
926) -> Result<()> {
927    for (loader_key, _, _) in packed.values() {
928        if cache.contains_key(loader_key) {
929            continue;
930        }
931        let bytes = loader
932            .tensor_bytes_borrowed(loader_key)
933            .ok_or_else(|| anyhow!("packed cache: {loader_key} bytes missing"))?;
934        cache.insert(loader_key.clone(), Arc::from(bytes));
935    }
936    Ok(())
937}
938
939fn upload_packed_opt(
940    compiled: &mut rlx_runtime::CompiledGraph,
941    loader: Option<&mut GgufLoader>,
942    packed: &PackedParams,
943    cache: &mut HashMap<String, Arc<[u8]>>,
944) -> Result<()> {
945    if packed.is_empty() {
946        return Ok(());
947    }
948    let loader = loader
949        .ok_or_else(|| anyhow!("packed params require a GGUF loader (missing weights path)"))?;
950    ensure_packed_cache(loader, packed, cache)?;
951    for (param_name, (loader_key, _scheme, _shape)) in packed {
952        let bytes = cache
953            .get(loader_key)
954            .ok_or_else(|| anyhow!("packed upload: cache miss for {loader_key}"))?;
955        compiled.set_param_typed(param_name, bytes, rlx_ir::DType::U8);
956    }
957    Ok(())
958}
959
960#[allow(dead_code)]
961fn upload_decode_packed(
962    weights_path: &std::path::Path,
963    compiled: &mut rlx_runtime::CompiledGraph,
964    packed: &PackedParams,
965) -> Result<()> {
966    if packed.is_empty() {
967        return Ok(());
968    }
969    let path = weights_path
970        .to_str()
971        .filter(|p| !p.is_empty())
972        .ok_or_else(|| anyhow!("packed decode params require a GGUF weights path"))?;
973    let mut loader = GgufLoader::from_file(path)?;
974    loader.include_mtp(true);
975    upload_packed_opt(compiled, Some(&mut loader), packed, &mut HashMap::new())
976}
977
978pub struct Qwen35Runner {
979    compiled: Option<rlx_runtime::CompiledGraph>,
980    prefill_cache: Option<rlx_runtime::CompiledGraph>,
981    prefill_dynamic_cache: Option<Qwen35CompileCache>,
982    prefill_hidden_dynamic_cache: Option<Qwen35CompileCache>,
983    prefill_cache_params: HashMap<String, Vec<f32>>,
984    prefill_cache_packed: PackedParams,
985    prefill_hidden_cache_params: HashMap<String, Vec<f32>>,
986    prefill_hidden_cache_packed: PackedParams,
987    decode_graphs: HashMap<usize, rlx_runtime::CompiledGraph>,
988    decode_compile_cache: Option<BucketedCompileCache>,
989    decode_dynamic_cache: Option<Qwen35CompileCache>,
990    /// Predict / reprefill HIR (must not share a template with decode graphs).
991    predict_hir_cache: Option<Qwen35CompileCache>,
992    decode_dynamic_params: HashMap<String, Vec<f32>>,
993    decode_dynamic_packed: PackedParams,
994    packed_bytes_cache: HashMap<String, Arc<[u8]>>,
995    cfg: Qwen35Config,
996    device: Device,
997    batch: usize,
998    max_seq: usize,
999    last_logits_only: bool,
1000    enable_mtp: bool,
1001    mtp_logits_path: bool,
1002    fast_mtp: bool,
1003    fast_greedy_lm_head: bool,
1004    weights: Qwen35Weights,
1005    weights_path: PathBuf,
1006    gguf_loader: Option<GgufLoader>,
1007    decode_cache: Option<Qwen35DecodeCache>,
1008    runtime_mrope: bool,
1009    mrope_section_positions: Option<Vec<[usize; 4]>>,
1010    aot_cache: Option<AotCache>,
1011    dynamic_prefill: bool,
1012    dynamic_decode: bool,
1013    vision_encoder: Option<Qwen35VisionEncoder>,
1014    mmproj_path: Option<PathBuf>,
1015    prefill_profile: CompileProfile,
1016    decode_profile: CompileProfile,
1017    /// TIDE-style per-layer expert offload.
1018    moe_offload: Option<MoeOffloadState>,
1019    /// Host expert stacks (migration source; F32 MoE only).
1020    moe_store: Option<MoeExpertStore>,
1021    /// Decode-step counter for MoE refresh scheduling (TIDE τ).
1022    moe_refresh_step: usize,
1023}
1024
1025#[derive(Debug, Clone)]
1026pub struct Qwen35PrefillSeed {
1027    pub trunk_logits: Vec<f32>,
1028    pub mtp_logits: Option<Vec<f32>>,
1029}
1030
1031#[derive(Debug, Clone)]
1032pub struct Qwen35PrefillOutput {
1033    pub logits: Vec<f32>,
1034    pub mtp_logits: Option<Vec<f32>>,
1035    pub vocab_size: usize,
1036}
1037
1038impl Qwen35Runner {
1039    pub fn builder() -> Qwen35RunnerBuilder {
1040        Qwen35RunnerBuilder::default()
1041    }
1042
1043    /// Whether an mmproj vision encoder (or its weights) is wired up,
1044    /// allowing [`Self::generate_multimodal`] to splice image embeddings
1045    /// into the prefill. Backs the `LmRunner::supports_multimodal` hook.
1046    pub fn has_mmproj(&self) -> bool {
1047        self.mmproj_path.is_some() || self.vision_encoder.is_some()
1048    }
1049
1050    /// Apply runner AOT settings and build compile options for a dynamic specialize path.
1051    fn execution_config(&self, config: ModelExecutionConfig) -> ModelExecutionConfig {
1052        if self.aot_cache.is_some() {
1053            config.with_compilation_mode(CompilationMode::Aot)
1054        } else {
1055            config
1056        }
1057    }
1058
1059    pub fn prefill_profile(&self) -> &CompileProfile {
1060        &self.prefill_profile
1061    }
1062
1063    pub fn decode_profile(&self) -> &CompileProfile {
1064        &self.decode_profile
1065    }
1066
1067    /// MoE offload state when enabled at build time.
1068    pub fn moe_offload(&self) -> Option<&MoeOffloadState> {
1069        self.moe_offload.as_ref()
1070    }
1071
1072    /// TIDE `enable_predictive_expert_offload` return payload (when offload is active).
1073    pub fn predictive_offload_info(&self) -> Option<&rlx_llada2::tide::PredictiveOffloadInfo> {
1074        self.moe_offload.as_ref().map(|m| &m.info)
1075    }
1076
1077    /// TIDE `get_offload_stats()` — pool promotions/demotions + last-forward residency (CPU).
1078    pub fn get_offload_stats(
1079        &self,
1080        residency: Option<&rlx_runtime::MoeResidencyStats>,
1081    ) -> rlx_llada2::tide::TideOffloadStats {
1082        self.moe_offload
1083            .as_ref()
1084            .map(|m| m.tide_offload_stats(residency))
1085            .unwrap_or_default()
1086    }
1087
1088    pub fn jump_steps(&self) -> usize {
1089        self.moe_offload.as_ref().map(|m| m.jump_steps).unwrap_or(1)
1090    }
1091
1092    pub fn predictive_offload_enabled(&self) -> bool {
1093        self.moe_offload
1094            .as_ref()
1095            .is_some_and(|m| m.predictive_enabled)
1096    }
1097
1098    pub fn moe_offload_mut(&mut self) -> Option<&mut MoeOffloadState> {
1099        self.moe_offload.as_mut()
1100    }
1101
1102    /// MoE refresh step index (TIDE `step` in block denoise / decode loop).
1103    pub fn moe_refresh_step(&self) -> usize {
1104        self.moe_refresh_step
1105    }
1106
1107    /// Enable TopK capture on a compiled graph (CPU; call once after compile).
1108    pub fn enable_moe_topk_on(&self, compiled: &mut rlx_runtime::CompiledGraph) {
1109        if self.moe_offload.is_some() {
1110            compiled.enable_moe_topk_capture(self.cfg.num_experts);
1111        }
1112    }
1113
1114    /// Push per-layer TIDE residency masks into the compiled graph.
1115    pub fn sync_moe_residency(&self, compiled: &mut rlx_runtime::CompiledGraph) {
1116        if let Some(mo) = &self.moe_offload {
1117            push_moe_residency(compiled, &mo.per_layer_resident_masks());
1118        }
1119    }
1120
1121    #[allow(dead_code)]
1122    fn moe_prepare_forward(&self, compiled: &mut rlx_runtime::CompiledGraph) {
1123        self.bind_moe_host_weights();
1124        if self.moe_offload.is_some() {
1125            compiled.enable_moe_topk_capture(self.cfg.num_experts);
1126            self.sync_moe_residency(compiled);
1127        }
1128    }
1129
1130    #[allow(dead_code)]
1131    fn moe_finish_forward(
1132        &mut self,
1133        compiled: &mut rlx_runtime::CompiledGraph,
1134        denoise_step: usize,
1135        is_prefill_block: bool,
1136    ) -> bool {
1137        let Some(layers) = compiled.take_moe_topk_capture() else {
1138            return false;
1139        };
1140        let store = self.moe_store.clone();
1141        let Some(mo) = self.moe_offload.as_mut() else {
1142            return false;
1143        };
1144        let refreshed = if let Some(store) = store.as_ref() {
1145            mo.refresh_from_capture_with_store(store, &layers, denoise_step, is_prefill_block)
1146        } else {
1147            mo.refresh_from_capture(&layers, denoise_step, is_prefill_block)
1148        };
1149        if refreshed {
1150            push_moe_residency(compiled, &mo.per_layer_resident_masks());
1151        }
1152        refreshed
1153    }
1154
1155    /// Install per-expert host pointers for CPU GroupedMatMul fallback (TIDE).
1156    fn bind_moe_host_weights(&self) {
1157        if self.moe_offload.is_none() {
1158            rlx_cpu::moe_residency::bind_host_weights(None);
1159            return;
1160        }
1161        if let Some(store) = &self.moe_store {
1162            rlx_cpu::moe_residency::bind_host_weights(Some(moe_host_bind_from_store(store)));
1163        } else {
1164            rlx_cpu::moe_residency::bind_host_weights(None);
1165        }
1166    }
1167
1168    /// After forward: refresh pools from captured TopK and update graph mask.
1169    pub fn moe_offload_after_forward(&mut self, compiled: &mut rlx_runtime::CompiledGraph) -> bool {
1170        let Some(mo) = self.moe_offload.as_mut() else {
1171            return false;
1172        };
1173        let Some(layers) = compiled.take_moe_topk_capture() else {
1174            return false;
1175        };
1176        let refreshed = mo.refresh_from_capture(&layers, self.moe_refresh_step, false);
1177        if refreshed {
1178            self.sync_moe_residency(compiled);
1179        }
1180        self.moe_refresh_step = self.moe_refresh_step.saturating_add(1);
1181        refreshed
1182    }
1183
1184    /// Manual refresh from flat expert indices (single shared indices for all layers).
1185    pub fn moe_refresh_after_forward(&mut self, expert_idx: &[u32]) -> bool {
1186        let Some(mo) = self.moe_offload.as_mut() else {
1187            return false;
1188        };
1189        let refresh = mo.pools[0].should_refresh(
1190            rlx_runtime::MoEExecMode::Reuse,
1191            self.moe_refresh_step,
1192            false,
1193        );
1194        if refresh {
1195            for pool in &mut mo.pools {
1196                pool.refresh_from_indices(expert_idx);
1197            }
1198        }
1199        self.moe_refresh_step = self.moe_refresh_step.saturating_add(1);
1200        refresh
1201    }
1202
1203    /// Override tier-1 profiles after build (e.g. tests).
1204    pub fn with_compile_profiles(
1205        mut self,
1206        prefill: CompileProfile,
1207        decode: CompileProfile,
1208    ) -> Self {
1209        self.prefill_profile = prefill;
1210        self.decode_profile = decode;
1211        self
1212    }
1213
1214    fn profile_compile_options(&self, decode: bool) -> CompileOptions {
1215        let profile = if decode {
1216            &self.decode_profile
1217        } else {
1218            &self.prefill_profile
1219        };
1220        compile_options_from_profile(profile, self.device, KernelDispatchConfig::default())
1221    }
1222
1223    fn dyn_compile_options(&self, config: &ModelExecutionConfig) -> CompileOptions {
1224        let decode = matches!(config.preset, ExecutionPreset::Qwen35Decode);
1225        let mut opts = self.profile_compile_options(decode);
1226        opts.kernel_dispatch = config.component().kernel_dispatch;
1227        opts.dim_binding(config.dim_binding())
1228    }
1229
1230    fn bucketed_decode_compile_options(&self) -> CompileOptions {
1231        self.profile_compile_options(true)
1232    }
1233
1234    /// Compile a tier-0 prefill [`rlx_flow::BuiltModel`] through [`Qwen35CompileCache`].
1235    pub fn compile_prefill_built(
1236        &self,
1237        cache: &mut Qwen35CompileCache,
1238        built: rlx_flow::BuiltModel,
1239        batch: usize,
1240        seq: usize,
1241    ) -> Result<rlx_runtime::CompiledGraph> {
1242        let config = self.execution_config(prefill_config(batch, seq));
1243        let opts = self.dyn_compile_options(&config);
1244        cache.compile_built(built, &config, &opts)
1245    }
1246
1247    pub fn cfg(&self) -> &Qwen35Config {
1248        &self.cfg
1249    }
1250    pub fn device(&self) -> Device {
1251        self.device
1252    }
1253    pub fn max_seq(&self) -> usize {
1254        self.max_seq
1255    }
1256    pub fn lm_vocab_size(&self) -> usize {
1257        self.weights.lm_vocab_size(&self.cfg)
1258    }
1259
1260    /// True when an mmproj vision encoder was loaded at build time.
1261    pub fn has_vision(&self) -> bool {
1262        self.vision_encoder.is_some()
1263    }
1264
1265    /// Optional path to the mmproj GGUF (if configured).
1266    pub fn mmproj_path(&self) -> Option<&std::path::Path> {
1267        self.mmproj_path.as_deref()
1268    }
1269
1270    fn effective_vocab(&self, graph_vocab: usize) -> usize {
1271        self.lm_vocab_size().min(graph_vocab)
1272    }
1273
1274    fn compile_hir_for_config(
1275        &mut self,
1276        config: ModelExecutionConfig,
1277        aot_disk_key: &str,
1278        hir: rlx_ir::hir::HirModule,
1279    ) -> Result<rlx_runtime::CompiledGraph> {
1280        let config = self.execution_config(config);
1281        let opts = self.dyn_compile_options(&config);
1282        if let Some(aot) = self.aot_cache.as_ref() {
1283            return Ok(aot.compile_hir_cached(aot_disk_key, self.device, hir, &opts)?);
1284        }
1285        // Per-`past_seq` decode HIR is concrete (not symbolic). Sharing one
1286        // `ModelCompilePipeline` template across variants reuses the wrong graph.
1287        if config.preset == ExecutionPreset::Qwen35Decode {
1288            return Ok(Session::new(self.device).compile_hir_with(hir, &opts)?);
1289        }
1290        let cache = self
1291            .predict_hir_cache
1292            .get_or_insert_with(|| make_qwen35_dyn_cache(self.device, 64, None));
1293        let hir = hir;
1294        get_or_specialize_hir_with_options(cache, &config, || hir.clone(), &opts, |_| Ok(()))?;
1295        if self.device == Device::Cpu {
1296            let compiled = get_or_specialize_hir_with_options(
1297                cache,
1298                &config,
1299                || hir.clone(),
1300                &opts,
1301                |_| Ok(()),
1302            )?;
1303            return Ok(compiled.clone());
1304        }
1305        Ok(Session::new(self.device).compile_hir_with(hir, &opts)?)
1306    }
1307
1308    fn lm_loader(&self) -> Option<&GgufLoader> {
1309        self.gguf_loader.as_ref()
1310    }
1311
1312    fn argmax_batch_from_hidden(&self, hidden: &[f32]) -> Result<Vec<u32>> {
1313        let n_embd = self.cfg.hidden_size;
1314        let mut toks = Vec::with_capacity(self.batch);
1315        for b in 0..self.batch {
1316            let h = &hidden[b * n_embd..(b + 1) * n_embd];
1317            let (idx, _) = greedy_lm_head_argmax(&self.weights, &self.cfg, h, self.lm_loader())?;
1318            toks.push(idx);
1319        }
1320        Ok(toks)
1321    }
1322
1323    fn sample_batch_from_hidden(&self, hidden: &[f32], opts: SampleOpts) -> Result<Vec<u32>> {
1324        let n_embd = self.cfg.hidden_size;
1325        let mut toks = Vec::with_capacity(self.batch);
1326        for b in 0..self.batch {
1327            let h = &hidden[b * n_embd..(b + 1) * n_embd];
1328            toks.push(sample_lm_head_from_hidden(
1329                &self.weights,
1330                &self.cfg,
1331                h,
1332                self.lm_loader(),
1333                opts,
1334            )?);
1335        }
1336        Ok(toks)
1337    }
1338
1339    fn decode_step_trunk_raw(
1340        &mut self,
1341        cache: &mut Qwen35DecodeCache,
1342        tokens: &[u32],
1343        generated_per_row: &[usize],
1344    ) -> Result<(Vec<f32>, Option<Vec<f32>>)> {
1345        if self.dynamic_decode {
1346            return self.decode_step_dynamic_raw(cache, tokens, generated_per_row);
1347        }
1348        let past_seq = cache.past_seq;
1349        let head_half = self.cfg.key_length / 2;
1350        let (cos, sin) = mrope_slice_at_pos(&self.cfg, past_seq, head_half);
1351        let use_bucket = self
1352            .decode_compile_cache
1353            .as_ref()
1354            .and_then(|c| c.bucket_for(past_seq as u64))
1355            .is_some();
1356        if use_bucket {
1357            self.decode_step_bucketed_raw(cache, tokens, generated_per_row, &cos, &sin)
1358        } else {
1359            let feeds_owned = decode_step_feeds(
1360                &self.cfg,
1361                cache,
1362                tokens,
1363                &cos,
1364                &sin,
1365                None,
1366                generated_per_row,
1367            )?;
1368            let feeds: Vec<(&str, &[f32])> = feeds_owned
1369                .iter()
1370                .map(|(k, v)| (k.as_str(), v.as_slice()))
1371                .collect();
1372            if !self.decode_graphs.contains_key(&past_seq) {
1373                let (hir, params, packed) = build_qwen35_decode_hir_ext(
1374                    &self.cfg,
1375                    self.weights.clone(),
1376                    self.batch,
1377                    past_seq,
1378                    false,
1379                    self.mtp_logits_path,
1380                    self.fast_mtp,
1381                    self.fast_greedy_lm_head,
1382                )?;
1383                let mut compiled = self.compile_hir_for_config(
1384                    decode_config(self.batch, past_seq),
1385                    &format!("decode_{past_seq}"),
1386                    hir,
1387                )?;
1388                for (name, data) in &params {
1389                    compiled.set_param(name, data);
1390                }
1391                upload_packed_opt(
1392                    &mut compiled,
1393                    self.gguf_loader.as_mut(),
1394                    &packed,
1395                    &mut self.packed_bytes_cache,
1396                )?;
1397                self.decode_graphs.insert(past_seq, compiled);
1398            }
1399            let step = self.moe_refresh_step;
1400            let has_moe = self.moe_offload.is_some();
1401            let num_experts = self.cfg.num_experts;
1402            let moe_masks = self
1403                .moe_offload
1404                .as_ref()
1405                .map(|m| m.per_layer_resident_masks());
1406            self.bind_moe_host_weights();
1407            let outs = {
1408                let compiled = self.decode_graphs.get_mut(&past_seq).unwrap();
1409                if has_moe {
1410                    compiled.enable_moe_topk_capture(num_experts);
1411                    if let Some(layers) = &moe_masks {
1412                        push_moe_residency(compiled, layers);
1413                    }
1414                }
1415                compiled.run(&feeds)
1416            };
1417            if has_moe {
1418                let layers = {
1419                    let compiled = self.decode_graphs.get_mut(&past_seq).unwrap();
1420                    compiled.take_moe_topk_capture()
1421                };
1422                if let (Some(mo), Some(layers)) = (self.moe_offload.as_mut(), layers) {
1423                    let store = self.moe_store.as_ref();
1424                    let compiled = self.decode_graphs.get_mut(&past_seq).unwrap();
1425                    if refresh_moe_from_capture(mo, store, compiled, &layers, step, false) {
1426                        if let Some(store) = self.moe_store.as_ref() {
1427                            store.apply_to_compiled(compiled);
1428                        }
1429                    }
1430                }
1431            }
1432            self.moe_refresh_step = step.saturating_add(1);
1433            advance_cache_from_decode_outputs(
1434                &self.cfg,
1435                cache,
1436                outs,
1437                None,
1438                self.mtp_logits_path,
1439                false,
1440                self.fast_greedy_lm_head,
1441            )
1442        }
1443    }
1444
1445    fn decode_step_dynamic_raw(
1446        &mut self,
1447        cache: &mut Qwen35DecodeCache,
1448        tokens: &[u32],
1449        generated_per_row: &[usize],
1450    ) -> Result<(Vec<f32>, Option<Vec<f32>>)> {
1451        let past_seq = cache.past_seq;
1452        let head_half = self.cfg.key_length / 2;
1453        let (cos, sin) = mrope_slice_at_pos(&self.cfg, past_seq, head_half);
1454        let feeds_owned = decode_step_feeds(
1455            &self.cfg,
1456            cache,
1457            tokens,
1458            &cos,
1459            &sin,
1460            None,
1461            generated_per_row,
1462        )?;
1463        let feeds: Vec<(&str, &[f32])> = feeds_owned
1464            .iter()
1465            .map(|(k, v)| (k.as_str(), v.as_slice()))
1466            .collect();
1467
1468        let config = self.execution_config(decode_config(self.batch, past_seq));
1469        let compile_opts = self.dyn_compile_options(&config);
1470        let dyn_cache = self
1471            .decode_dynamic_cache
1472            .as_mut()
1473            .ok_or_else(|| anyhow!("dynamic decode without cache"))?;
1474        let cfg = self.cfg.clone();
1475        let weights = self.weights.clone();
1476        let max_seq = self.max_seq;
1477        let mtp_logits_path = self.mtp_logits_path;
1478        let fast_mtp = self.fast_mtp;
1479        let fast_greedy = self.fast_greedy_lm_head;
1480        let batch = self.batch;
1481        let decode_params = &self.decode_dynamic_params;
1482        let decode_packed = &self.decode_dynamic_packed;
1483        let gguf_loader = &mut self.gguf_loader;
1484        let packed_bytes_cache = &mut self.packed_bytes_cache;
1485        let compiled = get_or_specialize_hir_with_options(
1486            dyn_cache,
1487            &config,
1488            || {
1489                build_qwen35_decode_hir_dynamic_ext(
1490                    &cfg,
1491                    weights,
1492                    batch,
1493                    max_seq,
1494                    mtp_logits_path,
1495                    fast_mtp,
1496                    fast_greedy,
1497                )
1498                .expect("dynamic decode HIR")
1499                .0
1500            },
1501            &compile_opts,
1502            |c| {
1503                for (name, data) in decode_params {
1504                    c.set_param(name, data);
1505                }
1506                upload_packed_opt(c, gguf_loader.as_mut(), decode_packed, packed_bytes_cache)
1507            },
1508        )?;
1509        let outs = compiled.run(&feeds);
1510        advance_cache_from_decode_outputs(
1511            &self.cfg,
1512            cache,
1513            outs,
1514            None,
1515            self.mtp_logits_path,
1516            false,
1517            self.fast_greedy_lm_head,
1518        )
1519    }
1520
1521    fn trunk_to_logits(&self, trunk: Vec<f32>, is_hidden: bool) -> Result<Vec<f32>> {
1522        if !is_hidden {
1523            return Ok(trunk);
1524        }
1525        let n_embd = self.cfg.hidden_size;
1526        let vocab = self.lm_vocab_size();
1527        let mut logits = Vec::with_capacity(self.batch * vocab);
1528        for b in 0..self.batch {
1529            let h = &trunk[b * n_embd..(b + 1) * n_embd];
1530            logits.extend(lm_head_logits_row(
1531                &self.weights,
1532                &self.cfg,
1533                h,
1534                self.lm_loader(),
1535            )?);
1536        }
1537        Ok(logits)
1538    }
1539
1540    /// Compile decode HIR for `key`'s bucket (if needed) and upload packed GGUF params once.
1541    fn ensure_decode_bucket_compiled(&mut self, key: u64) -> Result<usize> {
1542        let decode_opts = self.bucketed_decode_compile_options();
1543        let cache_mut = self
1544            .decode_compile_cache
1545            .as_mut()
1546            .ok_or_else(|| anyhow!("bucketed decode without cache"))?;
1547        let cfg = self.cfg.clone();
1548        let weights = self.weights.clone();
1549        let batch = self.batch;
1550        let mtp_logits_path = self.mtp_logits_path;
1551        let fast_mtp = self.fast_mtp;
1552        let fast_greedy = self.fast_greedy_lm_head;
1553        let packed_slot = RefCell::new(None::<PackedParams>);
1554        let (upper, compiled) = cache_mut
1555            .ensure_hir_with_params(
1556                key,
1557                |upper| {
1558                    let (hir, params, packed) = build_qwen35_decode_hir_ext(
1559                        &cfg,
1560                        weights.clone(),
1561                        batch,
1562                        upper as usize,
1563                        true,
1564                        mtp_logits_path,
1565                        fast_mtp,
1566                        fast_greedy,
1567                    )
1568                    .expect("qwen35 decode HIR");
1569                    *packed_slot.borrow_mut() = Some(packed);
1570                    (hir, params)
1571                },
1572                &decode_opts,
1573            )
1574            .ok_or_else(|| anyhow!("past_seq {key} outside decode buckets"))?;
1575        if let Some(packed) = packed_slot.take() {
1576            if !packed.is_empty() {
1577                upload_packed_opt(
1578                    compiled,
1579                    self.gguf_loader.as_mut(),
1580                    &packed,
1581                    &mut self.packed_bytes_cache,
1582                )?;
1583            }
1584        }
1585        Ok(upper as usize)
1586    }
1587
1588    /// Pre-compile every decode bucket and upload packed weights once.
1589    fn warm_decode_graphs(&mut self) -> Result<()> {
1590        let upper_bounds: Vec<usize> = match self.decode_compile_cache.as_ref() {
1591            Some(cache) => cache.buckets().map(|r| (r.end - 1) as usize).collect(),
1592            None => return Ok(()),
1593        };
1594        let t = Instant::now();
1595        let total = upper_bounds.len();
1596        for upper in upper_bounds {
1597            self.ensure_decode_bucket_compiled(upper as u64)?;
1598        }
1599        if total > 0 {
1600            eprintln!(
1601                "[qwen35] warmed {total} decode bucket(s) in {:.2?}",
1602                t.elapsed()
1603            );
1604        }
1605        Ok(())
1606    }
1607
1608    /// Pre-compile the predict (prefill logits) graph at build time.
1609    fn warm_predict_graph(&mut self) -> Result<()> {
1610        if self.compiled.is_some() {
1611            return Ok(());
1612        }
1613        let t = Instant::now();
1614        self.ensure_predict_compiled()?;
1615        eprintln!("[qwen35] warmed predict graph in {:.2?}", t.elapsed());
1616        Ok(())
1617    }
1618
1619    /// Clear the decode KV / recurrent cache (e.g. before a fresh prefill in spec decode).
1620    pub fn reset_decode_cache(&mut self) {
1621        self.decode_cache = None;
1622    }
1623
1624    /// Snapshot the current decode cache (for two-phase MTP draft propose).
1625    pub fn decode_cache_checkpoint(&self) -> Option<Qwen35DecodeCache> {
1626        self.decode_cache.clone()
1627    }
1628
1629    /// Restore a decode cache snapshot (discards uncommitted draft steps).
1630    pub fn restore_decode_cache(&mut self, cache: Option<Qwen35DecodeCache>) {
1631        self.decode_cache = cache;
1632    }
1633
1634    /// Advance the decode cache by `tokens` without returning logits (MTP commit path).
1635    pub fn commit_decode_tokens(&mut self, tokens: &[u32]) -> Result<()> {
1636        for &tok in tokens {
1637            let _ = self.decode_get_logits(tok)?;
1638        }
1639        Ok(())
1640    }
1641
1642    fn ensure_predict_compiled(&mut self) -> Result<()> {
1643        if self.compiled.is_some() {
1644            return Ok(());
1645        }
1646        // RLX_QWEN35_DEBUG_LAYERS=1 makes the predict graph emit every
1647        // trunk layer's hidden state as an extra output (gathered at
1648        // last_token_idx → shape `[batch, n_embd]` per layer). Combined
1649        // with the per-output stats dump in `predict_logits_batch` this
1650        // is the bisection harness for "all-zero logits" symptoms: the
1651        // log shows which layer first emits zero/NaN. Requires
1652        // `last_logits_only=true` (the assertion is in the builder).
1653        let debug_layers = std::env::var("RLX_QWEN35_DEBUG_LAYERS")
1654            .map(|v| v == "1")
1655            .unwrap_or(false);
1656        let t = Instant::now();
1657        let (hir, params, packed) = build_qwen35_hir_sized_ext(
1658            &self.cfg,
1659            self.weights.clone(),
1660            self.batch,
1661            self.max_seq,
1662            true,
1663            self.last_logits_only,
1664            self.enable_mtp,
1665            false,
1666            None,
1667            self.runtime_mrope,
1668            self.fast_mtp,
1669            false,
1670            debug_layers,
1671        )?;
1672        eprintln!(
1673            "[qwen35] built predict IR (lazy) in {:.2?} (params={}, packed={})",
1674            t.elapsed(),
1675            params.len(),
1676            packed.len(),
1677        );
1678        let t = Instant::now();
1679        let mut compiled = self.compile_hir_for_config(
1680            prefill_config(self.batch, self.max_seq),
1681            "predict_logits",
1682            hir,
1683        )?;
1684        eprintln!(
1685            "[qwen35] compiled predict graph (lazy) in {:.2?}",
1686            t.elapsed()
1687        );
1688        let t = Instant::now();
1689        for (name, data) in &params {
1690            compiled.set_param(name, data);
1691        }
1692        if !packed.is_empty() {
1693            upload_packed_opt(
1694                &mut compiled,
1695                self.gguf_loader.as_mut(),
1696                &packed,
1697                &mut self.packed_bytes_cache,
1698            )?;
1699        }
1700        eprintln!(
1701            "[qwen35] uploaded predict {} F32 + {} packed params in {:.2?}",
1702            params.len(),
1703            packed.len(),
1704            t.elapsed(),
1705        );
1706        self.compiled = Some(compiled);
1707        Ok(())
1708    }
1709
1710    pub fn predict_logits(&mut self, prompt_ids: &[u32]) -> Result<Qwen35PrefillOutput> {
1711        let out = self
1712            .predict_logits_batch(&[prompt_ids.to_vec()])
1713            .map(|v| v.into_iter().next().unwrap())?;
1714        // Preflight guard: a degenerate all-zero (or all-equal) logits
1715        // tensor is the signature of a broken forward pass — a buggy
1716        // op writing zeros, a packed-K-quant dispatch mis-routed, an
1717        // arena slot being read from the wrong offset, etc. Fail fast
1718        // with a clear error rather than silently returning argmax =
1719        // vocab_size − 1.
1720        if !out.logits.is_empty() {
1721            let mut min = f32::INFINITY;
1722            let mut max = f32::NEG_INFINITY;
1723            for &v in &out.logits {
1724                if v.is_finite() {
1725                    if v < min {
1726                        min = v;
1727                    }
1728                    if v > max {
1729                        max = v;
1730                    }
1731                }
1732            }
1733            if (max - min).abs() < 1e-6 {
1734                bail!(
1735                    "qwen35: predict_logits returned degenerate output \
1736                     (min={min}, max={max}) — the forward pass produced \
1737                     all-equal logits, which indicates a broken op or a \
1738                     mis-routed weight tensor in the trunk. Re-run with \
1739                     RUST_LOG=debug to capture the offending layer."
1740                );
1741            }
1742        }
1743        Ok(out)
1744    }
1745
1746    /// Prefill forward for `batch` prompts (must equal `self.batch()`).
1747    /// Each row may have a different length; all are zero-padded to
1748    /// `max_seq` in the compiled graph.
1749    pub fn predict_logits_batch(
1750        &mut self,
1751        batch_prompts: &[Vec<u32>],
1752    ) -> Result<Vec<Qwen35PrefillOutput>> {
1753        if batch_prompts.len() != self.batch {
1754            bail!(
1755                "qwen35: expected {} prompts (batch={}), got {}",
1756                self.batch,
1757                self.batch,
1758                batch_prompts.len()
1759            );
1760        }
1761        let max_prompt = batch_prompts.iter().map(|p| p.len()).max().unwrap_or(0);
1762        if max_prompt > self.max_seq {
1763            bail!(
1764                "qwen35: prompt length {max_prompt} exceeds compiled max_seq={}",
1765                self.max_seq
1766            );
1767        }
1768        let padded = pack_input_ids(batch_prompts, self.max_seq)?;
1769        let prompt_lens: Vec<usize> = batch_prompts.iter().map(|p| p.len()).collect();
1770        let last_idx = last_token_indices(&prompt_lens);
1771
1772        let mut feeds: Vec<(&str, &[f32])> = vec![("input_ids", padded.as_slice())];
1773        if self.last_logits_only {
1774            feeds.push(("last_token_idx", last_idx.as_slice()));
1775        }
1776        let rope_owned = self.mrope_prefill_rope_feeds(max_prompt);
1777        for (name, data) in &rope_owned {
1778            feeds.push((name.as_str(), data.as_slice()));
1779        }
1780        self.ensure_predict_compiled()?;
1781        let outs = self.compiled.as_mut().unwrap().run(&feeds);
1782        if outs.is_empty() {
1783            bail!("qwen35: forward produced no outputs");
1784        }
1785        // RLX_QWEN35_DEBUG_LAYERS=1 enabled extra per-layer outputs in
1786        // the predict graph (see `ensure_predict_compiled`). Dump stats
1787        // here so the next debugger can locate which layer first emits
1788        // zero/NaN values without re-running the entire 18-min cycle.
1789        if std::env::var("RLX_QWEN35_DEBUG_LAYERS").as_deref() == Ok("1") {
1790            // Output layout (when debug_layers is on):
1791            //   outs[0] = logits           [batch, vocab]
1792            //   outs[1..] = trunk_layer_hiddens (one per decoder layer)
1793            //               each [batch, n_embd]
1794            let n_layers = self.cfg.num_hidden_layers - self.cfg.nextn_predict_layers;
1795            for i in 0..outs.len() {
1796                let v = &outs[i];
1797                let mut min = f32::INFINITY;
1798                let mut max = f32::NEG_INFINITY;
1799                let mut sum = 0.0f64;
1800                let mut nan = 0usize;
1801                let mut nnz = 0usize;
1802                for &x in v {
1803                    if x.is_nan() {
1804                        nan += 1;
1805                        continue;
1806                    }
1807                    sum += x as f64;
1808                    if x < min {
1809                        min = x;
1810                    }
1811                    if x > max {
1812                        max = x;
1813                    }
1814                    if x != 0.0 {
1815                        nnz += 1;
1816                    }
1817                }
1818                let mean = sum / v.len().max(1) as f64;
1819                let label = if i == 0 {
1820                    "logits".to_string()
1821                } else if i - 1 < n_layers {
1822                    format!("layer_{:02}", i - 1)
1823                } else {
1824                    format!("extra_{:02}", i - 1 - n_layers)
1825                };
1826                eprintln!(
1827                    "[qwen35][debug-layers] {label}: len={} nnz={} nan={} min={} max={} mean={:.6}",
1828                    v.len(),
1829                    nnz,
1830                    nan,
1831                    min,
1832                    max,
1833                    mean
1834                );
1835            }
1836        }
1837        let vocab_size = if self.last_logits_only {
1838            outs[0].len() / self.batch
1839        } else {
1840            outs[0].len() / (self.batch * self.max_seq)
1841        };
1842        let sample_vocab = self.effective_vocab(vocab_size);
1843        let mtp_logits = if self.enable_mtp && outs.len() >= 2 {
1844            Some(outs[1].clone())
1845        } else {
1846            None
1847        };
1848        let mut per_batch = Vec::with_capacity(self.batch);
1849        for b in 0..self.batch {
1850            let start = b * vocab_size;
1851            let mut row = outs[0][start..start + vocab_size].to_vec();
1852            row.truncate(sample_vocab);
1853            per_batch.push(Qwen35PrefillOutput {
1854                logits: row,
1855                mtp_logits: mtp_logits.as_ref().map(|m| {
1856                    let m_vocab = m.len() / self.batch.max(1);
1857                    let mut mv = m[b * m_vocab..(b + 1) * m_vocab].to_vec();
1858                    mv.truncate(sample_vocab);
1859                    mv
1860                }),
1861                vocab_size: sample_vocab,
1862            });
1863        }
1864        Ok(per_batch)
1865    }
1866
1867    /// Greedy autoregressive generation with decode-state caching (batch=1).
1868    pub fn generate<F>(&mut self, prompt_ids: &[u32], n_new: usize, on_token: F) -> Result<Vec<u32>>
1869    where
1870        F: FnMut(u32) -> bool,
1871    {
1872        self.generate_with_opts(prompt_ids, n_new, SampleOpts::greedy(), on_token)
1873    }
1874
1875    /// Autoregressive generation with sampling options (batch=1).
1876    pub fn generate_with_opts<F>(
1877        &mut self,
1878        prompt_ids: &[u32],
1879        n_new: usize,
1880        opts: SampleOpts,
1881        mut on_token: F,
1882    ) -> Result<Vec<u32>>
1883    where
1884        F: FnMut(u32) -> bool,
1885    {
1886        if self.batch != 1 {
1887            bail!(
1888                "qwen35::generate: runner batch={} — use generate_batch() instead",
1889                self.batch
1890            );
1891        }
1892        let generated = self
1893            .generate_batch_with_opts(&[prompt_ids.to_vec()], n_new, None, opts, |_, tok| {
1894                on_token(tok)
1895            })?
1896            .into_iter()
1897            .next()
1898            .unwrap_or_default();
1899        Ok(generated)
1900    }
1901
1902    /// Batched greedy generation. `prompts.len()` must equal `self.batch()`.
1903    pub fn generate_batch<F>(
1904        &mut self,
1905        prompts: &[Vec<u32>],
1906        n_new: usize,
1907        on_token: F,
1908    ) -> Result<Vec<Vec<u32>>>
1909    where
1910        F: FnMut(usize, u32) -> bool,
1911    {
1912        self.generate_batch_with_opts(prompts, n_new, None, SampleOpts::greedy(), on_token)
1913    }
1914
1915    /// Batched generation with per-row token limits and sampling.
1916    ///
1917    /// `n_new_per_row`: optional per-row max new tokens (defaults to `n_new`).
1918    pub fn generate_batch_with_opts<F>(
1919        &mut self,
1920        prompts: &[Vec<u32>],
1921        n_new: usize,
1922        n_new_per_row: Option<&[usize]>,
1923        opts: SampleOpts,
1924        mut on_token: F,
1925    ) -> Result<Vec<Vec<u32>>>
1926    where
1927        F: FnMut(usize, u32) -> bool,
1928    {
1929        if prompts.is_empty() {
1930            bail!("qwen35::generate_batch: prompts must be non-empty");
1931        }
1932        if prompts.len() != self.batch {
1933            bail!(
1934                "qwen35::generate_batch: expected {} prompts, got {}",
1935                self.batch,
1936                prompts.len()
1937            );
1938        }
1939        if let Some(limits) = n_new_per_row {
1940            if limits.len() != self.batch {
1941                bail!(
1942                    "qwen35::generate_batch: n_new_per_row len {} != batch {}",
1943                    limits.len(),
1944                    self.batch
1945                );
1946            }
1947        }
1948        for (i, p) in prompts.iter().enumerate() {
1949            if p.is_empty() {
1950                bail!("qwen35::generate_batch: prompt row {i} is empty");
1951            }
1952        }
1953
1954        self.decode_cache = None;
1955
1956        let _prompt_lens: Vec<usize> = prompts.iter().map(|p| p.len()).collect();
1957        let row_limits: Vec<usize> = if let Some(limits) = n_new_per_row {
1958            limits.to_vec()
1959        } else {
1960            vec![n_new; self.batch]
1961        };
1962
1963        let (trunk, mut cache, _) = self.prefill_seed_decode_cache(prompts)?;
1964
1965        let mut generated: Vec<Vec<u32>> = vec![Vec::new(); self.batch];
1966        let mut active = vec![true; self.batch];
1967        let mut row_gen_count = vec![0usize; self.batch];
1968
1969        let mut next_tokens = if self.fast_greedy_lm_head && opts.greedy {
1970            self.argmax_batch_from_hidden(&trunk)?
1971        } else if self.fast_greedy_lm_head
1972            && sample_lm_cap(opts, self.lm_vocab_size()) < self.lm_vocab_size()
1973        {
1974            self.sample_batch_from_hidden(&trunk, opts)?
1975        } else {
1976            let logits = self.trunk_to_logits(trunk, self.fast_greedy_lm_head)?;
1977            sample_logits_batch(&logits, self.lm_vocab_size(), self.batch, opts)
1978        };
1979        if n_new > 0 {
1980            for b in 0..self.batch {
1981                if row_gen_count[b] >= row_limits[b] {
1982                    active[b] = false;
1983                    continue;
1984                }
1985                let tok = next_tokens[b];
1986                generated[b].push(tok);
1987                row_gen_count[b] += 1;
1988                active[b] = on_token(b, tok) && row_gen_count[b] < row_limits[b];
1989            }
1990        }
1991
1992        for _ in 1..n_new {
1993            if !active.iter().any(|&a| a) {
1994                break;
1995            }
1996            if cache.past_seq >= self.max_seq - 1 {
1997                bail!("qwen35: decode cache reached max_seq={}", self.max_seq);
1998            }
1999            next_tokens = self.decode_step(&mut cache, &next_tokens, &row_gen_count, opts)?;
2000            for b in 0..self.batch {
2001                if !active[b] || row_gen_count[b] >= row_limits[b] {
2002                    active[b] = false;
2003                    continue;
2004                }
2005                let tok = next_tokens[b];
2006                generated[b].push(tok);
2007                row_gen_count[b] += 1;
2008                active[b] = on_token(b, tok) && row_gen_count[b] < row_limits[b];
2009            }
2010            self.decode_cache = Some(cache.clone());
2011        }
2012        Ok(generated)
2013    }
2014
2015    /// Prefill and return trunk logits for the last prompt position.
2016    /// Seeds the decode cache so subsequent [`Self::decode_get_logits`] calls work.
2017    pub fn prefill_get_last_logits(&mut self, prompt_ids: &[u32]) -> Result<Vec<f32>> {
2018        Ok(self.prefill_seed_for_decode(prompt_ids)?.trunk_logits)
2019    }
2020
2021    /// Prefill-cache seed + decode cache for spec decode / MTP draft paths.
2022    pub fn prefill_seed_for_decode(&mut self, prompt_ids: &[u32]) -> Result<Qwen35PrefillSeed> {
2023        if self.batch != 1 {
2024            bail!(
2025                "qwen35: prefill_seed_for_decode requires batch=1 (runner batch={})",
2026                self.batch
2027            );
2028        }
2029        let (trunk, _, mtp_logits) = self.prefill_seed_decode_cache(&[prompt_ids.to_vec()])?;
2030        Ok(Qwen35PrefillSeed {
2031            trunk_logits: self.trunk_to_logits(trunk, self.fast_greedy_lm_head)?,
2032            mtp_logits,
2033        })
2034    }
2035
2036    /// Multimodal prefill: vision-encode `rgb`, splice into `prompt` at
2037    /// [`MEDIA_MARKER`](crate::MEDIA_MARKER), seed decode cache.
2038    pub fn prefill_multimodal(
2039        &mut self,
2040        prompt: &str,
2041        rgb: &[u8],
2042        img_w: usize,
2043        img_h: usize,
2044        tokenizer: Option<&std::path::Path>,
2045    ) -> Result<Qwen35PrefillSeed> {
2046        let (trunk, mtp_logits) =
2047            self.prefill_multimodal_trunk(prompt, rgb, img_w, img_h, tokenizer)?;
2048        Ok(Qwen35PrefillSeed {
2049            trunk_logits: self.trunk_to_logits(trunk, self.fast_greedy_lm_head)?,
2050            mtp_logits,
2051        })
2052    }
2053
2054    /// Prefill from an already-assembled multimodal payload (tests / custom tokenizers).
2055    pub fn prefill_from_assembled(
2056        &mut self,
2057        prefill: MultimodalPrefill,
2058    ) -> Result<Qwen35PrefillSeed> {
2059        if self.batch != 1 {
2060            bail!(
2061                "qwen35: prefill_from_assembled requires batch=1 (runner batch={})",
2062                self.batch
2063            );
2064        }
2065        self.mrope_section_positions = Some(prefill.mrope_sections.clone());
2066        let (trunk, _, mtp_logits) = self.prefill_seed_from_hidden(prefill)?;
2067        Ok(Qwen35PrefillSeed {
2068            trunk_logits: self.trunk_to_logits(trunk, self.fast_greedy_lm_head)?,
2069            mtp_logits,
2070        })
2071    }
2072
2073    fn prefill_multimodal_trunk(
2074        &mut self,
2075        prompt: &str,
2076        rgb: &[u8],
2077        img_w: usize,
2078        img_h: usize,
2079        tokenizer: Option<&std::path::Path>,
2080    ) -> Result<(Vec<f32>, Option<Vec<f32>>)> {
2081        if self.batch != 1 {
2082            bail!(
2083                "qwen35: prefill_multimodal requires batch=1 (runner batch={})",
2084                self.batch
2085            );
2086        }
2087        let vision = {
2088            let enc = self
2089                .vision_encoder
2090                .as_mut()
2091                .ok_or_else(|| anyhow!("qwen35: prefill_multimodal requires .mmproj(...)"))?;
2092            enc.encode_rgb(rgb, img_w, img_h)?
2093        };
2094        if self.weights.token_embd.is_empty() {
2095            bail!("qwen35: multimodal prefill requires token_embd weights");
2096        }
2097        let weights_path = self.weights_path.as_path();
2098        if weights_path.as_os_str().is_empty() {
2099            bail!("qwen35: multimodal prefill requires a GGUF weights path (for tokenizer)");
2100        }
2101        let n_embd = self.cfg.hidden_size;
2102        let mm = MultimodalPrompt {
2103            prompt,
2104            vision: &vision,
2105        };
2106        let prefill = mm.assemble(
2107            |text| encode_prompt_auto(weights_path, tokenizer, text),
2108            &self.weights.token_embd,
2109            n_embd,
2110            0,
2111        )?;
2112        self.mrope_section_positions = Some(prefill.mrope_sections.clone());
2113        let (trunk, _, mtp_logits) = self.prefill_seed_from_hidden(prefill)?;
2114        Ok((trunk, mtp_logits))
2115    }
2116
2117    /// Autoregressive generation from a multimodal prompt (batch=1).
2118    pub fn generate_multimodal_with_opts<F>(
2119        &mut self,
2120        prompt: &str,
2121        rgb: &[u8],
2122        img_w: usize,
2123        img_h: usize,
2124        tokenizer: Option<&std::path::Path>,
2125        n_new: usize,
2126        opts: SampleOpts,
2127        mut on_token: F,
2128    ) -> Result<Vec<u32>>
2129    where
2130        F: FnMut(u32) -> bool,
2131    {
2132        if self.batch != 1 {
2133            bail!(
2134                "qwen35: generate_multimodal requires batch=1 (runner batch={})",
2135                self.batch
2136            );
2137        }
2138        self.decode_cache = None;
2139        let (trunk, _) = self.prefill_multimodal_trunk(prompt, rgb, img_w, img_h, tokenizer)?;
2140        let mut cache = self
2141            .decode_cache
2142            .take()
2143            .ok_or_else(|| anyhow!("qwen35: multimodal prefill did not seed decode cache"))?;
2144        let mut next_tokens = if self.fast_greedy_lm_head && opts.greedy {
2145            self.argmax_batch_from_hidden(&trunk)?
2146        } else if self.fast_greedy_lm_head
2147            && sample_lm_cap(opts, self.lm_vocab_size()) < self.lm_vocab_size()
2148        {
2149            self.sample_batch_from_hidden(&trunk, opts)?
2150        } else {
2151            let logits = self.trunk_to_logits(trunk, self.fast_greedy_lm_head)?;
2152            sample_logits_batch(&logits, self.lm_vocab_size(), 1, opts)
2153        };
2154        let mut generated = Vec::new();
2155        if n_new > 0 {
2156            let tok = next_tokens[0];
2157            generated.push(tok);
2158            if !on_token(tok) {
2159                return Ok(generated);
2160            }
2161        }
2162        let row_gen = vec![0usize];
2163        for _ in 1..n_new {
2164            if cache.past_seq >= self.max_seq - 1 {
2165                bail!("qwen35: decode cache reached max_seq={}", self.max_seq);
2166            }
2167            next_tokens = self.decode_step(&mut cache, &next_tokens, &row_gen, opts)?;
2168            let tok = next_tokens[0];
2169            generated.push(tok);
2170            self.decode_cache = Some(cache.clone());
2171            if !on_token(tok) {
2172                break;
2173            }
2174        }
2175        Ok(generated)
2176    }
2177
2178    /// Greedy multimodal generation (batch=1).
2179    pub fn generate_multimodal<F>(
2180        &mut self,
2181        prompt: &str,
2182        rgb: &[u8],
2183        img_w: usize,
2184        img_h: usize,
2185        tokenizer: Option<&std::path::Path>,
2186        n_new: usize,
2187        on_token: F,
2188    ) -> Result<Vec<u32>>
2189    where
2190        F: FnMut(u32) -> bool,
2191    {
2192        self.generate_multimodal_with_opts(
2193            prompt,
2194            rgb,
2195            img_w,
2196            img_h,
2197            tokenizer,
2198            n_new,
2199            SampleOpts::greedy(),
2200            on_token,
2201        )
2202    }
2203
2204    /// Run the prefill-cache graph, seed decode state, return flattened batch logits.
2205    fn prefill_seed_decode_cache(
2206        &mut self,
2207        prompts: &[Vec<u32>],
2208    ) -> Result<(Vec<f32>, Qwen35DecodeCache, Option<Vec<f32>>)> {
2209        if prompts.len() != self.batch {
2210            bail!(
2211                "qwen35: expected {} prompts (batch={}), got {}",
2212                self.batch,
2213                self.batch,
2214                prompts.len()
2215            );
2216        }
2217        for (i, p) in prompts.iter().enumerate() {
2218            if p.is_empty() {
2219                bail!("qwen35: prompt row {i} is empty");
2220            }
2221        }
2222
2223        let prompt_lens: Vec<usize> = prompts.iter().map(|p| p.len()).collect();
2224        let seq = prompt_lens.iter().copied().max().unwrap();
2225        if seq > self.max_seq {
2226            bail!(
2227                "qwen35: prompt length {seq} exceeds compiled max_seq={}",
2228                self.max_seq
2229            );
2230        }
2231
2232        let input_ids = if self.dynamic_prefill {
2233            pack_input_ids(prompts, seq)?
2234        } else {
2235            pack_input_ids(prompts, self.max_seq)?
2236        };
2237        let last_idx = last_token_indices(&prompt_lens);
2238
2239        let mut feeds: Vec<(&str, &[f32])> = vec![("input_ids", input_ids.as_slice())];
2240        feeds.push(("last_token_idx", last_idx.as_slice()));
2241        let zero_in = zero_recurrent_inputs(&self.cfg, self.batch);
2242        for (name, data) in &zero_in {
2243            feeds.push((name, data.as_slice()));
2244        }
2245        let rope_owned = self.mrope_prefill_rope_feeds(seq);
2246        for (name, data) in &rope_owned {
2247            feeds.push((name.as_str(), data.as_slice()));
2248        }
2249
2250        let has_moe = self.moe_offload.is_some();
2251        let num_experts = self.cfg.num_experts;
2252        let moe_masks = self
2253            .moe_offload
2254            .as_ref()
2255            .map(|m| m.per_layer_resident_masks());
2256        self.bind_moe_host_weights();
2257
2258        let outs = if self.dynamic_prefill {
2259            let config = self.execution_config(prefill_config(self.batch, seq));
2260            let compile_opts = self.dyn_compile_options(&config);
2261            let compiled = {
2262                let cache = self
2263                    .prefill_dynamic_cache
2264                    .as_mut()
2265                    .expect("dynamic prefill cache");
2266                let cfg = self.cfg.clone();
2267                let weights = self.weights.clone();
2268                let runtime_mrope = self.runtime_mrope;
2269                let mtp_logits_path = self.mtp_logits_path;
2270                let fast_mtp = self.fast_mtp;
2271                let fast_greedy = self.fast_greedy_lm_head;
2272                let cache_params = &self.prefill_cache_params;
2273                let cache_packed = &self.prefill_cache_packed;
2274                let gguf_loader = &mut self.gguf_loader;
2275                let packed_bytes_cache = &mut self.packed_bytes_cache;
2276                get_or_specialize_hir_with_options(
2277                    cache,
2278                    &config,
2279                    || {
2280                        build_qwen35_prefill_cache_hir_dynamic_ext(
2281                            &cfg,
2282                            weights,
2283                            1,
2284                            seq,
2285                            runtime_mrope,
2286                            mtp_logits_path,
2287                            fast_mtp,
2288                            fast_greedy,
2289                        )
2290                        .expect("dynamic prefill HIR")
2291                        .0
2292                    },
2293                    &compile_opts,
2294                    |c| {
2295                        for (name, data) in cache_params {
2296                            c.set_param(name, data);
2297                        }
2298                        upload_packed_opt(c, gguf_loader.as_mut(), cache_packed, packed_bytes_cache)
2299                    },
2300                )?
2301            };
2302            if has_moe {
2303                compiled.enable_moe_topk_capture(num_experts);
2304                if let Some(layers) = &moe_masks {
2305                    push_moe_residency(compiled, layers);
2306                }
2307            }
2308            let outs = compiled.run(&feeds);
2309            if let Some(layers) = compiled.take_moe_topk_capture() {
2310                if let Some(mo) = self.moe_offload.as_mut() {
2311                    let store = self.moe_store.as_ref();
2312                    if refresh_moe_from_capture(mo, store, compiled, &layers, 0, true) {
2313                        if let Some(store) = self.moe_store.as_ref() {
2314                            store.apply_to_compiled(compiled);
2315                        }
2316                    }
2317                }
2318            }
2319            outs
2320        } else {
2321            let compiled = self.prefill_cache.as_mut().expect("static prefill cache");
2322            if has_moe {
2323                compiled.enable_moe_topk_capture(num_experts);
2324                if let Some(layers) = &moe_masks {
2325                    push_moe_residency(compiled, layers);
2326                }
2327            }
2328            let outs = compiled.run(&feeds);
2329            let layers = if has_moe {
2330                compiled.take_moe_topk_capture()
2331            } else {
2332                None
2333            };
2334            if let (Some(mo), Some(layers)) = (self.moe_offload.as_mut(), layers) {
2335                let store = self.moe_store.as_ref();
2336                if refresh_moe_from_capture(mo, store, compiled, &layers, 0, true) {
2337                    if let Some(store) = self.moe_store.as_ref() {
2338                        store.apply_to_compiled(compiled);
2339                    }
2340                }
2341            }
2342            outs
2343        };
2344        let (trunk, mut cache, mtp_logits) = seed_cache_from_outputs(
2345            &self.cfg,
2346            self.batch,
2347            seq,
2348            &prompt_lens,
2349            outs,
2350            self.mtp_logits_path,
2351            self.fast_greedy_lm_head,
2352        )?;
2353        zero_prompt_padding_kv(&self.cfg, &mut cache, seq);
2354        self.decode_cache = Some(cache.clone());
2355        Ok((trunk, cache, mtp_logits))
2356    }
2357
2358    /// VLM prefill-cache path: host-spliced hidden states + runtime MRoPE sections.
2359    fn prefill_seed_from_hidden(
2360        &mut self,
2361        prefill: MultimodalPrefill,
2362    ) -> Result<(Vec<f32>, Qwen35DecodeCache, Option<Vec<f32>>)> {
2363        let seq = prefill.seq.len();
2364        if seq == 0 {
2365            bail!("qwen35: multimodal prefill seq is empty");
2366        }
2367        if seq > self.max_seq {
2368            bail!(
2369                "qwen35: multimodal seq {seq} exceeds compiled max_seq={}",
2370                self.max_seq
2371            );
2372        }
2373        let n_embd = self.cfg.hidden_size;
2374        if prefill.hidden.len() != seq * n_embd {
2375            bail!(
2376                "qwen35: prefill hidden len {} != seq*n_embd {}*{}",
2377                prefill.hidden.len(),
2378                seq,
2379                n_embd
2380            );
2381        }
2382
2383        let last_idx = vec![prefill.last_token_idx as f32];
2384        let zero_in = zero_recurrent_inputs(&self.cfg, self.batch);
2385        let input_ids = if self.mtp_logits_path || self.enable_mtp {
2386            Some(pack_input_ids(std::slice::from_ref(&prefill.seq), seq)?)
2387        } else {
2388            None
2389        };
2390        let mut feeds: Vec<(&str, &[f32])> = vec![("prefill_hidden", prefill.hidden.as_slice())];
2391        feeds.push(("last_token_idx", last_idx.as_slice()));
2392        for (name, data) in &zero_in {
2393            feeds.push((name, data.as_slice()));
2394        }
2395        if let Some(ref ids) = input_ids {
2396            feeds.push(("input_ids", ids.as_slice()));
2397        }
2398        let rope_owned = self.mrope_prefill_rope_feeds(seq);
2399        for (name, data) in &rope_owned {
2400            feeds.push((name.as_str(), data.as_slice()));
2401        }
2402
2403        let config = self.execution_config(hidden_prefill_config(self.batch, seq));
2404        let compile_opts = self.dyn_compile_options(&config);
2405        let cache = self
2406            .prefill_hidden_dynamic_cache
2407            .as_mut()
2408            .ok_or_else(|| anyhow!("qwen35: hidden prefill cache missing (mmproj not loaded?)"))?;
2409        let cfg = self.cfg.clone();
2410        let weights = self.weights.clone();
2411        let runtime_mrope = self.runtime_mrope;
2412        let mtp_logits_path = self.mtp_logits_path;
2413        let fast_mtp = self.fast_mtp;
2414        let fast_greedy = self.fast_greedy_lm_head;
2415        let hidden_params = &self.prefill_hidden_cache_params;
2416        let hidden_packed = &self.prefill_hidden_cache_packed;
2417        let gguf_loader = &mut self.gguf_loader;
2418        let packed_bytes_cache = &mut self.packed_bytes_cache;
2419        let compiled = get_or_specialize_hir_with_options(
2420            cache,
2421            &config,
2422            || {
2423                build_qwen35_prefill_hidden_cache_hir_dynamic_ext(
2424                    &cfg,
2425                    weights,
2426                    1,
2427                    seq,
2428                    runtime_mrope,
2429                    mtp_logits_path,
2430                    fast_mtp,
2431                    fast_greedy,
2432                )
2433                .expect("dynamic hidden prefill HIR")
2434                .0
2435            },
2436            &compile_opts,
2437            |c| {
2438                for (name, data) in hidden_params {
2439                    c.set_param(name, data);
2440                }
2441                upload_packed_opt(c, gguf_loader.as_mut(), hidden_packed, packed_bytes_cache)
2442            },
2443        )?;
2444        let outs = compiled.run(&feeds);
2445        let prompt_lens = vec![seq];
2446        let (trunk, mut cache, mtp_logits) = seed_cache_from_outputs(
2447            &self.cfg,
2448            self.batch,
2449            seq,
2450            &prompt_lens,
2451            outs,
2452            self.mtp_logits_path,
2453            self.fast_greedy_lm_head,
2454        )?;
2455        zero_prompt_padding_kv(&self.cfg, &mut cache, seq);
2456        self.decode_cache = Some(cache.clone());
2457        Ok((trunk, cache, mtp_logits))
2458    }
2459
2460    /// Single cached decode step returning trunk logits for `token`.
2461    pub fn decode_get_logits(&mut self, token: u32) -> Result<Vec<f32>> {
2462        self.decode_forward_logits(token, false)
2463    }
2464
2465    /// Single cached decode step returning MTP head logits for `token`.
2466    pub fn decode_get_mtp_logits(&mut self, token: u32) -> Result<Vec<f32>> {
2467        if !self.mtp_logits_path {
2468            bail!("qwen35: decode_get_mtp_logits requires mtp_logits_path(true)");
2469        }
2470        self.decode_forward_logits(token, true)
2471    }
2472
2473    /// Prefill and return optional MTP logits (requires `enable_mtp`).
2474    pub fn prefill_get_mtp_logits(&mut self, prompt_ids: &[u32]) -> Result<Vec<f32>> {
2475        self.predict_logits(prompt_ids)?
2476            .mtp_logits
2477            .ok_or_else(|| anyhow!("qwen35: MTP logits unavailable (enable_mtp?)"))
2478    }
2479
2480    fn decode_step(
2481        &mut self,
2482        cache: &mut Qwen35DecodeCache,
2483        tokens: &[u32],
2484        generated_per_row: &[usize],
2485        opts: SampleOpts,
2486    ) -> Result<Vec<u32>> {
2487        if self.fast_greedy_lm_head {
2488            let vocab = self.lm_vocab_size();
2489            let (trunk, _mtp) = self.decode_step_trunk_raw(cache, tokens, generated_per_row)?;
2490            if opts.greedy {
2491                return self.argmax_batch_from_hidden(&trunk);
2492            }
2493            if sample_lm_cap(opts, vocab) < vocab {
2494                return self.sample_batch_from_hidden(&trunk, opts);
2495            }
2496        }
2497        let logits = self.decode_forward_logits_batch(cache, tokens, generated_per_row, false)?;
2498        Ok(sample_logits_batch(
2499            &logits,
2500            self.lm_vocab_size(),
2501            self.batch,
2502            opts,
2503        ))
2504    }
2505
2506    fn decode_forward_logits(&mut self, token: u32, want_mtp: bool) -> Result<Vec<f32>> {
2507        let mut cache = self
2508            .decode_cache
2509            .take()
2510            .ok_or_else(|| anyhow!("qwen35: decode requires seeded cache"))?;
2511        let row_gen = vec![0usize; self.batch];
2512        let logits = self.decode_forward_logits_batch(&mut cache, &[token], &row_gen, want_mtp)?;
2513        self.decode_cache = Some(cache);
2514        Ok(logits)
2515    }
2516
2517    fn decode_forward_logits_batch(
2518        &mut self,
2519        cache: &mut Qwen35DecodeCache,
2520        tokens: &[u32],
2521        generated_per_row: &[usize],
2522        want_mtp: bool,
2523    ) -> Result<Vec<f32>> {
2524        let past_seq = cache.past_seq;
2525        let head_half = self.cfg.key_length / 2;
2526        let (cos, sin) = mrope_slice_at_pos(&self.cfg, past_seq, head_half);
2527
2528        let use_bucket = self
2529            .decode_compile_cache
2530            .as_ref()
2531            .and_then(|c| c.bucket_for(past_seq as u64))
2532            .is_some();
2533
2534        if use_bucket {
2535            let (logits, mtp_logits) =
2536                self.decode_step_bucketed(cache, tokens, generated_per_row, &cos, &sin)?;
2537            if want_mtp {
2538                mtp_logits.ok_or_else(|| anyhow!("mtp decode logits missing from bucketed graph"))
2539            } else {
2540                Ok(logits)
2541            }
2542        } else {
2543            let feeds_owned = decode_step_feeds(
2544                &self.cfg,
2545                cache,
2546                tokens,
2547                &cos,
2548                &sin,
2549                None,
2550                generated_per_row,
2551            )?;
2552            let feeds: Vec<(&str, &[f32])> = feeds_owned
2553                .iter()
2554                .map(|(k, v)| (k.as_str(), v.as_slice()))
2555                .collect();
2556            if !self.decode_graphs.contains_key(&past_seq) {
2557                let (hir, params, packed) = build_qwen35_decode_hir_ext(
2558                    &self.cfg,
2559                    self.weights.clone(),
2560                    self.batch,
2561                    past_seq,
2562                    false,
2563                    self.mtp_logits_path,
2564                    self.fast_mtp,
2565                    self.fast_greedy_lm_head,
2566                )?;
2567                let mut compiled = self.compile_hir_for_config(
2568                    decode_config(self.batch, past_seq),
2569                    &format!("decode_{past_seq}"),
2570                    hir,
2571                )?;
2572                for (name, data) in &params {
2573                    compiled.set_param(name, data);
2574                }
2575                upload_packed_opt(
2576                    &mut compiled,
2577                    self.gguf_loader.as_mut(),
2578                    &packed,
2579                    &mut self.packed_bytes_cache,
2580                )?;
2581                self.decode_graphs.insert(past_seq, compiled);
2582            }
2583            let step = self.moe_refresh_step;
2584            let has_moe = self.moe_offload.is_some();
2585            let num_experts = self.cfg.num_experts;
2586            let moe_masks = self
2587                .moe_offload
2588                .as_ref()
2589                .map(|m| m.per_layer_resident_masks());
2590            self.bind_moe_host_weights();
2591            let outs = {
2592                let compiled = self.decode_graphs.get_mut(&past_seq).unwrap();
2593                if has_moe {
2594                    compiled.enable_moe_topk_capture(num_experts);
2595                    if let Some(layers) = &moe_masks {
2596                        push_moe_residency(compiled, layers);
2597                    }
2598                }
2599                compiled.run(&feeds)
2600            };
2601            if has_moe {
2602                let layers = {
2603                    let compiled = self.decode_graphs.get_mut(&past_seq).unwrap();
2604                    compiled.take_moe_topk_capture()
2605                };
2606                if let (Some(mo), Some(layers)) = (self.moe_offload.as_mut(), layers) {
2607                    let store = self.moe_store.as_ref();
2608                    let compiled = self.decode_graphs.get_mut(&past_seq).unwrap();
2609                    if refresh_moe_from_capture(mo, store, compiled, &layers, step, false) {
2610                        if let Some(store) = self.moe_store.as_ref() {
2611                            store.apply_to_compiled(compiled);
2612                        }
2613                    }
2614                }
2615            }
2616            self.moe_refresh_step = step.saturating_add(1);
2617            let (trunk, mtp_logits) = advance_cache_from_decode_outputs(
2618                &self.cfg,
2619                cache,
2620                outs,
2621                None,
2622                self.mtp_logits_path,
2623                want_mtp,
2624                self.fast_greedy_lm_head,
2625            )?;
2626            let logits = self.trunk_to_logits(trunk, self.fast_greedy_lm_head)?;
2627            if want_mtp {
2628                mtp_logits.ok_or_else(|| anyhow!("mtp decode logits missing from decode graph"))
2629            } else {
2630                Ok(logits)
2631            }
2632        }
2633    }
2634
2635    fn decode_step_bucketed(
2636        &mut self,
2637        cache: &mut Qwen35DecodeCache,
2638        tokens: &[u32],
2639        generated_per_row: &[usize],
2640        cos: &[f32],
2641        sin: &[f32],
2642    ) -> Result<(Vec<f32>, Option<Vec<f32>>)> {
2643        let (trunk, mtp) =
2644            self.decode_step_bucketed_raw(cache, tokens, generated_per_row, cos, sin)?;
2645        let logits = self.trunk_to_logits(trunk, self.fast_greedy_lm_head)?;
2646        Ok((logits, mtp))
2647    }
2648
2649    fn decode_step_bucketed_raw(
2650        &mut self,
2651        cache: &mut Qwen35DecodeCache,
2652        tokens: &[u32],
2653        generated_per_row: &[usize],
2654        cos: &[f32],
2655        sin: &[f32],
2656    ) -> Result<(Vec<f32>, Option<Vec<f32>>)> {
2657        let past_seq = cache.past_seq;
2658        let upper = self.ensure_decode_bucket_compiled(past_seq as u64)?;
2659
2660        let feeds_owned = decode_step_feeds(
2661            &self.cfg,
2662            cache,
2663            tokens,
2664            cos,
2665            sin,
2666            Some(upper),
2667            generated_per_row,
2668        )?;
2669        let feeds: Vec<(&str, &[f32])> = feeds_owned
2670            .iter()
2671            .map(|(k, v)| (k.as_str(), v.as_slice()))
2672            .collect();
2673
2674        let decode_opts = self.bucketed_decode_compile_options();
2675        let cache_mut = self.decode_compile_cache.as_mut().unwrap();
2676        let (_u, compiled) = cache_mut
2677            .ensure_hir_with_params(
2678                past_seq as u64,
2679                |_| panic!("decode bucket must be compiled"),
2680                &decode_opts,
2681            )
2682            .expect("decode bucket missing after ensure");
2683        // Scale attention KV length to actual cache position, not bucket upper.
2684        compiled.set_active_extent(Some((past_seq + 1, upper + 1)));
2685        let outs = compiled.run(&feeds);
2686        compiled.set_active_extent(None);
2687        advance_cache_from_decode_outputs(
2688            &self.cfg,
2689            cache,
2690            outs,
2691            Some(upper),
2692            self.mtp_logits_path,
2693            self.mtp_logits_path,
2694            self.fast_greedy_lm_head,
2695        )
2696    }
2697
2698    fn mrope_prefill_rope_feeds(&self, seq: usize) -> Vec<(String, Vec<f32>)> {
2699        if !self.runtime_mrope {
2700            return Vec::new();
2701        }
2702        let head_half = self.cfg.key_length / 2;
2703        let sections = self.mrope_section_positions.as_deref();
2704        let (cos, sin) = mrope_prefill_feeds(&self.cfg, seq, sections, head_half);
2705        vec![("rope_cos".into(), cos), ("rope_sin".into(), sin)]
2706    }
2707}
2708
2709impl rlx_cli::LmRunner for Qwen35Runner {
2710    fn family(&self) -> &'static str {
2711        "qwen35"
2712    }
2713    fn vocab_size(&self) -> usize {
2714        self.lm_vocab_size()
2715    }
2716    fn predict_logits(&mut self, prompt_ids: &[u32]) -> anyhow::Result<Vec<f32>> {
2717        let out = Qwen35Runner::predict_logits(self, prompt_ids)?;
2718        Ok(out.logits)
2719    }
2720    fn generate(
2721        &mut self,
2722        prompt_ids: &[u32],
2723        n_new: usize,
2724        on_token: &mut dyn FnMut(u32) -> bool,
2725    ) -> anyhow::Result<Vec<u32>> {
2726        // Qwen35Runner::generate takes a bool-returning callback
2727        // (false = stop). The trait callback now has the same shape,
2728        // so just forward.
2729        Qwen35Runner::generate(self, prompt_ids, n_new, on_token)
2730    }
2731
2732    fn supports_multimodal(&self) -> bool {
2733        // True when an mmproj path or inline mmproj weights were
2734        // attached at builder time. The encoder is lazy-loaded inside
2735        // `generate_multimodal_with_opts`.
2736        self.has_mmproj()
2737    }
2738
2739    fn generate_multimodal(
2740        &mut self,
2741        prompt: &str,
2742        rgb: &[u8],
2743        img_w: usize,
2744        img_h: usize,
2745        tokenizer: Option<&std::path::Path>,
2746        n_new: usize,
2747        on_token: &mut dyn FnMut(u32) -> bool,
2748    ) -> anyhow::Result<Vec<u32>> {
2749        Qwen35Runner::generate_multimodal(
2750            self, prompt, rgb, img_w, img_h, tokenizer, n_new, on_token,
2751        )
2752    }
2753}
2754
2755fn sample_logits_batch(logits: &[f32], vocab: usize, batch: usize, opts: SampleOpts) -> Vec<u32> {
2756    let mut out = Vec::with_capacity(batch);
2757    for b in 0..batch {
2758        let row = &logits[b * vocab..(b + 1) * vocab];
2759        out.push(sample_token(row, opts) as u32);
2760    }
2761    out
2762}