Skip to main content

rlx_qwen3/
generator.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Host-side generation loop for Qwen3.
17//!
18//! This is the **naive** generator: each `step()` rebuilds the prefill
19//! graph for the full token history and runs it from scratch
20//! (O(N²) compute over N generated tokens). The API is shaped to
21//! match the upcoming KV-cache version exactly so callers don't have
22//! to change anything when the cached path lands — only the internal
23//! implementation swaps.
24//!
25//! Why ship the naive version first:
26//!   - Establishes the public API contract before the IR/kernel
27//!     changes that the cached version needs land.
28//!   - Lets you run end-to-end generation against a real checkpoint
29//!     today and validate the prefill graph is numerically correct.
30//!   - Provides a reference baseline for the cached version's own
31//!     numerical-parity test (cached vs recompute must match).
32
33use crate::builder::{
34    build_qwen3_decode_graph_sized, build_qwen3_decode_graph_sized_ext,
35    build_qwen3_graph_sized_last_logits,
36};
37use crate::capabilities::validate_device;
38use crate::config::Qwen3Config;
39use crate::profile::qwen3_profile_near_weights;
40use crate::sampling::{SampleOpts, sample_token};
41use anyhow::{Context, Result};
42use rlx_core::autoregressive::{
43    DecodeLogitsKv, KvCacheState, compile_cache_ensure_graph, kv_from_prefill_outputs,
44    prefill_cache_key, run_bucketed_kv_decode, split_decode_logits_kv,
45};
46use rlx_core::flow_bridge::compile_options_from_profile;
47use rlx_core::weight_loader::WeightLoader;
48use rlx_core::weight_map::WeightMap;
49use rlx_flow::CompileProfile;
50use rlx_ir::logical_kernel::KernelDispatchConfig;
51use rlx_runtime::attn_mask::bucket_decode_mask;
52use rlx_runtime::compile_cache::{BucketedCompileCache, CacheRunInput, CompileCache};
53use rlx_runtime::{CompileOptions, Device, Session};
54use std::collections::HashMap;
55use std::path::Path;
56
57/// Stateful Qwen3 generation handle.
58///
59/// Holds the (config, weight bytes, token history) and rebuilds a
60/// prefill graph on each [`step`] call. Cheap to construct after
61/// initial weight load; tokens stay in-memory between calls.
62pub struct Qwen3Generator {
63    cfg: Qwen3Config,
64    /// Map of weight key → (f32 data, shape). Cloned on each step
65    /// into a fresh `WeightMap` because `WeightMap::take` is
66    /// destructive — see the cached-generator notes for the path
67    /// that avoids the clone.
68    weights_cache: HashMap<String, (Vec<f32>, Vec<usize>)>,
69    tokens: Vec<u32>,
70    device: Device,
71    /// Populated lazily on the first `step_cached` call (seeded from
72    /// the prompt via prefill-with-cache); thereafter advanced by each
73    /// decode step.
74    cache: Option<KvCacheState>,
75    /// Per-key LRU compile cache for prefill graphs. Keyed by `seq`.
76    /// Set to `None` to disable (default for new instances; opt in via
77    /// [`Qwen3Generator::with_prefill_cache`]).
78    prefill_compile_cache: Option<CompileCache>,
79    /// Bucketed compile cache for decode-mode graphs. Each bucket
80    /// holds one compiled graph specialized at its upper-bound
81    /// `past_seq`; the host pads `past_k`/`past_v` and supplies a
82    /// per-step mask so a single bucket serves every `past_seq` in
83    /// its range. Opt in via [`Qwen3Generator::with_decode_cache`].
84    decode_compile_cache: Option<BucketedCompileCache>,
85    prefill_profile: CompileProfile,
86    decode_profile: CompileProfile,
87}
88
89impl Qwen3Generator {
90    /// Construct from any [`WeightLoader`] — drains it into an
91    /// internal cache so the loader is free after this call.
92    pub fn from_loader(
93        cfg: Qwen3Config,
94        loader: &mut dyn WeightLoader,
95        device: Device,
96    ) -> Result<Self> {
97        validate_device(&cfg, device, false)?;
98        let keys = loader.remaining_keys();
99        let mut weights_cache = HashMap::with_capacity(keys.len());
100        for k in keys {
101            let v = loader
102                .take(&k)
103                .with_context(|| format!("draining weight {k}"))?;
104            // Normalize the cache key to the safetensors / HuggingFace
105            // naming convention so subsequent builder calls that ask
106            // for `model.embed_tokens.weight` (the canonical name baked
107            // into the qwen3 builder) hit the cache whether the
108            // loader was safetensors-native or GGUF-native.
109            let canonical =
110                rlx_core::weight_loader::gguf_to_hf_name(&k).unwrap_or_else(|| k.clone());
111            weights_cache.insert(canonical, v);
112        }
113        let max_past = cfg.max_position_embeddings.clamp(1, 4096);
114        Ok(Self {
115            cfg,
116            weights_cache,
117            tokens: Vec::new(),
118            device,
119            cache: None,
120            prefill_compile_cache: Some(CompileCache::new(device, 8)),
121            decode_compile_cache: Some(BucketedCompileCache::power_of_two_ladder(
122                device,
123                1,
124                max_past as u64,
125            )),
126            prefill_profile: CompileProfile::qwen3_prefill(),
127            decode_profile: CompileProfile::qwen3_decode(),
128        })
129    }
130
131    /// Like [`Self::from_loader`] but loads `qwen3.rlx.toml` from the weights directory when present.
132    pub fn from_loader_at(
133        cfg: Qwen3Config,
134        loader: &mut dyn WeightLoader,
135        device: Device,
136        weights_path: &Path,
137    ) -> Result<Self> {
138        let mut g = Self::from_loader(cfg, loader, device)?;
139        g.prefill_profile = qwen3_profile_near_weights(weights_path, false);
140        g.decode_profile = qwen3_profile_near_weights(weights_path, true);
141        Ok(g)
142    }
143
144    pub fn with_compile_profiles(
145        mut self,
146        prefill: CompileProfile,
147        decode: CompileProfile,
148    ) -> Self {
149        self.prefill_profile = prefill;
150        self.decode_profile = decode;
151        self
152    }
153
154    pub fn prefill_profile(&self) -> &CompileProfile {
155        &self.prefill_profile
156    }
157
158    pub fn decode_profile(&self) -> &CompileProfile {
159        &self.decode_profile
160    }
161
162    fn profile_compile_options(&self, decode: bool) -> CompileOptions {
163        let profile = if decode {
164            &self.decode_profile
165        } else {
166            &self.prefill_profile
167        };
168        compile_options_from_profile(profile, self.device, KernelDispatchConfig::default())
169    }
170
171    /// Enable the prefill compile cache with the given LRU capacity.
172    /// Useful when the same prompt length is used across multiple
173    /// generation runs — the second + Nth run skip the compile +
174    /// param-attach roundtrip (~30-50ms per call on CPU).
175    pub fn with_prefill_cache(mut self, capacity: usize) -> Self {
176        self.prefill_compile_cache = Some(CompileCache::new(self.device, capacity));
177        self
178    }
179
180    /// Enable the bucketed decode compile cache spanning past-seq
181    /// values in `[1, max_past]`. Buckets are power-of-two
182    /// `[1..2, 2..3, 3..5, 5..9, 9..17, …]`. Each bucket compiles
183    /// one graph at its upper bound; a steady-state generation loop
184    /// across `N` tokens compiles `O(log N)` graphs instead of `N`.
185    ///
186    /// Padding compute waste is bounded at 2×: actual `past_seq` is
187    /// at least half the bucket's upper bound (except possibly the
188    /// smallest bucket).
189    pub fn with_decode_cache(mut self, max_past: usize) -> Self {
190        let cache = BucketedCompileCache::power_of_two_ladder(
191            self.device,
192            /*min*/ 1,
193            max_past.max(1) as u64,
194        );
195        self.decode_compile_cache = Some(cache);
196        self
197    }
198
199    /// Convenience: load weights from a safetensors or GGUF path
200    /// (dispatch by extension; see `rlx_core::weight_loader::load_from_path`).
201    pub fn from_path(cfg: Qwen3Config, path: &str, device: Device) -> Result<Self> {
202        Self::from_path_at(cfg, path, device, Path::new("."))
203    }
204
205    /// Like [`Self::from_path`] with an explicit weights path for `qwen3.rlx.toml` discovery.
206    pub fn from_path_at(
207        cfg: Qwen3Config,
208        path: &str,
209        device: Device,
210        weights_path: &Path,
211    ) -> Result<Self> {
212        let mut loader = rlx_core::weight_loader::load_from_path(path)?;
213        Self::from_loader_at(cfg, loader.as_mut(), device, weights_path)
214    }
215
216    /// Same as [`from_path`] but with MTP-head visibility control.
217    /// When `include_mtp=true` and the file is GGUF, MTP weights are
218    /// drained into the generator's cache alongside the base
219    /// weights. The base inference path still ignores them — they
220    /// sit in cache for a future MTP-aware decoder. Non-GGUF formats
221    /// silently ignore the flag (safetensors files publish all
222    /// tensors uniformly; downstream code distinguishes by name).
223    pub fn from_path_with_mtp(
224        cfg: Qwen3Config,
225        path: &str,
226        device: Device,
227        include_mtp: bool,
228    ) -> Result<Self> {
229        // Branch on extension so we can flip the GGUF-specific
230        // visibility option. Safetensors has no equivalent — it
231        // doesn't isolate MTP tensors at the loader level.
232        if path.ends_with(".gguf") {
233            let mut gguf = rlx_core::weight_loader::GgufLoader::from_file(path)?;
234            gguf.include_mtp(include_mtp);
235            Self::from_loader(cfg, &mut gguf, device)
236        } else {
237            Self::from_path(cfg, path, device)
238        }
239    }
240
241    /// Replace the token history with `prompt_ids`. Does not run the
242    /// model — the next [`step`] call processes the full sequence.
243    /// Clears any KV cache from a prior generation.
244    pub fn prefill(&mut self, prompt_ids: &[u32]) {
245        self.tokens.clear();
246        self.tokens.extend_from_slice(prompt_ids);
247        self.cache = None;
248    }
249
250    /// Run one prefill over the current token history and sample the
251    /// next token. The sampled token is appended to the history and
252    /// returned. Call repeatedly to generate.
253    pub fn step(&mut self, opts: SampleOpts) -> Result<u32> {
254        if self.tokens.is_empty() {
255            anyhow::bail!("step() called with empty token history; call prefill() first");
256        }
257        let seq = self.tokens.len();
258        let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
259        let (graph, params) = build_qwen3_graph_sized_last_logits(
260            &self.cfg, &mut wm, /*batch*/ 1, seq, /*with_kv_outputs*/ false,
261        )?;
262        let compile_opts = self.profile_compile_options(false);
263        let mut compiled = Session::new(self.device).compile_with(graph, &compile_opts);
264        for (name, data) in &params {
265            compiled.set_param(name, data);
266        }
267        let ids_f32: Vec<f32> = self.tokens.iter().map(|&i| i as f32).collect();
268        let outputs = compiled.run(&[("input_ids", ids_f32.as_slice())]);
269        let logits = outputs
270            .into_iter()
271            .next()
272            .context("compiled.run returned no outputs")?;
273
274        let vocab = self.cfg.vocab_size;
275        let expected = vocab;
276        if logits.len() < expected {
277            anyhow::bail!(
278                "logits length {} < expected {} (last logits, seq {seq}, vocab {vocab})",
279                logits.len(),
280                expected
281            );
282        }
283        // Last-logits graph returns [B=1, 1, vocab].
284        let last_row = &logits[..vocab];
285        let tok = sample_token(last_row, opts) as u32;
286        self.tokens.push(tok);
287        Ok(tok)
288    }
289
290    /// Run `n` steps and return the newly generated token ids
291    /// (excludes the prefill prompt).
292    pub fn generate(&mut self, n: usize, opts: SampleOpts) -> Result<Vec<u32>> {
293        if self.decode_compile_cache.is_some() {
294            return self.generate_cached(n, opts);
295        }
296        let start = self.tokens.len();
297        for _ in 0..n {
298            self.step(opts)?;
299        }
300        Ok(self.tokens[start..].to_vec())
301    }
302
303    /// Cached step: O(L) per token instead of O(L²). First call seeds
304    /// the KV cache from the prompt via prefill-with-cache; subsequent
305    /// calls run the decode-mode graph on just the last token + cached
306    /// past. Output is bit-identical to [`step`] modulo reduction
307    /// order in the SDPA kernel.
308    ///
309    /// Invariant after each call: `cache.past_seq == tokens.len() - 1`
310    /// (the just-sampled token is appended but not yet in the cache;
311    /// it becomes the input for the next decode step).
312    pub fn step_cached(&mut self, opts: SampleOpts) -> Result<u32> {
313        if self.tokens.is_empty() {
314            anyhow::bail!("step_cached() called with empty token history; call prefill() first");
315        }
316        if self.cache.is_none() {
317            // The seed runs prefill, populates the cache, samples from
318            // the last position, and appends the token. Return that
319            // token directly — no decode step on this call.
320            let tok = self.seed_cache_from_prompt(opts)?;
321            return Ok(tok);
322        }
323        let cache = self.cache.as_ref().unwrap();
324        let past_seq = cache.past_len;
325        // The token we feed into decode is whatever's after the cached
326        // prefix in `self.tokens`. After a prior cached step this is
327        // the just-sampled token; after seeding it's the same.
328        if self.tokens.len() <= past_seq {
329            anyhow::bail!(
330                "cache invariant violated: tokens.len() {} <= past_seq {}",
331                self.tokens.len(),
332                past_seq
333            );
334        }
335        let input_tok = self.tokens[past_seq];
336
337        // Branch: bucketed compile cache vs one-shot compile per step.
338        let (logits, new_k, new_v) = if self.decode_compile_cache.is_some()
339            && self
340                .decode_compile_cache
341                .as_ref()
342                .unwrap()
343                .bucket_for(past_seq as u64)
344                .is_some()
345        {
346            self.decode_step_bucketed(past_seq, input_tok)?
347        } else {
348            self.decode_step_oneshot(past_seq, input_tok)?
349        };
350
351        let cache_mut = self.cache.as_mut().unwrap();
352        cache_mut.past_len = past_seq + 1;
353        cache_mut.layers_k = new_k;
354        cache_mut.layers_v = new_v;
355
356        let vocab = self.cfg.vocab_size;
357        if logits.len() != vocab {
358            anyhow::bail!("decode logits length {} != vocab {}", logits.len(), vocab);
359        }
360        let tok = sample_token(&logits, opts) as u32;
361        self.tokens.push(tok);
362        Ok(tok)
363    }
364
365    /// Decode path that compiles a fresh graph for the exact `past_seq`
366    /// every call. Slower but always-correct fallback.
367    fn decode_step_oneshot(&mut self, past_seq: usize, input_tok: u32) -> Result<DecodeLogitsKv> {
368        let cache = self.cache.as_ref().unwrap();
369
370        let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
371        let (graph, params) =
372            build_qwen3_decode_graph_sized(&self.cfg, &mut wm, /*batch*/ 1, past_seq)?;
373        let opts = self.profile_compile_options(true);
374        let mut compiled = Session::new(self.device).compile_with(graph, &opts);
375        for (name, data) in &params {
376            compiled.set_param(name, data);
377        }
378
379        let (cos, sin) = compute_rope_slice(&self.cfg, past_seq);
380        let input_ids_f32 = [input_tok as f32];
381        let key_strs: Vec<String> = (0..self.cfg.num_hidden_layers)
382            .flat_map(|i| [format!("past_k_{i}"), format!("past_v_{i}")])
383            .collect();
384        let mut inputs: Vec<(&str, &[f32])> =
385            Vec::with_capacity(3 + 2 * self.cfg.num_hidden_layers);
386        inputs.push(("input_ids", input_ids_f32.as_slice()));
387        inputs.push(("rope_cos", cos.as_slice()));
388        inputs.push(("rope_sin", sin.as_slice()));
389        for i in 0..self.cfg.num_hidden_layers {
390            inputs.push((&key_strs[2 * i], cache.layers_k[i].as_slice()));
391            inputs.push((&key_strs[2 * i + 1], cache.layers_v[i].as_slice()));
392        }
393
394        let outputs = compiled.run(&inputs);
395        split_decode_logits_kv(outputs, self.cfg.num_hidden_layers)
396    }
397
398    fn decode_step_bucketed(&mut self, past_seq: usize, input_tok: u32) -> Result<DecodeLogitsKv> {
399        let kv = self.cache.as_ref().unwrap().clone();
400        let kv_dim = self.cfg.kv_proj_dim();
401        let n_layers = self.cfg.num_hidden_layers;
402        let (cos, sin) = compute_rope_slice(&self.cfg, past_seq);
403        let input_ids_f32 = [input_tok as f32];
404        let decode_opts = self.profile_compile_options(true);
405        let upper = self
406            .decode_compile_cache
407            .as_ref()
408            .and_then(|cache_dec| {
409                cache_dec.bucket_for(past_seq as u64).map(|idx| {
410                    cache_dec
411                        .buckets()
412                        .nth(idx)
413                        .map(|r| (r.end - 1) as usize)
414                        .unwrap_or(past_seq)
415                })
416            })
417            .unwrap_or(past_seq);
418        let mask = bucket_decode_mask(past_seq, upper);
419        let fixed = [
420            CacheRunInput {
421                name: "input_ids",
422                data: &input_ids_f32,
423                row_inner: None,
424            },
425            CacheRunInput {
426                name: "rope_cos",
427                data: &cos,
428                row_inner: None,
429            },
430            CacheRunInput {
431                name: "rope_sin",
432                data: &sin,
433                row_inner: None,
434            },
435            CacheRunInput {
436                name: "mask",
437                data: &mask,
438                row_inner: None,
439            },
440        ];
441        let cfg = self.cfg.clone();
442        let weights = self.weights_cache.clone();
443        let cache_dec = self.decode_compile_cache.as_mut().unwrap();
444        run_bucketed_kv_decode(
445            cache_dec,
446            past_seq,
447            &kv,
448            kv_dim,
449            n_layers,
450            &fixed,
451            |upper| {
452                let mut wm = WeightMap::from_tensors(weights.clone());
453                build_qwen3_decode_graph_sized_ext(&cfg, &mut wm, 1, upper as usize, true)
454                    .expect("qwen3 bucketed decode graph")
455            },
456            &decode_opts,
457        )
458    }
459
460    /// Run prefill-with-cache and return the raw outputs. Uses the
461    /// LRU `CompileCache` when enabled; otherwise compiles fresh each
462    /// call. Keyed by `seq` because graph shape is seq-specialized.
463    fn run_prefill_with_cache(
464        &mut self,
465        batch: usize,
466        seq: usize,
467        ids_f32: &[f32],
468    ) -> Result<Vec<Vec<f32>>> {
469        let prefill_opts = self.profile_compile_options(false);
470        if let Some(cache) = &mut self.prefill_compile_cache {
471            let key = prefill_cache_key(batch, seq);
472            let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
473            let (graph, params) = build_qwen3_graph_sized_last_logits(
474                &self.cfg, &mut wm, batch, seq, /*with_kv_outputs*/ true,
475            )?;
476            let compiled = compile_cache_ensure_graph(cache, key, graph, params, &prefill_opts);
477            Ok(compiled.run(&[("input_ids", ids_f32)]))
478        } else {
479            let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
480            let (graph, params) = build_qwen3_graph_sized_last_logits(
481                &self.cfg, &mut wm, batch, seq, /*with_kv_outputs*/ true,
482            )?;
483            let opts = self.profile_compile_options(false);
484            let mut compiled = Session::new(self.device).compile_with(graph, &opts);
485            for (name, data) in &params {
486                compiled.set_param(name, data);
487            }
488            Ok(compiled.run(&[("input_ids", ids_f32)]))
489        }
490    }
491
492    /// Run `n` cached steps and return the newly generated tokens.
493    pub fn generate_cached(&mut self, n: usize, opts: SampleOpts) -> Result<Vec<u32>> {
494        self.generate_cached_with(n, opts, |_| {})
495    }
496
497    /// Same as [`generate_cached`] but invokes `on_token` once per
498    /// freshly sampled id, inside the decode loop. The whole `n` step
499    /// loop shares the bucketed compile cache — callers wanting a
500    /// streaming UI should prefer this to calling
501    /// `generate_cached(1, …)` `n` times (which forces a fresh
502    /// compile per token at the bucket boundaries).
503    pub fn generate_cached_with(
504        &mut self,
505        n: usize,
506        opts: SampleOpts,
507        mut on_token: impl FnMut(u32),
508    ) -> Result<Vec<u32>> {
509        let start = self.tokens.len();
510        for _ in 0..n {
511            let tok = self.step_cached(opts)?;
512            on_token(tok);
513        }
514        Ok(self.tokens[start..].to_vec())
515    }
516
517    /// Run prefill-with-cache on the current `self.tokens` (the
518    /// prompt), populate `self.cache`, sample the next token from the
519    /// last position's logits, and append it. Returns the sampled
520    /// token. Invariant after: `cache.past_seq == tokens.len() - 1`.
521    fn seed_cache_from_prompt(&mut self, opts: SampleOpts) -> Result<u32> {
522        let seq = self.tokens.len();
523        let batch = 1usize;
524        let kv_dim = self.cfg.kv_proj_dim();
525
526        let ids_f32: Vec<f32> = self.tokens.iter().map(|&i| i as f32).collect();
527        let outputs = self.run_prefill_with_cache(batch, seq, &ids_f32)?;
528        let (logits, kv) =
529            kv_from_prefill_outputs(outputs, batch, seq, kv_dim, self.cfg.num_hidden_layers)?;
530        self.cache = Some(kv);
531
532        let vocab = self.cfg.vocab_size;
533        let needed = vocab;
534        if logits.len() < needed {
535            anyhow::bail!("prefill logits length {} < {}", logits.len(), needed);
536        }
537        let last_row = &logits[..vocab];
538        let tok = sample_token(last_row, opts) as u32;
539        self.tokens.push(tok);
540        Ok(tok)
541    }
542
543    /// Full token history (prompt + generated).
544    pub fn tokens(&self) -> &[u32] {
545        &self.tokens
546    }
547
548    pub fn config(&self) -> &Qwen3Config {
549        &self.cfg
550    }
551
552    /// Low-level primitive: reset internal state, run prefill-with-cache
553    /// over `context`, and return the *last position's* logits row
554    /// (`P(next_token | context)`). Does NOT sample or append. The
555    /// internal `tokens` buffer is set to `context` and the KV cache
556    /// is populated to `past_seq = context.len()`.
557    ///
558    /// Used by [`crate::spec::Qwen3Speculator`] to compute the
559    /// first row of a `Speculator::verify` / `propose` result before
560    /// the decode loop runs.
561    pub fn prefill_get_last_logits(&mut self, context: &[u32]) -> Result<Vec<f32>> {
562        if context.is_empty() {
563            anyhow::bail!("prefill_get_last_logits: empty context");
564        }
565        self.tokens.clear();
566        self.tokens.extend_from_slice(context);
567        self.cache = None;
568
569        let seq = context.len();
570        let batch = 1usize;
571        let kv_dim = self.cfg.kv_proj_dim();
572
573        let ids_f32: Vec<f32> = context.iter().map(|&i| i as f32).collect();
574        let outputs = self.run_prefill_with_cache(batch, seq, &ids_f32)?;
575        let (logits, kv) =
576            kv_from_prefill_outputs(outputs, batch, seq, kv_dim, self.cfg.num_hidden_layers)?;
577        self.cache = Some(kv);
578
579        let vocab = self.cfg.vocab_size;
580        let needed = vocab;
581        if logits.len() < needed {
582            anyhow::bail!("logits short: {} < {}", logits.len(), needed);
583        }
584        Ok(logits[..vocab].to_vec())
585    }
586
587    /// Low-level primitive: run one decode step with the caller-
588    /// supplied input token (no sampling), advance the KV cache, and
589    /// return the resulting logits row `P(next | history ++ input)`.
590    /// Appends `input` to the `tokens` buffer so the invariant
591    /// `cache.past_seq == tokens.len()` holds after this call (note:
592    /// differs from `step_cached` invariant because this method does
593    /// not append a sampled token).
594    pub fn decode_get_logits(&mut self, input: u32) -> Result<Vec<f32>> {
595        let cache = self.cache.as_ref().ok_or_else(|| {
596            anyhow::anyhow!(
597                "decode_get_logits: cache not seeded; call prefill_get_last_logits first"
598            )
599        })?;
600        let past_seq = cache.past_len;
601
602        let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
603        let (graph, params) =
604            build_qwen3_decode_graph_sized(&self.cfg, &mut wm, /*batch*/ 1, past_seq)?;
605        let opts = self.profile_compile_options(true);
606        let mut compiled = Session::new(self.device).compile_with(graph, &opts);
607        for (name, data) in &params {
608            compiled.set_param(name, data);
609        }
610
611        let (cos, sin) = compute_rope_slice(&self.cfg, past_seq);
612        let input_ids_f32 = [input as f32];
613        let key_strs: Vec<String> = (0..self.cfg.num_hidden_layers)
614            .flat_map(|i| [format!("past_k_{i}"), format!("past_v_{i}")])
615            .collect();
616        let mut inputs: Vec<(&str, &[f32])> =
617            Vec::with_capacity(3 + 2 * self.cfg.num_hidden_layers);
618        inputs.push(("input_ids", input_ids_f32.as_slice()));
619        inputs.push(("rope_cos", cos.as_slice()));
620        inputs.push(("rope_sin", sin.as_slice()));
621        for i in 0..self.cfg.num_hidden_layers {
622            let pk = &cache.layers_k[i];
623            let pv = &cache.layers_v[i];
624            inputs.push((&key_strs[2 * i], pk.as_slice()));
625            inputs.push((&key_strs[2 * i + 1], pv.as_slice()));
626        }
627
628        let outputs = compiled.run(&inputs);
629        let (logits, new_k, new_v) = split_decode_logits_kv(outputs, self.cfg.num_hidden_layers)?;
630
631        let cache_mut = self.cache.as_mut().unwrap();
632        cache_mut.past_len = past_seq + 1;
633        cache_mut.layers_k = new_k;
634        cache_mut.layers_v = new_v;
635        self.tokens.push(input);
636
637        Ok(logits)
638    }
639}
640
641/// Compute the single-row (cos, sin) RoPE slice for absolute position
642/// `pos`. Matches the formula in the prefill builder so cached decode
643/// and recompute prefill produce the same RoPE rotation.
644fn compute_rope_slice(cfg: &Qwen3Config, pos: usize) -> (Vec<f32>, Vec<f32>) {
645    let dh = cfg.head_dim;
646    let half = dh / 2;
647    let mut cos = vec![0f32; half];
648    let mut sin = vec![0f32; half];
649    for i in 0..half {
650        let freq = 1.0 / cfg.rope_theta.powf((2 * i) as f64 / dh as f64);
651        let angle = pos as f64 * freq;
652        let (s, c) = angle.sin_cos();
653        cos[i] = c as f32;
654        sin[i] = s as f32;
655    }
656    (cos, sin)
657}
658
659#[cfg(test)]
660mod tests {
661    use super::*;
662    use crate::config::Qwen3Config;
663
664    fn tiny_cfg() -> Qwen3Config {
665        Qwen3Config {
666            vocab_size: 16,
667            hidden_size: 16,
668            intermediate_size: 32,
669            num_hidden_layers: 2,
670            num_attention_heads: 4,
671            num_key_value_heads: 2,
672            head_dim: 8,
673            max_position_embeddings: 16,
674            rms_norm_eps: 1e-6,
675            rope_theta: 1_000_000.0,
676            hidden_act: "silu".into(),
677            tie_word_embeddings: false,
678            attention_bias: false,
679            qk_norm: true,
680            sliding_window: None,
681            max_window_layers: usize::MAX,
682            use_sliding_window: false,
683            num_experts: 0,
684            num_experts_used: 0,
685            expert_ffn_size: 0,
686            shared_expert_ffn_size: 0,
687            expert_weights_scale: 1.0,
688        }
689    }
690
691    fn synthetic_weights(cfg: &Qwen3Config) -> WeightMap {
692        let h = cfg.hidden_size;
693        let q_dim = cfg.q_proj_dim();
694        let kv_dim = cfg.kv_proj_dim();
695        let int_dim = cfg.intermediate_size;
696        let dh = cfg.head_dim;
697        let mut t: HashMap<String, (Vec<f32>, Vec<usize>)> = HashMap::new();
698        // Use a deterministic non-zero pattern so logits aren't all 0
699        // (sampling on an all-zero row is undefined order).
700        let pat = |n: usize, salt: u32| -> Vec<f32> {
701            (0..n)
702                .map(|i| {
703                    let x = ((i as u32).wrapping_mul(2654435761).wrapping_add(salt)) >> 8;
704                    (x as f32 / (1u32 << 24) as f32) - 0.5
705                })
706                .collect()
707        };
708        t.insert(
709            "model.embed_tokens.weight".into(),
710            (pat(cfg.vocab_size * h, 1), vec![cfg.vocab_size, h]),
711        );
712        for i in 0..cfg.num_hidden_layers {
713            let lp = format!("model.layers.{i}");
714            t.insert(
715                format!("{lp}.input_layernorm.weight"),
716                (pat(h, 100 + i as u32), vec![h]),
717            );
718            t.insert(
719                format!("{lp}.post_attention_layernorm.weight"),
720                (pat(h, 200 + i as u32), vec![h]),
721            );
722            t.insert(
723                format!("{lp}.self_attn.q_proj.weight"),
724                (pat(q_dim * h, 300 + i as u32), vec![q_dim, h]),
725            );
726            t.insert(
727                format!("{lp}.self_attn.k_proj.weight"),
728                (pat(kv_dim * h, 400 + i as u32), vec![kv_dim, h]),
729            );
730            t.insert(
731                format!("{lp}.self_attn.v_proj.weight"),
732                (pat(kv_dim * h, 500 + i as u32), vec![kv_dim, h]),
733            );
734            t.insert(
735                format!("{lp}.self_attn.o_proj.weight"),
736                (pat(h * q_dim, 600 + i as u32), vec![h, q_dim]),
737            );
738            t.insert(
739                format!("{lp}.self_attn.q_norm.weight"),
740                (pat(dh, 700 + i as u32), vec![dh]),
741            );
742            t.insert(
743                format!("{lp}.self_attn.k_norm.weight"),
744                (pat(dh, 800 + i as u32), vec![dh]),
745            );
746            t.insert(
747                format!("{lp}.mlp.gate_proj.weight"),
748                (pat(int_dim * h, 900 + i as u32), vec![int_dim, h]),
749            );
750            t.insert(
751                format!("{lp}.mlp.up_proj.weight"),
752                (pat(int_dim * h, 1000 + i as u32), vec![int_dim, h]),
753            );
754            t.insert(
755                format!("{lp}.mlp.down_proj.weight"),
756                (pat(h * int_dim, 1100 + i as u32), vec![h, int_dim]),
757            );
758        }
759        t.insert("model.norm.weight".into(), (pat(h, 2000), vec![h]));
760        t.insert(
761            "lm_head.weight".into(),
762            (pat(cfg.vocab_size * h, 3000), vec![cfg.vocab_size, h]),
763        );
764        WeightMap::from_tensors(t)
765    }
766
767    #[test]
768    fn generator_drains_loader_and_runs_one_step() {
769        let cfg = tiny_cfg();
770        let mut wm = synthetic_weights(&cfg);
771        let mut gn = Qwen3Generator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap();
772        assert_eq!(wm.len(), 0, "loader should be drained");
773        gn.prefill(&[1, 2, 3]);
774        let t = gn.step(SampleOpts::greedy()).unwrap();
775        assert!((t as usize) < cfg.vocab_size);
776        assert_eq!(gn.tokens().len(), 4);
777    }
778
779    #[test]
780    fn generate_n_appends_n_tokens() {
781        let cfg = tiny_cfg();
782        let mut wm = synthetic_weights(&cfg);
783        let mut gn = Qwen3Generator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap();
784        gn.prefill(&[5, 6]);
785        let new_tokens = gn.generate(3, SampleOpts::greedy()).unwrap();
786        assert_eq!(new_tokens.len(), 3);
787        assert_eq!(gn.tokens().len(), 5);
788        for t in &new_tokens {
789            assert!((*t as usize) < cfg.vocab_size);
790        }
791    }
792
793    #[test]
794    fn step_without_prefill_errors() {
795        let cfg = tiny_cfg();
796        let mut wm = synthetic_weights(&cfg);
797        let mut gn = Qwen3Generator::from_loader(cfg, &mut wm, Device::Cpu).unwrap();
798        let r = gn.step(SampleOpts::greedy());
799        assert!(r.is_err());
800    }
801
802    #[test]
803    fn cached_matches_naive_on_greedy() {
804        // The cached and naive paths must produce the same token
805        // sequence given the same prompt + opts. This is the
806        // load-bearing test for the KV-cache implementation: if the
807        // decode-mode graph, the kernel's Lq!=Lk fix, the cache
808        // wiring, or the RoPE position-slice is wrong, the sequences
809        // diverge here.
810        let cfg = tiny_cfg();
811        let prompt: Vec<u32> = vec![1, 2, 3, 5];
812        let steps = 4;
813
814        let mut wm_n = synthetic_weights(&cfg);
815        let mut gn_naive =
816            Qwen3Generator::from_loader(cfg.clone(), &mut wm_n, Device::Cpu).unwrap();
817        gn_naive.prefill_compile_cache = None;
818        gn_naive.decode_compile_cache = None;
819        gn_naive.prefill(&prompt);
820        let naive_tokens = gn_naive.generate(steps, SampleOpts::greedy()).unwrap();
821
822        let mut wm_c = synthetic_weights(&cfg);
823        let mut gn_cached =
824            Qwen3Generator::from_loader(cfg.clone(), &mut wm_c, Device::Cpu).unwrap();
825        gn_cached.prefill(&prompt);
826        let cached_tokens = gn_cached
827            .generate_cached(steps, SampleOpts::greedy())
828            .unwrap();
829
830        assert_eq!(
831            cached_tokens, naive_tokens,
832            "cached vs naive token mismatch — KV cache or kernel-Lq!=Lk bug"
833        );
834    }
835
836    #[test]
837    fn cached_step_advances_cache_invariant() {
838        let cfg = tiny_cfg();
839        let mut wm = synthetic_weights(&cfg);
840        let mut gn = Qwen3Generator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap();
841        gn.prefill(&[1, 2, 3]);
842        let _ = gn.step_cached(SampleOpts::greedy()).unwrap();
843        // After seed: tokens.len() == 4, cache.past_seq == 3 (cache holds prompt).
844        assert_eq!(gn.tokens().len(), 4);
845        assert_eq!(gn.cache.as_ref().unwrap().past_len, 3);
846        let _ = gn.step_cached(SampleOpts::greedy()).unwrap();
847        // After one decode: tokens.len() == 5, cache.past_seq == 4.
848        assert_eq!(gn.tokens().len(), 5);
849        assert_eq!(gn.cache.as_ref().unwrap().past_len, 4);
850    }
851
852    #[test]
853    fn bucketed_decode_matches_oneshot() {
854        // The bucketed compile-cache path (padded K/V + custom mask)
855        // must produce the same token sequence as the one-shot
856        // path. Load-bearing for the bucketed cache feature: if the
857        // mask, padding, or output slicing is wrong, sequences
858        // diverge here.
859        let cfg = tiny_cfg();
860        let prompt: Vec<u32> = vec![1, 2, 3, 5];
861        let steps = 6;
862
863        let mut wm_one = synthetic_weights(&cfg);
864        let mut gn_one =
865            Qwen3Generator::from_loader(cfg.clone(), &mut wm_one, Device::Cpu).unwrap();
866        gn_one.prefill(&prompt);
867        let oneshot_tokens = gn_one.generate_cached(steps, SampleOpts::greedy()).unwrap();
868
869        let mut wm_buc = synthetic_weights(&cfg);
870        let mut gn_buc = Qwen3Generator::from_loader(cfg.clone(), &mut wm_buc, Device::Cpu)
871            .unwrap()
872            .with_decode_cache(/*max_past*/ 32);
873        gn_buc.prefill(&prompt);
874        let bucketed_tokens = gn_buc.generate_cached(steps, SampleOpts::greedy()).unwrap();
875
876        assert_eq!(
877            bucketed_tokens, oneshot_tokens,
878            "bucketed-cache decode diverged from one-shot decode — \
879             mask, padding, or output-slice bug"
880        );
881    }
882
883    #[test]
884    fn bucketed_decode_q_proj_seq_is_one() {
885        use rlx_ir::Op;
886
887        let cfg = tiny_cfg();
888        let mut wm = synthetic_weights(&cfg);
889        let (graph, _) = build_qwen3_decode_graph_sized_ext(&cfg, &mut wm, 1, 4, true).unwrap();
890        for node in graph.nodes() {
891            if let Op::MatMul = &node.op {
892                let sh = graph.shape(node.id);
893                if sh.rank() == 3 && sh.dim(2).unwrap_static() == cfg.q_proj_dim() {
894                    assert_eq!(
895                        sh.dim(1).unwrap_static(),
896                        1,
897                        "decode q_proj matmul seq dim must be 1, got {sh} on node {}",
898                        node.id
899                    );
900                }
901            }
902        }
903
904        let fused = rlx_opt::CompilePipeline::new(rlx_opt::FusionTarget::Metal)
905            .with_assert_fusion_clean(false)
906            .compile_graph(graph)
907            .lir
908            .into_graph();
909        for node in fused.nodes() {
910            if let Op::Narrow { len, .. } = &node.op {
911                let sh = fused.shape(node.id);
912                if sh.rank() == 3 && *len == cfg.q_proj_dim() {
913                    assert_eq!(
914                        sh.dim(1).unwrap_static(),
915                        1,
916                        "fused decode q narrow seq dim must be 1, got {sh} on node {}",
917                        node.id
918                    );
919                }
920            }
921        }
922    }
923
924    #[test]
925    fn prefill_compile_cache_does_not_change_output() {
926        let cfg = tiny_cfg();
927        let prompt: Vec<u32> = vec![1, 2, 3, 5];
928        let mut wm_a = synthetic_weights(&cfg);
929        let mut gn_a = Qwen3Generator::from_loader(cfg.clone(), &mut wm_a, Device::Cpu).unwrap();
930        gn_a.prefill(&prompt);
931        let a = gn_a.generate_cached(4, SampleOpts::greedy()).unwrap();
932
933        let mut wm_b = synthetic_weights(&cfg);
934        let mut gn_b = Qwen3Generator::from_loader(cfg.clone(), &mut wm_b, Device::Cpu)
935            .unwrap()
936            .with_prefill_cache(/*capacity*/ 4);
937        gn_b.prefill(&prompt);
938        let b = gn_b.generate_cached(4, SampleOpts::greedy()).unwrap();
939
940        assert_eq!(a, b, "enabling prefill_cache must not change output");
941    }
942
943    #[test]
944    fn greedy_is_deterministic_across_runs() {
945        let cfg = tiny_cfg();
946        let weights = synthetic_weights(&cfg);
947        let mk = || {
948            let mut wm = WeightMap::from_tensors(weights_as_hashmap(&weights));
949            Qwen3Generator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap()
950        };
951        let mut a = mk();
952        let mut b = mk();
953        a.prefill(&[1, 2, 3]);
954        b.prefill(&[1, 2, 3]);
955        let ta = a.generate(4, SampleOpts::greedy()).unwrap();
956        let tb = b.generate(4, SampleOpts::greedy()).unwrap();
957        assert_eq!(ta, tb);
958    }
959
960    fn weights_as_hashmap(wm: &WeightMap) -> HashMap<String, (Vec<f32>, Vec<usize>)> {
961        // Reconstruct the underlying map by re-running synthetic_weights
962        // — WeightMap doesn't expose its inner map. Sufficient for the
963        // determinism test since synthetic_weights is itself
964        // deterministic.
965        let _ = wm; // silence unused
966        let cfg = tiny_cfg();
967        let mut new = synthetic_weights(&cfg);
968        let keys: Vec<String> = new.keys().map(|s| s.to_string()).collect();
969        let mut out = HashMap::new();
970        for k in keys {
971            out.insert(k.clone(), new.take(&k).unwrap());
972        }
973        out
974    }
975}