use ferrum_kernels::backend::Backend;
use ferrum_types::{FerrumError, Result, TokenId};
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::sync::atomic::{AtomicUsize, Ordering};
pub type BlockHash = u64;
pub fn block_hash(parent: BlockHash, tokens: &[TokenId]) -> BlockHash {
let mut h = std::collections::hash_map::DefaultHasher::new();
parent.hash(&mut h);
for t in tokens {
t.get().hash(&mut h);
}
h.finish()
}
pub fn block_hash_chain(tokens: &[TokenId], block_size: usize) -> Vec<BlockHash> {
let n = tokens.len() / block_size;
let mut out = Vec::with_capacity(n);
let mut parent: BlockHash = 0;
for chunk in tokens.chunks_exact(block_size) {
parent = block_hash(parent, chunk);
out.push(parent);
}
out
}
pub struct BlockAllocator {
free_list: Vec<u32>,
capacity: u32,
peak_in_use: AtomicUsize,
ref_counts: Vec<u16>,
hash_table: HashMap<BlockHash, u32>,
block_to_hash: Vec<Option<BlockHash>>,
}
impl BlockAllocator {
pub fn new(num_blocks: u32) -> Self {
let mut free_list: Vec<u32> = (0..num_blocks).collect();
free_list.reverse(); Self {
free_list,
capacity: num_blocks,
peak_in_use: AtomicUsize::new(0),
ref_counts: vec![0u16; num_blocks as usize],
hash_table: HashMap::new(),
block_to_hash: vec![None; num_blocks as usize],
}
}
pub fn allocate(&mut self) -> Result<u32> {
match self.free_list.pop() {
Some(b) => {
debug_assert!(
self.ref_counts[b as usize] == 0,
"allocate yielded block {b} with non-zero ref_count {}",
self.ref_counts[b as usize]
);
self.evict_hash_if_any(b);
self.ref_counts[b as usize] = 1;
let in_use = self.capacity as usize - self.free_list.len();
self.peak_in_use.fetch_max(in_use, Ordering::Relaxed);
Ok(b)
}
None => Err(FerrumError::resource_exhausted(format!(
"paged KV pool exhausted (capacity={} blocks, all in use)",
self.capacity
))),
}
}
fn evict_hash_if_any(&mut self, block: u32) {
if let Some(h) = self.block_to_hash[block as usize].take() {
if let Some(&mapped) = self.hash_table.get(&h) {
if mapped == block {
self.hash_table.remove(&h);
}
}
}
}
pub fn allocate_n(&mut self, n: usize) -> Result<Vec<u32>> {
if self.free_list.len() < n {
return Err(FerrumError::resource_exhausted(format!(
"paged KV pool exhausted: need {n} blocks but only {} free",
self.free_list.len()
)));
}
let mut out = Vec::with_capacity(n);
for _ in 0..n {
let b = self.free_list.pop().unwrap();
debug_assert!(
self.ref_counts[b as usize] == 0,
"allocate_n yielded block {b} with non-zero ref_count"
);
self.evict_hash_if_any(b);
self.ref_counts[b as usize] = 1;
out.push(b);
}
let in_use = self.capacity as usize - self.free_list.len();
self.peak_in_use.fetch_max(in_use, Ordering::Relaxed);
Ok(out)
}
pub fn try_acquire_by_hash(&mut self, hash: BlockHash) -> Option<u32> {
let &block = self.hash_table.get(&hash)?;
let bi = block as usize;
if self.ref_counts[bi] == 0 {
if let Some(pos) = self.free_list.iter().rposition(|&b| b == block) {
self.free_list.swap_remove(pos);
} else {
debug_assert!(false, "block {block} has ref=0 but not in free_list");
return None;
}
self.ref_counts[bi] = 1;
let in_use = self.capacity as usize - self.free_list.len();
self.peak_in_use.fetch_max(in_use, Ordering::Relaxed);
} else {
self.acquire(block);
}
Some(block)
}
pub fn register_block_hash(&mut self, block: u32, hash: BlockHash) {
let bi = block as usize;
debug_assert!(
self.ref_counts[bi] > 0,
"register_block_hash on block {block} with ref_count=0 (not allocated)"
);
if let Some(old_h) = self.block_to_hash[bi].replace(hash) {
if old_h != hash {
if let Some(&mapped) = self.hash_table.get(&old_h) {
if mapped == block {
self.hash_table.remove(&old_h);
}
}
} else {
return;
}
}
self.hash_table.insert(hash, block);
}
#[inline]
pub fn hash_table_size(&self) -> usize {
self.hash_table.len()
}
pub fn acquire(&mut self, block: u32) {
let bi = block as usize;
debug_assert!(
self.ref_counts[bi] > 0,
"acquire on block {block} with ref_count=0 (not currently allocated)"
);
self.ref_counts[bi] = self.ref_counts[bi]
.checked_add(1)
.expect("BlockAllocator ref_count u16 overflow (>65535 sharers)");
}
pub fn acquire_many(&mut self, blocks: &[u32]) {
for &b in blocks {
self.acquire(b);
}
}
pub fn free(&mut self, blocks: &[u32]) {
for &b in blocks {
let bi = b as usize;
debug_assert!(
self.ref_counts[bi] > 0,
"free on block {b} with ref_count=0 (double-free)"
);
self.ref_counts[bi] -= 1;
if self.ref_counts[bi] == 0 {
self.free_list.push(b);
}
}
}
#[inline]
pub fn ref_count(&self, block: u32) -> u16 {
self.ref_counts[block as usize]
}
pub fn free_count(&self) -> usize {
self.free_list.len()
}
pub fn capacity(&self) -> u32 {
self.capacity
}
pub fn peak_in_use(&self) -> usize {
self.peak_in_use.load(Ordering::Relaxed)
}
}
pub struct PagedSeqState<B: Backend> {
pub blocks: Vec<u32>,
pub block_table_buf: B::Buffer,
pub context_lens_buf: B::Buffer,
pub len: usize,
pub block_size: usize,
pub max_blocks_per_seq: usize,
}
impl<B: Backend> PagedSeqState<B> {
pub fn new(block_size: usize, max_blocks_per_seq: usize) -> Self {
let block_table_buf =
B::alloc_typed(ferrum_kernels::backend::Dtype::U32, max_blocks_per_seq);
let context_lens_buf = B::alloc_typed(ferrum_kernels::backend::Dtype::U32, 1);
let mut ctx = B::new_context();
let mut cl = context_lens_buf;
B::write_typed::<u32>(&mut ctx, &mut cl, &[0u32]);
B::sync(&mut ctx);
Self {
blocks: Vec::with_capacity(max_blocks_per_seq),
block_table_buf,
context_lens_buf: cl,
len: 0,
block_size,
max_blocks_per_seq,
}
}
pub fn ensure_capacity(
&mut self,
ctx: &mut B::Context,
alloc: &mut BlockAllocator,
target_len: usize,
) -> Result<()> {
let needed = target_len.div_ceil(self.block_size);
if needed > self.max_blocks_per_seq {
return Err(FerrumError::model(format!(
"paged KV: target_len={target_len} would need {needed} blocks, exceeds max_blocks_per_seq={}",
self.max_blocks_per_seq
)));
}
while self.blocks.len() < needed {
let block = alloc.allocate()?;
self.blocks.push(block);
}
let mut padded = self.blocks.clone();
padded.resize(self.max_blocks_per_seq, 0);
B::write_typed::<u32>(ctx, &mut self.block_table_buf, &padded);
Ok(())
}
pub fn sync_context_len(&mut self, ctx: &mut B::Context) {
B::write_typed::<u32>(ctx, &mut self.context_lens_buf, &[self.len as u32]);
}
pub fn release(&mut self, alloc: &mut BlockAllocator) {
alloc.free(&self.blocks);
self.blocks.clear();
self.len = 0;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn allocator_basic() {
let mut a = BlockAllocator::new(4);
assert_eq!(a.free_count(), 4);
assert_eq!(a.allocate().unwrap(), 0);
assert_eq!(a.allocate().unwrap(), 1);
assert_eq!(a.allocate().unwrap(), 2);
assert_eq!(a.allocate().unwrap(), 3);
assert!(a.allocate().is_err());
assert_eq!(a.free_count(), 0);
a.free(&[1, 3]);
assert_eq!(a.free_count(), 2);
assert_eq!(a.allocate().unwrap(), 3);
assert_eq!(a.allocate().unwrap(), 1);
}
#[test]
fn allocator_atomic_n_failure() {
let mut a = BlockAllocator::new(3);
let _ = a.allocate().unwrap(); let _ = a.allocate().unwrap();
assert!(a.allocate_n(2).is_err());
assert_eq!(a.free_count(), 1);
}
#[test]
fn allocator_peak_tracking() {
let mut a = BlockAllocator::new(8);
let blocks = a.allocate_n(5).unwrap();
assert_eq!(a.peak_in_use(), 5);
a.free(&blocks);
assert_eq!(a.peak_in_use(), 5); let _ = a.allocate_n(3).unwrap();
assert_eq!(a.peak_in_use(), 5);
}
#[test]
fn refcount_allocate_starts_at_one() {
let mut a = BlockAllocator::new(4);
let b = a.allocate().unwrap();
assert_eq!(a.ref_count(b), 1);
}
#[test]
fn refcount_acquire_increments() {
let mut a = BlockAllocator::new(4);
let b = a.allocate().unwrap();
a.acquire(b);
a.acquire(b);
assert_eq!(a.ref_count(b), 3);
}
#[test]
fn refcount_free_decrements_no_physical_release() {
let mut a = BlockAllocator::new(4);
let b = a.allocate().unwrap();
a.acquire(b); a.free(&[b]); assert_eq!(a.ref_count(b), 1);
assert_eq!(a.free_count(), 3); }
#[test]
fn refcount_free_physical_release_at_zero() {
let mut a = BlockAllocator::new(4);
let b = a.allocate().unwrap();
a.acquire(b); a.free(&[b]); a.free(&[b]); assert_eq!(a.ref_count(b), 0);
assert_eq!(a.free_count(), 4);
}
#[test]
fn refcount_legacy_single_ref_behaviour_unchanged() {
let mut a = BlockAllocator::new(2);
let b = a.allocate().unwrap();
assert_eq!(a.free_count(), 1);
a.free(&[b]);
assert_eq!(a.free_count(), 2);
assert_eq!(a.ref_count(b), 0);
let b2 = a.allocate().unwrap();
assert_eq!(b2, b);
}
#[test]
fn refcount_bulk_acquire_and_release() {
let mut a = BlockAllocator::new(8);
let blocks = a.allocate_n(3).unwrap();
a.acquire_many(&blocks); for &b in &blocks {
assert_eq!(a.ref_count(b), 2);
}
a.free(&blocks); assert_eq!(a.free_count(), 5);
a.free(&blocks); assert_eq!(a.free_count(), 8);
}
fn toks(ids: &[u32]) -> Vec<TokenId> {
ids.iter().map(|&i| TokenId::new(i)).collect()
}
#[test]
fn block_hash_chain_basic() {
let tokens = toks(&[1, 2, 3, 4, 5, 6, 7, 8]);
let chain = block_hash_chain(&tokens, 4);
assert_eq!(chain.len(), 2);
assert_ne!(chain[0], chain[1]);
}
#[test]
fn block_hash_chain_identical_prefix_same_hashes() {
let a = toks(&[10, 20, 30, 40, 50, 60, 70, 80]);
let b = toks(&[10, 20, 30, 40, 99, 99, 99, 99]);
let ca = block_hash_chain(&a, 4);
let cb = block_hash_chain(&b, 4);
assert_eq!(ca[0], cb[0], "first block matches → same hash[0]");
assert_ne!(ca[1], cb[1], "second block differs → different hash[1]");
}
#[test]
fn block_hash_chain_drops_partial_trailing() {
let tokens = toks(&[1, 2, 3, 4, 5, 6, 7]); let chain = block_hash_chain(&tokens, 4);
assert_eq!(chain.len(), 1);
}
#[test]
fn hash_table_register_and_acquire_live() {
let mut a = BlockAllocator::new(4);
let b = a.allocate().unwrap();
let h = 0xDEAD_BEEFu64;
a.register_block_hash(b, h);
assert_eq!(a.hash_table_size(), 1);
let got = a.try_acquire_by_hash(h);
assert_eq!(got, Some(b));
assert_eq!(a.ref_count(b), 2);
}
#[test]
fn hash_table_soft_free_resurrection() {
let mut a = BlockAllocator::new(4);
let b = a.allocate().unwrap();
let h = 0xABCD_1234u64;
a.register_block_hash(b, h);
a.free(&[b]); assert_eq!(a.free_count(), 4);
assert_eq!(a.ref_count(b), 0);
let got = a.try_acquire_by_hash(h);
assert_eq!(got, Some(b));
assert_eq!(a.ref_count(b), 1);
assert_eq!(a.free_count(), 3);
}
#[test]
fn hash_table_miss_returns_none() {
let mut a = BlockAllocator::new(4);
let b = a.allocate().unwrap();
a.register_block_hash(b, 1);
assert_eq!(a.try_acquire_by_hash(99), None);
assert_eq!(a.ref_count(b), 1);
}
#[test]
fn hash_table_evicts_on_realloc() {
let mut a = BlockAllocator::new(2);
let b1 = a.allocate().unwrap();
let h = 0x1234u64;
a.register_block_hash(b1, h);
a.free(&[b1]); assert_eq!(a.hash_table_size(), 1);
let b2 = a.allocate().unwrap();
assert_eq!(b2, b1, "LIFO should reuse b1");
assert_eq!(a.hash_table_size(), 0, "stale hash erased on realloc");
assert_eq!(a.try_acquire_by_hash(h), None);
}
#[test]
fn hash_table_replace_block_hash() {
let mut a = BlockAllocator::new(4);
let b = a.allocate().unwrap();
a.register_block_hash(b, 1);
a.register_block_hash(b, 2); assert_eq!(a.hash_table_size(), 1);
assert_eq!(a.try_acquire_by_hash(1), None);
assert_eq!(a.try_acquire_by_hash(2), Some(b));
}
}