1use std::sync::OnceLock;
6use tiktoken_rs::{cl100k_base, o200k_base, CoreBPE};
7
8use crate::models::Encoding;
9
10static CL100K: OnceLock<CoreBPE> = OnceLock::new();
12static O200K: OnceLock<CoreBPE> = OnceLock::new();
13
14fn get_cl100k() -> &'static CoreBPE {
16 CL100K.get_or_init(|| cl100k_base().expect("Failed to load cl100k_base tokenizer"))
17}
18
19fn get_o200k() -> &'static CoreBPE {
21 O200K.get_or_init(|| o200k_base().expect("Failed to load o200k_base tokenizer"))
22}
23
24pub fn count_tokens(text: &str) -> usize {
37 count_tokens_with_encoding(text, Encoding::Cl100kBase)
38}
39
40pub fn count_tokens_with_encoding(text: &str, encoding: Encoding) -> usize {
57 match encoding {
58 Encoding::Cl100kBase => get_cl100k().encode_with_special_tokens(text).len(),
59 Encoding::O200kBase => get_o200k().encode_with_special_tokens(text).len(),
60 Encoding::LlamaBpe => {
61 get_cl100k().encode_with_special_tokens(text).len()
66 },
67 Encoding::Heuristic => {
68 heuristic_count(text)
71 },
72 }
73}
74
75pub fn count_tokens_for_model(text: &str, model: &str) -> usize {
86 let encoding = Encoding::infer_from_id(model);
87 count_tokens_with_encoding(text, encoding)
88}
89
90fn heuristic_count(text: &str) -> usize {
95 text.len().div_ceil(4)
97}
98
99pub struct TokenCounter {
116 encoding: Encoding,
117}
118
119impl TokenCounter {
120 pub fn new(encoding: Encoding) -> Self {
122 Self { encoding }
123 }
124
125 pub fn default_encoding() -> Self {
127 Self::new(Encoding::Cl100kBase)
128 }
129
130 pub fn for_model(model: &str) -> Self {
132 Self::new(Encoding::infer_from_id(model))
133 }
134
135 pub fn count(&self, text: &str) -> usize {
137 count_tokens_with_encoding(text, self.encoding)
138 }
139
140 pub fn count_many(&self, texts: &[&str]) -> usize {
142 texts.iter().map(|t| self.count(t)).sum()
143 }
144
145 pub fn count_json(&self, value: &serde_json::Value) -> usize {
147 let text = serde_json::to_string(value).unwrap_or_default();
148 self.count(&text)
149 }
150
151 pub fn encoding(&self) -> Encoding {
153 self.encoding
154 }
155}
156
157impl Default for TokenCounter {
158 fn default() -> Self {
159 Self::default_encoding()
160 }
161}
162
163pub fn estimate_savings(
167 original: &str,
168 compressed: &str,
169 encoding: Encoding,
170) -> (usize, usize, i64, f64) {
171 let original_tokens = count_tokens_with_encoding(original, encoding);
172 let compressed_tokens = count_tokens_with_encoding(compressed, encoding);
173 let savings = original_tokens as i64 - compressed_tokens as i64;
174 let savings_percent = if original_tokens > 0 {
175 (savings as f64 / original_tokens as f64) * 100.0
176 } else {
177 0.0
178 };
179
180 (original_tokens, compressed_tokens, savings, savings_percent)
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186
187 #[test]
188 fn test_count_tokens_basic() {
189 let tokens = count_tokens("Hello, world!");
190 assert!(tokens > 0);
191 assert!(tokens < 10);
192 }
193
194 #[test]
195 fn test_count_tokens_empty() {
196 assert_eq!(count_tokens(""), 0);
197 }
198
199 #[test]
200 fn test_different_encodings() {
201 let text = "Hello, world! This is a test.";
202
203 let cl100k = count_tokens_with_encoding(text, Encoding::Cl100kBase);
204 let o200k = count_tokens_with_encoding(text, Encoding::O200kBase);
205 let heuristic = count_tokens_with_encoding(text, Encoding::Heuristic);
206
207 assert!(cl100k > 0);
209 assert!(o200k > 0);
210 assert!(heuristic > 0);
211
212 let expected_heuristic = (text.len() + 3) / 4;
214 assert_eq!(heuristic, expected_heuristic);
215 }
216
217 #[test]
218 fn test_count_tokens_for_model() {
219 let text = "Hello!";
220
221 let o200k_tokens = count_tokens_for_model(text, "openai/gpt-4o");
223
224 let cl100k_tokens = count_tokens_for_model(text, "openai/gpt-4");
226
227 assert!(o200k_tokens > 0);
229 assert!(cl100k_tokens > 0);
230 }
231
232 #[test]
233 fn test_token_counter_struct() {
234 let counter = TokenCounter::new(Encoding::Cl100kBase);
235
236 let tokens = counter.count("Hello, world!");
237 assert!(tokens > 0);
238
239 let total = counter.count_many(&["Hello", "World"]);
240 assert!(total > 0);
241 }
242
243 #[test]
244 fn test_token_counter_json() {
245 let counter = TokenCounter::default();
246
247 let json = serde_json::json!({
248 "message": "Hello, world!",
249 "count": 42
250 });
251
252 let tokens = counter.count_json(&json);
253 assert!(tokens > 0);
254 }
255
256 #[test]
257 fn test_estimate_savings() {
258 let original = r#"{"messages":[{"role":"assistant","content":"Hello there! How can I help you today?"}],"temperature":1.0}"#;
260 let compressed = r#"{"m":[{"r":"A","c":"Hello there! How can I help you today?"}]}"#;
261
262 let (orig, comp, savings, percent) =
263 estimate_savings(original, compressed, Encoding::Cl100kBase);
264
265 if orig > comp {
268 assert!(savings > 0, "Should have positive savings");
269 assert!(percent > 0.0, "Should have positive percentage");
270 } else {
271 assert_eq!(savings, orig as i64 - comp as i64);
273 }
274 }
275
276 #[test]
277 fn test_heuristic_never_zero() {
278 assert!(heuristic_count("a") >= 1);
280 assert!(heuristic_count("ab") >= 1);
281 assert!(heuristic_count("abc") >= 1);
282 assert!(heuristic_count("abcd") >= 1);
283 }
284
285 #[test]
286 fn test_encoding_consistency() {
287 let text = "The quick brown fox jumps over the lazy dog.";
289
290 let count1 = count_tokens(text);
291 let count2 = count_tokens(text);
292 let count3 = count_tokens_with_encoding(text, Encoding::Cl100kBase);
293
294 assert_eq!(count1, count2);
295 assert_eq!(count1, count3);
296 }
297
298 #[test]
299 fn test_json_message_tokens() {
300 let message = r#"{"model":"openai/gpt-4o","messages":[{"role":"user","content":"Hello"}],"temperature":1.0}"#;
302
303 let tokens = count_tokens(message);
304
305 assert!(tokens > 10);
307 assert!(tokens < 50);
308 }
309}