1use crate::builder::{
34 build_gemma_decode_graph_sized, build_gemma_decode_hir_dynamic_ext,
35 build_gemma_decode_hir_sized_ext, build_gemma_graph_sized_last_logits,
36 build_gemma_prefill_hir_dynamic_ext,
37};
38use crate::config::GemmaConfig;
39use crate::rope::{resolve_inv_freq, rope_slice};
40use anyhow::{Context, Result};
41use rlx_core::autoregressive::{
42 KvCacheState, kv_from_prefill_outputs, run_bucketed_kv_decode_hir, split_decode_logits_kv,
43};
44use rlx_core::flow_bridge::compile_options_from_profile;
45use rlx_core::weight_loader::WeightLoader;
46use rlx_core::weight_map::WeightMap;
47use rlx_flow::CompileProfile;
48use rlx_ir::DimBinding;
49use rlx_ir::logical_kernel::KernelDispatchConfig;
50use rlx_qwen3::sampling::{SampleOpts, sample_token};
51use rlx_runtime::attn_mask::bucket_decode_mask;
52use rlx_runtime::compile_cache::{
53 BucketedCompileCache, CacheRunInput, CompileCache, DynamicDimCompileCache,
54};
55use rlx_runtime::{CompileOptions, Device, Session};
56use std::collections::HashMap;
57use std::path::Path;
58
59pub fn decode_profile_for_device(device: Device) -> CompileProfile {
61 metal_safe_decode_profile(device, CompileProfile::gemma_decode())
62}
63
64fn metal_safe_decode_profile(device: Device, mut profile: CompileProfile) -> CompileProfile {
66 if device == Device::Metal {
67 profile.fusion.skip = true;
68 profile.backend.metal.skip_fusion = true;
69 profile.backend.metal.unfuse_regions = true;
70 }
71 profile
72}
73
74pub struct GemmaGenerator {
80 cfg: GemmaConfig,
81 weights_cache: HashMap<String, (Vec<f32>, Vec<usize>)>,
86 tokens: Vec<u32>,
87 device: Device,
88 cache: Option<KvCacheState>,
92 prefill_compile_cache: Option<CompileCache>,
96 prefill_dynamic_cache: Option<DynamicDimCompileCache>,
98 decode_compile_cache: Option<BucketedCompileCache>,
104 decode_dynamic_cache: Option<DynamicDimCompileCache>,
105 inv_freq: Vec<f64>,
107 prefill_profile: CompileProfile,
109 decode_profile: CompileProfile,
111}
112
113impl GemmaGenerator {
114 pub fn from_loader(
117 cfg: GemmaConfig,
118 loader: &mut dyn WeightLoader,
119 device: Device,
120 ) -> Result<Self> {
121 let keys = loader.remaining_keys();
122 let arch_hint: Option<String> = loader.arch_hint().map(|s| s.to_string());
128 let mut weights_cache = HashMap::with_capacity(keys.len());
129 for k in keys {
130 let v = loader
131 .take(&k)
132 .with_context(|| format!("draining weight {k}"))?;
133 let canonical = match arch_hint.as_deref() {
139 Some(a) => rlx_core::weight_loader::gguf_to_hf_name_for_arch(&k, a)
140 .unwrap_or_else(|| k.clone()),
141 None => rlx_core::weight_loader::gguf_to_hf_name(&k).unwrap_or_else(|| k.clone()),
142 };
143 weights_cache.insert(canonical, v);
144 }
145 let rope_factors = weights_cache
146 .get("rope_freqs.weight")
147 .map(|(d, _)| d.as_slice());
148 let inv_freq = resolve_inv_freq(&cfg, rope_factors);
149 Ok(Self {
150 cfg,
151 weights_cache,
152 tokens: Vec::new(),
153 device,
154 cache: None,
155 prefill_compile_cache: None,
156 prefill_dynamic_cache: None,
157 decode_compile_cache: None,
158 decode_dynamic_cache: None,
159 inv_freq,
160 prefill_profile: CompileProfile::gemma_prefill(),
161 decode_profile: metal_safe_decode_profile(device, CompileProfile::gemma_decode()),
162 })
163 }
164
165 pub fn from_loader_at(
168 cfg: GemmaConfig,
169 loader: &mut dyn WeightLoader,
170 device: Device,
171 weights_path: &Path,
172 ) -> Result<Self> {
173 let mut g = Self::from_loader(cfg, loader, device)?;
174 g.prefill_profile = crate::gemma_profile_near_weights(weights_path, false);
175 g.decode_profile = metal_safe_decode_profile(
176 device,
177 crate::gemma_profile_near_weights(weights_path, true),
178 );
179 Ok(g)
180 }
181
182 pub fn with_compile_profiles(
184 mut self,
185 prefill: CompileProfile,
186 decode: CompileProfile,
187 ) -> Self {
188 self.prefill_profile = prefill;
189 self.decode_profile = metal_safe_decode_profile(self.device, decode);
190 self
191 }
192
193 pub fn prefill_profile(&self) -> &CompileProfile {
194 &self.prefill_profile
195 }
196
197 pub fn decode_profile(&self) -> &CompileProfile {
198 &self.decode_profile
199 }
200
201 fn profile_compile_options(&self, decode: bool) -> CompileOptions {
202 let profile = if decode {
203 &self.decode_profile
204 } else {
205 &self.prefill_profile
206 };
207 compile_options_from_profile(profile, self.device, KernelDispatchConfig::default())
208 }
209
210 fn compile_graph_profiled(
211 &self,
212 session: &Session,
213 graph: rlx_ir::Graph,
214 ) -> Result<rlx_runtime::CompiledGraph> {
215 let opts = self.profile_compile_options(false);
216 Ok(session.compile_with(graph, &opts))
217 }
218
219 fn compile_graph_profiled_decode(
220 &self,
221 session: &Session,
222 graph: rlx_ir::Graph,
223 ) -> Result<rlx_runtime::CompiledGraph> {
224 Ok(session.compile_with(graph, &self.profile_compile_options(true)))
225 }
226
227 pub fn with_prefill_cache(mut self, capacity: usize) -> Self {
232 self.prefill_compile_cache = Some(CompileCache::new(self.device, capacity));
233 self.prefill_dynamic_cache = None;
234 self
235 }
236
237 pub fn with_dynamic_prefill_cache(mut self, capacity: usize) -> Self {
239 self.prefill_dynamic_cache = Some(DynamicDimCompileCache::new(self.device, capacity));
240 self.prefill_compile_cache = None;
241 self
242 }
243
244 pub fn with_decode_cache(mut self, max_past: usize) -> Self {
254 let cache = BucketedCompileCache::power_of_two_ladder(
255 self.device,
256 1,
257 max_past.max(1) as u64,
258 );
259 self.decode_compile_cache = Some(cache);
260 self.decode_dynamic_cache = None;
261 self
262 }
263
264 pub fn with_dynamic_decode_cache(mut self, capacity: usize) -> Self {
266 self.decode_dynamic_cache = Some(DynamicDimCompileCache::new(self.device, capacity));
267 self.decode_compile_cache = None;
268 self
269 }
270
271 pub fn from_path(cfg: GemmaConfig, path: &str, device: Device) -> Result<Self> {
274 let mut loader = rlx_core::weight_loader::load_from_path(path)?;
275 Self::from_loader(cfg, loader.as_mut(), device)
276 }
277
278 pub fn from_path_with_mtp(
286 cfg: GemmaConfig,
287 path: &str,
288 device: Device,
289 include_mtp: bool,
290 ) -> Result<Self> {
291 if path.ends_with(".gguf") {
295 let mut gguf = rlx_core::weight_loader::GgufLoader::from_file(path)?;
296 gguf.include_mtp(include_mtp);
297 Self::from_loader(cfg, &mut gguf, device)
298 } else {
299 Self::from_path(cfg, path, device)
300 }
301 }
302
303 pub fn prefill(&mut self, prompt_ids: &[u32]) {
307 self.tokens.clear();
308 self.tokens.extend_from_slice(prompt_ids);
309 self.cache = None;
310 }
311
312 pub fn step(&mut self, opts: SampleOpts) -> Result<u32> {
316 if self.tokens.is_empty() {
317 anyhow::bail!("step() called with empty token history; call prefill() first");
318 }
319 let seq = self.tokens.len();
320 let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
321 let (graph, params) = build_gemma_graph_sized_last_logits(
322 &self.cfg, &mut wm, 1, seq, false,
323 )?;
324 let session = Session::new(self.device);
325 let mut compiled = self.compile_graph_profiled(&session, graph)?;
326 for (name, data) in ¶ms {
327 compiled.set_param(name, data);
328 }
329 let ids_f32: Vec<f32> = self.tokens.iter().map(|&i| i as f32).collect();
330 let outputs = compiled.run(&[("input_ids", ids_f32.as_slice())]);
331 let logits = outputs
332 .into_iter()
333 .next()
334 .context("compiled.run returned no outputs")?;
335
336 let vocab = self.cfg.vocab_size;
337 let expected = vocab;
338 if logits.len() < expected {
339 anyhow::bail!(
340 "logits length {} < expected {} (last logits, seq {seq}, vocab {vocab})",
341 logits.len(),
342 expected
343 );
344 }
345 let last_row = &logits[..vocab];
347 let tok = sample_token(last_row, opts) as u32;
348 self.tokens.push(tok);
349 Ok(tok)
350 }
351
352 pub fn generate(&mut self, n: usize, opts: SampleOpts) -> Result<Vec<u32>> {
355 let start = self.tokens.len();
356 for _ in 0..n {
357 self.step(opts)?;
358 }
359 Ok(self.tokens[start..].to_vec())
360 }
361
362 pub fn step_cached(&mut self, opts: SampleOpts) -> Result<u32> {
372 if self.tokens.is_empty() {
373 anyhow::bail!("step_cached() called with empty token history; call prefill() first");
374 }
375 if self.cache.is_none() {
376 let tok = self.seed_cache_from_prompt(opts)?;
380 return Ok(tok);
381 }
382 let cache = self.cache.as_ref().unwrap();
383 let past_seq = cache.past_len;
384 if self.tokens.len() <= past_seq {
385 anyhow::bail!(
386 "cache invariant violated: tokens.len() {} <= past_len {}",
387 self.tokens.len(),
388 past_seq
389 );
390 }
391 let input_tok = self.tokens[past_seq];
392
393 let (logits, new_k, new_v) = if self.decode_dynamic_cache.is_some() {
394 self.decode_step_dynamic(past_seq, input_tok)?
395 } else if self.decode_compile_cache.is_some()
396 && self
397 .decode_compile_cache
398 .as_ref()
399 .unwrap()
400 .bucket_for(past_seq as u64)
401 .is_some()
402 {
403 self.decode_step_bucketed(past_seq, input_tok)?
404 } else {
405 self.decode_step_oneshot(past_seq, input_tok)?
406 };
407
408 let cache_mut = self.cache.as_mut().unwrap();
409 cache_mut.past_len = past_seq + 1;
410 cache_mut.layers_k = new_k;
411 cache_mut.layers_v = new_v;
412
413 let vocab = self.cfg.vocab_size;
414 if logits.len() != vocab {
415 anyhow::bail!("decode logits length {} != vocab {}", logits.len(), vocab);
416 }
417 let tok = sample_token(&logits, opts) as u32;
418 self.tokens.push(tok);
419 Ok(tok)
420 }
421
422 #[allow(clippy::type_complexity)]
425 fn decode_step_oneshot(
426 &mut self,
427 past_seq: usize,
428 input_tok: u32,
429 ) -> Result<(Vec<f32>, Vec<Vec<f32>>, Vec<Vec<f32>>)> {
430 let cache = self.cache.as_ref().unwrap();
431
432 let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
433 let (graph, params) =
434 build_gemma_decode_graph_sized(&self.cfg, &mut wm, 1, past_seq)?;
435 let session = Session::new(self.device);
436 let mut compiled = self.compile_graph_profiled_decode(&session, graph)?;
437 for (name, data) in ¶ms {
438 compiled.set_param(name, data);
439 }
440
441 let input_ids_f32 = [input_tok as f32];
442 let key_strs: Vec<String> = (0..self.cfg.num_hidden_layers)
443 .flat_map(|i| [format!("past_k_{i}"), format!("past_v_{i}")])
444 .collect();
445 let mut inputs: Vec<(&str, &[f32])> =
446 Vec::with_capacity(1 + 2 * self.cfg.num_hidden_layers);
447 inputs.push(("input_ids", input_ids_f32.as_slice()));
448 for i in 0..self.cfg.num_hidden_layers {
449 inputs.push((&key_strs[2 * i], cache.layers_k[i].as_slice()));
450 inputs.push((&key_strs[2 * i + 1], cache.layers_v[i].as_slice()));
451 }
452
453 let outputs = compiled.run(&inputs);
454 split_decode_logits_kv(outputs, self.cfg.num_hidden_layers)
455 }
456
457 #[allow(clippy::type_complexity)]
458 fn decode_step_dynamic(
459 &mut self,
460 past_seq: usize,
461 input_tok: u32,
462 ) -> Result<(Vec<f32>, Vec<Vec<f32>>, Vec<Vec<f32>>)> {
463 let cache = self.cache.as_ref().unwrap();
464 let binding = DimBinding::batch_past_seq(1, past_seq);
465 let opts = self
466 .profile_compile_options(true)
467 .dim_binding(binding.clone());
468 let cache_dyn = self
469 .decode_dynamic_cache
470 .as_mut()
471 .ok_or_else(|| anyhow::anyhow!("dynamic decode without cache"))?;
472 let needs_upload = !cache_dyn.contains(past_seq as u64);
473 let cfg = self.cfg.clone();
474 let weights_cache = self.weights_cache.clone();
475 let max_past = self.cfg.max_position_embeddings;
476 let compiled = cache_dyn.get_or_specialize(
477 past_seq as u64,
478 &binding,
479 || {
480 let mut wm = WeightMap::from_tensors(weights_cache);
481 build_gemma_decode_hir_dynamic_ext(&cfg, &mut wm, 1, max_past)
482 .expect("dynamic decode HIR")
483 .0
484 },
485 &opts,
486 )?;
487 if needs_upload {
488 let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
489 let (_, params) = build_gemma_decode_hir_dynamic_ext(&self.cfg, &mut wm, 1, max_past)?;
490 for (name, data) in ¶ms {
491 compiled.set_param(name, data);
492 }
493 }
494
495 let (cos, sin) = compute_rope_slice(&self.inv_freq, past_seq);
496 let input_ids_f32 = [input_tok as f32];
497 let key_strs: Vec<String> = (0..self.cfg.num_hidden_layers)
498 .flat_map(|i| [format!("past_k_{i}"), format!("past_v_{i}")])
499 .collect();
500 let mut inputs: Vec<(&str, &[f32])> =
501 Vec::with_capacity(3 + 2 * self.cfg.num_hidden_layers);
502 inputs.push(("input_ids", input_ids_f32.as_slice()));
503 inputs.push(("rope_cos", cos.as_slice()));
504 inputs.push(("rope_sin", sin.as_slice()));
505 for i in 0..self.cfg.num_hidden_layers {
506 inputs.push((&key_strs[2 * i], cache.layers_k[i].as_slice()));
507 inputs.push((&key_strs[2 * i + 1], cache.layers_v[i].as_slice()));
508 }
509 let outputs = compiled.run(&inputs);
510 split_decode_logits_kv(outputs, self.cfg.num_hidden_layers)
511 }
512
513 #[allow(clippy::type_complexity)]
514 fn decode_step_bucketed(
515 &mut self,
516 past_seq: usize,
517 input_tok: u32,
518 ) -> Result<(Vec<f32>, Vec<Vec<f32>>, Vec<Vec<f32>>)> {
519 let kv = self.cache.as_ref().unwrap().clone();
520 let kv_dim = self.cfg.kv_proj_dim();
521 let n_layers = self.cfg.num_hidden_layers;
522 let (cos, sin) = compute_rope_slice(&self.inv_freq, past_seq);
523 let input_ids_f32 = [input_tok as f32];
524 let decode_opts = self.profile_compile_options(true);
525 let upper = self
526 .decode_compile_cache
527 .as_ref()
528 .and_then(|cache_dec| {
529 cache_dec.bucket_for(past_seq as u64).map(|idx| {
530 cache_dec
531 .buckets()
532 .nth(idx)
533 .map(|r| (r.end - 1) as usize)
534 .unwrap_or(past_seq)
535 })
536 })
537 .unwrap_or(past_seq);
538 let mask = bucket_decode_mask(past_seq, upper);
539 let fixed = [
540 CacheRunInput {
541 name: "input_ids",
542 data: &input_ids_f32,
543 row_inner: None,
544 },
545 CacheRunInput {
546 name: "rope_cos",
547 data: &cos,
548 row_inner: None,
549 },
550 CacheRunInput {
551 name: "rope_sin",
552 data: &sin,
553 row_inner: None,
554 },
555 CacheRunInput {
556 name: "mask",
557 data: &mask,
558 row_inner: None,
559 },
560 ];
561 let cfg = self.cfg.clone();
562 let weights = self.weights_cache.clone();
563 let cache_dec = self.decode_compile_cache.as_mut().unwrap();
564 run_bucketed_kv_decode_hir(
565 cache_dec,
566 past_seq,
567 &kv,
568 kv_dim,
569 n_layers,
570 &fixed,
571 |upper| {
572 let mut wm = WeightMap::from_tensors(weights.clone());
573 build_gemma_decode_hir_sized_ext(&cfg, &mut wm, 1, upper as usize, true)
574 .expect("gemma bucketed decode HIR")
575 },
576 &decode_opts,
577 )
578 }
579
580 #[allow(clippy::unnecessary_unwrap)]
584 fn run_prefill_with_cache(
585 &mut self,
586 batch: usize,
587 seq: usize,
588 ids_f32: &[f32],
589 ) -> Result<Vec<Vec<f32>>> {
590 if self.prefill_dynamic_cache.is_some() {
591 let binding = DimBinding::batch_seq(batch, seq);
592 let opts = compile_options_from_profile(
593 &self.prefill_profile,
594 self.device,
595 KernelDispatchConfig::default(),
596 )
597 .dim_binding(binding.clone());
598 let cache = self.prefill_dynamic_cache.as_mut().expect("checked");
599 let needs_upload = !cache.contains(seq as u64);
600 let cfg = self.cfg.clone();
601 let weights_cache = self.weights_cache.clone();
602 let max_seq = self.cfg.max_position_embeddings;
603 let compiled = cache.get_or_specialize(
604 seq as u64,
605 &binding,
606 || {
607 let mut wm = WeightMap::from_tensors(weights_cache);
608 build_gemma_prefill_hir_dynamic_ext(&cfg, &mut wm, batch, max_seq, true)
609 .expect("dynamic prefill HIR")
610 .0
611 },
612 &opts,
613 )?;
614 if needs_upload {
615 let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
616 let (_, params) =
617 build_gemma_prefill_hir_dynamic_ext(&self.cfg, &mut wm, batch, max_seq, true)?;
618 for (name, data) in ¶ms {
619 compiled.set_param(name, data);
620 }
621 }
622 let last_idx = vec![(seq - 1) as f32];
623 Ok(compiled.run(&[("input_ids", ids_f32), ("last_token_idx", &last_idx)]))
624 } else if self.prefill_compile_cache.is_some() {
625 let key = ((batch as u64) << 32) | (seq as u64);
626 let opts = self.profile_compile_options(false);
627 if !self.prefill_compile_cache.as_ref().unwrap().contains(key) {
628 let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
629 let (graph, params) = build_gemma_graph_sized_last_logits(
630 &self.cfg, &mut wm, batch, seq, true,
631 )?;
632 {
633 let compiled = self
634 .prefill_compile_cache
635 .as_mut()
636 .unwrap()
637 .get_or_compile_with_options(key, || graph, &opts);
638 for (name, data) in ¶ms {
639 compiled.set_param(name, data);
640 }
641 }
642 }
643 let compiled = self
644 .prefill_compile_cache
645 .as_mut()
646 .unwrap()
647 .get_or_compile_with_options(key, || unreachable!("just populated above"), &opts);
648 Ok(compiled.run(&[("input_ids", ids_f32)]))
649 } else {
650 let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
651 let (graph, params) = build_gemma_graph_sized_last_logits(
652 &self.cfg, &mut wm, batch, seq, true,
653 )?;
654 let session = Session::new(self.device);
655 let opts = self.profile_compile_options(false);
656 let mut compiled = session.compile_with(graph, &opts);
657 for (name, data) in ¶ms {
658 compiled.set_param(name, data);
659 }
660 Ok(compiled.run(&[("input_ids", ids_f32)]))
661 }
662 }
663
664 pub fn generate_cached(&mut self, n: usize, opts: SampleOpts) -> Result<Vec<u32>> {
666 self.generate_cached_with(n, opts, |_| {})
667 }
668
669 pub fn generate_cached_with(
676 &mut self,
677 n: usize,
678 opts: SampleOpts,
679 mut on_token: impl FnMut(u32),
680 ) -> Result<Vec<u32>> {
681 let start = self.tokens.len();
682 for _ in 0..n {
683 let tok = self.step_cached(opts)?;
684 on_token(tok);
685 }
686 Ok(self.tokens[start..].to_vec())
687 }
688
689 fn seed_cache_from_prompt(&mut self, opts: SampleOpts) -> Result<u32> {
694 let seq = self.tokens.len();
695 let batch = 1usize;
696 let kv_dim = self.cfg.kv_proj_dim();
697
698 let ids_f32: Vec<f32> = self.tokens.iter().map(|&i| i as f32).collect();
699 let outputs = self.run_prefill_with_cache(batch, seq, &ids_f32)?;
700 let (logits, kv) =
701 kv_from_prefill_outputs(outputs, batch, seq, kv_dim, self.cfg.num_hidden_layers)?;
702 self.cache = Some(kv);
703
704 let vocab = self.cfg.vocab_size;
705 let needed = vocab;
706 if logits.len() < needed {
707 anyhow::bail!("prefill logits length {} < {}", logits.len(), needed);
708 }
709 let last_row = &logits[..vocab];
710 let tok = sample_token(last_row, opts) as u32;
711 self.tokens.push(tok);
712 Ok(tok)
713 }
714
715 pub fn tokens(&self) -> &[u32] {
717 &self.tokens
718 }
719
720 pub fn config(&self) -> &GemmaConfig {
721 &self.cfg
722 }
723
724 pub fn prefill_get_last_logits(&mut self, context: &[u32]) -> Result<Vec<f32>> {
732 if context.is_empty() {
733 anyhow::bail!("prefill_get_last_logits: empty context");
734 }
735 self.tokens.clear();
736 self.tokens.extend_from_slice(context);
737 self.cache = None;
738
739 let seq = context.len();
740 let batch = 1usize;
741 let kv_dim = self.cfg.kv_proj_dim();
742
743 let ids_f32: Vec<f32> = context.iter().map(|&i| i as f32).collect();
744 let outputs = self.run_prefill_with_cache(batch, seq, &ids_f32)?;
745 let (logits, kv) =
746 kv_from_prefill_outputs(outputs, batch, seq, kv_dim, self.cfg.num_hidden_layers)?;
747 self.cache = Some(kv);
748
749 let vocab = self.cfg.vocab_size;
750 let needed = vocab;
751 if logits.len() < needed {
752 anyhow::bail!("logits short: {} < {}", logits.len(), needed);
753 }
754 Ok(logits[..vocab].to_vec())
755 }
756
757 pub fn decode_get_logits(&mut self, input: u32) -> Result<Vec<f32>> {
765 if self.cache.is_none() {
766 anyhow::bail!(
767 "decode_get_logits: cache not seeded; call prefill_get_last_logits first"
768 );
769 }
770 self.tokens.push(input);
771 let seq = self.tokens.len();
772 let batch = 1usize;
773 let kv_dim = self.cfg.kv_proj_dim();
774 let ids_f32: Vec<f32> = self.tokens.iter().map(|&i| i as f32).collect();
775 let outputs = self.run_prefill_with_cache(batch, seq, &ids_f32)?;
776 let (logits, kv) =
777 kv_from_prefill_outputs(outputs, batch, seq, kv_dim, self.cfg.num_hidden_layers)?;
778 self.cache = Some(kv);
779 let vocab = self.cfg.vocab_size;
780 Ok(logits[..vocab].to_vec())
781 }
782}
783
784fn compute_rope_slice(inv_freq: &[f64], pos: usize) -> (Vec<f32>, Vec<f32>) {
788 rope_slice(inv_freq, pos)
789}
790
791#[cfg(test)]
792mod tests {
793 use super::*;
794 use crate::config::GemmaConfig;
795 use crate::rope::{build_rope_tables, resolve_inv_freq, rope_slice};
796 use rlx_flow::CompileProfile;
797
798 fn tiny_cfg() -> GemmaConfig {
799 let mut cfg = GemmaConfig::tiny_test();
800 cfg.vocab_size = 16;
801 cfg.head_dim = Some(8);
802 cfg
803 }
804
805 fn synthetic_tensors(cfg: &GemmaConfig) -> HashMap<String, (Vec<f32>, Vec<usize>)> {
806 let h = cfg.hidden_size;
807 let q_dim = cfg.q_proj_dim();
808 let kv_dim = cfg.kv_proj_dim();
809 let int_dim = cfg.intermediate_size;
810 let mut t: HashMap<String, (Vec<f32>, Vec<usize>)> = HashMap::new();
811 let pat = |n: usize, salt: u32| -> Vec<f32> {
814 (0..n)
815 .map(|i| {
816 let x = ((i as u32).wrapping_mul(2654435761).wrapping_add(salt)) >> 8;
817 (x as f32 / (1u32 << 24) as f32) - 0.5
818 })
819 .collect()
820 };
821 t.insert(
822 "model.embed_tokens.weight".into(),
823 (pat(cfg.vocab_size * h, 1), vec![cfg.vocab_size, h]),
824 );
825 for i in 0..cfg.num_hidden_layers {
826 let lp = format!("model.layers.{i}");
827 t.insert(
828 format!("{lp}.input_layernorm.weight"),
829 (pat(h, 100 + i as u32), vec![h]),
830 );
831 t.insert(
832 format!("{lp}.post_attention_layernorm.weight"),
833 (pat(h, 200 + i as u32), vec![h]),
834 );
835 t.insert(
836 format!("{lp}.self_attn.q_proj.weight"),
837 (pat(q_dim * h, 300 + i as u32), vec![q_dim, h]),
838 );
839 t.insert(
840 format!("{lp}.self_attn.k_proj.weight"),
841 (pat(kv_dim * h, 400 + i as u32), vec![kv_dim, h]),
842 );
843 t.insert(
844 format!("{lp}.self_attn.v_proj.weight"),
845 (pat(kv_dim * h, 500 + i as u32), vec![kv_dim, h]),
846 );
847 t.insert(
848 format!("{lp}.self_attn.o_proj.weight"),
849 (pat(h * q_dim, 600 + i as u32), vec![h, q_dim]),
850 );
851 t.insert(
852 format!("{lp}.mlp.gate_proj.weight"),
853 (pat(int_dim * h, 900 + i as u32), vec![int_dim, h]),
854 );
855 t.insert(
856 format!("{lp}.mlp.up_proj.weight"),
857 (pat(int_dim * h, 1000 + i as u32), vec![int_dim, h]),
858 );
859 t.insert(
860 format!("{lp}.mlp.down_proj.weight"),
861 (pat(h * int_dim, 1100 + i as u32), vec![h, int_dim]),
862 );
863 }
864 t.insert("model.norm.weight".into(), (pat(h, 2000), vec![h]));
865 t.insert(
866 "lm_head.weight".into(),
867 (pat(cfg.vocab_size * h, 3000), vec![cfg.vocab_size, h]),
868 );
869 t
870 }
871
872 fn synthetic_weights(cfg: &GemmaConfig) -> WeightMap {
873 WeightMap::from_tensors(synthetic_tensors(cfg))
874 }
875
876 #[test]
877 fn generator_drains_loader_and_runs_one_step() {
878 let cfg = tiny_cfg();
879 let mut wm = synthetic_weights(&cfg);
880 let mut gn = GemmaGenerator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap();
881 assert_eq!(wm.len(), 0, "loader should be drained");
882 gn.prefill(&[1, 2, 3]);
883 let t = gn.step(SampleOpts::greedy()).unwrap();
884 assert!((t as usize) < cfg.vocab_size);
885 assert_eq!(gn.tokens().len(), 4);
886 }
887
888 #[test]
889 fn generate_n_appends_n_tokens() {
890 let cfg = tiny_cfg();
891 let mut wm = synthetic_weights(&cfg);
892 let mut gn = GemmaGenerator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap();
893 gn.prefill(&[5, 6]);
894 let new_tokens = gn.generate(3, SampleOpts::greedy()).unwrap();
895 assert_eq!(new_tokens.len(), 3);
896 assert_eq!(gn.tokens().len(), 5);
897 for t in &new_tokens {
898 assert!((*t as usize) < cfg.vocab_size);
899 }
900 }
901
902 #[test]
903 fn step_without_prefill_errors() {
904 let cfg = tiny_cfg();
905 let mut wm = synthetic_weights(&cfg);
906 let mut gn = GemmaGenerator::from_loader(cfg, &mut wm, Device::Cpu).unwrap();
907 let r = gn.step(SampleOpts::greedy());
908 assert!(r.is_err());
909 }
910
911 fn max_abs_diff(a: &[f32], b: &[f32]) -> f32 {
912 a.iter()
913 .zip(b.iter())
914 .map(|(x, y)| (x - y).abs())
915 .fold(0f32, f32::max)
916 }
917
918 #[test]
919 fn prefill_logits_unchanged_with_kv_export() {
920 let cfg = tiny_cfg();
921 let prompt: Vec<u32> = vec![1, 2, 3, 5];
922
923 let mut wm_a = synthetic_weights(&cfg);
924 let mut wm_b = synthetic_weights(&cfg);
925 let (graph_a, params_a) =
926 build_gemma_graph_sized_last_logits(&cfg, &mut wm_a, 1, 4, false).unwrap();
927 let (graph_b, params_b) =
928 build_gemma_graph_sized_last_logits(&cfg, &mut wm_b, 1, 4, true).unwrap();
929 let session = Session::new(Device::Cpu);
930 let opts = CompileOptions::new();
931 let mut ca = session.compile_with(graph_a, &opts);
932 let mut cb = session.compile_with(graph_b, &opts);
933 for (n, d) in ¶ms_a {
934 ca.set_param(n, d);
935 }
936 for (n, d) in ¶ms_b {
937 cb.set_param(n, d);
938 }
939 let ids: Vec<f32> = prompt.iter().map(|&i| i as f32).collect();
940 let la = ca.run(&[("input_ids", &ids)])[0].clone();
941 let lb = cb.run(&[("input_ids", &ids)])[0].clone();
942 let d = max_abs_diff(&la, &lb);
943 assert!(d < 1e-5, "kv export changed prefill logits: max_abs={d:.6}");
944 }
945
946 #[test]
947 fn incremental_decode_logits_match_full_prefill() {
948 let cfg = tiny_cfg();
949 let prompt: Vec<u32> = vec![1, 2, 3, 5];
950
951 let mut wm_a = synthetic_weights(&cfg);
952 let mut gn_a = GemmaGenerator::from_loader(cfg.clone(), &mut wm_a, Device::Cpu).unwrap();
953 let tok = gn_a
954 .prefill_get_last_logits(&prompt)
955 .map(|l| sample_token(&l, SampleOpts::greedy()) as u32)
956 .unwrap();
957
958 let mut extended = prompt.clone();
959 extended.push(tok);
960
961 let mut wm_b = synthetic_weights(&cfg);
962 let mut gn_b = GemmaGenerator::from_loader(cfg.clone(), &mut wm_b, Device::Cpu).unwrap();
963 let full = gn_b.prefill_get_last_logits(&extended).unwrap();
964
965 let mut wm_c = synthetic_weights(&cfg);
966 let mut gn_c = GemmaGenerator::from_loader(cfg.clone(), &mut wm_c, Device::Cpu).unwrap();
967 gn_c.prefill_get_last_logits(&prompt).unwrap();
968 let incremental = gn_c.decode_get_logits(tok).unwrap();
969
970 let d = max_abs_diff(&full, &incremental);
971 assert!(
972 d < 1e-2,
973 "decode+KV vs full prefill max_abs={d:.6} (tok={tok})"
974 );
975 }
976
977 fn run_prefill_kv(
978 cfg: &GemmaConfig,
979 wm: &mut WeightMap,
980 seq: usize,
981 ids: &[u32],
982 ) -> Vec<Vec<f32>> {
983 run_prefill_kv_with_options(cfg, wm, seq, ids, &kv_export_compile_options(true))
984 }
985
986 fn kv_export_compile_options(prefill: bool) -> CompileOptions {
987 let profile = if prefill {
988 CompileProfile::gemma_prefill()
989 } else {
990 CompileProfile::gemma_decode()
991 };
992 compile_options_from_profile(&profile, Device::Cpu, KernelDispatchConfig::default())
993 }
994
995 fn run_prefill_kv_with_options(
996 cfg: &GemmaConfig,
997 wm: &mut WeightMap,
998 seq: usize,
999 ids: &[u32],
1000 opts: &CompileOptions,
1001 ) -> Vec<Vec<f32>> {
1002 let ids_f32: Vec<f32> = ids.iter().map(|&i| i as f32).collect();
1003 let (graph, params) = build_gemma_graph_sized_last_logits(cfg, wm, 1, seq, true).unwrap();
1004 let session = Session::new(Device::Cpu);
1005 let mut compiled = session.compile_with(graph, opts);
1006 for (n, d) in ¶ms {
1007 compiled.set_param(n, d);
1008 }
1009 let outputs = compiled.run(&[("input_ids", &ids_f32)]);
1010 let n_layers = cfg.num_hidden_layers;
1011 assert_eq!(outputs.len(), 1 + 2 * n_layers);
1012 let mut kv = Vec::with_capacity(2 * n_layers);
1013 let mut iter = outputs.into_iter().skip(1);
1014 for _ in 0..n_layers {
1015 kv.push(iter.next().unwrap());
1016 kv.push(iter.next().unwrap());
1017 }
1018 kv
1019 }
1020
1021 #[test]
1022 fn decode_graph_bakes_rope_slice_length() {
1023 let cfg = tiny_cfg();
1024 let past_seq = 4usize;
1025 let half = cfg.head_dim() / 2;
1026 let mut wm = synthetic_weights(&cfg);
1027 let (_, params) = build_gemma_decode_graph_sized(&cfg, &mut wm, 1, past_seq).unwrap();
1028 let cos = params
1029 .get("decode.rope.cos")
1030 .expect("decode.rope.cos param");
1031 let sin = params
1032 .get("decode.rope.sin")
1033 .expect("decode.rope.sin param");
1034 assert_eq!(
1035 cos.len(),
1036 half,
1037 "cos param should be one row (half={half}), got {}",
1038 cos.len()
1039 );
1040 assert_eq!(sin.len(), half);
1041 for key in params.keys() {
1042 assert!(
1043 !key.starts_with("rope."),
1044 "decode graph must not include prefill rope table param {key}"
1045 );
1046 }
1047 let inv = resolve_inv_freq(&cfg, None);
1048 let (c_ref, s_ref) = rope_slice(&inv, past_seq);
1049 let d = max_abs_diff(cos, &c_ref) + max_abs_diff(sin, &s_ref);
1050 assert!(d < 1e-6, "baked rope mismatch: {d}");
1051 }
1052
1053 #[test]
1054 fn decode_graph_all_rope_use_baked_cos() {
1055 use rlx_ir::Op;
1056 let cfg = tiny_cfg();
1057 let mut wm = synthetic_weights(&cfg);
1058 let (graph, _) = build_gemma_decode_graph_sized(&cfg, &mut wm, 1, 4).unwrap();
1059 for node in graph.nodes() {
1060 if let Op::Rope { .. } = &node.op {
1061 let cos_id = node.inputs[1];
1062 let cos_node = &graph.node(cos_id);
1063 match &cos_node.op {
1064 Op::Param { name } => assert_eq!(
1065 name, "decode.rope.cos",
1066 "decode RoPE must use baked decode.rope.cos, got {name}"
1067 ),
1068 other => panic!("decode RoPE cos input is {other:?}, expected Param"),
1069 }
1070 }
1071 }
1072 }
1073
1074 #[test]
1075 fn decode_graph_rope_cos_is_single_row() {
1076 use rlx_ir::Op;
1077 let cfg = tiny_cfg();
1078 let past_seq = 4usize;
1079 let half = cfg.head_dim() / 2;
1080 let mut wm = synthetic_weights(&cfg);
1081 let (graph, _) = build_gemma_decode_graph_sized(&cfg, &mut wm, 1, past_seq).unwrap();
1082 let mut rope_cos_lens = Vec::new();
1083 for node in graph.nodes() {
1084 if let Op::Rope { .. } = &node.op {
1085 let cos_shape = &graph.node(node.inputs[1]).shape;
1086 let rows = if cos_shape.rank() >= 2 {
1087 cos_shape.dim(0).unwrap_static()
1088 } else {
1089 1
1090 };
1091 rope_cos_lens.push(rows);
1092 }
1093 }
1094 assert!(!rope_cos_lens.is_empty(), "decode graph has no RoPE nodes");
1095 for rows in &rope_cos_lens {
1096 assert_eq!(
1097 *rows, 1,
1098 "decode RoPE cos must be single-row [1, half], got {rows} rows"
1099 );
1100 }
1101 assert_eq!(half, cfg.head_dim() / 2);
1102 }
1103
1104 #[test]
1105 fn prefill_kv_matches_extended_prefix() {
1106 let cfg = tiny_cfg();
1107 let prompt: Vec<u32> = vec![1, 2, 3, 5];
1108 let tok = 6u32;
1109 let mut extended = prompt.clone();
1110 extended.push(tok);
1111
1112 let mut wm_prompt = synthetic_weights(&cfg);
1113 let prompt_kv = run_prefill_kv(&cfg, &mut wm_prompt, 4, &prompt);
1114 let mut wm_ext = synthetic_weights(&cfg);
1115 let ext_kv = run_prefill_kv(&cfg, &mut wm_ext, 5, &extended);
1116
1117 let kv_dim = cfg.kv_proj_dim();
1118 for layer in 0..cfg.num_hidden_layers {
1119 let k_prompt = &prompt_kv[2 * layer];
1120 let k_ext = &ext_kv[2 * layer];
1121 let prefix_len = 4 * kv_dim;
1122 assert_eq!(k_prompt.len(), prefix_len);
1123 assert_eq!(k_ext.len(), 5 * kv_dim);
1124 let d = max_abs_diff(k_prompt, &k_ext[..prefix_len]);
1125 assert!(
1126 d < 1e-4,
1127 "layer {layer} prefill K prefix vs extended K max_abs={d:.6}"
1128 );
1129 }
1130 }
1131
1132 #[test]
1133 fn decode_rope_slice_matches_prefill_table_row() {
1134 let cfg = tiny_cfg();
1135 let inv = resolve_inv_freq(&cfg, None);
1136 let (cos_tab, sin_tab) = build_rope_tables(&inv, cfg.max_position_embeddings);
1137 let half = inv.len();
1138 for pos in [3usize, 4, 5] {
1139 let (c, s) = rope_slice(&inv, pos);
1140 let off = pos * half;
1141 let d = max_abs_diff(&c, &cos_tab[off..off + half])
1142 + max_abs_diff(&s, &sin_tab[off..off + half]);
1143 assert!(d < 1e-6, "rope_slice mismatch at pos {pos}: {d}");
1144 }
1145 }
1146
1147 #[test]
1148 fn prefill_kv_export_correct_with_fusion() {
1149 let cfg = tiny_cfg();
1150 let tok = 6u32;
1151 let ids = [1u32, 2, 3, 5, tok];
1152 let opts = kv_export_compile_options(true);
1153 let mut wm_one = synthetic_weights(&cfg);
1154 let one_kv = run_prefill_kv_with_options(&cfg, &mut wm_one, 1, &[tok], &opts);
1155 let mut wm_ext = synthetic_weights(&cfg);
1156 let ext_kv = run_prefill_kv_with_options(&cfg, &mut wm_ext, 5, &ids, &opts);
1157 let kv_dim = cfg.kv_proj_dim();
1158 let d = max_abs_diff(&ext_kv[1][4 * kv_dim..], &one_kv[1][..kv_dim]);
1159 assert!(d < 1e-4, "KV export mismatch with profile fusion: {d:.6}");
1160
1161 let mut wm_default = synthetic_weights(&cfg);
1162 let default_kv =
1163 run_prefill_kv_with_options(&cfg, &mut wm_default, 5, &ids, &CompileOptions::new());
1164 let d_default = max_abs_diff(&default_kv[1][4 * kv_dim..], &one_kv[1][..kv_dim]);
1165 assert!(
1166 d_default < 1e-4,
1167 "KV export mismatch with default fusion (got {d_default:.6})"
1168 );
1169 }
1170
1171 #[test]
1172 fn decode_oneshot_kv_suffix_matches_extended() {
1173 let cfg = tiny_cfg();
1174 let prompt: Vec<u32> = vec![1, 2, 3, 5];
1175 let tok = 6u32;
1176 let mut extended = prompt.clone();
1177 extended.push(tok);
1178
1179 let opts = kv_export_compile_options(false);
1180 let mut wm_ext = synthetic_weights(&cfg);
1181 let ext_kv = run_prefill_kv_with_options(&cfg, &mut wm_ext, 5, &extended, &opts);
1182
1183 let mut wm = synthetic_weights(&cfg);
1184 let mut gn = GemmaGenerator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap();
1185 gn.prefill_get_last_logits(&prompt).unwrap();
1186
1187 let mut wm_d = synthetic_weights(&cfg);
1188 let (graph, params) = build_gemma_decode_graph_sized(&cfg, &mut wm_d, 1, 4).unwrap();
1189 let session = Session::new(Device::Cpu);
1190 let mut compiled = session.compile_with(graph, &opts);
1191 for (n, d) in ¶ms {
1192 compiled.set_param(n, d);
1193 }
1194 let cache = gn.cache.as_ref().unwrap();
1195 let key_strs: Vec<String> = (0..cfg.num_hidden_layers)
1196 .flat_map(|i| [format!("past_k_{i}"), format!("past_v_{i}")])
1197 .collect();
1198 let input_ids = [tok as f32];
1199 let mut inputs: Vec<(&str, &[f32])> = vec![("input_ids", input_ids.as_slice())];
1200 for i in 0..cfg.num_hidden_layers {
1201 inputs.push((&key_strs[2 * i], cache.layers_k[i].as_slice()));
1202 inputs.push((&key_strs[2 * i + 1], cache.layers_v[i].as_slice()));
1203 }
1204 let outputs = compiled.run(&inputs);
1205 let kv_dim = cfg.kv_proj_dim();
1206 let k_dec = &outputs[1][4 * kv_dim..];
1207
1208 let d = max_abs_diff(k_dec, &ext_kv[0][4 * kv_dim..]);
1209 assert!(
1210 d < 1e-3,
1211 "decode oneshot layer0 K suffix vs extended max_abs={d:.6}"
1212 );
1213 }
1214
1215 #[test]
1216 fn decode_logits_match_extended_prefill_after_one_token() {
1217 let cfg = tiny_cfg();
1218 let prompt: Vec<u32> = vec![1, 2, 3, 5];
1219 let tok = 6u32;
1220
1221 let mut extended = prompt.clone();
1222 extended.push(tok);
1223
1224 let mut wm_a = synthetic_weights(&cfg);
1225 let mut gn_a = GemmaGenerator::from_loader(cfg.clone(), &mut wm_a, Device::Cpu).unwrap();
1226 let full = gn_a.prefill_get_last_logits(&extended).unwrap();
1227
1228 let mut wm_b = synthetic_weights(&cfg);
1229 let mut gn_b = GemmaGenerator::from_loader(cfg.clone(), &mut wm_b, Device::Cpu).unwrap();
1230 gn_b.prefill_get_last_logits(&prompt).unwrap();
1231 let inc = gn_b.decode_get_logits(tok).unwrap();
1232
1233 let d = max_abs_diff(&full, &inc);
1234 assert!(d < 1e-2, "decode vs extended prefill max_abs={d:.6}");
1235 }
1236
1237 #[test]
1238 fn cached_second_token_matches_naive() {
1239 let cfg = tiny_cfg();
1240 let prompt: Vec<u32> = vec![1, 2, 3, 5];
1241
1242 let mut wm_n = synthetic_weights(&cfg);
1243 let mut gn_n = GemmaGenerator::from_loader(cfg.clone(), &mut wm_n, Device::Cpu).unwrap();
1244 gn_n.prefill(&prompt);
1245 let n0 = gn_n.step(SampleOpts::greedy()).unwrap();
1246 let n1 = gn_n.step(SampleOpts::greedy()).unwrap();
1247
1248 let mut wm_c = synthetic_weights(&cfg);
1249 let mut gn_c = GemmaGenerator::from_loader(cfg.clone(), &mut wm_c, Device::Cpu).unwrap();
1250 gn_c.prefill(&prompt);
1251 let c = gn_c.generate_cached(2, SampleOpts::greedy()).unwrap();
1252
1253 assert_eq!(c[0], n0, "first generated token");
1254 assert_eq!(c[1], n1, "second generated token (decode step)");
1255 }
1256
1257 #[test]
1258 fn cached_matches_naive_on_greedy() {
1259 let cfg = tiny_cfg();
1266 let prompt: Vec<u32> = vec![1, 2, 3, 5];
1267 let steps = 4;
1268
1269 let mut wm_n = synthetic_weights(&cfg);
1270 let mut gn_naive =
1271 GemmaGenerator::from_loader(cfg.clone(), &mut wm_n, Device::Cpu).unwrap();
1272 gn_naive.prefill(&prompt);
1273 let naive_tokens = gn_naive.generate(steps, SampleOpts::greedy()).unwrap();
1274
1275 let mut wm_c = synthetic_weights(&cfg);
1276 let mut gn_cached =
1277 GemmaGenerator::from_loader(cfg.clone(), &mut wm_c, Device::Cpu).unwrap();
1278 gn_cached.prefill(&prompt);
1279 let cached_tokens = gn_cached
1280 .generate_cached(steps, SampleOpts::greedy())
1281 .unwrap();
1282
1283 assert_eq!(
1284 cached_tokens, naive_tokens,
1285 "cached vs naive token mismatch — KV cache or kernel-Lq!=Lk bug"
1286 );
1287 }
1288
1289 #[test]
1290 fn cached_step_advances_cache_invariant() {
1291 let cfg = tiny_cfg();
1292 let mut wm = synthetic_weights(&cfg);
1293 let mut gn = GemmaGenerator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap();
1294 gn.prefill(&[1, 2, 3]);
1295 let _ = gn.step_cached(SampleOpts::greedy()).unwrap();
1296 assert_eq!(gn.tokens().len(), 4);
1298 assert_eq!(gn.cache.as_ref().unwrap().past_len, 3);
1299 let _ = gn.step_cached(SampleOpts::greedy()).unwrap();
1300 assert_eq!(gn.tokens().len(), 5);
1302 assert_eq!(gn.cache.as_ref().unwrap().past_len, 4);
1303 }
1304
1305 #[test]
1306 fn bucketed_decode_matches_oneshot() {
1307 let cfg = tiny_cfg();
1313 let prompt: Vec<u32> = vec![1, 2, 3, 5];
1314 let steps = 6;
1315
1316 let mut wm_one = synthetic_weights(&cfg);
1317 let mut gn_one =
1318 GemmaGenerator::from_loader(cfg.clone(), &mut wm_one, Device::Cpu).unwrap();
1319 gn_one.prefill(&prompt);
1320 let oneshot_tokens = gn_one.generate_cached(steps, SampleOpts::greedy()).unwrap();
1321
1322 let mut wm_buc = synthetic_weights(&cfg);
1323 let mut gn_buc = GemmaGenerator::from_loader(cfg.clone(), &mut wm_buc, Device::Cpu)
1324 .unwrap()
1325 .with_decode_cache(32);
1326 gn_buc.prefill(&prompt);
1327 let bucketed_tokens = gn_buc.generate_cached(steps, SampleOpts::greedy()).unwrap();
1328
1329 assert_eq!(
1330 bucketed_tokens, oneshot_tokens,
1331 "bucketed-cache decode diverged from one-shot decode — \
1332 mask, padding, or output-slice bug"
1333 );
1334 }
1335
1336 #[test]
1337 fn prefill_compile_cache_does_not_change_output() {
1338 let cfg = tiny_cfg();
1339 let prompt: Vec<u32> = vec![1, 2, 3, 5];
1340 let mut wm_a = synthetic_weights(&cfg);
1341 let mut gn_a = GemmaGenerator::from_loader(cfg.clone(), &mut wm_a, Device::Cpu).unwrap();
1342 gn_a.prefill(&prompt);
1343 let a = gn_a.generate_cached(4, SampleOpts::greedy()).unwrap();
1344
1345 let mut wm_b = synthetic_weights(&cfg);
1346 let mut gn_b = GemmaGenerator::from_loader(cfg.clone(), &mut wm_b, Device::Cpu)
1347 .unwrap()
1348 .with_prefill_cache(4);
1349 gn_b.prefill(&prompt);
1350 let b = gn_b.generate_cached(4, SampleOpts::greedy()).unwrap();
1351
1352 assert_eq!(a, b, "enabling prefill_cache must not change output");
1353 }
1354
1355 #[test]
1356 fn dynamic_decode_matches_oneshot() {
1357 let cfg = tiny_cfg();
1358 let prompt: Vec<u32> = vec![1, 2, 3, 5];
1359 let steps = 6;
1360
1361 let mut wm_one = synthetic_weights(&cfg);
1362 let mut gn_one =
1363 GemmaGenerator::from_loader(cfg.clone(), &mut wm_one, Device::Cpu).unwrap();
1364 gn_one.prefill(&prompt);
1365 let oneshot_tokens = gn_one.generate_cached(steps, SampleOpts::greedy()).unwrap();
1366
1367 let mut wm_dyn = synthetic_weights(&cfg);
1368 let mut gn_dyn = GemmaGenerator::from_loader(cfg.clone(), &mut wm_dyn, Device::Cpu)
1369 .unwrap()
1370 .with_dynamic_decode_cache(8);
1371 gn_dyn.prefill(&prompt);
1372 let dynamic_tokens = gn_dyn.generate_cached(steps, SampleOpts::greedy()).unwrap();
1373
1374 assert_eq!(
1375 dynamic_tokens, oneshot_tokens,
1376 "dynamic past_seq decode diverged from one-shot decode"
1377 );
1378 }
1379
1380 #[test]
1381 fn dynamic_prefill_matches_oneshot() {
1382 let cfg = tiny_cfg();
1383 let prompt: Vec<u32> = vec![1, 2, 3, 5];
1384 let steps = 4;
1385
1386 let mut wm_one = synthetic_weights(&cfg);
1387 let mut gn_one =
1388 GemmaGenerator::from_loader(cfg.clone(), &mut wm_one, Device::Cpu).unwrap();
1389 gn_one.prefill(&prompt);
1390 let oneshot_tokens = gn_one.generate_cached(steps, SampleOpts::greedy()).unwrap();
1391
1392 let mut wm_dyn = synthetic_weights(&cfg);
1393 let mut gn_dyn = GemmaGenerator::from_loader(cfg.clone(), &mut wm_dyn, Device::Cpu)
1394 .unwrap()
1395 .with_dynamic_prefill_cache(8);
1396 gn_dyn.prefill(&prompt);
1397 let dynamic_tokens = gn_dyn.generate_cached(steps, SampleOpts::greedy()).unwrap();
1398
1399 assert_eq!(
1400 dynamic_tokens, oneshot_tokens,
1401 "dynamic seq prefill diverged from one-shot prefill"
1402 );
1403 }
1404
1405 #[test]
1406 fn dynamic_prefill_and_decode_matches_oneshot() {
1407 let cfg = tiny_cfg();
1408 let prompt: Vec<u32> = vec![1, 2, 3, 5];
1409 let steps = 6;
1410
1411 let mut wm_one = synthetic_weights(&cfg);
1412 let mut gn_one =
1413 GemmaGenerator::from_loader(cfg.clone(), &mut wm_one, Device::Cpu).unwrap();
1414 gn_one.prefill(&prompt);
1415 let oneshot_tokens = gn_one.generate_cached(steps, SampleOpts::greedy()).unwrap();
1416
1417 let mut wm_dyn = synthetic_weights(&cfg);
1418 let mut gn_dyn = GemmaGenerator::from_loader(cfg.clone(), &mut wm_dyn, Device::Cpu)
1419 .unwrap()
1420 .with_dynamic_prefill_cache(8)
1421 .with_dynamic_decode_cache(8);
1422 gn_dyn.prefill(&prompt);
1423 let dynamic_tokens = gn_dyn.generate_cached(steps, SampleOpts::greedy()).unwrap();
1424
1425 assert_eq!(
1426 dynamic_tokens, oneshot_tokens,
1427 "dynamic prefill+decode diverged from one-shot path"
1428 );
1429 }
1430
1431 #[test]
1432 fn greedy_is_deterministic_across_runs() {
1433 let cfg = tiny_cfg();
1434 let weights = synthetic_weights(&cfg);
1435 let mk = || {
1436 let mut wm = WeightMap::from_tensors(weights_as_hashmap(&weights));
1437 GemmaGenerator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap()
1438 };
1439 let mut a = mk();
1440 let mut b = mk();
1441 a.prefill(&[1, 2, 3]);
1442 b.prefill(&[1, 2, 3]);
1443 let ta = a.generate(4, SampleOpts::greedy()).unwrap();
1444 let tb = b.generate(4, SampleOpts::greedy()).unwrap();
1445 assert_eq!(ta, tb);
1446 }
1447
1448 fn weights_as_hashmap(wm: &WeightMap) -> HashMap<String, (Vec<f32>, Vec<usize>)> {
1449 let _ = wm; let cfg = tiny_cfg();
1455 let mut new = synthetic_weights(&cfg);
1456 let keys: Vec<String> = new.keys().map(|s| s.to_string()).collect();
1457 let mut out = HashMap::new();
1458 for k in keys {
1459 out.insert(k.clone(), new.take(&k).unwrap());
1460 }
1461 out
1462 }
1463}