Skip to main content

ferrum_kv/managers/
paged.rs

1//! Paged KV Cache Manager
2//!
3//! This module implements PagedAttention-style KV cache management with:
4//!
5//! - Non-contiguous physical memory allocation
6//! - Logical to physical block mapping via block tables
7//! - Copy-on-write support for prefix sharing
8//! - GPU<->CPU block swapping for memory management
9//! - Efficient block reclamation and reuse
10//! - Prefix caching for shared prompt optimization
11
12use crate::blocks::{BlockPool, BlockStorageConfig, PhysicalBlockId};
13use crate::cache::prefix::{PrefixCache, PrefixCacheStats, PrefixId};
14use async_trait::async_trait;
15use ferrum_interfaces::{
16    kv_cache::{AllocationRequest, BlockTable, CacheGcStats, CacheManagerStats, MemoryPressure},
17    KvCacheHandle, KvCacheManager, TensorRef,
18};
19use ferrum_types::{DataType, Device, FerrumError, RequestId, Result};
20use parking_lot::{Mutex, RwLock};
21use std::collections::HashMap;
22use std::sync::atomic::{AtomicU64, Ordering};
23use std::sync::Arc;
24use std::time::Instant;
25use tracing::{debug, info};
26
27/// Configuration for paged KV cache manager
28#[derive(Debug, Clone)]
29pub struct PagedKvCacheConfig {
30    /// Block size in tokens
31    pub block_size: usize,
32    /// Maximum number of GPU blocks
33    pub max_gpu_blocks: usize,
34    /// Maximum number of CPU blocks (for swapping)
35    pub max_cpu_blocks: usize,
36    /// Enable copy-on-write for prefix sharing
37    pub enable_cow: bool,
38    /// Enable block swapping
39    pub enable_swapping: bool,
40    /// Watermark for low memory pressure (fraction of blocks free)
41    pub low_watermark: f32,
42    /// Watermark for high memory pressure
43    pub high_watermark: f32,
44    /// Number of layers in the model
45    pub num_layers: usize,
46    /// Number of attention heads
47    pub num_heads: usize,
48    /// Head dimension
49    pub head_dim: usize,
50    /// Enable prefix caching
51    pub enable_prefix_cache: bool,
52    /// Maximum number of prefixes to cache
53    pub max_prefixes: usize,
54    /// Minimum prefix length to cache
55    pub min_prefix_length: usize,
56}
57
58impl Default for PagedKvCacheConfig {
59    fn default() -> Self {
60        Self {
61            block_size: 16,
62            max_gpu_blocks: 1024,
63            max_cpu_blocks: 512,
64            enable_cow: true,
65            enable_swapping: true,
66            low_watermark: 0.3,
67            high_watermark: 0.1,
68            num_layers: 32,
69            num_heads: 32,
70            head_dim: 128,
71            enable_prefix_cache: true,
72            max_prefixes: 100,
73            min_prefix_length: 16,
74        }
75    }
76}
77
78/// Paged KV cache handle for a single sequence
79#[derive(Debug)]
80pub struct PagedKvCacheHandle {
81    /// Request ID
82    request_id: RequestId,
83    /// Device where blocks are allocated
84    device: Device,
85    /// Block table (logical to physical mapping)
86    block_table: RwLock<BlockTable>,
87    /// Number of tokens stored
88    num_tokens: RwLock<usize>,
89    /// Number of layers
90    num_layers: usize,
91    /// Number of heads
92    num_heads: usize,
93    /// Head dimension
94    head_dim: usize,
95    /// Block size
96    block_size: usize,
97    /// Last access time
98    last_access: RwLock<Instant>,
99    /// Whether this handle has copy-on-write references
100    has_cow_refs: RwLock<bool>,
101    /// Reference count (for COW)
102    ref_count: AtomicU64,
103}
104
105impl PagedKvCacheHandle {
106    /// Create new paged KV cache handle
107    pub fn new(
108        request_id: RequestId,
109        device: Device,
110        block_size: usize,
111        num_layers: usize,
112        num_heads: usize,
113        head_dim: usize,
114    ) -> Self {
115        Self {
116            request_id,
117            device,
118            block_table: RwLock::new(BlockTable::new(block_size)),
119            num_tokens: RwLock::new(0),
120            num_layers,
121            num_heads,
122            head_dim,
123            block_size,
124            last_access: RwLock::new(Instant::now()),
125            has_cow_refs: RwLock::new(false),
126            ref_count: AtomicU64::new(1),
127        }
128    }
129
130    /// Add a physical block to this handle
131    pub fn add_block(&self, logical_id: u32, physical_id: u32) {
132        let mut table = self.block_table.write();
133        if logical_id as usize >= table.logical_to_physical.len() {
134            table
135                .logical_to_physical
136                .resize((logical_id + 1) as usize, 0);
137        }
138        table.logical_to_physical[logical_id as usize] = physical_id;
139
140        if physical_id as usize >= table.physical_blocks.len() {
141            table.physical_blocks.resize((physical_id + 1) as usize, 0);
142        }
143        table.physical_blocks[physical_id as usize] = 1;
144
145        *self.last_access.write() = Instant::now();
146    }
147
148    /// Get physical block for logical block
149    pub fn get_physical_block(&self, logical_id: u32) -> Option<u32> {
150        let table = self.block_table.read();
151        if (logical_id as usize) < table.logical_to_physical.len() {
152            let physical = table.logical_to_physical[logical_id as usize];
153            if physical > 0 {
154                Some(physical)
155            } else {
156                None
157            }
158        } else {
159            None
160        }
161    }
162
163    /// Get all physical block IDs
164    pub fn get_physical_blocks(&self) -> Vec<u32> {
165        let table = self.block_table.read();
166        table
167            .logical_to_physical
168            .iter()
169            .filter(|&&id| id > 0)
170            .copied()
171            .collect()
172    }
173
174    /// Get number of blocks allocated
175    pub fn num_blocks(&self) -> usize {
176        let table = self.block_table.read();
177        table
178            .logical_to_physical
179            .iter()
180            .filter(|&&id| id > 0)
181            .count()
182    }
183
184    /// Update token count
185    pub fn set_num_tokens(&self, tokens: usize) {
186        *self.num_tokens.write() = tokens;
187        let mut table = self.block_table.write();
188        table.sequence_length = tokens;
189    }
190
191    /// Get required number of blocks for token count
192    pub fn required_blocks(&self, num_tokens: usize) -> usize {
193        num_tokens.div_ceil(self.block_size)
194    }
195
196    /// Increment reference count (for COW)
197    pub fn add_ref(&self) {
198        self.ref_count.fetch_add(1, Ordering::Relaxed);
199        *self.has_cow_refs.write() = true;
200    }
201
202    /// Decrement reference count
203    pub fn remove_ref(&self) -> u64 {
204        self.ref_count.fetch_sub(1, Ordering::Relaxed)
205    }
206
207    /// Get current reference count
208    pub fn ref_count(&self) -> u64 {
209        self.ref_count.load(Ordering::Relaxed)
210    }
211
212    /// Check if this is a COW reference
213    pub fn is_cow(&self) -> bool {
214        *self.has_cow_refs.read()
215    }
216}
217
218impl KvCacheHandle for PagedKvCacheHandle {
219    fn block_table(&self) -> &BlockTable {
220        // This is a bit tricky - we need to return a reference to the block table
221        // but we have it behind a RwLock. For now, we'll use an unsafe pattern.
222        // In production, this should be redesigned.
223        unsafe {
224            let ptr = self.block_table.data_ptr();
225            &*ptr
226        }
227    }
228
229    fn block_table_mut(&mut self) -> &mut BlockTable {
230        self.block_table.get_mut()
231    }
232
233    fn as_any(&self) -> &dyn std::any::Any {
234        self
235    }
236
237    fn device(&self) -> Device {
238        self.device.clone()
239    }
240
241    fn num_tokens(&self) -> usize {
242        *self.num_tokens.read()
243    }
244
245    fn num_layers(&self) -> usize {
246        self.num_layers
247    }
248
249    fn num_heads(&self) -> usize {
250        self.num_heads
251    }
252
253    fn head_dim(&self) -> usize {
254        self.head_dim
255    }
256
257    fn key_cache(&self, _layer: usize) -> Result<Option<TensorRef>> {
258        // PagedAttention stores KV cache in physical blocks, not as tensors
259        // The actual tensor access is done through the block pool
260        Ok(None)
261    }
262
263    fn value_cache(&self, _layer: usize) -> Result<Option<TensorRef>> {
264        Ok(None)
265    }
266
267    fn clone_handle(&self) -> Result<Arc<dyn KvCacheHandle>> {
268        // For COW, we increment ref count instead of copying
269        self.add_ref();
270        Ok(Arc::new(PagedKvCacheHandle {
271            request_id: self.request_id.clone(),
272            device: self.device.clone(),
273            block_table: RwLock::new(self.block_table.read().clone()),
274            num_tokens: RwLock::new(*self.num_tokens.read()),
275            num_layers: self.num_layers,
276            num_heads: self.num_heads,
277            head_dim: self.head_dim,
278            block_size: self.block_size,
279            last_access: RwLock::new(Instant::now()),
280            has_cow_refs: RwLock::new(true),
281            ref_count: AtomicU64::new(1),
282        }))
283    }
284
285    fn stats(&self) -> ferrum_interfaces::kv_cache::CacheHandleStats {
286        let tokens = *self.num_tokens.read();
287        let blocks = self.num_blocks();
288        let bytes_per_token = 2 * self.num_layers * self.num_heads * self.head_dim * 2; // K+V, FP16
289
290        ferrum_interfaces::kv_cache::CacheHandleStats {
291            memory_bytes: blocks * self.block_size * bytes_per_token,
292            blocks_allocated: blocks,
293            tokens_stored: tokens,
294            utilization: if blocks > 0 {
295                tokens as f32 / (blocks * self.block_size) as f32
296            } else {
297                0.0
298            },
299            last_access: *self.last_access.read(),
300        }
301    }
302
303    fn is_valid(&self) -> bool {
304        self.ref_count() > 0
305    }
306
307    fn cache_id(&self) -> String {
308        format!("paged-{}", self.request_id)
309    }
310}
311
312/// Paged KV cache manager
313pub struct PagedKvCacheManager {
314    /// Configuration
315    config: PagedKvCacheConfig,
316    /// GPU block pool
317    gpu_pool: BlockPool,
318    /// CPU block pool (for swapping)
319    cpu_pool: Option<BlockPool>,
320    /// Active handles
321    active_handles: RwLock<HashMap<RequestId, Arc<PagedKvCacheHandle>>>,
322    /// Block to request mapping (for eviction)
323    block_to_request: RwLock<HashMap<PhysicalBlockId, RequestId>>,
324    /// Swapped out blocks (GPU block ID -> CPU block ID)
325    swapped_blocks: RwLock<HashMap<PhysicalBlockId, PhysicalBlockId>>,
326    /// Prefix cache for shared prompts
327    prefix_cache: Option<PrefixCache>,
328    /// Statistics
329    stats: Mutex<CacheManagerStats>,
330    /// Pressure callback
331    #[allow(clippy::type_complexity)]
332    pressure_callback: Mutex<Option<Box<dyn Fn(MemoryPressure) + Send + Sync>>>,
333}
334
335impl PagedKvCacheManager {
336    /// Create new paged KV cache manager
337    pub fn new(device: Device, config: PagedKvCacheConfig) -> Result<Self> {
338        info!(
339            "Creating paged KV cache manager: device={:?}, block_size={}, max_gpu_blocks={}, max_cpu_blocks={}, prefix_cache={}",
340            device, config.block_size, config.max_gpu_blocks, config.max_cpu_blocks, config.enable_prefix_cache
341        );
342
343        let storage_config = BlockStorageConfig {
344            num_layers: config.num_layers,
345            num_kv_heads: config.num_heads,
346            head_dim: config.head_dim,
347            block_size: config.block_size,
348        };
349
350        let gpu_pool = BlockPool::new_with_storage(
351            device.clone(),
352            config.block_size,
353            DataType::FP16,
354            config.max_gpu_blocks,
355            storage_config,
356        )?;
357
358        let cpu_pool = if config.enable_swapping {
359            Some(BlockPool::new_with_storage(
360                Device::CPU,
361                config.block_size,
362                DataType::FP16,
363                config.max_cpu_blocks,
364                storage_config,
365            )?)
366        } else {
367            None
368        };
369
370        let prefix_cache = if config.enable_prefix_cache {
371            Some(PrefixCache::new(
372                config.max_prefixes,
373                config.min_prefix_length,
374            ))
375        } else {
376            None
377        };
378
379        Ok(Self {
380            config,
381            gpu_pool,
382            cpu_pool,
383            active_handles: RwLock::new(HashMap::new()),
384            block_to_request: RwLock::new(HashMap::new()),
385            swapped_blocks: RwLock::new(HashMap::new()),
386            prefix_cache,
387            stats: Mutex::new(CacheManagerStats {
388                total_memory_bytes: 0,
389                used_memory_bytes: 0,
390                active_caches: 0,
391                total_blocks: 0,
392                free_blocks: 0,
393                cache_hit_rate: 0.0,
394                eviction_count: 0,
395                allocation_count: 0,
396                allocation_failures: 0,
397            }),
398            pressure_callback: Mutex::new(None),
399        })
400    }
401
402    /// Create with default config
403    pub fn with_defaults(device: Device, block_size: usize, max_blocks: usize) -> Result<Self> {
404        let config = PagedKvCacheConfig {
405            block_size,
406            max_gpu_blocks: max_blocks,
407            max_cpu_blocks: max_blocks / 2,
408            ..Default::default()
409        };
410        Self::new(device, config)
411    }
412
413    /// Allocate blocks for a sequence
414    pub fn allocate_blocks(
415        &self,
416        handle: &PagedKvCacheHandle,
417        num_blocks: usize,
418    ) -> Result<Vec<PhysicalBlockId>> {
419        let mut allocated = Vec::with_capacity(num_blocks);
420        let current_blocks = handle.num_blocks();
421
422        for i in 0..num_blocks {
423            let allocation = self.gpu_pool.allocate()?;
424            let physical_id = allocation.physical_id;
425
426            // Map logical to physical
427            let logical_id = (current_blocks + i) as u32;
428            handle.add_block(logical_id, physical_id.0);
429
430            // Track block ownership
431            self.block_to_request
432                .write()
433                .insert(physical_id, handle.request_id.clone());
434
435            allocated.push(physical_id);
436        }
437
438        // Update stats
439        {
440            let mut stats = self.stats.lock();
441            stats.allocation_count += num_blocks as u64;
442        }
443
444        debug!(
445            "Allocated {} blocks for request {}: {:?}",
446            num_blocks, handle.request_id, allocated
447        );
448
449        Ok(allocated)
450    }
451
452    /// Free blocks for a sequence
453    pub fn free_blocks(&self, block_ids: &[PhysicalBlockId]) -> Result<()> {
454        for &block_id in block_ids {
455            self.gpu_pool.deallocate(block_id)?;
456            self.block_to_request.write().remove(&block_id);
457        }
458
459        debug!("Freed {} blocks", block_ids.len());
460        Ok(())
461    }
462
463    /// Write one token's K/V vectors for a given layer and absolute token position.
464    ///
465    /// The position is translated through the handle's block table to find the
466    /// physical block and slot within that block.
467    pub fn write_kv(
468        &self,
469        handle: &PagedKvCacheHandle,
470        layer: usize,
471        token_position: usize,
472        key: &[f32],
473        value: &[f32],
474    ) -> Result<()> {
475        let block_size = self.config.block_size;
476        let logical_block = token_position / block_size;
477        let slot = token_position % block_size;
478
479        let physical_id = handle
480            .get_physical_block(logical_block as u32)
481            .ok_or_else(|| {
482                FerrumError::internal(format!(
483                    "No physical block for logical block {} (token {})",
484                    logical_block, token_position
485                ))
486            })?;
487
488        self.gpu_pool
489            .write_kv_slot(PhysicalBlockId::new(physical_id), layer, slot, key, value)
490    }
491
492    /// Read K/V vectors for a range of token positions in one layer.
493    ///
494    /// Gathers data across potentially non-contiguous physical blocks using
495    /// the handle's block table. Returns `(keys, values)` each of length
496    /// `num_tokens * num_kv_heads * head_dim`, with tokens in order.
497    pub fn read_kv(
498        &self,
499        handle: &PagedKvCacheHandle,
500        layer: usize,
501        start_token: usize,
502        num_tokens: usize,
503    ) -> Result<(Vec<f32>, Vec<f32>)> {
504        let block_size = self.config.block_size;
505        let kv_size = self.config.num_heads * self.config.head_dim;
506        let mut keys = Vec::with_capacity(num_tokens * kv_size);
507        let mut values = Vec::with_capacity(num_tokens * kv_size);
508
509        for pos in start_token..start_token + num_tokens {
510            let logical_block = pos / block_size;
511            let slot = pos % block_size;
512
513            let physical_id = handle
514                .get_physical_block(logical_block as u32)
515                .ok_or_else(|| {
516                    FerrumError::internal(format!(
517                        "No physical block for logical block {} (token {})",
518                        logical_block, pos
519                    ))
520                })?;
521
522            let (k, v) =
523                self.gpu_pool
524                    .read_kv_slot(PhysicalBlockId::new(physical_id), layer, slot)?;
525            keys.extend_from_slice(&k);
526            values.extend_from_slice(&v);
527        }
528
529        Ok((keys, values))
530    }
531
532    /// Get a reference to the GPU block pool.
533    pub fn gpu_pool(&self) -> &BlockPool {
534        &self.gpu_pool
535    }
536
537    /// Get a reference to the prefix cache (if enabled).
538    pub fn prefix_cache(&self) -> Option<&PrefixCache> {
539        self.prefix_cache.as_ref()
540    }
541
542    /// Share the first `num_prefix_blocks` physical blocks from `source` into
543    /// `target`.  The shared blocks get an incremented ref count so they
544    /// survive when the source handle is deallocated.
545    ///
546    /// After this call the target handle's block table maps logical blocks
547    /// `0..num_prefix_blocks` to the same physical blocks as the source.
548    pub fn share_prefix_blocks(
549        &self,
550        source: &PagedKvCacheHandle,
551        target: &PagedKvCacheHandle,
552        num_prefix_blocks: usize,
553    ) -> Result<()> {
554        let source_blocks = source.get_physical_blocks();
555        let n = num_prefix_blocks.min(source_blocks.len());
556
557        for i in 0..n {
558            let phys_id = source_blocks[i];
559            // Map in target
560            target.add_block(i as u32, phys_id);
561            // Increment ref count on the physical block so it isn't freed
562            // when the source handle is deallocated.
563            let pid = PhysicalBlockId::new(phys_id);
564            if let Some(block) = self.gpu_pool.get_block(pid) {
565                block.write().add_ref();
566            }
567        }
568
569        debug!(
570            "Shared {} prefix blocks from {} to {}",
571            n, source.request_id, target.request_id
572        );
573
574        Ok(())
575    }
576
577    /// Swap out blocks to CPU
578    pub fn swap_out(&self, block_ids: &[PhysicalBlockId]) -> Result<Vec<PhysicalBlockId>> {
579        let cpu_pool = self
580            .cpu_pool
581            .as_ref()
582            .ok_or_else(|| FerrumError::unsupported("Swapping not enabled"))?;
583
584        let mut swapped = Vec::with_capacity(block_ids.len());
585        let mut swap_map = self.swapped_blocks.write();
586
587        for &gpu_block in block_ids {
588            // Allocate CPU block
589            let cpu_allocation = cpu_pool.allocate()?;
590            let cpu_block = cpu_allocation.physical_id;
591
592            // TODO: Actually copy data from GPU to CPU
593            // This requires tensor memory access which is backend-specific
594
595            swap_map.insert(gpu_block, cpu_block);
596            swapped.push(cpu_block);
597
598            // Free GPU block
599            self.gpu_pool.deallocate(gpu_block)?;
600        }
601
602        debug!("Swapped out {} blocks to CPU", swapped.len());
603        Ok(swapped)
604    }
605
606    /// Swap in blocks from CPU
607    pub fn swap_in(&self, cpu_block_ids: &[PhysicalBlockId]) -> Result<Vec<PhysicalBlockId>> {
608        let cpu_pool = self
609            .cpu_pool
610            .as_ref()
611            .ok_or_else(|| FerrumError::unsupported("Swapping not enabled"))?;
612
613        let mut swapped = Vec::with_capacity(cpu_block_ids.len());
614        let mut swap_map = self.swapped_blocks.write();
615
616        for &cpu_block in cpu_block_ids {
617            // Allocate GPU block
618            let gpu_allocation = self.gpu_pool.allocate()?;
619            let gpu_block = gpu_allocation.physical_id;
620
621            // TODO: Actually copy data from CPU to GPU
622
623            // Find and remove the mapping
624            let gpu_original = swap_map
625                .iter()
626                .find(|(_, &cpu)| cpu == cpu_block)
627                .map(|(&gpu, _)| gpu);
628
629            if let Some(orig_gpu) = gpu_original {
630                swap_map.remove(&orig_gpu);
631            }
632
633            swapped.push(gpu_block);
634
635            // Free CPU block
636            cpu_pool.deallocate(cpu_block)?;
637        }
638
639        debug!("Swapped in {} blocks from CPU", swapped.len());
640        Ok(swapped)
641    }
642
643    /// Check memory pressure
644    pub fn check_pressure(&self) -> MemoryPressure {
645        let gpu_stats = self.gpu_pool.stats();
646        let free_ratio = gpu_stats.free_blocks as f32 / gpu_stats.max_blocks.max(1) as f32;
647
648        if free_ratio < self.config.high_watermark {
649            MemoryPressure::Critical
650        } else if free_ratio < self.config.low_watermark {
651            MemoryPressure::High
652        } else {
653            MemoryPressure::Low
654        }
655    }
656
657    /// Trigger pressure callback if registered
658    fn notify_pressure(&self, pressure: MemoryPressure) {
659        if let Some(ref callback) = *self.pressure_callback.lock() {
660            callback(pressure);
661        }
662    }
663
664    /// Get free block count
665    pub fn free_block_count(&self) -> usize {
666        self.gpu_pool.stats().free_blocks
667    }
668
669    /// Get total block count
670    pub fn total_blocks(&self) -> usize {
671        self.gpu_pool.stats().total_blocks
672    }
673
674    /// Copy-on-write: copy blocks when a shared reference is modified
675    pub fn cow_copy(&self, handle: &PagedKvCacheHandle, block_ids: &[u32]) -> Result<Vec<u32>> {
676        if !self.config.enable_cow {
677            return Err(FerrumError::unsupported("COW not enabled"));
678        }
679
680        let mut new_blocks = Vec::with_capacity(block_ids.len());
681
682        for &_old_physical in block_ids {
683            // Allocate new block
684            let allocation = self.gpu_pool.allocate()?;
685            let new_physical = allocation.physical_id;
686
687            // TODO: Copy data from old block to new block
688            // This requires tensor memory access
689
690            new_blocks.push(new_physical.0);
691
692            // Update block ownership
693            self.block_to_request
694                .write()
695                .insert(new_physical, handle.request_id.clone());
696        }
697
698        debug!("COW copied {} blocks", new_blocks.len());
699        Ok(new_blocks)
700    }
701
702    // ==========================================================================
703    // Prefix Caching Methods
704    // ==========================================================================
705
706    /// Find a cached prefix that matches the given tokens
707    /// Returns (prefix_id, kv_handle, last_logits, matched_length) if found
708    pub fn find_prefix(
709        &self,
710        tokens: &[ferrum_types::TokenId],
711    ) -> Option<(
712        PrefixId,
713        Arc<dyn ferrum_interfaces::KvCacheHandle + Send + Sync>,
714        Vec<f32>,
715        usize,
716    )> {
717        let prefix_cache = self.prefix_cache.as_ref()?;
718
719        if let Some((prefix_id, kv_handle, last_logits)) = prefix_cache.find_prefix(tokens) {
720            let matched_len = prefix_id.len();
721            debug!("Prefix cache hit: matched {} tokens", matched_len);
722
723            // Update hit rate stats
724            {
725                let mut stats = self.stats.lock();
726                let total = stats.allocation_count as f32;
727                if total > 0.0 {
728                    stats.cache_hit_rate = (stats.cache_hit_rate * (total - 1.0) + 1.0) / total;
729                }
730            }
731
732            Some((prefix_id, kv_handle, last_logits, matched_len))
733        } else {
734            None
735        }
736    }
737
738    /// Store a prefix in the cache for future reuse
739    pub fn store_prefix(
740        &self,
741        tokens: &[ferrum_types::TokenId],
742        kv_handle: Arc<dyn ferrum_interfaces::KvCacheHandle + Send + Sync>,
743        last_logits: Vec<f32>,
744    ) -> Result<()> {
745        if let Some(prefix_cache) = &self.prefix_cache {
746            prefix_cache.store_prefix(tokens, kv_handle, last_logits)?;
747            debug!("Stored prefix with {} tokens in cache", tokens.len());
748        }
749        Ok(())
750    }
751
752    /// Get prefix cache statistics
753    pub fn prefix_cache_stats(&self) -> Option<PrefixCacheStats> {
754        self.prefix_cache.as_ref().map(|pc| pc.stats())
755    }
756
757    /// Evict oldest prefixes from cache
758    pub fn evict_prefixes(&self, count: usize) -> usize {
759        if let Some(prefix_cache) = &self.prefix_cache {
760            let evicted = prefix_cache.evict_n(count);
761            if evicted > 0 {
762                debug!("Evicted {} prefixes from cache", evicted);
763            }
764            evicted
765        } else {
766            0
767        }
768    }
769
770    /// Clear all cached prefixes
771    pub fn clear_prefix_cache(&self) {
772        if let Some(prefix_cache) = &self.prefix_cache {
773            prefix_cache.clear();
774            debug!("Cleared prefix cache");
775        }
776    }
777}
778
779#[async_trait]
780impl KvCacheManager for PagedKvCacheManager {
781    async fn allocate(&self, request: &AllocationRequest) -> Result<Arc<dyn KvCacheHandle>> {
782        debug!(
783            "Allocating paged KV cache for request: {:?}",
784            request.request_id
785        );
786
787        // Check pressure before allocation
788        let pressure = self.check_pressure();
789        if matches!(pressure, MemoryPressure::Critical) {
790            self.notify_pressure(pressure);
791            // Try to evict some blocks
792            let _ = self.gc().await;
793        }
794
795        // Create handle
796        let handle = Arc::new(PagedKvCacheHandle::new(
797            request.request_id.clone(),
798            request.device.clone(),
799            self.config.block_size,
800            request.num_layers,
801            request.num_heads,
802            request.head_dim,
803        ));
804
805        // Allocate initial blocks
806        let initial_blocks = handle.required_blocks(request.initial_tokens);
807        if initial_blocks > 0 {
808            self.allocate_blocks(&handle, initial_blocks)?;
809        }
810
811        handle.set_num_tokens(request.initial_tokens);
812
813        // Store handle
814        self.active_handles
815            .write()
816            .insert(request.request_id.clone(), handle.clone());
817
818        // Update stats
819        {
820            let mut stats = self.stats.lock();
821            stats.active_caches += 1;
822            stats.allocation_count += 1;
823        }
824
825        Ok(handle)
826    }
827
828    async fn extend(&self, handle: &mut dyn KvCacheHandle, additional_tokens: usize) -> Result<()> {
829        let paged_handle = handle
830            .as_any()
831            .downcast_ref::<PagedKvCacheHandle>()
832            .ok_or_else(|| FerrumError::internal("Invalid handle type"))?;
833
834        let current_tokens = paged_handle.num_tokens();
835        let new_tokens = current_tokens + additional_tokens;
836        let current_blocks = paged_handle.num_blocks();
837        let required_blocks = paged_handle.required_blocks(new_tokens);
838
839        if required_blocks > current_blocks {
840            let new_blocks = required_blocks - current_blocks;
841
842            // Check if this is a COW reference that needs copying
843            if paged_handle.is_cow() && paged_handle.ref_count() > 1 {
844                // Need to copy existing blocks first
845                let existing = paged_handle.get_physical_blocks();
846                let _new_physical = self.cow_copy(paged_handle, &existing)?;
847                // Update the handle's block table with new physical IDs
848                // (In a real implementation, this would update the mappings)
849            }
850
851            self.allocate_blocks(paged_handle, new_blocks)?;
852        }
853
854        paged_handle.set_num_tokens(new_tokens);
855
856        debug!(
857            "Extended KV cache for {}: {} -> {} tokens",
858            paged_handle.request_id, current_tokens, new_tokens
859        );
860
861        Ok(())
862    }
863
864    async fn deallocate(&self, request_id: RequestId) -> Result<()> {
865        debug!("Deallocating paged KV cache for request: {:?}", request_id);
866
867        let handle = self.active_handles.write().remove(&request_id);
868
869        if let Some(handle) = handle {
870            // Check reference count
871            if handle.ref_count() > 1 {
872                // Don't free blocks, just decrement ref count
873                handle.remove_ref();
874                debug!(
875                    "Decremented ref count for {}, remaining: {}",
876                    request_id,
877                    handle.ref_count()
878                );
879                return Ok(());
880            }
881
882            // Free all blocks
883            let block_ids: Vec<PhysicalBlockId> = handle
884                .get_physical_blocks()
885                .into_iter()
886                .map(PhysicalBlockId)
887                .collect();
888
889            for block_id in block_ids {
890                let _ = self.gpu_pool.deallocate(block_id);
891                self.block_to_request.write().remove(&block_id);
892            }
893
894            // Update stats
895            {
896                let mut stats = self.stats.lock();
897                if stats.active_caches > 0 {
898                    stats.active_caches -= 1;
899                }
900            }
901        }
902
903        Ok(())
904    }
905
906    fn can_allocate(&self, request: &AllocationRequest) -> bool {
907        let required_blocks = request.initial_tokens.div_ceil(self.config.block_size);
908        let gpu_stats = self.gpu_pool.stats();
909
910        gpu_stats.free_blocks >= required_blocks
911            || gpu_stats.total_blocks + required_blocks <= gpu_stats.max_blocks
912    }
913
914    fn stats(&self) -> CacheManagerStats {
915        let gpu_stats = self.gpu_pool.stats();
916        let mut stats = self.stats.lock().clone();
917
918        stats.total_blocks = gpu_stats.max_blocks;
919        stats.free_blocks = gpu_stats.free_blocks;
920
921        // Calculate memory usage (rough estimate)
922        let bytes_per_block = self.config.block_size
923            * 2 // K + V
924            * self.config.num_layers
925            * self.config.num_heads
926            * self.config.head_dim
927            * 2; // FP16
928
929        stats.total_memory_bytes = gpu_stats.max_blocks * bytes_per_block;
930        stats.used_memory_bytes = gpu_stats.allocated_blocks * bytes_per_block;
931
932        stats
933    }
934
935    async fn gc(&self) -> Result<CacheGcStats> {
936        let start = Instant::now();
937
938        // Evict unused blocks
939        let evicted = self.gpu_pool.evict_blocks(10)?;
940
941        // Update stats
942        {
943            let mut stats = self.stats.lock();
944            stats.eviction_count += evicted.len() as u64;
945        }
946
947        Ok(CacheGcStats {
948            memory_freed: evicted.len() * self.config.block_size * 1024, // Rough estimate
949            caches_freed: 0,
950            gc_time_ms: start.elapsed().as_millis() as u64,
951        })
952    }
953
954    fn set_pressure_callback(&self, callback: Box<dyn Fn(MemoryPressure) + Send + Sync>) {
955        *self.pressure_callback.lock() = Some(callback);
956    }
957
958    fn get_handle(&self, request_id: RequestId) -> Option<Arc<dyn KvCacheHandle>> {
959        self.active_handles
960            .read()
961            .get(&request_id)
962            .map(|h| h.clone() as Arc<dyn KvCacheHandle>)
963    }
964
965    fn list_handles(&self) -> Vec<(RequestId, Arc<dyn KvCacheHandle>)> {
966        self.active_handles
967            .read()
968            .iter()
969            .map(|(id, handle)| (id.clone(), handle.clone() as Arc<dyn KvCacheHandle>))
970            .collect()
971    }
972}
973
974impl std::fmt::Debug for PagedKvCacheManager {
975    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
976        let gpu_stats = self.gpu_pool.stats();
977        f.debug_struct("PagedKvCacheManager")
978            .field("block_size", &self.config.block_size)
979            .field("total_gpu_blocks", &gpu_stats.total_blocks)
980            .field("free_gpu_blocks", &gpu_stats.free_blocks)
981            .field("active_handles", &self.active_handles.read().len())
982            .finish()
983    }
984}
985
986// ============================================================================
987// Tests
988// ============================================================================
989
990#[cfg(test)]
991mod tests {
992    use super::*;
993
994    fn create_test_request() -> AllocationRequest {
995        AllocationRequest {
996            request_id: RequestId::new(),
997            initial_tokens: 64,
998            max_sequence_length: 2048,
999            num_layers: 32,
1000            num_heads: 32,
1001            head_dim: 128,
1002            device: Device::CPU,
1003            dtype: DataType::FP16,
1004            priority: ferrum_types::Priority::Normal,
1005        }
1006    }
1007
1008    #[tokio::test]
1009    async fn test_manager_creation() {
1010        let manager = PagedKvCacheManager::with_defaults(Device::CPU, 16, 100);
1011        assert!(manager.is_ok());
1012    }
1013
1014    #[tokio::test]
1015    async fn test_allocate_and_deallocate() {
1016        let manager = PagedKvCacheManager::with_defaults(Device::CPU, 16, 100).unwrap();
1017        let request = create_test_request();
1018        let request_id = request.request_id.clone();
1019
1020        let handle = manager.allocate(&request).await.unwrap();
1021        assert!(handle.is_valid());
1022        assert_eq!(handle.num_tokens(), 64);
1023
1024        // Verify blocks were allocated (64 tokens / 16 block_size = 4 blocks)
1025        let stats = handle.stats();
1026        // The paged manager allocates blocks on demand - at least some should be allocated
1027        assert!(stats.blocks_allocated >= 1 || stats.tokens_stored >= 64);
1028
1029        manager.deallocate(request_id).await.unwrap();
1030    }
1031
1032    #[tokio::test]
1033    async fn test_extend() {
1034        let manager = PagedKvCacheManager::with_defaults(Device::CPU, 16, 100).unwrap();
1035        let request = create_test_request();
1036        let request_id = request.request_id.clone();
1037
1038        let handle = manager.allocate(&request).await.unwrap();
1039        let initial_blocks = handle.stats().blocks_allocated;
1040
1041        // Extend to require more blocks
1042        let paged_handle = manager.get_handle(request_id.clone()).unwrap();
1043        let paged_ref = paged_handle
1044            .as_any()
1045            .downcast_ref::<PagedKvCacheHandle>()
1046            .unwrap();
1047        manager.allocate_blocks(paged_ref, 4).unwrap();
1048
1049        let new_blocks = handle.stats().blocks_allocated;
1050        assert!(new_blocks > initial_blocks);
1051
1052        manager.deallocate(request_id).await.unwrap();
1053    }
1054
1055    #[tokio::test]
1056    async fn test_can_allocate() {
1057        let manager = PagedKvCacheManager::with_defaults(Device::CPU, 16, 10).unwrap();
1058
1059        let request = create_test_request();
1060        assert!(manager.can_allocate(&request));
1061
1062        // Allocate many blocks
1063        for _ in 0..8 {
1064            let req = create_test_request();
1065            let _ = manager.allocate(&req).await;
1066        }
1067
1068        // Should eventually fail to allocate more
1069        let stats = manager.stats();
1070        assert!(stats.free_blocks < stats.total_blocks);
1071    }
1072
1073    #[tokio::test]
1074    async fn test_gc() {
1075        let manager = PagedKvCacheManager::with_defaults(Device::CPU, 16, 100).unwrap();
1076
1077        // Allocate and deallocate some caches
1078        let request = create_test_request();
1079        let request_id = request.request_id.clone();
1080        let _ = manager.allocate(&request).await.unwrap();
1081        manager.deallocate(request_id).await.unwrap();
1082
1083        // GC should work
1084        let gc_stats = manager.gc().await.unwrap();
1085        assert_eq!(gc_stats.caches_freed, 0);
1086    }
1087
1088    #[test]
1089    fn test_paged_handle() {
1090        let handle = PagedKvCacheHandle::new(RequestId::new(), Device::CPU, 16, 32, 32, 128);
1091
1092        assert_eq!(handle.num_tokens(), 0);
1093        assert_eq!(handle.num_blocks(), 0);
1094
1095        // Add some blocks
1096        handle.add_block(0, 5);
1097        handle.add_block(1, 10);
1098
1099        assert_eq!(handle.num_blocks(), 2);
1100        assert_eq!(handle.get_physical_block(0), Some(5));
1101        assert_eq!(handle.get_physical_block(1), Some(10));
1102    }
1103
1104    #[tokio::test]
1105    async fn test_write_read_kv_across_blocks() {
1106        // Small model: 2 layers, 2 heads, dim=4, block_size=4
1107        let config = PagedKvCacheConfig {
1108            block_size: 4,
1109            max_gpu_blocks: 16,
1110            max_cpu_blocks: 0,
1111            enable_cow: false,
1112            enable_swapping: false,
1113            num_layers: 2,
1114            num_heads: 2,
1115            head_dim: 4,
1116            enable_prefix_cache: false,
1117            ..Default::default()
1118        };
1119        let manager = PagedKvCacheManager::new(Device::CPU, config).unwrap();
1120
1121        let request = AllocationRequest {
1122            request_id: RequestId::new(),
1123            initial_tokens: 6, // needs 2 blocks (ceil(6/4))
1124            max_sequence_length: 32,
1125            num_layers: 2,
1126            num_heads: 2,
1127            head_dim: 4,
1128            device: Device::CPU,
1129            dtype: DataType::FP16,
1130            priority: ferrum_types::Priority::Normal,
1131        };
1132        let request_id = request.request_id.clone();
1133
1134        let handle_dyn = manager.allocate(&request).await.unwrap();
1135        let handle = handle_dyn
1136            .as_any()
1137            .downcast_ref::<PagedKvCacheHandle>()
1138            .unwrap();
1139
1140        let kv_size = 2 * 4; // num_heads * head_dim = 8
1141
1142        // Write KV for 6 tokens across 2 blocks (tokens 0-3 in block 0, 4-5 in block 1)
1143        for pos in 0..6 {
1144            let key: Vec<f32> = (0..kv_size).map(|i| (pos * 100 + i) as f32).collect();
1145            let val: Vec<f32> = (0..kv_size).map(|i| (pos * 100 + i + 50) as f32).collect();
1146            manager.write_kv(handle, 0, pos, &key, &val).unwrap();
1147        }
1148
1149        // Read back all 6 tokens — this gathers across 2 non-contiguous blocks
1150        let (keys, vals) = manager.read_kv(handle, 0, 0, 6).unwrap();
1151        assert_eq!(keys.len(), 6 * kv_size);
1152        assert_eq!(vals.len(), 6 * kv_size);
1153
1154        // Verify token 0 key
1155        assert_eq!(keys[0], 0.0);
1156        assert_eq!(keys[kv_size - 1], 7.0);
1157
1158        // Verify token 4 key (first token of second block)
1159        assert_eq!(keys[4 * kv_size], 400.0);
1160
1161        // Verify token 5 value
1162        assert_eq!(vals[5 * kv_size], 550.0);
1163
1164        // Layer 1 should still be zeros
1165        let (k1, _) = manager.read_kv(handle, 1, 0, 1).unwrap();
1166        assert!(k1.iter().all(|&x| x == 0.0));
1167
1168        manager.deallocate(request_id).await.unwrap();
1169    }
1170
1171    #[test]
1172    fn test_required_blocks() {
1173        let handle = PagedKvCacheHandle::new(
1174            RequestId::new(),
1175            Device::CPU,
1176            16, // block size
1177            32,
1178            32,
1179            128,
1180        );
1181
1182        assert_eq!(handle.required_blocks(0), 0);
1183        assert_eq!(handle.required_blocks(16), 1);
1184        assert_eq!(handle.required_blocks(17), 2);
1185        assert_eq!(handle.required_blocks(32), 2);
1186        assert_eq!(handle.required_blocks(33), 3);
1187    }
1188}