Skip to main content

ferrum_models/common/
paged_pool.rs

1//! Multi-sequence paged-KV pool — Phase 4 of Metal paged attention.
2//!
3//! Replaces the per-cache_id `KvCache` allocation model with a
4//! shared-pool architecture matching vLLM / mistral.rs:
5//!
6//! - **One pool per layer** holds K and V for *all* concurrent sequences.
7//!   Sized to `MAX_TOTAL_BLOCKS × num_kv_heads × block_size × head_dim`.
8//! - **Per-cache_id state** ([`PagedSeqState`]) carries that sequence's
9//!   logical → physical block mapping (`block_table`) plus its current
10//!   length (`len`). Multiple cache_ids can index into the same pool
11//!   without colliding because their block_tables point at disjoint
12//!   physical blocks (or shared ones, if prefix-caching is enabled
13//!   later).
14//! - **[`BlockAllocator`]** is a free-list owning physical block indices.
15//!   `allocate` pops, `free` pushes back. Out-of-memory surfaces as
16//!   `Result::Err` so the scheduler can refuse the request rather than
17//!   panicking deep in the model forward.
18//!
19//! What's *not* here yet (deferred to Phase 4c / 5):
20//! - Prefix sharing: today a fresh `PagedSeqState` always allocates new
21//!   blocks even if its prompt overlaps another live sequence's.
22//! - Eviction / preemption: when blocks run out we just `Err`. A real
23//!   scheduler would either refuse-and-queue or evict the least-recently-used.
24//! - Cross-process or cross-model pooling.
25
26use ferrum_kernels::backend::Backend;
27use ferrum_types::{FerrumError, Result};
28use std::sync::atomic::{AtomicUsize, Ordering};
29
30/// LIFO free-list block allocator. `O(1)` allocate / free, no fragmentation
31/// (all blocks are uniform size).
32///
33/// `capacity` is the total physical block count baked into the pool at
34/// load time. The allocator is independent per-model (block index space
35/// is not portable across models).
36pub struct BlockAllocator {
37    free_list: Vec<u32>,
38    capacity: u32,
39    /// Watermark: how many blocks have been live at peak, useful for
40    /// pool-sizing diagnostics in the bench harness.
41    peak_in_use: AtomicUsize,
42}
43
44impl BlockAllocator {
45    /// Create a fresh allocator. All `num_blocks` blocks start free.
46    /// Free-list is built so `allocate()` returns block 0 first, then 1,
47    /// etc. — predictable for tests and ensures the lower physical
48    /// blocks see the most reuse (better cache locality on M1's SLC).
49    pub fn new(num_blocks: u32) -> Self {
50        let mut free_list: Vec<u32> = (0..num_blocks).collect();
51        free_list.reverse(); // pop() yields 0 first
52        Self {
53            free_list,
54            capacity: num_blocks,
55            peak_in_use: AtomicUsize::new(0),
56        }
57    }
58
59    /// Allocate a single physical block. Returns `Err` when the pool is
60    /// exhausted — caller is expected to refuse the request and queue
61    /// it (or evict another seq, when that's wired up).
62    pub fn allocate(&mut self) -> Result<u32> {
63        match self.free_list.pop() {
64            Some(b) => {
65                let in_use = self.capacity as usize - self.free_list.len();
66                self.peak_in_use.fetch_max(in_use, Ordering::Relaxed);
67                Ok(b)
68            }
69            None => Err(FerrumError::resource_exhausted(format!(
70                "paged KV pool exhausted (capacity={} blocks, all in use)",
71                self.capacity
72            ))),
73        }
74    }
75
76    /// Bulk allocate. Atomic: either all `n` succeed or none are taken.
77    pub fn allocate_n(&mut self, n: usize) -> Result<Vec<u32>> {
78        if self.free_list.len() < n {
79            return Err(FerrumError::resource_exhausted(format!(
80                "paged KV pool exhausted: need {n} blocks but only {} free",
81                self.free_list.len()
82            )));
83        }
84        let mut out = Vec::with_capacity(n);
85        for _ in 0..n {
86            out.push(self.free_list.pop().unwrap());
87        }
88        let in_use = self.capacity as usize - self.free_list.len();
89        self.peak_in_use.fetch_max(in_use, Ordering::Relaxed);
90        Ok(out)
91    }
92
93    /// Return blocks to the free list. Caller is responsible for
94    /// ensuring no live sequence still references them; freeing a block
95    /// while it's still in a `PagedSeqState::blocks` will silently
96    /// corrupt the next allocation that gets it.
97    pub fn free(&mut self, blocks: &[u32]) {
98        self.free_list.extend_from_slice(blocks);
99    }
100
101    pub fn free_count(&self) -> usize {
102        self.free_list.len()
103    }
104
105    pub fn capacity(&self) -> u32 {
106        self.capacity
107    }
108
109    pub fn peak_in_use(&self) -> usize {
110        self.peak_in_use.load(Ordering::Relaxed)
111    }
112}
113
114/// Per-sequence paged-KV state.
115///
116/// Holds the logical→physical block mapping for ONE sequence (one
117/// `cache_id`) plus its current token count. The mapping is stored as
118/// both:
119/// - `blocks: Vec<u32>` — the host-side source of truth, used by the
120///   block allocator + grow logic.
121/// - `block_table_buf: B::Buffer` — a device-side u32 buffer that mirrors
122///   `blocks` and is read directly by the paged Metal kernels (PR #68 /
123///   #69). Kept in sync via [`Self::ensure_capacity`].
124///
125/// `context_lens_buf` is a 1-element u32 device buffer holding `len`.
126/// The kernel reads it each forward; we update it via `B::write_u32`.
127pub struct PagedSeqState<B: Backend> {
128    pub blocks: Vec<u32>,
129    pub block_table_buf: B::Buffer,
130    pub context_lens_buf: B::Buffer,
131    pub len: usize,
132    pub block_size: usize,
133    pub max_blocks_per_seq: usize,
134}
135
136impl<B: Backend> PagedSeqState<B> {
137    /// Allocate buffers for a sequence that hasn't yet allocated any
138    /// blocks. The allocator isn't touched here — the first call to
139    /// [`Self::ensure_capacity`] does the real work.
140    pub fn new(block_size: usize, max_blocks_per_seq: usize) -> Self {
141        let block_table_buf = B::alloc_u32(max_blocks_per_seq);
142        let context_lens_buf = B::alloc_u32(1);
143        // Initialise context_lens to 0 so a forward dispatched before
144        // any token has been written sees an empty context.
145        let mut ctx = B::new_context();
146        let mut cl = context_lens_buf;
147        B::write_u32(&mut ctx, &mut cl, &[0u32]);
148        B::sync(&mut ctx);
149        Self {
150            blocks: Vec::with_capacity(max_blocks_per_seq),
151            block_table_buf,
152            context_lens_buf: cl,
153            len: 0,
154            block_size,
155            max_blocks_per_seq,
156        }
157    }
158
159    /// Ensure the seq has enough blocks to hold `target_len` tokens.
160    /// Allocates additional blocks from the pool if needed and re-syncs
161    /// `block_table_buf` to the device. Idempotent if already big enough.
162    pub fn ensure_capacity(
163        &mut self,
164        ctx: &mut B::Context,
165        alloc: &mut BlockAllocator,
166        target_len: usize,
167    ) -> Result<()> {
168        let needed = target_len.div_ceil(self.block_size);
169        if needed > self.max_blocks_per_seq {
170            return Err(FerrumError::model(format!(
171                "paged KV: target_len={target_len} would need {needed} blocks, exceeds max_blocks_per_seq={}",
172                self.max_blocks_per_seq
173            )));
174        }
175        while self.blocks.len() < needed {
176            let block = alloc.allocate()?;
177            self.blocks.push(block);
178        }
179        // Mirror the host-side blocks list into the device buffer. We
180        // write the FULL `max_blocks_per_seq` entries — unused slots
181        // beyond `needed` are never read by the kernel (it only walks
182        // `[0, ceil(context_len / block_size))`), but writing them
183        // keeps the buffer's content predictable.
184        let mut padded = self.blocks.clone();
185        padded.resize(self.max_blocks_per_seq, 0);
186        B::write_u32(ctx, &mut self.block_table_buf, &padded);
187        Ok(())
188    }
189
190    /// Update the on-device `context_lens_buf` to the current `self.len`.
191    /// Call this after [`Self::ensure_capacity`] but before dispatching
192    /// the paged attention kernel for this seq.
193    pub fn sync_context_len(&mut self, ctx: &mut B::Context) {
194        B::write_u32(ctx, &mut self.context_lens_buf, &[self.len as u32]);
195    }
196
197    /// Release all blocks back to the allocator. Buffers are kept (cheap
198    /// to reuse for a future cache_id), but blocks become available for
199    /// other sequences. Sets `len` back to 0.
200    pub fn release(&mut self, alloc: &mut BlockAllocator) {
201        alloc.free(&self.blocks);
202        self.blocks.clear();
203        self.len = 0;
204    }
205}
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210
211    #[test]
212    fn allocator_basic() {
213        let mut a = BlockAllocator::new(4);
214        assert_eq!(a.free_count(), 4);
215        assert_eq!(a.allocate().unwrap(), 0);
216        assert_eq!(a.allocate().unwrap(), 1);
217        assert_eq!(a.allocate().unwrap(), 2);
218        assert_eq!(a.allocate().unwrap(), 3);
219        assert!(a.allocate().is_err());
220        assert_eq!(a.free_count(), 0);
221
222        a.free(&[1, 3]);
223        assert_eq!(a.free_count(), 2);
224        // LIFO: most recently freed comes back first.
225        assert_eq!(a.allocate().unwrap(), 3);
226        assert_eq!(a.allocate().unwrap(), 1);
227    }
228
229    #[test]
230    fn allocator_atomic_n_failure() {
231        let mut a = BlockAllocator::new(3);
232        let _ = a.allocate().unwrap(); // 1 left in free_list... wait, 2 left
233        let _ = a.allocate().unwrap();
234        // 1 free, asking for 2 should fail without consuming the 1.
235        assert!(a.allocate_n(2).is_err());
236        assert_eq!(a.free_count(), 1);
237    }
238
239    #[test]
240    fn allocator_peak_tracking() {
241        let mut a = BlockAllocator::new(8);
242        let blocks = a.allocate_n(5).unwrap();
243        assert_eq!(a.peak_in_use(), 5);
244        a.free(&blocks);
245        assert_eq!(a.peak_in_use(), 5); // peak doesn't decrease
246        let _ = a.allocate_n(3).unwrap();
247        assert_eq!(a.peak_in_use(), 5);
248    }
249}