Skip to main content

oxibonsai_model/
kv_cache_quant.rs

1//! Quantized KV cache: INT8 and FP8 per-row quantization for keys and values.
2//!
3//! INT8 memory reduction: 4× vs FP32, 2× vs FP16.
4//! FP8 memory reduction: 4× vs FP32, 2× vs FP16 (with floating-point distribution).
5//! Accuracy: ~0.1% error vs FP32 for typical activation ranges.
6//!
7//! # Layout
8//! For each layer, each head, each token position:
9//!   - keys_i8: [seq_len, num_kv_heads, head_dim] as i8
10//!   - key_scales: [seq_len, num_kv_heads] as f32  (per-row scale)
11//!   - values_i8: [seq_len, num_kv_heads, head_dim] as i8
12//!   - value_scales: [seq_len, num_kv_heads] as f32
13
14use oxibonsai_core::quant_fp8::{
15    fp8_e4m3_decode, fp8_e4m3_encode, fp8_e5m2_decode, fp8_e5m2_encode, FP8_E4M3_MAX, FP8_E5M2_MAX,
16};
17
18/// Error types for quantized KV cache operations.
19#[derive(Debug, thiserror::Error)]
20pub enum QuantKvError {
21    #[error("capacity exceeded: capacity {capacity}, tried to push token {pos}")]
22    CapacityExceeded { capacity: usize, pos: usize },
23
24    #[error("token position {0} out of range")]
25    PositionOutOfRange(usize),
26
27    #[error("head index {head} out of range (num_kv_heads = {num_heads})")]
28    HeadOutOfRange { head: usize, num_heads: usize },
29
30    #[error("layer {layer} out of range (num_layers = {num_layers})")]
31    LayerOutOfRange { layer: usize, num_layers: usize },
32
33    #[error("key/value shape mismatch: expected {expected}, got {actual}")]
34    ShapeMismatch { expected: usize, actual: usize },
35}
36
37// ─── Primitive quantization helpers ──────────────────────────────────────────
38
39/// Quantize a slice to INT8 with a single per-row scale.
40///
41/// Returns `(quantized: Vec<i8>, scale: f32)`.
42///
43/// `scale = max(|x|) / 127.0`, clamped to at least [`f32::EPSILON`] to avoid
44/// division-by-zero. All values are symmetrically clamped to `[-127, 127]` so
45/// that rounding can never produce the asymmetric `i8::MIN` (-128).
46pub fn quantize_row_i8(row: &[f32]) -> (Vec<i8>, f32) {
47    if row.is_empty() {
48        return (Vec::new(), f32::EPSILON);
49    }
50
51    let max_abs = row.iter().map(|x| x.abs()).fold(0.0_f32, f32::max);
52
53    // Clamp scale to at least EPSILON to avoid division by zero for all-zero rows.
54    let scale = (max_abs / 127.0_f32).max(f32::EPSILON);
55
56    let quantized = row
57        .iter()
58        .map(|&x| (x / scale).round().clamp(-127.0, 127.0) as i8)
59        .collect();
60
61    (quantized, scale)
62}
63
64/// Dequantize INT8 back to f32 using the row scale.
65///
66/// Each element is simply multiplied by `scale`. If `scale` is zero or
67/// near-zero the output will be all zeros, which is the correct representation
68/// for an all-zero input row.
69pub fn dequantize_row_i8(quantized: &[i8], scale: f32) -> Vec<f32> {
70    quantized.iter().map(|&q| q as f32 * scale).collect()
71}
72
73/// Mean absolute error (MAE) between the original f32 slice and the
74/// dequantized version of the quantized INT8 representation.
75///
76/// Returns `0.0` for an empty slice.
77pub fn quant_error_mae(original: &[f32], quantized: &[i8], scale: f32) -> f32 {
78    let n = original.len().min(quantized.len());
79    if n == 0 {
80        return 0.0;
81    }
82    let sum: f32 = original
83        .iter()
84        .zip(quantized.iter())
85        .map(|(&o, &q)| (o - q as f32 * scale).abs())
86        .sum();
87    sum / n as f32
88}
89
90// ─── Per-layer quantized KV storage ──────────────────────────────────────────
91
92/// A single layer's INT8-quantized KV cache.
93///
94/// Memory layout for the INT8 data arrays uses the token-major order
95/// `[token_pos * num_kv_heads * head_dim]`, so sequential decode steps
96/// append contiguous blocks. Scale arrays use `[token_pos * num_kv_heads]`.
97#[derive(Debug)]
98pub struct QuantizedKvLayer {
99    /// Quantized key data: `[capacity * num_kv_heads * head_dim]` as i8.
100    keys_i8: Vec<i8>,
101    /// Per-row key scales: `[capacity * num_kv_heads]` as f32.
102    key_scales: Vec<f32>,
103    /// Quantized value data: `[capacity * num_kv_heads * head_dim]` as i8.
104    values_i8: Vec<i8>,
105    /// Per-row value scales: `[capacity * num_kv_heads]` as f32.
106    value_scales: Vec<f32>,
107    /// Number of KV attention heads.
108    pub num_kv_heads: usize,
109    /// Dimension of each attention head.
110    pub head_dim: usize,
111    /// Maximum number of token positions pre-allocated.
112    pub capacity: usize,
113    /// Number of token positions actually stored so far.
114    pub len: usize,
115}
116
117impl QuantizedKvLayer {
118    /// Allocate an empty quantized KV layer with the given dimensions.
119    ///
120    /// Pre-allocates all storage so that subsequent [`push`](Self::push) calls
121    /// do not allocate.
122    pub fn new(capacity: usize, num_kv_heads: usize, head_dim: usize) -> Self {
123        let data_len = capacity * num_kv_heads * head_dim;
124        let scale_len = capacity * num_kv_heads;
125
126        Self {
127            keys_i8: vec![0i8; data_len],
128            key_scales: vec![0.0_f32; scale_len],
129            values_i8: vec![0i8; data_len],
130            value_scales: vec![0.0_f32; scale_len],
131            num_kv_heads,
132            head_dim,
133            capacity,
134            len: 0,
135        }
136    }
137
138    /// Append keys and values for the next token position.
139    ///
140    /// `keys` must be a flat slice of shape `[num_kv_heads * head_dim]` (heads
141    /// first, then dims). `values` must have the same shape.
142    ///
143    /// Each head's row is quantized independently with its own scale.
144    ///
145    /// # Errors
146    /// - [`QuantKvError::CapacityExceeded`] if `self.len == self.capacity`.
147    /// - [`QuantKvError::ShapeMismatch`] if `keys` or `values` length is wrong.
148    pub fn push(&mut self, keys: &[f32], values: &[f32]) -> Result<(), QuantKvError> {
149        let expected = self.num_kv_heads * self.head_dim;
150
151        if keys.len() != expected {
152            return Err(QuantKvError::ShapeMismatch {
153                expected,
154                actual: keys.len(),
155            });
156        }
157        if values.len() != expected {
158            return Err(QuantKvError::ShapeMismatch {
159                expected,
160                actual: values.len(),
161            });
162        }
163        if self.len >= self.capacity {
164            return Err(QuantKvError::CapacityExceeded {
165                capacity: self.capacity,
166                pos: self.len,
167            });
168        }
169
170        let token_pos = self.len;
171
172        for head in 0..self.num_kv_heads {
173            let row_start = head * self.head_dim;
174            let row_end = row_start + self.head_dim;
175
176            // Compute offsets before any mutable borrows to satisfy the borrow checker.
177            let data_off = self.data_offset(token_pos, head);
178            let scale_off = self.scale_offset(token_pos, head);
179
180            // Keys
181            let key_row = &keys[row_start..row_end];
182            let (kq, ks) = quantize_row_i8(key_row);
183            self.keys_i8[data_off..data_off + self.head_dim].copy_from_slice(&kq);
184            self.key_scales[scale_off] = ks;
185
186            // Values
187            let val_row = &values[row_start..row_end];
188            let (vq, vs) = quantize_row_i8(val_row);
189            self.values_i8[data_off..data_off + self.head_dim].copy_from_slice(&vq);
190            self.value_scales[scale_off] = vs;
191        }
192
193        self.len += 1;
194        Ok(())
195    }
196
197    /// Get dequantized keys for a specific token position and head.
198    ///
199    /// Returns a `Vec<f32>` of length `head_dim`.
200    ///
201    /// # Errors
202    /// - [`QuantKvError::PositionOutOfRange`] if `token_pos >= self.len`.
203    /// - [`QuantKvError::HeadOutOfRange`] if `head >= self.num_kv_heads`.
204    pub fn get_key(&self, token_pos: usize, head: usize) -> Result<Vec<f32>, QuantKvError> {
205        self.validate_pos_head(token_pos, head)?;
206        let data_off = self.data_offset(token_pos, head);
207        let scale = self.key_scales[self.scale_offset(token_pos, head)];
208        Ok(dequantize_row_i8(
209            &self.keys_i8[data_off..data_off + self.head_dim],
210            scale,
211        ))
212    }
213
214    /// Get dequantized values for a specific token position and head.
215    ///
216    /// Returns a `Vec<f32>` of length `head_dim`.
217    ///
218    /// # Errors
219    /// - [`QuantKvError::PositionOutOfRange`] if `token_pos >= self.len`.
220    /// - [`QuantKvError::HeadOutOfRange`] if `head >= self.num_kv_heads`.
221    pub fn get_value(&self, token_pos: usize, head: usize) -> Result<Vec<f32>, QuantKvError> {
222        self.validate_pos_head(token_pos, head)?;
223        let data_off = self.data_offset(token_pos, head);
224        let scale = self.value_scales[self.scale_offset(token_pos, head)];
225        Ok(dequantize_row_i8(
226            &self.values_i8[data_off..data_off + self.head_dim],
227            scale,
228        ))
229    }
230
231    /// Get all dequantized keys for a token position (all heads, interleaved).
232    ///
233    /// Returns a flat `Vec<f32>` of length `num_kv_heads * head_dim`.
234    ///
235    /// # Errors
236    /// - [`QuantKvError::PositionOutOfRange`] if `token_pos >= self.len`.
237    pub fn get_keys_at(&self, token_pos: usize) -> Result<Vec<f32>, QuantKvError> {
238        if token_pos >= self.len {
239            return Err(QuantKvError::PositionOutOfRange(token_pos));
240        }
241        let mut out = Vec::with_capacity(self.num_kv_heads * self.head_dim);
242        for head in 0..self.num_kv_heads {
243            let data_off = self.data_offset(token_pos, head);
244            let scale = self.key_scales[self.scale_offset(token_pos, head)];
245            out.extend(dequantize_row_i8(
246                &self.keys_i8[data_off..data_off + self.head_dim],
247                scale,
248            ));
249        }
250        Ok(out)
251    }
252
253    /// Get all dequantized values for a token position (all heads, interleaved).
254    ///
255    /// Returns a flat `Vec<f32>` of length `num_kv_heads * head_dim`.
256    ///
257    /// # Errors
258    /// - [`QuantKvError::PositionOutOfRange`] if `token_pos >= self.len`.
259    pub fn get_values_at(&self, token_pos: usize) -> Result<Vec<f32>, QuantKvError> {
260        if token_pos >= self.len {
261            return Err(QuantKvError::PositionOutOfRange(token_pos));
262        }
263        let mut out = Vec::with_capacity(self.num_kv_heads * self.head_dim);
264        for head in 0..self.num_kv_heads {
265            let data_off = self.data_offset(token_pos, head);
266            let scale = self.value_scales[self.scale_offset(token_pos, head)];
267            out.extend(dequantize_row_i8(
268                &self.values_i8[data_off..data_off + self.head_dim],
269                scale,
270            ));
271        }
272        Ok(out)
273    }
274
275    /// Memory used by this layer in bytes (INT8 data + f32 scales).
276    ///
277    /// Only accounts for the pre-allocated storage slabs, not struct overhead.
278    pub fn memory_bytes(&self) -> usize {
279        // INT8 data: 1 byte per element
280        let data_bytes = self.keys_i8.len() + self.values_i8.len();
281        // f32 scales: 4 bytes each
282        let scale_bytes = (self.key_scales.len() + self.value_scales.len()) * 4;
283        data_bytes + scale_bytes
284    }
285
286    /// Equivalent memory if the same data were stored as FP32 (no scales).
287    ///
288    /// `2 * capacity * num_kv_heads * head_dim * 4 bytes`
289    pub fn fp32_memory_bytes(&self) -> usize {
290        // Keys + values, each element 4 bytes
291        2 * self.capacity * self.num_kv_heads * self.head_dim * 4
292    }
293
294    /// Compression ratio versus FP32 storage.
295    ///
296    /// Values approaching 4.0 indicate near-ideal INT8 compression. The ratio
297    /// is slightly below 4.0 because per-row f32 scales add overhead.
298    pub fn compression_ratio(&self) -> f32 {
299        let quant = self.memory_bytes();
300        if quant == 0 {
301            return 1.0;
302        }
303        self.fp32_memory_bytes() as f32 / quant as f32
304    }
305
306    // ── Internal helpers ──────────────────────────────────────────────────────
307
308    /// Flat index into the INT8 data arrays for `(token_pos, head, 0)`.
309    ///
310    /// Layout: `[token_pos][head][dim]` → `(token_pos * num_kv_heads + head) * head_dim`
311    #[inline]
312    fn data_offset(&self, token_pos: usize, head: usize) -> usize {
313        (token_pos * self.num_kv_heads + head) * self.head_dim
314    }
315
316    /// Flat index into the scale arrays for `(token_pos, head)`.
317    ///
318    /// Layout: `[token_pos][head]` → `token_pos * num_kv_heads + head`
319    #[inline]
320    fn scale_offset(&self, token_pos: usize, head: usize) -> usize {
321        token_pos * self.num_kv_heads + head
322    }
323
324    /// Validate that `token_pos < self.len` and `head < self.num_kv_heads`.
325    fn validate_pos_head(&self, token_pos: usize, head: usize) -> Result<(), QuantKvError> {
326        if token_pos >= self.len {
327            return Err(QuantKvError::PositionOutOfRange(token_pos));
328        }
329        if head >= self.num_kv_heads {
330            return Err(QuantKvError::HeadOutOfRange {
331                head,
332                num_heads: self.num_kv_heads,
333            });
334        }
335        Ok(())
336    }
337}
338
339// ─── Multi-layer quantized KV cache ──────────────────────────────────────────
340
341/// Full multi-layer INT8-quantized KV cache for autoregressive decoding.
342///
343/// Wraps one [`QuantizedKvLayer`] per transformer layer and exposes a
344/// unified decode-step interface through [`push_step`](Self::push_step).
345#[derive(Debug)]
346pub struct QuantizedKvCache {
347    layers: Vec<QuantizedKvLayer>,
348    /// Number of transformer layers.
349    pub num_layers: usize,
350    /// Number of KV attention heads per layer.
351    pub num_kv_heads: usize,
352    /// Dimension of each attention head.
353    pub head_dim: usize,
354}
355
356impl QuantizedKvCache {
357    /// Allocate a new quantized KV cache for `num_layers` transformer layers.
358    ///
359    /// Each layer is pre-allocated for `capacity` token positions.
360    pub fn new(num_layers: usize, capacity: usize, num_kv_heads: usize, head_dim: usize) -> Self {
361        let layers = (0..num_layers)
362            .map(|_| QuantizedKvLayer::new(capacity, num_kv_heads, head_dim))
363            .collect();
364
365        Self {
366            layers,
367            num_layers,
368            num_kv_heads,
369            head_dim,
370        }
371    }
372
373    /// Append KV tensors for all layers at the current decode step.
374    ///
375    /// `all_keys[layer]` must be a flat slice of shape `[num_kv_heads * head_dim]`.
376    /// `all_values[layer]` must have the same shape.
377    ///
378    /// # Errors
379    /// - [`QuantKvError::LayerOutOfRange`] if `all_keys.len() != self.num_layers`.
380    /// - Propagates [`QuantKvError`] from each layer's [`push`](QuantizedKvLayer::push).
381    pub fn push_step(
382        &mut self,
383        all_keys: &[Vec<f32>],
384        all_values: &[Vec<f32>],
385    ) -> Result<(), QuantKvError> {
386        if all_keys.len() != self.num_layers {
387            return Err(QuantKvError::LayerOutOfRange {
388                layer: all_keys.len(),
389                num_layers: self.num_layers,
390            });
391        }
392        if all_values.len() != self.num_layers {
393            return Err(QuantKvError::LayerOutOfRange {
394                layer: all_values.len(),
395                num_layers: self.num_layers,
396            });
397        }
398
399        for (layer_idx, (layer, (keys, values))) in self
400            .layers
401            .iter_mut()
402            .zip(all_keys.iter().zip(all_values.iter()))
403            .enumerate()
404        {
405            layer.push(keys, values).map_err(|e| match e {
406                // Re-attach layer context to capacity errors
407                QuantKvError::CapacityExceeded { capacity, pos } => {
408                    QuantKvError::CapacityExceeded { capacity, pos }
409                }
410                QuantKvError::ShapeMismatch { expected, actual } => {
411                    QuantKvError::ShapeMismatch { expected, actual }
412                }
413                // Pass through other errors; we could enrich them with layer_idx
414                // but the error types don't carry that field — keep as is.
415                other => {
416                    let _ = layer_idx;
417                    other
418                }
419            })?;
420        }
421        Ok(())
422    }
423
424    /// Get dequantized keys for a specific layer, token position, and head.
425    ///
426    /// # Errors
427    /// - [`QuantKvError::LayerOutOfRange`] if `layer >= self.num_layers`.
428    /// - Propagates position/head errors from the underlying layer.
429    pub fn get_key(
430        &self,
431        layer: usize,
432        token_pos: usize,
433        head: usize,
434    ) -> Result<Vec<f32>, QuantKvError> {
435        self.validate_layer(layer)?;
436        self.layers[layer].get_key(token_pos, head)
437    }
438
439    /// Get dequantized values for a specific layer, token position, and head.
440    ///
441    /// # Errors
442    /// - [`QuantKvError::LayerOutOfRange`] if `layer >= self.num_layers`.
443    /// - Propagates position/head errors from the underlying layer.
444    pub fn get_value(
445        &self,
446        layer: usize,
447        token_pos: usize,
448        head: usize,
449    ) -> Result<Vec<f32>, QuantKvError> {
450        self.validate_layer(layer)?;
451        self.layers[layer].get_value(token_pos, head)
452    }
453
454    /// Total memory used across all layers in bytes.
455    pub fn total_memory_bytes(&self) -> usize {
456        self.layers.iter().map(|l| l.memory_bytes()).sum()
457    }
458
459    /// FP32-equivalent memory across all layers.
460    pub fn total_fp32_memory_bytes(&self) -> usize {
461        self.layers.iter().map(|l| l.fp32_memory_bytes()).sum()
462    }
463
464    /// Overall compression ratio vs FP32.
465    pub fn compression_ratio(&self) -> f32 {
466        let quant = self.total_memory_bytes();
467        if quant == 0 {
468            return 1.0;
469        }
470        self.total_fp32_memory_bytes() as f32 / quant as f32
471    }
472
473    /// Number of token positions currently stored (taken from layer 0).
474    ///
475    /// Returns `0` if there are no layers.
476    pub fn seq_len(&self) -> usize {
477        self.layers.first().map(|l| l.len).unwrap_or(0)
478    }
479
480    // ── Internal helpers ──────────────────────────────────────────────────────
481
482    fn validate_layer(&self, layer: usize) -> Result<(), QuantKvError> {
483        if layer >= self.num_layers {
484            return Err(QuantKvError::LayerOutOfRange {
485                layer,
486                num_layers: self.num_layers,
487            });
488        }
489        Ok(())
490    }
491}
492
493// ─── FP8 KV cache ─────────────────────────────────────────────────────────────
494
495/// FP8 encoding format variant for KV cache quantization.
496///
497/// - `E4M3` uses 4-bit exponent, 3-bit mantissa (max representable ≈ 448.0).
498///   Better accuracy for typical attention activations with bounded range.
499/// - `E5M2` uses 5-bit exponent, 2-bit mantissa (max representable ≈ 57344.0).
500///   Wider dynamic range, useful for outlier-heavy distributions.
501#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
502pub enum Fp8KvFormat {
503    /// E4M3FN format: 4-bit exponent, 3-bit mantissa, bias=7.
504    /// Max representable value: 448.0. No infinities; NaN = 0x7f/0xff.
505    E4M3,
506    /// E5M2 format: 5-bit exponent, 2-bit mantissa, bias=15.
507    /// Max representable value: 57344.0. Supports infinities; NaN = 0x7e.
508    E5M2,
509}
510
511/// Quantize a row of f32 values to FP8 using per-row absolute-max scaling.
512///
513/// Returns `(quantized_bytes: Vec<u8>, scale: f32)` where
514/// `scale = max(|row|) / FP8_MAX`. One scale per head-row is stored; all
515/// values are encoded relative to that scale.
516///
517/// For an all-zero row the scale is clamped to [`f32::EPSILON`] and all output
518/// bytes are `0x00`.
519fn quantize_row_fp8(row: &[f32], format: Fp8KvFormat) -> (Vec<u8>, f32) {
520    if row.is_empty() {
521        return (Vec::new(), f32::EPSILON);
522    }
523
524    let max_abs = row.iter().map(|x| x.abs()).fold(0.0_f32, f32::max);
525
526    let fp8_max = match format {
527        Fp8KvFormat::E4M3 => FP8_E4M3_MAX,
528        Fp8KvFormat::E5M2 => FP8_E5M2_MAX,
529    };
530
531    // Clamp scale to at least EPSILON to avoid division by zero for all-zero rows.
532    let scale = (max_abs / fp8_max).max(f32::EPSILON);
533
534    let quantized = match format {
535        Fp8KvFormat::E4M3 => row.iter().map(|&x| fp8_e4m3_encode(x / scale)).collect(),
536        Fp8KvFormat::E5M2 => row.iter().map(|&x| fp8_e5m2_encode(x / scale)).collect(),
537    };
538
539    (quantized, scale)
540}
541
542/// Dequantize FP8 bytes back to f32 using the stored row scale.
543///
544/// Each element is decoded from FP8 then multiplied by `scale`.
545fn dequantize_row_fp8(quantized: &[u8], scale: f32, format: Fp8KvFormat) -> Vec<f32> {
546    match format {
547        Fp8KvFormat::E4M3 => quantized
548            .iter()
549            .map(|&b| fp8_e4m3_decode(b) * scale)
550            .collect(),
551        Fp8KvFormat::E5M2 => quantized
552            .iter()
553            .map(|&b| fp8_e5m2_decode(b) * scale)
554            .collect(),
555    }
556}
557
558/// A single transformer layer's FP8-quantized KV cache.
559///
560/// Memory layout is token-major: `[token_pos][head][dim]` for data and
561/// `[token_pos][head]` for scales. Append-only; `clear` resets `len` to 0
562/// without reallocating.
563///
564/// Per-row scaling: one `f32` scale per `(token_pos, head)` pair, computed as
565/// `scale = max(|row|) / FP8_MAX`. This mirrors the INT8 implementation but
566/// uses FP8 byte encodings rather than i8.
567#[derive(Debug)]
568pub struct Fp8KvLayer {
569    /// FP8-encoded key data: `[capacity * num_kv_heads * head_dim]` as u8.
570    keys_fp8: Vec<u8>,
571    /// Per-head-row key scales: `[capacity * num_kv_heads]` as f32.
572    key_scales: Vec<f32>,
573    /// FP8-encoded value data: `[capacity * num_kv_heads * head_dim]` as u8.
574    values_fp8: Vec<u8>,
575    /// Per-head-row value scales: `[capacity * num_kv_heads]` as f32.
576    value_scales: Vec<f32>,
577    /// Number of KV attention heads per token position.
578    pub num_kv_heads: usize,
579    /// Dimension of each attention head.
580    pub head_dim: usize,
581    /// Maximum token positions pre-allocated.
582    pub capacity: usize,
583    /// Token positions actually stored.
584    pub len: usize,
585    /// FP8 encoding format (E4M3 or E5M2).
586    pub format: Fp8KvFormat,
587}
588
589impl Fp8KvLayer {
590    /// Allocate an FP8 KV layer for `num_kv_heads` heads of dimension `head_dim`,
591    /// holding up to `capacity` token positions in the given `format`.
592    ///
593    /// All storage is pre-allocated so subsequent [`push`](Self::push) calls
594    /// perform no heap allocation.
595    pub fn with_capacity(
596        num_kv_heads: usize,
597        head_dim: usize,
598        capacity: usize,
599        format: Fp8KvFormat,
600    ) -> Self {
601        let data_len = capacity * num_kv_heads * head_dim;
602        let scale_len = capacity * num_kv_heads;
603        Self {
604            keys_fp8: vec![0u8; data_len],
605            key_scales: vec![0.0_f32; scale_len],
606            values_fp8: vec![0u8; data_len],
607            value_scales: vec![0.0_f32; scale_len],
608            num_kv_heads,
609            head_dim,
610            capacity,
611            len: 0,
612            format,
613        }
614    }
615
616    /// Append FP8-quantized keys and values for the next token position.
617    ///
618    /// `key` and `value` must each be a flat slice of length
619    /// `num_kv_heads * head_dim` (heads first, then dims within each head).
620    /// Each head-row is quantized independently with its own scale.
621    ///
622    /// # Errors
623    /// - [`QuantKvError::CapacityExceeded`] if `self.len == self.capacity`.
624    /// - [`QuantKvError::ShapeMismatch`] if `key` or `value` length is wrong.
625    pub fn push(&mut self, key: &[f32], value: &[f32]) -> Result<(), QuantKvError> {
626        let expected = self.num_kv_heads * self.head_dim;
627
628        if key.len() != expected {
629            return Err(QuantKvError::ShapeMismatch {
630                expected,
631                actual: key.len(),
632            });
633        }
634        if value.len() != expected {
635            return Err(QuantKvError::ShapeMismatch {
636                expected,
637                actual: value.len(),
638            });
639        }
640        if self.len >= self.capacity {
641            return Err(QuantKvError::CapacityExceeded {
642                capacity: self.capacity,
643                pos: self.len,
644            });
645        }
646
647        let token_pos = self.len;
648        let format = self.format;
649
650        for head in 0..self.num_kv_heads {
651            let row_start = head * self.head_dim;
652            let row_end = row_start + self.head_dim;
653
654            let data_off = self.data_offset(token_pos, head);
655            let scale_off = self.scale_offset(token_pos, head);
656
657            // Keys
658            let key_row = &key[row_start..row_end];
659            let (kq, ks) = quantize_row_fp8(key_row, format);
660            self.keys_fp8[data_off..data_off + self.head_dim].copy_from_slice(&kq);
661            self.key_scales[scale_off] = ks;
662
663            // Values
664            let val_row = &value[row_start..row_end];
665            let (vq, vs) = quantize_row_fp8(val_row, format);
666            self.values_fp8[data_off..data_off + self.head_dim].copy_from_slice(&vq);
667            self.value_scales[scale_off] = vs;
668        }
669
670        self.len += 1;
671        Ok(())
672    }
673
674    /// Dequantize and return all keys for a token position as a flat
675    /// `Vec<f32>` of length `num_kv_heads * head_dim`.
676    ///
677    /// Layout: `[head_0_dims..., head_1_dims..., ...]`
678    ///
679    /// # Panics
680    /// Panics if `pos >= self.len` (index out of bounds on the pre-allocated slab).
681    pub fn get_key(&self, pos: usize) -> Vec<f32> {
682        let mut out = Vec::with_capacity(self.num_kv_heads * self.head_dim);
683        for head in 0..self.num_kv_heads {
684            let data_off = self.data_offset(pos, head);
685            let scale = self.key_scales[self.scale_offset(pos, head)];
686            out.extend(dequantize_row_fp8(
687                &self.keys_fp8[data_off..data_off + self.head_dim],
688                scale,
689                self.format,
690            ));
691        }
692        out
693    }
694
695    /// Dequantize and return all values for a token position as a flat
696    /// `Vec<f32>` of length `num_kv_heads * head_dim`.
697    ///
698    /// # Panics
699    /// Panics if `pos >= self.len`.
700    pub fn get_value(&self, pos: usize) -> Vec<f32> {
701        let mut out = Vec::with_capacity(self.num_kv_heads * self.head_dim);
702        for head in 0..self.num_kv_heads {
703            let data_off = self.data_offset(pos, head);
704            let scale = self.value_scales[self.scale_offset(pos, head)];
705            out.extend(dequantize_row_fp8(
706                &self.values_fp8[data_off..data_off + self.head_dim],
707                scale,
708                self.format,
709            ));
710        }
711        out
712    }
713
714    /// Dequantize keys for a subset of token positions.
715    ///
716    /// Returns a `Vec` of flat key vectors, one per position in `positions`.
717    /// Positions must be < `self.len`; out-of-range positions will panic
718    /// (index-out-of-bounds on the pre-allocated slab).
719    pub fn get_keys_at(&self, positions: &[usize]) -> Vec<Vec<f32>> {
720        positions.iter().map(|&pos| self.get_key(pos)).collect()
721    }
722
723    /// Dequantize values for a subset of token positions.
724    ///
725    /// Returns a `Vec` of flat value vectors, one per position in `positions`.
726    pub fn get_values_at(&self, positions: &[usize]) -> Vec<Vec<f32>> {
727        positions.iter().map(|&pos| self.get_value(pos)).collect()
728    }
729
730    /// Number of token positions currently stored.
731    #[inline]
732    pub fn len(&self) -> usize {
733        self.len
734    }
735
736    /// Returns `true` if no token positions have been stored yet.
737    #[inline]
738    pub fn is_empty(&self) -> bool {
739        self.len == 0
740    }
741
742    /// Maximum token positions this layer can hold.
743    #[inline]
744    pub fn capacity(&self) -> usize {
745        self.capacity
746    }
747
748    /// Bytes occupied by FP8 data and f32 scales for this layer.
749    ///
750    /// `keys_fp8 + values_fp8` (1 byte/element) + `key_scales + value_scales`
751    /// (4 bytes/element).
752    pub fn memory_bytes(&self) -> usize {
753        let data_bytes = self.keys_fp8.len() + self.values_fp8.len();
754        let scale_bytes = (self.key_scales.len() + self.value_scales.len()) * 4;
755        data_bytes + scale_bytes
756    }
757
758    /// Equivalent memory if the same data were stored as FP32 with no scales.
759    ///
760    /// `2 * capacity * num_kv_heads * head_dim * 4`
761    pub fn memory_bytes_fp32_equivalent(&self) -> usize {
762        2 * self.capacity * self.num_kv_heads * self.head_dim * 4
763    }
764
765    /// Reset stored length to zero, making the layer appear empty.
766    ///
767    /// Does not free or zero memory; existing bytes are overwritten on the next
768    /// series of [`push`](Self::push) calls.
769    pub fn clear(&mut self) {
770        self.len = 0;
771    }
772
773    // ── Internal helpers ──────────────────────────────────────────────────────
774
775    /// Flat index into the FP8 data arrays for `(token_pos, head, 0)`.
776    #[inline]
777    fn data_offset(&self, token_pos: usize, head: usize) -> usize {
778        (token_pos * self.num_kv_heads + head) * self.head_dim
779    }
780
781    /// Flat index into the scale arrays for `(token_pos, head)`.
782    #[inline]
783    fn scale_offset(&self, token_pos: usize, head: usize) -> usize {
784        token_pos * self.num_kv_heads + head
785    }
786}
787
788// ─── Multi-layer FP8 KV cache ─────────────────────────────────────────────────
789
790/// Full multi-layer FP8-quantized KV cache for autoregressive decoding.
791///
792/// Wraps one [`Fp8KvLayer`] per transformer layer and exposes per-layer
793/// mutable and immutable accessors. All layers share the same `format`,
794/// `num_kv_heads`, `head_dim`, and `capacity`.
795#[derive(Debug)]
796pub struct Fp8KvCache {
797    /// Per-transformer-layer FP8 KV stores.
798    pub layers: Vec<Fp8KvLayer>,
799}
800
801impl Fp8KvCache {
802    /// Allocate a new FP8 KV cache for `num_layers` transformer layers.
803    ///
804    /// Each layer is pre-allocated for `capacity` token positions.
805    pub fn new(
806        num_layers: usize,
807        num_kv_heads: usize,
808        head_dim: usize,
809        capacity: usize,
810        format: Fp8KvFormat,
811    ) -> Self {
812        let layers = (0..num_layers)
813            .map(|_| Fp8KvLayer::with_capacity(num_kv_heads, head_dim, capacity, format))
814            .collect();
815        Self { layers }
816    }
817
818    /// Immutable reference to a specific layer.
819    ///
820    /// # Panics
821    /// Panics if `layer_idx >= self.num_layers()`.
822    pub fn layer(&self, layer_idx: usize) -> &Fp8KvLayer {
823        &self.layers[layer_idx]
824    }
825
826    /// Mutable reference to a specific layer.
827    ///
828    /// # Panics
829    /// Panics if `layer_idx >= self.num_layers()`.
830    pub fn layer_mut(&mut self, layer_idx: usize) -> &mut Fp8KvLayer {
831        &mut self.layers[layer_idx]
832    }
833
834    /// Number of transformer layers in this cache.
835    pub fn num_layers(&self) -> usize {
836        self.layers.len()
837    }
838
839    /// Total memory used across all layers in bytes.
840    pub fn total_memory_bytes(&self) -> usize {
841        self.layers.iter().map(|l| l.memory_bytes()).sum()
842    }
843
844    /// Clear all layers, resetting stored lengths to zero.
845    pub fn clear_all(&mut self) {
846        for layer in &mut self.layers {
847            layer.clear();
848        }
849    }
850}