1use crate::TensorRef;
7use ferrum_types::{BlockId, Device, RequestId, Result};
8use serde::{Deserialize, Serialize};
9use smallvec::SmallVec;
10use std::{collections::HashMap, sync::Arc};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct BlockTable {
15 pub physical_blocks: SmallVec<[BlockId; 8]>,
17 pub logical_to_physical: SmallVec<[u32; 8]>,
19 pub sequence_length: usize,
29 pub block_size: usize,
31}
32
33impl BlockTable {
34 pub fn new(block_size: usize) -> Self {
36 Self {
37 physical_blocks: SmallVec::new(),
38 logical_to_physical: SmallVec::new(),
39 sequence_length: 0,
40 block_size,
41 }
42 }
43
44 pub fn num_blocks(&self) -> usize {
46 self.physical_blocks.len()
47 }
48
49 pub fn blocks_needed_for_length(length: usize, block_size: usize) -> usize {
51 length.div_ceil(block_size) }
53
54 pub fn has_free_space(&self) -> bool {
56 let used_blocks = Self::blocks_needed_for_length(self.sequence_length, self.block_size);
57 used_blocks < self.num_blocks()
58 }
59
60 pub fn free_tokens(&self) -> usize {
62 if self.num_blocks() == 0 {
63 0
64 } else {
65 self.num_blocks() * self.block_size - self.sequence_length
66 }
67 }
68
69 pub fn add_blocks(&mut self, blocks: &[BlockId]) {
71 let start_logical = self.logical_to_physical.len();
72
73 for (i, &block) in blocks.iter().enumerate() {
74 self.physical_blocks.push(block);
75 self.logical_to_physical.push((start_logical + i) as u32);
76 }
77 }
78
79 pub fn extend_sequence(&mut self, additional_tokens: usize) -> Result<()> {
81 let new_length = self.sequence_length + additional_tokens;
82 let required_blocks = Self::blocks_needed_for_length(new_length, self.block_size);
83
84 if required_blocks > self.num_blocks() {
85 return Err(ferrum_types::FerrumError::backend(format!(
86 "Insufficient blocks: need {}, have {}",
87 required_blocks,
88 self.num_blocks()
89 )));
90 }
91
92 self.sequence_length = new_length;
93 Ok(())
94 }
95}
96
97pub trait KvCacheHandle: Send + Sync + std::fmt::Debug {
99 fn block_table(&self) -> &BlockTable;
101
102 fn block_table_mut(&mut self) -> &mut BlockTable;
104
105 fn as_any(&self) -> &dyn std::any::Any;
107
108 fn device(&self) -> Device;
110
111 fn num_tokens(&self) -> usize {
113 self.block_table().sequence_length
114 }
115
116 fn num_layers(&self) -> usize;
118
119 fn num_heads(&self) -> usize;
121
122 fn head_dim(&self) -> usize;
124
125 fn key_cache(&self, layer: usize) -> Result<Option<TensorRef>>;
127
128 fn value_cache(&self, layer: usize) -> Result<Option<TensorRef>>;
130
131 fn kv_cache(&self, layer: usize) -> Result<(Option<TensorRef>, Option<TensorRef>)> {
133 Ok((self.key_cache(layer)?, self.value_cache(layer)?))
134 }
135
136 fn clone_handle(&self) -> Result<Arc<dyn KvCacheHandle>>;
138
139 fn stats(&self) -> CacheHandleStats;
141
142 fn is_valid(&self) -> bool;
144
145 fn cache_id(&self) -> String;
147}
148
149#[derive(Debug, Clone)]
151pub struct CacheHandleStats {
152 pub memory_bytes: usize,
154 pub blocks_allocated: usize,
156 pub tokens_stored: usize,
158 pub utilization: f32,
160 pub last_access: std::time::Instant,
162}
163
164#[derive(Debug, Clone)]
166pub struct AllocationRequest {
167 pub request_id: RequestId,
169 pub initial_tokens: usize,
171 pub max_sequence_length: usize,
173 pub num_layers: usize,
175 pub num_heads: usize,
177 pub head_dim: usize,
179 pub device: Device,
181 pub dtype: ferrum_types::DataType,
183 pub priority: ferrum_types::Priority,
185}
186
187impl AllocationRequest {
188 pub fn estimated_memory_bytes(&self) -> usize {
190 let kv_size =
192 self.num_layers * self.num_heads * self.max_sequence_length * self.head_dim * 2;
193 kv_size * self.dtype.size_bytes()
194 }
195}
196
197#[async_trait::async_trait]
199pub trait KvCacheManager: Send + Sync {
200 async fn allocate(&self, request: &AllocationRequest) -> Result<Arc<dyn KvCacheHandle>>;
202
203 async fn extend(&self, handle: &mut dyn KvCacheHandle, additional_tokens: usize) -> Result<()>;
205
206 async fn deallocate(&self, request_id: RequestId) -> Result<()>;
208
209 fn can_allocate(&self, request: &AllocationRequest) -> bool;
211
212 fn stats(&self) -> CacheManagerStats;
214
215 async fn gc(&self) -> Result<CacheGcStats>;
217
218 fn set_pressure_callback(&self, callback: Box<dyn Fn(MemoryPressure) + Send + Sync>);
220
221 fn get_handle(&self, request_id: RequestId) -> Option<Arc<dyn KvCacheHandle>>;
223
224 fn list_handles(&self) -> Vec<(RequestId, Arc<dyn KvCacheHandle>)>;
226}
227
228#[derive(Debug, Clone, Serialize, Deserialize)]
230pub struct CacheManagerStats {
231 pub total_memory_bytes: usize,
233 pub used_memory_bytes: usize,
235 pub active_caches: usize,
237 pub total_blocks: usize,
239 pub free_blocks: usize,
241 pub cache_hit_rate: f32,
243 pub eviction_count: u64,
245 pub allocation_count: u64,
247 pub allocation_failures: u64,
249}
250
251#[derive(Debug, Clone, Serialize, Deserialize)]
253pub struct CacheGcStats {
254 pub memory_freed: usize,
256 pub caches_freed: usize,
258 pub gc_time_ms: u64,
260}
261
262#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
264pub enum MemoryPressure {
265 Low,
267 Medium,
269 High,
271 Critical,
273}
274
275pub trait AdvancedKvCacheManager: KvCacheManager {
277 async fn enable_prefix_caching(&self, config: PrefixCacheConfig) -> Result<()>;
279
280 async fn share_prefix(
282 &self,
283 source: RequestId,
284 target: RequestId,
285 shared_tokens: usize,
286 ) -> Result<()>;
287
288 async fn swap_out(&self, request_id: RequestId) -> Result<()>;
290
291 async fn swap_in(&self, request_id: RequestId) -> Result<()>;
293
294 async fn compress_cache(&self, request_id: RequestId, compression_ratio: f32) -> Result<()>;
296
297 fn compression_stats(&self) -> CompressionStats;
299}
300
301#[derive(Debug, Clone, Serialize, Deserialize)]
303pub struct PrefixCacheConfig {
304 pub max_prefixes: usize,
306 pub min_prefix_length: usize,
308 pub prefix_ttl_seconds: u64,
310 pub enable_cross_request_sharing: bool,
312}
313
314#[derive(Debug, Clone, Serialize, Deserialize)]
316pub struct CompressionStats {
317 pub compressed_caches: usize,
319 pub memory_saved_bytes: usize,
321 pub avg_compression_ratio: f32,
323 pub avg_compression_time_ms: f64,
325}
326
327pub trait BlockAllocator: Send + Sync {
329 fn allocate_blocks(&self, num_blocks: usize) -> Result<Vec<BlockId>>;
331
332 fn free_blocks(&self, blocks: &[BlockId]) -> Result<()>;
334
335 fn free_block_count(&self) -> usize;
337
338 fn total_block_count(&self) -> usize;
340
341 fn block_size(&self) -> usize;
343
344 fn defragment(&self) -> Result<()>;
346}
347
348#[async_trait::async_trait]
350pub trait MultiDeviceCacheManager: KvCacheManager {
351 fn supported_devices(&self) -> Vec<Device>;
353
354 fn set_device_preference(&self, devices: Vec<Device>);
356
357 async fn move_cache(&self, request_id: RequestId, target_device: Device) -> Result<()>;
359
360 fn get_cache_device(&self, request_id: RequestId) -> Option<Device>;
362
363 async fn rebalance_devices(&self) -> Result<()>;
365
366 fn device_stats(&self) -> HashMap<Device, CacheManagerStats>;
368}
369
370pub trait CacheEvictionPolicy: Send + Sync {
372 fn select_eviction_candidates(
374 &self,
375 required_memory: usize,
376 active_caches: &[(RequestId, Arc<dyn KvCacheHandle>)],
377 ) -> Vec<RequestId>;
378
379 fn record_access(&mut self, request_id: RequestId, access_time: std::time::Instant);
381
382 fn name(&self) -> &str;
384}
385
386pub struct LruEvictionPolicy {
388 access_times: HashMap<RequestId, std::time::Instant>,
389}
390
391impl LruEvictionPolicy {
392 pub fn new() -> Self {
393 Self {
394 access_times: HashMap::new(),
395 }
396 }
397}
398
399impl CacheEvictionPolicy for LruEvictionPolicy {
400 fn select_eviction_candidates(
401 &self,
402 required_memory: usize,
403 active_caches: &[(RequestId, Arc<dyn KvCacheHandle>)],
404 ) -> Vec<RequestId> {
405 let mut candidates: Vec<_> = active_caches
406 .iter()
407 .map(|(req_id, handle)| {
408 let access_time = self
409 .access_times
410 .get(req_id)
411 .copied()
412 .unwrap_or_else(std::time::Instant::now);
413 (req_id.clone(), handle.stats().memory_bytes, access_time)
414 })
415 .collect();
416
417 candidates.sort_by(|a, b| a.2.cmp(&b.2));
419
420 let mut freed_memory = 0;
421 let mut result = Vec::new();
422
423 for (req_id, memory_bytes, _) in candidates {
424 result.push(req_id);
425 freed_memory += memory_bytes;
426 if freed_memory >= required_memory {
427 break;
428 }
429 }
430
431 result
432 }
433
434 fn record_access(&mut self, request_id: RequestId, access_time: std::time::Instant) {
435 self.access_times.insert(request_id, access_time);
436 }
437
438 fn name(&self) -> &str {
439 "lru"
440 }
441}
442
443impl Default for LruEvictionPolicy {
444 fn default() -> Self {
445 Self::new()
446 }
447}
448
449#[derive(Debug, Clone, Serialize, Deserialize)]
451pub struct CacheConfig {
452 pub block_size: usize,
454 pub max_blocks: usize,
456 pub initial_blocks: usize,
458 pub enable_pooling: bool,
460 pub target_devices: Vec<Device>,
462 pub enable_prefix_caching: bool,
464 pub prefix_cache_config: Option<PrefixCacheConfig>,
466 pub enable_multi_device: bool,
468 pub pressure_thresholds: MemoryPressureThresholds,
470}
471
472#[derive(Debug, Clone, Serialize, Deserialize)]
474pub struct MemoryPressureThresholds {
475 pub medium_threshold: f32,
477 pub high_threshold: f32,
479 pub critical_threshold: f32,
481}
482
483impl Default for MemoryPressureThresholds {
484 fn default() -> Self {
485 Self {
486 medium_threshold: 0.6,
487 high_threshold: 0.8,
488 critical_threshold: 0.95,
489 }
490 }
491}
492
493impl Default for CacheConfig {
494 fn default() -> Self {
495 Self {
496 block_size: 16,
497 max_blocks: 1000,
498 initial_blocks: 100,
499 enable_pooling: true,
500 target_devices: vec![Device::CPU],
501 enable_prefix_caching: false,
502 prefix_cache_config: None,
503 enable_multi_device: false,
504 pressure_thresholds: MemoryPressureThresholds::default(),
505 }
506 }
507}