Skip to main content

aster/context/
priority_sorter.rs

1//! Message Priority Sorter Module
2//!
3//! This module provides message priority sorting functionality for context management.
4//! It assigns priority levels to messages based on their type, recency, and content,
5//! enabling intelligent compression and truncation decisions.
6//!
7//! # Priority Levels
8//!
9//! - **Critical**: System messages and summaries (must be preserved)
10//! - **High**: Recent messages (last 20%) and messages with tool calls
11//! - **Medium**: Middle messages (50-80% of conversation)
12//! - **Low**: Older messages (20-50% of conversation)
13//! - **Minimal**: Oldest messages (first 20%)
14//!
15//! # Example
16//!
17//! ```rust,ignore
18//! use aster::context::priority_sorter::PrioritySorter;
19//! use aster::context::types::MessagePriority;
20//!
21//! let messages = vec![/* ... */];
22//! let prioritized = PrioritySorter::sort_by_priority(&messages, |m| estimate_tokens(m));
23//! ```
24
25use crate::context::token_estimator::TokenEstimator;
26use crate::context::types::{MessagePriority, PrioritizedMessage};
27use crate::conversation::message::{Message, MessageContent};
28
29// ============================================================================
30// Constants
31// ============================================================================
32
33/// Threshold for recent messages (last 20%)
34const RECENT_THRESHOLD: f64 = 0.8;
35
36/// Threshold for medium priority messages (50-80%)
37const MEDIUM_THRESHOLD: f64 = 0.5;
38
39/// Threshold for low priority messages (20-50%)
40const LOW_THRESHOLD: f64 = 0.2;
41
42/// Keywords that indicate a summary message
43const SUMMARY_KEYWORDS: &[&str] = &[
44    "[summary]",
45    "[conversation summary]",
46    "summary:",
47    "summarized:",
48    "previous conversation:",
49];
50
51// ============================================================================
52// PrioritySorter
53// ============================================================================
54
55/// Message priority sorter for intelligent context management.
56///
57/// Assigns priority levels to messages based on:
58/// - Message role (system messages are critical)
59/// - Message content (summaries are critical)
60/// - Message position (recent messages are high priority)
61/// - Tool calls (messages with tool calls are high priority)
62pub struct PrioritySorter;
63
64impl PrioritySorter {
65    /// Evaluate the priority of a message based on its position and content.
66    ///
67    /// # Priority Assignment Rules
68    ///
69    /// 1. System messages and summaries → Critical
70    /// 2. Recent messages (last 20%) → High
71    /// 3. Messages with tool calls → High
72    /// 4. Middle messages (50-80%) → Medium
73    /// 5. Older messages (20-50%) → Low
74    /// 6. Oldest messages (first 20%) → Minimal
75    ///
76    /// # Arguments
77    ///
78    /// * `message` - The message to evaluate
79    /// * `index` - The message's position in the conversation (0-based)
80    /// * `total_messages` - Total number of messages in the conversation
81    ///
82    /// # Returns
83    ///
84    /// The assigned `MessagePriority` level.
85    ///
86    /// # Example
87    ///
88    /// ```rust,ignore
89    /// let priority = PrioritySorter::evaluate_priority(&message, 5, 10);
90    /// assert_eq!(priority, MessagePriority::High); // Last 50% = recent
91    /// ```
92    pub fn evaluate_priority(
93        message: &Message,
94        index: usize,
95        total_messages: usize,
96    ) -> MessagePriority {
97        // Rule 1: System messages and summaries are Critical
98        if Self::is_system_or_summary(message) {
99            return MessagePriority::Critical;
100        }
101
102        // Rule 2 & 3: Check for tool calls (High priority)
103        if Self::has_tool_calls(message) {
104            return MessagePriority::High;
105        }
106
107        // Calculate position ratio (0.0 = oldest, 1.0 = newest)
108        let position_ratio = if total_messages <= 1 {
109            1.0
110        } else {
111            index as f64 / (total_messages - 1) as f64
112        };
113
114        // Rule 2: Recent messages (last 20%) are High priority
115        if position_ratio >= RECENT_THRESHOLD {
116            return MessagePriority::High;
117        }
118
119        // Rule 4: Middle messages (50-80%) are Medium priority
120        if position_ratio >= MEDIUM_THRESHOLD {
121            return MessagePriority::Medium;
122        }
123
124        // Rule 5: Older messages (20-50%) are Low priority
125        if position_ratio >= LOW_THRESHOLD {
126            return MessagePriority::Low;
127        }
128
129        // Rule 6: Oldest messages (first 20%) are Minimal priority
130        MessagePriority::Minimal
131    }
132
133    /// Sort messages by priority, then by timestamp (descending).
134    ///
135    /// Creates a list of `PrioritizedMessage` objects sorted by:
136    /// 1. Priority (Critical > High > Medium > Low > Minimal)
137    /// 2. Timestamp (newer messages first within same priority)
138    ///
139    /// # Arguments
140    ///
141    /// * `messages` - The messages to sort
142    /// * `estimate_tokens` - Function to estimate token count for a message
143    ///
144    /// # Returns
145    ///
146    /// A vector of `PrioritizedMessage` sorted by priority and timestamp.
147    ///
148    /// # Example
149    ///
150    /// ```rust,ignore
151    /// let sorted = PrioritySorter::sort_by_priority(&messages, |m| {
152    ///     TokenEstimator::estimate_message_tokens(m)
153    /// });
154    /// ```
155    pub fn sort_by_priority<F>(messages: &[Message], estimate_tokens: F) -> Vec<PrioritizedMessage>
156    where
157        F: Fn(&Message) -> usize,
158    {
159        let total_messages = messages.len();
160
161        let mut prioritized: Vec<PrioritizedMessage> = messages
162            .iter()
163            .enumerate()
164            .map(|(index, message)| {
165                let priority = Self::evaluate_priority(message, index, total_messages);
166                let tokens = estimate_tokens(message);
167
168                PrioritizedMessage::new(message.clone(), priority, message.created, tokens)
169            })
170            .collect();
171
172        // Sort by priority (descending) then by timestamp (descending)
173        prioritized.sort_by(|a, b| match b.priority.cmp(&a.priority) {
174            std::cmp::Ordering::Equal => b.timestamp.cmp(&a.timestamp),
175            other => other,
176        });
177
178        prioritized
179    }
180
181    /// Sort messages by priority using the default token estimator.
182    ///
183    /// Convenience method that uses `TokenEstimator::estimate_message_tokens`.
184    ///
185    /// # Arguments
186    ///
187    /// * `messages` - The messages to sort
188    ///
189    /// # Returns
190    ///
191    /// A vector of `PrioritizedMessage` sorted by priority and timestamp.
192    pub fn sort_by_priority_default(messages: &[Message]) -> Vec<PrioritizedMessage> {
193        Self::sort_by_priority(messages, TokenEstimator::estimate_message_tokens)
194    }
195
196    /// Check if a message is a system message or contains a summary.
197    ///
198    /// # Arguments
199    ///
200    /// * `message` - The message to check
201    ///
202    /// # Returns
203    ///
204    /// `true` if the message is a system message or contains summary content.
205    pub fn is_system_or_summary(message: &Message) -> bool {
206        // Check if it's a system role (Note: rmcp::model::Role doesn't have System,
207        // but we check for user messages that might contain system-like content)
208        // In practice, system prompts are handled separately, so we focus on summaries
209
210        // Check message content for summary indicators
211        for content in &message.content {
212            if let MessageContent::Text(text_content) = content {
213                let text_lower = text_content.text.to_lowercase();
214                for keyword in SUMMARY_KEYWORDS {
215                    if text_lower.contains(keyword) {
216                        return true;
217                    }
218                }
219            }
220        }
221
222        false
223    }
224
225    /// Check if a message contains tool calls (requests or responses).
226    ///
227    /// # Arguments
228    ///
229    /// * `message` - The message to check
230    ///
231    /// # Returns
232    ///
233    /// `true` if the message contains any tool-related content.
234    pub fn has_tool_calls(message: &Message) -> bool {
235        message.content.iter().any(|content| {
236            matches!(
237                content,
238                MessageContent::ToolRequest(_)
239                    | MessageContent::ToolResponse(_)
240                    | MessageContent::ToolConfirmationRequest(_)
241                    | MessageContent::FrontendToolRequest(_)
242            )
243        })
244    }
245
246    /// Filter messages by minimum priority level.
247    ///
248    /// Returns only messages with priority >= the specified minimum.
249    ///
250    /// # Arguments
251    ///
252    /// * `prioritized` - The prioritized messages to filter
253    /// * `min_priority` - Minimum priority level to include
254    ///
255    /// # Returns
256    ///
257    /// A vector of messages meeting the minimum priority requirement.
258    pub fn filter_by_priority(
259        prioritized: &[PrioritizedMessage],
260        min_priority: MessagePriority,
261    ) -> Vec<PrioritizedMessage> {
262        prioritized
263            .iter()
264            .filter(|p| p.priority >= min_priority)
265            .cloned()
266            .collect()
267    }
268
269    /// Select messages within a token budget, prioritizing higher priority messages.
270    ///
271    /// # Arguments
272    ///
273    /// * `prioritized` - The prioritized messages (should be pre-sorted)
274    /// * `max_tokens` - Maximum total tokens to include
275    ///
276    /// # Returns
277    ///
278    /// A vector of messages fitting within the token budget.
279    pub fn select_within_budget(
280        prioritized: &[PrioritizedMessage],
281        max_tokens: usize,
282    ) -> Vec<PrioritizedMessage> {
283        let mut result = Vec::new();
284        let mut current_tokens = 0;
285
286        for pm in prioritized {
287            if current_tokens + pm.tokens <= max_tokens {
288                result.push(pm.clone());
289                current_tokens += pm.tokens;
290            }
291        }
292
293        result
294    }
295
296    /// Get priority distribution statistics for a set of messages.
297    ///
298    /// # Arguments
299    ///
300    /// * `messages` - The messages to analyze
301    ///
302    /// # Returns
303    ///
304    /// A tuple of (critical_count, high_count, medium_count, low_count, minimal_count)
305    pub fn get_priority_distribution(messages: &[Message]) -> (usize, usize, usize, usize, usize) {
306        let total = messages.len();
307        let mut critical = 0;
308        let mut high = 0;
309        let mut medium = 0;
310        let mut low = 0;
311        let mut minimal = 0;
312
313        for (index, message) in messages.iter().enumerate() {
314            match Self::evaluate_priority(message, index, total) {
315                MessagePriority::Critical => critical += 1,
316                MessagePriority::High => high += 1,
317                MessagePriority::Medium => medium += 1,
318                MessagePriority::Low => low += 1,
319                MessagePriority::Minimal => minimal += 1,
320            }
321        }
322
323        (critical, high, medium, low, minimal)
324    }
325}
326
327// ============================================================================
328// Tests
329// ============================================================================
330
331#[cfg(test)]
332mod tests {
333    use super::*;
334    use rmcp::model::{CallToolRequestParam, JsonObject, Role};
335
336    fn create_text_message(role: Role, text: &str) -> Message {
337        match role {
338            Role::User => Message::user().with_text(text),
339            Role::Assistant => Message::assistant().with_text(text),
340        }
341    }
342
343    fn create_tool_call_message() -> Message {
344        Message::assistant().with_tool_request(
345            "tool_1",
346            Ok(CallToolRequestParam {
347                name: "test_tool".into(),
348                arguments: Some(JsonObject::new()),
349            }),
350        )
351    }
352
353    fn create_summary_message() -> Message {
354        Message::user().with_text("[Summary] Previous conversation discussed file operations.")
355    }
356
357    #[test]
358    fn test_evaluate_priority_summary_is_critical() {
359        let message = create_summary_message();
360        let priority = PrioritySorter::evaluate_priority(&message, 0, 10);
361        assert_eq!(priority, MessagePriority::Critical);
362    }
363
364    #[test]
365    fn test_evaluate_priority_tool_call_is_high() {
366        let message = create_tool_call_message();
367        let priority = PrioritySorter::evaluate_priority(&message, 0, 10);
368        assert_eq!(priority, MessagePriority::High);
369    }
370
371    #[test]
372    fn test_evaluate_priority_recent_is_high() {
373        let message = create_text_message(Role::User, "Recent message");
374        // Index 9 out of 10 = 90% position (recent)
375        let priority = PrioritySorter::evaluate_priority(&message, 9, 10);
376        assert_eq!(priority, MessagePriority::High);
377    }
378
379    #[test]
380    fn test_evaluate_priority_middle_is_medium() {
381        let message = create_text_message(Role::User, "Middle message");
382        // Index 6 out of 10 = 66% position (medium)
383        let priority = PrioritySorter::evaluate_priority(&message, 6, 10);
384        assert_eq!(priority, MessagePriority::Medium);
385    }
386
387    #[test]
388    fn test_evaluate_priority_older_is_low() {
389        let message = create_text_message(Role::User, "Older message");
390        // Index 3 out of 10 = 33% position (low)
391        let priority = PrioritySorter::evaluate_priority(&message, 3, 10);
392        assert_eq!(priority, MessagePriority::Low);
393    }
394
395    #[test]
396    fn test_evaluate_priority_oldest_is_minimal() {
397        let message = create_text_message(Role::User, "Oldest message");
398        // Index 1 out of 10 = 11% position (minimal)
399        let priority = PrioritySorter::evaluate_priority(&message, 1, 10);
400        assert_eq!(priority, MessagePriority::Minimal);
401    }
402
403    #[test]
404    fn test_is_system_or_summary_with_summary() {
405        let message = create_summary_message();
406        assert!(PrioritySorter::is_system_or_summary(&message));
407    }
408
409    #[test]
410    fn test_is_system_or_summary_without_summary() {
411        let message = create_text_message(Role::User, "Regular message");
412        assert!(!PrioritySorter::is_system_or_summary(&message));
413    }
414
415    #[test]
416    fn test_has_tool_calls_with_tool() {
417        let message = create_tool_call_message();
418        assert!(PrioritySorter::has_tool_calls(&message));
419    }
420
421    #[test]
422    fn test_has_tool_calls_without_tool() {
423        let message = create_text_message(Role::User, "No tools here");
424        assert!(!PrioritySorter::has_tool_calls(&message));
425    }
426
427    #[test]
428    fn test_sort_by_priority_ordering() {
429        let messages = vec![
430            create_text_message(Role::User, "First message"), // Minimal (index 0)
431            create_text_message(Role::Assistant, "Second message"), // Low (index 1)
432            create_summary_message(),                         // Critical (summary)
433            create_text_message(Role::User, "Fourth message"), // Medium (index 3)
434            create_text_message(Role::Assistant, "Fifth message"), // High (index 4)
435        ];
436
437        let sorted = PrioritySorter::sort_by_priority_default(&messages);
438
439        // Critical should be first
440        assert_eq!(sorted[0].priority, MessagePriority::Critical);
441        // High should be second
442        assert_eq!(sorted[1].priority, MessagePriority::High);
443    }
444
445    #[test]
446    fn test_filter_by_priority() {
447        let messages = vec![
448            create_text_message(Role::User, "First"),
449            create_text_message(Role::Assistant, "Second"),
450            create_text_message(Role::User, "Third"),
451            create_text_message(Role::Assistant, "Fourth"),
452            create_text_message(Role::User, "Fifth"),
453        ];
454
455        let prioritized = PrioritySorter::sort_by_priority_default(&messages);
456        let high_and_above =
457            PrioritySorter::filter_by_priority(&prioritized, MessagePriority::High);
458
459        // Only high priority messages should remain
460        for pm in &high_and_above {
461            assert!(pm.priority >= MessagePriority::High);
462        }
463    }
464
465    #[test]
466    fn test_select_within_budget() {
467        let messages = vec![
468            create_text_message(Role::User, "Short"),
469            create_text_message(Role::Assistant, "Also short"),
470            create_text_message(Role::User, "Another short one"),
471        ];
472
473        let prioritized = PrioritySorter::sort_by_priority_default(&messages);
474        let selected = PrioritySorter::select_within_budget(&prioritized, 50);
475
476        // Should select some messages within budget
477        let total_tokens: usize = selected.iter().map(|p| p.tokens).sum();
478        assert!(total_tokens <= 50);
479    }
480
481    #[test]
482    fn test_get_priority_distribution() {
483        let messages = vec![
484            create_summary_message(),                       // Critical
485            create_text_message(Role::User, "First"),       // Minimal
486            create_text_message(Role::Assistant, "Second"), // Low
487            create_text_message(Role::User, "Third"),       // Low
488            create_text_message(Role::Assistant, "Fourth"), // Medium
489            create_text_message(Role::User, "Fifth"),       // Medium
490            create_text_message(Role::Assistant, "Sixth"),  // Medium
491            create_text_message(Role::User, "Seventh"),     // High
492            create_text_message(Role::Assistant, "Eighth"), // High
493            create_tool_call_message(),                     // High (tool call)
494        ];
495
496        let (critical, high, medium, low, _minimal) =
497            PrioritySorter::get_priority_distribution(&messages);
498
499        assert_eq!(critical, 1); // Summary message
500        assert!(high >= 1); // At least the tool call message
501        assert!(medium >= 1);
502        assert!(low >= 1);
503        // Minimal might be 0 or 1 depending on exact thresholds
504    }
505
506    #[test]
507    fn test_single_message_is_high_priority() {
508        let message = create_text_message(Role::User, "Only message");
509        let priority = PrioritySorter::evaluate_priority(&message, 0, 1);
510        // Single message should be high priority (position ratio = 1.0)
511        assert_eq!(priority, MessagePriority::High);
512    }
513
514    #[test]
515    fn test_empty_messages() {
516        let messages: Vec<Message> = vec![];
517        let sorted = PrioritySorter::sort_by_priority_default(&messages);
518        assert!(sorted.is_empty());
519    }
520
521    #[test]
522    fn test_summary_keywords_case_insensitive() {
523        let message = create_text_message(Role::User, "[SUMMARY] This is a summary");
524        assert!(PrioritySorter::is_system_or_summary(&message));
525
526        let message2 = create_text_message(Role::User, "Conversation Summary: blah blah");
527        assert!(PrioritySorter::is_system_or_summary(&message2));
528    }
529}