matrixcode_core/compress/
priority.rs1use crate::providers::{ContentBlock, Message, MessageContent, Role};
7
8#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
10pub struct PriorityScore(pub f32);
11
12impl PriorityScore {
13 pub const MIN: f32 = 0.0;
14 pub const MAX: f32 = 1.0;
15
16 pub fn new(score: f32) -> Self {
17 Self(score.clamp(Self::MIN, Self::MAX))
18 }
19
20 pub fn value(&self) -> f32 {
21 self.0
22 }
23
24 pub fn is_high(&self) -> bool {
25 self.0 >= 0.7
26 }
27
28 pub fn is_medium(&self) -> bool {
29 self.0 >= 0.4 && self.0 < 0.7
30 }
31
32 pub fn is_low(&self) -> bool {
33 self.0 < 0.4
34 }
35}
36
37#[derive(Debug, Clone, Default)]
39pub struct PriorityFactors {
40 pub has_decision: bool,
42 pub has_error: bool,
44 pub has_tool_use: bool,
46 pub has_code: bool,
48 pub has_keywords: bool,
50 pub is_user_message: bool,
52 pub position_weight: f32,
54 pub length_factor: f32,
56 pub entity_count: usize,
58}
59
60#[derive(Debug, Clone)]
62pub struct PriorityWeights {
63 pub decision_weight: f32,
64 pub error_weight: f32,
65 pub tool_weight: f32,
66 pub code_weight: f32,
67 pub keyword_weight: f32,
68 pub user_message_weight: f32,
69 pub recency_weight: f32,
70 pub length_weight: f32,
71 pub entity_weight: f32,
72}
73
74impl Default for PriorityWeights {
75 fn default() -> Self {
76 Self {
77 decision_weight: 0.2, error_weight: 0.15, tool_weight: 0.15, code_weight: 0.1, keyword_weight: 0.1, user_message_weight: 0.1, recency_weight: 0.1, length_weight: 0.05, entity_weight: 0.05, }
87 }
88}
89
90pub struct PriorityScorer {
92 weights: PriorityWeights,
93}
94
95impl Default for PriorityScorer {
96 fn default() -> Self {
97 Self::new(PriorityWeights::default())
98 }
99}
100
101impl PriorityScorer {
102 pub fn new(weights: PriorityWeights) -> Self {
103 Self {
104 weights,
105 }
106 }
107
108 pub fn extract_factors(message: &Message, position: usize, total: usize) -> PriorityFactors {
110 let mut factors = PriorityFactors::default();
111
112 factors.is_user_message = matches!(message.role, Role::User);
114
115 factors.position_weight = if total > 1 {
117 position as f32 / (total - 1) as f32
118 } else {
119 1.0
120 };
121
122 match &message.content {
124 MessageContent::Text(text) => {
125 Self::analyze_text(text, &mut factors);
126 factors.length_factor = Self::calculate_length_factor(text.len());
127 }
128 MessageContent::Blocks(blocks) => {
129 let mut combined_text = String::new();
130 for block in blocks {
131 match block {
132 ContentBlock::Text { text } => {
133 combined_text.push_str(text);
134 combined_text.push(' ');
135 }
136 ContentBlock::ToolUse { name, input, .. } => {
137 factors.has_tool_use = true;
138 combined_text.push_str(name);
139 combined_text.push(' ');
140 combined_text.push_str(&input.to_string());
141 combined_text.push(' ');
142 }
143 ContentBlock::ToolResult { content, .. } => {
144 combined_text.push_str(content);
145 combined_text.push(' ');
146 if content.contains("error") || content.contains("failed") {
147 factors.has_error = true;
148 }
149 }
150 ContentBlock::Thinking { thinking, .. } => {
151 combined_text.push_str(thinking);
152 combined_text.push(' ');
153 }
154 _ => {}
155 }
156 }
157 Self::analyze_text(&combined_text, &mut factors);
158 factors.length_factor = Self::calculate_length_factor(combined_text.len());
159 }
160 }
161
162 factors
163 }
164
165 fn analyze_text(text: &str, factors: &mut PriorityFactors) {
167 let lower = text.to_lowercase();
168
169 if lower.contains("决定") || lower.contains("decided") || lower.contains("chose")
171 || lower.contains("选择") || lower.contains("selected")
172 {
173 factors.has_decision = true;
174 }
175
176 if lower.contains("error") || lower.contains("错误") || lower.contains("failed")
178 || lower.contains("失败") || lower.contains("exception") || lower.contains("异常")
179 {
180 factors.has_error = true;
181 }
182
183 if text.contains("```") || text.contains("fn ") || text.contains("function ")
185 || text.contains("class ") || text.contains("impl ")
186 {
187 factors.has_code = true;
188 }
189
190 factors.has_keywords = lower.split_whitespace().any(|word| {
192 word.trim_matches(|c: char| c.is_ascii_punctuation()).eq_ignore_ascii_case("important")
193 || word.eq_ignore_ascii_case("critical")
194 || word.eq_ignore_ascii_case("essential")
195 || word.eq_ignore_ascii_case("必须")
196 || word.eq_ignore_ascii_case("重要")
197 });
198
199 factors.entity_count = Self::count_entities(text);
201 }
202
203 fn count_entities(text: &str) -> usize {
205 let mut count = 0;
206
207 if text.contains(".rs") || text.contains(".py") || text.contains(".js")
209 || text.contains(".ts") || text.contains(".json") || text.contains(".toml")
210 {
211 count += 1;
212 }
213
214 for pattern in &["fn ", "function ", "def ", "class ", "impl "] {
216 if text.contains(pattern) {
217 count += 1;
218 }
219 }
220
221 if text.contains("GET /") || text.contains("POST /") || text.contains("PUT /")
223 || text.contains("DELETE /")
224 {
225 count += 1;
226 }
227
228 count
229 }
230
231 fn calculate_length_factor(len: usize) -> f32 {
233 (len as f32 / 100.0).min(1.0)
236 }
237
238 pub fn score(&self, message: &Message, position: usize, total: usize) -> PriorityScore {
240 let factors = Self::extract_factors(message, position, total);
241 self.score_from_factors(&factors)
242 }
243
244 pub fn score_from_factors(&self, factors: &PriorityFactors) -> PriorityScore {
246 let mut score = 0.0;
247
248 if factors.has_decision {
249 score += self.weights.decision_weight;
250 }
251 if factors.has_error {
252 score += self.weights.error_weight;
253 }
254 if factors.has_tool_use {
255 score += self.weights.tool_weight;
256 }
257 if factors.has_code {
258 score += self.weights.code_weight;
259 }
260 if factors.has_keywords {
261 score += self.weights.keyword_weight;
262 }
263 if factors.is_user_message {
264 score += self.weights.user_message_weight;
265 }
266
267 score += factors.position_weight * self.weights.recency_weight;
269
270 score += factors.length_factor * self.weights.length_weight;
272
273 score += (factors.entity_count as f32 * 0.02).min(self.weights.entity_weight);
275
276 PriorityScore::new(score)
277 }
278
279 pub fn level(score: PriorityScore) -> &'static str {
281 if score.is_high() {
282 "High"
283 } else if score.is_medium() {
284 "Medium"
285 } else {
286 "Low"
287 }
288 }
289}
290
291#[cfg(test)]
292mod tests {
293 use super::*;
294
295 #[test]
296 fn test_priority_score_clamping() {
297 assert_eq!(PriorityScore::new(-1.0).value(), 0.0);
298 assert_eq!(PriorityScore::new(2.0).value(), 1.0);
299 assert_eq!(PriorityScore::new(0.5).value(), 0.5);
300 }
301
302 #[test]
303 fn test_priority_levels() {
304 let high = PriorityScore::new(0.8);
305 assert!(high.is_high());
306 assert!(!high.is_medium());
307 assert!(!high.is_low());
308
309 let medium = PriorityScore::new(0.5);
310 assert!(!medium.is_high());
311 assert!(medium.is_medium());
312 assert!(!medium.is_low());
313
314 let low = PriorityScore::new(0.2);
315 assert!(!low.is_high());
316 assert!(!low.is_medium());
317 assert!(low.is_low());
318 }
319
320 #[test]
321 fn test_extract_factors_user_message() {
322 let msg = Message {
323 role: Role::User,
324 content: MessageContent::Text("Hello".to_string()),
325 };
326 let factors = PriorityScorer::extract_factors(&msg, 0, 1);
327 assert!(factors.is_user_message);
328 }
329
330 #[test]
331 fn test_extract_factors_decision() {
332 let msg = Message {
333 role: Role::Assistant,
334 content: MessageContent::Text("I decided to use Rust.".to_string()),
335 };
336 let factors = PriorityScorer::extract_factors(&msg, 0, 1);
337 assert!(factors.has_decision);
338 }
339
340 #[test]
341 fn test_extract_factors_error() {
342 let msg = Message {
343 role: Role::Assistant,
344 content: MessageContent::Text("The operation failed with error.".to_string()),
345 };
346 let factors = PriorityScorer::extract_factors(&msg, 0, 1);
347 assert!(factors.has_error);
348 }
349
350 #[test]
351 fn test_extract_factors_code() {
352 let msg = Message {
353 role: Role::Assistant,
354 content: MessageContent::Text("Here's the code:\n```rust\nfn main() {}\n```".to_string()),
355 };
356 let factors = PriorityScorer::extract_factors(&msg, 0, 1);
357 assert!(factors.has_code);
358 }
359
360 #[test]
361 fn test_extract_factors_tool_use() {
362 let msg = Message {
363 role: Role::Assistant,
364 content: MessageContent::Blocks(vec![ContentBlock::ToolUse {
365 id: "tool_1".to_string(),
366 name: "bash".to_string(),
367 input: serde_json::json!({"command": "ls"}),
368 }]),
369 };
370 let factors = PriorityScorer::extract_factors(&msg, 0, 1);
371 assert!(factors.has_tool_use);
372 }
373
374 #[test]
375 fn test_score_calculation() {
376 let scorer = PriorityScorer::default();
377
378 let msg = Message {
380 role: Role::User,
381 content: MessageContent::Text("I decided to use Rust for this important project. The error was fixed.".to_string()),
382 };
383 let score = scorer.score(&msg, 9, 10);
384 assert!(score.is_high());
385
386 let msg = Message {
388 role: Role::Assistant,
389 content: MessageContent::Text("ok".to_string()),
390 };
391 let score = scorer.score(&msg, 0, 10);
392 assert!(score.is_low());
393 }
394
395 #[test]
396 fn test_position_weight() {
397 let scorer = PriorityScorer::default();
398
399 let msg = Message {
401 role: Role::User,
402 content: MessageContent::Text("Test".to_string()),
403 };
404 let factors1 = PriorityScorer::extract_factors(&msg, 0, 10);
405 assert!(factors1.position_weight < 0.2);
406
407 let factors2 = PriorityScorer::extract_factors(&msg, 9, 10);
409 assert!(factors2.position_weight > 0.8);
410 }
411
412 #[test]
413 fn test_entity_counting() {
414 let text = "In src/main.rs, we have fn main() and fn helper()";
415 let count = PriorityScorer::count_entities(text);
416 assert!(count >= 2); }
418}