Skip to main content

rlx_llama32/
generator.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// (license header truncated — see workspace root.)
9
10//! Host-side generation loop for LLaMA-3.2.
11//!
12//! This is the **naive** generator: each `step()` rebuilds the prefill
13//! graph for the full token history and runs it from scratch
14//! (O(N²) compute over N generated tokens). The API is shaped to
15//! match the upcoming KV-cache version exactly so callers don't have
16//! to change anything when the cached path lands — only the internal
17//! implementation swaps.
18//!
19//! Why ship the naive version first:
20//!   - Establishes the public API contract before the IR/kernel
21//!     changes that the cached version needs land.
22//!   - Lets you run end-to-end generation against a real checkpoint
23//!     today and validate the prefill graph is numerically correct.
24//!   - Provides a reference oracle for the cached version's own
25//!     numerical-parity test (cached vs recompute must match).
26
27use crate::builder::{
28    build_llama32_decode_hir_dynamic_ext, build_llama32_decode_hir_sized,
29    build_llama32_decode_hir_sized_ext, build_llama32_graph_sized_last_logits,
30    build_llama32_prefill_hir_dynamic_ext,
31};
32use crate::config::Llama32Config;
33use crate::rope::{resolve_inv_freq, rope_slice};
34use anyhow::{Context, Result};
35use rlx_core::flow_bridge::compile_options_from_profile;
36use rlx_core::weight_loader::WeightLoader;
37use rlx_core::weight_map::WeightMap;
38use rlx_flow::CompileProfile;
39use rlx_ir::DimBinding;
40use rlx_ir::logical_kernel::KernelDispatchConfig;
41use rlx_qwen3::sampling::{SampleOpts, sample_token};
42use rlx_runtime::compile_cache::{BucketedCompileCache, CompileCache, DynamicDimCompileCache};
43use rlx_runtime::{CompileOptions, Device, Session};
44use std::collections::{HashMap, HashSet};
45use std::path::Path;
46
47/// Per-layer KV cache state for incremental decoding. Each `Vec<f32>`
48/// is a flat `[batch, past_seq, kv_proj_dim]` tensor.
49#[derive(Clone)]
50struct KvCacheState {
51    past_seq: usize,
52    layers_k: Vec<Vec<f32>>,
53    layers_v: Vec<Vec<f32>>,
54}
55
56/// Stateful LLaMA-3.2 generation handle.
57///
58/// Holds the (config, weight bytes, token history) and rebuilds a
59/// prefill graph on each [`step`] call. Cheap to construct after
60/// initial weight load; tokens stay in-memory between calls.
61pub struct Llama32Generator {
62    cfg: Llama32Config,
63    /// Map of weight key → (f32 data, shape). Cloned on each step
64    /// into a fresh `WeightMap` because `WeightMap::take` is
65    /// destructive — see the cached-generator notes for the path
66    /// that avoids the clone.
67    weights_cache: HashMap<String, (Vec<f32>, Vec<usize>)>,
68    tokens: Vec<u32>,
69    device: Device,
70    /// Populated lazily on the first `step_cached` call (seeded from
71    /// the prompt via prefill-with-cache); thereafter advanced by each
72    /// decode step.
73    cache: Option<KvCacheState>,
74    /// Per-key LRU compile cache for prefill graphs. Keyed by `seq`.
75    /// Set to `None` to disable (default for new instances; opt in via
76    /// [`Llama32Generator::with_prefill_cache`]).
77    prefill_compile_cache: Option<CompileCache>,
78    /// Compile prefill once with `sym::SEQ`, specialize per prompt length.
79    prefill_dynamic_cache: Option<DynamicDimCompileCache>,
80    /// Bucketed compile cache for decode-mode graphs. Each bucket
81    /// holds one compiled graph specialized at its upper-bound
82    /// `past_seq`; the host pads `past_k`/`past_v` and supplies a
83    /// per-step mask so a single bucket serves every `past_seq` in
84    /// its range. Opt in via [`Llama32Generator::with_decode_cache`].
85    decode_compile_cache: Option<BucketedCompileCache>,
86    decode_dynamic_cache: Option<DynamicDimCompileCache>,
87    /// Tracks which decode buckets have had params attached. The
88    /// `BucketedCompileCache` API doesn't expose per-bucket compile
89    /// status, so we maintain it here to avoid double-loading params.
90    decode_loaded_buckets: HashSet<usize>,
91    /// Resolved RoPE inverse frequencies (includes Llama 3 scaling).
92    inv_freq: Vec<f64>,
93    /// Tier-1 compile profile for prefill graphs.
94    prefill_profile: CompileProfile,
95    /// Tier-1 compile profile for decode graphs.
96    decode_profile: CompileProfile,
97}
98
99impl Llama32Generator {
100    /// Construct from any [`WeightLoader`] — drains it into an
101    /// internal cache so the loader is free after this call.
102    pub fn from_loader(
103        cfg: Llama32Config,
104        loader: &mut dyn WeightLoader,
105        device: Device,
106    ) -> Result<Self> {
107        let keys = loader.remaining_keys();
108        let mut weights_cache = HashMap::with_capacity(keys.len());
109        for k in keys {
110            let v = loader
111                .take(&k)
112                .with_context(|| format!("draining weight {k}"))?;
113            // Normalize the cache key to the safetensors / HuggingFace
114            // naming convention so subsequent builder calls that ask
115            // for `model.embed_tokens.weight` (the canonical name baked
116            // into the llama32 builder) hit the cache whether the
117            // loader was safetensors-native or GGUF-native.
118            let canonical =
119                rlx_core::weight_loader::gguf_to_hf_name(&k).unwrap_or_else(|| k.clone());
120            weights_cache.insert(canonical, v);
121        }
122        let rope_factors = weights_cache
123            .get("rope_freqs.weight")
124            .map(|(d, _)| d.as_slice());
125        let inv_freq = resolve_inv_freq(&cfg, rope_factors);
126        Ok(Self {
127            cfg,
128            weights_cache,
129            tokens: Vec::new(),
130            device,
131            cache: None,
132            prefill_compile_cache: None,
133            prefill_dynamic_cache: None,
134            decode_compile_cache: None,
135            decode_dynamic_cache: None,
136            decode_loaded_buckets: HashSet::new(),
137            inv_freq,
138            prefill_profile: CompileProfile::llama32_prefill(),
139            decode_profile: CompileProfile::llama32_decode(),
140        })
141    }
142
143    /// Like [`Self::from_loader`] but loads tier-1 profiles from
144    /// `llama32.rlx.toml` in the weights directory when present.
145    pub fn from_loader_at(
146        cfg: Llama32Config,
147        loader: &mut dyn WeightLoader,
148        device: Device,
149        weights_path: &Path,
150    ) -> Result<Self> {
151        let mut g = Self::from_loader(cfg, loader, device)?;
152        g.prefill_profile = crate::llama32_profile_near_weights(weights_path, false);
153        g.decode_profile = crate::llama32_profile_near_weights(weights_path, true);
154        Ok(g)
155    }
156
157    /// Override tier-1 compile profiles explicitly.
158    pub fn with_compile_profiles(
159        mut self,
160        prefill: CompileProfile,
161        decode: CompileProfile,
162    ) -> Self {
163        self.prefill_profile = prefill;
164        self.decode_profile = decode;
165        self
166    }
167
168    pub fn prefill_profile(&self) -> &CompileProfile {
169        &self.prefill_profile
170    }
171
172    pub fn decode_profile(&self) -> &CompileProfile {
173        &self.decode_profile
174    }
175
176    fn profile_compile_options(&self, decode: bool) -> CompileOptions {
177        let profile = if decode {
178            &self.decode_profile
179        } else {
180            &self.prefill_profile
181        };
182        compile_options_from_profile(profile, self.device, KernelDispatchConfig::default())
183    }
184
185    fn compile_hir_profiled(
186        &self,
187        session: &Session,
188        hir: rlx_ir::hir::HirModule,
189        decode: bool,
190    ) -> Result<rlx_runtime::CompiledGraph> {
191        let opts = self.profile_compile_options(decode);
192        Ok(session.compile_hir_with(hir, &opts)?)
193    }
194
195    fn compile_graph_profiled(
196        &self,
197        session: &Session,
198        graph: rlx_ir::Graph,
199    ) -> Result<rlx_runtime::CompiledGraph> {
200        let opts = self.profile_compile_options(false);
201        Ok(session.compile_with(graph, &opts))
202    }
203
204    /// Enable the prefill compile cache with the given LRU capacity.
205    /// Useful when the same prompt length is used across multiple
206    /// generation runs — the second + Nth run skip the compile +
207    /// param-attach roundtrip (~30-50ms per call on CPU).
208    pub fn with_prefill_cache(mut self, capacity: usize) -> Self {
209        self.prefill_compile_cache = Some(CompileCache::new(self.device, capacity));
210        self.prefill_dynamic_cache = None;
211        self
212    }
213
214    /// Compile prefill once with `sym::SEQ`, specialize per prompt length.
215    pub fn with_dynamic_prefill_cache(mut self, capacity: usize) -> Self {
216        self.prefill_dynamic_cache = Some(DynamicDimCompileCache::new(self.device, capacity));
217        self.prefill_compile_cache = None;
218        self
219    }
220
221    /// Enable the bucketed decode compile cache spanning past-seq
222    /// values in `[1, max_past]`. Buckets are power-of-two
223    /// `[1..2, 2..3, 3..5, 5..9, 9..17, …]`. Each bucket compiles
224    /// one graph at its upper bound; a steady-state generation loop
225    /// across `N` tokens compiles `O(log N)` graphs instead of `N`.
226    ///
227    /// Padding compute waste is bounded at 2×: actual `past_seq` is
228    /// at least half the bucket's upper bound (except possibly the
229    /// smallest bucket).
230    pub fn with_decode_cache(mut self, max_past: usize) -> Self {
231        let cache = BucketedCompileCache::power_of_two_ladder(
232            self.device,
233            /*min*/ 1,
234            max_past.max(1) as u64,
235        );
236        self.decode_compile_cache = Some(cache);
237        self.decode_dynamic_cache = None;
238        self.decode_loaded_buckets.clear();
239        self
240    }
241
242    /// Compile decode once with `sym::PAST_SEQ`, specialize per prefix length.
243    pub fn with_dynamic_decode_cache(mut self, capacity: usize) -> Self {
244        self.decode_dynamic_cache = Some(DynamicDimCompileCache::new(self.device, capacity));
245        self.decode_compile_cache = None;
246        self.decode_loaded_buckets.clear();
247        self
248    }
249
250    /// Convenience: load weights from a safetensors or GGUF path
251    /// (dispatch by extension; see `rlx_core::weight_loader::load_from_path`).
252    pub fn from_path(cfg: Llama32Config, path: &str, device: Device) -> Result<Self> {
253        let mut loader = rlx_core::weight_loader::load_from_path(path)?;
254        Self::from_loader(cfg, loader.as_mut(), device)
255    }
256
257    /// Same as [`from_path`] but with MTP-head visibility control.
258    /// When `include_mtp=true` and the file is GGUF, MTP weights are
259    /// drained into the generator's cache alongside the base
260    /// weights. The base inference path still ignores them — they
261    /// sit in cache for a future MTP-aware decoder. Non-GGUF formats
262    /// silently ignore the flag (safetensors files publish all
263    /// tensors uniformly; downstream code distinguishes by name).
264    pub fn from_path_with_mtp(
265        cfg: Llama32Config,
266        path: &str,
267        device: Device,
268        include_mtp: bool,
269    ) -> Result<Self> {
270        // Branch on extension so we can flip the GGUF-specific
271        // visibility knob. Safetensors has no equivalent — it
272        // doesn't isolate MTP tensors at the loader level.
273        if path.ends_with(".gguf") {
274            let mut gguf = rlx_core::weight_loader::GgufLoader::from_file(path)?;
275            gguf.include_mtp(include_mtp);
276            Self::from_loader(cfg, &mut gguf, device)
277        } else {
278            Self::from_path(cfg, path, device)
279        }
280    }
281
282    /// Replace the token history with `prompt_ids`. Does not run the
283    /// model — the next [`step`] call processes the full sequence.
284    /// Clears any KV cache from a prior generation.
285    pub fn prefill(&mut self, prompt_ids: &[u32]) {
286        self.tokens.clear();
287        self.tokens.extend_from_slice(prompt_ids);
288        self.cache = None;
289    }
290
291    /// Run one prefill over the current token history and sample the
292    /// next token. The sampled token is appended to the history and
293    /// returned. Call repeatedly to generate.
294    pub fn step(&mut self, opts: SampleOpts) -> Result<u32> {
295        if self.tokens.is_empty() {
296            anyhow::bail!("step() called with empty token history; call prefill() first");
297        }
298        let seq = self.tokens.len();
299        let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
300        let (graph, params) = build_llama32_graph_sized_last_logits(
301            &self.cfg, &mut wm, /*batch*/ 1, seq, /*with_kv_outputs*/ false,
302        )?;
303        let session = Session::new(self.device);
304        let mut compiled = self.compile_graph_profiled(&session, graph)?;
305        for (name, data) in &params {
306            compiled.set_param(name, data);
307        }
308        let ids_f32: Vec<f32> = self.tokens.iter().map(|&i| i as f32).collect();
309        let outputs = compiled.run(&[("input_ids", ids_f32.as_slice())]);
310        let logits = outputs
311            .into_iter()
312            .next()
313            .context("compiled.run returned no outputs")?;
314
315        let vocab = self.cfg.vocab_size;
316        let expected = vocab;
317        if logits.len() < expected {
318            anyhow::bail!(
319                "logits length {} < expected {} (last logits, seq {seq}, vocab {vocab})",
320                logits.len(),
321                expected
322            );
323        }
324        // Last-logits graph returns [B=1, 1, vocab].
325        let last_row = &logits[..vocab];
326        let tok = sample_token(last_row, opts) as u32;
327        self.tokens.push(tok);
328        Ok(tok)
329    }
330
331    /// Run `n` steps and return the newly generated token ids
332    /// (excludes the prefill prompt).
333    pub fn generate(&mut self, n: usize, opts: SampleOpts) -> Result<Vec<u32>> {
334        let start = self.tokens.len();
335        for _ in 0..n {
336            self.step(opts)?;
337        }
338        Ok(self.tokens[start..].to_vec())
339    }
340
341    /// Cached step: O(L) per token instead of O(L²). First call seeds
342    /// the KV cache from the prompt via prefill-with-cache; subsequent
343    /// calls run the decode-mode graph on just the last token + cached
344    /// past. Output is bit-identical to [`step`] modulo reduction
345    /// order in the SDPA kernel.
346    ///
347    /// Invariant after each call: `cache.past_seq == tokens.len() - 1`
348    /// (the just-sampled token is appended but not yet in the cache;
349    /// it becomes the input for the next decode step).
350    pub fn step_cached(&mut self, opts: SampleOpts) -> Result<u32> {
351        if self.tokens.is_empty() {
352            anyhow::bail!("step_cached() called with empty token history; call prefill() first");
353        }
354        if self.cache.is_none() {
355            // The seed runs prefill, populates the cache, samples from
356            // the last position, and appends the token. Return that
357            // token directly — no decode step on this call.
358            let tok = self.seed_cache_from_prompt(opts)?;
359            return Ok(tok);
360        }
361        let cache = self.cache.as_ref().unwrap();
362        let past_seq = cache.past_seq;
363        // The token we feed into decode is whatever's after the cached
364        // prefix in `self.tokens`. After a prior cached step this is
365        // the just-sampled token; after seeding it's the same.
366        if self.tokens.len() <= past_seq {
367            anyhow::bail!(
368                "cache invariant violated: tokens.len() {} <= past_seq {}",
369                self.tokens.len(),
370                past_seq
371            );
372        }
373        let input_tok = self.tokens[past_seq];
374
375        // Branch: bucketed compile cache vs one-shot compile per step.
376        let (logits, new_k, new_v) = if self.decode_dynamic_cache.is_some() {
377            self.decode_step_dynamic(past_seq, input_tok)?
378        } else if self.decode_compile_cache.is_some()
379            && self
380                .decode_compile_cache
381                .as_ref()
382                .unwrap()
383                .bucket_for(past_seq as u64)
384                .is_some()
385        {
386            self.decode_step_bucketed(past_seq, input_tok)?
387        } else {
388            self.decode_step_oneshot(past_seq, input_tok)?
389        };
390
391        let cache_mut = self.cache.as_mut().unwrap();
392        cache_mut.past_seq = past_seq + 1;
393        cache_mut.layers_k = new_k;
394        cache_mut.layers_v = new_v;
395
396        let vocab = self.cfg.vocab_size;
397        if logits.len() != vocab {
398            anyhow::bail!("decode logits length {} != vocab {}", logits.len(), vocab);
399        }
400        let tok = sample_token(&logits, opts) as u32;
401        self.tokens.push(tok);
402        Ok(tok)
403    }
404
405    /// Decode path that compiles a fresh graph for the exact `past_seq`
406    /// every call. Slower but always-correct fallback.
407    #[allow(clippy::type_complexity)]
408    fn decode_step_oneshot(
409        &mut self,
410        past_seq: usize,
411        input_tok: u32,
412    ) -> Result<(Vec<f32>, Vec<Vec<f32>>, Vec<Vec<f32>>)> {
413        let cache = self.cache.as_ref().unwrap();
414
415        let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
416        let (hir, params) =
417            build_llama32_decode_hir_sized(&self.cfg, &mut wm, /*batch*/ 1, past_seq)?;
418        let session = Session::new(self.device);
419        let mut compiled = self.compile_hir_profiled(&session, hir, true)?;
420        for (name, data) in &params {
421            compiled.set_param(name, data);
422        }
423
424        let (cos, sin) = compute_rope_slice(&self.inv_freq, past_seq);
425        let input_ids_f32 = [input_tok as f32];
426        let key_strs: Vec<String> = (0..self.cfg.num_hidden_layers)
427            .flat_map(|i| [format!("past_k_{i}"), format!("past_v_{i}")])
428            .collect();
429        let mut inputs: Vec<(&str, &[f32])> =
430            Vec::with_capacity(3 + 2 * self.cfg.num_hidden_layers);
431        inputs.push(("input_ids", input_ids_f32.as_slice()));
432        inputs.push(("rope_cos", cos.as_slice()));
433        inputs.push(("rope_sin", sin.as_slice()));
434        for i in 0..self.cfg.num_hidden_layers {
435            inputs.push((&key_strs[2 * i], cache.layers_k[i].as_slice()));
436            inputs.push((&key_strs[2 * i + 1], cache.layers_v[i].as_slice()));
437        }
438
439        let outputs = compiled.run(&inputs);
440        self.split_decode_outputs(outputs)
441    }
442
443    #[allow(clippy::type_complexity)]
444    fn decode_step_dynamic(
445        &mut self,
446        past_seq: usize,
447        input_tok: u32,
448    ) -> Result<(Vec<f32>, Vec<Vec<f32>>, Vec<Vec<f32>>)> {
449        let cache = self.cache.as_ref().unwrap();
450        let binding = DimBinding::batch_past_seq(1, past_seq);
451        let opts = self
452            .profile_compile_options(true)
453            .dim_binding(binding.clone());
454        let cache_dyn = self
455            .decode_dynamic_cache
456            .as_mut()
457            .ok_or_else(|| anyhow::anyhow!("dynamic decode without cache"))?;
458        let needs_upload = !cache_dyn.contains(past_seq as u64);
459        let cfg = self.cfg.clone();
460        let weights_cache = self.weights_cache.clone();
461        let max_past = self.cfg.max_position_embeddings;
462        let compiled = cache_dyn.get_or_specialize(
463            past_seq as u64,
464            &binding,
465            || {
466                let mut wm = WeightMap::from_tensors(weights_cache);
467                build_llama32_decode_hir_dynamic_ext(&cfg, &mut wm, 1, max_past)
468                    .expect("dynamic decode HIR")
469                    .0
470            },
471            &opts,
472        )?;
473        if needs_upload {
474            let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
475            let (_, params) =
476                build_llama32_decode_hir_dynamic_ext(&self.cfg, &mut wm, 1, max_past)?;
477            for (name, data) in &params {
478                compiled.set_param(name, data);
479            }
480        }
481
482        let (cos, sin) = compute_rope_slice(&self.inv_freq, past_seq);
483        let input_ids_f32 = [input_tok as f32];
484        let key_strs: Vec<String> = (0..self.cfg.num_hidden_layers)
485            .flat_map(|i| [format!("past_k_{i}"), format!("past_v_{i}")])
486            .collect();
487        let mut inputs: Vec<(&str, &[f32])> =
488            Vec::with_capacity(3 + 2 * self.cfg.num_hidden_layers);
489        inputs.push(("input_ids", input_ids_f32.as_slice()));
490        inputs.push(("rope_cos", cos.as_slice()));
491        inputs.push(("rope_sin", sin.as_slice()));
492        for i in 0..self.cfg.num_hidden_layers {
493            inputs.push((&key_strs[2 * i], cache.layers_k[i].as_slice()));
494            inputs.push((&key_strs[2 * i + 1], cache.layers_v[i].as_slice()));
495        }
496        let outputs = compiled.run(&inputs);
497        self.split_decode_outputs(outputs)
498    }
499
500    /// Decode path using the bucketed compile cache. Compiles one graph
501    /// per bucket (instead of per `past_seq`), pads `past_k`/`past_v` to
502    /// the bucket's upper bound, and uses a custom mask to zero out the
503    /// padded K positions in attention. After running, slices the
504    /// `new_k`/`new_v` outputs back to `actual_past + 1` length so the
505    /// stored cache stays compact.
506    #[allow(clippy::type_complexity)]
507    fn decode_step_bucketed(
508        &mut self,
509        past_seq: usize,
510        input_tok: u32,
511    ) -> Result<(Vec<f32>, Vec<Vec<f32>>, Vec<Vec<f32>>)> {
512        let cache_dec = self.decode_compile_cache.as_ref().unwrap();
513        let bucket_idx = cache_dec
514            .bucket_for(past_seq as u64)
515            .ok_or_else(|| anyhow::anyhow!("past_seq {past_seq} outside any bucket"))?;
516        let upper = cache_dec
517            .buckets()
518            .nth(bucket_idx)
519            .map(|r| r.end - 1)
520            .unwrap() as usize;
521
522        let kv_dim = self.cfg.kv_proj_dim();
523        let n_layers = self.cfg.num_hidden_layers;
524
525        // First-time-in-bucket: build the graph + compile + attach
526        // params, then mark the bucket as loaded. Subsequent calls skip
527        // all of this and just .run() the cached graph.
528        let needs_load = !self.decode_loaded_buckets.contains(&bucket_idx);
529        if needs_load {
530            let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
531            let (hir, params) = build_llama32_decode_hir_sized_ext(
532                &self.cfg, &mut wm, /*batch*/ 1, upper, /*use_custom_mask*/ true,
533            )?;
534            {
535                let decode_opts = self.profile_compile_options(true);
536                let cache_mut = self.decode_compile_cache.as_mut().unwrap();
537                let (_u, compiled) = cache_mut
538                    .get_or_compile_hir_with_options(past_seq as u64, |_upper| hir, &decode_opts)
539                    .expect("bucket must exist; we just looked it up");
540                for (name, data) in &params {
541                    compiled.set_param(name, data);
542                }
543            }
544            self.decode_loaded_buckets.insert(bucket_idx);
545        }
546
547        // Prepare host-side inputs.
548        let (cos, sin) = compute_rope_slice(&self.inv_freq, past_seq);
549        let input_ids_f32 = [input_tok as f32];
550
551        // Mask: shape [1, upper + 1]. 1.0 at positions 0..(past_seq + 1),
552        // 0.0 at (past_seq + 1)..(upper + 1). Without the mask the padded
553        // zero rows would still steal softmax weight (e^0 = 1 per pad
554        // position) and silently scale the output down.
555        let mask_len = upper + 1;
556        let mut mask = vec![0.0f32; mask_len];
557        for v in mask.iter_mut().take(past_seq + 1) {
558            *v = 1.0;
559        }
560
561        // Pad past_k / past_v to length `upper`.
562        let padded_k: Vec<Vec<f32>> = (0..n_layers)
563            .map(|i| {
564                let src = &self.cache.as_ref().unwrap().layers_k[i];
565                let mut out = vec![0f32; upper * kv_dim];
566                out[..src.len()].copy_from_slice(src);
567                out
568            })
569            .collect();
570        let padded_v: Vec<Vec<f32>> = (0..n_layers)
571            .map(|i| {
572                let src = &self.cache.as_ref().unwrap().layers_v[i];
573                let mut out = vec![0f32; upper * kv_dim];
574                out[..src.len()].copy_from_slice(src);
575                out
576            })
577            .collect();
578
579        let key_strs: Vec<String> = (0..n_layers)
580            .flat_map(|i| [format!("past_k_{i}"), format!("past_v_{i}")])
581            .collect();
582        let mut inputs: Vec<(&str, &[f32])> = Vec::with_capacity(4 + 2 * n_layers);
583        inputs.push(("input_ids", input_ids_f32.as_slice()));
584        inputs.push(("rope_cos", cos.as_slice()));
585        inputs.push(("rope_sin", sin.as_slice()));
586        inputs.push(("mask", mask.as_slice()));
587        for i in 0..n_layers {
588            inputs.push((&key_strs[2 * i], padded_k[i].as_slice()));
589            inputs.push((&key_strs[2 * i + 1], padded_v[i].as_slice()));
590        }
591
592        let cache_mut = self.decode_compile_cache.as_mut().unwrap();
593        let (_u, compiled) = cache_mut
594            .get_or_compile_hir(past_seq as u64, |_| {
595                unreachable!("bucket was just loaded above")
596            })
597            .unwrap();
598        let raw_outputs = compiled.run(&inputs);
599
600        // The graph emits new_k/new_v at length `upper + 1` (padded
601        // past + the new token). Slice each back to `past_seq + 1` so
602        // the stored cache only holds real positions.
603        let mut iter = raw_outputs.into_iter();
604        let logits = iter.next().context("bucketed decode logits missing")?;
605        let real_len = (past_seq + 1) * kv_dim;
606        let mut new_k = Vec::with_capacity(n_layers);
607        let mut new_v = Vec::with_capacity(n_layers);
608        for _ in 0..n_layers {
609            let k = iter.next().context("bucketed k missing")?;
610            let v = iter.next().context("bucketed v missing")?;
611            new_k.push(k[..real_len].to_vec());
612            new_v.push(v[..real_len].to_vec());
613        }
614        Ok((logits, new_k, new_v))
615    }
616
617    /// Run prefill-with-cache and return the raw outputs. Uses the
618    /// LRU `CompileCache` when enabled; otherwise compiles fresh each
619    /// call. Keyed by `seq` because graph shape is seq-specialized.
620    fn run_prefill_with_cache(
621        &mut self,
622        batch: usize,
623        seq: usize,
624        ids_f32: &[f32],
625    ) -> Result<Vec<Vec<f32>>> {
626        let dynamic_prefill = self.prefill_dynamic_cache.is_some().then(|| {
627            let binding = DimBinding::batch_seq(batch, seq);
628            let opts = self
629                .profile_compile_options(false)
630                .dim_binding(binding.clone());
631            (binding, opts)
632        });
633        if let (Some(cache), Some((binding, opts))) = (
634            self.prefill_dynamic_cache.as_mut(),
635            dynamic_prefill.as_ref(),
636        ) {
637            let needs_upload = !cache.contains(seq as u64);
638            let cfg = self.cfg.clone();
639            let weights_cache = self.weights_cache.clone();
640            let max_seq = self.cfg.max_position_embeddings;
641            let compiled = cache.get_or_specialize(
642                seq as u64,
643                binding,
644                || {
645                    let mut wm = WeightMap::from_tensors(weights_cache);
646                    build_llama32_prefill_hir_dynamic_ext(&cfg, &mut wm, batch, max_seq, true)
647                        .expect("dynamic prefill HIR")
648                        .0
649                },
650                opts,
651            )?;
652            if needs_upload {
653                let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
654                let (_, params) = build_llama32_prefill_hir_dynamic_ext(
655                    &self.cfg, &mut wm, batch, max_seq, true,
656                )?;
657                for (name, data) in &params {
658                    compiled.set_param(name, data);
659                }
660            }
661            let last_idx = vec![(seq - 1) as f32];
662            Ok(compiled.run(&[("input_ids", ids_f32), ("last_token_idx", &last_idx)]))
663        } else if let Some(prefill_cache) = self.prefill_compile_cache.as_mut() {
664            let key = ((batch as u64) << 32) | (seq as u64);
665            if !prefill_cache.contains(key) {
666                let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
667                let (graph, params) = build_llama32_graph_sized_last_logits(
668                    &self.cfg, &mut wm, batch, seq, /*with_kv_outputs*/ true,
669                )?;
670                {
671                    let compiled = prefill_cache.get_or_compile(key, || graph);
672                    for (name, data) in &params {
673                        compiled.set_param(name, data);
674                    }
675                }
676            }
677            let compiled =
678                prefill_cache.get_or_compile(key, || unreachable!("just populated above"));
679            Ok(compiled.run(&[("input_ids", ids_f32)]))
680        } else {
681            let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
682            let (graph, params) = build_llama32_graph_sized_last_logits(
683                &self.cfg, &mut wm, batch, seq, /*with_kv_outputs*/ true,
684            )?;
685            let session = Session::new(self.device);
686            let mut compiled = self.compile_graph_profiled(&session, graph)?;
687            for (name, data) in &params {
688                compiled.set_param(name, data);
689            }
690            Ok(compiled.run(&[("input_ids", ids_f32)]))
691        }
692    }
693
694    /// Split raw graph outputs (logits + per-layer K + per-layer V) into
695    /// (logits, layers_k, layers_v) for the one-shot decode path. The
696    /// bucketed path needs slicing too, so it doesn't reuse this.
697    #[allow(clippy::type_complexity)]
698    fn split_decode_outputs(
699        &self,
700        outputs: Vec<Vec<f32>>,
701    ) -> Result<(Vec<f32>, Vec<Vec<f32>>, Vec<Vec<f32>>)> {
702        let n_layers = self.cfg.num_hidden_layers;
703        if outputs.len() != 1 + 2 * n_layers {
704            anyhow::bail!(
705                "decode graph produced {} outputs, expected {}",
706                outputs.len(),
707                1 + 2 * n_layers
708            );
709        }
710        let mut iter = outputs.into_iter();
711        let logits = iter.next().context("decode logits missing")?;
712        let mut layers_k = Vec::with_capacity(n_layers);
713        let mut layers_v = Vec::with_capacity(n_layers);
714        for _ in 0..n_layers {
715            layers_k.push(iter.next().context("decode k missing")?);
716            layers_v.push(iter.next().context("decode v missing")?);
717        }
718        Ok((logits, layers_k, layers_v))
719    }
720
721    /// Run `n` cached steps and return the newly generated tokens.
722    pub fn generate_cached(&mut self, n: usize, opts: SampleOpts) -> Result<Vec<u32>> {
723        self.generate_cached_with(n, opts, |_| {})
724    }
725
726    /// Same as [`generate_cached`] but invokes `on_token` once per
727    /// freshly sampled id, inside the decode loop. The whole `n` step
728    /// loop shares the bucketed compile cache — callers wanting a
729    /// streaming UI should prefer this to calling
730    /// `generate_cached(1, …)` `n` times (which forces a fresh
731    /// compile per token at the bucket boundaries).
732    pub fn generate_cached_with(
733        &mut self,
734        n: usize,
735        opts: SampleOpts,
736        mut on_token: impl FnMut(u32),
737    ) -> Result<Vec<u32>> {
738        let start = self.tokens.len();
739        for _ in 0..n {
740            let tok = self.step_cached(opts)?;
741            on_token(tok);
742        }
743        Ok(self.tokens[start..].to_vec())
744    }
745
746    /// Run prefill-with-cache on the current `self.tokens` (the
747    /// prompt), populate `self.cache`, sample the next token from the
748    /// last position's logits, and append it. Returns the sampled
749    /// token. Invariant after: `cache.past_seq == tokens.len() - 1`.
750    fn seed_cache_from_prompt(&mut self, opts: SampleOpts) -> Result<u32> {
751        let seq = self.tokens.len();
752        let batch = 1usize;
753        let kv_dim = self.cfg.kv_proj_dim();
754
755        let ids_f32: Vec<f32> = self.tokens.iter().map(|&i| i as f32).collect();
756        let outputs = self.run_prefill_with_cache(batch, seq, &ids_f32)?;
757        if outputs.len() != 1 + 2 * self.cfg.num_hidden_layers {
758            anyhow::bail!(
759                "prefill-with-cache produced {} outputs, expected {}",
760                outputs.len(),
761                1 + 2 * self.cfg.num_hidden_layers
762            );
763        }
764        let expected_kv_len = batch * seq * kv_dim;
765        let mut iter = outputs.into_iter();
766        let logits = iter.next().context("prefill logits missing")?;
767        let mut layers_k = Vec::with_capacity(self.cfg.num_hidden_layers);
768        let mut layers_v = Vec::with_capacity(self.cfg.num_hidden_layers);
769        for layer in 0..self.cfg.num_hidden_layers {
770            let k = iter.next().context("prefill k missing")?;
771            let v = iter.next().context("prefill v missing")?;
772            if k.len() != expected_kv_len || v.len() != expected_kv_len {
773                anyhow::bail!(
774                    "layer {layer}: k.len={} v.len={} expected {}",
775                    k.len(),
776                    v.len(),
777                    expected_kv_len
778                );
779            }
780            layers_k.push(k);
781            layers_v.push(v);
782        }
783        self.cache = Some(KvCacheState {
784            past_seq: seq,
785            layers_k,
786            layers_v,
787        });
788
789        let vocab = self.cfg.vocab_size;
790        let needed = vocab;
791        if logits.len() < needed {
792            anyhow::bail!("prefill logits length {} < {}", logits.len(), needed);
793        }
794        let last_row = &logits[..vocab];
795        let tok = sample_token(last_row, opts) as u32;
796        self.tokens.push(tok);
797        Ok(tok)
798    }
799
800    /// Full token history (prompt + generated).
801    pub fn tokens(&self) -> &[u32] {
802        &self.tokens
803    }
804
805    pub fn config(&self) -> &Llama32Config {
806        &self.cfg
807    }
808
809    /// Low-level primitive: reset internal state, run prefill-with-cache
810    /// over `context`, and return the *last position's* logits row
811    /// (`P(next_token | context)`). Does NOT sample or append. The
812    /// internal `tokens` buffer is set to `context` and the KV cache
813    /// is populated to `past_seq = context.len()`.
814    ///
815    /// First row of logits after prefill-with-cache (no sampling).
816    pub fn prefill_get_last_logits(&mut self, context: &[u32]) -> Result<Vec<f32>> {
817        if context.is_empty() {
818            anyhow::bail!("prefill_get_last_logits: empty context");
819        }
820        self.tokens.clear();
821        self.tokens.extend_from_slice(context);
822        self.cache = None;
823
824        let seq = context.len();
825        let batch = 1usize;
826        let kv_dim = self.cfg.kv_proj_dim();
827
828        let ids_f32: Vec<f32> = context.iter().map(|&i| i as f32).collect();
829        let outputs = self.run_prefill_with_cache(batch, seq, &ids_f32)?;
830        if outputs.len() != 1 + 2 * self.cfg.num_hidden_layers {
831            anyhow::bail!(
832                "prefill_get_last_logits: got {} outputs, expected {}",
833                outputs.len(),
834                1 + 2 * self.cfg.num_hidden_layers
835            );
836        }
837        let expected_kv_len = batch * seq * kv_dim;
838        let mut iter = outputs.into_iter();
839        let logits = iter.next().context("logits missing")?;
840        let mut layers_k = Vec::with_capacity(self.cfg.num_hidden_layers);
841        let mut layers_v = Vec::with_capacity(self.cfg.num_hidden_layers);
842        for _ in 0..self.cfg.num_hidden_layers {
843            let k = iter.next().context("k missing")?;
844            let v = iter.next().context("v missing")?;
845            if k.len() != expected_kv_len || v.len() != expected_kv_len {
846                anyhow::bail!("kv length mismatch in prefill_get_last_logits");
847            }
848            layers_k.push(k);
849            layers_v.push(v);
850        }
851        self.cache = Some(KvCacheState {
852            past_seq: seq,
853            layers_k,
854            layers_v,
855        });
856
857        let vocab = self.cfg.vocab_size;
858        let needed = vocab;
859        if logits.len() < needed {
860            anyhow::bail!("logits short: {} < {}", logits.len(), needed);
861        }
862        Ok(logits[..vocab].to_vec())
863    }
864
865    /// Low-level primitive: run one decode step with the caller-
866    /// supplied input token (no sampling), advance the KV cache, and
867    /// return the resulting logits row `P(next | history ++ input)`.
868    /// Appends `input` to the `tokens` buffer so the invariant
869    /// `cache.past_seq == tokens.len()` holds after this call (note:
870    /// differs from `step_cached` invariant because this method does
871    /// not append a sampled token).
872    pub fn decode_get_logits(&mut self, input: u32) -> Result<Vec<f32>> {
873        let cache = self.cache.as_ref().ok_or_else(|| {
874            anyhow::anyhow!(
875                "decode_get_logits: cache not seeded; call prefill_get_last_logits first"
876            )
877        })?;
878        let past_seq = cache.past_seq;
879
880        let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
881        let (hir, params) =
882            build_llama32_decode_hir_sized(&self.cfg, &mut wm, /*batch*/ 1, past_seq)?;
883        let session = Session::new(self.device);
884        let mut compiled = self.compile_hir_profiled(&session, hir, true)?;
885        for (name, data) in &params {
886            compiled.set_param(name, data);
887        }
888
889        let (cos, sin) = compute_rope_slice(&self.inv_freq, past_seq);
890        let input_ids_f32 = [input as f32];
891        let key_strs: Vec<String> = (0..self.cfg.num_hidden_layers)
892            .flat_map(|i| [format!("past_k_{i}"), format!("past_v_{i}")])
893            .collect();
894        let mut inputs: Vec<(&str, &[f32])> =
895            Vec::with_capacity(3 + 2 * self.cfg.num_hidden_layers);
896        inputs.push(("input_ids", input_ids_f32.as_slice()));
897        inputs.push(("rope_cos", cos.as_slice()));
898        inputs.push(("rope_sin", sin.as_slice()));
899        for i in 0..self.cfg.num_hidden_layers {
900            let pk = &cache.layers_k[i];
901            let pv = &cache.layers_v[i];
902            inputs.push((&key_strs[2 * i], pk.as_slice()));
903            inputs.push((&key_strs[2 * i + 1], pv.as_slice()));
904        }
905
906        let outputs = compiled.run(&inputs);
907        if outputs.len() != 1 + 2 * self.cfg.num_hidden_layers {
908            anyhow::bail!(
909                "decode_get_logits: got {} outputs, expected {}",
910                outputs.len(),
911                1 + 2 * self.cfg.num_hidden_layers
912            );
913        }
914        let mut iter = outputs.into_iter();
915        let logits = iter.next().context("logits missing")?;
916        let mut new_k = Vec::with_capacity(self.cfg.num_hidden_layers);
917        let mut new_v = Vec::with_capacity(self.cfg.num_hidden_layers);
918        for _ in 0..self.cfg.num_hidden_layers {
919            new_k.push(iter.next().context("k missing")?);
920            new_v.push(iter.next().context("v missing")?);
921        }
922
923        let cache_mut = self.cache.as_mut().unwrap();
924        cache_mut.past_seq = past_seq + 1;
925        cache_mut.layers_k = new_k;
926        cache_mut.layers_v = new_v;
927        self.tokens.push(input);
928
929        Ok(logits)
930    }
931}
932
933/// Compute the single-row (cos, sin) RoPE slice for absolute position
934/// `pos`. Matches the formula in the prefill builder so cached decode
935/// and recompute prefill produce the same RoPE rotation.
936fn compute_rope_slice(inv_freq: &[f64], pos: usize) -> (Vec<f32>, Vec<f32>) {
937    rope_slice(inv_freq, pos)
938}
939
940#[cfg(test)]
941mod tests {
942    use super::*;
943    use crate::config::Llama32Config;
944
945    fn tiny_cfg() -> Llama32Config {
946        Llama32Config {
947            vocab_size: 16,
948            hidden_size: 16,
949            intermediate_size: 32,
950            num_hidden_layers: 2,
951            num_attention_heads: 4,
952            num_key_value_heads: 2,
953            max_position_embeddings: 16,
954            rms_norm_eps: 1e-5,
955            rope_theta: 500_000.0,
956            hidden_act: "silu".into(),
957            tie_word_embeddings: false,
958            attention_bias: false,
959            head_dim: Some(8),
960            rope_scaling: None,
961        }
962    }
963
964    fn synthetic_weights(cfg: &Llama32Config) -> WeightMap {
965        let h = cfg.hidden_size;
966        let q_dim = cfg.q_proj_dim();
967        let kv_dim = cfg.kv_proj_dim();
968        let int_dim = cfg.intermediate_size;
969        let mut t: HashMap<String, (Vec<f32>, Vec<usize>)> = HashMap::new();
970        // Use a deterministic non-zero pattern so logits aren't all 0
971        // (sampling on an all-zero row is undefined order).
972        let pat = |n: usize, salt: u32| -> Vec<f32> {
973            (0..n)
974                .map(|i| {
975                    let x = ((i as u32).wrapping_mul(2654435761).wrapping_add(salt)) >> 8;
976                    (x as f32 / (1u32 << 24) as f32) - 0.5
977                })
978                .collect()
979        };
980        t.insert(
981            "model.embed_tokens.weight".into(),
982            (pat(cfg.vocab_size * h, 1), vec![cfg.vocab_size, h]),
983        );
984        for i in 0..cfg.num_hidden_layers {
985            let lp = format!("model.layers.{i}");
986            t.insert(
987                format!("{lp}.input_layernorm.weight"),
988                (pat(h, 100 + i as u32), vec![h]),
989            );
990            t.insert(
991                format!("{lp}.post_attention_layernorm.weight"),
992                (pat(h, 200 + i as u32), vec![h]),
993            );
994            t.insert(
995                format!("{lp}.self_attn.q_proj.weight"),
996                (pat(q_dim * h, 300 + i as u32), vec![q_dim, h]),
997            );
998            t.insert(
999                format!("{lp}.self_attn.k_proj.weight"),
1000                (pat(kv_dim * h, 400 + i as u32), vec![kv_dim, h]),
1001            );
1002            t.insert(
1003                format!("{lp}.self_attn.v_proj.weight"),
1004                (pat(kv_dim * h, 500 + i as u32), vec![kv_dim, h]),
1005            );
1006            t.insert(
1007                format!("{lp}.self_attn.o_proj.weight"),
1008                (pat(h * q_dim, 600 + i as u32), vec![h, q_dim]),
1009            );
1010            t.insert(
1011                format!("{lp}.mlp.gate_proj.weight"),
1012                (pat(int_dim * h, 900 + i as u32), vec![int_dim, h]),
1013            );
1014            t.insert(
1015                format!("{lp}.mlp.up_proj.weight"),
1016                (pat(int_dim * h, 1000 + i as u32), vec![int_dim, h]),
1017            );
1018            t.insert(
1019                format!("{lp}.mlp.down_proj.weight"),
1020                (pat(h * int_dim, 1100 + i as u32), vec![h, int_dim]),
1021            );
1022        }
1023        t.insert("model.norm.weight".into(), (pat(h, 2000), vec![h]));
1024        t.insert(
1025            "lm_head.weight".into(),
1026            (pat(cfg.vocab_size * h, 3000), vec![cfg.vocab_size, h]),
1027        );
1028        WeightMap::from_tensors(t)
1029    }
1030
1031    #[test]
1032    fn generator_drains_loader_and_runs_one_step() {
1033        let cfg = tiny_cfg();
1034        let mut wm = synthetic_weights(&cfg);
1035        let mut gn = Llama32Generator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap();
1036        assert_eq!(wm.len(), 0, "loader should be drained");
1037        gn.prefill(&[1, 2, 3]);
1038        let t = gn.step(SampleOpts::greedy()).unwrap();
1039        assert!((t as usize) < cfg.vocab_size);
1040        assert_eq!(gn.tokens().len(), 4);
1041    }
1042
1043    #[test]
1044    fn generate_n_appends_n_tokens() {
1045        let cfg = tiny_cfg();
1046        let mut wm = synthetic_weights(&cfg);
1047        let mut gn = Llama32Generator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap();
1048        gn.prefill(&[5, 6]);
1049        let new_tokens = gn.generate(3, SampleOpts::greedy()).unwrap();
1050        assert_eq!(new_tokens.len(), 3);
1051        assert_eq!(gn.tokens().len(), 5);
1052        for t in &new_tokens {
1053            assert!((*t as usize) < cfg.vocab_size);
1054        }
1055    }
1056
1057    #[test]
1058    fn step_without_prefill_errors() {
1059        let cfg = tiny_cfg();
1060        let mut wm = synthetic_weights(&cfg);
1061        let mut gn = Llama32Generator::from_loader(cfg, &mut wm, Device::Cpu).unwrap();
1062        let r = gn.step(SampleOpts::greedy());
1063        assert!(r.is_err());
1064    }
1065
1066    #[test]
1067    fn cached_matches_naive_on_greedy() {
1068        // The cached and naive paths must produce the same token
1069        // sequence given the same prompt + opts. This is the
1070        // load-bearing test for the KV-cache implementation: if the
1071        // decode-mode graph, the kernel's Lq!=Lk fix, the cache
1072        // wiring, or the RoPE position-slice is wrong, the sequences
1073        // diverge here.
1074        let cfg = tiny_cfg();
1075        let prompt: Vec<u32> = vec![1, 2, 3, 5];
1076        let steps = 4;
1077
1078        let mut wm_n = synthetic_weights(&cfg);
1079        let mut gn_naive =
1080            Llama32Generator::from_loader(cfg.clone(), &mut wm_n, Device::Cpu).unwrap();
1081        gn_naive.prefill(&prompt);
1082        let naive_tokens = gn_naive.generate(steps, SampleOpts::greedy()).unwrap();
1083
1084        let mut wm_c = synthetic_weights(&cfg);
1085        let mut gn_cached =
1086            Llama32Generator::from_loader(cfg.clone(), &mut wm_c, Device::Cpu).unwrap();
1087        gn_cached.prefill(&prompt);
1088        let cached_tokens = gn_cached
1089            .generate_cached(steps, SampleOpts::greedy())
1090            .unwrap();
1091
1092        assert_eq!(
1093            cached_tokens, naive_tokens,
1094            "cached vs naive token mismatch — KV cache or kernel-Lq!=Lk bug"
1095        );
1096    }
1097
1098    #[test]
1099    fn cached_step_advances_cache_invariant() {
1100        let cfg = tiny_cfg();
1101        let mut wm = synthetic_weights(&cfg);
1102        let mut gn = Llama32Generator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap();
1103        gn.prefill(&[1, 2, 3]);
1104        let _ = gn.step_cached(SampleOpts::greedy()).unwrap();
1105        // After seed: tokens.len() == 4, cache.past_seq == 3 (cache holds prompt).
1106        assert_eq!(gn.tokens().len(), 4);
1107        assert_eq!(gn.cache.as_ref().unwrap().past_seq, 3);
1108        let _ = gn.step_cached(SampleOpts::greedy()).unwrap();
1109        // After one decode: tokens.len() == 5, cache.past_seq == 4.
1110        assert_eq!(gn.tokens().len(), 5);
1111        assert_eq!(gn.cache.as_ref().unwrap().past_seq, 4);
1112    }
1113
1114    #[test]
1115    fn bucketed_decode_matches_oneshot() {
1116        // The bucketed compile-cache path (padded K/V + custom mask)
1117        // must produce the same token sequence as the one-shot
1118        // path. Load-bearing for the bucketed cache feature: if the
1119        // mask, padding, or output slicing is wrong, sequences
1120        // diverge here.
1121        let cfg = tiny_cfg();
1122        let prompt: Vec<u32> = vec![1, 2, 3, 5];
1123        let steps = 6;
1124
1125        let mut wm_one = synthetic_weights(&cfg);
1126        let mut gn_one =
1127            Llama32Generator::from_loader(cfg.clone(), &mut wm_one, Device::Cpu).unwrap();
1128        gn_one.prefill(&prompt);
1129        let oneshot_tokens = gn_one.generate_cached(steps, SampleOpts::greedy()).unwrap();
1130
1131        let mut wm_buc = synthetic_weights(&cfg);
1132        let mut gn_buc = Llama32Generator::from_loader(cfg.clone(), &mut wm_buc, Device::Cpu)
1133            .unwrap()
1134            .with_decode_cache(/*max_past*/ 32);
1135        gn_buc.prefill(&prompt);
1136        let bucketed_tokens = gn_buc.generate_cached(steps, SampleOpts::greedy()).unwrap();
1137
1138        assert_eq!(
1139            bucketed_tokens, oneshot_tokens,
1140            "bucketed-cache decode diverged from one-shot decode — \
1141             mask, padding, or output-slice bug"
1142        );
1143    }
1144
1145    #[test]
1146    fn prefill_compile_cache_does_not_change_output() {
1147        let cfg = tiny_cfg();
1148        let prompt: Vec<u32> = vec![1, 2, 3, 5];
1149        let mut wm_a = synthetic_weights(&cfg);
1150        let mut gn_a = Llama32Generator::from_loader(cfg.clone(), &mut wm_a, Device::Cpu).unwrap();
1151        gn_a.prefill(&prompt);
1152        let a = gn_a.generate_cached(4, SampleOpts::greedy()).unwrap();
1153
1154        let mut wm_b = synthetic_weights(&cfg);
1155        let mut gn_b = Llama32Generator::from_loader(cfg.clone(), &mut wm_b, Device::Cpu)
1156            .unwrap()
1157            .with_prefill_cache(/*capacity*/ 4);
1158        gn_b.prefill(&prompt);
1159        let b = gn_b.generate_cached(4, SampleOpts::greedy()).unwrap();
1160
1161        assert_eq!(a, b, "enabling prefill_cache must not change output");
1162    }
1163
1164    #[test]
1165    fn dynamic_decode_matches_oneshot() {
1166        let cfg = tiny_cfg();
1167        let prompt: Vec<u32> = vec![1, 2, 3, 5];
1168        let steps = 6;
1169
1170        let mut wm_one = synthetic_weights(&cfg);
1171        let mut gn_one =
1172            Llama32Generator::from_loader(cfg.clone(), &mut wm_one, Device::Cpu).unwrap();
1173        gn_one.prefill(&prompt);
1174        let oneshot_tokens = gn_one.generate_cached(steps, SampleOpts::greedy()).unwrap();
1175
1176        let mut wm_dyn = synthetic_weights(&cfg);
1177        let mut gn_dyn = Llama32Generator::from_loader(cfg.clone(), &mut wm_dyn, Device::Cpu)
1178            .unwrap()
1179            .with_dynamic_decode_cache(/*capacity*/ 8);
1180        gn_dyn.prefill(&prompt);
1181        let dynamic_tokens = gn_dyn.generate_cached(steps, SampleOpts::greedy()).unwrap();
1182
1183        assert_eq!(
1184            dynamic_tokens, oneshot_tokens,
1185            "dynamic past_seq decode diverged from one-shot decode"
1186        );
1187    }
1188
1189    #[test]
1190    fn dynamic_prefill_matches_oneshot() {
1191        let cfg = tiny_cfg();
1192        let prompt: Vec<u32> = vec![1, 2, 3, 5];
1193        let steps = 4;
1194
1195        let mut wm_one = synthetic_weights(&cfg);
1196        let mut gn_one =
1197            Llama32Generator::from_loader(cfg.clone(), &mut wm_one, Device::Cpu).unwrap();
1198        gn_one.prefill(&prompt);
1199        let oneshot_tokens = gn_one.generate_cached(steps, SampleOpts::greedy()).unwrap();
1200
1201        let mut wm_dyn = synthetic_weights(&cfg);
1202        let mut gn_dyn = Llama32Generator::from_loader(cfg.clone(), &mut wm_dyn, Device::Cpu)
1203            .unwrap()
1204            .with_dynamic_prefill_cache(/*capacity*/ 8);
1205        gn_dyn.prefill(&prompt);
1206        let dynamic_tokens = gn_dyn.generate_cached(steps, SampleOpts::greedy()).unwrap();
1207
1208        assert_eq!(
1209            dynamic_tokens, oneshot_tokens,
1210            "dynamic seq prefill diverged from one-shot prefill"
1211        );
1212    }
1213
1214    #[test]
1215    fn dynamic_prefill_and_decode_matches_oneshot() {
1216        let cfg = tiny_cfg();
1217        let prompt: Vec<u32> = vec![1, 2, 3, 5];
1218        let steps = 6;
1219
1220        let mut wm_one = synthetic_weights(&cfg);
1221        let mut gn_one =
1222            Llama32Generator::from_loader(cfg.clone(), &mut wm_one, Device::Cpu).unwrap();
1223        gn_one.prefill(&prompt);
1224        let oneshot_tokens = gn_one.generate_cached(steps, SampleOpts::greedy()).unwrap();
1225
1226        let mut wm_dyn = synthetic_weights(&cfg);
1227        let mut gn_dyn = Llama32Generator::from_loader(cfg.clone(), &mut wm_dyn, Device::Cpu)
1228            .unwrap()
1229            .with_dynamic_prefill_cache(/*capacity*/ 8)
1230            .with_dynamic_decode_cache(/*capacity*/ 8);
1231        gn_dyn.prefill(&prompt);
1232        let dynamic_tokens = gn_dyn.generate_cached(steps, SampleOpts::greedy()).unwrap();
1233
1234        assert_eq!(
1235            dynamic_tokens, oneshot_tokens,
1236            "dynamic prefill+decode diverged from one-shot path"
1237        );
1238    }
1239
1240    #[test]
1241    fn greedy_is_deterministic_across_runs() {
1242        let cfg = tiny_cfg();
1243        let weights = synthetic_weights(&cfg);
1244        let mk = || {
1245            let mut wm = WeightMap::from_tensors(weights_as_hashmap(&weights));
1246            Llama32Generator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap()
1247        };
1248        let mut a = mk();
1249        let mut b = mk();
1250        a.prefill(&[1, 2, 3]);
1251        b.prefill(&[1, 2, 3]);
1252        let ta = a.generate(4, SampleOpts::greedy()).unwrap();
1253        let tb = b.generate(4, SampleOpts::greedy()).unwrap();
1254        assert_eq!(ta, tb);
1255    }
1256
1257    fn weights_as_hashmap(wm: &WeightMap) -> HashMap<String, (Vec<f32>, Vec<usize>)> {
1258        // Reconstruct the underlying map by re-running synthetic_weights
1259        // — WeightMap doesn't expose its inner map. Sufficient for the
1260        // determinism test since synthetic_weights is itself
1261        // deterministic.
1262        let _ = wm; // silence unused
1263        let cfg = tiny_cfg();
1264        let mut new = synthetic_weights(&cfg);
1265        let keys: Vec<String> = new.keys().map(|s| s.to_string()).collect();
1266        let mut out = HashMap::new();
1267        for k in keys {
1268            out.insert(k.clone(), new.take(&k).unwrap());
1269        }
1270        out
1271    }
1272}