1use crate::builder::{
34 build_llama32_decode_hir_dynamic_ext, build_llama32_decode_hir_sized,
35 build_llama32_decode_hir_sized_ext, build_llama32_graph_sized_last_logits,
36 build_llama32_prefill_hir_dynamic_ext,
37};
38use crate::config::Llama32Config;
39use crate::rope::{resolve_inv_freq, rope_slice};
40use anyhow::{Context, Result};
41use rlx_core::flow_bridge::compile_options_from_profile;
42use rlx_core::weight_loader::WeightLoader;
43use rlx_core::weight_map::WeightMap;
44use rlx_flow::CompileProfile;
45use rlx_ir::DimBinding;
46use rlx_ir::logical_kernel::KernelDispatchConfig;
47use rlx_qwen3::sampling::{SampleOpts, sample_token};
48use rlx_runtime::compile_cache::{BucketedCompileCache, CompileCache, DynamicDimCompileCache};
49use rlx_runtime::{CompileOptions, Device, Session};
50use std::collections::{HashMap, HashSet};
51use std::path::Path;
52
53fn metal_thunk_compile_guard<R, F>(device: Device, f: F) -> R
61where
62 F: FnOnce() -> R,
63{
64 if device == Device::Metal {
65 rlx_ir::env::set("RLX_DISABLE_MPSGRAPH", "1");
66 let out = f();
67 rlx_ir::env::unset("RLX_DISABLE_MPSGRAPH");
68 out
69 } else {
70 f()
71 }
72}
73
74fn metal_decode_compile_guard<R, F>(device: Device, decode: bool, f: F) -> R
75where
76 F: FnOnce() -> R,
77{
78 if decode {
79 metal_thunk_compile_guard(device, f)
80 } else {
81 f()
82 }
83}
84
85#[derive(Clone)]
88struct KvCacheState {
89 past_seq: usize,
90 layers_k: Vec<Vec<f32>>,
91 layers_v: Vec<Vec<f32>>,
92}
93
94pub struct Llama32Generator {
100 cfg: Llama32Config,
101 weights_cache: HashMap<String, (Vec<f32>, Vec<usize>)>,
106 tokens: Vec<u32>,
107 device: Device,
108 cache: Option<KvCacheState>,
112 prefill_compile_cache: Option<CompileCache>,
116 prefill_dynamic_cache: Option<DynamicDimCompileCache>,
118 decode_compile_cache: Option<BucketedCompileCache>,
124 decode_dynamic_cache: Option<DynamicDimCompileCache>,
125 decode_loaded_buckets: HashSet<usize>,
129 compile_seq_cap: Option<usize>,
132 inv_freq: Vec<f64>,
134 prefill_profile: CompileProfile,
136 decode_profile: CompileProfile,
138}
139
140impl Llama32Generator {
141 pub fn from_loader(
144 cfg: Llama32Config,
145 loader: &mut dyn WeightLoader,
146 device: Device,
147 ) -> Result<Self> {
148 let keys = loader.remaining_keys();
149 let mut weights_cache = HashMap::with_capacity(keys.len());
150 for k in keys {
151 let v = loader
152 .take(&k)
153 .with_context(|| format!("draining weight {k}"))?;
154 let canonical =
160 rlx_core::weight_loader::gguf_to_hf_name(&k).unwrap_or_else(|| k.clone());
161 weights_cache.insert(canonical, v);
162 }
163 let rope_factors = weights_cache
164 .get("rope_freqs.weight")
165 .map(|(d, _)| d.as_slice());
166 let inv_freq = resolve_inv_freq(&cfg, rope_factors);
167 Ok(Self {
168 cfg,
169 weights_cache,
170 tokens: Vec::new(),
171 device,
172 cache: None,
173 prefill_compile_cache: None,
174 prefill_dynamic_cache: None,
175 decode_compile_cache: None,
176 decode_dynamic_cache: None,
177 decode_loaded_buckets: HashSet::new(),
178 compile_seq_cap: None,
179 inv_freq,
180 prefill_profile: CompileProfile::llama32_prefill(),
181 decode_profile: CompileProfile::llama32_decode(),
182 })
183 }
184
185 fn compile_seq_cap(&self) -> usize {
186 self.compile_seq_cap
187 .unwrap_or(self.cfg.max_position_embeddings)
188 }
189
190 pub fn with_compile_seq_cap(mut self, cap: usize) -> Self {
194 self.compile_seq_cap = Some(cap.max(1));
195 self
196 }
197
198 pub fn from_loader_at(
201 cfg: Llama32Config,
202 loader: &mut dyn WeightLoader,
203 device: Device,
204 weights_path: &Path,
205 ) -> Result<Self> {
206 let mut g = Self::from_loader(cfg, loader, device)?;
207 g.prefill_profile = crate::llama32_profile_near_weights(weights_path, false);
208 g.decode_profile = crate::llama32_profile_near_weights(weights_path, true);
209 Ok(g)
210 }
211
212 pub fn with_compile_profiles(
214 mut self,
215 prefill: CompileProfile,
216 decode: CompileProfile,
217 ) -> Self {
218 self.prefill_profile = prefill;
219 self.decode_profile = decode;
220 self
221 }
222
223 pub fn prefill_profile(&self) -> &CompileProfile {
224 &self.prefill_profile
225 }
226
227 pub fn decode_profile(&self) -> &CompileProfile {
228 &self.decode_profile
229 }
230
231 fn profile_compile_options(&self, decode: bool) -> CompileOptions {
232 let profile = if decode {
233 &self.decode_profile
234 } else {
235 &self.prefill_profile
236 };
237 compile_options_from_profile(profile, self.device, KernelDispatchConfig::default())
238 }
239
240 fn compile_hir_profiled(
241 &self,
242 session: &Session,
243 hir: rlx_ir::hir::HirModule,
244 decode: bool,
245 ) -> Result<rlx_runtime::CompiledGraph> {
246 let opts = self.profile_compile_options(decode);
247 Ok(metal_decode_compile_guard(self.device, decode, || {
248 session.compile_hir_with(hir, &opts)
249 })?)
250 }
251
252 fn compile_graph_profiled(
253 &self,
254 session: &Session,
255 graph: rlx_ir::Graph,
256 ) -> Result<rlx_runtime::CompiledGraph> {
257 let opts = self.profile_compile_options(false);
258 Ok(session.compile_with(graph, &opts))
259 }
260
261 pub fn with_prefill_cache(mut self, capacity: usize) -> Self {
266 self.prefill_compile_cache = Some(CompileCache::new(self.device, capacity));
267 self.prefill_dynamic_cache = None;
268 self
269 }
270
271 pub fn with_dynamic_prefill_cache(mut self, capacity: usize) -> Self {
273 self.prefill_dynamic_cache = Some(DynamicDimCompileCache::new(self.device, capacity));
274 self.prefill_compile_cache = None;
275 self
276 }
277
278 pub fn with_decode_cache(mut self, max_past: usize) -> Self {
288 let cache = BucketedCompileCache::power_of_two_ladder(
289 self.device,
290 1,
291 max_past.max(1) as u64,
292 );
293 self.decode_compile_cache = Some(cache);
294 self.decode_dynamic_cache = None;
295 self.decode_loaded_buckets.clear();
296 self
297 }
298
299 pub fn with_dynamic_decode_cache(mut self, capacity: usize) -> Self {
301 self.decode_dynamic_cache = Some(DynamicDimCompileCache::new(self.device, capacity));
302 self.decode_compile_cache = None;
303 self.decode_loaded_buckets.clear();
304 self
305 }
306
307 pub fn from_path(cfg: Llama32Config, path: &str, device: Device) -> Result<Self> {
310 let mut loader = rlx_core::weight_loader::load_from_path(path)?;
311 Self::from_loader(cfg, loader.as_mut(), device)
312 }
313
314 pub fn from_path_with_mtp(
322 cfg: Llama32Config,
323 path: &str,
324 device: Device,
325 include_mtp: bool,
326 ) -> Result<Self> {
327 if path.ends_with(".gguf") {
331 let mut gguf = rlx_core::weight_loader::GgufLoader::from_file(path)?;
332 gguf.include_mtp(include_mtp);
333 Self::from_loader(cfg, &mut gguf, device)
334 } else {
335 Self::from_path(cfg, path, device)
336 }
337 }
338
339 pub fn prefill(&mut self, prompt_ids: &[u32]) {
343 self.tokens.clear();
344 self.tokens.extend_from_slice(prompt_ids);
345 self.cache = None;
346 }
347
348 pub fn step(&mut self, opts: SampleOpts) -> Result<u32> {
352 if self.tokens.is_empty() {
353 anyhow::bail!("step() called with empty token history; call prefill() first");
354 }
355 let seq = self.tokens.len();
356 let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
357 let (graph, params) = build_llama32_graph_sized_last_logits(
358 &self.cfg, &mut wm, 1, seq, false,
359 )?;
360 let session = Session::new(self.device);
361 let mut compiled = self.compile_graph_profiled(&session, graph)?;
362 for (name, data) in ¶ms {
363 compiled.set_param(name, data);
364 }
365 let ids_f32: Vec<f32> = self.tokens.iter().map(|&i| i as f32).collect();
366 let outputs = compiled.run(&[("input_ids", ids_f32.as_slice())]);
367 let logits = outputs
368 .into_iter()
369 .next()
370 .context("compiled.run returned no outputs")?;
371
372 let vocab = self.cfg.vocab_size;
373 let expected = vocab;
374 if logits.len() < expected {
375 anyhow::bail!(
376 "logits length {} < expected {} (last logits, seq {seq}, vocab {vocab})",
377 logits.len(),
378 expected
379 );
380 }
381 let last_row = &logits[..vocab];
383 let tok = sample_token(last_row, opts) as u32;
384 self.tokens.push(tok);
385 Ok(tok)
386 }
387
388 pub fn generate(&mut self, n: usize, opts: SampleOpts) -> Result<Vec<u32>> {
391 let start = self.tokens.len();
392 for _ in 0..n {
393 self.step(opts)?;
394 }
395 Ok(self.tokens[start..].to_vec())
396 }
397
398 pub fn step_cached(&mut self, opts: SampleOpts) -> Result<u32> {
408 if self.tokens.is_empty() {
409 anyhow::bail!("step_cached() called with empty token history; call prefill() first");
410 }
411 if self.cache.is_none() {
412 let tok = self.seed_cache_from_prompt(opts)?;
416 return Ok(tok);
417 }
418 let cache = self.cache.as_ref().unwrap();
419 let past_seq = cache.past_seq;
420 if self.tokens.len() <= past_seq {
424 anyhow::bail!(
425 "cache invariant violated: tokens.len() {} <= past_seq {}",
426 self.tokens.len(),
427 past_seq
428 );
429 }
430 let input_tok = self.tokens[past_seq];
431
432 let (logits, new_k, new_v) = if self.decode_dynamic_cache.is_some() {
434 self.decode_step_dynamic(past_seq, input_tok)?
435 } else if self.decode_compile_cache.is_some()
436 && self
437 .decode_compile_cache
438 .as_ref()
439 .unwrap()
440 .bucket_for(past_seq as u64)
441 .is_some()
442 {
443 self.decode_step_bucketed(past_seq, input_tok)?
444 } else {
445 self.decode_step_oneshot(past_seq, input_tok)?
446 };
447
448 let cache_mut = self.cache.as_mut().unwrap();
449 cache_mut.past_seq = past_seq + 1;
450 cache_mut.layers_k = new_k;
451 cache_mut.layers_v = new_v;
452
453 let vocab = self.cfg.vocab_size;
454 if logits.len() != vocab {
455 anyhow::bail!("decode logits length {} != vocab {}", logits.len(), vocab);
456 }
457 let tok = sample_token(&logits, opts) as u32;
458 self.tokens.push(tok);
459 Ok(tok)
460 }
461
462 #[allow(clippy::type_complexity)]
465 fn decode_step_oneshot(
466 &mut self,
467 past_seq: usize,
468 input_tok: u32,
469 ) -> Result<(Vec<f32>, Vec<Vec<f32>>, Vec<Vec<f32>>)> {
470 let cache = self.cache.as_ref().unwrap();
471
472 let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
473 let (hir, params) =
474 build_llama32_decode_hir_sized(&self.cfg, &mut wm, 1, past_seq)?;
475 let session = Session::new(self.device);
476 let mut compiled = self.compile_hir_profiled(&session, hir, true)?;
477 for (name, data) in ¶ms {
478 compiled.set_param(name, data);
479 }
480
481 let (cos, sin) = compute_rope_slice(&self.inv_freq, past_seq);
482 let input_ids_f32 = [input_tok as f32];
483 let key_strs: Vec<String> = (0..self.cfg.num_hidden_layers)
484 .flat_map(|i| [format!("past_k_{i}"), format!("past_v_{i}")])
485 .collect();
486 let mut inputs: Vec<(&str, &[f32])> =
487 Vec::with_capacity(3 + 2 * self.cfg.num_hidden_layers);
488 inputs.push(("input_ids", input_ids_f32.as_slice()));
489 inputs.push(("rope_cos", cos.as_slice()));
490 inputs.push(("rope_sin", sin.as_slice()));
491 for i in 0..self.cfg.num_hidden_layers {
492 inputs.push((&key_strs[2 * i], cache.layers_k[i].as_slice()));
493 inputs.push((&key_strs[2 * i + 1], cache.layers_v[i].as_slice()));
494 }
495
496 let outputs = compiled.run(&inputs);
497 self.split_decode_outputs(outputs)
498 }
499
500 #[allow(clippy::type_complexity)]
501 fn decode_step_dynamic(
502 &mut self,
503 past_seq: usize,
504 input_tok: u32,
505 ) -> Result<(Vec<f32>, Vec<Vec<f32>>, Vec<Vec<f32>>)> {
506 let cache = self.cache.as_ref().unwrap();
507 let binding = DimBinding::batch_past_seq(1, past_seq);
508 let opts = self
509 .profile_compile_options(true)
510 .dim_binding(binding.clone());
511 let max_past = self.compile_seq_cap();
512 let cache_dyn = self
513 .decode_dynamic_cache
514 .as_mut()
515 .ok_or_else(|| anyhow::anyhow!("dynamic decode without cache"))?;
516 let needs_upload = !cache_dyn.contains(past_seq as u64);
517 let cfg = self.cfg.clone();
518 let weights_cache = self.weights_cache.clone();
519 let device = self.device;
520 let compiled = cache_dyn.get_or_specialize(
521 past_seq as u64,
522 &binding,
523 || {
524 metal_decode_compile_guard(device, true, || {
525 let mut wm = WeightMap::from_tensors(weights_cache);
526 build_llama32_decode_hir_dynamic_ext(&cfg, &mut wm, 1, max_past)
527 .expect("dynamic decode HIR")
528 .0
529 })
530 },
531 &opts,
532 )?;
533 if needs_upload {
534 let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
535 let (_, params) =
536 build_llama32_decode_hir_dynamic_ext(&self.cfg, &mut wm, 1, max_past)?;
537 for (name, data) in ¶ms {
538 compiled.set_param(name, data);
539 }
540 }
541
542 let (cos, sin) = compute_rope_slice(&self.inv_freq, past_seq);
543 let input_ids_f32 = [input_tok as f32];
544 let key_strs: Vec<String> = (0..self.cfg.num_hidden_layers)
545 .flat_map(|i| [format!("past_k_{i}"), format!("past_v_{i}")])
546 .collect();
547 let mut inputs: Vec<(&str, &[f32])> =
548 Vec::with_capacity(3 + 2 * self.cfg.num_hidden_layers);
549 inputs.push(("input_ids", input_ids_f32.as_slice()));
550 inputs.push(("rope_cos", cos.as_slice()));
551 inputs.push(("rope_sin", sin.as_slice()));
552 for i in 0..self.cfg.num_hidden_layers {
553 inputs.push((&key_strs[2 * i], cache.layers_k[i].as_slice()));
554 inputs.push((&key_strs[2 * i + 1], cache.layers_v[i].as_slice()));
555 }
556 let outputs = compiled.run(&inputs);
557 self.split_decode_outputs(outputs)
558 }
559
560 #[allow(clippy::type_complexity)]
567 fn decode_step_bucketed(
568 &mut self,
569 past_seq: usize,
570 input_tok: u32,
571 ) -> Result<(Vec<f32>, Vec<Vec<f32>>, Vec<Vec<f32>>)> {
572 let cache_dec = self.decode_compile_cache.as_ref().unwrap();
573 let bucket_idx = cache_dec
574 .bucket_for(past_seq as u64)
575 .ok_or_else(|| anyhow::anyhow!("past_seq {past_seq} outside any bucket"))?;
576 let upper = cache_dec
577 .buckets()
578 .nth(bucket_idx)
579 .map(|r| r.end - 1)
580 .unwrap() as usize;
581
582 let kv_dim = self.cfg.kv_proj_dim();
583 let n_layers = self.cfg.num_hidden_layers;
584
585 let needs_load = !self.decode_loaded_buckets.contains(&bucket_idx);
589 if needs_load {
590 let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
591 let (hir, params) = build_llama32_decode_hir_sized_ext(
592 &self.cfg, &mut wm, 1, upper, true,
593 )?;
594 {
595 let decode_opts = self.profile_compile_options(true);
596 let cache_mut = self.decode_compile_cache.as_mut().unwrap();
597 metal_decode_compile_guard(self.device, true, || {
598 let (_u, compiled) = cache_mut
599 .get_or_compile_hir_with_options(
600 past_seq as u64,
601 |_upper| hir,
602 &decode_opts,
603 )
604 .expect("bucket must exist; we just looked it up");
605 for (name, data) in ¶ms {
606 compiled.set_param(name, data);
607 }
608 });
609 }
610 self.decode_loaded_buckets.insert(bucket_idx);
611 }
612
613 let (cos, sin) = compute_rope_slice(&self.inv_freq, past_seq);
615 let input_ids_f32 = [input_tok as f32];
616
617 let mask_len = upper + 1;
622 let mut mask = vec![0.0f32; mask_len];
623 for v in mask.iter_mut().take(past_seq + 1) {
624 *v = 1.0;
625 }
626
627 let padded_k: Vec<Vec<f32>> = (0..n_layers)
629 .map(|i| {
630 let src = &self.cache.as_ref().unwrap().layers_k[i];
631 let mut out = vec![0f32; upper * kv_dim];
632 out[..src.len()].copy_from_slice(src);
633 out
634 })
635 .collect();
636 let padded_v: Vec<Vec<f32>> = (0..n_layers)
637 .map(|i| {
638 let src = &self.cache.as_ref().unwrap().layers_v[i];
639 let mut out = vec![0f32; upper * kv_dim];
640 out[..src.len()].copy_from_slice(src);
641 out
642 })
643 .collect();
644
645 let key_strs: Vec<String> = (0..n_layers)
646 .flat_map(|i| [format!("past_k_{i}"), format!("past_v_{i}")])
647 .collect();
648 let mut inputs: Vec<(&str, &[f32])> = Vec::with_capacity(4 + 2 * n_layers);
649 inputs.push(("input_ids", input_ids_f32.as_slice()));
650 inputs.push(("rope_cos", cos.as_slice()));
651 inputs.push(("rope_sin", sin.as_slice()));
652 inputs.push(("mask", mask.as_slice()));
653 for i in 0..n_layers {
654 inputs.push((&key_strs[2 * i], padded_k[i].as_slice()));
655 inputs.push((&key_strs[2 * i + 1], padded_v[i].as_slice()));
656 }
657
658 let cache_mut = self.decode_compile_cache.as_mut().unwrap();
659 let (_u, compiled) = cache_mut
660 .get_or_compile_hir(past_seq as u64, |_| {
661 unreachable!("bucket was just loaded above")
662 })
663 .unwrap();
664 let raw_outputs = compiled.run(&inputs);
665
666 let mut iter = raw_outputs.into_iter();
670 let logits = iter.next().context("bucketed decode logits missing")?;
671 let real_len = (past_seq + 1) * kv_dim;
672 let mut new_k = Vec::with_capacity(n_layers);
673 let mut new_v = Vec::with_capacity(n_layers);
674 for _ in 0..n_layers {
675 let k = iter.next().context("bucketed k missing")?;
676 let v = iter.next().context("bucketed v missing")?;
677 new_k.push(k[..real_len].to_vec());
678 new_v.push(v[..real_len].to_vec());
679 }
680 Ok((logits, new_k, new_v))
681 }
682
683 fn run_prefill_with_cache(
687 &mut self,
688 batch: usize,
689 seq: usize,
690 ids_f32: &[f32],
691 ) -> Result<Vec<Vec<f32>>> {
692 let compile_cap = self.compile_seq_cap();
693 let dynamic_prefill = self.prefill_dynamic_cache.is_some().then(|| {
694 let binding = DimBinding::batch_seq(batch, seq);
695 let opts = self
696 .profile_compile_options(false)
697 .dim_binding(binding.clone());
698 (binding, opts)
699 });
700 if let (Some(cache), Some((binding, opts))) = (
701 self.prefill_dynamic_cache.as_mut(),
702 dynamic_prefill.as_ref(),
703 ) {
704 let max_seq = compile_cap;
705 let needs_upload = !cache.contains(seq as u64);
706 let cfg = self.cfg.clone();
707 let weights_cache = self.weights_cache.clone();
708 let compiled = cache.get_or_specialize(
709 seq as u64,
710 binding,
711 || {
712 let mut wm = WeightMap::from_tensors(weights_cache);
713 build_llama32_prefill_hir_dynamic_ext(&cfg, &mut wm, batch, max_seq, true)
714 .expect("dynamic prefill HIR")
715 .0
716 },
717 opts,
718 )?;
719 if needs_upload {
720 let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
721 let (_, params) = build_llama32_prefill_hir_dynamic_ext(
722 &self.cfg, &mut wm, batch, max_seq, true,
723 )?;
724 for (name, data) in ¶ms {
725 compiled.set_param(name, data);
726 }
727 }
728 let last_idx = vec![(seq - 1) as f32];
729 Ok(compiled.run(&[("input_ids", ids_f32), ("last_token_idx", &last_idx)]))
730 } else if let Some(prefill_cache) = self.prefill_compile_cache.as_mut() {
731 let key = ((batch as u64) << 32) | (seq as u64);
732 if !prefill_cache.contains(key) {
733 let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
734 let (graph, params) = build_llama32_graph_sized_last_logits(
735 &self.cfg, &mut wm, batch, seq, true,
736 )?;
737 {
738 let compiled = prefill_cache.get_or_compile(key, || graph);
739 for (name, data) in ¶ms {
740 compiled.set_param(name, data);
741 }
742 }
743 }
744 let compiled =
745 prefill_cache.get_or_compile(key, || unreachable!("just populated above"));
746 Ok(compiled.run(&[("input_ids", ids_f32)]))
747 } else {
748 let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
749 let (graph, params) = build_llama32_graph_sized_last_logits(
750 &self.cfg, &mut wm, batch, seq, true,
751 )?;
752 let session = Session::new(self.device);
753 let mut compiled = self.compile_graph_profiled(&session, graph)?;
754 for (name, data) in ¶ms {
755 compiled.set_param(name, data);
756 }
757 Ok(compiled.run(&[("input_ids", ids_f32)]))
758 }
759 }
760
761 #[allow(clippy::type_complexity)]
765 fn split_decode_outputs(
766 &self,
767 outputs: Vec<Vec<f32>>,
768 ) -> Result<(Vec<f32>, Vec<Vec<f32>>, Vec<Vec<f32>>)> {
769 let n_layers = self.cfg.num_hidden_layers;
770 if outputs.len() != 1 + 2 * n_layers {
771 anyhow::bail!(
772 "decode graph produced {} outputs, expected {}",
773 outputs.len(),
774 1 + 2 * n_layers
775 );
776 }
777 let mut iter = outputs.into_iter();
778 let logits = iter.next().context("decode logits missing")?;
779 let mut layers_k = Vec::with_capacity(n_layers);
780 let mut layers_v = Vec::with_capacity(n_layers);
781 for _ in 0..n_layers {
782 layers_k.push(iter.next().context("decode k missing")?);
783 layers_v.push(iter.next().context("decode v missing")?);
784 }
785 Ok((logits, layers_k, layers_v))
786 }
787
788 pub fn generate_cached(&mut self, n: usize, opts: SampleOpts) -> Result<Vec<u32>> {
790 self.generate_cached_with(n, opts, |_| {})
791 }
792
793 pub fn generate_cached_with(
800 &mut self,
801 n: usize,
802 opts: SampleOpts,
803 mut on_token: impl FnMut(u32),
804 ) -> Result<Vec<u32>> {
805 let start = self.tokens.len();
806 for _ in 0..n {
807 let tok = self.step_cached(opts)?;
808 on_token(tok);
809 }
810 Ok(self.tokens[start..].to_vec())
811 }
812
813 fn seed_cache_from_prompt(&mut self, opts: SampleOpts) -> Result<u32> {
818 let seq = self.tokens.len();
819 let batch = 1usize;
820 let kv_dim = self.cfg.kv_proj_dim();
821
822 let ids_f32: Vec<f32> = self.tokens.iter().map(|&i| i as f32).collect();
823 let outputs = self.run_prefill_with_cache(batch, seq, &ids_f32)?;
824 if outputs.len() != 1 + 2 * self.cfg.num_hidden_layers {
825 anyhow::bail!(
826 "prefill-with-cache produced {} outputs, expected {}",
827 outputs.len(),
828 1 + 2 * self.cfg.num_hidden_layers
829 );
830 }
831 let expected_kv_len = batch * seq * kv_dim;
832 let mut iter = outputs.into_iter();
833 let logits = iter.next().context("prefill logits missing")?;
834 let mut layers_k = Vec::with_capacity(self.cfg.num_hidden_layers);
835 let mut layers_v = Vec::with_capacity(self.cfg.num_hidden_layers);
836 for layer in 0..self.cfg.num_hidden_layers {
837 let k = iter.next().context("prefill k missing")?;
838 let v = iter.next().context("prefill v missing")?;
839 if k.len() != expected_kv_len || v.len() != expected_kv_len {
840 anyhow::bail!(
841 "layer {layer}: k.len={} v.len={} expected {}",
842 k.len(),
843 v.len(),
844 expected_kv_len
845 );
846 }
847 layers_k.push(k);
848 layers_v.push(v);
849 }
850 self.cache = Some(KvCacheState {
851 past_seq: seq,
852 layers_k,
853 layers_v,
854 });
855
856 let vocab = self.cfg.vocab_size;
857 let needed = vocab;
858 if logits.len() < needed {
859 anyhow::bail!("prefill logits length {} < {}", logits.len(), needed);
860 }
861 let last_row = &logits[..vocab];
862 let tok = sample_token(last_row, opts) as u32;
863 self.tokens.push(tok);
864 Ok(tok)
865 }
866
867 pub fn tokens(&self) -> &[u32] {
869 &self.tokens
870 }
871
872 pub fn config(&self) -> &Llama32Config {
873 &self.cfg
874 }
875
876 pub fn prefill_get_last_logits(&mut self, context: &[u32]) -> Result<Vec<f32>> {
884 if context.is_empty() {
885 anyhow::bail!("prefill_get_last_logits: empty context");
886 }
887 self.tokens.clear();
888 self.tokens.extend_from_slice(context);
889 self.cache = None;
890
891 let seq = context.len();
892 let batch = 1usize;
893 let kv_dim = self.cfg.kv_proj_dim();
894
895 let ids_f32: Vec<f32> = context.iter().map(|&i| i as f32).collect();
896 let outputs = self.run_prefill_with_cache(batch, seq, &ids_f32)?;
897 if outputs.len() != 1 + 2 * self.cfg.num_hidden_layers {
898 anyhow::bail!(
899 "prefill_get_last_logits: got {} outputs, expected {}",
900 outputs.len(),
901 1 + 2 * self.cfg.num_hidden_layers
902 );
903 }
904 let expected_kv_len = batch * seq * kv_dim;
905 let mut iter = outputs.into_iter();
906 let logits = iter.next().context("logits missing")?;
907 let mut layers_k = Vec::with_capacity(self.cfg.num_hidden_layers);
908 let mut layers_v = Vec::with_capacity(self.cfg.num_hidden_layers);
909 for _ in 0..self.cfg.num_hidden_layers {
910 let k = iter.next().context("k missing")?;
911 let v = iter.next().context("v missing")?;
912 if k.len() != expected_kv_len || v.len() != expected_kv_len {
913 anyhow::bail!("kv length mismatch in prefill_get_last_logits");
914 }
915 layers_k.push(k);
916 layers_v.push(v);
917 }
918 self.cache = Some(KvCacheState {
919 past_seq: seq,
920 layers_k,
921 layers_v,
922 });
923
924 let vocab = self.cfg.vocab_size;
925 let needed = vocab;
926 if logits.len() < needed {
927 anyhow::bail!("logits short: {} < {}", logits.len(), needed);
928 }
929 Ok(logits[..vocab].to_vec())
930 }
931
932 pub fn decode_get_logits(&mut self, input: u32) -> Result<Vec<f32>> {
940 let cache = self.cache.as_ref().ok_or_else(|| {
941 anyhow::anyhow!(
942 "decode_get_logits: cache not seeded; call prefill_get_last_logits first"
943 )
944 })?;
945 let past_seq = cache.past_seq;
946
947 let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
948 let (hir, params) =
949 build_llama32_decode_hir_sized(&self.cfg, &mut wm, 1, past_seq)?;
950 let session = Session::new(self.device);
951 let mut compiled = self.compile_hir_profiled(&session, hir, true)?;
952 for (name, data) in ¶ms {
953 compiled.set_param(name, data);
954 }
955
956 let (cos, sin) = compute_rope_slice(&self.inv_freq, past_seq);
957 let input_ids_f32 = [input as f32];
958 let key_strs: Vec<String> = (0..self.cfg.num_hidden_layers)
959 .flat_map(|i| [format!("past_k_{i}"), format!("past_v_{i}")])
960 .collect();
961 let mut inputs: Vec<(&str, &[f32])> =
962 Vec::with_capacity(3 + 2 * self.cfg.num_hidden_layers);
963 inputs.push(("input_ids", input_ids_f32.as_slice()));
964 inputs.push(("rope_cos", cos.as_slice()));
965 inputs.push(("rope_sin", sin.as_slice()));
966 for i in 0..self.cfg.num_hidden_layers {
967 let pk = &cache.layers_k[i];
968 let pv = &cache.layers_v[i];
969 inputs.push((&key_strs[2 * i], pk.as_slice()));
970 inputs.push((&key_strs[2 * i + 1], pv.as_slice()));
971 }
972
973 let outputs = compiled.run(&inputs);
974 if outputs.len() != 1 + 2 * self.cfg.num_hidden_layers {
975 anyhow::bail!(
976 "decode_get_logits: got {} outputs, expected {}",
977 outputs.len(),
978 1 + 2 * self.cfg.num_hidden_layers
979 );
980 }
981 let mut iter = outputs.into_iter();
982 let logits = iter.next().context("logits missing")?;
983 let mut new_k = Vec::with_capacity(self.cfg.num_hidden_layers);
984 let mut new_v = Vec::with_capacity(self.cfg.num_hidden_layers);
985 for _ in 0..self.cfg.num_hidden_layers {
986 new_k.push(iter.next().context("k missing")?);
987 new_v.push(iter.next().context("v missing")?);
988 }
989
990 let cache_mut = self.cache.as_mut().unwrap();
991 cache_mut.past_seq = past_seq + 1;
992 cache_mut.layers_k = new_k;
993 cache_mut.layers_v = new_v;
994 self.tokens.push(input);
995
996 Ok(logits)
997 }
998}
999
1000fn compute_rope_slice(inv_freq: &[f64], pos: usize) -> (Vec<f32>, Vec<f32>) {
1004 rope_slice(inv_freq, pos)
1005}
1006
1007#[cfg(test)]
1008mod tests {
1009 use super::*;
1010 use crate::config::Llama32Config;
1011
1012 fn tiny_cfg() -> Llama32Config {
1013 Llama32Config {
1014 vocab_size: 16,
1015 hidden_size: 16,
1016 intermediate_size: 32,
1017 num_hidden_layers: 2,
1018 num_attention_heads: 4,
1019 num_key_value_heads: 2,
1020 max_position_embeddings: 16,
1021 rms_norm_eps: 1e-5,
1022 rope_theta: 500_000.0,
1023 hidden_act: "silu".into(),
1024 tie_word_embeddings: false,
1025 attention_bias: false,
1026 head_dim: Some(8),
1027 rope_scaling: None,
1028 }
1029 }
1030
1031 fn synthetic_weights(cfg: &Llama32Config) -> WeightMap {
1032 let h = cfg.hidden_size;
1033 let q_dim = cfg.q_proj_dim();
1034 let kv_dim = cfg.kv_proj_dim();
1035 let int_dim = cfg.intermediate_size;
1036 let mut t: HashMap<String, (Vec<f32>, Vec<usize>)> = HashMap::new();
1037 let pat = |n: usize, salt: u32| -> Vec<f32> {
1040 (0..n)
1041 .map(|i| {
1042 let x = ((i as u32).wrapping_mul(2654435761).wrapping_add(salt)) >> 8;
1043 (x as f32 / (1u32 << 24) as f32) - 0.5
1044 })
1045 .collect()
1046 };
1047 t.insert(
1048 "model.embed_tokens.weight".into(),
1049 (pat(cfg.vocab_size * h, 1), vec![cfg.vocab_size, h]),
1050 );
1051 for i in 0..cfg.num_hidden_layers {
1052 let lp = format!("model.layers.{i}");
1053 t.insert(
1054 format!("{lp}.input_layernorm.weight"),
1055 (pat(h, 100 + i as u32), vec![h]),
1056 );
1057 t.insert(
1058 format!("{lp}.post_attention_layernorm.weight"),
1059 (pat(h, 200 + i as u32), vec![h]),
1060 );
1061 t.insert(
1062 format!("{lp}.self_attn.q_proj.weight"),
1063 (pat(q_dim * h, 300 + i as u32), vec![q_dim, h]),
1064 );
1065 t.insert(
1066 format!("{lp}.self_attn.k_proj.weight"),
1067 (pat(kv_dim * h, 400 + i as u32), vec![kv_dim, h]),
1068 );
1069 t.insert(
1070 format!("{lp}.self_attn.v_proj.weight"),
1071 (pat(kv_dim * h, 500 + i as u32), vec![kv_dim, h]),
1072 );
1073 t.insert(
1074 format!("{lp}.self_attn.o_proj.weight"),
1075 (pat(h * q_dim, 600 + i as u32), vec![h, q_dim]),
1076 );
1077 t.insert(
1078 format!("{lp}.mlp.gate_proj.weight"),
1079 (pat(int_dim * h, 900 + i as u32), vec![int_dim, h]),
1080 );
1081 t.insert(
1082 format!("{lp}.mlp.up_proj.weight"),
1083 (pat(int_dim * h, 1000 + i as u32), vec![int_dim, h]),
1084 );
1085 t.insert(
1086 format!("{lp}.mlp.down_proj.weight"),
1087 (pat(h * int_dim, 1100 + i as u32), vec![h, int_dim]),
1088 );
1089 }
1090 t.insert("model.norm.weight".into(), (pat(h, 2000), vec![h]));
1091 t.insert(
1092 "lm_head.weight".into(),
1093 (pat(cfg.vocab_size * h, 3000), vec![cfg.vocab_size, h]),
1094 );
1095 WeightMap::from_tensors(t)
1096 }
1097
1098 #[test]
1099 fn generator_drains_loader_and_runs_one_step() {
1100 let cfg = tiny_cfg();
1101 let mut wm = synthetic_weights(&cfg);
1102 let mut gn = Llama32Generator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap();
1103 assert_eq!(wm.len(), 0, "loader should be drained");
1104 gn.prefill(&[1, 2, 3]);
1105 let t = gn.step(SampleOpts::greedy()).unwrap();
1106 assert!((t as usize) < cfg.vocab_size);
1107 assert_eq!(gn.tokens().len(), 4);
1108 }
1109
1110 #[test]
1111 fn generate_n_appends_n_tokens() {
1112 let cfg = tiny_cfg();
1113 let mut wm = synthetic_weights(&cfg);
1114 let mut gn = Llama32Generator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap();
1115 gn.prefill(&[5, 6]);
1116 let new_tokens = gn.generate(3, SampleOpts::greedy()).unwrap();
1117 assert_eq!(new_tokens.len(), 3);
1118 assert_eq!(gn.tokens().len(), 5);
1119 for t in &new_tokens {
1120 assert!((*t as usize) < cfg.vocab_size);
1121 }
1122 }
1123
1124 #[test]
1125 fn step_without_prefill_errors() {
1126 let cfg = tiny_cfg();
1127 let mut wm = synthetic_weights(&cfg);
1128 let mut gn = Llama32Generator::from_loader(cfg, &mut wm, Device::Cpu).unwrap();
1129 let r = gn.step(SampleOpts::greedy());
1130 assert!(r.is_err());
1131 }
1132
1133 #[test]
1134 fn cached_matches_naive_on_greedy() {
1135 let cfg = tiny_cfg();
1142 let prompt: Vec<u32> = vec![1, 2, 3, 5];
1143 let steps = 4;
1144
1145 let mut wm_n = synthetic_weights(&cfg);
1146 let mut gn_naive =
1147 Llama32Generator::from_loader(cfg.clone(), &mut wm_n, Device::Cpu).unwrap();
1148 gn_naive.prefill(&prompt);
1149 let naive_tokens = gn_naive.generate(steps, SampleOpts::greedy()).unwrap();
1150
1151 let mut wm_c = synthetic_weights(&cfg);
1152 let mut gn_cached =
1153 Llama32Generator::from_loader(cfg.clone(), &mut wm_c, Device::Cpu).unwrap();
1154 gn_cached.prefill(&prompt);
1155 let cached_tokens = gn_cached
1156 .generate_cached(steps, SampleOpts::greedy())
1157 .unwrap();
1158
1159 assert_eq!(
1160 cached_tokens, naive_tokens,
1161 "cached vs naive token mismatch — KV cache or kernel-Lq!=Lk bug"
1162 );
1163 }
1164
1165 #[test]
1166 fn cached_step_advances_cache_invariant() {
1167 let cfg = tiny_cfg();
1168 let mut wm = synthetic_weights(&cfg);
1169 let mut gn = Llama32Generator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap();
1170 gn.prefill(&[1, 2, 3]);
1171 let _ = gn.step_cached(SampleOpts::greedy()).unwrap();
1172 assert_eq!(gn.tokens().len(), 4);
1174 assert_eq!(gn.cache.as_ref().unwrap().past_seq, 3);
1175 let _ = gn.step_cached(SampleOpts::greedy()).unwrap();
1176 assert_eq!(gn.tokens().len(), 5);
1178 assert_eq!(gn.cache.as_ref().unwrap().past_seq, 4);
1179 }
1180
1181 #[test]
1182 fn bucketed_decode_matches_oneshot() {
1183 let cfg = tiny_cfg();
1189 let prompt: Vec<u32> = vec![1, 2, 3, 5];
1190 let steps = 6;
1191
1192 let mut wm_one = synthetic_weights(&cfg);
1193 let mut gn_one =
1194 Llama32Generator::from_loader(cfg.clone(), &mut wm_one, Device::Cpu).unwrap();
1195 gn_one.prefill(&prompt);
1196 let oneshot_tokens = gn_one.generate_cached(steps, SampleOpts::greedy()).unwrap();
1197
1198 let mut wm_buc = synthetic_weights(&cfg);
1199 let mut gn_buc = Llama32Generator::from_loader(cfg.clone(), &mut wm_buc, Device::Cpu)
1200 .unwrap()
1201 .with_decode_cache(32);
1202 gn_buc.prefill(&prompt);
1203 let bucketed_tokens = gn_buc.generate_cached(steps, SampleOpts::greedy()).unwrap();
1204
1205 assert_eq!(
1206 bucketed_tokens, oneshot_tokens,
1207 "bucketed-cache decode diverged from one-shot decode — \
1208 mask, padding, or output-slice bug"
1209 );
1210 }
1211
1212 #[test]
1213 fn prefill_compile_cache_does_not_change_output() {
1214 let cfg = tiny_cfg();
1215 let prompt: Vec<u32> = vec![1, 2, 3, 5];
1216 let mut wm_a = synthetic_weights(&cfg);
1217 let mut gn_a = Llama32Generator::from_loader(cfg.clone(), &mut wm_a, Device::Cpu).unwrap();
1218 gn_a.prefill(&prompt);
1219 let a = gn_a.generate_cached(4, SampleOpts::greedy()).unwrap();
1220
1221 let mut wm_b = synthetic_weights(&cfg);
1222 let mut gn_b = Llama32Generator::from_loader(cfg.clone(), &mut wm_b, Device::Cpu)
1223 .unwrap()
1224 .with_prefill_cache(4);
1225 gn_b.prefill(&prompt);
1226 let b = gn_b.generate_cached(4, SampleOpts::greedy()).unwrap();
1227
1228 assert_eq!(a, b, "enabling prefill_cache must not change output");
1229 }
1230
1231 #[test]
1232 fn dynamic_decode_matches_oneshot() {
1233 let cfg = tiny_cfg();
1234 let prompt: Vec<u32> = vec![1, 2, 3, 5];
1235 let steps = 6;
1236
1237 let mut wm_one = synthetic_weights(&cfg);
1238 let mut gn_one =
1239 Llama32Generator::from_loader(cfg.clone(), &mut wm_one, Device::Cpu).unwrap();
1240 gn_one.prefill(&prompt);
1241 let oneshot_tokens = gn_one.generate_cached(steps, SampleOpts::greedy()).unwrap();
1242
1243 let mut wm_dyn = synthetic_weights(&cfg);
1244 let mut gn_dyn = Llama32Generator::from_loader(cfg.clone(), &mut wm_dyn, Device::Cpu)
1245 .unwrap()
1246 .with_dynamic_decode_cache(8);
1247 gn_dyn.prefill(&prompt);
1248 let dynamic_tokens = gn_dyn.generate_cached(steps, SampleOpts::greedy()).unwrap();
1249
1250 assert_eq!(
1251 dynamic_tokens, oneshot_tokens,
1252 "dynamic past_seq decode diverged from one-shot decode"
1253 );
1254 }
1255
1256 #[test]
1257 fn dynamic_prefill_matches_oneshot() {
1258 let cfg = tiny_cfg();
1259 let prompt: Vec<u32> = vec![1, 2, 3, 5];
1260 let steps = 4;
1261
1262 let mut wm_one = synthetic_weights(&cfg);
1263 let mut gn_one =
1264 Llama32Generator::from_loader(cfg.clone(), &mut wm_one, Device::Cpu).unwrap();
1265 gn_one.prefill(&prompt);
1266 let oneshot_tokens = gn_one.generate_cached(steps, SampleOpts::greedy()).unwrap();
1267
1268 let mut wm_dyn = synthetic_weights(&cfg);
1269 let mut gn_dyn = Llama32Generator::from_loader(cfg.clone(), &mut wm_dyn, Device::Cpu)
1270 .unwrap()
1271 .with_dynamic_prefill_cache(8);
1272 gn_dyn.prefill(&prompt);
1273 let dynamic_tokens = gn_dyn.generate_cached(steps, SampleOpts::greedy()).unwrap();
1274
1275 assert_eq!(
1276 dynamic_tokens, oneshot_tokens,
1277 "dynamic seq prefill diverged from one-shot prefill"
1278 );
1279 }
1280
1281 #[test]
1282 fn dynamic_prefill_and_decode_matches_oneshot() {
1283 let cfg = tiny_cfg();
1284 let prompt: Vec<u32> = vec![1, 2, 3, 5];
1285 let steps = 6;
1286
1287 let mut wm_one = synthetic_weights(&cfg);
1288 let mut gn_one =
1289 Llama32Generator::from_loader(cfg.clone(), &mut wm_one, Device::Cpu).unwrap();
1290 gn_one.prefill(&prompt);
1291 let oneshot_tokens = gn_one.generate_cached(steps, SampleOpts::greedy()).unwrap();
1292
1293 let mut wm_dyn = synthetic_weights(&cfg);
1294 let mut gn_dyn = Llama32Generator::from_loader(cfg.clone(), &mut wm_dyn, Device::Cpu)
1295 .unwrap()
1296 .with_dynamic_prefill_cache(8)
1297 .with_dynamic_decode_cache(8);
1298 gn_dyn.prefill(&prompt);
1299 let dynamic_tokens = gn_dyn.generate_cached(steps, SampleOpts::greedy()).unwrap();
1300
1301 assert_eq!(
1302 dynamic_tokens, oneshot_tokens,
1303 "dynamic prefill+decode diverged from one-shot path"
1304 );
1305 }
1306
1307 #[test]
1308 fn greedy_is_deterministic_across_runs() {
1309 let cfg = tiny_cfg();
1310 let weights = synthetic_weights(&cfg);
1311 let mk = || {
1312 let mut wm = WeightMap::from_tensors(weights_as_hashmap(&weights));
1313 Llama32Generator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap()
1314 };
1315 let mut a = mk();
1316 let mut b = mk();
1317 a.prefill(&[1, 2, 3]);
1318 b.prefill(&[1, 2, 3]);
1319 let ta = a.generate(4, SampleOpts::greedy()).unwrap();
1320 let tb = b.generate(4, SampleOpts::greedy()).unwrap();
1321 assert_eq!(ta, tb);
1322 }
1323
1324 fn weights_as_hashmap(wm: &WeightMap) -> HashMap<String, (Vec<f32>, Vec<usize>)> {
1325 let _ = wm; let cfg = tiny_cfg();
1331 let mut new = synthetic_weights(&cfg);
1332 let keys: Vec<String> = new.keys().map(|s| s.to_string()).collect();
1333 let mut out = HashMap::new();
1334 for k in keys {
1335 out.insert(k.clone(), new.take(&k).unwrap());
1336 }
1337 out
1338 }
1339}