use super::context_monitor::CompactionUrgency;
use crate::session::{ContentPart, Message, Role};
pub fn truncate_messages(messages: Vec<Message>, keep_recent: usize) -> Vec<Message> {
if messages.len() <= keep_recent {
return messages;
}
if keep_recent == 0 {
if let Some(first) = messages.first() {
if first.role == Role::System {
return vec![messages.into_iter().next().unwrap()];
}
}
return Vec::new();
}
let has_system_prefix = messages
.first()
.map(|m| m.role == Role::System)
.unwrap_or(false);
if has_system_prefix {
let total = messages.len();
let skip = (total - 1).saturating_sub(keep_recent);
let mut result = Vec::with_capacity(1 + keep_recent);
let mut iter = messages.into_iter();
result.push(iter.next().unwrap()); for msg in iter.skip(skip) {
result.push(msg);
}
result
} else {
let skip = messages.len() - keep_recent;
messages.into_iter().skip(skip).collect()
}
}
pub fn summarize_messages(
messages: Vec<Message>,
keep_recent: usize,
summary_text: &str,
) -> Vec<Message> {
if messages.is_empty() {
return vec![Message::system(&format!(
"[Conversation Summary]\n{}",
summary_text
))];
}
if messages.len() <= keep_recent {
return messages;
}
let has_system_prefix = messages
.first()
.map(|m| m.role == Role::System)
.unwrap_or(false);
let summary_msg = Message::system(&format!("[Conversation Summary]\n{}", summary_text));
let mut result = if has_system_prefix {
let total = messages.len();
let skip = (total - 1).saturating_sub(keep_recent);
let mut result = Vec::with_capacity(2 + keep_recent);
let mut iter = messages.into_iter();
result.push(iter.next().unwrap()); result.push(summary_msg);
for msg in iter.skip(skip) {
result.push(msg);
}
result
} else {
let total = messages.len();
let skip = total - keep_recent;
let mut result = Vec::with_capacity(1 + keep_recent);
result.push(summary_msg);
for msg in messages.into_iter().skip(skip) {
result.push(msg);
}
result
};
strip_images_from_messages(&mut result);
result
}
pub fn shrink_tool_results(
messages: Vec<Message>,
max_bytes_per_result: usize,
) -> (Vec<Message>, usize) {
let mut shrunk_count = 0;
let result = messages
.into_iter()
.map(|mut msg| {
if msg.is_tool_result() && msg.content.len() > max_bytes_per_result {
let original_len = msg.content.len();
msg.content.truncate(max_bytes_per_result);
while !msg.content.is_char_boundary(msg.content.len()) {
msg.content.pop();
}
msg.content.push_str(&format!(
"\n...[shrunk from {} to {} bytes]",
original_len,
msg.content.len()
));
shrunk_count += 1;
}
msg
})
.collect();
(result, shrunk_count)
}
pub fn shrink_tool_results_progressive(
messages: Vec<Message>,
target_max_bytes: usize,
recent_count: usize,
) -> Vec<Message> {
let tool_result_indices: Vec<usize> = messages
.iter()
.enumerate()
.filter(|(_, m)| m.is_tool_result())
.map(|(i, _)| i)
.collect();
let total_tool_results = tool_result_indices.len();
if total_tool_results == 0 {
return messages;
}
let mut messages = messages;
for (pos, &idx) in tool_result_indices.iter().enumerate() {
let is_recent = pos >= total_tool_results.saturating_sub(recent_count);
let budget = if is_recent {
target_max_bytes
} else {
target_max_bytes / 4
};
let msg = &mut messages[idx];
if msg.content.len() > budget {
let original_len = msg.content.len();
msg.content.truncate(budget);
while !msg.content.is_char_boundary(msg.content.len()) {
msg.content.pop();
}
msg.content.push_str(&format!(
"\n...[shrunk from {} to {} bytes]",
original_len,
msg.content.len()
));
}
}
messages
}
pub fn try_recover_context(
messages: Vec<Message>,
context_limit: usize,
keep_recent_tier1: usize,
tool_result_budget: usize,
) -> (Vec<Message>, u8) {
try_recover_context_with_urgency(
messages,
context_limit,
CompactionUrgency::Normal,
keep_recent_tier1,
tool_result_budget,
)
}
pub fn try_recover_context_with_urgency(
messages: Vec<Message>,
context_limit: usize,
urgency: CompactionUrgency,
keep_recent_tier1: usize,
tool_result_budget: usize,
) -> (Vec<Message>, u8) {
use super::context_monitor::ContextMonitor;
let target = context_limit as f64 * 0.95;
let estimated = ContextMonitor::estimate_tokens(&messages);
if (estimated as f64) <= target {
return (messages, 0);
}
match urgency {
CompactionUrgency::Critical => {
let recovered = truncate_messages(messages, 6);
(recovered, 3)
}
CompactionUrgency::Emergency => {
let recovered = truncate_messages(messages, keep_recent_tier1.min(5));
let estimated = ContextMonitor::estimate_tokens(&recovered);
if (estimated as f64) <= target {
return (recovered, 1);
}
let emergency_budget = (tool_result_budget / 2).max(1).min(tool_result_budget);
let recovered = shrink_tool_results_progressive(recovered, emergency_budget, 2);
let estimated = ContextMonitor::estimate_tokens(&recovered);
if (estimated as f64) <= target {
return (recovered, 2);
}
(truncate_messages(recovered, 3), 3)
}
CompactionUrgency::Normal => {
let recovered = truncate_messages(messages, keep_recent_tier1);
let estimated = ContextMonitor::estimate_tokens(&recovered);
if (estimated as f64) <= target {
return (recovered, 1);
}
let recovered = shrink_tool_results_progressive(recovered, tool_result_budget, 3);
let estimated = ContextMonitor::estimate_tokens(&recovered);
if (estimated as f64) <= target {
return (recovered, 2);
}
let recovered = truncate_messages(recovered, 3);
(recovered, 3)
}
}
}
pub fn build_summary_prompt(messages: &[Message]) -> String {
let mut transcript = String::new();
for msg in messages {
transcript.push_str(&format!("{}: {}\n", msg.role, msg.content));
}
format!(
"Summarize the following conversation focusing on key decisions, \
information exchanged, and actions taken. Be concise.\n\n{}",
transcript
)
}
pub fn strip_images_from_messages(messages: &mut [Message]) {
for msg in messages.iter_mut() {
if msg.has_images() {
msg.content_parts
.retain(|p| matches!(p, ContentPart::Text { .. }));
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::session::{ContentPart, ImageSource};
#[test]
fn test_strip_images_removes_image_parts() {
let images = vec![ContentPart::Image {
source: ImageSource::Base64 {
data: "big_data".to_string(),
},
media_type: "image/jpeg".to_string(),
}];
let mut msgs = vec![Message::user_with_images("What is this?", images)];
assert!(msgs[0].has_images());
strip_images_from_messages(&mut msgs);
assert!(!msgs[0].has_images());
assert_eq!(msgs[0].content_parts.len(), 1); assert_eq!(msgs[0].content, "What is this?");
}
#[test]
fn test_strip_images_leaves_text_only_unchanged() {
let mut msgs = vec![Message::user("Hello"), Message::assistant("Hi")];
strip_images_from_messages(&mut msgs);
assert_eq!(msgs[0].content_parts.len(), 1);
assert_eq!(msgs[1].content_parts.len(), 1);
}
#[test]
fn test_truncate_keeps_n_recent() {
let msgs = vec![
Message::user("one"),
Message::user("two"),
Message::user("three"),
Message::user("four"),
];
let result = truncate_messages(msgs, 2);
assert_eq!(result.len(), 2);
assert_eq!(result[0].content, "three");
assert_eq!(result[1].content, "four");
}
#[test]
fn test_truncate_preserves_system_message() {
let msgs = vec![
Message::system("system prompt"),
Message::user("one"),
Message::user("two"),
Message::user("three"),
];
let result = truncate_messages(msgs, 2);
assert_eq!(result.len(), 3); assert_eq!(result[0].role, Role::System);
assert_eq!(result[0].content, "system prompt");
assert_eq!(result[1].content, "two");
assert_eq!(result[2].content, "three");
}
#[test]
fn test_truncate_empty_messages() {
let result = truncate_messages(Vec::new(), 5);
assert!(result.is_empty());
}
#[test]
fn test_truncate_keep_greater_than_len() {
let msgs = vec![Message::user("one"), Message::user("two")];
let result = truncate_messages(msgs, 10);
assert_eq!(result.len(), 2);
assert_eq!(result[0].content, "one");
assert_eq!(result[1].content, "two");
}
#[test]
fn test_truncate_keep_equal_to_len() {
let msgs = vec![
Message::user("one"),
Message::user("two"),
Message::user("three"),
];
let result = truncate_messages(msgs, 3);
assert_eq!(result.len(), 3);
}
#[test]
fn test_truncate_keep_zero() {
let msgs = vec![Message::user("one"), Message::user("two")];
let result = truncate_messages(msgs, 0);
assert!(result.is_empty());
}
#[test]
fn test_truncate_keep_zero_with_system() {
let msgs = vec![
Message::system("sys"),
Message::user("one"),
Message::user("two"),
];
let result = truncate_messages(msgs, 0);
assert_eq!(result.len(), 1);
assert_eq!(result[0].role, Role::System);
assert_eq!(result[0].content, "sys");
}
#[test]
fn test_truncate_single_message() {
let msgs = vec![Message::user("only")];
let result = truncate_messages(msgs, 1);
assert_eq!(result.len(), 1);
assert_eq!(result[0].content, "only");
}
#[test]
fn test_summarize_with_system_message() {
let msgs = vec![
Message::system("You are helpful."),
Message::user("Tell me about Rust"),
Message::assistant("Rust is great."),
Message::user("And async?"),
Message::assistant("Use tokio."),
];
let result = summarize_messages(msgs, 2, "Discussed Rust basics.");
assert_eq!(result.len(), 4);
assert_eq!(result[0].role, Role::System);
assert_eq!(result[0].content, "You are helpful.");
assert_eq!(result[1].role, Role::System);
assert!(result[1].content.contains("[Conversation Summary]"));
assert!(result[1].content.contains("Discussed Rust basics."));
assert_eq!(result[2].content, "And async?");
assert_eq!(result[3].content, "Use tokio.");
}
#[test]
fn test_summarize_without_system_message() {
let msgs = vec![
Message::user("Hello"),
Message::assistant("Hi!"),
Message::user("Bye"),
Message::assistant("Goodbye!"),
];
let result = summarize_messages(msgs, 2, "User greeted.");
assert_eq!(result.len(), 3);
assert_eq!(result[0].role, Role::System);
assert!(result[0].content.contains("[Conversation Summary]"));
assert!(result[0].content.contains("User greeted."));
assert_eq!(result[1].content, "Bye");
assert_eq!(result[2].content, "Goodbye!");
}
#[test]
fn test_summarize_empty_messages() {
let result = summarize_messages(Vec::new(), 2, "Nothing happened.");
assert_eq!(result.len(), 1);
assert!(result[0].content.contains("[Conversation Summary]"));
assert!(result[0].content.contains("Nothing happened."));
}
#[test]
fn test_summarize_keep_greater_than_len() {
let msgs = vec![Message::user("one"), Message::user("two")];
let result = summarize_messages(msgs, 10, "summary");
assert_eq!(result.len(), 2);
assert_eq!(result[0].content, "one");
assert_eq!(result[1].content, "two");
}
#[test]
fn test_build_summary_prompt_includes_content() {
let msgs = vec![
Message::user("What is Rust?"),
Message::assistant("A systems programming language."),
];
let prompt = build_summary_prompt(&msgs);
assert!(prompt.contains("What is Rust?"));
assert!(prompt.contains("A systems programming language."));
}
#[test]
fn test_build_summary_prompt_includes_role_labels() {
let msgs = vec![
Message::user("Hi"),
Message::assistant("Hello"),
Message::system("Be concise"),
];
let prompt = build_summary_prompt(&msgs);
assert!(prompt.contains("user: Hi"));
assert!(prompt.contains("assistant: Hello"));
assert!(prompt.contains("system: Be concise"));
}
#[test]
fn test_build_summary_prompt_includes_instruction() {
let msgs = vec![Message::user("test")];
let prompt = build_summary_prompt(&msgs);
assert!(prompt.contains("Summarize the following conversation"));
assert!(prompt.contains("key decisions"));
assert!(prompt.contains("Be concise"));
}
#[test]
fn test_build_summary_prompt_empty_messages() {
let prompt = build_summary_prompt(&[]);
assert!(prompt.contains("Summarize the following conversation"));
assert!(!prompt.contains("user:"));
}
#[test]
fn test_shrink_tool_results_basic() {
let long_content = "x".repeat(200);
let msgs = vec![
Message::user("Hello"),
Message::tool_result("call_1", &long_content),
Message::assistant("Done"),
];
let (result, count) = shrink_tool_results(msgs, 50);
assert_eq!(count, 1);
assert!(result[1].content.contains("...[shrunk from 200 to"));
assert_eq!(result[0].content, "Hello");
assert_eq!(result[2].content, "Done");
}
#[test]
fn test_shrink_tool_results_preserves_small() {
let msgs = vec![
Message::user("Hello"),
Message::tool_result("call_1", "short result"),
Message::assistant("Done"),
];
let (result, count) = shrink_tool_results(msgs, 1000);
assert_eq!(count, 0);
assert_eq!(result[1].content, "short result");
}
#[test]
fn test_shrink_tool_results_no_tool_results() {
let msgs = vec![
Message::user("Hello"),
Message::assistant("Hi there"),
Message::user("Bye"),
];
let (result, count) = shrink_tool_results(msgs, 10);
assert_eq!(count, 0);
assert_eq!(result.len(), 3);
assert_eq!(result[0].content, "Hello");
assert_eq!(result[1].content, "Hi there");
assert_eq!(result[2].content, "Bye");
}
#[test]
fn test_shrink_tool_results_progressive_older_smaller() {
let long_content = "x".repeat(500);
let msgs = vec![
Message::tool_result("call_1", &long_content), Message::user("middle"),
Message::tool_result("call_2", &long_content), Message::tool_result("call_3", &long_content), ];
let result = shrink_tool_results_progressive(msgs, 200, 1);
assert!(result[0].content.contains("...[shrunk from"));
assert!(result[2].content.contains("...[shrunk from"));
assert!(result[3].content.contains("...[shrunk from"));
let old_base_len = result[0].content.find("\n...[shrunk").unwrap();
let recent_base_len = result[3].content.find("\n...[shrunk").unwrap();
assert!(
old_base_len < recent_base_len,
"Old result base ({}) should be shorter than recent ({})",
old_base_len,
recent_base_len
);
assert_eq!(result[1].content, "middle");
}
#[test]
fn test_try_recover_context_no_recovery_needed() {
let msgs = vec![Message::user("Hello"), Message::assistant("Hi!")];
let (result, tier) = try_recover_context(msgs.clone(), 100_000, 8, 5120);
assert_eq!(tier, 0);
assert_eq!(result.len(), 2);
}
#[test]
fn test_try_recover_context_tier1_sufficient() {
let msgs: Vec<Message> = (0..6)
.map(|_| Message::user("one two three four five six seven eight nine ten"))
.collect();
let (result, tier) = try_recover_context(msgs, 100, 3, 5120);
assert_eq!(tier, 1);
assert_eq!(result.len(), 3);
}
#[test]
fn test_try_recover_context_tier2_needed() {
let mut msgs = vec![Message::system("system prompt")];
for _ in 0..20 {
msgs.push(Message::user(
"one two three four five six seven eight nine ten",
));
}
for i in 0..5 {
let big = "word ".repeat(600); msgs.push(Message::tool_result(&format!("call_{}", i), &big));
}
for _ in 0..3 {
msgs.push(Message::user(
"one two three four five six seven eight nine ten",
));
}
let (result, tier) = try_recover_context(msgs, 2000, 8, 100);
assert!(
tier == 2 || tier == 1,
"Expected tier 1 or 2, got tier {}",
tier
);
let estimated = super::super::context_monitor::ContextMonitor::estimate_tokens(&result);
assert!(
(estimated as f64) <= 2000.0 * 0.95,
"Estimated {} should be <= {}",
estimated,
(2000.0 * 0.95) as usize
);
}
#[test]
fn test_try_recover_context_tier3_needed() {
let msgs: Vec<Message> = (0..10)
.map(|_| Message::user("one two three four five six seven eight nine ten"))
.collect();
let (result, tier) = try_recover_context(msgs, 100, 8, 5120);
assert_eq!(tier, 3);
assert_eq!(result.len(), 3);
let estimated = super::super::context_monitor::ContextMonitor::estimate_tokens(&result);
assert!(
(estimated as f64) <= 100.0 * 0.95,
"Estimated {} should be <= 95",
estimated
);
}
#[test]
fn test_try_recover_context_with_emergency_uses_truncate_path() {
let msgs: Vec<Message> = (0..12)
.map(|_| Message::user("one two three four five six seven eight nine ten"))
.collect();
let (result, tier) =
try_recover_context_with_urgency(msgs, 100, CompactionUrgency::Emergency, 8, 5120);
assert!(tier >= 1);
assert!(result.len() <= 6);
}
#[test]
fn test_try_recover_context_with_critical_hard_trims() {
let msgs: Vec<Message> = (0..20)
.map(|_| Message::user("one two three four five six seven eight nine ten"))
.collect();
let (result, tier) =
try_recover_context_with_urgency(msgs, 100, CompactionUrgency::Critical, 8, 5120);
assert_eq!(tier, 3);
assert!(result.len() <= 6);
}
}