Skip to main content

construct/agent/
history_pruner.rs

1use crate::providers::traits::ChatMessage;
2use schemars::JsonSchema;
3use serde::{Deserialize, Serialize};
4
5// ---------------------------------------------------------------------------
6// Config
7// ---------------------------------------------------------------------------
8
9fn default_max_tokens() -> usize {
10    8192
11}
12
13fn default_keep_recent() -> usize {
14    4
15}
16
17fn default_collapse() -> bool {
18    true
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
22pub struct HistoryPrunerConfig {
23    /// Enable history pruning. Default: false.
24    #[serde(default)]
25    pub enabled: bool,
26    /// Maximum estimated tokens for message history. Default: 8192.
27    #[serde(default = "default_max_tokens")]
28    pub max_tokens: usize,
29    /// Keep the N most recent messages untouched. Default: 4.
30    #[serde(default = "default_keep_recent")]
31    pub keep_recent: usize,
32    /// Collapse old tool call/result pairs into short summaries. Default: true.
33    #[serde(default = "default_collapse")]
34    pub collapse_tool_results: bool,
35}
36
37impl Default for HistoryPrunerConfig {
38    fn default() -> Self {
39        Self {
40            enabled: false,
41            max_tokens: 8192,
42            keep_recent: 4,
43            collapse_tool_results: true,
44        }
45    }
46}
47
48// ---------------------------------------------------------------------------
49// Stats
50// ---------------------------------------------------------------------------
51
52#[derive(Debug, Clone, PartialEq, Eq)]
53pub struct PruneStats {
54    pub messages_before: usize,
55    pub messages_after: usize,
56    pub collapsed_pairs: usize,
57    pub dropped_messages: usize,
58}
59
60// ---------------------------------------------------------------------------
61// Token estimation
62// ---------------------------------------------------------------------------
63
64fn estimate_tokens(messages: &[ChatMessage]) -> usize {
65    messages.iter().map(|m| m.content.len() / 4).sum()
66}
67
68// ---------------------------------------------------------------------------
69// Protected-index helpers
70// ---------------------------------------------------------------------------
71
72fn protected_indices(messages: &[ChatMessage], keep_recent: usize) -> Vec<bool> {
73    let len = messages.len();
74    let mut protected = vec![false; len];
75    for (i, msg) in messages.iter().enumerate() {
76        if msg.role == "system" {
77            protected[i] = true;
78        }
79    }
80    let recent_start = len.saturating_sub(keep_recent);
81    for p in protected.iter_mut().skip(recent_start) {
82        *p = true;
83    }
84    protected
85}
86
87// ---------------------------------------------------------------------------
88// Public entry point
89// ---------------------------------------------------------------------------
90
91pub fn prune_history(messages: &mut Vec<ChatMessage>, config: &HistoryPrunerConfig) -> PruneStats {
92    let messages_before = messages.len();
93    if !config.enabled || messages.is_empty() {
94        return PruneStats {
95            messages_before,
96            messages_after: messages_before,
97            collapsed_pairs: 0,
98            dropped_messages: 0,
99        };
100    }
101
102    let mut collapsed_pairs: usize = 0;
103
104    // Phase 1 – collapse assistant+tool pairs
105    if config.collapse_tool_results {
106        let mut i = 0;
107        while i + 1 < messages.len() {
108            let protected = protected_indices(messages, config.keep_recent);
109            if messages[i].role == "assistant"
110                && messages[i + 1].role == "tool"
111                && !protected[i]
112                && !protected[i + 1]
113            {
114                let tool_content = &messages[i + 1].content;
115                let truncated: String = tool_content.chars().take(100).collect();
116                let summary = format!("[Tool result: {truncated}...]");
117                messages[i] = ChatMessage {
118                    role: "assistant".to_string(),
119                    content: summary,
120                };
121                messages.remove(i + 1);
122                collapsed_pairs += 1;
123            } else {
124                i += 1;
125            }
126        }
127    }
128
129    // Phase 2 – budget enforcement
130    let mut dropped_messages: usize = 0;
131    while estimate_tokens(messages) > config.max_tokens {
132        let protected = protected_indices(messages, config.keep_recent);
133        if let Some(idx) = protected
134            .iter()
135            .enumerate()
136            .find(|&(_, &p)| !p)
137            .map(|(i, _)| i)
138        {
139            messages.remove(idx);
140            dropped_messages += 1;
141        } else {
142            break;
143        }
144    }
145
146    PruneStats {
147        messages_before,
148        messages_after: messages.len(),
149        collapsed_pairs,
150        dropped_messages,
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157
158    fn msg(role: &str, content: &str) -> ChatMessage {
159        ChatMessage {
160            role: role.to_string(),
161            content: content.to_string(),
162        }
163    }
164
165    #[test]
166    fn prune_disabled_is_noop() {
167        let mut messages = vec![
168            msg("system", "You are helpful."),
169            msg("user", "Hello"),
170            msg("assistant", "Hi there!"),
171        ];
172        let config = HistoryPrunerConfig {
173            enabled: false,
174            ..Default::default()
175        };
176        let stats = prune_history(&mut messages, &config);
177        assert_eq!(messages.len(), 3);
178        assert_eq!(messages[0].content, "You are helpful.");
179        assert_eq!(stats.messages_before, 3);
180        assert_eq!(stats.messages_after, 3);
181        assert_eq!(stats.collapsed_pairs, 0);
182    }
183
184    #[test]
185    fn prune_under_budget_no_change() {
186        let mut messages = vec![
187            msg("system", "You are helpful."),
188            msg("user", "Hello"),
189            msg("assistant", "Hi!"),
190        ];
191        let config = HistoryPrunerConfig {
192            enabled: true,
193            max_tokens: 8192,
194            keep_recent: 2,
195            collapse_tool_results: false,
196        };
197        let stats = prune_history(&mut messages, &config);
198        assert_eq!(messages.len(), 3);
199        assert_eq!(stats.collapsed_pairs, 0);
200        assert_eq!(stats.dropped_messages, 0);
201    }
202
203    #[test]
204    fn prune_collapses_tool_pairs() {
205        let tool_result = "a".repeat(160);
206        let mut messages = vec![
207            msg("system", "sys"),
208            msg("assistant", "calling tool X"),
209            msg("tool", &tool_result),
210            msg("user", "thanks"),
211            msg("assistant", "done"),
212        ];
213        let config = HistoryPrunerConfig {
214            enabled: true,
215            max_tokens: 100_000,
216            keep_recent: 2,
217            collapse_tool_results: true,
218        };
219        let stats = prune_history(&mut messages, &config);
220        assert_eq!(stats.collapsed_pairs, 1);
221        assert_eq!(messages.len(), 4);
222        assert_eq!(messages[1].role, "assistant");
223        assert!(messages[1].content.starts_with("[Tool result: "));
224    }
225
226    #[test]
227    fn prune_preserves_system_and_recent() {
228        let big = "x".repeat(40_000);
229        let mut messages = vec![
230            msg("system", "system prompt"),
231            msg("user", &big),
232            msg("assistant", "old reply"),
233            msg("user", "recent1"),
234            msg("assistant", "recent2"),
235        ];
236        let config = HistoryPrunerConfig {
237            enabled: true,
238            max_tokens: 100,
239            keep_recent: 2,
240            collapse_tool_results: false,
241        };
242        let stats = prune_history(&mut messages, &config);
243        assert!(messages.iter().any(|m| m.role == "system"));
244        assert!(messages.iter().any(|m| m.content == "recent1"));
245        assert!(messages.iter().any(|m| m.content == "recent2"));
246        assert!(stats.dropped_messages > 0);
247    }
248
249    #[test]
250    fn prune_drops_oldest_when_over_budget() {
251        let filler = "y".repeat(400);
252        let mut messages = vec![
253            msg("system", "sys"),
254            msg("user", &filler),
255            msg("assistant", &filler),
256            msg("user", "recent-user"),
257            msg("assistant", "recent-assistant"),
258        ];
259        let config = HistoryPrunerConfig {
260            enabled: true,
261            max_tokens: 150,
262            keep_recent: 2,
263            collapse_tool_results: false,
264        };
265        let stats = prune_history(&mut messages, &config);
266        assert!(stats.dropped_messages >= 1);
267        assert_eq!(messages[0].role, "system");
268        assert!(messages.iter().any(|m| m.content == "recent-user"));
269        assert!(messages.iter().any(|m| m.content == "recent-assistant"));
270    }
271
272    #[test]
273    fn prune_empty_messages() {
274        let mut messages: Vec<ChatMessage> = vec![];
275        let config = HistoryPrunerConfig {
276            enabled: true,
277            ..Default::default()
278        };
279        let stats = prune_history(&mut messages, &config);
280        assert_eq!(stats.messages_before, 0);
281        assert_eq!(stats.messages_after, 0);
282    }
283}