Skip to main content

bamboo_compression/
counter.rs

1//! Token counting for budget management.
2//!
3//! Provides both heuristic and accurate BPE-based token counting.
4//! `TiktokenTokenCounter` uses OpenAI's o200k_base encoding (bundled at compile
5//! time) for accurate counts. `HeuristicTokenCounter` remains available as a
6//! lightweight fallback and is also used automatically if the bundled BPE
7//! vocabulary ever fails to load, so a failed load degrades gracefully rather
8//! than panicking.
9
10use std::sync::OnceLock;
11
12use bamboo_domain::Message;
13use tiktoken_rs::o200k_base;
14use tiktoken_rs::CoreBPE;
15
16/// Cached BPE encoder — initialized once, reused across all count_text calls.
17///
18/// Holds `None` if the bundled o200k_base vocabulary failed to load (e.g. a
19/// corrupt or unlinkable build). When `None`, `TiktokenTokenCounter` falls back
20/// to `HeuristicTokenCounter` instead of panicking. The failure is logged once
21/// at initialization time so the degradation is observable.
22static O200K_ENCODER: OnceLock<Option<CoreBPE>> = OnceLock::new();
23
24/// Returns the cached o200k_base encoder, or `None` if it failed to load.
25///
26/// The first call loads (or attempts to load) the bundled vocabulary exactly
27/// once; a load failure is logged a single time and cached as `None`, so the
28/// hot path never panics and never re-attempts the failing load.
29fn o200k_encoder() -> Option<&'static CoreBPE> {
30    O200K_ENCODER
31        .get_or_init(|| match o200k_base() {
32            Ok(encoder) => Some(encoder),
33            Err(err) => {
34                tracing::warn!(
35                    error = %err,
36                    "failed to load bundled o200k_base tokenizer; \
37                     falling back to heuristic token counting"
38                );
39                None
40            }
41        })
42        .as_ref()
43}
44
45/// Trait for token counting implementations.
46pub trait TokenCounter: Send + Sync {
47    /// Count tokens in a single message.
48    fn count_message(&self, message: &Message) -> u32;
49
50    /// Count tokens in multiple messages.
51    fn count_messages(&self, messages: &[Message]) -> u32 {
52        messages.iter().map(|m| self.count_message(m)).sum()
53    }
54
55    /// Count tokens in a plain text string.
56    fn count_text(&self, text: &str) -> u32;
57}
58
59/// Heuristic token counter using character-based estimation.
60///
61/// Uses the approximation: tokens ≈ characters / 4, with a 10% safety margin
62/// plus additional overhead for message metadata (role, timestamps, etc.).
63///
64/// This is intentionally conservative to avoid underestimating token usage.
65#[derive(Debug, Clone)]
66pub struct HeuristicTokenCounter {
67    /// Characters per token ratio (default: 4)
68    chars_per_token: f64,
69    /// Safety margin multiplier (default: 1.1 = 10% extra)
70    safety_margin: f64,
71    /// Metadata overhead per message in tokens
72    metadata_overhead: u32,
73}
74
75impl HeuristicTokenCounter {
76    /// Create a new heuristic counter with custom parameters.
77    pub fn new(chars_per_token: f64, safety_margin: f64, metadata_overhead: u32) -> Self {
78        Self {
79            chars_per_token,
80            safety_margin,
81            metadata_overhead,
82        }
83    }
84
85    /// Create with default parameters (chars/4 + 10% margin + 10 metadata overhead).
86    pub fn with_defaults() -> Self {
87        Self {
88            chars_per_token: 4.0,
89            safety_margin: 1.1,
90            metadata_overhead: 10,
91        }
92    }
93}
94
95impl Default for HeuristicTokenCounter {
96    fn default() -> Self {
97        Self::with_defaults()
98    }
99}
100
101impl TokenCounter for HeuristicTokenCounter {
102    fn count_message(&self, message: &Message) -> u32 {
103        let content_tokens = self.count_text(&message.content);
104
105        // Add tokens for tool calls if present
106        let tool_calls_tokens = message
107            .tool_calls
108            .as_ref()
109            .map(|tc| {
110                tc.iter()
111                    .map(|c| {
112                        // Rough estimate: id + name + arguments
113                        let args_tokens = self.count_text(&c.function.arguments);
114                        let id_tokens = self.count_text(&c.id);
115                        let name_tokens = self.count_text(&c.function.name);
116                        // Use saturating_add to prevent overflow
117                        args_tokens
118                            .saturating_add(id_tokens)
119                            .saturating_add(name_tokens)
120                            .saturating_add(5) // type overhead
121                    })
122                    .fold(0u32, |acc, x| acc.saturating_add(x))
123            })
124            .unwrap_or(0);
125
126        // Add tokens for tool_call_id if present
127        let tool_call_id_tokens = message
128            .tool_call_id
129            .as_ref()
130            .map(|id| self.count_text(id).saturating_add(3)) // +3 for field name overhead
131            .unwrap_or(0);
132
133        // Use saturating_add to prevent overflow
134        content_tokens
135            .saturating_add(tool_calls_tokens)
136            .saturating_add(tool_call_id_tokens)
137            .saturating_add(self.metadata_overhead)
138    }
139
140    fn count_text(&self, text: &str) -> u32 {
141        if text.is_empty() {
142            return 0;
143        }
144
145        let char_count = text.chars().count() as f64;
146        let base_tokens = char_count / self.chars_per_token;
147        let adjusted_tokens = base_tokens * self.safety_margin;
148
149        adjusted_tokens.ceil() as u32
150    }
151}
152
153/// Accurate BPE-based token counter using OpenAI's o200k_base encoding.
154///
155/// Uses `tiktoken-rs` with the vocabulary bundled at compile time — no runtime
156/// downloads. This is the recommended counter for production use.
157#[derive(Debug)]
158pub struct TiktokenTokenCounter {
159    /// Per-message metadata overhead in tokens (role markers, formatting, etc.)
160    metadata_overhead: u32,
161}
162
163impl TiktokenTokenCounter {
164    /// Create with a custom metadata overhead.
165    pub fn new(metadata_overhead: u32) -> Self {
166        Self { metadata_overhead }
167    }
168
169    /// Truncate `text` to at most `max_tokens` tokens, keeping the START.
170    ///
171    /// Encodes the text **once** and decodes the first `max_tokens` tokens back
172    /// to a string — O(N) (one encode + one decode), versus the O(N²)
173    /// char-by-char re-tokenization the previous `find_prefix_within_tokens`
174    /// performed (which called `count_text(&text[..i])` on every char index).
175    ///
176    /// # Semantics
177    /// - `max_tokens == 0` → empty string (exactly 0 tokens; never exceeds budget).
178    /// - Text already within `max_tokens` → returned unchanged (fast path).
179    /// - Otherwise the result is an exact prefix of `text` (its START preserved),
180    ///   is valid UTF-8, and re-counts to ≤ `max_tokens`.
181    ///
182    /// If the o200k encoder is unavailable (the issue #25 fallback path), this
183    /// degrades to a conservative char-based cut instead of panicking.
184    pub fn truncate_to_token_prefix(&self, text: &str, max_tokens: u32) -> String {
185        if max_tokens == 0 {
186            return String::new();
187        }
188        let Some(encoder) = o200k_encoder() else {
189            return heuristic_char_prefix(text, max_tokens);
190        };
191        // One encode — same encoder `count_text` uses, so the fast-path length
192        // check is consistent with `count_text(text)`.
193        let tokens = encoder.encode_with_special_tokens(text);
194        if (tokens.len() as u32) <= max_tokens {
195            return text.to_string();
196        }
197        let end = max_tokens as usize;
198        match encoder.decode_bytes(&tokens[..end]) {
199            // `decode_bytes` yields exactly the bytes of `text` spanned by the
200            // first `end` tokens — a byte-prefix of `text`. A token boundary can
201            // fall inside a multi-byte UTF-8 char, so trim any partial trailing
202            // char: the result stays a valid-UTF-8 exact prefix of `text`.
203            Ok(bytes) => valid_utf8_prefix(bytes),
204            Err(_) => heuristic_char_prefix(text, max_tokens),
205        }
206    }
207
208    /// Truncate `text` to at most `max_tokens` tokens, keeping the END.
209    ///
210    /// Symmetric to [`truncate_to_token_prefix`](Self::truncate_to_token_prefix):
211    /// encodes once and decodes the **last** `max_tokens` tokens. Same budget /
212    /// fast-path / fallback semantics; the result is a valid-UTF-8 exact suffix
213    /// of `text` (its END preserved) that re-counts to ≤ `max_tokens`.
214    pub fn truncate_to_token_suffix(&self, text: &str, max_tokens: u32) -> String {
215        if max_tokens == 0 {
216            return String::new();
217        }
218        let Some(encoder) = o200k_encoder() else {
219            return heuristic_char_suffix(text, max_tokens);
220        };
221        let tokens = encoder.encode_with_special_tokens(text);
222        if (tokens.len() as u32) <= max_tokens {
223            return text.to_string();
224        }
225        let start = tokens.len() - (max_tokens as usize);
226        match encoder.decode_bytes(&tokens[start..]) {
227            // The last `max_tokens` tokens span a byte-suffix of `text`; a
228            // boundary may split a *leading* multi-byte char, so drop any partial
229            // leading bytes to keep the result valid UTF-8 (still an exact suffix
230            // of `text`).
231            Ok(bytes) => valid_utf8_suffix(bytes),
232            Err(_) => heuristic_char_suffix(text, max_tokens),
233        }
234    }
235}
236
237// ── Encode-once truncation helpers ───────────────────────────────────────────
238//
239// These are used only by `TiktokenTokenCounter::truncate_to_token_{prefix,suffix}`.
240
241/// Conservative char-based prefix used solely when the BPE encoder is
242/// unavailable (the issue #25 fallback). Sized so the `HeuristicTokenCounter`
243/// estimate (chars/4 · 1.1) stays within budget.
244fn heuristic_char_prefix(text: &str, max_tokens: u32) -> String {
245    text.chars()
246        .take(heuristic_char_budget(max_tokens))
247        .collect()
248}
249
250/// Conservative char-based suffix — symmetric to [`heuristic_char_prefix`].
251fn heuristic_char_suffix(text: &str, max_tokens: u32) -> String {
252    let max_chars = heuristic_char_budget(max_tokens);
253    let skip = text.chars().count().saturating_sub(max_chars);
254    text.chars().skip(skip).collect()
255}
256
257/// Number of chars whose heuristic token estimate (ceil(chars/4 · 1.1)) is
258/// ≤ `max_tokens`: solves ceil(c/4 · 1.1) ≤ max_tokens ⟺ c ≤ max_tokens·4/1.1.
259fn heuristic_char_budget(max_tokens: u32) -> usize {
260    ((max_tokens as f64) * 4.0 / 1.1).floor() as usize
261}
262
263/// Turn a byte-prefix of some valid UTF-8 text into a valid UTF-8 string,
264/// trimming a partial trailing multi-byte char if the token boundary landed
265/// mid-character. The result is still an exact prefix of the original text.
266fn valid_utf8_prefix(bytes: Vec<u8>) -> String {
267    let valid_up_to = match std::str::from_utf8(&bytes) {
268        Ok(_) => bytes.len(),
269        Err(e) => e.valid_up_to(),
270    };
271    // bytes[..valid_up_to] is valid UTF-8 → lossy is a zero-copy borrow.
272    String::from_utf8_lossy(&bytes[..valid_up_to]).into_owned()
273}
274
275/// Turn a byte-suffix of some valid UTF-8 text into a valid UTF-8 string,
276/// dropping leading bytes that belong to a partial multi-byte char. Because the
277/// input is a contiguous suffix of valid UTF-8 text, the only possible
278/// invalidity is a leading partial char (≤ 3 bytes), so this advances at most a
279/// couple of times — O(N) overall.
280fn valid_utf8_suffix(bytes: Vec<u8>) -> String {
281    let mut start = 0;
282    while start < bytes.len() {
283        if std::str::from_utf8(&bytes[start..]).is_ok() {
284            return String::from_utf8_lossy(&bytes[start..]).into_owned();
285        }
286        start += 1;
287    }
288    String::new()
289}
290
291impl Default for TiktokenTokenCounter {
292    fn default() -> Self {
293        Self {
294            metadata_overhead: 10,
295        }
296    }
297}
298
299impl TokenCounter for TiktokenTokenCounter {
300    fn count_message(&self, message: &Message) -> u32 {
301        let content_tokens = self.count_text(&message.content);
302
303        let tool_calls_tokens = message
304            .tool_calls
305            .as_ref()
306            .map(|tc| {
307                tc.iter()
308                    .map(|c| {
309                        let args_tokens = self.count_text(&c.function.arguments);
310                        let id_tokens = self.count_text(&c.id);
311                        let name_tokens = self.count_text(&c.function.name);
312                        args_tokens
313                            .saturating_add(id_tokens)
314                            .saturating_add(name_tokens)
315                            .saturating_add(5)
316                    })
317                    .fold(0u32, |acc, x| acc.saturating_add(x))
318            })
319            .unwrap_or(0);
320
321        let tool_call_id_tokens = message
322            .tool_call_id
323            .as_ref()
324            .map(|id| self.count_text(id).saturating_add(3))
325            .unwrap_or(0);
326
327        content_tokens
328            .saturating_add(tool_calls_tokens)
329            .saturating_add(tool_call_id_tokens)
330            .saturating_add(self.metadata_overhead)
331    }
332
333    fn count_text(&self, text: &str) -> u32 {
334        if text.is_empty() {
335            return 0;
336        }
337        match o200k_encoder() {
338            // Accurate BPE count.
339            Some(encoder) => encoder.encode_with_special_tokens(text).len() as u32,
340            // Encoder unavailable — degrade to the char-based heuristic instead
341            // of panicking. Reuses the existing HeuristicTokenCounter.
342            None => HeuristicTokenCounter::default().count_text(text),
343        }
344    }
345}
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350    use bamboo_domain::{FunctionCall, ToolCall};
351
352    #[test]
353    fn heuristic_counter_counts_text() {
354        let counter = HeuristicTokenCounter::default();
355
356        // "Hello, world!" = 13 chars -> 13/4 * 1.1 ≈ 3.57 -> 4 tokens
357        let tokens = counter.count_text("Hello, world!");
358        assert!(
359            (3..=5).contains(&tokens),
360            "Expected ~4 tokens, got {}",
361            tokens
362        );
363    }
364
365    #[test]
366    fn heuristic_counter_counts_empty_text() {
367        let counter = HeuristicTokenCounter::default();
368        assert_eq!(counter.count_text(""), 0);
369    }
370
371    #[test]
372    fn heuristic_counter_counts_user_message() {
373        let counter = HeuristicTokenCounter::default();
374        let message = Message::user("Hello, world!");
375
376        let tokens = counter.count_message(&message);
377        // Should include content + metadata overhead (10)
378        assert!(
379            tokens >= 10,
380            "Expected at least 10 tokens (content + metadata), got {}",
381            tokens
382        );
383    }
384
385    #[test]
386    fn heuristic_counter_counts_tool_calls() {
387        let counter = HeuristicTokenCounter::default();
388
389        let tool_call = ToolCall {
390            id: "call_123".to_string(),
391            tool_type: "function".to_string(),
392            function: FunctionCall {
393                name: "search".to_string(),
394                arguments: r#"{"query":"test"}"#.to_string(),
395            },
396        };
397
398        let message = Message::assistant("Let me search", Some(vec![tool_call]));
399
400        let tokens = counter.count_message(&message);
401        // Should include content + tool call (id + name + args) + metadata
402        assert!(tokens >= 15, "Expected at least 15 tokens, got {}", tokens);
403    }
404
405    #[test]
406    fn heuristic_counter_counts_tool_result() {
407        let counter = HeuristicTokenCounter::default();
408        let message = Message::tool_result("call_123", "Search results here");
409
410        let tokens = counter.count_message(&message);
411        // Should include content + tool_call_id + metadata
412        assert!(tokens >= 15, "Expected at least 15 tokens, got {}", tokens);
413    }
414
415    #[test]
416    fn heuristic_counter_counts_multiple_messages() {
417        let counter = HeuristicTokenCounter::default();
418        let messages = vec![
419            Message::system("You are helpful"),
420            Message::user("Hello"),
421            Message::assistant("Hi there", None),
422        ];
423
424        let total = counter.count_messages(&messages);
425        let sum: u32 = messages.iter().map(|m| counter.count_message(m)).sum();
426
427        assert_eq!(total, sum);
428    }
429
430    #[test]
431    fn custom_chars_per_token() {
432        let counter = HeuristicTokenCounter::new(2.0, 1.0, 0);
433        // With 2 chars per token, "test" (4 chars) = 2 tokens
434        let tokens = counter.count_text("test");
435        assert_eq!(tokens, 2);
436    }
437
438    #[test]
439    fn safety_margin_applied() {
440        let counter_no_margin = HeuristicTokenCounter::new(4.0, 1.0, 0);
441        let counter_with_margin = HeuristicTokenCounter::new(4.0, 1.1, 0);
442
443        let text = "Hello world!"; // 12 chars
444        let base = counter_no_margin.count_text(text);
445        let adjusted = counter_with_margin.count_text(text);
446
447        assert!(adjusted > base, "Safety margin should increase token count");
448    }
449
450    // --- TiktokenTokenCounter tests ---
451
452    #[test]
453    fn tiktoken_counter_counts_text() {
454        let counter = TiktokenTokenCounter::default();
455        let tokens = counter.count_text("Hello, world!");
456        // "Hello, world!" is 4 tokens with o200k_base
457        assert!(
458            (3..=6).contains(&tokens),
459            "Expected ~4 tokens, got {}",
460            tokens
461        );
462    }
463
464    #[test]
465    fn tiktoken_counter_counts_empty_text() {
466        let counter = TiktokenTokenCounter::default();
467        assert_eq!(counter.count_text(""), 0);
468    }
469
470    #[test]
471    fn tiktoken_counter_counts_cjk() {
472        let counter = TiktokenTokenCounter::default();
473        // CJK text: each character is typically 1-2 tokens
474        let tokens = counter.count_text("你好世界");
475        assert!(
476            (2..=8).contains(&tokens),
477            "Expected 2-8 tokens, got {}",
478            tokens
479        );
480    }
481
482    #[test]
483    fn tiktoken_counter_counts_user_message() {
484        let counter = TiktokenTokenCounter::default();
485        let message = Message::user("Hello, world!");
486        let tokens = counter.count_message(&message);
487        // Should include content + metadata overhead (10)
488        assert!(tokens >= 10, "Expected at least 10 tokens, got {}", tokens);
489    }
490
491    #[test]
492    fn tiktoken_counter_counts_tool_calls() {
493        let counter = TiktokenTokenCounter::default();
494        let tool_call = ToolCall {
495            id: "call_123".to_string(),
496            tool_type: "function".to_string(),
497            function: FunctionCall {
498                name: "search".to_string(),
499                arguments: r#"{"query":"test"}"#.to_string(),
500            },
501        };
502        let message = Message::assistant("Let me search", Some(vec![tool_call]));
503        let tokens = counter.count_message(&message);
504        assert!(tokens >= 15, "Expected at least 15 tokens, got {}", tokens);
505    }
506
507    #[test]
508    fn tiktoken_counter_more_accurate_than_heuristic() {
509        let heuristic = HeuristicTokenCounter::default();
510        let tiktoken = TiktokenTokenCounter::default();
511
512        let text = "The quick brown fox jumps over the lazy dog.";
513        let h_tokens = heuristic.count_text(text);
514        let t_tokens = tiktoken.count_text(text);
515
516        // Both should produce reasonable counts
517        assert!(h_tokens > 0 && t_tokens > 0);
518    }
519
520    #[test]
521    fn bundled_o200k_encoder_loads_successfully() {
522        // Regression guard: if the bundled o200k_base vocabulary ever fails to
523        // load (a build/link regression in tiktoken-rs), TiktokenTokenCounter
524        // would silently fall back to heuristic counting. Assert the bundled
525        // encoder actually loads so such a regression is caught here.
526        assert!(
527            o200k_base().is_ok(),
528            "bundled o200k_base tokenizer failed to load; \
529             suspected tiktoken-rs build/link regression"
530        );
531    }
532
533    // ── Encode-once truncation (issue #24: O(N²) → O(N)) ──
534
535    #[test]
536    fn truncate_prefix_keeps_start_and_stays_within_budget() {
537        let counter = TiktokenTokenCounter::default();
538        let text = "The quick brown fox jumps over the lazy dog. ".repeat(50);
539        // Sanity: text is well over the budget.
540        assert!(counter.count_text(&text) > 30);
541
542        let max_tokens = 30u32;
543        let prefix = counter.truncate_to_token_prefix(&text, max_tokens);
544
545        // (c) keep the START: prefix must be an exact prefix of `text`.
546        assert!(
547            text.starts_with(&prefix),
548            "prefix must be the START of text"
549        );
550        assert!(
551            !prefix.is_empty(),
552            "prefix should not be empty under budget"
553        );
554        // (a) never exceed max_tokens.
555        let count = counter.count_text(&prefix);
556        assert!(
557            count <= max_tokens,
558            "prefix token count {count} exceeds budget {max_tokens}"
559        );
560    }
561
562    #[test]
563    fn truncate_suffix_keeps_end_and_stays_within_budget() {
564        let counter = TiktokenTokenCounter::default();
565        let text = "The quick brown fox jumps over the lazy dog. ".repeat(50);
566        assert!(counter.count_text(&text) > 30);
567
568        let max_tokens = 30u32;
569        let suffix = counter.truncate_to_token_suffix(&text, max_tokens);
570
571        // (c) keep the END: suffix must be an exact suffix of `text`.
572        assert!(text.ends_with(&suffix), "suffix must be the END of text");
573        assert!(
574            !suffix.is_empty(),
575            "suffix should not be empty under budget"
576        );
577        // (a) never exceed max_tokens.
578        let count = counter.count_text(&suffix);
579        assert!(
580            count <= max_tokens,
581            "suffix token count {count} exceeds budget {max_tokens}"
582        );
583    }
584
585    #[test]
586    fn truncate_returns_text_unchanged_when_within_budget() {
587        let counter = TiktokenTokenCounter::default();
588        let text = "Hello, world!"; // a handful of tokens
589        assert!(counter.count_text(text) <= 1000);
590
591        assert_eq!(counter.truncate_to_token_prefix(text, 1000), text);
592        assert_eq!(counter.truncate_to_token_suffix(text, 1000), text);
593    }
594
595    #[test]
596    fn truncate_max_tokens_zero_returns_empty() {
597        let counter = TiktokenTokenCounter::default();
598        // (a) with budget 0 the only value that never exceeds it is empty.
599        assert_eq!(counter.truncate_to_token_prefix("Hello, world!", 0), "");
600        assert_eq!(counter.truncate_to_token_suffix("Hello, world!", 0), "");
601    }
602
603    #[test]
604    fn truncate_prefix_suffix_large_input_is_valid_and_within_budget() {
605        // Correctness + perf sanity on a ~100KB input mixing ASCII, CJK, digits
606        // and newlines — exercises multi-byte token-boundary alignment.
607        let counter = TiktokenTokenCounter::default();
608        let unit = "The quick brown fox 你好世界 jumps 1234567890 over.\n";
609        let text = unit.repeat(2_500);
610        assert!(text.len() > 100_000, "precondition: large input");
611        assert!(counter.count_text(&text) > 500);
612
613        let max_tokens = 500u32;
614
615        let prefix = counter.truncate_to_token_prefix(&text, max_tokens);
616        assert!(
617            text.starts_with(&prefix),
618            "prefix must be the START of text"
619        );
620        let pcount = counter.count_text(&prefix);
621        assert!(
622            pcount <= max_tokens,
623            "prefix token count {pcount} exceeds budget {max_tokens}"
624        );
625
626        let suffix = counter.truncate_to_token_suffix(&text, max_tokens);
627        assert!(text.ends_with(&suffix), "suffix must be the END of text");
628        let scount = counter.count_text(&suffix);
629        assert!(
630            scount <= max_tokens,
631            "suffix token count {scount} exceeds budget {max_tokens}"
632        );
633    }
634}