bamboo_compression/
counter.rs1use bamboo_agent_core::Message;
9use std::sync::Arc;
10use tiktoken_rs::o200k_base;
11use tiktoken_rs::CoreBPE;
12
13static O200K_ENCODER: std::sync::LazyLock<CoreBPE> =
15 std::sync::LazyLock::new(|| o200k_base().unwrap());
16
17pub trait TokenCounter: Send + Sync {
19 fn count_message(&self, message: &Message) -> u32;
21
22 fn count_messages(&self, messages: &[Message]) -> u32 {
24 messages.iter().map(|m| self.count_message(m)).sum()
25 }
26
27 fn count_text(&self, text: &str) -> u32;
29}
30
31#[derive(Debug, Clone)]
38pub struct HeuristicTokenCounter {
39 chars_per_token: f64,
41 safety_margin: f64,
43 metadata_overhead: u32,
45}
46
47impl HeuristicTokenCounter {
48 pub fn new(chars_per_token: f64, safety_margin: f64, metadata_overhead: u32) -> Self {
50 Self {
51 chars_per_token,
52 safety_margin,
53 metadata_overhead,
54 }
55 }
56
57 pub fn with_defaults() -> Self {
59 Self {
60 chars_per_token: 4.0,
61 safety_margin: 1.1,
62 metadata_overhead: 10,
63 }
64 }
65}
66
67impl Default for HeuristicTokenCounter {
68 fn default() -> Self {
69 Self::with_defaults()
70 }
71}
72
73impl TokenCounter for HeuristicTokenCounter {
74 fn count_message(&self, message: &Message) -> u32 {
75 let content_tokens = self.count_text(&message.content);
76
77 let tool_calls_tokens = message
79 .tool_calls
80 .as_ref()
81 .map(|tc| {
82 tc.iter()
83 .map(|c| {
84 let args_tokens = self.count_text(&c.function.arguments);
86 let id_tokens = self.count_text(&c.id);
87 let name_tokens = self.count_text(&c.function.name);
88 args_tokens
90 .saturating_add(id_tokens)
91 .saturating_add(name_tokens)
92 .saturating_add(5) })
94 .fold(0u32, |acc, x| acc.saturating_add(x))
95 })
96 .unwrap_or(0);
97
98 let tool_call_id_tokens = message
100 .tool_call_id
101 .as_ref()
102 .map(|id| self.count_text(id).saturating_add(3)) .unwrap_or(0);
104
105 content_tokens
107 .saturating_add(tool_calls_tokens)
108 .saturating_add(tool_call_id_tokens)
109 .saturating_add(self.metadata_overhead)
110 }
111
112 fn count_text(&self, text: &str) -> u32 {
113 if text.is_empty() {
114 return 0;
115 }
116
117 let char_count = text.chars().count() as f64;
118 let base_tokens = char_count / self.chars_per_token;
119 let adjusted_tokens = base_tokens * self.safety_margin;
120
121 adjusted_tokens.ceil() as u32
122 }
123}
124
125pub type SharedTokenCounter = Arc<dyn TokenCounter>;
127
128#[derive(Debug)]
133pub struct TiktokenTokenCounter {
134 metadata_overhead: u32,
136}
137
138impl TiktokenTokenCounter {
139 pub fn new(metadata_overhead: u32) -> Self {
141 Self { metadata_overhead }
142 }
143}
144
145impl Default for TiktokenTokenCounter {
146 fn default() -> Self {
147 Self {
148 metadata_overhead: 10,
149 }
150 }
151}
152
153impl TokenCounter for TiktokenTokenCounter {
154 fn count_message(&self, message: &Message) -> u32 {
155 let content_tokens = self.count_text(&message.content);
156
157 let tool_calls_tokens = message
158 .tool_calls
159 .as_ref()
160 .map(|tc| {
161 tc.iter()
162 .map(|c| {
163 let args_tokens = self.count_text(&c.function.arguments);
164 let id_tokens = self.count_text(&c.id);
165 let name_tokens = self.count_text(&c.function.name);
166 args_tokens
167 .saturating_add(id_tokens)
168 .saturating_add(name_tokens)
169 .saturating_add(5)
170 })
171 .fold(0u32, |acc, x| acc.saturating_add(x))
172 })
173 .unwrap_or(0);
174
175 let tool_call_id_tokens = message
176 .tool_call_id
177 .as_ref()
178 .map(|id| self.count_text(id).saturating_add(3))
179 .unwrap_or(0);
180
181 content_tokens
182 .saturating_add(tool_calls_tokens)
183 .saturating_add(tool_call_id_tokens)
184 .saturating_add(self.metadata_overhead)
185 }
186
187 fn count_text(&self, text: &str) -> u32 {
188 if text.is_empty() {
189 return 0;
190 }
191 let tokens = O200K_ENCODER.encode_with_special_tokens(text);
192 tokens.len() as u32
193 }
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199 use bamboo_agent_core::{FunctionCall, ToolCall};
200
201 #[test]
202 fn heuristic_counter_counts_text() {
203 let counter = HeuristicTokenCounter::default();
204
205 let tokens = counter.count_text("Hello, world!");
207 assert!(
208 tokens >= 3 && tokens <= 5,
209 "Expected ~4 tokens, got {}",
210 tokens
211 );
212 }
213
214 #[test]
215 fn heuristic_counter_counts_empty_text() {
216 let counter = HeuristicTokenCounter::default();
217 assert_eq!(counter.count_text(""), 0);
218 }
219
220 #[test]
221 fn heuristic_counter_counts_user_message() {
222 let counter = HeuristicTokenCounter::default();
223 let message = Message::user("Hello, world!");
224
225 let tokens = counter.count_message(&message);
226 assert!(
228 tokens >= 10,
229 "Expected at least 10 tokens (content + metadata), got {}",
230 tokens
231 );
232 }
233
234 #[test]
235 fn heuristic_counter_counts_tool_calls() {
236 let counter = HeuristicTokenCounter::default();
237
238 let tool_call = ToolCall {
239 id: "call_123".to_string(),
240 tool_type: "function".to_string(),
241 function: FunctionCall {
242 name: "search".to_string(),
243 arguments: r#"{"query":"test"}"#.to_string(),
244 },
245 };
246
247 let message = Message::assistant("Let me search", Some(vec![tool_call]));
248
249 let tokens = counter.count_message(&message);
250 assert!(tokens >= 15, "Expected at least 15 tokens, got {}", tokens);
252 }
253
254 #[test]
255 fn heuristic_counter_counts_tool_result() {
256 let counter = HeuristicTokenCounter::default();
257 let message = Message::tool_result("call_123", "Search results here");
258
259 let tokens = counter.count_message(&message);
260 assert!(tokens >= 15, "Expected at least 15 tokens, got {}", tokens);
262 }
263
264 #[test]
265 fn heuristic_counter_counts_multiple_messages() {
266 let counter = HeuristicTokenCounter::default();
267 let messages = vec![
268 Message::system("You are helpful"),
269 Message::user("Hello"),
270 Message::assistant("Hi there", None),
271 ];
272
273 let total = counter.count_messages(&messages);
274 let sum: u32 = messages.iter().map(|m| counter.count_message(m)).sum();
275
276 assert_eq!(total, sum);
277 }
278
279 #[test]
280 fn custom_chars_per_token() {
281 let counter = HeuristicTokenCounter::new(2.0, 1.0, 0);
282 let tokens = counter.count_text("test");
284 assert_eq!(tokens, 2);
285 }
286
287 #[test]
288 fn safety_margin_applied() {
289 let counter_no_margin = HeuristicTokenCounter::new(4.0, 1.0, 0);
290 let counter_with_margin = HeuristicTokenCounter::new(4.0, 1.1, 0);
291
292 let text = "Hello world!"; let base = counter_no_margin.count_text(text);
294 let adjusted = counter_with_margin.count_text(text);
295
296 assert!(adjusted > base, "Safety margin should increase token count");
297 }
298
299 #[test]
302 fn tiktoken_counter_counts_text() {
303 let counter = TiktokenTokenCounter::default();
304 let tokens = counter.count_text("Hello, world!");
305 assert!(
307 tokens >= 3 && tokens <= 6,
308 "Expected ~4 tokens, got {}",
309 tokens
310 );
311 }
312
313 #[test]
314 fn tiktoken_counter_counts_empty_text() {
315 let counter = TiktokenTokenCounter::default();
316 assert_eq!(counter.count_text(""), 0);
317 }
318
319 #[test]
320 fn tiktoken_counter_counts_cjk() {
321 let counter = TiktokenTokenCounter::default();
322 let tokens = counter.count_text("你好世界");
324 assert!(
325 tokens >= 2 && tokens <= 8,
326 "Expected 2-8 tokens, got {}",
327 tokens
328 );
329 }
330
331 #[test]
332 fn tiktoken_counter_counts_user_message() {
333 let counter = TiktokenTokenCounter::default();
334 let message = Message::user("Hello, world!");
335 let tokens = counter.count_message(&message);
336 assert!(tokens >= 10, "Expected at least 10 tokens, got {}", tokens);
338 }
339
340 #[test]
341 fn tiktoken_counter_counts_tool_calls() {
342 let counter = TiktokenTokenCounter::default();
343 let tool_call = ToolCall {
344 id: "call_123".to_string(),
345 tool_type: "function".to_string(),
346 function: FunctionCall {
347 name: "search".to_string(),
348 arguments: r#"{"query":"test"}"#.to_string(),
349 },
350 };
351 let message = Message::assistant("Let me search", Some(vec![tool_call]));
352 let tokens = counter.count_message(&message);
353 assert!(tokens >= 15, "Expected at least 15 tokens, got {}", tokens);
354 }
355
356 #[test]
357 fn tiktoken_counter_more_accurate_than_heuristic() {
358 let heuristic = HeuristicTokenCounter::default();
359 let tiktoken = TiktokenTokenCounter::default();
360
361 let text = "The quick brown fox jumps over the lazy dog.";
362 let h_tokens = heuristic.count_text(text);
363 let t_tokens = tiktoken.count_text(text);
364
365 assert!(h_tokens > 0 && t_tokens > 0);
367 }
368}