Skip to main content

bamboo_compression/
counter.rs

1//! Token counting for budget management.
2//!
3//! Provides both heuristic and accurate BPE-based token counting.
4//! `TiktokenTokenCounter` uses OpenAI's o200k_base encoding (bundled at compile
5//! time) for accurate counts. `HeuristicTokenCounter` remains available as a
6//! lightweight fallback.
7
8use bamboo_agent_core::Message;
9use std::sync::Arc;
10use tiktoken_rs::o200k_base;
11use tiktoken_rs::CoreBPE;
12
13/// Cached BPE encoder — initialized once, reused across all count_text calls.
14static O200K_ENCODER: std::sync::LazyLock<CoreBPE> =
15    std::sync::LazyLock::new(|| o200k_base().unwrap());
16
17/// Trait for token counting implementations.
18pub trait TokenCounter: Send + Sync {
19    /// Count tokens in a single message.
20    fn count_message(&self, message: &Message) -> u32;
21
22    /// Count tokens in multiple messages.
23    fn count_messages(&self, messages: &[Message]) -> u32 {
24        messages.iter().map(|m| self.count_message(m)).sum()
25    }
26
27    /// Count tokens in a plain text string.
28    fn count_text(&self, text: &str) -> u32;
29}
30
31/// Heuristic token counter using character-based estimation.
32///
33/// Uses the approximation: tokens ≈ characters / 4, with a 10% safety margin
34/// plus additional overhead for message metadata (role, timestamps, etc.).
35///
36/// This is intentionally conservative to avoid underestimating token usage.
37#[derive(Debug, Clone)]
38pub struct HeuristicTokenCounter {
39    /// Characters per token ratio (default: 4)
40    chars_per_token: f64,
41    /// Safety margin multiplier (default: 1.1 = 10% extra)
42    safety_margin: f64,
43    /// Metadata overhead per message in tokens
44    metadata_overhead: u32,
45}
46
47impl HeuristicTokenCounter {
48    /// Create a new heuristic counter with custom parameters.
49    pub fn new(chars_per_token: f64, safety_margin: f64, metadata_overhead: u32) -> Self {
50        Self {
51            chars_per_token,
52            safety_margin,
53            metadata_overhead,
54        }
55    }
56
57    /// Create with default parameters (chars/4 + 10% margin + 10 metadata overhead).
58    pub fn with_defaults() -> Self {
59        Self {
60            chars_per_token: 4.0,
61            safety_margin: 1.1,
62            metadata_overhead: 10,
63        }
64    }
65}
66
67impl Default for HeuristicTokenCounter {
68    fn default() -> Self {
69        Self::with_defaults()
70    }
71}
72
73impl TokenCounter for HeuristicTokenCounter {
74    fn count_message(&self, message: &Message) -> u32 {
75        let content_tokens = self.count_text(&message.content);
76
77        // Add tokens for tool calls if present
78        let tool_calls_tokens = message
79            .tool_calls
80            .as_ref()
81            .map(|tc| {
82                tc.iter()
83                    .map(|c| {
84                        // Rough estimate: id + name + arguments
85                        let args_tokens = self.count_text(&c.function.arguments);
86                        let id_tokens = self.count_text(&c.id);
87                        let name_tokens = self.count_text(&c.function.name);
88                        // Use saturating_add to prevent overflow
89                        args_tokens
90                            .saturating_add(id_tokens)
91                            .saturating_add(name_tokens)
92                            .saturating_add(5) // type overhead
93                    })
94                    .fold(0u32, |acc, x| acc.saturating_add(x))
95            })
96            .unwrap_or(0);
97
98        // Add tokens for tool_call_id if present
99        let tool_call_id_tokens = message
100            .tool_call_id
101            .as_ref()
102            .map(|id| self.count_text(id).saturating_add(3)) // +3 for field name overhead
103            .unwrap_or(0);
104
105        // Use saturating_add to prevent overflow
106        content_tokens
107            .saturating_add(tool_calls_tokens)
108            .saturating_add(tool_call_id_tokens)
109            .saturating_add(self.metadata_overhead)
110    }
111
112    fn count_text(&self, text: &str) -> u32 {
113        if text.is_empty() {
114            return 0;
115        }
116
117        let char_count = text.chars().count() as f64;
118        let base_tokens = char_count / self.chars_per_token;
119        let adjusted_tokens = base_tokens * self.safety_margin;
120
121        adjusted_tokens.ceil() as u32
122    }
123}
124
125/// Arc-wrapped token counter for easy sharing.
126pub type SharedTokenCounter = Arc<dyn TokenCounter>;
127
128/// Accurate BPE-based token counter using OpenAI's o200k_base encoding.
129///
130/// Uses `tiktoken-rs` with the vocabulary bundled at compile time — no runtime
131/// downloads. This is the recommended counter for production use.
132#[derive(Debug)]
133pub struct TiktokenTokenCounter {
134    /// Per-message metadata overhead in tokens (role markers, formatting, etc.)
135    metadata_overhead: u32,
136}
137
138impl TiktokenTokenCounter {
139    /// Create with a custom metadata overhead.
140    pub fn new(metadata_overhead: u32) -> Self {
141        Self { metadata_overhead }
142    }
143}
144
145impl Default for TiktokenTokenCounter {
146    fn default() -> Self {
147        Self {
148            metadata_overhead: 10,
149        }
150    }
151}
152
153impl TokenCounter for TiktokenTokenCounter {
154    fn count_message(&self, message: &Message) -> u32 {
155        let content_tokens = self.count_text(&message.content);
156
157        let tool_calls_tokens = message
158            .tool_calls
159            .as_ref()
160            .map(|tc| {
161                tc.iter()
162                    .map(|c| {
163                        let args_tokens = self.count_text(&c.function.arguments);
164                        let id_tokens = self.count_text(&c.id);
165                        let name_tokens = self.count_text(&c.function.name);
166                        args_tokens
167                            .saturating_add(id_tokens)
168                            .saturating_add(name_tokens)
169                            .saturating_add(5)
170                    })
171                    .fold(0u32, |acc, x| acc.saturating_add(x))
172            })
173            .unwrap_or(0);
174
175        let tool_call_id_tokens = message
176            .tool_call_id
177            .as_ref()
178            .map(|id| self.count_text(id).saturating_add(3))
179            .unwrap_or(0);
180
181        content_tokens
182            .saturating_add(tool_calls_tokens)
183            .saturating_add(tool_call_id_tokens)
184            .saturating_add(self.metadata_overhead)
185    }
186
187    fn count_text(&self, text: &str) -> u32 {
188        if text.is_empty() {
189            return 0;
190        }
191        let tokens = O200K_ENCODER.encode_with_special_tokens(text);
192        tokens.len() as u32
193    }
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199    use bamboo_agent_core::{FunctionCall, ToolCall};
200
201    #[test]
202    fn heuristic_counter_counts_text() {
203        let counter = HeuristicTokenCounter::default();
204
205        // "Hello, world!" = 13 chars -> 13/4 * 1.1 ≈ 3.57 -> 4 tokens
206        let tokens = counter.count_text("Hello, world!");
207        assert!(
208            tokens >= 3 && tokens <= 5,
209            "Expected ~4 tokens, got {}",
210            tokens
211        );
212    }
213
214    #[test]
215    fn heuristic_counter_counts_empty_text() {
216        let counter = HeuristicTokenCounter::default();
217        assert_eq!(counter.count_text(""), 0);
218    }
219
220    #[test]
221    fn heuristic_counter_counts_user_message() {
222        let counter = HeuristicTokenCounter::default();
223        let message = Message::user("Hello, world!");
224
225        let tokens = counter.count_message(&message);
226        // Should include content + metadata overhead (10)
227        assert!(
228            tokens >= 10,
229            "Expected at least 10 tokens (content + metadata), got {}",
230            tokens
231        );
232    }
233
234    #[test]
235    fn heuristic_counter_counts_tool_calls() {
236        let counter = HeuristicTokenCounter::default();
237
238        let tool_call = ToolCall {
239            id: "call_123".to_string(),
240            tool_type: "function".to_string(),
241            function: FunctionCall {
242                name: "search".to_string(),
243                arguments: r#"{"query":"test"}"#.to_string(),
244            },
245        };
246
247        let message = Message::assistant("Let me search", Some(vec![tool_call]));
248
249        let tokens = counter.count_message(&message);
250        // Should include content + tool call (id + name + args) + metadata
251        assert!(tokens >= 15, "Expected at least 15 tokens, got {}", tokens);
252    }
253
254    #[test]
255    fn heuristic_counter_counts_tool_result() {
256        let counter = HeuristicTokenCounter::default();
257        let message = Message::tool_result("call_123", "Search results here");
258
259        let tokens = counter.count_message(&message);
260        // Should include content + tool_call_id + metadata
261        assert!(tokens >= 15, "Expected at least 15 tokens, got {}", tokens);
262    }
263
264    #[test]
265    fn heuristic_counter_counts_multiple_messages() {
266        let counter = HeuristicTokenCounter::default();
267        let messages = vec![
268            Message::system("You are helpful"),
269            Message::user("Hello"),
270            Message::assistant("Hi there", None),
271        ];
272
273        let total = counter.count_messages(&messages);
274        let sum: u32 = messages.iter().map(|m| counter.count_message(m)).sum();
275
276        assert_eq!(total, sum);
277    }
278
279    #[test]
280    fn custom_chars_per_token() {
281        let counter = HeuristicTokenCounter::new(2.0, 1.0, 0);
282        // With 2 chars per token, "test" (4 chars) = 2 tokens
283        let tokens = counter.count_text("test");
284        assert_eq!(tokens, 2);
285    }
286
287    #[test]
288    fn safety_margin_applied() {
289        let counter_no_margin = HeuristicTokenCounter::new(4.0, 1.0, 0);
290        let counter_with_margin = HeuristicTokenCounter::new(4.0, 1.1, 0);
291
292        let text = "Hello world!"; // 12 chars
293        let base = counter_no_margin.count_text(text);
294        let adjusted = counter_with_margin.count_text(text);
295
296        assert!(adjusted > base, "Safety margin should increase token count");
297    }
298
299    // --- TiktokenTokenCounter tests ---
300
301    #[test]
302    fn tiktoken_counter_counts_text() {
303        let counter = TiktokenTokenCounter::default();
304        let tokens = counter.count_text("Hello, world!");
305        // "Hello, world!" is 4 tokens with o200k_base
306        assert!(
307            tokens >= 3 && tokens <= 6,
308            "Expected ~4 tokens, got {}",
309            tokens
310        );
311    }
312
313    #[test]
314    fn tiktoken_counter_counts_empty_text() {
315        let counter = TiktokenTokenCounter::default();
316        assert_eq!(counter.count_text(""), 0);
317    }
318
319    #[test]
320    fn tiktoken_counter_counts_cjk() {
321        let counter = TiktokenTokenCounter::default();
322        // CJK text: each character is typically 1-2 tokens
323        let tokens = counter.count_text("你好世界");
324        assert!(
325            tokens >= 2 && tokens <= 8,
326            "Expected 2-8 tokens, got {}",
327            tokens
328        );
329    }
330
331    #[test]
332    fn tiktoken_counter_counts_user_message() {
333        let counter = TiktokenTokenCounter::default();
334        let message = Message::user("Hello, world!");
335        let tokens = counter.count_message(&message);
336        // Should include content + metadata overhead (10)
337        assert!(tokens >= 10, "Expected at least 10 tokens, got {}", tokens);
338    }
339
340    #[test]
341    fn tiktoken_counter_counts_tool_calls() {
342        let counter = TiktokenTokenCounter::default();
343        let tool_call = ToolCall {
344            id: "call_123".to_string(),
345            tool_type: "function".to_string(),
346            function: FunctionCall {
347                name: "search".to_string(),
348                arguments: r#"{"query":"test"}"#.to_string(),
349            },
350        };
351        let message = Message::assistant("Let me search", Some(vec![tool_call]));
352        let tokens = counter.count_message(&message);
353        assert!(tokens >= 15, "Expected at least 15 tokens, got {}", tokens);
354    }
355
356    #[test]
357    fn tiktoken_counter_more_accurate_than_heuristic() {
358        let heuristic = HeuristicTokenCounter::default();
359        let tiktoken = TiktokenTokenCounter::default();
360
361        let text = "The quick brown fox jumps over the lazy dog.";
362        let h_tokens = heuristic.count_text(text);
363        let t_tokens = tiktoken.count_text(text);
364
365        // Both should produce reasonable counts
366        assert!(h_tokens > 0 && t_tokens > 0);
367    }
368}