use bamboo_agent_core::Message;
use std::sync::Arc;
use tiktoken_rs::o200k_base;
use tiktoken_rs::CoreBPE;
static O200K_ENCODER: std::sync::LazyLock<CoreBPE> =
std::sync::LazyLock::new(|| o200k_base().unwrap());
pub trait TokenCounter: Send + Sync {
fn count_message(&self, message: &Message) -> u32;
fn count_messages(&self, messages: &[Message]) -> u32 {
messages.iter().map(|m| self.count_message(m)).sum()
}
fn count_text(&self, text: &str) -> u32;
}
#[derive(Debug, Clone)]
pub struct HeuristicTokenCounter {
chars_per_token: f64,
safety_margin: f64,
metadata_overhead: u32,
}
impl HeuristicTokenCounter {
pub fn new(chars_per_token: f64, safety_margin: f64, metadata_overhead: u32) -> Self {
Self {
chars_per_token,
safety_margin,
metadata_overhead,
}
}
pub fn with_defaults() -> Self {
Self {
chars_per_token: 4.0,
safety_margin: 1.1,
metadata_overhead: 10,
}
}
}
impl Default for HeuristicTokenCounter {
fn default() -> Self {
Self::with_defaults()
}
}
impl TokenCounter for HeuristicTokenCounter {
fn count_message(&self, message: &Message) -> u32 {
let content_tokens = self.count_text(&message.content);
let tool_calls_tokens = message
.tool_calls
.as_ref()
.map(|tc| {
tc.iter()
.map(|c| {
let args_tokens = self.count_text(&c.function.arguments);
let id_tokens = self.count_text(&c.id);
let name_tokens = self.count_text(&c.function.name);
args_tokens
.saturating_add(id_tokens)
.saturating_add(name_tokens)
.saturating_add(5) })
.fold(0u32, |acc, x| acc.saturating_add(x))
})
.unwrap_or(0);
let tool_call_id_tokens = message
.tool_call_id
.as_ref()
.map(|id| self.count_text(id).saturating_add(3)) .unwrap_or(0);
content_tokens
.saturating_add(tool_calls_tokens)
.saturating_add(tool_call_id_tokens)
.saturating_add(self.metadata_overhead)
}
fn count_text(&self, text: &str) -> u32 {
if text.is_empty() {
return 0;
}
let char_count = text.chars().count() as f64;
let base_tokens = char_count / self.chars_per_token;
let adjusted_tokens = base_tokens * self.safety_margin;
adjusted_tokens.ceil() as u32
}
}
pub type SharedTokenCounter = Arc<dyn TokenCounter>;
#[derive(Debug)]
pub struct TiktokenTokenCounter {
metadata_overhead: u32,
}
impl TiktokenTokenCounter {
pub fn new(metadata_overhead: u32) -> Self {
Self { metadata_overhead }
}
}
impl Default for TiktokenTokenCounter {
fn default() -> Self {
Self {
metadata_overhead: 10,
}
}
}
impl TokenCounter for TiktokenTokenCounter {
fn count_message(&self, message: &Message) -> u32 {
let content_tokens = self.count_text(&message.content);
let tool_calls_tokens = message
.tool_calls
.as_ref()
.map(|tc| {
tc.iter()
.map(|c| {
let args_tokens = self.count_text(&c.function.arguments);
let id_tokens = self.count_text(&c.id);
let name_tokens = self.count_text(&c.function.name);
args_tokens
.saturating_add(id_tokens)
.saturating_add(name_tokens)
.saturating_add(5)
})
.fold(0u32, |acc, x| acc.saturating_add(x))
})
.unwrap_or(0);
let tool_call_id_tokens = message
.tool_call_id
.as_ref()
.map(|id| self.count_text(id).saturating_add(3))
.unwrap_or(0);
content_tokens
.saturating_add(tool_calls_tokens)
.saturating_add(tool_call_id_tokens)
.saturating_add(self.metadata_overhead)
}
fn count_text(&self, text: &str) -> u32 {
if text.is_empty() {
return 0;
}
let tokens = O200K_ENCODER.encode_with_special_tokens(text);
tokens.len() as u32
}
}
#[cfg(test)]
mod tests {
use super::*;
use bamboo_agent_core::{FunctionCall, ToolCall};
#[test]
fn heuristic_counter_counts_text() {
let counter = HeuristicTokenCounter::default();
let tokens = counter.count_text("Hello, world!");
assert!(
tokens >= 3 && tokens <= 5,
"Expected ~4 tokens, got {}",
tokens
);
}
#[test]
fn heuristic_counter_counts_empty_text() {
let counter = HeuristicTokenCounter::default();
assert_eq!(counter.count_text(""), 0);
}
#[test]
fn heuristic_counter_counts_user_message() {
let counter = HeuristicTokenCounter::default();
let message = Message::user("Hello, world!");
let tokens = counter.count_message(&message);
assert!(
tokens >= 10,
"Expected at least 10 tokens (content + metadata), got {}",
tokens
);
}
#[test]
fn heuristic_counter_counts_tool_calls() {
let counter = HeuristicTokenCounter::default();
let tool_call = ToolCall {
id: "call_123".to_string(),
tool_type: "function".to_string(),
function: FunctionCall {
name: "search".to_string(),
arguments: r#"{"query":"test"}"#.to_string(),
},
};
let message = Message::assistant("Let me search", Some(vec![tool_call]));
let tokens = counter.count_message(&message);
assert!(tokens >= 15, "Expected at least 15 tokens, got {}", tokens);
}
#[test]
fn heuristic_counter_counts_tool_result() {
let counter = HeuristicTokenCounter::default();
let message = Message::tool_result("call_123", "Search results here");
let tokens = counter.count_message(&message);
assert!(tokens >= 15, "Expected at least 15 tokens, got {}", tokens);
}
#[test]
fn heuristic_counter_counts_multiple_messages() {
let counter = HeuristicTokenCounter::default();
let messages = vec![
Message::system("You are helpful"),
Message::user("Hello"),
Message::assistant("Hi there", None),
];
let total = counter.count_messages(&messages);
let sum: u32 = messages.iter().map(|m| counter.count_message(m)).sum();
assert_eq!(total, sum);
}
#[test]
fn custom_chars_per_token() {
let counter = HeuristicTokenCounter::new(2.0, 1.0, 0);
let tokens = counter.count_text("test");
assert_eq!(tokens, 2);
}
#[test]
fn safety_margin_applied() {
let counter_no_margin = HeuristicTokenCounter::new(4.0, 1.0, 0);
let counter_with_margin = HeuristicTokenCounter::new(4.0, 1.1, 0);
let text = "Hello world!"; let base = counter_no_margin.count_text(text);
let adjusted = counter_with_margin.count_text(text);
assert!(adjusted > base, "Safety margin should increase token count");
}
#[test]
fn tiktoken_counter_counts_text() {
let counter = TiktokenTokenCounter::default();
let tokens = counter.count_text("Hello, world!");
assert!(
tokens >= 3 && tokens <= 6,
"Expected ~4 tokens, got {}",
tokens
);
}
#[test]
fn tiktoken_counter_counts_empty_text() {
let counter = TiktokenTokenCounter::default();
assert_eq!(counter.count_text(""), 0);
}
#[test]
fn tiktoken_counter_counts_cjk() {
let counter = TiktokenTokenCounter::default();
let tokens = counter.count_text("你好世界");
assert!(
tokens >= 2 && tokens <= 8,
"Expected 2-8 tokens, got {}",
tokens
);
}
#[test]
fn tiktoken_counter_counts_user_message() {
let counter = TiktokenTokenCounter::default();
let message = Message::user("Hello, world!");
let tokens = counter.count_message(&message);
assert!(tokens >= 10, "Expected at least 10 tokens, got {}", tokens);
}
#[test]
fn tiktoken_counter_counts_tool_calls() {
let counter = TiktokenTokenCounter::default();
let tool_call = ToolCall {
id: "call_123".to_string(),
tool_type: "function".to_string(),
function: FunctionCall {
name: "search".to_string(),
arguments: r#"{"query":"test"}"#.to_string(),
},
};
let message = Message::assistant("Let me search", Some(vec![tool_call]));
let tokens = counter.count_message(&message);
assert!(tokens >= 15, "Expected at least 15 tokens, got {}", tokens);
}
#[test]
fn tiktoken_counter_more_accurate_than_heuristic() {
let heuristic = HeuristicTokenCounter::default();
let tiktoken = TiktokenTokenCounter::default();
let text = "The quick brown fox jumps over the lazy dog.";
let h_tokens = heuristic.count_text(text);
let t_tokens = tiktoken.count_text(text);
assert!(h_tokens > 0 && t_tokens > 0);
}
}