Skip to main content

ferrum_kv/blocks/
handle.rs

1//! KV Cache handle - MVP placeholder implementation
2
3use ferrum_interfaces::{kv_cache::CacheHandleStats, BlockTable, KvCacheHandle, TensorRef};
4use ferrum_types::{Device, RequestId, Result};
5use std::sync::Arc;
6
7/// Default KV cache handle - MVP implementation
8#[derive(Debug)]
9pub struct DefaultKvCacheHandle {
10    block_table: BlockTable,
11    device: Device,
12    num_layers: usize,
13    num_heads: usize,
14    head_dim: usize,
15    cache_id: String,
16}
17
18impl DefaultKvCacheHandle {
19    pub fn new(request_id: RequestId, block_size: usize, num_tokens: usize) -> Self {
20        let mut block_table = BlockTable::new(block_size);
21        block_table.sequence_length = num_tokens;
22
23        Self {
24            cache_id: format!("cache_{}", request_id),
25            block_table,
26            device: Device::CPU,
27            num_layers: 32, // Default placeholder
28            num_heads: 32,  // Default placeholder
29            head_dim: 128,  // Default placeholder
30        }
31    }
32
33    pub fn set_num_tokens(&mut self, num_tokens: usize) {
34        self.block_table.sequence_length = num_tokens;
35    }
36}
37
38impl KvCacheHandle for DefaultKvCacheHandle {
39    fn block_table(&self) -> &BlockTable {
40        &self.block_table
41    }
42
43    fn block_table_mut(&mut self) -> &mut BlockTable {
44        &mut self.block_table
45    }
46
47    fn as_any(&self) -> &dyn std::any::Any {
48        self
49    }
50
51    fn device(&self) -> Device {
52        self.device.clone()
53    }
54
55    fn num_layers(&self) -> usize {
56        self.num_layers
57    }
58
59    fn num_heads(&self) -> usize {
60        self.num_heads
61    }
62
63    fn head_dim(&self) -> usize {
64        self.head_dim
65    }
66
67    fn key_cache(&self, _layer: usize) -> Result<Option<TensorRef>> {
68        // MVP: return None (not yet implemented)
69        Ok(None)
70    }
71
72    fn value_cache(&self, _layer: usize) -> Result<Option<TensorRef>> {
73        // MVP: return None (not yet implemented)
74        Ok(None)
75    }
76
77    fn clone_handle(&self) -> Result<Arc<dyn KvCacheHandle>> {
78        Err(ferrum_types::FerrumError::model(
79            "MVP: Handle cloning not yet implemented",
80        ))
81    }
82
83    fn stats(&self) -> CacheHandleStats {
84        CacheHandleStats {
85            memory_bytes: self.block_table.num_blocks() * self.block_table.block_size * 128,
86            blocks_allocated: self.block_table.num_blocks(),
87            tokens_stored: self.block_table.sequence_length,
88            utilization: if self.block_table.num_blocks() > 0 {
89                self.block_table.sequence_length as f32
90                    / (self.block_table.num_blocks() * self.block_table.block_size) as f32
91            } else {
92                0.0
93            },
94            last_access: std::time::Instant::now(),
95        }
96    }
97
98    fn is_valid(&self) -> bool {
99        true // MVP: always valid
100    }
101
102    fn cache_id(&self) -> String {
103        self.cache_id.clone()
104    }
105}
106
107// ============================================================================
108// 内联单元测试
109// ============================================================================
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114
115    #[test]
116    fn test_handle_creation() {
117        let request_id = RequestId::new();
118        let handle = DefaultKvCacheHandle::new(request_id.clone(), 16, 10);
119
120        assert!(handle.cache_id().contains(&request_id.to_string()));
121        assert_eq!(handle.block_table().block_size, 16);
122        assert_eq!(handle.block_table().sequence_length, 10);
123        assert!(handle.is_valid());
124    }
125
126    #[test]
127    fn test_handle_set_num_tokens() {
128        let request_id = RequestId::new();
129        let mut handle = DefaultKvCacheHandle::new(request_id, 16, 10);
130
131        handle.set_num_tokens(50);
132        assert_eq!(handle.block_table().sequence_length, 50);
133    }
134
135    #[test]
136    fn test_handle_device() {
137        let request_id = RequestId::new();
138        let handle = DefaultKvCacheHandle::new(request_id, 16, 10);
139
140        assert!(matches!(handle.device(), Device::CPU));
141    }
142
143    #[test]
144    fn test_handle_dimensions() {
145        let request_id = RequestId::new();
146        let handle = DefaultKvCacheHandle::new(request_id, 16, 10);
147
148        assert_eq!(handle.num_layers(), 32);
149        assert_eq!(handle.num_heads(), 32);
150        assert_eq!(handle.head_dim(), 128);
151    }
152
153    #[test]
154    fn test_handle_block_table() {
155        let request_id = RequestId::new();
156        let handle = DefaultKvCacheHandle::new(request_id, 16, 10);
157
158        let block_table = handle.block_table();
159        assert_eq!(block_table.block_size, 16);
160        assert_eq!(block_table.sequence_length, 10);
161    }
162
163    #[test]
164    fn test_handle_block_table_mut() {
165        let request_id = RequestId::new();
166        let mut handle = DefaultKvCacheHandle::new(request_id, 16, 10);
167
168        let block_table = handle.block_table_mut();
169        block_table.sequence_length = 20;
170
171        assert_eq!(handle.block_table().sequence_length, 20);
172    }
173
174    #[test]
175    fn test_handle_stats() {
176        let request_id = RequestId::new();
177        let handle = DefaultKvCacheHandle::new(request_id, 16, 10);
178
179        let stats = handle.stats();
180        assert_eq!(stats.tokens_stored, 10);
181        assert_eq!(stats.blocks_allocated, handle.block_table().num_blocks());
182        // No blocks are mapped in MVP handle construction, so memory usage is 0.
183        assert_eq!(stats.memory_bytes, 0);
184        assert!(stats.utilization >= 0.0 && stats.utilization <= 1.0);
185    }
186
187    #[test]
188    fn test_handle_cache_id() {
189        let request_id = RequestId::new();
190        let handle = DefaultKvCacheHandle::new(request_id.clone(), 16, 10);
191
192        let cache_id = handle.cache_id();
193        assert!(cache_id.contains("cache_"));
194        assert!(cache_id.contains(&request_id.to_string()));
195    }
196
197    #[test]
198    fn test_handle_is_valid() {
199        let request_id = RequestId::new();
200        let handle = DefaultKvCacheHandle::new(request_id, 16, 10);
201
202        // MVP implementation always returns true
203        assert!(handle.is_valid());
204    }
205
206    #[test]
207    fn test_handle_key_cache() {
208        let request_id = RequestId::new();
209        let handle = DefaultKvCacheHandle::new(request_id, 16, 10);
210
211        // MVP: should return None
212        let result = handle.key_cache(0);
213        assert!(result.is_ok());
214        assert!(result.unwrap().is_none());
215    }
216
217    #[test]
218    fn test_handle_value_cache() {
219        let request_id = RequestId::new();
220        let handle = DefaultKvCacheHandle::new(request_id, 16, 10);
221
222        // MVP: should return None
223        let result = handle.value_cache(0);
224        assert!(result.is_ok());
225        assert!(result.unwrap().is_none());
226    }
227
228    #[test]
229    fn test_handle_clone_not_implemented() {
230        let request_id = RequestId::new();
231        let handle = DefaultKvCacheHandle::new(request_id, 16, 10);
232
233        // MVP: clone_handle not yet implemented
234        let result = handle.clone_handle();
235        assert!(result.is_err());
236    }
237
238    #[test]
239    fn test_handle_as_any() {
240        let request_id = RequestId::new();
241        let handle = DefaultKvCacheHandle::new(request_id, 16, 10);
242
243        let any = handle.as_any();
244        assert!(any.downcast_ref::<DefaultKvCacheHandle>().is_some());
245    }
246
247    #[test]
248    fn test_handle_debug_format() {
249        let request_id = RequestId::new();
250        let handle = DefaultKvCacheHandle::new(request_id, 16, 10);
251
252        let debug_str = format!("{:?}", handle);
253        assert!(debug_str.contains("DefaultKvCacheHandle"));
254    }
255
256    #[test]
257    fn test_handle_with_different_block_sizes() {
258        let request_id = RequestId::new();
259
260        let handle_16 = DefaultKvCacheHandle::new(request_id.clone(), 16, 10);
261        let handle_32 = DefaultKvCacheHandle::new(request_id.clone(), 32, 10);
262
263        assert_eq!(handle_16.block_table().block_size, 16);
264        assert_eq!(handle_32.block_table().block_size, 32);
265    }
266
267    #[test]
268    fn test_handle_stats_utilization() {
269        let request_id = RequestId::new();
270        let handle = DefaultKvCacheHandle::new(request_id, 16, 8);
271
272        let stats = handle.stats();
273
274        // 如果有blocks,utilization应该在合理范围内
275        if stats.blocks_allocated > 0 {
276            assert!(stats.utilization >= 0.0);
277            assert!(stats.utilization <= 1.0);
278        }
279    }
280
281    #[test]
282    fn test_handle_zero_tokens() {
283        let request_id = RequestId::new();
284        let handle = DefaultKvCacheHandle::new(request_id, 16, 0);
285
286        assert_eq!(handle.block_table().sequence_length, 0);
287        let stats = handle.stats();
288        assert_eq!(stats.tokens_stored, 0);
289    }
290}