use crate::blocks::{BlockPool, BlockStorageConfig, PhysicalBlockId};
use crate::cache::prefix::{PrefixCache, PrefixCacheStats, PrefixId};
use async_trait::async_trait;
use ferrum_interfaces::{
kv_cache::{AllocationRequest, BlockTable, CacheGcStats, CacheManagerStats, MemoryPressure},
KvCacheHandle, KvCacheManager, TensorRef,
};
use ferrum_types::{DataType, Device, FerrumError, RequestId, Result};
use parking_lot::{Mutex, RwLock};
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Instant;
use tracing::{debug, info};
#[derive(Debug, Clone)]
pub struct PagedKvCacheConfig {
pub block_size: usize,
pub max_gpu_blocks: usize,
pub max_cpu_blocks: usize,
pub enable_cow: bool,
pub enable_swapping: bool,
pub low_watermark: f32,
pub high_watermark: f32,
pub num_layers: usize,
pub num_heads: usize,
pub head_dim: usize,
pub enable_prefix_cache: bool,
pub max_prefixes: usize,
pub min_prefix_length: usize,
}
impl Default for PagedKvCacheConfig {
fn default() -> Self {
Self {
block_size: 16,
max_gpu_blocks: 1024,
max_cpu_blocks: 512,
enable_cow: true,
enable_swapping: true,
low_watermark: 0.3,
high_watermark: 0.1,
num_layers: 32,
num_heads: 32,
head_dim: 128,
enable_prefix_cache: true,
max_prefixes: 100,
min_prefix_length: 16,
}
}
}
#[derive(Debug)]
pub struct PagedKvCacheHandle {
request_id: RequestId,
device: Device,
block_table: RwLock<BlockTable>,
num_tokens: RwLock<usize>,
num_layers: usize,
num_heads: usize,
head_dim: usize,
block_size: usize,
last_access: RwLock<Instant>,
has_cow_refs: RwLock<bool>,
ref_count: AtomicU64,
}
impl PagedKvCacheHandle {
pub fn new(
request_id: RequestId,
device: Device,
block_size: usize,
num_layers: usize,
num_heads: usize,
head_dim: usize,
) -> Self {
Self {
request_id,
device,
block_table: RwLock::new(BlockTable::new(block_size)),
num_tokens: RwLock::new(0),
num_layers,
num_heads,
head_dim,
block_size,
last_access: RwLock::new(Instant::now()),
has_cow_refs: RwLock::new(false),
ref_count: AtomicU64::new(1),
}
}
pub fn add_block(&self, logical_id: u32, physical_id: u32) {
let mut table = self.block_table.write();
if logical_id as usize >= table.logical_to_physical.len() {
table
.logical_to_physical
.resize((logical_id + 1) as usize, 0);
}
table.logical_to_physical[logical_id as usize] = physical_id;
if physical_id as usize >= table.physical_blocks.len() {
table.physical_blocks.resize((physical_id + 1) as usize, 0);
}
table.physical_blocks[physical_id as usize] = 1;
*self.last_access.write() = Instant::now();
}
pub fn get_physical_block(&self, logical_id: u32) -> Option<u32> {
let table = self.block_table.read();
if (logical_id as usize) < table.logical_to_physical.len() {
let physical = table.logical_to_physical[logical_id as usize];
if physical > 0 {
Some(physical)
} else {
None
}
} else {
None
}
}
pub fn get_physical_blocks(&self) -> Vec<u32> {
let table = self.block_table.read();
table
.logical_to_physical
.iter()
.filter(|&&id| id > 0)
.copied()
.collect()
}
pub fn num_blocks(&self) -> usize {
let table = self.block_table.read();
table
.logical_to_physical
.iter()
.filter(|&&id| id > 0)
.count()
}
pub fn set_num_tokens(&self, tokens: usize) {
*self.num_tokens.write() = tokens;
let mut table = self.block_table.write();
table.sequence_length = tokens;
}
pub fn required_blocks(&self, num_tokens: usize) -> usize {
num_tokens.div_ceil(self.block_size)
}
pub fn add_ref(&self) {
self.ref_count.fetch_add(1, Ordering::Relaxed);
*self.has_cow_refs.write() = true;
}
pub fn remove_ref(&self) -> u64 {
self.ref_count.fetch_sub(1, Ordering::Relaxed)
}
pub fn ref_count(&self) -> u64 {
self.ref_count.load(Ordering::Relaxed)
}
pub fn is_cow(&self) -> bool {
*self.has_cow_refs.read()
}
}
impl KvCacheHandle for PagedKvCacheHandle {
fn block_table(&self) -> &BlockTable {
unsafe {
let ptr = self.block_table.data_ptr();
&*ptr
}
}
fn block_table_mut(&mut self) -> &mut BlockTable {
self.block_table.get_mut()
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn device(&self) -> Device {
self.device.clone()
}
fn num_tokens(&self) -> usize {
*self.num_tokens.read()
}
fn num_layers(&self) -> usize {
self.num_layers
}
fn num_heads(&self) -> usize {
self.num_heads
}
fn head_dim(&self) -> usize {
self.head_dim
}
fn key_cache(&self, _layer: usize) -> Result<Option<TensorRef>> {
Ok(None)
}
fn value_cache(&self, _layer: usize) -> Result<Option<TensorRef>> {
Ok(None)
}
fn clone_handle(&self) -> Result<Arc<dyn KvCacheHandle>> {
self.add_ref();
Ok(Arc::new(PagedKvCacheHandle {
request_id: self.request_id.clone(),
device: self.device.clone(),
block_table: RwLock::new(self.block_table.read().clone()),
num_tokens: RwLock::new(*self.num_tokens.read()),
num_layers: self.num_layers,
num_heads: self.num_heads,
head_dim: self.head_dim,
block_size: self.block_size,
last_access: RwLock::new(Instant::now()),
has_cow_refs: RwLock::new(true),
ref_count: AtomicU64::new(1),
}))
}
fn stats(&self) -> ferrum_interfaces::kv_cache::CacheHandleStats {
let tokens = *self.num_tokens.read();
let blocks = self.num_blocks();
let bytes_per_token = 2 * self.num_layers * self.num_heads * self.head_dim * 2;
ferrum_interfaces::kv_cache::CacheHandleStats {
memory_bytes: blocks * self.block_size * bytes_per_token,
blocks_allocated: blocks,
tokens_stored: tokens,
utilization: if blocks > 0 {
tokens as f32 / (blocks * self.block_size) as f32
} else {
0.0
},
last_access: *self.last_access.read(),
}
}
fn is_valid(&self) -> bool {
self.ref_count() > 0
}
fn cache_id(&self) -> String {
format!("paged-{}", self.request_id)
}
}
pub struct PagedKvCacheManager {
config: PagedKvCacheConfig,
gpu_pool: BlockPool,
cpu_pool: Option<BlockPool>,
active_handles: RwLock<HashMap<RequestId, Arc<PagedKvCacheHandle>>>,
block_to_request: RwLock<HashMap<PhysicalBlockId, RequestId>>,
swapped_blocks: RwLock<HashMap<PhysicalBlockId, PhysicalBlockId>>,
prefix_cache: Option<PrefixCache>,
stats: Mutex<CacheManagerStats>,
#[allow(clippy::type_complexity)]
pressure_callback: Mutex<Option<Box<dyn Fn(MemoryPressure) + Send + Sync>>>,
}
impl PagedKvCacheManager {
pub fn new(device: Device, config: PagedKvCacheConfig) -> Result<Self> {
info!(
"Creating paged KV cache manager: device={:?}, block_size={}, max_gpu_blocks={}, max_cpu_blocks={}, prefix_cache={}",
device, config.block_size, config.max_gpu_blocks, config.max_cpu_blocks, config.enable_prefix_cache
);
let storage_config = BlockStorageConfig {
num_layers: config.num_layers,
num_kv_heads: config.num_heads,
head_dim: config.head_dim,
block_size: config.block_size,
};
let gpu_pool = BlockPool::new_with_storage(
device.clone(),
config.block_size,
DataType::FP16,
config.max_gpu_blocks,
storage_config,
)?;
let cpu_pool = if config.enable_swapping {
Some(BlockPool::new_with_storage(
Device::CPU,
config.block_size,
DataType::FP16,
config.max_cpu_blocks,
storage_config,
)?)
} else {
None
};
let prefix_cache = if config.enable_prefix_cache {
Some(PrefixCache::new(
config.max_prefixes,
config.min_prefix_length,
))
} else {
None
};
Ok(Self {
config,
gpu_pool,
cpu_pool,
active_handles: RwLock::new(HashMap::new()),
block_to_request: RwLock::new(HashMap::new()),
swapped_blocks: RwLock::new(HashMap::new()),
prefix_cache,
stats: Mutex::new(CacheManagerStats {
total_memory_bytes: 0,
used_memory_bytes: 0,
active_caches: 0,
total_blocks: 0,
free_blocks: 0,
cache_hit_rate: 0.0,
eviction_count: 0,
allocation_count: 0,
allocation_failures: 0,
}),
pressure_callback: Mutex::new(None),
})
}
pub fn with_defaults(device: Device, block_size: usize, max_blocks: usize) -> Result<Self> {
let config = PagedKvCacheConfig {
block_size,
max_gpu_blocks: max_blocks,
max_cpu_blocks: max_blocks / 2,
..Default::default()
};
Self::new(device, config)
}
pub fn allocate_blocks(
&self,
handle: &PagedKvCacheHandle,
num_blocks: usize,
) -> Result<Vec<PhysicalBlockId>> {
let mut allocated = Vec::with_capacity(num_blocks);
let current_blocks = handle.num_blocks();
for i in 0..num_blocks {
let allocation = self.gpu_pool.allocate()?;
let physical_id = allocation.physical_id;
let logical_id = (current_blocks + i) as u32;
handle.add_block(logical_id, physical_id.0);
self.block_to_request
.write()
.insert(physical_id, handle.request_id.clone());
allocated.push(physical_id);
}
{
let mut stats = self.stats.lock();
stats.allocation_count += num_blocks as u64;
}
debug!(
"Allocated {} blocks for request {}: {:?}",
num_blocks, handle.request_id, allocated
);
Ok(allocated)
}
pub fn free_blocks(&self, block_ids: &[PhysicalBlockId]) -> Result<()> {
for &block_id in block_ids {
self.gpu_pool.deallocate(block_id)?;
self.block_to_request.write().remove(&block_id);
}
debug!("Freed {} blocks", block_ids.len());
Ok(())
}
pub fn write_kv(
&self,
handle: &PagedKvCacheHandle,
layer: usize,
token_position: usize,
key: &[f32],
value: &[f32],
) -> Result<()> {
let block_size = self.config.block_size;
let logical_block = token_position / block_size;
let slot = token_position % block_size;
let physical_id = handle
.get_physical_block(logical_block as u32)
.ok_or_else(|| {
FerrumError::internal(format!(
"No physical block for logical block {} (token {})",
logical_block, token_position
))
})?;
self.gpu_pool
.write_kv_slot(PhysicalBlockId::new(physical_id), layer, slot, key, value)
}
pub fn read_kv(
&self,
handle: &PagedKvCacheHandle,
layer: usize,
start_token: usize,
num_tokens: usize,
) -> Result<(Vec<f32>, Vec<f32>)> {
let block_size = self.config.block_size;
let kv_size = self.config.num_heads * self.config.head_dim;
let mut keys = Vec::with_capacity(num_tokens * kv_size);
let mut values = Vec::with_capacity(num_tokens * kv_size);
for pos in start_token..start_token + num_tokens {
let logical_block = pos / block_size;
let slot = pos % block_size;
let physical_id = handle
.get_physical_block(logical_block as u32)
.ok_or_else(|| {
FerrumError::internal(format!(
"No physical block for logical block {} (token {})",
logical_block, pos
))
})?;
let (k, v) =
self.gpu_pool
.read_kv_slot(PhysicalBlockId::new(physical_id), layer, slot)?;
keys.extend_from_slice(&k);
values.extend_from_slice(&v);
}
Ok((keys, values))
}
pub fn gpu_pool(&self) -> &BlockPool {
&self.gpu_pool
}
pub fn prefix_cache(&self) -> Option<&PrefixCache> {
self.prefix_cache.as_ref()
}
pub fn share_prefix_blocks(
&self,
source: &PagedKvCacheHandle,
target: &PagedKvCacheHandle,
num_prefix_blocks: usize,
) -> Result<()> {
let source_blocks = source.get_physical_blocks();
let n = num_prefix_blocks.min(source_blocks.len());
for i in 0..n {
let phys_id = source_blocks[i];
target.add_block(i as u32, phys_id);
let pid = PhysicalBlockId::new(phys_id);
if let Some(block) = self.gpu_pool.get_block(pid) {
block.write().add_ref();
}
}
debug!(
"Shared {} prefix blocks from {} to {}",
n, source.request_id, target.request_id
);
Ok(())
}
pub fn swap_out(&self, block_ids: &[PhysicalBlockId]) -> Result<Vec<PhysicalBlockId>> {
let cpu_pool = self
.cpu_pool
.as_ref()
.ok_or_else(|| FerrumError::unsupported("Swapping not enabled"))?;
let mut swapped = Vec::with_capacity(block_ids.len());
let mut swap_map = self.swapped_blocks.write();
for &gpu_block in block_ids {
let cpu_allocation = cpu_pool.allocate()?;
let cpu_block = cpu_allocation.physical_id;
swap_map.insert(gpu_block, cpu_block);
swapped.push(cpu_block);
self.gpu_pool.deallocate(gpu_block)?;
}
debug!("Swapped out {} blocks to CPU", swapped.len());
Ok(swapped)
}
pub fn swap_in(&self, cpu_block_ids: &[PhysicalBlockId]) -> Result<Vec<PhysicalBlockId>> {
let cpu_pool = self
.cpu_pool
.as_ref()
.ok_or_else(|| FerrumError::unsupported("Swapping not enabled"))?;
let mut swapped = Vec::with_capacity(cpu_block_ids.len());
let mut swap_map = self.swapped_blocks.write();
for &cpu_block in cpu_block_ids {
let gpu_allocation = self.gpu_pool.allocate()?;
let gpu_block = gpu_allocation.physical_id;
let gpu_original = swap_map
.iter()
.find(|(_, &cpu)| cpu == cpu_block)
.map(|(&gpu, _)| gpu);
if let Some(orig_gpu) = gpu_original {
swap_map.remove(&orig_gpu);
}
swapped.push(gpu_block);
cpu_pool.deallocate(cpu_block)?;
}
debug!("Swapped in {} blocks from CPU", swapped.len());
Ok(swapped)
}
pub fn check_pressure(&self) -> MemoryPressure {
let gpu_stats = self.gpu_pool.stats();
let free_ratio = gpu_stats.free_blocks as f32 / gpu_stats.max_blocks.max(1) as f32;
if free_ratio < self.config.high_watermark {
MemoryPressure::Critical
} else if free_ratio < self.config.low_watermark {
MemoryPressure::High
} else {
MemoryPressure::Low
}
}
fn notify_pressure(&self, pressure: MemoryPressure) {
if let Some(ref callback) = *self.pressure_callback.lock() {
callback(pressure);
}
}
pub fn free_block_count(&self) -> usize {
self.gpu_pool.stats().free_blocks
}
pub fn total_blocks(&self) -> usize {
self.gpu_pool.stats().total_blocks
}
pub fn cow_copy(&self, handle: &PagedKvCacheHandle, block_ids: &[u32]) -> Result<Vec<u32>> {
if !self.config.enable_cow {
return Err(FerrumError::unsupported("COW not enabled"));
}
let mut new_blocks = Vec::with_capacity(block_ids.len());
for &_old_physical in block_ids {
let allocation = self.gpu_pool.allocate()?;
let new_physical = allocation.physical_id;
new_blocks.push(new_physical.0);
self.block_to_request
.write()
.insert(new_physical, handle.request_id.clone());
}
debug!("COW copied {} blocks", new_blocks.len());
Ok(new_blocks)
}
pub fn find_prefix(
&self,
tokens: &[ferrum_types::TokenId],
) -> Option<(
PrefixId,
Arc<dyn ferrum_interfaces::KvCacheHandle + Send + Sync>,
Vec<f32>,
usize,
)> {
let prefix_cache = self.prefix_cache.as_ref()?;
if let Some((prefix_id, kv_handle, last_logits)) = prefix_cache.find_prefix(tokens) {
let matched_len = prefix_id.len();
debug!("Prefix cache hit: matched {} tokens", matched_len);
{
let mut stats = self.stats.lock();
let total = stats.allocation_count as f32;
if total > 0.0 {
stats.cache_hit_rate = (stats.cache_hit_rate * (total - 1.0) + 1.0) / total;
}
}
Some((prefix_id, kv_handle, last_logits, matched_len))
} else {
None
}
}
pub fn store_prefix(
&self,
tokens: &[ferrum_types::TokenId],
kv_handle: Arc<dyn ferrum_interfaces::KvCacheHandle + Send + Sync>,
last_logits: Vec<f32>,
) -> Result<()> {
if let Some(prefix_cache) = &self.prefix_cache {
prefix_cache.store_prefix(tokens, kv_handle, last_logits)?;
debug!("Stored prefix with {} tokens in cache", tokens.len());
}
Ok(())
}
pub fn prefix_cache_stats(&self) -> Option<PrefixCacheStats> {
self.prefix_cache.as_ref().map(|pc| pc.stats())
}
pub fn evict_prefixes(&self, count: usize) -> usize {
if let Some(prefix_cache) = &self.prefix_cache {
let evicted = prefix_cache.evict_n(count);
if evicted > 0 {
debug!("Evicted {} prefixes from cache", evicted);
}
evicted
} else {
0
}
}
pub fn clear_prefix_cache(&self) {
if let Some(prefix_cache) = &self.prefix_cache {
prefix_cache.clear();
debug!("Cleared prefix cache");
}
}
}
#[async_trait]
impl KvCacheManager for PagedKvCacheManager {
async fn allocate(&self, request: &AllocationRequest) -> Result<Arc<dyn KvCacheHandle>> {
debug!(
"Allocating paged KV cache for request: {:?}",
request.request_id
);
let pressure = self.check_pressure();
if matches!(pressure, MemoryPressure::Critical) {
self.notify_pressure(pressure);
let _ = self.gc().await;
}
let handle = Arc::new(PagedKvCacheHandle::new(
request.request_id.clone(),
request.device.clone(),
self.config.block_size,
request.num_layers,
request.num_heads,
request.head_dim,
));
let initial_blocks = handle.required_blocks(request.initial_tokens);
if initial_blocks > 0 {
self.allocate_blocks(&handle, initial_blocks)?;
}
handle.set_num_tokens(request.initial_tokens);
self.active_handles
.write()
.insert(request.request_id.clone(), handle.clone());
{
let mut stats = self.stats.lock();
stats.active_caches += 1;
stats.allocation_count += 1;
}
Ok(handle)
}
async fn extend(&self, handle: &mut dyn KvCacheHandle, additional_tokens: usize) -> Result<()> {
let paged_handle = handle
.as_any()
.downcast_ref::<PagedKvCacheHandle>()
.ok_or_else(|| FerrumError::internal("Invalid handle type"))?;
let current_tokens = paged_handle.num_tokens();
let new_tokens = current_tokens + additional_tokens;
let current_blocks = paged_handle.num_blocks();
let required_blocks = paged_handle.required_blocks(new_tokens);
if required_blocks > current_blocks {
let new_blocks = required_blocks - current_blocks;
if paged_handle.is_cow() && paged_handle.ref_count() > 1 {
let existing = paged_handle.get_physical_blocks();
let _new_physical = self.cow_copy(paged_handle, &existing)?;
}
self.allocate_blocks(paged_handle, new_blocks)?;
}
paged_handle.set_num_tokens(new_tokens);
debug!(
"Extended KV cache for {}: {} -> {} tokens",
paged_handle.request_id, current_tokens, new_tokens
);
Ok(())
}
async fn deallocate(&self, request_id: RequestId) -> Result<()> {
debug!("Deallocating paged KV cache for request: {:?}", request_id);
let handle = self.active_handles.write().remove(&request_id);
if let Some(handle) = handle {
if handle.ref_count() > 1 {
handle.remove_ref();
debug!(
"Decremented ref count for {}, remaining: {}",
request_id,
handle.ref_count()
);
return Ok(());
}
let block_ids: Vec<PhysicalBlockId> = handle
.get_physical_blocks()
.into_iter()
.map(PhysicalBlockId)
.collect();
for block_id in block_ids {
let _ = self.gpu_pool.deallocate(block_id);
self.block_to_request.write().remove(&block_id);
}
{
let mut stats = self.stats.lock();
if stats.active_caches > 0 {
stats.active_caches -= 1;
}
}
}
Ok(())
}
fn can_allocate(&self, request: &AllocationRequest) -> bool {
let required_blocks = request.initial_tokens.div_ceil(self.config.block_size);
let gpu_stats = self.gpu_pool.stats();
gpu_stats.free_blocks >= required_blocks
|| gpu_stats.total_blocks + required_blocks <= gpu_stats.max_blocks
}
fn stats(&self) -> CacheManagerStats {
let gpu_stats = self.gpu_pool.stats();
let mut stats = self.stats.lock().clone();
stats.total_blocks = gpu_stats.max_blocks;
stats.free_blocks = gpu_stats.free_blocks;
let bytes_per_block = self.config.block_size
* 2 * self.config.num_layers
* self.config.num_heads
* self.config.head_dim
* 2;
stats.total_memory_bytes = gpu_stats.max_blocks * bytes_per_block;
stats.used_memory_bytes = gpu_stats.allocated_blocks * bytes_per_block;
stats
}
async fn gc(&self) -> Result<CacheGcStats> {
let start = Instant::now();
let evicted = self.gpu_pool.evict_blocks(10)?;
{
let mut stats = self.stats.lock();
stats.eviction_count += evicted.len() as u64;
}
Ok(CacheGcStats {
memory_freed: evicted.len() * self.config.block_size * 1024, caches_freed: 0,
gc_time_ms: start.elapsed().as_millis() as u64,
})
}
fn set_pressure_callback(&self, callback: Box<dyn Fn(MemoryPressure) + Send + Sync>) {
*self.pressure_callback.lock() = Some(callback);
}
fn get_handle(&self, request_id: RequestId) -> Option<Arc<dyn KvCacheHandle>> {
self.active_handles
.read()
.get(&request_id)
.map(|h| h.clone() as Arc<dyn KvCacheHandle>)
}
fn list_handles(&self) -> Vec<(RequestId, Arc<dyn KvCacheHandle>)> {
self.active_handles
.read()
.iter()
.map(|(id, handle)| (id.clone(), handle.clone() as Arc<dyn KvCacheHandle>))
.collect()
}
}
impl std::fmt::Debug for PagedKvCacheManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let gpu_stats = self.gpu_pool.stats();
f.debug_struct("PagedKvCacheManager")
.field("block_size", &self.config.block_size)
.field("total_gpu_blocks", &gpu_stats.total_blocks)
.field("free_gpu_blocks", &gpu_stats.free_blocks)
.field("active_handles", &self.active_handles.read().len())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_request() -> AllocationRequest {
AllocationRequest {
request_id: RequestId::new(),
initial_tokens: 64,
max_sequence_length: 2048,
num_layers: 32,
num_heads: 32,
head_dim: 128,
device: Device::CPU,
dtype: DataType::FP16,
priority: ferrum_types::Priority::Normal,
}
}
#[tokio::test]
async fn test_manager_creation() {
let manager = PagedKvCacheManager::with_defaults(Device::CPU, 16, 100);
assert!(manager.is_ok());
}
#[tokio::test]
async fn test_allocate_and_deallocate() {
let manager = PagedKvCacheManager::with_defaults(Device::CPU, 16, 100).unwrap();
let request = create_test_request();
let request_id = request.request_id.clone();
let handle = manager.allocate(&request).await.unwrap();
assert!(handle.is_valid());
assert_eq!(handle.num_tokens(), 64);
let stats = handle.stats();
assert!(stats.blocks_allocated >= 1 || stats.tokens_stored >= 64);
manager.deallocate(request_id).await.unwrap();
}
#[tokio::test]
async fn test_extend() {
let manager = PagedKvCacheManager::with_defaults(Device::CPU, 16, 100).unwrap();
let request = create_test_request();
let request_id = request.request_id.clone();
let handle = manager.allocate(&request).await.unwrap();
let initial_blocks = handle.stats().blocks_allocated;
let paged_handle = manager.get_handle(request_id.clone()).unwrap();
let paged_ref = paged_handle
.as_any()
.downcast_ref::<PagedKvCacheHandle>()
.unwrap();
manager.allocate_blocks(paged_ref, 4).unwrap();
let new_blocks = handle.stats().blocks_allocated;
assert!(new_blocks > initial_blocks);
manager.deallocate(request_id).await.unwrap();
}
#[tokio::test]
async fn test_can_allocate() {
let manager = PagedKvCacheManager::with_defaults(Device::CPU, 16, 10).unwrap();
let request = create_test_request();
assert!(manager.can_allocate(&request));
for _ in 0..8 {
let req = create_test_request();
let _ = manager.allocate(&req).await;
}
let stats = manager.stats();
assert!(stats.free_blocks < stats.total_blocks);
}
#[tokio::test]
async fn test_gc() {
let manager = PagedKvCacheManager::with_defaults(Device::CPU, 16, 100).unwrap();
let request = create_test_request();
let request_id = request.request_id.clone();
let _ = manager.allocate(&request).await.unwrap();
manager.deallocate(request_id).await.unwrap();
let gc_stats = manager.gc().await.unwrap();
assert_eq!(gc_stats.caches_freed, 0);
}
#[test]
fn test_paged_handle() {
let handle = PagedKvCacheHandle::new(RequestId::new(), Device::CPU, 16, 32, 32, 128);
assert_eq!(handle.num_tokens(), 0);
assert_eq!(handle.num_blocks(), 0);
handle.add_block(0, 5);
handle.add_block(1, 10);
assert_eq!(handle.num_blocks(), 2);
assert_eq!(handle.get_physical_block(0), Some(5));
assert_eq!(handle.get_physical_block(1), Some(10));
}
#[tokio::test]
async fn test_write_read_kv_across_blocks() {
let config = PagedKvCacheConfig {
block_size: 4,
max_gpu_blocks: 16,
max_cpu_blocks: 0,
enable_cow: false,
enable_swapping: false,
num_layers: 2,
num_heads: 2,
head_dim: 4,
enable_prefix_cache: false,
..Default::default()
};
let manager = PagedKvCacheManager::new(Device::CPU, config).unwrap();
let request = AllocationRequest {
request_id: RequestId::new(),
initial_tokens: 6, max_sequence_length: 32,
num_layers: 2,
num_heads: 2,
head_dim: 4,
device: Device::CPU,
dtype: DataType::FP16,
priority: ferrum_types::Priority::Normal,
};
let request_id = request.request_id.clone();
let handle_dyn = manager.allocate(&request).await.unwrap();
let handle = handle_dyn
.as_any()
.downcast_ref::<PagedKvCacheHandle>()
.unwrap();
let kv_size = 2 * 4;
for pos in 0..6 {
let key: Vec<f32> = (0..kv_size).map(|i| (pos * 100 + i) as f32).collect();
let val: Vec<f32> = (0..kv_size).map(|i| (pos * 100 + i + 50) as f32).collect();
manager.write_kv(handle, 0, pos, &key, &val).unwrap();
}
let (keys, vals) = manager.read_kv(handle, 0, 0, 6).unwrap();
assert_eq!(keys.len(), 6 * kv_size);
assert_eq!(vals.len(), 6 * kv_size);
assert_eq!(keys[0], 0.0);
assert_eq!(keys[kv_size - 1], 7.0);
assert_eq!(keys[4 * kv_size], 400.0);
assert_eq!(vals[5 * kv_size], 550.0);
let (k1, _) = manager.read_kv(handle, 1, 0, 1).unwrap();
assert!(k1.iter().all(|&x| x == 0.0));
manager.deallocate(request_id).await.unwrap();
}
#[test]
fn test_required_blocks() {
let handle = PagedKvCacheHandle::new(
RequestId::new(),
Device::CPU,
16, 32,
32,
128,
);
assert_eq!(handle.required_blocks(0), 0);
assert_eq!(handle.required_blocks(16), 1);
assert_eq!(handle.required_blocks(17), 2);
assert_eq!(handle.required_blocks(32), 2);
assert_eq!(handle.required_blocks(33), 3);
}
}