use crate::inference::Message;
use crate::provider::ProviderType;
pub trait TokenCounter: Send + Sync {
fn count(&self, text: &str) -> u32;
fn count_messages(&self, messages: &[Message]) -> u32;
}
pub struct SimpleTokenCounter;
impl SimpleTokenCounter {
const BYTES_PER_TOKEN: f64 = 4.0;
const MESSAGE_OVERHEAD: u32 = 4;
}
impl TokenCounter for SimpleTokenCounter {
#[inline]
fn count(&self, text: &str) -> u32 {
(text.len() as f64 / Self::BYTES_PER_TOKEN).ceil() as u32
}
fn count_messages(&self, messages: &[Message]) -> u32 {
let mut total = 0u32;
for msg in messages {
total = total.saturating_add(self.count(&msg.content.text()));
total = total.saturating_add(Self::MESSAGE_OVERHEAD);
for tc in &msg.tool_calls {
total = total.saturating_add(self.count(&tc.name));
total = total.saturating_add(self.count(&tc.arguments.to_string()));
}
}
total.saturating_add(3)
}
}
pub struct ProviderTokenCounter {
bytes_per_token: f64,
message_overhead: u32,
}
impl ProviderTokenCounter {
#[must_use]
pub fn for_provider(provider: ProviderType) -> Self {
let (cpt, overhead) = match provider {
ProviderType::OpenAi => (3.8, 4),
ProviderType::Anthropic => (3.5, 4),
ProviderType::DeepSeek => (3.8, 4),
ProviderType::Mistral => (3.8, 4),
ProviderType::Groq => (3.8, 4),
ProviderType::Grok => (3.8, 4),
ProviderType::OpenRouter => (3.8, 4),
ProviderType::Ollama
| ProviderType::LlamaCpp
| ProviderType::LmStudio
| ProviderType::LocalAi => (3.7, 4),
_ => (4.0, 4),
};
Self {
bytes_per_token: cpt,
message_overhead: overhead,
}
}
}
impl TokenCounter for ProviderTokenCounter {
#[inline]
fn count(&self, text: &str) -> u32 {
(text.len() as f64 / self.bytes_per_token).ceil() as u32
}
fn count_messages(&self, messages: &[Message]) -> u32 {
let mut total = 0u32;
for msg in messages {
total = total.saturating_add(self.count(&msg.content.text()));
total = total.saturating_add(self.message_overhead);
for tc in &msg.tool_calls {
total = total.saturating_add(self.count(&tc.name));
total = total.saturating_add(self.count(&tc.arguments.to_string()));
}
}
total.saturating_add(3)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::inference::Role;
#[test]
fn simple_counter_empty_string() {
let counter = SimpleTokenCounter;
assert_eq!(counter.count(""), 0);
}
#[test]
fn simple_counter_short_text() {
let counter = SimpleTokenCounter;
assert_eq!(counter.count("hello"), 2);
}
#[test]
fn simple_counter_longer_text() {
let counter = SimpleTokenCounter;
let text = "a".repeat(100);
assert_eq!(counter.count(&text), 25);
}
#[test]
fn simple_counter_messages() {
let counter = SimpleTokenCounter;
let messages = vec![
Message::new(Role::System, "You are helpful."),
Message::new(Role::User, "Hello there"),
];
let count = counter.count_messages(&messages);
assert_eq!(count, 18);
}
#[test]
fn simple_counter_empty_messages() {
let counter = SimpleTokenCounter;
assert_eq!(counter.count_messages(&[]), 3);
}
#[test]
fn provider_counter_openai() {
let counter = ProviderTokenCounter::for_provider(ProviderType::OpenAi);
let text = "a".repeat(100);
assert_eq!(counter.count(&text), 27);
}
#[test]
fn provider_counter_anthropic() {
let counter = ProviderTokenCounter::for_provider(ProviderType::Anthropic);
let text = "a".repeat(100);
assert_eq!(counter.count(&text), 29);
}
#[test]
fn provider_counter_local() {
let counter = ProviderTokenCounter::for_provider(ProviderType::Ollama);
let text = "a".repeat(100);
assert_eq!(counter.count(&text), 28);
}
#[test]
fn provider_counter_messages() {
let counter = ProviderTokenCounter::for_provider(ProviderType::OpenAi);
let messages = vec![Message::new(Role::User, "What is Rust?")];
let count = counter.count_messages(&messages);
assert_eq!(count, 11);
}
#[test]
fn saturation_on_large_input() {
let counter = SimpleTokenCounter;
let text = "a".repeat(10_000_000);
let count = counter.count(&text);
assert!(count > 0);
}
#[test]
fn provider_counter_with_tool_calls() {
let counter = SimpleTokenCounter;
let messages = vec![Message {
role: Role::Assistant,
content: "Let me check that.".into(),
tool_call_id: None,
tool_calls: vec![crate::tools::ToolCall {
id: "call_1".into(),
name: "get_weather".into(),
arguments: serde_json::json!({"city": "London"}),
}],
}];
let count = counter.count_messages(&messages);
assert!(count > 15, "expected at least 15 tokens, got {count}");
}
}