Skip to main content

bamboo_compression/
counter.rs

1//! Token counting for budget management.
2//!
3//! Provides both heuristic and accurate BPE-based token counting.
4//! `TiktokenTokenCounter` uses OpenAI's o200k_base encoding (bundled at compile
5//! time) for accurate counts. `HeuristicTokenCounter` remains available as a
6//! lightweight fallback.
7
8use std::sync::OnceLock;
9
10use bamboo_agent_core::Message;
11use tiktoken_rs::o200k_base;
12use tiktoken_rs::CoreBPE;
13
14/// Cached BPE encoder — initialized once, reused across all count_text calls.
15static O200K_ENCODER: OnceLock<CoreBPE> = OnceLock::new();
16
17fn o200k_encoder() -> &'static CoreBPE {
18    O200K_ENCODER.get_or_init(|| o200k_base().unwrap())
19}
20
21/// Trait for token counting implementations.
22pub trait TokenCounter: Send + Sync {
23    /// Count tokens in a single message.
24    fn count_message(&self, message: &Message) -> u32;
25
26    /// Count tokens in multiple messages.
27    fn count_messages(&self, messages: &[Message]) -> u32 {
28        messages.iter().map(|m| self.count_message(m)).sum()
29    }
30
31    /// Count tokens in a plain text string.
32    fn count_text(&self, text: &str) -> u32;
33}
34
35/// Heuristic token counter using character-based estimation.
36///
37/// Uses the approximation: tokens ≈ characters / 4, with a 10% safety margin
38/// plus additional overhead for message metadata (role, timestamps, etc.).
39///
40/// This is intentionally conservative to avoid underestimating token usage.
41#[derive(Debug, Clone)]
42pub struct HeuristicTokenCounter {
43    /// Characters per token ratio (default: 4)
44    chars_per_token: f64,
45    /// Safety margin multiplier (default: 1.1 = 10% extra)
46    safety_margin: f64,
47    /// Metadata overhead per message in tokens
48    metadata_overhead: u32,
49}
50
51impl HeuristicTokenCounter {
52    /// Create a new heuristic counter with custom parameters.
53    pub fn new(chars_per_token: f64, safety_margin: f64, metadata_overhead: u32) -> Self {
54        Self {
55            chars_per_token,
56            safety_margin,
57            metadata_overhead,
58        }
59    }
60
61    /// Create with default parameters (chars/4 + 10% margin + 10 metadata overhead).
62    pub fn with_defaults() -> Self {
63        Self {
64            chars_per_token: 4.0,
65            safety_margin: 1.1,
66            metadata_overhead: 10,
67        }
68    }
69}
70
71impl Default for HeuristicTokenCounter {
72    fn default() -> Self {
73        Self::with_defaults()
74    }
75}
76
77impl TokenCounter for HeuristicTokenCounter {
78    fn count_message(&self, message: &Message) -> u32 {
79        let content_tokens = self.count_text(&message.content);
80
81        // Add tokens for tool calls if present
82        let tool_calls_tokens = message
83            .tool_calls
84            .as_ref()
85            .map(|tc| {
86                tc.iter()
87                    .map(|c| {
88                        // Rough estimate: id + name + arguments
89                        let args_tokens = self.count_text(&c.function.arguments);
90                        let id_tokens = self.count_text(&c.id);
91                        let name_tokens = self.count_text(&c.function.name);
92                        // Use saturating_add to prevent overflow
93                        args_tokens
94                            .saturating_add(id_tokens)
95                            .saturating_add(name_tokens)
96                            .saturating_add(5) // type overhead
97                    })
98                    .fold(0u32, |acc, x| acc.saturating_add(x))
99            })
100            .unwrap_or(0);
101
102        // Add tokens for tool_call_id if present
103        let tool_call_id_tokens = message
104            .tool_call_id
105            .as_ref()
106            .map(|id| self.count_text(id).saturating_add(3)) // +3 for field name overhead
107            .unwrap_or(0);
108
109        // Use saturating_add to prevent overflow
110        content_tokens
111            .saturating_add(tool_calls_tokens)
112            .saturating_add(tool_call_id_tokens)
113            .saturating_add(self.metadata_overhead)
114    }
115
116    fn count_text(&self, text: &str) -> u32 {
117        if text.is_empty() {
118            return 0;
119        }
120
121        let char_count = text.chars().count() as f64;
122        let base_tokens = char_count / self.chars_per_token;
123        let adjusted_tokens = base_tokens * self.safety_margin;
124
125        adjusted_tokens.ceil() as u32
126    }
127}
128
129/// Accurate BPE-based token counter using OpenAI's o200k_base encoding.
130///
131/// Uses `tiktoken-rs` with the vocabulary bundled at compile time — no runtime
132/// downloads. This is the recommended counter for production use.
133#[derive(Debug)]
134pub struct TiktokenTokenCounter {
135    /// Per-message metadata overhead in tokens (role markers, formatting, etc.)
136    metadata_overhead: u32,
137}
138
139impl TiktokenTokenCounter {
140    /// Create with a custom metadata overhead.
141    pub fn new(metadata_overhead: u32) -> Self {
142        Self { metadata_overhead }
143    }
144}
145
146impl Default for TiktokenTokenCounter {
147    fn default() -> Self {
148        Self {
149            metadata_overhead: 10,
150        }
151    }
152}
153
154impl TokenCounter for TiktokenTokenCounter {
155    fn count_message(&self, message: &Message) -> u32 {
156        let content_tokens = self.count_text(&message.content);
157
158        let tool_calls_tokens = message
159            .tool_calls
160            .as_ref()
161            .map(|tc| {
162                tc.iter()
163                    .map(|c| {
164                        let args_tokens = self.count_text(&c.function.arguments);
165                        let id_tokens = self.count_text(&c.id);
166                        let name_tokens = self.count_text(&c.function.name);
167                        args_tokens
168                            .saturating_add(id_tokens)
169                            .saturating_add(name_tokens)
170                            .saturating_add(5)
171                    })
172                    .fold(0u32, |acc, x| acc.saturating_add(x))
173            })
174            .unwrap_or(0);
175
176        let tool_call_id_tokens = message
177            .tool_call_id
178            .as_ref()
179            .map(|id| self.count_text(id).saturating_add(3))
180            .unwrap_or(0);
181
182        content_tokens
183            .saturating_add(tool_calls_tokens)
184            .saturating_add(tool_call_id_tokens)
185            .saturating_add(self.metadata_overhead)
186    }
187
188    fn count_text(&self, text: &str) -> u32 {
189        if text.is_empty() {
190            return 0;
191        }
192        let tokens = o200k_encoder().encode_with_special_tokens(text);
193        tokens.len() as u32
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200    use bamboo_agent_core::{FunctionCall, ToolCall};
201
202    #[test]
203    fn heuristic_counter_counts_text() {
204        let counter = HeuristicTokenCounter::default();
205
206        // "Hello, world!" = 13 chars -> 13/4 * 1.1 ≈ 3.57 -> 4 tokens
207        let tokens = counter.count_text("Hello, world!");
208        assert!(
209            (3..=5).contains(&tokens),
210            "Expected ~4 tokens, got {}",
211            tokens
212        );
213    }
214
215    #[test]
216    fn heuristic_counter_counts_empty_text() {
217        let counter = HeuristicTokenCounter::default();
218        assert_eq!(counter.count_text(""), 0);
219    }
220
221    #[test]
222    fn heuristic_counter_counts_user_message() {
223        let counter = HeuristicTokenCounter::default();
224        let message = Message::user("Hello, world!");
225
226        let tokens = counter.count_message(&message);
227        // Should include content + metadata overhead (10)
228        assert!(
229            tokens >= 10,
230            "Expected at least 10 tokens (content + metadata), got {}",
231            tokens
232        );
233    }
234
235    #[test]
236    fn heuristic_counter_counts_tool_calls() {
237        let counter = HeuristicTokenCounter::default();
238
239        let tool_call = ToolCall {
240            id: "call_123".to_string(),
241            tool_type: "function".to_string(),
242            function: FunctionCall {
243                name: "search".to_string(),
244                arguments: r#"{"query":"test"}"#.to_string(),
245            },
246        };
247
248        let message = Message::assistant("Let me search", Some(vec![tool_call]));
249
250        let tokens = counter.count_message(&message);
251        // Should include content + tool call (id + name + args) + metadata
252        assert!(tokens >= 15, "Expected at least 15 tokens, got {}", tokens);
253    }
254
255    #[test]
256    fn heuristic_counter_counts_tool_result() {
257        let counter = HeuristicTokenCounter::default();
258        let message = Message::tool_result("call_123", "Search results here");
259
260        let tokens = counter.count_message(&message);
261        // Should include content + tool_call_id + metadata
262        assert!(tokens >= 15, "Expected at least 15 tokens, got {}", tokens);
263    }
264
265    #[test]
266    fn heuristic_counter_counts_multiple_messages() {
267        let counter = HeuristicTokenCounter::default();
268        let messages = vec![
269            Message::system("You are helpful"),
270            Message::user("Hello"),
271            Message::assistant("Hi there", None),
272        ];
273
274        let total = counter.count_messages(&messages);
275        let sum: u32 = messages.iter().map(|m| counter.count_message(m)).sum();
276
277        assert_eq!(total, sum);
278    }
279
280    #[test]
281    fn custom_chars_per_token() {
282        let counter = HeuristicTokenCounter::new(2.0, 1.0, 0);
283        // With 2 chars per token, "test" (4 chars) = 2 tokens
284        let tokens = counter.count_text("test");
285        assert_eq!(tokens, 2);
286    }
287
288    #[test]
289    fn safety_margin_applied() {
290        let counter_no_margin = HeuristicTokenCounter::new(4.0, 1.0, 0);
291        let counter_with_margin = HeuristicTokenCounter::new(4.0, 1.1, 0);
292
293        let text = "Hello world!"; // 12 chars
294        let base = counter_no_margin.count_text(text);
295        let adjusted = counter_with_margin.count_text(text);
296
297        assert!(adjusted > base, "Safety margin should increase token count");
298    }
299
300    // --- TiktokenTokenCounter tests ---
301
302    #[test]
303    fn tiktoken_counter_counts_text() {
304        let counter = TiktokenTokenCounter::default();
305        let tokens = counter.count_text("Hello, world!");
306        // "Hello, world!" is 4 tokens with o200k_base
307        assert!(
308            (3..=6).contains(&tokens),
309            "Expected ~4 tokens, got {}",
310            tokens
311        );
312    }
313
314    #[test]
315    fn tiktoken_counter_counts_empty_text() {
316        let counter = TiktokenTokenCounter::default();
317        assert_eq!(counter.count_text(""), 0);
318    }
319
320    #[test]
321    fn tiktoken_counter_counts_cjk() {
322        let counter = TiktokenTokenCounter::default();
323        // CJK text: each character is typically 1-2 tokens
324        let tokens = counter.count_text("你好世界");
325        assert!(
326            (2..=8).contains(&tokens),
327            "Expected 2-8 tokens, got {}",
328            tokens
329        );
330    }
331
332    #[test]
333    fn tiktoken_counter_counts_user_message() {
334        let counter = TiktokenTokenCounter::default();
335        let message = Message::user("Hello, world!");
336        let tokens = counter.count_message(&message);
337        // Should include content + metadata overhead (10)
338        assert!(tokens >= 10, "Expected at least 10 tokens, got {}", tokens);
339    }
340
341    #[test]
342    fn tiktoken_counter_counts_tool_calls() {
343        let counter = TiktokenTokenCounter::default();
344        let tool_call = ToolCall {
345            id: "call_123".to_string(),
346            tool_type: "function".to_string(),
347            function: FunctionCall {
348                name: "search".to_string(),
349                arguments: r#"{"query":"test"}"#.to_string(),
350            },
351        };
352        let message = Message::assistant("Let me search", Some(vec![tool_call]));
353        let tokens = counter.count_message(&message);
354        assert!(tokens >= 15, "Expected at least 15 tokens, got {}", tokens);
355    }
356
357    #[test]
358    fn tiktoken_counter_more_accurate_than_heuristic() {
359        let heuristic = HeuristicTokenCounter::default();
360        let tiktoken = TiktokenTokenCounter::default();
361
362        let text = "The quick brown fox jumps over the lazy dog.";
363        let h_tokens = heuristic.count_text(text);
364        let t_tokens = tiktoken.count_text(text);
365
366        // Both should produce reasonable counts
367        assert!(h_tokens > 0 && t_tokens > 0);
368    }
369}