1use crate::builder::{
28 build_llama32_decode_hir_dynamic_ext, build_llama32_decode_hir_sized,
29 build_llama32_decode_hir_sized_ext, build_llama32_graph_sized_last_logits,
30 build_llama32_prefill_hir_dynamic_ext,
31};
32use crate::config::Llama32Config;
33use crate::rope::{resolve_inv_freq, rope_slice};
34use anyhow::{Context, Result};
35use rlx_core::flow_bridge::compile_options_from_profile;
36use rlx_core::weight_loader::WeightLoader;
37use rlx_core::weight_map::WeightMap;
38use rlx_flow::CompileProfile;
39use rlx_ir::DimBinding;
40use rlx_ir::logical_kernel::KernelDispatchConfig;
41use rlx_qwen3::sampling::{SampleOpts, sample_token};
42use rlx_runtime::compile_cache::{BucketedCompileCache, CompileCache, DynamicDimCompileCache};
43use rlx_runtime::{CompileOptions, Device, Session};
44use std::collections::{HashMap, HashSet};
45use std::path::Path;
46
47#[derive(Clone)]
50struct KvCacheState {
51 past_seq: usize,
52 layers_k: Vec<Vec<f32>>,
53 layers_v: Vec<Vec<f32>>,
54}
55
56pub struct Llama32Generator {
62 cfg: Llama32Config,
63 weights_cache: HashMap<String, (Vec<f32>, Vec<usize>)>,
68 tokens: Vec<u32>,
69 device: Device,
70 cache: Option<KvCacheState>,
74 prefill_compile_cache: Option<CompileCache>,
78 prefill_dynamic_cache: Option<DynamicDimCompileCache>,
80 decode_compile_cache: Option<BucketedCompileCache>,
86 decode_dynamic_cache: Option<DynamicDimCompileCache>,
87 decode_loaded_buckets: HashSet<usize>,
91 inv_freq: Vec<f64>,
93 prefill_profile: CompileProfile,
95 decode_profile: CompileProfile,
97}
98
99impl Llama32Generator {
100 pub fn from_loader(
103 cfg: Llama32Config,
104 loader: &mut dyn WeightLoader,
105 device: Device,
106 ) -> Result<Self> {
107 let keys = loader.remaining_keys();
108 let mut weights_cache = HashMap::with_capacity(keys.len());
109 for k in keys {
110 let v = loader
111 .take(&k)
112 .with_context(|| format!("draining weight {k}"))?;
113 let canonical =
119 rlx_core::weight_loader::gguf_to_hf_name(&k).unwrap_or_else(|| k.clone());
120 weights_cache.insert(canonical, v);
121 }
122 let rope_factors = weights_cache
123 .get("rope_freqs.weight")
124 .map(|(d, _)| d.as_slice());
125 let inv_freq = resolve_inv_freq(&cfg, rope_factors);
126 Ok(Self {
127 cfg,
128 weights_cache,
129 tokens: Vec::new(),
130 device,
131 cache: None,
132 prefill_compile_cache: None,
133 prefill_dynamic_cache: None,
134 decode_compile_cache: None,
135 decode_dynamic_cache: None,
136 decode_loaded_buckets: HashSet::new(),
137 inv_freq,
138 prefill_profile: CompileProfile::llama32_prefill(),
139 decode_profile: CompileProfile::llama32_decode(),
140 })
141 }
142
143 pub fn from_loader_at(
146 cfg: Llama32Config,
147 loader: &mut dyn WeightLoader,
148 device: Device,
149 weights_path: &Path,
150 ) -> Result<Self> {
151 let mut g = Self::from_loader(cfg, loader, device)?;
152 g.prefill_profile = crate::llama32_profile_near_weights(weights_path, false);
153 g.decode_profile = crate::llama32_profile_near_weights(weights_path, true);
154 Ok(g)
155 }
156
157 pub fn with_compile_profiles(
159 mut self,
160 prefill: CompileProfile,
161 decode: CompileProfile,
162 ) -> Self {
163 self.prefill_profile = prefill;
164 self.decode_profile = decode;
165 self
166 }
167
168 pub fn prefill_profile(&self) -> &CompileProfile {
169 &self.prefill_profile
170 }
171
172 pub fn decode_profile(&self) -> &CompileProfile {
173 &self.decode_profile
174 }
175
176 fn profile_compile_options(&self, decode: bool) -> CompileOptions {
177 let profile = if decode {
178 &self.decode_profile
179 } else {
180 &self.prefill_profile
181 };
182 compile_options_from_profile(profile, self.device, KernelDispatchConfig::default())
183 }
184
185 fn compile_hir_profiled(
186 &self,
187 session: &Session,
188 hir: rlx_ir::hir::HirModule,
189 decode: bool,
190 ) -> Result<rlx_runtime::CompiledGraph> {
191 let opts = self.profile_compile_options(decode);
192 Ok(session.compile_hir_with(hir, &opts)?)
193 }
194
195 fn compile_graph_profiled(
196 &self,
197 session: &Session,
198 graph: rlx_ir::Graph,
199 ) -> Result<rlx_runtime::CompiledGraph> {
200 let opts = self.profile_compile_options(false);
201 Ok(session.compile_with(graph, &opts))
202 }
203
204 pub fn with_prefill_cache(mut self, capacity: usize) -> Self {
209 self.prefill_compile_cache = Some(CompileCache::new(self.device, capacity));
210 self.prefill_dynamic_cache = None;
211 self
212 }
213
214 pub fn with_dynamic_prefill_cache(mut self, capacity: usize) -> Self {
216 self.prefill_dynamic_cache = Some(DynamicDimCompileCache::new(self.device, capacity));
217 self.prefill_compile_cache = None;
218 self
219 }
220
221 pub fn with_decode_cache(mut self, max_past: usize) -> Self {
231 let cache = BucketedCompileCache::power_of_two_ladder(
232 self.device,
233 1,
234 max_past.max(1) as u64,
235 );
236 self.decode_compile_cache = Some(cache);
237 self.decode_dynamic_cache = None;
238 self.decode_loaded_buckets.clear();
239 self
240 }
241
242 pub fn with_dynamic_decode_cache(mut self, capacity: usize) -> Self {
244 self.decode_dynamic_cache = Some(DynamicDimCompileCache::new(self.device, capacity));
245 self.decode_compile_cache = None;
246 self.decode_loaded_buckets.clear();
247 self
248 }
249
250 pub fn from_path(cfg: Llama32Config, path: &str, device: Device) -> Result<Self> {
253 let mut loader = rlx_core::weight_loader::load_from_path(path)?;
254 Self::from_loader(cfg, loader.as_mut(), device)
255 }
256
257 pub fn from_path_with_mtp(
265 cfg: Llama32Config,
266 path: &str,
267 device: Device,
268 include_mtp: bool,
269 ) -> Result<Self> {
270 if path.ends_with(".gguf") {
274 let mut gguf = rlx_core::weight_loader::GgufLoader::from_file(path)?;
275 gguf.include_mtp(include_mtp);
276 Self::from_loader(cfg, &mut gguf, device)
277 } else {
278 Self::from_path(cfg, path, device)
279 }
280 }
281
282 pub fn prefill(&mut self, prompt_ids: &[u32]) {
286 self.tokens.clear();
287 self.tokens.extend_from_slice(prompt_ids);
288 self.cache = None;
289 }
290
291 pub fn step(&mut self, opts: SampleOpts) -> Result<u32> {
295 if self.tokens.is_empty() {
296 anyhow::bail!("step() called with empty token history; call prefill() first");
297 }
298 let seq = self.tokens.len();
299 let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
300 let (graph, params) = build_llama32_graph_sized_last_logits(
301 &self.cfg, &mut wm, 1, seq, false,
302 )?;
303 let session = Session::new(self.device);
304 let mut compiled = self.compile_graph_profiled(&session, graph)?;
305 for (name, data) in ¶ms {
306 compiled.set_param(name, data);
307 }
308 let ids_f32: Vec<f32> = self.tokens.iter().map(|&i| i as f32).collect();
309 let outputs = compiled.run(&[("input_ids", ids_f32.as_slice())]);
310 let logits = outputs
311 .into_iter()
312 .next()
313 .context("compiled.run returned no outputs")?;
314
315 let vocab = self.cfg.vocab_size;
316 let expected = vocab;
317 if logits.len() < expected {
318 anyhow::bail!(
319 "logits length {} < expected {} (last logits, seq {seq}, vocab {vocab})",
320 logits.len(),
321 expected
322 );
323 }
324 let last_row = &logits[..vocab];
326 let tok = sample_token(last_row, opts) as u32;
327 self.tokens.push(tok);
328 Ok(tok)
329 }
330
331 pub fn generate(&mut self, n: usize, opts: SampleOpts) -> Result<Vec<u32>> {
334 let start = self.tokens.len();
335 for _ in 0..n {
336 self.step(opts)?;
337 }
338 Ok(self.tokens[start..].to_vec())
339 }
340
341 pub fn step_cached(&mut self, opts: SampleOpts) -> Result<u32> {
351 if self.tokens.is_empty() {
352 anyhow::bail!("step_cached() called with empty token history; call prefill() first");
353 }
354 if self.cache.is_none() {
355 let tok = self.seed_cache_from_prompt(opts)?;
359 return Ok(tok);
360 }
361 let cache = self.cache.as_ref().unwrap();
362 let past_seq = cache.past_seq;
363 if self.tokens.len() <= past_seq {
367 anyhow::bail!(
368 "cache invariant violated: tokens.len() {} <= past_seq {}",
369 self.tokens.len(),
370 past_seq
371 );
372 }
373 let input_tok = self.tokens[past_seq];
374
375 let (logits, new_k, new_v) = if self.decode_dynamic_cache.is_some() {
377 self.decode_step_dynamic(past_seq, input_tok)?
378 } else if self.decode_compile_cache.is_some()
379 && self
380 .decode_compile_cache
381 .as_ref()
382 .unwrap()
383 .bucket_for(past_seq as u64)
384 .is_some()
385 {
386 self.decode_step_bucketed(past_seq, input_tok)?
387 } else {
388 self.decode_step_oneshot(past_seq, input_tok)?
389 };
390
391 let cache_mut = self.cache.as_mut().unwrap();
392 cache_mut.past_seq = past_seq + 1;
393 cache_mut.layers_k = new_k;
394 cache_mut.layers_v = new_v;
395
396 let vocab = self.cfg.vocab_size;
397 if logits.len() != vocab {
398 anyhow::bail!("decode logits length {} != vocab {}", logits.len(), vocab);
399 }
400 let tok = sample_token(&logits, opts) as u32;
401 self.tokens.push(tok);
402 Ok(tok)
403 }
404
405 #[allow(clippy::type_complexity)]
408 fn decode_step_oneshot(
409 &mut self,
410 past_seq: usize,
411 input_tok: u32,
412 ) -> Result<(Vec<f32>, Vec<Vec<f32>>, Vec<Vec<f32>>)> {
413 let cache = self.cache.as_ref().unwrap();
414
415 let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
416 let (hir, params) =
417 build_llama32_decode_hir_sized(&self.cfg, &mut wm, 1, past_seq)?;
418 let session = Session::new(self.device);
419 let mut compiled = self.compile_hir_profiled(&session, hir, true)?;
420 for (name, data) in ¶ms {
421 compiled.set_param(name, data);
422 }
423
424 let (cos, sin) = compute_rope_slice(&self.inv_freq, past_seq);
425 let input_ids_f32 = [input_tok as f32];
426 let key_strs: Vec<String> = (0..self.cfg.num_hidden_layers)
427 .flat_map(|i| [format!("past_k_{i}"), format!("past_v_{i}")])
428 .collect();
429 let mut inputs: Vec<(&str, &[f32])> =
430 Vec::with_capacity(3 + 2 * self.cfg.num_hidden_layers);
431 inputs.push(("input_ids", input_ids_f32.as_slice()));
432 inputs.push(("rope_cos", cos.as_slice()));
433 inputs.push(("rope_sin", sin.as_slice()));
434 for i in 0..self.cfg.num_hidden_layers {
435 inputs.push((&key_strs[2 * i], cache.layers_k[i].as_slice()));
436 inputs.push((&key_strs[2 * i + 1], cache.layers_v[i].as_slice()));
437 }
438
439 let outputs = compiled.run(&inputs);
440 self.split_decode_outputs(outputs)
441 }
442
443 #[allow(clippy::type_complexity)]
444 fn decode_step_dynamic(
445 &mut self,
446 past_seq: usize,
447 input_tok: u32,
448 ) -> Result<(Vec<f32>, Vec<Vec<f32>>, Vec<Vec<f32>>)> {
449 let cache = self.cache.as_ref().unwrap();
450 let binding = DimBinding::batch_past_seq(1, past_seq);
451 let opts = self
452 .profile_compile_options(true)
453 .dim_binding(binding.clone());
454 let cache_dyn = self
455 .decode_dynamic_cache
456 .as_mut()
457 .ok_or_else(|| anyhow::anyhow!("dynamic decode without cache"))?;
458 let needs_upload = !cache_dyn.contains(past_seq as u64);
459 let cfg = self.cfg.clone();
460 let weights_cache = self.weights_cache.clone();
461 let max_past = self.cfg.max_position_embeddings;
462 let compiled = cache_dyn.get_or_specialize(
463 past_seq as u64,
464 &binding,
465 || {
466 let mut wm = WeightMap::from_tensors(weights_cache);
467 build_llama32_decode_hir_dynamic_ext(&cfg, &mut wm, 1, max_past)
468 .expect("dynamic decode HIR")
469 .0
470 },
471 &opts,
472 )?;
473 if needs_upload {
474 let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
475 let (_, params) =
476 build_llama32_decode_hir_dynamic_ext(&self.cfg, &mut wm, 1, max_past)?;
477 for (name, data) in ¶ms {
478 compiled.set_param(name, data);
479 }
480 }
481
482 let (cos, sin) = compute_rope_slice(&self.inv_freq, past_seq);
483 let input_ids_f32 = [input_tok as f32];
484 let key_strs: Vec<String> = (0..self.cfg.num_hidden_layers)
485 .flat_map(|i| [format!("past_k_{i}"), format!("past_v_{i}")])
486 .collect();
487 let mut inputs: Vec<(&str, &[f32])> =
488 Vec::with_capacity(3 + 2 * self.cfg.num_hidden_layers);
489 inputs.push(("input_ids", input_ids_f32.as_slice()));
490 inputs.push(("rope_cos", cos.as_slice()));
491 inputs.push(("rope_sin", sin.as_slice()));
492 for i in 0..self.cfg.num_hidden_layers {
493 inputs.push((&key_strs[2 * i], cache.layers_k[i].as_slice()));
494 inputs.push((&key_strs[2 * i + 1], cache.layers_v[i].as_slice()));
495 }
496 let outputs = compiled.run(&inputs);
497 self.split_decode_outputs(outputs)
498 }
499
500 #[allow(clippy::type_complexity)]
507 fn decode_step_bucketed(
508 &mut self,
509 past_seq: usize,
510 input_tok: u32,
511 ) -> Result<(Vec<f32>, Vec<Vec<f32>>, Vec<Vec<f32>>)> {
512 let cache_dec = self.decode_compile_cache.as_ref().unwrap();
513 let bucket_idx = cache_dec
514 .bucket_for(past_seq as u64)
515 .ok_or_else(|| anyhow::anyhow!("past_seq {past_seq} outside any bucket"))?;
516 let upper = cache_dec
517 .buckets()
518 .nth(bucket_idx)
519 .map(|r| r.end - 1)
520 .unwrap() as usize;
521
522 let kv_dim = self.cfg.kv_proj_dim();
523 let n_layers = self.cfg.num_hidden_layers;
524
525 let needs_load = !self.decode_loaded_buckets.contains(&bucket_idx);
529 if needs_load {
530 let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
531 let (hir, params) = build_llama32_decode_hir_sized_ext(
532 &self.cfg, &mut wm, 1, upper, true,
533 )?;
534 {
535 let decode_opts = self.profile_compile_options(true);
536 let cache_mut = self.decode_compile_cache.as_mut().unwrap();
537 let (_u, compiled) = cache_mut
538 .get_or_compile_hir_with_options(past_seq as u64, |_upper| hir, &decode_opts)
539 .expect("bucket must exist; we just looked it up");
540 for (name, data) in ¶ms {
541 compiled.set_param(name, data);
542 }
543 }
544 self.decode_loaded_buckets.insert(bucket_idx);
545 }
546
547 let (cos, sin) = compute_rope_slice(&self.inv_freq, past_seq);
549 let input_ids_f32 = [input_tok as f32];
550
551 let mask_len = upper + 1;
556 let mut mask = vec![0.0f32; mask_len];
557 for v in mask.iter_mut().take(past_seq + 1) {
558 *v = 1.0;
559 }
560
561 let padded_k: Vec<Vec<f32>> = (0..n_layers)
563 .map(|i| {
564 let src = &self.cache.as_ref().unwrap().layers_k[i];
565 let mut out = vec![0f32; upper * kv_dim];
566 out[..src.len()].copy_from_slice(src);
567 out
568 })
569 .collect();
570 let padded_v: Vec<Vec<f32>> = (0..n_layers)
571 .map(|i| {
572 let src = &self.cache.as_ref().unwrap().layers_v[i];
573 let mut out = vec![0f32; upper * kv_dim];
574 out[..src.len()].copy_from_slice(src);
575 out
576 })
577 .collect();
578
579 let key_strs: Vec<String> = (0..n_layers)
580 .flat_map(|i| [format!("past_k_{i}"), format!("past_v_{i}")])
581 .collect();
582 let mut inputs: Vec<(&str, &[f32])> = Vec::with_capacity(4 + 2 * n_layers);
583 inputs.push(("input_ids", input_ids_f32.as_slice()));
584 inputs.push(("rope_cos", cos.as_slice()));
585 inputs.push(("rope_sin", sin.as_slice()));
586 inputs.push(("mask", mask.as_slice()));
587 for i in 0..n_layers {
588 inputs.push((&key_strs[2 * i], padded_k[i].as_slice()));
589 inputs.push((&key_strs[2 * i + 1], padded_v[i].as_slice()));
590 }
591
592 let cache_mut = self.decode_compile_cache.as_mut().unwrap();
593 let (_u, compiled) = cache_mut
594 .get_or_compile_hir(past_seq as u64, |_| {
595 unreachable!("bucket was just loaded above")
596 })
597 .unwrap();
598 let raw_outputs = compiled.run(&inputs);
599
600 let mut iter = raw_outputs.into_iter();
604 let logits = iter.next().context("bucketed decode logits missing")?;
605 let real_len = (past_seq + 1) * kv_dim;
606 let mut new_k = Vec::with_capacity(n_layers);
607 let mut new_v = Vec::with_capacity(n_layers);
608 for _ in 0..n_layers {
609 let k = iter.next().context("bucketed k missing")?;
610 let v = iter.next().context("bucketed v missing")?;
611 new_k.push(k[..real_len].to_vec());
612 new_v.push(v[..real_len].to_vec());
613 }
614 Ok((logits, new_k, new_v))
615 }
616
617 fn run_prefill_with_cache(
621 &mut self,
622 batch: usize,
623 seq: usize,
624 ids_f32: &[f32],
625 ) -> Result<Vec<Vec<f32>>> {
626 let dynamic_prefill = self.prefill_dynamic_cache.is_some().then(|| {
627 let binding = DimBinding::batch_seq(batch, seq);
628 let opts = self
629 .profile_compile_options(false)
630 .dim_binding(binding.clone());
631 (binding, opts)
632 });
633 if let (Some(cache), Some((binding, opts))) = (
634 self.prefill_dynamic_cache.as_mut(),
635 dynamic_prefill.as_ref(),
636 ) {
637 let needs_upload = !cache.contains(seq as u64);
638 let cfg = self.cfg.clone();
639 let weights_cache = self.weights_cache.clone();
640 let max_seq = self.cfg.max_position_embeddings;
641 let compiled = cache.get_or_specialize(
642 seq as u64,
643 binding,
644 || {
645 let mut wm = WeightMap::from_tensors(weights_cache);
646 build_llama32_prefill_hir_dynamic_ext(&cfg, &mut wm, batch, max_seq, true)
647 .expect("dynamic prefill HIR")
648 .0
649 },
650 opts,
651 )?;
652 if needs_upload {
653 let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
654 let (_, params) = build_llama32_prefill_hir_dynamic_ext(
655 &self.cfg, &mut wm, batch, max_seq, true,
656 )?;
657 for (name, data) in ¶ms {
658 compiled.set_param(name, data);
659 }
660 }
661 let last_idx = vec![(seq - 1) as f32];
662 Ok(compiled.run(&[("input_ids", ids_f32), ("last_token_idx", &last_idx)]))
663 } else if let Some(prefill_cache) = self.prefill_compile_cache.as_mut() {
664 let key = ((batch as u64) << 32) | (seq as u64);
665 if !prefill_cache.contains(key) {
666 let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
667 let (graph, params) = build_llama32_graph_sized_last_logits(
668 &self.cfg, &mut wm, batch, seq, true,
669 )?;
670 {
671 let compiled = prefill_cache.get_or_compile(key, || graph);
672 for (name, data) in ¶ms {
673 compiled.set_param(name, data);
674 }
675 }
676 }
677 let compiled =
678 prefill_cache.get_or_compile(key, || unreachable!("just populated above"));
679 Ok(compiled.run(&[("input_ids", ids_f32)]))
680 } else {
681 let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
682 let (graph, params) = build_llama32_graph_sized_last_logits(
683 &self.cfg, &mut wm, batch, seq, true,
684 )?;
685 let session = Session::new(self.device);
686 let mut compiled = self.compile_graph_profiled(&session, graph)?;
687 for (name, data) in ¶ms {
688 compiled.set_param(name, data);
689 }
690 Ok(compiled.run(&[("input_ids", ids_f32)]))
691 }
692 }
693
694 #[allow(clippy::type_complexity)]
698 fn split_decode_outputs(
699 &self,
700 outputs: Vec<Vec<f32>>,
701 ) -> Result<(Vec<f32>, Vec<Vec<f32>>, Vec<Vec<f32>>)> {
702 let n_layers = self.cfg.num_hidden_layers;
703 if outputs.len() != 1 + 2 * n_layers {
704 anyhow::bail!(
705 "decode graph produced {} outputs, expected {}",
706 outputs.len(),
707 1 + 2 * n_layers
708 );
709 }
710 let mut iter = outputs.into_iter();
711 let logits = iter.next().context("decode logits missing")?;
712 let mut layers_k = Vec::with_capacity(n_layers);
713 let mut layers_v = Vec::with_capacity(n_layers);
714 for _ in 0..n_layers {
715 layers_k.push(iter.next().context("decode k missing")?);
716 layers_v.push(iter.next().context("decode v missing")?);
717 }
718 Ok((logits, layers_k, layers_v))
719 }
720
721 pub fn generate_cached(&mut self, n: usize, opts: SampleOpts) -> Result<Vec<u32>> {
723 self.generate_cached_with(n, opts, |_| {})
724 }
725
726 pub fn generate_cached_with(
733 &mut self,
734 n: usize,
735 opts: SampleOpts,
736 mut on_token: impl FnMut(u32),
737 ) -> Result<Vec<u32>> {
738 let start = self.tokens.len();
739 for _ in 0..n {
740 let tok = self.step_cached(opts)?;
741 on_token(tok);
742 }
743 Ok(self.tokens[start..].to_vec())
744 }
745
746 fn seed_cache_from_prompt(&mut self, opts: SampleOpts) -> Result<u32> {
751 let seq = self.tokens.len();
752 let batch = 1usize;
753 let kv_dim = self.cfg.kv_proj_dim();
754
755 let ids_f32: Vec<f32> = self.tokens.iter().map(|&i| i as f32).collect();
756 let outputs = self.run_prefill_with_cache(batch, seq, &ids_f32)?;
757 if outputs.len() != 1 + 2 * self.cfg.num_hidden_layers {
758 anyhow::bail!(
759 "prefill-with-cache produced {} outputs, expected {}",
760 outputs.len(),
761 1 + 2 * self.cfg.num_hidden_layers
762 );
763 }
764 let expected_kv_len = batch * seq * kv_dim;
765 let mut iter = outputs.into_iter();
766 let logits = iter.next().context("prefill logits missing")?;
767 let mut layers_k = Vec::with_capacity(self.cfg.num_hidden_layers);
768 let mut layers_v = Vec::with_capacity(self.cfg.num_hidden_layers);
769 for layer in 0..self.cfg.num_hidden_layers {
770 let k = iter.next().context("prefill k missing")?;
771 let v = iter.next().context("prefill v missing")?;
772 if k.len() != expected_kv_len || v.len() != expected_kv_len {
773 anyhow::bail!(
774 "layer {layer}: k.len={} v.len={} expected {}",
775 k.len(),
776 v.len(),
777 expected_kv_len
778 );
779 }
780 layers_k.push(k);
781 layers_v.push(v);
782 }
783 self.cache = Some(KvCacheState {
784 past_seq: seq,
785 layers_k,
786 layers_v,
787 });
788
789 let vocab = self.cfg.vocab_size;
790 let needed = vocab;
791 if logits.len() < needed {
792 anyhow::bail!("prefill logits length {} < {}", logits.len(), needed);
793 }
794 let last_row = &logits[..vocab];
795 let tok = sample_token(last_row, opts) as u32;
796 self.tokens.push(tok);
797 Ok(tok)
798 }
799
800 pub fn tokens(&self) -> &[u32] {
802 &self.tokens
803 }
804
805 pub fn config(&self) -> &Llama32Config {
806 &self.cfg
807 }
808
809 pub fn prefill_get_last_logits(&mut self, context: &[u32]) -> Result<Vec<f32>> {
817 if context.is_empty() {
818 anyhow::bail!("prefill_get_last_logits: empty context");
819 }
820 self.tokens.clear();
821 self.tokens.extend_from_slice(context);
822 self.cache = None;
823
824 let seq = context.len();
825 let batch = 1usize;
826 let kv_dim = self.cfg.kv_proj_dim();
827
828 let ids_f32: Vec<f32> = context.iter().map(|&i| i as f32).collect();
829 let outputs = self.run_prefill_with_cache(batch, seq, &ids_f32)?;
830 if outputs.len() != 1 + 2 * self.cfg.num_hidden_layers {
831 anyhow::bail!(
832 "prefill_get_last_logits: got {} outputs, expected {}",
833 outputs.len(),
834 1 + 2 * self.cfg.num_hidden_layers
835 );
836 }
837 let expected_kv_len = batch * seq * kv_dim;
838 let mut iter = outputs.into_iter();
839 let logits = iter.next().context("logits missing")?;
840 let mut layers_k = Vec::with_capacity(self.cfg.num_hidden_layers);
841 let mut layers_v = Vec::with_capacity(self.cfg.num_hidden_layers);
842 for _ in 0..self.cfg.num_hidden_layers {
843 let k = iter.next().context("k missing")?;
844 let v = iter.next().context("v missing")?;
845 if k.len() != expected_kv_len || v.len() != expected_kv_len {
846 anyhow::bail!("kv length mismatch in prefill_get_last_logits");
847 }
848 layers_k.push(k);
849 layers_v.push(v);
850 }
851 self.cache = Some(KvCacheState {
852 past_seq: seq,
853 layers_k,
854 layers_v,
855 });
856
857 let vocab = self.cfg.vocab_size;
858 let needed = vocab;
859 if logits.len() < needed {
860 anyhow::bail!("logits short: {} < {}", logits.len(), needed);
861 }
862 Ok(logits[..vocab].to_vec())
863 }
864
865 pub fn decode_get_logits(&mut self, input: u32) -> Result<Vec<f32>> {
873 let cache = self.cache.as_ref().ok_or_else(|| {
874 anyhow::anyhow!(
875 "decode_get_logits: cache not seeded; call prefill_get_last_logits first"
876 )
877 })?;
878 let past_seq = cache.past_seq;
879
880 let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
881 let (hir, params) =
882 build_llama32_decode_hir_sized(&self.cfg, &mut wm, 1, past_seq)?;
883 let session = Session::new(self.device);
884 let mut compiled = self.compile_hir_profiled(&session, hir, true)?;
885 for (name, data) in ¶ms {
886 compiled.set_param(name, data);
887 }
888
889 let (cos, sin) = compute_rope_slice(&self.inv_freq, past_seq);
890 let input_ids_f32 = [input as f32];
891 let key_strs: Vec<String> = (0..self.cfg.num_hidden_layers)
892 .flat_map(|i| [format!("past_k_{i}"), format!("past_v_{i}")])
893 .collect();
894 let mut inputs: Vec<(&str, &[f32])> =
895 Vec::with_capacity(3 + 2 * self.cfg.num_hidden_layers);
896 inputs.push(("input_ids", input_ids_f32.as_slice()));
897 inputs.push(("rope_cos", cos.as_slice()));
898 inputs.push(("rope_sin", sin.as_slice()));
899 for i in 0..self.cfg.num_hidden_layers {
900 let pk = &cache.layers_k[i];
901 let pv = &cache.layers_v[i];
902 inputs.push((&key_strs[2 * i], pk.as_slice()));
903 inputs.push((&key_strs[2 * i + 1], pv.as_slice()));
904 }
905
906 let outputs = compiled.run(&inputs);
907 if outputs.len() != 1 + 2 * self.cfg.num_hidden_layers {
908 anyhow::bail!(
909 "decode_get_logits: got {} outputs, expected {}",
910 outputs.len(),
911 1 + 2 * self.cfg.num_hidden_layers
912 );
913 }
914 let mut iter = outputs.into_iter();
915 let logits = iter.next().context("logits missing")?;
916 let mut new_k = Vec::with_capacity(self.cfg.num_hidden_layers);
917 let mut new_v = Vec::with_capacity(self.cfg.num_hidden_layers);
918 for _ in 0..self.cfg.num_hidden_layers {
919 new_k.push(iter.next().context("k missing")?);
920 new_v.push(iter.next().context("v missing")?);
921 }
922
923 let cache_mut = self.cache.as_mut().unwrap();
924 cache_mut.past_seq = past_seq + 1;
925 cache_mut.layers_k = new_k;
926 cache_mut.layers_v = new_v;
927 self.tokens.push(input);
928
929 Ok(logits)
930 }
931}
932
933fn compute_rope_slice(inv_freq: &[f64], pos: usize) -> (Vec<f32>, Vec<f32>) {
937 rope_slice(inv_freq, pos)
938}
939
940#[cfg(test)]
941mod tests {
942 use super::*;
943 use crate::config::Llama32Config;
944
945 fn tiny_cfg() -> Llama32Config {
946 Llama32Config {
947 vocab_size: 16,
948 hidden_size: 16,
949 intermediate_size: 32,
950 num_hidden_layers: 2,
951 num_attention_heads: 4,
952 num_key_value_heads: 2,
953 max_position_embeddings: 16,
954 rms_norm_eps: 1e-5,
955 rope_theta: 500_000.0,
956 hidden_act: "silu".into(),
957 tie_word_embeddings: false,
958 attention_bias: false,
959 head_dim: Some(8),
960 rope_scaling: None,
961 }
962 }
963
964 fn synthetic_weights(cfg: &Llama32Config) -> WeightMap {
965 let h = cfg.hidden_size;
966 let q_dim = cfg.q_proj_dim();
967 let kv_dim = cfg.kv_proj_dim();
968 let int_dim = cfg.intermediate_size;
969 let mut t: HashMap<String, (Vec<f32>, Vec<usize>)> = HashMap::new();
970 let pat = |n: usize, salt: u32| -> Vec<f32> {
973 (0..n)
974 .map(|i| {
975 let x = ((i as u32).wrapping_mul(2654435761).wrapping_add(salt)) >> 8;
976 (x as f32 / (1u32 << 24) as f32) - 0.5
977 })
978 .collect()
979 };
980 t.insert(
981 "model.embed_tokens.weight".into(),
982 (pat(cfg.vocab_size * h, 1), vec![cfg.vocab_size, h]),
983 );
984 for i in 0..cfg.num_hidden_layers {
985 let lp = format!("model.layers.{i}");
986 t.insert(
987 format!("{lp}.input_layernorm.weight"),
988 (pat(h, 100 + i as u32), vec![h]),
989 );
990 t.insert(
991 format!("{lp}.post_attention_layernorm.weight"),
992 (pat(h, 200 + i as u32), vec![h]),
993 );
994 t.insert(
995 format!("{lp}.self_attn.q_proj.weight"),
996 (pat(q_dim * h, 300 + i as u32), vec![q_dim, h]),
997 );
998 t.insert(
999 format!("{lp}.self_attn.k_proj.weight"),
1000 (pat(kv_dim * h, 400 + i as u32), vec![kv_dim, h]),
1001 );
1002 t.insert(
1003 format!("{lp}.self_attn.v_proj.weight"),
1004 (pat(kv_dim * h, 500 + i as u32), vec![kv_dim, h]),
1005 );
1006 t.insert(
1007 format!("{lp}.self_attn.o_proj.weight"),
1008 (pat(h * q_dim, 600 + i as u32), vec![h, q_dim]),
1009 );
1010 t.insert(
1011 format!("{lp}.mlp.gate_proj.weight"),
1012 (pat(int_dim * h, 900 + i as u32), vec![int_dim, h]),
1013 );
1014 t.insert(
1015 format!("{lp}.mlp.up_proj.weight"),
1016 (pat(int_dim * h, 1000 + i as u32), vec![int_dim, h]),
1017 );
1018 t.insert(
1019 format!("{lp}.mlp.down_proj.weight"),
1020 (pat(h * int_dim, 1100 + i as u32), vec![h, int_dim]),
1021 );
1022 }
1023 t.insert("model.norm.weight".into(), (pat(h, 2000), vec![h]));
1024 t.insert(
1025 "lm_head.weight".into(),
1026 (pat(cfg.vocab_size * h, 3000), vec![cfg.vocab_size, h]),
1027 );
1028 WeightMap::from_tensors(t)
1029 }
1030
1031 #[test]
1032 fn generator_drains_loader_and_runs_one_step() {
1033 let cfg = tiny_cfg();
1034 let mut wm = synthetic_weights(&cfg);
1035 let mut gn = Llama32Generator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap();
1036 assert_eq!(wm.len(), 0, "loader should be drained");
1037 gn.prefill(&[1, 2, 3]);
1038 let t = gn.step(SampleOpts::greedy()).unwrap();
1039 assert!((t as usize) < cfg.vocab_size);
1040 assert_eq!(gn.tokens().len(), 4);
1041 }
1042
1043 #[test]
1044 fn generate_n_appends_n_tokens() {
1045 let cfg = tiny_cfg();
1046 let mut wm = synthetic_weights(&cfg);
1047 let mut gn = Llama32Generator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap();
1048 gn.prefill(&[5, 6]);
1049 let new_tokens = gn.generate(3, SampleOpts::greedy()).unwrap();
1050 assert_eq!(new_tokens.len(), 3);
1051 assert_eq!(gn.tokens().len(), 5);
1052 for t in &new_tokens {
1053 assert!((*t as usize) < cfg.vocab_size);
1054 }
1055 }
1056
1057 #[test]
1058 fn step_without_prefill_errors() {
1059 let cfg = tiny_cfg();
1060 let mut wm = synthetic_weights(&cfg);
1061 let mut gn = Llama32Generator::from_loader(cfg, &mut wm, Device::Cpu).unwrap();
1062 let r = gn.step(SampleOpts::greedy());
1063 assert!(r.is_err());
1064 }
1065
1066 #[test]
1067 fn cached_matches_naive_on_greedy() {
1068 let cfg = tiny_cfg();
1075 let prompt: Vec<u32> = vec![1, 2, 3, 5];
1076 let steps = 4;
1077
1078 let mut wm_n = synthetic_weights(&cfg);
1079 let mut gn_naive =
1080 Llama32Generator::from_loader(cfg.clone(), &mut wm_n, Device::Cpu).unwrap();
1081 gn_naive.prefill(&prompt);
1082 let naive_tokens = gn_naive.generate(steps, SampleOpts::greedy()).unwrap();
1083
1084 let mut wm_c = synthetic_weights(&cfg);
1085 let mut gn_cached =
1086 Llama32Generator::from_loader(cfg.clone(), &mut wm_c, Device::Cpu).unwrap();
1087 gn_cached.prefill(&prompt);
1088 let cached_tokens = gn_cached
1089 .generate_cached(steps, SampleOpts::greedy())
1090 .unwrap();
1091
1092 assert_eq!(
1093 cached_tokens, naive_tokens,
1094 "cached vs naive token mismatch — KV cache or kernel-Lq!=Lk bug"
1095 );
1096 }
1097
1098 #[test]
1099 fn cached_step_advances_cache_invariant() {
1100 let cfg = tiny_cfg();
1101 let mut wm = synthetic_weights(&cfg);
1102 let mut gn = Llama32Generator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap();
1103 gn.prefill(&[1, 2, 3]);
1104 let _ = gn.step_cached(SampleOpts::greedy()).unwrap();
1105 assert_eq!(gn.tokens().len(), 4);
1107 assert_eq!(gn.cache.as_ref().unwrap().past_seq, 3);
1108 let _ = gn.step_cached(SampleOpts::greedy()).unwrap();
1109 assert_eq!(gn.tokens().len(), 5);
1111 assert_eq!(gn.cache.as_ref().unwrap().past_seq, 4);
1112 }
1113
1114 #[test]
1115 fn bucketed_decode_matches_oneshot() {
1116 let cfg = tiny_cfg();
1122 let prompt: Vec<u32> = vec![1, 2, 3, 5];
1123 let steps = 6;
1124
1125 let mut wm_one = synthetic_weights(&cfg);
1126 let mut gn_one =
1127 Llama32Generator::from_loader(cfg.clone(), &mut wm_one, Device::Cpu).unwrap();
1128 gn_one.prefill(&prompt);
1129 let oneshot_tokens = gn_one.generate_cached(steps, SampleOpts::greedy()).unwrap();
1130
1131 let mut wm_buc = synthetic_weights(&cfg);
1132 let mut gn_buc = Llama32Generator::from_loader(cfg.clone(), &mut wm_buc, Device::Cpu)
1133 .unwrap()
1134 .with_decode_cache(32);
1135 gn_buc.prefill(&prompt);
1136 let bucketed_tokens = gn_buc.generate_cached(steps, SampleOpts::greedy()).unwrap();
1137
1138 assert_eq!(
1139 bucketed_tokens, oneshot_tokens,
1140 "bucketed-cache decode diverged from one-shot decode — \
1141 mask, padding, or output-slice bug"
1142 );
1143 }
1144
1145 #[test]
1146 fn prefill_compile_cache_does_not_change_output() {
1147 let cfg = tiny_cfg();
1148 let prompt: Vec<u32> = vec![1, 2, 3, 5];
1149 let mut wm_a = synthetic_weights(&cfg);
1150 let mut gn_a = Llama32Generator::from_loader(cfg.clone(), &mut wm_a, Device::Cpu).unwrap();
1151 gn_a.prefill(&prompt);
1152 let a = gn_a.generate_cached(4, SampleOpts::greedy()).unwrap();
1153
1154 let mut wm_b = synthetic_weights(&cfg);
1155 let mut gn_b = Llama32Generator::from_loader(cfg.clone(), &mut wm_b, Device::Cpu)
1156 .unwrap()
1157 .with_prefill_cache(4);
1158 gn_b.prefill(&prompt);
1159 let b = gn_b.generate_cached(4, SampleOpts::greedy()).unwrap();
1160
1161 assert_eq!(a, b, "enabling prefill_cache must not change output");
1162 }
1163
1164 #[test]
1165 fn dynamic_decode_matches_oneshot() {
1166 let cfg = tiny_cfg();
1167 let prompt: Vec<u32> = vec![1, 2, 3, 5];
1168 let steps = 6;
1169
1170 let mut wm_one = synthetic_weights(&cfg);
1171 let mut gn_one =
1172 Llama32Generator::from_loader(cfg.clone(), &mut wm_one, Device::Cpu).unwrap();
1173 gn_one.prefill(&prompt);
1174 let oneshot_tokens = gn_one.generate_cached(steps, SampleOpts::greedy()).unwrap();
1175
1176 let mut wm_dyn = synthetic_weights(&cfg);
1177 let mut gn_dyn = Llama32Generator::from_loader(cfg.clone(), &mut wm_dyn, Device::Cpu)
1178 .unwrap()
1179 .with_dynamic_decode_cache(8);
1180 gn_dyn.prefill(&prompt);
1181 let dynamic_tokens = gn_dyn.generate_cached(steps, SampleOpts::greedy()).unwrap();
1182
1183 assert_eq!(
1184 dynamic_tokens, oneshot_tokens,
1185 "dynamic past_seq decode diverged from one-shot decode"
1186 );
1187 }
1188
1189 #[test]
1190 fn dynamic_prefill_matches_oneshot() {
1191 let cfg = tiny_cfg();
1192 let prompt: Vec<u32> = vec![1, 2, 3, 5];
1193 let steps = 4;
1194
1195 let mut wm_one = synthetic_weights(&cfg);
1196 let mut gn_one =
1197 Llama32Generator::from_loader(cfg.clone(), &mut wm_one, Device::Cpu).unwrap();
1198 gn_one.prefill(&prompt);
1199 let oneshot_tokens = gn_one.generate_cached(steps, SampleOpts::greedy()).unwrap();
1200
1201 let mut wm_dyn = synthetic_weights(&cfg);
1202 let mut gn_dyn = Llama32Generator::from_loader(cfg.clone(), &mut wm_dyn, Device::Cpu)
1203 .unwrap()
1204 .with_dynamic_prefill_cache(8);
1205 gn_dyn.prefill(&prompt);
1206 let dynamic_tokens = gn_dyn.generate_cached(steps, SampleOpts::greedy()).unwrap();
1207
1208 assert_eq!(
1209 dynamic_tokens, oneshot_tokens,
1210 "dynamic seq prefill diverged from one-shot prefill"
1211 );
1212 }
1213
1214 #[test]
1215 fn dynamic_prefill_and_decode_matches_oneshot() {
1216 let cfg = tiny_cfg();
1217 let prompt: Vec<u32> = vec![1, 2, 3, 5];
1218 let steps = 6;
1219
1220 let mut wm_one = synthetic_weights(&cfg);
1221 let mut gn_one =
1222 Llama32Generator::from_loader(cfg.clone(), &mut wm_one, Device::Cpu).unwrap();
1223 gn_one.prefill(&prompt);
1224 let oneshot_tokens = gn_one.generate_cached(steps, SampleOpts::greedy()).unwrap();
1225
1226 let mut wm_dyn = synthetic_weights(&cfg);
1227 let mut gn_dyn = Llama32Generator::from_loader(cfg.clone(), &mut wm_dyn, Device::Cpu)
1228 .unwrap()
1229 .with_dynamic_prefill_cache(8)
1230 .with_dynamic_decode_cache(8);
1231 gn_dyn.prefill(&prompt);
1232 let dynamic_tokens = gn_dyn.generate_cached(steps, SampleOpts::greedy()).unwrap();
1233
1234 assert_eq!(
1235 dynamic_tokens, oneshot_tokens,
1236 "dynamic prefill+decode diverged from one-shot path"
1237 );
1238 }
1239
1240 #[test]
1241 fn greedy_is_deterministic_across_runs() {
1242 let cfg = tiny_cfg();
1243 let weights = synthetic_weights(&cfg);
1244 let mk = || {
1245 let mut wm = WeightMap::from_tensors(weights_as_hashmap(&weights));
1246 Llama32Generator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap()
1247 };
1248 let mut a = mk();
1249 let mut b = mk();
1250 a.prefill(&[1, 2, 3]);
1251 b.prefill(&[1, 2, 3]);
1252 let ta = a.generate(4, SampleOpts::greedy()).unwrap();
1253 let tb = b.generate(4, SampleOpts::greedy()).unwrap();
1254 assert_eq!(ta, tb);
1255 }
1256
1257 fn weights_as_hashmap(wm: &WeightMap) -> HashMap<String, (Vec<f32>, Vec<usize>)> {
1258 let _ = wm; let cfg = tiny_cfg();
1264 let mut new = synthetic_weights(&cfg);
1265 let keys: Vec<String> = new.keys().map(|s| s.to_string()).collect();
1266 let mut out = HashMap::new();
1267 for k in keys {
1268 out.insert(k.clone(), new.take(&k).unwrap());
1269 }
1270 out
1271 }
1272}