bamboo_compression/
counter.rs1use std::sync::OnceLock;
9
10use bamboo_agent_core::Message;
11use tiktoken_rs::o200k_base;
12use tiktoken_rs::CoreBPE;
13
14static O200K_ENCODER: OnceLock<CoreBPE> = OnceLock::new();
16
17fn o200k_encoder() -> &'static CoreBPE {
18 O200K_ENCODER.get_or_init(|| o200k_base().unwrap())
19}
20
21pub trait TokenCounter: Send + Sync {
23 fn count_message(&self, message: &Message) -> u32;
25
26 fn count_messages(&self, messages: &[Message]) -> u32 {
28 messages.iter().map(|m| self.count_message(m)).sum()
29 }
30
31 fn count_text(&self, text: &str) -> u32;
33}
34
35#[derive(Debug, Clone)]
42pub struct HeuristicTokenCounter {
43 chars_per_token: f64,
45 safety_margin: f64,
47 metadata_overhead: u32,
49}
50
51impl HeuristicTokenCounter {
52 pub fn new(chars_per_token: f64, safety_margin: f64, metadata_overhead: u32) -> Self {
54 Self {
55 chars_per_token,
56 safety_margin,
57 metadata_overhead,
58 }
59 }
60
61 pub fn with_defaults() -> Self {
63 Self {
64 chars_per_token: 4.0,
65 safety_margin: 1.1,
66 metadata_overhead: 10,
67 }
68 }
69}
70
71impl Default for HeuristicTokenCounter {
72 fn default() -> Self {
73 Self::with_defaults()
74 }
75}
76
77impl TokenCounter for HeuristicTokenCounter {
78 fn count_message(&self, message: &Message) -> u32 {
79 let content_tokens = self.count_text(&message.content);
80
81 let tool_calls_tokens = message
83 .tool_calls
84 .as_ref()
85 .map(|tc| {
86 tc.iter()
87 .map(|c| {
88 let args_tokens = self.count_text(&c.function.arguments);
90 let id_tokens = self.count_text(&c.id);
91 let name_tokens = self.count_text(&c.function.name);
92 args_tokens
94 .saturating_add(id_tokens)
95 .saturating_add(name_tokens)
96 .saturating_add(5) })
98 .fold(0u32, |acc, x| acc.saturating_add(x))
99 })
100 .unwrap_or(0);
101
102 let tool_call_id_tokens = message
104 .tool_call_id
105 .as_ref()
106 .map(|id| self.count_text(id).saturating_add(3)) .unwrap_or(0);
108
109 content_tokens
111 .saturating_add(tool_calls_tokens)
112 .saturating_add(tool_call_id_tokens)
113 .saturating_add(self.metadata_overhead)
114 }
115
116 fn count_text(&self, text: &str) -> u32 {
117 if text.is_empty() {
118 return 0;
119 }
120
121 let char_count = text.chars().count() as f64;
122 let base_tokens = char_count / self.chars_per_token;
123 let adjusted_tokens = base_tokens * self.safety_margin;
124
125 adjusted_tokens.ceil() as u32
126 }
127}
128
129#[derive(Debug)]
134pub struct TiktokenTokenCounter {
135 metadata_overhead: u32,
137}
138
139impl TiktokenTokenCounter {
140 pub fn new(metadata_overhead: u32) -> Self {
142 Self { metadata_overhead }
143 }
144}
145
146impl Default for TiktokenTokenCounter {
147 fn default() -> Self {
148 Self {
149 metadata_overhead: 10,
150 }
151 }
152}
153
154impl TokenCounter for TiktokenTokenCounter {
155 fn count_message(&self, message: &Message) -> u32 {
156 let content_tokens = self.count_text(&message.content);
157
158 let tool_calls_tokens = message
159 .tool_calls
160 .as_ref()
161 .map(|tc| {
162 tc.iter()
163 .map(|c| {
164 let args_tokens = self.count_text(&c.function.arguments);
165 let id_tokens = self.count_text(&c.id);
166 let name_tokens = self.count_text(&c.function.name);
167 args_tokens
168 .saturating_add(id_tokens)
169 .saturating_add(name_tokens)
170 .saturating_add(5)
171 })
172 .fold(0u32, |acc, x| acc.saturating_add(x))
173 })
174 .unwrap_or(0);
175
176 let tool_call_id_tokens = message
177 .tool_call_id
178 .as_ref()
179 .map(|id| self.count_text(id).saturating_add(3))
180 .unwrap_or(0);
181
182 content_tokens
183 .saturating_add(tool_calls_tokens)
184 .saturating_add(tool_call_id_tokens)
185 .saturating_add(self.metadata_overhead)
186 }
187
188 fn count_text(&self, text: &str) -> u32 {
189 if text.is_empty() {
190 return 0;
191 }
192 let tokens = o200k_encoder().encode_with_special_tokens(text);
193 tokens.len() as u32
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200 use bamboo_agent_core::{FunctionCall, ToolCall};
201
202 #[test]
203 fn heuristic_counter_counts_text() {
204 let counter = HeuristicTokenCounter::default();
205
206 let tokens = counter.count_text("Hello, world!");
208 assert!(
209 tokens >= 3 && tokens <= 5,
210 "Expected ~4 tokens, got {}",
211 tokens
212 );
213 }
214
215 #[test]
216 fn heuristic_counter_counts_empty_text() {
217 let counter = HeuristicTokenCounter::default();
218 assert_eq!(counter.count_text(""), 0);
219 }
220
221 #[test]
222 fn heuristic_counter_counts_user_message() {
223 let counter = HeuristicTokenCounter::default();
224 let message = Message::user("Hello, world!");
225
226 let tokens = counter.count_message(&message);
227 assert!(
229 tokens >= 10,
230 "Expected at least 10 tokens (content + metadata), got {}",
231 tokens
232 );
233 }
234
235 #[test]
236 fn heuristic_counter_counts_tool_calls() {
237 let counter = HeuristicTokenCounter::default();
238
239 let tool_call = ToolCall {
240 id: "call_123".to_string(),
241 tool_type: "function".to_string(),
242 function: FunctionCall {
243 name: "search".to_string(),
244 arguments: r#"{"query":"test"}"#.to_string(),
245 },
246 };
247
248 let message = Message::assistant("Let me search", Some(vec![tool_call]));
249
250 let tokens = counter.count_message(&message);
251 assert!(tokens >= 15, "Expected at least 15 tokens, got {}", tokens);
253 }
254
255 #[test]
256 fn heuristic_counter_counts_tool_result() {
257 let counter = HeuristicTokenCounter::default();
258 let message = Message::tool_result("call_123", "Search results here");
259
260 let tokens = counter.count_message(&message);
261 assert!(tokens >= 15, "Expected at least 15 tokens, got {}", tokens);
263 }
264
265 #[test]
266 fn heuristic_counter_counts_multiple_messages() {
267 let counter = HeuristicTokenCounter::default();
268 let messages = vec![
269 Message::system("You are helpful"),
270 Message::user("Hello"),
271 Message::assistant("Hi there", None),
272 ];
273
274 let total = counter.count_messages(&messages);
275 let sum: u32 = messages.iter().map(|m| counter.count_message(m)).sum();
276
277 assert_eq!(total, sum);
278 }
279
280 #[test]
281 fn custom_chars_per_token() {
282 let counter = HeuristicTokenCounter::new(2.0, 1.0, 0);
283 let tokens = counter.count_text("test");
285 assert_eq!(tokens, 2);
286 }
287
288 #[test]
289 fn safety_margin_applied() {
290 let counter_no_margin = HeuristicTokenCounter::new(4.0, 1.0, 0);
291 let counter_with_margin = HeuristicTokenCounter::new(4.0, 1.1, 0);
292
293 let text = "Hello world!"; let base = counter_no_margin.count_text(text);
295 let adjusted = counter_with_margin.count_text(text);
296
297 assert!(adjusted > base, "Safety margin should increase token count");
298 }
299
300 #[test]
303 fn tiktoken_counter_counts_text() {
304 let counter = TiktokenTokenCounter::default();
305 let tokens = counter.count_text("Hello, world!");
306 assert!(
308 tokens >= 3 && tokens <= 6,
309 "Expected ~4 tokens, got {}",
310 tokens
311 );
312 }
313
314 #[test]
315 fn tiktoken_counter_counts_empty_text() {
316 let counter = TiktokenTokenCounter::default();
317 assert_eq!(counter.count_text(""), 0);
318 }
319
320 #[test]
321 fn tiktoken_counter_counts_cjk() {
322 let counter = TiktokenTokenCounter::default();
323 let tokens = counter.count_text("你好世界");
325 assert!(
326 tokens >= 2 && tokens <= 8,
327 "Expected 2-8 tokens, got {}",
328 tokens
329 );
330 }
331
332 #[test]
333 fn tiktoken_counter_counts_user_message() {
334 let counter = TiktokenTokenCounter::default();
335 let message = Message::user("Hello, world!");
336 let tokens = counter.count_message(&message);
337 assert!(tokens >= 10, "Expected at least 10 tokens, got {}", tokens);
339 }
340
341 #[test]
342 fn tiktoken_counter_counts_tool_calls() {
343 let counter = TiktokenTokenCounter::default();
344 let tool_call = ToolCall {
345 id: "call_123".to_string(),
346 tool_type: "function".to_string(),
347 function: FunctionCall {
348 name: "search".to_string(),
349 arguments: r#"{"query":"test"}"#.to_string(),
350 },
351 };
352 let message = Message::assistant("Let me search", Some(vec![tool_call]));
353 let tokens = counter.count_message(&message);
354 assert!(tokens >= 15, "Expected at least 15 tokens, got {}", tokens);
355 }
356
357 #[test]
358 fn tiktoken_counter_more_accurate_than_heuristic() {
359 let heuristic = HeuristicTokenCounter::default();
360 let tiktoken = TiktokenTokenCounter::default();
361
362 let text = "The quick brown fox jumps over the lazy dog.";
363 let h_tokens = heuristic.count_text(text);
364 let t_tokens = tiktoken.count_text(text);
365
366 assert!(h_tokens > 0 && t_tokens > 0);
368 }
369}