Skip to main content

matrixcode_core/compress/
priority.rs

1//! Dynamic priority scoring for messages.
2//!
3//! This module implements intelligent message prioritization based on
4//! multiple factors such as importance, recency, tool usage, and code content.
5
6use crate::providers::{ContentBlock, Message, MessageContent, Role};
7
8/// Priority score for a message (0.0 to 1.0).
9#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
10pub struct PriorityScore(pub f32);
11
12impl PriorityScore {
13    pub const MIN: f32 = 0.0;
14    pub const MAX: f32 = 1.0;
15
16    pub fn new(score: f32) -> Self {
17        Self(score.clamp(Self::MIN, Self::MAX))
18    }
19
20    pub fn value(&self) -> f32 {
21        self.0
22    }
23
24    pub fn is_high(&self) -> bool {
25        self.0 >= 0.7
26    }
27
28    pub fn is_medium(&self) -> bool {
29        self.0 >= 0.4 && self.0 < 0.7
30    }
31
32    pub fn is_low(&self) -> bool {
33        self.0 < 0.4
34    }
35}
36
37/// Factors that contribute to message priority.
38#[derive(Debug, Clone, Default)]
39pub struct PriorityFactors {
40    /// Message contains decisions or choices
41    pub has_decision: bool,
42    /// Message contains errors or failures
43    pub has_error: bool,
44    /// Message contains tool calls
45    pub has_tool_use: bool,
46    /// Message contains code blocks
47    pub has_code: bool,
48    /// Message contains important keywords
49    pub has_keywords: bool,
50    /// Message is from user (usually higher priority)
51    pub is_user_message: bool,
52    /// Message position in conversation (normalized 0-1)
53    pub position_weight: f32,
54    /// Message length factor
55    pub length_factor: f32,
56    /// Number of important entities mentioned (files, functions, etc.)
57    pub entity_count: usize,
58}
59
60/// Weights for different priority factors.
61#[derive(Debug, Clone)]
62pub struct PriorityWeights {
63    pub decision_weight: f32,
64    pub error_weight: f32,
65    pub tool_weight: f32,
66    pub code_weight: f32,
67    pub keyword_weight: f32,
68    pub user_message_weight: f32,
69    pub recency_weight: f32,
70    pub length_weight: f32,
71    pub entity_weight: f32,
72}
73
74impl Default for PriorityWeights {
75    fn default() -> Self {
76        Self {
77            decision_weight: 0.2,   // High importance
78            error_weight: 0.15,     // High importance
79            tool_weight: 0.15,     // High importance
80            code_weight: 0.1,       // Medium importance
81            keyword_weight: 0.1,    // Medium importance
82            user_message_weight: 0.1, // Medium importance
83            recency_weight: 0.1,    // Medium importance
84            length_weight: 0.05,    // Low importance
85            entity_weight: 0.05,    // Low importance
86        }
87    }
88}
89
90/// Dynamic priority scorer.
91pub struct PriorityScorer {
92    weights: PriorityWeights,
93}
94
95impl Default for PriorityScorer {
96    fn default() -> Self {
97        Self::new(PriorityWeights::default())
98    }
99}
100
101impl PriorityScorer {
102    pub fn new(weights: PriorityWeights) -> Self {
103        Self {
104            weights,
105        }
106    }
107
108    /// Extract priority factors from a message.
109    pub fn extract_factors(message: &Message, position: usize, total: usize) -> PriorityFactors {
110        let mut factors = PriorityFactors::default();
111
112        // Check role
113        factors.is_user_message = matches!(message.role, Role::User);
114
115        // Position weight (0 = oldest, 1 = newest)
116        factors.position_weight = if total > 1 {
117            position as f32 / (total - 1) as f32
118        } else {
119            1.0
120        };
121
122        // Analyze content
123        match &message.content {
124            MessageContent::Text(text) => {
125                Self::analyze_text(text, &mut factors);
126                factors.length_factor = Self::calculate_length_factor(text.len());
127            }
128            MessageContent::Blocks(blocks) => {
129                let mut combined_text = String::new();
130                for block in blocks {
131                    match block {
132                        ContentBlock::Text { text } => {
133                            combined_text.push_str(text);
134                            combined_text.push(' ');
135                        }
136                        ContentBlock::ToolUse { name, input, .. } => {
137                            factors.has_tool_use = true;
138                            combined_text.push_str(name);
139                            combined_text.push(' ');
140                            combined_text.push_str(&input.to_string());
141                            combined_text.push(' ');
142                        }
143                        ContentBlock::ToolResult { content, .. } => {
144                            combined_text.push_str(content);
145                            combined_text.push(' ');
146                            if content.contains("error") || content.contains("failed") {
147                                factors.has_error = true;
148                            }
149                        }
150                        ContentBlock::Thinking { thinking, .. } => {
151                            combined_text.push_str(thinking);
152                            combined_text.push(' ');
153                        }
154                        _ => {}
155                    }
156                }
157                Self::analyze_text(&combined_text, &mut factors);
158                factors.length_factor = Self::calculate_length_factor(combined_text.len());
159            }
160        }
161
162        factors
163    }
164
165    /// Analyze text content for priority indicators.
166    fn analyze_text(text: &str, factors: &mut PriorityFactors) {
167        let lower = text.to_lowercase();
168
169        // Check for decision indicators
170        if lower.contains("决定") || lower.contains("decided") || lower.contains("chose")
171            || lower.contains("选择") || lower.contains("selected")
172        {
173            factors.has_decision = true;
174        }
175
176        // Check for error indicators
177        if lower.contains("error") || lower.contains("错误") || lower.contains("failed")
178            || lower.contains("失败") || lower.contains("exception") || lower.contains("异常")
179        {
180            factors.has_error = true;
181        }
182
183        // Check for code blocks
184        if text.contains("```") || text.contains("fn ") || text.contains("function ")
185            || text.contains("class ") || text.contains("impl ")
186        {
187            factors.has_code = true;
188        }
189
190        // Check for important keywords
191        factors.has_keywords = lower.split_whitespace().any(|word| {
192            word.trim_matches(|c: char| c.is_ascii_punctuation()).eq_ignore_ascii_case("important")
193                || word.eq_ignore_ascii_case("critical")
194                || word.eq_ignore_ascii_case("essential")
195                || word.eq_ignore_ascii_case("必须")
196                || word.eq_ignore_ascii_case("重要")
197        });
198
199        // Count entities (files, functions, etc.)
200        factors.entity_count = Self::count_entities(text);
201    }
202
203    /// Count important entities in text.
204    fn count_entities(text: &str) -> usize {
205        let mut count = 0;
206
207        // Count file references (e.g., "src/main.rs", "package.json")
208        if text.contains(".rs") || text.contains(".py") || text.contains(".js")
209            || text.contains(".ts") || text.contains(".json") || text.contains(".toml")
210        {
211            count += 1;
212        }
213
214        // Count function references (e.g., "fn main", "function test")
215        for pattern in &["fn ", "function ", "def ", "class ", "impl "] {
216            if text.contains(pattern) {
217                count += 1;
218            }
219        }
220
221        // Count API endpoints
222        if text.contains("GET /") || text.contains("POST /") || text.contains("PUT /")
223            || text.contains("DELETE /")
224        {
225            count += 1;
226        }
227
228        count
229    }
230
231    /// Calculate length factor (longer messages may be more important).
232    fn calculate_length_factor(len: usize) -> f32 {
233        // Normalize length: 0-100 chars = 0.0-1.0
234        // Cap at 1.0 for very long messages
235        (len as f32 / 100.0).min(1.0)
236    }
237
238    /// Calculate priority score for a message.
239    pub fn score(&self, message: &Message, position: usize, total: usize) -> PriorityScore {
240        let factors = Self::extract_factors(message, position, total);
241        self.score_from_factors(&factors)
242    }
243
244    /// Calculate priority score from factors.
245    pub fn score_from_factors(&self, factors: &PriorityFactors) -> PriorityScore {
246        let mut score = 0.0;
247
248        if factors.has_decision {
249            score += self.weights.decision_weight;
250        }
251        if factors.has_error {
252            score += self.weights.error_weight;
253        }
254        if factors.has_tool_use {
255            score += self.weights.tool_weight;
256        }
257        if factors.has_code {
258            score += self.weights.code_weight;
259        }
260        if factors.has_keywords {
261            score += self.weights.keyword_weight;
262        }
263        if factors.is_user_message {
264            score += self.weights.user_message_weight;
265        }
266
267        // Add recency weight (more recent = higher priority)
268        score += factors.position_weight * self.weights.recency_weight;
269
270        // Add length weight
271        score += factors.length_factor * self.weights.length_weight;
272
273        // Add entity weight
274        score += (factors.entity_count as f32 * 0.02).min(self.weights.entity_weight);
275
276        PriorityScore::new(score)
277    }
278
279    /// Get priority level description.
280    pub fn level(score: PriorityScore) -> &'static str {
281        if score.is_high() {
282            "High"
283        } else if score.is_medium() {
284            "Medium"
285        } else {
286            "Low"
287        }
288    }
289}
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294
295    #[test]
296    fn test_priority_score_clamping() {
297        assert_eq!(PriorityScore::new(-1.0).value(), 0.0);
298        assert_eq!(PriorityScore::new(2.0).value(), 1.0);
299        assert_eq!(PriorityScore::new(0.5).value(), 0.5);
300    }
301
302    #[test]
303    fn test_priority_levels() {
304        let high = PriorityScore::new(0.8);
305        assert!(high.is_high());
306        assert!(!high.is_medium());
307        assert!(!high.is_low());
308
309        let medium = PriorityScore::new(0.5);
310        assert!(!medium.is_high());
311        assert!(medium.is_medium());
312        assert!(!medium.is_low());
313
314        let low = PriorityScore::new(0.2);
315        assert!(!low.is_high());
316        assert!(!low.is_medium());
317        assert!(low.is_low());
318    }
319
320    #[test]
321    fn test_extract_factors_user_message() {
322        let msg = Message {
323            role: Role::User,
324            content: MessageContent::Text("Hello".to_string()),
325        };
326        let factors = PriorityScorer::extract_factors(&msg, 0, 1);
327        assert!(factors.is_user_message);
328    }
329
330    #[test]
331    fn test_extract_factors_decision() {
332        let msg = Message {
333            role: Role::Assistant,
334            content: MessageContent::Text("I decided to use Rust.".to_string()),
335        };
336        let factors = PriorityScorer::extract_factors(&msg, 0, 1);
337        assert!(factors.has_decision);
338    }
339
340    #[test]
341    fn test_extract_factors_error() {
342        let msg = Message {
343            role: Role::Assistant,
344            content: MessageContent::Text("The operation failed with error.".to_string()),
345        };
346        let factors = PriorityScorer::extract_factors(&msg, 0, 1);
347        assert!(factors.has_error);
348    }
349
350    #[test]
351    fn test_extract_factors_code() {
352        let msg = Message {
353            role: Role::Assistant,
354            content: MessageContent::Text("Here's the code:\n```rust\nfn main() {}\n```".to_string()),
355        };
356        let factors = PriorityScorer::extract_factors(&msg, 0, 1);
357        assert!(factors.has_code);
358    }
359
360    #[test]
361    fn test_extract_factors_tool_use() {
362        let msg = Message {
363            role: Role::Assistant,
364            content: MessageContent::Blocks(vec![ContentBlock::ToolUse {
365                id: "tool_1".to_string(),
366                name: "bash".to_string(),
367                input: serde_json::json!({"command": "ls"}),
368            }]),
369        };
370        let factors = PriorityScorer::extract_factors(&msg, 0, 1);
371        assert!(factors.has_tool_use);
372    }
373
374    #[test]
375    fn test_score_calculation() {
376        let scorer = PriorityScorer::default();
377        
378        // High priority message
379        let msg = Message {
380            role: Role::User,
381            content: MessageContent::Text("I decided to use Rust for this important project. The error was fixed.".to_string()),
382        };
383        let score = scorer.score(&msg, 9, 10);
384        assert!(score.is_high());
385        
386        // Low priority message
387        let msg = Message {
388            role: Role::Assistant,
389            content: MessageContent::Text("ok".to_string()),
390        };
391        let score = scorer.score(&msg, 0, 10);
392        assert!(score.is_low());
393    }
394
395    #[test]
396    fn test_position_weight() {
397        let scorer = PriorityScorer::default();
398        
399        // Old message (position 0)
400        let msg = Message {
401            role: Role::User,
402            content: MessageContent::Text("Test".to_string()),
403        };
404        let factors1 = PriorityScorer::extract_factors(&msg, 0, 10);
405        assert!(factors1.position_weight < 0.2);
406        
407        // New message (position 9)
408        let factors2 = PriorityScorer::extract_factors(&msg, 9, 10);
409        assert!(factors2.position_weight > 0.8);
410    }
411
412    #[test]
413    fn test_entity_counting() {
414        let text = "In src/main.rs, we have fn main() and fn helper()";
415        let count = PriorityScorer::count_entities(text);
416        assert!(count >= 2); // At least .rs and fn mentions
417    }
418}