Skip to main content

llama_rs/model/
kv_quantized.rs

1//! Quantized KV cache for reduced VRAM usage
2//!
3//! Stores KV cache entries in INT8 or FP8 format instead of F32,
4//! achieving ~2x memory reduction with minimal quality impact.
5
6#[allow(unused_imports)]
7use crate::tensor::{DType, Tensor};
8
9/// KV cache storage format
10#[derive(Debug, Clone, Copy, PartialEq)]
11pub enum KVCacheFormat {
12    /// Full precision (default)
13    F32,
14    /// INT8 with per-head symmetric quantization: value = scale * int8_value
15    Int8,
16    /// FP8 E4M3 format (4 exponent, 3 mantissa bits) - good for inference
17    Fp8E4M3,
18    /// FP8 E5M2 format (5 exponent, 2 mantissa bits) - wider range, less precision
19    Fp8E5M2,
20}
21
22impl KVCacheFormat {
23    /// Bytes per element for this format
24    const fn bytes_per_element(&self) -> usize {
25        match self {
26            KVCacheFormat::F32 => 4,
27            KVCacheFormat::Int8 | KVCacheFormat::Fp8E4M3 | KVCacheFormat::Fp8E5M2 => 1,
28        }
29    }
30
31    /// Whether this format uses per-head scales (INT8 only)
32    const fn uses_scales(&self) -> bool {
33        matches!(self, KVCacheFormat::Int8)
34    }
35}
36
37/// Quantized KV cache that stores K/V in reduced precision
38pub struct QuantizedKVCache {
39    /// Quantized key data per layer - raw bytes in the chosen format
40    pub k_data: Vec<Vec<u8>>,
41    /// Quantized value data per layer
42    pub v_data: Vec<Vec<u8>>,
43    /// Per-head scale factors for INT8 (one per head per position per layer)
44    /// Layout: [layer][head * max_seq_len + pos]
45    pub k_scales: Vec<Vec<f32>>,
46    pub v_scales: Vec<Vec<f32>>,
47    /// Storage format
48    pub format: KVCacheFormat,
49    /// Current sequence length
50    pub seq_len: usize,
51    pub max_seq_len: usize,
52    pub num_kv_heads: usize,
53    pub head_dim: usize,
54    pub num_layers: usize,
55}
56
57impl QuantizedKVCache {
58    /// Create a new quantized KV cache with the given dimensions and format
59    pub fn new(
60        num_layers: usize,
61        num_kv_heads: usize,
62        max_seq_len: usize,
63        head_dim: usize,
64        format: KVCacheFormat,
65    ) -> Self {
66        let elements_per_layer = num_kv_heads * max_seq_len * head_dim;
67        let bytes_per_element = format.bytes_per_element();
68        let layer_bytes = elements_per_layer * bytes_per_element;
69
70        let k_data: Vec<Vec<u8>> = (0..num_layers)
71            .map(|_| vec![0u8; layer_bytes])
72            .collect();
73        let v_data: Vec<Vec<u8>> = (0..num_layers)
74            .map(|_| vec![0u8; layer_bytes])
75            .collect();
76
77        let scales_per_layer = if format.uses_scales() {
78            num_kv_heads * max_seq_len
79        } else {
80            0
81        };
82
83        let k_scales: Vec<Vec<f32>> = (0..num_layers)
84            .map(|_| vec![0.0f32; scales_per_layer])
85            .collect();
86        let v_scales: Vec<Vec<f32>> = (0..num_layers)
87            .map(|_| vec![0.0f32; scales_per_layer])
88            .collect();
89
90        Self {
91            k_data,
92            v_data,
93            k_scales,
94            v_scales,
95            format,
96            seq_len: 0,
97            max_seq_len,
98            num_kv_heads,
99            head_dim,
100            num_layers,
101        }
102    }
103
104    /// Reset the cache for a new sequence
105    pub fn reset(&mut self) {
106        self.seq_len = 0;
107        for k in &mut self.k_data {
108            k.fill(0);
109        }
110        for v in &mut self.v_data {
111            v.fill(0);
112        }
113        for s in &mut self.k_scales {
114            s.fill(0.0);
115        }
116        for s in &mut self.v_scales {
117            s.fill(0.0);
118        }
119    }
120
121    /// Get remaining capacity
122    pub fn remaining_capacity(&self) -> usize {
123        self.max_seq_len.saturating_sub(self.seq_len)
124    }
125
126    /// Check if cache is full
127    pub fn is_full(&self) -> bool {
128        self.seq_len >= self.max_seq_len
129    }
130
131    /// Get memory usage in bytes
132    pub fn memory_usage(&self) -> usize {
133        let data_bytes: usize = self.k_data.iter().map(|v| v.len()).sum::<usize>()
134            + self.v_data.iter().map(|v| v.len()).sum::<usize>();
135        let scale_bytes: usize = self.k_scales.iter().map(|v| v.len() * 4).sum::<usize>()
136            + self.v_scales.iter().map(|v| v.len() * 4).sum::<usize>();
137        data_bytes + scale_bytes
138    }
139
140    /// Write quantized K/V for one position
141    ///
142    /// `k_data` and `v_data` are `[num_kv_heads * head_dim]` each
143    pub fn write_kv(
144        &mut self,
145        layer: usize,
146        pos: usize,
147        k_data: &[f32],
148        v_data: &[f32],
149    ) {
150        assert!(layer < self.num_layers);
151        assert!(pos < self.max_seq_len);
152        assert_eq!(k_data.len(), self.num_kv_heads * self.head_dim);
153        assert_eq!(v_data.len(), self.num_kv_heads * self.head_dim);
154
155        let k_layer = &mut self.k_data[layer];
156        let v_layer = &mut self.v_data[layer];
157
158        for head in 0..self.num_kv_heads {
159            let head_start = head * self.head_dim;
160            let head_end = head_start + self.head_dim;
161            let k_head = &k_data[head_start..head_end];
162            let v_head = &v_data[head_start..head_end];
163
164            let k_offset = (head * self.max_seq_len + pos) * self.head_dim
165                * self.format.bytes_per_element();
166            let v_offset = (head * self.max_seq_len + pos) * self.head_dim
167                * self.format.bytes_per_element();
168
169            match self.format {
170                KVCacheFormat::F32 => {
171                    for (i, &val) in k_head.iter().enumerate() {
172                        let bytes = val.to_le_bytes();
173                        k_layer[k_offset + i * 4..k_offset + (i + 1) * 4]
174                            .copy_from_slice(&bytes);
175                    }
176                    for (i, &val) in v_head.iter().enumerate() {
177                        let bytes = val.to_le_bytes();
178                        v_layer[v_offset + i * 4..v_offset + (i + 1) * 4]
179                            .copy_from_slice(&bytes);
180                    }
181                }
182                KVCacheFormat::Int8 => {
183                    let (k_quant, k_scale) = quantize_int8(k_head);
184                    let (v_quant, v_scale) = quantize_int8(v_head);
185
186                    let scale_idx = head * self.max_seq_len + pos;
187                    self.k_scales[layer][scale_idx] = k_scale;
188                    self.v_scales[layer][scale_idx] = v_scale;
189
190                    for (i, &q) in k_quant.iter().enumerate() {
191                        k_layer[k_offset + i] = q as u8;
192                    }
193                    for (i, &q) in v_quant.iter().enumerate() {
194                        v_layer[v_offset + i] = q as u8;
195                    }
196                }
197                KVCacheFormat::Fp8E4M3 => {
198                    for (i, &val) in k_head.iter().enumerate() {
199                        k_layer[k_offset + i] = quantize_fp8_e4m3(val);
200                    }
201                    for (i, &val) in v_head.iter().enumerate() {
202                        v_layer[v_offset + i] = quantize_fp8_e4m3(val);
203                    }
204                }
205                KVCacheFormat::Fp8E5M2 => {
206                    for (i, &val) in k_head.iter().enumerate() {
207                        k_layer[k_offset + i] = quantize_fp8_e5m2(val);
208                    }
209                    for (i, &val) in v_head.iter().enumerate() {
210                        v_layer[v_offset + i] = quantize_fp8_e5m2(val);
211                    }
212                }
213            }
214        }
215    }
216
217    /// Dequantize and return key for one head at one position
218    pub fn read_k(&self, layer: usize, head: usize, pos: usize) -> Vec<f32> {
219        self.read_k_range(layer, head, pos, pos + 1)
220    }
221
222    /// Dequantize and return value for one head at one position
223    pub fn read_v(&self, layer: usize, head: usize, pos: usize) -> Vec<f32> {
224        self.read_v_range(layer, head, pos, pos + 1)
225    }
226
227    /// Dequantize key range for one head
228    ///
229    /// Returns `[end_pos - start_pos, head_dim]` as flat vec
230    pub fn read_k_range(
231        &self,
232        layer: usize,
233        head: usize,
234        start_pos: usize,
235        end_pos: usize,
236    ) -> Vec<f32> {
237        let k_layer = &self.k_data[layer];
238        let bpe = self.format.bytes_per_element();
239        let mut result = Vec::with_capacity((end_pos - start_pos) * self.head_dim);
240
241        for pos in start_pos..end_pos {
242            let offset = (head * self.max_seq_len + pos) * self.head_dim * bpe;
243
244            for d in 0..self.head_dim {
245                let val = match self.format {
246                    KVCacheFormat::F32 => {
247                        let byte_offset = offset + d * 4;
248                        f32::from_le_bytes(
249                            k_layer[byte_offset..byte_offset + 4]
250                                .try_into()
251                                .unwrap(),
252                        )
253                    }
254                    KVCacheFormat::Int8 => {
255                        let scale_idx = head * self.max_seq_len + pos;
256                        let scale = self.k_scales[layer][scale_idx];
257                        let q = k_layer[offset + d] as i8;
258                        dequantize_int8(&[q], scale)[0]
259                    }
260                    KVCacheFormat::Fp8E4M3 => dequantize_fp8_e4m3(k_layer[offset + d]),
261                    KVCacheFormat::Fp8E5M2 => dequantize_fp8_e5m2(k_layer[offset + d]),
262                };
263                result.push(val);
264            }
265        }
266        result
267    }
268
269    /// Dequantize value range for one head
270    ///
271    /// Returns `[end_pos - start_pos, head_dim]` as flat vec
272    pub fn read_v_range(
273        &self,
274        layer: usize,
275        head: usize,
276        start_pos: usize,
277        end_pos: usize,
278    ) -> Vec<f32> {
279        let v_layer = &self.v_data[layer];
280        let bpe = self.format.bytes_per_element();
281        let mut result = Vec::with_capacity((end_pos - start_pos) * self.head_dim);
282
283        for pos in start_pos..end_pos {
284            let offset = (head * self.max_seq_len + pos) * self.head_dim * bpe;
285
286            for d in 0..self.head_dim {
287                let val = match self.format {
288                    KVCacheFormat::F32 => {
289                        let byte_offset = offset + d * 4;
290                        f32::from_le_bytes(
291                            v_layer[byte_offset..byte_offset + 4]
292                                .try_into()
293                                .unwrap(),
294                        )
295                    }
296                    KVCacheFormat::Int8 => {
297                        let scale_idx = head * self.max_seq_len + pos;
298                        let scale = self.v_scales[layer][scale_idx];
299                        let q = v_layer[offset + d] as i8;
300                        dequantize_int8(&[q], scale)[0]
301                    }
302                    KVCacheFormat::Fp8E4M3 => dequantize_fp8_e4m3(v_layer[offset + d]),
303                    KVCacheFormat::Fp8E5M2 => dequantize_fp8_e5m2(v_layer[offset + d]),
304                };
305                result.push(val);
306            }
307        }
308        result
309    }
310
311    /// Shift cache left by `amount` positions (for sliding window)
312    pub fn shift_left(&mut self, amount: usize) {
313        if amount == 0 || amount >= self.seq_len {
314            self.reset();
315            return;
316        }
317
318        let new_len = self.seq_len - amount;
319        let bpe = self.format.bytes_per_element();
320
321        for layer_idx in 0..self.num_layers {
322            let k_layer = &mut self.k_data[layer_idx];
323            let v_layer = &mut self.v_data[layer_idx];
324
325            for head in 0..self.num_kv_heads {
326                for pos in 0..new_len {
327                    let src_pos = pos + amount;
328                    let src_offset = (head * self.max_seq_len + src_pos) * self.head_dim * bpe;
329                    let dst_offset = (head * self.max_seq_len + pos) * self.head_dim * bpe;
330                    let block_len = self.head_dim * bpe;
331
332                    k_layer.copy_within(src_offset..src_offset + block_len, dst_offset);
333                    v_layer.copy_within(src_offset..src_offset + block_len, dst_offset);
334                }
335            }
336
337            if self.format.uses_scales() {
338                let k_scales = &mut self.k_scales[layer_idx];
339                let v_scales = &mut self.v_scales[layer_idx];
340
341                for head in 0..self.num_kv_heads {
342                    for pos in 0..new_len {
343                        let src_idx = head * self.max_seq_len + (pos + amount);
344                        let dst_idx = head * self.max_seq_len + pos;
345                        k_scales[dst_idx] = k_scales[src_idx];
346                        v_scales[dst_idx] = v_scales[src_idx];
347                    }
348                }
349            }
350        }
351
352        self.seq_len = new_len;
353    }
354
355    /// Truncate cache to a specific length
356    pub fn truncate(&mut self, new_len: usize) {
357        if new_len < self.seq_len {
358            self.seq_len = new_len;
359        }
360    }
361}
362
363// --- Internal helpers ---
364
365/// Symmetric INT8 quantization: scale = max(|x|) / 127
366fn quantize_int8(data: &[f32]) -> (Vec<i8>, f32) {
367    let max_abs = data
368        .iter()
369        .map(|&x| x.abs())
370        .fold(0.0f32, f32::max);
371
372    let scale = if max_abs > 1e-10 {
373        max_abs / 127.0
374    } else {
375        1.0
376    };
377
378    let quantized: Vec<i8> = data
379        .iter()
380        .map(|&x| {
381            let q = (x / scale).round();
382            q.clamp(-128.0, 127.0) as i8
383        })
384        .collect();
385
386    (quantized, scale)
387}
388
389/// Dequantize INT8: value = scale * int8_value
390fn dequantize_int8(data: &[i8], scale: f32) -> Vec<f32> {
391    data.iter().map(|&q| (q as f32) * scale).collect()
392}
393
394/// Convert f32 to FP8 E4M3 (1 sign + 4 exp + 3 mantissa, bias 7)
395fn quantize_fp8_e4m3(value: f32) -> u8 {
396    if value.is_nan() {
397        return 0xFF;
398    }
399    if value.is_infinite() {
400        return if value > 0.0 { 0x7F } else { 0xFF };
401    }
402    if value == 0.0 {
403        return 0x00;
404    }
405
406    let bits = value.to_bits();
407    let sign = ((bits >> 31) & 1) as u8;
408    let exponent = ((bits >> 23) & 0xFF) as i32 - 127;
409    let mut mantissa = bits & 0x7F_FFFF;
410    if exponent != -127 {
411        mantissa |= 0x800_000;
412    }
413
414    let e4m3_exp = exponent + 7;
415
416    if e4m3_exp > 15 {
417        return (sign << 7) | 0x7E;
418    }
419    if (e4m3_exp > -3) && (e4m3_exp <= 0) {
420        let shift_bits = (3 + e4m3_exp) as u32;
421        let mask = 0x7u32 >> (0i32.saturating_sub(e4m3_exp) as u32);
422        let e4m3_mantissa = ((mantissa >> (24 - shift_bits)) & mask) as u8;
423        return (sign << 7) | e4m3_mantissa;
424    }
425    if e4m3_exp <= -3 {
426        return sign << 7;
427    }
428
429    let e4m3_mantissa = ((mantissa >> 20) & 0x7) as u8;
430    (sign << 7) | ((e4m3_exp as u8) << 3) | e4m3_mantissa
431}
432
433/// Convert FP8 E4M3 back to f32
434fn dequantize_fp8_e4m3(value: u8) -> f32 {
435    let bits = value;
436    // Zero: S.0000.000
437    if (bits & 0x7F) == 0 {
438        return 0.0;
439    }
440    // NaN: S.1111.111
441    if (bits & 0x7F) == 0x7F {
442        return f32::NAN;
443    }
444
445    let sign = (bits >> 7) & 1;
446    let e4m3_exp = (bits >> 3) & 0xF;
447    let e4m3_mantissa = bits & 0x7;
448    let exponent = (e4m3_exp as i32) - 7;
449    let float_exp = (exponent + 127) as u32;
450
451    let result = if e4m3_exp > 0 {
452        (sign as u32) << 31 | float_exp << 23 | (e4m3_mantissa as u32) << 20
453    } else {
454        match e4m3_mantissa {
455            m if m >= 4 => (sign as u32) << 31 | float_exp << 23 | ((m & 3) as u32) << 21,
456            m if m > 1 => (sign as u32) << 31 | (float_exp - 1) << 23 | ((m & 1) as u32) << 22,
457            1 => (sign as u32) << 31 | (float_exp - 2) << 23,
458            _ => return f32::NAN,
459        }
460    };
461
462    f32::from_bits(result)
463}
464
465/// Convert f32 to FP8 E5M2 (1 sign + 5 exp + 2 mantissa, bias 15)
466fn quantize_fp8_e5m2(value: f32) -> u8 {
467    if value.is_nan() {
468        return 0xFF;
469    }
470    if value.is_infinite() {
471        return if value > 0.0 { 0x7C } else { 0xFC };
472    }
473    if value == 0.0 {
474        return 0x00;
475    }
476
477    let bits = value.to_bits();
478    let sign = ((bits >> 31) & 1) as u8;
479    let exponent = ((bits >> 23) & 0xFF) as i32 - 127;
480    let mut mantissa = bits & 0x7F_FFFF;
481    if exponent != -127 {
482        mantissa |= 0x800_000;
483    }
484
485    let e5m2_exp = exponent + 15;
486
487    if e5m2_exp > 31 {
488        return (sign << 7) | 0x7C;
489    }
490    if (e5m2_exp >= -1) && (e5m2_exp <= 0) {
491        let shift_bits = (2 + e5m2_exp) as u32;
492        let mask = 0x3u32 >> (0i32.saturating_sub(e5m2_exp) as u32);
493        let e5m2_mantissa = ((mantissa >> (24 - shift_bits)) & mask) as u8;
494        return (sign << 7) | e5m2_mantissa;
495    }
496    if e5m2_exp < -1 {
497        return sign << 7;
498    }
499
500    let e5m2_mantissa = ((mantissa >> 21) & 0x3) as u8;
501    (sign << 7) | ((e5m2_exp as u8) << 2) | e5m2_mantissa
502}
503
504/// Convert FP8 E5M2 back to f32
505fn dequantize_fp8_e5m2(value: u8) -> f32 {
506    let bits = value;
507    // Zero: S.00000.00
508    if (bits & 0x7F) == 0 {
509        return 0.0;
510    }
511    // Inf: S.11111.00
512    if (bits & 0x7F) == 0x7C {
513        return if (bits >> 7) != 0 {
514            f32::NEG_INFINITY
515        } else {
516            f32::INFINITY
517        };
518    }
519    // NaN: S.11111.{01,10,11}
520    if (bits & 0x7F) >= 0x7D {
521        return f32::NAN;
522    }
523
524    let sign = (bits >> 7) & 1;
525    let e5m2_exp = (bits >> 2) & 0x1F;
526    let e5m2_mantissa = bits & 0x3;
527    let exponent = (e5m2_exp as i32) - 15;
528    let float_exp = (exponent + 127) as u32;
529
530    let result = if e5m2_exp > 0 {
531        (sign as u32) << 31 | float_exp << 23 | (e5m2_mantissa as u32) << 21
532    } else {
533        match e5m2_mantissa {
534            m if m >= 2 => (sign as u32) << 31 | float_exp << 23 | ((m & 1) as u32) << 22,
535            1 => (sign as u32) << 31 | (float_exp - 1) << 23,
536            _ => return f32::NAN,
537        }
538    };
539
540    f32::from_bits(result)
541}
542
543#[cfg(test)]
544mod tests {
545    use super::*;
546
547    #[test]
548    fn test_int8_roundtrip() {
549        let data: Vec<f32> = (0..128).map(|i| (i as f32) * 0.1 - 6.4).collect();
550        let (quantized, scale) = quantize_int8(&data);
551        let dequantized = dequantize_int8(&quantized, scale);
552        for (orig, dec) in data.iter().zip(dequantized.iter()) {
553            let rel_err = if orig.abs() > 1e-6 {
554                (orig - dec).abs() / orig.abs()
555            } else {
556                (orig - dec).abs()
557            };
558            assert!(rel_err < 0.02, "orig={orig}, dec={dec}, rel_err={rel_err}");
559        }
560    }
561
562    #[test]
563    fn test_fp8_e4m3_roundtrip() {
564        let values = [
565            0.0f32,
566            1.0,
567            -1.0,
568            0.5,
569            0.0136719,
570            448.0,
571            2f32.powi(-6),
572            2f32.powi(-9),
573        ];
574        for &val in &values {
575            let q = quantize_fp8_e4m3(val);
576            let d = dequantize_fp8_e4m3(q);
577            if val == 0.0 {
578                assert_eq!(d, 0.0, "zero roundtrip");
579            } else if val.abs() < 1e-5 {
580                assert!(d.abs() < 0.01, "small value {val} -> {d}");
581            } else {
582                let rel_err = (val - d).abs() / val.abs();
583                assert!(rel_err < 0.05, "val={val}, d={d}, rel_err={rel_err}");
584            }
585        }
586    }
587
588    #[test]
589    fn test_fp8_e5m2_roundtrip() {
590        let values = [
591            0.0f32,
592            1.0,
593            -1.0,
594            0.5,
595            57344.0,
596            2f32.powi(-14),
597            1.52588e-5,
598        ];
599        for &val in &values {
600            let q = quantize_fp8_e5m2(val);
601            let d = dequantize_fp8_e5m2(q);
602            if val == 0.0 {
603                assert_eq!(d, 0.0, "zero roundtrip");
604            } else if val.abs() < 1e-5 {
605                assert!(d.abs() < 0.01, "small value {val} -> {d}");
606            } else {
607                let rel_err = (val - d).abs() / val.abs();
608                assert!(rel_err < 0.1, "val={val}, d={d}, rel_err={rel_err}");
609            }
610        }
611    }
612
613    #[test]
614    fn test_quantized_kv_cache_basic() {
615        let num_layers = 2;
616        let num_kv_heads = 4;
617        let max_seq_len = 16;
618        let head_dim = 64;
619
620        for format in [
621            KVCacheFormat::Int8,
622            KVCacheFormat::Fp8E4M3,
623            KVCacheFormat::Fp8E5M2,
624        ] {
625            let mut cache =
626                QuantizedKVCache::new(num_layers, num_kv_heads, max_seq_len, head_dim, format);
627
628            let k_data: Vec<f32> = (0..num_kv_heads * head_dim)
629                .map(|i| (i as f32) * 0.01 - 1.0)
630                .collect();
631            let v_data: Vec<f32> = (0..num_kv_heads * head_dim)
632                .map(|i| (i as f32) * 0.02 - 0.5)
633                .collect();
634
635            cache.write_kv(0, 0, &k_data, &v_data);
636            cache.seq_len = 1;
637
638            let read_k = cache.read_k(0, 0, 0);
639            let read_v = cache.read_v(0, 0, 0);
640
641            assert_eq!(read_k.len(), head_dim);
642            assert_eq!(read_v.len(), head_dim);
643
644            let orig_k_head = &k_data[0..head_dim];
645            let orig_v_head = &v_data[0..head_dim];
646
647            let tol = match format {
648                KVCacheFormat::Int8 => 0.15,
649                KVCacheFormat::Fp8E4M3 | KVCacheFormat::Fp8E5M2 => 0.25,
650                _ => 0.01,
651            };
652            for (a, b) in orig_k_head.iter().zip(read_k.iter()) {
653                let rel_err = if a.abs() > 1e-6 {
654                    (a - b).abs() / a.abs()
655                } else {
656                    (a - b).abs()
657                };
658                assert!(rel_err < tol, "k: orig={a}, read={b}");
659            }
660            for (a, b) in orig_v_head.iter().zip(read_v.iter()) {
661                let rel_err = if a.abs() > 1e-6 {
662                    (a - b).abs() / a.abs()
663                } else {
664                    (a - b).abs()
665                };
666                assert!(rel_err < tol, "v: orig={a}, read={b}");
667            }
668        }
669    }
670
671    #[test]
672    fn test_memory_savings() {
673        let num_layers = 4;
674        let num_kv_heads = 32;
675        let max_seq_len = 2048;
676        let head_dim = 128;
677
678        let f32_size = num_layers * 2 * (num_kv_heads * max_seq_len * head_dim * 4);
679
680        let int8_cache =
681            QuantizedKVCache::new(num_layers, num_kv_heads, max_seq_len, head_dim, KVCacheFormat::Int8);
682        let fp8_cache =
683            QuantizedKVCache::new(num_layers, num_kv_heads, max_seq_len, head_dim, KVCacheFormat::Fp8E4M3);
684
685        let int8_size = int8_cache.memory_usage();
686        let fp8_size = fp8_cache.memory_usage();
687
688        assert!(int8_size < f32_size / 2 + f32_size / 4);
689        assert!(fp8_size < f32_size / 2 + f32_size / 4);
690    }
691
692    #[test]
693    fn test_shift_left() {
694        let num_layers = 1;
695        let num_kv_heads = 2;
696        let max_seq_len = 8;
697        let head_dim = 4;
698
699        let mut cache = QuantizedKVCache::new(
700            num_layers,
701            num_kv_heads,
702            max_seq_len,
703            head_dim,
704            KVCacheFormat::Int8,
705        );
706
707        for pos in 0..5 {
708            let k_data: Vec<f32> = (0..num_kv_heads * head_dim)
709                .map(|_| pos as f32)
710                .collect();
711            let v_data = k_data.clone();
712            cache.write_kv(0, pos, &k_data, &v_data);
713        }
714        cache.seq_len = 5;
715
716        cache.shift_left(2);
717
718        assert_eq!(cache.seq_len, 3);
719
720        for (i, pos) in (2..5).enumerate() {
721            let read_k = cache.read_k(0, 0, i);
722            let expected: Vec<f32> = (0..head_dim).map(|_| pos as f32).collect();
723            for (a, b) in read_k.iter().zip(expected.iter()) {
724                assert!((a - b).abs() < 0.01, "pos {i}: expected {b}, got {a}");
725            }
726        }
727    }
728}