Skip to main content

ferrum_kv/managers/
default.rs

1//! Default KV cache manager - MVP placeholder implementation
2
3use crate::blocks::{BlockPool, DefaultKvCacheHandle};
4use async_trait::async_trait;
5use ferrum_interfaces::{
6    kv_cache::{AllocationRequest, CacheGcStats, CacheManagerStats, MemoryPressure},
7    KvCacheHandle, KvCacheManager,
8};
9use ferrum_types::{DataType, Device, FerrumError, RequestId, Result};
10use parking_lot::{Mutex, RwLock};
11use std::collections::HashMap;
12use std::sync::Arc;
13use tracing::debug;
14
15/// Default KV cache manager - MVP implementation
16pub struct DefaultKvCacheManager {
17    device: Device,
18    block_size: usize,
19    max_blocks: usize,
20    _gpu_pool: Option<BlockPool>,
21    _cpu_pool: Option<BlockPool>,
22    active_handles: RwLock<HashMap<RequestId, Arc<dyn KvCacheHandle>>>,
23    stats: Mutex<CacheManagerStats>,
24    #[allow(clippy::type_complexity)]
25    pressure_callback: Mutex<Option<Box<dyn Fn(MemoryPressure) + Send + Sync>>>,
26}
27
28impl DefaultKvCacheManager {
29    pub fn new(device: Device, block_size: usize, max_blocks: usize) -> Result<Self> {
30        debug!(
31            "Creating KV cache manager: device={:?}, block_size={}, max_blocks={}",
32            device, block_size, max_blocks
33        );
34
35        let gpu_pool = if device.is_gpu() {
36            Some(BlockPool::new(
37                device.clone(),
38                block_size,
39                DataType::FP16,
40                max_blocks,
41            )?)
42        } else {
43            None
44        };
45
46        let cpu_pool = Some(BlockPool::new(
47            Device::CPU,
48            block_size,
49            DataType::FP16,
50            max_blocks / 2,
51        )?);
52
53        Ok(Self {
54            device,
55            block_size,
56            max_blocks,
57            _gpu_pool: gpu_pool,
58            _cpu_pool: cpu_pool,
59            active_handles: RwLock::new(HashMap::new()),
60            stats: Mutex::new(CacheManagerStats {
61                total_memory_bytes: 0,
62                used_memory_bytes: 0,
63                active_caches: 0,
64                total_blocks: max_blocks,
65                free_blocks: max_blocks,
66                cache_hit_rate: 0.0,
67                eviction_count: 0,
68                allocation_count: 0,
69                allocation_failures: 0,
70            }),
71            pressure_callback: Mutex::new(None),
72        })
73    }
74}
75
76#[async_trait]
77impl KvCacheManager for DefaultKvCacheManager {
78    async fn allocate(&self, request: &AllocationRequest) -> Result<Arc<dyn KvCacheHandle>> {
79        debug!("Allocating KV cache for request: {:?}", request.request_id);
80
81        // MVP: Create a simple handle
82        let handle = DefaultKvCacheHandle::new(request.request_id.clone(), self.block_size, 0);
83
84        let handle_arc: Arc<dyn KvCacheHandle> = Arc::new(handle);
85
86        self.active_handles
87            .write()
88            .insert(request.request_id.clone(), handle_arc.clone());
89
90        // Update stats
91        {
92            let mut stats = self.stats.lock();
93            stats.active_caches += 1;
94            stats.allocation_count += 1;
95        }
96
97        Ok(handle_arc)
98    }
99
100    async fn extend(
101        &self,
102        _handle: &mut dyn KvCacheHandle,
103        _additional_tokens: usize,
104    ) -> Result<()> {
105        // MVP: Not yet implemented
106        Err(FerrumError::model("MVP: extend not yet implemented"))
107    }
108
109    async fn deallocate(&self, request_id: RequestId) -> Result<()> {
110        debug!("Deallocating KV cache for request: {:?}", request_id);
111
112        self.active_handles.write().remove(&request_id);
113
114        // Update stats
115        {
116            let mut stats = self.stats.lock();
117            if stats.active_caches > 0 {
118                stats.active_caches -= 1;
119            }
120        }
121
122        Ok(())
123    }
124
125    fn can_allocate(&self, _request: &AllocationRequest) -> bool {
126        // MVP: always allow allocation
127        let active_count = self.active_handles.read().len();
128        active_count < self.max_blocks
129    }
130
131    fn stats(&self) -> CacheManagerStats {
132        self.stats.lock().clone()
133    }
134
135    async fn gc(&self) -> Result<CacheGcStats> {
136        // MVP: No garbage collection
137        Ok(CacheGcStats {
138            memory_freed: 0,
139            caches_freed: 0,
140            gc_time_ms: 0,
141        })
142    }
143
144    fn set_pressure_callback(&self, callback: Box<dyn Fn(MemoryPressure) + Send + Sync>) {
145        *self.pressure_callback.lock() = Some(callback);
146    }
147
148    fn get_handle(&self, request_id: RequestId) -> Option<Arc<dyn KvCacheHandle>> {
149        self.active_handles.read().get(&request_id).cloned()
150    }
151
152    fn list_handles(&self) -> Vec<(RequestId, Arc<dyn KvCacheHandle>)> {
153        self.active_handles
154            .read()
155            .iter()
156            .map(|(id, handle)| (id.clone(), handle.clone()))
157            .collect()
158    }
159}
160
161impl std::fmt::Debug for DefaultKvCacheManager {
162    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163        f.debug_struct("DefaultKvCacheManager")
164            .field("device", &self.device)
165            .field("block_size", &self.block_size)
166            .field("max_blocks", &self.max_blocks)
167            .field("active_handles_count", &self.active_handles.read().len())
168            .finish()
169    }
170}
171
172// ============================================================================
173// Unit Tests
174// ============================================================================
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179
180    fn create_test_allocation_request() -> AllocationRequest {
181        AllocationRequest {
182            request_id: RequestId::new(),
183            initial_tokens: 10,
184            max_sequence_length: 100,
185            num_layers: 32,
186            num_heads: 32,
187            head_dim: 128,
188            device: Device::CPU,
189            dtype: DataType::FP16,
190            priority: ferrum_types::Priority::Normal,
191        }
192    }
193
194    #[tokio::test]
195    async fn test_manager_creation() {
196        let manager = DefaultKvCacheManager::new(Device::CPU, 16, 100);
197        assert!(manager.is_ok());
198    }
199
200    #[tokio::test]
201    async fn test_allocate_and_deallocate() {
202        let manager = DefaultKvCacheManager::new(Device::CPU, 16, 100).unwrap();
203        let request = create_test_allocation_request();
204        let request_id = request.request_id.clone();
205
206        let handle = manager.allocate(&request).await.unwrap();
207        assert!(handle.is_valid());
208
209        let result = manager.deallocate(request_id).await;
210        assert!(result.is_ok());
211    }
212
213    #[tokio::test]
214    async fn test_can_allocate() {
215        let manager = DefaultKvCacheManager::new(Device::CPU, 16, 10).unwrap();
216        let request = create_test_allocation_request();
217
218        assert!(manager.can_allocate(&request));
219    }
220
221    #[tokio::test]
222    async fn test_stats() {
223        let manager = DefaultKvCacheManager::new(Device::CPU, 16, 100).unwrap();
224        let stats = manager.stats();
225
226        assert_eq!(stats.active_caches, 0);
227        assert_eq!(stats.total_blocks, 100);
228    }
229
230    #[tokio::test]
231    async fn test_get_handle() {
232        let manager = DefaultKvCacheManager::new(Device::CPU, 16, 100).unwrap();
233        let request = create_test_allocation_request();
234        let request_id = request.request_id.clone();
235
236        manager.allocate(&request).await.unwrap();
237
238        let handle = manager.get_handle(request_id);
239        assert!(handle.is_some());
240    }
241}