use oxibonsai_core::quant_fp8::{
fp8_e4m3_decode, fp8_e4m3_encode, fp8_e5m2_decode, fp8_e5m2_encode, FP8_E4M3_MAX, FP8_E5M2_MAX,
};
#[derive(Debug, thiserror::Error)]
pub enum QuantKvError {
#[error("capacity exceeded: capacity {capacity}, tried to push token {pos}")]
CapacityExceeded { capacity: usize, pos: usize },
#[error("token position {0} out of range")]
PositionOutOfRange(usize),
#[error("head index {head} out of range (num_kv_heads = {num_heads})")]
HeadOutOfRange { head: usize, num_heads: usize },
#[error("layer {layer} out of range (num_layers = {num_layers})")]
LayerOutOfRange { layer: usize, num_layers: usize },
#[error("key/value shape mismatch: expected {expected}, got {actual}")]
ShapeMismatch { expected: usize, actual: usize },
}
pub fn quantize_row_i8(row: &[f32]) -> (Vec<i8>, f32) {
if row.is_empty() {
return (Vec::new(), f32::EPSILON);
}
let max_abs = row.iter().map(|x| x.abs()).fold(0.0_f32, f32::max);
let scale = (max_abs / 127.0_f32).max(f32::EPSILON);
let quantized = row
.iter()
.map(|&x| (x / scale).round().clamp(-127.0, 127.0) as i8)
.collect();
(quantized, scale)
}
pub fn dequantize_row_i8(quantized: &[i8], scale: f32) -> Vec<f32> {
quantized.iter().map(|&q| q as f32 * scale).collect()
}
pub fn quant_error_mae(original: &[f32], quantized: &[i8], scale: f32) -> f32 {
let n = original.len().min(quantized.len());
if n == 0 {
return 0.0;
}
let sum: f32 = original
.iter()
.zip(quantized.iter())
.map(|(&o, &q)| (o - q as f32 * scale).abs())
.sum();
sum / n as f32
}
#[derive(Debug)]
pub struct QuantizedKvLayer {
keys_i8: Vec<i8>,
key_scales: Vec<f32>,
values_i8: Vec<i8>,
value_scales: Vec<f32>,
pub num_kv_heads: usize,
pub head_dim: usize,
pub capacity: usize,
pub len: usize,
}
impl QuantizedKvLayer {
pub fn new(capacity: usize, num_kv_heads: usize, head_dim: usize) -> Self {
let data_len = capacity * num_kv_heads * head_dim;
let scale_len = capacity * num_kv_heads;
Self {
keys_i8: vec![0i8; data_len],
key_scales: vec![0.0_f32; scale_len],
values_i8: vec![0i8; data_len],
value_scales: vec![0.0_f32; scale_len],
num_kv_heads,
head_dim,
capacity,
len: 0,
}
}
pub fn push(&mut self, keys: &[f32], values: &[f32]) -> Result<(), QuantKvError> {
let expected = self.num_kv_heads * self.head_dim;
if keys.len() != expected {
return Err(QuantKvError::ShapeMismatch {
expected,
actual: keys.len(),
});
}
if values.len() != expected {
return Err(QuantKvError::ShapeMismatch {
expected,
actual: values.len(),
});
}
if self.len >= self.capacity {
return Err(QuantKvError::CapacityExceeded {
capacity: self.capacity,
pos: self.len,
});
}
let token_pos = self.len;
for head in 0..self.num_kv_heads {
let row_start = head * self.head_dim;
let row_end = row_start + self.head_dim;
let data_off = self.data_offset(token_pos, head);
let scale_off = self.scale_offset(token_pos, head);
let key_row = &keys[row_start..row_end];
let (kq, ks) = quantize_row_i8(key_row);
self.keys_i8[data_off..data_off + self.head_dim].copy_from_slice(&kq);
self.key_scales[scale_off] = ks;
let val_row = &values[row_start..row_end];
let (vq, vs) = quantize_row_i8(val_row);
self.values_i8[data_off..data_off + self.head_dim].copy_from_slice(&vq);
self.value_scales[scale_off] = vs;
}
self.len += 1;
Ok(())
}
pub fn get_key(&self, token_pos: usize, head: usize) -> Result<Vec<f32>, QuantKvError> {
self.validate_pos_head(token_pos, head)?;
let data_off = self.data_offset(token_pos, head);
let scale = self.key_scales[self.scale_offset(token_pos, head)];
Ok(dequantize_row_i8(
&self.keys_i8[data_off..data_off + self.head_dim],
scale,
))
}
pub fn get_value(&self, token_pos: usize, head: usize) -> Result<Vec<f32>, QuantKvError> {
self.validate_pos_head(token_pos, head)?;
let data_off = self.data_offset(token_pos, head);
let scale = self.value_scales[self.scale_offset(token_pos, head)];
Ok(dequantize_row_i8(
&self.values_i8[data_off..data_off + self.head_dim],
scale,
))
}
pub fn get_keys_at(&self, token_pos: usize) -> Result<Vec<f32>, QuantKvError> {
if token_pos >= self.len {
return Err(QuantKvError::PositionOutOfRange(token_pos));
}
let mut out = Vec::with_capacity(self.num_kv_heads * self.head_dim);
for head in 0..self.num_kv_heads {
let data_off = self.data_offset(token_pos, head);
let scale = self.key_scales[self.scale_offset(token_pos, head)];
out.extend(dequantize_row_i8(
&self.keys_i8[data_off..data_off + self.head_dim],
scale,
));
}
Ok(out)
}
pub fn get_values_at(&self, token_pos: usize) -> Result<Vec<f32>, QuantKvError> {
if token_pos >= self.len {
return Err(QuantKvError::PositionOutOfRange(token_pos));
}
let mut out = Vec::with_capacity(self.num_kv_heads * self.head_dim);
for head in 0..self.num_kv_heads {
let data_off = self.data_offset(token_pos, head);
let scale = self.value_scales[self.scale_offset(token_pos, head)];
out.extend(dequantize_row_i8(
&self.values_i8[data_off..data_off + self.head_dim],
scale,
));
}
Ok(out)
}
pub fn memory_bytes(&self) -> usize {
let data_bytes = self.keys_i8.len() + self.values_i8.len();
let scale_bytes = (self.key_scales.len() + self.value_scales.len()) * 4;
data_bytes + scale_bytes
}
pub fn fp32_memory_bytes(&self) -> usize {
2 * self.capacity * self.num_kv_heads * self.head_dim * 4
}
pub fn compression_ratio(&self) -> f32 {
let quant = self.memory_bytes();
if quant == 0 {
return 1.0;
}
self.fp32_memory_bytes() as f32 / quant as f32
}
#[inline]
fn data_offset(&self, token_pos: usize, head: usize) -> usize {
(token_pos * self.num_kv_heads + head) * self.head_dim
}
#[inline]
fn scale_offset(&self, token_pos: usize, head: usize) -> usize {
token_pos * self.num_kv_heads + head
}
fn validate_pos_head(&self, token_pos: usize, head: usize) -> Result<(), QuantKvError> {
if token_pos >= self.len {
return Err(QuantKvError::PositionOutOfRange(token_pos));
}
if head >= self.num_kv_heads {
return Err(QuantKvError::HeadOutOfRange {
head,
num_heads: self.num_kv_heads,
});
}
Ok(())
}
}
#[derive(Debug)]
pub struct QuantizedKvCache {
layers: Vec<QuantizedKvLayer>,
pub num_layers: usize,
pub num_kv_heads: usize,
pub head_dim: usize,
}
impl QuantizedKvCache {
pub fn new(num_layers: usize, capacity: usize, num_kv_heads: usize, head_dim: usize) -> Self {
let layers = (0..num_layers)
.map(|_| QuantizedKvLayer::new(capacity, num_kv_heads, head_dim))
.collect();
Self {
layers,
num_layers,
num_kv_heads,
head_dim,
}
}
pub fn push_step(
&mut self,
all_keys: &[Vec<f32>],
all_values: &[Vec<f32>],
) -> Result<(), QuantKvError> {
if all_keys.len() != self.num_layers {
return Err(QuantKvError::LayerOutOfRange {
layer: all_keys.len(),
num_layers: self.num_layers,
});
}
if all_values.len() != self.num_layers {
return Err(QuantKvError::LayerOutOfRange {
layer: all_values.len(),
num_layers: self.num_layers,
});
}
for (layer_idx, (layer, (keys, values))) in self
.layers
.iter_mut()
.zip(all_keys.iter().zip(all_values.iter()))
.enumerate()
{
layer.push(keys, values).map_err(|e| match e {
QuantKvError::CapacityExceeded { capacity, pos } => {
QuantKvError::CapacityExceeded { capacity, pos }
}
QuantKvError::ShapeMismatch { expected, actual } => {
QuantKvError::ShapeMismatch { expected, actual }
}
other => {
let _ = layer_idx;
other
}
})?;
}
Ok(())
}
pub fn get_key(
&self,
layer: usize,
token_pos: usize,
head: usize,
) -> Result<Vec<f32>, QuantKvError> {
self.validate_layer(layer)?;
self.layers[layer].get_key(token_pos, head)
}
pub fn get_value(
&self,
layer: usize,
token_pos: usize,
head: usize,
) -> Result<Vec<f32>, QuantKvError> {
self.validate_layer(layer)?;
self.layers[layer].get_value(token_pos, head)
}
pub fn total_memory_bytes(&self) -> usize {
self.layers.iter().map(|l| l.memory_bytes()).sum()
}
pub fn total_fp32_memory_bytes(&self) -> usize {
self.layers.iter().map(|l| l.fp32_memory_bytes()).sum()
}
pub fn compression_ratio(&self) -> f32 {
let quant = self.total_memory_bytes();
if quant == 0 {
return 1.0;
}
self.total_fp32_memory_bytes() as f32 / quant as f32
}
pub fn seq_len(&self) -> usize {
self.layers.first().map(|l| l.len).unwrap_or(0)
}
fn validate_layer(&self, layer: usize) -> Result<(), QuantKvError> {
if layer >= self.num_layers {
return Err(QuantKvError::LayerOutOfRange {
layer,
num_layers: self.num_layers,
});
}
Ok(())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Fp8KvFormat {
E4M3,
E5M2,
}
fn quantize_row_fp8(row: &[f32], format: Fp8KvFormat) -> (Vec<u8>, f32) {
if row.is_empty() {
return (Vec::new(), f32::EPSILON);
}
let max_abs = row.iter().map(|x| x.abs()).fold(0.0_f32, f32::max);
let fp8_max = match format {
Fp8KvFormat::E4M3 => FP8_E4M3_MAX,
Fp8KvFormat::E5M2 => FP8_E5M2_MAX,
};
let scale = (max_abs / fp8_max).max(f32::EPSILON);
let quantized = match format {
Fp8KvFormat::E4M3 => row.iter().map(|&x| fp8_e4m3_encode(x / scale)).collect(),
Fp8KvFormat::E5M2 => row.iter().map(|&x| fp8_e5m2_encode(x / scale)).collect(),
};
(quantized, scale)
}
fn dequantize_row_fp8(quantized: &[u8], scale: f32, format: Fp8KvFormat) -> Vec<f32> {
match format {
Fp8KvFormat::E4M3 => quantized
.iter()
.map(|&b| fp8_e4m3_decode(b) * scale)
.collect(),
Fp8KvFormat::E5M2 => quantized
.iter()
.map(|&b| fp8_e5m2_decode(b) * scale)
.collect(),
}
}
#[derive(Debug)]
pub struct Fp8KvLayer {
keys_fp8: Vec<u8>,
key_scales: Vec<f32>,
values_fp8: Vec<u8>,
value_scales: Vec<f32>,
pub num_kv_heads: usize,
pub head_dim: usize,
pub capacity: usize,
pub len: usize,
pub format: Fp8KvFormat,
}
impl Fp8KvLayer {
pub fn with_capacity(
num_kv_heads: usize,
head_dim: usize,
capacity: usize,
format: Fp8KvFormat,
) -> Self {
let data_len = capacity * num_kv_heads * head_dim;
let scale_len = capacity * num_kv_heads;
Self {
keys_fp8: vec![0u8; data_len],
key_scales: vec![0.0_f32; scale_len],
values_fp8: vec![0u8; data_len],
value_scales: vec![0.0_f32; scale_len],
num_kv_heads,
head_dim,
capacity,
len: 0,
format,
}
}
pub fn push(&mut self, key: &[f32], value: &[f32]) -> Result<(), QuantKvError> {
let expected = self.num_kv_heads * self.head_dim;
if key.len() != expected {
return Err(QuantKvError::ShapeMismatch {
expected,
actual: key.len(),
});
}
if value.len() != expected {
return Err(QuantKvError::ShapeMismatch {
expected,
actual: value.len(),
});
}
if self.len >= self.capacity {
return Err(QuantKvError::CapacityExceeded {
capacity: self.capacity,
pos: self.len,
});
}
let token_pos = self.len;
let format = self.format;
for head in 0..self.num_kv_heads {
let row_start = head * self.head_dim;
let row_end = row_start + self.head_dim;
let data_off = self.data_offset(token_pos, head);
let scale_off = self.scale_offset(token_pos, head);
let key_row = &key[row_start..row_end];
let (kq, ks) = quantize_row_fp8(key_row, format);
self.keys_fp8[data_off..data_off + self.head_dim].copy_from_slice(&kq);
self.key_scales[scale_off] = ks;
let val_row = &value[row_start..row_end];
let (vq, vs) = quantize_row_fp8(val_row, format);
self.values_fp8[data_off..data_off + self.head_dim].copy_from_slice(&vq);
self.value_scales[scale_off] = vs;
}
self.len += 1;
Ok(())
}
pub fn get_key(&self, pos: usize) -> Vec<f32> {
let mut out = Vec::with_capacity(self.num_kv_heads * self.head_dim);
for head in 0..self.num_kv_heads {
let data_off = self.data_offset(pos, head);
let scale = self.key_scales[self.scale_offset(pos, head)];
out.extend(dequantize_row_fp8(
&self.keys_fp8[data_off..data_off + self.head_dim],
scale,
self.format,
));
}
out
}
pub fn get_value(&self, pos: usize) -> Vec<f32> {
let mut out = Vec::with_capacity(self.num_kv_heads * self.head_dim);
for head in 0..self.num_kv_heads {
let data_off = self.data_offset(pos, head);
let scale = self.value_scales[self.scale_offset(pos, head)];
out.extend(dequantize_row_fp8(
&self.values_fp8[data_off..data_off + self.head_dim],
scale,
self.format,
));
}
out
}
pub fn get_keys_at(&self, positions: &[usize]) -> Vec<Vec<f32>> {
positions.iter().map(|&pos| self.get_key(pos)).collect()
}
pub fn get_values_at(&self, positions: &[usize]) -> Vec<Vec<f32>> {
positions.iter().map(|&pos| self.get_value(pos)).collect()
}
#[inline]
pub fn len(&self) -> usize {
self.len
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[inline]
pub fn capacity(&self) -> usize {
self.capacity
}
pub fn memory_bytes(&self) -> usize {
let data_bytes = self.keys_fp8.len() + self.values_fp8.len();
let scale_bytes = (self.key_scales.len() + self.value_scales.len()) * 4;
data_bytes + scale_bytes
}
pub fn memory_bytes_fp32_equivalent(&self) -> usize {
2 * self.capacity * self.num_kv_heads * self.head_dim * 4
}
pub fn clear(&mut self) {
self.len = 0;
}
#[inline]
fn data_offset(&self, token_pos: usize, head: usize) -> usize {
(token_pos * self.num_kv_heads + head) * self.head_dim
}
#[inline]
fn scale_offset(&self, token_pos: usize, head: usize) -> usize {
token_pos * self.num_kv_heads + head
}
}
#[derive(Debug)]
pub struct Fp8KvCache {
pub layers: Vec<Fp8KvLayer>,
}
impl Fp8KvCache {
pub fn new(
num_layers: usize,
num_kv_heads: usize,
head_dim: usize,
capacity: usize,
format: Fp8KvFormat,
) -> Self {
let layers = (0..num_layers)
.map(|_| Fp8KvLayer::with_capacity(num_kv_heads, head_dim, capacity, format))
.collect();
Self { layers }
}
pub fn layer(&self, layer_idx: usize) -> &Fp8KvLayer {
&self.layers[layer_idx]
}
pub fn layer_mut(&mut self, layer_idx: usize) -> &mut Fp8KvLayer {
&mut self.layers[layer_idx]
}
pub fn num_layers(&self) -> usize {
self.layers.len()
}
pub fn total_memory_bytes(&self) -> usize {
self.layers.iter().map(|l| l.memory_bytes()).sum()
}
pub fn clear_all(&mut self) {
for layer in &mut self.layers {
layer.clear();
}
}
}