1use crate::context::types::{PruningConfig, PruningLevel};
28use crate::conversation::message::{Message, MessageContent};
29use glob::Pattern;
30use rmcp::model::{CallToolResult, Content, RawContent, RawTextContent, Role};
31
32pub struct ProgressivePruner;
37
38impl ProgressivePruner {
39 pub fn soft_trim(content: &str, head_chars: usize, tail_chars: usize) -> String {
64 let total_len = content.len();
65 let min_len = head_chars + tail_chars;
66
67 if total_len <= min_len {
69 return content.to_string();
70 }
71
72 let head = Self::safe_substring(content, 0, head_chars);
73 let tail = Self::safe_substring(content, total_len.saturating_sub(tail_chars), total_len);
74
75 let omitted = total_len - head.len() - tail.len();
76 format!("{}...[{} chars omitted]...{}", head, omitted, tail)
77 }
78
79 pub fn hard_clear(placeholder: &str) -> String {
89 placeholder.to_string()
90 }
91
92 pub fn prune_messages(
111 messages: &[Message],
112 usage_ratio: f64,
113 config: &PruningConfig,
114 ) -> Vec<Message> {
115 let pruning_level = config.get_pruning_level(usage_ratio);
116
117 if pruning_level == PruningLevel::None {
118 return messages.to_vec();
119 }
120
121 let protected_indices = Self::find_protected_indices(messages, config.keep_last_assistants);
123
124 messages
125 .iter()
126 .enumerate()
127 .map(|(idx, msg)| {
128 if protected_indices.contains(&idx) {
129 msg.clone()
131 } else {
132 Self::prune_message(msg, pruning_level, config)
133 }
134 })
135 .collect()
136 }
137
138 fn prune_message(
140 message: &Message,
141 pruning_level: PruningLevel,
142 config: &PruningConfig,
143 ) -> Message {
144 let pruned_content: Vec<MessageContent> = message
145 .content
146 .iter()
147 .map(|content| Self::prune_content(content, pruning_level, config))
148 .collect();
149
150 Message {
151 id: message.id.clone(),
152 role: message.role.clone(),
153 created: message.created,
154 content: pruned_content,
155 metadata: message.metadata,
156 }
157 }
158
159 fn prune_content(
161 content: &MessageContent,
162 pruning_level: PruningLevel,
163 config: &PruningConfig,
164 ) -> MessageContent {
165 match content {
166 MessageContent::ToolResponse(tool_response) => {
167 let tool_name = Self::extract_tool_name_from_response(tool_response);
169 if !Self::is_tool_prunable(&tool_name, config) {
170 return content.clone();
171 }
172
173 Self::prune_tool_response(tool_response, pruning_level, config)
174 }
175 other => other.clone(),
177 }
178 }
179
180 fn prune_tool_response(
182 tool_response: &crate::conversation::message::ToolResponse,
183 pruning_level: PruningLevel,
184 config: &PruningConfig,
185 ) -> MessageContent {
186 match &tool_response.tool_result {
187 Ok(result) => {
188 let pruned_content: Vec<Content> = result
189 .content
190 .iter()
191 .map(|c| {
192 if let RawContent::Text(text) = &c.raw {
193 let pruned_text = match pruning_level {
194 PruningLevel::SoftTrim => Self::soft_trim(
195 &text.text,
196 config.soft_trim_head_chars,
197 config.soft_trim_tail_chars,
198 ),
199 PruningLevel::HardClear => {
200 Self::hard_clear(&config.hard_clear_placeholder)
201 }
202 PruningLevel::None => text.text.clone(),
203 };
204 Content {
205 raw: RawContent::Text(RawTextContent {
206 text: pruned_text,
207 meta: text.meta.clone(),
208 }),
209 annotations: c.annotations.clone(),
210 }
211 } else {
212 c.clone()
213 }
214 })
215 .collect();
216
217 MessageContent::ToolResponse(crate::conversation::message::ToolResponse {
218 id: tool_response.id.clone(),
219 tool_result: Ok(CallToolResult {
220 content: pruned_content,
221 is_error: result.is_error,
222 meta: result.meta.clone(),
223 structured_content: result.structured_content.clone(),
224 }),
225 metadata: tool_response.metadata.clone(),
226 })
227 }
228 Err(e) => MessageContent::ToolResponse(crate::conversation::message::ToolResponse {
229 id: tool_response.id.clone(),
230 tool_result: Err(e.clone()),
231 metadata: tool_response.metadata.clone(),
232 }),
233 }
234 }
235
236 pub fn is_tool_prunable(tool_name: &str, config: &PruningConfig) -> bool {
251 for denied in &config.denied_tools {
253 if Self::matches_pattern(tool_name, denied) {
254 return false;
255 }
256 }
257
258 if config.allowed_tools.is_empty() {
260 return true;
261 }
262
263 for allowed in &config.allowed_tools {
265 if Self::matches_pattern(tool_name, allowed) {
266 return true;
267 }
268 }
269
270 false
271 }
272
273 fn matches_pattern(tool_name: &str, pattern: &str) -> bool {
275 if let Ok(glob_pattern) = Pattern::new(pattern) {
277 return glob_pattern.matches(tool_name);
278 }
279
280 tool_name == pattern
282 }
283
284 fn extract_tool_name_from_response(
286 tool_response: &crate::conversation::message::ToolResponse,
287 ) -> String {
288 tool_response
291 .metadata
292 .as_ref()
293 .and_then(|m| m.get("tool_name"))
294 .and_then(|v| v.as_str())
295 .map(|s| s.to_string())
296 .unwrap_or_else(|| tool_response.id.clone())
297 }
298
299 fn find_protected_indices(messages: &[Message], keep_last: usize) -> Vec<usize> {
307 let mut protected = Vec::new();
308 let mut assistant_count = 0;
309
310 for (idx, msg) in messages.iter().enumerate().rev() {
312 if msg.role == Role::Assistant && assistant_count < keep_last {
313 protected.push(idx);
314 assistant_count += 1;
315 }
316 }
317
318 protected
319 }
320
321 fn safe_substring(s: &str, start: usize, end: usize) -> &str {
323 if s.is_empty() || start >= s.len() {
324 return "";
325 }
326
327 let valid_start = s
329 .char_indices()
330 .map(|(i, _)| i)
331 .find(|&i| i >= start)
332 .unwrap_or(s.len());
333
334 let valid_end = if end >= s.len() {
336 s.len()
337 } else {
338 s.char_indices()
339 .map(|(i, _)| i)
340 .take_while(|&i| i <= end)
341 .last()
342 .unwrap_or(0)
343 };
344
345 if valid_start >= valid_end {
346 return "";
347 }
348
349 s.get(valid_start..valid_end).unwrap_or("")
350 }
351}
352
353#[cfg(test)]
358mod tests {
359 use super::*;
360
361 #[test]
362 fn test_soft_trim_short_content() {
363 let content = "Short content";
364 let result = ProgressivePruner::soft_trim(content, 500, 300);
365 assert_eq!(result, content);
366 }
367
368 #[test]
369 fn test_soft_trim_long_content() {
370 let content = "A".repeat(2000);
371 let result = ProgressivePruner::soft_trim(&content, 500, 300);
372
373 assert!(result.starts_with(&"A".repeat(500)));
375 assert!(result.contains("chars omitted"));
377 assert!(result.ends_with(&"A".repeat(300)));
379 assert!(result.len() < content.len());
381 }
382
383 #[test]
384 fn test_soft_trim_preserves_head_tail() {
385 let content = format!("{}MIDDLE{}", "HEAD".repeat(100), "TAIL".repeat(100));
386 let result = ProgressivePruner::soft_trim(&content, 400, 400);
387
388 assert!(result.starts_with("HEAD"));
389 assert!(result.ends_with("TAIL"));
390 assert!(result.contains("chars omitted"));
391 }
392
393 #[test]
394 fn test_hard_clear() {
395 let result = ProgressivePruner::hard_clear("[content cleared]");
396 assert_eq!(result, "[content cleared]");
397 }
398
399 #[test]
400 fn test_is_tool_prunable_empty_lists() {
401 let config = PruningConfig::default();
402 assert!(ProgressivePruner::is_tool_prunable("read_file", &config));
403 assert!(ProgressivePruner::is_tool_prunable("write", &config));
404 }
405
406 #[test]
407 fn test_is_tool_prunable_denied_takes_precedence() {
408 let config = PruningConfig::default()
409 .with_allowed_tools(vec!["*".to_string()])
410 .with_denied_tools(vec!["write".to_string()]);
411
412 assert!(ProgressivePruner::is_tool_prunable("read_file", &config));
413 assert!(!ProgressivePruner::is_tool_prunable("write", &config));
414 }
415
416 #[test]
417 fn test_is_tool_prunable_glob_patterns() {
418 let config = PruningConfig::default()
419 .with_allowed_tools(vec!["read_*".to_string(), "grep".to_string()]);
420
421 assert!(ProgressivePruner::is_tool_prunable("read_file", &config));
422 assert!(ProgressivePruner::is_tool_prunable("read_dir", &config));
423 assert!(ProgressivePruner::is_tool_prunable("grep", &config));
424 assert!(!ProgressivePruner::is_tool_prunable("write", &config));
425 }
426
427 #[test]
428 fn test_is_tool_prunable_denied_glob() {
429 let config = PruningConfig::default().with_denied_tools(vec!["write_*".to_string()]);
430
431 assert!(ProgressivePruner::is_tool_prunable("read_file", &config));
432 assert!(!ProgressivePruner::is_tool_prunable("write_file", &config));
433 assert!(!ProgressivePruner::is_tool_prunable("write_dir", &config));
434 }
435
436 #[test]
437 fn test_safe_substring_ascii() {
438 let s = "Hello, World!";
439 assert_eq!(ProgressivePruner::safe_substring(s, 0, 5), "Hello");
440 assert_eq!(ProgressivePruner::safe_substring(s, 7, 12), "World");
441 }
442
443 #[test]
444 fn test_safe_substring_unicode() {
445 let s = "Hello, 世界!";
446 let result = ProgressivePruner::safe_substring(s, 0, 7);
447 assert_eq!(result, "Hello, ");
448
449 let result = ProgressivePruner::safe_substring(s, 7, 13);
451 assert!(result.contains("世"));
452 }
453
454 #[test]
455 fn test_safe_substring_empty() {
456 assert_eq!(ProgressivePruner::safe_substring("", 0, 10), "");
457 assert_eq!(ProgressivePruner::safe_substring("hello", 10, 20), "");
458 }
459
460 #[test]
461 fn test_find_protected_indices() {
462 let messages = vec![
463 Message::user().with_text("user 1"),
464 Message::assistant().with_text("assistant 1"),
465 Message::user().with_text("user 2"),
466 Message::assistant().with_text("assistant 2"),
467 Message::user().with_text("user 3"),
468 Message::assistant().with_text("assistant 3"),
469 ];
470
471 let protected = ProgressivePruner::find_protected_indices(&messages, 2);
472
473 assert!(protected.contains(&5));
475 assert!(protected.contains(&3));
476 assert!(!protected.contains(&1));
477 }
478
479 #[test]
480 fn test_prune_messages_no_pruning() {
481 let messages = vec![
482 Message::user().with_text("Hello"),
483 Message::assistant().with_text("Hi there"),
484 ];
485 let config = PruningConfig::default();
486
487 let result = ProgressivePruner::prune_messages(&messages, 0.2, &config);
489
490 assert_eq!(result.len(), messages.len());
491 }
492}