1use crate::compile_cache::pad_rows;
19
20#[derive(Debug, Clone)]
22pub struct LayerKvCache {
23 pub past_len: usize,
24 pub layers_k: Vec<Vec<f32>>,
25 pub layers_v: Vec<Vec<f32>>,
26}
27
28impl LayerKvCache {
29 pub fn from_layer_outputs(
30 num_layers: usize,
31 batch: usize,
32 past_seq: usize,
33 kv_dim: usize,
34 outputs: &[Vec<f32>],
35 ) -> Result<Self, String> {
36 if outputs.len() != 2 * num_layers {
37 return Err(format!(
38 "from_layer_outputs: expected {} K/V tensors, got {}",
39 2 * num_layers,
40 outputs.len()
41 ));
42 }
43 let expected = batch * past_seq * kv_dim;
44 let mut layers_k = Vec::with_capacity(num_layers);
45 let mut layers_v = Vec::with_capacity(num_layers);
46 for layer in 0..num_layers {
47 let k = &outputs[2 * layer];
48 let v = &outputs[2 * layer + 1];
49 if k.len() != expected || v.len() != expected {
50 return Err(format!(
51 "layer {layer}: k.len={} v.len={} expected {expected}",
52 k.len(),
53 v.len()
54 ));
55 }
56 layers_k.push(k.clone());
57 layers_v.push(v.clone());
58 }
59 Ok(Self {
60 past_len: past_seq,
61 layers_k,
62 layers_v,
63 })
64 }
65
66 pub fn pad_layers_to_upper(&self, upper: u64, kv_dim: usize) -> (Vec<Vec<f32>>, Vec<Vec<f32>>) {
68 let padded_k = self
69 .layers_k
70 .iter()
71 .map(|k| pad_rows(k, kv_dim, upper))
72 .collect();
73 let padded_v = self
74 .layers_v
75 .iter()
76 .map(|v| pad_rows(v, kv_dim, upper))
77 .collect();
78 (padded_k, padded_v)
79 }
80
81 pub fn advance_from_decode_outputs(
83 &mut self,
84 outputs: Vec<Vec<f32>>,
85 _batch: usize,
86 kv_dim: usize,
87 ) -> Result<(), String> {
88 let n = self.layers_k.len();
89 if outputs.len() != 1 + 2 * n {
90 return Err(format!(
91 "advance_from_decode_outputs: expected {} outputs, got {}",
92 1 + 2 * n,
93 outputs.len()
94 ));
95 }
96 let new_len = self.past_len + 1;
97 let real_len = new_len * kv_dim;
98 let mut iter = outputs.into_iter();
99 let _logits = iter.next().ok_or("missing logits")?;
100 for i in 0..n {
101 let k = iter.next().ok_or("missing k")?;
102 let v = iter.next().ok_or("missing v")?;
103 self.layers_k[i] = k[..real_len.min(k.len())].to_vec();
104 self.layers_v[i] = v[..real_len.min(v.len())].to_vec();
105 }
106 self.past_len = new_len;
107 Ok(())
108 }
109}