use crate::{Message, Role};
const CHARS_PER_TOKEN: usize = 4;
#[derive(Debug, Clone)]
pub struct TokenBudget {
pub max_tokens: usize,
}
impl TokenBudget {
pub fn new(max_tokens: usize) -> Self {
Self { max_tokens }
}
fn estimate_tokens(msg: &Message) -> usize {
let chars = msg.content.to_text_lossy().len();
4 + chars / CHARS_PER_TOKEN
}
pub fn estimate_total(messages: &[Message]) -> usize {
messages.iter().map(Self::estimate_tokens).sum()
}
pub fn apply(&self, messages: &mut Vec<Message>) {
while Self::estimate_total(messages) > self.max_tokens {
let first_non_system = messages
.iter()
.position(|m| !matches!(m.role, Role::System));
match first_non_system {
Some(idx) => {
messages.remove(idx);
}
None => {
break;
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Message;
#[test]
fn trims_oldest_messages_to_fit() {
let budget = TokenBudget::new(50);
let long = "x".repeat(200); let mut messages = vec![
Message::system("sys"),
Message::user(long.clone()),
Message::user(long.clone()),
Message::user("short"),
];
budget.apply(&mut messages);
assert!(messages.iter().any(|m| matches!(m.role, Role::System)));
assert!(TokenBudget::estimate_total(&messages) <= 50);
}
#[test]
fn does_not_trim_system_messages() {
let budget = TokenBudget::new(1); let mut messages = vec![Message::system("system prompt")];
budget.apply(&mut messages);
assert_eq!(messages.len(), 1);
}
#[test]
fn noop_when_within_budget() {
let budget = TokenBudget::new(10_000);
let mut messages = vec![
Message::system("sys"),
Message::user("hello"),
];
let original_len = messages.len();
budget.apply(&mut messages);
assert_eq!(messages.len(), original_len);
}
}