#[derive(Debug, thiserror::Error)]
pub enum QuantKvError {
#[error("capacity exceeded: capacity {capacity}, tried to push token {pos}")]
CapacityExceeded { capacity: usize, pos: usize },
#[error("token position {0} out of range")]
PositionOutOfRange(usize),
#[error("head index {head} out of range (num_kv_heads = {num_heads})")]
HeadOutOfRange { head: usize, num_heads: usize },
#[error("layer {layer} out of range (num_layers = {num_layers})")]
LayerOutOfRange { layer: usize, num_layers: usize },
#[error("key/value shape mismatch: expected {expected}, got {actual}")]
ShapeMismatch { expected: usize, actual: usize },
}
pub fn quantize_row_i8(row: &[f32]) -> (Vec<i8>, f32) {
if row.is_empty() {
return (Vec::new(), f32::EPSILON);
}
let max_abs = row.iter().map(|x| x.abs()).fold(0.0_f32, f32::max);
let scale = (max_abs / 127.0_f32).max(f32::EPSILON);
let quantized = row
.iter()
.map(|&x| (x / scale).round().clamp(-127.0, 127.0) as i8)
.collect();
(quantized, scale)
}
pub fn dequantize_row_i8(quantized: &[i8], scale: f32) -> Vec<f32> {
quantized.iter().map(|&q| q as f32 * scale).collect()
}
pub fn quant_error_mae(original: &[f32], quantized: &[i8], scale: f32) -> f32 {
let n = original.len().min(quantized.len());
if n == 0 {
return 0.0;
}
let sum: f32 = original
.iter()
.zip(quantized.iter())
.map(|(&o, &q)| (o - q as f32 * scale).abs())
.sum();
sum / n as f32
}
#[derive(Debug)]
pub struct QuantizedKvLayer {
keys_i8: Vec<i8>,
key_scales: Vec<f32>,
values_i8: Vec<i8>,
value_scales: Vec<f32>,
pub num_kv_heads: usize,
pub head_dim: usize,
pub capacity: usize,
pub len: usize,
}
impl QuantizedKvLayer {
pub fn new(capacity: usize, num_kv_heads: usize, head_dim: usize) -> Self {
let data_len = capacity * num_kv_heads * head_dim;
let scale_len = capacity * num_kv_heads;
Self {
keys_i8: vec![0i8; data_len],
key_scales: vec![0.0_f32; scale_len],
values_i8: vec![0i8; data_len],
value_scales: vec![0.0_f32; scale_len],
num_kv_heads,
head_dim,
capacity,
len: 0,
}
}
pub fn push(&mut self, keys: &[f32], values: &[f32]) -> Result<(), QuantKvError> {
let expected = self.num_kv_heads * self.head_dim;
if keys.len() != expected {
return Err(QuantKvError::ShapeMismatch {
expected,
actual: keys.len(),
});
}
if values.len() != expected {
return Err(QuantKvError::ShapeMismatch {
expected,
actual: values.len(),
});
}
if self.len >= self.capacity {
return Err(QuantKvError::CapacityExceeded {
capacity: self.capacity,
pos: self.len,
});
}
let token_pos = self.len;
for head in 0..self.num_kv_heads {
let row_start = head * self.head_dim;
let row_end = row_start + self.head_dim;
let data_off = self.data_offset(token_pos, head);
let scale_off = self.scale_offset(token_pos, head);
let key_row = &keys[row_start..row_end];
let (kq, ks) = quantize_row_i8(key_row);
self.keys_i8[data_off..data_off + self.head_dim].copy_from_slice(&kq);
self.key_scales[scale_off] = ks;
let val_row = &values[row_start..row_end];
let (vq, vs) = quantize_row_i8(val_row);
self.values_i8[data_off..data_off + self.head_dim].copy_from_slice(&vq);
self.value_scales[scale_off] = vs;
}
self.len += 1;
Ok(())
}
pub fn get_key(&self, token_pos: usize, head: usize) -> Result<Vec<f32>, QuantKvError> {
self.validate_pos_head(token_pos, head)?;
let data_off = self.data_offset(token_pos, head);
let scale = self.key_scales[self.scale_offset(token_pos, head)];
Ok(dequantize_row_i8(
&self.keys_i8[data_off..data_off + self.head_dim],
scale,
))
}
pub fn get_value(&self, token_pos: usize, head: usize) -> Result<Vec<f32>, QuantKvError> {
self.validate_pos_head(token_pos, head)?;
let data_off = self.data_offset(token_pos, head);
let scale = self.value_scales[self.scale_offset(token_pos, head)];
Ok(dequantize_row_i8(
&self.values_i8[data_off..data_off + self.head_dim],
scale,
))
}
pub fn get_keys_at(&self, token_pos: usize) -> Result<Vec<f32>, QuantKvError> {
if token_pos >= self.len {
return Err(QuantKvError::PositionOutOfRange(token_pos));
}
let mut out = Vec::with_capacity(self.num_kv_heads * self.head_dim);
for head in 0..self.num_kv_heads {
let data_off = self.data_offset(token_pos, head);
let scale = self.key_scales[self.scale_offset(token_pos, head)];
out.extend(dequantize_row_i8(
&self.keys_i8[data_off..data_off + self.head_dim],
scale,
));
}
Ok(out)
}
pub fn get_values_at(&self, token_pos: usize) -> Result<Vec<f32>, QuantKvError> {
if token_pos >= self.len {
return Err(QuantKvError::PositionOutOfRange(token_pos));
}
let mut out = Vec::with_capacity(self.num_kv_heads * self.head_dim);
for head in 0..self.num_kv_heads {
let data_off = self.data_offset(token_pos, head);
let scale = self.value_scales[self.scale_offset(token_pos, head)];
out.extend(dequantize_row_i8(
&self.values_i8[data_off..data_off + self.head_dim],
scale,
));
}
Ok(out)
}
pub fn memory_bytes(&self) -> usize {
let data_bytes = self.keys_i8.len() + self.values_i8.len();
let scale_bytes = (self.key_scales.len() + self.value_scales.len()) * 4;
data_bytes + scale_bytes
}
pub fn fp32_memory_bytes(&self) -> usize {
2 * self.capacity * self.num_kv_heads * self.head_dim * 4
}
pub fn compression_ratio(&self) -> f32 {
let quant = self.memory_bytes();
if quant == 0 {
return 1.0;
}
self.fp32_memory_bytes() as f32 / quant as f32
}
#[inline]
fn data_offset(&self, token_pos: usize, head: usize) -> usize {
(token_pos * self.num_kv_heads + head) * self.head_dim
}
#[inline]
fn scale_offset(&self, token_pos: usize, head: usize) -> usize {
token_pos * self.num_kv_heads + head
}
fn validate_pos_head(&self, token_pos: usize, head: usize) -> Result<(), QuantKvError> {
if token_pos >= self.len {
return Err(QuantKvError::PositionOutOfRange(token_pos));
}
if head >= self.num_kv_heads {
return Err(QuantKvError::HeadOutOfRange {
head,
num_heads: self.num_kv_heads,
});
}
Ok(())
}
}
#[derive(Debug)]
pub struct QuantizedKvCache {
layers: Vec<QuantizedKvLayer>,
pub num_layers: usize,
pub num_kv_heads: usize,
pub head_dim: usize,
}
impl QuantizedKvCache {
pub fn new(num_layers: usize, capacity: usize, num_kv_heads: usize, head_dim: usize) -> Self {
let layers = (0..num_layers)
.map(|_| QuantizedKvLayer::new(capacity, num_kv_heads, head_dim))
.collect();
Self {
layers,
num_layers,
num_kv_heads,
head_dim,
}
}
pub fn push_step(
&mut self,
all_keys: &[Vec<f32>],
all_values: &[Vec<f32>],
) -> Result<(), QuantKvError> {
if all_keys.len() != self.num_layers {
return Err(QuantKvError::LayerOutOfRange {
layer: all_keys.len(),
num_layers: self.num_layers,
});
}
if all_values.len() != self.num_layers {
return Err(QuantKvError::LayerOutOfRange {
layer: all_values.len(),
num_layers: self.num_layers,
});
}
for (layer_idx, (layer, (keys, values))) in self
.layers
.iter_mut()
.zip(all_keys.iter().zip(all_values.iter()))
.enumerate()
{
layer.push(keys, values).map_err(|e| match e {
QuantKvError::CapacityExceeded { capacity, pos } => {
QuantKvError::CapacityExceeded { capacity, pos }
}
QuantKvError::ShapeMismatch { expected, actual } => {
QuantKvError::ShapeMismatch { expected, actual }
}
other => {
let _ = layer_idx;
other
}
})?;
}
Ok(())
}
pub fn get_key(
&self,
layer: usize,
token_pos: usize,
head: usize,
) -> Result<Vec<f32>, QuantKvError> {
self.validate_layer(layer)?;
self.layers[layer].get_key(token_pos, head)
}
pub fn get_value(
&self,
layer: usize,
token_pos: usize,
head: usize,
) -> Result<Vec<f32>, QuantKvError> {
self.validate_layer(layer)?;
self.layers[layer].get_value(token_pos, head)
}
pub fn total_memory_bytes(&self) -> usize {
self.layers.iter().map(|l| l.memory_bytes()).sum()
}
pub fn total_fp32_memory_bytes(&self) -> usize {
self.layers.iter().map(|l| l.fp32_memory_bytes()).sum()
}
pub fn compression_ratio(&self) -> f32 {
let quant = self.total_memory_bytes();
if quant == 0 {
return 1.0;
}
self.total_fp32_memory_bytes() as f32 / quant as f32
}
pub fn seq_len(&self) -> usize {
self.layers.first().map(|l| l.len).unwrap_or(0)
}
fn validate_layer(&self, layer: usize) -> Result<(), QuantKvError> {
if layer >= self.num_layers {
return Err(QuantKvError::LayerOutOfRange {
layer,
num_layers: self.num_layers,
});
}
Ok(())
}
}