use candle_core::{DType, Device, Result, Tensor};
use std::sync::{Arc, Mutex};
#[derive(Clone, Debug)]
pub struct CacheConfig {
pub block_size: usize,
pub num_gpu_blocks: usize,
pub num_kv_heads: usize,
pub head_dim: usize,
pub num_layers: usize,
}
impl CacheConfig {
pub fn new(
num_kv_heads: usize,
head_dim: usize,
num_layers: usize,
max_seq_len: usize,
max_batch_size: usize,
) -> Self {
let block_size = 16;
let blocks_per_seq = max_seq_len.div_ceil(block_size);
let num_gpu_blocks = blocks_per_seq * max_batch_size;
Self {
block_size,
num_gpu_blocks,
num_kv_heads,
head_dim,
num_layers,
}
}
pub fn memory_bytes(&self, dtype: DType) -> usize {
let bytes_per_elem = match dtype {
DType::F32 => 4,
DType::F16 | DType::BF16 => 2,
_ => 4,
};
let key_size = self.num_gpu_blocks * self.num_kv_heads * self.head_dim * self.block_size;
let value_size = self.num_gpu_blocks * self.num_kv_heads * self.head_dim * self.block_size;
(key_size + value_size) * bytes_per_elem * self.num_layers
}
}
#[derive(Clone)]
pub struct PagedKVCache {
pub key_cache: Tensor,
pub value_cache: Tensor,
pub block_size: usize,
}
impl PagedKVCache {
pub fn new(
num_blocks: usize,
num_heads: usize,
head_dim: usize,
block_size: usize,
dtype: DType,
device: &Device,
) -> Result<Self> {
let key_cache =
Tensor::zeros((num_blocks, num_heads, head_dim, block_size), dtype, device)?;
let value_cache =
Tensor::zeros((num_blocks, num_heads, head_dim, block_size), dtype, device)?;
Ok(Self {
key_cache,
value_cache,
block_size,
})
}
}
pub struct CacheEngine {
caches: Arc<Mutex<Vec<PagedKVCache>>>,
config: CacheConfig,
}
impl CacheEngine {
pub fn new(config: CacheConfig, dtype: DType, device: &Device) -> Result<Self> {
let mut caches = Vec::with_capacity(config.num_layers);
for _ in 0..config.num_layers {
let cache = PagedKVCache::new(
config.num_gpu_blocks,
config.num_kv_heads,
config.head_dim,
config.block_size,
dtype,
device,
)?;
caches.push(cache);
}
tracing::info!(
"CacheEngine: Allocated {} blocks × {} layers = {:.2} MB",
config.num_gpu_blocks,
config.num_layers,
config.memory_bytes(dtype) as f64 / (1024.0 * 1024.0)
);
Ok(Self {
caches: Arc::new(Mutex::new(caches)),
config,
})
}
pub fn get_cache(&self, layer_idx: usize) -> Option<PagedKVCache> {
let caches = self.caches.lock().ok()?;
caches.get(layer_idx).cloned()
}
pub fn get_all_caches(&self) -> Vec<PagedKVCache> {
self.caches.lock().unwrap().clone()
}
pub fn config(&self) -> &CacheConfig {
&self.config
}
pub fn block_size(&self) -> usize {
self.config.block_size
}
pub fn num_blocks(&self) -> usize {
self.config.num_gpu_blocks
}
}