1#[allow(unused_imports)]
7use crate::tensor::{DType, Tensor};
8
9#[derive(Debug, Clone, Copy, PartialEq)]
11pub enum KVCacheFormat {
12 F32,
14 Int8,
16 Fp8E4M3,
18 Fp8E5M2,
20}
21
22impl KVCacheFormat {
23 const fn bytes_per_element(&self) -> usize {
25 match self {
26 KVCacheFormat::F32 => 4,
27 KVCacheFormat::Int8 | KVCacheFormat::Fp8E4M3 | KVCacheFormat::Fp8E5M2 => 1,
28 }
29 }
30
31 const fn uses_scales(&self) -> bool {
33 matches!(self, KVCacheFormat::Int8)
34 }
35}
36
37pub struct QuantizedKVCache {
39 pub k_data: Vec<Vec<u8>>,
41 pub v_data: Vec<Vec<u8>>,
43 pub k_scales: Vec<Vec<f32>>,
46 pub v_scales: Vec<Vec<f32>>,
47 pub format: KVCacheFormat,
49 pub seq_len: usize,
51 pub max_seq_len: usize,
52 pub num_kv_heads: usize,
53 pub head_dim: usize,
54 pub num_layers: usize,
55}
56
57impl QuantizedKVCache {
58 pub fn new(
60 num_layers: usize,
61 num_kv_heads: usize,
62 max_seq_len: usize,
63 head_dim: usize,
64 format: KVCacheFormat,
65 ) -> Self {
66 let elements_per_layer = num_kv_heads * max_seq_len * head_dim;
67 let bytes_per_element = format.bytes_per_element();
68 let layer_bytes = elements_per_layer * bytes_per_element;
69
70 let k_data: Vec<Vec<u8>> = (0..num_layers)
71 .map(|_| vec![0u8; layer_bytes])
72 .collect();
73 let v_data: Vec<Vec<u8>> = (0..num_layers)
74 .map(|_| vec![0u8; layer_bytes])
75 .collect();
76
77 let scales_per_layer = if format.uses_scales() {
78 num_kv_heads * max_seq_len
79 } else {
80 0
81 };
82
83 let k_scales: Vec<Vec<f32>> = (0..num_layers)
84 .map(|_| vec![0.0f32; scales_per_layer])
85 .collect();
86 let v_scales: Vec<Vec<f32>> = (0..num_layers)
87 .map(|_| vec![0.0f32; scales_per_layer])
88 .collect();
89
90 Self {
91 k_data,
92 v_data,
93 k_scales,
94 v_scales,
95 format,
96 seq_len: 0,
97 max_seq_len,
98 num_kv_heads,
99 head_dim,
100 num_layers,
101 }
102 }
103
104 pub fn reset(&mut self) {
106 self.seq_len = 0;
107 for k in &mut self.k_data {
108 k.fill(0);
109 }
110 for v in &mut self.v_data {
111 v.fill(0);
112 }
113 for s in &mut self.k_scales {
114 s.fill(0.0);
115 }
116 for s in &mut self.v_scales {
117 s.fill(0.0);
118 }
119 }
120
121 pub fn remaining_capacity(&self) -> usize {
123 self.max_seq_len.saturating_sub(self.seq_len)
124 }
125
126 pub fn is_full(&self) -> bool {
128 self.seq_len >= self.max_seq_len
129 }
130
131 pub fn memory_usage(&self) -> usize {
133 let data_bytes: usize = self.k_data.iter().map(|v| v.len()).sum::<usize>()
134 + self.v_data.iter().map(|v| v.len()).sum::<usize>();
135 let scale_bytes: usize = self.k_scales.iter().map(|v| v.len() * 4).sum::<usize>()
136 + self.v_scales.iter().map(|v| v.len() * 4).sum::<usize>();
137 data_bytes + scale_bytes
138 }
139
140 pub fn write_kv(
144 &mut self,
145 layer: usize,
146 pos: usize,
147 k_data: &[f32],
148 v_data: &[f32],
149 ) {
150 assert!(layer < self.num_layers);
151 assert!(pos < self.max_seq_len);
152 assert_eq!(k_data.len(), self.num_kv_heads * self.head_dim);
153 assert_eq!(v_data.len(), self.num_kv_heads * self.head_dim);
154
155 let k_layer = &mut self.k_data[layer];
156 let v_layer = &mut self.v_data[layer];
157
158 for head in 0..self.num_kv_heads {
159 let head_start = head * self.head_dim;
160 let head_end = head_start + self.head_dim;
161 let k_head = &k_data[head_start..head_end];
162 let v_head = &v_data[head_start..head_end];
163
164 let k_offset = (head * self.max_seq_len + pos) * self.head_dim
165 * self.format.bytes_per_element();
166 let v_offset = (head * self.max_seq_len + pos) * self.head_dim
167 * self.format.bytes_per_element();
168
169 match self.format {
170 KVCacheFormat::F32 => {
171 for (i, &val) in k_head.iter().enumerate() {
172 let bytes = val.to_le_bytes();
173 k_layer[k_offset + i * 4..k_offset + (i + 1) * 4]
174 .copy_from_slice(&bytes);
175 }
176 for (i, &val) in v_head.iter().enumerate() {
177 let bytes = val.to_le_bytes();
178 v_layer[v_offset + i * 4..v_offset + (i + 1) * 4]
179 .copy_from_slice(&bytes);
180 }
181 }
182 KVCacheFormat::Int8 => {
183 let (k_quant, k_scale) = quantize_int8(k_head);
184 let (v_quant, v_scale) = quantize_int8(v_head);
185
186 let scale_idx = head * self.max_seq_len + pos;
187 self.k_scales[layer][scale_idx] = k_scale;
188 self.v_scales[layer][scale_idx] = v_scale;
189
190 for (i, &q) in k_quant.iter().enumerate() {
191 k_layer[k_offset + i] = q as u8;
192 }
193 for (i, &q) in v_quant.iter().enumerate() {
194 v_layer[v_offset + i] = q as u8;
195 }
196 }
197 KVCacheFormat::Fp8E4M3 => {
198 for (i, &val) in k_head.iter().enumerate() {
199 k_layer[k_offset + i] = quantize_fp8_e4m3(val);
200 }
201 for (i, &val) in v_head.iter().enumerate() {
202 v_layer[v_offset + i] = quantize_fp8_e4m3(val);
203 }
204 }
205 KVCacheFormat::Fp8E5M2 => {
206 for (i, &val) in k_head.iter().enumerate() {
207 k_layer[k_offset + i] = quantize_fp8_e5m2(val);
208 }
209 for (i, &val) in v_head.iter().enumerate() {
210 v_layer[v_offset + i] = quantize_fp8_e5m2(val);
211 }
212 }
213 }
214 }
215 }
216
217 pub fn read_k(&self, layer: usize, head: usize, pos: usize) -> Vec<f32> {
219 self.read_k_range(layer, head, pos, pos + 1)
220 }
221
222 pub fn read_v(&self, layer: usize, head: usize, pos: usize) -> Vec<f32> {
224 self.read_v_range(layer, head, pos, pos + 1)
225 }
226
227 pub fn read_k_range(
231 &self,
232 layer: usize,
233 head: usize,
234 start_pos: usize,
235 end_pos: usize,
236 ) -> Vec<f32> {
237 let k_layer = &self.k_data[layer];
238 let bpe = self.format.bytes_per_element();
239 let mut result = Vec::with_capacity((end_pos - start_pos) * self.head_dim);
240
241 for pos in start_pos..end_pos {
242 let offset = (head * self.max_seq_len + pos) * self.head_dim * bpe;
243
244 for d in 0..self.head_dim {
245 let val = match self.format {
246 KVCacheFormat::F32 => {
247 let byte_offset = offset + d * 4;
248 f32::from_le_bytes(
249 k_layer[byte_offset..byte_offset + 4]
250 .try_into()
251 .unwrap(),
252 )
253 }
254 KVCacheFormat::Int8 => {
255 let scale_idx = head * self.max_seq_len + pos;
256 let scale = self.k_scales[layer][scale_idx];
257 let q = k_layer[offset + d] as i8;
258 dequantize_int8(&[q], scale)[0]
259 }
260 KVCacheFormat::Fp8E4M3 => dequantize_fp8_e4m3(k_layer[offset + d]),
261 KVCacheFormat::Fp8E5M2 => dequantize_fp8_e5m2(k_layer[offset + d]),
262 };
263 result.push(val);
264 }
265 }
266 result
267 }
268
269 pub fn read_v_range(
273 &self,
274 layer: usize,
275 head: usize,
276 start_pos: usize,
277 end_pos: usize,
278 ) -> Vec<f32> {
279 let v_layer = &self.v_data[layer];
280 let bpe = self.format.bytes_per_element();
281 let mut result = Vec::with_capacity((end_pos - start_pos) * self.head_dim);
282
283 for pos in start_pos..end_pos {
284 let offset = (head * self.max_seq_len + pos) * self.head_dim * bpe;
285
286 for d in 0..self.head_dim {
287 let val = match self.format {
288 KVCacheFormat::F32 => {
289 let byte_offset = offset + d * 4;
290 f32::from_le_bytes(
291 v_layer[byte_offset..byte_offset + 4]
292 .try_into()
293 .unwrap(),
294 )
295 }
296 KVCacheFormat::Int8 => {
297 let scale_idx = head * self.max_seq_len + pos;
298 let scale = self.v_scales[layer][scale_idx];
299 let q = v_layer[offset + d] as i8;
300 dequantize_int8(&[q], scale)[0]
301 }
302 KVCacheFormat::Fp8E4M3 => dequantize_fp8_e4m3(v_layer[offset + d]),
303 KVCacheFormat::Fp8E5M2 => dequantize_fp8_e5m2(v_layer[offset + d]),
304 };
305 result.push(val);
306 }
307 }
308 result
309 }
310
311 pub fn shift_left(&mut self, amount: usize) {
313 if amount == 0 || amount >= self.seq_len {
314 self.reset();
315 return;
316 }
317
318 let new_len = self.seq_len - amount;
319 let bpe = self.format.bytes_per_element();
320
321 for layer_idx in 0..self.num_layers {
322 let k_layer = &mut self.k_data[layer_idx];
323 let v_layer = &mut self.v_data[layer_idx];
324
325 for head in 0..self.num_kv_heads {
326 for pos in 0..new_len {
327 let src_pos = pos + amount;
328 let src_offset = (head * self.max_seq_len + src_pos) * self.head_dim * bpe;
329 let dst_offset = (head * self.max_seq_len + pos) * self.head_dim * bpe;
330 let block_len = self.head_dim * bpe;
331
332 k_layer.copy_within(src_offset..src_offset + block_len, dst_offset);
333 v_layer.copy_within(src_offset..src_offset + block_len, dst_offset);
334 }
335 }
336
337 if self.format.uses_scales() {
338 let k_scales = &mut self.k_scales[layer_idx];
339 let v_scales = &mut self.v_scales[layer_idx];
340
341 for head in 0..self.num_kv_heads {
342 for pos in 0..new_len {
343 let src_idx = head * self.max_seq_len + (pos + amount);
344 let dst_idx = head * self.max_seq_len + pos;
345 k_scales[dst_idx] = k_scales[src_idx];
346 v_scales[dst_idx] = v_scales[src_idx];
347 }
348 }
349 }
350 }
351
352 self.seq_len = new_len;
353 }
354
355 pub fn truncate(&mut self, new_len: usize) {
357 if new_len < self.seq_len {
358 self.seq_len = new_len;
359 }
360 }
361}
362
363fn quantize_int8(data: &[f32]) -> (Vec<i8>, f32) {
367 let max_abs = data
368 .iter()
369 .map(|&x| x.abs())
370 .fold(0.0f32, f32::max);
371
372 let scale = if max_abs > 1e-10 {
373 max_abs / 127.0
374 } else {
375 1.0
376 };
377
378 let quantized: Vec<i8> = data
379 .iter()
380 .map(|&x| {
381 let q = (x / scale).round();
382 q.clamp(-128.0, 127.0) as i8
383 })
384 .collect();
385
386 (quantized, scale)
387}
388
389fn dequantize_int8(data: &[i8], scale: f32) -> Vec<f32> {
391 data.iter().map(|&q| (q as f32) * scale).collect()
392}
393
394fn quantize_fp8_e4m3(value: f32) -> u8 {
396 if value.is_nan() {
397 return 0xFF;
398 }
399 if value.is_infinite() {
400 return if value > 0.0 { 0x7F } else { 0xFF };
401 }
402 if value == 0.0 {
403 return 0x00;
404 }
405
406 let bits = value.to_bits();
407 let sign = ((bits >> 31) & 1) as u8;
408 let exponent = ((bits >> 23) & 0xFF) as i32 - 127;
409 let mut mantissa = bits & 0x7F_FFFF;
410 if exponent != -127 {
411 mantissa |= 0x800_000;
412 }
413
414 let e4m3_exp = exponent + 7;
415
416 if e4m3_exp > 15 {
417 return (sign << 7) | 0x7E;
418 }
419 if (e4m3_exp > -3) && (e4m3_exp <= 0) {
420 let shift_bits = (3 + e4m3_exp) as u32;
421 let mask = 0x7u32 >> (0i32.saturating_sub(e4m3_exp) as u32);
422 let e4m3_mantissa = ((mantissa >> (24 - shift_bits)) & mask) as u8;
423 return (sign << 7) | e4m3_mantissa;
424 }
425 if e4m3_exp <= -3 {
426 return sign << 7;
427 }
428
429 let e4m3_mantissa = ((mantissa >> 20) & 0x7) as u8;
430 (sign << 7) | ((e4m3_exp as u8) << 3) | e4m3_mantissa
431}
432
433fn dequantize_fp8_e4m3(value: u8) -> f32 {
435 let bits = value;
436 if (bits & 0x7F) == 0 {
438 return 0.0;
439 }
440 if (bits & 0x7F) == 0x7F {
442 return f32::NAN;
443 }
444
445 let sign = (bits >> 7) & 1;
446 let e4m3_exp = (bits >> 3) & 0xF;
447 let e4m3_mantissa = bits & 0x7;
448 let exponent = (e4m3_exp as i32) - 7;
449 let float_exp = (exponent + 127) as u32;
450
451 let result = if e4m3_exp > 0 {
452 (sign as u32) << 31 | float_exp << 23 | (e4m3_mantissa as u32) << 20
453 } else {
454 match e4m3_mantissa {
455 m if m >= 4 => (sign as u32) << 31 | float_exp << 23 | ((m & 3) as u32) << 21,
456 m if m > 1 => (sign as u32) << 31 | (float_exp - 1) << 23 | ((m & 1) as u32) << 22,
457 1 => (sign as u32) << 31 | (float_exp - 2) << 23,
458 _ => return f32::NAN,
459 }
460 };
461
462 f32::from_bits(result)
463}
464
465fn quantize_fp8_e5m2(value: f32) -> u8 {
467 if value.is_nan() {
468 return 0xFF;
469 }
470 if value.is_infinite() {
471 return if value > 0.0 { 0x7C } else { 0xFC };
472 }
473 if value == 0.0 {
474 return 0x00;
475 }
476
477 let bits = value.to_bits();
478 let sign = ((bits >> 31) & 1) as u8;
479 let exponent = ((bits >> 23) & 0xFF) as i32 - 127;
480 let mut mantissa = bits & 0x7F_FFFF;
481 if exponent != -127 {
482 mantissa |= 0x800_000;
483 }
484
485 let e5m2_exp = exponent + 15;
486
487 if e5m2_exp > 31 {
488 return (sign << 7) | 0x7C;
489 }
490 if (e5m2_exp >= -1) && (e5m2_exp <= 0) {
491 let shift_bits = (2 + e5m2_exp) as u32;
492 let mask = 0x3u32 >> (0i32.saturating_sub(e5m2_exp) as u32);
493 let e5m2_mantissa = ((mantissa >> (24 - shift_bits)) & mask) as u8;
494 return (sign << 7) | e5m2_mantissa;
495 }
496 if e5m2_exp < -1 {
497 return sign << 7;
498 }
499
500 let e5m2_mantissa = ((mantissa >> 21) & 0x3) as u8;
501 (sign << 7) | ((e5m2_exp as u8) << 2) | e5m2_mantissa
502}
503
504fn dequantize_fp8_e5m2(value: u8) -> f32 {
506 let bits = value;
507 if (bits & 0x7F) == 0 {
509 return 0.0;
510 }
511 if (bits & 0x7F) == 0x7C {
513 return if (bits >> 7) != 0 {
514 f32::NEG_INFINITY
515 } else {
516 f32::INFINITY
517 };
518 }
519 if (bits & 0x7F) >= 0x7D {
521 return f32::NAN;
522 }
523
524 let sign = (bits >> 7) & 1;
525 let e5m2_exp = (bits >> 2) & 0x1F;
526 let e5m2_mantissa = bits & 0x3;
527 let exponent = (e5m2_exp as i32) - 15;
528 let float_exp = (exponent + 127) as u32;
529
530 let result = if e5m2_exp > 0 {
531 (sign as u32) << 31 | float_exp << 23 | (e5m2_mantissa as u32) << 21
532 } else {
533 match e5m2_mantissa {
534 m if m >= 2 => (sign as u32) << 31 | float_exp << 23 | ((m & 1) as u32) << 22,
535 1 => (sign as u32) << 31 | (float_exp - 1) << 23,
536 _ => return f32::NAN,
537 }
538 };
539
540 f32::from_bits(result)
541}
542
543#[cfg(test)]
544mod tests {
545 use super::*;
546
547 #[test]
548 fn test_int8_roundtrip() {
549 let data: Vec<f32> = (0..128).map(|i| (i as f32) * 0.1 - 6.4).collect();
550 let (quantized, scale) = quantize_int8(&data);
551 let dequantized = dequantize_int8(&quantized, scale);
552 for (orig, dec) in data.iter().zip(dequantized.iter()) {
553 let rel_err = if orig.abs() > 1e-6 {
554 (orig - dec).abs() / orig.abs()
555 } else {
556 (orig - dec).abs()
557 };
558 assert!(rel_err < 0.02, "orig={orig}, dec={dec}, rel_err={rel_err}");
559 }
560 }
561
562 #[test]
563 fn test_fp8_e4m3_roundtrip() {
564 let values = [
565 0.0f32,
566 1.0,
567 -1.0,
568 0.5,
569 0.0136719,
570 448.0,
571 2f32.powi(-6),
572 2f32.powi(-9),
573 ];
574 for &val in &values {
575 let q = quantize_fp8_e4m3(val);
576 let d = dequantize_fp8_e4m3(q);
577 if val == 0.0 {
578 assert_eq!(d, 0.0, "zero roundtrip");
579 } else if val.abs() < 1e-5 {
580 assert!(d.abs() < 0.01, "small value {val} -> {d}");
581 } else {
582 let rel_err = (val - d).abs() / val.abs();
583 assert!(rel_err < 0.05, "val={val}, d={d}, rel_err={rel_err}");
584 }
585 }
586 }
587
588 #[test]
589 fn test_fp8_e5m2_roundtrip() {
590 let values = [
591 0.0f32,
592 1.0,
593 -1.0,
594 0.5,
595 57344.0,
596 2f32.powi(-14),
597 1.52588e-5,
598 ];
599 for &val in &values {
600 let q = quantize_fp8_e5m2(val);
601 let d = dequantize_fp8_e5m2(q);
602 if val == 0.0 {
603 assert_eq!(d, 0.0, "zero roundtrip");
604 } else if val.abs() < 1e-5 {
605 assert!(d.abs() < 0.01, "small value {val} -> {d}");
606 } else {
607 let rel_err = (val - d).abs() / val.abs();
608 assert!(rel_err < 0.1, "val={val}, d={d}, rel_err={rel_err}");
609 }
610 }
611 }
612
613 #[test]
614 fn test_quantized_kv_cache_basic() {
615 let num_layers = 2;
616 let num_kv_heads = 4;
617 let max_seq_len = 16;
618 let head_dim = 64;
619
620 for format in [
621 KVCacheFormat::Int8,
622 KVCacheFormat::Fp8E4M3,
623 KVCacheFormat::Fp8E5M2,
624 ] {
625 let mut cache =
626 QuantizedKVCache::new(num_layers, num_kv_heads, max_seq_len, head_dim, format);
627
628 let k_data: Vec<f32> = (0..num_kv_heads * head_dim)
629 .map(|i| (i as f32) * 0.01 - 1.0)
630 .collect();
631 let v_data: Vec<f32> = (0..num_kv_heads * head_dim)
632 .map(|i| (i as f32) * 0.02 - 0.5)
633 .collect();
634
635 cache.write_kv(0, 0, &k_data, &v_data);
636 cache.seq_len = 1;
637
638 let read_k = cache.read_k(0, 0, 0);
639 let read_v = cache.read_v(0, 0, 0);
640
641 assert_eq!(read_k.len(), head_dim);
642 assert_eq!(read_v.len(), head_dim);
643
644 let orig_k_head = &k_data[0..head_dim];
645 let orig_v_head = &v_data[0..head_dim];
646
647 let tol = match format {
648 KVCacheFormat::Int8 => 0.15,
649 KVCacheFormat::Fp8E4M3 | KVCacheFormat::Fp8E5M2 => 0.25,
650 _ => 0.01,
651 };
652 for (a, b) in orig_k_head.iter().zip(read_k.iter()) {
653 let rel_err = if a.abs() > 1e-6 {
654 (a - b).abs() / a.abs()
655 } else {
656 (a - b).abs()
657 };
658 assert!(rel_err < tol, "k: orig={a}, read={b}");
659 }
660 for (a, b) in orig_v_head.iter().zip(read_v.iter()) {
661 let rel_err = if a.abs() > 1e-6 {
662 (a - b).abs() / a.abs()
663 } else {
664 (a - b).abs()
665 };
666 assert!(rel_err < tol, "v: orig={a}, read={b}");
667 }
668 }
669 }
670
671 #[test]
672 fn test_memory_savings() {
673 let num_layers = 4;
674 let num_kv_heads = 32;
675 let max_seq_len = 2048;
676 let head_dim = 128;
677
678 let f32_size = num_layers * 2 * (num_kv_heads * max_seq_len * head_dim * 4);
679
680 let int8_cache =
681 QuantizedKVCache::new(num_layers, num_kv_heads, max_seq_len, head_dim, KVCacheFormat::Int8);
682 let fp8_cache =
683 QuantizedKVCache::new(num_layers, num_kv_heads, max_seq_len, head_dim, KVCacheFormat::Fp8E4M3);
684
685 let int8_size = int8_cache.memory_usage();
686 let fp8_size = fp8_cache.memory_usage();
687
688 assert!(int8_size < f32_size / 2 + f32_size / 4);
689 assert!(fp8_size < f32_size / 2 + f32_size / 4);
690 }
691
692 #[test]
693 fn test_shift_left() {
694 let num_layers = 1;
695 let num_kv_heads = 2;
696 let max_seq_len = 8;
697 let head_dim = 4;
698
699 let mut cache = QuantizedKVCache::new(
700 num_layers,
701 num_kv_heads,
702 max_seq_len,
703 head_dim,
704 KVCacheFormat::Int8,
705 );
706
707 for pos in 0..5 {
708 let k_data: Vec<f32> = (0..num_kv_heads * head_dim)
709 .map(|_| pos as f32)
710 .collect();
711 let v_data = k_data.clone();
712 cache.write_kv(0, pos, &k_data, &v_data);
713 }
714 cache.seq_len = 5;
715
716 cache.shift_left(2);
717
718 assert_eq!(cache.seq_len, 3);
719
720 for (i, pos) in (2..5).enumerate() {
721 let read_k = cache.read_k(0, 0, i);
722 let expected: Vec<f32> = (0..head_dim).map(|_| pos as f32).collect();
723 for (a, b) in read_k.iter().zip(expected.iter()) {
724 assert!((a - b).abs() < 0.01, "pos {i}: expected {b}, got {a}");
725 }
726 }
727 }
728}