use candle_core::{Device, Result, Tensor};
use std::collections::{HashMap, VecDeque};
pub struct BlockManager {
num_blocks: usize,
block_size: usize,
free_blocks: VecDeque<usize>,
seq_to_blocks: HashMap<usize, Vec<usize>>,
seq_to_num_tokens: HashMap<usize, usize>,
next_seq_id: usize,
}
impl BlockManager {
pub fn new(num_blocks: usize, block_size: usize) -> Self {
let free_blocks: VecDeque<usize> = (0..num_blocks).collect();
Self {
num_blocks,
block_size,
free_blocks,
seq_to_blocks: HashMap::new(),
seq_to_num_tokens: HashMap::new(),
next_seq_id: 0,
}
}
pub fn allocate_sequence(&mut self) -> usize {
let seq_id = self.next_seq_id;
self.next_seq_id += 1;
self.seq_to_blocks.insert(seq_id, Vec::new());
self.seq_to_num_tokens.insert(seq_id, 0);
seq_id
}
pub fn allocate_slots(&mut self, seq_id: usize, num_tokens: usize) -> Result<Vec<i64>> {
let blocks = self
.seq_to_blocks
.get_mut(&seq_id)
.ok_or_else(|| candle_core::Error::Msg(format!("Sequence {} not found", seq_id)))?;
let current_tokens = *self.seq_to_num_tokens.get(&seq_id).unwrap_or(&0);
let total_tokens_needed = current_tokens + num_tokens;
let blocks_needed = total_tokens_needed.div_ceil(self.block_size);
let new_blocks_needed = blocks_needed.saturating_sub(blocks.len());
for _ in 0..new_blocks_needed {
let block_idx = self
.free_blocks
.pop_front()
.ok_or_else(|| candle_core::Error::Msg("Out of cache blocks".to_string()))?;
blocks.push(block_idx);
}
let mut slots = Vec::with_capacity(num_tokens);
for i in 0..num_tokens {
let token_pos = current_tokens + i;
let block_idx_in_seq = token_pos / self.block_size;
let offset_in_block = token_pos % self.block_size;
let global_block_idx = blocks[block_idx_in_seq];
let slot = (global_block_idx * self.block_size + offset_in_block) as i64;
slots.push(slot);
}
self.seq_to_num_tokens.insert(seq_id, total_tokens_needed);
Ok(slots)
}
pub fn get_block_table(&self, seq_id: usize) -> Option<Vec<u32>> {
self.seq_to_blocks
.get(&seq_id)
.map(|blocks| blocks.iter().map(|&b| b as u32).collect())
}
pub fn get_block_table_tensor(
&self,
seq_id: usize,
max_blocks: usize,
device: &Device,
) -> Result<Tensor> {
let blocks = self
.seq_to_blocks
.get(&seq_id)
.ok_or_else(|| candle_core::Error::Msg(format!("Sequence {} not found", seq_id)))?;
let mut padded: Vec<u32> = blocks.iter().map(|&b| b as u32).collect();
padded.resize(max_blocks, 0);
Tensor::from_vec(padded, (1, max_blocks), device)
}
pub fn get_slot_mapping_tensor(&self, slots: &[i64], device: &Device) -> Result<Tensor> {
Tensor::from_vec(slots.to_vec(), (slots.len(),), device)
}
pub fn get_context_len(&self, seq_id: usize) -> usize {
*self.seq_to_num_tokens.get(&seq_id).unwrap_or(&0)
}
pub fn free_sequence(&mut self, seq_id: usize) {
if let Some(blocks) = self.seq_to_blocks.remove(&seq_id) {
for block_idx in blocks {
self.free_blocks.push_back(block_idx);
}
}
self.seq_to_num_tokens.remove(&seq_id);
}
pub fn num_free_blocks(&self) -> usize {
self.free_blocks.len()
}
pub fn num_used_blocks(&self) -> usize {
self.num_blocks - self.free_blocks.len()
}
pub fn reset(&mut self) {
self.free_blocks = (0..self.num_blocks).collect();
self.seq_to_blocks.clear();
self.seq_to_num_tokens.clear();
self.next_seq_id = 0;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_block_allocation() {
let mut manager = BlockManager::new(10, 16);
let seq_id = manager.allocate_sequence();
assert_eq!(seq_id, 0);
let slots = manager.allocate_slots(seq_id, 5).unwrap();
assert_eq!(slots.len(), 5);
assert_eq!(slots, vec![0, 1, 2, 3, 4]);
let slots = manager.allocate_slots(seq_id, 20).unwrap();
assert_eq!(slots.len(), 20);
assert_eq!(manager.num_used_blocks(), 2); }
}