1use crate::blocks::{BlockPool, DefaultKvCacheHandle};
4use async_trait::async_trait;
5use ferrum_interfaces::{
6 kv_cache::{AllocationRequest, CacheGcStats, CacheManagerStats, MemoryPressure},
7 KvCacheHandle, KvCacheManager,
8};
9use ferrum_types::{DataType, Device, FerrumError, RequestId, Result};
10use parking_lot::{Mutex, RwLock};
11use std::collections::HashMap;
12use std::sync::Arc;
13use tracing::debug;
14
15pub struct DefaultKvCacheManager {
17 device: Device,
18 block_size: usize,
19 max_blocks: usize,
20 _gpu_pool: Option<BlockPool>,
21 _cpu_pool: Option<BlockPool>,
22 active_handles: RwLock<HashMap<RequestId, Arc<dyn KvCacheHandle>>>,
23 stats: Mutex<CacheManagerStats>,
24 #[allow(clippy::type_complexity)]
25 pressure_callback: Mutex<Option<Box<dyn Fn(MemoryPressure) + Send + Sync>>>,
26}
27
28impl DefaultKvCacheManager {
29 pub fn new(device: Device, block_size: usize, max_blocks: usize) -> Result<Self> {
30 debug!(
31 "Creating KV cache manager: device={:?}, block_size={}, max_blocks={}",
32 device, block_size, max_blocks
33 );
34
35 let gpu_pool = if device.is_gpu() {
36 Some(BlockPool::new(
37 device.clone(),
38 block_size,
39 DataType::FP16,
40 max_blocks,
41 )?)
42 } else {
43 None
44 };
45
46 let cpu_pool = Some(BlockPool::new(
47 Device::CPU,
48 block_size,
49 DataType::FP16,
50 max_blocks / 2,
51 )?);
52
53 Ok(Self {
54 device,
55 block_size,
56 max_blocks,
57 _gpu_pool: gpu_pool,
58 _cpu_pool: cpu_pool,
59 active_handles: RwLock::new(HashMap::new()),
60 stats: Mutex::new(CacheManagerStats {
61 total_memory_bytes: 0,
62 used_memory_bytes: 0,
63 active_caches: 0,
64 total_blocks: max_blocks,
65 free_blocks: max_blocks,
66 cache_hit_rate: 0.0,
67 eviction_count: 0,
68 allocation_count: 0,
69 allocation_failures: 0,
70 }),
71 pressure_callback: Mutex::new(None),
72 })
73 }
74}
75
76#[async_trait]
77impl KvCacheManager for DefaultKvCacheManager {
78 async fn allocate(&self, request: &AllocationRequest) -> Result<Arc<dyn KvCacheHandle>> {
79 debug!("Allocating KV cache for request: {:?}", request.request_id);
80
81 let handle = DefaultKvCacheHandle::new(request.request_id.clone(), self.block_size, 0);
83
84 let handle_arc: Arc<dyn KvCacheHandle> = Arc::new(handle);
85
86 self.active_handles
87 .write()
88 .insert(request.request_id.clone(), handle_arc.clone());
89
90 {
92 let mut stats = self.stats.lock();
93 stats.active_caches += 1;
94 stats.allocation_count += 1;
95 }
96
97 Ok(handle_arc)
98 }
99
100 async fn extend(
101 &self,
102 _handle: &mut dyn KvCacheHandle,
103 _additional_tokens: usize,
104 ) -> Result<()> {
105 Err(FerrumError::model("MVP: extend not yet implemented"))
107 }
108
109 async fn deallocate(&self, request_id: RequestId) -> Result<()> {
110 debug!("Deallocating KV cache for request: {:?}", request_id);
111
112 self.active_handles.write().remove(&request_id);
113
114 {
116 let mut stats = self.stats.lock();
117 if stats.active_caches > 0 {
118 stats.active_caches -= 1;
119 }
120 }
121
122 Ok(())
123 }
124
125 fn can_allocate(&self, _request: &AllocationRequest) -> bool {
126 let active_count = self.active_handles.read().len();
128 active_count < self.max_blocks
129 }
130
131 fn stats(&self) -> CacheManagerStats {
132 self.stats.lock().clone()
133 }
134
135 async fn gc(&self) -> Result<CacheGcStats> {
136 Ok(CacheGcStats {
138 memory_freed: 0,
139 caches_freed: 0,
140 gc_time_ms: 0,
141 })
142 }
143
144 fn set_pressure_callback(&self, callback: Box<dyn Fn(MemoryPressure) + Send + Sync>) {
145 *self.pressure_callback.lock() = Some(callback);
146 }
147
148 fn get_handle(&self, request_id: RequestId) -> Option<Arc<dyn KvCacheHandle>> {
149 self.active_handles.read().get(&request_id).cloned()
150 }
151
152 fn list_handles(&self) -> Vec<(RequestId, Arc<dyn KvCacheHandle>)> {
153 self.active_handles
154 .read()
155 .iter()
156 .map(|(id, handle)| (id.clone(), handle.clone()))
157 .collect()
158 }
159}
160
161impl std::fmt::Debug for DefaultKvCacheManager {
162 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163 f.debug_struct("DefaultKvCacheManager")
164 .field("device", &self.device)
165 .field("block_size", &self.block_size)
166 .field("max_blocks", &self.max_blocks)
167 .field("active_handles_count", &self.active_handles.read().len())
168 .finish()
169 }
170}
171
172#[cfg(test)]
177mod tests {
178 use super::*;
179
180 fn create_test_allocation_request() -> AllocationRequest {
181 AllocationRequest {
182 request_id: RequestId::new(),
183 initial_tokens: 10,
184 max_sequence_length: 100,
185 num_layers: 32,
186 num_heads: 32,
187 head_dim: 128,
188 device: Device::CPU,
189 dtype: DataType::FP16,
190 priority: ferrum_types::Priority::Normal,
191 }
192 }
193
194 #[tokio::test]
195 async fn test_manager_creation() {
196 let manager = DefaultKvCacheManager::new(Device::CPU, 16, 100);
197 assert!(manager.is_ok());
198 }
199
200 #[tokio::test]
201 async fn test_allocate_and_deallocate() {
202 let manager = DefaultKvCacheManager::new(Device::CPU, 16, 100).unwrap();
203 let request = create_test_allocation_request();
204 let request_id = request.request_id.clone();
205
206 let handle = manager.allocate(&request).await.unwrap();
207 assert!(handle.is_valid());
208
209 let result = manager.deallocate(request_id).await;
210 assert!(result.is_ok());
211 }
212
213 #[tokio::test]
214 async fn test_can_allocate() {
215 let manager = DefaultKvCacheManager::new(Device::CPU, 16, 10).unwrap();
216 let request = create_test_allocation_request();
217
218 assert!(manager.can_allocate(&request));
219 }
220
221 #[tokio::test]
222 async fn test_stats() {
223 let manager = DefaultKvCacheManager::new(Device::CPU, 16, 100).unwrap();
224 let stats = manager.stats();
225
226 assert_eq!(stats.active_caches, 0);
227 assert_eq!(stats.total_blocks, 100);
228 }
229
230 #[tokio::test]
231 async fn test_get_handle() {
232 let manager = DefaultKvCacheManager::new(Device::CPU, 16, 100).unwrap();
233 let request = create_test_allocation_request();
234 let request_id = request.request_id.clone();
235
236 manager.allocate(&request).await.unwrap();
237
238 let handle = manager.get_handle(request_id);
239 assert!(handle.is_some());
240 }
241}