Skip to main content

rlx_qwen3/
generator.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Host-side generation loop for Qwen3.
17//!
18//! This is the **naive** generator: each `step()` rebuilds the prefill
19//! graph for the full token history and runs it from scratch
20//! (O(N²) compute over N generated tokens). The API is shaped to
21//! match the upcoming KV-cache version exactly so callers don't have
22//! to change anything when the cached path lands — only the internal
23//! implementation swaps.
24//!
25//! Why ship the naive version first:
26//!   - Establishes the public API contract before the IR/kernel
27//!     changes that the cached version needs land.
28//!   - Lets you run end-to-end generation against a real checkpoint
29//!     today and validate the prefill graph is numerically correct.
30//!   - Provides a reference baseline for the cached version's own
31//!     numerical-parity test (cached vs recompute must match).
32
33use crate::builder::{
34    build_qwen3_decode_graph_sized, build_qwen3_decode_graph_sized_ext,
35    build_qwen3_graph_sized_last_logits,
36};
37use crate::capabilities::validate_device;
38use crate::config::Qwen3Config;
39use crate::profile::qwen3_profile_near_weights;
40use crate::sampling::{SampleOpts, sample_token};
41use anyhow::{Context, Result};
42use rlx_core::autoregressive::{
43    DecodeLogitsKv, KvCacheState, compile_cache_ensure_graph, kv_from_prefill_outputs,
44    prefill_cache_key, run_bucketed_kv_decode, split_decode_logits_kv,
45};
46use rlx_core::flow_bridge::compile_options_from_profile;
47use rlx_core::weight_loader::WeightLoader;
48use rlx_core::weight_map::WeightMap;
49use rlx_flow::CompileProfile;
50use rlx_ir::logical_kernel::KernelDispatchConfig;
51use rlx_runtime::attn_mask::bucket_decode_mask;
52use rlx_runtime::compile_cache::{BucketedCompileCache, CacheRunInput, CompileCache};
53use rlx_runtime::{CompileOptions, Device, Session};
54use std::collections::HashMap;
55use std::path::Path;
56
57/// Stateful Qwen3 generation handle.
58///
59/// Holds the (config, weight bytes, token history) and rebuilds a
60/// prefill graph on each [`step`] call. Cheap to construct after
61/// initial weight load; tokens stay in-memory between calls.
62pub struct Qwen3Generator {
63    cfg: Qwen3Config,
64    /// Map of weight key → (f32 data, shape). Cloned on each step
65    /// into a fresh `WeightMap` because `WeightMap::take` is
66    /// destructive — see the cached-generator notes for the path
67    /// that avoids the clone.
68    weights_cache: HashMap<String, (Vec<f32>, Vec<usize>)>,
69    tokens: Vec<u32>,
70    device: Device,
71    /// Populated lazily on the first `step_cached` call (seeded from
72    /// the prompt via prefill-with-cache); thereafter advanced by each
73    /// decode step.
74    cache: Option<KvCacheState>,
75    /// Per-key LRU compile cache for prefill graphs. Keyed by `seq`.
76    /// Set to `None` to disable (default for new instances; opt in via
77    /// [`Qwen3Generator::with_prefill_cache`]).
78    prefill_compile_cache: Option<CompileCache>,
79    /// Bucketed compile cache for decode-mode graphs. Each bucket
80    /// holds one compiled graph specialized at its upper-bound
81    /// `past_seq`; the host pads `past_k`/`past_v` and supplies a
82    /// per-step mask so a single bucket serves every `past_seq` in
83    /// its range. Opt in via [`Qwen3Generator::with_decode_cache`].
84    decode_compile_cache: Option<BucketedCompileCache>,
85    prefill_profile: CompileProfile,
86    decode_profile: CompileProfile,
87}
88
89impl Qwen3Generator {
90    /// Construct from any [`WeightLoader`] — drains it into an
91    /// internal cache so the loader is free after this call.
92    pub fn from_loader(
93        cfg: Qwen3Config,
94        loader: &mut dyn WeightLoader,
95        device: Device,
96    ) -> Result<Self> {
97        validate_device(&cfg, device, false)?;
98        let keys = loader.remaining_keys();
99        let mut weights_cache = HashMap::with_capacity(keys.len());
100        for k in keys {
101            let v = loader
102                .take(&k)
103                .with_context(|| format!("draining weight {k}"))?;
104            // Normalize the cache key to the safetensors / HuggingFace
105            // naming convention so subsequent builder calls that ask
106            // for `model.embed_tokens.weight` (the canonical name baked
107            // into the qwen3 builder) hit the cache whether the
108            // loader was safetensors-native or GGUF-native.
109            let canonical =
110                rlx_core::weight_loader::gguf_to_hf_name(&k).unwrap_or_else(|| k.clone());
111            weights_cache.insert(canonical, v);
112        }
113        let max_past = cfg.max_position_embeddings.clamp(1, 4096);
114        Ok(Self {
115            cfg,
116            weights_cache,
117            tokens: Vec::new(),
118            device,
119            cache: None,
120            prefill_compile_cache: Some(CompileCache::new(device, 8)),
121            decode_compile_cache: Some(BucketedCompileCache::power_of_two_ladder(
122                device,
123                1,
124                max_past as u64,
125            )),
126            prefill_profile: CompileProfile::qwen3_prefill(),
127            decode_profile: CompileProfile::qwen3_decode(),
128        })
129    }
130
131    /// Like [`Self::from_loader`] but loads `qwen3.rlx.toml` from the weights directory when present.
132    pub fn from_loader_at(
133        cfg: Qwen3Config,
134        loader: &mut dyn WeightLoader,
135        device: Device,
136        weights_path: &Path,
137    ) -> Result<Self> {
138        let mut g = Self::from_loader(cfg, loader, device)?;
139        g.prefill_profile = qwen3_profile_near_weights(weights_path, false);
140        g.decode_profile = qwen3_profile_near_weights(weights_path, true);
141        Ok(g)
142    }
143
144    pub fn with_compile_profiles(
145        mut self,
146        prefill: CompileProfile,
147        decode: CompileProfile,
148    ) -> Self {
149        self.prefill_profile = prefill;
150        self.decode_profile = decode;
151        self
152    }
153
154    pub fn prefill_profile(&self) -> &CompileProfile {
155        &self.prefill_profile
156    }
157
158    pub fn decode_profile(&self) -> &CompileProfile {
159        &self.decode_profile
160    }
161
162    fn profile_compile_options(&self, decode: bool) -> CompileOptions {
163        let profile = if decode {
164            &self.decode_profile
165        } else {
166            &self.prefill_profile
167        };
168        compile_options_from_profile(profile, self.device, KernelDispatchConfig::default())
169    }
170
171    /// Enable the prefill compile cache with the given LRU capacity.
172    /// Useful when the same prompt length is used across multiple
173    /// generation runs — the second + Nth run skip the compile +
174    /// param-attach roundtrip (~30-50ms per call on CPU).
175    pub fn with_prefill_cache(mut self, capacity: usize) -> Self {
176        self.prefill_compile_cache = Some(CompileCache::new(self.device, capacity));
177        self
178    }
179
180    /// Enable the bucketed decode compile cache spanning past-seq
181    /// values in `[1, max_past]`. Buckets are power-of-two
182    /// `[1..2, 2..3, 3..5, 5..9, 9..17, …]`. Each bucket compiles
183    /// one graph at its upper bound; a steady-state generation loop
184    /// across `N` tokens compiles `O(log N)` graphs instead of `N`.
185    ///
186    /// Padding compute waste is bounded at 2×: actual `past_seq` is
187    /// at least half the bucket's upper bound (except possibly the
188    /// smallest bucket).
189    /// Override the bucketed decode compile cache after construction.
190    /// Passing `None` forces the naive (O(N²)) generation path.
191    pub fn set_decode_compile_cache(&mut self, cache: Option<BucketedCompileCache>) {
192        self.decode_compile_cache = cache;
193    }
194
195    pub fn with_decode_cache(mut self, max_past: usize) -> Self {
196        let cache = BucketedCompileCache::power_of_two_ladder(
197            self.device,
198            /*min*/ 1,
199            max_past.max(1) as u64,
200        );
201        self.decode_compile_cache = Some(cache);
202        self
203    }
204
205    /// Convenience: load weights from a safetensors or GGUF path
206    /// (dispatch by extension; see `rlx_core::weight_loader::load_from_path`).
207    pub fn from_path(cfg: Qwen3Config, path: &str, device: Device) -> Result<Self> {
208        Self::from_path_at(cfg, path, device, Path::new("."))
209    }
210
211    /// Like [`Self::from_path`] with an explicit weights path for `qwen3.rlx.toml` discovery.
212    pub fn from_path_at(
213        cfg: Qwen3Config,
214        path: &str,
215        device: Device,
216        weights_path: &Path,
217    ) -> Result<Self> {
218        let mut loader = rlx_core::weight_loader::load_from_path(path)?;
219        Self::from_loader_at(cfg, loader.as_mut(), device, weights_path)
220    }
221
222    /// Same as [`from_path`] but with MTP-head visibility control.
223    /// When `include_mtp=true` and the file is GGUF, MTP weights are
224    /// drained into the generator's cache alongside the base
225    /// weights. The base inference path still ignores them — they
226    /// sit in cache for a future MTP-aware decoder. Non-GGUF formats
227    /// silently ignore the flag (safetensors files publish all
228    /// tensors uniformly; downstream code distinguishes by name).
229    pub fn from_path_with_mtp(
230        cfg: Qwen3Config,
231        path: &str,
232        device: Device,
233        include_mtp: bool,
234    ) -> Result<Self> {
235        // Branch on extension so we can flip the GGUF-specific
236        // visibility option. Safetensors has no equivalent — it
237        // doesn't isolate MTP tensors at the loader level.
238        if path.ends_with(".gguf") {
239            let mut gguf = rlx_core::weight_loader::GgufLoader::from_file(path)?;
240            gguf.include_mtp(include_mtp);
241            Self::from_loader(cfg, &mut gguf, device)
242        } else {
243            Self::from_path(cfg, path, device)
244        }
245    }
246
247    /// Replace the token history with `prompt_ids`. Does not run the
248    /// model — the next [`step`] call processes the full sequence.
249    /// Clears any KV cache from a prior generation.
250    pub fn prefill(&mut self, prompt_ids: &[u32]) {
251        self.tokens.clear();
252        self.tokens.extend_from_slice(prompt_ids);
253        self.cache = None;
254    }
255
256    /// Run one prefill over the current token history and sample the
257    /// next token. The sampled token is appended to the history and
258    /// returned. Call repeatedly to generate.
259    pub fn step(&mut self, opts: SampleOpts) -> Result<u32> {
260        if self.tokens.is_empty() {
261            anyhow::bail!("step() called with empty token history; call prefill() first");
262        }
263        let seq = self.tokens.len();
264        let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
265        let (graph, params) = build_qwen3_graph_sized_last_logits(
266            &self.cfg, &mut wm, /*batch*/ 1, seq, /*with_kv_outputs*/ false,
267        )?;
268        let compile_opts = self.profile_compile_options(false);
269        let mut compiled = Session::new(self.device).compile_with(graph, &compile_opts);
270        for (name, data) in &params {
271            compiled.set_param(name, data);
272        }
273        let ids_f32: Vec<f32> = self.tokens.iter().map(|&i| i as f32).collect();
274        let outputs = compiled.run(&[("input_ids", ids_f32.as_slice())]);
275        let logits = outputs
276            .into_iter()
277            .next()
278            .context("compiled.run returned no outputs")?;
279
280        let vocab = self.cfg.vocab_size;
281        let expected = vocab;
282        if logits.len() < expected {
283            anyhow::bail!(
284                "logits length {} < expected {} (last logits, seq {seq}, vocab {vocab})",
285                logits.len(),
286                expected
287            );
288        }
289        // Last-logits graph returns [B=1, 1, vocab].
290        let last_row = &logits[..vocab];
291        let tok = sample_token(last_row, opts) as u32;
292        self.tokens.push(tok);
293        Ok(tok)
294    }
295
296    /// Run `n` steps and return the newly generated token ids
297    /// (excludes the prefill prompt).
298    pub fn generate(&mut self, n: usize, opts: SampleOpts) -> Result<Vec<u32>> {
299        if self.decode_compile_cache.is_some() {
300            return self.generate_cached(n, opts);
301        }
302        let start = self.tokens.len();
303        for _ in 0..n {
304            self.step(opts)?;
305        }
306        Ok(self.tokens[start..].to_vec())
307    }
308
309    /// Cached step: O(L) per token instead of O(L²). First call seeds
310    /// the KV cache from the prompt via prefill-with-cache; subsequent
311    /// calls run the decode-mode graph on just the last token + cached
312    /// past. Output is bit-identical to [`step`] modulo reduction
313    /// order in the SDPA kernel.
314    ///
315    /// Invariant after each call: `cache.past_seq == tokens.len() - 1`
316    /// (the just-sampled token is appended but not yet in the cache;
317    /// it becomes the input for the next decode step).
318    pub fn step_cached(&mut self, opts: SampleOpts) -> Result<u32> {
319        if self.tokens.is_empty() {
320            anyhow::bail!("step_cached() called with empty token history; call prefill() first");
321        }
322        if self.cache.is_none() {
323            // The seed runs prefill, populates the cache, samples from
324            // the last position, and appends the token. Return that
325            // token directly — no decode step on this call.
326            let tok = self.seed_cache_from_prompt(opts)?;
327            return Ok(tok);
328        }
329        let cache = self.cache.as_ref().unwrap();
330        let past_seq = cache.past_len;
331        // The token we feed into decode is whatever's after the cached
332        // prefix in `self.tokens`. After a prior cached step this is
333        // the just-sampled token; after seeding it's the same.
334        if self.tokens.len() <= past_seq {
335            anyhow::bail!(
336                "cache invariant violated: tokens.len() {} <= past_seq {}",
337                self.tokens.len(),
338                past_seq
339            );
340        }
341        let input_tok = self.tokens[past_seq];
342
343        // Branch: bucketed compile cache vs one-shot compile per step.
344        let (logits, new_k, new_v) = if self.decode_compile_cache.is_some()
345            && self
346                .decode_compile_cache
347                .as_ref()
348                .unwrap()
349                .bucket_for(past_seq as u64)
350                .is_some()
351        {
352            self.decode_step_bucketed(past_seq, input_tok)?
353        } else {
354            self.decode_step_oneshot(past_seq, input_tok)?
355        };
356
357        let cache_mut = self.cache.as_mut().unwrap();
358        cache_mut.past_len = past_seq + 1;
359        cache_mut.layers_k = new_k;
360        cache_mut.layers_v = new_v;
361
362        let vocab = self.cfg.vocab_size;
363        if logits.len() != vocab {
364            anyhow::bail!("decode logits length {} != vocab {}", logits.len(), vocab);
365        }
366        let tok = sample_token(&logits, opts) as u32;
367        self.tokens.push(tok);
368        Ok(tok)
369    }
370
371    /// Decode path that compiles a fresh graph for the exact `past_seq`
372    /// every call. Slower but always-correct fallback.
373    fn decode_step_oneshot(&mut self, past_seq: usize, input_tok: u32) -> Result<DecodeLogitsKv> {
374        let cache = self.cache.as_ref().unwrap();
375
376        let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
377        let (graph, params) =
378            build_qwen3_decode_graph_sized(&self.cfg, &mut wm, /*batch*/ 1, past_seq)?;
379        let opts = self.profile_compile_options(true);
380        let mut compiled = Session::new(self.device).compile_with(graph, &opts);
381        for (name, data) in &params {
382            compiled.set_param(name, data);
383        }
384
385        let (cos, sin) = compute_rope_slice(&self.cfg, past_seq);
386        let input_ids_f32 = [input_tok as f32];
387        let key_strs: Vec<String> = (0..self.cfg.num_hidden_layers)
388            .flat_map(|i| [format!("past_k_{i}"), format!("past_v_{i}")])
389            .collect();
390        let mut inputs: Vec<(&str, &[f32])> =
391            Vec::with_capacity(3 + 2 * self.cfg.num_hidden_layers);
392        inputs.push(("input_ids", input_ids_f32.as_slice()));
393        inputs.push(("rope_cos", cos.as_slice()));
394        inputs.push(("rope_sin", sin.as_slice()));
395        for i in 0..self.cfg.num_hidden_layers {
396            inputs.push((&key_strs[2 * i], cache.layers_k[i].as_slice()));
397            inputs.push((&key_strs[2 * i + 1], cache.layers_v[i].as_slice()));
398        }
399
400        let outputs = compiled.run(&inputs);
401        split_decode_logits_kv(outputs, self.cfg.num_hidden_layers)
402    }
403
404    fn decode_step_bucketed(&mut self, past_seq: usize, input_tok: u32) -> Result<DecodeLogitsKv> {
405        let kv = self.cache.as_ref().unwrap().clone();
406        let kv_dim = self.cfg.kv_proj_dim();
407        let n_layers = self.cfg.num_hidden_layers;
408        let (cos, sin) = compute_rope_slice(&self.cfg, past_seq);
409        let input_ids_f32 = [input_tok as f32];
410        let decode_opts = self.profile_compile_options(true);
411        let upper = self
412            .decode_compile_cache
413            .as_ref()
414            .and_then(|cache_dec| {
415                cache_dec.bucket_for(past_seq as u64).map(|idx| {
416                    cache_dec
417                        .buckets()
418                        .nth(idx)
419                        .map(|r| (r.end - 1) as usize)
420                        .unwrap_or(past_seq)
421                })
422            })
423            .unwrap_or(past_seq);
424        let mask = bucket_decode_mask(past_seq, upper);
425        let fixed = [
426            CacheRunInput {
427                name: "input_ids",
428                data: &input_ids_f32,
429                row_inner: None,
430            },
431            CacheRunInput {
432                name: "rope_cos",
433                data: &cos,
434                row_inner: None,
435            },
436            CacheRunInput {
437                name: "rope_sin",
438                data: &sin,
439                row_inner: None,
440            },
441            CacheRunInput {
442                name: "mask",
443                data: &mask,
444                row_inner: None,
445            },
446        ];
447        let cfg = self.cfg.clone();
448        let weights = self.weights_cache.clone();
449        let cache_dec = self.decode_compile_cache.as_mut().unwrap();
450        run_bucketed_kv_decode(
451            cache_dec,
452            past_seq,
453            &kv,
454            kv_dim,
455            n_layers,
456            &fixed,
457            |upper| {
458                let mut wm = WeightMap::from_tensors(weights.clone());
459                build_qwen3_decode_graph_sized_ext(&cfg, &mut wm, 1, upper as usize, true)
460                    .expect("qwen3 bucketed decode graph")
461            },
462            &decode_opts,
463        )
464    }
465
466    /// Run prefill-with-cache and return the raw outputs. Uses the
467    /// LRU `CompileCache` when enabled; otherwise compiles fresh each
468    /// call. Keyed by `seq` because graph shape is seq-specialized.
469    fn run_prefill_with_cache(
470        &mut self,
471        batch: usize,
472        seq: usize,
473        ids_f32: &[f32],
474    ) -> Result<Vec<Vec<f32>>> {
475        let prefill_opts = self.profile_compile_options(false);
476        if let Some(cache) = &mut self.prefill_compile_cache {
477            let key = prefill_cache_key(batch, seq);
478            let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
479            let (graph, params) = build_qwen3_graph_sized_last_logits(
480                &self.cfg, &mut wm, batch, seq, /*with_kv_outputs*/ true,
481            )?;
482            let compiled = compile_cache_ensure_graph(cache, key, graph, params, &prefill_opts);
483            Ok(compiled.run(&[("input_ids", ids_f32)]))
484        } else {
485            let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
486            let (graph, params) = build_qwen3_graph_sized_last_logits(
487                &self.cfg, &mut wm, batch, seq, /*with_kv_outputs*/ true,
488            )?;
489            let opts = self.profile_compile_options(false);
490            let mut compiled = Session::new(self.device).compile_with(graph, &opts);
491            for (name, data) in &params {
492                compiled.set_param(name, data);
493            }
494            Ok(compiled.run(&[("input_ids", ids_f32)]))
495        }
496    }
497
498    /// Run `n` cached steps and return the newly generated tokens.
499    pub fn generate_cached(&mut self, n: usize, opts: SampleOpts) -> Result<Vec<u32>> {
500        self.generate_cached_with(n, opts, |_| {})
501    }
502
503    /// Same as [`generate_cached`] but invokes `on_token` once per
504    /// freshly sampled id, inside the decode loop. The whole `n` step
505    /// loop shares the bucketed compile cache — callers wanting a
506    /// streaming UI should prefer this to calling
507    /// `generate_cached(1, …)` `n` times (which forces a fresh
508    /// compile per token at the bucket boundaries).
509    pub fn generate_cached_with(
510        &mut self,
511        n: usize,
512        opts: SampleOpts,
513        on_token: impl FnMut(u32),
514    ) -> Result<Vec<u32>> {
515        self.generate_cached_until(n, opts, |_| true, on_token)
516    }
517
518    /// Like [`generate_cached_with`] but stops early when `should_continue`
519    /// returns `false` after sampling a token.
520    pub fn generate_cached_until(
521        &mut self,
522        n: usize,
523        opts: SampleOpts,
524        mut should_continue: impl FnMut(u32) -> bool,
525        mut on_token: impl FnMut(u32),
526    ) -> Result<Vec<u32>> {
527        let start = self.tokens.len();
528        for _ in 0..n {
529            let tok = self.step_cached(opts)?;
530            on_token(tok);
531            if !should_continue(tok) {
532                break;
533            }
534        }
535        Ok(self.tokens[start..].to_vec())
536    }
537
538    /// Run prefill-with-cache on the current `self.tokens` (the
539    /// prompt), populate `self.cache`, sample the next token from the
540    /// last position's logits, and append it. Returns the sampled
541    /// token. Invariant after: `cache.past_seq == tokens.len() - 1`.
542    fn seed_cache_from_prompt(&mut self, opts: SampleOpts) -> Result<u32> {
543        let seq = self.tokens.len();
544        let batch = 1usize;
545        let kv_dim = self.cfg.kv_proj_dim();
546
547        let ids_f32: Vec<f32> = self.tokens.iter().map(|&i| i as f32).collect();
548        let outputs = self.run_prefill_with_cache(batch, seq, &ids_f32)?;
549        let (logits, kv) =
550            kv_from_prefill_outputs(outputs, batch, seq, kv_dim, self.cfg.num_hidden_layers)?;
551        self.cache = Some(kv);
552
553        let vocab = self.cfg.vocab_size;
554        let needed = vocab;
555        if logits.len() < needed {
556            anyhow::bail!("prefill logits length {} < {}", logits.len(), needed);
557        }
558        let last_row = &logits[..vocab];
559        let tok = sample_token(last_row, opts) as u32;
560        self.tokens.push(tok);
561        Ok(tok)
562    }
563
564    /// Full token history (prompt + generated).
565    pub fn tokens(&self) -> &[u32] {
566        &self.tokens
567    }
568
569    pub fn config(&self) -> &Qwen3Config {
570        &self.cfg
571    }
572
573    /// Low-level primitive: reset internal state, run prefill-with-cache
574    /// over `context`, and return the *last position's* logits row
575    /// (`P(next_token | context)`). Does NOT sample or append. The
576    /// internal `tokens` buffer is set to `context` and the KV cache
577    /// is populated to `past_seq = context.len()`.
578    ///
579    /// Used by [`crate::spec::Qwen3Speculator`] to compute the
580    /// first row of a `Speculator::verify` / `propose` result before
581    /// the decode loop runs.
582    pub fn prefill_get_last_logits(&mut self, context: &[u32]) -> Result<Vec<f32>> {
583        if context.is_empty() {
584            anyhow::bail!("prefill_get_last_logits: empty context");
585        }
586        self.tokens.clear();
587        self.tokens.extend_from_slice(context);
588        self.cache = None;
589
590        let seq = context.len();
591        let batch = 1usize;
592        let kv_dim = self.cfg.kv_proj_dim();
593
594        let ids_f32: Vec<f32> = context.iter().map(|&i| i as f32).collect();
595        let outputs = self.run_prefill_with_cache(batch, seq, &ids_f32)?;
596        let (logits, kv) =
597            kv_from_prefill_outputs(outputs, batch, seq, kv_dim, self.cfg.num_hidden_layers)?;
598        self.cache = Some(kv);
599
600        let vocab = self.cfg.vocab_size;
601        let needed = vocab;
602        if logits.len() < needed {
603            anyhow::bail!("logits short: {} < {}", logits.len(), needed);
604        }
605        Ok(logits[..vocab].to_vec())
606    }
607
608    /// Low-level primitive: run one decode step with the caller-
609    /// supplied input token (no sampling), advance the KV cache, and
610    /// return the resulting logits row `P(next | history ++ input)`.
611    /// Appends `input` to the `tokens` buffer so the invariant
612    /// `cache.past_seq == tokens.len()` holds after this call (note:
613    /// differs from `step_cached` invariant because this method does
614    /// not append a sampled token).
615    pub fn decode_get_logits(&mut self, input: u32) -> Result<Vec<f32>> {
616        let cache = self.cache.as_ref().ok_or_else(|| {
617            anyhow::anyhow!(
618                "decode_get_logits: cache not seeded; call prefill_get_last_logits first"
619            )
620        })?;
621        let past_seq = cache.past_len;
622
623        let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
624        let (graph, params) =
625            build_qwen3_decode_graph_sized(&self.cfg, &mut wm, /*batch*/ 1, past_seq)?;
626        let opts = self.profile_compile_options(true);
627        let mut compiled = Session::new(self.device).compile_with(graph, &opts);
628        for (name, data) in &params {
629            compiled.set_param(name, data);
630        }
631
632        let (cos, sin) = compute_rope_slice(&self.cfg, past_seq);
633        let input_ids_f32 = [input as f32];
634        let key_strs: Vec<String> = (0..self.cfg.num_hidden_layers)
635            .flat_map(|i| [format!("past_k_{i}"), format!("past_v_{i}")])
636            .collect();
637        let mut inputs: Vec<(&str, &[f32])> =
638            Vec::with_capacity(3 + 2 * self.cfg.num_hidden_layers);
639        inputs.push(("input_ids", input_ids_f32.as_slice()));
640        inputs.push(("rope_cos", cos.as_slice()));
641        inputs.push(("rope_sin", sin.as_slice()));
642        for i in 0..self.cfg.num_hidden_layers {
643            let pk = &cache.layers_k[i];
644            let pv = &cache.layers_v[i];
645            inputs.push((&key_strs[2 * i], pk.as_slice()));
646            inputs.push((&key_strs[2 * i + 1], pv.as_slice()));
647        }
648
649        let outputs = compiled.run(&inputs);
650        let (logits, new_k, new_v) = split_decode_logits_kv(outputs, self.cfg.num_hidden_layers)?;
651
652        let cache_mut = self.cache.as_mut().unwrap();
653        cache_mut.past_len = past_seq + 1;
654        cache_mut.layers_k = new_k;
655        cache_mut.layers_v = new_v;
656        self.tokens.push(input);
657
658        Ok(logits)
659    }
660}
661
662/// Compute the single-row (cos, sin) RoPE slice for absolute position
663/// `pos`. Matches the formula in the prefill builder so cached decode
664/// and recompute prefill produce the same RoPE rotation.
665fn compute_rope_slice(cfg: &Qwen3Config, pos: usize) -> (Vec<f32>, Vec<f32>) {
666    let dh = cfg.head_dim;
667    let half = dh / 2;
668    let mut cos = vec![0f32; half];
669    let mut sin = vec![0f32; half];
670    for i in 0..half {
671        let freq = 1.0 / cfg.rope_theta.powf((2 * i) as f64 / dh as f64);
672        let angle = pos as f64 * freq;
673        let (s, c) = angle.sin_cos();
674        cos[i] = c as f32;
675        sin[i] = s as f32;
676    }
677    (cos, sin)
678}
679
680#[cfg(test)]
681mod tests {
682    use super::*;
683    use crate::config::Qwen3Config;
684
685    fn tiny_cfg() -> Qwen3Config {
686        Qwen3Config {
687            vocab_size: 16,
688            hidden_size: 16,
689            intermediate_size: 32,
690            num_hidden_layers: 2,
691            num_attention_heads: 4,
692            num_key_value_heads: 2,
693            head_dim: 8,
694            max_position_embeddings: 16,
695            rms_norm_eps: 1e-6,
696            rope_theta: 1_000_000.0,
697            hidden_act: "silu".into(),
698            tie_word_embeddings: false,
699            attention_bias: false,
700            qk_norm: true,
701            sliding_window: None,
702            max_window_layers: usize::MAX,
703            use_sliding_window: false,
704            num_experts: 0,
705            num_experts_used: 0,
706            expert_ffn_size: 0,
707            shared_expert_ffn_size: 0,
708            expert_weights_scale: 1.0,
709        }
710    }
711
712    fn synthetic_weights(cfg: &Qwen3Config) -> WeightMap {
713        let h = cfg.hidden_size;
714        let q_dim = cfg.q_proj_dim();
715        let kv_dim = cfg.kv_proj_dim();
716        let int_dim = cfg.intermediate_size;
717        let dh = cfg.head_dim;
718        let mut t: HashMap<String, (Vec<f32>, Vec<usize>)> = HashMap::new();
719        // Use a deterministic non-zero pattern so logits aren't all 0
720        // (sampling on an all-zero row is undefined order).
721        let pat = |n: usize, salt: u32| -> Vec<f32> {
722            (0..n)
723                .map(|i| {
724                    let x = ((i as u32).wrapping_mul(2654435761).wrapping_add(salt)) >> 8;
725                    (x as f32 / (1u32 << 24) as f32) - 0.5
726                })
727                .collect()
728        };
729        t.insert(
730            "model.embed_tokens.weight".into(),
731            (pat(cfg.vocab_size * h, 1), vec![cfg.vocab_size, h]),
732        );
733        for i in 0..cfg.num_hidden_layers {
734            let lp = format!("model.layers.{i}");
735            t.insert(
736                format!("{lp}.input_layernorm.weight"),
737                (pat(h, 100 + i as u32), vec![h]),
738            );
739            t.insert(
740                format!("{lp}.post_attention_layernorm.weight"),
741                (pat(h, 200 + i as u32), vec![h]),
742            );
743            t.insert(
744                format!("{lp}.self_attn.q_proj.weight"),
745                (pat(q_dim * h, 300 + i as u32), vec![q_dim, h]),
746            );
747            t.insert(
748                format!("{lp}.self_attn.k_proj.weight"),
749                (pat(kv_dim * h, 400 + i as u32), vec![kv_dim, h]),
750            );
751            t.insert(
752                format!("{lp}.self_attn.v_proj.weight"),
753                (pat(kv_dim * h, 500 + i as u32), vec![kv_dim, h]),
754            );
755            t.insert(
756                format!("{lp}.self_attn.o_proj.weight"),
757                (pat(h * q_dim, 600 + i as u32), vec![h, q_dim]),
758            );
759            t.insert(
760                format!("{lp}.self_attn.q_norm.weight"),
761                (pat(dh, 700 + i as u32), vec![dh]),
762            );
763            t.insert(
764                format!("{lp}.self_attn.k_norm.weight"),
765                (pat(dh, 800 + i as u32), vec![dh]),
766            );
767            t.insert(
768                format!("{lp}.mlp.gate_proj.weight"),
769                (pat(int_dim * h, 900 + i as u32), vec![int_dim, h]),
770            );
771            t.insert(
772                format!("{lp}.mlp.up_proj.weight"),
773                (pat(int_dim * h, 1000 + i as u32), vec![int_dim, h]),
774            );
775            t.insert(
776                format!("{lp}.mlp.down_proj.weight"),
777                (pat(h * int_dim, 1100 + i as u32), vec![h, int_dim]),
778            );
779        }
780        t.insert("model.norm.weight".into(), (pat(h, 2000), vec![h]));
781        t.insert(
782            "lm_head.weight".into(),
783            (pat(cfg.vocab_size * h, 3000), vec![cfg.vocab_size, h]),
784        );
785        WeightMap::from_tensors(t)
786    }
787
788    #[test]
789    fn generator_drains_loader_and_runs_one_step() {
790        let cfg = tiny_cfg();
791        let mut wm = synthetic_weights(&cfg);
792        let mut gn = Qwen3Generator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap();
793        assert_eq!(wm.len(), 0, "loader should be drained");
794        gn.prefill(&[1, 2, 3]);
795        let t = gn.step(SampleOpts::greedy()).unwrap();
796        assert!((t as usize) < cfg.vocab_size);
797        assert_eq!(gn.tokens().len(), 4);
798    }
799
800    #[test]
801    fn generate_n_appends_n_tokens() {
802        let cfg = tiny_cfg();
803        let mut wm = synthetic_weights(&cfg);
804        let mut gn = Qwen3Generator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap();
805        gn.prefill(&[5, 6]);
806        let new_tokens = gn.generate(3, SampleOpts::greedy()).unwrap();
807        assert_eq!(new_tokens.len(), 3);
808        assert_eq!(gn.tokens().len(), 5);
809        for t in &new_tokens {
810            assert!((*t as usize) < cfg.vocab_size);
811        }
812    }
813
814    #[test]
815    fn step_without_prefill_errors() {
816        let cfg = tiny_cfg();
817        let mut wm = synthetic_weights(&cfg);
818        let mut gn = Qwen3Generator::from_loader(cfg, &mut wm, Device::Cpu).unwrap();
819        let r = gn.step(SampleOpts::greedy());
820        assert!(r.is_err());
821    }
822
823    #[test]
824    fn cached_matches_naive_on_greedy() {
825        // The cached and naive paths must produce the same token
826        // sequence given the same prompt + opts. This is the
827        // load-bearing test for the KV-cache implementation: if the
828        // decode-mode graph, the kernel's Lq!=Lk fix, the cache
829        // wiring, or the RoPE position-slice is wrong, the sequences
830        // diverge here.
831        let cfg = tiny_cfg();
832        let prompt: Vec<u32> = vec![1, 2, 3, 5];
833        let steps = 4;
834
835        let mut wm_n = synthetic_weights(&cfg);
836        let mut gn_naive =
837            Qwen3Generator::from_loader(cfg.clone(), &mut wm_n, Device::Cpu).unwrap();
838        gn_naive.prefill_compile_cache = None;
839        gn_naive.decode_compile_cache = None;
840        gn_naive.prefill(&prompt);
841        let naive_tokens = gn_naive.generate(steps, SampleOpts::greedy()).unwrap();
842
843        let mut wm_c = synthetic_weights(&cfg);
844        let mut gn_cached =
845            Qwen3Generator::from_loader(cfg.clone(), &mut wm_c, Device::Cpu).unwrap();
846        gn_cached.prefill(&prompt);
847        let cached_tokens = gn_cached
848            .generate_cached(steps, SampleOpts::greedy())
849            .unwrap();
850
851        assert_eq!(
852            cached_tokens, naive_tokens,
853            "cached vs naive token mismatch — KV cache or kernel-Lq!=Lk bug"
854        );
855    }
856
857    #[test]
858    fn cached_step_advances_cache_invariant() {
859        let cfg = tiny_cfg();
860        let mut wm = synthetic_weights(&cfg);
861        let mut gn = Qwen3Generator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap();
862        gn.prefill(&[1, 2, 3]);
863        let _ = gn.step_cached(SampleOpts::greedy()).unwrap();
864        // After seed: tokens.len() == 4, cache.past_seq == 3 (cache holds prompt).
865        assert_eq!(gn.tokens().len(), 4);
866        assert_eq!(gn.cache.as_ref().unwrap().past_len, 3);
867        let _ = gn.step_cached(SampleOpts::greedy()).unwrap();
868        // After one decode: tokens.len() == 5, cache.past_seq == 4.
869        assert_eq!(gn.tokens().len(), 5);
870        assert_eq!(gn.cache.as_ref().unwrap().past_len, 4);
871    }
872
873    #[test]
874    fn bucketed_decode_matches_oneshot() {
875        // The bucketed compile-cache path (padded K/V + custom mask)
876        // must produce the same token sequence as the one-shot
877        // path. Load-bearing for the bucketed cache feature: if the
878        // mask, padding, or output slicing is wrong, sequences
879        // diverge here.
880        let cfg = tiny_cfg();
881        let prompt: Vec<u32> = vec![1, 2, 3, 5];
882        let steps = 6;
883
884        let mut wm_one = synthetic_weights(&cfg);
885        let mut gn_one =
886            Qwen3Generator::from_loader(cfg.clone(), &mut wm_one, Device::Cpu).unwrap();
887        gn_one.prefill(&prompt);
888        let oneshot_tokens = gn_one.generate_cached(steps, SampleOpts::greedy()).unwrap();
889
890        let mut wm_buc = synthetic_weights(&cfg);
891        let mut gn_buc = Qwen3Generator::from_loader(cfg.clone(), &mut wm_buc, Device::Cpu)
892            .unwrap()
893            .with_decode_cache(/*max_past*/ 32);
894        gn_buc.prefill(&prompt);
895        let bucketed_tokens = gn_buc.generate_cached(steps, SampleOpts::greedy()).unwrap();
896
897        assert_eq!(
898            bucketed_tokens, oneshot_tokens,
899            "bucketed-cache decode diverged from one-shot decode — \
900             mask, padding, or output-slice bug"
901        );
902    }
903
904    #[test]
905    fn bucketed_decode_q_proj_seq_is_one() {
906        use rlx_ir::Op;
907
908        let cfg = tiny_cfg();
909        let mut wm = synthetic_weights(&cfg);
910        let (graph, _) = build_qwen3_decode_graph_sized_ext(&cfg, &mut wm, 1, 4, true).unwrap();
911        for node in graph.nodes() {
912            if let Op::MatMul = &node.op {
913                let sh = graph.shape(node.id);
914                if sh.rank() == 3 && sh.dim(2).unwrap_static() == cfg.q_proj_dim() {
915                    assert_eq!(
916                        sh.dim(1).unwrap_static(),
917                        1,
918                        "decode q_proj matmul seq dim must be 1, got {sh} on node {}",
919                        node.id
920                    );
921                }
922            }
923        }
924
925        let fused = rlx_opt::CompilePipeline::new(rlx_opt::FusionTarget::Metal)
926            .with_assert_fusion_clean(false)
927            .compile_graph(graph)
928            .lir
929            .into_graph();
930        for node in fused.nodes() {
931            if let Op::Narrow { len, .. } = &node.op {
932                let sh = fused.shape(node.id);
933                if sh.rank() == 3 && *len == cfg.q_proj_dim() {
934                    assert_eq!(
935                        sh.dim(1).unwrap_static(),
936                        1,
937                        "fused decode q narrow seq dim must be 1, got {sh} on node {}",
938                        node.id
939                    );
940                }
941            }
942        }
943    }
944
945    #[test]
946    fn prefill_compile_cache_does_not_change_output() {
947        let cfg = tiny_cfg();
948        let prompt: Vec<u32> = vec![1, 2, 3, 5];
949        let mut wm_a = synthetic_weights(&cfg);
950        let mut gn_a = Qwen3Generator::from_loader(cfg.clone(), &mut wm_a, Device::Cpu).unwrap();
951        gn_a.prefill(&prompt);
952        let a = gn_a.generate_cached(4, SampleOpts::greedy()).unwrap();
953
954        let mut wm_b = synthetic_weights(&cfg);
955        let mut gn_b = Qwen3Generator::from_loader(cfg.clone(), &mut wm_b, Device::Cpu)
956            .unwrap()
957            .with_prefill_cache(/*capacity*/ 4);
958        gn_b.prefill(&prompt);
959        let b = gn_b.generate_cached(4, SampleOpts::greedy()).unwrap();
960
961        assert_eq!(a, b, "enabling prefill_cache must not change output");
962    }
963
964    #[test]
965    fn greedy_is_deterministic_across_runs() {
966        let cfg = tiny_cfg();
967        let weights = synthetic_weights(&cfg);
968        let mk = || {
969            let mut wm = WeightMap::from_tensors(weights_as_hashmap(&weights));
970            Qwen3Generator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap()
971        };
972        let mut a = mk();
973        let mut b = mk();
974        a.prefill(&[1, 2, 3]);
975        b.prefill(&[1, 2, 3]);
976        let ta = a.generate(4, SampleOpts::greedy()).unwrap();
977        let tb = b.generate(4, SampleOpts::greedy()).unwrap();
978        assert_eq!(ta, tb);
979    }
980
981    fn weights_as_hashmap(wm: &WeightMap) -> HashMap<String, (Vec<f32>, Vec<usize>)> {
982        // Reconstruct the underlying map by re-running synthetic_weights
983        // — WeightMap doesn't expose its inner map. Sufficient for the
984        // determinism test since synthetic_weights is itself
985        // deterministic.
986        let _ = wm; // silence unused
987        let cfg = tiny_cfg();
988        let mut new = synthetic_weights(&cfg);
989        let keys: Vec<String> = new.keys().map(|s| s.to_string()).collect();
990        let mut out = HashMap::new();
991        for k in keys {
992            out.insert(k.clone(), new.take(&k).unwrap());
993        }
994        out
995    }
996}