ferrum_kv/blocks/
handle.rs1use ferrum_interfaces::{kv_cache::CacheHandleStats, BlockTable, KvCacheHandle, TensorRef};
4use ferrum_types::{Device, RequestId, Result};
5use std::sync::Arc;
6
7#[derive(Debug)]
9pub struct DefaultKvCacheHandle {
10 block_table: BlockTable,
11 device: Device,
12 num_layers: usize,
13 num_heads: usize,
14 head_dim: usize,
15 cache_id: String,
16}
17
18impl DefaultKvCacheHandle {
19 pub fn new(request_id: RequestId, block_size: usize, num_tokens: usize) -> Self {
20 let mut block_table = BlockTable::new(block_size);
21 block_table.sequence_length = num_tokens;
22
23 Self {
24 cache_id: format!("cache_{}", request_id),
25 block_table,
26 device: Device::CPU,
27 num_layers: 32, num_heads: 32, head_dim: 128, }
31 }
32
33 pub fn set_num_tokens(&mut self, num_tokens: usize) {
34 self.block_table.sequence_length = num_tokens;
35 }
36}
37
38impl KvCacheHandle for DefaultKvCacheHandle {
39 fn block_table(&self) -> &BlockTable {
40 &self.block_table
41 }
42
43 fn block_table_mut(&mut self) -> &mut BlockTable {
44 &mut self.block_table
45 }
46
47 fn as_any(&self) -> &dyn std::any::Any {
48 self
49 }
50
51 fn device(&self) -> Device {
52 self.device.clone()
53 }
54
55 fn num_layers(&self) -> usize {
56 self.num_layers
57 }
58
59 fn num_heads(&self) -> usize {
60 self.num_heads
61 }
62
63 fn head_dim(&self) -> usize {
64 self.head_dim
65 }
66
67 fn key_cache(&self, _layer: usize) -> Result<Option<TensorRef>> {
68 Ok(None)
70 }
71
72 fn value_cache(&self, _layer: usize) -> Result<Option<TensorRef>> {
73 Ok(None)
75 }
76
77 fn clone_handle(&self) -> Result<Arc<dyn KvCacheHandle>> {
78 Err(ferrum_types::FerrumError::model(
79 "MVP: Handle cloning not yet implemented",
80 ))
81 }
82
83 fn stats(&self) -> CacheHandleStats {
84 CacheHandleStats {
85 memory_bytes: self.block_table.num_blocks() * self.block_table.block_size * 128,
86 blocks_allocated: self.block_table.num_blocks(),
87 tokens_stored: self.block_table.sequence_length,
88 utilization: if self.block_table.num_blocks() > 0 {
89 self.block_table.sequence_length as f32
90 / (self.block_table.num_blocks() * self.block_table.block_size) as f32
91 } else {
92 0.0
93 },
94 last_access: std::time::Instant::now(),
95 }
96 }
97
98 fn is_valid(&self) -> bool {
99 true }
101
102 fn cache_id(&self) -> String {
103 self.cache_id.clone()
104 }
105}
106
107#[cfg(test)]
112mod tests {
113 use super::*;
114
115 #[test]
116 fn test_handle_creation() {
117 let request_id = RequestId::new();
118 let handle = DefaultKvCacheHandle::new(request_id.clone(), 16, 10);
119
120 assert!(handle.cache_id().contains(&request_id.to_string()));
121 assert_eq!(handle.block_table().block_size, 16);
122 assert_eq!(handle.block_table().sequence_length, 10);
123 assert!(handle.is_valid());
124 }
125
126 #[test]
127 fn test_handle_set_num_tokens() {
128 let request_id = RequestId::new();
129 let mut handle = DefaultKvCacheHandle::new(request_id, 16, 10);
130
131 handle.set_num_tokens(50);
132 assert_eq!(handle.block_table().sequence_length, 50);
133 }
134
135 #[test]
136 fn test_handle_device() {
137 let request_id = RequestId::new();
138 let handle = DefaultKvCacheHandle::new(request_id, 16, 10);
139
140 assert!(matches!(handle.device(), Device::CPU));
141 }
142
143 #[test]
144 fn test_handle_dimensions() {
145 let request_id = RequestId::new();
146 let handle = DefaultKvCacheHandle::new(request_id, 16, 10);
147
148 assert_eq!(handle.num_layers(), 32);
149 assert_eq!(handle.num_heads(), 32);
150 assert_eq!(handle.head_dim(), 128);
151 }
152
153 #[test]
154 fn test_handle_block_table() {
155 let request_id = RequestId::new();
156 let handle = DefaultKvCacheHandle::new(request_id, 16, 10);
157
158 let block_table = handle.block_table();
159 assert_eq!(block_table.block_size, 16);
160 assert_eq!(block_table.sequence_length, 10);
161 }
162
163 #[test]
164 fn test_handle_block_table_mut() {
165 let request_id = RequestId::new();
166 let mut handle = DefaultKvCacheHandle::new(request_id, 16, 10);
167
168 let block_table = handle.block_table_mut();
169 block_table.sequence_length = 20;
170
171 assert_eq!(handle.block_table().sequence_length, 20);
172 }
173
174 #[test]
175 fn test_handle_stats() {
176 let request_id = RequestId::new();
177 let handle = DefaultKvCacheHandle::new(request_id, 16, 10);
178
179 let stats = handle.stats();
180 assert_eq!(stats.tokens_stored, 10);
181 assert_eq!(stats.blocks_allocated, handle.block_table().num_blocks());
182 assert_eq!(stats.memory_bytes, 0);
184 assert!(stats.utilization >= 0.0 && stats.utilization <= 1.0);
185 }
186
187 #[test]
188 fn test_handle_cache_id() {
189 let request_id = RequestId::new();
190 let handle = DefaultKvCacheHandle::new(request_id.clone(), 16, 10);
191
192 let cache_id = handle.cache_id();
193 assert!(cache_id.contains("cache_"));
194 assert!(cache_id.contains(&request_id.to_string()));
195 }
196
197 #[test]
198 fn test_handle_is_valid() {
199 let request_id = RequestId::new();
200 let handle = DefaultKvCacheHandle::new(request_id, 16, 10);
201
202 assert!(handle.is_valid());
204 }
205
206 #[test]
207 fn test_handle_key_cache() {
208 let request_id = RequestId::new();
209 let handle = DefaultKvCacheHandle::new(request_id, 16, 10);
210
211 let result = handle.key_cache(0);
213 assert!(result.is_ok());
214 assert!(result.unwrap().is_none());
215 }
216
217 #[test]
218 fn test_handle_value_cache() {
219 let request_id = RequestId::new();
220 let handle = DefaultKvCacheHandle::new(request_id, 16, 10);
221
222 let result = handle.value_cache(0);
224 assert!(result.is_ok());
225 assert!(result.unwrap().is_none());
226 }
227
228 #[test]
229 fn test_handle_clone_not_implemented() {
230 let request_id = RequestId::new();
231 let handle = DefaultKvCacheHandle::new(request_id, 16, 10);
232
233 let result = handle.clone_handle();
235 assert!(result.is_err());
236 }
237
238 #[test]
239 fn test_handle_as_any() {
240 let request_id = RequestId::new();
241 let handle = DefaultKvCacheHandle::new(request_id, 16, 10);
242
243 let any = handle.as_any();
244 assert!(any.downcast_ref::<DefaultKvCacheHandle>().is_some());
245 }
246
247 #[test]
248 fn test_handle_debug_format() {
249 let request_id = RequestId::new();
250 let handle = DefaultKvCacheHandle::new(request_id, 16, 10);
251
252 let debug_str = format!("{:?}", handle);
253 assert!(debug_str.contains("DefaultKvCacheHandle"));
254 }
255
256 #[test]
257 fn test_handle_with_different_block_sizes() {
258 let request_id = RequestId::new();
259
260 let handle_16 = DefaultKvCacheHandle::new(request_id.clone(), 16, 10);
261 let handle_32 = DefaultKvCacheHandle::new(request_id.clone(), 32, 10);
262
263 assert_eq!(handle_16.block_table().block_size, 16);
264 assert_eq!(handle_32.block_table().block_size, 32);
265 }
266
267 #[test]
268 fn test_handle_stats_utilization() {
269 let request_id = RequestId::new();
270 let handle = DefaultKvCacheHandle::new(request_id, 16, 8);
271
272 let stats = handle.stats();
273
274 if stats.blocks_allocated > 0 {
276 assert!(stats.utilization >= 0.0);
277 assert!(stats.utilization <= 1.0);
278 }
279 }
280
281 #[test]
282 fn test_handle_zero_tokens() {
283 let request_id = RequestId::new();
284 let handle = DefaultKvCacheHandle::new(request_id, 16, 0);
285
286 assert_eq!(handle.block_table().sequence_length, 0);
287 let stats = handle.stats();
288 assert_eq!(stats.tokens_stored, 0);
289 }
290}