use std::sync::OnceLock;
use bamboo_domain::Message;
use tiktoken_rs::o200k_base;
use tiktoken_rs::CoreBPE;
static O200K_ENCODER: OnceLock<Option<CoreBPE>> = OnceLock::new();
fn o200k_encoder() -> Option<&'static CoreBPE> {
O200K_ENCODER
.get_or_init(|| match o200k_base() {
Ok(encoder) => Some(encoder),
Err(err) => {
tracing::warn!(
error = %err,
"failed to load bundled o200k_base tokenizer; \
falling back to heuristic token counting"
);
None
}
})
.as_ref()
}
pub trait TokenCounter: Send + Sync {
fn count_message(&self, message: &Message) -> u32;
fn count_messages(&self, messages: &[Message]) -> u32 {
messages.iter().map(|m| self.count_message(m)).sum()
}
fn count_text(&self, text: &str) -> u32;
}
#[derive(Debug, Clone)]
pub struct HeuristicTokenCounter {
chars_per_token: f64,
safety_margin: f64,
metadata_overhead: u32,
}
impl HeuristicTokenCounter {
pub fn new(chars_per_token: f64, safety_margin: f64, metadata_overhead: u32) -> Self {
Self {
chars_per_token,
safety_margin,
metadata_overhead,
}
}
pub fn with_defaults() -> Self {
Self {
chars_per_token: 4.0,
safety_margin: 1.1,
metadata_overhead: 10,
}
}
}
impl Default for HeuristicTokenCounter {
fn default() -> Self {
Self::with_defaults()
}
}
impl TokenCounter for HeuristicTokenCounter {
fn count_message(&self, message: &Message) -> u32 {
let content_tokens = self.count_text(&message.content);
let tool_calls_tokens = message
.tool_calls
.as_ref()
.map(|tc| {
tc.iter()
.map(|c| {
let args_tokens = self.count_text(&c.function.arguments);
let id_tokens = self.count_text(&c.id);
let name_tokens = self.count_text(&c.function.name);
args_tokens
.saturating_add(id_tokens)
.saturating_add(name_tokens)
.saturating_add(5) })
.fold(0u32, |acc, x| acc.saturating_add(x))
})
.unwrap_or(0);
let tool_call_id_tokens = message
.tool_call_id
.as_ref()
.map(|id| self.count_text(id).saturating_add(3)) .unwrap_or(0);
content_tokens
.saturating_add(tool_calls_tokens)
.saturating_add(tool_call_id_tokens)
.saturating_add(self.metadata_overhead)
}
fn count_text(&self, text: &str) -> u32 {
if text.is_empty() {
return 0;
}
let char_count = text.chars().count() as f64;
let base_tokens = char_count / self.chars_per_token;
let adjusted_tokens = base_tokens * self.safety_margin;
adjusted_tokens.ceil() as u32
}
}
#[derive(Debug)]
pub struct TiktokenTokenCounter {
metadata_overhead: u32,
}
impl TiktokenTokenCounter {
pub fn new(metadata_overhead: u32) -> Self {
Self { metadata_overhead }
}
pub fn truncate_to_token_prefix(&self, text: &str, max_tokens: u32) -> String {
if max_tokens == 0 {
return String::new();
}
let Some(encoder) = o200k_encoder() else {
return heuristic_char_prefix(text, max_tokens);
};
let tokens = encoder.encode_with_special_tokens(text);
if (tokens.len() as u32) <= max_tokens {
return text.to_string();
}
let end = max_tokens as usize;
match encoder.decode_bytes(&tokens[..end]) {
Ok(bytes) => valid_utf8_prefix(bytes),
Err(_) => heuristic_char_prefix(text, max_tokens),
}
}
pub fn truncate_to_token_suffix(&self, text: &str, max_tokens: u32) -> String {
if max_tokens == 0 {
return String::new();
}
let Some(encoder) = o200k_encoder() else {
return heuristic_char_suffix(text, max_tokens);
};
let tokens = encoder.encode_with_special_tokens(text);
if (tokens.len() as u32) <= max_tokens {
return text.to_string();
}
let start = tokens.len() - (max_tokens as usize);
match encoder.decode_bytes(&tokens[start..]) {
Ok(bytes) => valid_utf8_suffix(bytes),
Err(_) => heuristic_char_suffix(text, max_tokens),
}
}
}
fn heuristic_char_prefix(text: &str, max_tokens: u32) -> String {
text.chars()
.take(heuristic_char_budget(max_tokens))
.collect()
}
fn heuristic_char_suffix(text: &str, max_tokens: u32) -> String {
let max_chars = heuristic_char_budget(max_tokens);
let skip = text.chars().count().saturating_sub(max_chars);
text.chars().skip(skip).collect()
}
fn heuristic_char_budget(max_tokens: u32) -> usize {
((max_tokens as f64) * 4.0 / 1.1).floor() as usize
}
fn valid_utf8_prefix(bytes: Vec<u8>) -> String {
let valid_up_to = match std::str::from_utf8(&bytes) {
Ok(_) => bytes.len(),
Err(e) => e.valid_up_to(),
};
String::from_utf8_lossy(&bytes[..valid_up_to]).into_owned()
}
fn valid_utf8_suffix(bytes: Vec<u8>) -> String {
let mut start = 0;
while start < bytes.len() {
if let Ok(_) = std::str::from_utf8(&bytes[start..]) {
return String::from_utf8_lossy(&bytes[start..]).into_owned();
}
start += 1;
}
String::new()
}
impl Default for TiktokenTokenCounter {
fn default() -> Self {
Self {
metadata_overhead: 10,
}
}
}
impl TokenCounter for TiktokenTokenCounter {
fn count_message(&self, message: &Message) -> u32 {
let content_tokens = self.count_text(&message.content);
let tool_calls_tokens = message
.tool_calls
.as_ref()
.map(|tc| {
tc.iter()
.map(|c| {
let args_tokens = self.count_text(&c.function.arguments);
let id_tokens = self.count_text(&c.id);
let name_tokens = self.count_text(&c.function.name);
args_tokens
.saturating_add(id_tokens)
.saturating_add(name_tokens)
.saturating_add(5)
})
.fold(0u32, |acc, x| acc.saturating_add(x))
})
.unwrap_or(0);
let tool_call_id_tokens = message
.tool_call_id
.as_ref()
.map(|id| self.count_text(id).saturating_add(3))
.unwrap_or(0);
content_tokens
.saturating_add(tool_calls_tokens)
.saturating_add(tool_call_id_tokens)
.saturating_add(self.metadata_overhead)
}
fn count_text(&self, text: &str) -> u32 {
if text.is_empty() {
return 0;
}
match o200k_encoder() {
Some(encoder) => encoder.encode_with_special_tokens(text).len() as u32,
None => HeuristicTokenCounter::default().count_text(text),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use bamboo_domain::{FunctionCall, ToolCall};
#[test]
fn heuristic_counter_counts_text() {
let counter = HeuristicTokenCounter::default();
let tokens = counter.count_text("Hello, world!");
assert!(
(3..=5).contains(&tokens),
"Expected ~4 tokens, got {}",
tokens
);
}
#[test]
fn heuristic_counter_counts_empty_text() {
let counter = HeuristicTokenCounter::default();
assert_eq!(counter.count_text(""), 0);
}
#[test]
fn heuristic_counter_counts_user_message() {
let counter = HeuristicTokenCounter::default();
let message = Message::user("Hello, world!");
let tokens = counter.count_message(&message);
assert!(
tokens >= 10,
"Expected at least 10 tokens (content + metadata), got {}",
tokens
);
}
#[test]
fn heuristic_counter_counts_tool_calls() {
let counter = HeuristicTokenCounter::default();
let tool_call = ToolCall {
id: "call_123".to_string(),
tool_type: "function".to_string(),
function: FunctionCall {
name: "search".to_string(),
arguments: r#"{"query":"test"}"#.to_string(),
},
};
let message = Message::assistant("Let me search", Some(vec![tool_call]));
let tokens = counter.count_message(&message);
assert!(tokens >= 15, "Expected at least 15 tokens, got {}", tokens);
}
#[test]
fn heuristic_counter_counts_tool_result() {
let counter = HeuristicTokenCounter::default();
let message = Message::tool_result("call_123", "Search results here");
let tokens = counter.count_message(&message);
assert!(tokens >= 15, "Expected at least 15 tokens, got {}", tokens);
}
#[test]
fn heuristic_counter_counts_multiple_messages() {
let counter = HeuristicTokenCounter::default();
let messages = vec![
Message::system("You are helpful"),
Message::user("Hello"),
Message::assistant("Hi there", None),
];
let total = counter.count_messages(&messages);
let sum: u32 = messages.iter().map(|m| counter.count_message(m)).sum();
assert_eq!(total, sum);
}
#[test]
fn custom_chars_per_token() {
let counter = HeuristicTokenCounter::new(2.0, 1.0, 0);
let tokens = counter.count_text("test");
assert_eq!(tokens, 2);
}
#[test]
fn safety_margin_applied() {
let counter_no_margin = HeuristicTokenCounter::new(4.0, 1.0, 0);
let counter_with_margin = HeuristicTokenCounter::new(4.0, 1.1, 0);
let text = "Hello world!"; let base = counter_no_margin.count_text(text);
let adjusted = counter_with_margin.count_text(text);
assert!(adjusted > base, "Safety margin should increase token count");
}
#[test]
fn tiktoken_counter_counts_text() {
let counter = TiktokenTokenCounter::default();
let tokens = counter.count_text("Hello, world!");
assert!(
(3..=6).contains(&tokens),
"Expected ~4 tokens, got {}",
tokens
);
}
#[test]
fn tiktoken_counter_counts_empty_text() {
let counter = TiktokenTokenCounter::default();
assert_eq!(counter.count_text(""), 0);
}
#[test]
fn tiktoken_counter_counts_cjk() {
let counter = TiktokenTokenCounter::default();
let tokens = counter.count_text("你好世界");
assert!(
(2..=8).contains(&tokens),
"Expected 2-8 tokens, got {}",
tokens
);
}
#[test]
fn tiktoken_counter_counts_user_message() {
let counter = TiktokenTokenCounter::default();
let message = Message::user("Hello, world!");
let tokens = counter.count_message(&message);
assert!(tokens >= 10, "Expected at least 10 tokens, got {}", tokens);
}
#[test]
fn tiktoken_counter_counts_tool_calls() {
let counter = TiktokenTokenCounter::default();
let tool_call = ToolCall {
id: "call_123".to_string(),
tool_type: "function".to_string(),
function: FunctionCall {
name: "search".to_string(),
arguments: r#"{"query":"test"}"#.to_string(),
},
};
let message = Message::assistant("Let me search", Some(vec![tool_call]));
let tokens = counter.count_message(&message);
assert!(tokens >= 15, "Expected at least 15 tokens, got {}", tokens);
}
#[test]
fn tiktoken_counter_more_accurate_than_heuristic() {
let heuristic = HeuristicTokenCounter::default();
let tiktoken = TiktokenTokenCounter::default();
let text = "The quick brown fox jumps over the lazy dog.";
let h_tokens = heuristic.count_text(text);
let t_tokens = tiktoken.count_text(text);
assert!(h_tokens > 0 && t_tokens > 0);
}
#[test]
fn bundled_o200k_encoder_loads_successfully() {
assert!(
o200k_base().is_ok(),
"bundled o200k_base tokenizer failed to load; \
suspected tiktoken-rs build/link regression"
);
}
#[test]
fn truncate_prefix_keeps_start_and_stays_within_budget() {
let counter = TiktokenTokenCounter::default();
let text = "The quick brown fox jumps over the lazy dog. ".repeat(50);
assert!(counter.count_text(&text) > 30);
let max_tokens = 30u32;
let prefix = counter.truncate_to_token_prefix(&text, max_tokens);
assert!(
text.starts_with(&prefix),
"prefix must be the START of text"
);
assert!(
!prefix.is_empty(),
"prefix should not be empty under budget"
);
let count = counter.count_text(&prefix);
assert!(
count <= max_tokens,
"prefix token count {count} exceeds budget {max_tokens}"
);
}
#[test]
fn truncate_suffix_keeps_end_and_stays_within_budget() {
let counter = TiktokenTokenCounter::default();
let text = "The quick brown fox jumps over the lazy dog. ".repeat(50);
assert!(counter.count_text(&text) > 30);
let max_tokens = 30u32;
let suffix = counter.truncate_to_token_suffix(&text, max_tokens);
assert!(text.ends_with(&suffix), "suffix must be the END of text");
assert!(
!suffix.is_empty(),
"suffix should not be empty under budget"
);
let count = counter.count_text(&suffix);
assert!(
count <= max_tokens,
"suffix token count {count} exceeds budget {max_tokens}"
);
}
#[test]
fn truncate_returns_text_unchanged_when_within_budget() {
let counter = TiktokenTokenCounter::default();
let text = "Hello, world!"; assert!(counter.count_text(text) <= 1000);
assert_eq!(counter.truncate_to_token_prefix(text, 1000), text);
assert_eq!(counter.truncate_to_token_suffix(text, 1000), text);
}
#[test]
fn truncate_max_tokens_zero_returns_empty() {
let counter = TiktokenTokenCounter::default();
assert_eq!(counter.truncate_to_token_prefix("Hello, world!", 0), "");
assert_eq!(counter.truncate_to_token_suffix("Hello, world!", 0), "");
}
#[test]
fn truncate_prefix_suffix_large_input_is_valid_and_within_budget() {
let counter = TiktokenTokenCounter::default();
let unit = "The quick brown fox 你好世界 jumps 1234567890 over.\n";
let text = unit.repeat(2_500);
assert!(text.len() > 100_000, "precondition: large input");
assert!(counter.count_text(&text) > 500);
let max_tokens = 500u32;
let prefix = counter.truncate_to_token_prefix(&text, max_tokens);
assert!(
text.starts_with(&prefix),
"prefix must be the START of text"
);
let pcount = counter.count_text(&prefix);
assert!(
pcount <= max_tokens,
"prefix token count {pcount} exceeds budget {max_tokens}"
);
let suffix = counter.truncate_to_token_suffix(&text, max_tokens);
assert!(text.ends_with(&suffix), "suffix must be the END of text");
let scount = counter.count_text(&suffix);
assert!(
scount <= max_tokens,
"suffix token count {scount} exceeds budget {max_tokens}"
);
}
}