Skip to main content

argyph_pack/
tokenize.rs

1use crate::PackError;
2
3/// Estimates token counts using the `cl100k_base` tokenizer (GPT-4 / Claude
4/// compatible). Accuracy is within ~5% across non-OpenAI providers; the budget
5/// is a soft guarantee.
6pub struct TokenCounter {
7    bpe: tiktoken_rs::CoreBPE,
8}
9
10impl TokenCounter {
11    /// Create a new token counter backed by the `cl100k_base` encoding.
12    ///
13    /// # Errors
14    ///
15    /// Returns [`PackError::Io`] if the tokenizer cannot be loaded (should only
16    /// happen if the `tiktoken-rs` data files are missing).
17    pub fn new() -> Result<Self, PackError> {
18        let bpe = tiktoken_rs::cl100k_base().map_err(|e| PackError::Io(e.to_string()))?;
19        Ok(Self { bpe })
20    }
21
22    /// Count the number of tokens in a UTF-8 string.
23    pub fn count(&self, text: &str) -> usize {
24        self.bpe.encode_ordinary(text).len()
25    }
26
27    /// Count the number of tokens in a byte slice, handling non-UTF-8 data
28    /// gracefully via lossy conversion.
29    pub fn count_bytes(&self, text: &[u8]) -> usize {
30        let s = String::from_utf8_lossy(text);
31        self.count(&s)
32    }
33}
34
35#[cfg(test)]
36#[allow(clippy::unwrap_used)]
37mod tests {
38    use super::*;
39
40    #[test]
41    fn empty_string_is_zero_tokens() {
42        let tc = TokenCounter::new().unwrap();
43        assert_eq!(tc.count(""), 0);
44    }
45
46    #[test]
47    fn simple_english_sentence() {
48        let tc = TokenCounter::new().unwrap();
49        let n = tc.count("Hello, world!");
50        assert!((3..=6).contains(&n), "expected 3-6 tokens, got {n}");
51    }
52
53    #[test]
54    fn rust_function_body() {
55        let tc = TokenCounter::new().unwrap();
56        let code = "fn main() {\n    println!(\"Hello\");\n}\n";
57        let n = tc.count(code);
58        assert!(n > 5, "expected >5 tokens for a small function, got {n}");
59        assert!(n < 30, "expected <30 tokens, got {n}");
60    }
61
62    #[test]
63    fn count_bytes_falls_back_to_lossy() {
64        let tc = TokenCounter::new().unwrap();
65        // Invalid UTF-8 bytes
66        let n = tc.count_bytes(&[0x48, 0x65, 0x6c, 0x6c, 0x6f]);
67        assert!(n > 0, "expected >0 tokens for 'Hello' bytes");
68    }
69
70    #[test]
71    fn longer_text_is_more_tokens() {
72        let tc = TokenCounter::new().unwrap();
73        let short = tc.count("fn");
74        let long = tc.count("fn foo(x: i32) -> i32 { x + 1 }");
75        assert!(long > short, "longer text should have more tokens");
76    }
77
78    #[test]
79    fn token_count_is_repeatable() {
80        let tc = TokenCounter::new().unwrap();
81        let text = "fn factorial(n: u64) -> u64 { if n <= 1 { 1 } else { n * factorial(n - 1) } }";
82        let n1 = tc.count(text);
83        let n2 = tc.count(text);
84        assert_eq!(n1, n2);
85    }
86}