Skip to main content

ferrum_testkit/
kv_cache.rs

1//! Mock KV cache manager for testing without GPU memory.
2
3use async_trait::async_trait;
4use ferrum_interfaces::{
5    kv_cache::{
6        AllocationRequest, CacheGcStats, CacheHandleStats, CacheManagerStats, MemoryPressure,
7    },
8    BlockTable, KvCacheHandle, KvCacheManager, TensorRef,
9};
10use ferrum_types::{Device, RequestId, Result};
11use parking_lot::RwLock;
12use std::collections::HashMap;
13use std::sync::atomic::{AtomicU64, Ordering};
14use std::sync::Arc;
15
16/// Mock KV cache handle — tracks block metadata without allocating real memory.
17#[derive(Debug)]
18pub struct MockKvCacheHandle {
19    request_id: RequestId,
20    block_table: BlockTable,
21    num_layers: usize,
22    num_heads: usize,
23    head_dim: usize,
24    device: Device,
25}
26
27impl MockKvCacheHandle {
28    pub fn new(request_id: RequestId, num_layers: usize, seq_len: usize) -> Self {
29        let mut block_table = BlockTable::new(16);
30        block_table.sequence_length = seq_len;
31        // Add a physical block
32        let blocks_needed = BlockTable::blocks_needed_for_length(seq_len, 16);
33        let block_ids: Vec<u32> = (0..blocks_needed as u32).collect();
34        block_table.add_blocks(&block_ids);
35
36        Self {
37            request_id,
38            block_table,
39            num_layers,
40            num_heads: 12,
41            head_dim: 64,
42            device: Device::CPU,
43        }
44    }
45}
46
47impl KvCacheHandle for MockKvCacheHandle {
48    fn block_table(&self) -> &BlockTable {
49        &self.block_table
50    }
51
52    fn block_table_mut(&mut self) -> &mut BlockTable {
53        &mut self.block_table
54    }
55
56    fn as_any(&self) -> &dyn std::any::Any {
57        self
58    }
59
60    fn device(&self) -> Device {
61        self.device.clone()
62    }
63
64    fn num_layers(&self) -> usize {
65        self.num_layers
66    }
67
68    fn num_heads(&self) -> usize {
69        self.num_heads
70    }
71
72    fn head_dim(&self) -> usize {
73        self.head_dim
74    }
75
76    fn key_cache(&self, _layer: usize) -> Result<Option<TensorRef>> {
77        Ok(None)
78    }
79
80    fn value_cache(&self, _layer: usize) -> Result<Option<TensorRef>> {
81        Ok(None)
82    }
83
84    fn clone_handle(&self) -> Result<Arc<dyn KvCacheHandle>> {
85        Ok(Arc::new(MockKvCacheHandle {
86            request_id: self.request_id.clone(),
87            block_table: self.block_table.clone(),
88            num_layers: self.num_layers,
89            num_heads: self.num_heads,
90            head_dim: self.head_dim,
91            device: self.device.clone(),
92        }))
93    }
94
95    fn stats(&self) -> CacheHandleStats {
96        CacheHandleStats {
97            memory_bytes: self.block_table.num_blocks()
98                * 16
99                * self.num_layers
100                * self.num_heads
101                * self.head_dim
102                * 2,
103            blocks_allocated: self.block_table.num_blocks(),
104            tokens_stored: self.block_table.sequence_length,
105            utilization: if self.block_table.num_blocks() > 0 {
106                self.block_table.sequence_length as f32
107                    / (self.block_table.num_blocks() * 16) as f32
108            } else {
109                0.0
110            },
111            last_access: std::time::Instant::now(),
112        }
113    }
114
115    fn is_valid(&self) -> bool {
116        true
117    }
118
119    fn cache_id(&self) -> String {
120        format!("mock_{}", self.request_id)
121    }
122}
123
124/// Mock KV cache manager — tracks allocations in memory, simulates block limits.
125pub struct MockKvCacheManager {
126    handles: RwLock<HashMap<RequestId, Arc<dyn KvCacheHandle>>>,
127    total_blocks: usize,
128    block_size: usize,
129    allocation_count: AtomicU64,
130    deallocation_count: AtomicU64,
131}
132
133impl MockKvCacheManager {
134    /// Create with a fixed total block budget.
135    pub fn new(total_blocks: usize) -> Self {
136        Self {
137            handles: RwLock::new(HashMap::new()),
138            total_blocks,
139            block_size: 16,
140            allocation_count: AtomicU64::new(0),
141            deallocation_count: AtomicU64::new(0),
142        }
143    }
144
145    pub fn active_count(&self) -> usize {
146        self.handles.read().len()
147    }
148}
149
150#[async_trait]
151impl KvCacheManager for MockKvCacheManager {
152    async fn allocate(&self, request: &AllocationRequest) -> Result<Arc<dyn KvCacheHandle>> {
153        let blocks_needed =
154            BlockTable::blocks_needed_for_length(request.initial_tokens, self.block_size);
155
156        // Check block budget
157        let used_blocks: usize = self
158            .handles
159            .read()
160            .values()
161            .map(|h| h.block_table().num_blocks())
162            .sum();
163
164        if used_blocks + blocks_needed > self.total_blocks {
165            return Err(ferrum_types::FerrumError::backend(format!(
166                "OOM: need {} blocks, have {} free out of {}",
167                blocks_needed,
168                self.total_blocks - used_blocks,
169                self.total_blocks
170            )));
171        }
172
173        let handle: Arc<dyn KvCacheHandle> = Arc::new(MockKvCacheHandle::new(
174            request.request_id.clone(),
175            request.num_layers,
176            request.initial_tokens,
177        ));
178
179        self.handles
180            .write()
181            .insert(request.request_id.clone(), handle.clone());
182        self.allocation_count.fetch_add(1, Ordering::Relaxed);
183
184        Ok(handle)
185    }
186
187    async fn extend(
188        &self,
189        _handle: &mut dyn KvCacheHandle,
190        _additional_tokens: usize,
191    ) -> Result<()> {
192        // Mock: no-op, real impl would allocate more blocks
193        Ok(())
194    }
195
196    async fn deallocate(&self, request_id: RequestId) -> Result<()> {
197        self.handles.write().remove(&request_id);
198        self.deallocation_count.fetch_add(1, Ordering::Relaxed);
199        Ok(())
200    }
201
202    fn can_allocate(&self, request: &AllocationRequest) -> bool {
203        let blocks_needed =
204            BlockTable::blocks_needed_for_length(request.initial_tokens, self.block_size);
205        let used_blocks: usize = self
206            .handles
207            .read()
208            .values()
209            .map(|h| h.block_table().num_blocks())
210            .sum();
211        used_blocks + blocks_needed <= self.total_blocks
212    }
213
214    fn stats(&self) -> CacheManagerStats {
215        let handles = self.handles.read();
216        let used_blocks: usize = handles.values().map(|h| h.block_table().num_blocks()).sum();
217        CacheManagerStats {
218            total_memory_bytes: self.total_blocks * self.block_size * 1024,
219            used_memory_bytes: used_blocks * self.block_size * 1024,
220            active_caches: handles.len(),
221            total_blocks: self.total_blocks,
222            free_blocks: self.total_blocks - used_blocks,
223            cache_hit_rate: 0.0,
224            eviction_count: 0,
225            allocation_count: self.allocation_count.load(Ordering::Relaxed),
226            allocation_failures: 0,
227        }
228    }
229
230    async fn gc(&self) -> Result<CacheGcStats> {
231        Ok(CacheGcStats {
232            memory_freed: 0,
233            caches_freed: 0,
234            gc_time_ms: 0,
235        })
236    }
237
238    fn set_pressure_callback(&self, _callback: Box<dyn Fn(MemoryPressure) + Send + Sync>) {
239        // Mock: no-op
240    }
241
242    fn get_handle(&self, request_id: RequestId) -> Option<Arc<dyn KvCacheHandle>> {
243        self.handles.read().get(&request_id).cloned()
244    }
245
246    fn list_handles(&self) -> Vec<(RequestId, Arc<dyn KvCacheHandle>)> {
247        self.handles
248            .read()
249            .iter()
250            .map(|(k, v)| (k.clone(), v.clone()))
251            .collect()
252    }
253}