1use crate::llm::types::{ContentBlock, Message, Role};
4use crate::tool::builtins::floor_char_boundary;
5
6#[derive(Debug, Clone)]
11pub struct SessionPruneConfig {
12 pub keep_recent_n: usize,
15 pub pruned_tool_result_max_bytes: usize,
18 pub preserve_task: bool,
21}
22
23impl Default for SessionPruneConfig {
24 fn default() -> Self {
25 Self {
26 keep_recent_n: 2,
27 pruned_tool_result_max_bytes: 200,
28 preserve_task: true,
29 }
30 }
31}
32
33#[derive(Debug, Clone, Default, PartialEq, Eq)]
35pub struct PruneStats {
36 pub tool_results_pruned: usize,
38 pub bytes_saved: usize,
40 pub tool_results_total: usize,
42}
43
44impl PruneStats {
45 pub fn did_prune(&self) -> bool {
47 self.tool_results_pruned > 0
48 }
49}
50
51pub fn prune_old_tool_results(
60 messages: &[Message],
61 config: &SessionPruneConfig,
62) -> (Vec<Message>, PruneStats) {
63 if messages.is_empty() {
64 return (vec![], PruneStats::default());
65 }
66
67 let mut stats = PruneStats::default();
68
69 let recent_count = config.keep_recent_n * 2;
71 let recent_start = messages.len().saturating_sub(recent_count);
72
73 let pruned = messages
74 .iter()
75 .enumerate()
76 .map(|(i, msg)| {
77 if i == 0 && config.preserve_task {
79 return msg.clone();
80 }
81 if i >= recent_start {
83 return msg.clone();
84 }
85 if msg.role != Role::User {
87 return msg.clone();
88 }
89 let has_tool_results = msg
90 .content
91 .iter()
92 .any(|b| matches!(b, ContentBlock::ToolResult { .. }));
93 if !has_tool_results {
94 return msg.clone();
95 }
96 let pruned_content = msg
98 .content
99 .iter()
100 .map(|block| match block {
101 ContentBlock::ToolResult {
102 tool_use_id,
103 content,
104 is_error,
105 } => {
106 stats.tool_results_total += 1;
107 let max = config.pruned_tool_result_max_bytes;
108 let pruned = truncate_with_marker(content, max);
109 if pruned.len() < content.len() {
110 stats.tool_results_pruned += 1;
111 stats.bytes_saved += content.len() - pruned.len();
112 }
113 ContentBlock::ToolResult {
114 tool_use_id: tool_use_id.clone(),
115 content: pruned,
116 is_error: *is_error,
117 }
118 }
119 other => other.clone(),
120 })
121 .collect();
122 Message {
123 role: msg.role.clone(),
124 content: pruned_content,
125 }
126 })
127 .collect();
128
129 (pruned, stats)
130}
131
132fn truncate_with_marker(content: &str, max_bytes: usize) -> String {
137 if content.len() <= max_bytes {
138 return content.to_string();
139 }
140 let omitted = content.len() - max_bytes;
141 let marker = format!("\n[pruned: {omitted} bytes omitted]");
142 let head_budget = max_bytes.saturating_sub(marker.len());
143 let boundary = floor_char_boundary(content, head_budget);
144 let head = &content[..boundary];
145 format!("{head}{marker}")
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151 use crate::llm::types::ToolResult;
152 use serde_json::json;
153
154 fn tool_use_msg(id: &str, name: &str) -> Message {
155 Message {
156 role: Role::Assistant,
157 content: vec![ContentBlock::ToolUse {
158 id: id.into(),
159 name: name.into(),
160 input: json!({}),
161 }],
162 }
163 }
164
165 fn tool_result_msg(id: &str, content: &str) -> Message {
166 Message::tool_results(vec![ToolResult::success(id, content)])
167 }
168
169 #[test]
170 fn prune_preserves_recent_messages() {
171 let messages = vec![
172 Message::user("task"),
173 tool_use_msg("c1", "search"),
174 tool_result_msg("c1", &"x".repeat(1000)),
175 tool_use_msg("c2", "read"),
176 tool_result_msg("c2", &"y".repeat(1000)),
177 Message::assistant("final answer"),
178 ];
179
180 let config = SessionPruneConfig {
181 keep_recent_n: 2,
182 pruned_tool_result_max_bytes: 50,
183 preserve_task: true,
184 };
185 let (pruned, stats) = prune_old_tool_results(&messages, &config);
186
187 assert_eq!(pruned.len(), messages.len(), "message count unchanged");
188
189 let last_result = &pruned[4];
191 if let ContentBlock::ToolResult { content, .. } = &last_result.content[0] {
192 assert_eq!(content.len(), 1000, "recent tool result should be intact");
193 }
194
195 assert!(!stats.did_prune());
200 }
201
202 #[test]
203 fn prune_trims_old_tool_results() {
204 let messages = vec![
205 Message::user("task"),
206 tool_use_msg("c1", "search"),
207 tool_result_msg("c1", &"a".repeat(1000)),
208 tool_use_msg("c2", "read"),
209 tool_result_msg("c2", &"b".repeat(500)),
210 tool_use_msg("c3", "write"),
211 tool_result_msg("c3", "short result"),
212 Message::assistant("done"),
213 ];
214
215 let config = SessionPruneConfig {
216 keep_recent_n: 1,
217 pruned_tool_result_max_bytes: 100,
218 preserve_task: true,
219 };
220 let (pruned, stats) = prune_old_tool_results(&messages, &config);
221
222 if let ContentBlock::ToolResult { content, .. } = &pruned[2].content[0] {
224 assert!(
225 content.len() <= 200,
226 "old tool result should be truncated, got {} bytes",
227 content.len()
228 );
229 assert!(content.contains("[pruned:"));
230 }
231
232 if let ContentBlock::ToolResult { content, .. } = &pruned[4].content[0] {
234 assert!(
235 content.len() <= 200,
236 "old tool result should be truncated, got {} bytes",
237 content.len()
238 );
239 assert!(content.contains("[pruned:"));
240 }
241
242 assert!(stats.did_prune());
243 assert_eq!(stats.tool_results_pruned, 2);
244 assert!(stats.bytes_saved > 0);
245 assert_eq!(stats.tool_results_total, 2);
246 }
247
248 #[test]
249 fn prune_preserves_task_message() {
250 let messages = vec![
251 Message::user("important initial task"),
252 tool_use_msg("c1", "search"),
253 tool_result_msg("c1", &"x".repeat(1000)),
254 Message::assistant("answer"),
255 ];
256
257 let config = SessionPruneConfig {
258 keep_recent_n: 0,
259 pruned_tool_result_max_bytes: 50,
260 preserve_task: true,
261 };
262 let (pruned, _stats) = prune_old_tool_results(&messages, &config);
263
264 if let ContentBlock::Text { text } = &pruned[0].content[0] {
266 assert_eq!(text, "important initial task");
267 }
268 }
269
270 #[test]
271 fn prune_preserves_message_count() {
272 let messages = vec![
273 Message::user("task"),
274 tool_use_msg("c1", "search"),
275 tool_result_msg("c1", &"x".repeat(1000)),
276 tool_use_msg("c2", "read"),
277 tool_result_msg("c2", &"y".repeat(1000)),
278 Message::assistant("done"),
279 ];
280
281 let config = SessionPruneConfig::default();
282 let (pruned, _stats) = prune_old_tool_results(&messages, &config);
283
284 assert_eq!(pruned.len(), messages.len());
285 for (original, pruned) in messages.iter().zip(pruned.iter()) {
287 assert_eq!(original.role, pruned.role);
288 }
289 }
290
291 #[test]
292 fn prune_utf8_safe() {
293 let emoji_content = "🦀".repeat(100); let messages = vec![
296 Message::user("task"),
297 tool_use_msg("c1", "search"),
298 tool_result_msg("c1", &emoji_content),
299 Message::assistant("done"),
300 ];
301
302 let config = SessionPruneConfig {
303 keep_recent_n: 0,
304 pruned_tool_result_max_bytes: 50,
305 preserve_task: true,
306 };
307 let (pruned, _stats) = prune_old_tool_results(&messages, &config);
308
309 if let ContentBlock::ToolResult { content, .. } = &pruned[2].content[0] {
311 assert!(content.is_char_boundary(0));
312 for _ in content.chars() {}
314 }
315 }
316
317 #[test]
318 fn prune_empty_messages() {
319 let (pruned, stats) = prune_old_tool_results(&[], &SessionPruneConfig::default());
320 assert!(pruned.is_empty());
321 assert!(!stats.did_prune());
322 }
323
324 #[test]
325 fn prune_no_tool_results_is_noop() {
326 let messages = vec![
327 Message::user("task"),
328 Message::assistant("response 1"),
329 Message::user("follow up"),
330 Message::assistant("response 2"),
331 ];
332
333 let config = SessionPruneConfig {
334 keep_recent_n: 0,
335 pruned_tool_result_max_bytes: 10,
336 preserve_task: true,
337 };
338 let (pruned, stats) = prune_old_tool_results(&messages, &config);
339
340 for (original, pruned) in messages.iter().zip(pruned.iter()) {
342 assert_eq!(original.content.len(), pruned.content.len());
343 }
344 assert!(!stats.did_prune());
345 }
346
347 #[test]
348 fn prune_short_tool_results_unchanged() {
349 let messages = vec![
350 Message::user("task"),
351 tool_use_msg("c1", "search"),
352 tool_result_msg("c1", "short"),
353 Message::assistant("done"),
354 ];
355
356 let config = SessionPruneConfig {
357 keep_recent_n: 0,
358 pruned_tool_result_max_bytes: 200,
359 preserve_task: true,
360 };
361 let (pruned, stats) = prune_old_tool_results(&messages, &config);
362
363 if let ContentBlock::ToolResult { content, .. } = &pruned[2].content[0] {
364 assert_eq!(content, "short", "short results should not be modified");
365 }
366 assert!(!stats.did_prune());
368 assert_eq!(stats.tool_results_total, 1);
369 assert_eq!(stats.tool_results_pruned, 0);
370 }
371
372 #[test]
373 fn truncate_with_marker_short_content() {
374 let result = truncate_with_marker("hello", 100);
375 assert_eq!(result, "hello");
376 }
377
378 #[test]
379 fn truncate_with_marker_long_content() {
380 let content = "a".repeat(1000);
381 let result = truncate_with_marker(&content, 100);
382 assert!(result.len() <= 200); assert!(result.contains("[pruned:"));
384 assert!(result.contains("bytes omitted]"));
385 }
386
387 #[test]
388 fn prune_stats_bytes_saved_accurate() {
389 let messages = vec![
390 Message::user("task"),
391 tool_use_msg("c1", "search"),
392 tool_result_msg("c1", &"a".repeat(1000)),
393 tool_use_msg("c2", "read"),
394 tool_result_msg("c2", &"b".repeat(2000)),
395 Message::assistant("done"),
396 ];
397
398 let config = SessionPruneConfig {
399 keep_recent_n: 0,
400 pruned_tool_result_max_bytes: 100,
401 preserve_task: true,
402 };
403 let (pruned, stats) = prune_old_tool_results(&messages, &config);
404
405 assert!(stats.did_prune());
406 assert_eq!(stats.tool_results_pruned, 2);
407 assert_eq!(stats.tool_results_total, 2);
408
409 let pruned_c1_len = if let ContentBlock::ToolResult { content, .. } = &pruned[2].content[0]
411 {
412 content.len()
413 } else {
414 panic!("expected tool result");
415 };
416 let pruned_c2_len = if let ContentBlock::ToolResult { content, .. } = &pruned[4].content[0]
417 {
418 content.len()
419 } else {
420 panic!("expected tool result");
421 };
422 let expected_saved = (1000 - pruned_c1_len) + (2000 - pruned_c2_len);
423 assert_eq!(stats.bytes_saved, expected_saved);
424 }
425}