use tiktoken_rs::{cl100k_base, o200k_base, p50k_base, r50k_base, CoreBPE};
#[derive(Debug, thiserror::Error)]
pub enum TokenError {
#[error("Failed to initialize tokenizer: {0}")]
InitError(String),
}
fn get_tokenizer_for_model(model: &str) -> Result<CoreBPE, TokenError> {
let model_lower = model.to_lowercase();
if model_lower.contains("gpt-5")
|| model_lower.contains("gpt-4o")
|| model_lower.starts_with("o1")
|| model_lower.starts_with("o3")
|| model_lower.starts_with("o4")
|| model_lower.contains("chatgpt-4o")
{
return o200k_base().map_err(|e| TokenError::InitError(e.to_string()));
}
if model_lower.contains("gpt-4")
|| model_lower.contains("text-embedding")
|| model_lower.contains("claude")
|| model_lower.contains("gemini")
|| model_lower.contains("deepseek")
{
return cl100k_base().map_err(|e| TokenError::InitError(e.to_string()));
}
if model_lower.contains("davinci") || model_lower.contains("code-") {
return p50k_base().map_err(|e| TokenError::InitError(e.to_string()));
}
if model_lower.contains("ada")
|| model_lower.contains("babbage")
|| model_lower.contains("curie")
{
return r50k_base().map_err(|e| TokenError::InitError(e.to_string()));
}
cl100k_base().map_err(|e| TokenError::InitError(e.to_string()))
}
pub fn count_tokens(text: &str, model: &str) -> Result<usize, TokenError> {
let bpe = get_tokenizer_for_model(model)?;
Ok(bpe.encode_with_special_tokens(text).len())
}
pub fn count_tokens_default(text: &str) -> Result<usize, TokenError> {
count_tokens(text, "gpt-4")
}
pub struct TokenCounter {
bpe: CoreBPE,
model: String,
}
impl TokenCounter {
pub fn new(model: &str) -> Result<Self, TokenError> {
let bpe = get_tokenizer_for_model(model)?;
Ok(Self {
bpe,
model: model.to_string(),
})
}
pub fn count(&self, text: &str) -> usize {
self.bpe.encode_with_special_tokens(text).len()
}
pub fn encode(&self, text: &str) -> Vec<u32> {
self.bpe.encode_with_special_tokens(text)
}
pub fn decode(&self, tokens: &[u32]) -> Result<String, TokenError> {
self.bpe
.decode(tokens.to_vec())
.map_err(|e| TokenError::InitError(e.to_string()))
}
pub fn model(&self) -> &str {
&self.model
}
}
pub fn estimate_message_tokens(
content: &str,
role: &str,
model: &str,
) -> Result<usize, TokenError> {
let content_tokens = count_tokens(content, model)?;
let role_tokens = count_tokens(role, model)?;
Ok(content_tokens + role_tokens + 4)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_count_tokens_gpt4() {
let count = count_tokens("Hello, world!", "gpt-4").unwrap();
assert!(count > 0);
assert!(count < 10);
}
#[test]
fn test_count_tokens_empty() {
let count = count_tokens("", "gpt-4").unwrap();
assert_eq!(count, 0);
}
#[test]
fn test_count_tokens_long_text() {
let text = "The quick brown fox jumps over the lazy dog. ".repeat(100);
let count = count_tokens(&text, "gpt-4").unwrap();
assert!(count > 500);
assert!(count < 2000);
}
#[test]
fn test_token_counter_reuse() {
let counter = TokenCounter::new("gpt-4").unwrap();
let count1 = counter.count("Hello");
let count2 = counter.count("World");
assert!(count1 > 0);
assert!(count2 > 0);
}
#[test]
fn test_encode_decode() {
let counter = TokenCounter::new("gpt-4").unwrap();
let text = "Hello, world!";
let tokens = counter.encode(text);
let decoded = counter.decode(&tokens).unwrap();
assert_eq!(decoded, text);
}
#[test]
fn test_different_models() {
let text = "Testing different models";
let gpt5_tokens = count_tokens(text, "gpt-5").unwrap();
let gpt4o_tokens = count_tokens(text, "gpt-4o").unwrap();
assert_eq!(gpt5_tokens, gpt4o_tokens);
}
#[test]
fn test_unknown_model_fallback() {
let count = count_tokens("Hello", "unknown-model-xyz").unwrap();
assert!(count > 0);
}
#[test]
fn test_gpt5_models() {
let count_gpt5 = count_tokens("Hello", "gpt-5").unwrap();
let count_mini = count_tokens("Hello", "gpt-5-mini").unwrap();
let count_codex = count_tokens("Hello", "gpt-5-codex").unwrap();
let count_51 = count_tokens("Hello", "gpt-5.1").unwrap();
assert!(count_gpt5 > 0);
assert_eq!(count_gpt5, count_mini);
assert_eq!(count_gpt5, count_codex);
assert_eq!(count_gpt5, count_51);
}
#[test]
fn test_o_series_models() {
let count_o3 = count_tokens("Hello", "o3-mini").unwrap();
let count_o4 = count_tokens("Hello", "o4-mini").unwrap();
assert!(count_o3 > 0);
assert_eq!(count_o3, count_o4);
}
}