use crate::llm::message::Message;
use crate::services::tokens;
pub struct CollapseResult {
pub api_messages: Vec<Message>,
pub snipped_count: usize,
pub tokens_freed: u64,
}
pub fn collapse_to_budget(messages: &[Message], max_tokens: u64) -> Option<CollapseResult> {
let current = tokens::estimate_context_tokens(messages);
if current <= max_tokens {
return None; }
let overshoot = current - max_tokens;
let groups = group_by_round(messages);
if groups.len() <= 2 {
return None; }
let mut freed = 0u64;
let mut snip_end = 1;
for (group_idx, group) in groups[1..groups.len().saturating_sub(1)].iter().enumerate() {
let group_tokens: u64 = group.iter().map(tokens::estimate_message_tokens).sum();
freed += group_tokens;
snip_end = group_idx + 2;
if freed >= overshoot {
break;
}
}
if freed == 0 {
return None;
}
let mut api_messages = Vec::new();
api_messages.extend(groups[0].iter().cloned());
api_messages.push(crate::llm::message::user_message(
"[Earlier messages collapsed to fit context window]",
));
for group in &groups[snip_end..] {
api_messages.extend(group.iter().cloned());
}
let snipped_count: usize = groups[1..snip_end].iter().map(|g| g.len()).sum();
Some(CollapseResult {
api_messages,
snipped_count,
tokens_freed: freed,
})
}
pub fn recover_from_overflow(
messages: &[Message],
token_gap: Option<u64>,
) -> Option<CollapseResult> {
let target = token_gap.map(|gap| gap + gap / 10).unwrap_or(20_000);
let current = tokens::estimate_context_tokens(messages);
let budget = current.saturating_sub(target);
collapse_to_budget(messages, budget)
}
fn group_by_round(messages: &[Message]) -> Vec<Vec<Message>> {
let mut groups: Vec<Vec<Message>> = Vec::new();
let mut current_group: Vec<Message> = Vec::new();
for msg in messages {
match msg {
Message::User(u) if !u.is_meta => {
if !current_group.is_empty() {
groups.push(current_group);
current_group = Vec::new();
}
current_group.push(msg.clone());
}
_ => {
current_group.push(msg.clone());
}
}
}
if !current_group.is_empty() {
groups.push(current_group);
}
groups
}
#[cfg(test)]
mod tests {
use super::*;
use crate::llm::message::user_message;
#[test]
fn test_no_collapse_within_budget() {
let messages = vec![user_message("short")];
assert!(collapse_to_budget(&messages, 1_000_000).is_none());
}
#[test]
fn test_collapse_empty_messages() {
let messages: Vec<crate::llm::message::Message> = vec![];
assert!(collapse_to_budget(&messages, 100).is_none());
}
#[test]
fn test_collapse_preserves_first_and_last() {
use crate::llm::message::*;
let mut messages = Vec::new();
for i in 0..10 {
messages.push(user_message(format!(
"message {i} with some content padding"
)));
messages.push(Message::Assistant(AssistantMessage {
uuid: uuid::Uuid::new_v4(),
timestamp: String::new(),
content: vec![ContentBlock::Text {
text: format!("response {i} with content"),
}],
model: None,
usage: None,
stop_reason: None,
request_id: None,
}));
}
if let Some(result) = collapse_to_budget(&messages, 50) {
assert!(result.snipped_count > 0);
assert!(result.tokens_freed > 0);
assert!(result.api_messages.len() < messages.len());
}
assert!(collapse_to_budget(&messages, 1_000_000).is_none());
}
#[test]
fn test_recover_from_overflow() {
use crate::llm::message::*;
let mut messages = Vec::new();
for i in 0..20 {
messages.push(user_message(format!("msg {i} {}", "x".repeat(200))));
messages.push(Message::Assistant(AssistantMessage {
uuid: uuid::Uuid::new_v4(),
timestamp: String::new(),
content: vec![ContentBlock::Text {
text: format!("resp {i} {}", "y".repeat(200)),
}],
model: None,
usage: None,
stop_reason: None,
request_id: None,
}));
}
let result = recover_from_overflow(&messages, Some(5000));
assert!(result.is_some());
let r = result.unwrap();
assert!(r.snipped_count > 0);
}
}