1use crate::TensorRef;
7use ferrum_types::{BlockId, Device, RequestId, Result};
8use serde::{Deserialize, Serialize};
9use smallvec::SmallVec;
10use std::{collections::HashMap, sync::Arc};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct BlockTable {
15 pub physical_blocks: SmallVec<[BlockId; 8]>,
17 pub logical_to_physical: SmallVec<[u32; 8]>,
19 pub sequence_length: usize,
21 pub block_size: usize,
23}
24
25impl BlockTable {
26 pub fn new(block_size: usize) -> Self {
28 Self {
29 physical_blocks: SmallVec::new(),
30 logical_to_physical: SmallVec::new(),
31 sequence_length: 0,
32 block_size,
33 }
34 }
35
36 pub fn num_blocks(&self) -> usize {
38 self.physical_blocks.len()
39 }
40
41 pub fn blocks_needed_for_length(length: usize, block_size: usize) -> usize {
43 (length + block_size - 1) / block_size }
45
46 pub fn has_free_space(&self) -> bool {
48 let used_blocks = Self::blocks_needed_for_length(self.sequence_length, self.block_size);
49 used_blocks < self.num_blocks()
50 }
51
52 pub fn free_tokens(&self) -> usize {
54 if self.num_blocks() == 0 {
55 0
56 } else {
57 self.num_blocks() * self.block_size - self.sequence_length
58 }
59 }
60
61 pub fn add_blocks(&mut self, blocks: &[BlockId]) {
63 let start_logical = self.logical_to_physical.len();
64
65 for (i, &block) in blocks.iter().enumerate() {
66 self.physical_blocks.push(block);
67 self.logical_to_physical.push((start_logical + i) as u32);
68 }
69 }
70
71 pub fn extend_sequence(&mut self, additional_tokens: usize) -> Result<()> {
73 let new_length = self.sequence_length + additional_tokens;
74 let required_blocks = Self::blocks_needed_for_length(new_length, self.block_size);
75
76 if required_blocks > self.num_blocks() {
77 return Err(ferrum_types::FerrumError::backend(format!(
78 "Insufficient blocks: need {}, have {}",
79 required_blocks,
80 self.num_blocks()
81 )));
82 }
83
84 self.sequence_length = new_length;
85 Ok(())
86 }
87}
88
89pub trait KvCacheHandle: Send + Sync + std::fmt::Debug {
91 fn block_table(&self) -> &BlockTable;
93
94 fn block_table_mut(&mut self) -> &mut BlockTable;
96
97 fn as_any(&self) -> &dyn std::any::Any;
99
100 fn device(&self) -> Device;
102
103 fn num_tokens(&self) -> usize {
105 self.block_table().sequence_length
106 }
107
108 fn num_layers(&self) -> usize;
110
111 fn num_heads(&self) -> usize;
113
114 fn head_dim(&self) -> usize;
116
117 fn key_cache(&self, layer: usize) -> Result<Option<TensorRef>>;
119
120 fn value_cache(&self, layer: usize) -> Result<Option<TensorRef>>;
122
123 fn kv_cache(&self, layer: usize) -> Result<(Option<TensorRef>, Option<TensorRef>)> {
125 Ok((self.key_cache(layer)?, self.value_cache(layer)?))
126 }
127
128 fn clone_handle(&self) -> Result<Arc<dyn KvCacheHandle>>;
130
131 fn stats(&self) -> CacheHandleStats;
133
134 fn is_valid(&self) -> bool;
136
137 fn cache_id(&self) -> String;
139}
140
141#[derive(Debug, Clone)]
143pub struct CacheHandleStats {
144 pub memory_bytes: usize,
146 pub blocks_allocated: usize,
148 pub tokens_stored: usize,
150 pub utilization: f32,
152 pub last_access: std::time::Instant,
154}
155
156#[derive(Debug, Clone)]
158pub struct AllocationRequest {
159 pub request_id: RequestId,
161 pub initial_tokens: usize,
163 pub max_sequence_length: usize,
165 pub num_layers: usize,
167 pub num_heads: usize,
169 pub head_dim: usize,
171 pub device: Device,
173 pub dtype: ferrum_types::DataType,
175 pub priority: ferrum_types::Priority,
177}
178
179impl AllocationRequest {
180 pub fn estimated_memory_bytes(&self) -> usize {
182 let kv_size =
184 self.num_layers * self.num_heads * self.max_sequence_length * self.head_dim * 2;
185 kv_size * self.dtype.size_bytes()
186 }
187}
188
189#[async_trait::async_trait]
191pub trait KvCacheManager: Send + Sync {
192 async fn allocate(&self, request: &AllocationRequest) -> Result<Arc<dyn KvCacheHandle>>;
194
195 async fn extend(&self, handle: &mut dyn KvCacheHandle, additional_tokens: usize) -> Result<()>;
197
198 async fn deallocate(&self, request_id: RequestId) -> Result<()>;
200
201 fn can_allocate(&self, request: &AllocationRequest) -> bool;
203
204 fn stats(&self) -> CacheManagerStats;
206
207 async fn gc(&self) -> Result<CacheGcStats>;
209
210 fn set_pressure_callback(&self, callback: Box<dyn Fn(MemoryPressure) + Send + Sync>);
212
213 fn get_handle(&self, request_id: RequestId) -> Option<Arc<dyn KvCacheHandle>>;
215
216 fn list_handles(&self) -> Vec<(RequestId, Arc<dyn KvCacheHandle>)>;
218}
219
220#[derive(Debug, Clone, Serialize, Deserialize)]
222pub struct CacheManagerStats {
223 pub total_memory_bytes: usize,
225 pub used_memory_bytes: usize,
227 pub active_caches: usize,
229 pub total_blocks: usize,
231 pub free_blocks: usize,
233 pub cache_hit_rate: f32,
235 pub eviction_count: u64,
237 pub allocation_count: u64,
239 pub allocation_failures: u64,
241}
242
243#[derive(Debug, Clone, Serialize, Deserialize)]
245pub struct CacheGcStats {
246 pub memory_freed: usize,
248 pub caches_freed: usize,
250 pub gc_time_ms: u64,
252}
253
254#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
256pub enum MemoryPressure {
257 Low,
259 Medium,
261 High,
263 Critical,
265}
266
267pub trait AdvancedKvCacheManager: KvCacheManager {
269 async fn enable_prefix_caching(&self, config: PrefixCacheConfig) -> Result<()>;
271
272 async fn share_prefix(
274 &self,
275 source: RequestId,
276 target: RequestId,
277 shared_tokens: usize,
278 ) -> Result<()>;
279
280 async fn swap_out(&self, request_id: RequestId) -> Result<()>;
282
283 async fn swap_in(&self, request_id: RequestId) -> Result<()>;
285
286 async fn compress_cache(&self, request_id: RequestId, compression_ratio: f32) -> Result<()>;
288
289 fn compression_stats(&self) -> CompressionStats;
291}
292
293#[derive(Debug, Clone, Serialize, Deserialize)]
295pub struct PrefixCacheConfig {
296 pub max_prefixes: usize,
298 pub min_prefix_length: usize,
300 pub prefix_ttl_seconds: u64,
302 pub enable_cross_request_sharing: bool,
304}
305
306#[derive(Debug, Clone, Serialize, Deserialize)]
308pub struct CompressionStats {
309 pub compressed_caches: usize,
311 pub memory_saved_bytes: usize,
313 pub avg_compression_ratio: f32,
315 pub avg_compression_time_ms: f64,
317}
318
319pub trait BlockAllocator: Send + Sync {
321 fn allocate_blocks(&self, num_blocks: usize) -> Result<Vec<BlockId>>;
323
324 fn free_blocks(&self, blocks: &[BlockId]) -> Result<()>;
326
327 fn free_block_count(&self) -> usize;
329
330 fn total_block_count(&self) -> usize;
332
333 fn block_size(&self) -> usize;
335
336 fn defragment(&self) -> Result<()>;
338}
339
340#[async_trait::async_trait]
342pub trait MultiDeviceCacheManager: KvCacheManager {
343 fn supported_devices(&self) -> Vec<Device>;
345
346 fn set_device_preference(&self, devices: Vec<Device>);
348
349 async fn move_cache(&self, request_id: RequestId, target_device: Device) -> Result<()>;
351
352 fn get_cache_device(&self, request_id: RequestId) -> Option<Device>;
354
355 async fn rebalance_devices(&self) -> Result<()>;
357
358 fn device_stats(&self) -> HashMap<Device, CacheManagerStats>;
360}
361
362pub trait CacheEvictionPolicy: Send + Sync {
364 fn select_eviction_candidates(
366 &self,
367 required_memory: usize,
368 active_caches: &[(RequestId, Arc<dyn KvCacheHandle>)],
369 ) -> Vec<RequestId>;
370
371 fn record_access(&mut self, request_id: RequestId, access_time: std::time::Instant);
373
374 fn name(&self) -> &str;
376}
377
378pub struct LruEvictionPolicy {
380 access_times: HashMap<RequestId, std::time::Instant>,
381}
382
383impl LruEvictionPolicy {
384 pub fn new() -> Self {
385 Self {
386 access_times: HashMap::new(),
387 }
388 }
389}
390
391impl CacheEvictionPolicy for LruEvictionPolicy {
392 fn select_eviction_candidates(
393 &self,
394 required_memory: usize,
395 active_caches: &[(RequestId, Arc<dyn KvCacheHandle>)],
396 ) -> Vec<RequestId> {
397 let mut candidates: Vec<_> = active_caches
398 .iter()
399 .map(|(req_id, handle)| {
400 let access_time = self
401 .access_times
402 .get(req_id)
403 .copied()
404 .unwrap_or_else(std::time::Instant::now);
405 (req_id.clone(), handle.stats().memory_bytes, access_time)
406 })
407 .collect();
408
409 candidates.sort_by(|a, b| a.2.cmp(&b.2));
411
412 let mut freed_memory = 0;
413 let mut result = Vec::new();
414
415 for (req_id, memory_bytes, _) in candidates {
416 result.push(req_id);
417 freed_memory += memory_bytes;
418 if freed_memory >= required_memory {
419 break;
420 }
421 }
422
423 result
424 }
425
426 fn record_access(&mut self, request_id: RequestId, access_time: std::time::Instant) {
427 self.access_times.insert(request_id, access_time);
428 }
429
430 fn name(&self) -> &str {
431 "lru"
432 }
433}
434
435impl Default for LruEvictionPolicy {
436 fn default() -> Self {
437 Self::new()
438 }
439}
440
441#[derive(Debug, Clone, Serialize, Deserialize)]
443pub struct CacheConfig {
444 pub block_size: usize,
446 pub max_blocks: usize,
448 pub initial_blocks: usize,
450 pub enable_pooling: bool,
452 pub target_devices: Vec<Device>,
454 pub enable_prefix_caching: bool,
456 pub prefix_cache_config: Option<PrefixCacheConfig>,
458 pub enable_multi_device: bool,
460 pub pressure_thresholds: MemoryPressureThresholds,
462}
463
464#[derive(Debug, Clone, Serialize, Deserialize)]
466pub struct MemoryPressureThresholds {
467 pub medium_threshold: f32,
469 pub high_threshold: f32,
471 pub critical_threshold: f32,
473}
474
475impl Default for MemoryPressureThresholds {
476 fn default() -> Self {
477 Self {
478 medium_threshold: 0.6,
479 high_threshold: 0.8,
480 critical_threshold: 0.95,
481 }
482 }
483}
484
485impl Default for CacheConfig {
486 fn default() -> Self {
487 Self {
488 block_size: 16,
489 max_blocks: 1000,
490 initial_blocks: 100,
491 enable_pooling: true,
492 target_devices: vec![Device::CPU],
493 enable_prefix_caching: false,
494 prefix_cache_config: None,
495 enable_multi_device: false,
496 pressure_thresholds: MemoryPressureThresholds::default(),
497 }
498 }
499}