use crate::messages::Message;
pub fn estimate_token_count(text: &str) -> usize {
if text.is_empty() {
return 0;
}
(text.len() as f64 / 4.0).ceil() as usize
}
pub fn count_message_tokens(messages: &[Message], tokens_per_message: usize) -> usize {
let mut total = 0;
for msg in messages {
total += tokens_per_message; total += estimate_token_count(&msg.content().text());
if let Message::Ai(ai) = msg {
for tc in &ai.tool_calls {
total += estimate_token_count(&tc.name);
total += estimate_token_count(&serde_json::to_string(&tc.args).unwrap_or_default());
}
}
}
total + 3 }
pub fn trim_messages(
messages: &[Message],
max_tokens: usize,
tokens_per_message: usize,
) -> Vec<Message> {
let mut system_messages: Vec<Message> = Vec::new();
let mut other_messages: Vec<&Message> = Vec::new();
let mut seen_non_system = false;
for msg in messages {
if !seen_non_system && msg.message_type() == crate::messages::MessageType::System {
system_messages.push(msg.clone());
} else {
seen_non_system = true;
other_messages.push(msg);
}
}
let system_tokens = if system_messages.is_empty() {
3 } else {
count_message_tokens(&system_messages, tokens_per_message)
};
if system_tokens >= max_tokens {
return system_messages;
}
let remaining_budget = max_tokens - system_tokens;
let mut kept: Vec<Message> = Vec::new();
let mut used_tokens: usize = 0;
for msg in other_messages.iter().rev() {
let msg_tokens = tokens_per_message + estimate_token_count(&(*msg).content().text());
let tool_tokens = if let Message::Ai(ai) = *msg {
ai.tool_calls
.iter()
.map(|tc| {
estimate_token_count(&tc.name)
+ estimate_token_count(&serde_json::to_string(&tc.args).unwrap_or_default())
})
.sum::<usize>()
} else {
0
};
let total_msg_tokens = msg_tokens + tool_tokens;
if used_tokens + total_msg_tokens > remaining_budget {
break;
}
used_tokens += total_msg_tokens;
kept.push((*msg).clone());
}
kept.reverse();
system_messages.extend(kept);
system_messages
}
pub fn get_model_context_window(model: &str) -> Option<usize> {
match model {
s if s.starts_with("gpt-4o") => Some(128_000),
s if s.starts_with("gpt-4-turbo") => Some(128_000),
s if s.starts_with("gpt-4") => Some(8_192),
s if s.starts_with("gpt-3.5") => Some(16_385),
s if s.starts_with("claude-3")
|| s.starts_with("claude-sonnet")
|| s.starts_with("claude-opus") =>
{
Some(200_000)
}
s if s.starts_with("gemini") => Some(1_000_000),
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::messages::tool_types::ToolCall;
use crate::messages::{AIMessage, Message};
use std::collections::HashMap;
#[test]
fn test_estimate_token_count_empty() {
assert_eq!(estimate_token_count(""), 0);
}
#[test]
fn test_estimate_token_count_known_text() {
assert_eq!(estimate_token_count("hello world"), 3);
assert_eq!(estimate_token_count("a"), 1);
assert_eq!(estimate_token_count("abcdefgh"), 2);
}
#[test]
fn test_count_message_tokens_basic() {
let messages = vec![Message::human("Hello"), Message::ai("Hi there!")];
let tokens = count_message_tokens(&messages, 3);
assert_eq!(tokens, 14);
}
#[test]
fn test_count_message_tokens_with_tool_calls() {
let mut args = HashMap::new();
args.insert("query".to_string(), serde_json::json!("weather"));
let tc = ToolCall {
name: "search".to_string(),
args,
id: Some("tc_1".to_string()),
};
let ai_msg = AIMessage::new("Let me search").with_tool_calls(vec![tc]);
let messages = vec![Message::Ai(ai_msg)];
let tokens = count_message_tokens(&messages, 3);
assert_eq!(tokens, 17);
}
#[test]
fn test_trim_messages_keeps_system() {
let messages = vec![
Message::system("You are helpful."),
Message::human("Oldest question"),
Message::ai("Oldest answer"),
Message::human("Newest question"),
];
let trimmed = trim_messages(&messages, 20, 3);
assert_eq!(
trimmed[0].message_type(),
crate::messages::MessageType::System
);
assert!(trimmed.len() < messages.len());
assert_eq!(trimmed.last().unwrap().content().text(), "Newest question");
}
#[test]
fn test_trim_messages_removes_oldest_first() {
let messages = vec![
Message::human("First"),
Message::ai("Second"),
Message::human("Third"),
Message::ai("Fourth"),
];
let trimmed = trim_messages(&messages, 15, 3);
if trimmed.len() < messages.len() {
assert_eq!(trimmed.last().unwrap().content().text(), "Fourth");
}
}
#[test]
fn test_trim_messages_all_fit() {
let messages = vec![Message::system("Be helpful."), Message::human("Hi")];
let trimmed = trim_messages(&messages, 1000, 3);
assert_eq!(trimmed.len(), messages.len());
}
#[test]
fn test_get_model_context_window_known() {
assert_eq!(get_model_context_window("gpt-4o"), Some(128_000));
assert_eq!(get_model_context_window("gpt-4o-mini"), Some(128_000));
assert_eq!(
get_model_context_window("gpt-4-turbo-preview"),
Some(128_000)
);
assert_eq!(get_model_context_window("gpt-4"), Some(8_192));
assert_eq!(get_model_context_window("gpt-3.5-turbo"), Some(16_385));
assert_eq!(get_model_context_window("claude-3-opus"), Some(200_000));
assert_eq!(get_model_context_window("claude-sonnet-4"), Some(200_000));
assert_eq!(get_model_context_window("claude-opus-4"), Some(200_000));
assert_eq!(get_model_context_window("gemini-pro"), Some(1_000_000));
}
#[test]
fn test_get_model_context_window_unknown() {
assert_eq!(get_model_context_window("unknown-model"), None);
assert_eq!(get_model_context_window("llama-3"), None);
}
}