gllm_kernels/ops/
mla.rs

1//! Multi-Head Latent Attention (MLA) compression utilities.
2
3use burn::tensor::backend::Backend;
4use burn::tensor::Tensor;
5
6use crate::ops::paged_attention::PagedKVCache;
7
8/// Multi-head latent attention module for KV compression.
9#[derive(Debug, Clone)]
10pub struct MultiHeadLatentAttention<B: Backend> {
11    /// Compression ratio (typical: 8-16).
12    compression_ratio: usize,
13    /// Latent dimension after compression.
14    latent_dim: usize,
15    /// Down projection matrix W_dkv: [head_dim, latent_dim].
16    down_proj: Tensor<B, 2>,
17    /// Up projection matrix W_ukv: [latent_dim, head_dim].
18    up_proj: Tensor<B, 2>,
19    /// Decoupled RoPE projection: [latent_dim, head_dim].
20    rope_key: Tensor<B, 2>,
21}
22
23impl<B: Backend> MultiHeadLatentAttention<B> {
24    /// Create a new MLA module with explicit projections.
25    pub fn new(
26        compression_ratio: usize,
27        latent_dim: usize,
28        down_proj: Tensor<B, 2>,
29        up_proj: Tensor<B, 2>,
30        rope_key: Tensor<B, 2>,
31    ) -> Self {
32        Self {
33            compression_ratio,
34            latent_dim,
35            down_proj,
36            up_proj,
37            rope_key,
38        }
39    }
40
41    /// Compression ratio configured for this module.
42    pub fn compression_ratio(&self) -> usize {
43        self.compression_ratio
44    }
45
46    /// Latent dimension configured for this module.
47    pub fn latent_dim(&self) -> usize {
48        self.latent_dim
49    }
50
51    /// Compress KV tensors to latent space.
52    ///
53    /// # Shapes
54    /// * `k`, `v`: [batch, num_heads, seq_len, head_dim]
55    /// * returns: [batch, num_heads, seq_len, latent_dim]
56    pub fn compress_kv(
57        &self,
58        k: Tensor<B, 4>,
59        v: Tensor<B, 4>,
60    ) -> Result<(Tensor<B, 4>, Tensor<B, 4>), &'static str> {
61        let [batch, num_heads, seq_len, head_dim] = k.dims();
62        if v.dims() != [batch, num_heads, seq_len, head_dim] {
63            return Err("keys/values shape mismatch");
64        }
65        self.validate_projections(head_dim)?;
66
67        let tokens = batch * num_heads * seq_len;
68        let k_flat = k.reshape([tokens, head_dim]);
69        let v_flat = v.reshape([tokens, head_dim]);
70
71        let k_latent = k_flat.matmul(self.down_proj.clone());
72        let v_latent = v_flat.matmul(self.down_proj.clone());
73
74        let k_latent = k_latent.reshape([batch, num_heads, seq_len, self.latent_dim]);
75        let v_latent = v_latent.reshape([batch, num_heads, seq_len, self.latent_dim]);
76
77        Ok((k_latent, v_latent))
78    }
79
80    /// Decompress KV tensors from latent space.
81    ///
82    /// # Shapes
83    /// * `k_latent`, `v_latent`: [batch, num_heads, seq_len, latent_dim]
84    /// * returns: [batch, num_heads, seq_len, head_dim]
85    pub fn decompress_kv(
86        &self,
87        k_latent: Tensor<B, 4>,
88        v_latent: Tensor<B, 4>,
89    ) -> Result<(Tensor<B, 4>, Tensor<B, 4>), &'static str> {
90        let [batch, num_heads, seq_len, latent_dim] = k_latent.dims();
91        if v_latent.dims() != [batch, num_heads, seq_len, latent_dim] {
92            return Err("latent keys/values shape mismatch");
93        }
94        if latent_dim != self.latent_dim {
95            return Err("latent dimension mismatch");
96        }
97
98        let head_dim = self.up_proj.dims()[1];
99        self.validate_projections(head_dim)?;
100
101        let tokens = batch * num_heads * seq_len;
102        let k_flat = k_latent.reshape([tokens, latent_dim]);
103        let v_flat = v_latent.reshape([tokens, latent_dim]);
104
105        let mut k_full = k_flat.clone().matmul(self.up_proj.clone());
106        let v_full = v_flat.matmul(self.up_proj.clone());
107
108        if self.rope_key.dims() == [latent_dim, head_dim] {
109            let rope = k_flat.clone().matmul(self.rope_key.clone());
110            k_full = k_full + rope;
111        }
112
113        let k_full = k_full.reshape([batch, num_heads, seq_len, head_dim]);
114        let v_full = v_full.reshape([batch, num_heads, seq_len, head_dim]);
115
116        Ok((k_full, v_full))
117    }
118
119    /// Compress KV tensors with 3D shapes.
120    ///
121    /// # Shapes
122    /// * `k`, `v`: [num_heads, seq_len, head_dim]
123    /// * returns: [num_heads, seq_len, latent_dim]
124    pub fn compress_kv_3d(
125        &self,
126        k: Tensor<B, 3>,
127        v: Tensor<B, 3>,
128    ) -> Result<(Tensor<B, 3>, Tensor<B, 3>), &'static str> {
129        let [num_heads, seq_len, head_dim] = k.dims();
130        if v.dims() != [num_heads, seq_len, head_dim] {
131            return Err("keys/values shape mismatch");
132        }
133        self.validate_projections(head_dim)?;
134
135        let k = k.reshape([1, num_heads, seq_len, head_dim]);
136        let v = v.reshape([1, num_heads, seq_len, head_dim]);
137        let (k_latent, v_latent) = self.compress_kv(k, v)?;
138
139        Ok((
140            k_latent.reshape([num_heads, seq_len, self.latent_dim]),
141            v_latent.reshape([num_heads, seq_len, self.latent_dim]),
142        ))
143    }
144
145    /// Decompress KV tensors with 3D shapes.
146    ///
147    /// # Shapes
148    /// * `k_latent`, `v_latent`: [num_heads, seq_len, latent_dim]
149    /// * returns: [num_heads, seq_len, head_dim]
150    pub fn decompress_kv_3d(
151        &self,
152        k_latent: Tensor<B, 3>,
153        v_latent: Tensor<B, 3>,
154    ) -> Result<(Tensor<B, 3>, Tensor<B, 3>), &'static str> {
155        let [num_heads, seq_len, latent_dim] = k_latent.dims();
156        if v_latent.dims() != [num_heads, seq_len, latent_dim] {
157            return Err("latent keys/values shape mismatch");
158        }
159        if latent_dim != self.latent_dim {
160            return Err("latent dimension mismatch");
161        }
162
163        let k = k_latent.reshape([1, num_heads, seq_len, latent_dim]);
164        let v = v_latent.reshape([1, num_heads, seq_len, latent_dim]);
165        let (k_full, v_full) = self.decompress_kv(k, v)?;
166        let head_dim = k_full.dims()[3];
167        let value_dim = v_full.dims()[3];
168
169        Ok((
170            k_full.reshape([num_heads, seq_len, head_dim]),
171            v_full.reshape([num_heads, seq_len, value_dim]),
172        ))
173    }
174
175    fn validate_projections(&self, head_dim: usize) -> Result<(), &'static str> {
176        let down_dims = self.down_proj.dims();
177        if down_dims != [head_dim, self.latent_dim] {
178            return Err("down projection shape mismatch");
179        }
180        let up_dims = self.up_proj.dims();
181        if up_dims != [self.latent_dim, head_dim] {
182            return Err("up projection shape mismatch");
183        }
184        let rope_dims = self.rope_key.dims();
185        if rope_dims != [self.latent_dim, head_dim] {
186            return Err("rope key shape mismatch");
187        }
188        if self.compression_ratio == 0 || self.latent_dim == 0 {
189            return Err("invalid compression configuration");
190        }
191        Ok(())
192    }
193}
194
195/// Compressed KV cache compatible with paged attention APIs.
196#[derive(Debug, Clone)]
197pub struct CompressedKVCache<B: Backend> {
198    inner: PagedKVCache<B>,
199    mla: MultiHeadLatentAttention<B>,
200}
201
202impl<B: Backend> CompressedKVCache<B> {
203    /// Create a compressed KV cache for a given MLA configuration.
204    pub fn new(
205        max_blocks: usize,
206        num_layers: usize,
207        num_heads: usize,
208        mla: MultiHeadLatentAttention<B>,
209        device: &B::Device,
210    ) -> Self {
211        let inner = PagedKVCache::new(max_blocks, num_layers, num_heads, mla.latent_dim(), device);
212        Self { inner, mla }
213    }
214
215    /// Allocate a new sequence and return its id.
216    pub fn allocate_sequence(&mut self) -> usize {
217        self.inner.allocate_sequence()
218    }
219
220    /// Append uncompressed KV tensors (3D) into the compressed cache.
221    pub fn append(
222        &mut self,
223        layer: usize,
224        seq_id: usize,
225        keys: Tensor<B, 3>,
226        values: Tensor<B, 3>,
227    ) -> Result<(), &'static str> {
228        let (k_latent, v_latent) = self.mla.compress_kv_3d(keys, values)?;
229        self.inner.append(layer, seq_id, k_latent, v_latent)
230    }
231
232    /// Append uncompressed KV tensors (4D, batch=1) into the compressed cache.
233    pub fn append_batched(
234        &mut self,
235        layer: usize,
236        seq_id: usize,
237        keys: Tensor<B, 4>,
238        values: Tensor<B, 4>,
239    ) -> Result<(), &'static str> {
240        let [batch, num_heads, seq_len, head_dim] = keys.dims();
241        if batch != 1 {
242            return Err("compressed cache expects batch=1");
243        }
244        if values.dims() != [batch, num_heads, seq_len, head_dim] {
245            return Err("keys/values shape mismatch");
246        }
247        let keys = keys.reshape([num_heads, seq_len, head_dim]);
248        let values = values.reshape([num_heads, seq_len, head_dim]);
249        self.append(layer, seq_id, keys, values)
250    }
251
252    /// Append pre-compressed KV tensors directly.
253    pub fn append_compressed(
254        &mut self,
255        layer: usize,
256        seq_id: usize,
257        keys_latent: Tensor<B, 3>,
258        values_latent: Tensor<B, 3>,
259    ) -> Result<(), &'static str> {
260        self.inner.append(layer, seq_id, keys_latent, values_latent)
261    }
262
263    /// Get decompressed KV tensors for a layer/sequence.
264    pub fn get_kv(
265        &self,
266        layer: usize,
267        seq_id: usize,
268    ) -> Result<(Tensor<B, 3>, Tensor<B, 3>), &'static str> {
269        let (k_latent, v_latent) = self.inner.get_kv(layer, seq_id)?;
270        self.mla.decompress_kv_3d(k_latent, v_latent)
271    }
272
273    /// Get compressed KV tensors for a layer/sequence.
274    pub fn get_compressed_kv(
275        &self,
276        layer: usize,
277        seq_id: usize,
278    ) -> Result<(Tensor<B, 3>, Tensor<B, 3>), &'static str> {
279        self.inner.get_kv(layer, seq_id)
280    }
281
282    /// Iterate over decompressed KV blocks for a sequence.
283    pub fn iter_kv_blocks(
284        &self,
285        layer: usize,
286        seq_id: usize,
287    ) -> Result<Vec<(Tensor<B, 3>, Tensor<B, 3>)>, &'static str> {
288        let kv_iter = self.inner.iter_kv_blocks(layer, seq_id)?;
289        let mut blocks = Vec::new();
290        for block in kv_iter {
291            let [num_heads, _, latent_dim] = block.keys.dims();
292            let k_latent = block
293                .keys
294                .clone()
295                .slice([0..num_heads, 0..block.num_tokens, 0..latent_dim]);
296            let v_latent = block
297                .values
298                .clone()
299                .slice([0..num_heads, 0..block.num_tokens, 0..latent_dim]);
300            let (k_full, v_full) = self.mla.decompress_kv_3d(k_latent, v_latent)?;
301            blocks.push((k_full, v_full));
302        }
303        Ok(blocks)
304    }
305
306    /// Iterate over compressed KV blocks for a sequence.
307    pub fn iter_compressed_blocks(
308        &self,
309        layer: usize,
310        seq_id: usize,
311    ) -> Result<Vec<(Tensor<B, 3>, Tensor<B, 3>)>, &'static str> {
312        let kv_iter = self.inner.iter_kv_blocks(layer, seq_id)?;
313        let mut blocks = Vec::new();
314        for block in kv_iter {
315            let [num_heads, _, latent_dim] = block.keys.dims();
316            let k_latent = block
317                .keys
318                .clone()
319                .slice([0..num_heads, 0..block.num_tokens, 0..latent_dim]);
320            let v_latent = block
321                .values
322                .clone()
323                .slice([0..num_heads, 0..block.num_tokens, 0..latent_dim]);
324            blocks.push((k_latent, v_latent));
325        }
326        Ok(blocks)
327    }
328
329    /// Get sequence length for a layer/sequence.
330    pub fn seq_len(&self, layer: usize, seq_id: usize) -> Result<usize, &'static str> {
331        self.inner.seq_len(layer, seq_id)
332    }
333
334    /// Get the number of free blocks in the cache.
335    pub fn num_free_blocks(&self) -> usize {
336        self.inner.num_free_blocks()
337    }
338
339    /// Get block manager configuration.
340    pub fn num_heads(&self) -> usize {
341        self.inner.num_heads()
342    }
343
344    pub fn latent_dim(&self) -> usize {
345        self.inner.head_dim()
346    }
347
348    pub fn device(&self) -> &B::Device {
349        self.inner.device()
350    }
351
352    /// Release all blocks associated with a sequence id.
353    pub fn free_sequence(&mut self, seq_id: usize) -> Result<(), &'static str> {
354        self.inner.free_sequence(seq_id)
355    }
356}
357
358#[cfg(all(test, feature = "cpu"))]
359mod tests {
360    use super::*;
361    use burn::tensor::{Distribution, TensorData};
362    use burn_ndarray::NdArray;
363
364    type TestBackend = NdArray<f32>;
365
366    fn identity_matrix(dim: usize, device: &<TestBackend as Backend>::Device) -> Tensor<TestBackend, 2> {
367        let mut data = vec![0.0f32; dim * dim];
368        for i in 0..dim {
369            data[i * dim + i] = 1.0;
370        }
371        Tensor::from_data(TensorData::new(data, [dim, dim]), device)
372    }
373
374    fn zero_matrix(rows: usize, cols: usize, device: &<TestBackend as Backend>::Device) -> Tensor<TestBackend, 2> {
375        let data = vec![0.0f32; rows * cols];
376        Tensor::from_data(TensorData::new(data, [rows, cols]), device)
377    }
378
379    #[test]
380    fn test_mla_compress_decompress_roundtrip_identity() {
381        let device = <TestBackend as Backend>::Device::default();
382        let head_dim = 4;
383        let latent_dim = 4;
384        let down = identity_matrix(head_dim, &device);
385        let up = identity_matrix(head_dim, &device);
386        let rope = zero_matrix(latent_dim, head_dim, &device);
387
388        let mla = MultiHeadLatentAttention::new(1, latent_dim, down, up, rope);
389
390        let q = Tensor::<TestBackend, 4>::random([1, 2, 4, head_dim], Distribution::Normal(0.0, 0.5), &device);
391        let v = Tensor::<TestBackend, 4>::random([1, 2, 4, head_dim], Distribution::Normal(0.0, 0.5), &device);
392
393        let (k_latent, v_latent) = mla.compress_kv(q.clone(), v.clone()).expect("compress");
394        let (k_full, v_full) = mla.decompress_kv(k_latent, v_latent).expect("decompress");
395
396        let k_data = q.into_data().into_vec::<f32>().expect("k data");
397        let k_roundtrip = k_full.into_data().into_vec::<f32>().expect("k roundtrip");
398        for (idx, (orig, round)) in k_data.iter().zip(k_roundtrip.iter()).enumerate() {
399            let diff = (orig - round).abs();
400            assert!(diff < 1e-4, "k mismatch at {}: {} vs {}", idx, orig, round);
401        }
402
403        let v_data = v.into_data().into_vec::<f32>().expect("v data");
404        let v_roundtrip = v_full.into_data().into_vec::<f32>().expect("v roundtrip");
405        for (idx, (orig, round)) in v_data.iter().zip(v_roundtrip.iter()).enumerate() {
406            let diff = (orig - round).abs();
407            assert!(diff < 1e-4, "v mismatch at {}: {} vs {}", idx, orig, round);
408        }
409    }
410
411    #[test]
412    fn test_compressed_kv_cache_roundtrip() {
413        let device = <TestBackend as Backend>::Device::default();
414        let head_dim = 4;
415        let latent_dim = 4;
416        let down = identity_matrix(head_dim, &device);
417        let up = identity_matrix(head_dim, &device);
418        let rope = zero_matrix(latent_dim, head_dim, &device);
419
420        let mla = MultiHeadLatentAttention::new(1, latent_dim, down, up, rope);
421        let mut cache = CompressedKVCache::new(4, 1, 2, mla, &device);
422
423        let seq_id = cache.allocate_sequence();
424        let keys = Tensor::<TestBackend, 3>::random(
425            [2, 5, head_dim],
426            Distribution::Normal(0.0, 0.5),
427            &device,
428        );
429        let values = Tensor::<TestBackend, 3>::random(
430            [2, 5, head_dim],
431            Distribution::Normal(0.0, 0.5),
432            &device,
433        );
434
435        cache.append(0, seq_id, keys.clone(), values.clone()).expect("append");
436
437        let (k_full, v_full) = cache.get_kv(0, seq_id).expect("get kv");
438        assert_eq!(k_full.dims(), [2, 5, head_dim]);
439        assert_eq!(v_full.dims(), [2, 5, head_dim]);
440
441        let k_data = keys.into_data().into_vec::<f32>().expect("keys data");
442        let k_round = k_full.into_data().into_vec::<f32>().expect("keys roundtrip");
443        for (idx, (orig, round)) in k_data.iter().zip(k_round.iter()).enumerate() {
444            let diff = (orig - round).abs();
445            assert!(diff < 1e-4, "k mismatch at {}: {} vs {}", idx, orig, round);
446        }
447    }
448}