use crate::chat::{ContentBlock, Message, SystemPrompt};
#[must_use]
pub fn estimate_text_tokens(text: &str) -> usize {
if text.is_ascii() {
return text.len().div_ceil(3);
}
let (cjk, other) = count_cjk_and_other_chars(text);
let deepseek_ratio = other
.saturating_mul(3)
.div_ceil(10)
.saturating_add(cjk.saturating_mul(6).div_ceil(10));
let legacy_chars_third = cjk.saturating_add(other).div_ceil(3);
deepseek_ratio.max(legacy_chars_third)
}
pub const MESSAGE_FRAMING_TOKENS: usize = 12;
pub const SESSION_FRAMING_TOKENS: usize = 48;
const MESSAGE_TOKENS_NUMERATOR: usize = 3;
const MESSAGE_TOKENS_DENOMINATOR: usize = 2;
#[derive(Debug, Clone, Copy, Default)]
pub struct TokenEstimator;
impl TokenEstimator {
#[must_use]
#[inline]
pub fn estimate_text(self, text: &str) -> usize {
estimate_text_tokens(text)
}
#[must_use]
pub fn estimate_block(self, block: &ContentBlock, include_thinking: bool) -> usize {
match block {
ContentBlock::Text { text, .. } => estimate_text_tokens(text),
ContentBlock::Thinking { thinking } => {
if include_thinking {
estimate_text_tokens(thinking)
} else {
0
}
}
ContentBlock::ToolUse { input, .. } => estimate_text_tokens(&input.to_string()),
ContentBlock::ToolResult { content, .. } => estimate_text_tokens(content),
ContentBlock::ServerToolUse { input, .. } => estimate_text_tokens(&input.to_string()),
ContentBlock::ToolSearchToolResult { content, .. } => {
estimate_text_tokens(&content.to_string())
}
ContentBlock::CodeExecutionToolResult { content, .. } => {
estimate_text_tokens(&content.to_string())
}
}
}
#[must_use]
pub fn estimate_message(self, message: &Message, include_thinking: bool) -> usize {
message
.content
.iter()
.map(|block| self.estimate_block(block, include_thinking))
.sum()
}
#[must_use]
pub fn estimate_system(self, system: Option<&SystemPrompt>) -> usize {
match system {
Some(SystemPrompt::Text(text)) => estimate_text_tokens(text),
Some(SystemPrompt::Blocks(blocks)) => {
blocks.iter().map(|b| estimate_text_tokens(&b.text)).sum()
}
None => 0,
}
}
#[must_use]
pub fn estimate_request_input(
self,
messages: &[Message],
system: Option<&SystemPrompt>,
include_thinking: bool,
) -> usize {
let raw: usize = messages
.iter()
.map(|m| self.estimate_message(m, include_thinking))
.sum();
let message_tokens = raw
.saturating_mul(MESSAGE_TOKENS_NUMERATOR)
.div_ceil(MESSAGE_TOKENS_DENOMINATOR);
let system_tokens = self.estimate_system(system);
let framing = messages
.len()
.saturating_mul(MESSAGE_FRAMING_TOKENS)
.saturating_add(SESSION_FRAMING_TOKENS);
message_tokens
.saturating_add(system_tokens)
.saturating_add(framing)
}
#[must_use]
pub fn estimate_request_input_with_selective_thinking(
self,
messages: &[Message],
system: Option<&SystemPrompt>,
) -> usize {
let raw: usize = messages
.iter()
.map(|m| {
let has_tool_use = m
.content
.iter()
.any(|b| matches!(b, ContentBlock::ToolUse { .. }));
self.estimate_message(m, has_tool_use)
})
.sum();
let message_tokens = raw
.saturating_mul(MESSAGE_TOKENS_NUMERATOR)
.div_ceil(MESSAGE_TOKENS_DENOMINATOR);
let system_tokens = self.estimate_system(system);
let framing = messages
.len()
.saturating_mul(MESSAGE_FRAMING_TOKENS)
.saturating_add(SESSION_FRAMING_TOKENS);
message_tokens
.saturating_add(system_tokens)
.saturating_add(framing)
}
}
fn count_cjk_and_other_chars(text: &str) -> (usize, usize) {
let mut cjk = 0usize;
let mut other = 0usize;
for ch in text.chars() {
if is_cjk_char(ch) {
cjk += 1;
} else {
other += 1;
}
}
(cjk, other)
}
fn is_cjk_char(ch: char) -> bool {
matches!(
ch,
'\u{4e00}'..='\u{9fff}'
| '\u{3400}'..='\u{4dbf}'
| '\u{3000}'..='\u{303f}'
| '\u{ff00}'..='\u{ffef}'
| '\u{2e80}'..='\u{2fdf}'
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_text_is_zero() {
assert_eq!(estimate_text_tokens(""), 0);
}
#[test]
fn ascii_uses_chars_third_envelope() {
let text = "a".repeat(30);
assert_eq!(estimate_text_tokens(&text), 10);
}
#[test]
fn cjk_uses_deepseek_envelope() {
let text = "汉".repeat(10);
assert_eq!(estimate_text_tokens(&text), 6);
}
#[test]
fn never_below_either_legacy_calibration() {
for text in [
"hello world",
"纯中文内容若干字符",
"mixed 中英 content with 标点。",
"fn main() { println!(\"hi\"); }",
] {
let estimate = estimate_text_tokens(text);
let chars = text.chars().count();
let (cjk, other) = count_cjk_and_other_chars(text);
let deepseek =
other.saturating_mul(3).div_ceil(10) + cjk.saturating_mul(6).div_ceil(10);
assert!(estimate >= chars.div_ceil(3), "below legacy core: {text}");
assert!(estimate >= deepseek, "below DeepSeek ratio: {text}");
}
}
#[test]
fn cjk_classification_covers_fullwidth_punctuation() {
assert!(is_cjk_char('。'));
assert!(is_cjk_char(','));
assert!(is_cjk_char('汉'));
assert!(!is_cjk_char('a'));
assert!(!is_cjk_char(' '));
}
#[test]
fn token_estimator_estimate_text_matches_free_fn() {
let est = TokenEstimator;
for text in ["hello", "世界", "mixed 中英 text", ""] {
assert_eq!(est.estimate_text(text), estimate_text_tokens(text));
}
}
#[test]
fn token_estimator_exclude_thinking_when_flag_false() {
let est = TokenEstimator;
let block = ContentBlock::Thinking {
thinking: "lots of reasoning".to_string(),
};
assert_eq!(est.estimate_block(&block, false), 0);
assert!(est.estimate_block(&block, true) > 0);
}
#[test]
fn token_estimator_request_input_formula() {
use crate::chat::Message;
let est = TokenEstimator;
let messages = vec![Message {
role: "user".to_string(),
content: vec![ContentBlock::Text {
text: "hello world".to_string(),
cache_control: None,
}],
}];
let raw = est.estimate_message(&messages[0], true);
let expected_msg_tokens = raw.saturating_mul(3).div_ceil(2);
let framing = MESSAGE_FRAMING_TOKENS + SESSION_FRAMING_TOKENS;
let expected = expected_msg_tokens + framing;
assert_eq!(est.estimate_request_input(&messages, None, true), expected);
}
#[test]
fn three_path_consistency_no_thinking_blocks() {
use crate::chat::Message;
let est = TokenEstimator;
let samples: &[&str] = &[
"fn main() { println!(\"Hello, world!\"); }",
"use std::collections::HashMap;",
"这是一段中文内容,用于测试 CJK 字符计费。",
"mixed 中英 content: struct Foo { bar: u32 }",
"",
&"x".repeat(1000),
&"汉".repeat(100),
];
for text in samples {
let messages = vec![Message {
role: "user".to_string(),
content: vec![ContentBlock::Text {
text: text.to_string(),
cache_control: None,
}],
}];
let path1 = est.estimate_request_input(&messages, None, true);
let path2 = est.estimate_request_input_with_selective_thinking(&messages, None);
let path3 = est.estimate_request_input(&messages, None, false);
assert_eq!(path1, path2, "path1 vs path2 diverge for: {text:?}");
assert_eq!(path1, path3, "path1 vs path3 diverge for: {text:?}");
let max_val = path1.max(path2).max(path3) as f64;
let min_val = path1.min(path2).min(path3) as f64;
if max_val > 0.0 {
let deviation_pct = (max_val - min_val) / max_val * 100.0;
assert!(
deviation_pct < 1.0,
"deviation {deviation_pct:.2}% >= 1% for: {text:?}"
);
}
}
}
}