enact_context/
token_counter.rs1use thiserror::Error;
6use tiktoken_rs::{cl100k_base, CoreBPE};
7
8#[derive(Debug, Error)]
10pub enum TokenCounterError {
11 #[error("Failed to initialize tokenizer: {0}")]
12 InitError(String),
13}
14
15pub struct TokenCounter {
19 bpe: CoreBPE,
20}
21
22impl TokenCounter {
23 pub fn new() -> Result<Self, TokenCounterError> {
25 let bpe = cl100k_base().map_err(|e| TokenCounterError::InitError(e.to_string()))?;
26 Ok(Self { bpe })
27 }
28
29 pub fn count(&self, text: &str) -> usize {
31 self.bpe.encode_with_special_tokens(text).len()
32 }
33
34 pub fn count_message(&self, role: &str, content: &str) -> usize {
40 let content_tokens = self.count(content);
41 let role_overhead = match role {
42 "system" => 4,
43 "user" => 4,
44 "assistant" => 4,
45 "function" | "tool" => 5,
46 _ => 4,
47 };
48 content_tokens + role_overhead
49 }
50
51 pub fn count_messages(&self, messages: &[(String, String)]) -> usize {
53 messages
54 .iter()
55 .map(|(role, content)| self.count_message(role, content))
56 .sum::<usize>()
57 + 3 }
59
60 pub fn truncate(&self, text: &str, max_tokens: usize) -> (String, usize) {
64 let tokens = self.bpe.encode_with_special_tokens(text);
65 if tokens.len() <= max_tokens {
66 return (text.to_string(), tokens.len());
67 }
68
69 let truncated_tokens = &tokens[..max_tokens];
70 let truncated_text = self
71 .bpe
72 .decode(truncated_tokens.to_vec())
73 .unwrap_or_else(|_| text[..text.len() / 2].to_string());
74
75 (truncated_text, max_tokens)
76 }
77
78 pub fn chunk(&self, text: &str, chunk_size: usize) -> Vec<String> {
80 let tokens = self.bpe.encode_with_special_tokens(text);
81 let mut chunks = Vec::new();
82
83 for chunk_tokens in tokens.chunks(chunk_size) {
84 if let Ok(chunk_text) = self.bpe.decode(chunk_tokens.to_vec()) {
85 chunks.push(chunk_text);
86 }
87 }
88
89 chunks
90 }
91
92 pub fn will_fit(&self, text: &str, budget: usize) -> bool {
94 self.count(text) <= budget
95 }
96}
97
98impl Default for TokenCounter {
99 fn default() -> Self {
100 Self::new().expect("Failed to initialize default token counter")
101 }
102}
103
104#[cfg(test)]
105mod tests {
106 use super::*;
107
108 #[test]
109 fn test_count_tokens() {
110 let counter = TokenCounter::new().unwrap();
111
112 assert_eq!(counter.count(""), 0);
114
115 let count = counter.count("Hello, world!");
117 assert!(count > 0);
118 assert!(count < 10);
119 }
120
121 #[test]
122 fn test_count_message() {
123 let counter = TokenCounter::new().unwrap();
124
125 let content_tokens = counter.count("Hello");
126 let message_tokens = counter.count_message("user", "Hello");
127
128 assert!(message_tokens > content_tokens);
130 }
131
132 #[test]
133 fn test_truncate() {
134 let counter = TokenCounter::new().unwrap();
135
136 let long_text = "This is a long text that we want to truncate to a smaller size.";
137 let (truncated, count) = counter.truncate(long_text, 5);
138
139 assert!(count <= 5);
140 assert!(truncated.len() < long_text.len());
141 }
142
143 #[test]
144 fn test_will_fit() {
145 let counter = TokenCounter::new().unwrap();
146
147 assert!(counter.will_fit("Hello", 100));
148 assert!(!counter.will_fit("Hello ".repeat(1000).as_str(), 10));
149 }
150}