Skip to main content

heartbit_core/agent/
pruner.rs

1//! Session context pruner — trims old messages before LLM calls to control context size.
2
3use crate::llm::types::{ContentBlock, Message, Role};
4use crate::tool::builtins::floor_char_boundary;
5
6/// Configuration for session-level pruning of old tool results.
7///
8/// Before each LLM call, old tool results are truncated in-place to reduce
9/// token usage. Recent messages are preserved at full fidelity.
10#[derive(Debug, Clone)]
11pub struct SessionPruneConfig {
12    /// Number of recent user/assistant message pairs to keep at full fidelity.
13    /// Default: 2.
14    pub keep_recent_n: usize,
15    /// Maximum bytes for a pruned tool result. Content exceeding this is
16    /// replaced with head + tail + `[pruned: N bytes]`. Default: 200.
17    pub pruned_tool_result_max_bytes: usize,
18    /// Whether to preserve the first user message (task) from pruning.
19    /// Default: true.
20    pub preserve_task: bool,
21}
22
23impl Default for SessionPruneConfig {
24    fn default() -> Self {
25        Self {
26            keep_recent_n: 2,
27            pruned_tool_result_max_bytes: 200,
28            preserve_task: true,
29        }
30    }
31}
32
33/// Statistics from a session pruning pass.
34#[derive(Debug, Clone, Default, PartialEq, Eq)]
35pub struct PruneStats {
36    /// Number of tool results that were truncated.
37    pub tool_results_pruned: usize,
38    /// Total bytes removed across all truncated tool results.
39    pub bytes_saved: usize,
40    /// Total number of tool results inspected (pruned + skipped).
41    pub tool_results_total: usize,
42}
43
44impl PruneStats {
45    /// Returns `true` if any pruning actually occurred.
46    pub fn did_prune(&self) -> bool {
47        self.tool_results_pruned > 0
48    }
49}
50
51/// Prune old tool results in a message list, returning a new list and stats.
52///
53/// Messages in the "recent" tail (last `keep_recent_n * 2` messages) are kept
54/// intact. Older messages containing tool results have their content truncated
55/// to `max_bytes` with a `[pruned: N bytes]` marker.
56///
57/// The first message (task) is always preserved if `preserve_task` is true.
58/// Message count and roles are never changed — only content is shortened.
59pub fn prune_old_tool_results(
60    messages: &[Message],
61    config: &SessionPruneConfig,
62) -> (Vec<Message>, PruneStats) {
63    if messages.is_empty() {
64        return (vec![], PruneStats::default());
65    }
66
67    let mut stats = PruneStats::default();
68
69    // Recent tail: keep the last N*2 messages (user+assistant pairs)
70    let recent_count = config.keep_recent_n * 2;
71    let recent_start = messages.len().saturating_sub(recent_count);
72
73    let pruned = messages
74        .iter()
75        .enumerate()
76        .map(|(i, msg)| {
77            // Preserve task message
78            if i == 0 && config.preserve_task {
79                return msg.clone();
80            }
81            // Preserve recent messages
82            if i >= recent_start {
83                return msg.clone();
84            }
85            // Only prune User messages with tool results
86            if msg.role != Role::User {
87                return msg.clone();
88            }
89            let has_tool_results = msg
90                .content
91                .iter()
92                .any(|b| matches!(b, ContentBlock::ToolResult { .. }));
93            if !has_tool_results {
94                return msg.clone();
95            }
96            // Prune tool result content
97            let pruned_content = msg
98                .content
99                .iter()
100                .map(|block| match block {
101                    ContentBlock::ToolResult {
102                        tool_use_id,
103                        content,
104                        is_error,
105                    } => {
106                        stats.tool_results_total += 1;
107                        let max = config.pruned_tool_result_max_bytes;
108                        let pruned = truncate_with_marker(content, max);
109                        if pruned.len() < content.len() {
110                            stats.tool_results_pruned += 1;
111                            stats.bytes_saved += content.len() - pruned.len();
112                        }
113                        ContentBlock::ToolResult {
114                            tool_use_id: tool_use_id.clone(),
115                            content: pruned,
116                            is_error: *is_error,
117                        }
118                    }
119                    other => other.clone(),
120                })
121                .collect();
122            Message {
123                role: msg.role.clone(),
124                content: pruned_content,
125            }
126        })
127        .collect();
128
129    (pruned, stats)
130}
131
132/// Truncate content to `max_bytes` with a `[pruned: N bytes]` marker.
133///
134/// If content fits within `max_bytes`, returns it unchanged.
135/// Otherwise, keeps head bytes up to a char boundary and appends marker.
136fn truncate_with_marker(content: &str, max_bytes: usize) -> String {
137    if content.len() <= max_bytes {
138        return content.to_string();
139    }
140    let omitted = content.len() - max_bytes;
141    let marker = format!("\n[pruned: {omitted} bytes omitted]");
142    let head_budget = max_bytes.saturating_sub(marker.len());
143    let boundary = floor_char_boundary(content, head_budget);
144    let head = &content[..boundary];
145    format!("{head}{marker}")
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151    use crate::llm::types::ToolResult;
152    use serde_json::json;
153
154    fn tool_use_msg(id: &str, name: &str) -> Message {
155        Message {
156            role: Role::Assistant,
157            content: vec![ContentBlock::ToolUse {
158                id: id.into(),
159                name: name.into(),
160                input: json!({}),
161            }],
162        }
163    }
164
165    fn tool_result_msg(id: &str, content: &str) -> Message {
166        Message::tool_results(vec![ToolResult::success(id, content)])
167    }
168
169    #[test]
170    fn prune_preserves_recent_messages() {
171        let messages = vec![
172            Message::user("task"),
173            tool_use_msg("c1", "search"),
174            tool_result_msg("c1", &"x".repeat(1000)),
175            tool_use_msg("c2", "read"),
176            tool_result_msg("c2", &"y".repeat(1000)),
177            Message::assistant("final answer"),
178        ];
179
180        let config = SessionPruneConfig {
181            keep_recent_n: 2,
182            pruned_tool_result_max_bytes: 50,
183            preserve_task: true,
184        };
185        let (pruned, stats) = prune_old_tool_results(&messages, &config);
186
187        assert_eq!(pruned.len(), messages.len(), "message count unchanged");
188
189        // Last 4 messages (2 pairs) should be intact
190        let last_result = &pruned[4];
191        if let ContentBlock::ToolResult { content, .. } = &last_result.content[0] {
192            assert_eq!(content.len(), 1000, "recent tool result should be intact");
193        }
194
195        // Only 1 tool result is outside the recent window (c1), but it's
196        // also the task-adjacent one — the first user msg (task) is index 0,
197        // c1 result is index 2. With keep_recent_n=2 the recent window
198        // starts at index 2 (6-4=2), so c1 is at the boundary and preserved.
199        assert!(!stats.did_prune());
200    }
201
202    #[test]
203    fn prune_trims_old_tool_results() {
204        let messages = vec![
205            Message::user("task"),
206            tool_use_msg("c1", "search"),
207            tool_result_msg("c1", &"a".repeat(1000)),
208            tool_use_msg("c2", "read"),
209            tool_result_msg("c2", &"b".repeat(500)),
210            tool_use_msg("c3", "write"),
211            tool_result_msg("c3", "short result"),
212            Message::assistant("done"),
213        ];
214
215        let config = SessionPruneConfig {
216            keep_recent_n: 1,
217            pruned_tool_result_max_bytes: 100,
218            preserve_task: true,
219        };
220        let (pruned, stats) = prune_old_tool_results(&messages, &config);
221
222        // messages[2] (old tool result with 1000 bytes) should be pruned
223        if let ContentBlock::ToolResult { content, .. } = &pruned[2].content[0] {
224            assert!(
225                content.len() <= 200,
226                "old tool result should be truncated, got {} bytes",
227                content.len()
228            );
229            assert!(content.contains("[pruned:"));
230        }
231
232        // messages[4] (old tool result with 500 bytes) should also be pruned
233        if let ContentBlock::ToolResult { content, .. } = &pruned[4].content[0] {
234            assert!(
235                content.len() <= 200,
236                "old tool result should be truncated, got {} bytes",
237                content.len()
238            );
239            assert!(content.contains("[pruned:"));
240        }
241
242        assert!(stats.did_prune());
243        assert_eq!(stats.tool_results_pruned, 2);
244        assert!(stats.bytes_saved > 0);
245        assert_eq!(stats.tool_results_total, 2);
246    }
247
248    #[test]
249    fn prune_preserves_task_message() {
250        let messages = vec![
251            Message::user("important initial task"),
252            tool_use_msg("c1", "search"),
253            tool_result_msg("c1", &"x".repeat(1000)),
254            Message::assistant("answer"),
255        ];
256
257        let config = SessionPruneConfig {
258            keep_recent_n: 0,
259            pruned_tool_result_max_bytes: 50,
260            preserve_task: true,
261        };
262        let (pruned, _stats) = prune_old_tool_results(&messages, &config);
263
264        // Task message should be unchanged
265        if let ContentBlock::Text { text } = &pruned[0].content[0] {
266            assert_eq!(text, "important initial task");
267        }
268    }
269
270    #[test]
271    fn prune_preserves_message_count() {
272        let messages = vec![
273            Message::user("task"),
274            tool_use_msg("c1", "search"),
275            tool_result_msg("c1", &"x".repeat(1000)),
276            tool_use_msg("c2", "read"),
277            tool_result_msg("c2", &"y".repeat(1000)),
278            Message::assistant("done"),
279        ];
280
281        let config = SessionPruneConfig::default();
282        let (pruned, _stats) = prune_old_tool_results(&messages, &config);
283
284        assert_eq!(pruned.len(), messages.len());
285        // Verify roles are preserved
286        for (original, pruned) in messages.iter().zip(pruned.iter()) {
287            assert_eq!(original.role, pruned.role);
288        }
289    }
290
291    #[test]
292    fn prune_utf8_safe() {
293        // Multi-byte UTF-8 content should not be split at invalid boundaries
294        let emoji_content = "🦀".repeat(100); // 400 bytes, 100 chars
295        let messages = vec![
296            Message::user("task"),
297            tool_use_msg("c1", "search"),
298            tool_result_msg("c1", &emoji_content),
299            Message::assistant("done"),
300        ];
301
302        let config = SessionPruneConfig {
303            keep_recent_n: 0,
304            pruned_tool_result_max_bytes: 50,
305            preserve_task: true,
306        };
307        let (pruned, _stats) = prune_old_tool_results(&messages, &config);
308
309        // Should not panic and content should be valid UTF-8
310        if let ContentBlock::ToolResult { content, .. } = &pruned[2].content[0] {
311            assert!(content.is_char_boundary(0));
312            // Verify it's valid UTF-8 by iterating
313            for _ in content.chars() {}
314        }
315    }
316
317    #[test]
318    fn prune_empty_messages() {
319        let (pruned, stats) = prune_old_tool_results(&[], &SessionPruneConfig::default());
320        assert!(pruned.is_empty());
321        assert!(!stats.did_prune());
322    }
323
324    #[test]
325    fn prune_no_tool_results_is_noop() {
326        let messages = vec![
327            Message::user("task"),
328            Message::assistant("response 1"),
329            Message::user("follow up"),
330            Message::assistant("response 2"),
331        ];
332
333        let config = SessionPruneConfig {
334            keep_recent_n: 0,
335            pruned_tool_result_max_bytes: 10,
336            preserve_task: true,
337        };
338        let (pruned, stats) = prune_old_tool_results(&messages, &config);
339
340        // No tool results to prune, all messages should be unchanged
341        for (original, pruned) in messages.iter().zip(pruned.iter()) {
342            assert_eq!(original.content.len(), pruned.content.len());
343        }
344        assert!(!stats.did_prune());
345    }
346
347    #[test]
348    fn prune_short_tool_results_unchanged() {
349        let messages = vec![
350            Message::user("task"),
351            tool_use_msg("c1", "search"),
352            tool_result_msg("c1", "short"),
353            Message::assistant("done"),
354        ];
355
356        let config = SessionPruneConfig {
357            keep_recent_n: 0,
358            pruned_tool_result_max_bytes: 200,
359            preserve_task: true,
360        };
361        let (pruned, stats) = prune_old_tool_results(&messages, &config);
362
363        if let ContentBlock::ToolResult { content, .. } = &pruned[2].content[0] {
364            assert_eq!(content, "short", "short results should not be modified");
365        }
366        // Tool result was inspected but not truncated (under max_bytes)
367        assert!(!stats.did_prune());
368        assert_eq!(stats.tool_results_total, 1);
369        assert_eq!(stats.tool_results_pruned, 0);
370    }
371
372    #[test]
373    fn truncate_with_marker_short_content() {
374        let result = truncate_with_marker("hello", 100);
375        assert_eq!(result, "hello");
376    }
377
378    #[test]
379    fn truncate_with_marker_long_content() {
380        let content = "a".repeat(1000);
381        let result = truncate_with_marker(&content, 100);
382        assert!(result.len() <= 200); // head + marker
383        assert!(result.contains("[pruned:"));
384        assert!(result.contains("bytes omitted]"));
385    }
386
387    #[test]
388    fn prune_stats_bytes_saved_accurate() {
389        let messages = vec![
390            Message::user("task"),
391            tool_use_msg("c1", "search"),
392            tool_result_msg("c1", &"a".repeat(1000)),
393            tool_use_msg("c2", "read"),
394            tool_result_msg("c2", &"b".repeat(2000)),
395            Message::assistant("done"),
396        ];
397
398        let config = SessionPruneConfig {
399            keep_recent_n: 0,
400            pruned_tool_result_max_bytes: 100,
401            preserve_task: true,
402        };
403        let (pruned, stats) = prune_old_tool_results(&messages, &config);
404
405        assert!(stats.did_prune());
406        assert_eq!(stats.tool_results_pruned, 2);
407        assert_eq!(stats.tool_results_total, 2);
408
409        // bytes_saved = original bytes - pruned bytes for each truncated result
410        let pruned_c1_len = if let ContentBlock::ToolResult { content, .. } = &pruned[2].content[0]
411        {
412            content.len()
413        } else {
414            panic!("expected tool result");
415        };
416        let pruned_c2_len = if let ContentBlock::ToolResult { content, .. } = &pruned[4].content[0]
417        {
418            content.len()
419        } else {
420            panic!("expected tool result");
421        };
422        let expected_saved = (1000 - pruned_c1_len) + (2000 - pruned_c2_len);
423        assert_eq!(stats.bytes_saved, expected_saved);
424    }
425}