Skip to main content

enact_context/
token_counter.rs

1//! Token Counter
2//!
3//! Token counting using tiktoken for accurate context management.
4
5use thiserror::Error;
6use tiktoken_rs::{cl100k_base, CoreBPE};
7
8/// Token counter errors
9#[derive(Debug, Error)]
10pub enum TokenCounterError {
11    #[error("Failed to initialize tokenizer: {0}")]
12    InitError(String),
13}
14
15/// Token counter using tiktoken
16///
17/// Uses cl100k_base encoding (GPT-4, GPT-3.5-turbo, text-embedding-ada-002)
18pub struct TokenCounter {
19    bpe: CoreBPE,
20}
21
22impl TokenCounter {
23    /// Create a new token counter with cl100k_base encoding
24    pub fn new() -> Result<Self, TokenCounterError> {
25        let bpe = cl100k_base().map_err(|e| TokenCounterError::InitError(e.to_string()))?;
26        Ok(Self { bpe })
27    }
28
29    /// Count tokens in a string
30    pub fn count(&self, text: &str) -> usize {
31        self.bpe.encode_with_special_tokens(text).len()
32    }
33
34    /// Count tokens for a chat message (includes role overhead)
35    ///
36    /// Approximates the token overhead for chat formatting:
37    /// - Each message has ~4 tokens overhead (role, separators)
38    /// - System messages have additional overhead
39    pub fn count_message(&self, role: &str, content: &str) -> usize {
40        let content_tokens = self.count(content);
41        let role_overhead = match role {
42            "system" => 4,
43            "user" => 4,
44            "assistant" => 4,
45            "function" | "tool" => 5,
46            _ => 4,
47        };
48        content_tokens + role_overhead
49    }
50
51    /// Count tokens for multiple messages
52    pub fn count_messages(&self, messages: &[(String, String)]) -> usize {
53        messages
54            .iter()
55            .map(|(role, content)| self.count_message(role, content))
56            .sum::<usize>()
57            + 3 // priming tokens
58    }
59
60    /// Truncate text to fit within a token limit
61    ///
62    /// Returns the truncated text and the actual token count.
63    pub fn truncate(&self, text: &str, max_tokens: usize) -> (String, usize) {
64        let tokens = self.bpe.encode_with_special_tokens(text);
65        if tokens.len() <= max_tokens {
66            return (text.to_string(), tokens.len());
67        }
68
69        let truncated_tokens = &tokens[..max_tokens];
70        let truncated_text = self
71            .bpe
72            .decode(truncated_tokens.to_vec())
73            .unwrap_or_else(|_| text[..text.len() / 2].to_string());
74
75        (truncated_text, max_tokens)
76    }
77
78    /// Split text into chunks of approximately equal token size
79    pub fn chunk(&self, text: &str, chunk_size: usize) -> Vec<String> {
80        let tokens = self.bpe.encode_with_special_tokens(text);
81        let mut chunks = Vec::new();
82
83        for chunk_tokens in tokens.chunks(chunk_size) {
84            if let Ok(chunk_text) = self.bpe.decode(chunk_tokens.to_vec()) {
85                chunks.push(chunk_text);
86            }
87        }
88
89        chunks
90    }
91
92    /// Estimate if text will fit within budget
93    pub fn will_fit(&self, text: &str, budget: usize) -> bool {
94        self.count(text) <= budget
95    }
96}
97
98impl Default for TokenCounter {
99    fn default() -> Self {
100        Self::new().expect("Failed to initialize default token counter")
101    }
102}
103
104#[cfg(test)]
105mod tests {
106    use super::*;
107
108    #[test]
109    fn test_count_tokens() {
110        let counter = TokenCounter::new().unwrap();
111
112        // Empty string
113        assert_eq!(counter.count(""), 0);
114
115        // Simple text
116        let count = counter.count("Hello, world!");
117        assert!(count > 0);
118        assert!(count < 10);
119    }
120
121    #[test]
122    fn test_count_message() {
123        let counter = TokenCounter::new().unwrap();
124
125        let content_tokens = counter.count("Hello");
126        let message_tokens = counter.count_message("user", "Hello");
127
128        // Message should have overhead
129        assert!(message_tokens > content_tokens);
130    }
131
132    #[test]
133    fn test_truncate() {
134        let counter = TokenCounter::new().unwrap();
135
136        let long_text = "This is a long text that we want to truncate to a smaller size.";
137        let (truncated, count) = counter.truncate(long_text, 5);
138
139        assert!(count <= 5);
140        assert!(truncated.len() < long_text.len());
141    }
142
143    #[test]
144    fn test_will_fit() {
145        let counter = TokenCounter::new().unwrap();
146
147        assert!(counter.will_fit("Hello", 100));
148        assert!(!counter.will_fit("Hello ".repeat(1000).as_str(), 10));
149    }
150}