oxibonsai_model/kv_cache_quant.rs
1//! Quantized KV cache: INT8 and FP8 per-row quantization for keys and values.
2//!
3//! INT8 memory reduction: 4× vs FP32, 2× vs FP16.
4//! FP8 memory reduction: 4× vs FP32, 2× vs FP16 (with floating-point distribution).
5//! Accuracy: ~0.1% error vs FP32 for typical activation ranges.
6//!
7//! # Layout
8//! For each layer, each head, each token position:
9//! - keys_i8: [seq_len, num_kv_heads, head_dim] as i8
10//! - key_scales: [seq_len, num_kv_heads] as f32 (per-row scale)
11//! - values_i8: [seq_len, num_kv_heads, head_dim] as i8
12//! - value_scales: [seq_len, num_kv_heads] as f32
13
14use oxibonsai_core::quant_fp8::{
15 fp8_e4m3_decode, fp8_e4m3_encode, fp8_e5m2_decode, fp8_e5m2_encode, FP8_E4M3_MAX, FP8_E5M2_MAX,
16};
17
18/// Error types for quantized KV cache operations.
19#[derive(Debug, thiserror::Error)]
20pub enum QuantKvError {
21 #[error("capacity exceeded: capacity {capacity}, tried to push token {pos}")]
22 CapacityExceeded { capacity: usize, pos: usize },
23
24 #[error("token position {0} out of range")]
25 PositionOutOfRange(usize),
26
27 #[error("head index {head} out of range (num_kv_heads = {num_heads})")]
28 HeadOutOfRange { head: usize, num_heads: usize },
29
30 #[error("layer {layer} out of range (num_layers = {num_layers})")]
31 LayerOutOfRange { layer: usize, num_layers: usize },
32
33 #[error("key/value shape mismatch: expected {expected}, got {actual}")]
34 ShapeMismatch { expected: usize, actual: usize },
35}
36
37// ─── Primitive quantization helpers ──────────────────────────────────────────
38
39/// Quantize a slice to INT8 with a single per-row scale.
40///
41/// Returns `(quantized: Vec<i8>, scale: f32)`.
42///
43/// `scale = max(|x|) / 127.0`, clamped to at least [`f32::EPSILON`] to avoid
44/// division-by-zero. All values are symmetrically clamped to `[-127, 127]` so
45/// that rounding can never produce the asymmetric `i8::MIN` (-128).
46pub fn quantize_row_i8(row: &[f32]) -> (Vec<i8>, f32) {
47 if row.is_empty() {
48 return (Vec::new(), f32::EPSILON);
49 }
50
51 let max_abs = row.iter().map(|x| x.abs()).fold(0.0_f32, f32::max);
52
53 // Clamp scale to at least EPSILON to avoid division by zero for all-zero rows.
54 let scale = (max_abs / 127.0_f32).max(f32::EPSILON);
55
56 let quantized = row
57 .iter()
58 .map(|&x| (x / scale).round().clamp(-127.0, 127.0) as i8)
59 .collect();
60
61 (quantized, scale)
62}
63
64/// Dequantize INT8 back to f32 using the row scale.
65///
66/// Each element is simply multiplied by `scale`. If `scale` is zero or
67/// near-zero the output will be all zeros, which is the correct representation
68/// for an all-zero input row.
69pub fn dequantize_row_i8(quantized: &[i8], scale: f32) -> Vec<f32> {
70 quantized.iter().map(|&q| q as f32 * scale).collect()
71}
72
73/// Mean absolute error (MAE) between the original f32 slice and the
74/// dequantized version of the quantized INT8 representation.
75///
76/// Returns `0.0` for an empty slice.
77pub fn quant_error_mae(original: &[f32], quantized: &[i8], scale: f32) -> f32 {
78 let n = original.len().min(quantized.len());
79 if n == 0 {
80 return 0.0;
81 }
82 let sum: f32 = original
83 .iter()
84 .zip(quantized.iter())
85 .map(|(&o, &q)| (o - q as f32 * scale).abs())
86 .sum();
87 sum / n as f32
88}
89
90// ─── Per-layer quantized KV storage ──────────────────────────────────────────
91
92/// A single layer's INT8-quantized KV cache.
93///
94/// Memory layout for the INT8 data arrays uses the token-major order
95/// `[token_pos * num_kv_heads * head_dim]`, so sequential decode steps
96/// append contiguous blocks. Scale arrays use `[token_pos * num_kv_heads]`.
97#[derive(Debug)]
98pub struct QuantizedKvLayer {
99 /// Quantized key data: `[capacity * num_kv_heads * head_dim]` as i8.
100 keys_i8: Vec<i8>,
101 /// Per-row key scales: `[capacity * num_kv_heads]` as f32.
102 key_scales: Vec<f32>,
103 /// Quantized value data: `[capacity * num_kv_heads * head_dim]` as i8.
104 values_i8: Vec<i8>,
105 /// Per-row value scales: `[capacity * num_kv_heads]` as f32.
106 value_scales: Vec<f32>,
107 /// Number of KV attention heads.
108 pub num_kv_heads: usize,
109 /// Dimension of each attention head.
110 pub head_dim: usize,
111 /// Maximum number of token positions pre-allocated.
112 pub capacity: usize,
113 /// Number of token positions actually stored so far.
114 pub len: usize,
115}
116
117impl QuantizedKvLayer {
118 /// Allocate an empty quantized KV layer with the given dimensions.
119 ///
120 /// Pre-allocates all storage so that subsequent [`push`](Self::push) calls
121 /// do not allocate.
122 pub fn new(capacity: usize, num_kv_heads: usize, head_dim: usize) -> Self {
123 let data_len = capacity * num_kv_heads * head_dim;
124 let scale_len = capacity * num_kv_heads;
125
126 Self {
127 keys_i8: vec![0i8; data_len],
128 key_scales: vec![0.0_f32; scale_len],
129 values_i8: vec![0i8; data_len],
130 value_scales: vec![0.0_f32; scale_len],
131 num_kv_heads,
132 head_dim,
133 capacity,
134 len: 0,
135 }
136 }
137
138 /// Append keys and values for the next token position.
139 ///
140 /// `keys` must be a flat slice of shape `[num_kv_heads * head_dim]` (heads
141 /// first, then dims). `values` must have the same shape.
142 ///
143 /// Each head's row is quantized independently with its own scale.
144 ///
145 /// # Errors
146 /// - [`QuantKvError::CapacityExceeded`] if `self.len == self.capacity`.
147 /// - [`QuantKvError::ShapeMismatch`] if `keys` or `values` length is wrong.
148 pub fn push(&mut self, keys: &[f32], values: &[f32]) -> Result<(), QuantKvError> {
149 let expected = self.num_kv_heads * self.head_dim;
150
151 if keys.len() != expected {
152 return Err(QuantKvError::ShapeMismatch {
153 expected,
154 actual: keys.len(),
155 });
156 }
157 if values.len() != expected {
158 return Err(QuantKvError::ShapeMismatch {
159 expected,
160 actual: values.len(),
161 });
162 }
163 if self.len >= self.capacity {
164 return Err(QuantKvError::CapacityExceeded {
165 capacity: self.capacity,
166 pos: self.len,
167 });
168 }
169
170 let token_pos = self.len;
171
172 for head in 0..self.num_kv_heads {
173 let row_start = head * self.head_dim;
174 let row_end = row_start + self.head_dim;
175
176 // Compute offsets before any mutable borrows to satisfy the borrow checker.
177 let data_off = self.data_offset(token_pos, head);
178 let scale_off = self.scale_offset(token_pos, head);
179
180 // Keys
181 let key_row = &keys[row_start..row_end];
182 let (kq, ks) = quantize_row_i8(key_row);
183 self.keys_i8[data_off..data_off + self.head_dim].copy_from_slice(&kq);
184 self.key_scales[scale_off] = ks;
185
186 // Values
187 let val_row = &values[row_start..row_end];
188 let (vq, vs) = quantize_row_i8(val_row);
189 self.values_i8[data_off..data_off + self.head_dim].copy_from_slice(&vq);
190 self.value_scales[scale_off] = vs;
191 }
192
193 self.len += 1;
194 Ok(())
195 }
196
197 /// Get dequantized keys for a specific token position and head.
198 ///
199 /// Returns a `Vec<f32>` of length `head_dim`.
200 ///
201 /// # Errors
202 /// - [`QuantKvError::PositionOutOfRange`] if `token_pos >= self.len`.
203 /// - [`QuantKvError::HeadOutOfRange`] if `head >= self.num_kv_heads`.
204 pub fn get_key(&self, token_pos: usize, head: usize) -> Result<Vec<f32>, QuantKvError> {
205 self.validate_pos_head(token_pos, head)?;
206 let data_off = self.data_offset(token_pos, head);
207 let scale = self.key_scales[self.scale_offset(token_pos, head)];
208 Ok(dequantize_row_i8(
209 &self.keys_i8[data_off..data_off + self.head_dim],
210 scale,
211 ))
212 }
213
214 /// Get dequantized values for a specific token position and head.
215 ///
216 /// Returns a `Vec<f32>` of length `head_dim`.
217 ///
218 /// # Errors
219 /// - [`QuantKvError::PositionOutOfRange`] if `token_pos >= self.len`.
220 /// - [`QuantKvError::HeadOutOfRange`] if `head >= self.num_kv_heads`.
221 pub fn get_value(&self, token_pos: usize, head: usize) -> Result<Vec<f32>, QuantKvError> {
222 self.validate_pos_head(token_pos, head)?;
223 let data_off = self.data_offset(token_pos, head);
224 let scale = self.value_scales[self.scale_offset(token_pos, head)];
225 Ok(dequantize_row_i8(
226 &self.values_i8[data_off..data_off + self.head_dim],
227 scale,
228 ))
229 }
230
231 /// Get all dequantized keys for a token position (all heads, interleaved).
232 ///
233 /// Returns a flat `Vec<f32>` of length `num_kv_heads * head_dim`.
234 ///
235 /// # Errors
236 /// - [`QuantKvError::PositionOutOfRange`] if `token_pos >= self.len`.
237 pub fn get_keys_at(&self, token_pos: usize) -> Result<Vec<f32>, QuantKvError> {
238 if token_pos >= self.len {
239 return Err(QuantKvError::PositionOutOfRange(token_pos));
240 }
241 let mut out = Vec::with_capacity(self.num_kv_heads * self.head_dim);
242 for head in 0..self.num_kv_heads {
243 let data_off = self.data_offset(token_pos, head);
244 let scale = self.key_scales[self.scale_offset(token_pos, head)];
245 out.extend(dequantize_row_i8(
246 &self.keys_i8[data_off..data_off + self.head_dim],
247 scale,
248 ));
249 }
250 Ok(out)
251 }
252
253 /// Get all dequantized values for a token position (all heads, interleaved).
254 ///
255 /// Returns a flat `Vec<f32>` of length `num_kv_heads * head_dim`.
256 ///
257 /// # Errors
258 /// - [`QuantKvError::PositionOutOfRange`] if `token_pos >= self.len`.
259 pub fn get_values_at(&self, token_pos: usize) -> Result<Vec<f32>, QuantKvError> {
260 if token_pos >= self.len {
261 return Err(QuantKvError::PositionOutOfRange(token_pos));
262 }
263 let mut out = Vec::with_capacity(self.num_kv_heads * self.head_dim);
264 for head in 0..self.num_kv_heads {
265 let data_off = self.data_offset(token_pos, head);
266 let scale = self.value_scales[self.scale_offset(token_pos, head)];
267 out.extend(dequantize_row_i8(
268 &self.values_i8[data_off..data_off + self.head_dim],
269 scale,
270 ));
271 }
272 Ok(out)
273 }
274
275 /// Memory used by this layer in bytes (INT8 data + f32 scales).
276 ///
277 /// Only accounts for the pre-allocated storage slabs, not struct overhead.
278 pub fn memory_bytes(&self) -> usize {
279 // INT8 data: 1 byte per element
280 let data_bytes = self.keys_i8.len() + self.values_i8.len();
281 // f32 scales: 4 bytes each
282 let scale_bytes = (self.key_scales.len() + self.value_scales.len()) * 4;
283 data_bytes + scale_bytes
284 }
285
286 /// Equivalent memory if the same data were stored as FP32 (no scales).
287 ///
288 /// `2 * capacity * num_kv_heads * head_dim * 4 bytes`
289 pub fn fp32_memory_bytes(&self) -> usize {
290 // Keys + values, each element 4 bytes
291 2 * self.capacity * self.num_kv_heads * self.head_dim * 4
292 }
293
294 /// Compression ratio versus FP32 storage.
295 ///
296 /// Values approaching 4.0 indicate near-ideal INT8 compression. The ratio
297 /// is slightly below 4.0 because per-row f32 scales add overhead.
298 pub fn compression_ratio(&self) -> f32 {
299 let quant = self.memory_bytes();
300 if quant == 0 {
301 return 1.0;
302 }
303 self.fp32_memory_bytes() as f32 / quant as f32
304 }
305
306 // ── Internal helpers ──────────────────────────────────────────────────────
307
308 /// Flat index into the INT8 data arrays for `(token_pos, head, 0)`.
309 ///
310 /// Layout: `[token_pos][head][dim]` → `(token_pos * num_kv_heads + head) * head_dim`
311 #[inline]
312 fn data_offset(&self, token_pos: usize, head: usize) -> usize {
313 (token_pos * self.num_kv_heads + head) * self.head_dim
314 }
315
316 /// Flat index into the scale arrays for `(token_pos, head)`.
317 ///
318 /// Layout: `[token_pos][head]` → `token_pos * num_kv_heads + head`
319 #[inline]
320 fn scale_offset(&self, token_pos: usize, head: usize) -> usize {
321 token_pos * self.num_kv_heads + head
322 }
323
324 /// Validate that `token_pos < self.len` and `head < self.num_kv_heads`.
325 fn validate_pos_head(&self, token_pos: usize, head: usize) -> Result<(), QuantKvError> {
326 if token_pos >= self.len {
327 return Err(QuantKvError::PositionOutOfRange(token_pos));
328 }
329 if head >= self.num_kv_heads {
330 return Err(QuantKvError::HeadOutOfRange {
331 head,
332 num_heads: self.num_kv_heads,
333 });
334 }
335 Ok(())
336 }
337}
338
339// ─── Multi-layer quantized KV cache ──────────────────────────────────────────
340
341/// Full multi-layer INT8-quantized KV cache for autoregressive decoding.
342///
343/// Wraps one [`QuantizedKvLayer`] per transformer layer and exposes a
344/// unified decode-step interface through [`push_step`](Self::push_step).
345#[derive(Debug)]
346pub struct QuantizedKvCache {
347 layers: Vec<QuantizedKvLayer>,
348 /// Number of transformer layers.
349 pub num_layers: usize,
350 /// Number of KV attention heads per layer.
351 pub num_kv_heads: usize,
352 /// Dimension of each attention head.
353 pub head_dim: usize,
354}
355
356impl QuantizedKvCache {
357 /// Allocate a new quantized KV cache for `num_layers` transformer layers.
358 ///
359 /// Each layer is pre-allocated for `capacity` token positions.
360 pub fn new(num_layers: usize, capacity: usize, num_kv_heads: usize, head_dim: usize) -> Self {
361 let layers = (0..num_layers)
362 .map(|_| QuantizedKvLayer::new(capacity, num_kv_heads, head_dim))
363 .collect();
364
365 Self {
366 layers,
367 num_layers,
368 num_kv_heads,
369 head_dim,
370 }
371 }
372
373 /// Append KV tensors for all layers at the current decode step.
374 ///
375 /// `all_keys[layer]` must be a flat slice of shape `[num_kv_heads * head_dim]`.
376 /// `all_values[layer]` must have the same shape.
377 ///
378 /// # Errors
379 /// - [`QuantKvError::LayerOutOfRange`] if `all_keys.len() != self.num_layers`.
380 /// - Propagates [`QuantKvError`] from each layer's [`push`](QuantizedKvLayer::push).
381 pub fn push_step(
382 &mut self,
383 all_keys: &[Vec<f32>],
384 all_values: &[Vec<f32>],
385 ) -> Result<(), QuantKvError> {
386 if all_keys.len() != self.num_layers {
387 return Err(QuantKvError::LayerOutOfRange {
388 layer: all_keys.len(),
389 num_layers: self.num_layers,
390 });
391 }
392 if all_values.len() != self.num_layers {
393 return Err(QuantKvError::LayerOutOfRange {
394 layer: all_values.len(),
395 num_layers: self.num_layers,
396 });
397 }
398
399 for (layer_idx, (layer, (keys, values))) in self
400 .layers
401 .iter_mut()
402 .zip(all_keys.iter().zip(all_values.iter()))
403 .enumerate()
404 {
405 layer.push(keys, values).map_err(|e| match e {
406 // Re-attach layer context to capacity errors
407 QuantKvError::CapacityExceeded { capacity, pos } => {
408 QuantKvError::CapacityExceeded { capacity, pos }
409 }
410 QuantKvError::ShapeMismatch { expected, actual } => {
411 QuantKvError::ShapeMismatch { expected, actual }
412 }
413 // Pass through other errors; we could enrich them with layer_idx
414 // but the error types don't carry that field — keep as is.
415 other => {
416 let _ = layer_idx;
417 other
418 }
419 })?;
420 }
421 Ok(())
422 }
423
424 /// Get dequantized keys for a specific layer, token position, and head.
425 ///
426 /// # Errors
427 /// - [`QuantKvError::LayerOutOfRange`] if `layer >= self.num_layers`.
428 /// - Propagates position/head errors from the underlying layer.
429 pub fn get_key(
430 &self,
431 layer: usize,
432 token_pos: usize,
433 head: usize,
434 ) -> Result<Vec<f32>, QuantKvError> {
435 self.validate_layer(layer)?;
436 self.layers[layer].get_key(token_pos, head)
437 }
438
439 /// Get dequantized values for a specific layer, token position, and head.
440 ///
441 /// # Errors
442 /// - [`QuantKvError::LayerOutOfRange`] if `layer >= self.num_layers`.
443 /// - Propagates position/head errors from the underlying layer.
444 pub fn get_value(
445 &self,
446 layer: usize,
447 token_pos: usize,
448 head: usize,
449 ) -> Result<Vec<f32>, QuantKvError> {
450 self.validate_layer(layer)?;
451 self.layers[layer].get_value(token_pos, head)
452 }
453
454 /// Total memory used across all layers in bytes.
455 pub fn total_memory_bytes(&self) -> usize {
456 self.layers.iter().map(|l| l.memory_bytes()).sum()
457 }
458
459 /// FP32-equivalent memory across all layers.
460 pub fn total_fp32_memory_bytes(&self) -> usize {
461 self.layers.iter().map(|l| l.fp32_memory_bytes()).sum()
462 }
463
464 /// Overall compression ratio vs FP32.
465 pub fn compression_ratio(&self) -> f32 {
466 let quant = self.total_memory_bytes();
467 if quant == 0 {
468 return 1.0;
469 }
470 self.total_fp32_memory_bytes() as f32 / quant as f32
471 }
472
473 /// Number of token positions currently stored (taken from layer 0).
474 ///
475 /// Returns `0` if there are no layers.
476 pub fn seq_len(&self) -> usize {
477 self.layers.first().map(|l| l.len).unwrap_or(0)
478 }
479
480 // ── Internal helpers ──────────────────────────────────────────────────────
481
482 fn validate_layer(&self, layer: usize) -> Result<(), QuantKvError> {
483 if layer >= self.num_layers {
484 return Err(QuantKvError::LayerOutOfRange {
485 layer,
486 num_layers: self.num_layers,
487 });
488 }
489 Ok(())
490 }
491}
492
493// ─── FP8 KV cache ─────────────────────────────────────────────────────────────
494
495/// FP8 encoding format variant for KV cache quantization.
496///
497/// - `E4M3` uses 4-bit exponent, 3-bit mantissa (max representable ≈ 448.0).
498/// Better accuracy for typical attention activations with bounded range.
499/// - `E5M2` uses 5-bit exponent, 2-bit mantissa (max representable ≈ 57344.0).
500/// Wider dynamic range, useful for outlier-heavy distributions.
501#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
502pub enum Fp8KvFormat {
503 /// E4M3FN format: 4-bit exponent, 3-bit mantissa, bias=7.
504 /// Max representable value: 448.0. No infinities; NaN = 0x7f/0xff.
505 E4M3,
506 /// E5M2 format: 5-bit exponent, 2-bit mantissa, bias=15.
507 /// Max representable value: 57344.0. Supports infinities; NaN = 0x7e.
508 E5M2,
509}
510
511/// Quantize a row of f32 values to FP8 using per-row absolute-max scaling.
512///
513/// Returns `(quantized_bytes: Vec<u8>, scale: f32)` where
514/// `scale = max(|row|) / FP8_MAX`. One scale per head-row is stored; all
515/// values are encoded relative to that scale.
516///
517/// For an all-zero row the scale is clamped to [`f32::EPSILON`] and all output
518/// bytes are `0x00`.
519fn quantize_row_fp8(row: &[f32], format: Fp8KvFormat) -> (Vec<u8>, f32) {
520 if row.is_empty() {
521 return (Vec::new(), f32::EPSILON);
522 }
523
524 let max_abs = row.iter().map(|x| x.abs()).fold(0.0_f32, f32::max);
525
526 let fp8_max = match format {
527 Fp8KvFormat::E4M3 => FP8_E4M3_MAX,
528 Fp8KvFormat::E5M2 => FP8_E5M2_MAX,
529 };
530
531 // Clamp scale to at least EPSILON to avoid division by zero for all-zero rows.
532 let scale = (max_abs / fp8_max).max(f32::EPSILON);
533
534 let quantized = match format {
535 Fp8KvFormat::E4M3 => row.iter().map(|&x| fp8_e4m3_encode(x / scale)).collect(),
536 Fp8KvFormat::E5M2 => row.iter().map(|&x| fp8_e5m2_encode(x / scale)).collect(),
537 };
538
539 (quantized, scale)
540}
541
542/// Dequantize FP8 bytes back to f32 using the stored row scale.
543///
544/// Each element is decoded from FP8 then multiplied by `scale`.
545fn dequantize_row_fp8(quantized: &[u8], scale: f32, format: Fp8KvFormat) -> Vec<f32> {
546 match format {
547 Fp8KvFormat::E4M3 => quantized
548 .iter()
549 .map(|&b| fp8_e4m3_decode(b) * scale)
550 .collect(),
551 Fp8KvFormat::E5M2 => quantized
552 .iter()
553 .map(|&b| fp8_e5m2_decode(b) * scale)
554 .collect(),
555 }
556}
557
558/// A single transformer layer's FP8-quantized KV cache.
559///
560/// Memory layout is token-major: `[token_pos][head][dim]` for data and
561/// `[token_pos][head]` for scales. Append-only; `clear` resets `len` to 0
562/// without reallocating.
563///
564/// Per-row scaling: one `f32` scale per `(token_pos, head)` pair, computed as
565/// `scale = max(|row|) / FP8_MAX`. This mirrors the INT8 implementation but
566/// uses FP8 byte encodings rather than i8.
567#[derive(Debug)]
568pub struct Fp8KvLayer {
569 /// FP8-encoded key data: `[capacity * num_kv_heads * head_dim]` as u8.
570 keys_fp8: Vec<u8>,
571 /// Per-head-row key scales: `[capacity * num_kv_heads]` as f32.
572 key_scales: Vec<f32>,
573 /// FP8-encoded value data: `[capacity * num_kv_heads * head_dim]` as u8.
574 values_fp8: Vec<u8>,
575 /// Per-head-row value scales: `[capacity * num_kv_heads]` as f32.
576 value_scales: Vec<f32>,
577 /// Number of KV attention heads per token position.
578 pub num_kv_heads: usize,
579 /// Dimension of each attention head.
580 pub head_dim: usize,
581 /// Maximum token positions pre-allocated.
582 pub capacity: usize,
583 /// Token positions actually stored.
584 pub len: usize,
585 /// FP8 encoding format (E4M3 or E5M2).
586 pub format: Fp8KvFormat,
587}
588
589impl Fp8KvLayer {
590 /// Allocate an FP8 KV layer for `num_kv_heads` heads of dimension `head_dim`,
591 /// holding up to `capacity` token positions in the given `format`.
592 ///
593 /// All storage is pre-allocated so subsequent [`push`](Self::push) calls
594 /// perform no heap allocation.
595 pub fn with_capacity(
596 num_kv_heads: usize,
597 head_dim: usize,
598 capacity: usize,
599 format: Fp8KvFormat,
600 ) -> Self {
601 let data_len = capacity * num_kv_heads * head_dim;
602 let scale_len = capacity * num_kv_heads;
603 Self {
604 keys_fp8: vec![0u8; data_len],
605 key_scales: vec![0.0_f32; scale_len],
606 values_fp8: vec![0u8; data_len],
607 value_scales: vec![0.0_f32; scale_len],
608 num_kv_heads,
609 head_dim,
610 capacity,
611 len: 0,
612 format,
613 }
614 }
615
616 /// Append FP8-quantized keys and values for the next token position.
617 ///
618 /// `key` and `value` must each be a flat slice of length
619 /// `num_kv_heads * head_dim` (heads first, then dims within each head).
620 /// Each head-row is quantized independently with its own scale.
621 ///
622 /// # Errors
623 /// - [`QuantKvError::CapacityExceeded`] if `self.len == self.capacity`.
624 /// - [`QuantKvError::ShapeMismatch`] if `key` or `value` length is wrong.
625 pub fn push(&mut self, key: &[f32], value: &[f32]) -> Result<(), QuantKvError> {
626 let expected = self.num_kv_heads * self.head_dim;
627
628 if key.len() != expected {
629 return Err(QuantKvError::ShapeMismatch {
630 expected,
631 actual: key.len(),
632 });
633 }
634 if value.len() != expected {
635 return Err(QuantKvError::ShapeMismatch {
636 expected,
637 actual: value.len(),
638 });
639 }
640 if self.len >= self.capacity {
641 return Err(QuantKvError::CapacityExceeded {
642 capacity: self.capacity,
643 pos: self.len,
644 });
645 }
646
647 let token_pos = self.len;
648 let format = self.format;
649
650 for head in 0..self.num_kv_heads {
651 let row_start = head * self.head_dim;
652 let row_end = row_start + self.head_dim;
653
654 let data_off = self.data_offset(token_pos, head);
655 let scale_off = self.scale_offset(token_pos, head);
656
657 // Keys
658 let key_row = &key[row_start..row_end];
659 let (kq, ks) = quantize_row_fp8(key_row, format);
660 self.keys_fp8[data_off..data_off + self.head_dim].copy_from_slice(&kq);
661 self.key_scales[scale_off] = ks;
662
663 // Values
664 let val_row = &value[row_start..row_end];
665 let (vq, vs) = quantize_row_fp8(val_row, format);
666 self.values_fp8[data_off..data_off + self.head_dim].copy_from_slice(&vq);
667 self.value_scales[scale_off] = vs;
668 }
669
670 self.len += 1;
671 Ok(())
672 }
673
674 /// Dequantize and return all keys for a token position as a flat
675 /// `Vec<f32>` of length `num_kv_heads * head_dim`.
676 ///
677 /// Layout: `[head_0_dims..., head_1_dims..., ...]`
678 ///
679 /// # Panics
680 /// Panics if `pos >= self.len` (index out of bounds on the pre-allocated slab).
681 pub fn get_key(&self, pos: usize) -> Vec<f32> {
682 let mut out = Vec::with_capacity(self.num_kv_heads * self.head_dim);
683 for head in 0..self.num_kv_heads {
684 let data_off = self.data_offset(pos, head);
685 let scale = self.key_scales[self.scale_offset(pos, head)];
686 out.extend(dequantize_row_fp8(
687 &self.keys_fp8[data_off..data_off + self.head_dim],
688 scale,
689 self.format,
690 ));
691 }
692 out
693 }
694
695 /// Dequantize and return all values for a token position as a flat
696 /// `Vec<f32>` of length `num_kv_heads * head_dim`.
697 ///
698 /// # Panics
699 /// Panics if `pos >= self.len`.
700 pub fn get_value(&self, pos: usize) -> Vec<f32> {
701 let mut out = Vec::with_capacity(self.num_kv_heads * self.head_dim);
702 for head in 0..self.num_kv_heads {
703 let data_off = self.data_offset(pos, head);
704 let scale = self.value_scales[self.scale_offset(pos, head)];
705 out.extend(dequantize_row_fp8(
706 &self.values_fp8[data_off..data_off + self.head_dim],
707 scale,
708 self.format,
709 ));
710 }
711 out
712 }
713
714 /// Dequantize keys for a subset of token positions.
715 ///
716 /// Returns a `Vec` of flat key vectors, one per position in `positions`.
717 /// Positions must be < `self.len`; out-of-range positions will panic
718 /// (index-out-of-bounds on the pre-allocated slab).
719 pub fn get_keys_at(&self, positions: &[usize]) -> Vec<Vec<f32>> {
720 positions.iter().map(|&pos| self.get_key(pos)).collect()
721 }
722
723 /// Dequantize values for a subset of token positions.
724 ///
725 /// Returns a `Vec` of flat value vectors, one per position in `positions`.
726 pub fn get_values_at(&self, positions: &[usize]) -> Vec<Vec<f32>> {
727 positions.iter().map(|&pos| self.get_value(pos)).collect()
728 }
729
730 /// Number of token positions currently stored.
731 #[inline]
732 pub fn len(&self) -> usize {
733 self.len
734 }
735
736 /// Returns `true` if no token positions have been stored yet.
737 #[inline]
738 pub fn is_empty(&self) -> bool {
739 self.len == 0
740 }
741
742 /// Maximum token positions this layer can hold.
743 #[inline]
744 pub fn capacity(&self) -> usize {
745 self.capacity
746 }
747
748 /// Bytes occupied by FP8 data and f32 scales for this layer.
749 ///
750 /// `keys_fp8 + values_fp8` (1 byte/element) + `key_scales + value_scales`
751 /// (4 bytes/element).
752 pub fn memory_bytes(&self) -> usize {
753 let data_bytes = self.keys_fp8.len() + self.values_fp8.len();
754 let scale_bytes = (self.key_scales.len() + self.value_scales.len()) * 4;
755 data_bytes + scale_bytes
756 }
757
758 /// Equivalent memory if the same data were stored as FP32 with no scales.
759 ///
760 /// `2 * capacity * num_kv_heads * head_dim * 4`
761 pub fn memory_bytes_fp32_equivalent(&self) -> usize {
762 2 * self.capacity * self.num_kv_heads * self.head_dim * 4
763 }
764
765 /// Reset stored length to zero, making the layer appear empty.
766 ///
767 /// Does not free or zero memory; existing bytes are overwritten on the next
768 /// series of [`push`](Self::push) calls.
769 pub fn clear(&mut self) {
770 self.len = 0;
771 }
772
773 // ── Internal helpers ──────────────────────────────────────────────────────
774
775 /// Flat index into the FP8 data arrays for `(token_pos, head, 0)`.
776 #[inline]
777 fn data_offset(&self, token_pos: usize, head: usize) -> usize {
778 (token_pos * self.num_kv_heads + head) * self.head_dim
779 }
780
781 /// Flat index into the scale arrays for `(token_pos, head)`.
782 #[inline]
783 fn scale_offset(&self, token_pos: usize, head: usize) -> usize {
784 token_pos * self.num_kv_heads + head
785 }
786}
787
788// ─── Multi-layer FP8 KV cache ─────────────────────────────────────────────────
789
790/// Full multi-layer FP8-quantized KV cache for autoregressive decoding.
791///
792/// Wraps one [`Fp8KvLayer`] per transformer layer and exposes per-layer
793/// mutable and immutable accessors. All layers share the same `format`,
794/// `num_kv_heads`, `head_dim`, and `capacity`.
795#[derive(Debug)]
796pub struct Fp8KvCache {
797 /// Per-transformer-layer FP8 KV stores.
798 pub layers: Vec<Fp8KvLayer>,
799}
800
801impl Fp8KvCache {
802 /// Allocate a new FP8 KV cache for `num_layers` transformer layers.
803 ///
804 /// Each layer is pre-allocated for `capacity` token positions.
805 pub fn new(
806 num_layers: usize,
807 num_kv_heads: usize,
808 head_dim: usize,
809 capacity: usize,
810 format: Fp8KvFormat,
811 ) -> Self {
812 let layers = (0..num_layers)
813 .map(|_| Fp8KvLayer::with_capacity(num_kv_heads, head_dim, capacity, format))
814 .collect();
815 Self { layers }
816 }
817
818 /// Immutable reference to a specific layer.
819 ///
820 /// # Panics
821 /// Panics if `layer_idx >= self.num_layers()`.
822 pub fn layer(&self, layer_idx: usize) -> &Fp8KvLayer {
823 &self.layers[layer_idx]
824 }
825
826 /// Mutable reference to a specific layer.
827 ///
828 /// # Panics
829 /// Panics if `layer_idx >= self.num_layers()`.
830 pub fn layer_mut(&mut self, layer_idx: usize) -> &mut Fp8KvLayer {
831 &mut self.layers[layer_idx]
832 }
833
834 /// Number of transformer layers in this cache.
835 pub fn num_layers(&self) -> usize {
836 self.layers.len()
837 }
838
839 /// Total memory used across all layers in bytes.
840 pub fn total_memory_bytes(&self) -> usize {
841 self.layers.iter().map(|l| l.memory_bytes()).sum()
842 }
843
844 /// Clear all layers, resetting stored lengths to zero.
845 pub fn clear_all(&mut self) {
846 for layer in &mut self.layers {
847 layer.clear();
848 }
849 }
850}