1use async_trait::async_trait;
4use ferrum_interfaces::{
5 kv_cache::{
6 AllocationRequest, CacheGcStats, CacheHandleStats, CacheManagerStats, MemoryPressure,
7 },
8 BlockTable, KvCacheHandle, KvCacheManager, TensorRef,
9};
10use ferrum_types::{Device, RequestId, Result};
11use parking_lot::RwLock;
12use std::collections::HashMap;
13use std::sync::atomic::{AtomicU64, Ordering};
14use std::sync::Arc;
15
16#[derive(Debug)]
18pub struct MockKvCacheHandle {
19 request_id: RequestId,
20 block_table: BlockTable,
21 num_layers: usize,
22 num_heads: usize,
23 head_dim: usize,
24 device: Device,
25}
26
27impl MockKvCacheHandle {
28 pub fn new(request_id: RequestId, num_layers: usize, seq_len: usize) -> Self {
29 let mut block_table = BlockTable::new(16);
30 block_table.sequence_length = seq_len;
31 let blocks_needed = BlockTable::blocks_needed_for_length(seq_len, 16);
33 let block_ids: Vec<u32> = (0..blocks_needed as u32).collect();
34 block_table.add_blocks(&block_ids);
35
36 Self {
37 request_id,
38 block_table,
39 num_layers,
40 num_heads: 12,
41 head_dim: 64,
42 device: Device::CPU,
43 }
44 }
45}
46
47impl KvCacheHandle for MockKvCacheHandle {
48 fn block_table(&self) -> &BlockTable {
49 &self.block_table
50 }
51
52 fn block_table_mut(&mut self) -> &mut BlockTable {
53 &mut self.block_table
54 }
55
56 fn as_any(&self) -> &dyn std::any::Any {
57 self
58 }
59
60 fn device(&self) -> Device {
61 self.device.clone()
62 }
63
64 fn num_layers(&self) -> usize {
65 self.num_layers
66 }
67
68 fn num_heads(&self) -> usize {
69 self.num_heads
70 }
71
72 fn head_dim(&self) -> usize {
73 self.head_dim
74 }
75
76 fn key_cache(&self, _layer: usize) -> Result<Option<TensorRef>> {
77 Ok(None)
78 }
79
80 fn value_cache(&self, _layer: usize) -> Result<Option<TensorRef>> {
81 Ok(None)
82 }
83
84 fn clone_handle(&self) -> Result<Arc<dyn KvCacheHandle>> {
85 Ok(Arc::new(MockKvCacheHandle {
86 request_id: self.request_id.clone(),
87 block_table: self.block_table.clone(),
88 num_layers: self.num_layers,
89 num_heads: self.num_heads,
90 head_dim: self.head_dim,
91 device: self.device.clone(),
92 }))
93 }
94
95 fn stats(&self) -> CacheHandleStats {
96 CacheHandleStats {
97 memory_bytes: self.block_table.num_blocks()
98 * 16
99 * self.num_layers
100 * self.num_heads
101 * self.head_dim
102 * 2,
103 blocks_allocated: self.block_table.num_blocks(),
104 tokens_stored: self.block_table.sequence_length,
105 utilization: if self.block_table.num_blocks() > 0 {
106 self.block_table.sequence_length as f32
107 / (self.block_table.num_blocks() * 16) as f32
108 } else {
109 0.0
110 },
111 last_access: std::time::Instant::now(),
112 }
113 }
114
115 fn is_valid(&self) -> bool {
116 true
117 }
118
119 fn cache_id(&self) -> String {
120 format!("mock_{}", self.request_id)
121 }
122}
123
124pub struct MockKvCacheManager {
126 handles: RwLock<HashMap<RequestId, Arc<dyn KvCacheHandle>>>,
127 total_blocks: usize,
128 block_size: usize,
129 allocation_count: AtomicU64,
130 deallocation_count: AtomicU64,
131}
132
133impl MockKvCacheManager {
134 pub fn new(total_blocks: usize) -> Self {
136 Self {
137 handles: RwLock::new(HashMap::new()),
138 total_blocks,
139 block_size: 16,
140 allocation_count: AtomicU64::new(0),
141 deallocation_count: AtomicU64::new(0),
142 }
143 }
144
145 pub fn active_count(&self) -> usize {
146 self.handles.read().len()
147 }
148}
149
150#[async_trait]
151impl KvCacheManager for MockKvCacheManager {
152 async fn allocate(&self, request: &AllocationRequest) -> Result<Arc<dyn KvCacheHandle>> {
153 let blocks_needed =
154 BlockTable::blocks_needed_for_length(request.initial_tokens, self.block_size);
155
156 let used_blocks: usize = self
158 .handles
159 .read()
160 .values()
161 .map(|h| h.block_table().num_blocks())
162 .sum();
163
164 if used_blocks + blocks_needed > self.total_blocks {
165 return Err(ferrum_types::FerrumError::backend(format!(
166 "OOM: need {} blocks, have {} free out of {}",
167 blocks_needed,
168 self.total_blocks - used_blocks,
169 self.total_blocks
170 )));
171 }
172
173 let handle: Arc<dyn KvCacheHandle> = Arc::new(MockKvCacheHandle::new(
174 request.request_id.clone(),
175 request.num_layers,
176 request.initial_tokens,
177 ));
178
179 self.handles
180 .write()
181 .insert(request.request_id.clone(), handle.clone());
182 self.allocation_count.fetch_add(1, Ordering::Relaxed);
183
184 Ok(handle)
185 }
186
187 async fn extend(
188 &self,
189 _handle: &mut dyn KvCacheHandle,
190 _additional_tokens: usize,
191 ) -> Result<()> {
192 Ok(())
194 }
195
196 async fn deallocate(&self, request_id: RequestId) -> Result<()> {
197 self.handles.write().remove(&request_id);
198 self.deallocation_count.fetch_add(1, Ordering::Relaxed);
199 Ok(())
200 }
201
202 fn can_allocate(&self, request: &AllocationRequest) -> bool {
203 let blocks_needed =
204 BlockTable::blocks_needed_for_length(request.initial_tokens, self.block_size);
205 let used_blocks: usize = self
206 .handles
207 .read()
208 .values()
209 .map(|h| h.block_table().num_blocks())
210 .sum();
211 used_blocks + blocks_needed <= self.total_blocks
212 }
213
214 fn stats(&self) -> CacheManagerStats {
215 let handles = self.handles.read();
216 let used_blocks: usize = handles.values().map(|h| h.block_table().num_blocks()).sum();
217 CacheManagerStats {
218 total_memory_bytes: self.total_blocks * self.block_size * 1024,
219 used_memory_bytes: used_blocks * self.block_size * 1024,
220 active_caches: handles.len(),
221 total_blocks: self.total_blocks,
222 free_blocks: self.total_blocks - used_blocks,
223 cache_hit_rate: 0.0,
224 eviction_count: 0,
225 allocation_count: self.allocation_count.load(Ordering::Relaxed),
226 allocation_failures: 0,
227 }
228 }
229
230 async fn gc(&self) -> Result<CacheGcStats> {
231 Ok(CacheGcStats {
232 memory_freed: 0,
233 caches_freed: 0,
234 gc_time_ms: 0,
235 })
236 }
237
238 fn set_pressure_callback(&self, _callback: Box<dyn Fn(MemoryPressure) + Send + Sync>) {
239 }
241
242 fn get_handle(&self, request_id: RequestId) -> Option<Arc<dyn KvCacheHandle>> {
243 self.handles.read().get(&request_id).cloned()
244 }
245
246 fn list_handles(&self) -> Vec<(RequestId, Arc<dyn KvCacheHandle>)> {
247 self.handles
248 .read()
249 .iter()
250 .map(|(k, v)| (k.clone(), v.clone()))
251 .collect()
252 }
253}