use std::collections::VecDeque;
#[derive(Debug, Clone, PartialEq)]
pub enum EvictionPolicy {
H2O,
SlidingWindow {
window: usize,
sink: usize,
},
PyramidKV {
total_layers: usize,
},
}
#[derive(Debug, Clone)]
pub struct KVCacheConfig {
pub max_seq_len: usize,
pub num_heads: usize,
pub head_dim: usize,
pub quantization_bits: u8,
pub eviction_policy: EvictionPolicy,
}
#[derive(Debug, Clone)]
pub struct QuantizedTensor {
pub data: Vec<u8>,
pub scales: Vec<f32>,
pub zero_points: Vec<f32>,
pub bits: u8,
}
#[inline]
pub fn round_to_nearest_even(x: f32) -> f32 {
let rounded = x.round();
let frac = (x - x.floor()).abs();
if (frac - 0.5).abs() < f32::EPSILON {
let r = rounded as i64;
if r % 2 != 0 {
if x > 0.0 { rounded - 1.0 } else { rounded + 1.0 }
} else {
rounded
}
} else {
rounded
}
}
pub fn quantize_asymmetric(tensor: &[f32], num_heads: usize, bits: u8) -> QuantizedTensor {
let head_dim = tensor.len() / num_heads;
let qmax = ((1u32 << bits) - 1) as f32;
let mut data = Vec::with_capacity(tensor.len());
let mut scales = Vec::with_capacity(num_heads);
let mut zero_points = Vec::with_capacity(num_heads);
for h in 0..num_heads {
let start = h * head_dim;
let end = start + head_dim;
let channel = &tensor[start..end];
let min_val = channel.iter().copied().fold(f32::INFINITY, f32::min);
let max_val = channel.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let range = max_val - min_val;
let scale = if range.abs() < f32::EPSILON { 1.0 } else { range / qmax };
let zp = if range.abs() < f32::EPSILON { 0.0 } else { -min_val / scale };
scales.push(scale);
zero_points.push(zp);
for &v in channel {
let q = round_to_nearest_even(v / scale + zp).clamp(0.0, qmax);
data.push(q as u8);
}
}
QuantizedTensor { data, scales, zero_points, bits }
}
pub fn quantize_symmetric(tensor: &[f32], bits: u8) -> (Vec<u8>, f32) {
assert!(bits >= 2 && bits <= 8, "quantize_symmetric: bits must be in [2, 8], got {}", bits);
let qmax = ((1u32 << (bits - 1)) - 1) as f32;
let abs_max = tensor.iter().copied().map(f32::abs).fold(0.0_f32, f32::max);
let scale = if abs_max < f32::EPSILON { 1.0 } else { abs_max / qmax };
let offset = (1u32 << (bits - 1)) as f32;
let data: Vec<u8> = tensor
.iter()
.map(|&v| {
let q = round_to_nearest_even(v / scale + offset).clamp(0.0, (1u32 << bits) as f32 - 1.0);
q as u8
})
.collect();
(data, scale)
}
pub fn dequantize_symmetric(data: &[u8], scale: f32, bits: u8) -> Vec<f32> {
let offset = (1u32 << (bits - 1)) as f32;
data.iter().map(|&q| (q as f32 - offset) * scale).collect()
}
pub fn dequantize(qt: &QuantizedTensor, num_heads: usize) -> Vec<f32> {
let head_dim = qt.data.len() / num_heads;
let mut out = Vec::with_capacity(qt.data.len());
for h in 0..num_heads {
let start = h * head_dim;
let end = start + head_dim;
let scale = qt.scales[h];
let zp = qt.zero_points[h];
for &q in &qt.data[start..end] {
out.push(scale * (q as f32 - zp));
}
}
out
}
#[derive(Debug, Clone)]
struct CacheEntry {
key: QuantizedTensor,
value: QuantizedTensor,
attention_score: f64,
seq_idx: usize,
}
pub struct CacheManager {
config: KVCacheConfig,
entries: VecDeque<CacheEntry>,
next_seq: usize,
}
impl CacheManager {
pub fn new(config: KVCacheConfig) -> Self {
Self {
config,
entries: VecDeque::new(),
next_seq: 0,
}
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn append(&mut self, key: &[f32], value: &[f32], _layer_idx: usize) {
let bits = self.config.quantization_bits;
let heads = self.config.num_heads;
let qk = quantize_asymmetric(key, heads, bits);
let qv = quantize_asymmetric(value, heads, bits);
self.entries.push_back(CacheEntry {
key: qk,
value: qv,
attention_score: 0.0,
seq_idx: self.next_seq,
});
self.next_seq += 1;
if self.entries.len() > self.config.max_seq_len {
self.evict(self.config.max_seq_len);
}
}
pub fn get(&self, positions: &[usize]) -> (Vec<Vec<f32>>, Vec<Vec<f32>>) {
let heads = self.config.num_heads;
let mut keys = Vec::with_capacity(positions.len());
let mut values = Vec::with_capacity(positions.len());
for &pos in positions {
if pos < self.entries.len() {
let entry = &self.entries[pos];
keys.push(dequantize(&entry.key, heads));
values.push(dequantize(&entry.value, heads));
}
}
(keys, values)
}
pub fn evict(&mut self, budget: usize) {
if self.entries.len() <= budget {
return;
}
match &self.config.eviction_policy {
EvictionPolicy::H2O => self.evict_h2o(budget),
EvictionPolicy::SlidingWindow { window, sink } => {
self.evict_sliding_window(budget, *window, *sink);
}
EvictionPolicy::PyramidKV { .. } => {
self.evict_h2o(budget);
}
}
}
fn evict_h2o(&mut self, budget: usize) {
while self.entries.len() > budget {
let min_idx = self
.entries
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| {
a.attention_score
.partial_cmp(&b.attention_score)
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(i, _)| i)
.unwrap();
self.entries.remove(min_idx);
}
}
fn evict_sliding_window(&mut self, budget: usize, window: usize, sink: usize) {
let effective_budget = budget.min(sink + window);
if self.entries.len() <= effective_budget {
return;
}
let len = self.entries.len();
let keep_end = window.min(len);
let keep_start = sink.min(len.saturating_sub(keep_end));
let mut kept: VecDeque<CacheEntry> = VecDeque::with_capacity(keep_start + keep_end);
for i in 0..keep_start {
kept.push_back(self.entries[i].clone());
}
for i in (len - keep_end)..len {
if i >= keep_start {
kept.push_back(self.entries[i].clone());
}
}
self.entries = kept;
}
pub fn update_attention_scores(&mut self, scores: &[f64]) {
for (entry, &s) in self.entries.iter_mut().zip(scores.iter()) {
entry.attention_score += s;
}
}
pub fn pyramid_budget(&self, layer_idx: usize, total_layers: usize) -> usize {
if total_layers == 0 {
return self.config.max_seq_len;
}
let weight = (total_layers - layer_idx) as f64 / total_layers as f64;
let sum_weights: f64 = (1..=total_layers).map(|i| i as f64 / total_layers as f64).sum();
let budget = (weight / sum_weights) * self.config.max_seq_len as f64;
(budget.ceil() as usize).max(1)
}
pub fn compression_ratio(&self) -> f64 {
let total_elements = self.config.num_heads * self.config.head_dim;
let f32_bytes = (total_elements * 4 * 2) as f64; let q_bytes = self.entry_quantized_bytes() as f64;
if q_bytes < f64::EPSILON {
return 0.0;
}
f32_bytes / q_bytes
}
fn entry_quantized_bytes(&self) -> usize {
let elements = self.config.num_heads * self.config.head_dim;
let per_tensor = elements + self.config.num_heads * 4 * 2; per_tensor * 2
}
pub fn memory_bytes(&self) -> usize {
self.entries.len() * self.entry_quantized_bytes()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_config(bits: u8, policy: EvictionPolicy) -> KVCacheConfig {
KVCacheConfig {
max_seq_len: 8,
num_heads: 2,
head_dim: 4,
quantization_bits: bits,
eviction_policy: policy,
}
}
#[test]
fn test_quantize_roundtrip_4bit() {
let data: Vec<f32> = vec![0.0, 0.5, 1.0, -1.0, 0.25, -0.5, 0.75, -0.25];
let qt = quantize_asymmetric(&data, 2, 4);
let restored = dequantize(&qt, 2);
for (orig, rest) in data.iter().zip(restored.iter()) {
assert!((orig - rest).abs() < 0.15, "4-bit error too large: {orig} vs {rest}");
}
}
#[test]
fn test_quantize_roundtrip_3bit() {
let data: Vec<f32> = vec![0.0, 0.5, 1.0, -1.0, 0.3, -0.7, 0.8, -0.2];
let qt = quantize_asymmetric(&data, 2, 3);
let restored = dequantize(&qt, 2);
for (orig, rest) in data.iter().zip(restored.iter()) {
assert!((orig - rest).abs() < 0.35, "3-bit error too large: {orig} vs {rest}");
}
}
#[test]
fn test_symmetric_quantize_roundtrip() {
let data: Vec<f32> = vec![0.0, 0.5, -0.5, 1.0, -1.0];
let (qdata, scale) = quantize_symmetric(&data, 4);
let restored = dequantize_symmetric(&qdata, scale, 4);
for (orig, rest) in data.iter().zip(restored.iter()) {
assert!((orig - rest).abs() < 0.2, "sym roundtrip: {orig} vs {rest}");
}
}
#[test]
fn test_bankers_rounding() {
assert_eq!(round_to_nearest_even(2.5), 2.0);
assert_eq!(round_to_nearest_even(3.5), 4.0);
assert_eq!(round_to_nearest_even(4.5), 4.0);
assert_eq!(round_to_nearest_even(1.3), 1.0);
assert_eq!(round_to_nearest_even(1.7), 2.0);
}
#[test]
fn test_cache_append_and_get() {
let cfg = make_config(4, EvictionPolicy::H2O);
let mut mgr = CacheManager::new(cfg);
let k = vec![1.0_f32; 8];
let v = vec![-1.0_f32; 8];
mgr.append(&k, &v, 0);
assert_eq!(mgr.len(), 1);
let (keys, vals) = mgr.get(&[0]);
assert_eq!(keys.len(), 1);
assert_eq!(vals.len(), 1);
assert_eq!(keys[0].len(), 8);
}
#[test]
fn test_cache_empty() {
let cfg = make_config(4, EvictionPolicy::H2O);
let mgr = CacheManager::new(cfg);
assert!(mgr.is_empty());
assert_eq!(mgr.len(), 0);
let (k, v) = mgr.get(&[0]);
assert!(k.is_empty());
assert!(v.is_empty());
}
#[test]
fn test_h2o_eviction() {
let cfg = make_config(4, EvictionPolicy::H2O);
let mut mgr = CacheManager::new(cfg);
for i in 0..4 {
let k = vec![i as f32; 8];
let v = vec![i as f32; 8];
mgr.append(&k, &v, 0);
}
mgr.update_attention_scores(&[5.0, 1.0, 3.0, 4.0]);
mgr.evict(3);
assert_eq!(mgr.len(), 3);
let scores: Vec<f64> = mgr.entries.iter().map(|e| e.attention_score).collect();
assert!(!scores.contains(&1.0));
}
#[test]
fn test_sliding_window_eviction() {
let mut cfg = make_config(4, EvictionPolicy::SlidingWindow { window: 3, sink: 2 });
cfg.max_seq_len = 100; let mut mgr = CacheManager::new(cfg);
for i in 0..10 {
let k = vec![i as f32; 8];
let v = vec![i as f32; 8];
mgr.append(&k, &v, 0);
}
assert_eq!(mgr.len(), 10);
mgr.evict(5);
assert_eq!(mgr.len(), 5);
let seq_idxs: Vec<usize> = mgr.entries.iter().map(|e| e.seq_idx).collect();
assert_eq!(seq_idxs[0], 0);
assert_eq!(seq_idxs[1], 1);
assert!(seq_idxs.contains(&7));
assert!(seq_idxs.contains(&8));
assert!(seq_idxs.contains(&9));
}
#[test]
fn test_compression_ratio() {
let cfg = make_config(4, EvictionPolicy::H2O);
let mgr = CacheManager::new(cfg);
let ratio = mgr.compression_ratio();
assert!(ratio > 1.0, "compression ratio should be > 1.0, got {ratio}");
}
#[test]
fn test_memory_bytes() {
let cfg = make_config(4, EvictionPolicy::H2O);
let mut mgr = CacheManager::new(cfg);
assert_eq!(mgr.memory_bytes(), 0);
let k = vec![0.5_f32; 8];
let v = vec![-0.5_f32; 8];
mgr.append(&k, &v, 0);
assert!(mgr.memory_bytes() > 0);
let bytes_one = mgr.memory_bytes();
mgr.append(&k, &v, 0);
assert_eq!(mgr.memory_bytes(), bytes_one * 2);
}
#[test]
fn test_auto_eviction_on_append() {
let cfg = make_config(4, EvictionPolicy::H2O);
let mut mgr = CacheManager::new(cfg);
for i in 0..12 {
let k = vec![i as f32; 8];
let v = vec![i as f32; 8];
mgr.append(&k, &v, 0);
}
assert!(mgr.len() <= 8);
}
#[test]
fn test_pyramid_budget() {
let cfg = make_config(4, EvictionPolicy::PyramidKV { total_layers: 4 });
let mgr = CacheManager::new(cfg);
let b0 = mgr.pyramid_budget(0, 4);
let b3 = mgr.pyramid_budget(3, 4);
assert!(b0 > b3, "layer 0 budget ({b0}) should exceed layer 3 ({b3})");
}
#[test]
fn test_single_entry_operations() {
let cfg = make_config(3, EvictionPolicy::H2O);
let mut mgr = CacheManager::new(cfg);
let k = vec![0.42_f32; 8];
let v = vec![-0.42_f32; 8];
mgr.append(&k, &v, 0);
mgr.update_attention_scores(&[1.0]);
mgr.evict(1);
assert_eq!(mgr.len(), 1);
let (keys, vals) = mgr.get(&[0]);
assert_eq!(keys.len(), 1);
assert_eq!(vals.len(), 1);
}
}