use std::sync::Arc;
use cudarc::driver::{CudaSlice, CudaStream};
#[derive(Debug, Clone)]
pub struct GpuPagedKvConfig {
pub block_size: usize, pub max_blocks: usize, pub num_kv_heads: usize,
pub head_dim: usize,
pub num_layers: usize,
}
pub struct GpuPagedKvPool {
k_pools: Vec<CudaSlice<half::f16>>,
v_pools: Vec<CudaSlice<half::f16>>,
config: GpuPagedKvConfig,
block_stride: usize,
token_stride: usize,
stream: Arc<CudaStream>,
}
impl GpuPagedKvPool {
pub fn new(
config: GpuPagedKvConfig,
stream: Arc<CudaStream>,
) -> Result<Self, cudarc::driver::DriverError> {
assert!(
config.block_size.is_power_of_two(),
"block_size must be power of 2, got {}",
config.block_size
);
let token_stride = config.num_kv_heads * config.head_dim;
let block_stride = config.block_size * token_stride;
let pool_size = config.max_blocks * block_stride;
let mut k_pools = Vec::with_capacity(config.num_layers);
let mut v_pools = Vec::with_capacity(config.num_layers);
for _ in 0..config.num_layers {
k_pools.push(unsafe { stream.alloc::<half::f16>(pool_size)? });
v_pools.push(unsafe { stream.alloc::<half::f16>(pool_size)? });
}
tracing::info!(
"GpuPagedKvPool allocated: {} layers × {} blocks × {} tok/block = {:.1}MB per K/V",
config.num_layers,
config.max_blocks,
config.block_size,
(pool_size * std::mem::size_of::<half::f16>()) as f64 / (1024.0 * 1024.0),
);
Ok(Self {
k_pools,
v_pools,
config,
block_stride,
token_stride,
stream,
})
}
pub fn write_k_token(
&mut self,
layer: usize,
physical_block: usize,
slot: usize,
k_data: &cudarc::driver::CudaView<half::f16>,
) -> Result<(), cudarc::driver::DriverError> {
let offset = physical_block * self.block_stride + slot * self.token_stride;
let mut dst = self.k_pools[layer].slice_mut(offset..offset + self.token_stride);
self.stream.memcpy_dtod(k_data, &mut dst)
}
pub fn write_v_token(
&mut self,
layer: usize,
physical_block: usize,
slot: usize,
v_data: &cudarc::driver::CudaView<half::f16>,
) -> Result<(), cudarc::driver::DriverError> {
let offset = physical_block * self.block_stride + slot * self.token_stride;
let mut dst = self.v_pools[layer].slice_mut(offset..offset + self.token_stride);
self.stream.memcpy_dtod(v_data, &mut dst)
}
pub fn copy_contiguous_to_paged(
&mut self,
layer: usize,
k_contiguous: &CudaSlice<half::f16>,
v_contiguous: &CudaSlice<half::f16>,
seq_len: usize,
block_ids: &[i32],
) -> Result<(), cudarc::driver::DriverError> {
let bs = self.config.block_size;
for (logical_block, &physical_block) in block_ids.iter().enumerate() {
let phys = physical_block as usize;
let src_start = logical_block * bs * self.token_stride;
let tokens_in_block = (seq_len - logical_block * bs).min(bs);
let elems = tokens_in_block * self.token_stride;
let k_src = k_contiguous.slice(src_start..src_start + elems);
let dst_offset = phys * self.block_stride;
let mut k_dst = self.k_pools[layer].slice_mut(dst_offset..dst_offset + elems);
self.stream.memcpy_dtod(&k_src, &mut k_dst)?;
let v_src = v_contiguous.slice(src_start..src_start + elems);
let mut v_dst = self.v_pools[layer].slice_mut(dst_offset..dst_offset + elems);
self.stream.memcpy_dtod(&v_src, &mut v_dst)?;
}
Ok(())
}
pub fn upload_block_table(
&self,
block_ids: &[i32],
) -> Result<CudaSlice<i32>, cudarc::driver::DriverError> {
self.stream.clone_htod(block_ids)
}
pub fn k_pool(&self, layer: usize) -> &CudaSlice<half::f16> {
&self.k_pools[layer]
}
pub fn v_pool(&self, layer: usize) -> &CudaSlice<half::f16> {
&self.v_pools[layer]
}
pub fn block_size(&self) -> usize {
self.config.block_size
}
pub fn max_blocks(&self) -> usize {
self.config.max_blocks
}
pub fn memory_bytes(&self) -> usize {
let pool_elems = self.config.max_blocks * self.block_stride;
pool_elems * std::mem::size_of::<half::f16>() * 2 * self.config.num_layers
}
}