1use crate::builder::{
34 build_gemma_decode_graph_sized, build_gemma_decode_hir_dynamic_ext,
35 build_gemma_decode_hir_sized_ext, build_gemma_graph_sized_last_logits,
36 build_gemma_graph_sized_last_logits_hidden, build_gemma_prefill_hidden_hir_dynamic_ext,
37 build_gemma_prefill_hir_dynamic_ext,
38};
39use crate::config::GemmaConfig;
40use crate::rope::{resolve_inv_freq, rope_slice};
41use anyhow::{Context, Result};
42use rlx_core::autoregressive::{
43 KvCacheState, kv_from_prefill_outputs_per_layer, run_bucketed_kv_decode_hir_scratch,
44 split_decode_logits_kv,
45};
46use rlx_core::flow_bridge::compile_options_from_profile;
47use rlx_core::gpu_kv::{
48 GpuKvBinding, device_supports_gpu_kv, run_bucketed_kv_decode_gpu_hir, sync_gpu_kv_to_host,
49};
50use rlx_core::weight_loader::WeightLoader;
51use rlx_core::weight_map::WeightMap;
52use rlx_flow::CompileProfile;
53use rlx_ir::DimBinding;
54use rlx_ir::logical_kernel::KernelDispatchConfig;
55use rlx_qwen3::sampling::{SampleOpts, sample_token};
56use rlx_runtime::compile_cache::{
57 BucketedCompileCache, CacheRunInput, CompileCache, DynamicDimCompileCache,
58};
59use rlx_runtime::{CompileOptions, Device, Session};
60use std::collections::HashMap;
61use std::path::Path;
62
63pub fn decode_profile_for_device(device: Device) -> CompileProfile {
65 metal_decode_profile(device, CompileProfile::gemma_decode())
66}
67
68fn metal_thunk_decode_requested() -> bool {
72 std::env::var("RLX_GEMMA_METAL_THUNK_DECODE")
73 .is_ok_and(|v| v == "1" || v.eq_ignore_ascii_case("true") || v.eq_ignore_ascii_case("yes"))
74}
75
76pub(crate) fn metal_decode_compile_guard<R, F>(device: Device, decode: bool, f: F) -> R
77where
78 F: FnOnce() -> R,
79{
80 if decode && metal_thunk_decode_requested() {
81 if device == Device::Metal {
82 rlx_ir::env::set("RLX_DISABLE_MPSGRAPH", "1");
83 let out = f();
84 rlx_ir::env::unset("RLX_DISABLE_MPSGRAPH");
85 out
86 } else {
87 f()
88 }
89 } else {
90 f()
91 }
92}
93
94fn metal_unfused_decode_requested() -> bool {
97 std::env::var("RLX_GEMMA_METAL_UNFUSED_DECODE")
98 .is_ok_and(|v| v == "1" || v.eq_ignore_ascii_case("true") || v.eq_ignore_ascii_case("yes"))
99}
100
101fn metal_decode_profile(device: Device, mut profile: CompileProfile) -> CompileProfile {
102 if device == Device::Metal && metal_unfused_decode_requested() {
103 profile.fusion.skip = true;
104 profile.backend.metal.skip_fusion = true;
105 profile.backend.metal.unfuse_regions = true;
106 }
107 profile
108}
109
110pub struct GemmaGenerator {
116 cfg: GemmaConfig,
117 weights_cache: HashMap<String, (Vec<f32>, Vec<usize>)>,
122 tokens: Vec<u32>,
123 device: Device,
124 cache: Option<KvCacheState>,
128 prefill_compile_cache: Option<CompileCache>,
132 prefill_dynamic_cache: Option<DynamicDimCompileCache>,
134 embed_prefill_compile_cache: Option<CompileCache>,
136 embed_prefill_dynamic_cache: Option<DynamicDimCompileCache>,
137 decode_compile_cache: Option<BucketedCompileCache>,
143 decode_dynamic_cache: Option<DynamicDimCompileCache>,
144 inv_freq: Vec<f64>,
146 prefill_profile: CompileProfile,
148 decode_profile: CompileProfile,
150 pending_prefill_embeds: Option<Vec<f32>>,
152 pending_prefill_attn_bias: Option<Vec<f32>>,
153 use_gpu_kv: bool,
155 gpu_kv_binding: GpuKvBinding,
156 decode_scratch: DecodeKvScratch,
158 decode_inputs: DecodeInputScratch,
159}
160
161#[derive(Default)]
163struct DecodeInputScratch {
164 mask: Vec<f32>,
165 cos: Vec<f32>,
166 sin: Vec<f32>,
167}
168
169#[derive(Default)]
171struct DecodeKvScratch {
172 padded_k: Vec<Vec<f32>>,
173 padded_v: Vec<Vec<f32>>,
174 bucket_upper: usize,
175}
176
177impl DecodeInputScratch {
178 fn fill_mask(&mut self, past_seq: usize, upper: usize) {
179 if self.mask.len() != upper + 1 {
180 self.mask.resize(upper + 1, 0.0);
181 }
182 for (i, m) in self.mask.iter_mut().enumerate().take(upper + 1) {
183 *m = if i < past_seq || i == upper { 1.0 } else { 0.0 };
184 }
185 }
186
187 fn fill_rope(&mut self, inv_freq: &[f64], pos: usize) {
188 let half = inv_freq.len();
189 self.cos.resize(half, 0.0);
190 self.sin.resize(half, 0.0);
191 for (i, &freq) in inv_freq.iter().enumerate() {
192 let angle = pos as f64 * freq;
193 let (s, c) = angle.sin_cos();
194 self.cos[i] = c as f32;
195 self.sin[i] = s as f32;
196 }
197 }
198}
199
200impl DecodeKvScratch {
201 fn ensure_bucket(&mut self, upper: usize, kv_dims: &[usize]) {
202 if self.bucket_upper == upper && self.padded_k.len() == kv_dims.len() {
203 return;
204 }
205 self.bucket_upper = upper;
206 self.padded_k = kv_dims.iter().map(|&d| vec![0.0_f32; upper * d]).collect();
207 self.padded_v = kv_dims.iter().map(|&d| vec![0.0_f32; upper * d]).collect();
208 }
209}
210
211fn gemma_use_gpu_kv(device: Device) -> bool {
212 if !device_supports_gpu_kv(device) {
213 return false;
214 }
215 match std::env::var("RLX_GEMMA_GPU_KV").ok().as_deref() {
216 Some("0") | Some("false") | Some("no") => false,
217 Some("1") | Some("true") | Some("yes") => true,
218 _ => false,
220 }
221}
222
223impl GemmaGenerator {
224 pub fn from_loader(
227 cfg: GemmaConfig,
228 loader: &mut dyn WeightLoader,
229 device: Device,
230 ) -> Result<Self> {
231 let keys = loader.remaining_keys();
232 let arch_hint: Option<String> = loader.arch_hint().map(|s| s.to_string());
238 let mut weights_cache = HashMap::with_capacity(keys.len());
239 for k in keys {
240 let v = loader
241 .take(&k)
242 .with_context(|| format!("draining weight {k}"))?;
243 let canonical = match arch_hint.as_deref() {
249 Some(a) => rlx_core::weight_loader::gguf_to_hf_name_for_arch(&k, a)
250 .unwrap_or_else(|| k.clone()),
251 None => rlx_core::weight_loader::gguf_to_hf_name(&k).unwrap_or_else(|| k.clone()),
252 };
253 weights_cache.insert(canonical, v);
254 }
255 let rope_factors = weights_cache
256 .get("rope_freqs.weight")
257 .map(|(d, _)| d.as_slice());
258 let inv_freq = resolve_inv_freq(&cfg, rope_factors);
259 Ok(Self {
260 cfg,
261 weights_cache,
262 tokens: Vec::new(),
263 device,
264 cache: None,
265 prefill_compile_cache: None,
266 prefill_dynamic_cache: None,
267 embed_prefill_compile_cache: None,
268 embed_prefill_dynamic_cache: None,
269 decode_compile_cache: None,
270 decode_dynamic_cache: None,
271 inv_freq,
272 prefill_profile: CompileProfile::gemma_prefill(),
273 decode_profile: metal_decode_profile(device, CompileProfile::gemma_decode()),
274 pending_prefill_embeds: None,
275 pending_prefill_attn_bias: None,
276 use_gpu_kv: gemma_use_gpu_kv(device),
277 gpu_kv_binding: GpuKvBinding::default(),
278 decode_scratch: DecodeKvScratch::default(),
279 decode_inputs: DecodeInputScratch::default(),
280 })
281 }
282
283 fn reset_gpu_kv_binding(&mut self) {
284 self.gpu_kv_binding = GpuKvBinding::default();
285 }
286
287 pub fn from_loader_at(
290 cfg: GemmaConfig,
291 loader: &mut dyn WeightLoader,
292 device: Device,
293 weights_path: &Path,
294 ) -> Result<Self> {
295 let mut g = Self::from_loader(cfg, loader, device)?;
296 g.prefill_profile = crate::gemma_profile_near_weights(weights_path, false);
297 g.decode_profile = metal_decode_profile(
298 device,
299 crate::gemma_profile_near_weights(weights_path, true),
300 );
301 Ok(g)
302 }
303
304 pub fn with_compile_profiles(
306 mut self,
307 prefill: CompileProfile,
308 decode: CompileProfile,
309 ) -> Self {
310 self.prefill_profile = prefill;
311 self.decode_profile = metal_decode_profile(self.device, decode);
312 self
313 }
314
315 pub fn prefill_profile(&self) -> &CompileProfile {
316 &self.prefill_profile
317 }
318
319 pub fn decode_profile(&self) -> &CompileProfile {
320 &self.decode_profile
321 }
322
323 fn profile_compile_options(&self, decode: bool) -> CompileOptions {
324 let profile = if decode {
325 &self.decode_profile
326 } else {
327 &self.prefill_profile
328 };
329 compile_options_from_profile(profile, self.device, KernelDispatchConfig::default())
330 }
331
332 fn compile_graph_profiled(
333 &self,
334 session: &Session,
335 graph: rlx_ir::Graph,
336 ) -> Result<rlx_runtime::CompiledGraph> {
337 let opts = self.profile_compile_options(false);
338 Ok(session.compile_with(graph, &opts))
339 }
340
341 fn compile_graph_profiled_decode(
342 &self,
343 session: &Session,
344 graph: rlx_ir::Graph,
345 ) -> Result<rlx_runtime::CompiledGraph> {
346 Ok(metal_decode_compile_guard(self.device, true, || {
347 session.compile_with(graph, &self.profile_compile_options(true))
348 }))
349 }
350
351 pub fn with_prefill_cache(mut self, capacity: usize) -> Self {
356 self.prefill_compile_cache = Some(CompileCache::new(self.device, capacity));
357 self.embed_prefill_compile_cache = Some(CompileCache::new(self.device, capacity));
358 self.prefill_dynamic_cache = None;
359 self.embed_prefill_dynamic_cache = None;
360 self
361 }
362
363 pub fn with_dynamic_prefill_cache(mut self, capacity: usize) -> Self {
365 self.prefill_dynamic_cache = Some(DynamicDimCompileCache::new(self.device, capacity));
366 self.embed_prefill_dynamic_cache = Some(DynamicDimCompileCache::new(self.device, capacity));
367 self.prefill_compile_cache = None;
368 self.embed_prefill_compile_cache = None;
369 self
370 }
371
372 pub fn with_decode_cache(mut self, max_past: usize) -> Self {
382 let cache = BucketedCompileCache::power_of_two_ladder(
383 self.device,
384 1,
385 max_past.max(1) as u64,
386 );
387 self.decode_compile_cache = Some(cache);
388 self.decode_dynamic_cache = None;
389 self
390 }
391
392 pub fn with_dynamic_decode_cache(mut self, capacity: usize) -> Self {
394 self.decode_dynamic_cache = Some(DynamicDimCompileCache::new(self.device, capacity));
395 self.decode_compile_cache = None;
396 self
397 }
398
399 fn inference_dynamic_decode() -> bool {
400 std::env::var("RLX_GEMMA_DYNAMIC_DECODE").is_ok_and(|v| {
401 v == "1" || v.eq_ignore_ascii_case("true") || v.eq_ignore_ascii_case("yes")
402 })
403 }
404
405 pub fn with_inference_caches(mut self, max_seq: usize) -> Self {
409 let decode_horizon = max_seq.saturating_add(16).max(32);
412 self = self.with_dynamic_prefill_cache(16);
413 if Self::inference_dynamic_decode() {
414 self.with_dynamic_decode_cache(32)
415 } else {
416 self.with_decode_cache(decode_horizon)
417 }
418 }
419
420 pub fn sync_device(&mut self) {
423 if let Some(c) = &mut self.prefill_compile_cache {
424 c.sync_all();
425 }
426 if let Some(c) = &mut self.embed_prefill_compile_cache {
427 c.sync_all();
428 }
429 if let Some(c) = &mut self.prefill_dynamic_cache {
430 c.sync_all();
431 }
432 if let Some(c) = &mut self.embed_prefill_dynamic_cache {
433 c.sync_all();
434 }
435 if let Some(c) = &mut self.decode_compile_cache {
436 c.sync_all();
437 }
438 if let Some(c) = &mut self.decode_dynamic_cache {
439 c.sync_all();
440 }
441 rlx_runtime::device_ext::drain_device(self.device);
442 }
443
444 pub fn from_path(cfg: GemmaConfig, path: &str, device: Device) -> Result<Self> {
447 let mut loader = rlx_core::weight_loader::load_from_path(path)?;
448 Self::from_loader(cfg, loader.as_mut(), device)
449 }
450
451 pub fn from_path_with_mtp(
459 cfg: GemmaConfig,
460 path: &str,
461 device: Device,
462 include_mtp: bool,
463 ) -> Result<Self> {
464 if path.ends_with(".gguf") {
468 let mut gguf = rlx_core::weight_loader::GgufLoader::from_file(path)?;
469 gguf.include_mtp(include_mtp);
470 Self::from_loader(cfg, &mut gguf, device)
471 } else {
472 Self::from_path(cfg, path, device)
473 }
474 }
475
476 pub fn prefill(&mut self, prompt_ids: &[u32]) {
480 self.tokens.clear();
481 self.tokens.extend_from_slice(prompt_ids);
482 self.cache = None;
483 self.reset_gpu_kv_binding();
484 }
485
486 pub fn prefill_from_embeds(
489 &mut self,
490 prompt_ids: &[u32],
491 embeds: &[f32],
492 attn_bias: Option<Vec<f32>>,
493 ) -> Result<()> {
494 let h = self.cfg.hidden_size;
495 if embeds.len() != prompt_ids.len() * h {
496 anyhow::bail!(
497 "prefill_from_embeds: embeds len {} != {} tokens × hidden {}",
498 embeds.len(),
499 prompt_ids.len(),
500 h
501 );
502 }
503 if let Some(ref bias) = attn_bias {
504 let seq = prompt_ids.len();
505 let nh = self.cfg.num_attention_heads;
506 let expected = seq * seq * nh;
507 if bias.len() != expected {
508 anyhow::bail!(
509 "prefill_from_embeds: attn_bias len {} != batch×heads×seq² ({expected})",
510 bias.len()
511 );
512 }
513 }
514 self.prefill(prompt_ids);
515 self.pending_prefill_embeds = Some(embeds.to_vec());
516 self.pending_prefill_attn_bias = attn_bias;
517 Ok(())
518 }
519
520 pub fn weights_cache(&self) -> &HashMap<String, (Vec<f32>, Vec<usize>)> {
522 &self.weights_cache
523 }
524
525 pub fn step(&mut self, opts: SampleOpts) -> Result<u32> {
529 if self.tokens.is_empty() {
530 anyhow::bail!("step() called with empty token history; call prefill() first");
531 }
532 let seq = self.tokens.len();
533 let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
534 let (graph, params) = build_gemma_graph_sized_last_logits(
535 &self.cfg, &mut wm, 1, seq, false,
536 )?;
537 let session = Session::new(self.device);
538 let mut compiled = self.compile_graph_profiled(&session, graph)?;
539 for (name, data) in ¶ms {
540 compiled.set_param(name, data);
541 }
542 let ids_f32: Vec<f32> = self.tokens.iter().map(|&i| i as f32).collect();
543 let outputs = compiled.run(&[("input_ids", ids_f32.as_slice())]);
544 let logits = outputs
545 .into_iter()
546 .next()
547 .context("compiled.run returned no outputs")?;
548
549 let vocab = self.cfg.vocab_size;
550 let expected = vocab;
551 if logits.len() < expected {
552 anyhow::bail!(
553 "logits length {} < expected {} (last logits, seq {seq}, vocab {vocab})",
554 logits.len(),
555 expected
556 );
557 }
558 let last_row = &logits[..vocab];
560 let tok = sample_token(last_row, opts) as u32;
561 self.tokens.push(tok);
562 Ok(tok)
563 }
564
565 pub fn generate(&mut self, n: usize, opts: SampleOpts) -> Result<Vec<u32>> {
568 let start = self.tokens.len();
569 for _ in 0..n {
570 self.step(opts)?;
571 }
572 Ok(self.tokens[start..].to_vec())
573 }
574
575 pub fn step_cached(&mut self, opts: SampleOpts) -> Result<u32> {
585 if self.tokens.is_empty() {
586 anyhow::bail!("step_cached() called with empty token history; call prefill() first");
587 }
588 if self.cache.is_none() {
589 let tok = self.seed_cache_from_prompt(opts)?;
593 return Ok(tok);
594 }
595 let cache = self.cache.as_ref().unwrap();
596 let past_seq = cache.past_len;
597 if self.tokens.len() <= past_seq {
598 anyhow::bail!(
599 "cache invariant violated: tokens.len() {} <= past_len {}",
600 self.tokens.len(),
601 past_seq
602 );
603 }
604 let input_tok = self.tokens[past_seq];
605
606 let (logits, new_k, new_v) = if self.decode_dynamic_cache.is_some() {
607 self.decode_step_dynamic(past_seq, input_tok)?
608 } else if self.decode_compile_cache.is_some()
609 && self
610 .decode_compile_cache
611 .as_ref()
612 .unwrap()
613 .bucket_for(past_seq as u64)
614 .is_some()
615 {
616 self.decode_step_bucketed(past_seq, input_tok)?
617 } else {
618 self.decode_step_oneshot(past_seq, input_tok)?
619 };
620
621 let cache_mut = self.cache.as_mut().unwrap();
622 cache_mut.past_len = past_seq + 1;
623 cache_mut.layers_k = new_k;
624 cache_mut.layers_v = new_v;
625
626 let vocab = self.cfg.vocab_size;
627 if logits.len() != vocab {
628 anyhow::bail!("decode logits length {} != vocab {}", logits.len(), vocab);
629 }
630 let tok = sample_token(&logits, opts) as u32;
631 self.tokens.push(tok);
632 Ok(tok)
633 }
634
635 #[allow(clippy::type_complexity)]
638 fn decode_step_oneshot(
639 &mut self,
640 past_seq: usize,
641 input_tok: u32,
642 ) -> Result<(Vec<f32>, Vec<Vec<f32>>, Vec<Vec<f32>>)> {
643 let cache = self.cache.as_ref().unwrap();
644
645 let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
646 let (graph, params) =
647 build_gemma_decode_graph_sized(&self.cfg, &mut wm, 1, past_seq)?;
648 let session = Session::new(self.device);
649 let mut compiled = self.compile_graph_profiled_decode(&session, graph)?;
650 for (name, data) in ¶ms {
651 compiled.set_param(name, data);
652 }
653
654 let input_ids_f32 = [input_tok as f32];
655 let key_strs: Vec<String> = (0..self.cfg.num_hidden_layers)
656 .flat_map(|i| [format!("past_k_{i}"), format!("past_v_{i}")])
657 .collect();
658 let mut inputs: Vec<(&str, &[f32])> =
659 Vec::with_capacity(1 + 2 * self.cfg.num_hidden_layers);
660 inputs.push(("input_ids", input_ids_f32.as_slice()));
661 for i in 0..self.cfg.num_hidden_layers {
662 inputs.push((&key_strs[2 * i], cache.layers_k[i].as_slice()));
663 inputs.push((&key_strs[2 * i + 1], cache.layers_v[i].as_slice()));
664 }
665
666 let outputs = compiled.run(&inputs);
667 split_decode_logits_kv(outputs, self.cfg.num_hidden_layers)
668 }
669
670 #[allow(clippy::type_complexity)]
671 fn decode_step_dynamic(
672 &mut self,
673 past_seq: usize,
674 input_tok: u32,
675 ) -> Result<(Vec<f32>, Vec<Vec<f32>>, Vec<Vec<f32>>)> {
676 let cache = self.cache.as_ref().unwrap();
677 let binding = DimBinding::batch_past_seq(1, past_seq);
678 let opts = self
679 .profile_compile_options(true)
680 .dim_binding(binding.clone());
681 let cache_dyn = self
682 .decode_dynamic_cache
683 .as_mut()
684 .ok_or_else(|| anyhow::anyhow!("dynamic decode without cache"))?;
685 let needs_upload = !cache_dyn.contains(past_seq as u64);
686 let cfg = self.cfg.clone();
687 let weights_cache = self.weights_cache.clone();
688 let max_past = self.cfg.max_position_embeddings;
689 let compiled = metal_decode_compile_guard(self.device, true, || {
690 cache_dyn.get_or_specialize(
691 past_seq as u64,
692 &binding,
693 || {
694 let mut wm = WeightMap::from_tensors(weights_cache);
695 build_gemma_decode_hir_dynamic_ext(&cfg, &mut wm, 1, max_past)
696 .expect("dynamic decode HIR")
697 .0
698 },
699 &opts,
700 )
701 })?;
702 if needs_upload {
703 let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
704 let (_, params) = build_gemma_decode_hir_dynamic_ext(&self.cfg, &mut wm, 1, max_past)?;
705 for (name, data) in ¶ms {
706 compiled.set_param(name, data);
707 }
708 }
709
710 let (cos, sin) = compute_rope_slice(&self.inv_freq, past_seq);
711 let input_ids_f32 = [input_tok as f32];
712 let key_strs: Vec<String> = (0..self.cfg.num_hidden_layers)
713 .flat_map(|i| [format!("past_k_{i}"), format!("past_v_{i}")])
714 .collect();
715 let mut inputs: Vec<(&str, &[f32])> =
716 Vec::with_capacity(3 + 2 * self.cfg.num_hidden_layers);
717 inputs.push(("input_ids", input_ids_f32.as_slice()));
718 inputs.push(("rope_cos", cos.as_slice()));
719 inputs.push(("rope_sin", sin.as_slice()));
720 for i in 0..self.cfg.num_hidden_layers {
721 inputs.push((&key_strs[2 * i], cache.layers_k[i].as_slice()));
722 inputs.push((&key_strs[2 * i + 1], cache.layers_v[i].as_slice()));
723 }
724 let outputs = compiled.run(&inputs);
725 split_decode_logits_kv(outputs, self.cfg.num_hidden_layers)
726 }
727
728 #[allow(clippy::type_complexity)]
729 fn decode_step_bucketed(
730 &mut self,
731 past_seq: usize,
732 input_tok: u32,
733 ) -> Result<(Vec<f32>, Vec<Vec<f32>>, Vec<Vec<f32>>)> {
734 let kv_dims = self.per_layer_kv_dims();
735 let n_layers = self.cfg.num_hidden_layers;
736 let decode_opts = self.profile_compile_options(true);
737 let upper = self
738 .decode_compile_cache
739 .as_ref()
740 .and_then(|cache_dec| {
741 cache_dec.bucket_for(past_seq as u64).map(|idx| {
742 cache_dec
743 .buckets()
744 .nth(idx)
745 .map(|r| (r.end - 1) as usize)
746 .unwrap_or(past_seq)
747 })
748 })
749 .unwrap_or(past_seq);
750
751 self.decode_scratch.ensure_bucket(upper, &kv_dims);
752 self.decode_inputs.fill_mask(past_seq, upper);
753 self.decode_inputs.fill_rope(&self.inv_freq, past_seq);
754
755 let input_ids_f32 = [input_tok as f32];
756 let fixed = [
757 CacheRunInput {
758 name: "input_ids",
759 data: &input_ids_f32,
760 row_inner: None,
761 },
762 CacheRunInput {
763 name: "rope_cos",
764 data: &self.decode_inputs.cos,
765 row_inner: None,
766 },
767 CacheRunInput {
768 name: "rope_sin",
769 data: &self.decode_inputs.sin,
770 row_inner: None,
771 },
772 CacheRunInput {
773 name: "mask",
774 data: &self.decode_inputs.mask,
775 row_inner: None,
776 },
777 ];
778
779 if self.use_gpu_kv && self.decode_compile_cache.is_some() {
780 let key = past_seq as u64;
781 let upper_u = upper as u64;
782 let prev_upper = self.gpu_kv_binding.upper;
783 let bucket_changed = prev_upper != 0 && prev_upper != upper_u;
784 let handles_live = self
785 .decode_compile_cache
786 .as_mut()
787 .and_then(|c| c.compiled_for_key_mut(key))
788 .map(|cg| cg.has_gpu_handle("past_k_0"))
789 .unwrap_or(false);
790 let refresh_kv = matches!(self.device, Device::Gpu | Device::Metal)
791 || bucket_changed
792 || !handles_live;
793 let cfg = self.cfg.clone();
794 let weights = self.weights_cache.clone();
795 let logits = {
796 let cache_dec = self.decode_compile_cache.as_mut().unwrap();
797 let cache_mut = self.cache.as_mut().unwrap();
798 metal_decode_compile_guard(self.device, true, || {
799 run_bucketed_kv_decode_gpu_hir(
800 cache_dec,
801 key,
802 past_seq,
803 cache_mut,
804 &mut self.gpu_kv_binding,
805 self.cfg.kv_proj_dim(),
806 n_layers,
807 &fixed,
808 move |upper| {
809 let mut wm = WeightMap::from_tensors(weights.clone());
810 build_gemma_decode_hir_sized_ext(&cfg, &mut wm, 1, upper as usize, true)
811 .expect("gemma bucketed decode HIR")
812 },
813 &decode_opts,
814 refresh_kv,
815 )
816 })?
817 };
818 if let Some(compiled) = self
819 .decode_compile_cache
820 .as_mut()
821 .and_then(|c| c.compiled_for_key_mut(key))
822 {
823 let cache_mut = self.cache.as_mut().unwrap();
824 sync_gpu_kv_to_host(compiled, cache_mut, self.cfg.kv_proj_dim(), n_layers)?;
825 }
826 let next_key = (past_seq + 1) as u64;
827 let next_upper = self
828 .decode_compile_cache
829 .as_ref()
830 .and_then(|cache| {
831 cache
832 .bucket_for(next_key)
833 .and_then(|idx| cache.buckets().nth(idx).map(|r| (r.end - 1) as usize))
834 })
835 .unwrap_or(upper);
836 if next_upper != upper {
837 self.reset_gpu_kv_binding();
838 }
839 let cache_mut = self.cache.as_ref().unwrap();
840 let new_k = cache_mut.layers_k.clone();
841 let new_v = cache_mut.layers_v.clone();
842 return Ok((logits, new_k, new_v));
843 }
844
845 let cfg = self.cfg.clone();
846 let weights = self.weights_cache.clone();
847 let cache_dec = self.decode_compile_cache.as_mut().unwrap();
848 let kv_cache = self.cache.as_ref().unwrap();
849 let DecodeKvScratch {
850 padded_k, padded_v, ..
851 } = &mut self.decode_scratch;
852 metal_decode_compile_guard(self.device, true, || {
853 run_bucketed_kv_decode_hir_scratch(
854 cache_dec,
855 past_seq,
856 kv_cache,
857 &kv_dims,
858 n_layers,
859 padded_k,
860 padded_v,
861 &fixed,
862 |upper| {
863 let mut wm = WeightMap::from_tensors(weights.clone());
864 build_gemma_decode_hir_sized_ext(&cfg, &mut wm, 1, upper as usize, true)
865 .expect("gemma bucketed decode HIR")
866 },
867 &decode_opts,
868 )
869 })
870 }
871
872 #[allow(clippy::unnecessary_unwrap)]
876 fn run_prefill_with_cache(
877 &mut self,
878 batch: usize,
879 seq: usize,
880 ids_f32: &[f32],
881 ) -> Result<Vec<Vec<f32>>> {
882 if self.prefill_dynamic_cache.is_some() {
883 let binding = DimBinding::batch_seq(batch, seq);
884 let opts = compile_options_from_profile(
885 &self.prefill_profile,
886 self.device,
887 KernelDispatchConfig::default(),
888 )
889 .dim_binding(binding.clone());
890 let cache = self.prefill_dynamic_cache.as_mut().expect("checked");
891 let needs_upload = !cache.contains(seq as u64);
892 let cfg = self.cfg.clone();
893 let weights_cache = self.weights_cache.clone();
894 let max_seq = self.cfg.max_position_embeddings;
895 let compiled = cache.get_or_specialize(
896 seq as u64,
897 &binding,
898 || {
899 let mut wm = WeightMap::from_tensors(weights_cache);
900 build_gemma_prefill_hir_dynamic_ext(&cfg, &mut wm, batch, max_seq, true)
901 .expect("dynamic prefill HIR")
902 .0
903 },
904 &opts,
905 )?;
906 if needs_upload {
907 let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
908 let (_, params) =
909 build_gemma_prefill_hir_dynamic_ext(&self.cfg, &mut wm, batch, max_seq, true)?;
910 for (name, data) in ¶ms {
911 compiled.set_param(name, data);
912 }
913 }
914 let last_idx = vec![(seq - 1) as f32];
915 Ok(compiled.run(&[("input_ids", ids_f32), ("last_token_idx", &last_idx)]))
916 } else if self.prefill_compile_cache.is_some() {
917 let key = ((batch as u64) << 32) | (seq as u64);
918 let opts = self.profile_compile_options(false);
919 if !self.prefill_compile_cache.as_ref().unwrap().contains(key) {
920 let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
921 let (graph, params) = build_gemma_graph_sized_last_logits(
922 &self.cfg, &mut wm, batch, seq, true,
923 )?;
924 {
925 let compiled = self
926 .prefill_compile_cache
927 .as_mut()
928 .unwrap()
929 .get_or_compile_with_options(key, || graph, &opts);
930 for (name, data) in ¶ms {
931 compiled.set_param(name, data);
932 }
933 }
934 }
935 let compiled = self
936 .prefill_compile_cache
937 .as_mut()
938 .unwrap()
939 .get_or_compile_with_options(key, || unreachable!("just populated above"), &opts);
940 Ok(compiled.run(&[("input_ids", ids_f32)]))
941 } else {
942 let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
943 let (graph, params) = build_gemma_graph_sized_last_logits(
944 &self.cfg, &mut wm, batch, seq, true,
945 )?;
946 let session = Session::new(self.device);
947 let opts = self.profile_compile_options(false);
948 let mut compiled = session.compile_with(graph, &opts);
949 for (name, data) in ¶ms {
950 compiled.set_param(name, data);
951 }
952 Ok(compiled.run(&[("input_ids", ids_f32)]))
953 }
954 }
955
956 fn run_prefill_hidden_with_cache(
957 &mut self,
958 batch: usize,
959 seq: usize,
960 hidden: &[f32],
961 attn_bias: Option<&[f32]>,
962 ) -> Result<Vec<Vec<f32>>> {
963 if self.cfg.use_bidirectional_vision() && attn_bias.is_none() {
964 anyhow::bail!(
965 "multimodal prefill requires attn_bias when use_bidirectional_attention=vision"
966 );
967 }
968 let mut inputs: Vec<(&str, &[f32])> = vec![("prefill_hidden", hidden)];
969 if let Some(bias) = attn_bias {
970 inputs.push(("attn_bias", bias));
971 }
972 let embed_compile_opts = self.profile_compile_options(false);
973 if let Some(cache) = &mut self.embed_prefill_dynamic_cache {
974 let binding = DimBinding::batch_seq(batch, seq);
975 let opts = compile_options_from_profile(
976 &self.prefill_profile,
977 self.device,
978 KernelDispatchConfig::default(),
979 )
980 .dim_binding(binding.clone());
981 let needs_upload = !cache.contains(seq as u64);
982 let cfg = self.cfg.clone();
983 let weights_cache = self.weights_cache.clone();
984 let max_seq = self.cfg.max_position_embeddings;
985 let compiled = cache.get_or_specialize(
986 seq as u64,
987 &binding,
988 || {
989 let mut wm = WeightMap::from_tensors(weights_cache);
990 build_gemma_prefill_hidden_hir_dynamic_ext(&cfg, &mut wm, batch, max_seq, true)
991 .expect("dynamic hidden prefill HIR")
992 .0
993 },
994 &opts,
995 )?;
996 if needs_upload {
997 let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
998 let (_, params) = build_gemma_prefill_hidden_hir_dynamic_ext(
999 &self.cfg, &mut wm, batch, max_seq, true,
1000 )?;
1001 for (name, data) in ¶ms {
1002 compiled.set_param(name, data);
1003 }
1004 }
1005 let last_idx = vec![(seq - 1) as f32];
1006 let mut dyn_inputs = inputs.clone();
1007 dyn_inputs.push(("last_token_idx", &last_idx));
1008 Ok(compiled.run(&dyn_inputs))
1009 } else if let Some(cache) = &mut self.embed_prefill_compile_cache {
1010 let key = ((batch as u64) << 32) | (seq as u64);
1011 let opts = &embed_compile_opts;
1012 if !cache.contains(key) {
1013 let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
1014 let (graph, params) = build_gemma_graph_sized_last_logits_hidden(
1015 &self.cfg, &mut wm, batch, seq, true,
1016 )?;
1017 {
1018 let compiled = cache.get_or_compile_with_options(key, || graph, opts);
1019 for (name, data) in ¶ms {
1020 compiled.set_param(name, data);
1021 }
1022 }
1023 }
1024 let compiled = cache.get_or_compile_with_options(
1025 key,
1026 || unreachable!("just populated above"),
1027 opts,
1028 );
1029 Ok(compiled.run(&inputs))
1030 } else {
1031 let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
1032 let (graph, params) =
1033 build_gemma_graph_sized_last_logits_hidden(&self.cfg, &mut wm, batch, seq, true)?;
1034 let session = Session::new(self.device);
1035 let opts = self.profile_compile_options(false);
1036 let mut compiled = session.compile_with(graph, &opts);
1037 for (name, data) in ¶ms {
1038 compiled.set_param(name, data);
1039 }
1040 Ok(compiled.run(&inputs))
1041 }
1042 }
1043
1044 pub fn generate_from_embeds(
1046 &mut self,
1047 prompt_ids: &[u32],
1048 embeds: &[f32],
1049 n: usize,
1050 opts: SampleOpts,
1051 ) -> Result<Vec<u32>> {
1052 self.generate_from_embeds_with_bias(prompt_ids, embeds, None, n, opts)
1053 }
1054
1055 pub fn generate_from_embeds_with_bias(
1056 &mut self,
1057 prompt_ids: &[u32],
1058 embeds: &[f32],
1059 attn_bias: Option<Vec<f32>>,
1060 n: usize,
1061 opts: SampleOpts,
1062 ) -> Result<Vec<u32>> {
1063 self.prefill_from_embeds(prompt_ids, embeds, attn_bias)?;
1064 self.generate_cached(n, opts)
1065 }
1066
1067 pub fn generate_from_embeds_with(
1069 &mut self,
1070 prompt_ids: &[u32],
1071 embeds: &[f32],
1072 n: usize,
1073 opts: SampleOpts,
1074 on_token: impl FnMut(u32),
1075 ) -> Result<Vec<u32>> {
1076 self.generate_from_embeds_with_bias_and_callback(
1077 prompt_ids, embeds, None, n, opts, on_token,
1078 )
1079 }
1080
1081 pub fn generate_from_embeds_with_bias_and_callback(
1082 &mut self,
1083 prompt_ids: &[u32],
1084 embeds: &[f32],
1085 attn_bias: Option<Vec<f32>>,
1086 n: usize,
1087 opts: SampleOpts,
1088 on_token: impl FnMut(u32),
1089 ) -> Result<Vec<u32>> {
1090 self.prefill_from_embeds(prompt_ids, embeds, attn_bias)?;
1091 self.generate_cached_with(n, opts, on_token)
1092 }
1093
1094 pub fn generate_cached(&mut self, n: usize, opts: SampleOpts) -> Result<Vec<u32>> {
1096 self.generate_cached_with(n, opts, |_| {})
1097 }
1098
1099 pub fn generate_cached_with(
1106 &mut self,
1107 n: usize,
1108 opts: SampleOpts,
1109 mut on_token: impl FnMut(u32),
1110 ) -> Result<Vec<u32>> {
1111 let start = self.tokens.len();
1112 for _ in 0..n {
1113 let tok = self.step_cached(opts)?;
1114 on_token(tok);
1115 }
1116 Ok(self.tokens[start..].to_vec())
1117 }
1118
1119 fn seed_cache_from_prompt(&mut self, opts: SampleOpts) -> Result<u32> {
1124 let seq = self.tokens.len();
1125 let batch = 1usize;
1126 let kv_dims = self.per_layer_kv_dims();
1127
1128 let outputs = if let Some(embeds) = self.pending_prefill_embeds.take() {
1129 let bias = self.pending_prefill_attn_bias.take();
1130 self.run_prefill_hidden_with_cache(batch, seq, &embeds, bias.as_deref())?
1131 } else {
1132 let ids_f32: Vec<f32> = self.tokens.iter().map(|&i| i as f32).collect();
1133 self.run_prefill_with_cache(batch, seq, &ids_f32)?
1134 };
1135 let (logits, kv) = kv_from_prefill_outputs_per_layer(
1136 outputs,
1137 batch,
1138 seq,
1139 &kv_dims,
1140 self.cfg.num_hidden_layers,
1141 )?;
1142 self.cache = Some(kv);
1143
1144 let vocab = self.cfg.vocab_size;
1145 let needed = vocab;
1146 if logits.len() < needed {
1147 anyhow::bail!("prefill logits length {} < {}", logits.len(), needed);
1148 }
1149 let last_row = &logits[..vocab];
1150 let tok = sample_token(last_row, opts) as u32;
1151 self.tokens.push(tok);
1152 Ok(tok)
1153 }
1154
1155 pub fn tokens(&self) -> &[u32] {
1157 &self.tokens
1158 }
1159
1160 pub fn config(&self) -> &GemmaConfig {
1161 &self.cfg
1162 }
1163
1164 pub fn prefill_get_last_logits(&mut self, context: &[u32]) -> Result<Vec<f32>> {
1172 if context.is_empty() {
1173 anyhow::bail!("prefill_get_last_logits: empty context");
1174 }
1175 self.tokens.clear();
1176 self.tokens.extend_from_slice(context);
1177 self.cache = None;
1178 self.reset_gpu_kv_binding();
1179
1180 let seq = context.len();
1181 let batch = 1usize;
1182 let kv_dims = self.per_layer_kv_dims();
1183
1184 let ids_f32: Vec<f32> = context.iter().map(|&i| i as f32).collect();
1185 let outputs = self.run_prefill_with_cache(batch, seq, &ids_f32)?;
1186 let (logits, kv) = kv_from_prefill_outputs_per_layer(
1187 outputs,
1188 batch,
1189 seq,
1190 &kv_dims,
1191 self.cfg.num_hidden_layers,
1192 )?;
1193 self.cache = Some(kv);
1194
1195 let vocab = self.cfg.vocab_size;
1196 let needed = vocab;
1197 if logits.len() < needed {
1198 anyhow::bail!("logits short: {} < {}", logits.len(), needed);
1199 }
1200 Ok(logits[..vocab].to_vec())
1201 }
1202
1203 pub fn decode_get_logits(&mut self, input: u32) -> Result<Vec<f32>> {
1211 if self.cache.is_none() {
1212 anyhow::bail!(
1213 "decode_get_logits: cache not seeded; call prefill_get_last_logits first"
1214 );
1215 }
1216 self.tokens.push(input);
1217 let seq = self.tokens.len();
1218 let batch = 1usize;
1219 let kv_dims = self.per_layer_kv_dims();
1220 let ids_f32: Vec<f32> = self.tokens.iter().map(|&i| i as f32).collect();
1221 let outputs = self.run_prefill_with_cache(batch, seq, &ids_f32)?;
1222 let (logits, kv) = kv_from_prefill_outputs_per_layer(
1223 outputs,
1224 batch,
1225 seq,
1226 &kv_dims,
1227 self.cfg.num_hidden_layers,
1228 )?;
1229 self.cache = Some(kv);
1230 let vocab = self.cfg.vocab_size;
1231 Ok(logits[..vocab].to_vec())
1232 }
1233
1234 fn per_layer_kv_dims(&self) -> Vec<usize> {
1237 (0..self.cfg.num_hidden_layers)
1238 .map(|i| self.cfg.layer_num_kv_heads(i) * self.cfg.layer_head_dim(i))
1239 .collect()
1240 }
1241}
1242
1243impl Drop for GemmaGenerator {
1244 fn drop(&mut self) {
1245 if self.device == Device::Metal {
1246 self.sync_device();
1247 }
1248 }
1249}
1250
1251fn compute_rope_slice(inv_freq: &[f64], pos: usize) -> (Vec<f32>, Vec<f32>) {
1255 rope_slice(inv_freq, pos)
1256}
1257
1258#[cfg(test)]
1259mod tests {
1260 use super::*;
1261 use crate::config::GemmaConfig;
1262 use crate::rope::{build_rope_tables, resolve_inv_freq, rope_slice};
1263 use rlx_flow::CompileProfile;
1264
1265 fn tiny_cfg() -> GemmaConfig {
1266 let mut cfg = GemmaConfig::tiny_test();
1267 cfg.vocab_size = 16;
1268 cfg.head_dim = Some(8);
1269 cfg
1270 }
1271
1272 fn synthetic_tensors(cfg: &GemmaConfig) -> HashMap<String, (Vec<f32>, Vec<usize>)> {
1273 let h = cfg.hidden_size;
1274 let q_dim = cfg.q_proj_dim();
1275 let kv_dim = cfg.kv_proj_dim();
1276 let int_dim = cfg.intermediate_size;
1277 let mut t: HashMap<String, (Vec<f32>, Vec<usize>)> = HashMap::new();
1278 let pat = |n: usize, salt: u32| -> Vec<f32> {
1281 (0..n)
1282 .map(|i| {
1283 let x = ((i as u32).wrapping_mul(2654435761).wrapping_add(salt)) >> 8;
1284 (x as f32 / (1u32 << 24) as f32) - 0.5
1285 })
1286 .collect()
1287 };
1288 t.insert(
1289 "model.embed_tokens.weight".into(),
1290 (pat(cfg.vocab_size * h, 1), vec![cfg.vocab_size, h]),
1291 );
1292 for i in 0..cfg.num_hidden_layers {
1293 let lp = format!("model.layers.{i}");
1294 t.insert(
1295 format!("{lp}.input_layernorm.weight"),
1296 (pat(h, 100 + i as u32), vec![h]),
1297 );
1298 t.insert(
1299 format!("{lp}.post_attention_layernorm.weight"),
1300 (pat(h, 200 + i as u32), vec![h]),
1301 );
1302 t.insert(
1303 format!("{lp}.self_attn.q_proj.weight"),
1304 (pat(q_dim * h, 300 + i as u32), vec![q_dim, h]),
1305 );
1306 t.insert(
1307 format!("{lp}.self_attn.k_proj.weight"),
1308 (pat(kv_dim * h, 400 + i as u32), vec![kv_dim, h]),
1309 );
1310 t.insert(
1311 format!("{lp}.self_attn.v_proj.weight"),
1312 (pat(kv_dim * h, 500 + i as u32), vec![kv_dim, h]),
1313 );
1314 t.insert(
1315 format!("{lp}.self_attn.o_proj.weight"),
1316 (pat(h * q_dim, 600 + i as u32), vec![h, q_dim]),
1317 );
1318 t.insert(
1319 format!("{lp}.mlp.gate_proj.weight"),
1320 (pat(int_dim * h, 900 + i as u32), vec![int_dim, h]),
1321 );
1322 t.insert(
1323 format!("{lp}.mlp.up_proj.weight"),
1324 (pat(int_dim * h, 1000 + i as u32), vec![int_dim, h]),
1325 );
1326 t.insert(
1327 format!("{lp}.mlp.down_proj.weight"),
1328 (pat(h * int_dim, 1100 + i as u32), vec![h, int_dim]),
1329 );
1330 }
1331 t.insert("model.norm.weight".into(), (pat(h, 2000), vec![h]));
1332 t.insert(
1333 "lm_head.weight".into(),
1334 (pat(cfg.vocab_size * h, 3000), vec![cfg.vocab_size, h]),
1335 );
1336 t
1337 }
1338
1339 fn synthetic_weights(cfg: &GemmaConfig) -> WeightMap {
1340 WeightMap::from_tensors(synthetic_tensors(cfg))
1341 }
1342
1343 #[test]
1344 fn generator_drains_loader_and_runs_one_step() {
1345 let cfg = tiny_cfg();
1346 let mut wm = synthetic_weights(&cfg);
1347 let mut gn = GemmaGenerator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap();
1348 assert_eq!(wm.len(), 0, "loader should be drained");
1349 gn.prefill(&[1, 2, 3]);
1350 let t = gn.step(SampleOpts::greedy()).unwrap();
1351 assert!((t as usize) < cfg.vocab_size);
1352 assert_eq!(gn.tokens().len(), 4);
1353 }
1354
1355 #[test]
1356 fn generate_n_appends_n_tokens() {
1357 let cfg = tiny_cfg();
1358 let mut wm = synthetic_weights(&cfg);
1359 let mut gn = GemmaGenerator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap();
1360 gn.prefill(&[5, 6]);
1361 let new_tokens = gn.generate(3, SampleOpts::greedy()).unwrap();
1362 assert_eq!(new_tokens.len(), 3);
1363 assert_eq!(gn.tokens().len(), 5);
1364 for t in &new_tokens {
1365 assert!((*t as usize) < cfg.vocab_size);
1366 }
1367 }
1368
1369 #[test]
1370 fn step_without_prefill_errors() {
1371 let cfg = tiny_cfg();
1372 let mut wm = synthetic_weights(&cfg);
1373 let mut gn = GemmaGenerator::from_loader(cfg, &mut wm, Device::Cpu).unwrap();
1374 let r = gn.step(SampleOpts::greedy());
1375 assert!(r.is_err());
1376 }
1377
1378 fn max_abs_diff(a: &[f32], b: &[f32]) -> f32 {
1379 a.iter()
1380 .zip(b.iter())
1381 .map(|(x, y)| (x - y).abs())
1382 .fold(0f32, f32::max)
1383 }
1384
1385 #[test]
1386 fn prefill_logits_unchanged_with_kv_export() {
1387 let cfg = tiny_cfg();
1388 let prompt: Vec<u32> = vec![1, 2, 3, 5];
1389
1390 let mut wm_a = synthetic_weights(&cfg);
1391 let mut wm_b = synthetic_weights(&cfg);
1392 let (graph_a, params_a) =
1393 build_gemma_graph_sized_last_logits(&cfg, &mut wm_a, 1, 4, false).unwrap();
1394 let (graph_b, params_b) =
1395 build_gemma_graph_sized_last_logits(&cfg, &mut wm_b, 1, 4, true).unwrap();
1396 let session = Session::new(Device::Cpu);
1397 let opts = CompileOptions::new();
1398 let mut ca = session.compile_with(graph_a, &opts);
1399 let mut cb = session.compile_with(graph_b, &opts);
1400 for (n, d) in ¶ms_a {
1401 ca.set_param(n, d);
1402 }
1403 for (n, d) in ¶ms_b {
1404 cb.set_param(n, d);
1405 }
1406 let ids: Vec<f32> = prompt.iter().map(|&i| i as f32).collect();
1407 let la = ca.run(&[("input_ids", &ids)])[0].clone();
1408 let lb = cb.run(&[("input_ids", &ids)])[0].clone();
1409 let d = max_abs_diff(&la, &lb);
1410 assert!(d < 1e-5, "kv export changed prefill logits: max_abs={d:.6}");
1411 }
1412
1413 #[test]
1414 fn incremental_decode_logits_match_full_prefill() {
1415 let cfg = tiny_cfg();
1416 let prompt: Vec<u32> = vec![1, 2, 3, 5];
1417
1418 let mut wm_a = synthetic_weights(&cfg);
1419 let mut gn_a = GemmaGenerator::from_loader(cfg.clone(), &mut wm_a, Device::Cpu).unwrap();
1420 let tok = gn_a
1421 .prefill_get_last_logits(&prompt)
1422 .map(|l| sample_token(&l, SampleOpts::greedy()) as u32)
1423 .unwrap();
1424
1425 let mut extended = prompt.clone();
1426 extended.push(tok);
1427
1428 let mut wm_b = synthetic_weights(&cfg);
1429 let mut gn_b = GemmaGenerator::from_loader(cfg.clone(), &mut wm_b, Device::Cpu).unwrap();
1430 let full = gn_b.prefill_get_last_logits(&extended).unwrap();
1431
1432 let mut wm_c = synthetic_weights(&cfg);
1433 let mut gn_c = GemmaGenerator::from_loader(cfg.clone(), &mut wm_c, Device::Cpu).unwrap();
1434 gn_c.prefill_get_last_logits(&prompt).unwrap();
1435 let incremental = gn_c.decode_get_logits(tok).unwrap();
1436
1437 let d = max_abs_diff(&full, &incremental);
1438 assert!(
1439 d < 1e-2,
1440 "decode+KV vs full prefill max_abs={d:.6} (tok={tok})"
1441 );
1442 }
1443
1444 fn run_prefill_kv(
1445 cfg: &GemmaConfig,
1446 wm: &mut WeightMap,
1447 seq: usize,
1448 ids: &[u32],
1449 ) -> Vec<Vec<f32>> {
1450 run_prefill_kv_with_options(cfg, wm, seq, ids, &kv_export_compile_options(true))
1451 }
1452
1453 fn kv_export_compile_options(prefill: bool) -> CompileOptions {
1454 let profile = if prefill {
1455 CompileProfile::gemma_prefill()
1456 } else {
1457 CompileProfile::gemma_decode()
1458 };
1459 compile_options_from_profile(&profile, Device::Cpu, KernelDispatchConfig::default())
1460 }
1461
1462 fn run_prefill_kv_with_options(
1463 cfg: &GemmaConfig,
1464 wm: &mut WeightMap,
1465 seq: usize,
1466 ids: &[u32],
1467 opts: &CompileOptions,
1468 ) -> Vec<Vec<f32>> {
1469 let ids_f32: Vec<f32> = ids.iter().map(|&i| i as f32).collect();
1470 let (graph, params) = build_gemma_graph_sized_last_logits(cfg, wm, 1, seq, true).unwrap();
1471 let session = Session::new(Device::Cpu);
1472 let mut compiled = session.compile_with(graph, opts);
1473 for (n, d) in ¶ms {
1474 compiled.set_param(n, d);
1475 }
1476 let outputs = compiled.run(&[("input_ids", &ids_f32)]);
1477 let n_layers = cfg.num_hidden_layers;
1478 assert_eq!(outputs.len(), 1 + 2 * n_layers);
1479 let mut kv = Vec::with_capacity(2 * n_layers);
1480 let mut iter = outputs.into_iter().skip(1);
1481 for _ in 0..n_layers {
1482 kv.push(iter.next().unwrap());
1483 kv.push(iter.next().unwrap());
1484 }
1485 kv
1486 }
1487
1488 #[test]
1489 fn decode_graph_bakes_rope_slice_length() {
1490 let cfg = tiny_cfg();
1491 let past_seq = 4usize;
1492 let half = cfg.head_dim() / 2;
1493 let mut wm = synthetic_weights(&cfg);
1494 let (_, params) = build_gemma_decode_graph_sized(&cfg, &mut wm, 1, past_seq).unwrap();
1495 let cos = params
1496 .get("decode.rope.cos")
1497 .expect("decode.rope.cos param");
1498 let sin = params
1499 .get("decode.rope.sin")
1500 .expect("decode.rope.sin param");
1501 assert_eq!(
1502 cos.len(),
1503 half,
1504 "cos param should be one row (half={half}), got {}",
1505 cos.len()
1506 );
1507 assert_eq!(sin.len(), half);
1508 for key in params.keys() {
1509 assert!(
1510 !key.starts_with("rope."),
1511 "decode graph must not include prefill rope table param {key}"
1512 );
1513 }
1514 let inv = resolve_inv_freq(&cfg, None);
1515 let (c_ref, s_ref) = rope_slice(&inv, past_seq);
1516 let d = max_abs_diff(cos, &c_ref) + max_abs_diff(sin, &s_ref);
1517 assert!(d < 1e-6, "baked rope mismatch: {d}");
1518 }
1519
1520 #[test]
1521 fn decode_graph_all_rope_use_baked_cos() {
1522 use rlx_ir::Op;
1523 let cfg = tiny_cfg();
1524 let mut wm = synthetic_weights(&cfg);
1525 let (graph, _) = build_gemma_decode_graph_sized(&cfg, &mut wm, 1, 4).unwrap();
1526 for node in graph.nodes() {
1527 if let Op::Rope { .. } = &node.op {
1528 let cos_id = node.inputs[1];
1529 let cos_node = &graph.node(cos_id);
1530 match &cos_node.op {
1531 Op::Param { name } => assert_eq!(
1532 name, "decode.rope.cos",
1533 "decode RoPE must use baked decode.rope.cos, got {name}"
1534 ),
1535 other => panic!("decode RoPE cos input is {other:?}, expected Param"),
1536 }
1537 }
1538 }
1539 }
1540
1541 #[test]
1542 fn decode_graph_rope_cos_is_single_row() {
1543 use rlx_ir::Op;
1544 let cfg = tiny_cfg();
1545 let past_seq = 4usize;
1546 let half = cfg.head_dim() / 2;
1547 let mut wm = synthetic_weights(&cfg);
1548 let (graph, _) = build_gemma_decode_graph_sized(&cfg, &mut wm, 1, past_seq).unwrap();
1549 let mut rope_cos_lens = Vec::new();
1550 for node in graph.nodes() {
1551 if let Op::Rope { .. } = &node.op {
1552 let cos_shape = &graph.node(node.inputs[1]).shape;
1553 let rows = if cos_shape.rank() >= 2 {
1554 cos_shape.dim(0).unwrap_static()
1555 } else {
1556 1
1557 };
1558 rope_cos_lens.push(rows);
1559 }
1560 }
1561 assert!(!rope_cos_lens.is_empty(), "decode graph has no RoPE nodes");
1562 for rows in &rope_cos_lens {
1563 assert_eq!(
1564 *rows, 1,
1565 "decode RoPE cos must be single-row [1, half], got {rows} rows"
1566 );
1567 }
1568 assert_eq!(half, cfg.head_dim() / 2);
1569 }
1570
1571 #[test]
1572 fn prefill_kv_matches_extended_prefix() {
1573 let cfg = tiny_cfg();
1574 let prompt: Vec<u32> = vec![1, 2, 3, 5];
1575 let tok = 6u32;
1576 let mut extended = prompt.clone();
1577 extended.push(tok);
1578
1579 let mut wm_prompt = synthetic_weights(&cfg);
1580 let prompt_kv = run_prefill_kv(&cfg, &mut wm_prompt, 4, &prompt);
1581 let mut wm_ext = synthetic_weights(&cfg);
1582 let ext_kv = run_prefill_kv(&cfg, &mut wm_ext, 5, &extended);
1583
1584 let kv_dim = cfg.kv_proj_dim();
1585 for layer in 0..cfg.num_hidden_layers {
1586 let k_prompt = &prompt_kv[2 * layer];
1587 let k_ext = &ext_kv[2 * layer];
1588 let prefix_len = 4 * kv_dim;
1589 assert_eq!(k_prompt.len(), prefix_len);
1590 assert_eq!(k_ext.len(), 5 * kv_dim);
1591 let d = max_abs_diff(k_prompt, &k_ext[..prefix_len]);
1592 assert!(
1593 d < 1e-4,
1594 "layer {layer} prefill K prefix vs extended K max_abs={d:.6}"
1595 );
1596 }
1597 }
1598
1599 #[test]
1600 fn decode_rope_slice_matches_prefill_table_row() {
1601 let cfg = tiny_cfg();
1602 let inv = resolve_inv_freq(&cfg, None);
1603 let (cos_tab, sin_tab) = build_rope_tables(&inv, cfg.max_position_embeddings);
1604 let half = inv.len();
1605 for pos in [3usize, 4, 5] {
1606 let (c, s) = rope_slice(&inv, pos);
1607 let off = pos * half;
1608 let d = max_abs_diff(&c, &cos_tab[off..off + half])
1609 + max_abs_diff(&s, &sin_tab[off..off + half]);
1610 assert!(d < 1e-6, "rope_slice mismatch at pos {pos}: {d}");
1611 }
1612 }
1613
1614 #[test]
1615 fn prefill_kv_export_correct_with_fusion() {
1616 let cfg = tiny_cfg();
1617 let tok = 6u32;
1618 let ids = [1u32, 2, 3, 5, tok];
1619 let opts = kv_export_compile_options(true);
1620 let mut wm_one = synthetic_weights(&cfg);
1621 let one_kv = run_prefill_kv_with_options(&cfg, &mut wm_one, 1, &[tok], &opts);
1622 let mut wm_ext = synthetic_weights(&cfg);
1623 let ext_kv = run_prefill_kv_with_options(&cfg, &mut wm_ext, 5, &ids, &opts);
1624 let kv_dim = cfg.kv_proj_dim();
1625 let d = max_abs_diff(&ext_kv[1][4 * kv_dim..], &one_kv[1][..kv_dim]);
1626 assert!(d < 1e-4, "KV export mismatch with profile fusion: {d:.6}");
1627
1628 let mut wm_default = synthetic_weights(&cfg);
1629 let default_kv =
1630 run_prefill_kv_with_options(&cfg, &mut wm_default, 5, &ids, &CompileOptions::new());
1631 let d_default = max_abs_diff(&default_kv[1][4 * kv_dim..], &one_kv[1][..kv_dim]);
1632 assert!(
1633 d_default < 1e-4,
1634 "KV export mismatch with default fusion (got {d_default:.6})"
1635 );
1636 }
1637
1638 #[test]
1639 fn decode_oneshot_kv_suffix_matches_extended() {
1640 let cfg = tiny_cfg();
1641 let prompt: Vec<u32> = vec![1, 2, 3, 5];
1642 let tok = 6u32;
1643 let mut extended = prompt.clone();
1644 extended.push(tok);
1645
1646 let opts = kv_export_compile_options(false);
1647 let mut wm_ext = synthetic_weights(&cfg);
1648 let ext_kv = run_prefill_kv_with_options(&cfg, &mut wm_ext, 5, &extended, &opts);
1649
1650 let mut wm = synthetic_weights(&cfg);
1651 let mut gn = GemmaGenerator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap();
1652 gn.prefill_get_last_logits(&prompt).unwrap();
1653
1654 let mut wm_d = synthetic_weights(&cfg);
1655 let (graph, params) = build_gemma_decode_graph_sized(&cfg, &mut wm_d, 1, 4).unwrap();
1656 let session = Session::new(Device::Cpu);
1657 let mut compiled = session.compile_with(graph, &opts);
1658 for (n, d) in ¶ms {
1659 compiled.set_param(n, d);
1660 }
1661 let cache = gn.cache.as_ref().unwrap();
1662 let key_strs: Vec<String> = (0..cfg.num_hidden_layers)
1663 .flat_map(|i| [format!("past_k_{i}"), format!("past_v_{i}")])
1664 .collect();
1665 let input_ids = [tok as f32];
1666 let mut inputs: Vec<(&str, &[f32])> = vec![("input_ids", input_ids.as_slice())];
1667 for i in 0..cfg.num_hidden_layers {
1668 inputs.push((&key_strs[2 * i], cache.layers_k[i].as_slice()));
1669 inputs.push((&key_strs[2 * i + 1], cache.layers_v[i].as_slice()));
1670 }
1671 let outputs = compiled.run(&inputs);
1672 let kv_dim = cfg.kv_proj_dim();
1673 let k_dec = &outputs[1][4 * kv_dim..];
1674
1675 let d = max_abs_diff(k_dec, &ext_kv[0][4 * kv_dim..]);
1676 assert!(
1677 d < 1e-3,
1678 "decode oneshot layer0 K suffix vs extended max_abs={d:.6}"
1679 );
1680 }
1681
1682 #[test]
1683 fn decode_logits_match_extended_prefill_after_one_token() {
1684 let cfg = tiny_cfg();
1685 let prompt: Vec<u32> = vec![1, 2, 3, 5];
1686 let tok = 6u32;
1687
1688 let mut extended = prompt.clone();
1689 extended.push(tok);
1690
1691 let mut wm_a = synthetic_weights(&cfg);
1692 let mut gn_a = GemmaGenerator::from_loader(cfg.clone(), &mut wm_a, Device::Cpu).unwrap();
1693 let full = gn_a.prefill_get_last_logits(&extended).unwrap();
1694
1695 let mut wm_b = synthetic_weights(&cfg);
1696 let mut gn_b = GemmaGenerator::from_loader(cfg.clone(), &mut wm_b, Device::Cpu).unwrap();
1697 gn_b.prefill_get_last_logits(&prompt).unwrap();
1698 let inc = gn_b.decode_get_logits(tok).unwrap();
1699
1700 let d = max_abs_diff(&full, &inc);
1701 assert!(d < 1e-2, "decode vs extended prefill max_abs={d:.6}");
1702 }
1703
1704 #[test]
1705 fn cached_second_token_matches_naive() {
1706 let cfg = tiny_cfg();
1707 let prompt: Vec<u32> = vec![1, 2, 3, 5];
1708
1709 let mut wm_n = synthetic_weights(&cfg);
1710 let mut gn_n = GemmaGenerator::from_loader(cfg.clone(), &mut wm_n, Device::Cpu).unwrap();
1711 gn_n.prefill(&prompt);
1712 let n0 = gn_n.step(SampleOpts::greedy()).unwrap();
1713 let n1 = gn_n.step(SampleOpts::greedy()).unwrap();
1714
1715 let mut wm_c = synthetic_weights(&cfg);
1716 let mut gn_c = GemmaGenerator::from_loader(cfg.clone(), &mut wm_c, Device::Cpu).unwrap();
1717 gn_c.prefill(&prompt);
1718 let c = gn_c.generate_cached(2, SampleOpts::greedy()).unwrap();
1719
1720 assert_eq!(c[0], n0, "first generated token");
1721 assert_eq!(c[1], n1, "second generated token (decode step)");
1722 }
1723
1724 #[test]
1725 fn cached_matches_naive_on_greedy() {
1726 let cfg = tiny_cfg();
1733 let prompt: Vec<u32> = vec![1, 2, 3, 5];
1734 let steps = 4;
1735
1736 let mut wm_n = synthetic_weights(&cfg);
1737 let mut gn_naive =
1738 GemmaGenerator::from_loader(cfg.clone(), &mut wm_n, Device::Cpu).unwrap();
1739 gn_naive.prefill(&prompt);
1740 let naive_tokens = gn_naive.generate(steps, SampleOpts::greedy()).unwrap();
1741
1742 let mut wm_c = synthetic_weights(&cfg);
1743 let mut gn_cached =
1744 GemmaGenerator::from_loader(cfg.clone(), &mut wm_c, Device::Cpu).unwrap();
1745 gn_cached.prefill(&prompt);
1746 let cached_tokens = gn_cached
1747 .generate_cached(steps, SampleOpts::greedy())
1748 .unwrap();
1749
1750 assert_eq!(
1751 cached_tokens, naive_tokens,
1752 "cached vs naive token mismatch — KV cache or kernel-Lq!=Lk bug"
1753 );
1754 }
1755
1756 #[test]
1757 fn cached_step_advances_cache_invariant() {
1758 let cfg = tiny_cfg();
1759 let mut wm = synthetic_weights(&cfg);
1760 let mut gn = GemmaGenerator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap();
1761 gn.prefill(&[1, 2, 3]);
1762 let _ = gn.step_cached(SampleOpts::greedy()).unwrap();
1763 assert_eq!(gn.tokens().len(), 4);
1765 assert_eq!(gn.cache.as_ref().unwrap().past_len, 3);
1766 let _ = gn.step_cached(SampleOpts::greedy()).unwrap();
1767 assert_eq!(gn.tokens().len(), 5);
1769 assert_eq!(gn.cache.as_ref().unwrap().past_len, 4);
1770 }
1771
1772 #[test]
1773 fn bucketed_decode_matches_oneshot() {
1774 let cfg = tiny_cfg();
1780 let prompt: Vec<u32> = vec![1, 2, 3, 5];
1781 let steps = 6;
1782
1783 let mut wm_one = synthetic_weights(&cfg);
1784 let mut gn_one =
1785 GemmaGenerator::from_loader(cfg.clone(), &mut wm_one, Device::Cpu).unwrap();
1786 gn_one.prefill(&prompt);
1787 let oneshot_tokens = gn_one.generate_cached(steps, SampleOpts::greedy()).unwrap();
1788
1789 let mut wm_buc = synthetic_weights(&cfg);
1790 let mut gn_buc = GemmaGenerator::from_loader(cfg.clone(), &mut wm_buc, Device::Cpu)
1791 .unwrap()
1792 .with_decode_cache(32);
1793 gn_buc.prefill(&prompt);
1794 let bucketed_tokens = gn_buc.generate_cached(steps, SampleOpts::greedy()).unwrap();
1795
1796 assert_eq!(
1797 bucketed_tokens, oneshot_tokens,
1798 "bucketed-cache decode diverged from one-shot decode — \
1799 mask, padding, or output-slice bug"
1800 );
1801 }
1802
1803 #[test]
1804 fn prefill_compile_cache_does_not_change_output() {
1805 let cfg = tiny_cfg();
1806 let prompt: Vec<u32> = vec![1, 2, 3, 5];
1807 let mut wm_a = synthetic_weights(&cfg);
1808 let mut gn_a = GemmaGenerator::from_loader(cfg.clone(), &mut wm_a, Device::Cpu).unwrap();
1809 gn_a.prefill(&prompt);
1810 let a = gn_a.generate_cached(4, SampleOpts::greedy()).unwrap();
1811
1812 let mut wm_b = synthetic_weights(&cfg);
1813 let mut gn_b = GemmaGenerator::from_loader(cfg.clone(), &mut wm_b, Device::Cpu)
1814 .unwrap()
1815 .with_prefill_cache(4);
1816 gn_b.prefill(&prompt);
1817 let b = gn_b.generate_cached(4, SampleOpts::greedy()).unwrap();
1818
1819 assert_eq!(a, b, "enabling prefill_cache must not change output");
1820 }
1821
1822 #[test]
1823 fn dynamic_decode_matches_oneshot() {
1824 let cfg = tiny_cfg();
1825 let prompt: Vec<u32> = vec![1, 2, 3, 5];
1826 let steps = 6;
1827
1828 let mut wm_one = synthetic_weights(&cfg);
1829 let mut gn_one =
1830 GemmaGenerator::from_loader(cfg.clone(), &mut wm_one, Device::Cpu).unwrap();
1831 gn_one.prefill(&prompt);
1832 let oneshot_tokens = gn_one.generate_cached(steps, SampleOpts::greedy()).unwrap();
1833
1834 let mut wm_dyn = synthetic_weights(&cfg);
1835 let mut gn_dyn = GemmaGenerator::from_loader(cfg.clone(), &mut wm_dyn, Device::Cpu)
1836 .unwrap()
1837 .with_dynamic_decode_cache(8);
1838 gn_dyn.prefill(&prompt);
1839 let dynamic_tokens = gn_dyn.generate_cached(steps, SampleOpts::greedy()).unwrap();
1840
1841 assert_eq!(
1842 dynamic_tokens, oneshot_tokens,
1843 "dynamic past_seq decode diverged from one-shot decode"
1844 );
1845 }
1846
1847 #[test]
1848 fn dynamic_prefill_matches_oneshot() {
1849 let cfg = tiny_cfg();
1850 let prompt: Vec<u32> = vec![1, 2, 3, 5];
1851 let steps = 4;
1852
1853 let mut wm_one = synthetic_weights(&cfg);
1854 let mut gn_one =
1855 GemmaGenerator::from_loader(cfg.clone(), &mut wm_one, Device::Cpu).unwrap();
1856 gn_one.prefill(&prompt);
1857 let oneshot_tokens = gn_one.generate_cached(steps, SampleOpts::greedy()).unwrap();
1858
1859 let mut wm_dyn = synthetic_weights(&cfg);
1860 let mut gn_dyn = GemmaGenerator::from_loader(cfg.clone(), &mut wm_dyn, Device::Cpu)
1861 .unwrap()
1862 .with_dynamic_prefill_cache(8);
1863 gn_dyn.prefill(&prompt);
1864 let dynamic_tokens = gn_dyn.generate_cached(steps, SampleOpts::greedy()).unwrap();
1865
1866 assert_eq!(
1867 dynamic_tokens, oneshot_tokens,
1868 "dynamic seq prefill diverged from one-shot prefill"
1869 );
1870 }
1871
1872 #[test]
1873 fn dynamic_prefill_and_decode_matches_oneshot() {
1874 let cfg = tiny_cfg();
1875 let prompt: Vec<u32> = vec![1, 2, 3, 5];
1876 let steps = 6;
1877
1878 let mut wm_one = synthetic_weights(&cfg);
1879 let mut gn_one =
1880 GemmaGenerator::from_loader(cfg.clone(), &mut wm_one, Device::Cpu).unwrap();
1881 gn_one.prefill(&prompt);
1882 let oneshot_tokens = gn_one.generate_cached(steps, SampleOpts::greedy()).unwrap();
1883
1884 let mut wm_dyn = synthetic_weights(&cfg);
1885 let mut gn_dyn = GemmaGenerator::from_loader(cfg.clone(), &mut wm_dyn, Device::Cpu)
1886 .unwrap()
1887 .with_dynamic_prefill_cache(8)
1888 .with_dynamic_decode_cache(8);
1889 gn_dyn.prefill(&prompt);
1890 let dynamic_tokens = gn_dyn.generate_cached(steps, SampleOpts::greedy()).unwrap();
1891
1892 assert_eq!(
1893 dynamic_tokens, oneshot_tokens,
1894 "dynamic prefill+decode diverged from one-shot path"
1895 );
1896 }
1897
1898 #[test]
1899 fn greedy_is_deterministic_across_runs() {
1900 let cfg = tiny_cfg();
1901 let weights = synthetic_weights(&cfg);
1902 let mk = || {
1903 let mut wm = WeightMap::from_tensors(weights_as_hashmap(&weights));
1904 GemmaGenerator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap()
1905 };
1906 let mut a = mk();
1907 let mut b = mk();
1908 a.prefill(&[1, 2, 3]);
1909 b.prefill(&[1, 2, 3]);
1910 let ta = a.generate(4, SampleOpts::greedy()).unwrap();
1911 let tb = b.generate(4, SampleOpts::greedy()).unwrap();
1912 assert_eq!(ta, tb);
1913 }
1914
1915 fn weights_as_hashmap(wm: &WeightMap) -> HashMap<String, (Vec<f32>, Vec<usize>)> {
1916 let _ = wm; let cfg = tiny_cfg();
1922 let mut new = synthetic_weights(&cfg);
1923 let keys: Vec<String> = new.keys().map(|s| s.to_string()).collect();
1924 let mut out = HashMap::new();
1925 for k in keys {
1926 out.insert(k.clone(), new.take(&k).unwrap());
1927 }
1928 out
1929 }
1930}