1use crate::compile_cache::pad_rows;
19
20#[derive(Debug, Clone)]
22pub struct LayerKvCache {
23 pub past_len: usize,
24 pub layers_k: Vec<Vec<f32>>,
25 pub layers_v: Vec<Vec<f32>>,
26}
27
28impl LayerKvCache {
29 pub fn from_layer_outputs(
30 num_layers: usize,
31 batch: usize,
32 past_seq: usize,
33 kv_dim: usize,
34 outputs: &[Vec<f32>],
35 ) -> Result<Self, String> {
36 let dims: Vec<usize> = vec![kv_dim; num_layers];
37 Self::from_layer_outputs_per_layer(num_layers, batch, past_seq, &dims, outputs)
38 }
39
40 pub fn from_layer_outputs_per_layer(
45 num_layers: usize,
46 batch: usize,
47 past_seq: usize,
48 kv_dims: &[usize],
49 outputs: &[Vec<f32>],
50 ) -> Result<Self, String> {
51 if outputs.len() != 2 * num_layers {
52 return Err(format!(
53 "from_layer_outputs_per_layer: expected {} K/V tensors, got {}",
54 2 * num_layers,
55 outputs.len()
56 ));
57 }
58 if kv_dims.len() != num_layers {
59 return Err(format!(
60 "from_layer_outputs_per_layer: expected {} kv_dims, got {}",
61 num_layers,
62 kv_dims.len()
63 ));
64 }
65 let mut layers_k = Vec::with_capacity(num_layers);
66 let mut layers_v = Vec::with_capacity(num_layers);
67 for layer in 0..num_layers {
68 let kv_dim = kv_dims[layer];
69 let expected = batch * past_seq * kv_dim;
70 let k = &outputs[2 * layer];
71 let v = &outputs[2 * layer + 1];
72 if k.len() != expected || v.len() != expected {
73 return Err(format!(
74 "layer {layer}: k.len={} v.len={} expected {expected} (kv_dim={kv_dim})",
75 k.len(),
76 v.len()
77 ));
78 }
79 layers_k.push(k.clone());
80 layers_v.push(v.clone());
81 }
82 Ok(Self {
83 past_len: past_seq,
84 layers_k,
85 layers_v,
86 })
87 }
88
89 pub fn pad_layers_to_upper(&self, upper: u64, kv_dim: usize) -> (Vec<Vec<f32>>, Vec<Vec<f32>>) {
91 let dims: Vec<usize> = vec![kv_dim; self.layers_k.len()];
92 self.pad_layers_to_upper_per_layer(upper, &dims)
93 }
94
95 pub fn pad_layers_to_upper_per_layer(
99 &self,
100 upper: u64,
101 kv_dims: &[usize],
102 ) -> (Vec<Vec<f32>>, Vec<Vec<f32>>) {
103 assert_eq!(
104 kv_dims.len(),
105 self.layers_k.len(),
106 "pad_layers_to_upper_per_layer: kv_dims len {} != layers {}",
107 kv_dims.len(),
108 self.layers_k.len(),
109 );
110 let padded_k = self
111 .layers_k
112 .iter()
113 .zip(kv_dims.iter())
114 .map(|(k, &d)| pad_rows(k, d, upper))
115 .collect();
116 let padded_v = self
117 .layers_v
118 .iter()
119 .zip(kv_dims.iter())
120 .map(|(v, &d)| pad_rows(v, d, upper))
121 .collect();
122 (padded_k, padded_v)
123 }
124
125 pub fn advance_from_decode_outputs(
127 &mut self,
128 outputs: Vec<Vec<f32>>,
129 batch: usize,
130 kv_dim: usize,
131 ) -> Result<(), String> {
132 let dims: Vec<usize> = vec![kv_dim; self.layers_k.len()];
133 self.advance_from_decode_outputs_per_layer(outputs, batch, &dims)
134 }
135
136 pub fn trim_sliding_window_per_layer(
155 &mut self,
156 kv_dims_keep: &[Option<(usize, usize)>],
157 ) -> Result<(), String> {
158 if kv_dims_keep.len() != self.layers_k.len() {
159 return Err(format!(
160 "trim_sliding_window_per_layer: kv_dims_keep len {} != layers {}",
161 kv_dims_keep.len(),
162 self.layers_k.len(),
163 ));
164 }
165 for (i, spec) in kv_dims_keep.iter().enumerate() {
166 let Some((kv_dim, window)) = spec else {
167 continue;
168 };
169 let kv_dim = *kv_dim;
170 let window = *window;
171 if window == 0 || kv_dim == 0 {
172 continue;
173 }
174 let rows = self.layers_k[i].len() / kv_dim;
175 if rows <= window {
176 continue;
177 }
178 let drop_rows = rows - window;
179 let drop_bytes = drop_rows * kv_dim;
180 self.layers_k[i].drain(..drop_bytes);
181 self.layers_v[i].drain(..drop_bytes);
182 }
183 Ok(())
184 }
185
186 pub fn advance_from_decode_outputs_per_layer(
188 &mut self,
189 outputs: Vec<Vec<f32>>,
190 _batch: usize,
191 kv_dims: &[usize],
192 ) -> Result<(), String> {
193 let n = self.layers_k.len();
194 if outputs.len() != 1 + 2 * n {
195 return Err(format!(
196 "advance_from_decode_outputs_per_layer: expected {} outputs, got {}",
197 1 + 2 * n,
198 outputs.len()
199 ));
200 }
201 if kv_dims.len() != n {
202 return Err(format!(
203 "advance_from_decode_outputs_per_layer: kv_dims len {} != layers {n}",
204 kv_dims.len()
205 ));
206 }
207 let new_len = self.past_len + 1;
208 let mut iter = outputs.into_iter();
209 let _logits = iter.next().ok_or("missing logits")?;
210 for i in 0..n {
211 let k = iter.next().ok_or("missing k")?;
212 let v = iter.next().ok_or("missing v")?;
213 let real_len = new_len * kv_dims[i];
214 self.layers_k[i] = k[..real_len.min(k.len())].to_vec();
215 self.layers_v[i] = v[..real_len.min(v.len())].to_vec();
216 }
217 self.past_len = new_len;
218 Ok(())
219 }
220}
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225
226 #[test]
227 fn sliding_window_trim_keeps_last_w_rows() {
228 let kv_dim = 4;
230 let rows = 6;
231 let mut cache = LayerKvCache {
232 past_len: rows,
233 layers_k: vec![(0..(rows * kv_dim)).map(|x| x as f32).collect(); 3],
234 layers_v: vec![(0..(rows * kv_dim)).map(|x| x as f32).collect(); 3],
235 };
236 let spec = [Some((kv_dim, 2)), None, Some((kv_dim, 4))];
238 cache.trim_sliding_window_per_layer(&spec).unwrap();
239 assert_eq!(cache.layers_k[0].len(), 2 * kv_dim);
240 assert_eq!(
242 cache.layers_k[0],
243 vec![16., 17., 18., 19., 20., 21., 22., 23.]
244 );
245 assert_eq!(
246 cache.layers_k[1].len(),
247 6 * kv_dim,
248 "untouched layer keeps full history"
249 );
250 assert_eq!(cache.layers_k[2].len(), 4 * kv_dim);
251 }
252
253 #[test]
254 fn sliding_window_trim_no_op_when_under_window() {
255 let kv_dim = 4;
256 let rows = 3;
257 let mut cache = LayerKvCache {
258 past_len: rows,
259 layers_k: vec![vec![1.0f32; rows * kv_dim]],
260 layers_v: vec![vec![2.0f32; rows * kv_dim]],
261 };
262 cache
263 .trim_sliding_window_per_layer(&[Some((kv_dim, 10))])
264 .unwrap();
265 assert_eq!(cache.layers_k[0].len(), rows * kv_dim);
266 }
267}