Skip to main content

matrixcode_core/
tokenizer.rs

1//! Accurate token counting using tiktoken.
2//!
3//! Provides precise token counting for messages to enable better
4//! context window management.
5
6use once_cell::sync::Lazy;
7use std::sync::Arc;
8use tiktoken_rs::CoreBPE;
9
10/// Global BPE encoder for token counting (cl100k_base for GPT-4/Claude).
11static BPE: Lazy<Arc<CoreBPE>> = Lazy::new(|| {
12    Arc::new(tiktoken_rs::cl100k_base().expect("Failed to initialize tokenizer"))
13});
14
15/// Count tokens in a text string.
16pub fn count_tokens(text: &str) -> u32 {
17    BPE.encode_with_special_tokens(text).len() as u32
18}
19
20/// Count tokens for a role prefix (e.g., "user: ", "assistant: ").
21/// Each message has overhead for role markers and formatting.
22pub fn message_overhead() -> u32 {
23    // Approximate overhead per message:
24    // - Role prefix: ~4 tokens
25    // - Message separators: ~2 tokens
26    // - Total overhead: ~6 tokens
27    6
28}
29
30#[cfg(test)]
31mod tests {
32    use super::*;
33
34    #[test]
35    fn test_count_tokens_simple() {
36        let text = "Hello, world!";
37        let count = count_tokens(text);
38        assert!(count > 0);
39        // "Hello, world!" is typically 4 tokens
40        assert!(count >= 3 && count <= 5);
41    }
42
43    #[test]
44    fn test_count_tokens_chinese() {
45        let text = "你好,世界!";
46        let count = count_tokens(text);
47        assert!(count > 0);
48        // Chinese characters typically use more tokens
49        assert!(count >= 5);
50    }
51
52    #[test]
53    fn test_count_tokens_code() {
54        let code = r#"
55fn main() {
56    println!("Hello");
57}
58"#;
59        let count = count_tokens(code);
60        assert!(count > 0);
61        // Code typically uses more tokens due to symbols
62        // Actual count is around 13 tokens
63        assert!(count >= 10, "Code should use at least 10 tokens, got {}", count);
64    }
65
66    #[test]
67    fn test_message_overhead() {
68        let overhead = message_overhead();
69        assert_eq!(overhead, 6);
70    }
71
72    #[test]
73    fn test_token_counting_accuracy() {
74        // Compare with known token counts
75        // The phrase "Hello, world!" is 4 tokens in cl100k_base
76        assert_eq!(count_tokens("Hello, world!"), 4);
77        
78        // Single word
79        assert_eq!(count_tokens("Hello"), 1);
80        
81        // Numbers
82        assert_eq!(count_tokens("12345"), 2); // "123" + "45"
83        
84        // Chinese characters (each typically 1-2 tokens)
85        let chinese = "你好世界";
86        let chinese_count = count_tokens(chinese);
87        assert!(chinese_count >= 4, "Chinese text should use at least 4 tokens, got {}", chinese_count);
88        
89        // Empty string
90        assert_eq!(count_tokens(""), 0);
91        
92        // Whitespace
93        assert_eq!(count_tokens("   "), 1);
94    }
95}