Skip to main content

rlx_gemma/
generator.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Host-side generation loop for Gemma.
17//!
18//! This is the **naive** generator: each `step()` rebuilds the prefill
19//! graph for the full token history and runs it from scratch
20//! (O(N²) compute over N generated tokens). The API is shaped to
21//! match the upcoming KV-cache version exactly so callers don't have
22//! to change anything when the cached path lands — only the internal
23//! implementation swaps.
24//!
25//! Why ship the naive version first:
26//!   - Establishes the public API contract before the IR/kernel
27//!     changes that the cached version needs land.
28//!   - Lets you run end-to-end generation against a real checkpoint
29//!     today and validate the prefill graph is numerically correct.
30//!   - Provides a reference baseline for the cached version's own
31//!     numerical-parity test (cached vs recompute must match).
32
33use crate::builder::{
34    build_gemma_decode_graph_sized, build_gemma_decode_hir_dynamic_ext,
35    build_gemma_decode_hir_sized_ext, build_gemma_graph_sized_last_logits,
36    build_gemma_prefill_hir_dynamic_ext,
37};
38use crate::config::GemmaConfig;
39use crate::rope::{resolve_inv_freq, rope_slice};
40use anyhow::{Context, Result};
41use rlx_core::autoregressive::{
42    KvCacheState, kv_from_prefill_outputs, run_bucketed_kv_decode_hir, split_decode_logits_kv,
43};
44use rlx_core::flow_bridge::compile_options_from_profile;
45use rlx_core::weight_loader::WeightLoader;
46use rlx_core::weight_map::WeightMap;
47use rlx_flow::CompileProfile;
48use rlx_ir::DimBinding;
49use rlx_ir::logical_kernel::KernelDispatchConfig;
50use rlx_qwen3::sampling::{SampleOpts, sample_token};
51use rlx_runtime::attn_mask::bucket_decode_mask;
52use rlx_runtime::compile_cache::{
53    BucketedCompileCache, CacheRunInput, CompileCache, DynamicDimCompileCache,
54};
55use rlx_runtime::{CompileOptions, Device, Session};
56use std::collections::HashMap;
57use std::path::Path;
58
59/// Decode compile profile with backend-specific fixes (Metal: unfused GQA path).
60pub fn decode_profile_for_device(device: Device) -> CompileProfile {
61    metal_safe_decode_profile(device, CompileProfile::gemma_decode())
62}
63
64/// MPSGraph rejects fused GQA reshapes in decode (KV concat + `repeat_kv`).
65fn metal_safe_decode_profile(device: Device, mut profile: CompileProfile) -> CompileProfile {
66    if device == Device::Metal {
67        profile.fusion.skip = true;
68        profile.backend.metal.skip_fusion = true;
69        profile.backend.metal.unfuse_regions = true;
70    }
71    profile
72}
73
74/// Stateful Gemma generation handle.
75///
76/// Holds the (config, weight bytes, token history) and rebuilds a
77/// prefill graph on each [`step`] call. Cheap to construct after
78/// initial weight load; tokens stay in-memory between calls.
79pub struct GemmaGenerator {
80    cfg: GemmaConfig,
81    /// Map of weight key → (f32 data, shape). Cloned on each step
82    /// into a fresh `WeightMap` because `WeightMap::take` is
83    /// destructive — see the cached-generator notes for the path
84    /// that avoids the clone.
85    weights_cache: HashMap<String, (Vec<f32>, Vec<usize>)>,
86    tokens: Vec<u32>,
87    device: Device,
88    /// Populated lazily on the first `step_cached` call (seeded from
89    /// the prompt via prefill-with-cache); thereafter advanced by each
90    /// decode step.
91    cache: Option<KvCacheState>,
92    /// Per-key LRU compile cache for prefill graphs. Keyed by `seq`.
93    /// Set to `None` to disable (default for new instances; opt in via
94    /// [`GemmaGenerator::with_prefill_cache`]).
95    prefill_compile_cache: Option<CompileCache>,
96    /// Compile prefill once with `sym::SEQ`, specialize per prompt length.
97    prefill_dynamic_cache: Option<DynamicDimCompileCache>,
98    /// Bucketed compile cache for decode-mode graphs. Each bucket
99    /// holds one compiled graph specialized at its upper-bound
100    /// `past_seq`; the host pads `past_k`/`past_v` and supplies a
101    /// per-step mask so a single bucket serves every `past_seq` in
102    /// its range. Opt in via [`GemmaGenerator::with_decode_cache`].
103    decode_compile_cache: Option<BucketedCompileCache>,
104    decode_dynamic_cache: Option<DynamicDimCompileCache>,
105    /// Resolved RoPE inverse frequencies (includes Llama 3 scaling).
106    inv_freq: Vec<f64>,
107    /// Tier-1 compile profile for prefill graphs.
108    prefill_profile: CompileProfile,
109    /// Tier-1 compile profile for decode graphs.
110    decode_profile: CompileProfile,
111}
112
113impl GemmaGenerator {
114    /// Construct from any [`WeightLoader`] — drains it into an
115    /// internal cache so the loader is free after this call.
116    pub fn from_loader(
117        cfg: GemmaConfig,
118        loader: &mut dyn WeightLoader,
119        device: Device,
120    ) -> Result<Self> {
121        let keys = loader.remaining_keys();
122        // Capture the arch up front so the cache-key normalization can
123        // pick the gemma2 reverse alias (4 distinct per-layer norms)
124        // over the generic Llama-flavored one (2 norms, ambiguous on
125        // `ffn_norm`). Owned string so we don't hold a borrow across
126        // the mutable `loader.take` calls below.
127        let arch_hint: Option<String> = loader.arch_hint().map(|s| s.to_string());
128        let mut weights_cache = HashMap::with_capacity(keys.len());
129        for k in keys {
130            let v = loader
131                .take(&k)
132                .with_context(|| format!("draining weight {k}"))?;
133            // Normalize the cache key to the safetensors / HuggingFace
134            // naming convention so subsequent builder calls that ask
135            // for `model.embed_tokens.weight` (the canonical name baked
136            // into the gemma builder) hit the cache whether the
137            // loader was safetensors-native or GGUF-native.
138            let canonical = match arch_hint.as_deref() {
139                Some(a) => rlx_core::weight_loader::gguf_to_hf_name_for_arch(&k, a)
140                    .unwrap_or_else(|| k.clone()),
141                None => rlx_core::weight_loader::gguf_to_hf_name(&k).unwrap_or_else(|| k.clone()),
142            };
143            weights_cache.insert(canonical, v);
144        }
145        let rope_factors = weights_cache
146            .get("rope_freqs.weight")
147            .map(|(d, _)| d.as_slice());
148        let inv_freq = resolve_inv_freq(&cfg, rope_factors);
149        Ok(Self {
150            cfg,
151            weights_cache,
152            tokens: Vec::new(),
153            device,
154            cache: None,
155            prefill_compile_cache: None,
156            prefill_dynamic_cache: None,
157            decode_compile_cache: None,
158            decode_dynamic_cache: None,
159            inv_freq,
160            prefill_profile: CompileProfile::gemma_prefill(),
161            decode_profile: metal_safe_decode_profile(device, CompileProfile::gemma_decode()),
162        })
163    }
164
165    /// Like [`Self::from_loader`] but loads tier-1 profiles from
166    /// `gemma.rlx.toml` in the weights directory when present.
167    pub fn from_loader_at(
168        cfg: GemmaConfig,
169        loader: &mut dyn WeightLoader,
170        device: Device,
171        weights_path: &Path,
172    ) -> Result<Self> {
173        let mut g = Self::from_loader(cfg, loader, device)?;
174        g.prefill_profile = crate::gemma_profile_near_weights(weights_path, false);
175        g.decode_profile = metal_safe_decode_profile(
176            device,
177            crate::gemma_profile_near_weights(weights_path, true),
178        );
179        Ok(g)
180    }
181
182    /// Override tier-1 compile profiles explicitly.
183    pub fn with_compile_profiles(
184        mut self,
185        prefill: CompileProfile,
186        decode: CompileProfile,
187    ) -> Self {
188        self.prefill_profile = prefill;
189        self.decode_profile = metal_safe_decode_profile(self.device, decode);
190        self
191    }
192
193    pub fn prefill_profile(&self) -> &CompileProfile {
194        &self.prefill_profile
195    }
196
197    pub fn decode_profile(&self) -> &CompileProfile {
198        &self.decode_profile
199    }
200
201    fn profile_compile_options(&self, decode: bool) -> CompileOptions {
202        let profile = if decode {
203            &self.decode_profile
204        } else {
205            &self.prefill_profile
206        };
207        compile_options_from_profile(profile, self.device, KernelDispatchConfig::default())
208    }
209
210    fn compile_graph_profiled(
211        &self,
212        session: &Session,
213        graph: rlx_ir::Graph,
214    ) -> Result<rlx_runtime::CompiledGraph> {
215        let opts = self.profile_compile_options(false);
216        Ok(session.compile_with(graph, &opts))
217    }
218
219    fn compile_graph_profiled_decode(
220        &self,
221        session: &Session,
222        graph: rlx_ir::Graph,
223    ) -> Result<rlx_runtime::CompiledGraph> {
224        Ok(session.compile_with(graph, &self.profile_compile_options(true)))
225    }
226
227    /// Enable the prefill compile cache with the given LRU capacity.
228    /// Useful when the same prompt length is used across multiple
229    /// generation runs — the second + Nth run skip the compile +
230    /// param-attach roundtrip (~30-50ms per call on CPU).
231    pub fn with_prefill_cache(mut self, capacity: usize) -> Self {
232        self.prefill_compile_cache = Some(CompileCache::new(self.device, capacity));
233        self.prefill_dynamic_cache = None;
234        self
235    }
236
237    /// Compile prefill once with `sym::SEQ`, specialize per prompt length.
238    pub fn with_dynamic_prefill_cache(mut self, capacity: usize) -> Self {
239        self.prefill_dynamic_cache = Some(DynamicDimCompileCache::new(self.device, capacity));
240        self.prefill_compile_cache = None;
241        self
242    }
243
244    /// Enable the bucketed decode compile cache spanning past-seq
245    /// values in `[1, max_past]`. Buckets are power-of-two
246    /// `[1..2, 2..3, 3..5, 5..9, 9..17, …]`. Each bucket compiles
247    /// one graph at its upper bound; a steady-state generation loop
248    /// across `N` tokens compiles `O(log N)` graphs instead of `N`.
249    ///
250    /// Padding compute waste is bounded at 2×: actual `past_seq` is
251    /// at least half the bucket's upper bound (except possibly the
252    /// smallest bucket).
253    pub fn with_decode_cache(mut self, max_past: usize) -> Self {
254        let cache = BucketedCompileCache::power_of_two_ladder(
255            self.device,
256            /*min*/ 1,
257            max_past.max(1) as u64,
258        );
259        self.decode_compile_cache = Some(cache);
260        self.decode_dynamic_cache = None;
261        self
262    }
263
264    /// Compile decode once with `sym::PAST_SEQ`, specialize per prefix length.
265    pub fn with_dynamic_decode_cache(mut self, capacity: usize) -> Self {
266        self.decode_dynamic_cache = Some(DynamicDimCompileCache::new(self.device, capacity));
267        self.decode_compile_cache = None;
268        self
269    }
270
271    /// Convenience: load weights from a safetensors or GGUF path
272    /// (dispatch by extension; see `rlx_core::weight_loader::load_from_path`).
273    pub fn from_path(cfg: GemmaConfig, path: &str, device: Device) -> Result<Self> {
274        let mut loader = rlx_core::weight_loader::load_from_path(path)?;
275        Self::from_loader(cfg, loader.as_mut(), device)
276    }
277
278    /// Same as [`from_path`] but with MTP-head visibility control.
279    /// When `include_mtp=true` and the file is GGUF, MTP weights are
280    /// drained into the generator's cache alongside the base
281    /// weights. The base inference path still ignores them — they
282    /// sit in cache for a future MTP-aware decoder. Non-GGUF formats
283    /// silently ignore the flag (safetensors files publish all
284    /// tensors uniformly; downstream code distinguishes by name).
285    pub fn from_path_with_mtp(
286        cfg: GemmaConfig,
287        path: &str,
288        device: Device,
289        include_mtp: bool,
290    ) -> Result<Self> {
291        // Branch on extension so we can flip the GGUF-specific
292        // visibility option. Safetensors has no equivalent — it
293        // doesn't isolate MTP tensors at the loader level.
294        if path.ends_with(".gguf") {
295            let mut gguf = rlx_core::weight_loader::GgufLoader::from_file(path)?;
296            gguf.include_mtp(include_mtp);
297            Self::from_loader(cfg, &mut gguf, device)
298        } else {
299            Self::from_path(cfg, path, device)
300        }
301    }
302
303    /// Replace the token history with `prompt_ids`. Does not run the
304    /// model — the next [`step`] call processes the full sequence.
305    /// Clears any KV cache from a prior generation.
306    pub fn prefill(&mut self, prompt_ids: &[u32]) {
307        self.tokens.clear();
308        self.tokens.extend_from_slice(prompt_ids);
309        self.cache = None;
310    }
311
312    /// Run one prefill over the current token history and sample the
313    /// next token. The sampled token is appended to the history and
314    /// returned. Call repeatedly to generate.
315    pub fn step(&mut self, opts: SampleOpts) -> Result<u32> {
316        if self.tokens.is_empty() {
317            anyhow::bail!("step() called with empty token history; call prefill() first");
318        }
319        let seq = self.tokens.len();
320        let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
321        let (graph, params) = build_gemma_graph_sized_last_logits(
322            &self.cfg, &mut wm, /*batch*/ 1, seq, /*with_kv_outputs*/ false,
323        )?;
324        let session = Session::new(self.device);
325        let mut compiled = self.compile_graph_profiled(&session, graph)?;
326        for (name, data) in &params {
327            compiled.set_param(name, data);
328        }
329        let ids_f32: Vec<f32> = self.tokens.iter().map(|&i| i as f32).collect();
330        let outputs = compiled.run(&[("input_ids", ids_f32.as_slice())]);
331        let logits = outputs
332            .into_iter()
333            .next()
334            .context("compiled.run returned no outputs")?;
335
336        let vocab = self.cfg.vocab_size;
337        let expected = vocab;
338        if logits.len() < expected {
339            anyhow::bail!(
340                "logits length {} < expected {} (last logits, seq {seq}, vocab {vocab})",
341                logits.len(),
342                expected
343            );
344        }
345        // Last-logits graph returns [B=1, 1, vocab].
346        let last_row = &logits[..vocab];
347        let tok = sample_token(last_row, opts) as u32;
348        self.tokens.push(tok);
349        Ok(tok)
350    }
351
352    /// Run `n` steps and return the newly generated token ids
353    /// (excludes the prefill prompt).
354    pub fn generate(&mut self, n: usize, opts: SampleOpts) -> Result<Vec<u32>> {
355        let start = self.tokens.len();
356        for _ in 0..n {
357            self.step(opts)?;
358        }
359        Ok(self.tokens[start..].to_vec())
360    }
361
362    /// Cached step: O(L) per token instead of O(L²). First call seeds
363    /// the KV cache from the prompt via prefill-with-cache; subsequent
364    /// calls run the decode-mode graph on just the last token + cached
365    /// past. Output is bit-identical to [`step`] modulo reduction
366    /// order in the SDPA kernel.
367    ///
368    /// Invariant after each call: `cache.past_seq == tokens.len() - 1`
369    /// (the just-sampled token is appended but not yet in the cache;
370    /// it becomes the input for the next decode step).
371    pub fn step_cached(&mut self, opts: SampleOpts) -> Result<u32> {
372        if self.tokens.is_empty() {
373            anyhow::bail!("step_cached() called with empty token history; call prefill() first");
374        }
375        if self.cache.is_none() {
376            // The seed runs prefill, populates the cache, samples from
377            // the last position, and appends the token. Return that
378            // token directly — no decode step on this call.
379            let tok = self.seed_cache_from_prompt(opts)?;
380            return Ok(tok);
381        }
382        let cache = self.cache.as_ref().unwrap();
383        let past_seq = cache.past_len;
384        if self.tokens.len() <= past_seq {
385            anyhow::bail!(
386                "cache invariant violated: tokens.len() {} <= past_len {}",
387                self.tokens.len(),
388                past_seq
389            );
390        }
391        let input_tok = self.tokens[past_seq];
392
393        let (logits, new_k, new_v) = if self.decode_dynamic_cache.is_some() {
394            self.decode_step_dynamic(past_seq, input_tok)?
395        } else if self.decode_compile_cache.is_some()
396            && self
397                .decode_compile_cache
398                .as_ref()
399                .unwrap()
400                .bucket_for(past_seq as u64)
401                .is_some()
402        {
403            self.decode_step_bucketed(past_seq, input_tok)?
404        } else {
405            self.decode_step_oneshot(past_seq, input_tok)?
406        };
407
408        let cache_mut = self.cache.as_mut().unwrap();
409        cache_mut.past_len = past_seq + 1;
410        cache_mut.layers_k = new_k;
411        cache_mut.layers_v = new_v;
412
413        let vocab = self.cfg.vocab_size;
414        if logits.len() != vocab {
415            anyhow::bail!("decode logits length {} != vocab {}", logits.len(), vocab);
416        }
417        let tok = sample_token(&logits, opts) as u32;
418        self.tokens.push(tok);
419        Ok(tok)
420    }
421
422    /// Decode path that compiles a fresh graph for the exact `past_seq`
423    /// every call. Slower but always-correct fallback.
424    #[allow(clippy::type_complexity)]
425    fn decode_step_oneshot(
426        &mut self,
427        past_seq: usize,
428        input_tok: u32,
429    ) -> Result<(Vec<f32>, Vec<Vec<f32>>, Vec<Vec<f32>>)> {
430        let cache = self.cache.as_ref().unwrap();
431
432        let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
433        let (graph, params) =
434            build_gemma_decode_graph_sized(&self.cfg, &mut wm, /*batch*/ 1, past_seq)?;
435        let session = Session::new(self.device);
436        let mut compiled = self.compile_graph_profiled_decode(&session, graph)?;
437        for (name, data) in &params {
438            compiled.set_param(name, data);
439        }
440
441        let input_ids_f32 = [input_tok as f32];
442        let key_strs: Vec<String> = (0..self.cfg.num_hidden_layers)
443            .flat_map(|i| [format!("past_k_{i}"), format!("past_v_{i}")])
444            .collect();
445        let mut inputs: Vec<(&str, &[f32])> =
446            Vec::with_capacity(1 + 2 * self.cfg.num_hidden_layers);
447        inputs.push(("input_ids", input_ids_f32.as_slice()));
448        for i in 0..self.cfg.num_hidden_layers {
449            inputs.push((&key_strs[2 * i], cache.layers_k[i].as_slice()));
450            inputs.push((&key_strs[2 * i + 1], cache.layers_v[i].as_slice()));
451        }
452
453        let outputs = compiled.run(&inputs);
454        split_decode_logits_kv(outputs, self.cfg.num_hidden_layers)
455    }
456
457    #[allow(clippy::type_complexity)]
458    fn decode_step_dynamic(
459        &mut self,
460        past_seq: usize,
461        input_tok: u32,
462    ) -> Result<(Vec<f32>, Vec<Vec<f32>>, Vec<Vec<f32>>)> {
463        let cache = self.cache.as_ref().unwrap();
464        let binding = DimBinding::batch_past_seq(1, past_seq);
465        let opts = self
466            .profile_compile_options(true)
467            .dim_binding(binding.clone());
468        let cache_dyn = self
469            .decode_dynamic_cache
470            .as_mut()
471            .ok_or_else(|| anyhow::anyhow!("dynamic decode without cache"))?;
472        let needs_upload = !cache_dyn.contains(past_seq as u64);
473        let cfg = self.cfg.clone();
474        let weights_cache = self.weights_cache.clone();
475        let max_past = self.cfg.max_position_embeddings;
476        let compiled = cache_dyn.get_or_specialize(
477            past_seq as u64,
478            &binding,
479            || {
480                let mut wm = WeightMap::from_tensors(weights_cache);
481                build_gemma_decode_hir_dynamic_ext(&cfg, &mut wm, 1, max_past)
482                    .expect("dynamic decode HIR")
483                    .0
484            },
485            &opts,
486        )?;
487        if needs_upload {
488            let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
489            let (_, params) = build_gemma_decode_hir_dynamic_ext(&self.cfg, &mut wm, 1, max_past)?;
490            for (name, data) in &params {
491                compiled.set_param(name, data);
492            }
493        }
494
495        let (cos, sin) = compute_rope_slice(&self.inv_freq, past_seq);
496        let input_ids_f32 = [input_tok as f32];
497        let key_strs: Vec<String> = (0..self.cfg.num_hidden_layers)
498            .flat_map(|i| [format!("past_k_{i}"), format!("past_v_{i}")])
499            .collect();
500        let mut inputs: Vec<(&str, &[f32])> =
501            Vec::with_capacity(3 + 2 * self.cfg.num_hidden_layers);
502        inputs.push(("input_ids", input_ids_f32.as_slice()));
503        inputs.push(("rope_cos", cos.as_slice()));
504        inputs.push(("rope_sin", sin.as_slice()));
505        for i in 0..self.cfg.num_hidden_layers {
506            inputs.push((&key_strs[2 * i], cache.layers_k[i].as_slice()));
507            inputs.push((&key_strs[2 * i + 1], cache.layers_v[i].as_slice()));
508        }
509        let outputs = compiled.run(&inputs);
510        split_decode_logits_kv(outputs, self.cfg.num_hidden_layers)
511    }
512
513    #[allow(clippy::type_complexity)]
514    fn decode_step_bucketed(
515        &mut self,
516        past_seq: usize,
517        input_tok: u32,
518    ) -> Result<(Vec<f32>, Vec<Vec<f32>>, Vec<Vec<f32>>)> {
519        let kv = self.cache.as_ref().unwrap().clone();
520        let kv_dim = self.cfg.kv_proj_dim();
521        let n_layers = self.cfg.num_hidden_layers;
522        let (cos, sin) = compute_rope_slice(&self.inv_freq, past_seq);
523        let input_ids_f32 = [input_tok as f32];
524        let decode_opts = self.profile_compile_options(true);
525        let upper = self
526            .decode_compile_cache
527            .as_ref()
528            .and_then(|cache_dec| {
529                cache_dec.bucket_for(past_seq as u64).map(|idx| {
530                    cache_dec
531                        .buckets()
532                        .nth(idx)
533                        .map(|r| (r.end - 1) as usize)
534                        .unwrap_or(past_seq)
535                })
536            })
537            .unwrap_or(past_seq);
538        let mask = bucket_decode_mask(past_seq, upper);
539        let fixed = [
540            CacheRunInput {
541                name: "input_ids",
542                data: &input_ids_f32,
543                row_inner: None,
544            },
545            CacheRunInput {
546                name: "rope_cos",
547                data: &cos,
548                row_inner: None,
549            },
550            CacheRunInput {
551                name: "rope_sin",
552                data: &sin,
553                row_inner: None,
554            },
555            CacheRunInput {
556                name: "mask",
557                data: &mask,
558                row_inner: None,
559            },
560        ];
561        let cfg = self.cfg.clone();
562        let weights = self.weights_cache.clone();
563        let cache_dec = self.decode_compile_cache.as_mut().unwrap();
564        run_bucketed_kv_decode_hir(
565            cache_dec,
566            past_seq,
567            &kv,
568            kv_dim,
569            n_layers,
570            &fixed,
571            |upper| {
572                let mut wm = WeightMap::from_tensors(weights.clone());
573                build_gemma_decode_hir_sized_ext(&cfg, &mut wm, 1, upper as usize, true)
574                    .expect("gemma bucketed decode HIR")
575            },
576            &decode_opts,
577        )
578    }
579
580    /// Run prefill-with-cache and return the raw outputs. Uses the
581    /// LRU `CompileCache` when enabled; otherwise compiles fresh each
582    /// call. Keyed by `seq` because graph shape is seq-specialized.
583    #[allow(clippy::unnecessary_unwrap)]
584    fn run_prefill_with_cache(
585        &mut self,
586        batch: usize,
587        seq: usize,
588        ids_f32: &[f32],
589    ) -> Result<Vec<Vec<f32>>> {
590        if self.prefill_dynamic_cache.is_some() {
591            let binding = DimBinding::batch_seq(batch, seq);
592            let opts = compile_options_from_profile(
593                &self.prefill_profile,
594                self.device,
595                KernelDispatchConfig::default(),
596            )
597            .dim_binding(binding.clone());
598            let cache = self.prefill_dynamic_cache.as_mut().expect("checked");
599            let needs_upload = !cache.contains(seq as u64);
600            let cfg = self.cfg.clone();
601            let weights_cache = self.weights_cache.clone();
602            let max_seq = self.cfg.max_position_embeddings;
603            let compiled = cache.get_or_specialize(
604                seq as u64,
605                &binding,
606                || {
607                    let mut wm = WeightMap::from_tensors(weights_cache);
608                    build_gemma_prefill_hir_dynamic_ext(&cfg, &mut wm, batch, max_seq, true)
609                        .expect("dynamic prefill HIR")
610                        .0
611                },
612                &opts,
613            )?;
614            if needs_upload {
615                let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
616                let (_, params) =
617                    build_gemma_prefill_hir_dynamic_ext(&self.cfg, &mut wm, batch, max_seq, true)?;
618                for (name, data) in &params {
619                    compiled.set_param(name, data);
620                }
621            }
622            let last_idx = vec![(seq - 1) as f32];
623            Ok(compiled.run(&[("input_ids", ids_f32), ("last_token_idx", &last_idx)]))
624        } else if self.prefill_compile_cache.is_some() {
625            let key = ((batch as u64) << 32) | (seq as u64);
626            let opts = self.profile_compile_options(false);
627            if !self.prefill_compile_cache.as_ref().unwrap().contains(key) {
628                let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
629                let (graph, params) = build_gemma_graph_sized_last_logits(
630                    &self.cfg, &mut wm, batch, seq, /*with_kv_outputs*/ true,
631                )?;
632                {
633                    let compiled = self
634                        .prefill_compile_cache
635                        .as_mut()
636                        .unwrap()
637                        .get_or_compile_with_options(key, || graph, &opts);
638                    for (name, data) in &params {
639                        compiled.set_param(name, data);
640                    }
641                }
642            }
643            let compiled = self
644                .prefill_compile_cache
645                .as_mut()
646                .unwrap()
647                .get_or_compile_with_options(key, || unreachable!("just populated above"), &opts);
648            Ok(compiled.run(&[("input_ids", ids_f32)]))
649        } else {
650            let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
651            let (graph, params) = build_gemma_graph_sized_last_logits(
652                &self.cfg, &mut wm, batch, seq, /*with_kv_outputs*/ true,
653            )?;
654            let session = Session::new(self.device);
655            let opts = self.profile_compile_options(false);
656            let mut compiled = session.compile_with(graph, &opts);
657            for (name, data) in &params {
658                compiled.set_param(name, data);
659            }
660            Ok(compiled.run(&[("input_ids", ids_f32)]))
661        }
662    }
663
664    /// Run `n` cached steps and return the newly generated tokens.
665    pub fn generate_cached(&mut self, n: usize, opts: SampleOpts) -> Result<Vec<u32>> {
666        self.generate_cached_with(n, opts, |_| {})
667    }
668
669    /// Same as [`generate_cached`] but invokes `on_token` once per
670    /// freshly sampled id, inside the decode loop. The whole `n` step
671    /// loop shares the bucketed compile cache — callers wanting a
672    /// streaming UI should prefer this to calling
673    /// `generate_cached(1, …)` `n` times (which forces a fresh
674    /// compile per token at the bucket boundaries).
675    pub fn generate_cached_with(
676        &mut self,
677        n: usize,
678        opts: SampleOpts,
679        mut on_token: impl FnMut(u32),
680    ) -> Result<Vec<u32>> {
681        let start = self.tokens.len();
682        for _ in 0..n {
683            let tok = self.step_cached(opts)?;
684            on_token(tok);
685        }
686        Ok(self.tokens[start..].to_vec())
687    }
688
689    /// Run prefill-with-cache on the current `self.tokens` (the
690    /// prompt), populate `self.cache`, sample the next token from the
691    /// last position's logits, and append it. Returns the sampled
692    /// token. Invariant after: `cache.past_seq == tokens.len() - 1`.
693    fn seed_cache_from_prompt(&mut self, opts: SampleOpts) -> Result<u32> {
694        let seq = self.tokens.len();
695        let batch = 1usize;
696        let kv_dim = self.cfg.kv_proj_dim();
697
698        let ids_f32: Vec<f32> = self.tokens.iter().map(|&i| i as f32).collect();
699        let outputs = self.run_prefill_with_cache(batch, seq, &ids_f32)?;
700        let (logits, kv) =
701            kv_from_prefill_outputs(outputs, batch, seq, kv_dim, self.cfg.num_hidden_layers)?;
702        self.cache = Some(kv);
703
704        let vocab = self.cfg.vocab_size;
705        let needed = vocab;
706        if logits.len() < needed {
707            anyhow::bail!("prefill logits length {} < {}", logits.len(), needed);
708        }
709        let last_row = &logits[..vocab];
710        let tok = sample_token(last_row, opts) as u32;
711        self.tokens.push(tok);
712        Ok(tok)
713    }
714
715    /// Full token history (prompt + generated).
716    pub fn tokens(&self) -> &[u32] {
717        &self.tokens
718    }
719
720    pub fn config(&self) -> &GemmaConfig {
721        &self.cfg
722    }
723
724    /// Low-level primitive: reset internal state, run prefill-with-cache
725    /// over `context`, and return the *last position's* logits row
726    /// (`P(next_token | context)`). Does NOT sample or append. The
727    /// internal `tokens` buffer is set to `context` and the KV cache
728    /// is populated to `past_seq = context.len()`.
729    ///
730    /// First row of logits after prefill-with-cache (no sampling).
731    pub fn prefill_get_last_logits(&mut self, context: &[u32]) -> Result<Vec<f32>> {
732        if context.is_empty() {
733            anyhow::bail!("prefill_get_last_logits: empty context");
734        }
735        self.tokens.clear();
736        self.tokens.extend_from_slice(context);
737        self.cache = None;
738
739        let seq = context.len();
740        let batch = 1usize;
741        let kv_dim = self.cfg.kv_proj_dim();
742
743        let ids_f32: Vec<f32> = context.iter().map(|&i| i as f32).collect();
744        let outputs = self.run_prefill_with_cache(batch, seq, &ids_f32)?;
745        let (logits, kv) =
746            kv_from_prefill_outputs(outputs, batch, seq, kv_dim, self.cfg.num_hidden_layers)?;
747        self.cache = Some(kv);
748
749        let vocab = self.cfg.vocab_size;
750        let needed = vocab;
751        if logits.len() < needed {
752            anyhow::bail!("logits short: {} < {}", logits.len(), needed);
753        }
754        Ok(logits[..vocab].to_vec())
755    }
756
757    /// Low-level primitive: run one decode step with the caller-
758    /// supplied input token (no sampling), advance the KV cache, and
759    /// return the resulting logits row `P(next | history ++ input)`.
760    /// Appends `input` to the `tokens` buffer so the invariant
761    /// `cache.past_seq == tokens.len()` holds after this call (note:
762    /// differs from `step_cached` invariant because this method does
763    /// not append a sampled token).
764    pub fn decode_get_logits(&mut self, input: u32) -> Result<Vec<f32>> {
765        if self.cache.is_none() {
766            anyhow::bail!(
767                "decode_get_logits: cache not seeded; call prefill_get_last_logits first"
768            );
769        }
770        self.tokens.push(input);
771        let seq = self.tokens.len();
772        let batch = 1usize;
773        let kv_dim = self.cfg.kv_proj_dim();
774        let ids_f32: Vec<f32> = self.tokens.iter().map(|&i| i as f32).collect();
775        let outputs = self.run_prefill_with_cache(batch, seq, &ids_f32)?;
776        let (logits, kv) =
777            kv_from_prefill_outputs(outputs, batch, seq, kv_dim, self.cfg.num_hidden_layers)?;
778        self.cache = Some(kv);
779        let vocab = self.cfg.vocab_size;
780        Ok(logits[..vocab].to_vec())
781    }
782}
783
784/// Compute the single-row (cos, sin) RoPE slice for absolute position
785/// `pos`. Matches the formula in the prefill builder so cached decode
786/// and recompute prefill produce the same RoPE rotation.
787fn compute_rope_slice(inv_freq: &[f64], pos: usize) -> (Vec<f32>, Vec<f32>) {
788    rope_slice(inv_freq, pos)
789}
790
791#[cfg(test)]
792mod tests {
793    use super::*;
794    use crate::config::GemmaConfig;
795    use crate::rope::{build_rope_tables, resolve_inv_freq, rope_slice};
796    use rlx_flow::CompileProfile;
797
798    fn tiny_cfg() -> GemmaConfig {
799        let mut cfg = GemmaConfig::tiny_test();
800        cfg.vocab_size = 16;
801        cfg.head_dim = Some(8);
802        cfg
803    }
804
805    fn synthetic_tensors(cfg: &GemmaConfig) -> HashMap<String, (Vec<f32>, Vec<usize>)> {
806        let h = cfg.hidden_size;
807        let q_dim = cfg.q_proj_dim();
808        let kv_dim = cfg.kv_proj_dim();
809        let int_dim = cfg.intermediate_size;
810        let mut t: HashMap<String, (Vec<f32>, Vec<usize>)> = HashMap::new();
811        // Use a deterministic non-zero pattern so logits aren't all 0
812        // (sampling on an all-zero row is undefined order).
813        let pat = |n: usize, salt: u32| -> Vec<f32> {
814            (0..n)
815                .map(|i| {
816                    let x = ((i as u32).wrapping_mul(2654435761).wrapping_add(salt)) >> 8;
817                    (x as f32 / (1u32 << 24) as f32) - 0.5
818                })
819                .collect()
820        };
821        t.insert(
822            "model.embed_tokens.weight".into(),
823            (pat(cfg.vocab_size * h, 1), vec![cfg.vocab_size, h]),
824        );
825        for i in 0..cfg.num_hidden_layers {
826            let lp = format!("model.layers.{i}");
827            t.insert(
828                format!("{lp}.input_layernorm.weight"),
829                (pat(h, 100 + i as u32), vec![h]),
830            );
831            t.insert(
832                format!("{lp}.post_attention_layernorm.weight"),
833                (pat(h, 200 + i as u32), vec![h]),
834            );
835            t.insert(
836                format!("{lp}.self_attn.q_proj.weight"),
837                (pat(q_dim * h, 300 + i as u32), vec![q_dim, h]),
838            );
839            t.insert(
840                format!("{lp}.self_attn.k_proj.weight"),
841                (pat(kv_dim * h, 400 + i as u32), vec![kv_dim, h]),
842            );
843            t.insert(
844                format!("{lp}.self_attn.v_proj.weight"),
845                (pat(kv_dim * h, 500 + i as u32), vec![kv_dim, h]),
846            );
847            t.insert(
848                format!("{lp}.self_attn.o_proj.weight"),
849                (pat(h * q_dim, 600 + i as u32), vec![h, q_dim]),
850            );
851            t.insert(
852                format!("{lp}.mlp.gate_proj.weight"),
853                (pat(int_dim * h, 900 + i as u32), vec![int_dim, h]),
854            );
855            t.insert(
856                format!("{lp}.mlp.up_proj.weight"),
857                (pat(int_dim * h, 1000 + i as u32), vec![int_dim, h]),
858            );
859            t.insert(
860                format!("{lp}.mlp.down_proj.weight"),
861                (pat(h * int_dim, 1100 + i as u32), vec![h, int_dim]),
862            );
863        }
864        t.insert("model.norm.weight".into(), (pat(h, 2000), vec![h]));
865        t.insert(
866            "lm_head.weight".into(),
867            (pat(cfg.vocab_size * h, 3000), vec![cfg.vocab_size, h]),
868        );
869        t
870    }
871
872    fn synthetic_weights(cfg: &GemmaConfig) -> WeightMap {
873        WeightMap::from_tensors(synthetic_tensors(cfg))
874    }
875
876    #[test]
877    fn generator_drains_loader_and_runs_one_step() {
878        let cfg = tiny_cfg();
879        let mut wm = synthetic_weights(&cfg);
880        let mut gn = GemmaGenerator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap();
881        assert_eq!(wm.len(), 0, "loader should be drained");
882        gn.prefill(&[1, 2, 3]);
883        let t = gn.step(SampleOpts::greedy()).unwrap();
884        assert!((t as usize) < cfg.vocab_size);
885        assert_eq!(gn.tokens().len(), 4);
886    }
887
888    #[test]
889    fn generate_n_appends_n_tokens() {
890        let cfg = tiny_cfg();
891        let mut wm = synthetic_weights(&cfg);
892        let mut gn = GemmaGenerator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap();
893        gn.prefill(&[5, 6]);
894        let new_tokens = gn.generate(3, SampleOpts::greedy()).unwrap();
895        assert_eq!(new_tokens.len(), 3);
896        assert_eq!(gn.tokens().len(), 5);
897        for t in &new_tokens {
898            assert!((*t as usize) < cfg.vocab_size);
899        }
900    }
901
902    #[test]
903    fn step_without_prefill_errors() {
904        let cfg = tiny_cfg();
905        let mut wm = synthetic_weights(&cfg);
906        let mut gn = GemmaGenerator::from_loader(cfg, &mut wm, Device::Cpu).unwrap();
907        let r = gn.step(SampleOpts::greedy());
908        assert!(r.is_err());
909    }
910
911    fn max_abs_diff(a: &[f32], b: &[f32]) -> f32 {
912        a.iter()
913            .zip(b.iter())
914            .map(|(x, y)| (x - y).abs())
915            .fold(0f32, f32::max)
916    }
917
918    #[test]
919    fn prefill_logits_unchanged_with_kv_export() {
920        let cfg = tiny_cfg();
921        let prompt: Vec<u32> = vec![1, 2, 3, 5];
922
923        let mut wm_a = synthetic_weights(&cfg);
924        let mut wm_b = synthetic_weights(&cfg);
925        let (graph_a, params_a) =
926            build_gemma_graph_sized_last_logits(&cfg, &mut wm_a, 1, 4, false).unwrap();
927        let (graph_b, params_b) =
928            build_gemma_graph_sized_last_logits(&cfg, &mut wm_b, 1, 4, true).unwrap();
929        let session = Session::new(Device::Cpu);
930        let opts = CompileOptions::new();
931        let mut ca = session.compile_with(graph_a, &opts);
932        let mut cb = session.compile_with(graph_b, &opts);
933        for (n, d) in &params_a {
934            ca.set_param(n, d);
935        }
936        for (n, d) in &params_b {
937            cb.set_param(n, d);
938        }
939        let ids: Vec<f32> = prompt.iter().map(|&i| i as f32).collect();
940        let la = ca.run(&[("input_ids", &ids)])[0].clone();
941        let lb = cb.run(&[("input_ids", &ids)])[0].clone();
942        let d = max_abs_diff(&la, &lb);
943        assert!(d < 1e-5, "kv export changed prefill logits: max_abs={d:.6}");
944    }
945
946    #[test]
947    fn incremental_decode_logits_match_full_prefill() {
948        let cfg = tiny_cfg();
949        let prompt: Vec<u32> = vec![1, 2, 3, 5];
950
951        let mut wm_a = synthetic_weights(&cfg);
952        let mut gn_a = GemmaGenerator::from_loader(cfg.clone(), &mut wm_a, Device::Cpu).unwrap();
953        let tok = gn_a
954            .prefill_get_last_logits(&prompt)
955            .map(|l| sample_token(&l, SampleOpts::greedy()) as u32)
956            .unwrap();
957
958        let mut extended = prompt.clone();
959        extended.push(tok);
960
961        let mut wm_b = synthetic_weights(&cfg);
962        let mut gn_b = GemmaGenerator::from_loader(cfg.clone(), &mut wm_b, Device::Cpu).unwrap();
963        let full = gn_b.prefill_get_last_logits(&extended).unwrap();
964
965        let mut wm_c = synthetic_weights(&cfg);
966        let mut gn_c = GemmaGenerator::from_loader(cfg.clone(), &mut wm_c, Device::Cpu).unwrap();
967        gn_c.prefill_get_last_logits(&prompt).unwrap();
968        let incremental = gn_c.decode_get_logits(tok).unwrap();
969
970        let d = max_abs_diff(&full, &incremental);
971        assert!(
972            d < 1e-2,
973            "decode+KV vs full prefill max_abs={d:.6} (tok={tok})"
974        );
975    }
976
977    fn run_prefill_kv(
978        cfg: &GemmaConfig,
979        wm: &mut WeightMap,
980        seq: usize,
981        ids: &[u32],
982    ) -> Vec<Vec<f32>> {
983        run_prefill_kv_with_options(cfg, wm, seq, ids, &kv_export_compile_options(true))
984    }
985
986    fn kv_export_compile_options(prefill: bool) -> CompileOptions {
987        let profile = if prefill {
988            CompileProfile::gemma_prefill()
989        } else {
990            CompileProfile::gemma_decode()
991        };
992        compile_options_from_profile(&profile, Device::Cpu, KernelDispatchConfig::default())
993    }
994
995    fn run_prefill_kv_with_options(
996        cfg: &GemmaConfig,
997        wm: &mut WeightMap,
998        seq: usize,
999        ids: &[u32],
1000        opts: &CompileOptions,
1001    ) -> Vec<Vec<f32>> {
1002        let ids_f32: Vec<f32> = ids.iter().map(|&i| i as f32).collect();
1003        let (graph, params) = build_gemma_graph_sized_last_logits(cfg, wm, 1, seq, true).unwrap();
1004        let session = Session::new(Device::Cpu);
1005        let mut compiled = session.compile_with(graph, opts);
1006        for (n, d) in &params {
1007            compiled.set_param(n, d);
1008        }
1009        let outputs = compiled.run(&[("input_ids", &ids_f32)]);
1010        let n_layers = cfg.num_hidden_layers;
1011        assert_eq!(outputs.len(), 1 + 2 * n_layers);
1012        let mut kv = Vec::with_capacity(2 * n_layers);
1013        let mut iter = outputs.into_iter().skip(1);
1014        for _ in 0..n_layers {
1015            kv.push(iter.next().unwrap());
1016            kv.push(iter.next().unwrap());
1017        }
1018        kv
1019    }
1020
1021    #[test]
1022    fn decode_graph_bakes_rope_slice_length() {
1023        let cfg = tiny_cfg();
1024        let past_seq = 4usize;
1025        let half = cfg.head_dim() / 2;
1026        let mut wm = synthetic_weights(&cfg);
1027        let (_, params) = build_gemma_decode_graph_sized(&cfg, &mut wm, 1, past_seq).unwrap();
1028        let cos = params
1029            .get("decode.rope.cos")
1030            .expect("decode.rope.cos param");
1031        let sin = params
1032            .get("decode.rope.sin")
1033            .expect("decode.rope.sin param");
1034        assert_eq!(
1035            cos.len(),
1036            half,
1037            "cos param should be one row (half={half}), got {}",
1038            cos.len()
1039        );
1040        assert_eq!(sin.len(), half);
1041        for key in params.keys() {
1042            assert!(
1043                !key.starts_with("rope."),
1044                "decode graph must not include prefill rope table param {key}"
1045            );
1046        }
1047        let inv = resolve_inv_freq(&cfg, None);
1048        let (c_ref, s_ref) = rope_slice(&inv, past_seq);
1049        let d = max_abs_diff(cos, &c_ref) + max_abs_diff(sin, &s_ref);
1050        assert!(d < 1e-6, "baked rope mismatch: {d}");
1051    }
1052
1053    #[test]
1054    fn decode_graph_all_rope_use_baked_cos() {
1055        use rlx_ir::Op;
1056        let cfg = tiny_cfg();
1057        let mut wm = synthetic_weights(&cfg);
1058        let (graph, _) = build_gemma_decode_graph_sized(&cfg, &mut wm, 1, 4).unwrap();
1059        for node in graph.nodes() {
1060            if let Op::Rope { .. } = &node.op {
1061                let cos_id = node.inputs[1];
1062                let cos_node = &graph.node(cos_id);
1063                match &cos_node.op {
1064                    Op::Param { name } => assert_eq!(
1065                        name, "decode.rope.cos",
1066                        "decode RoPE must use baked decode.rope.cos, got {name}"
1067                    ),
1068                    other => panic!("decode RoPE cos input is {other:?}, expected Param"),
1069                }
1070            }
1071        }
1072    }
1073
1074    #[test]
1075    fn decode_graph_rope_cos_is_single_row() {
1076        use rlx_ir::Op;
1077        let cfg = tiny_cfg();
1078        let past_seq = 4usize;
1079        let half = cfg.head_dim() / 2;
1080        let mut wm = synthetic_weights(&cfg);
1081        let (graph, _) = build_gemma_decode_graph_sized(&cfg, &mut wm, 1, past_seq).unwrap();
1082        let mut rope_cos_lens = Vec::new();
1083        for node in graph.nodes() {
1084            if let Op::Rope { .. } = &node.op {
1085                let cos_shape = &graph.node(node.inputs[1]).shape;
1086                let rows = if cos_shape.rank() >= 2 {
1087                    cos_shape.dim(0).unwrap_static()
1088                } else {
1089                    1
1090                };
1091                rope_cos_lens.push(rows);
1092            }
1093        }
1094        assert!(!rope_cos_lens.is_empty(), "decode graph has no RoPE nodes");
1095        for rows in &rope_cos_lens {
1096            assert_eq!(
1097                *rows, 1,
1098                "decode RoPE cos must be single-row [1, half], got {rows} rows"
1099            );
1100        }
1101        assert_eq!(half, cfg.head_dim() / 2);
1102    }
1103
1104    #[test]
1105    fn prefill_kv_matches_extended_prefix() {
1106        let cfg = tiny_cfg();
1107        let prompt: Vec<u32> = vec![1, 2, 3, 5];
1108        let tok = 6u32;
1109        let mut extended = prompt.clone();
1110        extended.push(tok);
1111
1112        let mut wm_prompt = synthetic_weights(&cfg);
1113        let prompt_kv = run_prefill_kv(&cfg, &mut wm_prompt, 4, &prompt);
1114        let mut wm_ext = synthetic_weights(&cfg);
1115        let ext_kv = run_prefill_kv(&cfg, &mut wm_ext, 5, &extended);
1116
1117        let kv_dim = cfg.kv_proj_dim();
1118        for layer in 0..cfg.num_hidden_layers {
1119            let k_prompt = &prompt_kv[2 * layer];
1120            let k_ext = &ext_kv[2 * layer];
1121            let prefix_len = 4 * kv_dim;
1122            assert_eq!(k_prompt.len(), prefix_len);
1123            assert_eq!(k_ext.len(), 5 * kv_dim);
1124            let d = max_abs_diff(k_prompt, &k_ext[..prefix_len]);
1125            assert!(
1126                d < 1e-4,
1127                "layer {layer} prefill K prefix vs extended K max_abs={d:.6}"
1128            );
1129        }
1130    }
1131
1132    #[test]
1133    fn decode_rope_slice_matches_prefill_table_row() {
1134        let cfg = tiny_cfg();
1135        let inv = resolve_inv_freq(&cfg, None);
1136        let (cos_tab, sin_tab) = build_rope_tables(&inv, cfg.max_position_embeddings);
1137        let half = inv.len();
1138        for pos in [3usize, 4, 5] {
1139            let (c, s) = rope_slice(&inv, pos);
1140            let off = pos * half;
1141            let d = max_abs_diff(&c, &cos_tab[off..off + half])
1142                + max_abs_diff(&s, &sin_tab[off..off + half]);
1143            assert!(d < 1e-6, "rope_slice mismatch at pos {pos}: {d}");
1144        }
1145    }
1146
1147    #[test]
1148    fn prefill_kv_export_correct_with_fusion() {
1149        let cfg = tiny_cfg();
1150        let tok = 6u32;
1151        let ids = [1u32, 2, 3, 5, tok];
1152        let opts = kv_export_compile_options(true);
1153        let mut wm_one = synthetic_weights(&cfg);
1154        let one_kv = run_prefill_kv_with_options(&cfg, &mut wm_one, 1, &[tok], &opts);
1155        let mut wm_ext = synthetic_weights(&cfg);
1156        let ext_kv = run_prefill_kv_with_options(&cfg, &mut wm_ext, 5, &ids, &opts);
1157        let kv_dim = cfg.kv_proj_dim();
1158        let d = max_abs_diff(&ext_kv[1][4 * kv_dim..], &one_kv[1][..kv_dim]);
1159        assert!(d < 1e-4, "KV export mismatch with profile fusion: {d:.6}");
1160
1161        let mut wm_default = synthetic_weights(&cfg);
1162        let default_kv =
1163            run_prefill_kv_with_options(&cfg, &mut wm_default, 5, &ids, &CompileOptions::new());
1164        let d_default = max_abs_diff(&default_kv[1][4 * kv_dim..], &one_kv[1][..kv_dim]);
1165        assert!(
1166            d_default < 1e-4,
1167            "KV export mismatch with default fusion (got {d_default:.6})"
1168        );
1169    }
1170
1171    #[test]
1172    fn decode_oneshot_kv_suffix_matches_extended() {
1173        let cfg = tiny_cfg();
1174        let prompt: Vec<u32> = vec![1, 2, 3, 5];
1175        let tok = 6u32;
1176        let mut extended = prompt.clone();
1177        extended.push(tok);
1178
1179        let opts = kv_export_compile_options(false);
1180        let mut wm_ext = synthetic_weights(&cfg);
1181        let ext_kv = run_prefill_kv_with_options(&cfg, &mut wm_ext, 5, &extended, &opts);
1182
1183        let mut wm = synthetic_weights(&cfg);
1184        let mut gn = GemmaGenerator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap();
1185        gn.prefill_get_last_logits(&prompt).unwrap();
1186
1187        let mut wm_d = synthetic_weights(&cfg);
1188        let (graph, params) = build_gemma_decode_graph_sized(&cfg, &mut wm_d, 1, 4).unwrap();
1189        let session = Session::new(Device::Cpu);
1190        let mut compiled = session.compile_with(graph, &opts);
1191        for (n, d) in &params {
1192            compiled.set_param(n, d);
1193        }
1194        let cache = gn.cache.as_ref().unwrap();
1195        let key_strs: Vec<String> = (0..cfg.num_hidden_layers)
1196            .flat_map(|i| [format!("past_k_{i}"), format!("past_v_{i}")])
1197            .collect();
1198        let input_ids = [tok as f32];
1199        let mut inputs: Vec<(&str, &[f32])> = vec![("input_ids", input_ids.as_slice())];
1200        for i in 0..cfg.num_hidden_layers {
1201            inputs.push((&key_strs[2 * i], cache.layers_k[i].as_slice()));
1202            inputs.push((&key_strs[2 * i + 1], cache.layers_v[i].as_slice()));
1203        }
1204        let outputs = compiled.run(&inputs);
1205        let kv_dim = cfg.kv_proj_dim();
1206        let k_dec = &outputs[1][4 * kv_dim..];
1207
1208        let d = max_abs_diff(k_dec, &ext_kv[0][4 * kv_dim..]);
1209        assert!(
1210            d < 1e-3,
1211            "decode oneshot layer0 K suffix vs extended max_abs={d:.6}"
1212        );
1213    }
1214
1215    #[test]
1216    fn decode_logits_match_extended_prefill_after_one_token() {
1217        let cfg = tiny_cfg();
1218        let prompt: Vec<u32> = vec![1, 2, 3, 5];
1219        let tok = 6u32;
1220
1221        let mut extended = prompt.clone();
1222        extended.push(tok);
1223
1224        let mut wm_a = synthetic_weights(&cfg);
1225        let mut gn_a = GemmaGenerator::from_loader(cfg.clone(), &mut wm_a, Device::Cpu).unwrap();
1226        let full = gn_a.prefill_get_last_logits(&extended).unwrap();
1227
1228        let mut wm_b = synthetic_weights(&cfg);
1229        let mut gn_b = GemmaGenerator::from_loader(cfg.clone(), &mut wm_b, Device::Cpu).unwrap();
1230        gn_b.prefill_get_last_logits(&prompt).unwrap();
1231        let inc = gn_b.decode_get_logits(tok).unwrap();
1232
1233        let d = max_abs_diff(&full, &inc);
1234        assert!(d < 1e-2, "decode vs extended prefill max_abs={d:.6}");
1235    }
1236
1237    #[test]
1238    fn cached_second_token_matches_naive() {
1239        let cfg = tiny_cfg();
1240        let prompt: Vec<u32> = vec![1, 2, 3, 5];
1241
1242        let mut wm_n = synthetic_weights(&cfg);
1243        let mut gn_n = GemmaGenerator::from_loader(cfg.clone(), &mut wm_n, Device::Cpu).unwrap();
1244        gn_n.prefill(&prompt);
1245        let n0 = gn_n.step(SampleOpts::greedy()).unwrap();
1246        let n1 = gn_n.step(SampleOpts::greedy()).unwrap();
1247
1248        let mut wm_c = synthetic_weights(&cfg);
1249        let mut gn_c = GemmaGenerator::from_loader(cfg.clone(), &mut wm_c, Device::Cpu).unwrap();
1250        gn_c.prefill(&prompt);
1251        let c = gn_c.generate_cached(2, SampleOpts::greedy()).unwrap();
1252
1253        assert_eq!(c[0], n0, "first generated token");
1254        assert_eq!(c[1], n1, "second generated token (decode step)");
1255    }
1256
1257    #[test]
1258    fn cached_matches_naive_on_greedy() {
1259        // The cached and naive paths must produce the same token
1260        // sequence given the same prompt + opts. This is the
1261        // load-bearing test for the KV-cache implementation: if the
1262        // decode-mode graph, the kernel's Lq!=Lk fix, the cache
1263        // wiring, or the RoPE position-slice is wrong, the sequences
1264        // diverge here.
1265        let cfg = tiny_cfg();
1266        let prompt: Vec<u32> = vec![1, 2, 3, 5];
1267        let steps = 4;
1268
1269        let mut wm_n = synthetic_weights(&cfg);
1270        let mut gn_naive =
1271            GemmaGenerator::from_loader(cfg.clone(), &mut wm_n, Device::Cpu).unwrap();
1272        gn_naive.prefill(&prompt);
1273        let naive_tokens = gn_naive.generate(steps, SampleOpts::greedy()).unwrap();
1274
1275        let mut wm_c = synthetic_weights(&cfg);
1276        let mut gn_cached =
1277            GemmaGenerator::from_loader(cfg.clone(), &mut wm_c, Device::Cpu).unwrap();
1278        gn_cached.prefill(&prompt);
1279        let cached_tokens = gn_cached
1280            .generate_cached(steps, SampleOpts::greedy())
1281            .unwrap();
1282
1283        assert_eq!(
1284            cached_tokens, naive_tokens,
1285            "cached vs naive token mismatch — KV cache or kernel-Lq!=Lk bug"
1286        );
1287    }
1288
1289    #[test]
1290    fn cached_step_advances_cache_invariant() {
1291        let cfg = tiny_cfg();
1292        let mut wm = synthetic_weights(&cfg);
1293        let mut gn = GemmaGenerator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap();
1294        gn.prefill(&[1, 2, 3]);
1295        let _ = gn.step_cached(SampleOpts::greedy()).unwrap();
1296        // After seed: tokens.len() == 4, cache.past_seq == 3 (cache holds prompt).
1297        assert_eq!(gn.tokens().len(), 4);
1298        assert_eq!(gn.cache.as_ref().unwrap().past_len, 3);
1299        let _ = gn.step_cached(SampleOpts::greedy()).unwrap();
1300        // After one decode: tokens.len() == 5, cache.past_seq == 4.
1301        assert_eq!(gn.tokens().len(), 5);
1302        assert_eq!(gn.cache.as_ref().unwrap().past_len, 4);
1303    }
1304
1305    #[test]
1306    fn bucketed_decode_matches_oneshot() {
1307        // The bucketed compile-cache path (padded K/V + custom mask)
1308        // must produce the same token sequence as the one-shot
1309        // path. Load-bearing for the bucketed cache feature: if the
1310        // mask, padding, or output slicing is wrong, sequences
1311        // diverge here.
1312        let cfg = tiny_cfg();
1313        let prompt: Vec<u32> = vec![1, 2, 3, 5];
1314        let steps = 6;
1315
1316        let mut wm_one = synthetic_weights(&cfg);
1317        let mut gn_one =
1318            GemmaGenerator::from_loader(cfg.clone(), &mut wm_one, Device::Cpu).unwrap();
1319        gn_one.prefill(&prompt);
1320        let oneshot_tokens = gn_one.generate_cached(steps, SampleOpts::greedy()).unwrap();
1321
1322        let mut wm_buc = synthetic_weights(&cfg);
1323        let mut gn_buc = GemmaGenerator::from_loader(cfg.clone(), &mut wm_buc, Device::Cpu)
1324            .unwrap()
1325            .with_decode_cache(/*max_past*/ 32);
1326        gn_buc.prefill(&prompt);
1327        let bucketed_tokens = gn_buc.generate_cached(steps, SampleOpts::greedy()).unwrap();
1328
1329        assert_eq!(
1330            bucketed_tokens, oneshot_tokens,
1331            "bucketed-cache decode diverged from one-shot decode — \
1332             mask, padding, or output-slice bug"
1333        );
1334    }
1335
1336    #[test]
1337    fn prefill_compile_cache_does_not_change_output() {
1338        let cfg = tiny_cfg();
1339        let prompt: Vec<u32> = vec![1, 2, 3, 5];
1340        let mut wm_a = synthetic_weights(&cfg);
1341        let mut gn_a = GemmaGenerator::from_loader(cfg.clone(), &mut wm_a, Device::Cpu).unwrap();
1342        gn_a.prefill(&prompt);
1343        let a = gn_a.generate_cached(4, SampleOpts::greedy()).unwrap();
1344
1345        let mut wm_b = synthetic_weights(&cfg);
1346        let mut gn_b = GemmaGenerator::from_loader(cfg.clone(), &mut wm_b, Device::Cpu)
1347            .unwrap()
1348            .with_prefill_cache(/*capacity*/ 4);
1349        gn_b.prefill(&prompt);
1350        let b = gn_b.generate_cached(4, SampleOpts::greedy()).unwrap();
1351
1352        assert_eq!(a, b, "enabling prefill_cache must not change output");
1353    }
1354
1355    #[test]
1356    fn dynamic_decode_matches_oneshot() {
1357        let cfg = tiny_cfg();
1358        let prompt: Vec<u32> = vec![1, 2, 3, 5];
1359        let steps = 6;
1360
1361        let mut wm_one = synthetic_weights(&cfg);
1362        let mut gn_one =
1363            GemmaGenerator::from_loader(cfg.clone(), &mut wm_one, Device::Cpu).unwrap();
1364        gn_one.prefill(&prompt);
1365        let oneshot_tokens = gn_one.generate_cached(steps, SampleOpts::greedy()).unwrap();
1366
1367        let mut wm_dyn = synthetic_weights(&cfg);
1368        let mut gn_dyn = GemmaGenerator::from_loader(cfg.clone(), &mut wm_dyn, Device::Cpu)
1369            .unwrap()
1370            .with_dynamic_decode_cache(/*capacity*/ 8);
1371        gn_dyn.prefill(&prompt);
1372        let dynamic_tokens = gn_dyn.generate_cached(steps, SampleOpts::greedy()).unwrap();
1373
1374        assert_eq!(
1375            dynamic_tokens, oneshot_tokens,
1376            "dynamic past_seq decode diverged from one-shot decode"
1377        );
1378    }
1379
1380    #[test]
1381    fn dynamic_prefill_matches_oneshot() {
1382        let cfg = tiny_cfg();
1383        let prompt: Vec<u32> = vec![1, 2, 3, 5];
1384        let steps = 4;
1385
1386        let mut wm_one = synthetic_weights(&cfg);
1387        let mut gn_one =
1388            GemmaGenerator::from_loader(cfg.clone(), &mut wm_one, Device::Cpu).unwrap();
1389        gn_one.prefill(&prompt);
1390        let oneshot_tokens = gn_one.generate_cached(steps, SampleOpts::greedy()).unwrap();
1391
1392        let mut wm_dyn = synthetic_weights(&cfg);
1393        let mut gn_dyn = GemmaGenerator::from_loader(cfg.clone(), &mut wm_dyn, Device::Cpu)
1394            .unwrap()
1395            .with_dynamic_prefill_cache(/*capacity*/ 8);
1396        gn_dyn.prefill(&prompt);
1397        let dynamic_tokens = gn_dyn.generate_cached(steps, SampleOpts::greedy()).unwrap();
1398
1399        assert_eq!(
1400            dynamic_tokens, oneshot_tokens,
1401            "dynamic seq prefill diverged from one-shot prefill"
1402        );
1403    }
1404
1405    #[test]
1406    fn dynamic_prefill_and_decode_matches_oneshot() {
1407        let cfg = tiny_cfg();
1408        let prompt: Vec<u32> = vec![1, 2, 3, 5];
1409        let steps = 6;
1410
1411        let mut wm_one = synthetic_weights(&cfg);
1412        let mut gn_one =
1413            GemmaGenerator::from_loader(cfg.clone(), &mut wm_one, Device::Cpu).unwrap();
1414        gn_one.prefill(&prompt);
1415        let oneshot_tokens = gn_one.generate_cached(steps, SampleOpts::greedy()).unwrap();
1416
1417        let mut wm_dyn = synthetic_weights(&cfg);
1418        let mut gn_dyn = GemmaGenerator::from_loader(cfg.clone(), &mut wm_dyn, Device::Cpu)
1419            .unwrap()
1420            .with_dynamic_prefill_cache(/*capacity*/ 8)
1421            .with_dynamic_decode_cache(/*capacity*/ 8);
1422        gn_dyn.prefill(&prompt);
1423        let dynamic_tokens = gn_dyn.generate_cached(steps, SampleOpts::greedy()).unwrap();
1424
1425        assert_eq!(
1426            dynamic_tokens, oneshot_tokens,
1427            "dynamic prefill+decode diverged from one-shot path"
1428        );
1429    }
1430
1431    #[test]
1432    fn greedy_is_deterministic_across_runs() {
1433        let cfg = tiny_cfg();
1434        let weights = synthetic_weights(&cfg);
1435        let mk = || {
1436            let mut wm = WeightMap::from_tensors(weights_as_hashmap(&weights));
1437            GemmaGenerator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap()
1438        };
1439        let mut a = mk();
1440        let mut b = mk();
1441        a.prefill(&[1, 2, 3]);
1442        b.prefill(&[1, 2, 3]);
1443        let ta = a.generate(4, SampleOpts::greedy()).unwrap();
1444        let tb = b.generate(4, SampleOpts::greedy()).unwrap();
1445        assert_eq!(ta, tb);
1446    }
1447
1448    fn weights_as_hashmap(wm: &WeightMap) -> HashMap<String, (Vec<f32>, Vec<usize>)> {
1449        // Reconstruct the underlying map by re-running synthetic_weights
1450        // — WeightMap doesn't expose its inner map. Sufficient for the
1451        // determinism test since synthetic_weights is itself
1452        // deterministic.
1453        let _ = wm; // silence unused
1454        let cfg = tiny_cfg();
1455        let mut new = synthetic_weights(&cfg);
1456        let keys: Vec<String> = new.keys().map(|s| s.to_string()).collect();
1457        let mut out = HashMap::new();
1458        for k in keys {
1459            out.insert(k.clone(), new.take(&k).unwrap());
1460        }
1461        out
1462    }
1463}