use crate::message::{Message, MessageContent};
pub trait TokenCounter: Send + Sync {
fn count_messages(&self, messages: &[Message]) -> u32;
fn count_text(&self, text: &str) -> u32;
}
pub struct CharTokenCounter;
impl TokenCounter for CharTokenCounter {
fn count_messages(&self, messages: &[Message]) -> u32 {
messages
.iter()
.map(|m| self.count_text(&message_to_text(m)))
.sum()
}
fn count_text(&self, text: &str) -> u32 {
(text.chars().count() / 4).max(1) as u32
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum TrimStrategy {
Last,
First,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum MessageRole {
Human,
Ai,
System,
Tool,
}
pub struct TrimOptions<'a> {
pub strategy: TrimStrategy,
pub max_tokens: u32,
pub token_counter: &'a dyn TokenCounter,
pub start_on: Option<MessageRole>,
pub end_on: Option<Vec<MessageRole>>,
}
pub fn trim_messages(messages: &[Message], opts: TrimOptions) -> Vec<Message> {
if messages.is_empty() {
return vec![];
}
let mut result: Vec<Message> = match opts.strategy {
TrimStrategy::Last => {
let mut selected = Vec::new();
let mut budget = opts.max_tokens;
for msg in messages.iter().rev() {
let cost = opts.token_counter.count_messages(std::slice::from_ref(msg));
if cost > budget {
break;
}
budget -= cost;
selected.push(msg.clone());
}
selected.reverse();
selected
}
TrimStrategy::First => {
let mut selected = Vec::new();
let mut budget = opts.max_tokens;
for msg in messages {
let cost = opts.token_counter.count_messages(std::slice::from_ref(msg));
if cost > budget {
break;
}
budget -= cost;
selected.push(msg.clone());
}
selected
}
};
if let Some(ref start_role) = opts.start_on {
if let Some(start_idx) = result.iter().position(|m| message_has_role(m, start_role)) {
result.drain(..start_idx);
} else {
result.clear();
}
}
if let Some(ref end_roles) = opts.end_on {
while !result.is_empty()
&& !end_roles
.iter()
.any(|r| message_has_role(result.last().unwrap(), r))
{
result.pop();
}
}
result
}
fn message_has_role(msg: &Message, role: &MessageRole) -> bool {
matches!(
(msg, role),
(Message::Human(_), MessageRole::Human)
| (Message::Ai(_), MessageRole::Ai)
| (Message::System(_), MessageRole::System)
| (Message::Tool(_), MessageRole::Tool)
)
}
fn message_to_text(msg: &Message) -> String {
match msg {
Message::Human(m) => content_to_text(&m.content),
Message::Ai(m) => content_to_text(&m.content),
Message::System(m) => m.content.clone(),
Message::Tool(m) => m.content.clone(),
}
}
fn content_to_text(content: &MessageContent) -> String {
match content {
MessageContent::Text(t) => t.clone(),
MessageContent::Blocks(blocks) => blocks
.iter()
.map(|b| match b {
crate::message::ContentBlock::Text { text } => text.clone(),
_ => String::new(),
})
.collect::<Vec<_>>()
.join(" "),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_char_token_counter() {
let counter = CharTokenCounter;
assert_eq!(counter.count_text("hello world"), 2);
assert_eq!(counter.count_text("hi"), 1); assert_eq!(counter.count_text(""), 1); }
#[test]
fn test_trim_empty_input() {
let result = trim_messages(
&[],
TrimOptions {
strategy: TrimStrategy::Last,
max_tokens: 100,
token_counter: &CharTokenCounter,
start_on: None,
end_on: None,
},
);
assert!(result.is_empty());
}
#[test]
fn test_trim_all_under_budget() {
let messages = vec![Message::human("hello"), Message::ai("hi")];
let result = trim_messages(
&messages,
TrimOptions {
strategy: TrimStrategy::Last,
max_tokens: 1000,
token_counter: &CharTokenCounter,
start_on: None,
end_on: None,
},
);
assert_eq!(result.len(), 2);
}
#[test]
fn test_trim_last_strategy() {
let messages = vec![
Message::system("You are helpful. This is a long system prompt with many tokens."),
Message::human("short q"),
Message::ai("short a"),
];
let result = trim_messages(
&messages,
TrimOptions {
strategy: TrimStrategy::Last,
max_tokens: 5,
token_counter: &CharTokenCounter,
start_on: None,
end_on: None,
},
);
assert!(result.len() < messages.len());
if !result.is_empty() {
assert!(matches!(result.last().unwrap(), Message::Ai(_)));
}
}
#[test]
fn test_trim_first_strategy() {
let messages = vec![
Message::human("first"),
Message::ai("second"),
Message::human("this is a much longer message that uses more tokens"),
];
let result = trim_messages(
&messages,
TrimOptions {
strategy: TrimStrategy::First,
max_tokens: 5,
token_counter: &CharTokenCounter,
start_on: None,
end_on: None,
},
);
assert!(result.len() < messages.len());
if !result.is_empty() {
assert!(matches!(result[0], Message::Human(_)));
}
}
#[test]
fn test_trim_start_on_human() {
let messages = vec![
Message::system("sys"),
Message::ai("ai response"),
Message::human("question"),
];
let result = trim_messages(
&messages,
TrimOptions {
strategy: TrimStrategy::Last,
max_tokens: 1000,
token_counter: &CharTokenCounter,
start_on: Some(MessageRole::Human),
end_on: None,
},
);
assert!(matches!(result[0], Message::Human(_)));
}
#[test]
fn test_trim_end_on_human_or_tool() {
let messages = vec![Message::human("q"), Message::ai("response")];
let result = trim_messages(
&messages,
TrimOptions {
strategy: TrimStrategy::Last,
max_tokens: 1000,
token_counter: &CharTokenCounter,
start_on: None,
end_on: Some(vec![MessageRole::Human, MessageRole::Tool]),
},
);
if !result.is_empty() {
assert!(!matches!(result.last().unwrap(), Message::Ai(_)));
}
}
}