use once_cell::sync::Lazy;
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::sync::RwLock;
use tiktoken_rs::{cl100k_base, CoreBPE};
use tokenizers::Tokenizer;
use tracing::{debug, warn};
const MAX_CACHE_ENTRIES: usize = 1_000;
static TOKENIZER: Lazy<TokenizerState> = Lazy::new(TokenizerState::new);
static TOKEN_CACHE: Lazy<RwLock<HashMap<u64, usize>>> =
Lazy::new(|| RwLock::new(HashMap::with_capacity(256)));
enum TokenizerState {
Qwen(Box<Tokenizer>),
Tiktoken(CoreBPE),
Heuristic,
}
impl TokenizerState {
fn new() -> Self {
match Tokenizer::from_pretrained("Qwen/Qwen2.5-Coder-32B", None) {
Ok(tokenizer) => {
debug!("Successfully loaded Qwen tokenizer from HF Hub");
return TokenizerState::Qwen(Box::new(tokenizer));
}
Err(e) => {
warn!(
"Failed to load Qwen tokenizer: {}. Falling back to tiktoken cl100k",
e
);
}
}
match cl100k_base() {
Ok(bpe) => TokenizerState::Tiktoken(bpe),
Err(_) => TokenizerState::Heuristic,
}
}
fn count(&self, content: &str) -> usize {
match self {
TokenizerState::Qwen(t) => t
.encode(content, false)
.map(|e| e.get_tokens().len())
.unwrap_or_else(|_| heuristic_estimate(content)),
TokenizerState::Tiktoken(bpe) => bpe.encode_with_special_tokens(content).len(),
TokenizerState::Heuristic => heuristic_estimate(content),
}
}
}
#[inline]
pub fn estimate_tokens_with_overhead(content: &str, message_overhead: usize) -> usize {
estimate_content_tokens(content) + message_overhead
}
#[inline]
pub fn estimate_content_tokens(content: &str) -> usize {
let key = hash_content(content);
if let Ok(cache) = TOKEN_CACHE.read() {
if let Some(&count) = cache.get(&key) {
return count;
}
}
let count = TOKENIZER.count(content);
if let Ok(mut cache) = TOKEN_CACHE.write() {
if cache.len() >= MAX_CACHE_ENTRIES {
cache.clear();
}
cache.insert(key, count);
}
count
}
fn hash_content(content: &str) -> u64 {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
content.hash(&mut hasher);
hasher.finish()
}
fn heuristic_estimate(content: &str) -> usize {
let factor = if content.contains('{') || content.contains(';') {
3
} else {
4
};
(content.len() / factor).max(1)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_estimate_content_tokens_non_zero() {
let tokens = estimate_content_tokens("hello world");
assert!(tokens > 0);
}
#[test]
fn test_estimate_tokens_with_overhead() {
let tokens = estimate_tokens_with_overhead("hello", 10);
assert!(tokens >= 11);
}
#[test]
fn test_estimate_content_tokens_code() {
let tokens = estimate_content_tokens("fn main() { println!(\"hi\"); }");
assert!(tokens > 0);
}
#[test]
fn test_cache_returns_consistent_results() {
let content = "The quick brown fox jumps over the lazy dog";
let first = estimate_content_tokens(content);
let second = estimate_content_tokens(content);
assert_eq!(first, second);
}
#[test]
fn test_hash_content_deterministic() {
let a = hash_content("hello");
let b = hash_content("hello");
assert_eq!(a, b);
let c = hash_content("world");
assert_ne!(a, c);
}
}