gllm_kernels/ops/
kv_compression.rs

1//! KV cache compression utilities (low-rank, vector quantization, hybrid).
2
3use std::marker::PhantomData;
4
5use burn::tensor::backend::Backend;
6use burn::tensor::{Tensor, TensorData};
7
8use crate::ops::mla::CompressedKVCache as MlaCompressedKVCache;
9use crate::ops::paged_attention::PagedKVCache;
10
11const ENERGY_FRACTION: f32 = 0.9;
12const OUTLIER_STD_FACTOR: f32 = 3.0;
13const OUTLIER_MAD_FACTOR: f32 = 6.0;
14
15/// Compression method for KV caches.
16#[derive(Debug, Clone)]
17pub enum CompressionMethod {
18    /// Low-rank projection (PALU-style).
19    LowRank { rank: usize },
20    /// Vector quantization (CommVQ-style).
21    VectorQuantization { codebook_size: usize },
22    /// Low-rank projection with quantization.
23    Hybrid { rank: usize, quant_bits: u8 },
24}
25
26/// Layout metadata for compressed KV tensors.
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum KVLayout {
29    /// Unbatched KV: [num_heads, seq_len, head_dim].
30    Unbatched {
31        num_heads: usize,
32        seq_len: usize,
33        head_dim: usize,
34    },
35    /// Batched KV: [batch, num_heads, seq_len, head_dim].
36    Batched {
37        batch: usize,
38        num_heads: usize,
39        seq_len: usize,
40        head_dim: usize,
41    },
42}
43
44/// Compressed KV pair with metadata.
45#[derive(Debug, Clone)]
46pub struct CompressedKV<B: Backend> {
47    device: B::Device,
48    layout: KVLayout,
49    keys: CompressedTensor<B>,
50    values: CompressedTensor<B>,
51}
52
53impl<B: Backend> CompressedKV<B> {
54    /// Access the layout metadata.
55    pub fn layout(&self) -> KVLayout {
56        self.layout
57    }
58
59    /// Device used for reconstruction.
60    pub fn device(&self) -> &B::Device {
61        &self.device
62    }
63}
64
65/// KV cache compressor for low-rank and quantized representations.
66#[derive(Debug, Clone)]
67pub struct KVCacheCompressor<B: Backend> {
68    /// Compression method selection.
69    pub method: CompressionMethod,
70    /// Default quantization bits (INT4/INT8).
71    pub quant_bits: u8,
72    /// Phantom marker for backend type.
73    _marker: PhantomData<B>,
74}
75
76impl<B: Backend> KVCacheCompressor<B> {
77    /// Create a new KV cache compressor.
78    pub fn new(method: CompressionMethod, quant_bits: u8) -> Self {
79        Self { method, quant_bits, _marker: PhantomData }
80    }
81
82    /// Access the compression method.
83    pub fn method(&self) -> &CompressionMethod {
84        &self.method
85    }
86
87    /// Default quantization bits.
88    pub fn quant_bits(&self) -> u8 {
89        self.quant_bits
90    }
91
92    /// Compress batched KV tensors.
93    ///
94    /// # Shapes
95    /// * `k`, `v`: [batch, num_heads, seq_len, head_dim]
96    pub fn compress_kv(
97        &self,
98        k: Tensor<B, 4>,
99        v: Tensor<B, 4>,
100    ) -> Result<CompressedKV<B>, &'static str> {
101        let [batch, num_heads, seq_len, head_dim] = k.dims();
102        if v.dims() != [batch, num_heads, seq_len, head_dim] {
103            return Err("keys/values shape mismatch");
104        }
105        let combined_heads = batch * num_heads;
106        let k = k.reshape([combined_heads, seq_len, head_dim]);
107        let v = v.reshape([combined_heads, seq_len, head_dim]);
108        let layout = KVLayout::Batched {
109            batch,
110            num_heads,
111            seq_len,
112            head_dim,
113        };
114        self.compress_with_layout(k, v, layout)
115    }
116
117    /// Compress KV tensors without batch dimension.
118    ///
119    /// # Shapes
120    /// * `k`, `v`: [num_heads, seq_len, head_dim]
121    pub fn compress_kv_3d(
122        &self,
123        k: Tensor<B, 3>,
124        v: Tensor<B, 3>,
125    ) -> Result<CompressedKV<B>, &'static str> {
126        let [num_heads, seq_len, head_dim] = k.dims();
127        if v.dims() != [num_heads, seq_len, head_dim] {
128            return Err("keys/values shape mismatch");
129        }
130        let layout = KVLayout::Unbatched {
131            num_heads,
132            seq_len,
133            head_dim,
134        };
135        self.compress_with_layout(k, v, layout)
136    }
137
138    /// Decompress batched KV tensors.
139    ///
140    /// # Shapes
141    /// * returns: [batch, num_heads, seq_len, head_dim]
142    pub fn decompress_kv(
143        &self,
144        compressed: CompressedKV<B>,
145    ) -> Result<(Tensor<B, 4>, Tensor<B, 4>), &'static str> {
146        let (k, v, layout) = self.decompress_to_3d(compressed)?;
147        match layout {
148            KVLayout::Batched {
149                batch,
150                num_heads,
151                seq_len,
152                head_dim,
153            } => {
154                let k = k.reshape([batch, num_heads, seq_len, head_dim]);
155                let v = v.reshape([batch, num_heads, seq_len, head_dim]);
156                Ok((k, v))
157            }
158            KVLayout::Unbatched { .. } => Err("expected batched layout"),
159        }
160    }
161
162    /// Decompress KV tensors without batch dimension.
163    ///
164    /// # Shapes
165    /// * returns: [num_heads, seq_len, head_dim]
166    pub fn decompress_kv_3d(
167        &self,
168        compressed: CompressedKV<B>,
169    ) -> Result<(Tensor<B, 3>, Tensor<B, 3>), &'static str> {
170        let (k, v, layout) = self.decompress_to_3d(compressed)?;
171        match layout {
172            KVLayout::Unbatched { .. } => Ok((k, v)),
173            KVLayout::Batched { .. } => Err("expected unbatched layout"),
174        }
175    }
176
177    /// Compress a sequence from a paged KV cache.
178    pub fn compress_paged_cache(
179        &self,
180        cache: &PagedKVCache<B>,
181        layer: usize,
182        seq_id: usize,
183    ) -> Result<CompressedKV<B>, &'static str> {
184        let (k, v) = cache.get_kv(layer, seq_id)?;
185        self.compress_kv_3d(k, v)
186    }
187
188    /// Decompress into a paged KV cache by appending tokens.
189    pub fn decompress_to_paged_cache(
190        &self,
191        compressed: CompressedKV<B>,
192        cache: &mut PagedKVCache<B>,
193        layer: usize,
194        seq_id: usize,
195    ) -> Result<(), &'static str> {
196        let (k, v) = self.decompress_kv_3d(compressed)?;
197        cache.append(layer, seq_id, k, v)
198    }
199
200    /// Compress a sequence from an MLA compressed cache.
201    pub fn compress_mla_cache(
202        &self,
203        cache: &MlaCompressedKVCache<B>,
204        layer: usize,
205        seq_id: usize,
206    ) -> Result<CompressedKV<B>, &'static str> {
207        let (k, v) = cache.get_kv(layer, seq_id)?;
208        self.compress_kv_3d(k, v)
209    }
210
211    /// Decompress into an MLA compressed cache by appending tokens.
212    pub fn decompress_to_mla_cache(
213        &self,
214        compressed: CompressedKV<B>,
215        cache: &mut MlaCompressedKVCache<B>,
216        layer: usize,
217        seq_id: usize,
218    ) -> Result<(), &'static str> {
219        let (k, v) = self.decompress_kv_3d(compressed)?;
220        cache.append(layer, seq_id, k, v)
221    }
222
223    fn compress_with_layout(
224        &self,
225        k: Tensor<B, 3>,
226        v: Tensor<B, 3>,
227        layout: KVLayout,
228    ) -> Result<CompressedKV<B>, &'static str> {
229        if k.dims() != v.dims() {
230            return Err("keys/values shape mismatch");
231        }
232        let device = k.device();
233        let keys = self.compress_tensor(k)?;
234        let values = self.compress_tensor(v)?;
235        Ok(CompressedKV {
236            device,
237            layout,
238            keys,
239            values,
240        })
241    }
242
243    fn decompress_to_3d(
244        &self,
245        compressed: CompressedKV<B>,
246    ) -> Result<(Tensor<B, 3>, Tensor<B, 3>, KVLayout), &'static str> {
247        let device = compressed.device.clone();
248        let keys = decompress_tensor(compressed.keys, &device)?;
249        let values = decompress_tensor(compressed.values, &device)?;
250        Ok((keys, values, compressed.layout))
251    }
252
253    fn compress_tensor(&self, tensor: Tensor<B, 3>) -> Result<CompressedTensor<B>, &'static str> {
254        match self.method {
255            CompressionMethod::LowRank { rank } => {
256                let low_rank = compress_low_rank(tensor, rank)?;
257                Ok(CompressedTensor::LowRank(low_rank))
258            }
259            CompressionMethod::VectorQuantization { codebook_size } => {
260                let bits = effective_vq_bits(self.quant_bits, codebook_size)?;
261                let vq = compress_vector_quantization(tensor, codebook_size, bits)?;
262                Ok(CompressedTensor::VectorQuantized(vq))
263            }
264            CompressionMethod::Hybrid { rank, quant_bits } => {
265                let bits = if quant_bits == 0 { self.quant_bits } else { quant_bits };
266                let hybrid = compress_hybrid(tensor, rank, bits)?;
267                Ok(CompressedTensor::Hybrid(hybrid))
268            }
269        }
270    }
271}
272
273#[derive(Debug, Clone)]
274enum CompressedTensor<B: Backend> {
275    LowRank(LowRankTensor<B>),
276    VectorQuantized(VectorQuantizedTensor),
277    Hybrid(HybridTensor),
278}
279
280#[derive(Debug, Clone)]
281struct LowRankTensor<B: Backend> {
282    projected: Tensor<B, 3>,
283    basis_indices: Vec<usize>,
284    original_head_dim: usize,
285}
286
287#[derive(Debug, Clone)]
288struct VectorQuantizedTensor {
289    codebook: Vec<f32>,
290    codes: QuantizedCodes,
291    vector_dim: usize,
292    shape: [usize; 3],
293    outliers: Vec<OutlierVector>,
294}
295
296#[derive(Debug, Clone)]
297struct OutlierVector {
298    index: usize,
299    values: Vec<f32>,
300}
301
302#[derive(Debug, Clone)]
303struct HybridTensor {
304    quantized: QuantizedTensor,
305    basis_indices: Vec<usize>,
306    original_head_dim: usize,
307}
308
309#[derive(Debug, Clone)]
310struct QuantizedTensor {
311    data: QuantizedData,
312    shape: [usize; 3],
313    scale: f32,
314    bits: u8,
315    outliers: Vec<(usize, f32)>,
316}
317
318#[derive(Debug, Clone)]
319enum QuantizedData {
320    Int8(Vec<i8>),
321    Int4(Vec<u8>),
322}
323
324#[derive(Debug, Clone)]
325enum QuantizedCodes {
326    Int4 { data: Vec<u8>, len: usize },
327    Int8 { data: Vec<u8> },
328}
329
330fn compress_low_rank<B: Backend>(
331    tensor: Tensor<B, 3>,
332    rank: usize,
333) -> Result<LowRankTensor<B>, &'static str> {
334    let [combined_heads, seq_len, head_dim] = tensor.dims();
335    if rank == 0 || head_dim == 0 {
336        return Err("invalid rank or head_dim");
337    }
338    let device = tensor.device();
339    let data = tensor
340        .into_data()
341        .into_vec::<f32>()
342        .map_err(|_| "low-rank compression expects f32 data")?;
343    let tokens = combined_heads * seq_len;
344    let mut energies = vec![0.0f32; head_dim];
345    for token in 0..tokens {
346        let base = token * head_dim;
347        for dim in 0..head_dim {
348            let value = data[base + dim];
349            energies[dim] += value * value;
350        }
351    }
352
353    let max_rank = rank.min(head_dim);
354    let mut ranked: Vec<(usize, f32)> = energies.into_iter().enumerate().collect();
355    ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
356
357    let total_energy: f32 = ranked.iter().map(|(_, energy)| *energy).sum();
358    let mut effective_rank = max_rank;
359    if total_energy > 0.0 {
360        let mut cumulative = 0.0f32;
361        effective_rank = 0;
362        for (_, energy) in ranked.iter() {
363            cumulative += *energy;
364            effective_rank += 1;
365            if cumulative / total_energy >= ENERGY_FRACTION {
366                break;
367            }
368        }
369        effective_rank = effective_rank.min(max_rank).max(1);
370    }
371
372    let basis_indices: Vec<usize> = ranked
373        .iter()
374        .take(effective_rank)
375        .map(|(idx, _)| *idx)
376        .collect();
377
378    let mut projected = vec![0.0f32; tokens * effective_rank];
379    for token in 0..tokens {
380        let in_base = token * head_dim;
381        let out_base = token * effective_rank;
382        for (r, &dim) in basis_indices.iter().enumerate() {
383            projected[out_base + r] = data[in_base + dim];
384        }
385    }
386
387    let projected = Tensor::<B, 3>::from_data(
388        TensorData::new(projected, [combined_heads, seq_len, effective_rank]),
389        &device,
390    );
391
392    Ok(LowRankTensor {
393        projected,
394        basis_indices,
395        original_head_dim: head_dim,
396    })
397}
398
399fn compress_hybrid<B: Backend>(
400    tensor: Tensor<B, 3>,
401    rank: usize,
402    quant_bits: u8,
403) -> Result<HybridTensor, &'static str> {
404    let low_rank = compress_low_rank(tensor, rank)?;
405    let [combined_heads, seq_len, effective_rank] = low_rank.projected.dims();
406    let projected = low_rank.projected.clone();
407    let data = projected
408        .into_data()
409        .into_vec::<f32>()
410        .map_err(|_| "hybrid compression expects f32 data")?;
411
412    let quantized = quantize_values(
413        &data,
414        [combined_heads, seq_len, effective_rank],
415        quant_bits,
416    )?;
417
418    Ok(HybridTensor {
419        quantized,
420        basis_indices: low_rank.basis_indices,
421        original_head_dim: low_rank.original_head_dim,
422    })
423}
424
425fn compress_vector_quantization<B: Backend>(
426    tensor: Tensor<B, 3>,
427    codebook_size: usize,
428    quant_bits: u8,
429) -> Result<VectorQuantizedTensor, &'static str> {
430    if codebook_size == 0 {
431        return Err("codebook_size must be > 0");
432    }
433    let [combined_heads, seq_len, head_dim] = tensor.dims();
434    let data = tensor
435        .into_data()
436        .into_vec::<f32>()
437        .map_err(|_| "vector quantization expects f32 data")?;
438    let tokens = combined_heads * seq_len;
439
440    if tokens == 0 {
441        return Err("vector quantization expects non-empty tensor");
442    }
443    if codebook_size > 256 {
444        return Err("codebook_size must be <= 256");
445    }
446    if quant_bits == 4 && codebook_size > 16 {
447        return Err("codebook_size must be <= 16 for INT4");
448    }
449
450    let mut codebook = vec![0.0f32; codebook_size * head_dim];
451    for c in 0..codebook_size {
452        if c < tokens {
453            let start = c * head_dim;
454            let end = start + head_dim;
455            codebook[c * head_dim..(c + 1) * head_dim].copy_from_slice(&data[start..end]);
456        }
457    }
458
459    refine_codebook(&data, &mut codebook, codebook_size, head_dim, tokens);
460
461    let (codes, distances) = assign_codes(&data, &codebook, codebook_size, head_dim, tokens);
462    let outliers = detect_vq_outliers(&data, &distances, head_dim);
463    let packed = pack_codes(&codes, quant_bits);
464
465    Ok(VectorQuantizedTensor {
466        codebook,
467        codes: packed,
468        vector_dim: head_dim,
469        shape: [combined_heads, seq_len, head_dim],
470        outliers,
471    })
472}
473
474fn refine_codebook(
475    data: &[f32],
476    codebook: &mut [f32],
477    codebook_size: usize,
478    vector_dim: usize,
479    tokens: usize,
480) {
481    const KMEANS_ITERS: usize = 2;
482    for _ in 0..KMEANS_ITERS {
483        let mut counts = vec![0usize; codebook_size];
484        let mut sums = vec![0.0f32; codebook_size * vector_dim];
485
486        for token in 0..tokens {
487            let (idx, _) = nearest_centroid(
488                data,
489                codebook,
490                codebook_size,
491                vector_dim,
492                token,
493            );
494            counts[idx] += 1;
495            let base = token * vector_dim;
496            let sum_base = idx * vector_dim;
497            for d in 0..vector_dim {
498                sums[sum_base + d] += data[base + d];
499            }
500        }
501
502        for c in 0..codebook_size {
503            if counts[c] > 0 {
504                let base = c * vector_dim;
505                for d in 0..vector_dim {
506                    codebook[base + d] = sums[base + d] / counts[c] as f32;
507                }
508            }
509        }
510    }
511}
512
513fn assign_codes(
514    data: &[f32],
515    codebook: &[f32],
516    codebook_size: usize,
517    vector_dim: usize,
518    tokens: usize,
519) -> (Vec<u8>, Vec<f32>) {
520    let mut codes = Vec::with_capacity(tokens);
521    let mut distances = Vec::with_capacity(tokens);
522    for token in 0..tokens {
523        let (idx, dist) = nearest_centroid(data, codebook, codebook_size, vector_dim, token);
524        codes.push(idx as u8);
525        distances.push(dist);
526    }
527    (codes, distances)
528}
529
530fn detect_vq_outliers(
531    data: &[f32],
532    distances: &[f32],
533    vector_dim: usize,
534) -> Vec<OutlierVector> {
535    if distances.is_empty() {
536        return Vec::new();
537    }
538    let mean = distances.iter().sum::<f32>() / distances.len() as f32;
539    let mut var = 0.0f32;
540    for &dist in distances {
541        let diff = dist - mean;
542        var += diff * diff;
543    }
544    let std = (var / distances.len() as f32).sqrt();
545    let threshold = mean + OUTLIER_STD_FACTOR * std;
546
547    let mut outliers = Vec::new();
548    for (token, &dist) in distances.iter().enumerate() {
549        if dist > threshold {
550            let base = token * vector_dim;
551            let values = data[base..base + vector_dim].to_vec();
552            outliers.push(OutlierVector { index: token, values });
553        }
554    }
555    outliers
556}
557
558fn nearest_centroid(
559    data: &[f32],
560    codebook: &[f32],
561    codebook_size: usize,
562    vector_dim: usize,
563    token: usize,
564) -> (usize, f32) {
565    let base = token * vector_dim;
566    let mut best_idx = 0;
567    let mut best_dist = f32::INFINITY;
568    for c in 0..codebook_size {
569        let mut dist = 0.0f32;
570        let code_base = c * vector_dim;
571        for d in 0..vector_dim {
572            let diff = data[base + d] - codebook[code_base + d];
573            dist += diff * diff;
574        }
575        if dist < best_dist {
576            best_dist = dist;
577            best_idx = c;
578        }
579    }
580    (best_idx, best_dist)
581}
582
583fn pack_codes(codes: &[u8], bits: u8) -> QuantizedCodes {
584    match bits {
585        4 => QuantizedCodes::Int4 {
586            data: pack_nibbles(codes),
587            len: codes.len(),
588        },
589        _ => QuantizedCodes::Int8 { data: codes.to_vec() },
590    }
591}
592
593fn unpack_codes(codes: &QuantizedCodes) -> Vec<u8> {
594    match codes {
595        QuantizedCodes::Int4 { data, len } => unpack_nibbles(data, *len),
596        QuantizedCodes::Int8 { data } => data.clone(),
597    }
598}
599
600fn quantize_values(
601    data: &[f32],
602    shape: [usize; 3],
603    bits: u8,
604) -> Result<QuantizedTensor, &'static str> {
605    if bits != 4 && bits != 8 {
606        return Err("quant_bits must be 4 or 8");
607    }
608    if data.is_empty() {
609        return Err("cannot quantize empty tensor");
610    }
611
612    let mut sum_abs = 0.0f32;
613    let mut abs_values = Vec::with_capacity(data.len());
614    for value in data {
615        let abs = value.abs();
616        sum_abs += abs;
617        abs_values.push(abs);
618    }
619    let mean = sum_abs / data.len() as f32;
620    let mut var = 0.0f32;
621    for &abs in &abs_values {
622        let diff = abs - mean;
623        var += diff * diff;
624    }
625    let std = (var / data.len() as f32).sqrt();
626    abs_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
627    let median = abs_values[abs_values.len() / 2];
628    let mut deviations: Vec<f32> = abs_values.iter().map(|v| (v - median).abs()).collect();
629    deviations.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
630    let mad = deviations[deviations.len() / 2];
631    let threshold = if mad > 0.0 {
632        median + OUTLIER_MAD_FACTOR * mad
633    } else {
634        mean + OUTLIER_STD_FACTOR * std
635    };
636
637    let mut clipped = Vec::with_capacity(data.len());
638    let mut outliers = Vec::new();
639    let mut max_abs = 0.0f32;
640
641    for (idx, &value) in data.iter().enumerate() {
642        let abs = value.abs();
643        if abs > threshold {
644            outliers.push((idx, value));
645            let sign = if value.is_sign_negative() { -1.0 } else { 1.0 };
646            let clipped_value = sign * threshold;
647            max_abs = max_abs.max(clipped_value.abs());
648            clipped.push(clipped_value);
649        } else {
650            max_abs = max_abs.max(abs);
651            clipped.push(value);
652        }
653    }
654
655    let max_level = if bits == 4 { 7.0 } else { 127.0 };
656    let scale = if max_abs > 0.0 { max_abs / max_level } else { 1.0 };
657
658    let quantized = match bits {
659        4 => {
660            let mut values = Vec::with_capacity(clipped.len());
661            for value in &clipped {
662                let q = (value / scale).round().clamp(-max_level, max_level) as i8;
663                values.push(q);
664            }
665            QuantizedData::Int4(pack_int4(&values))
666        }
667        _ => {
668            let mut values = Vec::with_capacity(clipped.len());
669            for value in &clipped {
670                let q = (value / scale).round().clamp(-max_level, max_level) as i8;
671                values.push(q);
672            }
673            QuantizedData::Int8(values)
674        }
675    };
676
677    Ok(QuantizedTensor {
678        data: quantized,
679        shape,
680        scale,
681        bits,
682        outliers,
683    })
684}
685
686fn dequantize_values<B: Backend>(
687    quantized: &QuantizedTensor,
688    device: &B::Device,
689) -> Result<Tensor<B, 3>, &'static str> {
690    let num_values = quantized.shape[0] * quantized.shape[1] * quantized.shape[2];
691    let mut values: Vec<f32> = match &quantized.data {
692        QuantizedData::Int8(data) => data.iter().map(|&q| q as f32 * quantized.scale).collect(),
693        QuantizedData::Int4(data) => {
694            let unpacked = unpack_int4(data, num_values);
695            unpacked
696                .into_iter()
697                .map(|q| q as f32 * quantized.scale)
698                .collect()
699        }
700    };
701
702    for &(idx, value) in &quantized.outliers {
703        if idx < values.len() {
704            values[idx] = value;
705        }
706    }
707
708    Ok(Tensor::<B, 3>::from_data(
709        TensorData::new(values, quantized.shape),
710        device,
711    ))
712}
713
714fn decompress_tensor<B: Backend>(
715    compressed: CompressedTensor<B>,
716    device: &B::Device,
717) -> Result<Tensor<B, 3>, &'static str> {
718    match compressed {
719        CompressedTensor::LowRank(low_rank) => decompress_low_rank(low_rank, device),
720        CompressedTensor::VectorQuantized(vq) => decompress_vector_quantized(vq, device),
721        CompressedTensor::Hybrid(hybrid) => decompress_hybrid(hybrid, device),
722    }
723}
724
725fn decompress_low_rank<B: Backend>(
726    low_rank: LowRankTensor<B>,
727    device: &B::Device,
728) -> Result<Tensor<B, 3>, &'static str> {
729    let [combined_heads, seq_len, rank] = low_rank.projected.dims();
730    if rank == 0 {
731        return Err("low-rank projection has rank 0");
732    }
733    let data = low_rank
734        .projected
735        .into_data()
736        .into_vec::<f32>()
737        .map_err(|_| "low-rank decompression expects f32 data")?;
738    let tokens = combined_heads * seq_len;
739    let head_dim = low_rank.original_head_dim;
740    let mut full = vec![0.0f32; tokens * head_dim];
741
742    for token in 0..tokens {
743        let in_base = token * rank;
744        let out_base = token * head_dim;
745        for (r, &dim) in low_rank.basis_indices.iter().enumerate() {
746            if dim < head_dim {
747                full[out_base + dim] = data[in_base + r];
748            }
749        }
750    }
751
752    Ok(Tensor::<B, 3>::from_data(
753        TensorData::new(full, [combined_heads, seq_len, head_dim]),
754        device,
755    ))
756}
757
758fn decompress_hybrid<B: Backend>(
759    hybrid: HybridTensor,
760    device: &B::Device,
761) -> Result<Tensor<B, 3>, &'static str> {
762    let projected = dequantize_values::<B>(&hybrid.quantized, device)?;
763    let low_rank = LowRankTensor {
764        projected,
765        basis_indices: hybrid.basis_indices,
766        original_head_dim: hybrid.original_head_dim,
767    };
768    decompress_low_rank(low_rank, device)
769}
770
771fn decompress_vector_quantized<B: Backend>(
772    vq: VectorQuantizedTensor,
773    device: &B::Device,
774) -> Result<Tensor<B, 3>, &'static str> {
775    let tokens = vq.shape[0] * vq.shape[1];
776    let vector_dim = vq.vector_dim;
777    let codes = unpack_codes(&vq.codes);
778    if codes.len() != tokens {
779        return Err("vector quantization code length mismatch");
780    }
781    let mut data = vec![0.0f32; tokens * vector_dim];
782
783    for token in 0..tokens {
784        let code = codes[token] as usize;
785        let base = token * vector_dim;
786        let code_base = code * vector_dim;
787        for d in 0..vector_dim {
788            data[base + d] = vq.codebook[code_base + d];
789        }
790    }
791
792    for outlier in &vq.outliers {
793        let base = outlier.index * vector_dim;
794        if base + vector_dim <= data.len() {
795            data[base..base + vector_dim].copy_from_slice(&outlier.values);
796        }
797    }
798
799    Ok(Tensor::<B, 3>::from_data(
800        TensorData::new(data, vq.shape),
801        device,
802    ))
803}
804
805fn effective_vq_bits(bits: u8, codebook_size: usize) -> Result<u8, &'static str> {
806    match bits {
807        4 => {
808            if codebook_size > 16 {
809                Err("codebook_size must be <= 16 for INT4")
810            } else {
811                Ok(4)
812            }
813        }
814        8 => {
815            if codebook_size > 256 {
816                Err("codebook_size must be <= 256 for INT8")
817            } else {
818                Ok(8)
819            }
820        }
821        _ => {
822            if codebook_size <= 16 {
823                Ok(4)
824            } else if codebook_size <= 256 {
825                Ok(8)
826            } else {
827                Err("codebook_size must be <= 256")
828            }
829        }
830    }
831}
832
833fn pack_nibbles(values: &[u8]) -> Vec<u8> {
834    let mut packed = Vec::with_capacity((values.len() + 1) / 2);
835    let mut iter = values.iter();
836    loop {
837        let low = match iter.next() {
838            Some(v) => v & 0x0F,
839            None => break,
840        };
841        let high = match iter.next() {
842            Some(v) => (v & 0x0F) << 4,
843            None => 0,
844        };
845        packed.push(low | high);
846    }
847    packed
848}
849
850fn unpack_nibbles(values: &[u8], len: usize) -> Vec<u8> {
851    let mut unpacked = Vec::with_capacity(len);
852    for &byte in values {
853        if unpacked.len() < len {
854            unpacked.push(byte & 0x0F);
855        }
856        if unpacked.len() < len {
857            unpacked.push((byte >> 4) & 0x0F);
858        }
859    }
860    unpacked
861}
862
863fn pack_int4(values: &[i8]) -> Vec<u8> {
864    let mut packed = Vec::with_capacity((values.len() + 1) / 2);
865    let mut iter = values.iter();
866    loop {
867        let low = match iter.next() {
868            Some(v) => (*v as i16 + 8).clamp(0, 15) as u8,
869            None => break,
870        };
871        let high = match iter.next() {
872            Some(v) => ((*v as i16 + 8).clamp(0, 15) as u8) << 4,
873            None => 0,
874        };
875        packed.push(low | high);
876    }
877    packed
878}
879
880fn unpack_int4(values: &[u8], len: usize) -> Vec<i8> {
881    let mut unpacked = Vec::with_capacity(len);
882    for &byte in values {
883        if unpacked.len() < len {
884            unpacked.push(((byte & 0x0F) as i8) - 8);
885        }
886        if unpacked.len() < len {
887            unpacked.push(((byte >> 4) as i8) - 8);
888        }
889    }
890    unpacked
891}
892
893#[cfg(all(test, feature = "cpu"))]
894mod tests {
895    use super::*;
896    use burn::tensor::{Tensor, TensorData};
897    use burn_ndarray::NdArray;
898
899    #[test]
900    fn test_low_rank_roundtrip_preserves_top_dims() {
901        let device = <NdArray<f32> as Backend>::Device::default();
902        let num_heads = 2;
903        let seq_len = 2;
904        let head_dim = 4;
905        let mut data = Vec::new();
906        for _ in 0..(num_heads * seq_len) {
907            data.extend_from_slice(&[1.0, 0.1, 1.0, 0.1]);
908        }
909
910        let k = Tensor::<NdArray<f32>, 3>::from_data(
911            TensorData::new(data.clone(), [num_heads, seq_len, head_dim]),
912            &device,
913        );
914        let v = Tensor::<NdArray<f32>, 3>::from_data(
915            TensorData::new(data, [num_heads, seq_len, head_dim]),
916            &device,
917        );
918
919        let compressor =
920            KVCacheCompressor::<NdArray<f32>>::new(CompressionMethod::LowRank { rank: 3 }, 8);
921        let compressed = compressor.compress_kv_3d(k, v).expect("compress");
922        let (k_full, _v_full) = compressor.decompress_kv_3d(compressed).expect("decompress");
923
924        let k_data = k_full.into_data().into_vec::<f32>().expect("data");
925        for token in 0..(num_heads * seq_len) {
926            let base = token * head_dim;
927            assert!((k_data[base + 1]).abs() < 1e-3);
928            assert!((k_data[base + 3]).abs() < 1e-3);
929        }
930    }
931
932    #[test]
933    fn test_hybrid_quantization_outlier_preserved() {
934        let device = <NdArray<f32> as Backend>::Device::default();
935        let num_heads = 1;
936        let seq_len = 4;
937        let head_dim = 1;
938        let data = vec![0.1, 0.2, 0.15, 10.0];
939
940        let k = Tensor::<NdArray<f32>, 3>::from_data(
941            TensorData::new(data.clone(), [num_heads, seq_len, head_dim]),
942            &device,
943        );
944        let v = Tensor::<NdArray<f32>, 3>::from_data(
945            TensorData::new(data, [num_heads, seq_len, head_dim]),
946            &device,
947        );
948
949        let compressor = KVCacheCompressor::<NdArray<f32>>::new(
950            CompressionMethod::Hybrid {
951                rank: 1,
952                quant_bits: 4,
953            },
954            4,
955        );
956        let compressed = compressor.compress_kv_3d(k, v).expect("compress");
957        let (k_full, _) = compressor.decompress_kv_3d(compressed).expect("decompress");
958        let k_data = k_full.into_data().into_vec::<f32>().expect("data");
959
960        assert!((k_data[3] - 10.0).abs() < 1e-3);
961    }
962
963    #[test]
964    fn test_vector_quantization_roundtrip() {
965        let device = <NdArray<f32> as Backend>::Device::default();
966        let num_heads = 1;
967        let seq_len = 2;
968        let head_dim = 2;
969        let data = vec![1.0, 0.0, -1.0, 0.0];
970        let original_data = data.clone();
971
972        let k = Tensor::<NdArray<f32>, 3>::from_data(
973            TensorData::new(data.clone(), [num_heads, seq_len, head_dim]),
974            &device,
975        );
976        let v = Tensor::<NdArray<f32>, 3>::from_data(
977            TensorData::new(data, [num_heads, seq_len, head_dim]),
978            &device,
979        );
980
981        let compressor = KVCacheCompressor::<NdArray<f32>>::new(
982            CompressionMethod::VectorQuantization { codebook_size: 2 },
983            8,
984        );
985        let compressed = compressor.compress_kv_3d(k, v).expect("compress");
986        let (k_full, _) = compressor.decompress_kv_3d(compressed).expect("decompress");
987        let k_data = k_full.into_data().into_vec::<f32>().expect("data");
988
989        for (orig, round) in original_data.iter().zip(k_data.iter()) {
990            assert!((orig - round).abs() < 1e-4);
991        }
992    }
993
994    #[test]
995    fn test_paged_cache_compatibility() {
996        let device = <NdArray<f32> as Backend>::Device::default();
997        let mut cache = PagedKVCache::<NdArray<f32>>::new(4, 1, 1, 2, &device);
998        let seq_id = cache.allocate_sequence();
999        let keys = Tensor::<NdArray<f32>, 3>::from_data(
1000            TensorData::new(vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6], [1, 3, 2]),
1001            &device,
1002        );
1003        let values = Tensor::<NdArray<f32>, 3>::from_data(
1004            TensorData::new(vec![0.6, 0.5, 0.4, 0.3, 0.2, 0.1], [1, 3, 2]),
1005            &device,
1006        );
1007        cache.append(0, seq_id, keys, values).expect("append");
1008
1009        let compressor =
1010            KVCacheCompressor::<NdArray<f32>>::new(CompressionMethod::LowRank { rank: 2 }, 8);
1011        let compressed = compressor
1012            .compress_paged_cache(&cache, 0, seq_id)
1013            .expect("compress paged");
1014
1015        let seq_id2 = cache.allocate_sequence();
1016        compressor
1017            .decompress_to_paged_cache(compressed, &mut cache, 0, seq_id2)
1018            .expect("decompress paged");
1019        assert_eq!(cache.seq_len(0, seq_id2).expect("seq len"), 3);
1020    }
1021
1022    fn identity_matrix(
1023        dim: usize,
1024        device: &<NdArray<f32> as Backend>::Device,
1025    ) -> Tensor<NdArray<f32>, 2> {
1026        let mut data = vec![0.0f32; dim * dim];
1027        for i in 0..dim {
1028            data[i * dim + i] = 1.0;
1029        }
1030        Tensor::from_data(TensorData::new(data, [dim, dim]), device)
1031    }
1032
1033    fn zero_matrix(
1034        rows: usize,
1035        cols: usize,
1036        device: &<NdArray<f32> as Backend>::Device,
1037    ) -> Tensor<NdArray<f32>, 2> {
1038        Tensor::from_data(TensorData::new(vec![0.0f32; rows * cols], [rows, cols]), device)
1039    }
1040
1041    #[test]
1042    fn test_mla_cache_compatibility() {
1043        use crate::ops::mla::MultiHeadLatentAttention;
1044
1045        let device = <NdArray<f32> as Backend>::Device::default();
1046        let head_dim = 2;
1047        let latent_dim = 2;
1048        let down = identity_matrix(head_dim, &device);
1049        let up = identity_matrix(head_dim, &device);
1050        let rope = zero_matrix(latent_dim, head_dim, &device);
1051
1052        let mla = MultiHeadLatentAttention::new(1, latent_dim, down, up, rope);
1053        let mut cache = MlaCompressedKVCache::new(4, 1, 1, mla, &device);
1054        let seq_id = cache.allocate_sequence();
1055
1056        let keys = Tensor::<NdArray<f32>, 3>::from_data(
1057            TensorData::new(vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6], [1, 3, 2]),
1058            &device,
1059        );
1060        let values = Tensor::<NdArray<f32>, 3>::from_data(
1061            TensorData::new(vec![0.6, 0.5, 0.4, 0.3, 0.2, 0.1], [1, 3, 2]),
1062            &device,
1063        );
1064        cache.append(0, seq_id, keys, values).expect("append");
1065
1066        let compressor =
1067            KVCacheCompressor::<NdArray<f32>>::new(CompressionMethod::LowRank { rank: 2 }, 8);
1068        let compressed = compressor
1069            .compress_mla_cache(&cache, 0, seq_id)
1070            .expect("compress mla");
1071
1072        let seq_id2 = cache.allocate_sequence();
1073        compressor
1074            .decompress_to_mla_cache(compressed, &mut cache, 0, seq_id2)
1075            .expect("decompress mla");
1076        assert_eq!(cache.seq_len(0, seq_id2).expect("seq len"), 3);
1077    }
1078}