Skip to main content

llama_rs/model/
paged.rs

1//! Paged attention for efficient KV cache memory management
2//!
3//! Implements vLLM-style paged attention where the KV cache is divided into
4//! fixed-size blocks that are allocated on demand and can be shared across
5//! sequences (for prefix caching with copy-on-write).
6
7use std::sync::atomic::{AtomicUsize, Ordering};
8
9/// Physical block ID
10pub type BlockId = usize;
11
12/// Block size in tokens (each block stores this many KV entries per head)
13pub const DEFAULT_BLOCK_SIZE: usize = 16;
14
15/// Manages a pool of physical blocks.
16pub struct PageAllocator {
17    /// Total number of physical blocks
18    num_blocks: usize,
19    /// Free block IDs (stack-based for O(1) alloc/free)
20    free_blocks: Vec<BlockId>,
21    /// Reference count per block (for copy-on-write)
22    ref_counts: Vec<AtomicUsize>,
23}
24
25impl PageAllocator {
26    /// Create a new page allocator with all blocks initially free.
27    pub fn new(num_blocks: usize) -> Self {
28        let free_blocks: Vec<BlockId> = (0..num_blocks).collect();
29        let ref_counts: Vec<AtomicUsize> = (0..num_blocks).map(|_| AtomicUsize::new(0)).collect();
30        Self {
31            num_blocks,
32            free_blocks,
33            ref_counts,
34        }
35    }
36
37    /// Allocate a block from the free list. Returns None if no blocks available.
38    pub fn allocate(&mut self) -> Option<BlockId> {
39        let block_id = self.free_blocks.pop()?;
40        self.ref_counts[block_id].store(1, Ordering::SeqCst);
41        Some(block_id)
42    }
43
44    /// Free a block. Decrements ref count; pushes to free list when it reaches zero.
45    pub fn free(&mut self, block_id: BlockId) {
46        if block_id >= self.num_blocks {
47            return;
48        }
49        let prev = self.ref_counts[block_id].fetch_sub(1, Ordering::SeqCst);
50        if prev == 1 {
51            self.free_blocks.push(block_id);
52        }
53    }
54
55    /// Increment reference count for copy-on-write sharing.
56    pub fn increment_ref(&self, block_id: BlockId) {
57        if block_id < self.num_blocks {
58            self.ref_counts[block_id].fetch_add(1, Ordering::SeqCst);
59        }
60    }
61
62    /// Get the current reference count of a block.
63    pub fn ref_count(&self, block_id: BlockId) -> usize {
64        if block_id >= self.num_blocks {
65            return 0;
66        }
67        self.ref_counts[block_id].load(Ordering::SeqCst)
68    }
69
70    /// Number of free blocks available.
71    pub fn num_free(&self) -> usize {
72        self.free_blocks.len()
73    }
74
75    /// Number of blocks currently in use (ref count > 0).
76    pub fn num_used(&self) -> usize {
77        self.num_blocks - self.free_blocks.len()
78    }
79}
80
81/// Per-sequence mapping from logical to physical blocks.
82pub struct BlockTable {
83    /// Logical block index -> physical BlockId
84    entries: Vec<Option<BlockId>>,
85    /// Number of tokens stored
86    num_tokens: usize,
87    /// Block size
88    block_size: usize,
89}
90
91impl BlockTable {
92    /// Create an empty block table.
93    pub fn new(block_size: usize) -> Self {
94        Self {
95            entries: Vec::new(),
96            num_tokens: 0,
97            block_size,
98        }
99    }
100
101    /// Number of allocated blocks.
102    pub fn num_blocks(&self) -> usize {
103        self.entries.len()
104    }
105
106    /// Number of tokens stored.
107    pub fn num_tokens(&self) -> usize {
108        self.num_tokens
109    }
110
111    /// Map logical block index to physical block ID.
112    pub fn logical_to_physical(&self, logical_idx: usize) -> Option<BlockId> {
113        self.entries.get(logical_idx).and_then(|e| *e)
114    }
115
116    /// Add a new block mapping.
117    pub fn append_block(&mut self, block_id: BlockId) {
118        self.entries.push(Some(block_id));
119    }
120
121    /// Convert token position to (logical_block_idx, offset_within_block).
122    pub fn token_to_block(&self, token_pos: usize) -> (usize, usize) {
123        if self.block_size == 0 {
124            return (0, 0);
125        }
126        let logical_block_idx = token_pos / self.block_size;
127        let offset_within_block = token_pos % self.block_size;
128        (logical_block_idx, offset_within_block)
129    }
130
131    /// Update the token count.
132    pub fn set_num_tokens(&mut self, n: usize) {
133        self.num_tokens = n;
134    }
135}
136
137/// The main memory pool that holds all KV data.
138pub struct PagedKVPool {
139    /// Physical KV data per layer: flat storage [block_id][head][offset][dim]
140    /// Layout: block_id * block_stride + head * (block_size * head_dim) + offset * head_dim + d
141    k_pool: Vec<Vec<f32>>,
142    v_pool: Vec<Vec<f32>>,
143    /// Page allocator
144    allocator: PageAllocator,
145    /// Configuration
146    num_layers: usize,
147    num_kv_heads: usize,
148    head_dim: usize,
149    block_size: usize,
150    num_blocks: usize,
151}
152
153impl PagedKVPool {
154    /// Size of one block in floats (all heads).
155    fn block_stride(&self) -> usize {
156        self.num_kv_heads * self.block_size * self.head_dim
157    }
158
159    /// Offset for a single (block_id, offset, head) position.
160    fn block_offset(&self, block_id: BlockId, offset: usize, head: usize) -> usize {
161        block_id * self.block_stride() + head * (self.block_size * self.head_dim) + offset * self.head_dim
162    }
163
164    /// Create a new paged KV pool.
165    pub fn new(
166        num_layers: usize,
167        num_kv_heads: usize,
168        head_dim: usize,
169        block_size: usize,
170        num_blocks: usize,
171    ) -> Self {
172        let block_stride = num_kv_heads * block_size * head_dim;
173        let layer_size = num_blocks * block_stride;
174
175        let k_pool: Vec<Vec<f32>> = (0..num_layers)
176            .map(|_| vec![0.0; layer_size])
177            .collect();
178        let v_pool: Vec<Vec<f32>> = (0..num_layers)
179            .map(|_| vec![0.0; layer_size])
180            .collect();
181
182        Self {
183            k_pool,
184            v_pool,
185            allocator: PageAllocator::new(num_blocks),
186            num_layers,
187            num_kv_heads,
188            head_dim,
189            block_size,
190            num_blocks,
191        }
192    }
193
194    /// Allocate N blocks from the pool.
195    pub fn allocate_blocks(&mut self, count: usize) -> Option<Vec<BlockId>> {
196        let mut blocks = Vec::with_capacity(count);
197        for _ in 0..count {
198            let block_id = self.allocator.allocate()?;
199            blocks.push(block_id);
200        }
201        Some(blocks)
202    }
203
204    /// Free blocks back to the pool.
205    pub fn free_blocks(&mut self, block_ids: &[BlockId]) {
206        for &block_id in block_ids {
207            self.allocator.free(block_id);
208        }
209    }
210
211    /// Write one KV position. k and v must have length head_dim.
212    pub fn write_kv(
213        &mut self,
214        layer: usize,
215        block_id: BlockId,
216        offset: usize,
217        head: usize,
218        k: &[f32],
219        v: &[f32],
220    ) {
221        if layer >= self.num_layers
222            || head >= self.num_kv_heads
223            || offset >= self.block_size
224            || k.len() != self.head_dim
225            || v.len() != self.head_dim
226        {
227            return;
228        }
229        let base = self.block_offset(block_id, offset, head);
230        self.k_pool[layer][base..base + self.head_dim].copy_from_slice(k);
231        self.v_pool[layer][base..base + self.head_dim].copy_from_slice(v);
232    }
233
234    /// Read K for one position. Returns slice of length head_dim.
235    pub fn read_k(
236        &self,
237        layer: usize,
238        block_id: BlockId,
239        offset: usize,
240        head: usize,
241    ) -> &[f32] {
242        if layer >= self.num_layers
243            || head >= self.num_kv_heads
244            || offset >= self.block_size
245        {
246            return &[];
247        }
248        let base = self.block_offset(block_id, offset, head);
249        &self.k_pool[layer][base..base + self.head_dim]
250    }
251
252    /// Read V for one position. Returns slice of length head_dim.
253    pub fn read_v(
254        &self,
255        layer: usize,
256        block_id: BlockId,
257        offset: usize,
258        head: usize,
259    ) -> &[f32] {
260        if layer >= self.num_layers
261            || head >= self.num_kv_heads
262            || offset >= self.block_size
263        {
264            return &[];
265        }
266        let base = self.block_offset(block_id, offset, head);
267        &self.v_pool[layer][base..base + self.head_dim]
268    }
269
270    /// Copy-on-write: copy all layer data from src block to dst block.
271    pub fn copy_block(&mut self, src: BlockId, dst: BlockId) {
272        let block_stride = self.block_stride();
273        let src_base = src * block_stride;
274        let dst_base = dst * block_stride;
275        for layer in 0..self.num_layers {
276            let src_slice = self.k_pool[layer][src_base..src_base + block_stride].to_vec();
277            self.k_pool[layer][dst_base..dst_base + block_stride].copy_from_slice(&src_slice);
278            let src_slice = self.v_pool[layer][src_base..src_base + block_stride].to_vec();
279            self.v_pool[layer][dst_base..dst_base + block_stride].copy_from_slice(&src_slice);
280        }
281    }
282
283    /// Total pool memory usage in bytes.
284    pub fn memory_usage(&self) -> usize {
285        let floats_per_layer = self.num_blocks * self.block_stride();
286        let total_floats = floats_per_layer * self.num_layers * 2; // K and V
287        total_floats * std::mem::size_of::<f32>()
288    }
289
290    /// Number of free blocks.
291    pub fn num_free_blocks(&self) -> usize {
292        self.allocator.num_free()
293    }
294
295    /// Total number of blocks.
296    pub fn total_blocks(&self) -> usize {
297        self.num_blocks
298    }
299
300    /// Expose allocator for PagedSequence (needs allocate/free/increment_ref).
301    #[allow(dead_code)]
302    pub(crate) fn allocator_mut(&mut self) -> &mut PageAllocator {
303        &mut self.allocator
304    }
305
306    /// Expose allocator for ref counting.
307    #[allow(dead_code)]
308    pub(crate) fn allocator(&self) -> &PageAllocator {
309        &self.allocator
310    }
311}
312
313/// Per-sequence state for paged attention.
314pub struct PagedSequence {
315    /// Block table mapping logical blocks to physical
316    pub block_table: BlockTable,
317    /// Sequence ID
318    pub seq_id: usize,
319    /// Current token count
320    pub num_tokens: usize,
321}
322
323impl PagedSequence {
324    /// Create a new paged sequence.
325    pub fn new(seq_id: usize, block_size: usize) -> Self {
326        Self {
327            block_table: BlockTable::new(block_size),
328            seq_id,
329            num_tokens: 0,
330        }
331    }
332
333    /// Append a KV entry for one (layer, head). Allocates a new block if needed.
334    pub fn append_token(
335        &mut self,
336        pool: &mut PagedKVPool,
337        layer: usize,
338        head: usize,
339        k: &[f32],
340        v: &[f32],
341    ) -> Result<(), &'static str> {
342        let (logical_block_idx, offset_within_block) =
343            self.block_table.token_to_block(self.num_tokens);
344
345        // Allocate new block if needed
346        while logical_block_idx >= self.block_table.num_blocks() {
347            let blocks = pool
348                .allocate_blocks(1)
349                .ok_or("No free blocks in pool")?;
350            let block_id = blocks[0];
351            self.block_table.append_block(block_id);
352        }
353
354        let block_id = self
355            .block_table
356            .logical_to_physical(logical_block_idx)
357            .ok_or("Missing block mapping")?;
358
359        pool.write_kv(layer, block_id, offset_within_block, head, k, v);
360
361        Ok(())
362    }
363
364    /// Advance to the next token position after writing all (layer, head) for the current token.
365    pub fn advance_token(&mut self) {
366        self.num_tokens += 1;
367        self.block_table.set_num_tokens(self.num_tokens);
368    }
369
370    /// Gather all K/V for the given layer and head into contiguous buffers for attention.
371    /// Returns (k_buf, v_buf) each of size num_tokens * head_dim.
372    pub fn get_kv_for_attention(
373        &self,
374        pool: &PagedKVPool,
375        layer: usize,
376        head: usize,
377    ) -> (Vec<f32>, Vec<f32>) {
378        let num_tokens = self.num_tokens;
379        let head_dim = pool.head_dim;
380
381        let mut k_buf = vec![0.0; num_tokens * head_dim];
382        let mut v_buf = vec![0.0; num_tokens * head_dim];
383
384        for token_pos in 0..num_tokens {
385            let (logical_block_idx, offset) = self.block_table.token_to_block(token_pos);
386            if let Some(block_id) = self.block_table.logical_to_physical(logical_block_idx) {
387                let k_slice = pool.read_k(layer, block_id, offset, head);
388                let v_slice = pool.read_v(layer, block_id, offset, head);
389                if k_slice.len() == head_dim && v_slice.len() == head_dim {
390                    k_buf[token_pos * head_dim..(token_pos + 1) * head_dim]
391                        .copy_from_slice(k_slice);
392                    v_buf[token_pos * head_dim..(token_pos + 1) * head_dim]
393                        .copy_from_slice(v_slice);
394                }
395            }
396        }
397
398        (k_buf, v_buf)
399    }
400
401}
402
403impl BlockTable {
404    /// Clear all block mappings (caller must free physical blocks separately).
405    pub fn clear(&mut self) {
406        self.entries.clear();
407        self.num_tokens = 0;
408    }
409}
410
411impl PagedSequence {
412    /// Release all blocks back to the pool.
413    pub fn free(&mut self, pool: &mut PagedKVPool) {
414        let block_ids: Vec<BlockId> = (0..self.block_table.num_blocks())
415            .filter_map(|i| self.block_table.logical_to_physical(i))
416            .collect();
417        pool.free_blocks(&block_ids);
418        self.block_table.clear();
419        self.num_tokens = 0;
420    }
421}
422
423#[cfg(test)]
424mod tests {
425    use super::*;
426
427    #[test]
428    fn test_page_allocator_basic() {
429        let mut alloc = PageAllocator::new(4);
430        assert_eq!(alloc.num_free(), 4);
431        assert_eq!(alloc.num_used(), 0);
432
433        let b0 = alloc.allocate().unwrap();
434        let b1 = alloc.allocate().unwrap();
435        assert_eq!(alloc.num_free(), 2);
436        assert_eq!(alloc.num_used(), 2);
437        assert_eq!(alloc.ref_count(b0), 1);
438        assert_eq!(alloc.ref_count(b1), 1);
439
440        alloc.increment_ref(b0);
441        assert_eq!(alloc.ref_count(b0), 2);
442
443        alloc.free(b0);
444        assert_eq!(alloc.ref_count(b0), 1);
445        assert_eq!(alloc.num_free(), 2);
446
447        alloc.free(b0);
448        assert_eq!(alloc.ref_count(b0), 0);
449        assert_eq!(alloc.num_free(), 3);
450
451        alloc.free(b1);
452        assert_eq!(alloc.num_free(), 4);
453    }
454
455    #[test]
456    fn test_block_table() {
457        let mut table = BlockTable::new(16);
458        assert_eq!(table.num_blocks(), 0);
459        assert_eq!(table.num_tokens(), 0);
460
461        table.append_block(5);
462        table.append_block(7);
463        assert_eq!(table.num_blocks(), 2);
464        assert_eq!(table.logical_to_physical(0), Some(5));
465        assert_eq!(table.logical_to_physical(1), Some(7));
466        assert_eq!(table.logical_to_physical(2), None);
467
468        assert_eq!(table.token_to_block(0), (0, 0));
469        assert_eq!(table.token_to_block(15), (0, 15));
470        assert_eq!(table.token_to_block(16), (1, 0));
471        assert_eq!(table.token_to_block(31), (1, 15));
472
473        table.set_num_tokens(20);
474        assert_eq!(table.num_tokens(), 20);
475    }
476
477    #[test]
478    fn test_paged_kv_pool() {
479        let mut pool = PagedKVPool::new(2, 4, 8, 16, 10);
480        assert_eq!(pool.num_free_blocks(), 10);
481        assert_eq!(pool.total_blocks(), 10);
482
483        let blocks = pool.allocate_blocks(2).unwrap();
484        let b0 = blocks[0];
485        let b1 = blocks[1];
486
487        let k: Vec<f32> = (0..8).map(|i| i as f32).collect();
488        let v: Vec<f32> = (0..8).map(|i| (i + 10) as f32).collect();
489
490        pool.write_kv(0, b0, 0, 0, &k, &v);
491        pool.write_kv(0, b0, 1, 1, &k, &v);
492
493        let read_k = pool.read_k(0, b0, 0, 0);
494        let read_v = pool.read_v(0, b0, 0, 0);
495        assert_eq!(read_k, &k[..]);
496        assert_eq!(read_v, &v[..]);
497
498        pool.free_blocks(&[b0, b1]);
499        assert_eq!(pool.num_free_blocks(), 10);
500        assert!(pool.memory_usage() > 0);
501    }
502
503    #[test]
504    fn test_paged_sequence() {
505        let mut pool = PagedKVPool::new(1, 1, 4, 8, 16);
506        let mut seq = PagedSequence::new(0, 8);
507
508        let k: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
509        let v: Vec<f32> = vec![5.0, 6.0, 7.0, 8.0];
510
511        seq.append_token(&mut pool, 0, 0, &k, &v).unwrap();
512        seq.advance_token();
513
514        let k2: Vec<f32> = vec![10.0, 20.0, 30.0, 40.0];
515        let v2: Vec<f32> = vec![50.0, 60.0, 70.0, 80.0];
516        seq.append_token(&mut pool, 0, 0, &k2, &v2).unwrap();
517        seq.advance_token();
518
519        assert_eq!(seq.num_tokens, 2);
520
521        let (gathered_k, gathered_v) = seq.get_kv_for_attention(&pool, 0, 0);
522        assert_eq!(gathered_k[0..4], k[..]);
523        assert_eq!(gathered_v[0..4], v[..]);
524        assert_eq!(gathered_k[4..8], k2[..]);
525        assert_eq!(gathered_v[4..8], v2[..]);
526
527        seq.free(&mut pool);
528        assert_eq!(pool.num_free_blocks(), 16);
529    }
530
531    #[test]
532    fn test_copy_on_write() {
533        let mut pool = PagedKVPool::new(1, 1, 4, 8, 16);
534        let blocks = pool.allocate_blocks(2).unwrap();
535        let src = blocks[0];
536        let dst = blocks[1];
537
538        let k: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
539        let v: Vec<f32> = vec![5.0, 6.0, 7.0, 8.0];
540        pool.write_kv(0, src, 0, 0, &k, &v);
541
542        pool.copy_block(src, dst);
543
544        let read_k = pool.read_k(0, dst, 0, 0);
545        let read_v = pool.read_v(0, dst, 0, 0);
546        assert_eq!(read_k, &k[..]);
547        assert_eq!(read_v, &v[..]);
548
549        pool.allocator_mut().increment_ref(src);
550        assert_eq!(pool.allocator().ref_count(src), 2);
551
552        pool.free_blocks(&[src, dst]);
553    }
554
555    #[test]
556    fn test_memory_fragmentation() {
557        let mut pool = PagedKVPool::new(1, 1, 4, 8, 10);
558        let mut allocated = Vec::new();
559
560        for _ in 0..10 {
561            let blocks = pool.allocate_blocks(1).unwrap();
562            allocated.push(blocks[0]);
563        }
564        assert_eq!(pool.num_free_blocks(), 0);
565        assert!(pool.allocate_blocks(1).is_none());
566
567        pool.free_blocks(&allocated[0..5]);
568        assert_eq!(pool.num_free_blocks(), 5);
569
570        let blocks = pool.allocate_blocks(5).unwrap();
571        assert_eq!(pool.num_free_blocks(), 0);
572
573        pool.free_blocks(&allocated[5..10]);
574        pool.free_blocks(&blocks);
575        assert_eq!(pool.num_free_blocks(), 10);
576    }
577}