1use std::marker::PhantomData;
4
5use burn::tensor::backend::Backend;
6use burn::tensor::{Tensor, TensorData};
7
8use crate::ops::mla::CompressedKVCache as MlaCompressedKVCache;
9use crate::ops::paged_attention::PagedKVCache;
10
11const ENERGY_FRACTION: f32 = 0.9;
12const OUTLIER_STD_FACTOR: f32 = 3.0;
13const OUTLIER_MAD_FACTOR: f32 = 6.0;
14
15#[derive(Debug, Clone)]
17pub enum CompressionMethod {
18 LowRank { rank: usize },
20 VectorQuantization { codebook_size: usize },
22 Hybrid { rank: usize, quant_bits: u8 },
24}
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum KVLayout {
29 Unbatched {
31 num_heads: usize,
32 seq_len: usize,
33 head_dim: usize,
34 },
35 Batched {
37 batch: usize,
38 num_heads: usize,
39 seq_len: usize,
40 head_dim: usize,
41 },
42}
43
44#[derive(Debug, Clone)]
46pub struct CompressedKV<B: Backend> {
47 device: B::Device,
48 layout: KVLayout,
49 keys: CompressedTensor<B>,
50 values: CompressedTensor<B>,
51}
52
53impl<B: Backend> CompressedKV<B> {
54 pub fn layout(&self) -> KVLayout {
56 self.layout
57 }
58
59 pub fn device(&self) -> &B::Device {
61 &self.device
62 }
63}
64
65#[derive(Debug, Clone)]
67pub struct KVCacheCompressor<B: Backend> {
68 pub method: CompressionMethod,
70 pub quant_bits: u8,
72 _marker: PhantomData<B>,
74}
75
76impl<B: Backend> KVCacheCompressor<B> {
77 pub fn new(method: CompressionMethod, quant_bits: u8) -> Self {
79 Self { method, quant_bits, _marker: PhantomData }
80 }
81
82 pub fn method(&self) -> &CompressionMethod {
84 &self.method
85 }
86
87 pub fn quant_bits(&self) -> u8 {
89 self.quant_bits
90 }
91
92 pub fn compress_kv(
97 &self,
98 k: Tensor<B, 4>,
99 v: Tensor<B, 4>,
100 ) -> Result<CompressedKV<B>, &'static str> {
101 let [batch, num_heads, seq_len, head_dim] = k.dims();
102 if v.dims() != [batch, num_heads, seq_len, head_dim] {
103 return Err("keys/values shape mismatch");
104 }
105 let combined_heads = batch * num_heads;
106 let k = k.reshape([combined_heads, seq_len, head_dim]);
107 let v = v.reshape([combined_heads, seq_len, head_dim]);
108 let layout = KVLayout::Batched {
109 batch,
110 num_heads,
111 seq_len,
112 head_dim,
113 };
114 self.compress_with_layout(k, v, layout)
115 }
116
117 pub fn compress_kv_3d(
122 &self,
123 k: Tensor<B, 3>,
124 v: Tensor<B, 3>,
125 ) -> Result<CompressedKV<B>, &'static str> {
126 let [num_heads, seq_len, head_dim] = k.dims();
127 if v.dims() != [num_heads, seq_len, head_dim] {
128 return Err("keys/values shape mismatch");
129 }
130 let layout = KVLayout::Unbatched {
131 num_heads,
132 seq_len,
133 head_dim,
134 };
135 self.compress_with_layout(k, v, layout)
136 }
137
138 pub fn decompress_kv(
143 &self,
144 compressed: CompressedKV<B>,
145 ) -> Result<(Tensor<B, 4>, Tensor<B, 4>), &'static str> {
146 let (k, v, layout) = self.decompress_to_3d(compressed)?;
147 match layout {
148 KVLayout::Batched {
149 batch,
150 num_heads,
151 seq_len,
152 head_dim,
153 } => {
154 let k = k.reshape([batch, num_heads, seq_len, head_dim]);
155 let v = v.reshape([batch, num_heads, seq_len, head_dim]);
156 Ok((k, v))
157 }
158 KVLayout::Unbatched { .. } => Err("expected batched layout"),
159 }
160 }
161
162 pub fn decompress_kv_3d(
167 &self,
168 compressed: CompressedKV<B>,
169 ) -> Result<(Tensor<B, 3>, Tensor<B, 3>), &'static str> {
170 let (k, v, layout) = self.decompress_to_3d(compressed)?;
171 match layout {
172 KVLayout::Unbatched { .. } => Ok((k, v)),
173 KVLayout::Batched { .. } => Err("expected unbatched layout"),
174 }
175 }
176
177 pub fn compress_paged_cache(
179 &self,
180 cache: &PagedKVCache<B>,
181 layer: usize,
182 seq_id: usize,
183 ) -> Result<CompressedKV<B>, &'static str> {
184 let (k, v) = cache.get_kv(layer, seq_id)?;
185 self.compress_kv_3d(k, v)
186 }
187
188 pub fn decompress_to_paged_cache(
190 &self,
191 compressed: CompressedKV<B>,
192 cache: &mut PagedKVCache<B>,
193 layer: usize,
194 seq_id: usize,
195 ) -> Result<(), &'static str> {
196 let (k, v) = self.decompress_kv_3d(compressed)?;
197 cache.append(layer, seq_id, k, v)
198 }
199
200 pub fn compress_mla_cache(
202 &self,
203 cache: &MlaCompressedKVCache<B>,
204 layer: usize,
205 seq_id: usize,
206 ) -> Result<CompressedKV<B>, &'static str> {
207 let (k, v) = cache.get_kv(layer, seq_id)?;
208 self.compress_kv_3d(k, v)
209 }
210
211 pub fn decompress_to_mla_cache(
213 &self,
214 compressed: CompressedKV<B>,
215 cache: &mut MlaCompressedKVCache<B>,
216 layer: usize,
217 seq_id: usize,
218 ) -> Result<(), &'static str> {
219 let (k, v) = self.decompress_kv_3d(compressed)?;
220 cache.append(layer, seq_id, k, v)
221 }
222
223 fn compress_with_layout(
224 &self,
225 k: Tensor<B, 3>,
226 v: Tensor<B, 3>,
227 layout: KVLayout,
228 ) -> Result<CompressedKV<B>, &'static str> {
229 if k.dims() != v.dims() {
230 return Err("keys/values shape mismatch");
231 }
232 let device = k.device();
233 let keys = self.compress_tensor(k)?;
234 let values = self.compress_tensor(v)?;
235 Ok(CompressedKV {
236 device,
237 layout,
238 keys,
239 values,
240 })
241 }
242
243 fn decompress_to_3d(
244 &self,
245 compressed: CompressedKV<B>,
246 ) -> Result<(Tensor<B, 3>, Tensor<B, 3>, KVLayout), &'static str> {
247 let device = compressed.device.clone();
248 let keys = decompress_tensor(compressed.keys, &device)?;
249 let values = decompress_tensor(compressed.values, &device)?;
250 Ok((keys, values, compressed.layout))
251 }
252
253 fn compress_tensor(&self, tensor: Tensor<B, 3>) -> Result<CompressedTensor<B>, &'static str> {
254 match self.method {
255 CompressionMethod::LowRank { rank } => {
256 let low_rank = compress_low_rank(tensor, rank)?;
257 Ok(CompressedTensor::LowRank(low_rank))
258 }
259 CompressionMethod::VectorQuantization { codebook_size } => {
260 let bits = effective_vq_bits(self.quant_bits, codebook_size)?;
261 let vq = compress_vector_quantization(tensor, codebook_size, bits)?;
262 Ok(CompressedTensor::VectorQuantized(vq))
263 }
264 CompressionMethod::Hybrid { rank, quant_bits } => {
265 let bits = if quant_bits == 0 { self.quant_bits } else { quant_bits };
266 let hybrid = compress_hybrid(tensor, rank, bits)?;
267 Ok(CompressedTensor::Hybrid(hybrid))
268 }
269 }
270 }
271}
272
273#[derive(Debug, Clone)]
274enum CompressedTensor<B: Backend> {
275 LowRank(LowRankTensor<B>),
276 VectorQuantized(VectorQuantizedTensor),
277 Hybrid(HybridTensor),
278}
279
280#[derive(Debug, Clone)]
281struct LowRankTensor<B: Backend> {
282 projected: Tensor<B, 3>,
283 basis_indices: Vec<usize>,
284 original_head_dim: usize,
285}
286
287#[derive(Debug, Clone)]
288struct VectorQuantizedTensor {
289 codebook: Vec<f32>,
290 codes: QuantizedCodes,
291 vector_dim: usize,
292 shape: [usize; 3],
293 outliers: Vec<OutlierVector>,
294}
295
296#[derive(Debug, Clone)]
297struct OutlierVector {
298 index: usize,
299 values: Vec<f32>,
300}
301
302#[derive(Debug, Clone)]
303struct HybridTensor {
304 quantized: QuantizedTensor,
305 basis_indices: Vec<usize>,
306 original_head_dim: usize,
307}
308
309#[derive(Debug, Clone)]
310struct QuantizedTensor {
311 data: QuantizedData,
312 shape: [usize; 3],
313 scale: f32,
314 bits: u8,
315 outliers: Vec<(usize, f32)>,
316}
317
318#[derive(Debug, Clone)]
319enum QuantizedData {
320 Int8(Vec<i8>),
321 Int4(Vec<u8>),
322}
323
324#[derive(Debug, Clone)]
325enum QuantizedCodes {
326 Int4 { data: Vec<u8>, len: usize },
327 Int8 { data: Vec<u8> },
328}
329
330fn compress_low_rank<B: Backend>(
331 tensor: Tensor<B, 3>,
332 rank: usize,
333) -> Result<LowRankTensor<B>, &'static str> {
334 let [combined_heads, seq_len, head_dim] = tensor.dims();
335 if rank == 0 || head_dim == 0 {
336 return Err("invalid rank or head_dim");
337 }
338 let device = tensor.device();
339 let data = tensor
340 .into_data()
341 .into_vec::<f32>()
342 .map_err(|_| "low-rank compression expects f32 data")?;
343 let tokens = combined_heads * seq_len;
344 let mut energies = vec![0.0f32; head_dim];
345 for token in 0..tokens {
346 let base = token * head_dim;
347 for dim in 0..head_dim {
348 let value = data[base + dim];
349 energies[dim] += value * value;
350 }
351 }
352
353 let max_rank = rank.min(head_dim);
354 let mut ranked: Vec<(usize, f32)> = energies.into_iter().enumerate().collect();
355 ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
356
357 let total_energy: f32 = ranked.iter().map(|(_, energy)| *energy).sum();
358 let mut effective_rank = max_rank;
359 if total_energy > 0.0 {
360 let mut cumulative = 0.0f32;
361 effective_rank = 0;
362 for (_, energy) in ranked.iter() {
363 cumulative += *energy;
364 effective_rank += 1;
365 if cumulative / total_energy >= ENERGY_FRACTION {
366 break;
367 }
368 }
369 effective_rank = effective_rank.min(max_rank).max(1);
370 }
371
372 let basis_indices: Vec<usize> = ranked
373 .iter()
374 .take(effective_rank)
375 .map(|(idx, _)| *idx)
376 .collect();
377
378 let mut projected = vec![0.0f32; tokens * effective_rank];
379 for token in 0..tokens {
380 let in_base = token * head_dim;
381 let out_base = token * effective_rank;
382 for (r, &dim) in basis_indices.iter().enumerate() {
383 projected[out_base + r] = data[in_base + dim];
384 }
385 }
386
387 let projected = Tensor::<B, 3>::from_data(
388 TensorData::new(projected, [combined_heads, seq_len, effective_rank]),
389 &device,
390 );
391
392 Ok(LowRankTensor {
393 projected,
394 basis_indices,
395 original_head_dim: head_dim,
396 })
397}
398
399fn compress_hybrid<B: Backend>(
400 tensor: Tensor<B, 3>,
401 rank: usize,
402 quant_bits: u8,
403) -> Result<HybridTensor, &'static str> {
404 let low_rank = compress_low_rank(tensor, rank)?;
405 let [combined_heads, seq_len, effective_rank] = low_rank.projected.dims();
406 let projected = low_rank.projected.clone();
407 let data = projected
408 .into_data()
409 .into_vec::<f32>()
410 .map_err(|_| "hybrid compression expects f32 data")?;
411
412 let quantized = quantize_values(
413 &data,
414 [combined_heads, seq_len, effective_rank],
415 quant_bits,
416 )?;
417
418 Ok(HybridTensor {
419 quantized,
420 basis_indices: low_rank.basis_indices,
421 original_head_dim: low_rank.original_head_dim,
422 })
423}
424
425fn compress_vector_quantization<B: Backend>(
426 tensor: Tensor<B, 3>,
427 codebook_size: usize,
428 quant_bits: u8,
429) -> Result<VectorQuantizedTensor, &'static str> {
430 if codebook_size == 0 {
431 return Err("codebook_size must be > 0");
432 }
433 let [combined_heads, seq_len, head_dim] = tensor.dims();
434 let data = tensor
435 .into_data()
436 .into_vec::<f32>()
437 .map_err(|_| "vector quantization expects f32 data")?;
438 let tokens = combined_heads * seq_len;
439
440 if tokens == 0 {
441 return Err("vector quantization expects non-empty tensor");
442 }
443 if codebook_size > 256 {
444 return Err("codebook_size must be <= 256");
445 }
446 if quant_bits == 4 && codebook_size > 16 {
447 return Err("codebook_size must be <= 16 for INT4");
448 }
449
450 let mut codebook = vec![0.0f32; codebook_size * head_dim];
451 for c in 0..codebook_size {
452 if c < tokens {
453 let start = c * head_dim;
454 let end = start + head_dim;
455 codebook[c * head_dim..(c + 1) * head_dim].copy_from_slice(&data[start..end]);
456 }
457 }
458
459 refine_codebook(&data, &mut codebook, codebook_size, head_dim, tokens);
460
461 let (codes, distances) = assign_codes(&data, &codebook, codebook_size, head_dim, tokens);
462 let outliers = detect_vq_outliers(&data, &distances, head_dim);
463 let packed = pack_codes(&codes, quant_bits);
464
465 Ok(VectorQuantizedTensor {
466 codebook,
467 codes: packed,
468 vector_dim: head_dim,
469 shape: [combined_heads, seq_len, head_dim],
470 outliers,
471 })
472}
473
474fn refine_codebook(
475 data: &[f32],
476 codebook: &mut [f32],
477 codebook_size: usize,
478 vector_dim: usize,
479 tokens: usize,
480) {
481 const KMEANS_ITERS: usize = 2;
482 for _ in 0..KMEANS_ITERS {
483 let mut counts = vec![0usize; codebook_size];
484 let mut sums = vec![0.0f32; codebook_size * vector_dim];
485
486 for token in 0..tokens {
487 let (idx, _) = nearest_centroid(
488 data,
489 codebook,
490 codebook_size,
491 vector_dim,
492 token,
493 );
494 counts[idx] += 1;
495 let base = token * vector_dim;
496 let sum_base = idx * vector_dim;
497 for d in 0..vector_dim {
498 sums[sum_base + d] += data[base + d];
499 }
500 }
501
502 for c in 0..codebook_size {
503 if counts[c] > 0 {
504 let base = c * vector_dim;
505 for d in 0..vector_dim {
506 codebook[base + d] = sums[base + d] / counts[c] as f32;
507 }
508 }
509 }
510 }
511}
512
513fn assign_codes(
514 data: &[f32],
515 codebook: &[f32],
516 codebook_size: usize,
517 vector_dim: usize,
518 tokens: usize,
519) -> (Vec<u8>, Vec<f32>) {
520 let mut codes = Vec::with_capacity(tokens);
521 let mut distances = Vec::with_capacity(tokens);
522 for token in 0..tokens {
523 let (idx, dist) = nearest_centroid(data, codebook, codebook_size, vector_dim, token);
524 codes.push(idx as u8);
525 distances.push(dist);
526 }
527 (codes, distances)
528}
529
530fn detect_vq_outliers(
531 data: &[f32],
532 distances: &[f32],
533 vector_dim: usize,
534) -> Vec<OutlierVector> {
535 if distances.is_empty() {
536 return Vec::new();
537 }
538 let mean = distances.iter().sum::<f32>() / distances.len() as f32;
539 let mut var = 0.0f32;
540 for &dist in distances {
541 let diff = dist - mean;
542 var += diff * diff;
543 }
544 let std = (var / distances.len() as f32).sqrt();
545 let threshold = mean + OUTLIER_STD_FACTOR * std;
546
547 let mut outliers = Vec::new();
548 for (token, &dist) in distances.iter().enumerate() {
549 if dist > threshold {
550 let base = token * vector_dim;
551 let values = data[base..base + vector_dim].to_vec();
552 outliers.push(OutlierVector { index: token, values });
553 }
554 }
555 outliers
556}
557
558fn nearest_centroid(
559 data: &[f32],
560 codebook: &[f32],
561 codebook_size: usize,
562 vector_dim: usize,
563 token: usize,
564) -> (usize, f32) {
565 let base = token * vector_dim;
566 let mut best_idx = 0;
567 let mut best_dist = f32::INFINITY;
568 for c in 0..codebook_size {
569 let mut dist = 0.0f32;
570 let code_base = c * vector_dim;
571 for d in 0..vector_dim {
572 let diff = data[base + d] - codebook[code_base + d];
573 dist += diff * diff;
574 }
575 if dist < best_dist {
576 best_dist = dist;
577 best_idx = c;
578 }
579 }
580 (best_idx, best_dist)
581}
582
583fn pack_codes(codes: &[u8], bits: u8) -> QuantizedCodes {
584 match bits {
585 4 => QuantizedCodes::Int4 {
586 data: pack_nibbles(codes),
587 len: codes.len(),
588 },
589 _ => QuantizedCodes::Int8 { data: codes.to_vec() },
590 }
591}
592
593fn unpack_codes(codes: &QuantizedCodes) -> Vec<u8> {
594 match codes {
595 QuantizedCodes::Int4 { data, len } => unpack_nibbles(data, *len),
596 QuantizedCodes::Int8 { data } => data.clone(),
597 }
598}
599
600fn quantize_values(
601 data: &[f32],
602 shape: [usize; 3],
603 bits: u8,
604) -> Result<QuantizedTensor, &'static str> {
605 if bits != 4 && bits != 8 {
606 return Err("quant_bits must be 4 or 8");
607 }
608 if data.is_empty() {
609 return Err("cannot quantize empty tensor");
610 }
611
612 let mut sum_abs = 0.0f32;
613 let mut abs_values = Vec::with_capacity(data.len());
614 for value in data {
615 let abs = value.abs();
616 sum_abs += abs;
617 abs_values.push(abs);
618 }
619 let mean = sum_abs / data.len() as f32;
620 let mut var = 0.0f32;
621 for &abs in &abs_values {
622 let diff = abs - mean;
623 var += diff * diff;
624 }
625 let std = (var / data.len() as f32).sqrt();
626 abs_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
627 let median = abs_values[abs_values.len() / 2];
628 let mut deviations: Vec<f32> = abs_values.iter().map(|v| (v - median).abs()).collect();
629 deviations.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
630 let mad = deviations[deviations.len() / 2];
631 let threshold = if mad > 0.0 {
632 median + OUTLIER_MAD_FACTOR * mad
633 } else {
634 mean + OUTLIER_STD_FACTOR * std
635 };
636
637 let mut clipped = Vec::with_capacity(data.len());
638 let mut outliers = Vec::new();
639 let mut max_abs = 0.0f32;
640
641 for (idx, &value) in data.iter().enumerate() {
642 let abs = value.abs();
643 if abs > threshold {
644 outliers.push((idx, value));
645 let sign = if value.is_sign_negative() { -1.0 } else { 1.0 };
646 let clipped_value = sign * threshold;
647 max_abs = max_abs.max(clipped_value.abs());
648 clipped.push(clipped_value);
649 } else {
650 max_abs = max_abs.max(abs);
651 clipped.push(value);
652 }
653 }
654
655 let max_level = if bits == 4 { 7.0 } else { 127.0 };
656 let scale = if max_abs > 0.0 { max_abs / max_level } else { 1.0 };
657
658 let quantized = match bits {
659 4 => {
660 let mut values = Vec::with_capacity(clipped.len());
661 for value in &clipped {
662 let q = (value / scale).round().clamp(-max_level, max_level) as i8;
663 values.push(q);
664 }
665 QuantizedData::Int4(pack_int4(&values))
666 }
667 _ => {
668 let mut values = Vec::with_capacity(clipped.len());
669 for value in &clipped {
670 let q = (value / scale).round().clamp(-max_level, max_level) as i8;
671 values.push(q);
672 }
673 QuantizedData::Int8(values)
674 }
675 };
676
677 Ok(QuantizedTensor {
678 data: quantized,
679 shape,
680 scale,
681 bits,
682 outliers,
683 })
684}
685
686fn dequantize_values<B: Backend>(
687 quantized: &QuantizedTensor,
688 device: &B::Device,
689) -> Result<Tensor<B, 3>, &'static str> {
690 let num_values = quantized.shape[0] * quantized.shape[1] * quantized.shape[2];
691 let mut values: Vec<f32> = match &quantized.data {
692 QuantizedData::Int8(data) => data.iter().map(|&q| q as f32 * quantized.scale).collect(),
693 QuantizedData::Int4(data) => {
694 let unpacked = unpack_int4(data, num_values);
695 unpacked
696 .into_iter()
697 .map(|q| q as f32 * quantized.scale)
698 .collect()
699 }
700 };
701
702 for &(idx, value) in &quantized.outliers {
703 if idx < values.len() {
704 values[idx] = value;
705 }
706 }
707
708 Ok(Tensor::<B, 3>::from_data(
709 TensorData::new(values, quantized.shape),
710 device,
711 ))
712}
713
714fn decompress_tensor<B: Backend>(
715 compressed: CompressedTensor<B>,
716 device: &B::Device,
717) -> Result<Tensor<B, 3>, &'static str> {
718 match compressed {
719 CompressedTensor::LowRank(low_rank) => decompress_low_rank(low_rank, device),
720 CompressedTensor::VectorQuantized(vq) => decompress_vector_quantized(vq, device),
721 CompressedTensor::Hybrid(hybrid) => decompress_hybrid(hybrid, device),
722 }
723}
724
725fn decompress_low_rank<B: Backend>(
726 low_rank: LowRankTensor<B>,
727 device: &B::Device,
728) -> Result<Tensor<B, 3>, &'static str> {
729 let [combined_heads, seq_len, rank] = low_rank.projected.dims();
730 if rank == 0 {
731 return Err("low-rank projection has rank 0");
732 }
733 let data = low_rank
734 .projected
735 .into_data()
736 .into_vec::<f32>()
737 .map_err(|_| "low-rank decompression expects f32 data")?;
738 let tokens = combined_heads * seq_len;
739 let head_dim = low_rank.original_head_dim;
740 let mut full = vec![0.0f32; tokens * head_dim];
741
742 for token in 0..tokens {
743 let in_base = token * rank;
744 let out_base = token * head_dim;
745 for (r, &dim) in low_rank.basis_indices.iter().enumerate() {
746 if dim < head_dim {
747 full[out_base + dim] = data[in_base + r];
748 }
749 }
750 }
751
752 Ok(Tensor::<B, 3>::from_data(
753 TensorData::new(full, [combined_heads, seq_len, head_dim]),
754 device,
755 ))
756}
757
758fn decompress_hybrid<B: Backend>(
759 hybrid: HybridTensor,
760 device: &B::Device,
761) -> Result<Tensor<B, 3>, &'static str> {
762 let projected = dequantize_values::<B>(&hybrid.quantized, device)?;
763 let low_rank = LowRankTensor {
764 projected,
765 basis_indices: hybrid.basis_indices,
766 original_head_dim: hybrid.original_head_dim,
767 };
768 decompress_low_rank(low_rank, device)
769}
770
771fn decompress_vector_quantized<B: Backend>(
772 vq: VectorQuantizedTensor,
773 device: &B::Device,
774) -> Result<Tensor<B, 3>, &'static str> {
775 let tokens = vq.shape[0] * vq.shape[1];
776 let vector_dim = vq.vector_dim;
777 let codes = unpack_codes(&vq.codes);
778 if codes.len() != tokens {
779 return Err("vector quantization code length mismatch");
780 }
781 let mut data = vec![0.0f32; tokens * vector_dim];
782
783 for token in 0..tokens {
784 let code = codes[token] as usize;
785 let base = token * vector_dim;
786 let code_base = code * vector_dim;
787 for d in 0..vector_dim {
788 data[base + d] = vq.codebook[code_base + d];
789 }
790 }
791
792 for outlier in &vq.outliers {
793 let base = outlier.index * vector_dim;
794 if base + vector_dim <= data.len() {
795 data[base..base + vector_dim].copy_from_slice(&outlier.values);
796 }
797 }
798
799 Ok(Tensor::<B, 3>::from_data(
800 TensorData::new(data, vq.shape),
801 device,
802 ))
803}
804
805fn effective_vq_bits(bits: u8, codebook_size: usize) -> Result<u8, &'static str> {
806 match bits {
807 4 => {
808 if codebook_size > 16 {
809 Err("codebook_size must be <= 16 for INT4")
810 } else {
811 Ok(4)
812 }
813 }
814 8 => {
815 if codebook_size > 256 {
816 Err("codebook_size must be <= 256 for INT8")
817 } else {
818 Ok(8)
819 }
820 }
821 _ => {
822 if codebook_size <= 16 {
823 Ok(4)
824 } else if codebook_size <= 256 {
825 Ok(8)
826 } else {
827 Err("codebook_size must be <= 256")
828 }
829 }
830 }
831}
832
833fn pack_nibbles(values: &[u8]) -> Vec<u8> {
834 let mut packed = Vec::with_capacity((values.len() + 1) / 2);
835 let mut iter = values.iter();
836 loop {
837 let low = match iter.next() {
838 Some(v) => v & 0x0F,
839 None => break,
840 };
841 let high = match iter.next() {
842 Some(v) => (v & 0x0F) << 4,
843 None => 0,
844 };
845 packed.push(low | high);
846 }
847 packed
848}
849
850fn unpack_nibbles(values: &[u8], len: usize) -> Vec<u8> {
851 let mut unpacked = Vec::with_capacity(len);
852 for &byte in values {
853 if unpacked.len() < len {
854 unpacked.push(byte & 0x0F);
855 }
856 if unpacked.len() < len {
857 unpacked.push((byte >> 4) & 0x0F);
858 }
859 }
860 unpacked
861}
862
863fn pack_int4(values: &[i8]) -> Vec<u8> {
864 let mut packed = Vec::with_capacity((values.len() + 1) / 2);
865 let mut iter = values.iter();
866 loop {
867 let low = match iter.next() {
868 Some(v) => (*v as i16 + 8).clamp(0, 15) as u8,
869 None => break,
870 };
871 let high = match iter.next() {
872 Some(v) => ((*v as i16 + 8).clamp(0, 15) as u8) << 4,
873 None => 0,
874 };
875 packed.push(low | high);
876 }
877 packed
878}
879
880fn unpack_int4(values: &[u8], len: usize) -> Vec<i8> {
881 let mut unpacked = Vec::with_capacity(len);
882 for &byte in values {
883 if unpacked.len() < len {
884 unpacked.push(((byte & 0x0F) as i8) - 8);
885 }
886 if unpacked.len() < len {
887 unpacked.push(((byte >> 4) as i8) - 8);
888 }
889 }
890 unpacked
891}
892
893#[cfg(all(test, feature = "cpu"))]
894mod tests {
895 use super::*;
896 use burn::tensor::{Tensor, TensorData};
897 use burn_ndarray::NdArray;
898
899 #[test]
900 fn test_low_rank_roundtrip_preserves_top_dims() {
901 let device = <NdArray<f32> as Backend>::Device::default();
902 let num_heads = 2;
903 let seq_len = 2;
904 let head_dim = 4;
905 let mut data = Vec::new();
906 for _ in 0..(num_heads * seq_len) {
907 data.extend_from_slice(&[1.0, 0.1, 1.0, 0.1]);
908 }
909
910 let k = Tensor::<NdArray<f32>, 3>::from_data(
911 TensorData::new(data.clone(), [num_heads, seq_len, head_dim]),
912 &device,
913 );
914 let v = Tensor::<NdArray<f32>, 3>::from_data(
915 TensorData::new(data, [num_heads, seq_len, head_dim]),
916 &device,
917 );
918
919 let compressor =
920 KVCacheCompressor::<NdArray<f32>>::new(CompressionMethod::LowRank { rank: 3 }, 8);
921 let compressed = compressor.compress_kv_3d(k, v).expect("compress");
922 let (k_full, _v_full) = compressor.decompress_kv_3d(compressed).expect("decompress");
923
924 let k_data = k_full.into_data().into_vec::<f32>().expect("data");
925 for token in 0..(num_heads * seq_len) {
926 let base = token * head_dim;
927 assert!((k_data[base + 1]).abs() < 1e-3);
928 assert!((k_data[base + 3]).abs() < 1e-3);
929 }
930 }
931
932 #[test]
933 fn test_hybrid_quantization_outlier_preserved() {
934 let device = <NdArray<f32> as Backend>::Device::default();
935 let num_heads = 1;
936 let seq_len = 4;
937 let head_dim = 1;
938 let data = vec![0.1, 0.2, 0.15, 10.0];
939
940 let k = Tensor::<NdArray<f32>, 3>::from_data(
941 TensorData::new(data.clone(), [num_heads, seq_len, head_dim]),
942 &device,
943 );
944 let v = Tensor::<NdArray<f32>, 3>::from_data(
945 TensorData::new(data, [num_heads, seq_len, head_dim]),
946 &device,
947 );
948
949 let compressor = KVCacheCompressor::<NdArray<f32>>::new(
950 CompressionMethod::Hybrid {
951 rank: 1,
952 quant_bits: 4,
953 },
954 4,
955 );
956 let compressed = compressor.compress_kv_3d(k, v).expect("compress");
957 let (k_full, _) = compressor.decompress_kv_3d(compressed).expect("decompress");
958 let k_data = k_full.into_data().into_vec::<f32>().expect("data");
959
960 assert!((k_data[3] - 10.0).abs() < 1e-3);
961 }
962
963 #[test]
964 fn test_vector_quantization_roundtrip() {
965 let device = <NdArray<f32> as Backend>::Device::default();
966 let num_heads = 1;
967 let seq_len = 2;
968 let head_dim = 2;
969 let data = vec![1.0, 0.0, -1.0, 0.0];
970 let original_data = data.clone();
971
972 let k = Tensor::<NdArray<f32>, 3>::from_data(
973 TensorData::new(data.clone(), [num_heads, seq_len, head_dim]),
974 &device,
975 );
976 let v = Tensor::<NdArray<f32>, 3>::from_data(
977 TensorData::new(data, [num_heads, seq_len, head_dim]),
978 &device,
979 );
980
981 let compressor = KVCacheCompressor::<NdArray<f32>>::new(
982 CompressionMethod::VectorQuantization { codebook_size: 2 },
983 8,
984 );
985 let compressed = compressor.compress_kv_3d(k, v).expect("compress");
986 let (k_full, _) = compressor.decompress_kv_3d(compressed).expect("decompress");
987 let k_data = k_full.into_data().into_vec::<f32>().expect("data");
988
989 for (orig, round) in original_data.iter().zip(k_data.iter()) {
990 assert!((orig - round).abs() < 1e-4);
991 }
992 }
993
994 #[test]
995 fn test_paged_cache_compatibility() {
996 let device = <NdArray<f32> as Backend>::Device::default();
997 let mut cache = PagedKVCache::<NdArray<f32>>::new(4, 1, 1, 2, &device);
998 let seq_id = cache.allocate_sequence();
999 let keys = Tensor::<NdArray<f32>, 3>::from_data(
1000 TensorData::new(vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6], [1, 3, 2]),
1001 &device,
1002 );
1003 let values = Tensor::<NdArray<f32>, 3>::from_data(
1004 TensorData::new(vec![0.6, 0.5, 0.4, 0.3, 0.2, 0.1], [1, 3, 2]),
1005 &device,
1006 );
1007 cache.append(0, seq_id, keys, values).expect("append");
1008
1009 let compressor =
1010 KVCacheCompressor::<NdArray<f32>>::new(CompressionMethod::LowRank { rank: 2 }, 8);
1011 let compressed = compressor
1012 .compress_paged_cache(&cache, 0, seq_id)
1013 .expect("compress paged");
1014
1015 let seq_id2 = cache.allocate_sequence();
1016 compressor
1017 .decompress_to_paged_cache(compressed, &mut cache, 0, seq_id2)
1018 .expect("decompress paged");
1019 assert_eq!(cache.seq_len(0, seq_id2).expect("seq len"), 3);
1020 }
1021
1022 fn identity_matrix(
1023 dim: usize,
1024 device: &<NdArray<f32> as Backend>::Device,
1025 ) -> Tensor<NdArray<f32>, 2> {
1026 let mut data = vec![0.0f32; dim * dim];
1027 for i in 0..dim {
1028 data[i * dim + i] = 1.0;
1029 }
1030 Tensor::from_data(TensorData::new(data, [dim, dim]), device)
1031 }
1032
1033 fn zero_matrix(
1034 rows: usize,
1035 cols: usize,
1036 device: &<NdArray<f32> as Backend>::Device,
1037 ) -> Tensor<NdArray<f32>, 2> {
1038 Tensor::from_data(TensorData::new(vec![0.0f32; rows * cols], [rows, cols]), device)
1039 }
1040
1041 #[test]
1042 fn test_mla_cache_compatibility() {
1043 use crate::ops::mla::MultiHeadLatentAttention;
1044
1045 let device = <NdArray<f32> as Backend>::Device::default();
1046 let head_dim = 2;
1047 let latent_dim = 2;
1048 let down = identity_matrix(head_dim, &device);
1049 let up = identity_matrix(head_dim, &device);
1050 let rope = zero_matrix(latent_dim, head_dim, &device);
1051
1052 let mla = MultiHeadLatentAttention::new(1, latent_dim, down, up, rope);
1053 let mut cache = MlaCompressedKVCache::new(4, 1, 1, mla, &device);
1054 let seq_id = cache.allocate_sequence();
1055
1056 let keys = Tensor::<NdArray<f32>, 3>::from_data(
1057 TensorData::new(vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6], [1, 3, 2]),
1058 &device,
1059 );
1060 let values = Tensor::<NdArray<f32>, 3>::from_data(
1061 TensorData::new(vec![0.6, 0.5, 0.4, 0.3, 0.2, 0.1], [1, 3, 2]),
1062 &device,
1063 );
1064 cache.append(0, seq_id, keys, values).expect("append");
1065
1066 let compressor =
1067 KVCacheCompressor::<NdArray<f32>>::new(CompressionMethod::LowRank { rank: 2 }, 8);
1068 let compressed = compressor
1069 .compress_mla_cache(&cache, 0, seq_id)
1070 .expect("compress mla");
1071
1072 let seq_id2 = cache.allocate_sequence();
1073 compressor
1074 .decompress_to_mla_cache(compressed, &mut cache, 0, seq_id2)
1075 .expect("decompress mla");
1076 assert_eq!(cache.seq_len(0, seq_id2).expect("seq len"), 3);
1077 }
1078}