1use std::sync::OnceLock;
12use tiktoken::CoreBpe;
13
14fn bpe() -> &'static CoreBpe {
18 static BPE: OnceLock<&CoreBpe> = OnceLock::new();
19 BPE.get_or_init(|| tiktoken::get_encoding("cl100k_base").expect("failed to load cl100k_base"))
20}
21
22pub fn estimate_tokens(text: &str) -> usize {
26 if text.is_empty() {
27 return 0;
28 }
29 bpe().count(text)
30}
31
32pub fn truncate_to_tokens(text: &str, max_tokens: usize) -> String {
38 if max_tokens == 0 || text.is_empty() {
39 return String::new();
40 }
41 let tokens = bpe().encode_with_special_tokens(text);
42 if tokens.len() <= max_tokens {
43 return text.to_string();
44 }
45 bpe()
46 .decode_to_string(&tokens[..max_tokens])
47 .unwrap_or_else(|_| {
48 let end = (max_tokens * 4).min(text.len());
50 let mut end = end;
51 while end > 0 && !text.is_char_boundary(end) {
52 end -= 1;
53 }
54 let mut result = text[..end].to_string();
55 result.push_str("...");
56 result
57 })
58}
59
60#[cfg(test)]
61mod tests {
62 use super::*;
63
64 #[test]
65 fn empty_string_returns_zero() {
66 assert_eq!(estimate_tokens(""), 0);
67 assert_eq!(truncate_to_tokens("", 10), "");
68 }
69
70 #[test]
71 fn count_is_reasonable() {
72 let count = estimate_tokens("Hello, how are you today?");
73 assert!(count >= 4 && count <= 12, "count={count}");
74 }
75
76 #[test]
77 fn truncate_respects_limit() {
78 let text = "the quick brown fox jumps over the lazy dog";
79 let truncated = truncate_to_tokens(text, 5);
80 let count = estimate_tokens(&truncated);
81 assert!(count <= 5 + 1, "count={count} should be <= 6");
82 }
83
84 #[test]
85 fn truncate_zero_returns_empty() {
86 assert_eq!(truncate_to_tokens("hello", 0), "");
87 }
88
89 #[test]
90 fn code_and_prose_tokenize() {
91 let code = "fn main() { println!(\"hello\"); }";
92 let prose = "the main function prints hello to console";
93 assert!(estimate_tokens(code) > 0);
94 assert!(estimate_tokens(prose) > 0);
95 }
96
97 #[test]
98 fn json_tokenizes() {
99 let json = r#"{"name":"test","value":123,"nested":{"key":"value"}}"#;
100 let count = estimate_tokens(json);
101 assert!(count >= 10 && count <= 40, "json count={count}");
102 }
103}