1use crate::PackError;
2
3pub struct TokenCounter {
7 bpe: tiktoken_rs::CoreBPE,
8}
9
10impl TokenCounter {
11 pub fn new() -> Result<Self, PackError> {
18 let bpe = tiktoken_rs::cl100k_base().map_err(|e| PackError::Io(e.to_string()))?;
19 Ok(Self { bpe })
20 }
21
22 pub fn count(&self, text: &str) -> usize {
24 self.bpe.encode_ordinary(text).len()
25 }
26
27 pub fn count_bytes(&self, text: &[u8]) -> usize {
30 let s = String::from_utf8_lossy(text);
31 self.count(&s)
32 }
33}
34
35#[cfg(test)]
36#[allow(clippy::unwrap_used)]
37mod tests {
38 use super::*;
39
40 #[test]
41 fn empty_string_is_zero_tokens() {
42 let tc = TokenCounter::new().unwrap();
43 assert_eq!(tc.count(""), 0);
44 }
45
46 #[test]
47 fn simple_english_sentence() {
48 let tc = TokenCounter::new().unwrap();
49 let n = tc.count("Hello, world!");
50 assert!((3..=6).contains(&n), "expected 3-6 tokens, got {n}");
51 }
52
53 #[test]
54 fn rust_function_body() {
55 let tc = TokenCounter::new().unwrap();
56 let code = "fn main() {\n println!(\"Hello\");\n}\n";
57 let n = tc.count(code);
58 assert!(n > 5, "expected >5 tokens for a small function, got {n}");
59 assert!(n < 30, "expected <30 tokens, got {n}");
60 }
61
62 #[test]
63 fn count_bytes_falls_back_to_lossy() {
64 let tc = TokenCounter::new().unwrap();
65 let n = tc.count_bytes(&[0x48, 0x65, 0x6c, 0x6c, 0x6f]);
67 assert!(n > 0, "expected >0 tokens for 'Hello' bytes");
68 }
69
70 #[test]
71 fn longer_text_is_more_tokens() {
72 let tc = TokenCounter::new().unwrap();
73 let short = tc.count("fn");
74 let long = tc.count("fn foo(x: i32) -> i32 { x + 1 }");
75 assert!(long > short, "longer text should have more tokens");
76 }
77
78 #[test]
79 fn token_count_is_repeatable() {
80 let tc = TokenCounter::new().unwrap();
81 let text = "fn factorial(n: u64) -> u64 { if n <= 1 { 1 } else { n * factorial(n - 1) } }";
82 let n1 = tc.count(text);
83 let n2 = tc.count(text);
84 assert_eq!(n1, n2);
85 }
86}