use slotmap::{SlotMap, new_key_type};
use std::collections::{HashMap, HashSet};
use std::time::Instant;
new_key_type! {
pub struct NodeId;
}
pub struct TokenPool {
free: Vec<usize>,
total: usize,
}
impl TokenPool {
pub fn new(total: usize) -> Self {
let free: Vec<usize> = (0..total).rev().collect();
Self { free, total }
}
pub fn allocate(&mut self, n: usize) -> Option<Vec<usize>> {
if self.free.len() < n {
return None;
}
let start = self.free.len() - n;
let indices: Vec<usize> = self.free.drain(start..).collect();
Some(indices)
}
pub fn free(&mut self, indices: &[usize]) {
self.free.extend(indices);
}
pub fn available(&self) -> usize {
self.free.len()
}
pub fn total(&self) -> usize {
self.total
}
}
pub struct TreeNode {
pub children: HashMap<Vec<u64>, NodeId>,
pub parent: Option<NodeId>,
pub key: Vec<u64>,
pub value: Vec<usize>,
pub lock_ref: usize,
pub last_access_time: Instant,
}
pub struct RadixCache {
nodes: SlotMap<NodeId, TreeNode>,
root: NodeId,
pub token_pool: TokenPool,
page_size: usize,
pub evictable_leaves: HashSet<NodeId>,
pub evictable_size: usize,
pub protected_size: usize,
}
impl RadixCache {
pub fn new(total_tokens: usize, page_size: usize) -> Self {
assert!(page_size >= 1, "page_size must be >= 1");
let mut nodes = SlotMap::with_key();
let root = nodes.insert(TreeNode {
children: HashMap::new(),
parent: None,
key: Vec::new(),
value: Vec::new(),
lock_ref: 0,
last_access_time: Instant::now(),
});
Self {
nodes,
root,
token_pool: TokenPool::new(total_tokens),
page_size,
evictable_leaves: HashSet::new(),
evictable_size: 0,
protected_size: 0,
}
}
pub fn root(&self) -> NodeId {
self.root
}
pub fn node(&self, id: NodeId) -> &TreeNode {
&self.nodes[id]
}
pub fn page_size(&self) -> usize {
self.page_size
}
pub fn num_nodes(&self) -> usize {
self.nodes.len()
}
fn child_key(&self, key: &[u64]) -> Vec<u64> {
key[..self.page_size.min(key.len())].to_vec()
}
fn page_align(&self, len: usize) -> usize {
len / self.page_size * self.page_size
}
fn key_match(&self, key0: &[u64], key1: &[u64]) -> usize {
if self.page_size == 1 {
key0.iter().zip(key1).take_while(|(a, b)| a == b).count()
} else {
let min_len = key0.len().min(key1.len());
let mut i = 0;
while i + self.page_size <= min_len {
if key0[i..i + self.page_size] != key1[i..i + self.page_size] {
break;
}
i += self.page_size;
}
i
}
}
pub fn match_prefix(&mut self, key: &[u64]) -> (usize, NodeId) {
let now = Instant::now();
self.nodes[self.root].last_access_time = now;
let mut current = self.root;
let mut matched: usize = 0;
while matched < key.len() {
let ck = self.child_key(&key[matched..]);
let child_id = match self.nodes[current].children.get(&ck).copied() {
Some(id) => id,
None => break,
};
let child_key = self.nodes[child_id].key.clone();
let common_len = self.key_match(&child_key, &key[matched..]);
if common_len < child_key.len() {
if common_len > 0 {
let intermediate = self.split_node(child_id, common_len);
current = intermediate;
}
matched += common_len;
break;
}
matched += common_len;
current = child_id;
self.nodes[current].last_access_time = now;
}
(matched, current)
}
pub fn prefix_match_len(&self, key: &[u64]) -> usize {
let mut current = self.root;
let mut matched: usize = 0;
while matched < key.len() {
let ck = self.child_key(&key[matched..]);
let child_id = match self.nodes[current].children.get(&ck).copied() {
Some(id) => id,
None => break,
};
let child_key = &self.nodes[child_id].key;
let common_len = self.key_match(child_key, &key[matched..]);
if common_len < child_key.len() {
matched += common_len;
break;
}
matched += common_len;
current = child_id;
}
matched / self.page_size * self.page_size
}
pub fn insert(&mut self, key: &[u64], value: &[usize]) {
let aligned_len = self.page_align(key.len());
if aligned_len == 0 {
return;
}
assert!(
value.len() >= aligned_len,
"not enough token indices: need {aligned_len}, got {}",
value.len()
);
let key = &key[..aligned_len];
let value = &value[..aligned_len];
let now = Instant::now();
self.nodes[self.root].last_access_time = now;
let mut current = self.root;
let mut key_offset: usize = 0;
while key_offset < key.len() {
let ck = self.child_key(&key[key_offset..]);
let child_id = match self.nodes[current].children.get(&ck).copied() {
Some(id) => id,
None => {
self.create_child(current, &key[key_offset..], &value[key_offset..]);
return;
}
};
let child_key = self.nodes[child_id].key.clone();
let common_len = self.key_match(&child_key, &key[key_offset..]);
if common_len == child_key.len() {
key_offset += common_len;
current = child_id;
self.nodes[current].last_access_time = now;
} else {
if common_len > 0 {
let intermediate = self.split_node(child_id, common_len);
key_offset += common_len;
if key_offset < key.len() {
self.create_child(intermediate, &key[key_offset..], &value[key_offset..]);
}
}
return;
}
}
}
fn split_node(&mut self, child_id: NodeId, split_pos: usize) -> NodeId {
let child = &self.nodes[child_id];
let child_parent = child.parent;
let child_key = child.key.clone();
let child_value = child.value.clone();
let child_lock_ref = child.lock_ref;
let child_last_access = child.last_access_time;
let suffix_ck = self.child_key(&child_key[split_pos..]);
let mut inter_children = HashMap::new();
inter_children.insert(suffix_ck, child_id);
let intermediate = TreeNode {
children: inter_children,
parent: child_parent,
key: child_key[..split_pos].to_vec(),
value: child_value[..split_pos].to_vec(),
lock_ref: child_lock_ref,
last_access_time: child_last_access,
};
let inter_id = self.nodes.insert(intermediate);
let child = &mut self.nodes[child_id];
child.key = child_key[split_pos..].to_vec();
child.value = child_value[split_pos..].to_vec();
child.parent = Some(inter_id);
let original_ck = self.child_key(&child_key);
if let Some(parent_id) = child_parent {
self.nodes[parent_id].children.insert(original_ck, inter_id);
}
if self.evictable_leaves.contains(&child_id) {
let old_tokens = child_key.len();
let new_tokens = child_key.len() - split_pos;
self.evictable_size = self.evictable_size - old_tokens + new_tokens;
}
inter_id
}
fn create_child(&mut self, parent_id: NodeId, key: &[u64], value: &[usize]) {
let new_node = TreeNode {
children: HashMap::new(),
parent: Some(parent_id),
key: key.to_vec(),
value: value.to_vec(),
lock_ref: 0,
last_access_time: Instant::now(),
};
let ck = self.child_key(key);
let new_id = self.nodes.insert(new_node);
if self.evictable_leaves.remove(&parent_id) {
let parent_tokens = self.nodes[parent_id].key.len();
self.evictable_size -= parent_tokens;
}
self.nodes[parent_id].children.insert(ck, new_id);
self.evictable_leaves.insert(new_id);
self.evictable_size += key.len();
}
pub fn is_leaf(&self, id: NodeId) -> bool {
self.nodes[id].children.is_empty()
}
pub fn inc_lock_ref(&mut self, node_id: NodeId) {
let mut current = Some(node_id);
while let Some(id) = current {
if id == self.root {
break;
}
let node = &mut self.nodes[id];
let tokens = node.key.len();
node.lock_ref += 1;
if node.lock_ref == 1 {
if self.evictable_leaves.remove(&id) {
self.evictable_size -= tokens;
}
self.protected_size += tokens;
}
current = self.nodes[id].parent;
}
}
pub fn dec_lock_ref(&mut self, node_id: NodeId) {
let mut current = Some(node_id);
while let Some(id) = current {
if id == self.root {
break;
}
let node = &mut self.nodes[id];
if node.lock_ref == 0 {
tracing::warn!("dec_lock_ref on node with lock_ref == 0, skipping");
break;
}
node.lock_ref -= 1;
if node.lock_ref == 0 {
let tokens = node.key.len();
self.protected_size -= tokens;
if self.is_leaf(id) {
self.evictable_leaves.insert(id);
self.evictable_size += tokens;
}
}
current = self.nodes[id].parent;
}
}
pub fn evict(&mut self, num_tokens: usize) -> (usize, Vec<usize>) {
let mut evicted = 0;
let mut evicted_indices = Vec::new();
while evicted < num_tokens {
let victim = self
.evictable_leaves
.iter()
.min_by_key(|&&id| self.nodes[id].last_access_time)
.copied();
let Some(victim_id) = victim else {
break;
};
let victim_node = &self.nodes[victim_id];
let tokens = victim_node.key.len();
let pool_indices = victim_node.value.clone();
let parent_id = victim_node.parent;
let victim_key = victim_node.key.clone();
self.evictable_leaves.remove(&victim_id);
self.evictable_size -= tokens;
evicted += tokens;
evicted_indices.extend_from_slice(&pool_indices);
self.token_pool.free(&pool_indices);
if let Some(pid) = parent_id {
let ck = self.child_key(&victim_key);
self.nodes[pid].children.remove(&ck);
if pid != self.root
&& self.nodes[pid].children.is_empty()
&& self.nodes[pid].lock_ref == 0
{
let parent_tokens = self.nodes[pid].key.len();
self.evictable_leaves.insert(pid);
self.evictable_size += parent_tokens;
}
}
self.nodes.remove(victim_id);
}
(evicted, evicted_indices)
}
pub fn available_tokens(&self) -> usize {
self.token_pool.available()
}
pub fn total_tokens(&self) -> usize {
self.token_pool.total()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_token_pool_allocate_and_free() {
let mut pool = TokenPool::new(10);
assert_eq!(pool.available(), 10);
let a = pool.allocate(3).unwrap();
assert_eq!(a.len(), 3);
assert_eq!(pool.available(), 7);
let b = pool.allocate(7).unwrap();
assert_eq!(pool.available(), 0);
assert!(pool.allocate(1).is_none());
pool.free(&a);
assert_eq!(pool.available(), 3);
pool.free(&b);
assert_eq!(pool.available(), 10);
}
#[test]
fn test_match_prefix() {
let mut cache = RadixCache::new(100, 1);
let (len, node) = cache.match_prefix(&[1, 2, 3]);
assert_eq!(len, 0);
assert_eq!(node, cache.root());
cache.insert(&[1, 2, 3, 4, 5], &[10, 20, 30, 40, 50]);
assert_eq!(cache.match_prefix(&[1, 2, 3, 4, 5]).0, 5);
cache.insert(&[1, 2, 3, 4, 5, 6, 7], &[10, 20, 30, 40, 50, 60, 70]);
let (len, node) = cache.match_prefix(&[1, 2, 3, 4, 5, 9, 9]);
assert_eq!(len, 5);
let n = cache.node(node);
assert_eq!(n.key, vec![1, 2, 3, 4, 5]);
assert_eq!(n.value, vec![10, 20, 30, 40, 50]);
let &suffix_id = n.children.get(&vec![6]).unwrap();
assert_eq!(cache.node(suffix_id).value, vec![60, 70]);
}
#[test]
fn test_insert() {
let mut cache = RadixCache::new(100, 1);
cache.insert(&[1, 2, 3, 4, 5], &[10, 20, 30, 40, 50]);
cache.insert(&[1, 2, 3, 6, 7], &[10, 20, 30, 60, 70]);
assert_eq!(cache.match_prefix(&[1, 2, 3, 4, 5]).0, 5);
assert_eq!(cache.match_prefix(&[1, 2, 3, 6, 7]).0, 5);
assert_eq!(cache.match_prefix(&[1, 2, 3, 9]).0, 3);
let mut cache = RadixCache::new(100, 1);
cache.insert(&[1, 2, 3], &[10, 20, 30]);
cache.insert(&[1, 2, 3, 4, 5], &[10, 20, 30, 40, 50]);
assert_eq!(cache.match_prefix(&[1, 2, 3, 4, 5]).0, 5);
cache.insert(&[1, 2, 3], &[10, 20, 30]);
let mut cache = RadixCache::new(100, 1);
cache.insert(&[1, 2, 3, 4, 5], &[10, 20, 30, 40, 50]);
assert_eq!(cache.match_prefix(&[1, 2, 3, 4, 5, 6, 7, 8]).0, 5);
cache.insert(&[1, 2, 3, 4, 5, 6, 7, 8], &[10, 20, 30, 40, 50, 60, 70, 80]);
assert_eq!(cache.match_prefix(&[1, 2, 3, 4, 5, 6, 7, 8]).0, 8);
}
#[test]
fn test_page_size() {
let mut cache = RadixCache::new(100, 4);
assert_eq!(cache.token_pool.total(), 100);
cache.insert(&[1, 2, 3, 4, 5, 6, 7], &[0, 1, 2, 3, 4, 5, 6]);
assert_eq!(cache.match_prefix(&[1, 2, 3, 4]).0, 4);
let (_, node) = cache.match_prefix(&[1, 2, 3, 4]);
assert_eq!(cache.node(node).value, vec![0, 1, 2, 3]);
cache.insert(&[1, 2, 3, 4, 5, 6, 7, 8], &[0, 1, 2, 3, 10, 11, 12, 13]);
assert_eq!(cache.match_prefix(&[1, 2, 3, 4, 5, 6, 7, 8]).0, 8);
let mut cache = RadixCache::new(100, 4);
cache.insert(&[1, 2, 3, 4], &[0, 1, 2, 3]);
cache.insert(&[1, 2, 3, 5], &[4, 5, 6, 7]);
assert_eq!(cache.match_prefix(&[1, 2, 3, 4]).0, 4);
assert_eq!(cache.match_prefix(&[1, 2, 3, 5]).0, 4);
assert_eq!(cache.match_prefix(&[1, 2, 3, 6]).0, 0);
let mut cache = RadixCache::new(100, 4);
cache.insert(&[1, 2, 3, 4, 5, 6, 7, 8], &[0, 1, 2, 3, 10, 11, 12, 13]);
cache.match_prefix(&[1, 2, 3, 4, 9, 9, 9, 9]);
let (_, node) = cache.match_prefix(&[1, 2, 3, 4]);
assert_eq!(cache.node(node).value, vec![0, 1, 2, 3]);
}
#[test]
fn test_lock_unlock_shared_prefix() {
let mut cache = RadixCache::new(100, 1);
cache.insert(&[1, 2, 3, 4, 5], &[0, 1, 2, 3, 4]);
cache.insert(&[1, 2, 3, 6, 7], &[0, 1, 2, 5, 6]);
let (_, node_a) = cache.match_prefix(&[1, 2, 3, 4, 5]);
let (_, node_b) = cache.match_prefix(&[1, 2, 3, 6, 7]);
cache.inc_lock_ref(node_a);
cache.inc_lock_ref(node_b);
assert_eq!(cache.protected_size, 7);
cache.dec_lock_ref(node_a);
assert!(cache.evictable_leaves.contains(&node_a));
cache.dec_lock_ref(node_b);
assert_eq!(cache.protected_size, 0);
}
#[test]
fn test_evict() {
let mut cache = RadixCache::new(100, 1);
cache.insert(&[1, 2, 3], &[0, 1, 2]);
let (_, n1) = cache.match_prefix(&[1, 2, 3]);
cache.inc_lock_ref(n1);
cache.dec_lock_ref(n1);
std::thread::sleep(std::time::Duration::from_millis(1));
cache.insert(&[4, 5, 6], &[3, 4, 5]);
let (_, n2) = cache.match_prefix(&[4, 5, 6]);
cache.inc_lock_ref(n2);
cache.dec_lock_ref(n2);
let (evicted_count, evicted_indices) = cache.evict(3);
assert_eq!(evicted_count, 3);
let mut sorted_evicted = evicted_indices.clone();
sorted_evicted.sort();
let mut expected_indices = vec![0, 1, 2];
expected_indices.sort();
assert_eq!(
sorted_evicted, expected_indices,
"evicted indices should match inserted indices"
);
assert_eq!(cache.match_prefix(&[1, 2, 3]).0, 0); assert_eq!(cache.match_prefix(&[4, 5, 6]).0, 3);
let mut cache = RadixCache::new(100, 1);
cache.insert(&[1, 2, 3], &[0, 1, 2]);
cache.insert(&[4, 5, 6], &[3, 4, 5]);
let (_, locked) = cache.match_prefix(&[1, 2, 3]);
cache.inc_lock_ref(locked);
let (_, unlocked) = cache.match_prefix(&[4, 5, 6]);
cache.inc_lock_ref(unlocked);
cache.dec_lock_ref(unlocked);
let (evicted_count, evicted_indices) = cache.evict(6);
assert_eq!(evicted_count, 3); let mut sorted_evicted = evicted_indices;
sorted_evicted.sort();
assert_eq!(
sorted_evicted,
vec![3, 4, 5],
"should evict unlocked [4,5,6] indices"
);
assert_eq!(cache.match_prefix(&[1, 2, 3]).0, 3);
}
#[test]
fn test_query_methods() {
let cache = RadixCache::new(100, 1);
assert_eq!(cache.available_tokens(), 100);
assert_eq!(cache.total_tokens(), 100);
let cache4 = RadixCache::new(100, 4);
assert_eq!(cache4.available_tokens(), 100);
assert_eq!(cache4.total_tokens(), 100);
}
}