ferrum_models/common/paged_pool.rs
1//! Multi-sequence paged-KV pool — Phase 4 of Metal paged attention.
2//!
3//! Replaces the per-cache_id `KvCache` allocation model with a
4//! shared-pool architecture matching vLLM / mistral.rs:
5//!
6//! - **One pool per layer** holds K and V for *all* concurrent sequences.
7//! Sized to `MAX_TOTAL_BLOCKS × num_kv_heads × block_size × head_dim`.
8//! - **Per-cache_id state** ([`PagedSeqState`]) carries that sequence's
9//! logical → physical block mapping (`block_table`) plus its current
10//! length (`len`). Multiple cache_ids can index into the same pool
11//! without colliding because their block_tables point at disjoint
12//! physical blocks (or shared ones, if prefix-caching is enabled
13//! later).
14//! - **[`BlockAllocator`]** is a free-list owning physical block indices.
15//! `allocate` pops, `free` pushes back. Out-of-memory surfaces as
16//! `Result::Err` so the scheduler can refuse the request rather than
17//! panicking deep in the model forward.
18//!
19//! What's *not* here yet (deferred to Phase 4c / 5):
20//! - Prefix sharing: today a fresh `PagedSeqState` always allocates new
21//! blocks even if its prompt overlaps another live sequence's.
22//! - Eviction / preemption: when blocks run out we just `Err`. A real
23//! scheduler would either refuse-and-queue or evict the least-recently-used.
24//! - Cross-process or cross-model pooling.
25
26use ferrum_kernels::backend::Backend;
27use ferrum_types::{FerrumError, Result};
28use std::sync::atomic::{AtomicUsize, Ordering};
29
30/// LIFO free-list block allocator. `O(1)` allocate / free, no fragmentation
31/// (all blocks are uniform size).
32///
33/// `capacity` is the total physical block count baked into the pool at
34/// load time. The allocator is independent per-model (block index space
35/// is not portable across models).
36pub struct BlockAllocator {
37 free_list: Vec<u32>,
38 capacity: u32,
39 /// Watermark: how many blocks have been live at peak, useful for
40 /// pool-sizing diagnostics in the bench harness.
41 peak_in_use: AtomicUsize,
42}
43
44impl BlockAllocator {
45 /// Create a fresh allocator. All `num_blocks` blocks start free.
46 /// Free-list is built so `allocate()` returns block 0 first, then 1,
47 /// etc. — predictable for tests and ensures the lower physical
48 /// blocks see the most reuse (better cache locality on M1's SLC).
49 pub fn new(num_blocks: u32) -> Self {
50 let mut free_list: Vec<u32> = (0..num_blocks).collect();
51 free_list.reverse(); // pop() yields 0 first
52 Self {
53 free_list,
54 capacity: num_blocks,
55 peak_in_use: AtomicUsize::new(0),
56 }
57 }
58
59 /// Allocate a single physical block. Returns `Err` when the pool is
60 /// exhausted — caller is expected to refuse the request and queue
61 /// it (or evict another seq, when that's wired up).
62 pub fn allocate(&mut self) -> Result<u32> {
63 match self.free_list.pop() {
64 Some(b) => {
65 let in_use = self.capacity as usize - self.free_list.len();
66 self.peak_in_use.fetch_max(in_use, Ordering::Relaxed);
67 Ok(b)
68 }
69 None => Err(FerrumError::resource_exhausted(format!(
70 "paged KV pool exhausted (capacity={} blocks, all in use)",
71 self.capacity
72 ))),
73 }
74 }
75
76 /// Bulk allocate. Atomic: either all `n` succeed or none are taken.
77 pub fn allocate_n(&mut self, n: usize) -> Result<Vec<u32>> {
78 if self.free_list.len() < n {
79 return Err(FerrumError::resource_exhausted(format!(
80 "paged KV pool exhausted: need {n} blocks but only {} free",
81 self.free_list.len()
82 )));
83 }
84 let mut out = Vec::with_capacity(n);
85 for _ in 0..n {
86 out.push(self.free_list.pop().unwrap());
87 }
88 let in_use = self.capacity as usize - self.free_list.len();
89 self.peak_in_use.fetch_max(in_use, Ordering::Relaxed);
90 Ok(out)
91 }
92
93 /// Return blocks to the free list. Caller is responsible for
94 /// ensuring no live sequence still references them; freeing a block
95 /// while it's still in a `PagedSeqState::blocks` will silently
96 /// corrupt the next allocation that gets it.
97 pub fn free(&mut self, blocks: &[u32]) {
98 self.free_list.extend_from_slice(blocks);
99 }
100
101 pub fn free_count(&self) -> usize {
102 self.free_list.len()
103 }
104
105 pub fn capacity(&self) -> u32 {
106 self.capacity
107 }
108
109 pub fn peak_in_use(&self) -> usize {
110 self.peak_in_use.load(Ordering::Relaxed)
111 }
112}
113
114/// Per-sequence paged-KV state.
115///
116/// Holds the logical→physical block mapping for ONE sequence (one
117/// `cache_id`) plus its current token count. The mapping is stored as
118/// both:
119/// - `blocks: Vec<u32>` — the host-side source of truth, used by the
120/// block allocator + grow logic.
121/// - `block_table_buf: B::Buffer` — a device-side u32 buffer that mirrors
122/// `blocks` and is read directly by the paged Metal kernels (PR #68 /
123/// #69). Kept in sync via [`Self::ensure_capacity`].
124///
125/// `context_lens_buf` is a 1-element u32 device buffer holding `len`.
126/// The kernel reads it each forward; we update it via `B::write_u32`.
127pub struct PagedSeqState<B: Backend> {
128 pub blocks: Vec<u32>,
129 pub block_table_buf: B::Buffer,
130 pub context_lens_buf: B::Buffer,
131 pub len: usize,
132 pub block_size: usize,
133 pub max_blocks_per_seq: usize,
134}
135
136impl<B: Backend> PagedSeqState<B> {
137 /// Allocate buffers for a sequence that hasn't yet allocated any
138 /// blocks. The allocator isn't touched here — the first call to
139 /// [`Self::ensure_capacity`] does the real work.
140 pub fn new(block_size: usize, max_blocks_per_seq: usize) -> Self {
141 let block_table_buf = B::alloc_u32(max_blocks_per_seq);
142 let context_lens_buf = B::alloc_u32(1);
143 // Initialise context_lens to 0 so a forward dispatched before
144 // any token has been written sees an empty context.
145 let mut ctx = B::new_context();
146 let mut cl = context_lens_buf;
147 B::write_u32(&mut ctx, &mut cl, &[0u32]);
148 B::sync(&mut ctx);
149 Self {
150 blocks: Vec::with_capacity(max_blocks_per_seq),
151 block_table_buf,
152 context_lens_buf: cl,
153 len: 0,
154 block_size,
155 max_blocks_per_seq,
156 }
157 }
158
159 /// Ensure the seq has enough blocks to hold `target_len` tokens.
160 /// Allocates additional blocks from the pool if needed and re-syncs
161 /// `block_table_buf` to the device. Idempotent if already big enough.
162 pub fn ensure_capacity(
163 &mut self,
164 ctx: &mut B::Context,
165 alloc: &mut BlockAllocator,
166 target_len: usize,
167 ) -> Result<()> {
168 let needed = target_len.div_ceil(self.block_size);
169 if needed > self.max_blocks_per_seq {
170 return Err(FerrumError::model(format!(
171 "paged KV: target_len={target_len} would need {needed} blocks, exceeds max_blocks_per_seq={}",
172 self.max_blocks_per_seq
173 )));
174 }
175 while self.blocks.len() < needed {
176 let block = alloc.allocate()?;
177 self.blocks.push(block);
178 }
179 // Mirror the host-side blocks list into the device buffer. We
180 // write the FULL `max_blocks_per_seq` entries — unused slots
181 // beyond `needed` are never read by the kernel (it only walks
182 // `[0, ceil(context_len / block_size))`), but writing them
183 // keeps the buffer's content predictable.
184 let mut padded = self.blocks.clone();
185 padded.resize(self.max_blocks_per_seq, 0);
186 B::write_u32(ctx, &mut self.block_table_buf, &padded);
187 Ok(())
188 }
189
190 /// Update the on-device `context_lens_buf` to the current `self.len`.
191 /// Call this after [`Self::ensure_capacity`] but before dispatching
192 /// the paged attention kernel for this seq.
193 pub fn sync_context_len(&mut self, ctx: &mut B::Context) {
194 B::write_u32(ctx, &mut self.context_lens_buf, &[self.len as u32]);
195 }
196
197 /// Release all blocks back to the allocator. Buffers are kept (cheap
198 /// to reuse for a future cache_id), but blocks become available for
199 /// other sequences. Sets `len` back to 0.
200 pub fn release(&mut self, alloc: &mut BlockAllocator) {
201 alloc.free(&self.blocks);
202 self.blocks.clear();
203 self.len = 0;
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210
211 #[test]
212 fn allocator_basic() {
213 let mut a = BlockAllocator::new(4);
214 assert_eq!(a.free_count(), 4);
215 assert_eq!(a.allocate().unwrap(), 0);
216 assert_eq!(a.allocate().unwrap(), 1);
217 assert_eq!(a.allocate().unwrap(), 2);
218 assert_eq!(a.allocate().unwrap(), 3);
219 assert!(a.allocate().is_err());
220 assert_eq!(a.free_count(), 0);
221
222 a.free(&[1, 3]);
223 assert_eq!(a.free_count(), 2);
224 // LIFO: most recently freed comes back first.
225 assert_eq!(a.allocate().unwrap(), 3);
226 assert_eq!(a.allocate().unwrap(), 1);
227 }
228
229 #[test]
230 fn allocator_atomic_n_failure() {
231 let mut a = BlockAllocator::new(3);
232 let _ = a.allocate().unwrap(); // 1 left in free_list... wait, 2 left
233 let _ = a.allocate().unwrap();
234 // 1 free, asking for 2 should fail without consuming the 1.
235 assert!(a.allocate_n(2).is_err());
236 assert_eq!(a.free_count(), 1);
237 }
238
239 #[test]
240 fn allocator_peak_tracking() {
241 let mut a = BlockAllocator::new(8);
242 let blocks = a.allocate_n(5).unwrap();
243 assert_eq!(a.peak_in_use(), 5);
244 a.free(&blocks);
245 assert_eq!(a.peak_in_use(), 5); // peak doesn't decrease
246 let _ = a.allocate_n(3).unwrap();
247 assert_eq!(a.peak_in_use(), 5);
248 }
249}