Skip to main content

ferrum_interfaces/
kv_cache.rs

1//! KV-Cache abstraction with handle semantics and block management
2//!
3//! This module provides a sentence-handle based abstraction for KV cache management,
4//! supporting both contiguous and paged attention patterns with zero-copy operations.
5
6use crate::TensorRef;
7use ferrum_types::{BlockId, Device, RequestId, Result};
8use serde::{Deserialize, Serialize};
9use smallvec::SmallVec;
10use std::{collections::HashMap, sync::Arc};
11
12/// Block table for mapping logical to physical cache blocks
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct BlockTable {
15    /// Physical block IDs allocated for this sequence
16    pub physical_blocks: SmallVec<[BlockId; 8]>,
17    /// Mapping from logical to physical block indices
18    pub logical_to_physical: SmallVec<[u32; 8]>,
19    /// Current sequence length in tokens
20    pub sequence_length: usize,
21    /// Block size (tokens per block)
22    pub block_size: usize,
23}
24
25impl BlockTable {
26    /// Create new block table
27    pub fn new(block_size: usize) -> Self {
28        Self {
29            physical_blocks: SmallVec::new(),
30            logical_to_physical: SmallVec::new(),
31            sequence_length: 0,
32            block_size,
33        }
34    }
35
36    /// Get number of blocks allocated
37    pub fn num_blocks(&self) -> usize {
38        self.physical_blocks.len()
39    }
40
41    /// Get required number of blocks for sequence length
42    pub fn blocks_needed_for_length(length: usize, block_size: usize) -> usize {
43        (length + block_size - 1) / block_size // Ceiling division
44    }
45
46    /// Check if can accommodate more tokens without new blocks
47    pub fn has_free_space(&self) -> bool {
48        let used_blocks = Self::blocks_needed_for_length(self.sequence_length, self.block_size);
49        used_blocks < self.num_blocks()
50    }
51
52    /// Get number of free tokens in allocated blocks
53    pub fn free_tokens(&self) -> usize {
54        if self.num_blocks() == 0 {
55            0
56        } else {
57            self.num_blocks() * self.block_size - self.sequence_length
58        }
59    }
60
61    /// Add blocks to the table
62    pub fn add_blocks(&mut self, blocks: &[BlockId]) {
63        let start_logical = self.logical_to_physical.len();
64
65        for (i, &block) in blocks.iter().enumerate() {
66            self.physical_blocks.push(block);
67            self.logical_to_physical.push((start_logical + i) as u32);
68        }
69    }
70
71    /// Extend sequence length
72    pub fn extend_sequence(&mut self, additional_tokens: usize) -> Result<()> {
73        let new_length = self.sequence_length + additional_tokens;
74        let required_blocks = Self::blocks_needed_for_length(new_length, self.block_size);
75
76        if required_blocks > self.num_blocks() {
77            return Err(ferrum_types::FerrumError::backend(format!(
78                "Insufficient blocks: need {}, have {}",
79                required_blocks,
80                self.num_blocks()
81            )));
82        }
83
84        self.sequence_length = new_length;
85        Ok(())
86    }
87}
88
89/// KV cache handle providing access to cached key-value states
90pub trait KvCacheHandle: Send + Sync + std::fmt::Debug {
91    /// Get block table for this cache
92    fn block_table(&self) -> &BlockTable;
93
94    /// Get mutable block table (for extending)
95    fn block_table_mut(&mut self) -> &mut BlockTable;
96
97    /// Downcast support for backend-specific handles
98    fn as_any(&self) -> &dyn std::any::Any;
99
100    /// Get device where cache resides
101    fn device(&self) -> Device;
102
103    /// Get number of tokens stored in cache
104    fn num_tokens(&self) -> usize {
105        self.block_table().sequence_length
106    }
107
108    /// Get number of layers cached
109    fn num_layers(&self) -> usize;
110
111    /// Get number of attention heads
112    fn num_heads(&self) -> usize;
113
114    /// Get head dimension
115    fn head_dim(&self) -> usize;
116
117    /// Get key cache for specific layer (returns tensor reference)
118    fn key_cache(&self, layer: usize) -> Result<Option<TensorRef>>;
119
120    /// Get value cache for specific layer
121    fn value_cache(&self, layer: usize) -> Result<Option<TensorRef>>;
122
123    /// Get both key and value caches for layer
124    fn kv_cache(&self, layer: usize) -> Result<(Option<TensorRef>, Option<TensorRef>)> {
125        Ok((self.key_cache(layer)?, self.value_cache(layer)?))
126    }
127
128    /// Clone handle (creates new reference, not deep copy)
129    fn clone_handle(&self) -> Result<Arc<dyn KvCacheHandle>>;
130
131    /// Get cache statistics
132    fn stats(&self) -> CacheHandleStats;
133
134    /// Check if cache is valid and accessible
135    fn is_valid(&self) -> bool;
136
137    /// Get unique identifier for this cache instance
138    fn cache_id(&self) -> String;
139}
140
141/// Statistics for individual cache handle
142#[derive(Debug, Clone)]
143pub struct CacheHandleStats {
144    /// Total memory usage in bytes
145    pub memory_bytes: usize,
146    /// Number of blocks allocated
147    pub blocks_allocated: usize,
148    /// Number of tokens stored
149    pub tokens_stored: usize,
150    /// Memory utilization ratio
151    pub utilization: f32,
152    /// Last access timestamp (for LRU)
153    pub last_access: std::time::Instant,
154}
155
156/// KV cache allocation request
157#[derive(Debug, Clone)]
158pub struct AllocationRequest {
159    /// Request ID this allocation is for
160    pub request_id: RequestId,
161    /// Initial number of tokens
162    pub initial_tokens: usize,
163    /// Maximum expected sequence length
164    pub max_sequence_length: usize,
165    /// Number of layers to cache
166    pub num_layers: usize,
167    /// Number of attention heads
168    pub num_heads: usize,
169    /// Head dimension
170    pub head_dim: usize,
171    /// Target device
172    pub device: Device,
173    /// Data type for cache
174    pub dtype: ferrum_types::DataType,
175    /// Priority level for allocation
176    pub priority: ferrum_types::Priority,
177}
178
179impl AllocationRequest {
180    /// Calculate estimated memory requirement
181    pub fn estimated_memory_bytes(&self) -> usize {
182        // Key + Value cache size: layers * heads * max_seq * head_dim * 2 * dtype_size
183        let kv_size =
184            self.num_layers * self.num_heads * self.max_sequence_length * self.head_dim * 2;
185        kv_size * self.dtype.size_bytes()
186    }
187}
188
189/// KV cache manager for allocation and lifecycle management
190#[async_trait::async_trait]
191pub trait KvCacheManager: Send + Sync {
192    /// Allocate cache for new sequence
193    async fn allocate(&self, request: &AllocationRequest) -> Result<Arc<dyn KvCacheHandle>>;
194
195    /// Extend existing cache to accommodate more tokens
196    async fn extend(&self, handle: &mut dyn KvCacheHandle, additional_tokens: usize) -> Result<()>;
197
198    /// Deallocate cache (handle becomes invalid)
199    async fn deallocate(&self, request_id: RequestId) -> Result<()>;
200
201    /// Check if can allocate requested cache size
202    fn can_allocate(&self, request: &AllocationRequest) -> bool;
203
204    /// Get cache statistics
205    fn stats(&self) -> CacheManagerStats;
206
207    /// Force garbage collection of unused caches
208    async fn gc(&self) -> Result<CacheGcStats>;
209
210    /// Set memory pressure callback
211    fn set_pressure_callback(&self, callback: Box<dyn Fn(MemoryPressure) + Send + Sync>);
212
213    /// Get handle for existing request (if exists)
214    fn get_handle(&self, request_id: RequestId) -> Option<Arc<dyn KvCacheHandle>>;
215
216    /// List all active cache handles
217    fn list_handles(&self) -> Vec<(RequestId, Arc<dyn KvCacheHandle>)>;
218}
219
220/// Cache manager statistics
221#[derive(Debug, Clone, Serialize, Deserialize)]
222pub struct CacheManagerStats {
223    /// Total memory allocated in bytes
224    pub total_memory_bytes: usize,
225    /// Memory currently in use
226    pub used_memory_bytes: usize,
227    /// Number of active caches
228    pub active_caches: usize,
229    /// Total blocks allocated
230    pub total_blocks: usize,
231    /// Free blocks available
232    pub free_blocks: usize,
233    /// Cache hit rate (for prefix caching)
234    pub cache_hit_rate: f32,
235    /// Number of evictions performed
236    pub eviction_count: u64,
237    /// Number of successful allocations
238    pub allocation_count: u64,
239    /// Number of failed allocations
240    pub allocation_failures: u64,
241}
242
243/// Garbage collection statistics
244#[derive(Debug, Clone, Serialize, Deserialize)]
245pub struct CacheGcStats {
246    /// Memory freed in bytes
247    pub memory_freed: usize,
248    /// Number of caches garbage collected
249    pub caches_freed: usize,
250    /// Time taken for GC
251    pub gc_time_ms: u64,
252}
253
254/// Memory pressure levels for adaptive management
255#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
256pub enum MemoryPressure {
257    /// Low memory usage, allocations can proceed freely
258    Low,
259    /// Moderate usage, start being more conservative
260    Medium,
261    /// High usage, consider eviction
262    High,
263    /// Critical usage, must evict or reject allocations
264    Critical,
265}
266
267/// Advanced KV cache capabilities
268pub trait AdvancedKvCacheManager: KvCacheManager {
269    /// Enable prefix caching for common prompt prefixes
270    async fn enable_prefix_caching(&self, config: PrefixCacheConfig) -> Result<()>;
271
272    /// Share cache blocks between compatible sequences
273    async fn share_prefix(
274        &self,
275        source: RequestId,
276        target: RequestId,
277        shared_tokens: usize,
278    ) -> Result<()>;
279
280    /// Swap cache from GPU to CPU to free GPU memory
281    async fn swap_out(&self, request_id: RequestId) -> Result<()>;
282
283    /// Swap cache from CPU back to GPU
284    async fn swap_in(&self, request_id: RequestId) -> Result<()>;
285
286    /// Compress cache to reduce memory usage
287    async fn compress_cache(&self, request_id: RequestId, compression_ratio: f32) -> Result<()>;
288
289    /// Get cache compression statistics
290    fn compression_stats(&self) -> CompressionStats;
291}
292
293/// Prefix caching configuration
294#[derive(Debug, Clone, Serialize, Deserialize)]
295pub struct PrefixCacheConfig {
296    /// Maximum number of prefixes to cache
297    pub max_prefixes: usize,
298    /// Minimum prefix length to be eligible for caching
299    pub min_prefix_length: usize,
300    /// TTL for cached prefixes
301    pub prefix_ttl_seconds: u64,
302    /// Enable cross-request prefix sharing
303    pub enable_cross_request_sharing: bool,
304}
305
306/// Cache compression statistics
307#[derive(Debug, Clone, Serialize, Deserialize)]
308pub struct CompressionStats {
309    /// Number of compressed caches
310    pub compressed_caches: usize,
311    /// Total memory saved by compression
312    pub memory_saved_bytes: usize,
313    /// Average compression ratio achieved
314    pub avg_compression_ratio: f32,
315    /// Compression/decompression time overhead
316    pub avg_compression_time_ms: f64,
317}
318
319/// Block-based cache allocator
320pub trait BlockAllocator: Send + Sync {
321    /// Allocate specified number of blocks
322    fn allocate_blocks(&self, num_blocks: usize) -> Result<Vec<BlockId>>;
323
324    /// Free blocks back to allocator
325    fn free_blocks(&self, blocks: &[BlockId]) -> Result<()>;
326
327    /// Get number of free blocks
328    fn free_block_count(&self) -> usize;
329
330    /// Get total block count
331    fn total_block_count(&self) -> usize;
332
333    /// Get block size in tokens
334    fn block_size(&self) -> usize;
335
336    /// Defragment free block list
337    fn defragment(&self) -> Result<()>;
338}
339
340/// Multi-device cache manager supporting GPU/CPU hierarchies
341#[async_trait::async_trait]
342pub trait MultiDeviceCacheManager: KvCacheManager {
343    /// Get supported devices
344    fn supported_devices(&self) -> Vec<Device>;
345
346    /// Set device preference for new allocations
347    fn set_device_preference(&self, devices: Vec<Device>);
348
349    /// Move cache between devices
350    async fn move_cache(&self, request_id: RequestId, target_device: Device) -> Result<()>;
351
352    /// Get cache location
353    fn get_cache_device(&self, request_id: RequestId) -> Option<Device>;
354
355    /// Balance cache distribution across devices
356    async fn rebalance_devices(&self) -> Result<()>;
357
358    /// Get per-device statistics
359    fn device_stats(&self) -> HashMap<Device, CacheManagerStats>;
360}
361
362/// Cache eviction strategies
363pub trait CacheEvictionPolicy: Send + Sync {
364    /// Select caches to evict to free requested memory
365    fn select_eviction_candidates(
366        &self,
367        required_memory: usize,
368        active_caches: &[(RequestId, Arc<dyn KvCacheHandle>)],
369    ) -> Vec<RequestId>;
370
371    /// Update cache access information
372    fn record_access(&mut self, request_id: RequestId, access_time: std::time::Instant);
373
374    /// Get policy name
375    fn name(&self) -> &str;
376}
377
378/// Least Recently Used eviction policy
379pub struct LruEvictionPolicy {
380    access_times: HashMap<RequestId, std::time::Instant>,
381}
382
383impl LruEvictionPolicy {
384    pub fn new() -> Self {
385        Self {
386            access_times: HashMap::new(),
387        }
388    }
389}
390
391impl CacheEvictionPolicy for LruEvictionPolicy {
392    fn select_eviction_candidates(
393        &self,
394        required_memory: usize,
395        active_caches: &[(RequestId, Arc<dyn KvCacheHandle>)],
396    ) -> Vec<RequestId> {
397        let mut candidates: Vec<_> = active_caches
398            .iter()
399            .map(|(req_id, handle)| {
400                let access_time = self
401                    .access_times
402                    .get(req_id)
403                    .copied()
404                    .unwrap_or_else(std::time::Instant::now);
405                (req_id.clone(), handle.stats().memory_bytes, access_time)
406            })
407            .collect();
408
409        // Sort by access time (oldest first)
410        candidates.sort_by(|a, b| a.2.cmp(&b.2));
411
412        let mut freed_memory = 0;
413        let mut result = Vec::new();
414
415        for (req_id, memory_bytes, _) in candidates {
416            result.push(req_id);
417            freed_memory += memory_bytes;
418            if freed_memory >= required_memory {
419                break;
420            }
421        }
422
423        result
424    }
425
426    fn record_access(&mut self, request_id: RequestId, access_time: std::time::Instant) {
427        self.access_times.insert(request_id, access_time);
428    }
429
430    fn name(&self) -> &str {
431        "lru"
432    }
433}
434
435impl Default for LruEvictionPolicy {
436    fn default() -> Self {
437        Self::new()
438    }
439}
440
441/// Cache configuration
442#[derive(Debug, Clone, Serialize, Deserialize)]
443pub struct CacheConfig {
444    /// Block size in tokens
445    pub block_size: usize,
446    /// Maximum number of blocks
447    pub max_blocks: usize,
448    /// Initial number of blocks to allocate
449    pub initial_blocks: usize,
450    /// Enable memory pooling
451    pub enable_pooling: bool,
452    /// Target devices for allocation
453    pub target_devices: Vec<Device>,
454    /// Enable prefix caching
455    pub enable_prefix_caching: bool,
456    /// Prefix cache configuration
457    pub prefix_cache_config: Option<PrefixCacheConfig>,
458    /// Enable multi-device support
459    pub enable_multi_device: bool,
460    /// Memory pressure thresholds
461    pub pressure_thresholds: MemoryPressureThresholds,
462}
463
464/// Memory pressure threshold configuration
465#[derive(Debug, Clone, Serialize, Deserialize)]
466pub struct MemoryPressureThresholds {
467    /// Medium pressure threshold (0.0-1.0)
468    pub medium_threshold: f32,
469    /// High pressure threshold (0.0-1.0)
470    pub high_threshold: f32,
471    /// Critical pressure threshold (0.0-1.0)
472    pub critical_threshold: f32,
473}
474
475impl Default for MemoryPressureThresholds {
476    fn default() -> Self {
477        Self {
478            medium_threshold: 0.6,
479            high_threshold: 0.8,
480            critical_threshold: 0.95,
481        }
482    }
483}
484
485impl Default for CacheConfig {
486    fn default() -> Self {
487        Self {
488            block_size: 16,
489            max_blocks: 1000,
490            initial_blocks: 100,
491            enable_pooling: true,
492            target_devices: vec![Device::CPU],
493            enable_prefix_caching: false,
494            prefix_cache_config: None,
495            enable_multi_device: false,
496            pressure_thresholds: MemoryPressureThresholds::default(),
497        }
498    }
499}