use crate::llm::types::{ContentBlock, Message, Role};
use crate::tool::builtins::floor_char_boundary;
#[derive(Debug, Clone)]
pub struct SessionPruneConfig {
pub keep_recent_n: usize,
pub pruned_tool_result_max_bytes: usize,
pub preserve_task: bool,
}
impl Default for SessionPruneConfig {
fn default() -> Self {
Self {
keep_recent_n: 2,
pruned_tool_result_max_bytes: 200,
preserve_task: true,
}
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct PruneStats {
pub tool_results_pruned: usize,
pub bytes_saved: usize,
pub tool_results_total: usize,
}
impl PruneStats {
pub fn did_prune(&self) -> bool {
self.tool_results_pruned > 0
}
}
pub fn prune_old_tool_results(
messages: &[Message],
config: &SessionPruneConfig,
) -> (Vec<Message>, PruneStats) {
if messages.is_empty() {
return (vec![], PruneStats::default());
}
let mut stats = PruneStats::default();
let recent_count = config.keep_recent_n * 2;
let recent_start = messages.len().saturating_sub(recent_count);
let pruned = messages
.iter()
.enumerate()
.map(|(i, msg)| {
if i == 0 && config.preserve_task {
return msg.clone();
}
if i >= recent_start {
return msg.clone();
}
if msg.role != Role::User {
return msg.clone();
}
let has_tool_results = msg
.content
.iter()
.any(|b| matches!(b, ContentBlock::ToolResult { .. }));
if !has_tool_results {
return msg.clone();
}
let pruned_content = msg
.content
.iter()
.map(|block| match block {
ContentBlock::ToolResult {
tool_use_id,
content,
is_error,
} => {
stats.tool_results_total += 1;
let max = config.pruned_tool_result_max_bytes;
let pruned = truncate_with_marker(content, max);
if pruned.len() < content.len() {
stats.tool_results_pruned += 1;
stats.bytes_saved += content.len() - pruned.len();
}
ContentBlock::ToolResult {
tool_use_id: tool_use_id.clone(),
content: pruned,
is_error: *is_error,
}
}
other => other.clone(),
})
.collect();
Message {
role: msg.role.clone(),
content: pruned_content,
}
})
.collect();
(pruned, stats)
}
fn truncate_with_marker(content: &str, max_bytes: usize) -> String {
if content.len() <= max_bytes {
return content.to_string();
}
let omitted = content.len() - max_bytes;
let marker = format!("\n[pruned: {omitted} bytes omitted]");
let head_budget = max_bytes.saturating_sub(marker.len());
let boundary = floor_char_boundary(content, head_budget);
let head = &content[..boundary];
format!("{head}{marker}")
}
#[cfg(test)]
mod tests {
use super::*;
use crate::llm::types::ToolResult;
use serde_json::json;
fn tool_use_msg(id: &str, name: &str) -> Message {
Message {
role: Role::Assistant,
content: vec![ContentBlock::ToolUse {
id: id.into(),
name: name.into(),
input: json!({}),
}],
}
}
fn tool_result_msg(id: &str, content: &str) -> Message {
Message::tool_results(vec![ToolResult::success(id, content)])
}
#[test]
fn prune_preserves_recent_messages() {
let messages = vec![
Message::user("task"),
tool_use_msg("c1", "search"),
tool_result_msg("c1", &"x".repeat(1000)),
tool_use_msg("c2", "read"),
tool_result_msg("c2", &"y".repeat(1000)),
Message::assistant("final answer"),
];
let config = SessionPruneConfig {
keep_recent_n: 2,
pruned_tool_result_max_bytes: 50,
preserve_task: true,
};
let (pruned, stats) = prune_old_tool_results(&messages, &config);
assert_eq!(pruned.len(), messages.len(), "message count unchanged");
let last_result = &pruned[4];
if let ContentBlock::ToolResult { content, .. } = &last_result.content[0] {
assert_eq!(content.len(), 1000, "recent tool result should be intact");
}
assert!(!stats.did_prune());
}
#[test]
fn prune_trims_old_tool_results() {
let messages = vec![
Message::user("task"),
tool_use_msg("c1", "search"),
tool_result_msg("c1", &"a".repeat(1000)),
tool_use_msg("c2", "read"),
tool_result_msg("c2", &"b".repeat(500)),
tool_use_msg("c3", "write"),
tool_result_msg("c3", "short result"),
Message::assistant("done"),
];
let config = SessionPruneConfig {
keep_recent_n: 1,
pruned_tool_result_max_bytes: 100,
preserve_task: true,
};
let (pruned, stats) = prune_old_tool_results(&messages, &config);
if let ContentBlock::ToolResult { content, .. } = &pruned[2].content[0] {
assert!(
content.len() <= 200,
"old tool result should be truncated, got {} bytes",
content.len()
);
assert!(content.contains("[pruned:"));
}
if let ContentBlock::ToolResult { content, .. } = &pruned[4].content[0] {
assert!(
content.len() <= 200,
"old tool result should be truncated, got {} bytes",
content.len()
);
assert!(content.contains("[pruned:"));
}
assert!(stats.did_prune());
assert_eq!(stats.tool_results_pruned, 2);
assert!(stats.bytes_saved > 0);
assert_eq!(stats.tool_results_total, 2);
}
#[test]
fn prune_preserves_task_message() {
let messages = vec![
Message::user("important initial task"),
tool_use_msg("c1", "search"),
tool_result_msg("c1", &"x".repeat(1000)),
Message::assistant("answer"),
];
let config = SessionPruneConfig {
keep_recent_n: 0,
pruned_tool_result_max_bytes: 50,
preserve_task: true,
};
let (pruned, _stats) = prune_old_tool_results(&messages, &config);
if let ContentBlock::Text { text } = &pruned[0].content[0] {
assert_eq!(text, "important initial task");
}
}
#[test]
fn prune_preserves_message_count() {
let messages = vec![
Message::user("task"),
tool_use_msg("c1", "search"),
tool_result_msg("c1", &"x".repeat(1000)),
tool_use_msg("c2", "read"),
tool_result_msg("c2", &"y".repeat(1000)),
Message::assistant("done"),
];
let config = SessionPruneConfig::default();
let (pruned, _stats) = prune_old_tool_results(&messages, &config);
assert_eq!(pruned.len(), messages.len());
for (original, pruned) in messages.iter().zip(pruned.iter()) {
assert_eq!(original.role, pruned.role);
}
}
#[test]
fn prune_utf8_safe() {
let emoji_content = "🦀".repeat(100); let messages = vec![
Message::user("task"),
tool_use_msg("c1", "search"),
tool_result_msg("c1", &emoji_content),
Message::assistant("done"),
];
let config = SessionPruneConfig {
keep_recent_n: 0,
pruned_tool_result_max_bytes: 50,
preserve_task: true,
};
let (pruned, _stats) = prune_old_tool_results(&messages, &config);
if let ContentBlock::ToolResult { content, .. } = &pruned[2].content[0] {
assert!(content.is_char_boundary(0));
for _ in content.chars() {}
}
}
#[test]
fn prune_empty_messages() {
let (pruned, stats) = prune_old_tool_results(&[], &SessionPruneConfig::default());
assert!(pruned.is_empty());
assert!(!stats.did_prune());
}
#[test]
fn prune_no_tool_results_is_noop() {
let messages = vec![
Message::user("task"),
Message::assistant("response 1"),
Message::user("follow up"),
Message::assistant("response 2"),
];
let config = SessionPruneConfig {
keep_recent_n: 0,
pruned_tool_result_max_bytes: 10,
preserve_task: true,
};
let (pruned, stats) = prune_old_tool_results(&messages, &config);
for (original, pruned) in messages.iter().zip(pruned.iter()) {
assert_eq!(original.content.len(), pruned.content.len());
}
assert!(!stats.did_prune());
}
#[test]
fn prune_short_tool_results_unchanged() {
let messages = vec![
Message::user("task"),
tool_use_msg("c1", "search"),
tool_result_msg("c1", "short"),
Message::assistant("done"),
];
let config = SessionPruneConfig {
keep_recent_n: 0,
pruned_tool_result_max_bytes: 200,
preserve_task: true,
};
let (pruned, stats) = prune_old_tool_results(&messages, &config);
if let ContentBlock::ToolResult { content, .. } = &pruned[2].content[0] {
assert_eq!(content, "short", "short results should not be modified");
}
assert!(!stats.did_prune());
assert_eq!(stats.tool_results_total, 1);
assert_eq!(stats.tool_results_pruned, 0);
}
#[test]
fn truncate_with_marker_short_content() {
let result = truncate_with_marker("hello", 100);
assert_eq!(result, "hello");
}
#[test]
fn truncate_with_marker_long_content() {
let content = "a".repeat(1000);
let result = truncate_with_marker(&content, 100);
assert!(result.len() <= 200); assert!(result.contains("[pruned:"));
assert!(result.contains("bytes omitted]"));
}
#[test]
fn prune_stats_bytes_saved_accurate() {
let messages = vec![
Message::user("task"),
tool_use_msg("c1", "search"),
tool_result_msg("c1", &"a".repeat(1000)),
tool_use_msg("c2", "read"),
tool_result_msg("c2", &"b".repeat(2000)),
Message::assistant("done"),
];
let config = SessionPruneConfig {
keep_recent_n: 0,
pruned_tool_result_max_bytes: 100,
preserve_task: true,
};
let (pruned, stats) = prune_old_tool_results(&messages, &config);
assert!(stats.did_prune());
assert_eq!(stats.tool_results_pruned, 2);
assert_eq!(stats.tool_results_total, 2);
let pruned_c1_len = if let ContentBlock::ToolResult { content, .. } = &pruned[2].content[0]
{
content.len()
} else {
panic!("expected tool result");
};
let pruned_c2_len = if let ContentBlock::ToolResult { content, .. } = &pruned[4].content[0]
{
content.len()
} else {
panic!("expected tool result");
};
let expected_saved = (1000 - pruned_c1_len) + (2000 - pruned_c2_len);
assert_eq!(stats.bytes_saved, expected_saved);
}
}