use crate::TensorRef;
use ferrum_types::{BlockId, Device, RequestId, Result};
use serde::{Deserialize, Serialize};
use smallvec::SmallVec;
use std::{collections::HashMap, sync::Arc};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BlockTable {
pub physical_blocks: SmallVec<[BlockId; 8]>,
pub logical_to_physical: SmallVec<[u32; 8]>,
pub sequence_length: usize,
pub block_size: usize,
}
impl BlockTable {
pub fn new(block_size: usize) -> Self {
Self {
physical_blocks: SmallVec::new(),
logical_to_physical: SmallVec::new(),
sequence_length: 0,
block_size,
}
}
pub fn num_blocks(&self) -> usize {
self.physical_blocks.len()
}
pub fn blocks_needed_for_length(length: usize, block_size: usize) -> usize {
(length + block_size - 1) / block_size }
pub fn has_free_space(&self) -> bool {
let used_blocks = Self::blocks_needed_for_length(self.sequence_length, self.block_size);
used_blocks < self.num_blocks()
}
pub fn free_tokens(&self) -> usize {
if self.num_blocks() == 0 {
0
} else {
self.num_blocks() * self.block_size - self.sequence_length
}
}
pub fn add_blocks(&mut self, blocks: &[BlockId]) {
let start_logical = self.logical_to_physical.len();
for (i, &block) in blocks.iter().enumerate() {
self.physical_blocks.push(block);
self.logical_to_physical.push((start_logical + i) as u32);
}
}
pub fn extend_sequence(&mut self, additional_tokens: usize) -> Result<()> {
let new_length = self.sequence_length + additional_tokens;
let required_blocks = Self::blocks_needed_for_length(new_length, self.block_size);
if required_blocks > self.num_blocks() {
return Err(ferrum_types::FerrumError::backend(format!(
"Insufficient blocks: need {}, have {}",
required_blocks,
self.num_blocks()
)));
}
self.sequence_length = new_length;
Ok(())
}
}
pub trait KvCacheHandle: Send + Sync + std::fmt::Debug {
fn block_table(&self) -> &BlockTable;
fn block_table_mut(&mut self) -> &mut BlockTable;
fn as_any(&self) -> &dyn std::any::Any;
fn device(&self) -> Device;
fn num_tokens(&self) -> usize {
self.block_table().sequence_length
}
fn num_layers(&self) -> usize;
fn num_heads(&self) -> usize;
fn head_dim(&self) -> usize;
fn key_cache(&self, layer: usize) -> Result<Option<TensorRef>>;
fn value_cache(&self, layer: usize) -> Result<Option<TensorRef>>;
fn kv_cache(&self, layer: usize) -> Result<(Option<TensorRef>, Option<TensorRef>)> {
Ok((self.key_cache(layer)?, self.value_cache(layer)?))
}
fn clone_handle(&self) -> Result<Arc<dyn KvCacheHandle>>;
fn stats(&self) -> CacheHandleStats;
fn is_valid(&self) -> bool;
fn cache_id(&self) -> String;
}
#[derive(Debug, Clone)]
pub struct CacheHandleStats {
pub memory_bytes: usize,
pub blocks_allocated: usize,
pub tokens_stored: usize,
pub utilization: f32,
pub last_access: std::time::Instant,
}
#[derive(Debug, Clone)]
pub struct AllocationRequest {
pub request_id: RequestId,
pub initial_tokens: usize,
pub max_sequence_length: usize,
pub num_layers: usize,
pub num_heads: usize,
pub head_dim: usize,
pub device: Device,
pub dtype: ferrum_types::DataType,
pub priority: ferrum_types::Priority,
}
impl AllocationRequest {
pub fn estimated_memory_bytes(&self) -> usize {
let kv_size =
self.num_layers * self.num_heads * self.max_sequence_length * self.head_dim * 2;
kv_size * self.dtype.size_bytes()
}
}
#[async_trait::async_trait]
pub trait KvCacheManager: Send + Sync {
async fn allocate(&self, request: &AllocationRequest) -> Result<Arc<dyn KvCacheHandle>>;
async fn extend(&self, handle: &mut dyn KvCacheHandle, additional_tokens: usize) -> Result<()>;
async fn deallocate(&self, request_id: RequestId) -> Result<()>;
fn can_allocate(&self, request: &AllocationRequest) -> bool;
fn stats(&self) -> CacheManagerStats;
async fn gc(&self) -> Result<CacheGcStats>;
fn set_pressure_callback(&self, callback: Box<dyn Fn(MemoryPressure) + Send + Sync>);
fn get_handle(&self, request_id: RequestId) -> Option<Arc<dyn KvCacheHandle>>;
fn list_handles(&self) -> Vec<(RequestId, Arc<dyn KvCacheHandle>)>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheManagerStats {
pub total_memory_bytes: usize,
pub used_memory_bytes: usize,
pub active_caches: usize,
pub total_blocks: usize,
pub free_blocks: usize,
pub cache_hit_rate: f32,
pub eviction_count: u64,
pub allocation_count: u64,
pub allocation_failures: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheGcStats {
pub memory_freed: usize,
pub caches_freed: usize,
pub gc_time_ms: u64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub enum MemoryPressure {
Low,
Medium,
High,
Critical,
}
pub trait AdvancedKvCacheManager: KvCacheManager {
async fn enable_prefix_caching(&self, config: PrefixCacheConfig) -> Result<()>;
async fn share_prefix(
&self,
source: RequestId,
target: RequestId,
shared_tokens: usize,
) -> Result<()>;
async fn swap_out(&self, request_id: RequestId) -> Result<()>;
async fn swap_in(&self, request_id: RequestId) -> Result<()>;
async fn compress_cache(&self, request_id: RequestId, compression_ratio: f32) -> Result<()>;
fn compression_stats(&self) -> CompressionStats;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PrefixCacheConfig {
pub max_prefixes: usize,
pub min_prefix_length: usize,
pub prefix_ttl_seconds: u64,
pub enable_cross_request_sharing: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompressionStats {
pub compressed_caches: usize,
pub memory_saved_bytes: usize,
pub avg_compression_ratio: f32,
pub avg_compression_time_ms: f64,
}
pub trait BlockAllocator: Send + Sync {
fn allocate_blocks(&self, num_blocks: usize) -> Result<Vec<BlockId>>;
fn free_blocks(&self, blocks: &[BlockId]) -> Result<()>;
fn free_block_count(&self) -> usize;
fn total_block_count(&self) -> usize;
fn block_size(&self) -> usize;
fn defragment(&self) -> Result<()>;
}
#[async_trait::async_trait]
pub trait MultiDeviceCacheManager: KvCacheManager {
fn supported_devices(&self) -> Vec<Device>;
fn set_device_preference(&self, devices: Vec<Device>);
async fn move_cache(&self, request_id: RequestId, target_device: Device) -> Result<()>;
fn get_cache_device(&self, request_id: RequestId) -> Option<Device>;
async fn rebalance_devices(&self) -> Result<()>;
fn device_stats(&self) -> HashMap<Device, CacheManagerStats>;
}
pub trait CacheEvictionPolicy: Send + Sync {
fn select_eviction_candidates(
&self,
required_memory: usize,
active_caches: &[(RequestId, Arc<dyn KvCacheHandle>)],
) -> Vec<RequestId>;
fn record_access(&mut self, request_id: RequestId, access_time: std::time::Instant);
fn name(&self) -> &str;
}
pub struct LruEvictionPolicy {
access_times: HashMap<RequestId, std::time::Instant>,
}
impl LruEvictionPolicy {
pub fn new() -> Self {
Self {
access_times: HashMap::new(),
}
}
}
impl CacheEvictionPolicy for LruEvictionPolicy {
fn select_eviction_candidates(
&self,
required_memory: usize,
active_caches: &[(RequestId, Arc<dyn KvCacheHandle>)],
) -> Vec<RequestId> {
let mut candidates: Vec<_> = active_caches
.iter()
.map(|(req_id, handle)| {
let access_time = self
.access_times
.get(req_id)
.copied()
.unwrap_or_else(std::time::Instant::now);
(req_id.clone(), handle.stats().memory_bytes, access_time)
})
.collect();
candidates.sort_by(|a, b| a.2.cmp(&b.2));
let mut freed_memory = 0;
let mut result = Vec::new();
for (req_id, memory_bytes, _) in candidates {
result.push(req_id);
freed_memory += memory_bytes;
if freed_memory >= required_memory {
break;
}
}
result
}
fn record_access(&mut self, request_id: RequestId, access_time: std::time::Instant) {
self.access_times.insert(request_id, access_time);
}
fn name(&self) -> &str {
"lru"
}
}
impl Default for LruEvictionPolicy {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheConfig {
pub block_size: usize,
pub max_blocks: usize,
pub initial_blocks: usize,
pub enable_pooling: bool,
pub target_devices: Vec<Device>,
pub enable_prefix_caching: bool,
pub prefix_cache_config: Option<PrefixCacheConfig>,
pub enable_multi_device: bool,
pub pressure_thresholds: MemoryPressureThresholds,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryPressureThresholds {
pub medium_threshold: f32,
pub high_threshold: f32,
pub critical_threshold: f32,
}
impl Default for MemoryPressureThresholds {
fn default() -> Self {
Self {
medium_threshold: 0.6,
high_threshold: 0.8,
critical_threshold: 0.95,
}
}
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
block_size: 16,
max_blocks: 1000,
initial_blocks: 100,
enable_pooling: true,
target_devices: vec![Device::CPU],
enable_prefix_caching: false,
prefix_cache_config: None,
enable_multi_device: false,
pressure_thresholds: MemoryPressureThresholds::default(),
}
}
}