use crate::contracts::runtime::tool_call::ToolDescriptor;
use crate::contracts::thread::Message;
const CHARS_PER_TOKEN_ASCII: f32 = 4.0;
const CHARS_PER_TOKEN_CJK: f32 = 1.5;
const MESSAGE_OVERHEAD: usize = 4;
const TOOL_CALL_OVERHEAD: usize = 20;
const TOOL_DESCRIPTOR_OVERHEAD: usize = 20;
fn is_cjk(c: char) -> bool {
matches!(c,
'\u{4E00}'..='\u{9FFF}' | '\u{3400}'..='\u{4DBF}' | '\u{F900}'..='\u{FAFF}' | '\u{3000}'..='\u{303F}' | '\u{3040}'..='\u{309F}' | '\u{30A0}'..='\u{30FF}' | '\u{AC00}'..='\u{D7AF}' )
}
pub fn estimate_tokens(text: &str) -> usize {
if text.is_empty() {
return 0;
}
let mut cjk_chars = 0usize;
let mut other_chars = 0usize;
for c in text.chars() {
if is_cjk(c) {
cjk_chars += 1;
} else {
other_chars += 1;
}
}
let cjk_tokens = (cjk_chars as f32 / CHARS_PER_TOKEN_CJK).ceil() as usize;
let ascii_tokens = (other_chars as f32 / CHARS_PER_TOKEN_ASCII).ceil() as usize;
(cjk_tokens + ascii_tokens).max(1)
}
pub fn estimate_message_tokens(msg: &Message) -> usize {
let content_tokens = estimate_tokens(&msg.content);
let tool_call_tokens: usize = msg
.tool_calls
.as_ref()
.map(|calls| {
calls
.iter()
.map(|c| {
estimate_tokens(&c.name)
+ estimate_tokens(&c.arguments.to_string())
+ TOOL_CALL_OVERHEAD
})
.sum()
})
.unwrap_or(0);
content_tokens + tool_call_tokens + MESSAGE_OVERHEAD
}
pub fn estimate_messages_tokens(messages: &[Message]) -> usize {
messages.iter().map(estimate_message_tokens).sum()
}
pub fn estimate_tool_tokens(tools: &[ToolDescriptor]) -> usize {
tools
.iter()
.map(|t| {
estimate_tokens(&t.name)
+ estimate_tokens(&t.description)
+ estimate_tokens(&t.parameters.to_string())
+ TOOL_DESCRIPTOR_OVERHEAD
})
.sum()
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn estimate_tokens_empty() {
assert_eq!(estimate_tokens(""), 0);
}
#[test]
fn estimate_tokens_ascii() {
let tokens = estimate_tokens("Hello world");
assert!((2..=5).contains(&tokens), "got {tokens}");
}
#[test]
fn estimate_tokens_cjk() {
let tokens = estimate_tokens("你好世界");
assert!((2..=5).contains(&tokens), "got {tokens}");
}
#[test]
fn estimate_tokens_mixed() {
let tokens = estimate_tokens("Hello 你好 world 世界");
assert!((4..=10).contains(&tokens), "got {tokens}");
}
#[test]
fn estimate_tokens_code_block() {
let code = "fn main() {\n let x = compute(42);\n return x;\n}";
let tokens = estimate_tokens(code);
assert!((8..=20).contains(&tokens), "got {tokens}");
}
#[test]
fn estimate_message_tokens_simple() {
let msg = Message::user("What is 2+2?");
let tokens = estimate_message_tokens(&msg);
assert!(tokens >= 5, "got {tokens}");
}
#[test]
fn estimate_message_tokens_with_tool_calls() {
use crate::contracts::thread::ToolCall;
let msg = Message::assistant_with_tool_calls(
"I'll calculate that.",
vec![ToolCall::new(
"call_1",
"calculator",
json!({"expr": "2+2"}),
)],
);
let tokens = estimate_message_tokens(&msg);
assert!(tokens >= 15, "got {tokens}");
}
#[test]
fn estimate_tool_tokens_basic() {
let tools = vec![
ToolDescriptor::new("calc", "Calculator", "Evaluate math expressions").with_parameters(
json!({
"type": "object",
"properties": {
"expression": { "type": "string" }
},
"required": ["expression"]
}),
),
];
let tokens = estimate_tool_tokens(&tools);
assert!(tokens >= 20, "got {tokens}");
}
#[test]
fn estimate_messages_tokens_multiple() {
let messages = vec![
Message::user("Hello"),
Message::assistant("Hi there!"),
Message::user("How are you?"),
];
let total = estimate_messages_tokens(&messages);
let sum: usize = messages.iter().map(estimate_message_tokens).sum();
assert_eq!(total, sum);
}
}