use crate::scheduler::block_manager::BlockId;
use std::collections::HashMap;
use std::time::Instant;
#[derive(Debug)]
struct RadixNode {
tokens: Vec<u32>,
block_ids: Vec<BlockId>,
children: HashMap<u32, Box<RadixNode>>,
last_accessed: Instant,
}
impl RadixNode {
fn new(tokens: Vec<u32>, block_ids: Vec<BlockId>) -> Self {
Self {
tokens,
block_ids,
children: HashMap::new(),
last_accessed: Instant::now(),
}
}
}
pub struct RadixCache {
root: RadixNode,
pub num_cached_blocks: usize,
pub max_capacity: usize,
}
impl RadixCache {
pub fn new(max_capacity: usize) -> Self {
Self {
root: RadixNode::new(vec![], vec![]),
num_cached_blocks: 0,
max_capacity,
}
}
pub fn match_prefix(&mut self, tokens: &[u32]) -> (Vec<BlockId>, usize) {
let mut current_node = &mut self.root;
let mut matched_blocks = Vec::new();
let mut total_matched_tokens = 0;
let mut token_idx = 0;
while token_idx < tokens.len() {
let first_token = tokens[token_idx];
if let Some(child) = current_node.children.get_mut(&first_token) {
let match_len = child
.tokens
.iter()
.zip(&tokens[token_idx..])
.take_while(|(a, b)| a == b)
.count();
if match_len == child.tokens.len() {
matched_blocks.extend_from_slice(&child.block_ids);
total_matched_tokens += match_len;
token_idx += match_len;
child.last_accessed = Instant::now();
current_node = child;
} else {
break;
}
} else {
break;
}
}
(matched_blocks, total_matched_tokens)
}
pub fn insert(&mut self, tokens: &[u32], block_ids: &[BlockId]) {
let mut current_node = &mut self.root;
let mut token_idx = 0;
while token_idx < tokens.len() {
let first_token = tokens[token_idx];
if let std::collections::hash_map::Entry::Vacant(e) =
current_node.children.entry(first_token)
{
let new_node = RadixNode::new(tokens[token_idx..].to_vec(), block_ids.to_vec());
e.insert(Box::new(new_node));
self.num_cached_blocks += block_ids.len();
break;
}
let child = current_node.children.get_mut(&first_token).unwrap();
token_idx += child.tokens.len();
current_node = child;
}
if self.num_cached_blocks > self.max_capacity {
self.evict_lru();
}
}
pub fn evict_lru(&mut self) -> Vec<BlockId> {
let mut blocks_to_free = Vec::new();
loop {
let blocks = self.remove_oldest_leaf();
match blocks {
Some(evicted_blocks) => {
self.num_cached_blocks -= evicted_blocks.len();
blocks_to_free.extend(evicted_blocks);
if self.num_cached_blocks <= self.max_capacity {
break;
}
}
None => break,
}
}
blocks_to_free
}
fn remove_oldest_leaf(&mut self) -> Option<Vec<BlockId>> {
let node = &mut self.root;
if node.children.is_empty() {
return None;
}
let mut oldest_token = None;
let mut oldest_time = Instant::now();
for (token, child) in &node.children {
if child.children.is_empty() && child.last_accessed < oldest_time {
oldest_time = child.last_accessed;
oldest_token = Some(*token);
}
}
if let Some(token) = oldest_token {
let child = node.children.remove(&token).unwrap();
Some(child.block_ids)
} else {
None
}
}
}