matrixcode_core/compress/
priority.rs1use crate::providers::{ContentBlock, Message, MessageContent, Role};
7use std::collections::HashSet;
8
9#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
11pub struct PriorityScore(pub f32);
12
13impl PriorityScore {
14 pub const MIN: f32 = 0.0;
15 pub const MAX: f32 = 1.0;
16
17 pub fn new(score: f32) -> Self {
18 Self(score.clamp(Self::MIN, Self::MAX))
19 }
20
21 pub fn value(&self) -> f32 {
22 self.0
23 }
24
25 pub fn is_high(&self) -> bool {
26 self.0 >= 0.7
27 }
28
29 pub fn is_medium(&self) -> bool {
30 self.0 >= 0.4 && self.0 < 0.7
31 }
32
33 pub fn is_low(&self) -> bool {
34 self.0 < 0.4
35 }
36}
37
38#[derive(Debug, Clone, Default)]
40pub struct PriorityFactors {
41 pub has_decision: bool,
43 pub has_error: bool,
45 pub has_tool_use: bool,
47 pub has_code: bool,
49 pub has_keywords: bool,
51 pub is_user_message: bool,
53 pub position_weight: f32,
55 pub length_factor: f32,
57 pub entity_count: usize,
59}
60
61#[derive(Debug, Clone)]
63pub struct PriorityWeights {
64 pub decision_weight: f32,
65 pub error_weight: f32,
66 pub tool_weight: f32,
67 pub code_weight: f32,
68 pub keyword_weight: f32,
69 pub user_message_weight: f32,
70 pub recency_weight: f32,
71 pub length_weight: f32,
72 pub entity_weight: f32,
73}
74
75impl Default for PriorityWeights {
76 fn default() -> Self {
77 Self {
78 decision_weight: 0.2, error_weight: 0.15, tool_weight: 0.15, code_weight: 0.1, keyword_weight: 0.1, user_message_weight: 0.1, recency_weight: 0.1, length_weight: 0.05, entity_weight: 0.05, }
88 }
89}
90
91pub struct PriorityScorer {
93 weights: PriorityWeights,
94 important_keywords: HashSet<String>,
95}
96
97impl Default for PriorityScorer {
98 fn default() -> Self {
99 Self::new(PriorityWeights::default())
100 }
101}
102
103impl PriorityScorer {
104 pub fn new(weights: PriorityWeights) -> Self {
105 let important_keywords = Self::build_keyword_set();
106 Self {
107 weights,
108 important_keywords,
109 }
110 }
111
112 fn build_keyword_set() -> HashSet<String> {
114 let keywords = [
115 "important", "critical", "essential", "必须", "重要",
117 "决定", "选择", "decided", "chose", "selected",
118 "fix", "解决", "修复", "implement", "实现", "create", "创建",
120 "error", "错误", "failed", "失败", "exception", "异常",
122 "success", "成功", "completed", "完成", "done", "完成",
124 "requirement", "需求", "spec", "规范", "constraint", "约束",
126 ];
127
128 keywords.iter().map(|s| s.to_lowercase()).collect()
129 }
130
131 pub fn extract_factors(message: &Message, position: usize, total: usize) -> PriorityFactors {
133 let mut factors = PriorityFactors::default();
134
135 factors.is_user_message = matches!(message.role, Role::User);
137
138 factors.position_weight = if total > 1 {
140 position as f32 / (total - 1) as f32
141 } else {
142 1.0
143 };
144
145 match &message.content {
147 MessageContent::Text(text) => {
148 Self::analyze_text(text, &mut factors);
149 factors.length_factor = Self::calculate_length_factor(text.len());
150 }
151 MessageContent::Blocks(blocks) => {
152 let mut combined_text = String::new();
153 for block in blocks {
154 match block {
155 ContentBlock::Text { text } => {
156 combined_text.push_str(text);
157 combined_text.push(' ');
158 }
159 ContentBlock::ToolUse { name, input, .. } => {
160 factors.has_tool_use = true;
161 combined_text.push_str(name);
162 combined_text.push(' ');
163 combined_text.push_str(&input.to_string());
164 combined_text.push(' ');
165 }
166 ContentBlock::ToolResult { content, .. } => {
167 combined_text.push_str(content);
168 combined_text.push(' ');
169 if content.contains("error") || content.contains("failed") {
170 factors.has_error = true;
171 }
172 }
173 ContentBlock::Thinking { thinking, .. } => {
174 combined_text.push_str(thinking);
175 combined_text.push(' ');
176 }
177 _ => {}
178 }
179 }
180 Self::analyze_text(&combined_text, &mut factors);
181 factors.length_factor = Self::calculate_length_factor(combined_text.len());
182 }
183 }
184
185 factors
186 }
187
188 fn analyze_text(text: &str, factors: &mut PriorityFactors) {
190 let lower = text.to_lowercase();
191
192 if lower.contains("决定") || lower.contains("decided") || lower.contains("chose")
194 || lower.contains("选择") || lower.contains("selected")
195 {
196 factors.has_decision = true;
197 }
198
199 if lower.contains("error") || lower.contains("错误") || lower.contains("failed")
201 || lower.contains("失败") || lower.contains("exception") || lower.contains("异常")
202 {
203 factors.has_error = true;
204 }
205
206 if text.contains("```") || text.contains("fn ") || text.contains("function ")
208 || text.contains("class ") || text.contains("impl ")
209 {
210 factors.has_code = true;
211 }
212
213 factors.has_keywords = lower.split_whitespace().any(|word| {
215 word.trim_matches(|c: char| c.is_ascii_punctuation()).eq_ignore_ascii_case("important")
216 || word.eq_ignore_ascii_case("critical")
217 || word.eq_ignore_ascii_case("essential")
218 || word.eq_ignore_ascii_case("必须")
219 || word.eq_ignore_ascii_case("重要")
220 });
221
222 factors.entity_count = Self::count_entities(text);
224 }
225
226 fn count_entities(text: &str) -> usize {
228 let mut count = 0;
229
230 if text.contains(".rs") || text.contains(".py") || text.contains(".js")
232 || text.contains(".ts") || text.contains(".json") || text.contains(".toml")
233 {
234 count += 1;
235 }
236
237 for pattern in &["fn ", "function ", "def ", "class ", "impl "] {
239 if text.contains(pattern) {
240 count += 1;
241 }
242 }
243
244 if text.contains("GET /") || text.contains("POST /") || text.contains("PUT /")
246 || text.contains("DELETE /")
247 {
248 count += 1;
249 }
250
251 count
252 }
253
254 fn calculate_length_factor(len: usize) -> f32 {
256 (len as f32 / 100.0).min(1.0)
259 }
260
261 pub fn score(&self, message: &Message, position: usize, total: usize) -> PriorityScore {
263 let factors = Self::extract_factors(message, position, total);
264 self.score_from_factors(&factors)
265 }
266
267 pub fn score_from_factors(&self, factors: &PriorityFactors) -> PriorityScore {
269 let mut score = 0.0;
270
271 if factors.has_decision {
272 score += self.weights.decision_weight;
273 }
274 if factors.has_error {
275 score += self.weights.error_weight;
276 }
277 if factors.has_tool_use {
278 score += self.weights.tool_weight;
279 }
280 if factors.has_code {
281 score += self.weights.code_weight;
282 }
283 if factors.has_keywords {
284 score += self.weights.keyword_weight;
285 }
286 if factors.is_user_message {
287 score += self.weights.user_message_weight;
288 }
289
290 score += factors.position_weight * self.weights.recency_weight;
292
293 score += factors.length_factor * self.weights.length_weight;
295
296 score += (factors.entity_count as f32 * 0.02).min(self.weights.entity_weight);
298
299 PriorityScore::new(score)
300 }
301
302 pub fn level(score: PriorityScore) -> &'static str {
304 if score.is_high() {
305 "High"
306 } else if score.is_medium() {
307 "Medium"
308 } else {
309 "Low"
310 }
311 }
312}
313
314#[derive(Debug, Clone)]
316pub struct ScoredMessage {
317 pub message: Message,
318 pub score: PriorityScore,
319 pub position: usize,
320 pub factors: PriorityFactors,
321}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326
327 #[test]
328 fn test_priority_score_clamping() {
329 assert_eq!(PriorityScore::new(-1.0).value(), 0.0);
330 assert_eq!(PriorityScore::new(2.0).value(), 1.0);
331 assert_eq!(PriorityScore::new(0.5).value(), 0.5);
332 }
333
334 #[test]
335 fn test_priority_levels() {
336 let high = PriorityScore::new(0.8);
337 assert!(high.is_high());
338 assert!(!high.is_medium());
339 assert!(!high.is_low());
340
341 let medium = PriorityScore::new(0.5);
342 assert!(!medium.is_high());
343 assert!(medium.is_medium());
344 assert!(!medium.is_low());
345
346 let low = PriorityScore::new(0.2);
347 assert!(!low.is_high());
348 assert!(!low.is_medium());
349 assert!(low.is_low());
350 }
351
352 #[test]
353 fn test_extract_factors_user_message() {
354 let msg = Message {
355 role: Role::User,
356 content: MessageContent::Text("Hello".to_string()),
357 };
358 let factors = PriorityScorer::extract_factors(&msg, 0, 1);
359 assert!(factors.is_user_message);
360 }
361
362 #[test]
363 fn test_extract_factors_decision() {
364 let msg = Message {
365 role: Role::Assistant,
366 content: MessageContent::Text("I decided to use Rust.".to_string()),
367 };
368 let factors = PriorityScorer::extract_factors(&msg, 0, 1);
369 assert!(factors.has_decision);
370 }
371
372 #[test]
373 fn test_extract_factors_error() {
374 let msg = Message {
375 role: Role::Assistant,
376 content: MessageContent::Text("The operation failed with error.".to_string()),
377 };
378 let factors = PriorityScorer::extract_factors(&msg, 0, 1);
379 assert!(factors.has_error);
380 }
381
382 #[test]
383 fn test_extract_factors_code() {
384 let msg = Message {
385 role: Role::Assistant,
386 content: MessageContent::Text("Here's the code:\n```rust\nfn main() {}\n```".to_string()),
387 };
388 let factors = PriorityScorer::extract_factors(&msg, 0, 1);
389 assert!(factors.has_code);
390 }
391
392 #[test]
393 fn test_extract_factors_tool_use() {
394 let msg = Message {
395 role: Role::Assistant,
396 content: MessageContent::Blocks(vec![ContentBlock::ToolUse {
397 id: "tool_1".to_string(),
398 name: "bash".to_string(),
399 input: serde_json::json!({"command": "ls"}),
400 }]),
401 };
402 let factors = PriorityScorer::extract_factors(&msg, 0, 1);
403 assert!(factors.has_tool_use);
404 }
405
406 #[test]
407 fn test_score_calculation() {
408 let scorer = PriorityScorer::default();
409
410 let msg = Message {
412 role: Role::User,
413 content: MessageContent::Text("I decided to use Rust for this important project. The error was fixed.".to_string()),
414 };
415 let score = scorer.score(&msg, 9, 10);
416 assert!(score.is_high());
417
418 let msg = Message {
420 role: Role::Assistant,
421 content: MessageContent::Text("ok".to_string()),
422 };
423 let score = scorer.score(&msg, 0, 10);
424 assert!(score.is_low());
425 }
426
427 #[test]
428 fn test_position_weight() {
429 let scorer = PriorityScorer::default();
430
431 let msg = Message {
433 role: Role::User,
434 content: MessageContent::Text("Test".to_string()),
435 };
436 let factors1 = PriorityScorer::extract_factors(&msg, 0, 10);
437 assert!(factors1.position_weight < 0.2);
438
439 let factors2 = PriorityScorer::extract_factors(&msg, 9, 10);
441 assert!(factors2.position_weight > 0.8);
442 }
443
444 #[test]
445 fn test_entity_counting() {
446 let text = "In src/main.rs, we have fn main() and fn helper()";
447 let count = PriorityScorer::count_entities(text);
448 assert!(count >= 2); }
450}