use crate::error::{Error, Result};
use crate::inference::memory::{BlockAllocator, BlockId, BlockTable};
use crate::ops::traits::KvCacheOps;
use numr::dtype::DType;
use numr::runtime::Runtime;
use numr::tensor::Tensor;
pub struct PagedKvCache<R: Runtime> {
k_cache: Tensor<R>,
v_cache: Tensor<R>,
num_blocks: usize,
block_size: usize,
num_heads: usize,
head_dim: usize,
dtype: DType,
}
impl<R: Runtime<DType = DType>> PagedKvCache<R> {
pub fn new(
num_blocks: usize,
block_size: usize,
num_heads: usize,
head_dim: usize,
dtype: DType,
device: &R::Device,
) -> Self {
let shape = [num_blocks, block_size, num_heads, head_dim];
let k_cache = Tensor::<R>::zeros(&shape, dtype, device);
let v_cache = Tensor::<R>::zeros(&shape, dtype, device);
Self {
k_cache,
v_cache,
num_blocks,
block_size,
num_heads,
head_dim,
dtype,
}
}
pub fn update(
&self,
key: &Tensor<R>,
value: &Tensor<R>,
slot_mapping: &Tensor<R>,
client: &R::Client,
) -> Result<()>
where
R::Client: KvCacheOps<R>,
{
client.reshape_and_cache(
key,
value,
&self.k_cache,
&self.v_cache,
slot_mapping,
self.block_size,
)
}
pub fn k_cache(&self) -> &Tensor<R> {
&self.k_cache
}
pub fn v_cache(&self) -> &Tensor<R> {
&self.v_cache
}
pub fn num_blocks(&self) -> usize {
self.num_blocks
}
pub fn block_size(&self) -> usize {
self.block_size
}
pub fn num_heads(&self) -> usize {
self.num_heads
}
pub fn head_dim(&self) -> usize {
self.head_dim
}
pub fn dtype(&self) -> DType {
self.dtype
}
}
pub struct LayeredPagedKvCache<R: Runtime> {
layers: Vec<PagedKvCache<R>>,
block_tables: Vec<BlockTable>,
block_size: usize,
seq_len: usize,
}
impl<R: Runtime<DType = DType>> LayeredPagedKvCache<R> {
#[allow(clippy::too_many_arguments)]
pub fn new(
num_layers: usize,
num_blocks_per_layer: usize,
block_size: usize,
num_kv_heads: usize,
head_dim: usize,
dtype: DType,
device: &R::Device,
) -> Self {
let mut layers = Vec::with_capacity(num_layers);
let mut block_tables = Vec::with_capacity(num_layers);
for _ in 0..num_layers {
layers.push(PagedKvCache::new(
num_blocks_per_layer,
block_size,
num_kv_heads,
head_dim,
dtype,
device,
));
block_tables.push(BlockTable::new(block_size));
}
Self {
layers,
block_tables,
block_size,
seq_len: 0,
}
}
pub fn layer(&self, idx: usize) -> &PagedKvCache<R> {
&self.layers[idx]
}
pub fn block_table(&self, idx: usize) -> &BlockTable {
&self.block_tables[idx]
}
pub fn block_table_mut(&mut self, idx: usize) -> &mut BlockTable {
&mut self.block_tables[idx]
}
pub fn num_layers(&self) -> usize {
self.layers.len()
}
pub fn block_size(&self) -> usize {
self.block_size
}
pub fn seq_len(&self) -> usize {
self.seq_len
}
pub fn set_seq_len(&mut self, seq_len: usize) {
self.seq_len = seq_len;
for bt in &mut self.block_tables {
bt.set_num_tokens(seq_len);
}
}
pub fn allocate_blocks<A: BlockAllocator>(
&mut self,
additional_tokens: usize,
allocator: &A,
) -> Result<()> {
let needed = self.block_tables[0].additional_blocks_needed(additional_tokens);
if needed == 0 {
return Ok(());
}
let blocks = allocator.allocate(needed)?;
for bt in &mut self.block_tables {
bt.append_blocks(blocks.clone());
}
Ok(())
}
pub fn compute_slot_mapping(&self, start: usize, count: usize) -> Result<Vec<i32>> {
let bt = &self.block_tables[0];
let mut slots = Vec::with_capacity(count);
for pos in start..start + count {
let (block_id, offset) = bt.get_slot(pos).ok_or_else(|| Error::InferenceError {
reason: format!(
"no block allocated for token position {} (have {} blocks of size {})",
pos,
bt.num_blocks(),
self.block_size
),
})?;
slots.push((block_id as i32) * (self.block_size as i32) + (offset as i32));
}
Ok(slots)
}
pub fn block_table_device_format(&self, idx: usize) -> Vec<i32> {
self.block_tables[idx].to_device_format()
}
pub fn set_blocks(&mut self, blocks: Vec<BlockId>) {
for bt in &mut self.block_tables {
bt.blocks = blocks.clone();
}
}
pub fn reset(&mut self) {
self.seq_len = 0;
for bt in &mut self.block_tables {
bt.blocks.clear();
bt.num_tokens = 0;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use numr::runtime::cpu::CpuRuntime;
#[test]
fn test_paged_kv_cache_update() {
let (client, device) = crate::test_utils::cpu_setup();
let num_blocks = 4;
let block_size = 8;
let num_heads = 2;
let head_dim = 4;
let cache = PagedKvCache::<CpuRuntime>::new(
num_blocks,
block_size,
num_heads,
head_dim,
DType::F32,
&device,
);
let num_tokens = 3;
let k_data: Vec<f32> = (0..num_tokens * num_heads * head_dim)
.map(|i| i as f32 * 0.1)
.collect();
let v_data: Vec<f32> = (0..num_tokens * num_heads * head_dim)
.map(|i| i as f32 * 0.2)
.collect();
let key =
Tensor::<CpuRuntime>::from_slice(&k_data, &[num_tokens, num_heads, head_dim], &device);
let value =
Tensor::<CpuRuntime>::from_slice(&v_data, &[num_tokens, num_heads, head_dim], &device);
let slots: Vec<i32> = vec![0, 1, 9];
let slot_mapping = Tensor::<CpuRuntime>::from_slice(&slots, &[num_tokens], &device);
cache.update(&key, &value, &slot_mapping, &client).unwrap();
assert_eq!(
cache.k_cache().shape(),
&[num_blocks, block_size, num_heads, head_dim]
);
assert_eq!(
cache.v_cache().shape(),
&[num_blocks, block_size, num_heads, head_dim]
);
let kc = cache.k_cache().to_vec::<f32>();
assert!((kc[0] - 0.0).abs() < 1e-6);
assert!((kc[1] - 0.1).abs() < 1e-6);
}
}