use std::sync::OnceLock;
static TOKENIZER: OnceLock<tiktoken_rs::CoreBPE> = OnceLock::new();
pub struct TokenEstimator;
impl TokenEstimator {
pub fn count(text: &str) -> usize {
let tok =
TOKENIZER.get_or_init(|| tiktoken_rs::cl100k_base().expect("cl100k_base must load"));
tok.encode_with_special_tokens(text).len()
}
pub fn count_nonblocking(text: &str) -> usize {
if let Some(tok) = TOKENIZER.get() {
tok.encode_with_special_tokens(text).len()
} else {
Self::count_fast(text)
}
}
pub fn count_fast(text: &str) -> usize {
if text.is_empty() {
0
} else {
(text.len() / 4).max(1)
}
}
pub fn savings_pct(before: usize, after: usize) -> f64 {
if before == 0 {
return 0.0;
}
(1.0 - after as f64 / before as f64) * 100.0
}
pub fn warmup() {
let _ = Self::count("warmup");
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn estimates_non_zero() {
assert!(TokenEstimator::count("hello world this is a test") > 0);
}
#[test]
fn count_fast_non_zero() {
assert!(TokenEstimator::count_fast("hello world") > 0);
assert_eq!(TokenEstimator::count_fast(""), 0);
}
#[test]
fn savings_calculation() {
let pct = TokenEstimator::savings_pct(1000, 200);
assert!((pct - 80.0).abs() < 0.1);
}
#[test]
fn count_nonblocking_returns_fast_when_not_initialized() {
assert!(TokenEstimator::count_nonblocking("hello world") > 0);
}
}