use std::sync::OnceLock;
use tiktoken::CoreBpe;
fn bpe() -> &'static CoreBpe {
static BPE: OnceLock<&CoreBpe> = OnceLock::new();
BPE.get_or_init(|| tiktoken::get_encoding("cl100k_base").expect("failed to load cl100k_base"))
}
pub fn estimate_tokens(text: &str) -> usize {
if text.is_empty() {
return 0;
}
bpe().count(text)
}
pub fn truncate_to_tokens(text: &str, max_tokens: usize) -> String {
if max_tokens == 0 || text.is_empty() {
return String::new();
}
let tokens = bpe().encode_with_special_tokens(text);
if tokens.len() <= max_tokens {
return text.to_string();
}
bpe()
.decode_to_string(&tokens[..max_tokens])
.unwrap_or_else(|_| {
let end = (max_tokens * 4).min(text.len());
let mut end = end;
while end > 0 && !text.is_char_boundary(end) {
end -= 1;
}
let mut result = text[..end].to_string();
result.push_str("...");
result
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_string_returns_zero() {
assert_eq!(estimate_tokens(""), 0);
assert_eq!(truncate_to_tokens("", 10), "");
}
#[test]
fn count_is_reasonable() {
let count = estimate_tokens("Hello, how are you today?");
assert!(count >= 4 && count <= 12, "count={count}");
}
#[test]
fn truncate_respects_limit() {
let text = "the quick brown fox jumps over the lazy dog";
let truncated = truncate_to_tokens(text, 5);
let count = estimate_tokens(&truncated);
assert!(count <= 5 + 1, "count={count} should be <= 6");
}
#[test]
fn truncate_zero_returns_empty() {
assert_eq!(truncate_to_tokens("hello", 0), "");
}
#[test]
fn code_and_prose_tokenize() {
let code = "fn main() { println!(\"hello\"); }";
let prose = "the main function prints hello to console";
assert!(estimate_tokens(code) > 0);
assert!(estimate_tokens(prose) > 0);
}
#[test]
fn json_tokenizes() {
let json = r#"{"name":"test","value":123,"nested":{"key":"value"}}"#;
let count = estimate_tokens(json);
assert!(count >= 10 && count <= 40, "json count={count}");
}
}