Skip to main content

rlx_gemma/
generator.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Host-side generation loop for Gemma.
17//!
18//! This is the **naive** generator: each `step()` rebuilds the prefill
19//! graph for the full token history and runs it from scratch
20//! (O(N²) compute over N generated tokens). The API is shaped to
21//! match the upcoming KV-cache version exactly so callers don't have
22//! to change anything when the cached path lands — only the internal
23//! implementation swaps.
24//!
25//! Why ship the naive version first:
26//!   - Establishes the public API contract before the IR/kernel
27//!     changes that the cached version needs land.
28//!   - Lets you run end-to-end generation against a real checkpoint
29//!     today and validate the prefill graph is numerically correct.
30//!   - Provides a reference baseline for the cached version's own
31//!     numerical-parity test (cached vs recompute must match).
32
33use crate::builder::{
34    build_gemma_decode_graph_sized, build_gemma_decode_hir_dynamic_ext,
35    build_gemma_decode_hir_sized_ext, build_gemma_graph_sized_last_logits,
36    build_gemma_graph_sized_last_logits_hidden, build_gemma_prefill_hidden_hir_dynamic_ext,
37    build_gemma_prefill_hir_dynamic_ext,
38};
39use crate::config::GemmaConfig;
40use crate::rope::{resolve_inv_freq, rope_slice};
41use anyhow::{Context, Result};
42use rlx_core::autoregressive::{
43    KvCacheState, kv_from_prefill_outputs_per_layer, run_bucketed_kv_decode_hir_scratch,
44    split_decode_logits_kv,
45};
46use rlx_core::flow_bridge::compile_options_from_profile;
47use rlx_core::gpu_kv::{
48    GpuKvBinding, device_supports_gpu_kv, run_bucketed_kv_decode_gpu_hir, sync_gpu_kv_to_host,
49};
50use rlx_core::weight_loader::WeightLoader;
51use rlx_core::weight_map::WeightMap;
52use rlx_flow::CompileProfile;
53use rlx_ir::DimBinding;
54use rlx_ir::logical_kernel::KernelDispatchConfig;
55use rlx_qwen3::sampling::{SampleOpts, sample_token};
56use rlx_runtime::compile_cache::{
57    BucketedCompileCache, CacheRunInput, CompileCache, DynamicDimCompileCache,
58};
59use rlx_runtime::{CompileOptions, Device, Session};
60use std::collections::HashMap;
61use std::path::Path;
62
63/// Decode compile profile with backend-specific fixes (Metal: optional unfused GQA).
64pub fn decode_profile_for_device(device: Device) -> CompileProfile {
65    metal_decode_profile(device, CompileProfile::gemma_decode())
66}
67
68/// When `RLX_GEMMA_METAL_THUNK_DECODE=1`, compile decode graphs with
69/// `RLX_DISABLE_MPSGRAPH=1` (Metal thunk path). Default: MPSGraph + tier-1
70/// fusion for throughput; use for parity/debug if bucketed decode miscompares.
71fn metal_thunk_decode_requested() -> bool {
72    std::env::var("RLX_GEMMA_METAL_THUNK_DECODE")
73        .is_ok_and(|v| v == "1" || v.eq_ignore_ascii_case("true") || v.eq_ignore_ascii_case("yes"))
74}
75
76pub(crate) fn metal_decode_compile_guard<R, F>(device: Device, decode: bool, f: F) -> R
77where
78    F: FnOnce() -> R,
79{
80    if decode && metal_thunk_decode_requested() {
81        if device == Device::Metal {
82            rlx_ir::env::set("RLX_DISABLE_MPSGRAPH", "1");
83            let out = f();
84            rlx_ir::env::unset("RLX_DISABLE_MPSGRAPH");
85            out
86        } else {
87            f()
88        }
89    } else {
90        f()
91    }
92}
93
94/// When `RLX_GEMMA_METAL_UNFUSED_DECODE=1`, disable tier-1 fusion on Metal decode
95/// (MPSGraph GQA reshape workaround). Default: fusion **on** for throughput.
96fn metal_unfused_decode_requested() -> bool {
97    std::env::var("RLX_GEMMA_METAL_UNFUSED_DECODE")
98        .is_ok_and(|v| v == "1" || v.eq_ignore_ascii_case("true") || v.eq_ignore_ascii_case("yes"))
99}
100
101fn metal_decode_profile(device: Device, mut profile: CompileProfile) -> CompileProfile {
102    if device == Device::Metal && metal_unfused_decode_requested() {
103        profile.fusion.skip = true;
104        profile.backend.metal.skip_fusion = true;
105        profile.backend.metal.unfuse_regions = true;
106    }
107    profile
108}
109
110/// Stateful Gemma generation handle.
111///
112/// Holds the (config, weight bytes, token history) and rebuilds a
113/// prefill graph on each [`step`] call. Cheap to construct after
114/// initial weight load; tokens stay in-memory between calls.
115pub struct GemmaGenerator {
116    cfg: GemmaConfig,
117    /// Map of weight key → (f32 data, shape). Cloned on each step
118    /// into a fresh `WeightMap` because `WeightMap::take` is
119    /// destructive — see the cached-generator notes for the path
120    /// that avoids the clone.
121    weights_cache: HashMap<String, (Vec<f32>, Vec<usize>)>,
122    tokens: Vec<u32>,
123    device: Device,
124    /// Populated lazily on the first `step_cached` call (seeded from
125    /// the prompt via prefill-with-cache); thereafter advanced by each
126    /// decode step.
127    cache: Option<KvCacheState>,
128    /// Per-key LRU compile cache for prefill graphs. Keyed by `seq`.
129    /// Set to `None` to disable (default for new instances; opt in via
130    /// [`GemmaGenerator::with_prefill_cache`]).
131    prefill_compile_cache: Option<CompileCache>,
132    /// Compile prefill once with `sym::SEQ`, specialize per prompt length.
133    prefill_dynamic_cache: Option<DynamicDimCompileCache>,
134    /// Same as above but for fused `prefill_hidden` (multimodal).
135    embed_prefill_compile_cache: Option<CompileCache>,
136    embed_prefill_dynamic_cache: Option<DynamicDimCompileCache>,
137    /// Bucketed compile cache for decode-mode graphs. Each bucket
138    /// holds one compiled graph specialized at its upper-bound
139    /// `past_seq`; the host pads `past_k`/`past_v` and supplies a
140    /// per-step mask so a single bucket serves every `past_seq` in
141    /// its range. Opt in via [`GemmaGenerator::with_decode_cache`].
142    decode_compile_cache: Option<BucketedCompileCache>,
143    decode_dynamic_cache: Option<DynamicDimCompileCache>,
144    /// Resolved RoPE inverse frequencies (includes Llama 3 scaling).
145    inv_freq: Vec<f64>,
146    /// Tier-1 compile profile for prefill graphs.
147    prefill_profile: CompileProfile,
148    /// Tier-1 compile profile for decode graphs.
149    decode_profile: CompileProfile,
150    /// Fused prefill embeddings for the next cached prefill (multimodal).
151    pending_prefill_embeds: Option<Vec<f32>>,
152    pending_prefill_attn_bias: Option<Vec<f32>>,
153    /// GPU-resident K/V on Metal/MLX/CUDA decode (logits-only readback).
154    use_gpu_kv: bool,
155    gpu_kv_binding: GpuKvBinding,
156    /// Reused decode inputs (mask, padded K/V) to avoid per-step allocs.
157    decode_scratch: DecodeKvScratch,
158    decode_inputs: DecodeInputScratch,
159}
160
161/// Reusable mask + rope slices for bucketed decode (no per-step alloc).
162#[derive(Default)]
163struct DecodeInputScratch {
164    mask: Vec<f32>,
165    cos: Vec<f32>,
166    sin: Vec<f32>,
167}
168
169/// Per-session padded K/V upload buffers (resized when bucket upper changes).
170#[derive(Default)]
171struct DecodeKvScratch {
172    padded_k: Vec<Vec<f32>>,
173    padded_v: Vec<Vec<f32>>,
174    bucket_upper: usize,
175}
176
177impl DecodeInputScratch {
178    fn fill_mask(&mut self, past_seq: usize, upper: usize) {
179        if self.mask.len() != upper + 1 {
180            self.mask.resize(upper + 1, 0.0);
181        }
182        for (i, m) in self.mask.iter_mut().enumerate().take(upper + 1) {
183            *m = if i < past_seq || i == upper { 1.0 } else { 0.0 };
184        }
185    }
186
187    fn fill_rope(&mut self, inv_freq: &[f64], pos: usize) {
188        let half = inv_freq.len();
189        self.cos.resize(half, 0.0);
190        self.sin.resize(half, 0.0);
191        for (i, &freq) in inv_freq.iter().enumerate() {
192            let angle = pos as f64 * freq;
193            let (s, c) = angle.sin_cos();
194            self.cos[i] = c as f32;
195            self.sin[i] = s as f32;
196        }
197    }
198}
199
200impl DecodeKvScratch {
201    fn ensure_bucket(&mut self, upper: usize, kv_dims: &[usize]) {
202        if self.bucket_upper == upper && self.padded_k.len() == kv_dims.len() {
203            return;
204        }
205        self.bucket_upper = upper;
206        self.padded_k = kv_dims.iter().map(|&d| vec![0.0_f32; upper * d]).collect();
207        self.padded_v = kv_dims.iter().map(|&d| vec![0.0_f32; upper * d]).collect();
208    }
209}
210
211fn gemma_use_gpu_kv(device: Device) -> bool {
212    if !device_supports_gpu_kv(device) {
213        return false;
214    }
215    match std::env::var("RLX_GEMMA_GPU_KV").ok().as_deref() {
216        Some("0") | Some("false") | Some("no") => false,
217        Some("1") | Some("true") | Some("yes") => true,
218        // Host readback path is faster on Gemma Metal today; Whisper/Qwen3-TTS differ.
219        _ => false,
220    }
221}
222
223impl GemmaGenerator {
224    /// Construct from any [`WeightLoader`] — drains it into an
225    /// internal cache so the loader is free after this call.
226    pub fn from_loader(
227        cfg: GemmaConfig,
228        loader: &mut dyn WeightLoader,
229        device: Device,
230    ) -> Result<Self> {
231        let keys = loader.remaining_keys();
232        // Capture the arch up front so the cache-key normalization can
233        // pick the gemma2 reverse alias (4 distinct per-layer norms)
234        // over the generic Llama-flavored one (2 norms, ambiguous on
235        // `ffn_norm`). Owned string so we don't hold a borrow across
236        // the mutable `loader.take` calls below.
237        let arch_hint: Option<String> = loader.arch_hint().map(|s| s.to_string());
238        let mut weights_cache = HashMap::with_capacity(keys.len());
239        for k in keys {
240            let v = loader
241                .take(&k)
242                .with_context(|| format!("draining weight {k}"))?;
243            // Normalize the cache key to the safetensors / HuggingFace
244            // naming convention so subsequent builder calls that ask
245            // for `model.embed_tokens.weight` (the canonical name baked
246            // into the gemma builder) hit the cache whether the
247            // loader was safetensors-native or GGUF-native.
248            let canonical = match arch_hint.as_deref() {
249                Some(a) => rlx_core::weight_loader::gguf_to_hf_name_for_arch(&k, a)
250                    .unwrap_or_else(|| k.clone()),
251                None => rlx_core::weight_loader::gguf_to_hf_name(&k).unwrap_or_else(|| k.clone()),
252            };
253            weights_cache.insert(canonical, v);
254        }
255        let rope_factors = weights_cache
256            .get("rope_freqs.weight")
257            .map(|(d, _)| d.as_slice());
258        let inv_freq = resolve_inv_freq(&cfg, rope_factors);
259        Ok(Self {
260            cfg,
261            weights_cache,
262            tokens: Vec::new(),
263            device,
264            cache: None,
265            prefill_compile_cache: None,
266            prefill_dynamic_cache: None,
267            embed_prefill_compile_cache: None,
268            embed_prefill_dynamic_cache: None,
269            decode_compile_cache: None,
270            decode_dynamic_cache: None,
271            inv_freq,
272            prefill_profile: CompileProfile::gemma_prefill(),
273            decode_profile: metal_decode_profile(device, CompileProfile::gemma_decode()),
274            pending_prefill_embeds: None,
275            pending_prefill_attn_bias: None,
276            use_gpu_kv: gemma_use_gpu_kv(device),
277            gpu_kv_binding: GpuKvBinding::default(),
278            decode_scratch: DecodeKvScratch::default(),
279            decode_inputs: DecodeInputScratch::default(),
280        })
281    }
282
283    fn reset_gpu_kv_binding(&mut self) {
284        self.gpu_kv_binding = GpuKvBinding::default();
285    }
286
287    /// Like [`Self::from_loader`] but loads tier-1 profiles from
288    /// `gemma.rlx.toml` in the weights directory when present.
289    pub fn from_loader_at(
290        cfg: GemmaConfig,
291        loader: &mut dyn WeightLoader,
292        device: Device,
293        weights_path: &Path,
294    ) -> Result<Self> {
295        let mut g = Self::from_loader(cfg, loader, device)?;
296        g.prefill_profile = crate::gemma_profile_near_weights(weights_path, false);
297        g.decode_profile = metal_decode_profile(
298            device,
299            crate::gemma_profile_near_weights(weights_path, true),
300        );
301        Ok(g)
302    }
303
304    /// Override tier-1 compile profiles explicitly.
305    pub fn with_compile_profiles(
306        mut self,
307        prefill: CompileProfile,
308        decode: CompileProfile,
309    ) -> Self {
310        self.prefill_profile = prefill;
311        self.decode_profile = metal_decode_profile(self.device, decode);
312        self
313    }
314
315    pub fn prefill_profile(&self) -> &CompileProfile {
316        &self.prefill_profile
317    }
318
319    pub fn decode_profile(&self) -> &CompileProfile {
320        &self.decode_profile
321    }
322
323    fn profile_compile_options(&self, decode: bool) -> CompileOptions {
324        let profile = if decode {
325            &self.decode_profile
326        } else {
327            &self.prefill_profile
328        };
329        compile_options_from_profile(profile, self.device, KernelDispatchConfig::default())
330    }
331
332    fn compile_graph_profiled(
333        &self,
334        session: &Session,
335        graph: rlx_ir::Graph,
336    ) -> Result<rlx_runtime::CompiledGraph> {
337        let opts = self.profile_compile_options(false);
338        Ok(session.compile_with(graph, &opts))
339    }
340
341    fn compile_graph_profiled_decode(
342        &self,
343        session: &Session,
344        graph: rlx_ir::Graph,
345    ) -> Result<rlx_runtime::CompiledGraph> {
346        Ok(metal_decode_compile_guard(self.device, true, || {
347            session.compile_with(graph, &self.profile_compile_options(true))
348        }))
349    }
350
351    /// Enable the prefill compile cache with the given LRU capacity.
352    /// Useful when the same prompt length is used across multiple
353    /// generation runs — the second + Nth run skip the compile +
354    /// param-attach roundtrip (~30-50ms per call on CPU).
355    pub fn with_prefill_cache(mut self, capacity: usize) -> Self {
356        self.prefill_compile_cache = Some(CompileCache::new(self.device, capacity));
357        self.embed_prefill_compile_cache = Some(CompileCache::new(self.device, capacity));
358        self.prefill_dynamic_cache = None;
359        self.embed_prefill_dynamic_cache = None;
360        self
361    }
362
363    /// Compile prefill once with `sym::SEQ`, specialize per prompt length.
364    pub fn with_dynamic_prefill_cache(mut self, capacity: usize) -> Self {
365        self.prefill_dynamic_cache = Some(DynamicDimCompileCache::new(self.device, capacity));
366        self.embed_prefill_dynamic_cache = Some(DynamicDimCompileCache::new(self.device, capacity));
367        self.prefill_compile_cache = None;
368        self.embed_prefill_compile_cache = None;
369        self
370    }
371
372    /// Enable the bucketed decode compile cache spanning past-seq
373    /// values in `[1, max_past]`. Buckets are power-of-two
374    /// `[1..2, 2..3, 3..5, 5..9, 9..17, …]`. Each bucket compiles
375    /// one graph at its upper bound; a steady-state generation loop
376    /// across `N` tokens compiles `O(log N)` graphs instead of `N`.
377    ///
378    /// Padding compute waste is bounded at 2×: actual `past_seq` is
379    /// at least half the bucket's upper bound (except possibly the
380    /// smallest bucket).
381    pub fn with_decode_cache(mut self, max_past: usize) -> Self {
382        let cache = BucketedCompileCache::power_of_two_ladder(
383            self.device,
384            /*min*/ 1,
385            max_past.max(1) as u64,
386        );
387        self.decode_compile_cache = Some(cache);
388        self.decode_dynamic_cache = None;
389        self
390    }
391
392    /// Compile decode once with `sym::PAST_SEQ`, specialize per prefix length.
393    pub fn with_dynamic_decode_cache(mut self, capacity: usize) -> Self {
394        self.decode_dynamic_cache = Some(DynamicDimCompileCache::new(self.device, capacity));
395        self.decode_compile_cache = None;
396        self
397    }
398
399    fn inference_dynamic_decode() -> bool {
400        std::env::var("RLX_GEMMA_DYNAMIC_DECODE").is_ok_and(|v| {
401            v == "1" || v.eq_ignore_ascii_case("true") || v.eq_ignore_ascii_case("yes")
402        })
403    }
404
405    /// Production inference caches: dynamic prefill (+ multimodal embed prefill)
406    /// and bucketed decode. Bucketed decode avoids the per-step specialize
407    /// overhead of `RLX_GEMMA_DYNAMIC_DECODE=1` (experimental; often much slower on Metal).
408    pub fn with_inference_caches(mut self, max_seq: usize) -> Self {
409        // Match decode buckets to the session horizon (+slack). Avoid a global
410        // 256-token floor — padded bucket attention on Metal is costly for chat.
411        let decode_horizon = max_seq.saturating_add(16).max(32);
412        self = self.with_dynamic_prefill_cache(16);
413        if Self::inference_dynamic_decode() {
414            self.with_dynamic_decode_cache(32)
415        } else {
416            self.with_decode_cache(decode_horizon)
417        }
418    }
419
420    /// Wait for in-flight Metal command buffers on all cached graphs.
421    /// Call between heavy inference phases to avoid MPS lifecycle warnings.
422    pub fn sync_device(&mut self) {
423        if let Some(c) = &mut self.prefill_compile_cache {
424            c.sync_all();
425        }
426        if let Some(c) = &mut self.embed_prefill_compile_cache {
427            c.sync_all();
428        }
429        if let Some(c) = &mut self.prefill_dynamic_cache {
430            c.sync_all();
431        }
432        if let Some(c) = &mut self.embed_prefill_dynamic_cache {
433            c.sync_all();
434        }
435        if let Some(c) = &mut self.decode_compile_cache {
436            c.sync_all();
437        }
438        if let Some(c) = &mut self.decode_dynamic_cache {
439            c.sync_all();
440        }
441        rlx_runtime::device_ext::drain_device(self.device);
442    }
443
444    /// Convenience: load weights from a safetensors or GGUF path
445    /// (dispatch by extension; see `rlx_core::weight_loader::load_from_path`).
446    pub fn from_path(cfg: GemmaConfig, path: &str, device: Device) -> Result<Self> {
447        let mut loader = rlx_core::weight_loader::load_from_path(path)?;
448        Self::from_loader(cfg, loader.as_mut(), device)
449    }
450
451    /// Same as [`from_path`] but with MTP-head visibility control.
452    /// When `include_mtp=true` and the file is GGUF, MTP weights are
453    /// drained into the generator's cache alongside the base
454    /// weights. The base inference path still ignores them — they
455    /// sit in cache for a future MTP-aware decoder. Non-GGUF formats
456    /// silently ignore the flag (safetensors files publish all
457    /// tensors uniformly; downstream code distinguishes by name).
458    pub fn from_path_with_mtp(
459        cfg: GemmaConfig,
460        path: &str,
461        device: Device,
462        include_mtp: bool,
463    ) -> Result<Self> {
464        // Branch on extension so we can flip the GGUF-specific
465        // visibility option. Safetensors has no equivalent — it
466        // doesn't isolate MTP tensors at the loader level.
467        if path.ends_with(".gguf") {
468            let mut gguf = rlx_core::weight_loader::GgufLoader::from_file(path)?;
469            gguf.include_mtp(include_mtp);
470            Self::from_loader(cfg, &mut gguf, device)
471        } else {
472            Self::from_path(cfg, path, device)
473        }
474    }
475
476    /// Replace the token history with `prompt_ids`. Does not run the
477    /// model — the next [`step`] call processes the full sequence.
478    /// Clears any KV cache from a prior generation.
479    pub fn prefill(&mut self, prompt_ids: &[u32]) {
480        self.tokens.clear();
481        self.tokens.extend_from_slice(prompt_ids);
482        self.cache = None;
483        self.reset_gpu_kv_binding();
484    }
485
486    /// Like [`prefill`], but the next cached prefill uses fused
487    /// `inputs_embeds` (`prefill_hidden`) instead of token lookup.
488    pub fn prefill_from_embeds(
489        &mut self,
490        prompt_ids: &[u32],
491        embeds: &[f32],
492        attn_bias: Option<Vec<f32>>,
493    ) -> Result<()> {
494        let h = self.cfg.hidden_size;
495        if embeds.len() != prompt_ids.len() * h {
496            anyhow::bail!(
497                "prefill_from_embeds: embeds len {} != {} tokens × hidden {}",
498                embeds.len(),
499                prompt_ids.len(),
500                h
501            );
502        }
503        if let Some(ref bias) = attn_bias {
504            let seq = prompt_ids.len();
505            let nh = self.cfg.num_attention_heads;
506            let expected = seq * seq * nh;
507            if bias.len() != expected {
508                anyhow::bail!(
509                    "prefill_from_embeds: attn_bias len {} != batch×heads×seq² ({expected})",
510                    bias.len()
511                );
512            }
513        }
514        self.prefill(prompt_ids);
515        self.pending_prefill_embeds = Some(embeds.to_vec());
516        self.pending_prefill_attn_bias = attn_bias;
517        Ok(())
518    }
519
520    /// Weight table for CPU-side embedding / fusion helpers.
521    pub fn weights_cache(&self) -> &HashMap<String, (Vec<f32>, Vec<usize>)> {
522        &self.weights_cache
523    }
524
525    /// Run one prefill over the current token history and sample the
526    /// next token. The sampled token is appended to the history and
527    /// returned. Call repeatedly to generate.
528    pub fn step(&mut self, opts: SampleOpts) -> Result<u32> {
529        if self.tokens.is_empty() {
530            anyhow::bail!("step() called with empty token history; call prefill() first");
531        }
532        let seq = self.tokens.len();
533        let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
534        let (graph, params) = build_gemma_graph_sized_last_logits(
535            &self.cfg, &mut wm, /*batch*/ 1, seq, /*with_kv_outputs*/ false,
536        )?;
537        let session = Session::new(self.device);
538        let mut compiled = self.compile_graph_profiled(&session, graph)?;
539        for (name, data) in &params {
540            compiled.set_param(name, data);
541        }
542        let ids_f32: Vec<f32> = self.tokens.iter().map(|&i| i as f32).collect();
543        let outputs = compiled.run(&[("input_ids", ids_f32.as_slice())]);
544        let logits = outputs
545            .into_iter()
546            .next()
547            .context("compiled.run returned no outputs")?;
548
549        let vocab = self.cfg.vocab_size;
550        let expected = vocab;
551        if logits.len() < expected {
552            anyhow::bail!(
553                "logits length {} < expected {} (last logits, seq {seq}, vocab {vocab})",
554                logits.len(),
555                expected
556            );
557        }
558        // Last-logits graph returns [B=1, 1, vocab].
559        let last_row = &logits[..vocab];
560        let tok = sample_token(last_row, opts) as u32;
561        self.tokens.push(tok);
562        Ok(tok)
563    }
564
565    /// Run `n` steps and return the newly generated token ids
566    /// (excludes the prefill prompt).
567    pub fn generate(&mut self, n: usize, opts: SampleOpts) -> Result<Vec<u32>> {
568        let start = self.tokens.len();
569        for _ in 0..n {
570            self.step(opts)?;
571        }
572        Ok(self.tokens[start..].to_vec())
573    }
574
575    /// Cached step: O(L) per token instead of O(L²). First call seeds
576    /// the KV cache from the prompt via prefill-with-cache; subsequent
577    /// calls run the decode-mode graph on just the last token + cached
578    /// past. Output is bit-identical to [`step`] modulo reduction
579    /// order in the SDPA kernel.
580    ///
581    /// Invariant after each call: `cache.past_seq == tokens.len() - 1`
582    /// (the just-sampled token is appended but not yet in the cache;
583    /// it becomes the input for the next decode step).
584    pub fn step_cached(&mut self, opts: SampleOpts) -> Result<u32> {
585        if self.tokens.is_empty() {
586            anyhow::bail!("step_cached() called with empty token history; call prefill() first");
587        }
588        if self.cache.is_none() {
589            // The seed runs prefill, populates the cache, samples from
590            // the last position, and appends the token. Return that
591            // token directly — no decode step on this call.
592            let tok = self.seed_cache_from_prompt(opts)?;
593            return Ok(tok);
594        }
595        let cache = self.cache.as_ref().unwrap();
596        let past_seq = cache.past_len;
597        if self.tokens.len() <= past_seq {
598            anyhow::bail!(
599                "cache invariant violated: tokens.len() {} <= past_len {}",
600                self.tokens.len(),
601                past_seq
602            );
603        }
604        let input_tok = self.tokens[past_seq];
605
606        let (logits, new_k, new_v) = if self.decode_dynamic_cache.is_some() {
607            self.decode_step_dynamic(past_seq, input_tok)?
608        } else if self.decode_compile_cache.is_some()
609            && self
610                .decode_compile_cache
611                .as_ref()
612                .unwrap()
613                .bucket_for(past_seq as u64)
614                .is_some()
615        {
616            self.decode_step_bucketed(past_seq, input_tok)?
617        } else {
618            self.decode_step_oneshot(past_seq, input_tok)?
619        };
620
621        let cache_mut = self.cache.as_mut().unwrap();
622        cache_mut.past_len = past_seq + 1;
623        cache_mut.layers_k = new_k;
624        cache_mut.layers_v = new_v;
625
626        let vocab = self.cfg.vocab_size;
627        if logits.len() != vocab {
628            anyhow::bail!("decode logits length {} != vocab {}", logits.len(), vocab);
629        }
630        let tok = sample_token(&logits, opts) as u32;
631        self.tokens.push(tok);
632        Ok(tok)
633    }
634
635    /// Decode path that compiles a fresh graph for the exact `past_seq`
636    /// every call. Slower but always-correct fallback.
637    #[allow(clippy::type_complexity)]
638    fn decode_step_oneshot(
639        &mut self,
640        past_seq: usize,
641        input_tok: u32,
642    ) -> Result<(Vec<f32>, Vec<Vec<f32>>, Vec<Vec<f32>>)> {
643        let cache = self.cache.as_ref().unwrap();
644
645        let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
646        let (graph, params) =
647            build_gemma_decode_graph_sized(&self.cfg, &mut wm, /*batch*/ 1, past_seq)?;
648        let session = Session::new(self.device);
649        let mut compiled = self.compile_graph_profiled_decode(&session, graph)?;
650        for (name, data) in &params {
651            compiled.set_param(name, data);
652        }
653
654        let input_ids_f32 = [input_tok as f32];
655        let key_strs: Vec<String> = (0..self.cfg.num_hidden_layers)
656            .flat_map(|i| [format!("past_k_{i}"), format!("past_v_{i}")])
657            .collect();
658        let mut inputs: Vec<(&str, &[f32])> =
659            Vec::with_capacity(1 + 2 * self.cfg.num_hidden_layers);
660        inputs.push(("input_ids", input_ids_f32.as_slice()));
661        for i in 0..self.cfg.num_hidden_layers {
662            inputs.push((&key_strs[2 * i], cache.layers_k[i].as_slice()));
663            inputs.push((&key_strs[2 * i + 1], cache.layers_v[i].as_slice()));
664        }
665
666        let outputs = compiled.run(&inputs);
667        split_decode_logits_kv(outputs, self.cfg.num_hidden_layers)
668    }
669
670    #[allow(clippy::type_complexity)]
671    fn decode_step_dynamic(
672        &mut self,
673        past_seq: usize,
674        input_tok: u32,
675    ) -> Result<(Vec<f32>, Vec<Vec<f32>>, Vec<Vec<f32>>)> {
676        let cache = self.cache.as_ref().unwrap();
677        let binding = DimBinding::batch_past_seq(1, past_seq);
678        let opts = self
679            .profile_compile_options(true)
680            .dim_binding(binding.clone());
681        let cache_dyn = self
682            .decode_dynamic_cache
683            .as_mut()
684            .ok_or_else(|| anyhow::anyhow!("dynamic decode without cache"))?;
685        let needs_upload = !cache_dyn.contains(past_seq as u64);
686        let cfg = self.cfg.clone();
687        let weights_cache = self.weights_cache.clone();
688        let max_past = self.cfg.max_position_embeddings;
689        let compiled = metal_decode_compile_guard(self.device, true, || {
690            cache_dyn.get_or_specialize(
691                past_seq as u64,
692                &binding,
693                || {
694                    let mut wm = WeightMap::from_tensors(weights_cache);
695                    build_gemma_decode_hir_dynamic_ext(&cfg, &mut wm, 1, max_past)
696                        .expect("dynamic decode HIR")
697                        .0
698                },
699                &opts,
700            )
701        })?;
702        if needs_upload {
703            let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
704            let (_, params) = build_gemma_decode_hir_dynamic_ext(&self.cfg, &mut wm, 1, max_past)?;
705            for (name, data) in &params {
706                compiled.set_param(name, data);
707            }
708        }
709
710        let (cos, sin) = compute_rope_slice(&self.inv_freq, past_seq);
711        let input_ids_f32 = [input_tok as f32];
712        let key_strs: Vec<String> = (0..self.cfg.num_hidden_layers)
713            .flat_map(|i| [format!("past_k_{i}"), format!("past_v_{i}")])
714            .collect();
715        let mut inputs: Vec<(&str, &[f32])> =
716            Vec::with_capacity(3 + 2 * self.cfg.num_hidden_layers);
717        inputs.push(("input_ids", input_ids_f32.as_slice()));
718        inputs.push(("rope_cos", cos.as_slice()));
719        inputs.push(("rope_sin", sin.as_slice()));
720        for i in 0..self.cfg.num_hidden_layers {
721            inputs.push((&key_strs[2 * i], cache.layers_k[i].as_slice()));
722            inputs.push((&key_strs[2 * i + 1], cache.layers_v[i].as_slice()));
723        }
724        let outputs = compiled.run(&inputs);
725        split_decode_logits_kv(outputs, self.cfg.num_hidden_layers)
726    }
727
728    #[allow(clippy::type_complexity)]
729    fn decode_step_bucketed(
730        &mut self,
731        past_seq: usize,
732        input_tok: u32,
733    ) -> Result<(Vec<f32>, Vec<Vec<f32>>, Vec<Vec<f32>>)> {
734        let kv_dims = self.per_layer_kv_dims();
735        let n_layers = self.cfg.num_hidden_layers;
736        let decode_opts = self.profile_compile_options(true);
737        let upper = self
738            .decode_compile_cache
739            .as_ref()
740            .and_then(|cache_dec| {
741                cache_dec.bucket_for(past_seq as u64).map(|idx| {
742                    cache_dec
743                        .buckets()
744                        .nth(idx)
745                        .map(|r| (r.end - 1) as usize)
746                        .unwrap_or(past_seq)
747                })
748            })
749            .unwrap_or(past_seq);
750
751        self.decode_scratch.ensure_bucket(upper, &kv_dims);
752        self.decode_inputs.fill_mask(past_seq, upper);
753        self.decode_inputs.fill_rope(&self.inv_freq, past_seq);
754
755        let input_ids_f32 = [input_tok as f32];
756        let fixed = [
757            CacheRunInput {
758                name: "input_ids",
759                data: &input_ids_f32,
760                row_inner: None,
761            },
762            CacheRunInput {
763                name: "rope_cos",
764                data: &self.decode_inputs.cos,
765                row_inner: None,
766            },
767            CacheRunInput {
768                name: "rope_sin",
769                data: &self.decode_inputs.sin,
770                row_inner: None,
771            },
772            CacheRunInput {
773                name: "mask",
774                data: &self.decode_inputs.mask,
775                row_inner: None,
776            },
777        ];
778
779        if self.use_gpu_kv && self.decode_compile_cache.is_some() {
780            let key = past_seq as u64;
781            let upper_u = upper as u64;
782            let prev_upper = self.gpu_kv_binding.upper;
783            let bucket_changed = prev_upper != 0 && prev_upper != upper_u;
784            let handles_live = self
785                .decode_compile_cache
786                .as_mut()
787                .and_then(|c| c.compiled_for_key_mut(key))
788                .map(|cg| cg.has_gpu_handle("past_k_0"))
789                .unwrap_or(false);
790            let refresh_kv = matches!(self.device, Device::Gpu | Device::Metal)
791                || bucket_changed
792                || !handles_live;
793            let cfg = self.cfg.clone();
794            let weights = self.weights_cache.clone();
795            let logits = {
796                let cache_dec = self.decode_compile_cache.as_mut().unwrap();
797                let cache_mut = self.cache.as_mut().unwrap();
798                metal_decode_compile_guard(self.device, true, || {
799                    run_bucketed_kv_decode_gpu_hir(
800                        cache_dec,
801                        key,
802                        past_seq,
803                        cache_mut,
804                        &mut self.gpu_kv_binding,
805                        self.cfg.kv_proj_dim(),
806                        n_layers,
807                        &fixed,
808                        move |upper| {
809                            let mut wm = WeightMap::from_tensors(weights.clone());
810                            build_gemma_decode_hir_sized_ext(&cfg, &mut wm, 1, upper as usize, true)
811                                .expect("gemma bucketed decode HIR")
812                        },
813                        &decode_opts,
814                        refresh_kv,
815                    )
816                })?
817            };
818            if let Some(compiled) = self
819                .decode_compile_cache
820                .as_mut()
821                .and_then(|c| c.compiled_for_key_mut(key))
822            {
823                let cache_mut = self.cache.as_mut().unwrap();
824                sync_gpu_kv_to_host(compiled, cache_mut, self.cfg.kv_proj_dim(), n_layers)?;
825            }
826            let next_key = (past_seq + 1) as u64;
827            let next_upper = self
828                .decode_compile_cache
829                .as_ref()
830                .and_then(|cache| {
831                    cache
832                        .bucket_for(next_key)
833                        .and_then(|idx| cache.buckets().nth(idx).map(|r| (r.end - 1) as usize))
834                })
835                .unwrap_or(upper);
836            if next_upper != upper {
837                self.reset_gpu_kv_binding();
838            }
839            let cache_mut = self.cache.as_ref().unwrap();
840            let new_k = cache_mut.layers_k.clone();
841            let new_v = cache_mut.layers_v.clone();
842            return Ok((logits, new_k, new_v));
843        }
844
845        let cfg = self.cfg.clone();
846        let weights = self.weights_cache.clone();
847        let cache_dec = self.decode_compile_cache.as_mut().unwrap();
848        let kv_cache = self.cache.as_ref().unwrap();
849        let DecodeKvScratch {
850            padded_k, padded_v, ..
851        } = &mut self.decode_scratch;
852        metal_decode_compile_guard(self.device, true, || {
853            run_bucketed_kv_decode_hir_scratch(
854                cache_dec,
855                past_seq,
856                kv_cache,
857                &kv_dims,
858                n_layers,
859                padded_k,
860                padded_v,
861                &fixed,
862                |upper| {
863                    let mut wm = WeightMap::from_tensors(weights.clone());
864                    build_gemma_decode_hir_sized_ext(&cfg, &mut wm, 1, upper as usize, true)
865                        .expect("gemma bucketed decode HIR")
866                },
867                &decode_opts,
868            )
869        })
870    }
871
872    /// Run prefill-with-cache and return the raw outputs. Uses the
873    /// LRU `CompileCache` when enabled; otherwise compiles fresh each
874    /// call. Keyed by `seq` because graph shape is seq-specialized.
875    #[allow(clippy::unnecessary_unwrap)]
876    fn run_prefill_with_cache(
877        &mut self,
878        batch: usize,
879        seq: usize,
880        ids_f32: &[f32],
881    ) -> Result<Vec<Vec<f32>>> {
882        if self.prefill_dynamic_cache.is_some() {
883            let binding = DimBinding::batch_seq(batch, seq);
884            let opts = compile_options_from_profile(
885                &self.prefill_profile,
886                self.device,
887                KernelDispatchConfig::default(),
888            )
889            .dim_binding(binding.clone());
890            let cache = self.prefill_dynamic_cache.as_mut().expect("checked");
891            let needs_upload = !cache.contains(seq as u64);
892            let cfg = self.cfg.clone();
893            let weights_cache = self.weights_cache.clone();
894            let max_seq = self.cfg.max_position_embeddings;
895            let compiled = cache.get_or_specialize(
896                seq as u64,
897                &binding,
898                || {
899                    let mut wm = WeightMap::from_tensors(weights_cache);
900                    build_gemma_prefill_hir_dynamic_ext(&cfg, &mut wm, batch, max_seq, true)
901                        .expect("dynamic prefill HIR")
902                        .0
903                },
904                &opts,
905            )?;
906            if needs_upload {
907                let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
908                let (_, params) =
909                    build_gemma_prefill_hir_dynamic_ext(&self.cfg, &mut wm, batch, max_seq, true)?;
910                for (name, data) in &params {
911                    compiled.set_param(name, data);
912                }
913            }
914            let last_idx = vec![(seq - 1) as f32];
915            Ok(compiled.run(&[("input_ids", ids_f32), ("last_token_idx", &last_idx)]))
916        } else if self.prefill_compile_cache.is_some() {
917            let key = ((batch as u64) << 32) | (seq as u64);
918            let opts = self.profile_compile_options(false);
919            if !self.prefill_compile_cache.as_ref().unwrap().contains(key) {
920                let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
921                let (graph, params) = build_gemma_graph_sized_last_logits(
922                    &self.cfg, &mut wm, batch, seq, /*with_kv_outputs*/ true,
923                )?;
924                {
925                    let compiled = self
926                        .prefill_compile_cache
927                        .as_mut()
928                        .unwrap()
929                        .get_or_compile_with_options(key, || graph, &opts);
930                    for (name, data) in &params {
931                        compiled.set_param(name, data);
932                    }
933                }
934            }
935            let compiled = self
936                .prefill_compile_cache
937                .as_mut()
938                .unwrap()
939                .get_or_compile_with_options(key, || unreachable!("just populated above"), &opts);
940            Ok(compiled.run(&[("input_ids", ids_f32)]))
941        } else {
942            let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
943            let (graph, params) = build_gemma_graph_sized_last_logits(
944                &self.cfg, &mut wm, batch, seq, /*with_kv_outputs*/ true,
945            )?;
946            let session = Session::new(self.device);
947            let opts = self.profile_compile_options(false);
948            let mut compiled = session.compile_with(graph, &opts);
949            for (name, data) in &params {
950                compiled.set_param(name, data);
951            }
952            Ok(compiled.run(&[("input_ids", ids_f32)]))
953        }
954    }
955
956    fn run_prefill_hidden_with_cache(
957        &mut self,
958        batch: usize,
959        seq: usize,
960        hidden: &[f32],
961        attn_bias: Option<&[f32]>,
962    ) -> Result<Vec<Vec<f32>>> {
963        if self.cfg.use_bidirectional_vision() && attn_bias.is_none() {
964            anyhow::bail!(
965                "multimodal prefill requires attn_bias when use_bidirectional_attention=vision"
966            );
967        }
968        let mut inputs: Vec<(&str, &[f32])> = vec![("prefill_hidden", hidden)];
969        if let Some(bias) = attn_bias {
970            inputs.push(("attn_bias", bias));
971        }
972        let embed_compile_opts = self.profile_compile_options(false);
973        if let Some(cache) = &mut self.embed_prefill_dynamic_cache {
974            let binding = DimBinding::batch_seq(batch, seq);
975            let opts = compile_options_from_profile(
976                &self.prefill_profile,
977                self.device,
978                KernelDispatchConfig::default(),
979            )
980            .dim_binding(binding.clone());
981            let needs_upload = !cache.contains(seq as u64);
982            let cfg = self.cfg.clone();
983            let weights_cache = self.weights_cache.clone();
984            let max_seq = self.cfg.max_position_embeddings;
985            let compiled = cache.get_or_specialize(
986                seq as u64,
987                &binding,
988                || {
989                    let mut wm = WeightMap::from_tensors(weights_cache);
990                    build_gemma_prefill_hidden_hir_dynamic_ext(&cfg, &mut wm, batch, max_seq, true)
991                        .expect("dynamic hidden prefill HIR")
992                        .0
993                },
994                &opts,
995            )?;
996            if needs_upload {
997                let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
998                let (_, params) = build_gemma_prefill_hidden_hir_dynamic_ext(
999                    &self.cfg, &mut wm, batch, max_seq, true,
1000                )?;
1001                for (name, data) in &params {
1002                    compiled.set_param(name, data);
1003                }
1004            }
1005            let last_idx = vec![(seq - 1) as f32];
1006            let mut dyn_inputs = inputs.clone();
1007            dyn_inputs.push(("last_token_idx", &last_idx));
1008            Ok(compiled.run(&dyn_inputs))
1009        } else if let Some(cache) = &mut self.embed_prefill_compile_cache {
1010            let key = ((batch as u64) << 32) | (seq as u64);
1011            let opts = &embed_compile_opts;
1012            if !cache.contains(key) {
1013                let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
1014                let (graph, params) = build_gemma_graph_sized_last_logits_hidden(
1015                    &self.cfg, &mut wm, batch, seq, true,
1016                )?;
1017                {
1018                    let compiled = cache.get_or_compile_with_options(key, || graph, opts);
1019                    for (name, data) in &params {
1020                        compiled.set_param(name, data);
1021                    }
1022                }
1023            }
1024            let compiled = cache.get_or_compile_with_options(
1025                key,
1026                || unreachable!("just populated above"),
1027                opts,
1028            );
1029            Ok(compiled.run(&inputs))
1030        } else {
1031            let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
1032            let (graph, params) =
1033                build_gemma_graph_sized_last_logits_hidden(&self.cfg, &mut wm, batch, seq, true)?;
1034            let session = Session::new(self.device);
1035            let opts = self.profile_compile_options(false);
1036            let mut compiled = session.compile_with(graph, &opts);
1037            for (name, data) in &params {
1038                compiled.set_param(name, data);
1039            }
1040            Ok(compiled.run(&inputs))
1041        }
1042    }
1043
1044    /// Run `n` cached steps after [`prefill_from_embeds`].
1045    pub fn generate_from_embeds(
1046        &mut self,
1047        prompt_ids: &[u32],
1048        embeds: &[f32],
1049        n: usize,
1050        opts: SampleOpts,
1051    ) -> Result<Vec<u32>> {
1052        self.generate_from_embeds_with_bias(prompt_ids, embeds, None, n, opts)
1053    }
1054
1055    pub fn generate_from_embeds_with_bias(
1056        &mut self,
1057        prompt_ids: &[u32],
1058        embeds: &[f32],
1059        attn_bias: Option<Vec<f32>>,
1060        n: usize,
1061        opts: SampleOpts,
1062    ) -> Result<Vec<u32>> {
1063        self.prefill_from_embeds(prompt_ids, embeds, attn_bias)?;
1064        self.generate_cached(n, opts)
1065    }
1066
1067    /// Streaming variant of [`generate_from_embeds`].
1068    pub fn generate_from_embeds_with(
1069        &mut self,
1070        prompt_ids: &[u32],
1071        embeds: &[f32],
1072        n: usize,
1073        opts: SampleOpts,
1074        on_token: impl FnMut(u32),
1075    ) -> Result<Vec<u32>> {
1076        self.generate_from_embeds_with_bias_and_callback(
1077            prompt_ids, embeds, None, n, opts, on_token,
1078        )
1079    }
1080
1081    pub fn generate_from_embeds_with_bias_and_callback(
1082        &mut self,
1083        prompt_ids: &[u32],
1084        embeds: &[f32],
1085        attn_bias: Option<Vec<f32>>,
1086        n: usize,
1087        opts: SampleOpts,
1088        on_token: impl FnMut(u32),
1089    ) -> Result<Vec<u32>> {
1090        self.prefill_from_embeds(prompt_ids, embeds, attn_bias)?;
1091        self.generate_cached_with(n, opts, on_token)
1092    }
1093
1094    /// Run `n` cached steps and return the newly generated tokens.
1095    pub fn generate_cached(&mut self, n: usize, opts: SampleOpts) -> Result<Vec<u32>> {
1096        self.generate_cached_with(n, opts, |_| {})
1097    }
1098
1099    /// Same as [`generate_cached`] but invokes `on_token` once per
1100    /// freshly sampled id, inside the decode loop. The whole `n` step
1101    /// loop shares the bucketed compile cache — callers wanting a
1102    /// streaming UI should prefer this to calling
1103    /// `generate_cached(1, …)` `n` times (which forces a fresh
1104    /// compile per token at the bucket boundaries).
1105    pub fn generate_cached_with(
1106        &mut self,
1107        n: usize,
1108        opts: SampleOpts,
1109        mut on_token: impl FnMut(u32),
1110    ) -> Result<Vec<u32>> {
1111        let start = self.tokens.len();
1112        for _ in 0..n {
1113            let tok = self.step_cached(opts)?;
1114            on_token(tok);
1115        }
1116        Ok(self.tokens[start..].to_vec())
1117    }
1118
1119    /// Run prefill-with-cache on the current `self.tokens` (the
1120    /// prompt), populate `self.cache`, sample the next token from the
1121    /// last position's logits, and append it. Returns the sampled
1122    /// token. Invariant after: `cache.past_seq == tokens.len() - 1`.
1123    fn seed_cache_from_prompt(&mut self, opts: SampleOpts) -> Result<u32> {
1124        let seq = self.tokens.len();
1125        let batch = 1usize;
1126        let kv_dims = self.per_layer_kv_dims();
1127
1128        let outputs = if let Some(embeds) = self.pending_prefill_embeds.take() {
1129            let bias = self.pending_prefill_attn_bias.take();
1130            self.run_prefill_hidden_with_cache(batch, seq, &embeds, bias.as_deref())?
1131        } else {
1132            let ids_f32: Vec<f32> = self.tokens.iter().map(|&i| i as f32).collect();
1133            self.run_prefill_with_cache(batch, seq, &ids_f32)?
1134        };
1135        let (logits, kv) = kv_from_prefill_outputs_per_layer(
1136            outputs,
1137            batch,
1138            seq,
1139            &kv_dims,
1140            self.cfg.num_hidden_layers,
1141        )?;
1142        self.cache = Some(kv);
1143
1144        let vocab = self.cfg.vocab_size;
1145        let needed = vocab;
1146        if logits.len() < needed {
1147            anyhow::bail!("prefill logits length {} < {}", logits.len(), needed);
1148        }
1149        let last_row = &logits[..vocab];
1150        let tok = sample_token(last_row, opts) as u32;
1151        self.tokens.push(tok);
1152        Ok(tok)
1153    }
1154
1155    /// Full token history (prompt + generated).
1156    pub fn tokens(&self) -> &[u32] {
1157        &self.tokens
1158    }
1159
1160    pub fn config(&self) -> &GemmaConfig {
1161        &self.cfg
1162    }
1163
1164    /// Low-level primitive: reset internal state, run prefill-with-cache
1165    /// over `context`, and return the *last position's* logits row
1166    /// (`P(next_token | context)`). Does NOT sample or append. The
1167    /// internal `tokens` buffer is set to `context` and the KV cache
1168    /// is populated to `past_seq = context.len()`.
1169    ///
1170    /// First row of logits after prefill-with-cache (no sampling).
1171    pub fn prefill_get_last_logits(&mut self, context: &[u32]) -> Result<Vec<f32>> {
1172        if context.is_empty() {
1173            anyhow::bail!("prefill_get_last_logits: empty context");
1174        }
1175        self.tokens.clear();
1176        self.tokens.extend_from_slice(context);
1177        self.cache = None;
1178        self.reset_gpu_kv_binding();
1179
1180        let seq = context.len();
1181        let batch = 1usize;
1182        let kv_dims = self.per_layer_kv_dims();
1183
1184        let ids_f32: Vec<f32> = context.iter().map(|&i| i as f32).collect();
1185        let outputs = self.run_prefill_with_cache(batch, seq, &ids_f32)?;
1186        let (logits, kv) = kv_from_prefill_outputs_per_layer(
1187            outputs,
1188            batch,
1189            seq,
1190            &kv_dims,
1191            self.cfg.num_hidden_layers,
1192        )?;
1193        self.cache = Some(kv);
1194
1195        let vocab = self.cfg.vocab_size;
1196        let needed = vocab;
1197        if logits.len() < needed {
1198            anyhow::bail!("logits short: {} < {}", logits.len(), needed);
1199        }
1200        Ok(logits[..vocab].to_vec())
1201    }
1202
1203    /// Low-level primitive: run one decode step with the caller-
1204    /// supplied input token (no sampling), advance the KV cache, and
1205    /// return the resulting logits row `P(next | history ++ input)`.
1206    /// Appends `input` to the `tokens` buffer so the invariant
1207    /// `cache.past_seq == tokens.len()` holds after this call (note:
1208    /// differs from `step_cached` invariant because this method does
1209    /// not append a sampled token).
1210    pub fn decode_get_logits(&mut self, input: u32) -> Result<Vec<f32>> {
1211        if self.cache.is_none() {
1212            anyhow::bail!(
1213                "decode_get_logits: cache not seeded; call prefill_get_last_logits first"
1214            );
1215        }
1216        self.tokens.push(input);
1217        let seq = self.tokens.len();
1218        let batch = 1usize;
1219        let kv_dims = self.per_layer_kv_dims();
1220        let ids_f32: Vec<f32> = self.tokens.iter().map(|&i| i as f32).collect();
1221        let outputs = self.run_prefill_with_cache(batch, seq, &ids_f32)?;
1222        let (logits, kv) = kv_from_prefill_outputs_per_layer(
1223            outputs,
1224            batch,
1225            seq,
1226            &kv_dims,
1227            self.cfg.num_hidden_layers,
1228        )?;
1229        self.cache = Some(kv);
1230        let vocab = self.cfg.vocab_size;
1231        Ok(logits[..vocab].to_vec())
1232    }
1233
1234    /// Per-layer KV dimensions (`num_kv_heads * head_dim` for each
1235    /// layer, accounting for Gemma 4's full-attention shape divergence).
1236    fn per_layer_kv_dims(&self) -> Vec<usize> {
1237        (0..self.cfg.num_hidden_layers)
1238            .map(|i| self.cfg.layer_num_kv_heads(i) * self.cfg.layer_head_dim(i))
1239            .collect()
1240    }
1241}
1242
1243impl Drop for GemmaGenerator {
1244    fn drop(&mut self) {
1245        if self.device == Device::Metal {
1246            self.sync_device();
1247        }
1248    }
1249}
1250
1251/// Compute the single-row (cos, sin) RoPE slice for absolute position
1252/// `pos`. Matches the formula in the prefill builder so cached decode
1253/// and recompute prefill produce the same RoPE rotation.
1254fn compute_rope_slice(inv_freq: &[f64], pos: usize) -> (Vec<f32>, Vec<f32>) {
1255    rope_slice(inv_freq, pos)
1256}
1257
1258#[cfg(test)]
1259mod tests {
1260    use super::*;
1261    use crate::config::GemmaConfig;
1262    use crate::rope::{build_rope_tables, resolve_inv_freq, rope_slice};
1263    use rlx_flow::CompileProfile;
1264
1265    fn tiny_cfg() -> GemmaConfig {
1266        let mut cfg = GemmaConfig::tiny_test();
1267        cfg.vocab_size = 16;
1268        cfg.head_dim = Some(8);
1269        cfg
1270    }
1271
1272    fn synthetic_tensors(cfg: &GemmaConfig) -> HashMap<String, (Vec<f32>, Vec<usize>)> {
1273        let h = cfg.hidden_size;
1274        let q_dim = cfg.q_proj_dim();
1275        let kv_dim = cfg.kv_proj_dim();
1276        let int_dim = cfg.intermediate_size;
1277        let mut t: HashMap<String, (Vec<f32>, Vec<usize>)> = HashMap::new();
1278        // Use a deterministic non-zero pattern so logits aren't all 0
1279        // (sampling on an all-zero row is undefined order).
1280        let pat = |n: usize, salt: u32| -> Vec<f32> {
1281            (0..n)
1282                .map(|i| {
1283                    let x = ((i as u32).wrapping_mul(2654435761).wrapping_add(salt)) >> 8;
1284                    (x as f32 / (1u32 << 24) as f32) - 0.5
1285                })
1286                .collect()
1287        };
1288        t.insert(
1289            "model.embed_tokens.weight".into(),
1290            (pat(cfg.vocab_size * h, 1), vec![cfg.vocab_size, h]),
1291        );
1292        for i in 0..cfg.num_hidden_layers {
1293            let lp = format!("model.layers.{i}");
1294            t.insert(
1295                format!("{lp}.input_layernorm.weight"),
1296                (pat(h, 100 + i as u32), vec![h]),
1297            );
1298            t.insert(
1299                format!("{lp}.post_attention_layernorm.weight"),
1300                (pat(h, 200 + i as u32), vec![h]),
1301            );
1302            t.insert(
1303                format!("{lp}.self_attn.q_proj.weight"),
1304                (pat(q_dim * h, 300 + i as u32), vec![q_dim, h]),
1305            );
1306            t.insert(
1307                format!("{lp}.self_attn.k_proj.weight"),
1308                (pat(kv_dim * h, 400 + i as u32), vec![kv_dim, h]),
1309            );
1310            t.insert(
1311                format!("{lp}.self_attn.v_proj.weight"),
1312                (pat(kv_dim * h, 500 + i as u32), vec![kv_dim, h]),
1313            );
1314            t.insert(
1315                format!("{lp}.self_attn.o_proj.weight"),
1316                (pat(h * q_dim, 600 + i as u32), vec![h, q_dim]),
1317            );
1318            t.insert(
1319                format!("{lp}.mlp.gate_proj.weight"),
1320                (pat(int_dim * h, 900 + i as u32), vec![int_dim, h]),
1321            );
1322            t.insert(
1323                format!("{lp}.mlp.up_proj.weight"),
1324                (pat(int_dim * h, 1000 + i as u32), vec![int_dim, h]),
1325            );
1326            t.insert(
1327                format!("{lp}.mlp.down_proj.weight"),
1328                (pat(h * int_dim, 1100 + i as u32), vec![h, int_dim]),
1329            );
1330        }
1331        t.insert("model.norm.weight".into(), (pat(h, 2000), vec![h]));
1332        t.insert(
1333            "lm_head.weight".into(),
1334            (pat(cfg.vocab_size * h, 3000), vec![cfg.vocab_size, h]),
1335        );
1336        t
1337    }
1338
1339    fn synthetic_weights(cfg: &GemmaConfig) -> WeightMap {
1340        WeightMap::from_tensors(synthetic_tensors(cfg))
1341    }
1342
1343    #[test]
1344    fn generator_drains_loader_and_runs_one_step() {
1345        let cfg = tiny_cfg();
1346        let mut wm = synthetic_weights(&cfg);
1347        let mut gn = GemmaGenerator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap();
1348        assert_eq!(wm.len(), 0, "loader should be drained");
1349        gn.prefill(&[1, 2, 3]);
1350        let t = gn.step(SampleOpts::greedy()).unwrap();
1351        assert!((t as usize) < cfg.vocab_size);
1352        assert_eq!(gn.tokens().len(), 4);
1353    }
1354
1355    #[test]
1356    fn generate_n_appends_n_tokens() {
1357        let cfg = tiny_cfg();
1358        let mut wm = synthetic_weights(&cfg);
1359        let mut gn = GemmaGenerator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap();
1360        gn.prefill(&[5, 6]);
1361        let new_tokens = gn.generate(3, SampleOpts::greedy()).unwrap();
1362        assert_eq!(new_tokens.len(), 3);
1363        assert_eq!(gn.tokens().len(), 5);
1364        for t in &new_tokens {
1365            assert!((*t as usize) < cfg.vocab_size);
1366        }
1367    }
1368
1369    #[test]
1370    fn step_without_prefill_errors() {
1371        let cfg = tiny_cfg();
1372        let mut wm = synthetic_weights(&cfg);
1373        let mut gn = GemmaGenerator::from_loader(cfg, &mut wm, Device::Cpu).unwrap();
1374        let r = gn.step(SampleOpts::greedy());
1375        assert!(r.is_err());
1376    }
1377
1378    fn max_abs_diff(a: &[f32], b: &[f32]) -> f32 {
1379        a.iter()
1380            .zip(b.iter())
1381            .map(|(x, y)| (x - y).abs())
1382            .fold(0f32, f32::max)
1383    }
1384
1385    #[test]
1386    fn prefill_logits_unchanged_with_kv_export() {
1387        let cfg = tiny_cfg();
1388        let prompt: Vec<u32> = vec![1, 2, 3, 5];
1389
1390        let mut wm_a = synthetic_weights(&cfg);
1391        let mut wm_b = synthetic_weights(&cfg);
1392        let (graph_a, params_a) =
1393            build_gemma_graph_sized_last_logits(&cfg, &mut wm_a, 1, 4, false).unwrap();
1394        let (graph_b, params_b) =
1395            build_gemma_graph_sized_last_logits(&cfg, &mut wm_b, 1, 4, true).unwrap();
1396        let session = Session::new(Device::Cpu);
1397        let opts = CompileOptions::new();
1398        let mut ca = session.compile_with(graph_a, &opts);
1399        let mut cb = session.compile_with(graph_b, &opts);
1400        for (n, d) in &params_a {
1401            ca.set_param(n, d);
1402        }
1403        for (n, d) in &params_b {
1404            cb.set_param(n, d);
1405        }
1406        let ids: Vec<f32> = prompt.iter().map(|&i| i as f32).collect();
1407        let la = ca.run(&[("input_ids", &ids)])[0].clone();
1408        let lb = cb.run(&[("input_ids", &ids)])[0].clone();
1409        let d = max_abs_diff(&la, &lb);
1410        assert!(d < 1e-5, "kv export changed prefill logits: max_abs={d:.6}");
1411    }
1412
1413    #[test]
1414    fn incremental_decode_logits_match_full_prefill() {
1415        let cfg = tiny_cfg();
1416        let prompt: Vec<u32> = vec![1, 2, 3, 5];
1417
1418        let mut wm_a = synthetic_weights(&cfg);
1419        let mut gn_a = GemmaGenerator::from_loader(cfg.clone(), &mut wm_a, Device::Cpu).unwrap();
1420        let tok = gn_a
1421            .prefill_get_last_logits(&prompt)
1422            .map(|l| sample_token(&l, SampleOpts::greedy()) as u32)
1423            .unwrap();
1424
1425        let mut extended = prompt.clone();
1426        extended.push(tok);
1427
1428        let mut wm_b = synthetic_weights(&cfg);
1429        let mut gn_b = GemmaGenerator::from_loader(cfg.clone(), &mut wm_b, Device::Cpu).unwrap();
1430        let full = gn_b.prefill_get_last_logits(&extended).unwrap();
1431
1432        let mut wm_c = synthetic_weights(&cfg);
1433        let mut gn_c = GemmaGenerator::from_loader(cfg.clone(), &mut wm_c, Device::Cpu).unwrap();
1434        gn_c.prefill_get_last_logits(&prompt).unwrap();
1435        let incremental = gn_c.decode_get_logits(tok).unwrap();
1436
1437        let d = max_abs_diff(&full, &incremental);
1438        assert!(
1439            d < 1e-2,
1440            "decode+KV vs full prefill max_abs={d:.6} (tok={tok})"
1441        );
1442    }
1443
1444    fn run_prefill_kv(
1445        cfg: &GemmaConfig,
1446        wm: &mut WeightMap,
1447        seq: usize,
1448        ids: &[u32],
1449    ) -> Vec<Vec<f32>> {
1450        run_prefill_kv_with_options(cfg, wm, seq, ids, &kv_export_compile_options(true))
1451    }
1452
1453    fn kv_export_compile_options(prefill: bool) -> CompileOptions {
1454        let profile = if prefill {
1455            CompileProfile::gemma_prefill()
1456        } else {
1457            CompileProfile::gemma_decode()
1458        };
1459        compile_options_from_profile(&profile, Device::Cpu, KernelDispatchConfig::default())
1460    }
1461
1462    fn run_prefill_kv_with_options(
1463        cfg: &GemmaConfig,
1464        wm: &mut WeightMap,
1465        seq: usize,
1466        ids: &[u32],
1467        opts: &CompileOptions,
1468    ) -> Vec<Vec<f32>> {
1469        let ids_f32: Vec<f32> = ids.iter().map(|&i| i as f32).collect();
1470        let (graph, params) = build_gemma_graph_sized_last_logits(cfg, wm, 1, seq, true).unwrap();
1471        let session = Session::new(Device::Cpu);
1472        let mut compiled = session.compile_with(graph, opts);
1473        for (n, d) in &params {
1474            compiled.set_param(n, d);
1475        }
1476        let outputs = compiled.run(&[("input_ids", &ids_f32)]);
1477        let n_layers = cfg.num_hidden_layers;
1478        assert_eq!(outputs.len(), 1 + 2 * n_layers);
1479        let mut kv = Vec::with_capacity(2 * n_layers);
1480        let mut iter = outputs.into_iter().skip(1);
1481        for _ in 0..n_layers {
1482            kv.push(iter.next().unwrap());
1483            kv.push(iter.next().unwrap());
1484        }
1485        kv
1486    }
1487
1488    #[test]
1489    fn decode_graph_bakes_rope_slice_length() {
1490        let cfg = tiny_cfg();
1491        let past_seq = 4usize;
1492        let half = cfg.head_dim() / 2;
1493        let mut wm = synthetic_weights(&cfg);
1494        let (_, params) = build_gemma_decode_graph_sized(&cfg, &mut wm, 1, past_seq).unwrap();
1495        let cos = params
1496            .get("decode.rope.cos")
1497            .expect("decode.rope.cos param");
1498        let sin = params
1499            .get("decode.rope.sin")
1500            .expect("decode.rope.sin param");
1501        assert_eq!(
1502            cos.len(),
1503            half,
1504            "cos param should be one row (half={half}), got {}",
1505            cos.len()
1506        );
1507        assert_eq!(sin.len(), half);
1508        for key in params.keys() {
1509            assert!(
1510                !key.starts_with("rope."),
1511                "decode graph must not include prefill rope table param {key}"
1512            );
1513        }
1514        let inv = resolve_inv_freq(&cfg, None);
1515        let (c_ref, s_ref) = rope_slice(&inv, past_seq);
1516        let d = max_abs_diff(cos, &c_ref) + max_abs_diff(sin, &s_ref);
1517        assert!(d < 1e-6, "baked rope mismatch: {d}");
1518    }
1519
1520    #[test]
1521    fn decode_graph_all_rope_use_baked_cos() {
1522        use rlx_ir::Op;
1523        let cfg = tiny_cfg();
1524        let mut wm = synthetic_weights(&cfg);
1525        let (graph, _) = build_gemma_decode_graph_sized(&cfg, &mut wm, 1, 4).unwrap();
1526        for node in graph.nodes() {
1527            if let Op::Rope { .. } = &node.op {
1528                let cos_id = node.inputs[1];
1529                let cos_node = &graph.node(cos_id);
1530                match &cos_node.op {
1531                    Op::Param { name } => assert_eq!(
1532                        name, "decode.rope.cos",
1533                        "decode RoPE must use baked decode.rope.cos, got {name}"
1534                    ),
1535                    other => panic!("decode RoPE cos input is {other:?}, expected Param"),
1536                }
1537            }
1538        }
1539    }
1540
1541    #[test]
1542    fn decode_graph_rope_cos_is_single_row() {
1543        use rlx_ir::Op;
1544        let cfg = tiny_cfg();
1545        let past_seq = 4usize;
1546        let half = cfg.head_dim() / 2;
1547        let mut wm = synthetic_weights(&cfg);
1548        let (graph, _) = build_gemma_decode_graph_sized(&cfg, &mut wm, 1, past_seq).unwrap();
1549        let mut rope_cos_lens = Vec::new();
1550        for node in graph.nodes() {
1551            if let Op::Rope { .. } = &node.op {
1552                let cos_shape = &graph.node(node.inputs[1]).shape;
1553                let rows = if cos_shape.rank() >= 2 {
1554                    cos_shape.dim(0).unwrap_static()
1555                } else {
1556                    1
1557                };
1558                rope_cos_lens.push(rows);
1559            }
1560        }
1561        assert!(!rope_cos_lens.is_empty(), "decode graph has no RoPE nodes");
1562        for rows in &rope_cos_lens {
1563            assert_eq!(
1564                *rows, 1,
1565                "decode RoPE cos must be single-row [1, half], got {rows} rows"
1566            );
1567        }
1568        assert_eq!(half, cfg.head_dim() / 2);
1569    }
1570
1571    #[test]
1572    fn prefill_kv_matches_extended_prefix() {
1573        let cfg = tiny_cfg();
1574        let prompt: Vec<u32> = vec![1, 2, 3, 5];
1575        let tok = 6u32;
1576        let mut extended = prompt.clone();
1577        extended.push(tok);
1578
1579        let mut wm_prompt = synthetic_weights(&cfg);
1580        let prompt_kv = run_prefill_kv(&cfg, &mut wm_prompt, 4, &prompt);
1581        let mut wm_ext = synthetic_weights(&cfg);
1582        let ext_kv = run_prefill_kv(&cfg, &mut wm_ext, 5, &extended);
1583
1584        let kv_dim = cfg.kv_proj_dim();
1585        for layer in 0..cfg.num_hidden_layers {
1586            let k_prompt = &prompt_kv[2 * layer];
1587            let k_ext = &ext_kv[2 * layer];
1588            let prefix_len = 4 * kv_dim;
1589            assert_eq!(k_prompt.len(), prefix_len);
1590            assert_eq!(k_ext.len(), 5 * kv_dim);
1591            let d = max_abs_diff(k_prompt, &k_ext[..prefix_len]);
1592            assert!(
1593                d < 1e-4,
1594                "layer {layer} prefill K prefix vs extended K max_abs={d:.6}"
1595            );
1596        }
1597    }
1598
1599    #[test]
1600    fn decode_rope_slice_matches_prefill_table_row() {
1601        let cfg = tiny_cfg();
1602        let inv = resolve_inv_freq(&cfg, None);
1603        let (cos_tab, sin_tab) = build_rope_tables(&inv, cfg.max_position_embeddings);
1604        let half = inv.len();
1605        for pos in [3usize, 4, 5] {
1606            let (c, s) = rope_slice(&inv, pos);
1607            let off = pos * half;
1608            let d = max_abs_diff(&c, &cos_tab[off..off + half])
1609                + max_abs_diff(&s, &sin_tab[off..off + half]);
1610            assert!(d < 1e-6, "rope_slice mismatch at pos {pos}: {d}");
1611        }
1612    }
1613
1614    #[test]
1615    fn prefill_kv_export_correct_with_fusion() {
1616        let cfg = tiny_cfg();
1617        let tok = 6u32;
1618        let ids = [1u32, 2, 3, 5, tok];
1619        let opts = kv_export_compile_options(true);
1620        let mut wm_one = synthetic_weights(&cfg);
1621        let one_kv = run_prefill_kv_with_options(&cfg, &mut wm_one, 1, &[tok], &opts);
1622        let mut wm_ext = synthetic_weights(&cfg);
1623        let ext_kv = run_prefill_kv_with_options(&cfg, &mut wm_ext, 5, &ids, &opts);
1624        let kv_dim = cfg.kv_proj_dim();
1625        let d = max_abs_diff(&ext_kv[1][4 * kv_dim..], &one_kv[1][..kv_dim]);
1626        assert!(d < 1e-4, "KV export mismatch with profile fusion: {d:.6}");
1627
1628        let mut wm_default = synthetic_weights(&cfg);
1629        let default_kv =
1630            run_prefill_kv_with_options(&cfg, &mut wm_default, 5, &ids, &CompileOptions::new());
1631        let d_default = max_abs_diff(&default_kv[1][4 * kv_dim..], &one_kv[1][..kv_dim]);
1632        assert!(
1633            d_default < 1e-4,
1634            "KV export mismatch with default fusion (got {d_default:.6})"
1635        );
1636    }
1637
1638    #[test]
1639    fn decode_oneshot_kv_suffix_matches_extended() {
1640        let cfg = tiny_cfg();
1641        let prompt: Vec<u32> = vec![1, 2, 3, 5];
1642        let tok = 6u32;
1643        let mut extended = prompt.clone();
1644        extended.push(tok);
1645
1646        let opts = kv_export_compile_options(false);
1647        let mut wm_ext = synthetic_weights(&cfg);
1648        let ext_kv = run_prefill_kv_with_options(&cfg, &mut wm_ext, 5, &extended, &opts);
1649
1650        let mut wm = synthetic_weights(&cfg);
1651        let mut gn = GemmaGenerator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap();
1652        gn.prefill_get_last_logits(&prompt).unwrap();
1653
1654        let mut wm_d = synthetic_weights(&cfg);
1655        let (graph, params) = build_gemma_decode_graph_sized(&cfg, &mut wm_d, 1, 4).unwrap();
1656        let session = Session::new(Device::Cpu);
1657        let mut compiled = session.compile_with(graph, &opts);
1658        for (n, d) in &params {
1659            compiled.set_param(n, d);
1660        }
1661        let cache = gn.cache.as_ref().unwrap();
1662        let key_strs: Vec<String> = (0..cfg.num_hidden_layers)
1663            .flat_map(|i| [format!("past_k_{i}"), format!("past_v_{i}")])
1664            .collect();
1665        let input_ids = [tok as f32];
1666        let mut inputs: Vec<(&str, &[f32])> = vec![("input_ids", input_ids.as_slice())];
1667        for i in 0..cfg.num_hidden_layers {
1668            inputs.push((&key_strs[2 * i], cache.layers_k[i].as_slice()));
1669            inputs.push((&key_strs[2 * i + 1], cache.layers_v[i].as_slice()));
1670        }
1671        let outputs = compiled.run(&inputs);
1672        let kv_dim = cfg.kv_proj_dim();
1673        let k_dec = &outputs[1][4 * kv_dim..];
1674
1675        let d = max_abs_diff(k_dec, &ext_kv[0][4 * kv_dim..]);
1676        assert!(
1677            d < 1e-3,
1678            "decode oneshot layer0 K suffix vs extended max_abs={d:.6}"
1679        );
1680    }
1681
1682    #[test]
1683    fn decode_logits_match_extended_prefill_after_one_token() {
1684        let cfg = tiny_cfg();
1685        let prompt: Vec<u32> = vec![1, 2, 3, 5];
1686        let tok = 6u32;
1687
1688        let mut extended = prompt.clone();
1689        extended.push(tok);
1690
1691        let mut wm_a = synthetic_weights(&cfg);
1692        let mut gn_a = GemmaGenerator::from_loader(cfg.clone(), &mut wm_a, Device::Cpu).unwrap();
1693        let full = gn_a.prefill_get_last_logits(&extended).unwrap();
1694
1695        let mut wm_b = synthetic_weights(&cfg);
1696        let mut gn_b = GemmaGenerator::from_loader(cfg.clone(), &mut wm_b, Device::Cpu).unwrap();
1697        gn_b.prefill_get_last_logits(&prompt).unwrap();
1698        let inc = gn_b.decode_get_logits(tok).unwrap();
1699
1700        let d = max_abs_diff(&full, &inc);
1701        assert!(d < 1e-2, "decode vs extended prefill max_abs={d:.6}");
1702    }
1703
1704    #[test]
1705    fn cached_second_token_matches_naive() {
1706        let cfg = tiny_cfg();
1707        let prompt: Vec<u32> = vec![1, 2, 3, 5];
1708
1709        let mut wm_n = synthetic_weights(&cfg);
1710        let mut gn_n = GemmaGenerator::from_loader(cfg.clone(), &mut wm_n, Device::Cpu).unwrap();
1711        gn_n.prefill(&prompt);
1712        let n0 = gn_n.step(SampleOpts::greedy()).unwrap();
1713        let n1 = gn_n.step(SampleOpts::greedy()).unwrap();
1714
1715        let mut wm_c = synthetic_weights(&cfg);
1716        let mut gn_c = GemmaGenerator::from_loader(cfg.clone(), &mut wm_c, Device::Cpu).unwrap();
1717        gn_c.prefill(&prompt);
1718        let c = gn_c.generate_cached(2, SampleOpts::greedy()).unwrap();
1719
1720        assert_eq!(c[0], n0, "first generated token");
1721        assert_eq!(c[1], n1, "second generated token (decode step)");
1722    }
1723
1724    #[test]
1725    fn cached_matches_naive_on_greedy() {
1726        // The cached and naive paths must produce the same token
1727        // sequence given the same prompt + opts. This is the
1728        // load-bearing test for the KV-cache implementation: if the
1729        // decode-mode graph, the kernel's Lq!=Lk fix, the cache
1730        // wiring, or the RoPE position-slice is wrong, the sequences
1731        // diverge here.
1732        let cfg = tiny_cfg();
1733        let prompt: Vec<u32> = vec![1, 2, 3, 5];
1734        let steps = 4;
1735
1736        let mut wm_n = synthetic_weights(&cfg);
1737        let mut gn_naive =
1738            GemmaGenerator::from_loader(cfg.clone(), &mut wm_n, Device::Cpu).unwrap();
1739        gn_naive.prefill(&prompt);
1740        let naive_tokens = gn_naive.generate(steps, SampleOpts::greedy()).unwrap();
1741
1742        let mut wm_c = synthetic_weights(&cfg);
1743        let mut gn_cached =
1744            GemmaGenerator::from_loader(cfg.clone(), &mut wm_c, Device::Cpu).unwrap();
1745        gn_cached.prefill(&prompt);
1746        let cached_tokens = gn_cached
1747            .generate_cached(steps, SampleOpts::greedy())
1748            .unwrap();
1749
1750        assert_eq!(
1751            cached_tokens, naive_tokens,
1752            "cached vs naive token mismatch — KV cache or kernel-Lq!=Lk bug"
1753        );
1754    }
1755
1756    #[test]
1757    fn cached_step_advances_cache_invariant() {
1758        let cfg = tiny_cfg();
1759        let mut wm = synthetic_weights(&cfg);
1760        let mut gn = GemmaGenerator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap();
1761        gn.prefill(&[1, 2, 3]);
1762        let _ = gn.step_cached(SampleOpts::greedy()).unwrap();
1763        // After seed: tokens.len() == 4, cache.past_seq == 3 (cache holds prompt).
1764        assert_eq!(gn.tokens().len(), 4);
1765        assert_eq!(gn.cache.as_ref().unwrap().past_len, 3);
1766        let _ = gn.step_cached(SampleOpts::greedy()).unwrap();
1767        // After one decode: tokens.len() == 5, cache.past_seq == 4.
1768        assert_eq!(gn.tokens().len(), 5);
1769        assert_eq!(gn.cache.as_ref().unwrap().past_len, 4);
1770    }
1771
1772    #[test]
1773    fn bucketed_decode_matches_oneshot() {
1774        // The bucketed compile-cache path (padded K/V + custom mask)
1775        // must produce the same token sequence as the one-shot
1776        // path. Load-bearing for the bucketed cache feature: if the
1777        // mask, padding, or output slicing is wrong, sequences
1778        // diverge here.
1779        let cfg = tiny_cfg();
1780        let prompt: Vec<u32> = vec![1, 2, 3, 5];
1781        let steps = 6;
1782
1783        let mut wm_one = synthetic_weights(&cfg);
1784        let mut gn_one =
1785            GemmaGenerator::from_loader(cfg.clone(), &mut wm_one, Device::Cpu).unwrap();
1786        gn_one.prefill(&prompt);
1787        let oneshot_tokens = gn_one.generate_cached(steps, SampleOpts::greedy()).unwrap();
1788
1789        let mut wm_buc = synthetic_weights(&cfg);
1790        let mut gn_buc = GemmaGenerator::from_loader(cfg.clone(), &mut wm_buc, Device::Cpu)
1791            .unwrap()
1792            .with_decode_cache(/*max_past*/ 32);
1793        gn_buc.prefill(&prompt);
1794        let bucketed_tokens = gn_buc.generate_cached(steps, SampleOpts::greedy()).unwrap();
1795
1796        assert_eq!(
1797            bucketed_tokens, oneshot_tokens,
1798            "bucketed-cache decode diverged from one-shot decode — \
1799             mask, padding, or output-slice bug"
1800        );
1801    }
1802
1803    #[test]
1804    fn prefill_compile_cache_does_not_change_output() {
1805        let cfg = tiny_cfg();
1806        let prompt: Vec<u32> = vec![1, 2, 3, 5];
1807        let mut wm_a = synthetic_weights(&cfg);
1808        let mut gn_a = GemmaGenerator::from_loader(cfg.clone(), &mut wm_a, Device::Cpu).unwrap();
1809        gn_a.prefill(&prompt);
1810        let a = gn_a.generate_cached(4, SampleOpts::greedy()).unwrap();
1811
1812        let mut wm_b = synthetic_weights(&cfg);
1813        let mut gn_b = GemmaGenerator::from_loader(cfg.clone(), &mut wm_b, Device::Cpu)
1814            .unwrap()
1815            .with_prefill_cache(/*capacity*/ 4);
1816        gn_b.prefill(&prompt);
1817        let b = gn_b.generate_cached(4, SampleOpts::greedy()).unwrap();
1818
1819        assert_eq!(a, b, "enabling prefill_cache must not change output");
1820    }
1821
1822    #[test]
1823    fn dynamic_decode_matches_oneshot() {
1824        let cfg = tiny_cfg();
1825        let prompt: Vec<u32> = vec![1, 2, 3, 5];
1826        let steps = 6;
1827
1828        let mut wm_one = synthetic_weights(&cfg);
1829        let mut gn_one =
1830            GemmaGenerator::from_loader(cfg.clone(), &mut wm_one, Device::Cpu).unwrap();
1831        gn_one.prefill(&prompt);
1832        let oneshot_tokens = gn_one.generate_cached(steps, SampleOpts::greedy()).unwrap();
1833
1834        let mut wm_dyn = synthetic_weights(&cfg);
1835        let mut gn_dyn = GemmaGenerator::from_loader(cfg.clone(), &mut wm_dyn, Device::Cpu)
1836            .unwrap()
1837            .with_dynamic_decode_cache(/*capacity*/ 8);
1838        gn_dyn.prefill(&prompt);
1839        let dynamic_tokens = gn_dyn.generate_cached(steps, SampleOpts::greedy()).unwrap();
1840
1841        assert_eq!(
1842            dynamic_tokens, oneshot_tokens,
1843            "dynamic past_seq decode diverged from one-shot decode"
1844        );
1845    }
1846
1847    #[test]
1848    fn dynamic_prefill_matches_oneshot() {
1849        let cfg = tiny_cfg();
1850        let prompt: Vec<u32> = vec![1, 2, 3, 5];
1851        let steps = 4;
1852
1853        let mut wm_one = synthetic_weights(&cfg);
1854        let mut gn_one =
1855            GemmaGenerator::from_loader(cfg.clone(), &mut wm_one, Device::Cpu).unwrap();
1856        gn_one.prefill(&prompt);
1857        let oneshot_tokens = gn_one.generate_cached(steps, SampleOpts::greedy()).unwrap();
1858
1859        let mut wm_dyn = synthetic_weights(&cfg);
1860        let mut gn_dyn = GemmaGenerator::from_loader(cfg.clone(), &mut wm_dyn, Device::Cpu)
1861            .unwrap()
1862            .with_dynamic_prefill_cache(/*capacity*/ 8);
1863        gn_dyn.prefill(&prompt);
1864        let dynamic_tokens = gn_dyn.generate_cached(steps, SampleOpts::greedy()).unwrap();
1865
1866        assert_eq!(
1867            dynamic_tokens, oneshot_tokens,
1868            "dynamic seq prefill diverged from one-shot prefill"
1869        );
1870    }
1871
1872    #[test]
1873    fn dynamic_prefill_and_decode_matches_oneshot() {
1874        let cfg = tiny_cfg();
1875        let prompt: Vec<u32> = vec![1, 2, 3, 5];
1876        let steps = 6;
1877
1878        let mut wm_one = synthetic_weights(&cfg);
1879        let mut gn_one =
1880            GemmaGenerator::from_loader(cfg.clone(), &mut wm_one, Device::Cpu).unwrap();
1881        gn_one.prefill(&prompt);
1882        let oneshot_tokens = gn_one.generate_cached(steps, SampleOpts::greedy()).unwrap();
1883
1884        let mut wm_dyn = synthetic_weights(&cfg);
1885        let mut gn_dyn = GemmaGenerator::from_loader(cfg.clone(), &mut wm_dyn, Device::Cpu)
1886            .unwrap()
1887            .with_dynamic_prefill_cache(/*capacity*/ 8)
1888            .with_dynamic_decode_cache(/*capacity*/ 8);
1889        gn_dyn.prefill(&prompt);
1890        let dynamic_tokens = gn_dyn.generate_cached(steps, SampleOpts::greedy()).unwrap();
1891
1892        assert_eq!(
1893            dynamic_tokens, oneshot_tokens,
1894            "dynamic prefill+decode diverged from one-shot path"
1895        );
1896    }
1897
1898    #[test]
1899    fn greedy_is_deterministic_across_runs() {
1900        let cfg = tiny_cfg();
1901        let weights = synthetic_weights(&cfg);
1902        let mk = || {
1903            let mut wm = WeightMap::from_tensors(weights_as_hashmap(&weights));
1904            GemmaGenerator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap()
1905        };
1906        let mut a = mk();
1907        let mut b = mk();
1908        a.prefill(&[1, 2, 3]);
1909        b.prefill(&[1, 2, 3]);
1910        let ta = a.generate(4, SampleOpts::greedy()).unwrap();
1911        let tb = b.generate(4, SampleOpts::greedy()).unwrap();
1912        assert_eq!(ta, tb);
1913    }
1914
1915    fn weights_as_hashmap(wm: &WeightMap) -> HashMap<String, (Vec<f32>, Vec<usize>)> {
1916        // Reconstruct the underlying map by re-running synthetic_weights
1917        // — WeightMap doesn't expose its inner map. Sufficient for the
1918        // determinism test since synthetic_weights is itself
1919        // deterministic.
1920        let _ = wm; // silence unused
1921        let cfg = tiny_cfg();
1922        let mut new = synthetic_weights(&cfg);
1923        let keys: Vec<String> = new.keys().map(|s| s.to_string()).collect();
1924        let mut out = HashMap::new();
1925        for k in keys {
1926            out.insert(k.clone(), new.take(&k).unwrap());
1927        }
1928        out
1929    }
1930}