use crate::error::{Result, RuvLLMError};
use crate::memory_pool::{BufferPool, BufferSize, PooledBuffer};
use crate::types::Precision;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::alloc::{alloc, dealloc, Layout};
use std::collections::VecDeque;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
const CACHE_LINE_SIZE: usize = 64;
const NEON_ALIGNMENT: usize = 16;
const POOL_BLOCK_SIZE: usize = 4096;
#[derive(Debug)]
pub struct AlignedBuffer {
ptr: *mut f32,
len: usize,
capacity: usize,
layout: Layout,
}
unsafe impl Send for AlignedBuffer {}
unsafe impl Sync for AlignedBuffer {}
impl AlignedBuffer {
pub fn new(capacity: usize) -> Self {
let size = capacity * std::mem::size_of::<f32>();
let layout = Layout::from_size_align(size.max(CACHE_LINE_SIZE), CACHE_LINE_SIZE)
.expect("Invalid layout");
let ptr = unsafe { alloc(layout) as *mut f32 };
if ptr.is_null() {
panic!("Failed to allocate aligned buffer");
}
Self {
ptr,
len: 0,
capacity,
layout,
}
}
#[inline(always)]
pub fn as_slice(&self) -> &[f32] {
unsafe { std::slice::from_raw_parts(self.ptr, self.len) }
}
#[inline(always)]
pub fn as_mut_slice(&mut self) -> &mut [f32] {
unsafe { std::slice::from_raw_parts_mut(self.ptr, self.len) }
}
#[inline(always)]
pub fn extend_from_slice(&mut self, data: &[f32]) {
let new_len = self.len + data.len();
assert!(new_len <= self.capacity, "Buffer overflow");
unsafe {
std::ptr::copy_nonoverlapping(data.as_ptr(), self.ptr.add(self.len), data.len());
}
self.len = new_len;
}
#[inline(always)]
pub fn clear(&mut self) {
self.len = 0;
}
#[inline(always)]
pub fn as_ptr(&self) -> *const f32 {
self.ptr
}
#[inline(always)]
pub fn as_mut_ptr(&mut self) -> *mut f32 {
self.ptr
}
#[inline(always)]
pub fn len(&self) -> usize {
self.len
}
#[inline(always)]
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[inline(always)]
pub fn capacity(&self) -> usize {
self.capacity
}
#[inline(always)]
pub(crate) unsafe fn set_len_unchecked(&mut self, new_len: usize) {
debug_assert!(
new_len <= self.capacity,
"set_len_unchecked: {} > {}",
new_len,
self.capacity
);
self.len = new_len;
}
}
impl Drop for AlignedBuffer {
fn drop(&mut self) {
unsafe {
dealloc(self.ptr as *mut u8, self.layout);
}
}
}
impl Clone for AlignedBuffer {
fn clone(&self) -> Self {
let mut new_buf = Self::new(self.capacity);
new_buf.extend_from_slice(self.as_slice());
new_buf
}
}
#[derive(Debug)]
pub struct KvMemoryPool {
key_pool: RwLock<Vec<AlignedBuffer>>,
value_pool: RwLock<Vec<AlignedBuffer>>,
block_size: usize,
max_blocks: usize,
allocated_blocks: AtomicUsize,
}
impl KvMemoryPool {
pub fn new(block_size: usize, max_blocks: usize) -> Self {
Self {
key_pool: RwLock::new(Vec::with_capacity(max_blocks)),
value_pool: RwLock::new(Vec::with_capacity(max_blocks)),
block_size,
max_blocks,
allocated_blocks: AtomicUsize::new(0),
}
}
pub fn get_key_buffer(&self) -> AlignedBuffer {
let mut pool = self.key_pool.write();
if let Some(buf) = pool.pop() {
buf
} else {
self.allocated_blocks.fetch_add(1, Ordering::Relaxed);
AlignedBuffer::new(self.block_size)
}
}
pub fn get_value_buffer(&self) -> AlignedBuffer {
let mut pool = self.value_pool.write();
if let Some(buf) = pool.pop() {
buf
} else {
self.allocated_blocks.fetch_add(1, Ordering::Relaxed);
AlignedBuffer::new(self.block_size)
}
}
pub fn return_key_buffer(&self, mut buf: AlignedBuffer) {
buf.clear();
let mut pool = self.key_pool.write();
if pool.len() < self.max_blocks {
pool.push(buf);
}
}
pub fn return_value_buffer(&self, mut buf: AlignedBuffer) {
buf.clear();
let mut pool = self.value_pool.write();
if pool.len() < self.max_blocks {
pool.push(buf);
}
}
pub fn prewarm(&self, count: usize) {
let count = count.min(self.max_blocks);
let mut key_pool = self.key_pool.write();
let mut value_pool = self.value_pool.write();
for _ in 0..count {
if key_pool.len() < self.max_blocks {
key_pool.push(AlignedBuffer::new(self.block_size));
self.allocated_blocks.fetch_add(1, Ordering::Relaxed);
}
if value_pool.len() < self.max_blocks {
value_pool.push(AlignedBuffer::new(self.block_size));
self.allocated_blocks.fetch_add(1, Ordering::Relaxed);
}
}
}
pub fn stats(&self) -> PoolStats {
PoolStats {
key_pool_size: self.key_pool.read().len(),
value_pool_size: self.value_pool.read().len(),
total_allocated: self.allocated_blocks.load(Ordering::Relaxed),
block_size_bytes: self.block_size * std::mem::size_of::<f32>(),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct PoolStats {
pub key_pool_size: usize,
pub value_pool_size: usize,
pub total_allocated: usize,
pub block_size_bytes: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KvCacheConfig {
pub tail_length: usize,
pub tail_precision: Precision,
pub store_precision: Precision,
pub max_tokens: usize,
pub num_kv_heads: usize,
pub head_dim: usize,
pub migration_batch: usize,
}
impl Default for KvCacheConfig {
fn default() -> Self {
Self {
tail_length: 256,
tail_precision: Precision::FP16,
store_precision: Precision::Q4,
max_tokens: 4096,
num_kv_heads: 8,
head_dim: 128,
migration_batch: 64,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum CacheTier {
Hot,
Warm,
Cold,
TurboQuant,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum CacheQuantization {
HighPrecisionTail {
tail_length: usize,
precision: Precision,
},
QuantizedStore {
precision: Precision,
compression_ratio: f32,
},
Hybrid {
tail_length: usize,
tail_precision: Precision,
store_precision: Precision,
},
TurboQuantHybrid {
tail_length: usize,
tail_precision: Precision,
turbo_bits: f32,
},
}
impl Default for CacheQuantization {
fn default() -> Self {
Self::Hybrid {
tail_length: 256,
tail_precision: Precision::FP16,
store_precision: Precision::Q4,
}
}
}
#[derive(Debug, Clone)]
struct KvPair {
keys: Vec<f32>,
values: Vec<f32>,
position: usize,
}
#[derive(Debug, Clone)]
struct QuantizedKvPair {
keys: Vec<f32>,
values: Vec<f32>,
scale: f32,
zero_point: f32,
position: usize,
}
impl QuantizedKvPair {
fn from_kv_pair(pair: &KvPair, precision: Precision) -> Self {
let (scale, zero_point) = Self::compute_scale_and_zero(&pair.keys, precision);
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
let quantize = |vals: &[f32]| -> Vec<f32> { Self::quantize_neon(vals, scale, zero_point) };
#[cfg(not(all(target_arch = "aarch64", target_feature = "neon")))]
let quantize = |vals: &[f32]| -> Vec<f32> {
vals.iter()
.map(|v| ((v - zero_point) / scale).round())
.collect()
};
Self {
keys: quantize(&pair.keys),
values: quantize(&pair.values),
scale,
zero_point,
position: pair.position,
}
}
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
fn quantize_neon(values: &[f32], scale: f32, zero_point: f32) -> Vec<f32> {
use std::arch::aarch64::*;
let mut result = vec![0.0f32; values.len()];
let inv_scale = 1.0 / scale;
unsafe {
let inv_scale_vec = vdupq_n_f32(inv_scale);
let zero_vec = vdupq_n_f32(zero_point);
const UNROLL_8X: usize = 8;
let chunks = values.len() / UNROLL_8X;
for c in 0..chunks {
let base = c * UNROLL_8X;
let v0 = vld1q_f32(values.as_ptr().add(base));
let v1 = vld1q_f32(values.as_ptr().add(base + 4));
let sub0 = vsubq_f32(v0, zero_vec);
let sub1 = vsubq_f32(v1, zero_vec);
let scaled0 = vmulq_f32(sub0, inv_scale_vec);
let scaled1 = vmulq_f32(sub1, inv_scale_vec);
let rounded0 = vrndnq_f32(scaled0);
let rounded1 = vrndnq_f32(scaled1);
vst1q_f32(result.as_mut_ptr().add(base), rounded0);
vst1q_f32(result.as_mut_ptr().add(base + 4), rounded1);
}
for i in (chunks * UNROLL_8X)..values.len() {
result[i] = ((values[i] - zero_point) * inv_scale).round();
}
}
result
}
fn compute_scale_and_zero(values: &[f32], precision: Precision) -> (f32, f32) {
if values.is_empty() {
return (1.0, 0.0);
}
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
let (min_val, max_val) = unsafe { Self::minmax_neon(values) };
#[cfg(not(all(target_arch = "aarch64", target_feature = "neon")))]
let (min_val, max_val) = {
let min = values.iter().cloned().fold(f32::INFINITY, f32::min);
let max = values.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
(min, max)
};
let range = match precision {
Precision::Q8 => 255.0,
Precision::Q4 | Precision::Q4K => 15.0,
_ => 255.0,
};
let scale = (max_val - min_val) / range;
let zero_point = min_val;
(scale.max(1e-8), zero_point)
}
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
unsafe fn minmax_neon(values: &[f32]) -> (f32, f32) {
use std::arch::aarch64::*;
let mut min_vec = vdupq_n_f32(f32::INFINITY);
let mut max_vec = vdupq_n_f32(f32::NEG_INFINITY);
const UNROLL_8X: usize = 8;
let chunks = values.len() / UNROLL_8X;
for c in 0..chunks {
let base = c * UNROLL_8X;
let v0 = vld1q_f32(values.as_ptr().add(base));
let v1 = vld1q_f32(values.as_ptr().add(base + 4));
min_vec = vminq_f32(min_vec, vminq_f32(v0, v1));
max_vec = vmaxq_f32(max_vec, vmaxq_f32(v0, v1));
}
let min_val = vminvq_f32(min_vec);
let max_val = vmaxvq_f32(max_vec);
let mut final_min = min_val;
let mut final_max = max_val;
for i in (chunks * UNROLL_8X)..values.len() {
final_min = final_min.min(values[i]);
final_max = final_max.max(values[i]);
}
(final_min, final_max)
}
fn dequantize(&self) -> KvPair {
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
let dequant =
|vals: &[f32]| -> Vec<f32> { Self::dequantize_neon(vals, self.scale, self.zero_point) };
#[cfg(not(all(target_arch = "aarch64", target_feature = "neon")))]
let dequant = |vals: &[f32]| -> Vec<f32> {
vals.iter()
.map(|v| v * self.scale + self.zero_point)
.collect()
};
KvPair {
keys: dequant(&self.keys),
values: dequant(&self.values),
position: self.position,
}
}
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
fn dequantize_neon(quantized: &[f32], scale: f32, zero_point: f32) -> Vec<f32> {
use std::arch::aarch64::*;
let mut result = vec![0.0f32; quantized.len()];
unsafe {
let scale_vec = vdupq_n_f32(scale);
let zero_vec = vdupq_n_f32(zero_point);
const UNROLL_8X: usize = 8;
let chunks = quantized.len() / UNROLL_8X;
for c in 0..chunks {
let base = c * UNROLL_8X;
let q0 = vld1q_f32(quantized.as_ptr().add(base));
let q1 = vld1q_f32(quantized.as_ptr().add(base + 4));
let d0 = vfmaq_f32(zero_vec, q0, scale_vec);
let d1 = vfmaq_f32(zero_vec, q1, scale_vec);
vst1q_f32(result.as_mut_ptr().add(base), d0);
vst1q_f32(result.as_mut_ptr().add(base + 4), d1);
}
for i in (chunks * UNROLL_8X)..quantized.len() {
result[i] = quantized[i] * scale + zero_point;
}
}
result
}
#[inline(always)]
fn dequantize_into(&self, key_buf: &mut AlignedBuffer, value_buf: &mut AlignedBuffer) {
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
unsafe {
let key_new_len = key_buf.len() + self.keys.len();
let value_new_len = value_buf.len() + self.values.len();
assert!(
key_new_len <= key_buf.capacity(),
"Key buffer overflow: {} > {}",
key_new_len,
key_buf.capacity()
);
assert!(
value_new_len <= value_buf.capacity(),
"Value buffer overflow: {} > {}",
value_new_len,
value_buf.capacity()
);
Self::dequantize_neon_into(
&self.keys,
key_buf.as_mut_ptr().add(key_buf.len()),
self.scale,
self.zero_point,
);
Self::dequantize_neon_into(
&self.values,
value_buf.as_mut_ptr().add(value_buf.len()),
self.scale,
self.zero_point,
);
key_buf.set_len_unchecked(key_new_len);
value_buf.set_len_unchecked(value_new_len);
}
#[cfg(not(all(target_arch = "aarch64", target_feature = "neon")))]
{
let keys: Vec<f32> = self
.keys
.iter()
.map(|v| v * self.scale + self.zero_point)
.collect();
let values: Vec<f32> = self
.values
.iter()
.map(|v| v * self.scale + self.zero_point)
.collect();
key_buf.extend_from_slice(&keys);
value_buf.extend_from_slice(&values);
}
}
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
#[inline(always)]
unsafe fn dequantize_neon_into(
quantized: &[f32],
output: *mut f32,
scale: f32,
zero_point: f32,
) {
use std::arch::aarch64::*;
let scale_vec = vdupq_n_f32(scale);
let zero_vec = vdupq_n_f32(zero_point);
const UNROLL_8X: usize = 8;
let chunks = quantized.len() / UNROLL_8X;
for c in 0..chunks {
let base = c * UNROLL_8X;
let q0 = vld1q_f32(quantized.as_ptr().add(base));
let q1 = vld1q_f32(quantized.as_ptr().add(base + 4));
let d0 = vfmaq_f32(zero_vec, q0, scale_vec);
let d1 = vfmaq_f32(zero_vec, q1, scale_vec);
vst1q_f32(output.add(base), d0);
vst1q_f32(output.add(base + 4), d1);
}
for i in (chunks * UNROLL_8X)..quantized.len() {
*output.add(i) = quantized[i] * scale + zero_point;
}
}
}
#[derive(Debug)]
pub struct TwoTierKvCache {
config: KvCacheConfig,
tail: RwLock<VecDeque<KvPair>>,
store: RwLock<Vec<QuantizedKvPair>>,
total_tokens: AtomicUsize,
quantization_policy: Arc<RwLock<CacheQuantization>>,
memory_pool: Arc<KvMemoryPool>,
}
impl TwoTierKvCache {
pub fn new(config: KvCacheConfig) -> Self {
let quantization_policy = Arc::new(RwLock::new(CacheQuantization::Hybrid {
tail_length: config.tail_length,
tail_precision: config.tail_precision,
store_precision: config.store_precision,
}));
let stride = config.num_kv_heads * config.head_dim;
let block_size = stride * config.tail_length;
let max_blocks = (config.max_tokens / config.tail_length).max(4);
let memory_pool = Arc::new(KvMemoryPool::new(block_size, max_blocks));
memory_pool.prewarm(2);
Self {
config,
tail: RwLock::new(VecDeque::new()),
store: RwLock::new(Vec::new()),
total_tokens: AtomicUsize::new(0),
quantization_policy,
memory_pool,
}
}
pub fn with_pool(config: KvCacheConfig, pool: Arc<KvMemoryPool>) -> Self {
let quantization_policy = Arc::new(RwLock::new(CacheQuantization::Hybrid {
tail_length: config.tail_length,
tail_precision: config.tail_precision,
store_precision: config.store_precision,
}));
Self {
config,
tail: RwLock::new(VecDeque::new()),
store: RwLock::new(Vec::new()),
total_tokens: AtomicUsize::new(0),
quantization_policy,
memory_pool: pool,
}
}
pub fn append(&self, keys: &[f32], values: &[f32]) -> Result<()> {
let stride = self.config.num_kv_heads * self.config.head_dim;
let num_tokens = keys.len() / stride;
if keys.len() != values.len() {
return Err(RuvLLMError::KvCache(
"Key and value lengths must match".to_string(),
));
}
let current_tokens = self.total_tokens.load(Ordering::SeqCst);
let mut tail = self.tail.write();
for i in 0..num_tokens {
let offset = i * stride;
tail.push_back(KvPair {
keys: keys[offset..offset + stride].to_vec(),
values: values[offset..offset + stride].to_vec(),
position: current_tokens + i,
});
}
while tail.len() > self.config.tail_length {
let batch_size = self
.config
.migration_batch
.min(tail.len() - self.config.tail_length);
let to_migrate: Vec<_> = (0..batch_size).filter_map(|_| tail.pop_front()).collect();
let mut store = self.store.write();
for pair in to_migrate {
let quantized = QuantizedKvPair::from_kv_pair(&pair, self.config.store_precision);
store.push(quantized);
}
}
self.total_tokens.fetch_add(num_tokens, Ordering::SeqCst);
self.enforce_max_tokens()?;
Ok(())
}
fn enforce_max_tokens(&self) -> Result<()> {
let total = self.total_tokens.load(Ordering::SeqCst);
if total <= self.config.max_tokens {
return Ok(());
}
let to_evict = total - self.config.max_tokens;
let mut store = self.store.write();
let store_evict = to_evict.min(store.len());
store.drain(0..store_evict);
self.total_tokens.fetch_sub(store_evict, Ordering::SeqCst);
let remaining = to_evict - store_evict;
if remaining > 0 {
let mut tail = self.tail.write();
for _ in 0..remaining.min(tail.len()) {
tail.pop_front();
}
self.total_tokens
.fetch_sub(remaining.min(tail.len()), Ordering::SeqCst);
}
Ok(())
}
pub fn get_all_kv(&self) -> (Vec<f32>, Vec<f32>) {
let stride = self.config.num_kv_heads * self.config.head_dim;
let total = self.total_tokens.load(Ordering::SeqCst);
let mut all_keys = Vec::with_capacity(total * stride);
let mut all_values = Vec::with_capacity(total * stride);
let store = self.store.read();
for qpair in store.iter() {
let pair = qpair.dequantize();
all_keys.extend_from_slice(&pair.keys);
all_values.extend_from_slice(&pair.values);
}
drop(store);
let tail = self.tail.read();
for pair in tail.iter() {
all_keys.extend_from_slice(&pair.keys);
all_values.extend_from_slice(&pair.values);
}
(all_keys, all_values)
}
pub fn get_all_kv_aligned(&self) -> (AlignedBuffer, AlignedBuffer) {
let stride = self.config.num_kv_heads * self.config.head_dim;
let total = self.total_tokens.load(Ordering::SeqCst);
let mut key_buf = AlignedBuffer::new(total * stride);
let mut value_buf = AlignedBuffer::new(total * stride);
let store = self.store.read();
for qpair in store.iter() {
qpair.dequantize_into(&mut key_buf, &mut value_buf);
}
drop(store);
let tail = self.tail.read();
for pair in tail.iter() {
key_buf.extend_from_slice(&pair.keys);
value_buf.extend_from_slice(&pair.values);
}
(key_buf, value_buf)
}
pub fn memory_pool(&self) -> &Arc<KvMemoryPool> {
&self.memory_pool
}
pub fn pool_stats(&self) -> PoolStats {
self.memory_pool.stats()
}
pub fn attend(&self, query: &[f32], scale: f32) -> Result<Vec<f32>> {
let (keys, values) = self.get_all_kv();
let stride = self.config.num_kv_heads * self.config.head_dim;
let num_tokens = keys.len() / stride;
if num_tokens == 0 {
return Ok(vec![0.0; query.len()]);
}
let mut scores = Vec::with_capacity(num_tokens);
for t in 0..num_tokens {
let k_offset = t * stride;
let k_slice = &keys[k_offset..k_offset + stride];
let score: f32 = query
.iter()
.zip(k_slice.iter())
.map(|(q, k)| q * k * scale)
.sum();
scores.push(score);
}
let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp_scores: Vec<f32> = scores.iter().map(|s| (s - max_score).exp()).collect();
let sum_exp: f32 = exp_scores.iter().sum();
let attn_weights: Vec<f32> = exp_scores.iter().map(|e| e / sum_exp).collect();
let mut output = vec![0.0; stride];
for (t, weight) in attn_weights.iter().enumerate() {
let v_offset = t * stride;
for (i, v) in values[v_offset..v_offset + stride].iter().enumerate() {
output[i] += weight * v;
}
}
Ok(output)
}
pub fn stats(&self) -> KvCacheStats {
let tail = self.tail.read();
let store = self.store.read();
let stride = self.config.num_kv_heads * self.config.head_dim;
let tail_bytes = tail.len() * stride * 4 * 2; let store_bytes =
store.len() * stride * self.config.store_precision.bytes_per_element() as usize * 2;
KvCacheStats {
total_tokens: self.total_tokens.load(Ordering::SeqCst),
tail_tokens: tail.len(),
store_tokens: store.len(),
tail_bytes,
store_bytes,
compression_ratio: tail_bytes as f32 / store_bytes.max(1) as f32,
}
}
pub fn clear(&self) {
let mut tail = self.tail.write();
let mut store = self.store.write();
tail.clear();
store.clear();
self.total_tokens.store(0, Ordering::SeqCst);
}
pub fn update_policy(&self, policy: CacheQuantization) {
let mut current = self.quantization_policy.write();
*current = policy;
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct KvCacheStats {
pub total_tokens: usize,
pub tail_tokens: usize,
pub store_tokens: usize,
pub tail_bytes: usize,
pub store_bytes: usize,
pub compression_ratio: f32,
}
pub struct PooledKvBlock {
keys: PooledBuffer,
values: PooledBuffer,
token_count: usize,
stride: usize,
}
impl PooledKvBlock {
pub fn new(
pool: &BufferPool,
max_tokens: usize,
num_heads: usize,
head_dim: usize,
) -> Option<Self> {
let stride = num_heads * head_dim;
let bytes_needed = max_tokens * stride * std::mem::size_of::<f32>();
let keys = pool.acquire_for_size(bytes_needed).ok()??;
let values = pool.acquire_for_size(bytes_needed).ok()??;
Some(Self {
keys,
values,
token_count: 0,
stride,
})
}
pub fn append(&mut self, keys: &[f32], values: &[f32]) -> usize {
let capacity_tokens = self.keys.capacity() / (self.stride * std::mem::size_of::<f32>());
let input_tokens = keys.len() / self.stride;
let space_remaining = capacity_tokens.saturating_sub(self.token_count);
let tokens_to_append = input_tokens.min(space_remaining);
if tokens_to_append == 0 {
return 0;
}
let elements = tokens_to_append * self.stride;
let offset = self.token_count * self.stride;
let key_slice = self.keys.as_slice_mut::<f32>();
key_slice[offset..offset + elements].copy_from_slice(&keys[..elements]);
let value_slice = self.values.as_slice_mut::<f32>();
value_slice[offset..offset + elements].copy_from_slice(&values[..elements]);
self.token_count += tokens_to_append;
tokens_to_append
}
pub fn keys(&self) -> &[f32] {
let elements = self.token_count * self.stride;
&self.keys.as_slice::<f32>()[..elements]
}
pub fn values(&self) -> &[f32] {
let elements = self.token_count * self.stride;
&self.values.as_slice::<f32>()[..elements]
}
pub fn token_count(&self) -> usize {
self.token_count
}
pub fn is_full(&self) -> bool {
let capacity_tokens = self.keys.capacity() / (self.stride * std::mem::size_of::<f32>());
self.token_count >= capacity_tokens
}
pub fn remaining_tokens(&self) -> usize {
let capacity_tokens = self.keys.capacity() / (self.stride * std::mem::size_of::<f32>());
capacity_tokens.saturating_sub(self.token_count)
}
pub fn clear(&mut self) {
self.token_count = 0;
}
}
impl std::fmt::Debug for PooledKvBlock {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PooledKvBlock")
.field("token_count", &self.token_count)
.field("stride", &self.stride)
.field("key_capacity", &self.keys.capacity())
.field("value_capacity", &self.values.capacity())
.finish()
}
}
#[derive(Debug)]
pub struct PooledKvCache {
config: KvCacheConfig,
pool: BufferPool,
blocks: RwLock<Vec<PooledKvBlock>>,
tokens_per_block: usize,
total_tokens: AtomicUsize,
}
impl PooledKvCache {
pub fn new(config: KvCacheConfig, pool: BufferPool, tokens_per_block: usize) -> Self {
Self {
config,
pool,
blocks: RwLock::new(Vec::new()),
tokens_per_block,
total_tokens: AtomicUsize::new(0),
}
}
pub fn with_new_pool(config: KvCacheConfig, tokens_per_block: usize) -> Self {
let pool = BufferPool::new();
Self::new(config, pool, tokens_per_block)
}
pub fn append(&self, keys: &[f32], values: &[f32]) -> Result<()> {
let stride = self.config.num_kv_heads * self.config.head_dim;
let input_tokens = keys.len() / stride;
if keys.len() != values.len() {
return Err(RuvLLMError::KvCache(
"Key and value lengths must match".to_string(),
));
}
let mut blocks = self.blocks.write();
let mut remaining_keys = keys;
let mut remaining_values = values;
while !remaining_keys.is_empty() {
let need_new_block = blocks.is_empty() || blocks.last().map_or(true, |b| b.is_full());
if need_new_block {
let new_block = PooledKvBlock::new(
&self.pool,
self.tokens_per_block,
self.config.num_kv_heads,
self.config.head_dim,
)
.ok_or_else(|| {
RuvLLMError::OutOfMemory("Failed to allocate KV block from pool".to_string())
})?;
blocks.push(new_block);
}
let block = blocks
.last_mut()
.expect("blocks should be non-empty after allocation");
let tokens_appended = block.append(remaining_keys, remaining_values);
if tokens_appended == 0 {
break;
}
let elements = tokens_appended * stride;
remaining_keys = &remaining_keys[elements..];
remaining_values = &remaining_values[elements..];
self.total_tokens
.fetch_add(tokens_appended, Ordering::SeqCst);
}
self.enforce_max_tokens(&mut blocks)?;
Ok(())
}
fn enforce_max_tokens(&self, blocks: &mut Vec<PooledKvBlock>) -> Result<()> {
let total = self.total_tokens.load(Ordering::SeqCst);
if total <= self.config.max_tokens {
return Ok(());
}
let mut to_evict = total - self.config.max_tokens;
while to_evict > 0 && !blocks.is_empty() {
let first_block_tokens = blocks[0].token_count();
if first_block_tokens <= to_evict {
blocks.remove(0);
to_evict -= first_block_tokens;
self.total_tokens
.fetch_sub(first_block_tokens, Ordering::SeqCst);
} else {
let removed_tokens = blocks[0].token_count();
blocks.remove(0);
self.total_tokens
.fetch_sub(removed_tokens, Ordering::SeqCst);
break;
}
}
Ok(())
}
pub fn get_all_kv(&self) -> (Vec<f32>, Vec<f32>) {
let blocks = self.blocks.read();
let total = self.total_tokens.load(Ordering::SeqCst);
let stride = self.config.num_kv_heads * self.config.head_dim;
let mut all_keys = Vec::with_capacity(total * stride);
let mut all_values = Vec::with_capacity(total * stride);
for block in blocks.iter() {
all_keys.extend_from_slice(block.keys());
all_values.extend_from_slice(block.values());
}
(all_keys, all_values)
}
pub fn stats(&self) -> PooledKvCacheStats {
let blocks = self.blocks.read();
let total_tokens = self.total_tokens.load(Ordering::SeqCst);
let stride = self.config.num_kv_heads * self.config.head_dim;
PooledKvCacheStats {
total_tokens,
block_count: blocks.len(),
tokens_per_block: self.tokens_per_block,
total_bytes: total_tokens * stride * std::mem::size_of::<f32>() * 2,
pool_stats: self.pool.stats(),
}
}
pub fn clear(&self) {
let mut blocks = self.blocks.write();
blocks.clear();
self.total_tokens.store(0, Ordering::SeqCst);
}
pub fn pool(&self) -> &BufferPool {
&self.pool
}
}
#[derive(Debug, Clone)]
pub struct PooledKvCacheStats {
pub total_tokens: usize,
pub block_count: usize,
pub tokens_per_block: usize,
pub total_bytes: usize,
pub pool_stats: crate::memory_pool::BufferPoolStats,
}
#[cfg(feature = "quantize")]
pub struct TurboQuantKvCache {
config: TurboQuantKvCacheConfig,
tail: RwLock<VecDeque<KvPair>>,
turbo_tier: RwLock<crate::quantize::turbo_quant::TurboQuantCacheTier>,
total_tokens: AtomicUsize,
}
#[cfg(feature = "quantize")]
#[derive(Debug, Clone)]
pub struct TurboQuantKvCacheConfig {
pub tail_length: usize,
pub max_tokens: usize,
pub num_kv_heads: usize,
pub head_dim: usize,
pub migration_batch: usize,
pub turbo_config: crate::quantize::turbo_quant::TurboQuantConfig,
}
#[cfg(feature = "quantize")]
impl Default for TurboQuantKvCacheConfig {
fn default() -> Self {
Self {
tail_length: 256,
max_tokens: 8192,
num_kv_heads: 8,
head_dim: 128,
migration_batch: 64,
turbo_config: crate::quantize::turbo_quant::TurboQuantConfig::default(),
}
}
}
#[cfg(feature = "quantize")]
impl TurboQuantKvCache {
pub fn new(config: TurboQuantKvCacheConfig) -> Result<Self> {
let turbo_tier =
crate::quantize::turbo_quant::TurboQuantCacheTier::new(config.turbo_config.clone())?;
Ok(Self {
config,
tail: RwLock::new(VecDeque::new()),
turbo_tier: RwLock::new(turbo_tier),
total_tokens: AtomicUsize::new(0),
})
}
pub fn append(&self, keys: &[f32], values: &[f32]) -> Result<()> {
let stride = self.config.num_kv_heads * self.config.head_dim;
let num_tokens = keys.len() / stride;
if keys.len() != values.len() {
return Err(RuvLLMError::KvCache(
"Key and value lengths must match".to_string(),
));
}
let current_tokens = self.total_tokens.load(Ordering::SeqCst);
let mut tail = self.tail.write();
for i in 0..num_tokens {
let offset = i * stride;
tail.push_back(KvPair {
keys: keys[offset..offset + stride].to_vec(),
values: values[offset..offset + stride].to_vec(),
position: current_tokens + i,
});
}
while tail.len() > self.config.tail_length {
let batch_size = self
.config
.migration_batch
.min(tail.len() - self.config.tail_length);
let to_migrate: Vec<_> = (0..batch_size).filter_map(|_| tail.pop_front()).collect();
let mut turbo = self.turbo_tier.write();
for pair in to_migrate {
turbo.push(&pair.keys, &pair.values, pair.position)?;
}
}
self.total_tokens.fetch_add(num_tokens, Ordering::SeqCst);
self.enforce_max_tokens()?;
Ok(())
}
fn enforce_max_tokens(&self) -> Result<()> {
let total = self.total_tokens.load(Ordering::SeqCst);
if total <= self.config.max_tokens {
return Ok(());
}
let to_evict = total - self.config.max_tokens;
let mut turbo = self.turbo_tier.write();
let turbo_evict = to_evict.min(turbo.len());
turbo.evict_oldest(turbo_evict);
self.total_tokens.fetch_sub(turbo_evict, Ordering::SeqCst);
let remaining = to_evict - turbo_evict;
if remaining > 0 {
let mut tail = self.tail.write();
let tail_evict = remaining.min(tail.len());
for _ in 0..tail_evict {
tail.pop_front();
}
self.total_tokens.fetch_sub(tail_evict, Ordering::SeqCst);
}
Ok(())
}
pub fn get_all_kv(&self) -> Result<(Vec<f32>, Vec<f32>)> {
let stride = self.config.num_kv_heads * self.config.head_dim;
let total = self.total_tokens.load(Ordering::SeqCst);
let mut all_keys = Vec::with_capacity(total * stride);
let mut all_values = Vec::with_capacity(total * stride);
let turbo = self.turbo_tier.read();
let (turbo_keys, turbo_values) = turbo.get_all_kv()?;
all_keys.extend(turbo_keys);
all_values.extend(turbo_values);
drop(turbo);
let tail = self.tail.read();
for pair in tail.iter() {
all_keys.extend_from_slice(&pair.keys);
all_values.extend_from_slice(&pair.values);
}
Ok((all_keys, all_values))
}
pub fn stats(&self) -> TurboQuantKvCacheStats {
let tail = self.tail.read();
let turbo = self.turbo_tier.read();
let stride = self.config.num_kv_heads * self.config.head_dim;
let tail_bytes = tail.len() * stride * 4 * 2; let turbo_stats = turbo.stats();
TurboQuantKvCacheStats {
total_tokens: self.total_tokens.load(Ordering::SeqCst),
tail_tokens: tail.len(),
turbo_tokens: turbo.len(),
tail_bytes,
turbo_bytes: turbo_stats.compressed_bytes,
turbo_original_bytes: turbo_stats.original_bytes,
turbo_compression_ratio: turbo_stats.compression_ratio,
turbo_bits_per_value: turbo_stats.bits_per_value,
}
}
pub fn clear(&self) {
let mut tail = self.tail.write();
let mut turbo = self.turbo_tier.write();
tail.clear();
turbo.clear();
self.total_tokens.store(0, Ordering::SeqCst);
}
}
#[cfg(feature = "quantize")]
#[derive(Debug, Clone)]
pub struct TurboQuantKvCacheStats {
pub total_tokens: usize,
pub tail_tokens: usize,
pub turbo_tokens: usize,
pub tail_bytes: usize,
pub turbo_bytes: usize,
pub turbo_original_bytes: usize,
pub turbo_compression_ratio: f32,
pub turbo_bits_per_value: f32,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_kv_cache_append() {
let config = KvCacheConfig {
tail_length: 4,
num_kv_heads: 2,
head_dim: 4,
migration_batch: 2,
..Default::default()
};
let cache = TwoTierKvCache::new(config);
let keys = vec![1.0; 2 * 4]; let values = vec![1.0; 2 * 4];
cache.append(&keys, &values).unwrap();
let stats = cache.stats();
assert_eq!(stats.total_tokens, 1);
assert_eq!(stats.tail_tokens, 1);
assert_eq!(stats.store_tokens, 0);
}
#[test]
fn test_kv_cache_migration() {
let config = KvCacheConfig {
tail_length: 2,
num_kv_heads: 2,
head_dim: 4,
migration_batch: 1,
max_tokens: 100,
..Default::default()
};
let cache = TwoTierKvCache::new(config);
for _ in 0..5 {
let keys = vec![1.0; 2 * 4];
let values = vec![1.0; 2 * 4];
cache.append(&keys, &values).unwrap();
}
let stats = cache.stats();
assert_eq!(stats.total_tokens, 5);
assert_eq!(stats.tail_tokens, 2);
assert_eq!(stats.store_tokens, 3);
}
#[test]
fn test_kv_cache_attend() {
let config = KvCacheConfig {
tail_length: 4,
num_kv_heads: 1,
head_dim: 4,
..Default::default()
};
let cache = TwoTierKvCache::new(config);
let keys = vec![1.0, 0.0, 0.0, 0.0];
let values = vec![1.0, 2.0, 3.0, 4.0];
cache.append(&keys, &values).unwrap();
let query = vec![1.0, 0.0, 0.0, 0.0];
let output = cache.attend(&query, 1.0).unwrap();
assert_eq!(output.len(), 4);
assert!((output[0] - 1.0).abs() < 0.1);
}
#[test]
fn test_pooled_kv_cache_basic() {
let config = KvCacheConfig {
tail_length: 4,
num_kv_heads: 2,
head_dim: 4,
max_tokens: 100,
..Default::default()
};
let cache = PooledKvCache::with_new_pool(config, 16);
let stride = 2 * 4; let keys = vec![1.0; stride]; let values = vec![2.0; stride];
cache.append(&keys, &values).unwrap();
let stats = cache.stats();
assert_eq!(stats.total_tokens, 1);
assert_eq!(stats.block_count, 1);
}
#[test]
fn test_pooled_kv_cache_multiple_blocks() {
let config = KvCacheConfig {
tail_length: 4,
num_kv_heads: 2,
head_dim: 4,
max_tokens: 100,
..Default::default()
};
let cache = PooledKvCache::with_new_pool(config, 2);
let stride = 2 * 4;
for i in 0..5 {
let keys = vec![i as f32; stride];
let values = vec![(i * 2) as f32; stride];
cache.append(&keys, &values).unwrap();
}
let stats = cache.stats();
assert_eq!(stats.total_tokens, 5);
assert!(stats.block_count >= 1, "Should have at least 1 block");
assert!(stats.block_count <= 5, "Should have at most 5 blocks");
let (all_keys, all_values) = cache.get_all_kv();
assert_eq!(all_keys.len(), 5 * stride);
assert_eq!(all_values.len(), 5 * stride);
assert_eq!(all_keys[0], 0.0);
assert_eq!(all_keys[4 * stride], 4.0);
}
#[test]
fn test_pooled_kv_cache_pool_reuse() {
let config = KvCacheConfig {
tail_length: 4,
num_kv_heads: 2,
head_dim: 4,
max_tokens: 100,
..Default::default()
};
let pool = BufferPool::new();
pool.prewarm(BufferSize::KB4, 4);
let cache = PooledKvCache::new(config, pool, 16);
let stride = 2 * 4;
let keys = vec![1.0; stride];
let values = vec![2.0; stride];
for _ in 0..3 {
cache.append(&keys, &values).unwrap();
cache.clear();
}
let stats = cache.stats();
assert_eq!(stats.total_tokens, 0);
assert!(stats.pool_stats.returns > 0 || stats.pool_stats.hits > 0);
}
}