use ferrum_kernels::backend::Backend;
use ferrum_types::{FerrumError, Result};
use std::sync::atomic::{AtomicUsize, Ordering};
pub struct BlockAllocator {
free_list: Vec<u32>,
capacity: u32,
peak_in_use: AtomicUsize,
}
impl BlockAllocator {
pub fn new(num_blocks: u32) -> Self {
let mut free_list: Vec<u32> = (0..num_blocks).collect();
free_list.reverse(); Self {
free_list,
capacity: num_blocks,
peak_in_use: AtomicUsize::new(0),
}
}
pub fn allocate(&mut self) -> Result<u32> {
match self.free_list.pop() {
Some(b) => {
let in_use = self.capacity as usize - self.free_list.len();
self.peak_in_use.fetch_max(in_use, Ordering::Relaxed);
Ok(b)
}
None => Err(FerrumError::resource_exhausted(format!(
"paged KV pool exhausted (capacity={} blocks, all in use)",
self.capacity
))),
}
}
pub fn allocate_n(&mut self, n: usize) -> Result<Vec<u32>> {
if self.free_list.len() < n {
return Err(FerrumError::resource_exhausted(format!(
"paged KV pool exhausted: need {n} blocks but only {} free",
self.free_list.len()
)));
}
let mut out = Vec::with_capacity(n);
for _ in 0..n {
out.push(self.free_list.pop().unwrap());
}
let in_use = self.capacity as usize - self.free_list.len();
self.peak_in_use.fetch_max(in_use, Ordering::Relaxed);
Ok(out)
}
pub fn free(&mut self, blocks: &[u32]) {
self.free_list.extend_from_slice(blocks);
}
pub fn free_count(&self) -> usize {
self.free_list.len()
}
pub fn capacity(&self) -> u32 {
self.capacity
}
pub fn peak_in_use(&self) -> usize {
self.peak_in_use.load(Ordering::Relaxed)
}
}
pub struct PagedSeqState<B: Backend> {
pub blocks: Vec<u32>,
pub block_table_buf: B::Buffer,
pub context_lens_buf: B::Buffer,
pub len: usize,
pub block_size: usize,
pub max_blocks_per_seq: usize,
}
impl<B: Backend> PagedSeqState<B> {
pub fn new(block_size: usize, max_blocks_per_seq: usize) -> Self {
let block_table_buf = B::alloc_u32(max_blocks_per_seq);
let context_lens_buf = B::alloc_u32(1);
let mut ctx = B::new_context();
let mut cl = context_lens_buf;
B::write_u32(&mut ctx, &mut cl, &[0u32]);
B::sync(&mut ctx);
Self {
blocks: Vec::with_capacity(max_blocks_per_seq),
block_table_buf,
context_lens_buf: cl,
len: 0,
block_size,
max_blocks_per_seq,
}
}
pub fn ensure_capacity(
&mut self,
ctx: &mut B::Context,
alloc: &mut BlockAllocator,
target_len: usize,
) -> Result<()> {
let needed = target_len.div_ceil(self.block_size);
if needed > self.max_blocks_per_seq {
return Err(FerrumError::model(format!(
"paged KV: target_len={target_len} would need {needed} blocks, exceeds max_blocks_per_seq={}",
self.max_blocks_per_seq
)));
}
while self.blocks.len() < needed {
let block = alloc.allocate()?;
self.blocks.push(block);
}
let mut padded = self.blocks.clone();
padded.resize(self.max_blocks_per_seq, 0);
B::write_u32(ctx, &mut self.block_table_buf, &padded);
Ok(())
}
pub fn sync_context_len(&mut self, ctx: &mut B::Context) {
B::write_u32(ctx, &mut self.context_lens_buf, &[self.len as u32]);
}
pub fn release(&mut self, alloc: &mut BlockAllocator) {
alloc.free(&self.blocks);
self.blocks.clear();
self.len = 0;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn allocator_basic() {
let mut a = BlockAllocator::new(4);
assert_eq!(a.free_count(), 4);
assert_eq!(a.allocate().unwrap(), 0);
assert_eq!(a.allocate().unwrap(), 1);
assert_eq!(a.allocate().unwrap(), 2);
assert_eq!(a.allocate().unwrap(), 3);
assert!(a.allocate().is_err());
assert_eq!(a.free_count(), 0);
a.free(&[1, 3]);
assert_eq!(a.free_count(), 2);
assert_eq!(a.allocate().unwrap(), 3);
assert_eq!(a.allocate().unwrap(), 1);
}
#[test]
fn allocator_atomic_n_failure() {
let mut a = BlockAllocator::new(3);
let _ = a.allocate().unwrap(); let _ = a.allocate().unwrap();
assert!(a.allocate_n(2).is_err());
assert_eq!(a.free_count(), 1);
}
#[test]
fn allocator_peak_tracking() {
let mut a = BlockAllocator::new(8);
let blocks = a.allocate_n(5).unwrap();
assert_eq!(a.peak_in_use(), 5);
a.free(&blocks);
assert_eq!(a.peak_in_use(), 5); let _ = a.allocate_n(3).unwrap();
assert_eq!(a.peak_in_use(), 5);
}
}