1use crate::providers::traits::ChatMessage;
2use schemars::JsonSchema;
3use serde::{Deserialize, Serialize};
4
5fn default_max_tokens() -> usize {
10 8192
11}
12
13fn default_keep_recent() -> usize {
14 4
15}
16
17fn default_collapse() -> bool {
18 true
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
22pub struct HistoryPrunerConfig {
23 #[serde(default)]
25 pub enabled: bool,
26 #[serde(default = "default_max_tokens")]
28 pub max_tokens: usize,
29 #[serde(default = "default_keep_recent")]
31 pub keep_recent: usize,
32 #[serde(default = "default_collapse")]
34 pub collapse_tool_results: bool,
35}
36
37impl Default for HistoryPrunerConfig {
38 fn default() -> Self {
39 Self {
40 enabled: false,
41 max_tokens: 8192,
42 keep_recent: 4,
43 collapse_tool_results: true,
44 }
45 }
46}
47
48#[derive(Debug, Clone, PartialEq, Eq)]
53pub struct PruneStats {
54 pub messages_before: usize,
55 pub messages_after: usize,
56 pub collapsed_pairs: usize,
57 pub dropped_messages: usize,
58}
59
60fn estimate_tokens(messages: &[ChatMessage]) -> usize {
65 messages.iter().map(|m| m.content.len() / 4).sum()
66}
67
68fn protected_indices(messages: &[ChatMessage], keep_recent: usize) -> Vec<bool> {
73 let len = messages.len();
74 let mut protected = vec![false; len];
75 for (i, msg) in messages.iter().enumerate() {
76 if msg.role == "system" {
77 protected[i] = true;
78 }
79 }
80 let recent_start = len.saturating_sub(keep_recent);
81 for p in protected.iter_mut().skip(recent_start) {
82 *p = true;
83 }
84 protected
85}
86
87pub fn prune_history(messages: &mut Vec<ChatMessage>, config: &HistoryPrunerConfig) -> PruneStats {
92 let messages_before = messages.len();
93 if !config.enabled || messages.is_empty() {
94 return PruneStats {
95 messages_before,
96 messages_after: messages_before,
97 collapsed_pairs: 0,
98 dropped_messages: 0,
99 };
100 }
101
102 let mut collapsed_pairs: usize = 0;
103
104 if config.collapse_tool_results {
106 let mut i = 0;
107 while i + 1 < messages.len() {
108 let protected = protected_indices(messages, config.keep_recent);
109 if messages[i].role == "assistant"
110 && messages[i + 1].role == "tool"
111 && !protected[i]
112 && !protected[i + 1]
113 {
114 let tool_content = &messages[i + 1].content;
115 let truncated: String = tool_content.chars().take(100).collect();
116 let summary = format!("[Tool result: {truncated}...]");
117 messages[i] = ChatMessage {
118 role: "assistant".to_string(),
119 content: summary,
120 };
121 messages.remove(i + 1);
122 collapsed_pairs += 1;
123 } else {
124 i += 1;
125 }
126 }
127 }
128
129 let mut dropped_messages: usize = 0;
131 while estimate_tokens(messages) > config.max_tokens {
132 let protected = protected_indices(messages, config.keep_recent);
133 if let Some(idx) = protected
134 .iter()
135 .enumerate()
136 .find(|&(_, &p)| !p)
137 .map(|(i, _)| i)
138 {
139 messages.remove(idx);
140 dropped_messages += 1;
141 } else {
142 break;
143 }
144 }
145
146 PruneStats {
147 messages_before,
148 messages_after: messages.len(),
149 collapsed_pairs,
150 dropped_messages,
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157
158 fn msg(role: &str, content: &str) -> ChatMessage {
159 ChatMessage {
160 role: role.to_string(),
161 content: content.to_string(),
162 }
163 }
164
165 #[test]
166 fn prune_disabled_is_noop() {
167 let mut messages = vec![
168 msg("system", "You are helpful."),
169 msg("user", "Hello"),
170 msg("assistant", "Hi there!"),
171 ];
172 let config = HistoryPrunerConfig {
173 enabled: false,
174 ..Default::default()
175 };
176 let stats = prune_history(&mut messages, &config);
177 assert_eq!(messages.len(), 3);
178 assert_eq!(messages[0].content, "You are helpful.");
179 assert_eq!(stats.messages_before, 3);
180 assert_eq!(stats.messages_after, 3);
181 assert_eq!(stats.collapsed_pairs, 0);
182 }
183
184 #[test]
185 fn prune_under_budget_no_change() {
186 let mut messages = vec![
187 msg("system", "You are helpful."),
188 msg("user", "Hello"),
189 msg("assistant", "Hi!"),
190 ];
191 let config = HistoryPrunerConfig {
192 enabled: true,
193 max_tokens: 8192,
194 keep_recent: 2,
195 collapse_tool_results: false,
196 };
197 let stats = prune_history(&mut messages, &config);
198 assert_eq!(messages.len(), 3);
199 assert_eq!(stats.collapsed_pairs, 0);
200 assert_eq!(stats.dropped_messages, 0);
201 }
202
203 #[test]
204 fn prune_collapses_tool_pairs() {
205 let tool_result = "a".repeat(160);
206 let mut messages = vec![
207 msg("system", "sys"),
208 msg("assistant", "calling tool X"),
209 msg("tool", &tool_result),
210 msg("user", "thanks"),
211 msg("assistant", "done"),
212 ];
213 let config = HistoryPrunerConfig {
214 enabled: true,
215 max_tokens: 100_000,
216 keep_recent: 2,
217 collapse_tool_results: true,
218 };
219 let stats = prune_history(&mut messages, &config);
220 assert_eq!(stats.collapsed_pairs, 1);
221 assert_eq!(messages.len(), 4);
222 assert_eq!(messages[1].role, "assistant");
223 assert!(messages[1].content.starts_with("[Tool result: "));
224 }
225
226 #[test]
227 fn prune_preserves_system_and_recent() {
228 let big = "x".repeat(40_000);
229 let mut messages = vec![
230 msg("system", "system prompt"),
231 msg("user", &big),
232 msg("assistant", "old reply"),
233 msg("user", "recent1"),
234 msg("assistant", "recent2"),
235 ];
236 let config = HistoryPrunerConfig {
237 enabled: true,
238 max_tokens: 100,
239 keep_recent: 2,
240 collapse_tool_results: false,
241 };
242 let stats = prune_history(&mut messages, &config);
243 assert!(messages.iter().any(|m| m.role == "system"));
244 assert!(messages.iter().any(|m| m.content == "recent1"));
245 assert!(messages.iter().any(|m| m.content == "recent2"));
246 assert!(stats.dropped_messages > 0);
247 }
248
249 #[test]
250 fn prune_drops_oldest_when_over_budget() {
251 let filler = "y".repeat(400);
252 let mut messages = vec![
253 msg("system", "sys"),
254 msg("user", &filler),
255 msg("assistant", &filler),
256 msg("user", "recent-user"),
257 msg("assistant", "recent-assistant"),
258 ];
259 let config = HistoryPrunerConfig {
260 enabled: true,
261 max_tokens: 150,
262 keep_recent: 2,
263 collapse_tool_results: false,
264 };
265 let stats = prune_history(&mut messages, &config);
266 assert!(stats.dropped_messages >= 1);
267 assert_eq!(messages[0].role, "system");
268 assert!(messages.iter().any(|m| m.content == "recent-user"));
269 assert!(messages.iter().any(|m| m.content == "recent-assistant"));
270 }
271
272 #[test]
273 fn prune_empty_messages() {
274 let mut messages: Vec<ChatMessage> = vec![];
275 let config = HistoryPrunerConfig {
276 enabled: true,
277 ..Default::default()
278 };
279 let stats = prune_history(&mut messages, &config);
280 assert_eq!(stats.messages_before, 0);
281 assert_eq!(stats.messages_after, 0);
282 }
283}