use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::sync::OnceLock;
use dashmap::DashMap;
use tiktoken_rs::CoreBPE;
use zeph_common::text::estimate_tokens;
use zeph_llm::provider::{Message, MessagePart};
static BPE: OnceLock<Option<CoreBPE>> = OnceLock::new();
const CACHE_CAP: usize = 10_000;
const MAX_INPUT_LEN: usize = 65_536;
const FUNC_INIT: usize = 7;
const PROP_INIT: usize = 3;
const PROP_KEY: usize = 3;
const ENUM_INIT: isize = -3;
const ENUM_ITEM: usize = 3;
const FUNC_END: usize = 12;
const TOOL_USE_OVERHEAD: usize = 20;
const TOOL_RESULT_OVERHEAD: usize = 15;
const TOOL_OUTPUT_OVERHEAD: usize = 8;
const IMAGE_OVERHEAD: usize = 50;
const IMAGE_DEFAULT_TOKENS: usize = 1000;
const THINKING_OVERHEAD: usize = 10;
pub struct TokenCounter {
bpe: &'static Option<CoreBPE>,
cache: DashMap<u64, usize>,
cache_cap: usize,
}
impl TokenCounter {
#[must_use]
pub fn new() -> Self {
let bpe = BPE.get_or_init(|| match tiktoken_rs::cl100k_base() {
Ok(b) => Some(b),
Err(e) => {
tracing::warn!("tiktoken cl100k_base init failed, using chars/4 fallback: {e}");
None
}
});
Self {
bpe,
cache: DashMap::new(),
cache_cap: CACHE_CAP,
}
}
#[must_use]
pub fn count_tokens(&self, text: &str) -> usize {
if text.is_empty() {
return 0;
}
if text.len() > MAX_INPUT_LEN {
return zeph_common::text::estimate_tokens(text);
}
let key = hash_text(text);
if let Some(cached) = self.cache.get(&key) {
return *cached;
}
let count = match self.bpe {
Some(bpe) => bpe.encode_with_special_tokens(text).len(),
None => zeph_common::text::estimate_tokens(text),
};
if self.cache.len() >= self.cache_cap {
let key_to_evict = self.cache.iter().next().map(|e| *e.key());
if let Some(k) = key_to_evict {
self.cache.remove(&k);
}
}
self.cache.insert(key, count);
count
}
#[must_use]
pub fn count_message_tokens(&self, msg: &Message) -> usize {
if msg.parts.is_empty() {
return self.count_tokens(&msg.content);
}
msg.parts.iter().map(|p| self.count_part_tokens(p)).sum()
}
#[must_use]
fn count_part_tokens(&self, part: &MessagePart) -> usize {
match part {
MessagePart::Text { text }
| MessagePart::Recall { text }
| MessagePart::CodeContext { text }
| MessagePart::Summary { text }
| MessagePart::CrossSession { text } => {
if text.trim().is_empty() {
return 0;
}
self.count_tokens(text)
}
MessagePart::ToolOutput {
tool_name, body, ..
} => {
TOOL_OUTPUT_OVERHEAD
+ self.count_tokens(tool_name.as_str())
+ self.count_tokens(body)
}
MessagePart::ToolUse { id, name, input } => {
TOOL_USE_OVERHEAD
+ self.count_tokens(id)
+ self.count_tokens(name)
+ self.count_tokens(&input.to_string())
}
MessagePart::ToolResult {
tool_use_id,
content,
..
} => TOOL_RESULT_OVERHEAD + self.count_tokens(tool_use_id) + self.count_tokens(content),
MessagePart::Image(_) => IMAGE_OVERHEAD + IMAGE_DEFAULT_TOKENS,
MessagePart::ThinkingBlock {
thinking,
signature,
} => THINKING_OVERHEAD + self.count_tokens(thinking) + self.count_tokens(signature),
MessagePart::RedactedThinkingBlock { data } => {
THINKING_OVERHEAD + estimate_tokens(data)
}
MessagePart::Compaction { summary } => self.count_tokens(summary),
}
}
#[must_use]
pub fn count_tool_schema_tokens(&self, schema: &serde_json::Value) -> usize {
let base = count_schema_value(self, schema);
let total =
base.cast_signed() + ENUM_INIT + FUNC_INIT.cast_signed() + FUNC_END.cast_signed();
total.max(0).cast_unsigned()
}
}
impl Default for TokenCounter {
fn default() -> Self {
Self::new()
}
}
fn hash_text(text: &str) -> u64 {
let mut hasher = DefaultHasher::new();
text.hash(&mut hasher);
hasher.finish()
}
fn count_schema_value(counter: &TokenCounter, value: &serde_json::Value) -> usize {
match value {
serde_json::Value::Object(map) => {
let mut tokens = PROP_INIT;
for (key, val) in map {
tokens += PROP_KEY + counter.count_tokens(key);
tokens += count_schema_value(counter, val);
}
tokens
}
serde_json::Value::Array(arr) => {
let mut tokens = ENUM_ITEM;
for item in arr {
tokens += count_schema_value(counter, item);
}
tokens
}
serde_json::Value::String(s) => counter.count_tokens(s),
serde_json::Value::Bool(_) | serde_json::Value::Number(_) | serde_json::Value::Null => 1,
}
}
#[cfg(test)]
mod tests {
use super::*;
use zeph_llm::provider::{ImageData, Message, MessageMetadata, MessagePart, Role};
static BPE_NONE: Option<CoreBPE> = None;
fn counter_with_no_bpe(cache_cap: usize) -> TokenCounter {
TokenCounter {
bpe: &BPE_NONE,
cache: DashMap::new(),
cache_cap,
}
}
fn make_msg(parts: Vec<MessagePart>) -> Message {
Message::from_parts(Role::User, parts)
}
fn make_msg_no_parts(content: &str) -> Message {
Message {
role: Role::User,
content: content.to_string(),
parts: vec![],
metadata: MessageMetadata::default(),
}
}
#[test]
fn count_message_tokens_empty_parts_falls_back_to_content() {
let counter = TokenCounter::new();
let msg = make_msg_no_parts("hello world");
assert_eq!(
counter.count_message_tokens(&msg),
counter.count_tokens("hello world")
);
}
#[test]
fn count_message_tokens_text_part_matches_count_tokens() {
let counter = TokenCounter::new();
let text = "the quick brown fox jumps over the lazy dog";
let msg = make_msg(vec![MessagePart::Text {
text: text.to_string(),
}]);
assert_eq!(
counter.count_message_tokens(&msg),
counter.count_tokens(text)
);
}
#[test]
fn count_message_tokens_tool_use_exceeds_flattened_content() {
let counter = TokenCounter::new();
let input = serde_json::json!({"command": "find /home -name '*.rs' -type f | head -100"});
let msg = make_msg(vec![MessagePart::ToolUse {
id: "toolu_abc".into(),
name: "bash".into(),
input,
}]);
let structured = counter.count_message_tokens(&msg);
let flattened = counter.count_tokens(&msg.content);
assert!(
structured > flattened,
"structured={structured} should exceed flattened={flattened}"
);
}
#[test]
fn count_message_tokens_compacted_tool_output_is_small() {
let counter = TokenCounter::new();
let msg = make_msg(vec![MessagePart::ToolOutput {
tool_name: "bash".into(),
body: String::new(),
compacted_at: Some(1_700_000_000),
}]);
let tokens = counter.count_message_tokens(&msg);
assert!(
tokens <= 15,
"compacted tool output should be small, got {tokens}"
);
}
#[test]
fn count_message_tokens_image_returns_constant() {
let counter = TokenCounter::new();
let msg = make_msg(vec![MessagePart::Image(Box::new(ImageData {
data: vec![0u8; 1000],
mime_type: "image/jpeg".into(),
}))]);
assert_eq!(
counter.count_message_tokens(&msg),
IMAGE_OVERHEAD + IMAGE_DEFAULT_TOKENS
);
}
#[test]
fn count_message_tokens_thinking_block_counts_text() {
let counter = TokenCounter::new();
let thinking = "step by step reasoning about the problem";
let signature = "sig";
let msg = make_msg(vec![MessagePart::ThinkingBlock {
thinking: thinking.to_string(),
signature: signature.to_string(),
}]);
let expected =
THINKING_OVERHEAD + counter.count_tokens(thinking) + counter.count_tokens(signature);
assert_eq!(counter.count_message_tokens(&msg), expected);
}
#[test]
fn count_part_tokens_empty_text_returns_zero() {
let counter = TokenCounter::new();
assert_eq!(
counter.count_part_tokens(&MessagePart::Text {
text: String::new()
}),
0
);
assert_eq!(
counter.count_part_tokens(&MessagePart::Text {
text: " ".to_string()
}),
0
);
assert_eq!(
counter.count_part_tokens(&MessagePart::Recall {
text: "\n\t".to_string()
}),
0
);
}
#[test]
fn count_message_tokens_push_recompute_consistency() {
let counter = TokenCounter::new();
let parts = vec![
MessagePart::Text {
text: "hello".into(),
},
MessagePart::ToolOutput {
tool_name: "bash".into(),
body: "output data".into(),
compacted_at: None,
},
];
let msg = make_msg(parts);
let total = counter.count_message_tokens(&msg);
let sum: usize = msg.parts.iter().map(|p| counter.count_part_tokens(p)).sum();
assert_eq!(total, sum);
}
#[test]
fn count_message_tokens_parts_take_priority_over_content() {
let counter = TokenCounter::new();
let parts_text = "hello from parts";
let msg = Message {
role: Role::User,
content: "completely different content that should be ignored".to_string(),
parts: vec![MessagePart::Text {
text: parts_text.to_string(),
}],
metadata: MessageMetadata::default(),
};
let parts_based = counter.count_tokens(parts_text);
let content_based = counter.count_tokens(&msg.content);
assert_ne!(
parts_based, content_based,
"test setup: parts and content must differ"
);
assert_eq!(counter.count_message_tokens(&msg), parts_based);
}
#[test]
fn count_part_tokens_tool_result() {
let counter = TokenCounter::new();
let tool_use_id = "toolu_xyz";
let content = "result text";
let part = MessagePart::ToolResult {
tool_use_id: tool_use_id.to_string(),
content: content.to_string(),
is_error: false,
};
let expected = TOOL_RESULT_OVERHEAD
+ counter.count_tokens(tool_use_id)
+ counter.count_tokens(content);
assert_eq!(counter.count_part_tokens(&part), expected);
}
#[test]
fn count_tokens_empty() {
let counter = TokenCounter::new();
assert_eq!(counter.count_tokens(""), 0);
}
#[test]
fn count_tokens_non_empty() {
let counter = TokenCounter::new();
assert!(counter.count_tokens("hello world") > 0);
}
#[test]
fn count_tokens_cache_hit_returns_same() {
let counter = TokenCounter::new();
let text = "the quick brown fox";
let first = counter.count_tokens(text);
let second = counter.count_tokens(text);
assert_eq!(first, second);
}
#[test]
fn count_tokens_fallback_mode() {
let counter = counter_with_no_bpe(CACHE_CAP);
assert_eq!(counter.count_tokens("abcdefgh"), 2);
assert_eq!(counter.count_tokens(""), 0);
}
#[test]
fn count_tokens_oversized_input_uses_fallback_without_cache() {
let counter = TokenCounter::new();
let large = "a".repeat(MAX_INPUT_LEN + 1);
let result = counter.count_tokens(&large);
assert_eq!(result, zeph_common::text::estimate_tokens(&large));
assert!(counter.cache.is_empty());
}
#[test]
fn count_tokens_unicode_bpe_differs_from_fallback() {
let counter = TokenCounter::new();
let text = "Привет мир! 你好世界! こんにちは! 🌍";
let bpe_count = counter.count_tokens(text);
let fallback_count = zeph_common::text::estimate_tokens(text);
assert!(bpe_count > 0, "BPE count must be positive");
assert_ne!(
bpe_count, fallback_count,
"BPE tokenization should differ from chars/4 for unicode text"
);
}
#[test]
fn count_tool_schema_tokens_sample() {
let counter = TokenCounter::new();
let schema = serde_json::json!({
"name": "get_weather",
"description": "Get the current weather for a location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city name"
}
},
"required": ["location"]
}
});
let tokens = counter.count_tool_schema_tokens(&schema);
assert_eq!(tokens, 82);
}
#[test]
fn two_instances_share_bpe_pointer() {
let a = TokenCounter::new();
let b = TokenCounter::new();
assert!(std::ptr::eq(a.bpe, b.bpe));
}
#[test]
fn cache_eviction_at_capacity() {
let counter = counter_with_no_bpe(3);
let _ = counter.count_tokens("aaaa");
let _ = counter.count_tokens("bbbb");
let _ = counter.count_tokens("cccc");
assert_eq!(counter.cache.len(), 3);
let _ = counter.count_tokens("dddd");
assert_eq!(counter.cache.len(), 3);
}
}