Skip to main content

rlx_llama32/
generator.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Host-side generation loop for LLaMA-3.2.
17//!
18//! This is the **naive** generator: each `step()` rebuilds the prefill
19//! graph for the full token history and runs it from scratch
20//! (O(N²) compute over N generated tokens). The API is shaped to
21//! match the upcoming KV-cache version exactly so callers don't have
22//! to change anything when the cached path lands — only the internal
23//! implementation swaps.
24//!
25//! Why ship the naive version first:
26//!   - Establishes the public API contract before the IR/kernel
27//!     changes that the cached version needs land.
28//!   - Lets you run end-to-end generation against a real checkpoint
29//!     today and validate the prefill graph is numerically correct.
30//!   - Provides a reference oracle for the cached version's own
31//!     numerical-parity test (cached vs recompute must match).
32
33use crate::builder::{
34    build_llama32_decode_hir_dynamic_ext, build_llama32_decode_hir_sized,
35    build_llama32_decode_hir_sized_ext, build_llama32_graph_sized_last_logits,
36    build_llama32_prefill_hir_dynamic_ext,
37};
38use crate::config::Llama32Config;
39use crate::rope::{resolve_inv_freq, rope_slice};
40use anyhow::{Context, Result};
41use rlx_core::flow_bridge::compile_options_from_profile;
42use rlx_core::weight_loader::WeightLoader;
43use rlx_core::weight_map::WeightMap;
44use rlx_flow::CompileProfile;
45use rlx_ir::DimBinding;
46use rlx_ir::logical_kernel::KernelDispatchConfig;
47use rlx_qwen3::sampling::{SampleOpts, sample_token};
48use rlx_runtime::compile_cache::{BucketedCompileCache, CompileCache, DynamicDimCompileCache};
49use rlx_runtime::{CompileOptions, Device, Session};
50use std::collections::{HashMap, HashSet};
51use std::path::Path;
52
53/// MPSGraph on Metal has two known gaps in `rlx-metal` 0.2.1:
54///
55/// 1. **Decode attention** — assumes `seq_q == seq_kv`; KV-cache decode breaks.
56/// 2. **Packed GGUF `DequantMatMul`** — lowers U8 quant bytes as F32 constants
57///    (`w_bytes.len() == k*n*4`), destroying parity vs CPU thunks.
58///
59/// Force the Metal thunk path for affected compiles until `rlx-metal` ≥ 0.2.2.
60fn metal_thunk_compile_guard<R, F>(device: Device, f: F) -> R
61where
62    F: FnOnce() -> R,
63{
64    if device == Device::Metal {
65        rlx_ir::env::set("RLX_DISABLE_MPSGRAPH", "1");
66        let out = f();
67        rlx_ir::env::unset("RLX_DISABLE_MPSGRAPH");
68        out
69    } else {
70        f()
71    }
72}
73
74fn metal_decode_compile_guard<R, F>(device: Device, decode: bool, f: F) -> R
75where
76    F: FnOnce() -> R,
77{
78    if decode {
79        metal_thunk_compile_guard(device, f)
80    } else {
81        f()
82    }
83}
84
85/// Per-layer KV cache state for incremental decoding. Each `Vec<f32>`
86/// is a flat `[batch, past_seq, kv_proj_dim]` tensor.
87#[derive(Clone)]
88struct KvCacheState {
89    past_seq: usize,
90    layers_k: Vec<Vec<f32>>,
91    layers_v: Vec<Vec<f32>>,
92}
93
94/// Stateful LLaMA-3.2 generation handle.
95///
96/// Holds the (config, weight bytes, token history) and rebuilds a
97/// prefill graph on each [`step`] call. Cheap to construct after
98/// initial weight load; tokens stay in-memory between calls.
99pub struct Llama32Generator {
100    cfg: Llama32Config,
101    /// Map of weight key → (f32 data, shape). Cloned on each step
102    /// into a fresh `WeightMap` because `WeightMap::take` is
103    /// destructive — see the cached-generator notes for the path
104    /// that avoids the clone.
105    weights_cache: HashMap<String, (Vec<f32>, Vec<usize>)>,
106    tokens: Vec<u32>,
107    device: Device,
108    /// Populated lazily on the first `step_cached` call (seeded from
109    /// the prompt via prefill-with-cache); thereafter advanced by each
110    /// decode step.
111    cache: Option<KvCacheState>,
112    /// Per-key LRU compile cache for prefill graphs. Keyed by `seq`.
113    /// Set to `None` to disable (default for new instances; opt in via
114    /// [`Llama32Generator::with_prefill_cache`]).
115    prefill_compile_cache: Option<CompileCache>,
116    /// Compile prefill once with `sym::SEQ`, specialize per prompt length.
117    prefill_dynamic_cache: Option<DynamicDimCompileCache>,
118    /// Bucketed compile cache for decode-mode graphs. Each bucket
119    /// holds one compiled graph specialized at its upper-bound
120    /// `past_seq`; the host pads `past_k`/`past_v` and supplies a
121    /// per-step mask so a single bucket serves every `past_seq` in
122    /// its range. Opt in via [`Llama32Generator::with_decode_cache`].
123    decode_compile_cache: Option<BucketedCompileCache>,
124    decode_dynamic_cache: Option<DynamicDimCompileCache>,
125    /// Tracks which decode buckets have had params attached. The
126    /// `BucketedCompileCache` API doesn't expose per-bucket compile
127    /// status, so we maintain it here to avoid double-loading params.
128    decode_loaded_buckets: HashSet<usize>,
129    /// Upper bound for dynamic HIR `SEQ` / `PAST_SEQ` (defaults to
130    /// [`Llama32Config::max_position_embeddings`]).
131    compile_seq_cap: Option<usize>,
132    /// Resolved RoPE inverse frequencies (includes Llama 3 scaling).
133    inv_freq: Vec<f64>,
134    /// Tier-1 compile profile for prefill graphs.
135    prefill_profile: CompileProfile,
136    /// Tier-1 compile profile for decode graphs.
137    decode_profile: CompileProfile,
138}
139
140impl Llama32Generator {
141    /// Construct from any [`WeightLoader`] — drains it into an
142    /// internal cache so the loader is free after this call.
143    pub fn from_loader(
144        cfg: Llama32Config,
145        loader: &mut dyn WeightLoader,
146        device: Device,
147    ) -> Result<Self> {
148        let keys = loader.remaining_keys();
149        let mut weights_cache = HashMap::with_capacity(keys.len());
150        for k in keys {
151            let v = loader
152                .take(&k)
153                .with_context(|| format!("draining weight {k}"))?;
154            // Normalize the cache key to the safetensors / HuggingFace
155            // naming convention so subsequent builder calls that ask
156            // for `model.embed_tokens.weight` (the canonical name baked
157            // into the llama32 builder) hit the cache whether the
158            // loader was safetensors-native or GGUF-native.
159            let canonical =
160                rlx_core::weight_loader::gguf_to_hf_name(&k).unwrap_or_else(|| k.clone());
161            weights_cache.insert(canonical, v);
162        }
163        let rope_factors = weights_cache
164            .get("rope_freqs.weight")
165            .map(|(d, _)| d.as_slice());
166        let inv_freq = resolve_inv_freq(&cfg, rope_factors);
167        Ok(Self {
168            cfg,
169            weights_cache,
170            tokens: Vec::new(),
171            device,
172            cache: None,
173            prefill_compile_cache: None,
174            prefill_dynamic_cache: None,
175            decode_compile_cache: None,
176            decode_dynamic_cache: None,
177            decode_loaded_buckets: HashSet::new(),
178            compile_seq_cap: None,
179            inv_freq,
180            prefill_profile: CompileProfile::llama32_prefill(),
181            decode_profile: CompileProfile::llama32_decode(),
182        })
183    }
184
185    fn compile_seq_cap(&self) -> usize {
186        self.compile_seq_cap
187            .unwrap_or(self.cfg.max_position_embeddings)
188    }
189
190    /// Cap symbolic seq / past-seq in dynamic compile paths. Use for models
191    /// with very large `max_position_embeddings` (e.g. 128k) when the runner
192    /// only needs a short window.
193    pub fn with_compile_seq_cap(mut self, cap: usize) -> Self {
194        self.compile_seq_cap = Some(cap.max(1));
195        self
196    }
197
198    /// Like [`Self::from_loader`] but loads tier-1 profiles from
199    /// `llama32.rlx.toml` in the weights directory when present.
200    pub fn from_loader_at(
201        cfg: Llama32Config,
202        loader: &mut dyn WeightLoader,
203        device: Device,
204        weights_path: &Path,
205    ) -> Result<Self> {
206        let mut g = Self::from_loader(cfg, loader, device)?;
207        g.prefill_profile = crate::llama32_profile_near_weights(weights_path, false);
208        g.decode_profile = crate::llama32_profile_near_weights(weights_path, true);
209        Ok(g)
210    }
211
212    /// Override tier-1 compile profiles explicitly.
213    pub fn with_compile_profiles(
214        mut self,
215        prefill: CompileProfile,
216        decode: CompileProfile,
217    ) -> Self {
218        self.prefill_profile = prefill;
219        self.decode_profile = decode;
220        self
221    }
222
223    pub fn prefill_profile(&self) -> &CompileProfile {
224        &self.prefill_profile
225    }
226
227    pub fn decode_profile(&self) -> &CompileProfile {
228        &self.decode_profile
229    }
230
231    fn profile_compile_options(&self, decode: bool) -> CompileOptions {
232        let profile = if decode {
233            &self.decode_profile
234        } else {
235            &self.prefill_profile
236        };
237        compile_options_from_profile(profile, self.device, KernelDispatchConfig::default())
238    }
239
240    fn compile_hir_profiled(
241        &self,
242        session: &Session,
243        hir: rlx_ir::hir::HirModule,
244        decode: bool,
245    ) -> Result<rlx_runtime::CompiledGraph> {
246        let opts = self.profile_compile_options(decode);
247        Ok(metal_decode_compile_guard(self.device, decode, || {
248            session.compile_hir_with(hir, &opts)
249        })?)
250    }
251
252    fn compile_graph_profiled(
253        &self,
254        session: &Session,
255        graph: rlx_ir::Graph,
256    ) -> Result<rlx_runtime::CompiledGraph> {
257        let opts = self.profile_compile_options(false);
258        Ok(session.compile_with(graph, &opts))
259    }
260
261    /// Enable the prefill compile cache with the given LRU capacity.
262    /// Useful when the same prompt length is used across multiple
263    /// generation runs — the second + Nth run skip the compile +
264    /// param-attach roundtrip (~30-50ms per call on CPU).
265    pub fn with_prefill_cache(mut self, capacity: usize) -> Self {
266        self.prefill_compile_cache = Some(CompileCache::new(self.device, capacity));
267        self.prefill_dynamic_cache = None;
268        self
269    }
270
271    /// Compile prefill once with `sym::SEQ`, specialize per prompt length.
272    pub fn with_dynamic_prefill_cache(mut self, capacity: usize) -> Self {
273        self.prefill_dynamic_cache = Some(DynamicDimCompileCache::new(self.device, capacity));
274        self.prefill_compile_cache = None;
275        self
276    }
277
278    /// Enable the bucketed decode compile cache spanning past-seq
279    /// values in `[1, max_past]`. Buckets are power-of-two
280    /// `[1..2, 2..3, 3..5, 5..9, 9..17, …]`. Each bucket compiles
281    /// one graph at its upper bound; a steady-state generation loop
282    /// across `N` tokens compiles `O(log N)` graphs instead of `N`.
283    ///
284    /// Padding compute waste is bounded at 2×: actual `past_seq` is
285    /// at least half the bucket's upper bound (except possibly the
286    /// smallest bucket).
287    pub fn with_decode_cache(mut self, max_past: usize) -> Self {
288        let cache = BucketedCompileCache::power_of_two_ladder(
289            self.device,
290            /*min*/ 1,
291            max_past.max(1) as u64,
292        );
293        self.decode_compile_cache = Some(cache);
294        self.decode_dynamic_cache = None;
295        self.decode_loaded_buckets.clear();
296        self
297    }
298
299    /// Compile decode once with `sym::PAST_SEQ`, specialize per prefix length.
300    pub fn with_dynamic_decode_cache(mut self, capacity: usize) -> Self {
301        self.decode_dynamic_cache = Some(DynamicDimCompileCache::new(self.device, capacity));
302        self.decode_compile_cache = None;
303        self.decode_loaded_buckets.clear();
304        self
305    }
306
307    /// Convenience: load weights from a safetensors or GGUF path
308    /// (dispatch by extension; see `rlx_core::weight_loader::load_from_path`).
309    pub fn from_path(cfg: Llama32Config, path: &str, device: Device) -> Result<Self> {
310        let mut loader = rlx_core::weight_loader::load_from_path(path)?;
311        Self::from_loader(cfg, loader.as_mut(), device)
312    }
313
314    /// Same as [`from_path`] but with MTP-head visibility control.
315    /// When `include_mtp=true` and the file is GGUF, MTP weights are
316    /// drained into the generator's cache alongside the base
317    /// weights. The base inference path still ignores them — they
318    /// sit in cache for a future MTP-aware decoder. Non-GGUF formats
319    /// silently ignore the flag (safetensors files publish all
320    /// tensors uniformly; downstream code distinguishes by name).
321    pub fn from_path_with_mtp(
322        cfg: Llama32Config,
323        path: &str,
324        device: Device,
325        include_mtp: bool,
326    ) -> Result<Self> {
327        // Branch on extension so we can flip the GGUF-specific
328        // visibility knob. Safetensors has no equivalent — it
329        // doesn't isolate MTP tensors at the loader level.
330        if path.ends_with(".gguf") {
331            let mut gguf = rlx_core::weight_loader::GgufLoader::from_file(path)?;
332            gguf.include_mtp(include_mtp);
333            Self::from_loader(cfg, &mut gguf, device)
334        } else {
335            Self::from_path(cfg, path, device)
336        }
337    }
338
339    /// Replace the token history with `prompt_ids`. Does not run the
340    /// model — the next [`step`] call processes the full sequence.
341    /// Clears any KV cache from a prior generation.
342    pub fn prefill(&mut self, prompt_ids: &[u32]) {
343        self.tokens.clear();
344        self.tokens.extend_from_slice(prompt_ids);
345        self.cache = None;
346    }
347
348    /// Run one prefill over the current token history and sample the
349    /// next token. The sampled token is appended to the history and
350    /// returned. Call repeatedly to generate.
351    pub fn step(&mut self, opts: SampleOpts) -> Result<u32> {
352        if self.tokens.is_empty() {
353            anyhow::bail!("step() called with empty token history; call prefill() first");
354        }
355        let seq = self.tokens.len();
356        let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
357        let (graph, params) = build_llama32_graph_sized_last_logits(
358            &self.cfg, &mut wm, /*batch*/ 1, seq, /*with_kv_outputs*/ false,
359        )?;
360        let session = Session::new(self.device);
361        let mut compiled = self.compile_graph_profiled(&session, graph)?;
362        for (name, data) in &params {
363            compiled.set_param(name, data);
364        }
365        let ids_f32: Vec<f32> = self.tokens.iter().map(|&i| i as f32).collect();
366        let outputs = compiled.run(&[("input_ids", ids_f32.as_slice())]);
367        let logits = outputs
368            .into_iter()
369            .next()
370            .context("compiled.run returned no outputs")?;
371
372        let vocab = self.cfg.vocab_size;
373        let expected = vocab;
374        if logits.len() < expected {
375            anyhow::bail!(
376                "logits length {} < expected {} (last logits, seq {seq}, vocab {vocab})",
377                logits.len(),
378                expected
379            );
380        }
381        // Last-logits graph returns [B=1, 1, vocab].
382        let last_row = &logits[..vocab];
383        let tok = sample_token(last_row, opts) as u32;
384        self.tokens.push(tok);
385        Ok(tok)
386    }
387
388    /// Run `n` steps and return the newly generated token ids
389    /// (excludes the prefill prompt).
390    pub fn generate(&mut self, n: usize, opts: SampleOpts) -> Result<Vec<u32>> {
391        let start = self.tokens.len();
392        for _ in 0..n {
393            self.step(opts)?;
394        }
395        Ok(self.tokens[start..].to_vec())
396    }
397
398    /// Cached step: O(L) per token instead of O(L²). First call seeds
399    /// the KV cache from the prompt via prefill-with-cache; subsequent
400    /// calls run the decode-mode graph on just the last token + cached
401    /// past. Output is bit-identical to [`step`] modulo reduction
402    /// order in the SDPA kernel.
403    ///
404    /// Invariant after each call: `cache.past_seq == tokens.len() - 1`
405    /// (the just-sampled token is appended but not yet in the cache;
406    /// it becomes the input for the next decode step).
407    pub fn step_cached(&mut self, opts: SampleOpts) -> Result<u32> {
408        if self.tokens.is_empty() {
409            anyhow::bail!("step_cached() called with empty token history; call prefill() first");
410        }
411        if self.cache.is_none() {
412            // The seed runs prefill, populates the cache, samples from
413            // the last position, and appends the token. Return that
414            // token directly — no decode step on this call.
415            let tok = self.seed_cache_from_prompt(opts)?;
416            return Ok(tok);
417        }
418        let cache = self.cache.as_ref().unwrap();
419        let past_seq = cache.past_seq;
420        // The token we feed into decode is whatever's after the cached
421        // prefix in `self.tokens`. After a prior cached step this is
422        // the just-sampled token; after seeding it's the same.
423        if self.tokens.len() <= past_seq {
424            anyhow::bail!(
425                "cache invariant violated: tokens.len() {} <= past_seq {}",
426                self.tokens.len(),
427                past_seq
428            );
429        }
430        let input_tok = self.tokens[past_seq];
431
432        // Branch: bucketed compile cache vs one-shot compile per step.
433        let (logits, new_k, new_v) = if self.decode_dynamic_cache.is_some() {
434            self.decode_step_dynamic(past_seq, input_tok)?
435        } else if self.decode_compile_cache.is_some()
436            && self
437                .decode_compile_cache
438                .as_ref()
439                .unwrap()
440                .bucket_for(past_seq as u64)
441                .is_some()
442        {
443            self.decode_step_bucketed(past_seq, input_tok)?
444        } else {
445            self.decode_step_oneshot(past_seq, input_tok)?
446        };
447
448        let cache_mut = self.cache.as_mut().unwrap();
449        cache_mut.past_seq = past_seq + 1;
450        cache_mut.layers_k = new_k;
451        cache_mut.layers_v = new_v;
452
453        let vocab = self.cfg.vocab_size;
454        if logits.len() != vocab {
455            anyhow::bail!("decode logits length {} != vocab {}", logits.len(), vocab);
456        }
457        let tok = sample_token(&logits, opts) as u32;
458        self.tokens.push(tok);
459        Ok(tok)
460    }
461
462    /// Decode path that compiles a fresh graph for the exact `past_seq`
463    /// every call. Slower but always-correct fallback.
464    #[allow(clippy::type_complexity)]
465    fn decode_step_oneshot(
466        &mut self,
467        past_seq: usize,
468        input_tok: u32,
469    ) -> Result<(Vec<f32>, Vec<Vec<f32>>, Vec<Vec<f32>>)> {
470        let cache = self.cache.as_ref().unwrap();
471
472        let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
473        let (hir, params) =
474            build_llama32_decode_hir_sized(&self.cfg, &mut wm, /*batch*/ 1, past_seq)?;
475        let session = Session::new(self.device);
476        let mut compiled = self.compile_hir_profiled(&session, hir, true)?;
477        for (name, data) in &params {
478            compiled.set_param(name, data);
479        }
480
481        let (cos, sin) = compute_rope_slice(&self.inv_freq, past_seq);
482        let input_ids_f32 = [input_tok as f32];
483        let key_strs: Vec<String> = (0..self.cfg.num_hidden_layers)
484            .flat_map(|i| [format!("past_k_{i}"), format!("past_v_{i}")])
485            .collect();
486        let mut inputs: Vec<(&str, &[f32])> =
487            Vec::with_capacity(3 + 2 * self.cfg.num_hidden_layers);
488        inputs.push(("input_ids", input_ids_f32.as_slice()));
489        inputs.push(("rope_cos", cos.as_slice()));
490        inputs.push(("rope_sin", sin.as_slice()));
491        for i in 0..self.cfg.num_hidden_layers {
492            inputs.push((&key_strs[2 * i], cache.layers_k[i].as_slice()));
493            inputs.push((&key_strs[2 * i + 1], cache.layers_v[i].as_slice()));
494        }
495
496        let outputs = compiled.run(&inputs);
497        self.split_decode_outputs(outputs)
498    }
499
500    #[allow(clippy::type_complexity)]
501    fn decode_step_dynamic(
502        &mut self,
503        past_seq: usize,
504        input_tok: u32,
505    ) -> Result<(Vec<f32>, Vec<Vec<f32>>, Vec<Vec<f32>>)> {
506        let cache = self.cache.as_ref().unwrap();
507        let binding = DimBinding::batch_past_seq(1, past_seq);
508        let opts = self
509            .profile_compile_options(true)
510            .dim_binding(binding.clone());
511        let max_past = self.compile_seq_cap();
512        let cache_dyn = self
513            .decode_dynamic_cache
514            .as_mut()
515            .ok_or_else(|| anyhow::anyhow!("dynamic decode without cache"))?;
516        let needs_upload = !cache_dyn.contains(past_seq as u64);
517        let cfg = self.cfg.clone();
518        let weights_cache = self.weights_cache.clone();
519        let device = self.device;
520        let compiled = cache_dyn.get_or_specialize(
521            past_seq as u64,
522            &binding,
523            || {
524                metal_decode_compile_guard(device, true, || {
525                    let mut wm = WeightMap::from_tensors(weights_cache);
526                    build_llama32_decode_hir_dynamic_ext(&cfg, &mut wm, 1, max_past)
527                        .expect("dynamic decode HIR")
528                        .0
529                })
530            },
531            &opts,
532        )?;
533        if needs_upload {
534            let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
535            let (_, params) =
536                build_llama32_decode_hir_dynamic_ext(&self.cfg, &mut wm, 1, max_past)?;
537            for (name, data) in &params {
538                compiled.set_param(name, data);
539            }
540        }
541
542        let (cos, sin) = compute_rope_slice(&self.inv_freq, past_seq);
543        let input_ids_f32 = [input_tok as f32];
544        let key_strs: Vec<String> = (0..self.cfg.num_hidden_layers)
545            .flat_map(|i| [format!("past_k_{i}"), format!("past_v_{i}")])
546            .collect();
547        let mut inputs: Vec<(&str, &[f32])> =
548            Vec::with_capacity(3 + 2 * self.cfg.num_hidden_layers);
549        inputs.push(("input_ids", input_ids_f32.as_slice()));
550        inputs.push(("rope_cos", cos.as_slice()));
551        inputs.push(("rope_sin", sin.as_slice()));
552        for i in 0..self.cfg.num_hidden_layers {
553            inputs.push((&key_strs[2 * i], cache.layers_k[i].as_slice()));
554            inputs.push((&key_strs[2 * i + 1], cache.layers_v[i].as_slice()));
555        }
556        let outputs = compiled.run(&inputs);
557        self.split_decode_outputs(outputs)
558    }
559
560    /// Decode path using the bucketed compile cache. Compiles one graph
561    /// per bucket (instead of per `past_seq`), pads `past_k`/`past_v` to
562    /// the bucket's upper bound, and uses a custom mask to zero out the
563    /// padded K positions in attention. After running, slices the
564    /// `new_k`/`new_v` outputs back to `actual_past + 1` length so the
565    /// stored cache stays compact.
566    #[allow(clippy::type_complexity)]
567    fn decode_step_bucketed(
568        &mut self,
569        past_seq: usize,
570        input_tok: u32,
571    ) -> Result<(Vec<f32>, Vec<Vec<f32>>, Vec<Vec<f32>>)> {
572        let cache_dec = self.decode_compile_cache.as_ref().unwrap();
573        let bucket_idx = cache_dec
574            .bucket_for(past_seq as u64)
575            .ok_or_else(|| anyhow::anyhow!("past_seq {past_seq} outside any bucket"))?;
576        let upper = cache_dec
577            .buckets()
578            .nth(bucket_idx)
579            .map(|r| r.end - 1)
580            .unwrap() as usize;
581
582        let kv_dim = self.cfg.kv_proj_dim();
583        let n_layers = self.cfg.num_hidden_layers;
584
585        // First-time-in-bucket: build the graph + compile + attach
586        // params, then mark the bucket as loaded. Subsequent calls skip
587        // all of this and just .run() the cached graph.
588        let needs_load = !self.decode_loaded_buckets.contains(&bucket_idx);
589        if needs_load {
590            let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
591            let (hir, params) = build_llama32_decode_hir_sized_ext(
592                &self.cfg, &mut wm, /*batch*/ 1, upper, /*use_custom_mask*/ true,
593            )?;
594            {
595                let decode_opts = self.profile_compile_options(true);
596                let cache_mut = self.decode_compile_cache.as_mut().unwrap();
597                metal_decode_compile_guard(self.device, true, || {
598                    let (_u, compiled) = cache_mut
599                        .get_or_compile_hir_with_options(
600                            past_seq as u64,
601                            |_upper| hir,
602                            &decode_opts,
603                        )
604                        .expect("bucket must exist; we just looked it up");
605                    for (name, data) in &params {
606                        compiled.set_param(name, data);
607                    }
608                });
609            }
610            self.decode_loaded_buckets.insert(bucket_idx);
611        }
612
613        // Prepare host-side inputs.
614        let (cos, sin) = compute_rope_slice(&self.inv_freq, past_seq);
615        let input_ids_f32 = [input_tok as f32];
616
617        // Mask: shape [1, upper + 1]. 1.0 at positions 0..(past_seq + 1),
618        // 0.0 at (past_seq + 1)..(upper + 1). Without the mask the padded
619        // zero rows would still steal softmax weight (e^0 = 1 per pad
620        // position) and silently scale the output down.
621        let mask_len = upper + 1;
622        let mut mask = vec![0.0f32; mask_len];
623        for v in mask.iter_mut().take(past_seq + 1) {
624            *v = 1.0;
625        }
626
627        // Pad past_k / past_v to length `upper`.
628        let padded_k: Vec<Vec<f32>> = (0..n_layers)
629            .map(|i| {
630                let src = &self.cache.as_ref().unwrap().layers_k[i];
631                let mut out = vec![0f32; upper * kv_dim];
632                out[..src.len()].copy_from_slice(src);
633                out
634            })
635            .collect();
636        let padded_v: Vec<Vec<f32>> = (0..n_layers)
637            .map(|i| {
638                let src = &self.cache.as_ref().unwrap().layers_v[i];
639                let mut out = vec![0f32; upper * kv_dim];
640                out[..src.len()].copy_from_slice(src);
641                out
642            })
643            .collect();
644
645        let key_strs: Vec<String> = (0..n_layers)
646            .flat_map(|i| [format!("past_k_{i}"), format!("past_v_{i}")])
647            .collect();
648        let mut inputs: Vec<(&str, &[f32])> = Vec::with_capacity(4 + 2 * n_layers);
649        inputs.push(("input_ids", input_ids_f32.as_slice()));
650        inputs.push(("rope_cos", cos.as_slice()));
651        inputs.push(("rope_sin", sin.as_slice()));
652        inputs.push(("mask", mask.as_slice()));
653        for i in 0..n_layers {
654            inputs.push((&key_strs[2 * i], padded_k[i].as_slice()));
655            inputs.push((&key_strs[2 * i + 1], padded_v[i].as_slice()));
656        }
657
658        let cache_mut = self.decode_compile_cache.as_mut().unwrap();
659        let (_u, compiled) = cache_mut
660            .get_or_compile_hir(past_seq as u64, |_| {
661                unreachable!("bucket was just loaded above")
662            })
663            .unwrap();
664        let raw_outputs = compiled.run(&inputs);
665
666        // The graph emits new_k/new_v at length `upper + 1` (padded
667        // past + the new token). Slice each back to `past_seq + 1` so
668        // the stored cache only holds real positions.
669        let mut iter = raw_outputs.into_iter();
670        let logits = iter.next().context("bucketed decode logits missing")?;
671        let real_len = (past_seq + 1) * kv_dim;
672        let mut new_k = Vec::with_capacity(n_layers);
673        let mut new_v = Vec::with_capacity(n_layers);
674        for _ in 0..n_layers {
675            let k = iter.next().context("bucketed k missing")?;
676            let v = iter.next().context("bucketed v missing")?;
677            new_k.push(k[..real_len].to_vec());
678            new_v.push(v[..real_len].to_vec());
679        }
680        Ok((logits, new_k, new_v))
681    }
682
683    /// Run prefill-with-cache and return the raw outputs. Uses the
684    /// LRU `CompileCache` when enabled; otherwise compiles fresh each
685    /// call. Keyed by `seq` because graph shape is seq-specialized.
686    fn run_prefill_with_cache(
687        &mut self,
688        batch: usize,
689        seq: usize,
690        ids_f32: &[f32],
691    ) -> Result<Vec<Vec<f32>>> {
692        let compile_cap = self.compile_seq_cap();
693        let dynamic_prefill = self.prefill_dynamic_cache.is_some().then(|| {
694            let binding = DimBinding::batch_seq(batch, seq);
695            let opts = self
696                .profile_compile_options(false)
697                .dim_binding(binding.clone());
698            (binding, opts)
699        });
700        if let (Some(cache), Some((binding, opts))) = (
701            self.prefill_dynamic_cache.as_mut(),
702            dynamic_prefill.as_ref(),
703        ) {
704            let max_seq = compile_cap;
705            let needs_upload = !cache.contains(seq as u64);
706            let cfg = self.cfg.clone();
707            let weights_cache = self.weights_cache.clone();
708            let compiled = cache.get_or_specialize(
709                seq as u64,
710                binding,
711                || {
712                    let mut wm = WeightMap::from_tensors(weights_cache);
713                    build_llama32_prefill_hir_dynamic_ext(&cfg, &mut wm, batch, max_seq, true)
714                        .expect("dynamic prefill HIR")
715                        .0
716                },
717                opts,
718            )?;
719            if needs_upload {
720                let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
721                let (_, params) = build_llama32_prefill_hir_dynamic_ext(
722                    &self.cfg, &mut wm, batch, max_seq, true,
723                )?;
724                for (name, data) in &params {
725                    compiled.set_param(name, data);
726                }
727            }
728            let last_idx = vec![(seq - 1) as f32];
729            Ok(compiled.run(&[("input_ids", ids_f32), ("last_token_idx", &last_idx)]))
730        } else if let Some(prefill_cache) = self.prefill_compile_cache.as_mut() {
731            let key = ((batch as u64) << 32) | (seq as u64);
732            if !prefill_cache.contains(key) {
733                let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
734                let (graph, params) = build_llama32_graph_sized_last_logits(
735                    &self.cfg, &mut wm, batch, seq, /*with_kv_outputs*/ true,
736                )?;
737                {
738                    let compiled = prefill_cache.get_or_compile(key, || graph);
739                    for (name, data) in &params {
740                        compiled.set_param(name, data);
741                    }
742                }
743            }
744            let compiled =
745                prefill_cache.get_or_compile(key, || unreachable!("just populated above"));
746            Ok(compiled.run(&[("input_ids", ids_f32)]))
747        } else {
748            let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
749            let (graph, params) = build_llama32_graph_sized_last_logits(
750                &self.cfg, &mut wm, batch, seq, /*with_kv_outputs*/ true,
751            )?;
752            let session = Session::new(self.device);
753            let mut compiled = self.compile_graph_profiled(&session, graph)?;
754            for (name, data) in &params {
755                compiled.set_param(name, data);
756            }
757            Ok(compiled.run(&[("input_ids", ids_f32)]))
758        }
759    }
760
761    /// Split raw graph outputs (logits + per-layer K + per-layer V) into
762    /// (logits, layers_k, layers_v) for the one-shot decode path. The
763    /// bucketed path needs slicing too, so it doesn't reuse this.
764    #[allow(clippy::type_complexity)]
765    fn split_decode_outputs(
766        &self,
767        outputs: Vec<Vec<f32>>,
768    ) -> Result<(Vec<f32>, Vec<Vec<f32>>, Vec<Vec<f32>>)> {
769        let n_layers = self.cfg.num_hidden_layers;
770        if outputs.len() != 1 + 2 * n_layers {
771            anyhow::bail!(
772                "decode graph produced {} outputs, expected {}",
773                outputs.len(),
774                1 + 2 * n_layers
775            );
776        }
777        let mut iter = outputs.into_iter();
778        let logits = iter.next().context("decode logits missing")?;
779        let mut layers_k = Vec::with_capacity(n_layers);
780        let mut layers_v = Vec::with_capacity(n_layers);
781        for _ in 0..n_layers {
782            layers_k.push(iter.next().context("decode k missing")?);
783            layers_v.push(iter.next().context("decode v missing")?);
784        }
785        Ok((logits, layers_k, layers_v))
786    }
787
788    /// Run `n` cached steps and return the newly generated tokens.
789    pub fn generate_cached(&mut self, n: usize, opts: SampleOpts) -> Result<Vec<u32>> {
790        self.generate_cached_with(n, opts, |_| {})
791    }
792
793    /// Same as [`generate_cached`] but invokes `on_token` once per
794    /// freshly sampled id, inside the decode loop. The whole `n` step
795    /// loop shares the bucketed compile cache — callers wanting a
796    /// streaming UI should prefer this to calling
797    /// `generate_cached(1, …)` `n` times (which forces a fresh
798    /// compile per token at the bucket boundaries).
799    pub fn generate_cached_with(
800        &mut self,
801        n: usize,
802        opts: SampleOpts,
803        mut on_token: impl FnMut(u32),
804    ) -> Result<Vec<u32>> {
805        let start = self.tokens.len();
806        for _ in 0..n {
807            let tok = self.step_cached(opts)?;
808            on_token(tok);
809        }
810        Ok(self.tokens[start..].to_vec())
811    }
812
813    /// Run prefill-with-cache on the current `self.tokens` (the
814    /// prompt), populate `self.cache`, sample the next token from the
815    /// last position's logits, and append it. Returns the sampled
816    /// token. Invariant after: `cache.past_seq == tokens.len() - 1`.
817    fn seed_cache_from_prompt(&mut self, opts: SampleOpts) -> Result<u32> {
818        let seq = self.tokens.len();
819        let batch = 1usize;
820        let kv_dim = self.cfg.kv_proj_dim();
821
822        let ids_f32: Vec<f32> = self.tokens.iter().map(|&i| i as f32).collect();
823        let outputs = self.run_prefill_with_cache(batch, seq, &ids_f32)?;
824        if outputs.len() != 1 + 2 * self.cfg.num_hidden_layers {
825            anyhow::bail!(
826                "prefill-with-cache produced {} outputs, expected {}",
827                outputs.len(),
828                1 + 2 * self.cfg.num_hidden_layers
829            );
830        }
831        let expected_kv_len = batch * seq * kv_dim;
832        let mut iter = outputs.into_iter();
833        let logits = iter.next().context("prefill logits missing")?;
834        let mut layers_k = Vec::with_capacity(self.cfg.num_hidden_layers);
835        let mut layers_v = Vec::with_capacity(self.cfg.num_hidden_layers);
836        for layer in 0..self.cfg.num_hidden_layers {
837            let k = iter.next().context("prefill k missing")?;
838            let v = iter.next().context("prefill v missing")?;
839            if k.len() != expected_kv_len || v.len() != expected_kv_len {
840                anyhow::bail!(
841                    "layer {layer}: k.len={} v.len={} expected {}",
842                    k.len(),
843                    v.len(),
844                    expected_kv_len
845                );
846            }
847            layers_k.push(k);
848            layers_v.push(v);
849        }
850        self.cache = Some(KvCacheState {
851            past_seq: seq,
852            layers_k,
853            layers_v,
854        });
855
856        let vocab = self.cfg.vocab_size;
857        let needed = vocab;
858        if logits.len() < needed {
859            anyhow::bail!("prefill logits length {} < {}", logits.len(), needed);
860        }
861        let last_row = &logits[..vocab];
862        let tok = sample_token(last_row, opts) as u32;
863        self.tokens.push(tok);
864        Ok(tok)
865    }
866
867    /// Full token history (prompt + generated).
868    pub fn tokens(&self) -> &[u32] {
869        &self.tokens
870    }
871
872    pub fn config(&self) -> &Llama32Config {
873        &self.cfg
874    }
875
876    /// Low-level primitive: reset internal state, run prefill-with-cache
877    /// over `context`, and return the *last position's* logits row
878    /// (`P(next_token | context)`). Does NOT sample or append. The
879    /// internal `tokens` buffer is set to `context` and the KV cache
880    /// is populated to `past_seq = context.len()`.
881    ///
882    /// First row of logits after prefill-with-cache (no sampling).
883    pub fn prefill_get_last_logits(&mut self, context: &[u32]) -> Result<Vec<f32>> {
884        if context.is_empty() {
885            anyhow::bail!("prefill_get_last_logits: empty context");
886        }
887        self.tokens.clear();
888        self.tokens.extend_from_slice(context);
889        self.cache = None;
890
891        let seq = context.len();
892        let batch = 1usize;
893        let kv_dim = self.cfg.kv_proj_dim();
894
895        let ids_f32: Vec<f32> = context.iter().map(|&i| i as f32).collect();
896        let outputs = self.run_prefill_with_cache(batch, seq, &ids_f32)?;
897        if outputs.len() != 1 + 2 * self.cfg.num_hidden_layers {
898            anyhow::bail!(
899                "prefill_get_last_logits: got {} outputs, expected {}",
900                outputs.len(),
901                1 + 2 * self.cfg.num_hidden_layers
902            );
903        }
904        let expected_kv_len = batch * seq * kv_dim;
905        let mut iter = outputs.into_iter();
906        let logits = iter.next().context("logits missing")?;
907        let mut layers_k = Vec::with_capacity(self.cfg.num_hidden_layers);
908        let mut layers_v = Vec::with_capacity(self.cfg.num_hidden_layers);
909        for _ in 0..self.cfg.num_hidden_layers {
910            let k = iter.next().context("k missing")?;
911            let v = iter.next().context("v missing")?;
912            if k.len() != expected_kv_len || v.len() != expected_kv_len {
913                anyhow::bail!("kv length mismatch in prefill_get_last_logits");
914            }
915            layers_k.push(k);
916            layers_v.push(v);
917        }
918        self.cache = Some(KvCacheState {
919            past_seq: seq,
920            layers_k,
921            layers_v,
922        });
923
924        let vocab = self.cfg.vocab_size;
925        let needed = vocab;
926        if logits.len() < needed {
927            anyhow::bail!("logits short: {} < {}", logits.len(), needed);
928        }
929        Ok(logits[..vocab].to_vec())
930    }
931
932    /// Low-level primitive: run one decode step with the caller-
933    /// supplied input token (no sampling), advance the KV cache, and
934    /// return the resulting logits row `P(next | history ++ input)`.
935    /// Appends `input` to the `tokens` buffer so the invariant
936    /// `cache.past_seq == tokens.len()` holds after this call (note:
937    /// differs from `step_cached` invariant because this method does
938    /// not append a sampled token).
939    pub fn decode_get_logits(&mut self, input: u32) -> Result<Vec<f32>> {
940        let cache = self.cache.as_ref().ok_or_else(|| {
941            anyhow::anyhow!(
942                "decode_get_logits: cache not seeded; call prefill_get_last_logits first"
943            )
944        })?;
945        let past_seq = cache.past_seq;
946
947        let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
948        let (hir, params) =
949            build_llama32_decode_hir_sized(&self.cfg, &mut wm, /*batch*/ 1, past_seq)?;
950        let session = Session::new(self.device);
951        let mut compiled = self.compile_hir_profiled(&session, hir, true)?;
952        for (name, data) in &params {
953            compiled.set_param(name, data);
954        }
955
956        let (cos, sin) = compute_rope_slice(&self.inv_freq, past_seq);
957        let input_ids_f32 = [input as f32];
958        let key_strs: Vec<String> = (0..self.cfg.num_hidden_layers)
959            .flat_map(|i| [format!("past_k_{i}"), format!("past_v_{i}")])
960            .collect();
961        let mut inputs: Vec<(&str, &[f32])> =
962            Vec::with_capacity(3 + 2 * self.cfg.num_hidden_layers);
963        inputs.push(("input_ids", input_ids_f32.as_slice()));
964        inputs.push(("rope_cos", cos.as_slice()));
965        inputs.push(("rope_sin", sin.as_slice()));
966        for i in 0..self.cfg.num_hidden_layers {
967            let pk = &cache.layers_k[i];
968            let pv = &cache.layers_v[i];
969            inputs.push((&key_strs[2 * i], pk.as_slice()));
970            inputs.push((&key_strs[2 * i + 1], pv.as_slice()));
971        }
972
973        let outputs = compiled.run(&inputs);
974        if outputs.len() != 1 + 2 * self.cfg.num_hidden_layers {
975            anyhow::bail!(
976                "decode_get_logits: got {} outputs, expected {}",
977                outputs.len(),
978                1 + 2 * self.cfg.num_hidden_layers
979            );
980        }
981        let mut iter = outputs.into_iter();
982        let logits = iter.next().context("logits missing")?;
983        let mut new_k = Vec::with_capacity(self.cfg.num_hidden_layers);
984        let mut new_v = Vec::with_capacity(self.cfg.num_hidden_layers);
985        for _ in 0..self.cfg.num_hidden_layers {
986            new_k.push(iter.next().context("k missing")?);
987            new_v.push(iter.next().context("v missing")?);
988        }
989
990        let cache_mut = self.cache.as_mut().unwrap();
991        cache_mut.past_seq = past_seq + 1;
992        cache_mut.layers_k = new_k;
993        cache_mut.layers_v = new_v;
994        self.tokens.push(input);
995
996        Ok(logits)
997    }
998}
999
1000/// Compute the single-row (cos, sin) RoPE slice for absolute position
1001/// `pos`. Matches the formula in the prefill builder so cached decode
1002/// and recompute prefill produce the same RoPE rotation.
1003fn compute_rope_slice(inv_freq: &[f64], pos: usize) -> (Vec<f32>, Vec<f32>) {
1004    rope_slice(inv_freq, pos)
1005}
1006
1007#[cfg(test)]
1008mod tests {
1009    use super::*;
1010    use crate::config::Llama32Config;
1011
1012    fn tiny_cfg() -> Llama32Config {
1013        Llama32Config {
1014            vocab_size: 16,
1015            hidden_size: 16,
1016            intermediate_size: 32,
1017            num_hidden_layers: 2,
1018            num_attention_heads: 4,
1019            num_key_value_heads: 2,
1020            max_position_embeddings: 16,
1021            rms_norm_eps: 1e-5,
1022            rope_theta: 500_000.0,
1023            hidden_act: "silu".into(),
1024            tie_word_embeddings: false,
1025            attention_bias: false,
1026            head_dim: Some(8),
1027            rope_scaling: None,
1028        }
1029    }
1030
1031    fn synthetic_weights(cfg: &Llama32Config) -> WeightMap {
1032        let h = cfg.hidden_size;
1033        let q_dim = cfg.q_proj_dim();
1034        let kv_dim = cfg.kv_proj_dim();
1035        let int_dim = cfg.intermediate_size;
1036        let mut t: HashMap<String, (Vec<f32>, Vec<usize>)> = HashMap::new();
1037        // Use a deterministic non-zero pattern so logits aren't all 0
1038        // (sampling on an all-zero row is undefined order).
1039        let pat = |n: usize, salt: u32| -> Vec<f32> {
1040            (0..n)
1041                .map(|i| {
1042                    let x = ((i as u32).wrapping_mul(2654435761).wrapping_add(salt)) >> 8;
1043                    (x as f32 / (1u32 << 24) as f32) - 0.5
1044                })
1045                .collect()
1046        };
1047        t.insert(
1048            "model.embed_tokens.weight".into(),
1049            (pat(cfg.vocab_size * h, 1), vec![cfg.vocab_size, h]),
1050        );
1051        for i in 0..cfg.num_hidden_layers {
1052            let lp = format!("model.layers.{i}");
1053            t.insert(
1054                format!("{lp}.input_layernorm.weight"),
1055                (pat(h, 100 + i as u32), vec![h]),
1056            );
1057            t.insert(
1058                format!("{lp}.post_attention_layernorm.weight"),
1059                (pat(h, 200 + i as u32), vec![h]),
1060            );
1061            t.insert(
1062                format!("{lp}.self_attn.q_proj.weight"),
1063                (pat(q_dim * h, 300 + i as u32), vec![q_dim, h]),
1064            );
1065            t.insert(
1066                format!("{lp}.self_attn.k_proj.weight"),
1067                (pat(kv_dim * h, 400 + i as u32), vec![kv_dim, h]),
1068            );
1069            t.insert(
1070                format!("{lp}.self_attn.v_proj.weight"),
1071                (pat(kv_dim * h, 500 + i as u32), vec![kv_dim, h]),
1072            );
1073            t.insert(
1074                format!("{lp}.self_attn.o_proj.weight"),
1075                (pat(h * q_dim, 600 + i as u32), vec![h, q_dim]),
1076            );
1077            t.insert(
1078                format!("{lp}.mlp.gate_proj.weight"),
1079                (pat(int_dim * h, 900 + i as u32), vec![int_dim, h]),
1080            );
1081            t.insert(
1082                format!("{lp}.mlp.up_proj.weight"),
1083                (pat(int_dim * h, 1000 + i as u32), vec![int_dim, h]),
1084            );
1085            t.insert(
1086                format!("{lp}.mlp.down_proj.weight"),
1087                (pat(h * int_dim, 1100 + i as u32), vec![h, int_dim]),
1088            );
1089        }
1090        t.insert("model.norm.weight".into(), (pat(h, 2000), vec![h]));
1091        t.insert(
1092            "lm_head.weight".into(),
1093            (pat(cfg.vocab_size * h, 3000), vec![cfg.vocab_size, h]),
1094        );
1095        WeightMap::from_tensors(t)
1096    }
1097
1098    #[test]
1099    fn generator_drains_loader_and_runs_one_step() {
1100        let cfg = tiny_cfg();
1101        let mut wm = synthetic_weights(&cfg);
1102        let mut gn = Llama32Generator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap();
1103        assert_eq!(wm.len(), 0, "loader should be drained");
1104        gn.prefill(&[1, 2, 3]);
1105        let t = gn.step(SampleOpts::greedy()).unwrap();
1106        assert!((t as usize) < cfg.vocab_size);
1107        assert_eq!(gn.tokens().len(), 4);
1108    }
1109
1110    #[test]
1111    fn generate_n_appends_n_tokens() {
1112        let cfg = tiny_cfg();
1113        let mut wm = synthetic_weights(&cfg);
1114        let mut gn = Llama32Generator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap();
1115        gn.prefill(&[5, 6]);
1116        let new_tokens = gn.generate(3, SampleOpts::greedy()).unwrap();
1117        assert_eq!(new_tokens.len(), 3);
1118        assert_eq!(gn.tokens().len(), 5);
1119        for t in &new_tokens {
1120            assert!((*t as usize) < cfg.vocab_size);
1121        }
1122    }
1123
1124    #[test]
1125    fn step_without_prefill_errors() {
1126        let cfg = tiny_cfg();
1127        let mut wm = synthetic_weights(&cfg);
1128        let mut gn = Llama32Generator::from_loader(cfg, &mut wm, Device::Cpu).unwrap();
1129        let r = gn.step(SampleOpts::greedy());
1130        assert!(r.is_err());
1131    }
1132
1133    #[test]
1134    fn cached_matches_naive_on_greedy() {
1135        // The cached and naive paths must produce the same token
1136        // sequence given the same prompt + opts. This is the
1137        // load-bearing test for the KV-cache implementation: if the
1138        // decode-mode graph, the kernel's Lq!=Lk fix, the cache
1139        // wiring, or the RoPE position-slice is wrong, the sequences
1140        // diverge here.
1141        let cfg = tiny_cfg();
1142        let prompt: Vec<u32> = vec![1, 2, 3, 5];
1143        let steps = 4;
1144
1145        let mut wm_n = synthetic_weights(&cfg);
1146        let mut gn_naive =
1147            Llama32Generator::from_loader(cfg.clone(), &mut wm_n, Device::Cpu).unwrap();
1148        gn_naive.prefill(&prompt);
1149        let naive_tokens = gn_naive.generate(steps, SampleOpts::greedy()).unwrap();
1150
1151        let mut wm_c = synthetic_weights(&cfg);
1152        let mut gn_cached =
1153            Llama32Generator::from_loader(cfg.clone(), &mut wm_c, Device::Cpu).unwrap();
1154        gn_cached.prefill(&prompt);
1155        let cached_tokens = gn_cached
1156            .generate_cached(steps, SampleOpts::greedy())
1157            .unwrap();
1158
1159        assert_eq!(
1160            cached_tokens, naive_tokens,
1161            "cached vs naive token mismatch — KV cache or kernel-Lq!=Lk bug"
1162        );
1163    }
1164
1165    #[test]
1166    fn cached_step_advances_cache_invariant() {
1167        let cfg = tiny_cfg();
1168        let mut wm = synthetic_weights(&cfg);
1169        let mut gn = Llama32Generator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap();
1170        gn.prefill(&[1, 2, 3]);
1171        let _ = gn.step_cached(SampleOpts::greedy()).unwrap();
1172        // After seed: tokens.len() == 4, cache.past_seq == 3 (cache holds prompt).
1173        assert_eq!(gn.tokens().len(), 4);
1174        assert_eq!(gn.cache.as_ref().unwrap().past_seq, 3);
1175        let _ = gn.step_cached(SampleOpts::greedy()).unwrap();
1176        // After one decode: tokens.len() == 5, cache.past_seq == 4.
1177        assert_eq!(gn.tokens().len(), 5);
1178        assert_eq!(gn.cache.as_ref().unwrap().past_seq, 4);
1179    }
1180
1181    #[test]
1182    fn bucketed_decode_matches_oneshot() {
1183        // The bucketed compile-cache path (padded K/V + custom mask)
1184        // must produce the same token sequence as the one-shot
1185        // path. Load-bearing for the bucketed cache feature: if the
1186        // mask, padding, or output slicing is wrong, sequences
1187        // diverge here.
1188        let cfg = tiny_cfg();
1189        let prompt: Vec<u32> = vec![1, 2, 3, 5];
1190        let steps = 6;
1191
1192        let mut wm_one = synthetic_weights(&cfg);
1193        let mut gn_one =
1194            Llama32Generator::from_loader(cfg.clone(), &mut wm_one, Device::Cpu).unwrap();
1195        gn_one.prefill(&prompt);
1196        let oneshot_tokens = gn_one.generate_cached(steps, SampleOpts::greedy()).unwrap();
1197
1198        let mut wm_buc = synthetic_weights(&cfg);
1199        let mut gn_buc = Llama32Generator::from_loader(cfg.clone(), &mut wm_buc, Device::Cpu)
1200            .unwrap()
1201            .with_decode_cache(/*max_past*/ 32);
1202        gn_buc.prefill(&prompt);
1203        let bucketed_tokens = gn_buc.generate_cached(steps, SampleOpts::greedy()).unwrap();
1204
1205        assert_eq!(
1206            bucketed_tokens, oneshot_tokens,
1207            "bucketed-cache decode diverged from one-shot decode — \
1208             mask, padding, or output-slice bug"
1209        );
1210    }
1211
1212    #[test]
1213    fn prefill_compile_cache_does_not_change_output() {
1214        let cfg = tiny_cfg();
1215        let prompt: Vec<u32> = vec![1, 2, 3, 5];
1216        let mut wm_a = synthetic_weights(&cfg);
1217        let mut gn_a = Llama32Generator::from_loader(cfg.clone(), &mut wm_a, Device::Cpu).unwrap();
1218        gn_a.prefill(&prompt);
1219        let a = gn_a.generate_cached(4, SampleOpts::greedy()).unwrap();
1220
1221        let mut wm_b = synthetic_weights(&cfg);
1222        let mut gn_b = Llama32Generator::from_loader(cfg.clone(), &mut wm_b, Device::Cpu)
1223            .unwrap()
1224            .with_prefill_cache(/*capacity*/ 4);
1225        gn_b.prefill(&prompt);
1226        let b = gn_b.generate_cached(4, SampleOpts::greedy()).unwrap();
1227
1228        assert_eq!(a, b, "enabling prefill_cache must not change output");
1229    }
1230
1231    #[test]
1232    fn dynamic_decode_matches_oneshot() {
1233        let cfg = tiny_cfg();
1234        let prompt: Vec<u32> = vec![1, 2, 3, 5];
1235        let steps = 6;
1236
1237        let mut wm_one = synthetic_weights(&cfg);
1238        let mut gn_one =
1239            Llama32Generator::from_loader(cfg.clone(), &mut wm_one, Device::Cpu).unwrap();
1240        gn_one.prefill(&prompt);
1241        let oneshot_tokens = gn_one.generate_cached(steps, SampleOpts::greedy()).unwrap();
1242
1243        let mut wm_dyn = synthetic_weights(&cfg);
1244        let mut gn_dyn = Llama32Generator::from_loader(cfg.clone(), &mut wm_dyn, Device::Cpu)
1245            .unwrap()
1246            .with_dynamic_decode_cache(/*capacity*/ 8);
1247        gn_dyn.prefill(&prompt);
1248        let dynamic_tokens = gn_dyn.generate_cached(steps, SampleOpts::greedy()).unwrap();
1249
1250        assert_eq!(
1251            dynamic_tokens, oneshot_tokens,
1252            "dynamic past_seq decode diverged from one-shot decode"
1253        );
1254    }
1255
1256    #[test]
1257    fn dynamic_prefill_matches_oneshot() {
1258        let cfg = tiny_cfg();
1259        let prompt: Vec<u32> = vec![1, 2, 3, 5];
1260        let steps = 4;
1261
1262        let mut wm_one = synthetic_weights(&cfg);
1263        let mut gn_one =
1264            Llama32Generator::from_loader(cfg.clone(), &mut wm_one, Device::Cpu).unwrap();
1265        gn_one.prefill(&prompt);
1266        let oneshot_tokens = gn_one.generate_cached(steps, SampleOpts::greedy()).unwrap();
1267
1268        let mut wm_dyn = synthetic_weights(&cfg);
1269        let mut gn_dyn = Llama32Generator::from_loader(cfg.clone(), &mut wm_dyn, Device::Cpu)
1270            .unwrap()
1271            .with_dynamic_prefill_cache(/*capacity*/ 8);
1272        gn_dyn.prefill(&prompt);
1273        let dynamic_tokens = gn_dyn.generate_cached(steps, SampleOpts::greedy()).unwrap();
1274
1275        assert_eq!(
1276            dynamic_tokens, oneshot_tokens,
1277            "dynamic seq prefill diverged from one-shot prefill"
1278        );
1279    }
1280
1281    #[test]
1282    fn dynamic_prefill_and_decode_matches_oneshot() {
1283        let cfg = tiny_cfg();
1284        let prompt: Vec<u32> = vec![1, 2, 3, 5];
1285        let steps = 6;
1286
1287        let mut wm_one = synthetic_weights(&cfg);
1288        let mut gn_one =
1289            Llama32Generator::from_loader(cfg.clone(), &mut wm_one, Device::Cpu).unwrap();
1290        gn_one.prefill(&prompt);
1291        let oneshot_tokens = gn_one.generate_cached(steps, SampleOpts::greedy()).unwrap();
1292
1293        let mut wm_dyn = synthetic_weights(&cfg);
1294        let mut gn_dyn = Llama32Generator::from_loader(cfg.clone(), &mut wm_dyn, Device::Cpu)
1295            .unwrap()
1296            .with_dynamic_prefill_cache(/*capacity*/ 8)
1297            .with_dynamic_decode_cache(/*capacity*/ 8);
1298        gn_dyn.prefill(&prompt);
1299        let dynamic_tokens = gn_dyn.generate_cached(steps, SampleOpts::greedy()).unwrap();
1300
1301        assert_eq!(
1302            dynamic_tokens, oneshot_tokens,
1303            "dynamic prefill+decode diverged from one-shot path"
1304        );
1305    }
1306
1307    #[test]
1308    fn greedy_is_deterministic_across_runs() {
1309        let cfg = tiny_cfg();
1310        let weights = synthetic_weights(&cfg);
1311        let mk = || {
1312            let mut wm = WeightMap::from_tensors(weights_as_hashmap(&weights));
1313            Llama32Generator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap()
1314        };
1315        let mut a = mk();
1316        let mut b = mk();
1317        a.prefill(&[1, 2, 3]);
1318        b.prefill(&[1, 2, 3]);
1319        let ta = a.generate(4, SampleOpts::greedy()).unwrap();
1320        let tb = b.generate(4, SampleOpts::greedy()).unwrap();
1321        assert_eq!(ta, tb);
1322    }
1323
1324    fn weights_as_hashmap(wm: &WeightMap) -> HashMap<String, (Vec<f32>, Vec<usize>)> {
1325        // Reconstruct the underlying map by re-running synthetic_weights
1326        // — WeightMap doesn't expose its inner map. Sufficient for the
1327        // determinism test since synthetic_weights is itself
1328        // deterministic.
1329        let _ = wm; // silence unused
1330        let cfg = tiny_cfg();
1331        let mut new = synthetic_weights(&cfg);
1332        let keys: Vec<String> = new.keys().map(|s| s.to_string()).collect();
1333        let mut out = HashMap::new();
1334        for k in keys {
1335            out.insert(k.clone(), new.take(&k).unwrap());
1336        }
1337        out
1338    }
1339}