use oxicuda_blas::GpuFloat;
use oxicuda_driver::ffi::CUdeviceptr;
use oxicuda_memory::DeviceBuffer;
use crate::error::{DnnError, DnnResult};
#[derive(Debug, Clone)]
pub struct KvCacheConfig {
pub num_layers: u32,
pub num_heads: u32,
pub head_dim: u32,
pub page_size: u32,
pub max_pages: u32,
}
impl KvCacheConfig {
#[must_use]
pub fn page_elements(&self) -> usize {
self.num_heads as usize * self.page_size as usize * self.head_dim as usize
}
#[must_use]
pub fn page_bytes<T: GpuFloat>(&self) -> usize {
self.page_elements() * T::SIZE
}
#[must_use]
pub fn total_pool_elements(&self) -> usize {
self.page_elements() * self.max_pages as usize * 2 }
}
pub struct KvCache {
config: KvCacheConfig,
k_pool_ptr: CUdeviceptr,
v_pool_ptr: CUdeviceptr,
free_pages: Vec<u32>,
allocated_count: u32,
page_tables: Vec<i32>,
num_sequences: u32,
max_pages_per_seq: u32,
}
impl KvCache {
pub fn new<T: GpuFloat>(
config: KvCacheConfig,
k_pool: &DeviceBuffer<T>,
v_pool: &DeviceBuffer<T>,
max_sequences: u32,
max_pages_per_seq: u32,
) -> DnnResult<Self> {
let required = config.page_elements() * config.max_pages as usize;
if k_pool.len() < required {
return Err(DnnError::BufferTooSmall {
expected: required * T::SIZE,
actual: k_pool.len() * T::SIZE,
});
}
if v_pool.len() < required {
return Err(DnnError::BufferTooSmall {
expected: required * T::SIZE,
actual: v_pool.len() * T::SIZE,
});
}
let free_pages: Vec<u32> = (0..config.max_pages).collect();
let table_size = max_sequences as usize * max_pages_per_seq as usize;
let page_tables = vec![-1i32; table_size];
Ok(Self {
config,
k_pool_ptr: k_pool.as_device_ptr(),
v_pool_ptr: v_pool.as_device_ptr(),
free_pages,
allocated_count: 0,
page_tables,
num_sequences: max_sequences,
max_pages_per_seq,
})
}
pub fn allocate_page(&mut self) -> DnnResult<u32> {
self.free_pages
.pop()
.ok_or_else(|| {
DnnError::WorkspaceRequired(
self.config.page_elements() * std::mem::size_of::<f32>(),
)
})
.inspect(|_| {
self.allocated_count += 1;
})
}
pub fn free_page(&mut self, page: u32) -> DnnResult<()> {
if page >= self.config.max_pages {
return Err(DnnError::InvalidArgument(format!(
"page {} out of range (max {})",
page, self.config.max_pages
)));
}
self.free_pages.push(page);
self.allocated_count = self.allocated_count.saturating_sub(1);
Ok(())
}
pub fn append_kv<T: GpuFloat>(
&mut self,
seq_id: u32,
token_pos: u32,
) -> DnnResult<(CUdeviceptr, CUdeviceptr)> {
if seq_id >= self.num_sequences {
return Err(DnnError::InvalidArgument(format!(
"seq_id {} out of range (max {})",
seq_id, self.num_sequences
)));
}
let logical_page = token_pos / self.config.page_size;
let offset_in_page = token_pos % self.config.page_size;
if logical_page >= self.max_pages_per_seq {
return Err(DnnError::InvalidArgument(format!(
"logical page {} exceeds max_pages_per_seq {}",
logical_page, self.max_pages_per_seq
)));
}
let table_idx = seq_id as usize * self.max_pages_per_seq as usize + logical_page as usize;
let phys_page = if self.page_tables[table_idx] < 0 {
let new_page = self.allocate_page()?;
self.page_tables[table_idx] = new_page as i32;
new_page
} else {
self.page_tables[table_idx] as u32
};
let page_elem_offset = phys_page as usize * self.config.page_elements();
let token_offset_in_page = offset_in_page as usize
* self.config.num_heads as usize
* self.config.head_dim as usize;
let total_elem_offset = page_elem_offset + token_offset_in_page;
let byte_offset = (total_elem_offset * T::SIZE) as u64;
let k_ptr = self.k_pool_ptr + byte_offset;
let v_ptr = self.v_pool_ptr + byte_offset;
Ok((k_ptr, v_ptr))
}
pub fn get_page_table(&self, seq_id: u32) -> DnnResult<&[i32]> {
if seq_id >= self.num_sequences {
return Err(DnnError::InvalidArgument(format!(
"seq_id {} out of range (max {})",
seq_id, self.num_sequences
)));
}
let start = seq_id as usize * self.max_pages_per_seq as usize;
let end = start + self.max_pages_per_seq as usize;
Ok(&self.page_tables[start..end])
}
pub fn k_pool_ptr(&self) -> CUdeviceptr {
self.k_pool_ptr
}
pub fn v_pool_ptr(&self) -> CUdeviceptr {
self.v_pool_ptr
}
pub fn allocated_pages(&self) -> u32 {
self.allocated_count
}
pub fn free_pages_count(&self) -> u32 {
self.free_pages.len() as u32
}
pub fn config(&self) -> &KvCacheConfig {
&self.config
}
pub fn reset_sequence(&mut self, seq_id: u32) -> DnnResult<()> {
if seq_id >= self.num_sequences {
return Err(DnnError::InvalidArgument(format!(
"seq_id {} out of range (max {})",
seq_id, self.num_sequences
)));
}
let start = seq_id as usize * self.max_pages_per_seq as usize;
for i in 0..self.max_pages_per_seq as usize {
let phys = self.page_tables[start + i];
if phys >= 0 {
self.free_pages.push(phys as u32);
self.allocated_count = self.allocated_count.saturating_sub(1);
self.page_tables[start + i] = -1;
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn config_page_elements() {
let cfg = KvCacheConfig {
num_layers: 32,
num_heads: 8,
head_dim: 64,
page_size: 16,
max_pages: 1024,
};
assert_eq!(cfg.page_elements(), 8192);
}
#[test]
fn config_page_bytes() {
let cfg = KvCacheConfig {
num_layers: 1,
num_heads: 4,
head_dim: 128,
page_size: 16,
max_pages: 256,
};
assert_eq!(cfg.page_bytes::<f32>(), 4 * 16 * 128 * 4);
assert_eq!(cfg.page_bytes::<f64>(), 4 * 16 * 128 * 8);
}
#[test]
fn config_total_pool() {
let cfg = KvCacheConfig {
num_layers: 1,
num_heads: 8,
head_dim: 64,
page_size: 16,
max_pages: 100,
};
assert_eq!(cfg.total_pool_elements(), 8192 * 100 * 2);
}
#[test]
fn page_table_out_of_range() {
let cfg = KvCacheConfig {
num_layers: 1,
num_heads: 8,
head_dim: 64,
page_size: 16,
max_pages: 10,
};
assert_eq!(cfg.page_elements(), 8 * 16 * 64);
}
}