use crate::core::models::openai::{ChatMessage, ContentPart, MessageContent};
use crate::utils::error::{GatewayError, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct TokenCounter {
model_configs: HashMap<String, ModelTokenConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelTokenConfig {
pub model: String,
pub chars_per_token: f64,
pub message_overhead: u32,
pub request_overhead: u32,
pub max_context_tokens: u32,
pub special_tokens: HashMap<String, u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenEstimate {
pub input_tokens: u32,
pub output_tokens: Option<u32>,
pub total_tokens: u32,
pub is_approximate: bool,
pub confidence: f64,
}
impl TokenCounter {
pub fn new() -> Self {
Self {
model_configs: Self::default_model_configs(),
}
}
#[allow(dead_code)]
pub fn count_chat_tokens(
&self,
model: &str,
messages: &[ChatMessage],
) -> Result<TokenEstimate> {
let config = self.get_model_config(model)?;
let mut total_tokens = config.request_overhead;
for message in messages {
total_tokens += self.count_message_tokens(config, message)?;
}
Ok(TokenEstimate {
input_tokens: total_tokens,
output_tokens: None,
total_tokens,
is_approximate: true,
confidence: 0.85, })
}
#[allow(dead_code)]
fn count_message_tokens(
&self,
config: &ModelTokenConfig,
message: &ChatMessage,
) -> Result<u32> {
let mut tokens = config.message_overhead;
tokens += self.estimate_text_tokens(config, &ToString::to_string(&message.role));
if let Some(content) = &message.content {
tokens += self.count_content_tokens(config, content)?;
}
if let Some(name) = &message.name {
tokens += self.estimate_text_tokens(config, name);
}
if let Some(function_call) = &message.function_call {
tokens += self.estimate_text_tokens(config, &function_call.name);
tokens += self.estimate_text_tokens(config, &function_call.arguments);
}
if let Some(tool_calls) = &message.tool_calls {
for tool_call in tool_calls {
tokens += self.estimate_text_tokens(config, &tool_call.id);
tokens += self.estimate_text_tokens(config, &tool_call.tool_type);
tokens += self.estimate_text_tokens(config, &tool_call.function.name);
tokens += self.estimate_text_tokens(config, &tool_call.function.arguments);
}
}
Ok(tokens)
}
#[allow(dead_code)]
fn count_content_tokens(
&self,
config: &ModelTokenConfig,
content: &MessageContent,
) -> Result<u32> {
match content {
MessageContent::Text(text) => Ok(self.estimate_text_tokens(config, text)),
MessageContent::Parts(parts) => {
let mut tokens = 0;
for part in parts {
tokens += self.count_content_part_tokens(config, part)?;
}
Ok(tokens)
}
}
}
#[allow(dead_code)]
fn count_content_part_tokens(
&self,
config: &ModelTokenConfig,
part: &ContentPart,
) -> Result<u32> {
match part {
ContentPart::Text { text } => Ok(self.estimate_text_tokens(config, text)),
ContentPart::ImageUrl { image_url: _ } => {
Ok(85) }
ContentPart::Audio { audio: _ } => {
Ok(100)
}
}
}
fn estimate_text_tokens(&self, config: &ModelTokenConfig, text: &str) -> u32 {
if text.is_empty() {
return 0;
}
let char_count = text.chars().count() as f64;
let estimated_tokens = (char_count / config.chars_per_token).ceil() as u32;
(estimated_tokens as f64 * 1.1).ceil() as u32
}
pub fn count_completion_tokens(&self, model: &str, prompt: &str) -> Result<TokenEstimate> {
let config = self.get_model_config(model)?;
let input_tokens = config.request_overhead + self.estimate_text_tokens(config, prompt);
Ok(TokenEstimate {
input_tokens,
output_tokens: None,
total_tokens: input_tokens,
is_approximate: true,
confidence: 0.8,
})
}
#[allow(dead_code)]
pub fn count_embedding_tokens(&self, model: &str, input: &[String]) -> Result<TokenEstimate> {
let config = self.get_model_config(model)?;
let mut total_tokens = config.request_overhead;
for text in input {
total_tokens += self.estimate_text_tokens(config, text);
}
Ok(TokenEstimate {
input_tokens: total_tokens,
output_tokens: None,
total_tokens,
is_approximate: true,
confidence: 0.9, })
}
#[allow(dead_code)]
pub fn estimate_output_tokens(
&self,
max_tokens: Option<u32>,
input_tokens: u32,
model: &str,
) -> Result<u32> {
let config = self.get_model_config(model)?;
if let Some(max) = max_tokens {
let available_tokens = config.max_context_tokens.saturating_sub(input_tokens);
Ok(max.min(available_tokens))
} else {
let available_tokens = config.max_context_tokens.saturating_sub(input_tokens);
Ok((available_tokens as f64 * 0.25).ceil() as u32)
}
}
#[allow(dead_code)]
pub fn check_context_window(
&self,
model: &str,
input_tokens: u32,
max_output_tokens: Option<u32>,
) -> Result<bool> {
let config = self.get_model_config(model)?;
let output_tokens = max_output_tokens.unwrap_or(0);
let total_tokens = input_tokens + output_tokens;
Ok(total_tokens <= config.max_context_tokens)
}
fn get_model_config(&self, model: &str) -> Result<&ModelTokenConfig> {
if let Some(config) = self.model_configs.get(model) {
return Ok(config);
}
let model_family = self.extract_model_family(model);
if let Some(config) = self.model_configs.get(&model_family) {
return Ok(config);
}
self.model_configs.get("default").ok_or_else(|| {
GatewayError::Config(format!("No token config found for model: {}", model))
})
}
fn extract_model_family(&self, model: &str) -> String {
let model = if let Some(pos) = model.find('/') {
&model[pos + 1..]
} else {
model
};
if model.starts_with("gpt-4") {
"gpt-4".to_string()
} else if model.starts_with("gpt-3.5") {
"gpt-3.5-turbo".to_string()
} else if model.starts_with("claude-3") {
"claude-3".to_string()
} else if model.starts_with("claude-2") {
"claude-2".to_string()
} else {
"default".to_string()
}
}
fn default_model_configs() -> HashMap<String, ModelTokenConfig> {
let mut configs = HashMap::new();
configs.insert(
"gpt-4".to_string(),
ModelTokenConfig {
model: "gpt-4".to_string(),
chars_per_token: 4.0,
message_overhead: 3,
request_overhead: 3,
max_context_tokens: 8192,
special_tokens: HashMap::new(),
},
);
configs.insert(
"gpt-3.5-turbo".to_string(),
ModelTokenConfig {
model: "gpt-3.5-turbo".to_string(),
chars_per_token: 4.0,
message_overhead: 3,
request_overhead: 3,
max_context_tokens: 4096,
special_tokens: HashMap::new(),
},
);
configs.insert(
"claude-3".to_string(),
ModelTokenConfig {
model: "claude-3".to_string(),
chars_per_token: 3.5,
message_overhead: 4,
request_overhead: 5,
max_context_tokens: 200000,
special_tokens: HashMap::new(),
},
);
configs.insert(
"claude-2".to_string(),
ModelTokenConfig {
model: "claude-2".to_string(),
chars_per_token: 3.5,
message_overhead: 4,
request_overhead: 5,
max_context_tokens: 100000,
special_tokens: HashMap::new(),
},
);
configs.insert(
"default".to_string(),
ModelTokenConfig {
model: "default".to_string(),
chars_per_token: 4.0,
message_overhead: 3,
request_overhead: 3,
max_context_tokens: 4096,
special_tokens: HashMap::new(),
},
);
configs
}
#[allow(dead_code)]
pub fn add_model_config(&mut self, config: ModelTokenConfig) {
self.model_configs.insert(config.model.clone(), config);
}
#[allow(dead_code)]
pub fn get_supported_models(&self) -> Vec<String> {
self.model_configs.keys().cloned().collect()
}
}
impl Default for TokenCounter {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::models::openai::{ChatMessage, MessageContent, MessageRole};
#[test]
fn test_text_token_estimation() {
let counter = TokenCounter::new();
let config = counter.get_model_config("gpt-3.5-turbo").unwrap();
let tokens = counter.estimate_text_tokens(config, "Hello, world!");
assert!(tokens > 0);
assert!(tokens < 10); }
#[test]
fn test_chat_token_counting() {
let counter = TokenCounter::new();
let messages = vec![ChatMessage {
role: MessageRole::User,
content: Some(MessageContent::Text("Hello, how are you?".to_string())),
name: None,
function_call: None,
tool_calls: None,
tool_call_id: None,
audio: None,
}];
let estimate = counter
.count_chat_tokens("gpt-3.5-turbo", &messages)
.unwrap();
assert!(estimate.input_tokens > 0);
assert!(estimate.is_approximate);
}
#[test]
fn test_context_window_check() {
let counter = TokenCounter::new();
assert!(
counter
.check_context_window("gpt-3.5-turbo", 1000, Some(1000))
.unwrap()
);
assert!(
!counter
.check_context_window("gpt-3.5-turbo", 3000, Some(2000))
.unwrap()
);
}
#[test]
fn test_model_family_extraction() {
let counter = TokenCounter::new();
assert_eq!(counter.extract_model_family("gpt-4-turbo"), "gpt-4");
assert_eq!(
counter.extract_model_family("gpt-3.5-turbo-16k"),
"gpt-3.5-turbo"
);
assert_eq!(counter.extract_model_family("claude-3-opus"), "claude-3");
assert_eq!(counter.extract_model_family("unknown-model"), "default");
}
}