use crate::error::{LlmError, LlmResult};
use crate::logging::{log_debug, log_warn};
use std::sync::Arc;
use tiktoken_rs::{cl100k_base, o200k_base, CoreBPE};
pub trait TokenCounter: Send + Sync + std::fmt::Debug {
fn count_tokens(&self, text: &str) -> LlmResult<u32>;
fn count_message_tokens(&self, messages: &[serde_json::Value]) -> LlmResult<u32>;
fn max_context_tokens(&self) -> u32;
fn validate_token_limit(&self, text: &str) -> LlmResult<()>;
fn truncate_to_limit(&self, text: &str, max_tokens: u32) -> LlmResult<String>;
}
pub struct OpenAITokenCounter {
tokenizer: CoreBPE,
max_tokens: u32,
model_name: String,
}
impl std::fmt::Debug for OpenAITokenCounter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OpenAITokenCounter")
.field("max_tokens", &self.max_tokens)
.field("model_name", &self.model_name)
.finish()
}
}
impl OpenAITokenCounter {
fn gpt4_max_tokens(model: &str) -> u32 {
if model.contains("turbo") || model.contains("preview") {
128000
} else if model.contains("32k") {
32768
} else {
8192
}
}
fn gpt35_max_tokens(model: &str) -> u32 {
if model.contains("16k") {
16384
} else {
4096
}
}
fn get_model_config(model: &str) -> LlmResult<(CoreBPE, u32)> {
match model {
m if m.starts_with("gpt-4") => {
let tokenizer = cl100k_base().map_err(|e| {
LlmError::configuration_error(format!("Failed to initialize tokenizer: {}", e))
})?;
Ok((tokenizer, Self::gpt4_max_tokens(m)))
}
m if m.starts_with("gpt-3.5") => {
let tokenizer = cl100k_base().map_err(|e| {
LlmError::configuration_error(format!("Failed to initialize tokenizer: {}", e))
})?;
Ok((tokenizer, Self::gpt35_max_tokens(m)))
}
m if m.starts_with("o1") => {
let tokenizer = o200k_base().map_err(|e| {
LlmError::configuration_error(format!("Failed to initialize tokenizer: {}", e))
})?;
Ok((tokenizer, 200000))
}
_ => {
log_warn!(model = %model, "Unknown model, using cl100k_base tokenizer with 4k context");
let tokenizer = cl100k_base().map_err(|e| {
LlmError::configuration_error(format!("Failed to initialize tokenizer: {}", e))
})?;
Ok((tokenizer, 4096))
}
}
}
pub fn new(model: &str) -> LlmResult<Self> {
let (tokenizer, max_tokens) = Self::get_model_config(model)?;
Ok(Self {
tokenizer,
max_tokens,
model_name: model.to_string(),
})
}
pub fn for_lm_studio(max_tokens: u32) -> LlmResult<Self> {
let tokenizer = cl100k_base().map_err(|e| {
LlmError::configuration_error(format!(
"Failed to initialize LM Studio tokenizer: {}",
e
))
})?;
Ok(Self {
tokenizer,
max_tokens,
model_name: "lm-studio".to_string(),
})
}
}
impl TokenCounter for OpenAITokenCounter {
fn count_tokens(&self, text: &str) -> LlmResult<u32> {
let tokens = self.tokenizer.encode_with_special_tokens(text);
Ok(tokens.len() as u32)
}
fn count_message_tokens(&self, messages: &[serde_json::Value]) -> LlmResult<u32> {
let mut total_tokens = 3u32;
for message in messages {
total_tokens += self.count_single_message_tokens(message);
}
total_tokens += 3;
log_debug!(
total_tokens = total_tokens,
message_count = messages.len(),
model = %self.model_name,
"Calculated message token count"
);
Ok(total_tokens)
}
fn max_context_tokens(&self) -> u32 {
self.max_tokens
}
fn validate_token_limit(&self, text: &str) -> LlmResult<()> {
let token_count = self.count_tokens(text)?;
if token_count > self.max_tokens {
return Err(LlmError::token_limit_exceeded(
token_count as usize,
self.max_tokens as usize,
));
}
Ok(())
}
fn truncate_to_limit(&self, text: &str, max_tokens: u32) -> LlmResult<String> {
let tokens = self.tokenizer.encode_with_special_tokens(text);
if tokens.len() <= max_tokens as usize {
return Ok(text.to_string());
}
let truncated_tokens = &tokens[..max_tokens as usize];
let truncated_text = self
.tokenizer
.decode(truncated_tokens.to_vec())
.map_err(|e| {
LlmError::response_parsing_error(format!(
"Failed to decode truncated tokens: {}",
e
))
})?;
Ok(truncated_text)
}
}
impl OpenAITokenCounter {
fn count_single_message_tokens(&self, message: &serde_json::Value) -> u32 {
let role = message
.get("role")
.and_then(|r| r.as_str())
.unwrap_or("user");
let content = message
.get("content")
.and_then(|c| c.as_str())
.unwrap_or("");
let mut tokens = 4u32; tokens += self.tokenizer.encode_with_special_tokens(role).len() as u32;
tokens += self.tokenizer.encode_with_special_tokens(content).len() as u32;
tokens += self.count_tool_call_tokens(message);
tokens
}
fn count_tool_call_tokens(&self, message: &serde_json::Value) -> u32 {
let Some(tool_calls) = message.get("tool_calls") else {
return 0;
};
let Some(calls_array) = tool_calls.as_array() else {
return 0;
};
calls_array
.iter()
.filter_map(|call| {
call.get("function")
.and_then(|f| f.get("arguments"))
.and_then(|a| a.as_str())
})
.map(|args_str| self.tokenizer.encode_with_special_tokens(args_str).len() as u32)
.sum()
}
}
pub struct AnthropicTokenCounter {
tokenizer: CoreBPE,
max_tokens: u32,
}
impl std::fmt::Debug for AnthropicTokenCounter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AnthropicTokenCounter")
.field("max_tokens", &self.max_tokens)
.finish()
}
}
impl AnthropicTokenCounter {
pub fn new(model: &str) -> LlmResult<Self> {
let max_tokens = match model {
m if m.contains("claude-3-5-sonnet") => 200000,
m if m.contains("claude-3") => 200000,
m if m.contains("claude-2") => 100000,
_ => {
log_warn!(model = %model, "Unknown Anthropic model, using 100k context");
100000
}
};
let tokenizer = cl100k_base().map_err(|e| {
LlmError::configuration_error(format!(
"Failed to initialize Anthropic tokenizer: {}",
e
))
})?;
Ok(Self {
tokenizer,
max_tokens,
})
}
}
impl TokenCounter for AnthropicTokenCounter {
fn count_tokens(&self, text: &str) -> LlmResult<u32> {
let tokens = self.tokenizer.encode_with_special_tokens(text);
Ok((tokens.len() as f32 * 1.1) as u32)
}
fn count_message_tokens(&self, messages: &[serde_json::Value]) -> LlmResult<u32> {
let mut total_tokens = 0u32;
for message in messages {
let content = message
.get("content")
.and_then(|c| c.as_str())
.unwrap_or("");
let content_tokens = self.count_tokens(content)?;
total_tokens += content_tokens;
total_tokens += 10; }
log_debug!(
total_tokens = total_tokens,
message_count = messages.len(),
"Calculated Anthropic message token count"
);
Ok(total_tokens)
}
fn max_context_tokens(&self) -> u32 {
self.max_tokens
}
fn validate_token_limit(&self, text: &str) -> LlmResult<()> {
let token_count = self.count_tokens(text)?;
if token_count > self.max_tokens {
return Err(LlmError::token_limit_exceeded(
token_count as usize,
self.max_tokens as usize,
));
}
Ok(())
}
fn truncate_to_limit(&self, text: &str, max_tokens: u32) -> LlmResult<String> {
let tokens = self.tokenizer.encode_with_special_tokens(text);
let adjusted_limit = (max_tokens as f32 / 1.1) as usize;
if tokens.len() <= adjusted_limit {
return Ok(text.to_string());
}
log_debug!(
original_tokens = tokens.len(),
max_tokens = max_tokens,
adjusted_limit = adjusted_limit,
"Truncating Anthropic text to fit token limit"
);
let truncated_tokens = &tokens[..adjusted_limit];
let truncated_text = self
.tokenizer
.decode(truncated_tokens.to_vec())
.map_err(|e| {
LlmError::response_parsing_error(format!(
"Failed to decode truncated tokens: {}",
e
))
})?;
Ok(truncated_text)
}
}
pub struct TokenCounterFactory;
impl TokenCounterFactory {
pub fn create_counter(provider: &str, model: &str) -> LlmResult<Arc<dyn TokenCounter>> {
match provider.to_lowercase().as_str() {
"openai" => {
let counter = OpenAITokenCounter::new(model)?;
Ok(Arc::new(counter))
}
"lmstudio" => {
let counter = OpenAITokenCounter::for_lm_studio(4096)?;
Ok(Arc::new(counter))
}
"ollama" => {
let counter = OpenAITokenCounter::for_lm_studio(4096)?;
Ok(Arc::new(counter))
}
"anthropic" => {
let counter = AnthropicTokenCounter::new(model)?;
Ok(Arc::new(counter))
}
_ => Err(LlmError::unsupported_provider(provider)),
}
}
pub fn create_counter_with_limit(
provider: &str,
model: &str,
max_tokens: u32,
) -> LlmResult<Arc<dyn TokenCounter>> {
match provider.to_lowercase().as_str() {
"openai" => {
let mut counter = OpenAITokenCounter::new(model)?;
counter.max_tokens = max_tokens;
Ok(Arc::new(counter))
}
"lmstudio" => {
let counter = OpenAITokenCounter::for_lm_studio(max_tokens)?;
Ok(Arc::new(counter))
}
"ollama" => {
let counter = OpenAITokenCounter::for_lm_studio(max_tokens)?;
Ok(Arc::new(counter))
}
"anthropic" => {
let mut counter = AnthropicTokenCounter::new(model)?;
counter.max_tokens = max_tokens;
Ok(Arc::new(counter))
}
_ => Err(LlmError::unsupported_provider(provider)),
}
}
}