use std::{
mem::size_of,
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
};
use blake3;
use dashmap::DashMap;
use crate::traits::TokenIdType;
type Blake3Hash = [u8; 32];
const NUM_SHARDS: usize = 16;
fn find_special_token_boundaries(text: &str, special_tokens: &[&str]) -> Vec<usize> {
if special_tokens.is_empty() {
return Vec::new();
}
let mut boundaries = Vec::new();
for &token in special_tokens {
let mut start = 0;
while let Some(pos) = text[start..].find(token) {
let boundary = start + pos + token.len();
if boundary < text.len() {
boundaries.push(boundary);
}
start = boundary;
}
}
boundaries.sort_unstable();
boundaries.dedup();
boundaries
}
#[derive(Debug, Clone)]
struct CachedPrefix {
tokens: Arc<[TokenIdType]>,
last_accessed: Arc<AtomicU64>,
size_bytes: usize,
}
pub struct L1Cache {
shards: Vec<Arc<DashMap<Blake3Hash, CachedPrefix>>>,
max_memory: usize,
current_memory: AtomicU64,
hits: AtomicU64,
misses: AtomicU64,
access_counter: AtomicU64,
}
impl L1Cache {
pub fn new(max_memory: usize) -> Self {
let shards = (0..NUM_SHARDS).map(|_| Arc::new(DashMap::new())).collect();
Self {
shards,
max_memory,
current_memory: AtomicU64::new(0),
hits: AtomicU64::new(0),
misses: AtomicU64::new(0),
access_counter: AtomicU64::new(0),
}
}
pub fn longest_prefix_match(
&self,
input: &str,
special_tokens: &[&str],
) -> Option<(Vec<TokenIdType>, usize)> {
let boundaries = find_special_token_boundaries(input, special_tokens);
if boundaries.is_empty() {
self.misses.fetch_add(1, Ordering::Relaxed);
return None;
}
let mut hasher = blake3::Hasher::new();
let mut prefix_hashes = Vec::with_capacity(boundaries.len());
let mut last_pos = 0;
let bytes = input.as_bytes();
for &boundary_pos in &boundaries {
hasher.update(&bytes[last_pos..boundary_pos]);
prefix_hashes.push((boundary_pos, *hasher.clone().finalize().as_bytes()));
last_pos = boundary_pos;
}
for (boundary_pos, hash_bytes) in prefix_hashes.into_iter().rev() {
let shard_idx = hash_bytes[0] as usize % NUM_SHARDS;
if let Some(entry) = self.shards[shard_idx].get(&hash_bytes) {
let timestamp = self.access_counter.fetch_add(1, Ordering::Relaxed);
entry.last_accessed.store(timestamp, Ordering::Relaxed);
self.hits.fetch_add(1, Ordering::Relaxed);
return Some((entry.tokens.to_vec(), boundary_pos));
}
}
self.misses.fetch_add(1, Ordering::Relaxed);
None
}
pub fn insert_at_boundaries<E: super::super::traits::Encoder + ?Sized>(
&self,
input: &str,
tokenizer: &E,
special_tokens: &[&str],
add_special_tokens: bool,
) -> anyhow::Result<()> {
let boundaries = find_special_token_boundaries(input, special_tokens);
if boundaries.is_empty() {
return Ok(());
}
let mut hasher = blake3::Hasher::new();
let mut running_tokens = Vec::new();
let mut last_pos = 0;
let mut entries_to_insert = Vec::with_capacity(boundaries.len());
let bytes = input.as_bytes();
for (i, &boundary_pos) in boundaries.iter().enumerate() {
let delta_text = &input[last_pos..boundary_pos];
hasher.update(&bytes[last_pos..boundary_pos]);
let hash_bytes: Blake3Hash = *hasher.clone().finalize().as_bytes();
let segment_encoding = tokenizer.encode(delta_text, (i == 0) && add_special_tokens)?;
running_tokens.extend_from_slice(segment_encoding.token_ids());
let prefix_tokens: Arc<[TokenIdType]> = running_tokens.as_slice().into();
let size_bytes = boundary_pos + prefix_tokens.len() * size_of::<TokenIdType>();
entries_to_insert.push((hash_bytes, prefix_tokens, size_bytes));
last_pos = boundary_pos;
}
if entries_to_insert.is_empty() {
return Ok(());
}
let total_size_needed: usize = entries_to_insert.iter().map(|(_, _, size)| size).sum();
let current = self.current_memory.load(Ordering::Relaxed) as usize;
if current + total_size_needed > self.max_memory {
self.evict_lru(total_size_needed);
}
let current_timestamp = self.access_counter.load(Ordering::Relaxed);
for (hash_bytes, prefix_tokens, size_bytes) in entries_to_insert {
let shard_idx = hash_bytes[0] as usize % NUM_SHARDS;
let cached = CachedPrefix {
tokens: prefix_tokens,
last_accessed: Arc::new(AtomicU64::new(current_timestamp)),
size_bytes,
};
if let Some(old) = self.shards[shard_idx].insert(hash_bytes, cached) {
let old_size = old.size_bytes as u64;
let new_size = size_bytes as u64;
if new_size >= old_size {
self.current_memory
.fetch_add(new_size - old_size, Ordering::Relaxed);
} else {
self.current_memory
.fetch_sub(old_size - new_size, Ordering::Relaxed);
}
} else {
self.current_memory
.fetch_add(size_bytes as u64, Ordering::Relaxed);
}
}
Ok(())
}
fn evict_lru(&self, space_needed: usize) {
const SAMPLE_SIZE: usize = 32; let mut freed = 0usize;
let mut iteration = 0usize;
while freed < space_needed {
let mut samples: Vec<(usize, Blake3Hash, u64, usize)> = Vec::with_capacity(SAMPLE_SIZE);
for i in 0..SAMPLE_SIZE {
let shard_idx = (iteration * SAMPLE_SIZE + i) % NUM_SHARDS;
if let Some(entry) = self.shards[shard_idx].iter().next() {
let hash = *entry.key();
let timestamp = entry.value().last_accessed.load(Ordering::Relaxed);
let size = entry.value().size_bytes;
samples.push((shard_idx, hash, timestamp, size));
}
}
if samples.is_empty() {
break;
}
if let Some((shard_idx, hash, _, _)) =
samples.iter().min_by_key(|(_, _, ts, _)| ts).copied()
{
if let Some((_, removed)) = self.shards[shard_idx].remove(&hash) {
freed += removed.size_bytes;
self.current_memory
.fetch_sub(removed.size_bytes as u64, Ordering::Relaxed);
}
}
iteration += 1;
}
}
pub fn len(&self) -> usize {
self.shards.iter().map(|s| s.len()).sum()
}
pub fn is_empty(&self) -> bool {
self.shards.iter().all(|s| s.is_empty())
}
pub fn stats(&self) -> L1CacheStats {
let hits = self.hits.load(Ordering::Relaxed);
let misses = self.misses.load(Ordering::Relaxed);
let total_requests = hits + misses;
L1CacheStats {
hits,
misses,
entries: self.len(),
memory_bytes: self.current_memory.load(Ordering::Relaxed) as usize,
hit_rate: if total_requests > 0 {
hits as f64 / total_requests as f64
} else {
0.0
},
}
}
pub fn clear(&self) {
for shard in &self.shards {
shard.clear();
}
self.current_memory.store(0, Ordering::Relaxed);
self.hits.store(0, Ordering::Relaxed);
self.misses.store(0, Ordering::Relaxed);
}
}
#[derive(Debug, Clone)]
pub struct L1CacheStats {
pub hits: u64,
pub misses: u64,
pub entries: usize,
pub memory_bytes: usize,
pub hit_rate: f64,
}
#[cfg(test)]
mod tests {
use crate::{mock::MockTokenizer, *};
#[test]
fn test_basic_prefix_match() {
let cache = L1Cache::new(1024 * 1024);
let special_tokens = &["<|im_start|>", "<|im_end|>"];
let tokenizer = MockTokenizer::new();
let input1 = "<|im_start|>system\nYou are a helpful assistant that provides clear and detailed responses.<|im_end|><|im_start|>user\nHello there! How are you doing today?<|im_end|>";
cache
.insert_at_boundaries(input1, &tokenizer, special_tokens, false)
.unwrap();
assert!(!cache.is_empty());
let input2 = "<|im_start|>system\nYou are a helpful assistant that provides clear and detailed responses.<|im_end|><|im_start|>user\nWhat is 2+2?<|im_end|>";
let result = cache.longest_prefix_match(input2, special_tokens);
assert!(result.is_some());
let (tokens, offset) = result.unwrap();
assert!(offset > 0);
assert!(!tokens.is_empty());
}
#[test]
fn test_short_input_with_boundaries() {
let cache = L1Cache::new(1024 * 1024);
let special_tokens = &["<|im_start|>", "<|im_end|>"];
let tokenizer = MockTokenizer::new();
let input = "<|im_start|>user\nHi<|im_end|>";
cache
.insert_at_boundaries(input, &tokenizer, special_tokens, false)
.unwrap();
assert!(!cache.is_empty());
let result = cache.longest_prefix_match(input, special_tokens);
assert!(result.is_some());
}
#[test]
fn test_longest_match() {
let cache = L1Cache::new(1024 * 1024);
let special_tokens = &["<|im_start|>", "<|im_end|>"];
let tokenizer = MockTokenizer::new();
let input = "<|im_start|>system\nYou are a helpful AI assistant that provides detailed and accurate responses.<|im_end|><|im_start|>user\nHello there! How are you today? Can you help me understand how tokenization works in language models?<|im_end|><|im_start|>assistant\nI'm doing well, thank you! I'd be happy to explain tokenization. Tokenization is the process of breaking text into smaller units called tokens.<|im_end|>";
cache
.insert_at_boundaries(input, &tokenizer, special_tokens, false)
.unwrap();
assert!(cache.len() >= 2);
let partial_input = "<|im_start|>system\nYou are a helpful AI assistant that provides detailed and accurate responses.<|im_end|><|im_start|>user\nHello there! How are you today? Can you help me understand how tokenization works in language models?<|im_end|>";
let result = cache.longest_prefix_match(partial_input, special_tokens);
assert!(result.is_some());
let (_, offset) = result.unwrap();
assert!(offset > 0);
assert!(offset <= partial_input.len());
}
#[test]
fn test_stats() {
let cache = L1Cache::new(1024 * 1024);
let special_tokens = &["<|im_start|>", "<|im_end|>"];
let tokenizer = MockTokenizer::new();
let input = "<|im_start|>system\nYou are a helpful assistant that provides detailed answers.<|im_end|><|im_start|>user\nHello there! How are you today?<|im_end|>";
cache
.insert_at_boundaries(input, &tokenizer, special_tokens, false)
.unwrap();
let _ = cache.longest_prefix_match(input, special_tokens);
let stats = cache.stats();
assert!(stats.hits >= 1);
assert_eq!(stats.hit_rate, 1.0);
}
#[test]
fn test_clear() {
let cache = L1Cache::new(1024 * 1024);
let special_tokens = &["<|im_start|>", "<|im_end|>"];
let tokenizer = MockTokenizer::new();
let input = "<|im_start|>system\nYou are a helpful assistant that provides clear and detailed responses.<|im_end|><|im_start|>user\nHello there!<|im_end|>";
cache
.insert_at_boundaries(input, &tokenizer, special_tokens, false)
.unwrap();
assert!(!cache.is_empty());
cache.clear();
assert!(cache.is_empty());
let stats = cache.stats();
assert_eq!(stats.hits, 0);
assert_eq!(stats.misses, 0);
}
#[test]
fn test_lru_eviction() {
let cache = L1Cache::new(5 * 1024);
let special_tokens = &["<|im_start|>", "<|im_end|>", "<|eot_id|>"];
let tokenizer = MockTokenizer::new();
let input1 = "<|im_start|>system\nYou are a helpful assistant specialized in mathematics.<|im_end|><|im_start|>user\nCan you explain calculus to me?<|im_end|><|im_start|>assistant\nCertainly! Calculus is a branch of mathematics that studies continuous change.<|im_end|><|eot_id|>";
cache
.insert_at_boundaries(input1, &tokenizer, special_tokens, false)
.unwrap();
let result = cache.longest_prefix_match(input1, special_tokens);
assert!(result.is_some());
let input2 = "<|im_start|>system\nYou are a helpful assistant specialized in physics.<|im_end|><|im_start|>user\nWhat is quantum mechanics?<|im_end|><|im_start|>assistant\nQuantum mechanics is the fundamental theory describing nature at atomic and subatomic scales.<|im_end|><|eot_id|>";
cache
.insert_at_boundaries(input2, &tokenizer, special_tokens, false)
.unwrap();
let result = cache.longest_prefix_match(input2, special_tokens);
assert!(result.is_some());
let input3 = "<|im_start|>system\nYou are a helpful assistant specialized in chemistry.<|im_end|><|im_start|>user\nExplain the periodic table to me please.<|im_end|><|im_start|>assistant\nThe periodic table is a tabular arrangement of chemical elements organized by atomic number and electron configuration.<|im_end|><|eot_id|>";
cache
.insert_at_boundaries(input3, &tokenizer, special_tokens, false)
.unwrap();
let stats = cache.stats();
assert!(stats.memory_bytes <= 5 * 1024);
let result = cache.longest_prefix_match(input3, special_tokens);
assert!(result.is_some());
}
#[test]
fn test_concurrent_access() {
use std::{sync::Arc, thread};
let cache = Arc::new(L1Cache::new(1024 * 1024));
let special_tokens_owned: Vec<String> =
vec!["<|im_start|>".to_string(), "<|im_end|>".to_string()];
let special_tokens_arc = Arc::new(special_tokens_owned);
let mut handles = vec![];
for i in 0..10 {
let cache_clone = cache.clone();
let st_clone = special_tokens_arc.clone();
handles.push(thread::spawn(move || {
let tokenizer = MockTokenizer::new();
let special_tokens: Vec<&str> = st_clone.iter().map(|s| s.as_str()).collect();
let input = format!(
"<|im_start|>system\nYou are assistant number {i}.<|im_end|>\
<|im_start|>user\nThread {i} says hello world test token.<|im_end|>"
);
cache_clone
.insert_at_boundaries(&input, &tokenizer, &special_tokens, false)
.unwrap();
let result = cache_clone.longest_prefix_match(&input, &special_tokens);
assert!(
result.is_some(),
"Thread {i} expected a prefix match after insertion"
);
let (tokens, offset) = result.unwrap();
assert!(
!tokens.is_empty(),
"Thread {i} expected non-empty cached tokens"
);
assert!(offset > 0, "Thread {i} expected positive byte offset");
assert!(
offset <= input.len(),
"Thread {i}: offset {offset} exceeds input length {}",
input.len()
);
}));
}
for handle in handles {
handle.join().unwrap();
}
assert!(!cache.is_empty());
let stats = cache.stats();
assert!(
stats.memory_bytes > 0,
"Expected non-zero memory tracking after concurrent inserts"
);
assert!(
stats.entries > 0,
"Expected non-zero cache entries after concurrent inserts"
);
assert!(
stats.hits >= 10,
"Expected at least 10 cache hits, got {}",
stats.hits
);
}
}