mod fingerprint;
mod l0;
mod l1;
use std::sync::Arc;
use anyhow::Result;
pub use fingerprint::TokenizerFingerprint;
pub use l0::{CacheStats, L0Cache};
pub use l1::{L1Cache, L1CacheStats};
use rayon::prelude::*;
use crate::{
chat_template::{
ChatTemplateContentFormat, ChatTemplateParams, ThinkingKeyName, ThinkingToggle,
},
traits::{Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer},
};
#[derive(Debug, Clone)]
pub struct CacheConfig {
pub enable_l0: bool,
pub l0_max_entries: usize,
pub enable_l1: bool,
pub l1_max_memory: usize,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
enable_l0: true,
l0_max_entries: 10_000, enable_l1: false, l1_max_memory: 50 * 1024 * 1024, }
}
}
pub struct CachedTokenizer {
inner: Arc<dyn Tokenizer>,
l0: Option<L0Cache>,
l1: Option<L1Cache>,
fingerprint: TokenizerFingerprint,
special_token_strings: Vec<String>,
}
impl CachedTokenizer {
pub fn new(inner: Arc<dyn Tokenizer>, config: CacheConfig) -> Self {
let fingerprint = TokenizerFingerprint::from_tokenizer(inner.as_ref());
let l0 = if config.enable_l0 {
Some(L0Cache::new(config.l0_max_entries))
} else {
None
};
let l1 = if config.enable_l1 {
Some(L1Cache::new(config.l1_max_memory))
} else {
None
};
let special_token_strings = Self::extract_special_token_strings(&inner);
Self {
inner,
l0,
l1,
fingerprint,
special_token_strings,
}
}
fn extract_special_token_strings(tokenizer: &Arc<dyn Tokenizer>) -> Vec<String> {
let special_tokens = tokenizer.get_special_tokens();
let mut tokens = Vec::new();
if let Some(ref token) = special_tokens.bos_token {
tokens.push(token.clone());
}
if let Some(ref token) = special_tokens.eos_token {
tokens.push(token.clone());
}
if let Some(ref token) = special_tokens.unk_token {
tokens.push(token.clone());
}
if let Some(ref token) = special_tokens.sep_token {
tokens.push(token.clone());
}
if let Some(ref token) = special_tokens.pad_token {
tokens.push(token.clone());
}
if let Some(ref token) = special_tokens.cls_token {
tokens.push(token.clone());
}
if let Some(ref token) = special_tokens.mask_token {
tokens.push(token.clone());
}
tokens.extend(special_tokens.additional_special_tokens.iter().cloned());
tokens
}
pub fn cache_stats(&self) -> Option<CacheStats> {
self.l0.as_ref().map(|cache| cache.stats())
}
pub fn l1_cache_stats(&self) -> Option<L1CacheStats> {
self.l1.as_ref().map(|cache| cache.stats())
}
pub fn clear_cache(&self) {
if let Some(l0) = &self.l0 {
l0.clear();
}
if let Some(l1) = &self.l1 {
l1.clear();
}
}
pub fn fingerprint(&self) -> &TokenizerFingerprint {
&self.fingerprint
}
pub fn inner(&self) -> &Arc<dyn Tokenizer> {
&self.inner
}
}
impl Encoder for CachedTokenizer {
fn encode(&self, input: &str, add_special_tokens: bool) -> Result<Encoding> {
if let Some(l0) = &self.l0 {
if let Some(cached) = l0.get(input, add_special_tokens) {
return Ok((*cached).clone());
}
}
if let Some(l1) = &self.l1 {
let tokens: Vec<&str> = self
.special_token_strings
.iter()
.map(|s| s.as_str())
.collect();
if let Some((prefix_tokens, prefix_len)) = l1.longest_prefix_match(input, &tokens) {
let suffix = &input[prefix_len..];
if !suffix.is_empty() {
let suffix_encoding = self.inner.encode(suffix, add_special_tokens)?;
let mut merged_tokens = prefix_tokens;
merged_tokens.extend_from_slice(suffix_encoding.token_ids());
let merged_encoding = Encoding::Plain(merged_tokens);
if let Some(l0) = &self.l0 {
l0.insert(
input.to_string(),
add_special_tokens,
merged_encoding.clone(),
);
}
return Ok(merged_encoding);
}
}
}
let encoding = self.inner.encode(input, add_special_tokens)?;
if let Some(l0) = &self.l0 {
l0.insert(input.to_string(), add_special_tokens, encoding.clone());
}
if let Some(l1) = &self.l1 {
let tokens: Vec<&str> = self
.special_token_strings
.iter()
.map(|s| s.as_str())
.collect();
let _ =
l1.insert_at_boundaries(input, self.inner.as_ref(), &tokens, add_special_tokens);
}
Ok(encoding)
}
fn encode_batch(&self, inputs: &[&str], add_special_tokens: bool) -> Result<Vec<Encoding>> {
inputs
.par_iter()
.map(|&input| self.encode(input, add_special_tokens))
.collect()
}
}
impl Decoder for CachedTokenizer {
fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String> {
self.inner.decode(token_ids, skip_special_tokens)
}
}
impl Tokenizer for CachedTokenizer {
fn vocab_size(&self) -> usize {
self.inner.vocab_size()
}
fn get_special_tokens(&self) -> &SpecialTokens {
self.inner.get_special_tokens()
}
fn token_to_id(&self, token: &str) -> Option<TokenIdType> {
self.inner.token_to_id(token)
}
fn id_to_token(&self, id: TokenIdType) -> Option<String> {
self.inner.id_to_token(id)
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn apply_chat_template(
&self,
messages: &[serde_json::Value],
params: ChatTemplateParams,
) -> Result<String> {
self.inner.apply_chat_template(messages, params)
}
fn chat_template_content_format(&self) -> ChatTemplateContentFormat {
self.inner.chat_template_content_format()
}
fn thinking_toggle(&self) -> ThinkingToggle {
self.inner.thinking_toggle()
}
fn thinking_key_name(&self) -> Option<ThinkingKeyName> {
self.inner.thinking_key_name()
}
fn think_in_prefill(&self) -> bool {
self.inner.think_in_prefill()
}
}
#[cfg(test)]
mod tests {
use crate::{mock::MockTokenizer, *};
#[test]
fn test_cache_hit() {
let tokenizer = Arc::new(MockTokenizer::new());
let cached = CachedTokenizer::new(tokenizer, CacheConfig::default());
let input = "Hello world";
let result1 = cached.encode(input, false).unwrap();
let result2 = cached.encode(input, false).unwrap();
assert_eq!(result1.token_ids(), result2.token_ids());
let stats = cached.cache_stats().unwrap();
assert_eq!(stats.hits, 1);
assert_eq!(stats.misses, 1);
}
#[test]
fn test_cache_disabled() {
let tokenizer = Arc::new(MockTokenizer::new());
let config = CacheConfig {
enable_l0: false,
l0_max_entries: 0,
enable_l1: false,
l1_max_memory: 0,
};
let cached = CachedTokenizer::new(tokenizer, config);
let input = "Hello world";
let result1 = cached.encode(input, false).unwrap();
let result2 = cached.encode(input, false).unwrap();
assert_eq!(result1.token_ids(), result2.token_ids());
assert!(cached.cache_stats().is_none());
}
#[test]
fn test_encode_batch() {
let tokenizer = Arc::new(MockTokenizer::new());
let cached = CachedTokenizer::new(tokenizer, CacheConfig::default());
let inputs = vec!["Hello", "world", "Hello"];
let results = cached.encode_batch(&inputs, false).unwrap();
assert_eq!(results.len(), 3);
assert_eq!(results[0].token_ids(), results[2].token_ids());
let _ = cached.encode("Hello", false).unwrap();
let stats = cached.cache_stats().unwrap();
assert!(
stats.hits >= 1,
"Expected at least 1 cache hit after batch processing"
);
}
#[test]
fn test_decoder_passthrough() {
let tokenizer = Arc::new(MockTokenizer::new());
let cached = CachedTokenizer::new(tokenizer, CacheConfig::default());
let tokens = vec![1, 2, 3];
let decoded = cached.decode(&tokens, false).unwrap();
assert!(!decoded.is_empty());
}
#[test]
fn test_tokenizer_trait_methods() {
let tokenizer = Arc::new(MockTokenizer::new());
let cached = CachedTokenizer::new(tokenizer.clone(), CacheConfig::default());
assert_eq!(cached.vocab_size(), tokenizer.vocab_size());
assert!(cached.token_to_id("Hello").is_some());
assert!(cached.id_to_token(1).is_some());
}
}