use std::collections::HashMap;
use super::block_hash::{BlockHash, BlockHashWithGroupId};
const NO_LINK: usize = usize::MAX;
#[derive(Debug)]
pub struct KVCacheBlock {
#[allow(dead_code)]
pub block_id: usize,
pub ref_cnt: u32,
pub block_hash: Option<BlockHashWithGroupId>,
prev_free: usize,
next_free: usize,
pub is_null: bool,
}
impl KVCacheBlock {
fn new(block_id: usize) -> Self {
Self {
block_id,
ref_cnt: 0,
block_hash: None,
prev_free: NO_LINK,
next_free: NO_LINK,
is_null: false,
}
}
#[allow(dead_code)]
fn is_in_free_list(&self) -> bool {
self.prev_free != NO_LINK || self.next_free != NO_LINK
}
fn reset_hash(&mut self) {
self.block_hash = None;
}
}
struct FreeKVCacheBlockQueue {
num_free_blocks: usize,
fake_head: usize,
fake_tail: usize,
}
impl FreeKVCacheBlockQueue {
fn new(
blocks: &mut [KVCacheBlock],
block_ids: &[usize],
fake_head: usize,
fake_tail: usize,
) -> Self {
let n = block_ids.len();
for i in 0..n {
let id = block_ids[i];
blocks[id].prev_free = if i > 0 { block_ids[i - 1] } else { fake_head };
blocks[id].next_free = if i + 1 < n {
block_ids[i + 1]
} else {
fake_tail
};
}
if n > 0 {
blocks[fake_head].next_free = block_ids[0];
blocks[fake_tail].prev_free = block_ids[n - 1];
} else {
blocks[fake_head].next_free = fake_tail;
blocks[fake_tail].prev_free = fake_head;
}
Self {
num_free_blocks: n,
fake_head,
fake_tail,
}
}
fn popleft(&mut self, blocks: &mut [KVCacheBlock]) -> Option<usize> {
let first_id = blocks[self.fake_head].next_free;
if first_id == self.fake_tail {
return None; }
let next_id = blocks[first_id].next_free;
blocks[self.fake_head].next_free = next_id;
blocks[next_id].prev_free = self.fake_head;
blocks[first_id].prev_free = NO_LINK;
blocks[first_id].next_free = NO_LINK;
self.num_free_blocks -= 1;
Some(first_id)
}
fn remove(&mut self, blocks: &mut [KVCacheBlock], block_id: usize) {
let prev_id = blocks[block_id].prev_free;
let next_id = blocks[block_id].next_free;
debug_assert!(
prev_id != NO_LINK && next_id != NO_LINK,
"remove() called on block {} not in free list",
block_id
);
blocks[prev_id].next_free = next_id;
blocks[next_id].prev_free = prev_id;
blocks[block_id].prev_free = NO_LINK;
blocks[block_id].next_free = NO_LINK;
self.num_free_blocks -= 1;
}
fn append(&mut self, blocks: &mut [KVCacheBlock], block_id: usize) {
let last_id = blocks[self.fake_tail].prev_free;
blocks[last_id].next_free = block_id;
blocks[block_id].prev_free = last_id;
blocks[block_id].next_free = self.fake_tail;
blocks[self.fake_tail].prev_free = block_id;
self.num_free_blocks += 1;
}
}
pub struct BlockHashToBlockMap {
cache: HashMap<BlockHashWithGroupId, CachedBlocks>,
}
enum CachedBlocks {
Single(usize),
Multiple(HashMap<usize, usize>), }
impl BlockHashToBlockMap {
fn new() -> Self {
Self {
cache: HashMap::new(),
}
}
fn get_one(&self, key: &BlockHashWithGroupId) -> Option<usize> {
match self.cache.get(key)? {
CachedBlocks::Single(id) => Some(*id),
CachedBlocks::Multiple(map) => map.values().next().copied(),
}
}
fn insert(&mut self, key: BlockHashWithGroupId, block_id: usize) {
match self.cache.get_mut(&key) {
None => {
self.cache.insert(key, CachedBlocks::Single(block_id));
}
Some(CachedBlocks::Single(existing_id)) => {
let existing = *existing_id;
let mut map = HashMap::new();
map.insert(existing, existing);
map.insert(block_id, block_id);
self.cache.insert(key, CachedBlocks::Multiple(map));
}
Some(CachedBlocks::Multiple(map)) => {
map.insert(block_id, block_id);
}
}
}
fn pop(&mut self, key: &BlockHashWithGroupId, block_id: usize) -> Option<usize> {
let entry = self.cache.remove(key)?;
match entry {
CachedBlocks::Single(id) => {
if id == block_id {
Some(id)
} else {
self.cache.insert(*key, CachedBlocks::Single(id));
None
}
}
CachedBlocks::Multiple(mut map) => {
let result = map.remove(&block_id);
if map.len() == 1 {
let single_id = *map.values().next().unwrap();
self.cache.insert(*key, CachedBlocks::Single(single_id));
} else if !map.is_empty() {
self.cache.insert(*key, CachedBlocks::Multiple(map));
}
result
}
}
}
fn len(&self) -> usize {
self.cache.len()
}
fn clear(&mut self) {
self.cache.clear();
}
}
pub struct BlockPool {
blocks: Vec<KVCacheBlock>,
free_queue: FreeKVCacheBlockQueue,
cached_block_hash_to_block: BlockHashToBlockMap,
enable_caching: bool,
num_gpu_blocks: usize,
null_block_id: usize,
hash_block_size: usize,
}
impl BlockPool {
pub fn new(num_gpu_blocks: usize, enable_caching: bool, hash_block_size: usize) -> Self {
assert!(num_gpu_blocks > 0, "Must have at least 1 GPU block");
let fake_head = num_gpu_blocks;
let fake_tail = num_gpu_blocks + 1;
let total = num_gpu_blocks + 2;
let mut blocks: Vec<KVCacheBlock> = (0..total).map(KVCacheBlock::new).collect();
let all_ids: Vec<usize> = (0..num_gpu_blocks).collect();
let free_queue = FreeKVCacheBlockQueue::new(&mut blocks, &all_ids, fake_head, fake_tail);
let mut pool = Self {
blocks,
free_queue,
cached_block_hash_to_block: BlockHashToBlockMap::new(),
enable_caching,
num_gpu_blocks,
null_block_id: 0, hash_block_size,
};
let null_id = pool
.free_queue
.popleft(&mut pool.blocks)
.expect("Pool should have blocks");
pool.blocks[null_id].is_null = true;
pool.null_block_id = null_id;
pool
}
pub fn null_block_id(&self) -> usize {
self.null_block_id
}
pub fn num_free_blocks(&self) -> usize {
self.free_queue.num_free_blocks
}
pub fn num_gpu_blocks(&self) -> usize {
self.num_gpu_blocks
}
#[allow(clippy::cast_precision_loss)]
pub fn usage(&self) -> f64 {
let total = self.num_gpu_blocks - 1; if total == 0 {
return 0.0;
}
1.0 - (self.num_free_blocks() as f64 / total as f64)
}
pub fn get_cached_block(
&self,
block_hash: BlockHash,
kv_cache_group_ids: &[u32],
) -> Option<Vec<usize>> {
let mut cached_ids = Vec::with_capacity(kv_cache_group_ids.len());
for &group_id in kv_cache_group_ids {
let key = BlockHashWithGroupId {
block_hash,
group_id,
};
match self.cached_block_hash_to_block.get_one(&key) {
Some(id) => cached_ids.push(id),
None => return None,
}
}
Some(cached_ids)
}
pub fn touch(&mut self, block_ids: &[usize]) {
for &block_id in block_ids {
let block = &mut self.blocks[block_id];
if block.ref_cnt == 0 && !block.is_null {
self.free_queue.remove(&mut self.blocks, block_id);
} else {
}
self.blocks[block_id].ref_cnt += 1;
}
}
pub fn free_blocks(&mut self, ordered_block_ids: &[usize]) {
for &block_id in ordered_block_ids {
debug_assert!(
self.blocks[block_id].ref_cnt > 0,
"Block {block_id} ref_cnt underflow: attempting to free block with ref_cnt=0"
);
self.blocks[block_id].ref_cnt = self.blocks[block_id].ref_cnt.saturating_sub(1);
}
for &block_id in ordered_block_ids {
if self.blocks[block_id].ref_cnt == 0 && !self.blocks[block_id].is_null {
self.free_queue.append(&mut self.blocks, block_id);
}
}
}
pub fn get_new_blocks(&mut self, num_blocks: usize) -> Option<Vec<usize>> {
if num_blocks > self.free_queue.num_free_blocks {
return None;
}
let mut result = Vec::with_capacity(num_blocks);
for _ in 0..num_blocks {
let block_id = self
.free_queue
.popleft(&mut self.blocks)
.expect("Should have enough free blocks");
if self.enable_caching {
self.maybe_evict_cached_block(block_id);
}
debug_assert_eq!(self.blocks[block_id].ref_cnt, 0);
self.blocks[block_id].ref_cnt = 1;
result.push(block_id);
}
Some(result)
}
pub fn cache_full_blocks(
&mut self,
block_ids: &[usize],
block_hashes: &[BlockHash],
num_cached_blocks: usize,
num_full_blocks: usize,
kv_cache_group_id: u32,
) {
if !self.enable_caching || num_cached_blocks >= num_full_blocks {
return;
}
assert!(
block_hashes.len() >= num_full_blocks,
"Not enough block hashes ({}) for {} full blocks",
block_hashes.len(),
num_full_blocks
);
for idx in num_cached_blocks..num_full_blocks {
let block_id = block_ids[idx];
let block = &mut self.blocks[block_id];
if block.is_null || block.block_hash.is_some() {
continue;
}
let hash_with_group = BlockHashWithGroupId {
block_hash: block_hashes[idx],
group_id: kv_cache_group_id,
};
block.block_hash = Some(hash_with_group);
self.cached_block_hash_to_block
.insert(hash_with_group, block_id);
}
}
fn maybe_evict_cached_block(&mut self, block_id: usize) {
let block_hash = self.blocks[block_id].block_hash;
if let Some(hash) = block_hash {
self.cached_block_hash_to_block.pop(&hash, block_id);
self.blocks[block_id].reset_hash();
}
}
pub fn reset_prefix_cache(&mut self) -> bool {
let num_used = self.num_gpu_blocks - self.num_free_blocks();
if num_used != 1 {
return false;
}
self.cached_block_hash_to_block.clear();
for block in &mut self.blocks {
block.reset_hash();
}
true
}
pub fn num_cached_blocks(&self) -> usize {
self.cached_block_hash_to_block.len()
}
pub fn hash_block_size(&self) -> usize {
self.hash_block_size
}
pub fn caching_enabled(&self) -> bool {
self.enable_caching
}
pub fn block_ref_cnt(&self, block_id: usize) -> u32 {
self.blocks[block_id].ref_cnt
}
pub fn block_hash(&self, block_id: usize) -> Option<BlockHashWithGroupId> {
self.blocks[block_id].block_hash
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::paged_attention::block_hash::hash_block_tokens;
#[test]
fn test_basic_allocation() {
let mut pool = BlockPool::new(4, false, 16);
assert_eq!(pool.num_free_blocks(), 3);
let blocks = pool.get_new_blocks(2).unwrap();
assert_eq!(blocks.len(), 2);
assert_eq!(pool.num_free_blocks(), 1);
for &id in &blocks {
assert_eq!(pool.block_ref_cnt(id), 1);
}
}
#[test]
fn test_free_returns_to_pool() {
let mut pool = BlockPool::new(4, false, 16);
let blocks = pool.get_new_blocks(3).unwrap();
assert_eq!(pool.num_free_blocks(), 0);
pool.free_blocks(&blocks);
assert_eq!(pool.num_free_blocks(), 3);
for &id in &blocks {
assert_eq!(pool.block_ref_cnt(id), 0);
}
}
#[test]
fn test_allocation_fails_when_exhausted() {
let mut pool = BlockPool::new(2, false, 16);
assert_eq!(pool.num_free_blocks(), 1);
let _b = pool.get_new_blocks(1).unwrap();
assert_eq!(pool.num_free_blocks(), 0);
assert!(pool.get_new_blocks(1).is_none());
}
#[test]
fn test_prefix_cache_basic() {
let mut pool = BlockPool::new(8, true, 4);
let block_ids = pool.get_new_blocks(3).unwrap();
let h0 = hash_block_tokens(None, &[1, 2, 3, 4], None);
let h1 = hash_block_tokens(Some(h0), &[5, 6, 7, 8], None);
let h2 = hash_block_tokens(Some(h1), &[9, 10, 11, 12], None);
let hashes = vec![h0, h1, h2];
pool.cache_full_blocks(&block_ids, &hashes, 0, 3, 0);
assert_eq!(pool.num_cached_blocks(), 3);
let cached = pool.get_cached_block(h0, &[0]);
assert!(cached.is_some());
assert_eq!(cached.unwrap()[0], block_ids[0]);
}
#[test]
fn test_prefix_cache_reuse_after_free() {
let mut pool = BlockPool::new(8, true, 4);
let block_ids = pool.get_new_blocks(2).unwrap();
let h0 = hash_block_tokens(None, &[1, 2, 3, 4], None);
let h1 = hash_block_tokens(Some(h0), &[5, 6, 7, 8], None);
let hashes = vec![h0, h1];
pool.cache_full_blocks(&block_ids, &hashes, 0, 2, 0);
pool.free_blocks(&block_ids);
assert_eq!(pool.num_free_blocks(), 7);
let cached = pool.get_cached_block(h0, &[0]);
assert!(cached.is_some());
let cached_ids = cached.unwrap();
pool.touch(&cached_ids);
assert_eq!(pool.block_ref_cnt(cached_ids[0]), 1);
assert_eq!(pool.num_free_blocks(), 6);
}
#[test]
fn test_eviction_on_reallocation() {
let mut pool = BlockPool::new(4, true, 4);
let block_ids = pool.get_new_blocks(3).unwrap();
let h0 = hash_block_tokens(None, &[1, 2, 3, 4], None);
pool.cache_full_blocks(&block_ids, &[h0, h0, h0], 0, 1, 0);
pool.free_blocks(&block_ids);
let new_ids = pool.get_new_blocks(3).unwrap();
assert_eq!(new_ids.len(), 3);
}
#[test]
fn test_touch_ref_cnt_management() {
let mut pool = BlockPool::new(8, true, 4);
let block_ids = pool.get_new_blocks(1).unwrap();
assert_eq!(pool.block_ref_cnt(block_ids[0]), 1);
pool.touch(&block_ids);
assert_eq!(pool.block_ref_cnt(block_ids[0]), 2);
pool.free_blocks(&block_ids);
assert_eq!(pool.block_ref_cnt(block_ids[0]), 1);
pool.free_blocks(&block_ids);
assert_eq!(pool.block_ref_cnt(block_ids[0]), 0);
}
#[test]
fn test_null_block_never_freed() {
let mut pool = BlockPool::new(4, false, 16);
let null_id = pool.null_block_id();
assert!(pool.blocks[null_id].is_null);
pool.blocks[null_id].ref_cnt = 1;
pool.free_blocks(&[null_id]);
assert_eq!(pool.block_ref_cnt(null_id), 0);
}
#[test]
fn test_usage() {
let mut pool = BlockPool::new(4, false, 16);
assert!(pool.usage() < 0.01);
let _b = pool.get_new_blocks(3).unwrap();
assert!((pool.usage() - 1.0).abs() < 0.01); }
#[test]
fn test_get_cached_block_multiple_groups() {
let mut pool = BlockPool::new(8, true, 4);
let ids_g0 = pool.get_new_blocks(1).unwrap();
let ids_g1 = pool.get_new_blocks(1).unwrap();
let h0 = hash_block_tokens(None, &[1, 2, 3, 4], None);
pool.cache_full_blocks(&ids_g0, &[h0], 0, 1, 0);
pool.cache_full_blocks(&ids_g1, &[h0], 0, 1, 1);
let cached = pool.get_cached_block(h0, &[0, 1]);
assert!(cached.is_some());
let cached = cached.unwrap();
assert_eq!(cached.len(), 2);
assert_eq!(cached[0], ids_g0[0]);
assert_eq!(cached[1], ids_g1[0]);
let cached = pool.get_cached_block(h0, &[0, 2]);
assert!(cached.is_none());
}
#[test]
fn test_reset_prefix_cache() {
let mut pool = BlockPool::new(4, true, 4);
let ids = pool.get_new_blocks(2).unwrap();
let h0 = hash_block_tokens(None, &[1, 2, 3, 4], None);
pool.cache_full_blocks(&ids, &[h0, h0], 0, 1, 0);
assert!(!pool.reset_prefix_cache());
pool.free_blocks(&ids);
assert!(pool.reset_prefix_cache());
assert_eq!(pool.num_cached_blocks(), 0);
}
}