1use crate::blocks::{BlockPool, BlockStorageConfig, PhysicalBlockId};
13use crate::cache::prefix::{PrefixCache, PrefixCacheStats, PrefixId};
14use async_trait::async_trait;
15use ferrum_interfaces::{
16 kv_cache::{AllocationRequest, BlockTable, CacheGcStats, CacheManagerStats, MemoryPressure},
17 KvCacheHandle, KvCacheManager, TensorRef,
18};
19use ferrum_types::{DataType, Device, FerrumError, RequestId, Result};
20use parking_lot::{Mutex, RwLock};
21use std::collections::HashMap;
22use std::sync::atomic::{AtomicU64, Ordering};
23use std::sync::Arc;
24use std::time::Instant;
25use tracing::{debug, info};
26
27#[derive(Debug, Clone)]
29pub struct PagedKvCacheConfig {
30 pub block_size: usize,
32 pub max_gpu_blocks: usize,
34 pub max_cpu_blocks: usize,
36 pub enable_cow: bool,
38 pub enable_swapping: bool,
40 pub low_watermark: f32,
42 pub high_watermark: f32,
44 pub num_layers: usize,
46 pub num_heads: usize,
48 pub head_dim: usize,
50 pub enable_prefix_cache: bool,
52 pub max_prefixes: usize,
54 pub min_prefix_length: usize,
56}
57
58impl Default for PagedKvCacheConfig {
59 fn default() -> Self {
60 Self {
61 block_size: 16,
62 max_gpu_blocks: 1024,
63 max_cpu_blocks: 512,
64 enable_cow: true,
65 enable_swapping: true,
66 low_watermark: 0.3,
67 high_watermark: 0.1,
68 num_layers: 32,
69 num_heads: 32,
70 head_dim: 128,
71 enable_prefix_cache: true,
72 max_prefixes: 100,
73 min_prefix_length: 16,
74 }
75 }
76}
77
78#[derive(Debug)]
80pub struct PagedKvCacheHandle {
81 request_id: RequestId,
83 device: Device,
85 block_table: RwLock<BlockTable>,
87 num_tokens: RwLock<usize>,
89 num_layers: usize,
91 num_heads: usize,
93 head_dim: usize,
95 block_size: usize,
97 last_access: RwLock<Instant>,
99 has_cow_refs: RwLock<bool>,
101 ref_count: AtomicU64,
103}
104
105impl PagedKvCacheHandle {
106 pub fn new(
108 request_id: RequestId,
109 device: Device,
110 block_size: usize,
111 num_layers: usize,
112 num_heads: usize,
113 head_dim: usize,
114 ) -> Self {
115 Self {
116 request_id,
117 device,
118 block_table: RwLock::new(BlockTable::new(block_size)),
119 num_tokens: RwLock::new(0),
120 num_layers,
121 num_heads,
122 head_dim,
123 block_size,
124 last_access: RwLock::new(Instant::now()),
125 has_cow_refs: RwLock::new(false),
126 ref_count: AtomicU64::new(1),
127 }
128 }
129
130 pub fn add_block(&self, logical_id: u32, physical_id: u32) {
132 let mut table = self.block_table.write();
133 if logical_id as usize >= table.logical_to_physical.len() {
134 table
135 .logical_to_physical
136 .resize((logical_id + 1) as usize, 0);
137 }
138 table.logical_to_physical[logical_id as usize] = physical_id;
139
140 if physical_id as usize >= table.physical_blocks.len() {
141 table.physical_blocks.resize((physical_id + 1) as usize, 0);
142 }
143 table.physical_blocks[physical_id as usize] = 1;
144
145 *self.last_access.write() = Instant::now();
146 }
147
148 pub fn get_physical_block(&self, logical_id: u32) -> Option<u32> {
150 let table = self.block_table.read();
151 if (logical_id as usize) < table.logical_to_physical.len() {
152 let physical = table.logical_to_physical[logical_id as usize];
153 if physical > 0 {
154 Some(physical)
155 } else {
156 None
157 }
158 } else {
159 None
160 }
161 }
162
163 pub fn get_physical_blocks(&self) -> Vec<u32> {
165 let table = self.block_table.read();
166 table
167 .logical_to_physical
168 .iter()
169 .filter(|&&id| id > 0)
170 .copied()
171 .collect()
172 }
173
174 pub fn num_blocks(&self) -> usize {
176 let table = self.block_table.read();
177 table
178 .logical_to_physical
179 .iter()
180 .filter(|&&id| id > 0)
181 .count()
182 }
183
184 pub fn set_num_tokens(&self, tokens: usize) {
186 *self.num_tokens.write() = tokens;
187 let mut table = self.block_table.write();
188 table.sequence_length = tokens;
189 }
190
191 pub fn required_blocks(&self, num_tokens: usize) -> usize {
193 num_tokens.div_ceil(self.block_size)
194 }
195
196 pub fn add_ref(&self) {
198 self.ref_count.fetch_add(1, Ordering::Relaxed);
199 *self.has_cow_refs.write() = true;
200 }
201
202 pub fn remove_ref(&self) -> u64 {
204 self.ref_count.fetch_sub(1, Ordering::Relaxed)
205 }
206
207 pub fn ref_count(&self) -> u64 {
209 self.ref_count.load(Ordering::Relaxed)
210 }
211
212 pub fn is_cow(&self) -> bool {
214 *self.has_cow_refs.read()
215 }
216}
217
218impl KvCacheHandle for PagedKvCacheHandle {
219 fn block_table(&self) -> &BlockTable {
220 unsafe {
224 let ptr = self.block_table.data_ptr();
225 &*ptr
226 }
227 }
228
229 fn block_table_mut(&mut self) -> &mut BlockTable {
230 self.block_table.get_mut()
231 }
232
233 fn as_any(&self) -> &dyn std::any::Any {
234 self
235 }
236
237 fn device(&self) -> Device {
238 self.device.clone()
239 }
240
241 fn num_tokens(&self) -> usize {
242 *self.num_tokens.read()
243 }
244
245 fn num_layers(&self) -> usize {
246 self.num_layers
247 }
248
249 fn num_heads(&self) -> usize {
250 self.num_heads
251 }
252
253 fn head_dim(&self) -> usize {
254 self.head_dim
255 }
256
257 fn key_cache(&self, _layer: usize) -> Result<Option<TensorRef>> {
258 Ok(None)
261 }
262
263 fn value_cache(&self, _layer: usize) -> Result<Option<TensorRef>> {
264 Ok(None)
265 }
266
267 fn clone_handle(&self) -> Result<Arc<dyn KvCacheHandle>> {
268 self.add_ref();
270 Ok(Arc::new(PagedKvCacheHandle {
271 request_id: self.request_id.clone(),
272 device: self.device.clone(),
273 block_table: RwLock::new(self.block_table.read().clone()),
274 num_tokens: RwLock::new(*self.num_tokens.read()),
275 num_layers: self.num_layers,
276 num_heads: self.num_heads,
277 head_dim: self.head_dim,
278 block_size: self.block_size,
279 last_access: RwLock::new(Instant::now()),
280 has_cow_refs: RwLock::new(true),
281 ref_count: AtomicU64::new(1),
282 }))
283 }
284
285 fn stats(&self) -> ferrum_interfaces::kv_cache::CacheHandleStats {
286 let tokens = *self.num_tokens.read();
287 let blocks = self.num_blocks();
288 let bytes_per_token = 2 * self.num_layers * self.num_heads * self.head_dim * 2; ferrum_interfaces::kv_cache::CacheHandleStats {
291 memory_bytes: blocks * self.block_size * bytes_per_token,
292 blocks_allocated: blocks,
293 tokens_stored: tokens,
294 utilization: if blocks > 0 {
295 tokens as f32 / (blocks * self.block_size) as f32
296 } else {
297 0.0
298 },
299 last_access: *self.last_access.read(),
300 }
301 }
302
303 fn is_valid(&self) -> bool {
304 self.ref_count() > 0
305 }
306
307 fn cache_id(&self) -> String {
308 format!("paged-{}", self.request_id)
309 }
310}
311
312pub struct PagedKvCacheManager {
314 config: PagedKvCacheConfig,
316 gpu_pool: BlockPool,
318 cpu_pool: Option<BlockPool>,
320 active_handles: RwLock<HashMap<RequestId, Arc<PagedKvCacheHandle>>>,
322 block_to_request: RwLock<HashMap<PhysicalBlockId, RequestId>>,
324 swapped_blocks: RwLock<HashMap<PhysicalBlockId, PhysicalBlockId>>,
326 prefix_cache: Option<PrefixCache>,
328 stats: Mutex<CacheManagerStats>,
330 #[allow(clippy::type_complexity)]
332 pressure_callback: Mutex<Option<Box<dyn Fn(MemoryPressure) + Send + Sync>>>,
333}
334
335impl PagedKvCacheManager {
336 pub fn new(device: Device, config: PagedKvCacheConfig) -> Result<Self> {
338 info!(
339 "Creating paged KV cache manager: device={:?}, block_size={}, max_gpu_blocks={}, max_cpu_blocks={}, prefix_cache={}",
340 device, config.block_size, config.max_gpu_blocks, config.max_cpu_blocks, config.enable_prefix_cache
341 );
342
343 let storage_config = BlockStorageConfig {
344 num_layers: config.num_layers,
345 num_kv_heads: config.num_heads,
346 head_dim: config.head_dim,
347 block_size: config.block_size,
348 };
349
350 let gpu_pool = BlockPool::new_with_storage(
351 device.clone(),
352 config.block_size,
353 DataType::FP16,
354 config.max_gpu_blocks,
355 storage_config,
356 )?;
357
358 let cpu_pool = if config.enable_swapping {
359 Some(BlockPool::new_with_storage(
360 Device::CPU,
361 config.block_size,
362 DataType::FP16,
363 config.max_cpu_blocks,
364 storage_config,
365 )?)
366 } else {
367 None
368 };
369
370 let prefix_cache = if config.enable_prefix_cache {
371 Some(PrefixCache::new(
372 config.max_prefixes,
373 config.min_prefix_length,
374 ))
375 } else {
376 None
377 };
378
379 Ok(Self {
380 config,
381 gpu_pool,
382 cpu_pool,
383 active_handles: RwLock::new(HashMap::new()),
384 block_to_request: RwLock::new(HashMap::new()),
385 swapped_blocks: RwLock::new(HashMap::new()),
386 prefix_cache,
387 stats: Mutex::new(CacheManagerStats {
388 total_memory_bytes: 0,
389 used_memory_bytes: 0,
390 active_caches: 0,
391 total_blocks: 0,
392 free_blocks: 0,
393 cache_hit_rate: 0.0,
394 eviction_count: 0,
395 allocation_count: 0,
396 allocation_failures: 0,
397 }),
398 pressure_callback: Mutex::new(None),
399 })
400 }
401
402 pub fn with_defaults(device: Device, block_size: usize, max_blocks: usize) -> Result<Self> {
404 let config = PagedKvCacheConfig {
405 block_size,
406 max_gpu_blocks: max_blocks,
407 max_cpu_blocks: max_blocks / 2,
408 ..Default::default()
409 };
410 Self::new(device, config)
411 }
412
413 pub fn allocate_blocks(
415 &self,
416 handle: &PagedKvCacheHandle,
417 num_blocks: usize,
418 ) -> Result<Vec<PhysicalBlockId>> {
419 let mut allocated = Vec::with_capacity(num_blocks);
420 let current_blocks = handle.num_blocks();
421
422 for i in 0..num_blocks {
423 let allocation = self.gpu_pool.allocate()?;
424 let physical_id = allocation.physical_id;
425
426 let logical_id = (current_blocks + i) as u32;
428 handle.add_block(logical_id, physical_id.0);
429
430 self.block_to_request
432 .write()
433 .insert(physical_id, handle.request_id.clone());
434
435 allocated.push(physical_id);
436 }
437
438 {
440 let mut stats = self.stats.lock();
441 stats.allocation_count += num_blocks as u64;
442 }
443
444 debug!(
445 "Allocated {} blocks for request {}: {:?}",
446 num_blocks, handle.request_id, allocated
447 );
448
449 Ok(allocated)
450 }
451
452 pub fn free_blocks(&self, block_ids: &[PhysicalBlockId]) -> Result<()> {
454 for &block_id in block_ids {
455 self.gpu_pool.deallocate(block_id)?;
456 self.block_to_request.write().remove(&block_id);
457 }
458
459 debug!("Freed {} blocks", block_ids.len());
460 Ok(())
461 }
462
463 pub fn write_kv(
468 &self,
469 handle: &PagedKvCacheHandle,
470 layer: usize,
471 token_position: usize,
472 key: &[f32],
473 value: &[f32],
474 ) -> Result<()> {
475 let block_size = self.config.block_size;
476 let logical_block = token_position / block_size;
477 let slot = token_position % block_size;
478
479 let physical_id = handle
480 .get_physical_block(logical_block as u32)
481 .ok_or_else(|| {
482 FerrumError::internal(format!(
483 "No physical block for logical block {} (token {})",
484 logical_block, token_position
485 ))
486 })?;
487
488 self.gpu_pool
489 .write_kv_slot(PhysicalBlockId::new(physical_id), layer, slot, key, value)
490 }
491
492 pub fn read_kv(
498 &self,
499 handle: &PagedKvCacheHandle,
500 layer: usize,
501 start_token: usize,
502 num_tokens: usize,
503 ) -> Result<(Vec<f32>, Vec<f32>)> {
504 let block_size = self.config.block_size;
505 let kv_size = self.config.num_heads * self.config.head_dim;
506 let mut keys = Vec::with_capacity(num_tokens * kv_size);
507 let mut values = Vec::with_capacity(num_tokens * kv_size);
508
509 for pos in start_token..start_token + num_tokens {
510 let logical_block = pos / block_size;
511 let slot = pos % block_size;
512
513 let physical_id = handle
514 .get_physical_block(logical_block as u32)
515 .ok_or_else(|| {
516 FerrumError::internal(format!(
517 "No physical block for logical block {} (token {})",
518 logical_block, pos
519 ))
520 })?;
521
522 let (k, v) =
523 self.gpu_pool
524 .read_kv_slot(PhysicalBlockId::new(physical_id), layer, slot)?;
525 keys.extend_from_slice(&k);
526 values.extend_from_slice(&v);
527 }
528
529 Ok((keys, values))
530 }
531
532 pub fn gpu_pool(&self) -> &BlockPool {
534 &self.gpu_pool
535 }
536
537 pub fn prefix_cache(&self) -> Option<&PrefixCache> {
539 self.prefix_cache.as_ref()
540 }
541
542 pub fn share_prefix_blocks(
549 &self,
550 source: &PagedKvCacheHandle,
551 target: &PagedKvCacheHandle,
552 num_prefix_blocks: usize,
553 ) -> Result<()> {
554 let source_blocks = source.get_physical_blocks();
555 let n = num_prefix_blocks.min(source_blocks.len());
556
557 for i in 0..n {
558 let phys_id = source_blocks[i];
559 target.add_block(i as u32, phys_id);
561 let pid = PhysicalBlockId::new(phys_id);
564 if let Some(block) = self.gpu_pool.get_block(pid) {
565 block.write().add_ref();
566 }
567 }
568
569 debug!(
570 "Shared {} prefix blocks from {} to {}",
571 n, source.request_id, target.request_id
572 );
573
574 Ok(())
575 }
576
577 pub fn swap_out(&self, block_ids: &[PhysicalBlockId]) -> Result<Vec<PhysicalBlockId>> {
579 let cpu_pool = self
580 .cpu_pool
581 .as_ref()
582 .ok_or_else(|| FerrumError::unsupported("Swapping not enabled"))?;
583
584 let mut swapped = Vec::with_capacity(block_ids.len());
585 let mut swap_map = self.swapped_blocks.write();
586
587 for &gpu_block in block_ids {
588 let cpu_allocation = cpu_pool.allocate()?;
590 let cpu_block = cpu_allocation.physical_id;
591
592 swap_map.insert(gpu_block, cpu_block);
596 swapped.push(cpu_block);
597
598 self.gpu_pool.deallocate(gpu_block)?;
600 }
601
602 debug!("Swapped out {} blocks to CPU", swapped.len());
603 Ok(swapped)
604 }
605
606 pub fn swap_in(&self, cpu_block_ids: &[PhysicalBlockId]) -> Result<Vec<PhysicalBlockId>> {
608 let cpu_pool = self
609 .cpu_pool
610 .as_ref()
611 .ok_or_else(|| FerrumError::unsupported("Swapping not enabled"))?;
612
613 let mut swapped = Vec::with_capacity(cpu_block_ids.len());
614 let mut swap_map = self.swapped_blocks.write();
615
616 for &cpu_block in cpu_block_ids {
617 let gpu_allocation = self.gpu_pool.allocate()?;
619 let gpu_block = gpu_allocation.physical_id;
620
621 let gpu_original = swap_map
625 .iter()
626 .find(|(_, &cpu)| cpu == cpu_block)
627 .map(|(&gpu, _)| gpu);
628
629 if let Some(orig_gpu) = gpu_original {
630 swap_map.remove(&orig_gpu);
631 }
632
633 swapped.push(gpu_block);
634
635 cpu_pool.deallocate(cpu_block)?;
637 }
638
639 debug!("Swapped in {} blocks from CPU", swapped.len());
640 Ok(swapped)
641 }
642
643 pub fn check_pressure(&self) -> MemoryPressure {
645 let gpu_stats = self.gpu_pool.stats();
646 let free_ratio = gpu_stats.free_blocks as f32 / gpu_stats.max_blocks.max(1) as f32;
647
648 if free_ratio < self.config.high_watermark {
649 MemoryPressure::Critical
650 } else if free_ratio < self.config.low_watermark {
651 MemoryPressure::High
652 } else {
653 MemoryPressure::Low
654 }
655 }
656
657 fn notify_pressure(&self, pressure: MemoryPressure) {
659 if let Some(ref callback) = *self.pressure_callback.lock() {
660 callback(pressure);
661 }
662 }
663
664 pub fn free_block_count(&self) -> usize {
666 self.gpu_pool.stats().free_blocks
667 }
668
669 pub fn total_blocks(&self) -> usize {
671 self.gpu_pool.stats().total_blocks
672 }
673
674 pub fn cow_copy(&self, handle: &PagedKvCacheHandle, block_ids: &[u32]) -> Result<Vec<u32>> {
676 if !self.config.enable_cow {
677 return Err(FerrumError::unsupported("COW not enabled"));
678 }
679
680 let mut new_blocks = Vec::with_capacity(block_ids.len());
681
682 for &_old_physical in block_ids {
683 let allocation = self.gpu_pool.allocate()?;
685 let new_physical = allocation.physical_id;
686
687 new_blocks.push(new_physical.0);
691
692 self.block_to_request
694 .write()
695 .insert(new_physical, handle.request_id.clone());
696 }
697
698 debug!("COW copied {} blocks", new_blocks.len());
699 Ok(new_blocks)
700 }
701
702 pub fn find_prefix(
709 &self,
710 tokens: &[ferrum_types::TokenId],
711 ) -> Option<(
712 PrefixId,
713 Arc<dyn ferrum_interfaces::KvCacheHandle + Send + Sync>,
714 Vec<f32>,
715 usize,
716 )> {
717 let prefix_cache = self.prefix_cache.as_ref()?;
718
719 if let Some((prefix_id, kv_handle, last_logits)) = prefix_cache.find_prefix(tokens) {
720 let matched_len = prefix_id.len();
721 debug!("Prefix cache hit: matched {} tokens", matched_len);
722
723 {
725 let mut stats = self.stats.lock();
726 let total = stats.allocation_count as f32;
727 if total > 0.0 {
728 stats.cache_hit_rate = (stats.cache_hit_rate * (total - 1.0) + 1.0) / total;
729 }
730 }
731
732 Some((prefix_id, kv_handle, last_logits, matched_len))
733 } else {
734 None
735 }
736 }
737
738 pub fn store_prefix(
740 &self,
741 tokens: &[ferrum_types::TokenId],
742 kv_handle: Arc<dyn ferrum_interfaces::KvCacheHandle + Send + Sync>,
743 last_logits: Vec<f32>,
744 ) -> Result<()> {
745 if let Some(prefix_cache) = &self.prefix_cache {
746 prefix_cache.store_prefix(tokens, kv_handle, last_logits)?;
747 debug!("Stored prefix with {} tokens in cache", tokens.len());
748 }
749 Ok(())
750 }
751
752 pub fn prefix_cache_stats(&self) -> Option<PrefixCacheStats> {
754 self.prefix_cache.as_ref().map(|pc| pc.stats())
755 }
756
757 pub fn evict_prefixes(&self, count: usize) -> usize {
759 if let Some(prefix_cache) = &self.prefix_cache {
760 let evicted = prefix_cache.evict_n(count);
761 if evicted > 0 {
762 debug!("Evicted {} prefixes from cache", evicted);
763 }
764 evicted
765 } else {
766 0
767 }
768 }
769
770 pub fn clear_prefix_cache(&self) {
772 if let Some(prefix_cache) = &self.prefix_cache {
773 prefix_cache.clear();
774 debug!("Cleared prefix cache");
775 }
776 }
777}
778
779#[async_trait]
780impl KvCacheManager for PagedKvCacheManager {
781 async fn allocate(&self, request: &AllocationRequest) -> Result<Arc<dyn KvCacheHandle>> {
782 debug!(
783 "Allocating paged KV cache for request: {:?}",
784 request.request_id
785 );
786
787 let pressure = self.check_pressure();
789 if matches!(pressure, MemoryPressure::Critical) {
790 self.notify_pressure(pressure);
791 let _ = self.gc().await;
793 }
794
795 let handle = Arc::new(PagedKvCacheHandle::new(
797 request.request_id.clone(),
798 request.device.clone(),
799 self.config.block_size,
800 request.num_layers,
801 request.num_heads,
802 request.head_dim,
803 ));
804
805 let initial_blocks = handle.required_blocks(request.initial_tokens);
807 if initial_blocks > 0 {
808 self.allocate_blocks(&handle, initial_blocks)?;
809 }
810
811 handle.set_num_tokens(request.initial_tokens);
812
813 self.active_handles
815 .write()
816 .insert(request.request_id.clone(), handle.clone());
817
818 {
820 let mut stats = self.stats.lock();
821 stats.active_caches += 1;
822 stats.allocation_count += 1;
823 }
824
825 Ok(handle)
826 }
827
828 async fn extend(&self, handle: &mut dyn KvCacheHandle, additional_tokens: usize) -> Result<()> {
829 let paged_handle = handle
830 .as_any()
831 .downcast_ref::<PagedKvCacheHandle>()
832 .ok_or_else(|| FerrumError::internal("Invalid handle type"))?;
833
834 let current_tokens = paged_handle.num_tokens();
835 let new_tokens = current_tokens + additional_tokens;
836 let current_blocks = paged_handle.num_blocks();
837 let required_blocks = paged_handle.required_blocks(new_tokens);
838
839 if required_blocks > current_blocks {
840 let new_blocks = required_blocks - current_blocks;
841
842 if paged_handle.is_cow() && paged_handle.ref_count() > 1 {
844 let existing = paged_handle.get_physical_blocks();
846 let _new_physical = self.cow_copy(paged_handle, &existing)?;
847 }
850
851 self.allocate_blocks(paged_handle, new_blocks)?;
852 }
853
854 paged_handle.set_num_tokens(new_tokens);
855
856 debug!(
857 "Extended KV cache for {}: {} -> {} tokens",
858 paged_handle.request_id, current_tokens, new_tokens
859 );
860
861 Ok(())
862 }
863
864 async fn deallocate(&self, request_id: RequestId) -> Result<()> {
865 debug!("Deallocating paged KV cache for request: {:?}", request_id);
866
867 let handle = self.active_handles.write().remove(&request_id);
868
869 if let Some(handle) = handle {
870 if handle.ref_count() > 1 {
872 handle.remove_ref();
874 debug!(
875 "Decremented ref count for {}, remaining: {}",
876 request_id,
877 handle.ref_count()
878 );
879 return Ok(());
880 }
881
882 let block_ids: Vec<PhysicalBlockId> = handle
884 .get_physical_blocks()
885 .into_iter()
886 .map(PhysicalBlockId)
887 .collect();
888
889 for block_id in block_ids {
890 let _ = self.gpu_pool.deallocate(block_id);
891 self.block_to_request.write().remove(&block_id);
892 }
893
894 {
896 let mut stats = self.stats.lock();
897 if stats.active_caches > 0 {
898 stats.active_caches -= 1;
899 }
900 }
901 }
902
903 Ok(())
904 }
905
906 fn can_allocate(&self, request: &AllocationRequest) -> bool {
907 let required_blocks = request.initial_tokens.div_ceil(self.config.block_size);
908 let gpu_stats = self.gpu_pool.stats();
909
910 gpu_stats.free_blocks >= required_blocks
911 || gpu_stats.total_blocks + required_blocks <= gpu_stats.max_blocks
912 }
913
914 fn stats(&self) -> CacheManagerStats {
915 let gpu_stats = self.gpu_pool.stats();
916 let mut stats = self.stats.lock().clone();
917
918 stats.total_blocks = gpu_stats.max_blocks;
919 stats.free_blocks = gpu_stats.free_blocks;
920
921 let bytes_per_block = self.config.block_size
923 * 2 * self.config.num_layers
925 * self.config.num_heads
926 * self.config.head_dim
927 * 2; stats.total_memory_bytes = gpu_stats.max_blocks * bytes_per_block;
930 stats.used_memory_bytes = gpu_stats.allocated_blocks * bytes_per_block;
931
932 stats
933 }
934
935 async fn gc(&self) -> Result<CacheGcStats> {
936 let start = Instant::now();
937
938 let evicted = self.gpu_pool.evict_blocks(10)?;
940
941 {
943 let mut stats = self.stats.lock();
944 stats.eviction_count += evicted.len() as u64;
945 }
946
947 Ok(CacheGcStats {
948 memory_freed: evicted.len() * self.config.block_size * 1024, caches_freed: 0,
950 gc_time_ms: start.elapsed().as_millis() as u64,
951 })
952 }
953
954 fn set_pressure_callback(&self, callback: Box<dyn Fn(MemoryPressure) + Send + Sync>) {
955 *self.pressure_callback.lock() = Some(callback);
956 }
957
958 fn get_handle(&self, request_id: RequestId) -> Option<Arc<dyn KvCacheHandle>> {
959 self.active_handles
960 .read()
961 .get(&request_id)
962 .map(|h| h.clone() as Arc<dyn KvCacheHandle>)
963 }
964
965 fn list_handles(&self) -> Vec<(RequestId, Arc<dyn KvCacheHandle>)> {
966 self.active_handles
967 .read()
968 .iter()
969 .map(|(id, handle)| (id.clone(), handle.clone() as Arc<dyn KvCacheHandle>))
970 .collect()
971 }
972}
973
974impl std::fmt::Debug for PagedKvCacheManager {
975 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
976 let gpu_stats = self.gpu_pool.stats();
977 f.debug_struct("PagedKvCacheManager")
978 .field("block_size", &self.config.block_size)
979 .field("total_gpu_blocks", &gpu_stats.total_blocks)
980 .field("free_gpu_blocks", &gpu_stats.free_blocks)
981 .field("active_handles", &self.active_handles.read().len())
982 .finish()
983 }
984}
985
986#[cfg(test)]
991mod tests {
992 use super::*;
993
994 fn create_test_request() -> AllocationRequest {
995 AllocationRequest {
996 request_id: RequestId::new(),
997 initial_tokens: 64,
998 max_sequence_length: 2048,
999 num_layers: 32,
1000 num_heads: 32,
1001 head_dim: 128,
1002 device: Device::CPU,
1003 dtype: DataType::FP16,
1004 priority: ferrum_types::Priority::Normal,
1005 }
1006 }
1007
1008 #[tokio::test]
1009 async fn test_manager_creation() {
1010 let manager = PagedKvCacheManager::with_defaults(Device::CPU, 16, 100);
1011 assert!(manager.is_ok());
1012 }
1013
1014 #[tokio::test]
1015 async fn test_allocate_and_deallocate() {
1016 let manager = PagedKvCacheManager::with_defaults(Device::CPU, 16, 100).unwrap();
1017 let request = create_test_request();
1018 let request_id = request.request_id.clone();
1019
1020 let handle = manager.allocate(&request).await.unwrap();
1021 assert!(handle.is_valid());
1022 assert_eq!(handle.num_tokens(), 64);
1023
1024 let stats = handle.stats();
1026 assert!(stats.blocks_allocated >= 1 || stats.tokens_stored >= 64);
1028
1029 manager.deallocate(request_id).await.unwrap();
1030 }
1031
1032 #[tokio::test]
1033 async fn test_extend() {
1034 let manager = PagedKvCacheManager::with_defaults(Device::CPU, 16, 100).unwrap();
1035 let request = create_test_request();
1036 let request_id = request.request_id.clone();
1037
1038 let handle = manager.allocate(&request).await.unwrap();
1039 let initial_blocks = handle.stats().blocks_allocated;
1040
1041 let paged_handle = manager.get_handle(request_id.clone()).unwrap();
1043 let paged_ref = paged_handle
1044 .as_any()
1045 .downcast_ref::<PagedKvCacheHandle>()
1046 .unwrap();
1047 manager.allocate_blocks(paged_ref, 4).unwrap();
1048
1049 let new_blocks = handle.stats().blocks_allocated;
1050 assert!(new_blocks > initial_blocks);
1051
1052 manager.deallocate(request_id).await.unwrap();
1053 }
1054
1055 #[tokio::test]
1056 async fn test_can_allocate() {
1057 let manager = PagedKvCacheManager::with_defaults(Device::CPU, 16, 10).unwrap();
1058
1059 let request = create_test_request();
1060 assert!(manager.can_allocate(&request));
1061
1062 for _ in 0..8 {
1064 let req = create_test_request();
1065 let _ = manager.allocate(&req).await;
1066 }
1067
1068 let stats = manager.stats();
1070 assert!(stats.free_blocks < stats.total_blocks);
1071 }
1072
1073 #[tokio::test]
1074 async fn test_gc() {
1075 let manager = PagedKvCacheManager::with_defaults(Device::CPU, 16, 100).unwrap();
1076
1077 let request = create_test_request();
1079 let request_id = request.request_id.clone();
1080 let _ = manager.allocate(&request).await.unwrap();
1081 manager.deallocate(request_id).await.unwrap();
1082
1083 let gc_stats = manager.gc().await.unwrap();
1085 assert_eq!(gc_stats.caches_freed, 0);
1086 }
1087
1088 #[test]
1089 fn test_paged_handle() {
1090 let handle = PagedKvCacheHandle::new(RequestId::new(), Device::CPU, 16, 32, 32, 128);
1091
1092 assert_eq!(handle.num_tokens(), 0);
1093 assert_eq!(handle.num_blocks(), 0);
1094
1095 handle.add_block(0, 5);
1097 handle.add_block(1, 10);
1098
1099 assert_eq!(handle.num_blocks(), 2);
1100 assert_eq!(handle.get_physical_block(0), Some(5));
1101 assert_eq!(handle.get_physical_block(1), Some(10));
1102 }
1103
1104 #[tokio::test]
1105 async fn test_write_read_kv_across_blocks() {
1106 let config = PagedKvCacheConfig {
1108 block_size: 4,
1109 max_gpu_blocks: 16,
1110 max_cpu_blocks: 0,
1111 enable_cow: false,
1112 enable_swapping: false,
1113 num_layers: 2,
1114 num_heads: 2,
1115 head_dim: 4,
1116 enable_prefix_cache: false,
1117 ..Default::default()
1118 };
1119 let manager = PagedKvCacheManager::new(Device::CPU, config).unwrap();
1120
1121 let request = AllocationRequest {
1122 request_id: RequestId::new(),
1123 initial_tokens: 6, max_sequence_length: 32,
1125 num_layers: 2,
1126 num_heads: 2,
1127 head_dim: 4,
1128 device: Device::CPU,
1129 dtype: DataType::FP16,
1130 priority: ferrum_types::Priority::Normal,
1131 };
1132 let request_id = request.request_id.clone();
1133
1134 let handle_dyn = manager.allocate(&request).await.unwrap();
1135 let handle = handle_dyn
1136 .as_any()
1137 .downcast_ref::<PagedKvCacheHandle>()
1138 .unwrap();
1139
1140 let kv_size = 2 * 4; for pos in 0..6 {
1144 let key: Vec<f32> = (0..kv_size).map(|i| (pos * 100 + i) as f32).collect();
1145 let val: Vec<f32> = (0..kv_size).map(|i| (pos * 100 + i + 50) as f32).collect();
1146 manager.write_kv(handle, 0, pos, &key, &val).unwrap();
1147 }
1148
1149 let (keys, vals) = manager.read_kv(handle, 0, 0, 6).unwrap();
1151 assert_eq!(keys.len(), 6 * kv_size);
1152 assert_eq!(vals.len(), 6 * kv_size);
1153
1154 assert_eq!(keys[0], 0.0);
1156 assert_eq!(keys[kv_size - 1], 7.0);
1157
1158 assert_eq!(keys[4 * kv_size], 400.0);
1160
1161 assert_eq!(vals[5 * kv_size], 550.0);
1163
1164 let (k1, _) = manager.read_kv(handle, 1, 0, 1).unwrap();
1166 assert!(k1.iter().all(|&x| x == 0.0));
1167
1168 manager.deallocate(request_id).await.unwrap();
1169 }
1170
1171 #[test]
1172 fn test_required_blocks() {
1173 let handle = PagedKvCacheHandle::new(
1174 RequestId::new(),
1175 Device::CPU,
1176 16, 32,
1178 32,
1179 128,
1180 );
1181
1182 assert_eq!(handle.required_blocks(0), 0);
1183 assert_eq!(handle.required_blocks(16), 1);
1184 assert_eq!(handle.required_blocks(17), 2);
1185 assert_eq!(handle.required_blocks(32), 2);
1186 assert_eq!(handle.required_blocks(33), 3);
1187 }
1188}