Skip to main content

matrixcode_core/compress/
priority.rs

1//! Dynamic priority scoring for messages.
2//!
3//! This module implements intelligent message prioritization based on
4//! multiple factors such as importance, recency, tool usage, and code content.
5
6use crate::providers::{ContentBlock, Message, MessageContent, Role};
7use std::collections::HashSet;
8
9/// Priority score for a message (0.0 to 1.0).
10#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
11pub struct PriorityScore(pub f32);
12
13impl PriorityScore {
14    pub const MIN: f32 = 0.0;
15    pub const MAX: f32 = 1.0;
16
17    pub fn new(score: f32) -> Self {
18        Self(score.clamp(Self::MIN, Self::MAX))
19    }
20
21    pub fn value(&self) -> f32 {
22        self.0
23    }
24
25    pub fn is_high(&self) -> bool {
26        self.0 >= 0.7
27    }
28
29    pub fn is_medium(&self) -> bool {
30        self.0 >= 0.4 && self.0 < 0.7
31    }
32
33    pub fn is_low(&self) -> bool {
34        self.0 < 0.4
35    }
36}
37
38/// Factors that contribute to message priority.
39#[derive(Debug, Clone, Default)]
40pub struct PriorityFactors {
41    /// Message contains decisions or choices
42    pub has_decision: bool,
43    /// Message contains errors or failures
44    pub has_error: bool,
45    /// Message contains tool calls
46    pub has_tool_use: bool,
47    /// Message contains code blocks
48    pub has_code: bool,
49    /// Message contains important keywords
50    pub has_keywords: bool,
51    /// Message is from user (usually higher priority)
52    pub is_user_message: bool,
53    /// Message position in conversation (normalized 0-1)
54    pub position_weight: f32,
55    /// Message length factor
56    pub length_factor: f32,
57    /// Number of important entities mentioned (files, functions, etc.)
58    pub entity_count: usize,
59}
60
61/// Weights for different priority factors.
62#[derive(Debug, Clone)]
63pub struct PriorityWeights {
64    pub decision_weight: f32,
65    pub error_weight: f32,
66    pub tool_weight: f32,
67    pub code_weight: f32,
68    pub keyword_weight: f32,
69    pub user_message_weight: f32,
70    pub recency_weight: f32,
71    pub length_weight: f32,
72    pub entity_weight: f32,
73}
74
75impl Default for PriorityWeights {
76    fn default() -> Self {
77        Self {
78            decision_weight: 0.2,   // High importance
79            error_weight: 0.15,     // High importance
80            tool_weight: 0.15,     // High importance
81            code_weight: 0.1,       // Medium importance
82            keyword_weight: 0.1,    // Medium importance
83            user_message_weight: 0.1, // Medium importance
84            recency_weight: 0.1,    // Medium importance
85            length_weight: 0.05,    // Low importance
86            entity_weight: 0.05,    // Low importance
87        }
88    }
89}
90
91/// Dynamic priority scorer.
92pub struct PriorityScorer {
93    weights: PriorityWeights,
94    important_keywords: HashSet<String>,
95}
96
97impl Default for PriorityScorer {
98    fn default() -> Self {
99        Self::new(PriorityWeights::default())
100    }
101}
102
103impl PriorityScorer {
104    pub fn new(weights: PriorityWeights) -> Self {
105        let important_keywords = Self::build_keyword_set();
106        Self {
107            weights,
108            important_keywords,
109        }
110    }
111
112    /// Build a set of important keywords that indicate high priority.
113    fn build_keyword_set() -> HashSet<String> {
114        let keywords = [
115            // Decision keywords
116            "important", "critical", "essential", "必须", "重要",
117            "决定", "选择", "decided", "chose", "selected",
118            // Action keywords
119            "fix", "解决", "修复", "implement", "实现", "create", "创建",
120            // Error keywords
121            "error", "错误", "failed", "失败", "exception", "异常",
122            // Success keywords
123            "success", "成功", "completed", "完成", "done", "完成",
124            // Requirement keywords
125            "requirement", "需求", "spec", "规范", "constraint", "约束",
126        ];
127
128        keywords.iter().map(|s| s.to_lowercase()).collect()
129    }
130
131    /// Extract priority factors from a message.
132    pub fn extract_factors(message: &Message, position: usize, total: usize) -> PriorityFactors {
133        let mut factors = PriorityFactors::default();
134
135        // Check role
136        factors.is_user_message = matches!(message.role, Role::User);
137
138        // Position weight (0 = oldest, 1 = newest)
139        factors.position_weight = if total > 1 {
140            position as f32 / (total - 1) as f32
141        } else {
142            1.0
143        };
144
145        // Analyze content
146        match &message.content {
147            MessageContent::Text(text) => {
148                Self::analyze_text(text, &mut factors);
149                factors.length_factor = Self::calculate_length_factor(text.len());
150            }
151            MessageContent::Blocks(blocks) => {
152                let mut combined_text = String::new();
153                for block in blocks {
154                    match block {
155                        ContentBlock::Text { text } => {
156                            combined_text.push_str(text);
157                            combined_text.push(' ');
158                        }
159                        ContentBlock::ToolUse { name, input, .. } => {
160                            factors.has_tool_use = true;
161                            combined_text.push_str(name);
162                            combined_text.push(' ');
163                            combined_text.push_str(&input.to_string());
164                            combined_text.push(' ');
165                        }
166                        ContentBlock::ToolResult { content, .. } => {
167                            combined_text.push_str(content);
168                            combined_text.push(' ');
169                            if content.contains("error") || content.contains("failed") {
170                                factors.has_error = true;
171                            }
172                        }
173                        ContentBlock::Thinking { thinking, .. } => {
174                            combined_text.push_str(thinking);
175                            combined_text.push(' ');
176                        }
177                        _ => {}
178                    }
179                }
180                Self::analyze_text(&combined_text, &mut factors);
181                factors.length_factor = Self::calculate_length_factor(combined_text.len());
182            }
183        }
184
185        factors
186    }
187
188    /// Analyze text content for priority indicators.
189    fn analyze_text(text: &str, factors: &mut PriorityFactors) {
190        let lower = text.to_lowercase();
191
192        // Check for decision indicators
193        if lower.contains("决定") || lower.contains("decided") || lower.contains("chose")
194            || lower.contains("选择") || lower.contains("selected")
195        {
196            factors.has_decision = true;
197        }
198
199        // Check for error indicators
200        if lower.contains("error") || lower.contains("错误") || lower.contains("failed")
201            || lower.contains("失败") || lower.contains("exception") || lower.contains("异常")
202        {
203            factors.has_error = true;
204        }
205
206        // Check for code blocks
207        if text.contains("```") || text.contains("fn ") || text.contains("function ")
208            || text.contains("class ") || text.contains("impl ")
209        {
210            factors.has_code = true;
211        }
212
213        // Check for important keywords
214        factors.has_keywords = lower.split_whitespace().any(|word| {
215            word.trim_matches(|c: char| c.is_ascii_punctuation()).eq_ignore_ascii_case("important")
216                || word.eq_ignore_ascii_case("critical")
217                || word.eq_ignore_ascii_case("essential")
218                || word.eq_ignore_ascii_case("必须")
219                || word.eq_ignore_ascii_case("重要")
220        });
221
222        // Count entities (files, functions, etc.)
223        factors.entity_count = Self::count_entities(text);
224    }
225
226    /// Count important entities in text.
227    fn count_entities(text: &str) -> usize {
228        let mut count = 0;
229
230        // Count file references (e.g., "src/main.rs", "package.json")
231        if text.contains(".rs") || text.contains(".py") || text.contains(".js")
232            || text.contains(".ts") || text.contains(".json") || text.contains(".toml")
233        {
234            count += 1;
235        }
236
237        // Count function references (e.g., "fn main", "function test")
238        for pattern in &["fn ", "function ", "def ", "class ", "impl "] {
239            if text.contains(pattern) {
240                count += 1;
241            }
242        }
243
244        // Count API endpoints
245        if text.contains("GET /") || text.contains("POST /") || text.contains("PUT /")
246            || text.contains("DELETE /")
247        {
248            count += 1;
249        }
250
251        count
252    }
253
254    /// Calculate length factor (longer messages may be more important).
255    fn calculate_length_factor(len: usize) -> f32 {
256        // Normalize length: 0-100 chars = 0.0-1.0
257        // Cap at 1.0 for very long messages
258        (len as f32 / 100.0).min(1.0)
259    }
260
261    /// Calculate priority score for a message.
262    pub fn score(&self, message: &Message, position: usize, total: usize) -> PriorityScore {
263        let factors = Self::extract_factors(message, position, total);
264        self.score_from_factors(&factors)
265    }
266
267    /// Calculate priority score from factors.
268    pub fn score_from_factors(&self, factors: &PriorityFactors) -> PriorityScore {
269        let mut score = 0.0;
270
271        if factors.has_decision {
272            score += self.weights.decision_weight;
273        }
274        if factors.has_error {
275            score += self.weights.error_weight;
276        }
277        if factors.has_tool_use {
278            score += self.weights.tool_weight;
279        }
280        if factors.has_code {
281            score += self.weights.code_weight;
282        }
283        if factors.has_keywords {
284            score += self.weights.keyword_weight;
285        }
286        if factors.is_user_message {
287            score += self.weights.user_message_weight;
288        }
289
290        // Add recency weight (more recent = higher priority)
291        score += factors.position_weight * self.weights.recency_weight;
292
293        // Add length weight
294        score += factors.length_factor * self.weights.length_weight;
295
296        // Add entity weight
297        score += (factors.entity_count as f32 * 0.02).min(self.weights.entity_weight);
298
299        PriorityScore::new(score)
300    }
301
302    /// Get priority level description.
303    pub fn level(score: PriorityScore) -> &'static str {
304        if score.is_high() {
305            "High"
306        } else if score.is_medium() {
307            "Medium"
308        } else {
309            "Low"
310        }
311    }
312}
313
314/// Message with priority score.
315#[derive(Debug, Clone)]
316pub struct ScoredMessage {
317    pub message: Message,
318    pub score: PriorityScore,
319    pub position: usize,
320    pub factors: PriorityFactors,
321}
322
323#[cfg(test)]
324mod tests {
325    use super::*;
326
327    #[test]
328    fn test_priority_score_clamping() {
329        assert_eq!(PriorityScore::new(-1.0).value(), 0.0);
330        assert_eq!(PriorityScore::new(2.0).value(), 1.0);
331        assert_eq!(PriorityScore::new(0.5).value(), 0.5);
332    }
333
334    #[test]
335    fn test_priority_levels() {
336        let high = PriorityScore::new(0.8);
337        assert!(high.is_high());
338        assert!(!high.is_medium());
339        assert!(!high.is_low());
340
341        let medium = PriorityScore::new(0.5);
342        assert!(!medium.is_high());
343        assert!(medium.is_medium());
344        assert!(!medium.is_low());
345
346        let low = PriorityScore::new(0.2);
347        assert!(!low.is_high());
348        assert!(!low.is_medium());
349        assert!(low.is_low());
350    }
351
352    #[test]
353    fn test_extract_factors_user_message() {
354        let msg = Message {
355            role: Role::User,
356            content: MessageContent::Text("Hello".to_string()),
357        };
358        let factors = PriorityScorer::extract_factors(&msg, 0, 1);
359        assert!(factors.is_user_message);
360    }
361
362    #[test]
363    fn test_extract_factors_decision() {
364        let msg = Message {
365            role: Role::Assistant,
366            content: MessageContent::Text("I decided to use Rust.".to_string()),
367        };
368        let factors = PriorityScorer::extract_factors(&msg, 0, 1);
369        assert!(factors.has_decision);
370    }
371
372    #[test]
373    fn test_extract_factors_error() {
374        let msg = Message {
375            role: Role::Assistant,
376            content: MessageContent::Text("The operation failed with error.".to_string()),
377        };
378        let factors = PriorityScorer::extract_factors(&msg, 0, 1);
379        assert!(factors.has_error);
380    }
381
382    #[test]
383    fn test_extract_factors_code() {
384        let msg = Message {
385            role: Role::Assistant,
386            content: MessageContent::Text("Here's the code:\n```rust\nfn main() {}\n```".to_string()),
387        };
388        let factors = PriorityScorer::extract_factors(&msg, 0, 1);
389        assert!(factors.has_code);
390    }
391
392    #[test]
393    fn test_extract_factors_tool_use() {
394        let msg = Message {
395            role: Role::Assistant,
396            content: MessageContent::Blocks(vec![ContentBlock::ToolUse {
397                id: "tool_1".to_string(),
398                name: "bash".to_string(),
399                input: serde_json::json!({"command": "ls"}),
400            }]),
401        };
402        let factors = PriorityScorer::extract_factors(&msg, 0, 1);
403        assert!(factors.has_tool_use);
404    }
405
406    #[test]
407    fn test_score_calculation() {
408        let scorer = PriorityScorer::default();
409        
410        // High priority message
411        let msg = Message {
412            role: Role::User,
413            content: MessageContent::Text("I decided to use Rust for this important project. The error was fixed.".to_string()),
414        };
415        let score = scorer.score(&msg, 9, 10);
416        assert!(score.is_high());
417        
418        // Low priority message
419        let msg = Message {
420            role: Role::Assistant,
421            content: MessageContent::Text("ok".to_string()),
422        };
423        let score = scorer.score(&msg, 0, 10);
424        assert!(score.is_low());
425    }
426
427    #[test]
428    fn test_position_weight() {
429        let scorer = PriorityScorer::default();
430        
431        // Old message (position 0)
432        let msg = Message {
433            role: Role::User,
434            content: MessageContent::Text("Test".to_string()),
435        };
436        let factors1 = PriorityScorer::extract_factors(&msg, 0, 10);
437        assert!(factors1.position_weight < 0.2);
438        
439        // New message (position 9)
440        let factors2 = PriorityScorer::extract_factors(&msg, 9, 10);
441        assert!(factors2.position_weight > 0.8);
442    }
443
444    #[test]
445    fn test_entity_counting() {
446        let text = "In src/main.rs, we have fn main() and fn helper()";
447        let count = PriorityScorer::count_entities(text);
448        assert!(count >= 2); // At least .rs and fn mentions
449    }
450}