use crate::types::{ChatMessage, Conversation, MessageRole, ToolDefinition};
use crate::estimator_language::detect_language_class;
#[derive(Debug, Clone, Copy)]
pub struct TokenEstimator;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ConversationTokenEstimate {
pub system_prompt: u32,
pub summary: u32,
pub history: u32,
pub total: u32,
}
pub const MESSAGE_OVERHEAD_TOKENS: u32 = 4;
impl TokenEstimator {
#[must_use]
pub fn estimate_tokens(text: &str) -> u32 {
if text.is_empty() {
return 0;
}
#[allow(clippy::cast_precision_loss)]
let total_chars = text.len() as f64;
let chars_per_token = detect_language_class(text).chars_per_token();
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let tokens = (total_chars / chars_per_token).ceil() as u32;
tokens.max(1)
}
#[must_use]
pub fn estimate_message(message: &ChatMessage) -> u32 {
Self::estimate_tokens(&message.content) + MESSAGE_OVERHEAD_TOKENS
}
#[must_use]
pub fn estimate_messages(messages: &[ChatMessage]) -> u32 {
messages.iter().map(Self::estimate_message).sum()
}
#[must_use]
pub fn estimate_conversation(conversation: &Conversation) -> ConversationTokenEstimate {
let system_prompt = conversation
.system_prompt
.as_deref()
.map_or(0, Self::estimate_tokens);
let summary = conversation
.summary
.as_deref()
.map_or(0, Self::estimate_tokens);
let history = Self::estimate_messages(&conversation.messages);
ConversationTokenEstimate {
system_prompt,
summary,
history,
total: system_prompt + summary + history,
}
}
#[must_use]
pub fn estimate_tool_definition(tool: &ToolDefinition) -> u32 {
let name_tokens = Self::estimate_tokens(&tool.name);
let desc_tokens = Self::estimate_tokens(&tool.description);
let param_tokens: u32 = tool
.parameters
.properties
.values()
.map(|p| Self::estimate_tokens(&p.description) + 2) .sum();
let schema_overhead = 8;
name_tokens + desc_tokens + param_tokens + schema_overhead
}
#[must_use]
pub fn estimate_tool_definitions(tools: &[ToolDefinition]) -> u32 {
tools.iter().map(Self::estimate_tool_definition).sum()
}
}
#[must_use]
pub fn role_token_cost(role: MessageRole) -> u32 {
match role {
MessageRole::System | MessageRole::User | MessageRole::Assistant => 2,
MessageRole::Tool => 3, }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_string_is_zero_tokens() {
assert_eq!(TokenEstimator::estimate_tokens(""), 0);
}
#[test]
fn short_ascii_text() {
assert_eq!(TokenEstimator::estimate_tokens("Hello"), 2);
}
#[test]
fn longer_ascii_text() {
let text = "The quick brown fox jumps over the lazy dog";
assert_eq!(TokenEstimator::estimate_tokens(text), 11);
}
#[test]
fn german_text_uses_language_class_ratio() {
let text = "Ünüöäßüöäßüöäßüöäß ÖÜÄ";
let estimate = TokenEstimator::estimate_tokens(text);
assert!(estimate > 0);
#[allow(
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
clippy::cast_sign_loss
)]
let expected = (text.len() as f64 / 4.0).ceil() as u32;
assert_eq!(estimate, expected);
}
#[test]
fn message_includes_overhead() {
let msg = ChatMessage::user("Hello");
let content_tokens = TokenEstimator::estimate_tokens("Hello");
assert_eq!(
TokenEstimator::estimate_message(&msg),
content_tokens + MESSAGE_OVERHEAD_TOKENS
);
}
#[test]
fn conversation_estimate_breakdown() {
let mut conv = Conversation::with_system_prompt("You are helpful.");
conv.summary = Some("Previously discussed weather.".to_string());
conv.add_user_message("What's the weather?");
conv.add_assistant_message("It's sunny today.");
let est = TokenEstimator::estimate_conversation(&conv);
assert!(est.system_prompt > 0);
assert!(est.summary > 0);
assert!(est.history > 0);
assert_eq!(est.total, est.system_prompt + est.summary + est.history);
}
#[test]
fn empty_conversation_estimate() {
let conv = Conversation::new();
let est = TokenEstimator::estimate_conversation(&conv);
assert_eq!(est.total, 0);
}
#[test]
fn single_char_is_at_least_one_token() {
assert_eq!(TokenEstimator::estimate_tokens("a"), 1);
}
#[test]
fn cjk_text_uses_language_class_ratio() {
let text = "你好世界这是测试文本";
let estimate = TokenEstimator::estimate_tokens(text);
#[allow(
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
clippy::cast_sign_loss
)]
let expected = (text.len() as f64 / 1.5).ceil() as u32;
assert_eq!(estimate, expected);
}
#[test]
fn cyrillic_text_uses_language_class_ratio() {
let text = "Привет мир как дела";
let estimate = TokenEstimator::estimate_tokens(text);
#[allow(
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
clippy::cast_sign_loss
)]
let expected = (text.len() as f64 / 2.5).ceil() as u32;
assert_eq!(estimate, expected);
}
}