Skip to main content

ferrum_interfaces/
kv_cache.rs

1//! KV-Cache abstraction with handle semantics and block management
2//!
3//! This module provides a sentence-handle based abstraction for KV cache management,
4//! supporting both contiguous and paged attention patterns with zero-copy operations.
5
6use crate::TensorRef;
7use ferrum_types::{BlockId, Device, RequestId, Result};
8use serde::{Deserialize, Serialize};
9use smallvec::SmallVec;
10use std::{collections::HashMap, sync::Arc};
11
12/// Block table for mapping logical to physical cache blocks
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct BlockTable {
15    /// Physical block IDs allocated for this sequence
16    pub physical_blocks: SmallVec<[BlockId; 8]>,
17    /// Mapping from logical to physical block indices
18    pub logical_to_physical: SmallVec<[u32; 8]>,
19    /// Cache-fill bookkeeping owned by the `KvCacheManager` that allocated
20    /// this handle. **Not authoritative for engine-side position tracking**:
21    /// only updated by the manager that minted the handle (Paged sets it
22    /// at allocate-time and never bumps it; Default ignores `initial_tokens`
23    /// and leaves it at 0; `GenericKvCacheHandle` is the only one updated
24    /// per decode step, via `with_sequence_length`). The engine sources the
25    /// "where is the next token written" answer from `SequenceState`
26    /// (`input_tokens.len() + generated_tokens.len() - 1`) — see
27    /// `ContinuousBatchEngine::process_batch_unified` / `run_batch_decode`.
28    pub sequence_length: usize,
29    /// Block size (tokens per block)
30    pub block_size: usize,
31}
32
33impl BlockTable {
34    /// Create new block table
35    pub fn new(block_size: usize) -> Self {
36        Self {
37            physical_blocks: SmallVec::new(),
38            logical_to_physical: SmallVec::new(),
39            sequence_length: 0,
40            block_size,
41        }
42    }
43
44    /// Get number of blocks allocated
45    pub fn num_blocks(&self) -> usize {
46        self.physical_blocks.len()
47    }
48
49    /// Get required number of blocks for sequence length
50    pub fn blocks_needed_for_length(length: usize, block_size: usize) -> usize {
51        length.div_ceil(block_size) // Ceiling division
52    }
53
54    /// Check if can accommodate more tokens without new blocks
55    pub fn has_free_space(&self) -> bool {
56        let used_blocks = Self::blocks_needed_for_length(self.sequence_length, self.block_size);
57        used_blocks < self.num_blocks()
58    }
59
60    /// Get number of free tokens in allocated blocks
61    pub fn free_tokens(&self) -> usize {
62        if self.num_blocks() == 0 {
63            0
64        } else {
65            self.num_blocks() * self.block_size - self.sequence_length
66        }
67    }
68
69    /// Add blocks to the table
70    pub fn add_blocks(&mut self, blocks: &[BlockId]) {
71        let start_logical = self.logical_to_physical.len();
72
73        for (i, &block) in blocks.iter().enumerate() {
74            self.physical_blocks.push(block);
75            self.logical_to_physical.push((start_logical + i) as u32);
76        }
77    }
78
79    /// Extend sequence length
80    pub fn extend_sequence(&mut self, additional_tokens: usize) -> Result<()> {
81        let new_length = self.sequence_length + additional_tokens;
82        let required_blocks = Self::blocks_needed_for_length(new_length, self.block_size);
83
84        if required_blocks > self.num_blocks() {
85            return Err(ferrum_types::FerrumError::backend(format!(
86                "Insufficient blocks: need {}, have {}",
87                required_blocks,
88                self.num_blocks()
89            )));
90        }
91
92        self.sequence_length = new_length;
93        Ok(())
94    }
95}
96
97/// KV cache handle providing access to cached key-value states
98pub trait KvCacheHandle: Send + Sync + std::fmt::Debug {
99    /// Get block table for this cache
100    fn block_table(&self) -> &BlockTable;
101
102    /// Get mutable block table (for extending)
103    fn block_table_mut(&mut self) -> &mut BlockTable;
104
105    /// Downcast support for backend-specific handles
106    fn as_any(&self) -> &dyn std::any::Any;
107
108    /// Get device where cache resides
109    fn device(&self) -> Device;
110
111    /// Get number of tokens stored in cache
112    fn num_tokens(&self) -> usize {
113        self.block_table().sequence_length
114    }
115
116    /// Get number of layers cached
117    fn num_layers(&self) -> usize;
118
119    /// Get number of attention heads
120    fn num_heads(&self) -> usize;
121
122    /// Get head dimension
123    fn head_dim(&self) -> usize;
124
125    /// Get key cache for specific layer (returns tensor reference)
126    fn key_cache(&self, layer: usize) -> Result<Option<TensorRef>>;
127
128    /// Get value cache for specific layer
129    fn value_cache(&self, layer: usize) -> Result<Option<TensorRef>>;
130
131    /// Get both key and value caches for layer
132    fn kv_cache(&self, layer: usize) -> Result<(Option<TensorRef>, Option<TensorRef>)> {
133        Ok((self.key_cache(layer)?, self.value_cache(layer)?))
134    }
135
136    /// Clone handle (creates new reference, not deep copy)
137    fn clone_handle(&self) -> Result<Arc<dyn KvCacheHandle>>;
138
139    /// Get cache statistics
140    fn stats(&self) -> CacheHandleStats;
141
142    /// Check if cache is valid and accessible
143    fn is_valid(&self) -> bool;
144
145    /// Get unique identifier for this cache instance
146    fn cache_id(&self) -> String;
147}
148
149/// Statistics for individual cache handle
150#[derive(Debug, Clone)]
151pub struct CacheHandleStats {
152    /// Total memory usage in bytes
153    pub memory_bytes: usize,
154    /// Number of blocks allocated
155    pub blocks_allocated: usize,
156    /// Number of tokens stored
157    pub tokens_stored: usize,
158    /// Memory utilization ratio
159    pub utilization: f32,
160    /// Last access timestamp (for LRU)
161    pub last_access: std::time::Instant,
162}
163
164/// KV cache allocation request
165#[derive(Debug, Clone)]
166pub struct AllocationRequest {
167    /// Request ID this allocation is for
168    pub request_id: RequestId,
169    /// Initial number of tokens
170    pub initial_tokens: usize,
171    /// Maximum expected sequence length
172    pub max_sequence_length: usize,
173    /// Number of layers to cache
174    pub num_layers: usize,
175    /// Number of attention heads
176    pub num_heads: usize,
177    /// Head dimension
178    pub head_dim: usize,
179    /// Target device
180    pub device: Device,
181    /// Data type for cache
182    pub dtype: ferrum_types::DataType,
183    /// Priority level for allocation
184    pub priority: ferrum_types::Priority,
185}
186
187impl AllocationRequest {
188    /// Calculate estimated memory requirement
189    pub fn estimated_memory_bytes(&self) -> usize {
190        // Key + Value cache size: layers * heads * max_seq * head_dim * 2 * dtype_size
191        let kv_size =
192            self.num_layers * self.num_heads * self.max_sequence_length * self.head_dim * 2;
193        kv_size * self.dtype.size_bytes()
194    }
195}
196
197/// KV cache manager for allocation and lifecycle management
198#[async_trait::async_trait]
199pub trait KvCacheManager: Send + Sync {
200    /// Allocate cache for new sequence
201    async fn allocate(&self, request: &AllocationRequest) -> Result<Arc<dyn KvCacheHandle>>;
202
203    /// Extend existing cache to accommodate more tokens
204    async fn extend(&self, handle: &mut dyn KvCacheHandle, additional_tokens: usize) -> Result<()>;
205
206    /// Deallocate cache (handle becomes invalid)
207    async fn deallocate(&self, request_id: RequestId) -> Result<()>;
208
209    /// Check if can allocate requested cache size
210    fn can_allocate(&self, request: &AllocationRequest) -> bool;
211
212    /// Get cache statistics
213    fn stats(&self) -> CacheManagerStats;
214
215    /// Force garbage collection of unused caches
216    async fn gc(&self) -> Result<CacheGcStats>;
217
218    /// Set memory pressure callback
219    fn set_pressure_callback(&self, callback: Box<dyn Fn(MemoryPressure) + Send + Sync>);
220
221    /// Get handle for existing request (if exists)
222    fn get_handle(&self, request_id: RequestId) -> Option<Arc<dyn KvCacheHandle>>;
223
224    /// List all active cache handles
225    fn list_handles(&self) -> Vec<(RequestId, Arc<dyn KvCacheHandle>)>;
226}
227
228/// Cache manager statistics
229#[derive(Debug, Clone, Serialize, Deserialize)]
230pub struct CacheManagerStats {
231    /// Total memory allocated in bytes
232    pub total_memory_bytes: usize,
233    /// Memory currently in use
234    pub used_memory_bytes: usize,
235    /// Number of active caches
236    pub active_caches: usize,
237    /// Total blocks allocated
238    pub total_blocks: usize,
239    /// Free blocks available
240    pub free_blocks: usize,
241    /// Cache hit rate (for prefix caching)
242    pub cache_hit_rate: f32,
243    /// Number of evictions performed
244    pub eviction_count: u64,
245    /// Number of successful allocations
246    pub allocation_count: u64,
247    /// Number of failed allocations
248    pub allocation_failures: u64,
249}
250
251/// Garbage collection statistics
252#[derive(Debug, Clone, Serialize, Deserialize)]
253pub struct CacheGcStats {
254    /// Memory freed in bytes
255    pub memory_freed: usize,
256    /// Number of caches garbage collected
257    pub caches_freed: usize,
258    /// Time taken for GC
259    pub gc_time_ms: u64,
260}
261
262/// Memory pressure levels for adaptive management
263#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
264pub enum MemoryPressure {
265    /// Low memory usage, allocations can proceed freely
266    Low,
267    /// Moderate usage, start being more conservative
268    Medium,
269    /// High usage, consider eviction
270    High,
271    /// Critical usage, must evict or reject allocations
272    Critical,
273}
274
275/// Advanced KV cache capabilities
276pub trait AdvancedKvCacheManager: KvCacheManager {
277    /// Enable prefix caching for common prompt prefixes
278    async fn enable_prefix_caching(&self, config: PrefixCacheConfig) -> Result<()>;
279
280    /// Share cache blocks between compatible sequences
281    async fn share_prefix(
282        &self,
283        source: RequestId,
284        target: RequestId,
285        shared_tokens: usize,
286    ) -> Result<()>;
287
288    /// Swap cache from GPU to CPU to free GPU memory
289    async fn swap_out(&self, request_id: RequestId) -> Result<()>;
290
291    /// Swap cache from CPU back to GPU
292    async fn swap_in(&self, request_id: RequestId) -> Result<()>;
293
294    /// Compress cache to reduce memory usage
295    async fn compress_cache(&self, request_id: RequestId, compression_ratio: f32) -> Result<()>;
296
297    /// Get cache compression statistics
298    fn compression_stats(&self) -> CompressionStats;
299}
300
301/// Prefix caching configuration
302#[derive(Debug, Clone, Serialize, Deserialize)]
303pub struct PrefixCacheConfig {
304    /// Maximum number of prefixes to cache
305    pub max_prefixes: usize,
306    /// Minimum prefix length to be eligible for caching
307    pub min_prefix_length: usize,
308    /// TTL for cached prefixes
309    pub prefix_ttl_seconds: u64,
310    /// Enable cross-request prefix sharing
311    pub enable_cross_request_sharing: bool,
312}
313
314/// Cache compression statistics
315#[derive(Debug, Clone, Serialize, Deserialize)]
316pub struct CompressionStats {
317    /// Number of compressed caches
318    pub compressed_caches: usize,
319    /// Total memory saved by compression
320    pub memory_saved_bytes: usize,
321    /// Average compression ratio achieved
322    pub avg_compression_ratio: f32,
323    /// Compression/decompression time overhead
324    pub avg_compression_time_ms: f64,
325}
326
327/// Block-based cache allocator
328pub trait BlockAllocator: Send + Sync {
329    /// Allocate specified number of blocks
330    fn allocate_blocks(&self, num_blocks: usize) -> Result<Vec<BlockId>>;
331
332    /// Free blocks back to allocator
333    fn free_blocks(&self, blocks: &[BlockId]) -> Result<()>;
334
335    /// Get number of free blocks
336    fn free_block_count(&self) -> usize;
337
338    /// Get total block count
339    fn total_block_count(&self) -> usize;
340
341    /// Get block size in tokens
342    fn block_size(&self) -> usize;
343
344    /// Defragment free block list
345    fn defragment(&self) -> Result<()>;
346}
347
348/// Multi-device cache manager supporting GPU/CPU hierarchies
349#[async_trait::async_trait]
350pub trait MultiDeviceCacheManager: KvCacheManager {
351    /// Get supported devices
352    fn supported_devices(&self) -> Vec<Device>;
353
354    /// Set device preference for new allocations
355    fn set_device_preference(&self, devices: Vec<Device>);
356
357    /// Move cache between devices
358    async fn move_cache(&self, request_id: RequestId, target_device: Device) -> Result<()>;
359
360    /// Get cache location
361    fn get_cache_device(&self, request_id: RequestId) -> Option<Device>;
362
363    /// Balance cache distribution across devices
364    async fn rebalance_devices(&self) -> Result<()>;
365
366    /// Get per-device statistics
367    fn device_stats(&self) -> HashMap<Device, CacheManagerStats>;
368}
369
370/// Cache eviction strategies
371pub trait CacheEvictionPolicy: Send + Sync {
372    /// Select caches to evict to free requested memory
373    fn select_eviction_candidates(
374        &self,
375        required_memory: usize,
376        active_caches: &[(RequestId, Arc<dyn KvCacheHandle>)],
377    ) -> Vec<RequestId>;
378
379    /// Update cache access information
380    fn record_access(&mut self, request_id: RequestId, access_time: std::time::Instant);
381
382    /// Get policy name
383    fn name(&self) -> &str;
384}
385
386/// Least Recently Used eviction policy
387pub struct LruEvictionPolicy {
388    access_times: HashMap<RequestId, std::time::Instant>,
389}
390
391impl LruEvictionPolicy {
392    pub fn new() -> Self {
393        Self {
394            access_times: HashMap::new(),
395        }
396    }
397}
398
399impl CacheEvictionPolicy for LruEvictionPolicy {
400    fn select_eviction_candidates(
401        &self,
402        required_memory: usize,
403        active_caches: &[(RequestId, Arc<dyn KvCacheHandle>)],
404    ) -> Vec<RequestId> {
405        let mut candidates: Vec<_> = active_caches
406            .iter()
407            .map(|(req_id, handle)| {
408                let access_time = self
409                    .access_times
410                    .get(req_id)
411                    .copied()
412                    .unwrap_or_else(std::time::Instant::now);
413                (req_id.clone(), handle.stats().memory_bytes, access_time)
414            })
415            .collect();
416
417        // Sort by access time (oldest first)
418        candidates.sort_by(|a, b| a.2.cmp(&b.2));
419
420        let mut freed_memory = 0;
421        let mut result = Vec::new();
422
423        for (req_id, memory_bytes, _) in candidates {
424            result.push(req_id);
425            freed_memory += memory_bytes;
426            if freed_memory >= required_memory {
427                break;
428            }
429        }
430
431        result
432    }
433
434    fn record_access(&mut self, request_id: RequestId, access_time: std::time::Instant) {
435        self.access_times.insert(request_id, access_time);
436    }
437
438    fn name(&self) -> &str {
439        "lru"
440    }
441}
442
443impl Default for LruEvictionPolicy {
444    fn default() -> Self {
445        Self::new()
446    }
447}
448
449/// Cache configuration
450#[derive(Debug, Clone, Serialize, Deserialize)]
451pub struct CacheConfig {
452    /// Block size in tokens
453    pub block_size: usize,
454    /// Maximum number of blocks
455    pub max_blocks: usize,
456    /// Initial number of blocks to allocate
457    pub initial_blocks: usize,
458    /// Enable memory pooling
459    pub enable_pooling: bool,
460    /// Target devices for allocation
461    pub target_devices: Vec<Device>,
462    /// Enable prefix caching
463    pub enable_prefix_caching: bool,
464    /// Prefix cache configuration
465    pub prefix_cache_config: Option<PrefixCacheConfig>,
466    /// Enable multi-device support
467    pub enable_multi_device: bool,
468    /// Memory pressure thresholds
469    pub pressure_thresholds: MemoryPressureThresholds,
470}
471
472/// Memory pressure threshold configuration
473#[derive(Debug, Clone, Serialize, Deserialize)]
474pub struct MemoryPressureThresholds {
475    /// Medium pressure threshold (0.0-1.0)
476    pub medium_threshold: f32,
477    /// High pressure threshold (0.0-1.0)
478    pub high_threshold: f32,
479    /// Critical pressure threshold (0.0-1.0)
480    pub critical_threshold: f32,
481}
482
483impl Default for MemoryPressureThresholds {
484    fn default() -> Self {
485        Self {
486            medium_threshold: 0.6,
487            high_threshold: 0.8,
488            critical_threshold: 0.95,
489        }
490    }
491}
492
493impl Default for CacheConfig {
494    fn default() -> Self {
495        Self {
496            block_size: 16,
497            max_blocks: 1000,
498            initial_blocks: 100,
499            enable_pooling: true,
500            target_devices: vec![Device::CPU],
501            enable_prefix_caching: false,
502            prefix_cache_config: None,
503            enable_multi_device: false,
504            pressure_thresholds: MemoryPressureThresholds::default(),
505        }
506    }
507}