Skip to main content

aster/context/
pruner.rs

1//! Progressive Pruner Module
2//!
3//! This module provides progressive pruning functionality for Tool outputs
4//! to manage context size while preserving important information.
5//!
6//! # Pruning Strategies
7//!
8//! - **Soft Trim**: Preserves head and tail of content, replacing middle with "..."
9//! - **Hard Clear**: Completely replaces content with a placeholder
10//!
11//! # Example
12//!
13//! ```rust,ignore
14//! use aster::context::pruner::ProgressivePruner;
15//! use aster::context::PruningConfig;
16//!
17//! let config = PruningConfig::default();
18//! let content = "Very long tool output...";
19//!
20//! // Soft trim: keep head and tail
21//! let trimmed = ProgressivePruner::soft_trim(content, 500, 300);
22//!
23//! // Hard clear: replace with placeholder
24//! let cleared = ProgressivePruner::hard_clear("[content cleared]");
25//! ```
26
27use crate::context::types::{PruningConfig, PruningLevel};
28use crate::conversation::message::{Message, MessageContent};
29use glob::Pattern;
30use rmcp::model::{CallToolResult, Content, RawContent, RawTextContent, Role};
31
32/// Progressive pruner for Tool output management.
33///
34/// Provides methods for soft trimming and hard clearing of content
35/// based on context usage thresholds.
36pub struct ProgressivePruner;
37
38impl ProgressivePruner {
39    // ========================================================================
40    // Core Pruning Operations
41    // ========================================================================
42
43    /// Soft trim content by preserving head and tail, replacing middle with "...".
44    ///
45    /// # Arguments
46    ///
47    /// * `content` - The content to trim
48    /// * `head_chars` - Number of characters to preserve from the head
49    /// * `tail_chars` - Number of characters to preserve from the tail
50    ///
51    /// # Returns
52    ///
53    /// The trimmed content string. If content is shorter than head_chars + tail_chars,
54    /// returns the original content unchanged.
55    ///
56    /// # Example
57    ///
58    /// ```rust,ignore
59    /// let content = "A".repeat(2000);
60    /// let trimmed = ProgressivePruner::soft_trim(&content, 500, 300);
61    /// // Result: first 500 chars + "..." + last 300 chars
62    /// ```
63    pub fn soft_trim(content: &str, head_chars: usize, tail_chars: usize) -> String {
64        let total_len = content.len();
65        let min_len = head_chars + tail_chars;
66
67        // If content is short enough, return unchanged
68        if total_len <= min_len {
69            return content.to_string();
70        }
71
72        let head = Self::safe_substring(content, 0, head_chars);
73        let tail = Self::safe_substring(content, total_len.saturating_sub(tail_chars), total_len);
74
75        let omitted = total_len - head.len() - tail.len();
76        format!("{}...[{} chars omitted]...{}", head, omitted, tail)
77    }
78
79    /// Hard clear content by replacing it entirely with a placeholder.
80    ///
81    /// # Arguments
82    ///
83    /// * `placeholder` - The placeholder text to use
84    ///
85    /// # Returns
86    ///
87    /// The placeholder string.
88    pub fn hard_clear(placeholder: &str) -> String {
89        placeholder.to_string()
90    }
91
92    // ========================================================================
93    // Message Pruning
94    // ========================================================================
95
96    /// Prune messages based on context usage ratio.
97    ///
98    /// This function applies progressive pruning to Tool outputs in messages
99    /// based on the current context usage ratio and configuration.
100    ///
101    /// # Arguments
102    ///
103    /// * `messages` - The messages to prune
104    /// * `usage_ratio` - Current context usage ratio (0.0-1.0)
105    /// * `config` - Pruning configuration
106    ///
107    /// # Returns
108    ///
109    /// A new vector of messages with pruned Tool outputs.
110    pub fn prune_messages(
111        messages: &[Message],
112        usage_ratio: f64,
113        config: &PruningConfig,
114    ) -> Vec<Message> {
115        let pruning_level = config.get_pruning_level(usage_ratio);
116
117        if pruning_level == PruningLevel::None {
118            return messages.to_vec();
119        }
120
121        // Find indices of assistant messages to protect
122        let protected_indices = Self::find_protected_indices(messages, config.keep_last_assistants);
123
124        messages
125            .iter()
126            .enumerate()
127            .map(|(idx, msg)| {
128                if protected_indices.contains(&idx) {
129                    // Protected message, don't prune
130                    msg.clone()
131                } else {
132                    Self::prune_message(msg, pruning_level, config)
133                }
134            })
135            .collect()
136    }
137
138    /// Prune a single message's Tool responses.
139    fn prune_message(
140        message: &Message,
141        pruning_level: PruningLevel,
142        config: &PruningConfig,
143    ) -> Message {
144        let pruned_content: Vec<MessageContent> = message
145            .content
146            .iter()
147            .map(|content| Self::prune_content(content, pruning_level, config))
148            .collect();
149
150        Message {
151            id: message.id.clone(),
152            role: message.role.clone(),
153            created: message.created,
154            content: pruned_content,
155            metadata: message.metadata,
156        }
157    }
158
159    /// Prune a single content block.
160    fn prune_content(
161        content: &MessageContent,
162        pruning_level: PruningLevel,
163        config: &PruningConfig,
164    ) -> MessageContent {
165        match content {
166            MessageContent::ToolResponse(tool_response) => {
167                // Check if this tool should be pruned
168                let tool_name = Self::extract_tool_name_from_response(tool_response);
169                if !Self::is_tool_prunable(&tool_name, config) {
170                    return content.clone();
171                }
172
173                Self::prune_tool_response(tool_response, pruning_level, config)
174            }
175            // Other content types pass through unchanged
176            other => other.clone(),
177        }
178    }
179
180    /// Prune a tool response based on pruning level.
181    fn prune_tool_response(
182        tool_response: &crate::conversation::message::ToolResponse,
183        pruning_level: PruningLevel,
184        config: &PruningConfig,
185    ) -> MessageContent {
186        match &tool_response.tool_result {
187            Ok(result) => {
188                let pruned_content: Vec<Content> = result
189                    .content
190                    .iter()
191                    .map(|c| {
192                        if let RawContent::Text(text) = &c.raw {
193                            let pruned_text = match pruning_level {
194                                PruningLevel::SoftTrim => Self::soft_trim(
195                                    &text.text,
196                                    config.soft_trim_head_chars,
197                                    config.soft_trim_tail_chars,
198                                ),
199                                PruningLevel::HardClear => {
200                                    Self::hard_clear(&config.hard_clear_placeholder)
201                                }
202                                PruningLevel::None => text.text.clone(),
203                            };
204                            Content {
205                                raw: RawContent::Text(RawTextContent {
206                                    text: pruned_text,
207                                    meta: text.meta.clone(),
208                                }),
209                                annotations: c.annotations.clone(),
210                            }
211                        } else {
212                            c.clone()
213                        }
214                    })
215                    .collect();
216
217                MessageContent::ToolResponse(crate::conversation::message::ToolResponse {
218                    id: tool_response.id.clone(),
219                    tool_result: Ok(CallToolResult {
220                        content: pruned_content,
221                        is_error: result.is_error,
222                        meta: result.meta.clone(),
223                        structured_content: result.structured_content.clone(),
224                    }),
225                    metadata: tool_response.metadata.clone(),
226                })
227            }
228            Err(e) => MessageContent::ToolResponse(crate::conversation::message::ToolResponse {
229                id: tool_response.id.clone(),
230                tool_result: Err(e.clone()),
231                metadata: tool_response.metadata.clone(),
232            }),
233        }
234    }
235
236    // ========================================================================
237    // Tool Filtering
238    // ========================================================================
239
240    /// Check if a tool is allowed to be pruned based on configuration.
241    ///
242    /// # Arguments
243    ///
244    /// * `tool_name` - The name of the tool
245    /// * `config` - Pruning configuration
246    ///
247    /// # Returns
248    ///
249    /// `true` if the tool can be pruned, `false` otherwise.
250    pub fn is_tool_prunable(tool_name: &str, config: &PruningConfig) -> bool {
251        // Check denied list first (takes precedence)
252        for denied in &config.denied_tools {
253            if Self::matches_pattern(tool_name, denied) {
254                return false;
255            }
256        }
257
258        // If allowed list is empty, all tools are allowed (except denied)
259        if config.allowed_tools.is_empty() {
260            return true;
261        }
262
263        // Check allowed list
264        for allowed in &config.allowed_tools {
265            if Self::matches_pattern(tool_name, allowed) {
266                return true;
267            }
268        }
269
270        false
271    }
272
273    /// Check if a tool name matches a pattern (supports glob patterns).
274    fn matches_pattern(tool_name: &str, pattern: &str) -> bool {
275        // Try glob pattern matching first
276        if let Ok(glob_pattern) = Pattern::new(pattern) {
277            return glob_pattern.matches(tool_name);
278        }
279
280        // Fall back to exact match
281        tool_name == pattern
282    }
283
284    /// Extract tool name from a tool response (if available).
285    fn extract_tool_name_from_response(
286        tool_response: &crate::conversation::message::ToolResponse,
287    ) -> String {
288        // The tool name is typically stored in metadata or can be inferred
289        // For now, we'll use the id as a fallback
290        tool_response
291            .metadata
292            .as_ref()
293            .and_then(|m| m.get("tool_name"))
294            .and_then(|v| v.as_str())
295            .map(|s| s.to_string())
296            .unwrap_or_else(|| tool_response.id.clone())
297    }
298
299    // ========================================================================
300    // Helper Functions
301    // ========================================================================
302
303    /// Find indices of messages that should be protected from pruning.
304    ///
305    /// Protects the last N assistant messages.
306    fn find_protected_indices(messages: &[Message], keep_last: usize) -> Vec<usize> {
307        let mut protected = Vec::new();
308        let mut assistant_count = 0;
309
310        // Iterate in reverse to find the last N assistant messages
311        for (idx, msg) in messages.iter().enumerate().rev() {
312            if msg.role == Role::Assistant && assistant_count < keep_last {
313                protected.push(idx);
314                assistant_count += 1;
315            }
316        }
317
318        protected
319    }
320
321    /// Safely extract a substring respecting UTF-8 boundaries.
322    fn safe_substring(s: &str, start: usize, end: usize) -> &str {
323        if s.is_empty() || start >= s.len() {
324            return "";
325        }
326
327        // Find the valid start position (first char boundary >= start)
328        let valid_start = s
329            .char_indices()
330            .map(|(i, _)| i)
331            .find(|&i| i >= start)
332            .unwrap_or(s.len());
333
334        // Find the valid end position
335        let valid_end = if end >= s.len() {
336            s.len()
337        } else {
338            s.char_indices()
339                .map(|(i, _)| i)
340                .take_while(|&i| i <= end)
341                .last()
342                .unwrap_or(0)
343        };
344
345        if valid_start >= valid_end {
346            return "";
347        }
348
349        s.get(valid_start..valid_end).unwrap_or("")
350    }
351}
352
353// ============================================================================
354// Tests
355// ============================================================================
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360
361    #[test]
362    fn test_soft_trim_short_content() {
363        let content = "Short content";
364        let result = ProgressivePruner::soft_trim(content, 500, 300);
365        assert_eq!(result, content);
366    }
367
368    #[test]
369    fn test_soft_trim_long_content() {
370        let content = "A".repeat(2000);
371        let result = ProgressivePruner::soft_trim(&content, 500, 300);
372
373        // Should start with head
374        assert!(result.starts_with(&"A".repeat(500)));
375        // Should contain omission marker
376        assert!(result.contains("chars omitted"));
377        // Should end with tail
378        assert!(result.ends_with(&"A".repeat(300)));
379        // Should be shorter than original
380        assert!(result.len() < content.len());
381    }
382
383    #[test]
384    fn test_soft_trim_preserves_head_tail() {
385        let content = format!("{}MIDDLE{}", "HEAD".repeat(100), "TAIL".repeat(100));
386        let result = ProgressivePruner::soft_trim(&content, 400, 400);
387
388        assert!(result.starts_with("HEAD"));
389        assert!(result.ends_with("TAIL"));
390        assert!(result.contains("chars omitted"));
391    }
392
393    #[test]
394    fn test_hard_clear() {
395        let result = ProgressivePruner::hard_clear("[content cleared]");
396        assert_eq!(result, "[content cleared]");
397    }
398
399    #[test]
400    fn test_is_tool_prunable_empty_lists() {
401        let config = PruningConfig::default();
402        assert!(ProgressivePruner::is_tool_prunable("read_file", &config));
403        assert!(ProgressivePruner::is_tool_prunable("write", &config));
404    }
405
406    #[test]
407    fn test_is_tool_prunable_denied_takes_precedence() {
408        let config = PruningConfig::default()
409            .with_allowed_tools(vec!["*".to_string()])
410            .with_denied_tools(vec!["write".to_string()]);
411
412        assert!(ProgressivePruner::is_tool_prunable("read_file", &config));
413        assert!(!ProgressivePruner::is_tool_prunable("write", &config));
414    }
415
416    #[test]
417    fn test_is_tool_prunable_glob_patterns() {
418        let config = PruningConfig::default()
419            .with_allowed_tools(vec!["read_*".to_string(), "grep".to_string()]);
420
421        assert!(ProgressivePruner::is_tool_prunable("read_file", &config));
422        assert!(ProgressivePruner::is_tool_prunable("read_dir", &config));
423        assert!(ProgressivePruner::is_tool_prunable("grep", &config));
424        assert!(!ProgressivePruner::is_tool_prunable("write", &config));
425    }
426
427    #[test]
428    fn test_is_tool_prunable_denied_glob() {
429        let config = PruningConfig::default().with_denied_tools(vec!["write_*".to_string()]);
430
431        assert!(ProgressivePruner::is_tool_prunable("read_file", &config));
432        assert!(!ProgressivePruner::is_tool_prunable("write_file", &config));
433        assert!(!ProgressivePruner::is_tool_prunable("write_dir", &config));
434    }
435
436    #[test]
437    fn test_safe_substring_ascii() {
438        let s = "Hello, World!";
439        assert_eq!(ProgressivePruner::safe_substring(s, 0, 5), "Hello");
440        assert_eq!(ProgressivePruner::safe_substring(s, 7, 12), "World");
441    }
442
443    #[test]
444    fn test_safe_substring_unicode() {
445        let s = "Hello, 世界!";
446        let result = ProgressivePruner::safe_substring(s, 0, 7);
447        assert_eq!(result, "Hello, ");
448
449        // Test with multi-byte characters
450        let result = ProgressivePruner::safe_substring(s, 7, 13);
451        assert!(result.contains("世"));
452    }
453
454    #[test]
455    fn test_safe_substring_empty() {
456        assert_eq!(ProgressivePruner::safe_substring("", 0, 10), "");
457        assert_eq!(ProgressivePruner::safe_substring("hello", 10, 20), "");
458    }
459
460    #[test]
461    fn test_find_protected_indices() {
462        let messages = vec![
463            Message::user().with_text("user 1"),
464            Message::assistant().with_text("assistant 1"),
465            Message::user().with_text("user 2"),
466            Message::assistant().with_text("assistant 2"),
467            Message::user().with_text("user 3"),
468            Message::assistant().with_text("assistant 3"),
469        ];
470
471        let protected = ProgressivePruner::find_protected_indices(&messages, 2);
472
473        // Should protect the last 2 assistant messages (indices 5 and 3)
474        assert!(protected.contains(&5));
475        assert!(protected.contains(&3));
476        assert!(!protected.contains(&1));
477    }
478
479    #[test]
480    fn test_prune_messages_no_pruning() {
481        let messages = vec![
482            Message::user().with_text("Hello"),
483            Message::assistant().with_text("Hi there"),
484        ];
485        let config = PruningConfig::default();
486
487        // Usage ratio below soft_trim_ratio (0.3)
488        let result = ProgressivePruner::prune_messages(&messages, 0.2, &config);
489
490        assert_eq!(result.len(), messages.len());
491    }
492}