use super::types::ConversationItem;
pub trait TokenCounter: Send + Sync {
fn count_tokens(&self, text: &str) -> usize;
fn count_messages(&self, messages: &[ConversationItem]) -> usize {
messages
.iter()
.map(|m| self.count_tokens(&m.content) + 4) .sum()
}
fn model_context_limit(&self) -> usize;
}
pub fn context_limit_for_model(model: &str) -> usize {
let m = model.to_lowercase();
if m.contains("claude") {
return 200_000;
}
if m.contains("gpt-4o") || m.contains("gpt-4-turbo") || m.contains("o1") || m.contains("o3") {
return 128_000;
}
if m.contains("gpt-4") {
return 128_000;
}
if m.contains("gemini") {
return 1_000_000;
}
if m.contains("qwen") {
return 131_072;
}
if m.contains("llama") {
return 128_000;
}
if m.contains("mistral") || m.contains("mixtral") {
return 32_000;
}
if m.contains("deepseek") {
return 128_000;
}
if m.contains("kimi") || m.contains("moonshot") {
return 128_000;
}
if m.contains("command-r") {
return 128_000;
}
32_000
}
pub struct TiktokenCounter {
bpe: tiktoken_rs::CoreBPE,
context_limit: usize,
}
impl TiktokenCounter {
pub fn for_model(model: &str) -> Self {
let model_lower = model.to_lowercase();
if model_lower.contains("gpt-4o")
|| model_lower.contains("o1")
|| model_lower.contains("o3")
{
if let Ok(bpe) = tiktoken_rs::o200k_base() {
return Self {
bpe,
context_limit: context_limit_for_model(model),
};
}
}
let bpe = tiktoken_rs::cl100k_base().expect("tiktoken-rs failed to load cl100k_base");
Self {
bpe,
context_limit: context_limit_for_model(model),
}
}
}
impl TokenCounter for TiktokenCounter {
fn count_tokens(&self, text: &str) -> usize {
self.bpe.encode_with_special_tokens(text).len()
}
fn model_context_limit(&self) -> usize {
self.context_limit
}
}
pub fn create_token_counter(model: &str) -> Box<dyn TokenCounter> {
let m = model.to_lowercase();
let use_tiktoken = m.contains("gpt")
|| m.contains("claude")
|| m.contains("o1")
|| m.contains("o3")
|| m.contains("text-embedding");
if use_tiktoken {
Box::new(TiktokenCounter::for_model(model))
} else {
Box::new(HeuristicTokenCounter::new(context_limit_for_model(model)))
}
}
pub struct HeuristicTokenCounter {
context_limit: usize,
}
impl HeuristicTokenCounter {
pub fn new(context_limit: usize) -> Self {
Self { context_limit }
}
}
impl TokenCounter for HeuristicTokenCounter {
fn count_tokens(&self, text: &str) -> usize {
let raw = (text.len() as f64 / 3.5).ceil() as usize;
raw + raw / 7 }
fn model_context_limit(&self) -> usize {
self.context_limit
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn heuristic_counter_counts_tokens() {
let counter = HeuristicTokenCounter::new(32_000);
let count = counter.count_tokens("hello world");
assert!(count > 0, "should count some tokens");
assert!(count < 20, "heuristic should be reasonable for short text");
}
#[test]
fn heuristic_counter_empty_string() {
let counter = HeuristicTokenCounter::new(32_000);
assert_eq!(counter.count_tokens(""), 0);
}
#[test]
fn heuristic_counter_context_limit() {
let counter = HeuristicTokenCounter::new(128_000);
assert_eq!(counter.model_context_limit(), 128_000);
}
#[test]
fn tiktoken_counter_counts_gpt4o() {
let counter = TiktokenCounter::for_model("gpt-4o");
let count = counter.count_tokens("Hello, world!");
assert!(count > 0);
assert!(
count < 10,
"short greeting should be under 10 tokens, got {count}"
);
assert_eq!(counter.model_context_limit(), 128_000);
}
#[test]
fn tiktoken_counter_counts_claude() {
let counter = TiktokenCounter::for_model("claude-sonnet-4-5-20250929");
let count = counter.count_tokens("Hello, world!");
assert!(count > 0);
assert_eq!(counter.model_context_limit(), 200_000);
}
#[test]
fn factory_returns_tiktoken_for_openai() {
let counter = create_token_counter("gpt-4o");
let count = counter.count_tokens("Hello");
assert!(count > 0);
assert_eq!(counter.model_context_limit(), 128_000);
}
#[test]
fn factory_returns_tiktoken_for_claude() {
let counter = create_token_counter("claude-haiku-4-5-20251001");
let count = counter.count_tokens("Hello");
assert!(count > 0);
assert_eq!(counter.model_context_limit(), 200_000);
}
#[test]
fn factory_returns_heuristic_for_unknown() {
let counter = create_token_counter("my-custom-local-model");
let count = counter.count_tokens("Hello");
assert!(count > 0);
assert_eq!(counter.model_context_limit(), 32_000);
}
}