1use std::sync::atomic::{AtomicUsize, Ordering};
8
9pub type BlockId = usize;
11
12pub const DEFAULT_BLOCK_SIZE: usize = 16;
14
15pub struct PageAllocator {
17 num_blocks: usize,
19 free_blocks: Vec<BlockId>,
21 ref_counts: Vec<AtomicUsize>,
23}
24
25impl PageAllocator {
26 pub fn new(num_blocks: usize) -> Self {
28 let free_blocks: Vec<BlockId> = (0..num_blocks).collect();
29 let ref_counts: Vec<AtomicUsize> = (0..num_blocks).map(|_| AtomicUsize::new(0)).collect();
30 Self {
31 num_blocks,
32 free_blocks,
33 ref_counts,
34 }
35 }
36
37 pub fn allocate(&mut self) -> Option<BlockId> {
39 let block_id = self.free_blocks.pop()?;
40 self.ref_counts[block_id].store(1, Ordering::SeqCst);
41 Some(block_id)
42 }
43
44 pub fn free(&mut self, block_id: BlockId) {
46 if block_id >= self.num_blocks {
47 return;
48 }
49 let prev = self.ref_counts[block_id].fetch_sub(1, Ordering::SeqCst);
50 if prev == 1 {
51 self.free_blocks.push(block_id);
52 }
53 }
54
55 pub fn increment_ref(&self, block_id: BlockId) {
57 if block_id < self.num_blocks {
58 self.ref_counts[block_id].fetch_add(1, Ordering::SeqCst);
59 }
60 }
61
62 pub fn ref_count(&self, block_id: BlockId) -> usize {
64 if block_id >= self.num_blocks {
65 return 0;
66 }
67 self.ref_counts[block_id].load(Ordering::SeqCst)
68 }
69
70 pub fn num_free(&self) -> usize {
72 self.free_blocks.len()
73 }
74
75 pub fn num_used(&self) -> usize {
77 self.num_blocks - self.free_blocks.len()
78 }
79}
80
81pub struct BlockTable {
83 entries: Vec<Option<BlockId>>,
85 num_tokens: usize,
87 block_size: usize,
89}
90
91impl BlockTable {
92 pub fn new(block_size: usize) -> Self {
94 Self {
95 entries: Vec::new(),
96 num_tokens: 0,
97 block_size,
98 }
99 }
100
101 pub fn num_blocks(&self) -> usize {
103 self.entries.len()
104 }
105
106 pub fn num_tokens(&self) -> usize {
108 self.num_tokens
109 }
110
111 pub fn logical_to_physical(&self, logical_idx: usize) -> Option<BlockId> {
113 self.entries.get(logical_idx).and_then(|e| *e)
114 }
115
116 pub fn append_block(&mut self, block_id: BlockId) {
118 self.entries.push(Some(block_id));
119 }
120
121 pub fn token_to_block(&self, token_pos: usize) -> (usize, usize) {
123 if self.block_size == 0 {
124 return (0, 0);
125 }
126 let logical_block_idx = token_pos / self.block_size;
127 let offset_within_block = token_pos % self.block_size;
128 (logical_block_idx, offset_within_block)
129 }
130
131 pub fn set_num_tokens(&mut self, n: usize) {
133 self.num_tokens = n;
134 }
135}
136
137pub struct PagedKVPool {
139 k_pool: Vec<Vec<f32>>,
142 v_pool: Vec<Vec<f32>>,
143 allocator: PageAllocator,
145 num_layers: usize,
147 num_kv_heads: usize,
148 head_dim: usize,
149 block_size: usize,
150 num_blocks: usize,
151}
152
153impl PagedKVPool {
154 fn block_stride(&self) -> usize {
156 self.num_kv_heads * self.block_size * self.head_dim
157 }
158
159 fn block_offset(&self, block_id: BlockId, offset: usize, head: usize) -> usize {
161 block_id * self.block_stride() + head * (self.block_size * self.head_dim) + offset * self.head_dim
162 }
163
164 pub fn new(
166 num_layers: usize,
167 num_kv_heads: usize,
168 head_dim: usize,
169 block_size: usize,
170 num_blocks: usize,
171 ) -> Self {
172 let block_stride = num_kv_heads * block_size * head_dim;
173 let layer_size = num_blocks * block_stride;
174
175 let k_pool: Vec<Vec<f32>> = (0..num_layers)
176 .map(|_| vec![0.0; layer_size])
177 .collect();
178 let v_pool: Vec<Vec<f32>> = (0..num_layers)
179 .map(|_| vec![0.0; layer_size])
180 .collect();
181
182 Self {
183 k_pool,
184 v_pool,
185 allocator: PageAllocator::new(num_blocks),
186 num_layers,
187 num_kv_heads,
188 head_dim,
189 block_size,
190 num_blocks,
191 }
192 }
193
194 pub fn allocate_blocks(&mut self, count: usize) -> Option<Vec<BlockId>> {
196 let mut blocks = Vec::with_capacity(count);
197 for _ in 0..count {
198 let block_id = self.allocator.allocate()?;
199 blocks.push(block_id);
200 }
201 Some(blocks)
202 }
203
204 pub fn free_blocks(&mut self, block_ids: &[BlockId]) {
206 for &block_id in block_ids {
207 self.allocator.free(block_id);
208 }
209 }
210
211 pub fn write_kv(
213 &mut self,
214 layer: usize,
215 block_id: BlockId,
216 offset: usize,
217 head: usize,
218 k: &[f32],
219 v: &[f32],
220 ) {
221 if layer >= self.num_layers
222 || head >= self.num_kv_heads
223 || offset >= self.block_size
224 || k.len() != self.head_dim
225 || v.len() != self.head_dim
226 {
227 return;
228 }
229 let base = self.block_offset(block_id, offset, head);
230 self.k_pool[layer][base..base + self.head_dim].copy_from_slice(k);
231 self.v_pool[layer][base..base + self.head_dim].copy_from_slice(v);
232 }
233
234 pub fn read_k(
236 &self,
237 layer: usize,
238 block_id: BlockId,
239 offset: usize,
240 head: usize,
241 ) -> &[f32] {
242 if layer >= self.num_layers
243 || head >= self.num_kv_heads
244 || offset >= self.block_size
245 {
246 return &[];
247 }
248 let base = self.block_offset(block_id, offset, head);
249 &self.k_pool[layer][base..base + self.head_dim]
250 }
251
252 pub fn read_v(
254 &self,
255 layer: usize,
256 block_id: BlockId,
257 offset: usize,
258 head: usize,
259 ) -> &[f32] {
260 if layer >= self.num_layers
261 || head >= self.num_kv_heads
262 || offset >= self.block_size
263 {
264 return &[];
265 }
266 let base = self.block_offset(block_id, offset, head);
267 &self.v_pool[layer][base..base + self.head_dim]
268 }
269
270 pub fn copy_block(&mut self, src: BlockId, dst: BlockId) {
272 let block_stride = self.block_stride();
273 let src_base = src * block_stride;
274 let dst_base = dst * block_stride;
275 for layer in 0..self.num_layers {
276 let src_slice = self.k_pool[layer][src_base..src_base + block_stride].to_vec();
277 self.k_pool[layer][dst_base..dst_base + block_stride].copy_from_slice(&src_slice);
278 let src_slice = self.v_pool[layer][src_base..src_base + block_stride].to_vec();
279 self.v_pool[layer][dst_base..dst_base + block_stride].copy_from_slice(&src_slice);
280 }
281 }
282
283 pub fn memory_usage(&self) -> usize {
285 let floats_per_layer = self.num_blocks * self.block_stride();
286 let total_floats = floats_per_layer * self.num_layers * 2; total_floats * std::mem::size_of::<f32>()
288 }
289
290 pub fn num_free_blocks(&self) -> usize {
292 self.allocator.num_free()
293 }
294
295 pub fn total_blocks(&self) -> usize {
297 self.num_blocks
298 }
299
300 #[allow(dead_code)]
302 pub(crate) fn allocator_mut(&mut self) -> &mut PageAllocator {
303 &mut self.allocator
304 }
305
306 #[allow(dead_code)]
308 pub(crate) fn allocator(&self) -> &PageAllocator {
309 &self.allocator
310 }
311}
312
313pub struct PagedSequence {
315 pub block_table: BlockTable,
317 pub seq_id: usize,
319 pub num_tokens: usize,
321}
322
323impl PagedSequence {
324 pub fn new(seq_id: usize, block_size: usize) -> Self {
326 Self {
327 block_table: BlockTable::new(block_size),
328 seq_id,
329 num_tokens: 0,
330 }
331 }
332
333 pub fn append_token(
335 &mut self,
336 pool: &mut PagedKVPool,
337 layer: usize,
338 head: usize,
339 k: &[f32],
340 v: &[f32],
341 ) -> Result<(), &'static str> {
342 let (logical_block_idx, offset_within_block) =
343 self.block_table.token_to_block(self.num_tokens);
344
345 while logical_block_idx >= self.block_table.num_blocks() {
347 let blocks = pool
348 .allocate_blocks(1)
349 .ok_or("No free blocks in pool")?;
350 let block_id = blocks[0];
351 self.block_table.append_block(block_id);
352 }
353
354 let block_id = self
355 .block_table
356 .logical_to_physical(logical_block_idx)
357 .ok_or("Missing block mapping")?;
358
359 pool.write_kv(layer, block_id, offset_within_block, head, k, v);
360
361 Ok(())
362 }
363
364 pub fn advance_token(&mut self) {
366 self.num_tokens += 1;
367 self.block_table.set_num_tokens(self.num_tokens);
368 }
369
370 pub fn get_kv_for_attention(
373 &self,
374 pool: &PagedKVPool,
375 layer: usize,
376 head: usize,
377 ) -> (Vec<f32>, Vec<f32>) {
378 let num_tokens = self.num_tokens;
379 let head_dim = pool.head_dim;
380
381 let mut k_buf = vec![0.0; num_tokens * head_dim];
382 let mut v_buf = vec![0.0; num_tokens * head_dim];
383
384 for token_pos in 0..num_tokens {
385 let (logical_block_idx, offset) = self.block_table.token_to_block(token_pos);
386 if let Some(block_id) = self.block_table.logical_to_physical(logical_block_idx) {
387 let k_slice = pool.read_k(layer, block_id, offset, head);
388 let v_slice = pool.read_v(layer, block_id, offset, head);
389 if k_slice.len() == head_dim && v_slice.len() == head_dim {
390 k_buf[token_pos * head_dim..(token_pos + 1) * head_dim]
391 .copy_from_slice(k_slice);
392 v_buf[token_pos * head_dim..(token_pos + 1) * head_dim]
393 .copy_from_slice(v_slice);
394 }
395 }
396 }
397
398 (k_buf, v_buf)
399 }
400
401}
402
403impl BlockTable {
404 pub fn clear(&mut self) {
406 self.entries.clear();
407 self.num_tokens = 0;
408 }
409}
410
411impl PagedSequence {
412 pub fn free(&mut self, pool: &mut PagedKVPool) {
414 let block_ids: Vec<BlockId> = (0..self.block_table.num_blocks())
415 .filter_map(|i| self.block_table.logical_to_physical(i))
416 .collect();
417 pool.free_blocks(&block_ids);
418 self.block_table.clear();
419 self.num_tokens = 0;
420 }
421}
422
423#[cfg(test)]
424mod tests {
425 use super::*;
426
427 #[test]
428 fn test_page_allocator_basic() {
429 let mut alloc = PageAllocator::new(4);
430 assert_eq!(alloc.num_free(), 4);
431 assert_eq!(alloc.num_used(), 0);
432
433 let b0 = alloc.allocate().unwrap();
434 let b1 = alloc.allocate().unwrap();
435 assert_eq!(alloc.num_free(), 2);
436 assert_eq!(alloc.num_used(), 2);
437 assert_eq!(alloc.ref_count(b0), 1);
438 assert_eq!(alloc.ref_count(b1), 1);
439
440 alloc.increment_ref(b0);
441 assert_eq!(alloc.ref_count(b0), 2);
442
443 alloc.free(b0);
444 assert_eq!(alloc.ref_count(b0), 1);
445 assert_eq!(alloc.num_free(), 2);
446
447 alloc.free(b0);
448 assert_eq!(alloc.ref_count(b0), 0);
449 assert_eq!(alloc.num_free(), 3);
450
451 alloc.free(b1);
452 assert_eq!(alloc.num_free(), 4);
453 }
454
455 #[test]
456 fn test_block_table() {
457 let mut table = BlockTable::new(16);
458 assert_eq!(table.num_blocks(), 0);
459 assert_eq!(table.num_tokens(), 0);
460
461 table.append_block(5);
462 table.append_block(7);
463 assert_eq!(table.num_blocks(), 2);
464 assert_eq!(table.logical_to_physical(0), Some(5));
465 assert_eq!(table.logical_to_physical(1), Some(7));
466 assert_eq!(table.logical_to_physical(2), None);
467
468 assert_eq!(table.token_to_block(0), (0, 0));
469 assert_eq!(table.token_to_block(15), (0, 15));
470 assert_eq!(table.token_to_block(16), (1, 0));
471 assert_eq!(table.token_to_block(31), (1, 15));
472
473 table.set_num_tokens(20);
474 assert_eq!(table.num_tokens(), 20);
475 }
476
477 #[test]
478 fn test_paged_kv_pool() {
479 let mut pool = PagedKVPool::new(2, 4, 8, 16, 10);
480 assert_eq!(pool.num_free_blocks(), 10);
481 assert_eq!(pool.total_blocks(), 10);
482
483 let blocks = pool.allocate_blocks(2).unwrap();
484 let b0 = blocks[0];
485 let b1 = blocks[1];
486
487 let k: Vec<f32> = (0..8).map(|i| i as f32).collect();
488 let v: Vec<f32> = (0..8).map(|i| (i + 10) as f32).collect();
489
490 pool.write_kv(0, b0, 0, 0, &k, &v);
491 pool.write_kv(0, b0, 1, 1, &k, &v);
492
493 let read_k = pool.read_k(0, b0, 0, 0);
494 let read_v = pool.read_v(0, b0, 0, 0);
495 assert_eq!(read_k, &k[..]);
496 assert_eq!(read_v, &v[..]);
497
498 pool.free_blocks(&[b0, b1]);
499 assert_eq!(pool.num_free_blocks(), 10);
500 assert!(pool.memory_usage() > 0);
501 }
502
503 #[test]
504 fn test_paged_sequence() {
505 let mut pool = PagedKVPool::new(1, 1, 4, 8, 16);
506 let mut seq = PagedSequence::new(0, 8);
507
508 let k: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
509 let v: Vec<f32> = vec![5.0, 6.0, 7.0, 8.0];
510
511 seq.append_token(&mut pool, 0, 0, &k, &v).unwrap();
512 seq.advance_token();
513
514 let k2: Vec<f32> = vec![10.0, 20.0, 30.0, 40.0];
515 let v2: Vec<f32> = vec![50.0, 60.0, 70.0, 80.0];
516 seq.append_token(&mut pool, 0, 0, &k2, &v2).unwrap();
517 seq.advance_token();
518
519 assert_eq!(seq.num_tokens, 2);
520
521 let (gathered_k, gathered_v) = seq.get_kv_for_attention(&pool, 0, 0);
522 assert_eq!(gathered_k[0..4], k[..]);
523 assert_eq!(gathered_v[0..4], v[..]);
524 assert_eq!(gathered_k[4..8], k2[..]);
525 assert_eq!(gathered_v[4..8], v2[..]);
526
527 seq.free(&mut pool);
528 assert_eq!(pool.num_free_blocks(), 16);
529 }
530
531 #[test]
532 fn test_copy_on_write() {
533 let mut pool = PagedKVPool::new(1, 1, 4, 8, 16);
534 let blocks = pool.allocate_blocks(2).unwrap();
535 let src = blocks[0];
536 let dst = blocks[1];
537
538 let k: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
539 let v: Vec<f32> = vec![5.0, 6.0, 7.0, 8.0];
540 pool.write_kv(0, src, 0, 0, &k, &v);
541
542 pool.copy_block(src, dst);
543
544 let read_k = pool.read_k(0, dst, 0, 0);
545 let read_v = pool.read_v(0, dst, 0, 0);
546 assert_eq!(read_k, &k[..]);
547 assert_eq!(read_v, &v[..]);
548
549 pool.allocator_mut().increment_ref(src);
550 assert_eq!(pool.allocator().ref_count(src), 2);
551
552 pool.free_blocks(&[src, dst]);
553 }
554
555 #[test]
556 fn test_memory_fragmentation() {
557 let mut pool = PagedKVPool::new(1, 1, 4, 8, 10);
558 let mut allocated = Vec::new();
559
560 for _ in 0..10 {
561 let blocks = pool.allocate_blocks(1).unwrap();
562 allocated.push(blocks[0]);
563 }
564 assert_eq!(pool.num_free_blocks(), 0);
565 assert!(pool.allocate_blocks(1).is_none());
566
567 pool.free_blocks(&allocated[0..5]);
568 assert_eq!(pool.num_free_blocks(), 5);
569
570 let blocks = pool.allocate_blocks(5).unwrap();
571 assert_eq!(pool.num_free_blocks(), 0);
572
573 pool.free_blocks(&allocated[5..10]);
574 pool.free_blocks(&blocks);
575 assert_eq!(pool.num_free_blocks(), 10);
576 }
577}