matrixcode_core/compress/
hierarchical.rs1use crate::providers::{Message, MessageContent, Role};
12use crate::compress::priority::PriorityScore;
13use crate::compress::hardcode_config::HardcodeConfig;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum SummaryLevel {
18 Brief,
21
22 Standard,
25
26 Detailed,
29}
30
31impl SummaryLevel {
32 pub fn from_priority(priority: PriorityScore) -> Self {
34 if priority.is_high() {
35 SummaryLevel::Detailed
36 } else if priority.is_medium() {
37 SummaryLevel::Standard
38 } else {
39 SummaryLevel::Brief
40 }
41 }
42
43 pub fn retention_ratio(&self) -> f32 {
45 match self {
46 SummaryLevel::Brief => 0.25, SummaryLevel::Standard => 0.45, SummaryLevel::Detailed => 0.65, }
50 }
51
52 pub fn max_tokens(&self) -> usize {
54 match self {
55 SummaryLevel::Brief => 100,
56 SummaryLevel::Standard => 200,
57 SummaryLevel::Detailed => 350,
58 }
59 }
60}
61
62#[derive(Debug, Clone)]
64pub struct HierarchicalConfig {
65 pub progressive: bool,
67 pub min_messages: usize,
69 pub max_messages: usize,
71}
72
73impl Default for HierarchicalConfig {
74 fn default() -> Self {
75 Self {
76 progressive: true,
77 min_messages: 10,
78 max_messages: 50,
79 }
80 }
81}
82
83pub struct HierarchicalSummarizer {
85 config: HierarchicalConfig,
86 hardcode_config: HardcodeConfig,
87}
88
89impl Default for HierarchicalSummarizer {
90 fn default() -> Self {
91 Self::new(HierarchicalConfig::default())
92 }
93}
94
95impl HierarchicalSummarizer {
96 pub fn new(config: HierarchicalConfig) -> Self {
97 Self {
98 config,
99 hardcode_config: HardcodeConfig::default(),
100 }
101 }
102
103 pub fn with_hardcode_config(mut self, hardcode_config: HardcodeConfig) -> Self {
105 self.hardcode_config = hardcode_config;
106 self
107 }
108
109 pub fn summarize_message(&self, message: &Message, level: SummaryLevel) -> String {
111 let content = match &message.content {
112 MessageContent::Text(text) => text.clone(),
113 MessageContent::Blocks(blocks) => {
114 blocks.iter()
115 .filter_map(|b| match b {
116 crate::providers::ContentBlock::Text { text } => Some(text.clone()),
117 _ => None,
118 })
119 .collect::<Vec<_>>()
120 .join("\n")
121 }
122 };
123
124 if content.is_empty() {
125 return String::new();
126 }
127
128 match level {
130 SummaryLevel::Brief => self.brief_summary(&content, &message.role),
131 SummaryLevel::Standard => self.standard_summary(&content, &message.role),
132 SummaryLevel::Detailed => self.detailed_summary(&content, &message.role),
133 }
134 }
135
136 fn brief_summary(&self, content: &str, role: &Role) -> String {
138 let sentences: Vec<&str> = content
139 .split(|c| c == '。' || c == '.' || c == '\n')
140 .filter(|s| !s.trim().is_empty())
141 .collect();
142
143 if sentences.is_empty() {
144 return truncate_to_chars(content, 50);
145 }
146
147 let first_sentence = sentences[0].trim();
149
150 let key_actions = extract_key_actions(content);
152
153 if key_actions.is_empty() {
154 format!("[{}] {}", role_label(role), truncate_to_chars(first_sentence, 40))
155 } else {
156 format!("[{}] {} | {}", role_label(role), truncate_to_chars(first_sentence, 30), key_actions.join(", "))
157 }
158 }
159
160 fn standard_summary(&self, content: &str, role: &Role) -> String {
162 let sentences: Vec<&str> = content
163 .split(|c| c == '。' || c == '.' || c == '\n')
164 .filter(|s| !s.trim().is_empty())
165 .collect();
166
167 if sentences.is_empty() {
168 return truncate_to_chars(content, 100);
169 }
170
171 let mut summary_parts = Vec::new();
173
174 if let Some(first) = sentences.first() {
176 summary_parts.push(first.trim().to_string());
177 }
178
179 if sentences.len() > self.hardcode_config.brief_summary_sentence_count {
181 if let Some(key_sentence) = find_key_sentence(&sentences[1..sentences.len()-1], &self.hardcode_config) {
182 summary_parts.push(key_sentence);
183 }
184 }
185
186 if sentences.len() > self.hardcode_config.min_messages_for_compression {
188 if let Some(last) = sentences.last() {
189 summary_parts.push(last.trim().to_string());
190 }
191 }
192
193 let entities = extract_entities(content, &self.hardcode_config);
195 if !entities.is_empty() {
196 summary_parts.push(format!("[{}]", entities.join(", ")));
197 }
198
199 format!("[{}] {}", role_label(role), summary_parts.join(" | "))
200 }
201
202 fn detailed_summary(&self, content: &str, role: &Role) -> String {
204 let sentences: Vec<&str> = content
205 .split(|c| c == '。' || c == '.' || c == '\n')
206 .filter(|s| !s.trim().is_empty())
207 .collect();
208
209 if sentences.is_empty() {
210 return truncate_to_chars(content, 200);
211 }
212
213 let compressed_sentences: Vec<String> = sentences
215 .iter()
216 .enumerate()
217 .map(|(i, s)| {
218 if i == 0 || i == sentences.len() - 1 {
219 s.trim().to_string()
221 } else {
222 compress_sentence(s.trim(), &self.hardcode_config)
224 }
225 })
226 .collect();
227
228 let mut result = format!("[{}] ", role_label(role));
230 result.push_str(&compressed_sentences.join(" → "));
231
232 if content.contains("```") {
234 let code_blocks = extract_code_blocks(content);
235 if !code_blocks.is_empty() {
236 result.push_str("\n[代码: ");
237 result.push_str(&code_blocks.len().to_string());
238 result.push_str(" 个代码块]");
239 }
240 }
241
242 result
243 }
244
245 pub fn determine_batch_level(&self, messages: &[Message], priorities: &[PriorityScore]) -> SummaryLevel {
247 if messages.is_empty() || priorities.is_empty() {
248 return SummaryLevel::Standard;
249 }
250
251 let priority_scores: Vec<f32> = priorities
253 .iter()
254 .map(|p| p.value())
255 .collect();
256
257 let avg_score: f32 = priority_scores.iter().sum::<f32>() / priority_scores.len() as f32;
258
259 let count_factor = if messages.len() > self.hardcode_config.large_conversation_threshold {
261 0.8 } else if messages.len() > self.hardcode_config.medium_conversation_threshold {
263 0.9
264 } else {
265 1.0
266 };
267
268 let adjusted_score = avg_score * count_factor;
269
270 if adjusted_score >= 0.75 {
271 SummaryLevel::Detailed
272 } else if adjusted_score >= 0.45 {
273 SummaryLevel::Standard
274 } else {
275 SummaryLevel::Brief
276 }
277 }
278
279 pub fn progressive_summarize(&self, messages: &[Message], priorities: &[PriorityScore]) -> Vec<String> {
281 if messages.is_empty() {
282 return Vec::new();
283 }
284
285 let mut summaries = Vec::with_capacity(messages.len());
286 let total = messages.len();
287
288 for (i, (msg, priority)) in messages.iter().zip(priorities.iter()).enumerate() {
289 let base_level = SummaryLevel::from_priority(*priority);
291
292 let level = if self.config.progressive {
293 let age_factor = (total - i) as f32 / total as f32;
294
295 if age_factor > 0.7 {
297 base_level
299 } else if age_factor > 0.4 {
300 compress_level(base_level)
302 } else {
303 compress_level(compress_level(base_level))
305 }
306 } else {
307 base_level
308 };
309
310 summaries.push(self.summarize_message(msg, level));
311 }
312
313 summaries
314 }
315}
316
317fn role_label(role: &Role) -> &'static str {
321 match role {
322 Role::User => "U",
323 Role::Assistant => "A",
324 Role::System => "S",
325 Role::Tool => "T",
326 }
327}
328
329fn truncate_to_chars(s: &str, max_chars: usize) -> String {
331 if s.chars().count() <= max_chars {
332 s.to_string()
333 } else {
334 s.chars().take(max_chars).collect::<String>() + "..."
335 }
336}
337
338fn extract_key_actions(content: &str) -> Vec<String> {
340 let action_keywords = [
341 "创建", "删除", "修改", "更新", "查询", "搜索", "分析", "优化",
342 "create", "delete", "update", "query", "search", "analyze", "optimize",
343 "fix", "add", "remove", "refactor", "test"
344 ];
345
346 let mut actions = Vec::new();
347 let lower = content.to_lowercase();
348
349 for keyword in &action_keywords {
350 if lower.contains(keyword) {
351 actions.push(keyword.to_string());
352 if actions.len() >= 3 {
353 break;
354 }
355 }
356 }
357
358 actions
359}
360
361fn extract_entities(content: &str, config: &HardcodeConfig) -> Vec<String> {
363 let mut entities = Vec::new();
365
366 let in_quotes: Vec<&str> = content
368 .split('"')
369 .enumerate()
370 .filter(|(i, _)| i % 2 == 1)
371 .map(|(_, s)| s)
372 .take(3)
373 .collect();
374
375 for q in in_quotes {
376 if config.is_valid_question_length(q.len()) {
377 entities.push(format!("\"{}\"", truncate_to_chars(q, config.max_question_extract_length)));
378 }
379 }
380
381 entities
382}
383
384fn find_key_sentence(sentences: &[&str], config: &HardcodeConfig) -> Option<String> {
386 let key_terms = ["error", "问题", "result", "结果", "success", "成功", "fail", "失败"];
388
389 sentences
390 .iter()
391 .filter(|s| s.len() > config.min_sentence_length)
392 .max_by(|a, b| {
393 let a_score = key_terms.iter().filter(|t| a.contains(*t)).count();
394 let b_score = key_terms.iter().filter(|t| b.contains(*t)).count();
395 a_score.cmp(&b_score)
396 })
397 .map(|s| s.to_string())
398}
399
400fn compress_sentence(sentence: &str, config: &HardcodeConfig) -> String {
402 let fillers = ["的", "了", "然后", "接着", "因此", "所以", "that", "the", "then", "therefore"];
404
405 let mut compressed = sentence.to_string();
406 for filler in &fillers {
407 if compressed.len() > config.max_compressed_output_length {
408 compressed = compressed.replace(filler, "");
409 }
410 }
411
412 truncate_to_chars(&compressed, config.short_summary_word_count * 5)
413}
414
415fn extract_code_blocks(content: &str) -> Vec<&str> {
417 let mut blocks = Vec::new();
418 let mut in_block = false;
419
420 for line in content.lines() {
421 if line.contains("```") {
422 in_block = !in_block;
423 } else if in_block {
424 blocks.push(line);
425 }
426 }
427
428 blocks
429}
430
431fn compress_level(level: SummaryLevel) -> SummaryLevel {
433 match level {
434 SummaryLevel::Detailed => SummaryLevel::Standard,
435 SummaryLevel::Standard => SummaryLevel::Brief,
436 SummaryLevel::Brief => SummaryLevel::Brief,
437 }
438}
439
440#[cfg(test)]
441mod tests {
442 use super::*;
443
444 #[test]
445 fn test_summary_level_from_priority() {
446 assert_eq!(SummaryLevel::from_priority(PriorityScore::new(0.9)), SummaryLevel::Detailed);
447 assert_eq!(SummaryLevel::from_priority(PriorityScore::new(0.75)), SummaryLevel::Detailed);
448 assert_eq!(SummaryLevel::from_priority(PriorityScore::new(0.5)), SummaryLevel::Standard);
449 assert_eq!(SummaryLevel::from_priority(PriorityScore::new(0.3)), SummaryLevel::Brief);
450 }
451
452 #[test]
453 fn test_retention_ratio() {
454 assert!((SummaryLevel::Brief.retention_ratio() - 0.25).abs() < 0.01);
455 assert!((SummaryLevel::Standard.retention_ratio() - 0.45).abs() < 0.01);
456 assert!((SummaryLevel::Detailed.retention_ratio() - 0.65).abs() < 0.01);
457 }
458
459 #[test]
460 fn test_brief_summary() {
461 let summarizer = HierarchicalSummarizer::default();
462 let msg = Message {
463 role: Role::User,
464 content: MessageContent::Text("我需要创建一个新的API接口来处理用户认证。请帮我实现这个功能。".to_string()),
465 };
466
467 let summary = summarizer.summarize_message(&msg, SummaryLevel::Brief);
468 assert!(summary.contains("[U]"));
469 assert!(summary.len() < 100);
470 }
471
472 #[test]
473 fn test_standard_summary() {
474 let summarizer = HierarchicalSummarizer::default();
475 let msg = Message {
476 role: Role::Assistant,
477 content: MessageContent::Text("好的,我来创建API接口。首先需要设计数据结构。然后实现认证逻辑。最后添加测试用例。".to_string()),
478 };
479
480 let summary = summarizer.summarize_message(&msg, SummaryLevel::Standard);
481 assert!(summary.contains("[A]"));
482 assert!(summary.len() < 200);
483 }
484
485 #[test]
486 fn test_detailed_summary() {
487 let summarizer = HierarchicalSummarizer::default();
488 let msg = Message {
489 role: Role::Assistant,
490 content: MessageContent::Text("这是一个详细的实现方案。首先,我们需要考虑性能问题。其次,安全性也���重要。最后,要确保代码可维护性。".to_string()),
491 };
492
493 let summary = summarizer.summarize_message(&msg, SummaryLevel::Detailed);
494 assert!(summary.contains("[A]"));
495 assert!(summary.contains("→")); }
497
498 #[test]
499 fn test_progressive_summarize() {
500 let summarizer = HierarchicalSummarizer::default();
501 let messages = vec![
502 Message {
503 role: Role::User,
504 content: MessageContent::Text("第一条消息".to_string()),
505 },
506 Message {
507 role: Role::Assistant,
508 content: MessageContent::Text("第二条消息".to_string()),
509 },
510 Message {
511 role: Role::User,
512 content: MessageContent::Text("第三条消息".to_string()),
513 },
514 ];
515 let priorities = vec![
516 PriorityScore::new(0.3),
517 PriorityScore::new(0.5),
518 PriorityScore::new(0.8),
519 ];
520
521 let summaries = summarizer.progressive_summarize(&messages, &priorities);
522 assert_eq!(summaries.len(), 3);
523 }
524}