use super::counts::TokenCounts;
use super::models::TokenModel;
use dashmap::DashMap;
use std::hash::{Hash, Hasher};
use std::sync::OnceLock;
use tiktoken_rs::{cl100k_base, o200k_base, CoreBPE};
static GPT4O_TOKENIZER: OnceLock<CoreBPE> = OnceLock::new();
static GPT4_TOKENIZER: OnceLock<CoreBPE> = OnceLock::new();
struct CacheEntry {
count: u32,
last_access: u32,
}
static TOKEN_CACHE: OnceLock<DashMap<(u64, TokenModel), CacheEntry>> = OnceLock::new();
const MAX_CACHE_ENTRIES: usize = 100_000;
const EVICTION_FRACTION: usize = 2;
fn get_token_cache() -> &'static DashMap<(u64, TokenModel), CacheEntry> {
TOKEN_CACHE.get_or_init(DashMap::new)
}
fn current_timestamp() -> u32 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs() as u32)
.unwrap_or(0)
}
fn maybe_cleanup_cache(cache: &DashMap<(u64, TokenModel), CacheEntry>) {
if cache.len() < MAX_CACHE_ENTRIES {
return;
}
let mut entries: Vec<((u64, TokenModel), u32)> = cache
.iter()
.map(|entry| (*entry.key(), entry.value().last_access))
.collect();
entries.sort_by_key(|(_, ts)| *ts);
let to_remove = entries.len() / EVICTION_FRACTION;
for (key, _) in entries.into_iter().take(to_remove) {
cache.remove(&key);
}
}
fn hash_content(content: &str) -> u64 {
use std::collections::hash_map::DefaultHasher;
let mut hasher = DefaultHasher::new();
content.hash(&mut hasher);
hasher.finish()
}
fn get_gpt4o_tokenizer() -> &'static CoreBPE {
GPT4O_TOKENIZER.get_or_init(|| {
o200k_base().expect("tiktoken o200k_base initialization failed - please report this bug")
})
}
fn get_gpt4_tokenizer() -> &'static CoreBPE {
GPT4_TOKENIZER.get_or_init(|| {
cl100k_base().expect("tiktoken cl100k_base initialization failed - please report this bug")
})
}
#[derive(Clone, Copy)]
struct EstimationStats {
len: usize,
whitespace_count: u32,
newline_count: u32,
special_char_count: u32,
}
pub struct Tokenizer {
use_exact: bool,
use_cache: bool,
}
impl Default for Tokenizer {
fn default() -> Self {
Self::new()
}
}
impl Tokenizer {
pub fn new() -> Self {
Self { use_exact: true, use_cache: true }
}
pub fn estimation_only() -> Self {
Self { use_exact: false, use_cache: true }
}
pub fn without_cache() -> Self {
Self { use_exact: true, use_cache: false }
}
#[must_use]
pub fn count(&self, text: &str, model: TokenModel) -> u32 {
if text.is_empty() {
return 0;
}
if self.use_cache {
let cache = get_token_cache();
let content_hash = hash_content(text);
let key = (content_hash, model);
let now = current_timestamp();
if let Some(mut entry) = cache.get_mut(&key) {
entry.last_access = now;
return entry.count;
}
let count = self.count_uncached(text, model);
maybe_cleanup_cache(cache);
cache.insert(key, CacheEntry { count, last_access: now });
count
} else {
self.count_uncached(text, model)
}
}
fn count_uncached(&self, text: &str, model: TokenModel) -> u32 {
if self.use_exact && model.has_exact_tokenizer() {
self.count_exact(text, model)
} else {
self.estimate(text, model)
}
}
fn count_exact(&self, text: &str, model: TokenModel) -> u32 {
if model.uses_o200k() {
let tokenizer = get_gpt4o_tokenizer();
self.tokenize_with_panic_guard(tokenizer, text, model)
} else if model.uses_cl100k() {
let tokenizer = get_gpt4_tokenizer();
self.tokenize_with_panic_guard(tokenizer, text, model)
} else {
self.estimate(text, model)
}
}
fn tokenize_with_panic_guard(&self, tokenizer: &CoreBPE, text: &str, model: TokenModel) -> u32 {
let prev_hook = std::panic::take_hook();
std::panic::set_hook(Box::new(|_| {
}));
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
tokenizer.encode_ordinary(text).len() as u32
}));
std::panic::set_hook(prev_hook);
match result {
Ok(count) => count,
Err(_) => self.estimate(text, model), }
}
fn estimate(&self, text: &str, model: TokenModel) -> u32 {
if text.is_empty() {
return 0;
}
let stats = compute_estimation_stats(text);
estimate_from_stats(&stats, model)
}
pub fn count_all(&self, text: &str) -> TokenCounts {
if text.is_empty() {
return TokenCounts::default();
}
let content_hash = hash_content(text);
let cache = if self.use_cache {
Some(get_token_cache())
} else {
None
};
let now = current_timestamp();
let get_exact = |model: TokenModel, tokenizer: &CoreBPE| -> u32 {
if let Some(cache) = cache {
let key = (content_hash, model);
if let Some(mut entry) = cache.get_mut(&key) {
entry.last_access = now;
return entry.count;
}
let count = self.tokenize_with_panic_guard(tokenizer, text, model);
maybe_cleanup_cache(cache);
cache.insert(key, CacheEntry { count, last_access: now });
count
} else {
self.tokenize_with_panic_guard(tokenizer, text, model)
}
};
let stats = compute_estimation_stats(text);
let o200k = if self.use_exact {
get_exact(TokenModel::Gpt4o, get_gpt4o_tokenizer())
} else {
estimate_from_stats(&stats, TokenModel::Gpt4o)
};
let cl100k = if self.use_exact {
get_exact(TokenModel::Gpt4, get_gpt4_tokenizer())
} else {
estimate_from_stats(&stats, TokenModel::Gpt4)
};
TokenCounts {
o200k,
cl100k,
claude: estimate_from_stats(&stats, TokenModel::Claude),
gemini: estimate_from_stats(&stats, TokenModel::Gemini),
llama: estimate_from_stats(&stats, TokenModel::Llama),
mistral: estimate_from_stats(&stats, TokenModel::Mistral),
deepseek: estimate_from_stats(&stats, TokenModel::DeepSeek),
qwen: estimate_from_stats(&stats, TokenModel::Qwen),
cohere: estimate_from_stats(&stats, TokenModel::Cohere),
grok: estimate_from_stats(&stats, TokenModel::Grok),
}
}
pub fn most_efficient_model(&self, text: &str) -> (TokenModel, u32) {
let counts = self.count_all(text);
let models = [
(TokenModel::Gpt4o, counts.o200k), (TokenModel::Gpt4, counts.cl100k), (TokenModel::Claude, counts.claude),
(TokenModel::Gemini, counts.gemini),
(TokenModel::Llama, counts.llama),
(TokenModel::Mistral, counts.mistral),
(TokenModel::DeepSeek, counts.deepseek),
(TokenModel::Qwen, counts.qwen),
(TokenModel::Cohere, counts.cohere),
(TokenModel::Grok, counts.grok),
];
models
.into_iter()
.min_by_key(|(_, count)| *count)
.unwrap_or((TokenModel::Claude, 0))
}
pub fn truncate_to_budget<'a>(&self, text: &'a str, model: TokenModel, budget: u32) -> &'a str {
let current = self.count(text, model);
if current <= budget {
return text;
}
let mut low = 0usize;
let mut high = text.len();
while low < high {
let mid_raw = (low + high).div_ceil(2);
let mid = text.floor_char_boundary(mid_raw);
if mid <= low {
break;
}
let count = self.count(&text[..mid], model);
if count <= budget {
low = mid;
} else {
high = mid.saturating_sub(1);
}
}
let low = text.floor_char_boundary(low);
let mut end = low;
while end > 0 {
let boundary = text.floor_char_boundary(end);
if boundary < end {
end = boundary;
continue;
}
let c = text.as_bytes().get(end - 1).copied().unwrap_or(0);
if c == b' ' || c == b'\n' {
break;
}
end -= 1;
}
let end = text.floor_char_boundary(end);
if end > 0 {
&text[..end]
} else {
&text[..low]
}
}
pub fn exceeds_budget(&self, text: &str, model: TokenModel, budget: u32) -> bool {
self.count(text, model) > budget
}
}
pub fn quick_estimate(text: &str, model: TokenModel) -> u32 {
if text.is_empty() {
return 0;
}
let chars_per_token = model.chars_per_token();
(text.len() as f32 / chars_per_token).ceil().max(1.0) as u32
}
fn compute_estimation_stats(text: &str) -> EstimationStats {
let mut whitespace_count = 0u32;
let mut newline_count = 0u32;
let mut special_char_count = 0u32;
for &byte in text.as_bytes() {
match byte {
b' ' | b'\t' => whitespace_count += 1,
b'\n' => newline_count += 1,
b'{' | b'}' | b'(' | b')' | b'[' | b']' | b';' | b':' | b',' | b'.' | b'=' | b'+'
| b'-' | b'*' | b'/' | b'<' | b'>' | b'!' | b'&' | b'|' | b'@' | b'#' | b'$' | b'%'
| b'^' | b'~' | b'`' | b'"' | b'\'' => special_char_count += 1,
_ => {},
}
}
EstimationStats { len: text.len(), whitespace_count, newline_count, special_char_count }
}
fn estimate_from_stats(stats: &EstimationStats, model: TokenModel) -> u32 {
let chars_per_token = model.chars_per_token();
let len = stats.len as f32;
let mut estimate = len / chars_per_token;
estimate -= stats.whitespace_count as f32 * 0.3;
estimate += stats.newline_count as f32 * 0.5;
if matches!(
model,
TokenModel::CodeLlama | TokenModel::Claude | TokenModel::DeepSeek | TokenModel::Mistral
) {
estimate += stats.special_char_count as f32 * 0.3;
}
estimate.ceil().max(1.0) as u32
}