1use std::sync::OnceLock;
11
12use bamboo_domain::Message;
13use tiktoken_rs::o200k_base;
14use tiktoken_rs::CoreBPE;
15
16static O200K_ENCODER: OnceLock<Option<CoreBPE>> = OnceLock::new();
23
24fn o200k_encoder() -> Option<&'static CoreBPE> {
30 O200K_ENCODER
31 .get_or_init(|| match o200k_base() {
32 Ok(encoder) => Some(encoder),
33 Err(err) => {
34 tracing::warn!(
35 error = %err,
36 "failed to load bundled o200k_base tokenizer; \
37 falling back to heuristic token counting"
38 );
39 None
40 }
41 })
42 .as_ref()
43}
44
45pub trait TokenCounter: Send + Sync {
47 fn count_message(&self, message: &Message) -> u32;
49
50 fn count_messages(&self, messages: &[Message]) -> u32 {
52 messages.iter().map(|m| self.count_message(m)).sum()
53 }
54
55 fn count_text(&self, text: &str) -> u32;
57}
58
59#[derive(Debug, Clone)]
66pub struct HeuristicTokenCounter {
67 chars_per_token: f64,
69 safety_margin: f64,
71 metadata_overhead: u32,
73}
74
75impl HeuristicTokenCounter {
76 pub fn new(chars_per_token: f64, safety_margin: f64, metadata_overhead: u32) -> Self {
78 Self {
79 chars_per_token,
80 safety_margin,
81 metadata_overhead,
82 }
83 }
84
85 pub fn with_defaults() -> Self {
87 Self {
88 chars_per_token: 4.0,
89 safety_margin: 1.1,
90 metadata_overhead: 10,
91 }
92 }
93}
94
95impl Default for HeuristicTokenCounter {
96 fn default() -> Self {
97 Self::with_defaults()
98 }
99}
100
101impl TokenCounter for HeuristicTokenCounter {
102 fn count_message(&self, message: &Message) -> u32 {
103 let content_tokens = self.count_text(&message.content);
104
105 let tool_calls_tokens = message
107 .tool_calls
108 .as_ref()
109 .map(|tc| {
110 tc.iter()
111 .map(|c| {
112 let args_tokens = self.count_text(&c.function.arguments);
114 let id_tokens = self.count_text(&c.id);
115 let name_tokens = self.count_text(&c.function.name);
116 args_tokens
118 .saturating_add(id_tokens)
119 .saturating_add(name_tokens)
120 .saturating_add(5) })
122 .fold(0u32, |acc, x| acc.saturating_add(x))
123 })
124 .unwrap_or(0);
125
126 let tool_call_id_tokens = message
128 .tool_call_id
129 .as_ref()
130 .map(|id| self.count_text(id).saturating_add(3)) .unwrap_or(0);
132
133 content_tokens
135 .saturating_add(tool_calls_tokens)
136 .saturating_add(tool_call_id_tokens)
137 .saturating_add(self.metadata_overhead)
138 }
139
140 fn count_text(&self, text: &str) -> u32 {
141 if text.is_empty() {
142 return 0;
143 }
144
145 let char_count = text.chars().count() as f64;
146 let base_tokens = char_count / self.chars_per_token;
147 let adjusted_tokens = base_tokens * self.safety_margin;
148
149 adjusted_tokens.ceil() as u32
150 }
151}
152
153#[derive(Debug)]
158pub struct TiktokenTokenCounter {
159 metadata_overhead: u32,
161}
162
163impl TiktokenTokenCounter {
164 pub fn new(metadata_overhead: u32) -> Self {
166 Self { metadata_overhead }
167 }
168
169 pub fn truncate_to_token_prefix(&self, text: &str, max_tokens: u32) -> String {
185 if max_tokens == 0 {
186 return String::new();
187 }
188 let Some(encoder) = o200k_encoder() else {
189 return heuristic_char_prefix(text, max_tokens);
190 };
191 let tokens = encoder.encode_with_special_tokens(text);
194 if (tokens.len() as u32) <= max_tokens {
195 return text.to_string();
196 }
197 let end = max_tokens as usize;
198 match encoder.decode_bytes(&tokens[..end]) {
199 Ok(bytes) => valid_utf8_prefix(bytes),
204 Err(_) => heuristic_char_prefix(text, max_tokens),
205 }
206 }
207
208 pub fn truncate_to_token_suffix(&self, text: &str, max_tokens: u32) -> String {
215 if max_tokens == 0 {
216 return String::new();
217 }
218 let Some(encoder) = o200k_encoder() else {
219 return heuristic_char_suffix(text, max_tokens);
220 };
221 let tokens = encoder.encode_with_special_tokens(text);
222 if (tokens.len() as u32) <= max_tokens {
223 return text.to_string();
224 }
225 let start = tokens.len() - (max_tokens as usize);
226 match encoder.decode_bytes(&tokens[start..]) {
227 Ok(bytes) => valid_utf8_suffix(bytes),
232 Err(_) => heuristic_char_suffix(text, max_tokens),
233 }
234 }
235}
236
237fn heuristic_char_prefix(text: &str, max_tokens: u32) -> String {
245 text.chars()
246 .take(heuristic_char_budget(max_tokens))
247 .collect()
248}
249
250fn heuristic_char_suffix(text: &str, max_tokens: u32) -> String {
252 let max_chars = heuristic_char_budget(max_tokens);
253 let skip = text.chars().count().saturating_sub(max_chars);
254 text.chars().skip(skip).collect()
255}
256
257fn heuristic_char_budget(max_tokens: u32) -> usize {
260 ((max_tokens as f64) * 4.0 / 1.1).floor() as usize
261}
262
263fn valid_utf8_prefix(bytes: Vec<u8>) -> String {
267 let valid_up_to = match std::str::from_utf8(&bytes) {
268 Ok(_) => bytes.len(),
269 Err(e) => e.valid_up_to(),
270 };
271 String::from_utf8_lossy(&bytes[..valid_up_to]).into_owned()
273}
274
275fn valid_utf8_suffix(bytes: Vec<u8>) -> String {
281 let mut start = 0;
282 while start < bytes.len() {
283 if std::str::from_utf8(&bytes[start..]).is_ok() {
284 return String::from_utf8_lossy(&bytes[start..]).into_owned();
285 }
286 start += 1;
287 }
288 String::new()
289}
290
291impl Default for TiktokenTokenCounter {
292 fn default() -> Self {
293 Self {
294 metadata_overhead: 10,
295 }
296 }
297}
298
299impl TokenCounter for TiktokenTokenCounter {
300 fn count_message(&self, message: &Message) -> u32 {
301 let content_tokens = self.count_text(&message.content);
302
303 let tool_calls_tokens = message
304 .tool_calls
305 .as_ref()
306 .map(|tc| {
307 tc.iter()
308 .map(|c| {
309 let args_tokens = self.count_text(&c.function.arguments);
310 let id_tokens = self.count_text(&c.id);
311 let name_tokens = self.count_text(&c.function.name);
312 args_tokens
313 .saturating_add(id_tokens)
314 .saturating_add(name_tokens)
315 .saturating_add(5)
316 })
317 .fold(0u32, |acc, x| acc.saturating_add(x))
318 })
319 .unwrap_or(0);
320
321 let tool_call_id_tokens = message
322 .tool_call_id
323 .as_ref()
324 .map(|id| self.count_text(id).saturating_add(3))
325 .unwrap_or(0);
326
327 content_tokens
328 .saturating_add(tool_calls_tokens)
329 .saturating_add(tool_call_id_tokens)
330 .saturating_add(self.metadata_overhead)
331 }
332
333 fn count_text(&self, text: &str) -> u32 {
334 if text.is_empty() {
335 return 0;
336 }
337 match o200k_encoder() {
338 Some(encoder) => encoder.encode_with_special_tokens(text).len() as u32,
340 None => HeuristicTokenCounter::default().count_text(text),
343 }
344 }
345}
346
347#[cfg(test)]
348mod tests {
349 use super::*;
350 use bamboo_domain::{FunctionCall, ToolCall};
351
352 #[test]
353 fn heuristic_counter_counts_text() {
354 let counter = HeuristicTokenCounter::default();
355
356 let tokens = counter.count_text("Hello, world!");
358 assert!(
359 (3..=5).contains(&tokens),
360 "Expected ~4 tokens, got {}",
361 tokens
362 );
363 }
364
365 #[test]
366 fn heuristic_counter_counts_empty_text() {
367 let counter = HeuristicTokenCounter::default();
368 assert_eq!(counter.count_text(""), 0);
369 }
370
371 #[test]
372 fn heuristic_counter_counts_user_message() {
373 let counter = HeuristicTokenCounter::default();
374 let message = Message::user("Hello, world!");
375
376 let tokens = counter.count_message(&message);
377 assert!(
379 tokens >= 10,
380 "Expected at least 10 tokens (content + metadata), got {}",
381 tokens
382 );
383 }
384
385 #[test]
386 fn heuristic_counter_counts_tool_calls() {
387 let counter = HeuristicTokenCounter::default();
388
389 let tool_call = ToolCall {
390 id: "call_123".to_string(),
391 tool_type: "function".to_string(),
392 function: FunctionCall {
393 name: "search".to_string(),
394 arguments: r#"{"query":"test"}"#.to_string(),
395 },
396 };
397
398 let message = Message::assistant("Let me search", Some(vec![tool_call]));
399
400 let tokens = counter.count_message(&message);
401 assert!(tokens >= 15, "Expected at least 15 tokens, got {}", tokens);
403 }
404
405 #[test]
406 fn heuristic_counter_counts_tool_result() {
407 let counter = HeuristicTokenCounter::default();
408 let message = Message::tool_result("call_123", "Search results here");
409
410 let tokens = counter.count_message(&message);
411 assert!(tokens >= 15, "Expected at least 15 tokens, got {}", tokens);
413 }
414
415 #[test]
416 fn heuristic_counter_counts_multiple_messages() {
417 let counter = HeuristicTokenCounter::default();
418 let messages = vec![
419 Message::system("You are helpful"),
420 Message::user("Hello"),
421 Message::assistant("Hi there", None),
422 ];
423
424 let total = counter.count_messages(&messages);
425 let sum: u32 = messages.iter().map(|m| counter.count_message(m)).sum();
426
427 assert_eq!(total, sum);
428 }
429
430 #[test]
431 fn custom_chars_per_token() {
432 let counter = HeuristicTokenCounter::new(2.0, 1.0, 0);
433 let tokens = counter.count_text("test");
435 assert_eq!(tokens, 2);
436 }
437
438 #[test]
439 fn safety_margin_applied() {
440 let counter_no_margin = HeuristicTokenCounter::new(4.0, 1.0, 0);
441 let counter_with_margin = HeuristicTokenCounter::new(4.0, 1.1, 0);
442
443 let text = "Hello world!"; let base = counter_no_margin.count_text(text);
445 let adjusted = counter_with_margin.count_text(text);
446
447 assert!(adjusted > base, "Safety margin should increase token count");
448 }
449
450 #[test]
453 fn tiktoken_counter_counts_text() {
454 let counter = TiktokenTokenCounter::default();
455 let tokens = counter.count_text("Hello, world!");
456 assert!(
458 (3..=6).contains(&tokens),
459 "Expected ~4 tokens, got {}",
460 tokens
461 );
462 }
463
464 #[test]
465 fn tiktoken_counter_counts_empty_text() {
466 let counter = TiktokenTokenCounter::default();
467 assert_eq!(counter.count_text(""), 0);
468 }
469
470 #[test]
471 fn tiktoken_counter_counts_cjk() {
472 let counter = TiktokenTokenCounter::default();
473 let tokens = counter.count_text("你好世界");
475 assert!(
476 (2..=8).contains(&tokens),
477 "Expected 2-8 tokens, got {}",
478 tokens
479 );
480 }
481
482 #[test]
483 fn tiktoken_counter_counts_user_message() {
484 let counter = TiktokenTokenCounter::default();
485 let message = Message::user("Hello, world!");
486 let tokens = counter.count_message(&message);
487 assert!(tokens >= 10, "Expected at least 10 tokens, got {}", tokens);
489 }
490
491 #[test]
492 fn tiktoken_counter_counts_tool_calls() {
493 let counter = TiktokenTokenCounter::default();
494 let tool_call = ToolCall {
495 id: "call_123".to_string(),
496 tool_type: "function".to_string(),
497 function: FunctionCall {
498 name: "search".to_string(),
499 arguments: r#"{"query":"test"}"#.to_string(),
500 },
501 };
502 let message = Message::assistant("Let me search", Some(vec![tool_call]));
503 let tokens = counter.count_message(&message);
504 assert!(tokens >= 15, "Expected at least 15 tokens, got {}", tokens);
505 }
506
507 #[test]
508 fn tiktoken_counter_more_accurate_than_heuristic() {
509 let heuristic = HeuristicTokenCounter::default();
510 let tiktoken = TiktokenTokenCounter::default();
511
512 let text = "The quick brown fox jumps over the lazy dog.";
513 let h_tokens = heuristic.count_text(text);
514 let t_tokens = tiktoken.count_text(text);
515
516 assert!(h_tokens > 0 && t_tokens > 0);
518 }
519
520 #[test]
521 fn bundled_o200k_encoder_loads_successfully() {
522 assert!(
527 o200k_base().is_ok(),
528 "bundled o200k_base tokenizer failed to load; \
529 suspected tiktoken-rs build/link regression"
530 );
531 }
532
533 #[test]
536 fn truncate_prefix_keeps_start_and_stays_within_budget() {
537 let counter = TiktokenTokenCounter::default();
538 let text = "The quick brown fox jumps over the lazy dog. ".repeat(50);
539 assert!(counter.count_text(&text) > 30);
541
542 let max_tokens = 30u32;
543 let prefix = counter.truncate_to_token_prefix(&text, max_tokens);
544
545 assert!(
547 text.starts_with(&prefix),
548 "prefix must be the START of text"
549 );
550 assert!(
551 !prefix.is_empty(),
552 "prefix should not be empty under budget"
553 );
554 let count = counter.count_text(&prefix);
556 assert!(
557 count <= max_tokens,
558 "prefix token count {count} exceeds budget {max_tokens}"
559 );
560 }
561
562 #[test]
563 fn truncate_suffix_keeps_end_and_stays_within_budget() {
564 let counter = TiktokenTokenCounter::default();
565 let text = "The quick brown fox jumps over the lazy dog. ".repeat(50);
566 assert!(counter.count_text(&text) > 30);
567
568 let max_tokens = 30u32;
569 let suffix = counter.truncate_to_token_suffix(&text, max_tokens);
570
571 assert!(text.ends_with(&suffix), "suffix must be the END of text");
573 assert!(
574 !suffix.is_empty(),
575 "suffix should not be empty under budget"
576 );
577 let count = counter.count_text(&suffix);
579 assert!(
580 count <= max_tokens,
581 "suffix token count {count} exceeds budget {max_tokens}"
582 );
583 }
584
585 #[test]
586 fn truncate_returns_text_unchanged_when_within_budget() {
587 let counter = TiktokenTokenCounter::default();
588 let text = "Hello, world!"; assert!(counter.count_text(text) <= 1000);
590
591 assert_eq!(counter.truncate_to_token_prefix(text, 1000), text);
592 assert_eq!(counter.truncate_to_token_suffix(text, 1000), text);
593 }
594
595 #[test]
596 fn truncate_max_tokens_zero_returns_empty() {
597 let counter = TiktokenTokenCounter::default();
598 assert_eq!(counter.truncate_to_token_prefix("Hello, world!", 0), "");
600 assert_eq!(counter.truncate_to_token_suffix("Hello, world!", 0), "");
601 }
602
603 #[test]
604 fn truncate_prefix_suffix_large_input_is_valid_and_within_budget() {
605 let counter = TiktokenTokenCounter::default();
608 let unit = "The quick brown fox 你好世界 jumps 1234567890 over.\n";
609 let text = unit.repeat(2_500);
610 assert!(text.len() > 100_000, "precondition: large input");
611 assert!(counter.count_text(&text) > 500);
612
613 let max_tokens = 500u32;
614
615 let prefix = counter.truncate_to_token_prefix(&text, max_tokens);
616 assert!(
617 text.starts_with(&prefix),
618 "prefix must be the START of text"
619 );
620 let pcount = counter.count_text(&prefix);
621 assert!(
622 pcount <= max_tokens,
623 "prefix token count {pcount} exceeds budget {max_tokens}"
624 );
625
626 let suffix = counter.truncate_to_token_suffix(&text, max_tokens);
627 assert!(text.ends_with(&suffix), "suffix must be the END of text");
628 let scount = counter.count_text(&suffix);
629 assert!(
630 scount <= max_tokens,
631 "suffix token count {scount} exceeds budget {max_tokens}"
632 );
633 }
634}