use crate::inference::memory::BlockId;
pub trait PrefixLookup: Send + Sync {
fn lookup_blocks(&self, token_block_hashes: &[u64]) -> Vec<Option<BlockId>>;
}
#[derive(Debug, Clone, Copy, Default)]
pub struct GpuRadixStats {
pub hits: usize,
pub misses: usize,
pub num_entries: usize,
pub capacity: usize,
pub occupancy: f64,
pub hit_rate: f64,
}
const EMPTY_SLOT: i32 = -1;
const FNV_OFFSET: u64 = 0xcbf29ce484222325;
const FNV_PRIME: u64 = 0x100000001b3;
fn next_power_of_two(n: usize) -> usize {
if n <= 16 {
return 16;
}
let mut p = 1usize;
while p < n {
p <<= 1;
}
p
}
fn pack_meta(ref_count: u32, lru_timestamp: u32) -> u64 {
((ref_count as u64) << 32) | (lru_timestamp as u64)
}
fn unpack_meta(meta: u64) -> (u32, u32) {
((meta >> 32) as u32, meta as u32)
}
pub struct GpuRadixTree {
keys: Vec<u64>,
values: Vec<i32>,
metadata: Vec<u64>,
num_entries: usize,
capacity: usize,
block_size: usize,
clock: u32,
hits: usize,
misses: usize,
}
impl GpuRadixTree {
pub fn new(capacity: usize, block_size: usize) -> Self {
let cap = next_power_of_two(capacity);
Self {
keys: vec![0u64; cap],
values: vec![EMPTY_SLOT; cap],
metadata: vec![0u64; cap],
num_entries: 0,
capacity: cap,
block_size,
clock: 0,
hits: 0,
misses: 0,
}
}
fn probe(&self, key: u64) -> Option<usize> {
if self.num_entries == 0 {
return None;
}
let mask = self.capacity - 1;
let mut slot = (key as usize) & mask;
for _ in 0..self.capacity {
let v = self.values[slot];
if v == EMPTY_SLOT {
return None; }
if self.keys[slot] == key {
return Some(slot);
}
slot = (slot + 1) & mask;
}
None
}
fn probe_for_insert(&self, key: u64) -> Option<usize> {
let mask = self.capacity - 1;
let mut slot = (key as usize) & mask;
for _ in 0..self.capacity {
let v = self.values[slot];
if v == EMPTY_SLOT || self.keys[slot] == key {
return Some(slot);
}
slot = (slot + 1) & mask;
}
None
}
pub fn hash_tokens(tokens: &[u32]) -> u64 {
let mut hash = FNV_OFFSET;
for &token in tokens {
let bytes = token.to_le_bytes();
for byte in bytes {
hash ^= byte as u64;
hash = hash.wrapping_mul(FNV_PRIME);
}
}
hash
}
pub fn compute_block_hashes(tokens: &[u32], block_size: usize) -> Vec<u64> {
if block_size == 0 || tokens.is_empty() {
return Vec::new();
}
let num_blocks = tokens.len().div_ceil(block_size);
let mut hashes = Vec::with_capacity(num_blocks);
for block_idx in 0..num_blocks {
let end = ((block_idx + 1) * block_size).min(tokens.len());
hashes.push(Self::hash_tokens(&tokens[..end]));
}
hashes
}
pub fn lookup(&mut self, token_block_hashes: &[u64]) -> Vec<Option<BlockId>> {
let ts = self.clock;
self.clock = self.clock.wrapping_add(1);
token_block_hashes
.iter()
.map(|&key| {
if let Some(slot) = self.probe(key) {
let (ref_count, _) = unpack_meta(self.metadata[slot]);
self.metadata[slot] = pack_meta(ref_count, ts);
self.hits += 1;
Some(self.values[slot] as BlockId)
} else {
self.misses += 1;
None
}
})
.collect()
}
pub fn insert(&mut self, token_hash: u64, block_id: BlockId) -> bool {
if self.num_entries * 4 >= self.capacity * 3 {
return false;
}
if let Some(slot) = self.probe_for_insert(token_hash) {
let is_new = self.values[slot] == EMPTY_SLOT;
self.keys[slot] = token_hash;
self.values[slot] = block_id as i32;
let ts = self.clock;
self.clock = self.clock.wrapping_add(1);
self.metadata[slot] = pack_meta(0, ts);
if is_new {
self.num_entries += 1;
}
true
} else {
false
}
}
pub fn inc_ref(&mut self, token_hash: u64) {
if let Some(slot) = self.probe(token_hash) {
let (ref_count, ts) = unpack_meta(self.metadata[slot]);
self.metadata[slot] = pack_meta(ref_count.saturating_add(1), ts);
}
}
pub fn dec_ref(&mut self, token_hash: u64) {
if let Some(slot) = self.probe(token_hash) {
let (ref_count, ts) = unpack_meta(self.metadata[slot]);
self.metadata[slot] = pack_meta(ref_count.saturating_sub(1), ts);
}
}
pub fn evict_lru(&mut self, num_to_evict: usize) -> Vec<BlockId> {
if num_to_evict == 0 || self.num_entries == 0 {
return Vec::new();
}
let mut candidates: Vec<(u32, usize)> = self .values
.iter()
.enumerate()
.filter_map(|(slot, &v)| {
if v == EMPTY_SLOT {
return None;
}
let (ref_count, ts) = unpack_meta(self.metadata[slot]);
if ref_count == 0 {
Some((ts, slot))
} else {
None
}
})
.collect();
candidates.sort_unstable_by_key(|&(ts, _)| ts);
let evict_count = num_to_evict.min(candidates.len());
let mut evicted = Vec::with_capacity(evict_count);
for (_, slot) in candidates.into_iter().take(evict_count) {
evicted.push(self.values[slot] as BlockId);
self.values[slot] = EMPTY_SLOT;
self.keys[slot] = 0;
self.metadata[slot] = 0;
self.num_entries -= 1;
}
self.rehash();
evicted
}
fn rehash(&mut self) {
let live: Vec<(u64, i32, u64)> = (0..self.capacity)
.filter_map(|slot| {
if self.values[slot] != EMPTY_SLOT {
Some((self.keys[slot], self.values[slot], self.metadata[slot]))
} else {
None
}
})
.collect();
self.keys.iter_mut().for_each(|k| *k = 0);
self.values.iter_mut().for_each(|v| *v = EMPTY_SLOT);
self.metadata.iter_mut().for_each(|m| *m = 0);
self.num_entries = 0;
let mask = self.capacity - 1;
for (key, value, meta) in live {
let mut slot = (key as usize) & mask;
loop {
if self.values[slot] == EMPTY_SLOT {
self.keys[slot] = key;
self.values[slot] = value;
self.metadata[slot] = meta;
self.num_entries += 1;
break;
}
slot = (slot + 1) & mask;
}
}
}
pub fn stats(&self) -> GpuRadixStats {
let occupancy = if self.capacity == 0 {
0.0
} else {
self.num_entries as f64 / self.capacity as f64
};
let total_lookups = self.hits + self.misses;
let hit_rate = if total_lookups == 0 {
0.0
} else {
self.hits as f64 / total_lookups as f64
};
GpuRadixStats {
hits: self.hits,
misses: self.misses,
num_entries: self.num_entries,
capacity: self.capacity,
occupancy,
hit_rate,
}
}
pub fn block_size(&self) -> usize {
self.block_size
}
pub fn num_entries(&self) -> usize {
self.num_entries
}
#[cfg(feature = "cuda")]
pub fn keys(&self) -> &[u64] {
&self.keys
}
#[cfg(feature = "cuda")]
pub fn values(&self) -> &[i32] {
&self.values
}
}
impl PrefixLookup for GpuRadixTree {
fn lookup_blocks(&self, token_block_hashes: &[u64]) -> Vec<Option<BlockId>> {
token_block_hashes
.iter()
.map(|&key| self.probe(key).map(|slot| self.values[slot] as BlockId))
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hash_tokens_deterministic() {
let h1 = GpuRadixTree::hash_tokens(&[1, 2, 3, 4]);
let h2 = GpuRadixTree::hash_tokens(&[1, 2, 3, 4]);
let h3 = GpuRadixTree::hash_tokens(&[1, 2, 3, 5]);
assert_eq!(h1, h2);
assert_ne!(h1, h3);
}
#[test]
fn test_compute_block_hashes_length() {
let hashes = GpuRadixTree::compute_block_hashes(&[0u32; 10], 4);
assert_eq!(hashes.len(), 3);
}
#[test]
fn test_insert_and_lookup() {
let mut tree = GpuRadixTree::new(64, 16);
let key = GpuRadixTree::hash_tokens(&[1, 2, 3, 4]);
assert!(tree.insert(key, 42));
let results = tree.lookup(&[key]);
assert_eq!(results, vec![Some(42)]);
}
#[test]
fn test_miss_returns_none() {
let mut tree = GpuRadixTree::new(64, 16);
let key = GpuRadixTree::hash_tokens(&[9, 9, 9]);
let results = tree.lookup(&[key]);
assert_eq!(results, vec![None]);
}
#[test]
fn test_evict_lru_removes_unreferenced() {
let mut tree = GpuRadixTree::new(64, 16);
let k1 = GpuRadixTree::hash_tokens(&[1]);
let k2 = GpuRadixTree::hash_tokens(&[2]);
tree.insert(k1, 10);
tree.insert(k2, 11);
assert_eq!(tree.num_entries(), 2);
let evicted = tree.evict_lru(1);
assert_eq!(evicted.len(), 1);
assert_eq!(tree.num_entries(), 1);
}
#[test]
fn test_evict_lru_skips_referenced() {
let mut tree = GpuRadixTree::new(64, 16);
let k1 = GpuRadixTree::hash_tokens(&[1]);
tree.insert(k1, 10);
tree.inc_ref(k1);
let evicted = tree.evict_lru(10);
assert!(evicted.is_empty());
assert_eq!(tree.num_entries(), 1);
}
#[test]
fn test_prefix_lookup_trait() {
let mut tree = GpuRadixTree::new(64, 16);
let k = GpuRadixTree::hash_tokens(&[5, 6, 7]);
tree.insert(k, 99);
let lookup: &dyn PrefixLookup = &tree;
let result = lookup.lookup_blocks(&[k, 0xdeadbeefdeadbeef]);
assert_eq!(result[0], Some(99));
assert_eq!(result[1], None);
}
#[test]
fn test_stats_hit_rate() {
let mut tree = GpuRadixTree::new(64, 16);
let k = GpuRadixTree::hash_tokens(&[1, 2]);
tree.insert(k, 7);
tree.lookup(&[k]);
tree.lookup(&[0xdeadbeef]);
let stats = tree.stats();
assert_eq!(stats.hits, 1);
assert_eq!(stats.misses, 1);
assert!((stats.hit_rate - 0.5).abs() < 1e-9);
}
#[test]
fn test_load_factor_limit() {
let mut tree = GpuRadixTree::new(16, 1);
let mut inserted = 0usize;
for i in 0u32..20 {
if tree.insert(GpuRadixTree::hash_tokens(&[i]), i as BlockId) {
inserted += 1;
}
}
assert!(inserted <= 12);
}
}