use crate::chat::ChatMessage;
#[derive(Debug)]
pub struct ContextWindow {
max_tokens: u32,
reserved_for_output: u32,
messages: Vec<TrackedMessage>,
}
#[derive(Debug, Clone)]
struct TrackedMessage {
message: ChatMessage,
token_count: u32,
compactable: bool,
}
impl ContextWindow {
pub fn new(max_tokens: u32, reserved_for_output: u32) -> Self {
assert!(
reserved_for_output < max_tokens,
"reserved_for_output ({reserved_for_output}) must be less than max_tokens ({max_tokens})"
);
Self {
max_tokens,
reserved_for_output,
messages: Vec::new(),
}
}
pub fn push(&mut self, message: ChatMessage, tokens: u32) {
self.messages.push(TrackedMessage {
message,
token_count: tokens,
compactable: true,
});
}
pub fn available(&self) -> u32 {
let input_budget = self.max_tokens.saturating_sub(self.reserved_for_output);
input_budget.saturating_sub(self.total_tokens())
}
pub fn iter(&self) -> impl Iterator<Item = &ChatMessage> {
self.messages.iter().map(|t| &t.message)
}
pub fn messages(&self) -> Vec<&ChatMessage> {
self.messages.iter().map(|t| &t.message).collect()
}
pub fn messages_owned(&self) -> Vec<ChatMessage> {
self.messages.iter().map(|t| t.message.clone()).collect()
}
pub fn total_tokens(&self) -> u32 {
self.messages
.iter()
.map(|t| t.token_count)
.fold(0, u32::saturating_add)
}
pub fn len(&self) -> usize {
self.messages.len()
}
pub fn is_empty(&self) -> bool {
self.messages.is_empty()
}
#[allow(clippy::cast_precision_loss)]
pub fn needs_compaction(&self, threshold: f32) -> bool {
let input_budget = self.max_tokens.saturating_sub(self.reserved_for_output);
if input_budget == 0 {
return false;
}
let usage_ratio = self.total_tokens() as f32 / input_budget as f32;
usage_ratio > threshold
}
pub fn compact(&mut self) -> Vec<ChatMessage> {
let mut removed = Vec::new();
let mut retained = Vec::new();
for tracked in self.messages.drain(..) {
if tracked.compactable {
removed.push(tracked.message);
} else {
retained.push(tracked);
}
}
self.messages = retained;
removed
}
pub fn protect_recent(&mut self, n: usize) {
let len = self.messages.len();
let start = len.saturating_sub(n);
for msg in &mut self.messages[start..] {
msg.compactable = false;
}
}
pub fn protect(&mut self, index: usize) {
self.messages[index].compactable = false;
}
pub fn unprotect(&mut self, index: usize) {
self.messages[index].compactable = true;
}
pub fn is_protected(&self, index: usize) -> bool {
!self.messages[index].compactable
}
pub fn input_budget(&self) -> u32 {
self.max_tokens.saturating_sub(self.reserved_for_output)
}
pub fn max_tokens(&self) -> u32 {
self.max_tokens
}
pub fn reserved_for_output(&self) -> u32 {
self.reserved_for_output
}
pub fn clear(&mut self) {
self.messages.clear();
}
pub fn token_count(&self, index: usize) -> u32 {
self.messages[index].token_count
}
pub fn update_token_count(&mut self, index: usize, tokens: u32) {
self.messages[index].token_count = tokens;
}
pub fn force_fit(&mut self) -> Vec<ChatMessage> {
let mut removed = Vec::new();
while self.needs_compaction(1.0) {
let idx = self.messages.iter().position(|m| m.compactable);
match idx {
Some(i) => removed.push(self.messages.remove(i).message),
None => break, }
}
removed
}
}
#[allow(clippy::cast_possible_truncation)]
pub fn estimate_tokens(text: &str) -> u32 {
if text.is_empty() {
return 0;
}
let len = text.len().min(u32::MAX as usize) as u32;
len.div_ceil(4).max(1)
}
pub fn estimate_message_tokens(message: &ChatMessage) -> u32 {
use crate::chat::ContentBlock;
let content_tokens: u32 = message
.content
.iter()
.map(|block| match block {
ContentBlock::Text(text) => estimate_tokens(text),
ContentBlock::Image { .. } => 85,
ContentBlock::ToolCall(tc) => {
estimate_tokens(&tc.name) + estimate_tokens(&tc.arguments.to_string())
}
ContentBlock::ToolResult(tr) => estimate_tokens(&tr.content) + 10,
ContentBlock::Reasoning { content } => estimate_tokens(content),
})
.sum();
content_tokens + 4
}
#[cfg(test)]
mod tests {
use super::*;
use crate::chat::ChatRole;
fn user_msg(text: &str) -> ChatMessage {
ChatMessage::user(text)
}
fn assistant_msg(text: &str) -> ChatMessage {
ChatMessage::assistant(text)
}
fn system_msg(text: &str) -> ChatMessage {
ChatMessage::system(text)
}
#[test]
fn test_new_context_window() {
let window = ContextWindow::new(8000, 1000);
assert_eq!(window.max_tokens(), 8000);
assert_eq!(window.reserved_for_output(), 1000);
assert_eq!(window.input_budget(), 7000);
assert!(window.is_empty());
assert_eq!(window.len(), 0);
}
#[test]
#[should_panic(expected = "reserved_for_output")]
fn test_new_invalid_reserved() {
ContextWindow::new(1000, 1000);
}
#[test]
#[should_panic(expected = "reserved_for_output")]
fn test_new_reserved_exceeds_max() {
ContextWindow::new(1000, 2000);
}
#[test]
fn test_push_and_len() {
let mut window = ContextWindow::new(8000, 1000);
window.push(user_msg("Hello"), 10);
window.push(assistant_msg("Hi"), 8);
assert_eq!(window.len(), 2);
assert!(!window.is_empty());
}
#[test]
fn test_total_tokens() {
let mut window = ContextWindow::new(8000, 1000);
window.push(user_msg("Hello"), 10);
window.push(assistant_msg("Hi"), 8);
window.push(user_msg("How are you?"), 15);
assert_eq!(window.total_tokens(), 33);
}
#[test]
fn test_available_tokens() {
let mut window = ContextWindow::new(8000, 1000);
assert_eq!(window.available(), 7000);
window.push(user_msg("Hello"), 100);
assert_eq!(window.available(), 6900);
window.push(assistant_msg("Hi"), 50);
assert_eq!(window.available(), 6850);
}
#[test]
fn test_available_saturates() {
let mut window = ContextWindow::new(1000, 100);
window.push(user_msg("Large message"), 1000);
assert_eq!(window.available(), 0);
}
#[test]
fn test_messages() {
let mut window = ContextWindow::new(8000, 1000);
window.push(user_msg("Hello"), 10);
window.push(assistant_msg("Hi"), 8);
let messages = window.messages();
assert_eq!(messages.len(), 2);
assert_eq!(messages[0].role, ChatRole::User);
assert_eq!(messages[1].role, ChatRole::Assistant);
}
#[test]
fn test_messages_owned() {
let mut window = ContextWindow::new(8000, 1000);
window.push(user_msg("Hello"), 10);
let messages = window.messages_owned();
assert_eq!(messages.len(), 1);
assert_eq!(messages[0].role, ChatRole::User);
}
#[test]
fn test_needs_compaction_below_threshold() {
let mut window = ContextWindow::new(1000, 200);
window.push(user_msg("Hello"), 400);
assert!(!window.needs_compaction(0.8));
}
#[test]
fn test_needs_compaction_above_threshold() {
let mut window = ContextWindow::new(1000, 200);
window.push(user_msg("Hello"), 700);
assert!(window.needs_compaction(0.8));
}
#[test]
fn test_needs_compaction_at_threshold() {
let mut window = ContextWindow::new(1000, 200);
window.push(user_msg("Hello"), 640);
assert!(!window.needs_compaction(0.8));
}
#[test]
fn test_needs_compaction_zero_budget() {
let window = ContextWindow::new(100, 99);
assert!(!window.needs_compaction(0.8));
}
#[test]
fn test_compact_all_compactable() {
let mut window = ContextWindow::new(8000, 1000);
window.push(user_msg("Hello"), 10);
window.push(assistant_msg("Hi"), 8);
window.push(user_msg("Bye"), 5);
let removed = window.compact();
assert_eq!(removed.len(), 3);
assert!(window.is_empty());
assert_eq!(window.total_tokens(), 0);
}
#[test]
fn test_compact_with_protected() {
let mut window = ContextWindow::new(8000, 1000);
window.push(system_msg("System"), 20);
window.push(user_msg("Hello"), 10);
window.push(assistant_msg("Hi"), 8);
window.push(user_msg("Question"), 15);
window.protect(0);
window.protect_recent(2);
let removed = window.compact();
assert_eq!(removed.len(), 1);
assert_eq!(window.len(), 3);
assert_eq!(window.total_tokens(), 20 + 8 + 15);
}
#[test]
fn test_compact_none_compactable() {
let mut window = ContextWindow::new(8000, 1000);
window.push(system_msg("System"), 20);
window.push(user_msg("Hello"), 10);
window.protect_recent(2);
let removed = window.compact();
assert!(removed.is_empty());
assert_eq!(window.len(), 2);
}
#[test]
fn test_protect_recent() {
let mut window = ContextWindow::new(8000, 1000);
window.push(user_msg("1"), 10);
window.push(user_msg("2"), 10);
window.push(user_msg("3"), 10);
window.push(user_msg("4"), 10);
window.protect_recent(2);
let removed = window.compact();
assert_eq!(removed.len(), 2);
assert_eq!(window.len(), 2);
}
#[test]
fn test_protect_recent_more_than_len() {
let mut window = ContextWindow::new(8000, 1000);
window.push(user_msg("1"), 10);
window.push(user_msg("2"), 10);
window.protect_recent(10);
let removed = window.compact();
assert!(removed.is_empty());
assert_eq!(window.len(), 2);
}
#[test]
fn test_protect_index() {
let mut window = ContextWindow::new(8000, 1000);
window.push(user_msg("1"), 10);
window.push(user_msg("2"), 10);
window.push(user_msg("3"), 10);
window.protect(1);
let removed = window.compact();
assert_eq!(removed.len(), 2);
assert_eq!(window.len(), 1);
}
#[test]
fn test_unprotect() {
let mut window = ContextWindow::new(8000, 1000);
window.push(user_msg("1"), 10);
window.push(user_msg("2"), 10);
window.protect(0);
assert!(window.is_protected(0));
window.unprotect(0);
assert!(!window.is_protected(0));
let removed = window.compact();
assert_eq!(removed.len(), 2);
}
#[test]
fn test_is_protected() {
let mut window = ContextWindow::new(8000, 1000);
window.push(user_msg("1"), 10);
window.push(user_msg("2"), 10);
assert!(!window.is_protected(0));
assert!(!window.is_protected(1));
window.protect(0);
assert!(window.is_protected(0));
assert!(!window.is_protected(1));
}
#[test]
fn test_iter() {
let mut window = ContextWindow::new(8000, 1000);
window.push(user_msg("Hello"), 10);
window.push(assistant_msg("Hi"), 8);
let collected: Vec<_> = window.iter().collect();
assert_eq!(collected.len(), 2);
assert_eq!(collected[0].role, ChatRole::User);
assert_eq!(collected[1].role, ChatRole::Assistant);
}
#[test]
fn test_token_count() {
let mut window = ContextWindow::new(8000, 1000);
window.push(user_msg("Hello"), 42);
assert_eq!(window.token_count(0), 42);
}
#[test]
fn test_update_token_count() {
let mut window = ContextWindow::new(8000, 1000);
window.push(user_msg("Hello"), 10);
assert_eq!(window.total_tokens(), 10);
window.update_token_count(0, 15);
assert_eq!(window.token_count(0), 15);
assert_eq!(window.total_tokens(), 15);
}
#[test]
fn test_clear() {
let mut window = ContextWindow::new(8000, 1000);
window.push(user_msg("Hello"), 10);
window.push(assistant_msg("Hi"), 8);
window.clear();
assert!(window.is_empty());
assert_eq!(window.total_tokens(), 0);
assert_eq!(window.available(), 7000);
}
#[test]
fn test_estimate_tokens_empty() {
assert_eq!(estimate_tokens(""), 0);
}
#[test]
fn test_estimate_tokens_short() {
assert_eq!(estimate_tokens("Hi"), 1);
}
#[test]
fn test_estimate_tokens_medium() {
assert_eq!(estimate_tokens("Hello world"), 3);
}
#[test]
fn test_estimate_tokens_exact_multiple() {
assert_eq!(estimate_tokens("1234567890123456"), 4);
}
#[test]
fn test_estimate_tokens_minimum() {
assert_eq!(estimate_tokens("a"), 1);
}
#[test]
fn test_estimate_message_tokens() {
let msg = user_msg("Hello world");
let estimate = estimate_message_tokens(&msg);
assert_eq!(estimate, 7);
}
#[test]
fn test_estimate_message_tokens_empty() {
let msg = ChatMessage {
role: ChatRole::User,
content: vec![],
};
let estimate = estimate_message_tokens(&msg);
assert_eq!(estimate, 4);
}
#[test]
fn test_context_window_debug() {
let window = ContextWindow::new(8000, 1000);
let debug = format!("{window:?}");
assert!(debug.contains("ContextWindow"));
assert!(debug.contains("8000"));
}
#[test]
fn test_context_window_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<ContextWindow>();
}
#[test]
fn test_typical_conversation_flow() {
let mut window = ContextWindow::new(4000, 500);
window.push(system_msg("You are a helpful assistant."), 15);
window.protect(0);
window.push(user_msg("What is 2+2?"), 20);
window.push(assistant_msg("2+2 equals 4."), 25);
window.push(user_msg("What about 3+3?"), 22);
window.push(assistant_msg("3+3 equals 6."), 25);
assert_eq!(window.len(), 5);
assert_eq!(window.total_tokens(), 107);
assert_eq!(window.available(), 3500 - 107);
assert!(!window.needs_compaction(0.8));
for i in 0..50 {
window.push(user_msg(&format!("Question {i}")), 30);
window.push(assistant_msg(&format!("Answer {i}")), 30);
}
assert!(window.needs_compaction(0.8));
window.protect_recent(4);
let removed = window.compact();
assert!(!removed.is_empty());
assert!(window.len() <= 5); assert!(window.messages()[0].role == ChatRole::System);
}
#[test]
fn test_compact_then_add_summary() {
let mut window = ContextWindow::new(1000, 100);
window.push(system_msg("System"), 20);
window.protect(0);
for _ in 0..10 {
window.push(user_msg("Message"), 80);
}
let removed = window.compact();
assert_eq!(removed.len(), 10);
assert_eq!(window.len(), 1);
window.push(
ChatMessage::system("Summary of previous conversation..."),
50,
);
assert_eq!(window.len(), 2);
assert_eq!(window.total_tokens(), 70);
}
#[test]
fn test_force_fit_drops_oldest_first() {
let mut window = ContextWindow::new(1000, 100);
window.push(system_msg("System"), 20);
window.protect(0);
window.push(user_msg("Old"), 500);
window.push(user_msg("Newer"), 500);
assert!(window.needs_compaction(1.0));
let removed = window.force_fit();
assert_eq!(removed.len(), 1);
assert_eq!(window.len(), 2); assert_eq!(window.total_tokens(), 520);
assert!(!window.needs_compaction(1.0));
}
#[test]
fn test_force_fit_stops_when_under_budget() {
let mut window = ContextWindow::new(1000, 100);
window.push(user_msg("A"), 300);
window.push(user_msg("B"), 300);
window.push(user_msg("C"), 300);
window.push(user_msg("D"), 200);
assert!(window.needs_compaction(1.0));
let removed = window.force_fit();
assert_eq!(removed.len(), 1);
assert_eq!(window.total_tokens(), 800);
}
#[test]
fn test_force_fit_skips_protected() {
let mut window = ContextWindow::new(1000, 100);
window.push(system_msg("System"), 400);
window.protect(0);
window.push(user_msg("Old 1"), 300);
window.push(user_msg("Old 2"), 300);
let removed = window.force_fit();
assert_eq!(removed.len(), 1);
assert_eq!(window.len(), 2); assert_eq!(window.total_tokens(), 700);
}
#[test]
fn test_force_fit_noop_when_under_budget() {
let mut window = ContextWindow::new(1000, 100);
window.push(user_msg("Small"), 50);
let removed = window.force_fit();
assert!(removed.is_empty());
assert_eq!(window.len(), 1);
}
#[test]
fn test_force_fit_stops_when_only_protected_remain() {
let mut window = ContextWindow::new(1000, 100);
window.push(system_msg("Big system"), 600);
window.protect(0);
window.push(user_msg("Big user"), 400);
window.protect(1);
let removed = window.force_fit();
assert!(removed.is_empty());
assert_eq!(window.len(), 2);
assert!(window.needs_compaction(1.0));
}
}