use async_trait::async_trait;
use ferrum_interfaces::{
kv_cache::{
AllocationRequest, CacheGcStats, CacheHandleStats, CacheManagerStats, MemoryPressure,
},
BlockTable, KvCacheHandle, KvCacheManager, TensorRef,
};
use ferrum_types::{Device, RequestId, Result};
use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
#[derive(Debug)]
pub struct MockKvCacheHandle {
request_id: RequestId,
block_table: BlockTable,
num_layers: usize,
num_heads: usize,
head_dim: usize,
device: Device,
}
impl MockKvCacheHandle {
pub fn new(request_id: RequestId, num_layers: usize, seq_len: usize) -> Self {
let mut block_table = BlockTable::new(16);
block_table.sequence_length = seq_len;
let blocks_needed = BlockTable::blocks_needed_for_length(seq_len, 16);
let block_ids: Vec<u32> = (0..blocks_needed as u32).collect();
block_table.add_blocks(&block_ids);
Self {
request_id,
block_table,
num_layers,
num_heads: 12,
head_dim: 64,
device: Device::CPU,
}
}
}
impl KvCacheHandle for MockKvCacheHandle {
fn block_table(&self) -> &BlockTable {
&self.block_table
}
fn block_table_mut(&mut self) -> &mut BlockTable {
&mut self.block_table
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn device(&self) -> Device {
self.device.clone()
}
fn num_layers(&self) -> usize {
self.num_layers
}
fn num_heads(&self) -> usize {
self.num_heads
}
fn head_dim(&self) -> usize {
self.head_dim
}
fn key_cache(&self, _layer: usize) -> Result<Option<TensorRef>> {
Ok(None)
}
fn value_cache(&self, _layer: usize) -> Result<Option<TensorRef>> {
Ok(None)
}
fn clone_handle(&self) -> Result<Arc<dyn KvCacheHandle>> {
Ok(Arc::new(MockKvCacheHandle {
request_id: self.request_id.clone(),
block_table: self.block_table.clone(),
num_layers: self.num_layers,
num_heads: self.num_heads,
head_dim: self.head_dim,
device: self.device.clone(),
}))
}
fn stats(&self) -> CacheHandleStats {
CacheHandleStats {
memory_bytes: self.block_table.num_blocks()
* 16
* self.num_layers
* self.num_heads
* self.head_dim
* 2,
blocks_allocated: self.block_table.num_blocks(),
tokens_stored: self.block_table.sequence_length,
utilization: if self.block_table.num_blocks() > 0 {
self.block_table.sequence_length as f32
/ (self.block_table.num_blocks() * 16) as f32
} else {
0.0
},
last_access: std::time::Instant::now(),
}
}
fn is_valid(&self) -> bool {
true
}
fn cache_id(&self) -> String {
format!("mock_{}", self.request_id)
}
}
pub struct MockKvCacheManager {
handles: RwLock<HashMap<RequestId, Arc<dyn KvCacheHandle>>>,
total_blocks: usize,
block_size: usize,
allocation_count: AtomicU64,
deallocation_count: AtomicU64,
}
impl MockKvCacheManager {
pub fn new(total_blocks: usize) -> Self {
Self {
handles: RwLock::new(HashMap::new()),
total_blocks,
block_size: 16,
allocation_count: AtomicU64::new(0),
deallocation_count: AtomicU64::new(0),
}
}
pub fn active_count(&self) -> usize {
self.handles.read().len()
}
}
#[async_trait]
impl KvCacheManager for MockKvCacheManager {
async fn allocate(&self, request: &AllocationRequest) -> Result<Arc<dyn KvCacheHandle>> {
let blocks_needed =
BlockTable::blocks_needed_for_length(request.initial_tokens, self.block_size);
let used_blocks: usize = self
.handles
.read()
.values()
.map(|h| h.block_table().num_blocks())
.sum();
if used_blocks + blocks_needed > self.total_blocks {
return Err(ferrum_types::FerrumError::backend(format!(
"OOM: need {} blocks, have {} free out of {}",
blocks_needed,
self.total_blocks - used_blocks,
self.total_blocks
)));
}
let handle: Arc<dyn KvCacheHandle> = Arc::new(MockKvCacheHandle::new(
request.request_id.clone(),
request.num_layers,
request.initial_tokens,
));
self.handles
.write()
.insert(request.request_id.clone(), handle.clone());
self.allocation_count.fetch_add(1, Ordering::Relaxed);
Ok(handle)
}
async fn extend(
&self,
_handle: &mut dyn KvCacheHandle,
_additional_tokens: usize,
) -> Result<()> {
Ok(())
}
async fn deallocate(&self, request_id: RequestId) -> Result<()> {
self.handles.write().remove(&request_id);
self.deallocation_count.fetch_add(1, Ordering::Relaxed);
Ok(())
}
fn can_allocate(&self, request: &AllocationRequest) -> bool {
let blocks_needed =
BlockTable::blocks_needed_for_length(request.initial_tokens, self.block_size);
let used_blocks: usize = self
.handles
.read()
.values()
.map(|h| h.block_table().num_blocks())
.sum();
used_blocks + blocks_needed <= self.total_blocks
}
fn stats(&self) -> CacheManagerStats {
let handles = self.handles.read();
let used_blocks: usize = handles.values().map(|h| h.block_table().num_blocks()).sum();
CacheManagerStats {
total_memory_bytes: self.total_blocks * self.block_size * 1024,
used_memory_bytes: used_blocks * self.block_size * 1024,
active_caches: handles.len(),
total_blocks: self.total_blocks,
free_blocks: self.total_blocks - used_blocks,
cache_hit_rate: 0.0,
eviction_count: 0,
allocation_count: self.allocation_count.load(Ordering::Relaxed),
allocation_failures: 0,
}
}
async fn gc(&self) -> Result<CacheGcStats> {
Ok(CacheGcStats {
memory_freed: 0,
caches_freed: 0,
gc_time_ms: 0,
})
}
fn set_pressure_callback(&self, _callback: Box<dyn Fn(MemoryPressure) + Send + Sync>) {
}
fn get_handle(&self, request_id: RequestId) -> Option<Arc<dyn KvCacheHandle>> {
self.handles.read().get(&request_id).cloned()
}
fn list_handles(&self) -> Vec<(RequestId, Arc<dyn KvCacheHandle>)> {
self.handles
.read()
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect()
}
}