1use burn::tensor::backend::Backend;
4use burn::tensor::Tensor;
5
6use crate::ops::paged_attention::PagedKVCache;
7
8#[derive(Debug, Clone)]
10pub struct MultiHeadLatentAttention<B: Backend> {
11 compression_ratio: usize,
13 latent_dim: usize,
15 down_proj: Tensor<B, 2>,
17 up_proj: Tensor<B, 2>,
19 rope_key: Tensor<B, 2>,
21}
22
23impl<B: Backend> MultiHeadLatentAttention<B> {
24 pub fn new(
26 compression_ratio: usize,
27 latent_dim: usize,
28 down_proj: Tensor<B, 2>,
29 up_proj: Tensor<B, 2>,
30 rope_key: Tensor<B, 2>,
31 ) -> Self {
32 Self {
33 compression_ratio,
34 latent_dim,
35 down_proj,
36 up_proj,
37 rope_key,
38 }
39 }
40
41 pub fn compression_ratio(&self) -> usize {
43 self.compression_ratio
44 }
45
46 pub fn latent_dim(&self) -> usize {
48 self.latent_dim
49 }
50
51 pub fn compress_kv(
57 &self,
58 k: Tensor<B, 4>,
59 v: Tensor<B, 4>,
60 ) -> Result<(Tensor<B, 4>, Tensor<B, 4>), &'static str> {
61 let [batch, num_heads, seq_len, head_dim] = k.dims();
62 if v.dims() != [batch, num_heads, seq_len, head_dim] {
63 return Err("keys/values shape mismatch");
64 }
65 self.validate_projections(head_dim)?;
66
67 let tokens = batch * num_heads * seq_len;
68 let k_flat = k.reshape([tokens, head_dim]);
69 let v_flat = v.reshape([tokens, head_dim]);
70
71 let k_latent = k_flat.matmul(self.down_proj.clone());
72 let v_latent = v_flat.matmul(self.down_proj.clone());
73
74 let k_latent = k_latent.reshape([batch, num_heads, seq_len, self.latent_dim]);
75 let v_latent = v_latent.reshape([batch, num_heads, seq_len, self.latent_dim]);
76
77 Ok((k_latent, v_latent))
78 }
79
80 pub fn decompress_kv(
86 &self,
87 k_latent: Tensor<B, 4>,
88 v_latent: Tensor<B, 4>,
89 ) -> Result<(Tensor<B, 4>, Tensor<B, 4>), &'static str> {
90 let [batch, num_heads, seq_len, latent_dim] = k_latent.dims();
91 if v_latent.dims() != [batch, num_heads, seq_len, latent_dim] {
92 return Err("latent keys/values shape mismatch");
93 }
94 if latent_dim != self.latent_dim {
95 return Err("latent dimension mismatch");
96 }
97
98 let head_dim = self.up_proj.dims()[1];
99 self.validate_projections(head_dim)?;
100
101 let tokens = batch * num_heads * seq_len;
102 let k_flat = k_latent.reshape([tokens, latent_dim]);
103 let v_flat = v_latent.reshape([tokens, latent_dim]);
104
105 let mut k_full = k_flat.clone().matmul(self.up_proj.clone());
106 let v_full = v_flat.matmul(self.up_proj.clone());
107
108 if self.rope_key.dims() == [latent_dim, head_dim] {
109 let rope = k_flat.clone().matmul(self.rope_key.clone());
110 k_full = k_full + rope;
111 }
112
113 let k_full = k_full.reshape([batch, num_heads, seq_len, head_dim]);
114 let v_full = v_full.reshape([batch, num_heads, seq_len, head_dim]);
115
116 Ok((k_full, v_full))
117 }
118
119 pub fn compress_kv_3d(
125 &self,
126 k: Tensor<B, 3>,
127 v: Tensor<B, 3>,
128 ) -> Result<(Tensor<B, 3>, Tensor<B, 3>), &'static str> {
129 let [num_heads, seq_len, head_dim] = k.dims();
130 if v.dims() != [num_heads, seq_len, head_dim] {
131 return Err("keys/values shape mismatch");
132 }
133 self.validate_projections(head_dim)?;
134
135 let k = k.reshape([1, num_heads, seq_len, head_dim]);
136 let v = v.reshape([1, num_heads, seq_len, head_dim]);
137 let (k_latent, v_latent) = self.compress_kv(k, v)?;
138
139 Ok((
140 k_latent.reshape([num_heads, seq_len, self.latent_dim]),
141 v_latent.reshape([num_heads, seq_len, self.latent_dim]),
142 ))
143 }
144
145 pub fn decompress_kv_3d(
151 &self,
152 k_latent: Tensor<B, 3>,
153 v_latent: Tensor<B, 3>,
154 ) -> Result<(Tensor<B, 3>, Tensor<B, 3>), &'static str> {
155 let [num_heads, seq_len, latent_dim] = k_latent.dims();
156 if v_latent.dims() != [num_heads, seq_len, latent_dim] {
157 return Err("latent keys/values shape mismatch");
158 }
159 if latent_dim != self.latent_dim {
160 return Err("latent dimension mismatch");
161 }
162
163 let k = k_latent.reshape([1, num_heads, seq_len, latent_dim]);
164 let v = v_latent.reshape([1, num_heads, seq_len, latent_dim]);
165 let (k_full, v_full) = self.decompress_kv(k, v)?;
166 let head_dim = k_full.dims()[3];
167 let value_dim = v_full.dims()[3];
168
169 Ok((
170 k_full.reshape([num_heads, seq_len, head_dim]),
171 v_full.reshape([num_heads, seq_len, value_dim]),
172 ))
173 }
174
175 fn validate_projections(&self, head_dim: usize) -> Result<(), &'static str> {
176 let down_dims = self.down_proj.dims();
177 if down_dims != [head_dim, self.latent_dim] {
178 return Err("down projection shape mismatch");
179 }
180 let up_dims = self.up_proj.dims();
181 if up_dims != [self.latent_dim, head_dim] {
182 return Err("up projection shape mismatch");
183 }
184 let rope_dims = self.rope_key.dims();
185 if rope_dims != [self.latent_dim, head_dim] {
186 return Err("rope key shape mismatch");
187 }
188 if self.compression_ratio == 0 || self.latent_dim == 0 {
189 return Err("invalid compression configuration");
190 }
191 Ok(())
192 }
193}
194
195#[derive(Debug, Clone)]
197pub struct CompressedKVCache<B: Backend> {
198 inner: PagedKVCache<B>,
199 mla: MultiHeadLatentAttention<B>,
200}
201
202impl<B: Backend> CompressedKVCache<B> {
203 pub fn new(
205 max_blocks: usize,
206 num_layers: usize,
207 num_heads: usize,
208 mla: MultiHeadLatentAttention<B>,
209 device: &B::Device,
210 ) -> Self {
211 let inner = PagedKVCache::new(max_blocks, num_layers, num_heads, mla.latent_dim(), device);
212 Self { inner, mla }
213 }
214
215 pub fn allocate_sequence(&mut self) -> usize {
217 self.inner.allocate_sequence()
218 }
219
220 pub fn append(
222 &mut self,
223 layer: usize,
224 seq_id: usize,
225 keys: Tensor<B, 3>,
226 values: Tensor<B, 3>,
227 ) -> Result<(), &'static str> {
228 let (k_latent, v_latent) = self.mla.compress_kv_3d(keys, values)?;
229 self.inner.append(layer, seq_id, k_latent, v_latent)
230 }
231
232 pub fn append_batched(
234 &mut self,
235 layer: usize,
236 seq_id: usize,
237 keys: Tensor<B, 4>,
238 values: Tensor<B, 4>,
239 ) -> Result<(), &'static str> {
240 let [batch, num_heads, seq_len, head_dim] = keys.dims();
241 if batch != 1 {
242 return Err("compressed cache expects batch=1");
243 }
244 if values.dims() != [batch, num_heads, seq_len, head_dim] {
245 return Err("keys/values shape mismatch");
246 }
247 let keys = keys.reshape([num_heads, seq_len, head_dim]);
248 let values = values.reshape([num_heads, seq_len, head_dim]);
249 self.append(layer, seq_id, keys, values)
250 }
251
252 pub fn append_compressed(
254 &mut self,
255 layer: usize,
256 seq_id: usize,
257 keys_latent: Tensor<B, 3>,
258 values_latent: Tensor<B, 3>,
259 ) -> Result<(), &'static str> {
260 self.inner.append(layer, seq_id, keys_latent, values_latent)
261 }
262
263 pub fn get_kv(
265 &self,
266 layer: usize,
267 seq_id: usize,
268 ) -> Result<(Tensor<B, 3>, Tensor<B, 3>), &'static str> {
269 let (k_latent, v_latent) = self.inner.get_kv(layer, seq_id)?;
270 self.mla.decompress_kv_3d(k_latent, v_latent)
271 }
272
273 pub fn get_compressed_kv(
275 &self,
276 layer: usize,
277 seq_id: usize,
278 ) -> Result<(Tensor<B, 3>, Tensor<B, 3>), &'static str> {
279 self.inner.get_kv(layer, seq_id)
280 }
281
282 pub fn iter_kv_blocks(
284 &self,
285 layer: usize,
286 seq_id: usize,
287 ) -> Result<Vec<(Tensor<B, 3>, Tensor<B, 3>)>, &'static str> {
288 let kv_iter = self.inner.iter_kv_blocks(layer, seq_id)?;
289 let mut blocks = Vec::new();
290 for block in kv_iter {
291 let [num_heads, _, latent_dim] = block.keys.dims();
292 let k_latent = block
293 .keys
294 .clone()
295 .slice([0..num_heads, 0..block.num_tokens, 0..latent_dim]);
296 let v_latent = block
297 .values
298 .clone()
299 .slice([0..num_heads, 0..block.num_tokens, 0..latent_dim]);
300 let (k_full, v_full) = self.mla.decompress_kv_3d(k_latent, v_latent)?;
301 blocks.push((k_full, v_full));
302 }
303 Ok(blocks)
304 }
305
306 pub fn iter_compressed_blocks(
308 &self,
309 layer: usize,
310 seq_id: usize,
311 ) -> Result<Vec<(Tensor<B, 3>, Tensor<B, 3>)>, &'static str> {
312 let kv_iter = self.inner.iter_kv_blocks(layer, seq_id)?;
313 let mut blocks = Vec::new();
314 for block in kv_iter {
315 let [num_heads, _, latent_dim] = block.keys.dims();
316 let k_latent = block
317 .keys
318 .clone()
319 .slice([0..num_heads, 0..block.num_tokens, 0..latent_dim]);
320 let v_latent = block
321 .values
322 .clone()
323 .slice([0..num_heads, 0..block.num_tokens, 0..latent_dim]);
324 blocks.push((k_latent, v_latent));
325 }
326 Ok(blocks)
327 }
328
329 pub fn seq_len(&self, layer: usize, seq_id: usize) -> Result<usize, &'static str> {
331 self.inner.seq_len(layer, seq_id)
332 }
333
334 pub fn num_free_blocks(&self) -> usize {
336 self.inner.num_free_blocks()
337 }
338
339 pub fn num_heads(&self) -> usize {
341 self.inner.num_heads()
342 }
343
344 pub fn latent_dim(&self) -> usize {
345 self.inner.head_dim()
346 }
347
348 pub fn device(&self) -> &B::Device {
349 self.inner.device()
350 }
351
352 pub fn free_sequence(&mut self, seq_id: usize) -> Result<(), &'static str> {
354 self.inner.free_sequence(seq_id)
355 }
356}
357
358#[cfg(all(test, feature = "cpu"))]
359mod tests {
360 use super::*;
361 use burn::tensor::{Distribution, TensorData};
362 use burn_ndarray::NdArray;
363
364 type TestBackend = NdArray<f32>;
365
366 fn identity_matrix(dim: usize, device: &<TestBackend as Backend>::Device) -> Tensor<TestBackend, 2> {
367 let mut data = vec![0.0f32; dim * dim];
368 for i in 0..dim {
369 data[i * dim + i] = 1.0;
370 }
371 Tensor::from_data(TensorData::new(data, [dim, dim]), device)
372 }
373
374 fn zero_matrix(rows: usize, cols: usize, device: &<TestBackend as Backend>::Device) -> Tensor<TestBackend, 2> {
375 let data = vec![0.0f32; rows * cols];
376 Tensor::from_data(TensorData::new(data, [rows, cols]), device)
377 }
378
379 #[test]
380 fn test_mla_compress_decompress_roundtrip_identity() {
381 let device = <TestBackend as Backend>::Device::default();
382 let head_dim = 4;
383 let latent_dim = 4;
384 let down = identity_matrix(head_dim, &device);
385 let up = identity_matrix(head_dim, &device);
386 let rope = zero_matrix(latent_dim, head_dim, &device);
387
388 let mla = MultiHeadLatentAttention::new(1, latent_dim, down, up, rope);
389
390 let q = Tensor::<TestBackend, 4>::random([1, 2, 4, head_dim], Distribution::Normal(0.0, 0.5), &device);
391 let v = Tensor::<TestBackend, 4>::random([1, 2, 4, head_dim], Distribution::Normal(0.0, 0.5), &device);
392
393 let (k_latent, v_latent) = mla.compress_kv(q.clone(), v.clone()).expect("compress");
394 let (k_full, v_full) = mla.decompress_kv(k_latent, v_latent).expect("decompress");
395
396 let k_data = q.into_data().into_vec::<f32>().expect("k data");
397 let k_roundtrip = k_full.into_data().into_vec::<f32>().expect("k roundtrip");
398 for (idx, (orig, round)) in k_data.iter().zip(k_roundtrip.iter()).enumerate() {
399 let diff = (orig - round).abs();
400 assert!(diff < 1e-4, "k mismatch at {}: {} vs {}", idx, orig, round);
401 }
402
403 let v_data = v.into_data().into_vec::<f32>().expect("v data");
404 let v_roundtrip = v_full.into_data().into_vec::<f32>().expect("v roundtrip");
405 for (idx, (orig, round)) in v_data.iter().zip(v_roundtrip.iter()).enumerate() {
406 let diff = (orig - round).abs();
407 assert!(diff < 1e-4, "v mismatch at {}: {} vs {}", idx, orig, round);
408 }
409 }
410
411 #[test]
412 fn test_compressed_kv_cache_roundtrip() {
413 let device = <TestBackend as Backend>::Device::default();
414 let head_dim = 4;
415 let latent_dim = 4;
416 let down = identity_matrix(head_dim, &device);
417 let up = identity_matrix(head_dim, &device);
418 let rope = zero_matrix(latent_dim, head_dim, &device);
419
420 let mla = MultiHeadLatentAttention::new(1, latent_dim, down, up, rope);
421 let mut cache = CompressedKVCache::new(4, 1, 2, mla, &device);
422
423 let seq_id = cache.allocate_sequence();
424 let keys = Tensor::<TestBackend, 3>::random(
425 [2, 5, head_dim],
426 Distribution::Normal(0.0, 0.5),
427 &device,
428 );
429 let values = Tensor::<TestBackend, 3>::random(
430 [2, 5, head_dim],
431 Distribution::Normal(0.0, 0.5),
432 &device,
433 );
434
435 cache.append(0, seq_id, keys.clone(), values.clone()).expect("append");
436
437 let (k_full, v_full) = cache.get_kv(0, seq_id).expect("get kv");
438 assert_eq!(k_full.dims(), [2, 5, head_dim]);
439 assert_eq!(v_full.dims(), [2, 5, head_dim]);
440
441 let k_data = keys.into_data().into_vec::<f32>().expect("keys data");
442 let k_round = k_full.into_data().into_vec::<f32>().expect("keys roundtrip");
443 for (idx, (orig, round)) in k_data.iter().zip(k_round.iter()).enumerate() {
444 let diff = (orig - round).abs();
445 assert!(diff < 1e-4, "k mismatch at {}: {} vs {}", idx, orig, round);
446 }
447 }
448}