matrixcode_core/compress/
hierarchical.rs1use crate::providers::{Message, MessageContent, Role};
12use crate::compress::priority::PriorityScore;
13use crate::compress::hardcode_config::HardcodeConfig;
14use std::collections::HashSet;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum SummaryLevel {
19 Brief,
22
23 Standard,
26
27 Detailed,
30}
31
32impl SummaryLevel {
33 pub fn from_priority(priority: PriorityScore) -> Self {
35 if priority.is_high() {
36 SummaryLevel::Detailed
37 } else if priority.is_medium() {
38 SummaryLevel::Standard
39 } else {
40 SummaryLevel::Brief
41 }
42 }
43
44 pub fn retention_ratio(&self) -> f32 {
46 match self {
47 SummaryLevel::Brief => 0.25, SummaryLevel::Standard => 0.45, SummaryLevel::Detailed => 0.65, }
51 }
52
53 pub fn max_tokens(&self) -> usize {
55 match self {
56 SummaryLevel::Brief => 100,
57 SummaryLevel::Standard => 200,
58 SummaryLevel::Detailed => 350,
59 }
60 }
61}
62
63#[derive(Debug, Clone)]
65pub struct HierarchicalConfig {
66 pub progressive: bool,
68 pub min_messages: usize,
70 pub max_messages: usize,
72}
73
74impl Default for HierarchicalConfig {
75 fn default() -> Self {
76 Self {
77 progressive: true,
78 min_messages: 10,
79 max_messages: 50,
80 }
81 }
82}
83
84pub struct HierarchicalSummarizer {
86 config: HierarchicalConfig,
87 hardcode_config: HardcodeConfig,
88}
89
90impl Default for HierarchicalSummarizer {
91 fn default() -> Self {
92 Self::new(HierarchicalConfig::default())
93 }
94}
95
96impl HierarchicalSummarizer {
97 pub fn new(config: HierarchicalConfig) -> Self {
98 Self {
99 config,
100 hardcode_config: HardcodeConfig::default(),
101 }
102 }
103
104 pub fn with_hardcode_config(mut self, hardcode_config: HardcodeConfig) -> Self {
106 self.hardcode_config = hardcode_config;
107 self
108 }
109
110 pub fn summarize_message(&self, message: &Message, level: SummaryLevel) -> String {
112 let content = match &message.content {
113 MessageContent::Text(text) => text.clone(),
114 MessageContent::Blocks(blocks) => {
115 blocks.iter()
116 .filter_map(|b| match b {
117 crate::providers::ContentBlock::Text { text } => Some(text.clone()),
118 _ => None,
119 })
120 .collect::<Vec<_>>()
121 .join("\n")
122 }
123 };
124
125 if content.is_empty() {
126 return String::new();
127 }
128
129 match level {
131 SummaryLevel::Brief => self.brief_summary(&content, &message.role),
132 SummaryLevel::Standard => self.standard_summary(&content, &message.role),
133 SummaryLevel::Detailed => self.detailed_summary(&content, &message.role),
134 }
135 }
136
137 fn brief_summary(&self, content: &str, role: &Role) -> String {
139 let sentences: Vec<&str> = content
140 .split(|c| c == '。' || c == '.' || c == '\n')
141 .filter(|s| !s.trim().is_empty())
142 .collect();
143
144 if sentences.is_empty() {
145 return truncate_to_chars(content, 50);
146 }
147
148 let first_sentence = sentences[0].trim();
150
151 let key_actions = extract_key_actions(content);
153
154 if key_actions.is_empty() {
155 format!("[{}] {}", role_label(role), truncate_to_chars(first_sentence, 40))
156 } else {
157 format!("[{}] {} | {}", role_label(role), truncate_to_chars(first_sentence, 30), key_actions.join(", "))
158 }
159 }
160
161 fn standard_summary(&self, content: &str, role: &Role) -> String {
163 let sentences: Vec<&str> = content
164 .split(|c| c == '。' || c == '.' || c == '\n')
165 .filter(|s| !s.trim().is_empty())
166 .collect();
167
168 if sentences.is_empty() {
169 return truncate_to_chars(content, 100);
170 }
171
172 let mut summary_parts = Vec::new();
174
175 if let Some(first) = sentences.first() {
177 summary_parts.push(first.trim().to_string());
178 }
179
180 if sentences.len() > self.hardcode_config.brief_summary_sentence_count {
182 if let Some(key_sentence) = find_key_sentence(&sentences[1..sentences.len()-1], &self.hardcode_config) {
183 summary_parts.push(key_sentence);
184 }
185 }
186
187 if sentences.len() > self.hardcode_config.min_messages_for_compression {
189 if let Some(last) = sentences.last() {
190 summary_parts.push(last.trim().to_string());
191 }
192 }
193
194 let entities = extract_entities(content, &self.hardcode_config);
196 if !entities.is_empty() {
197 summary_parts.push(format!("[{}]", entities.join(", ")));
198 }
199
200 format!("[{}] {}", role_label(role), summary_parts.join(" | "))
201 }
202
203 fn detailed_summary(&self, content: &str, role: &Role) -> String {
205 let sentences: Vec<&str> = content
206 .split(|c| c == '。' || c == '.' || c == '\n')
207 .filter(|s| !s.trim().is_empty())
208 .collect();
209
210 if sentences.is_empty() {
211 return truncate_to_chars(content, 200);
212 }
213
214 let compressed_sentences: Vec<String> = sentences
216 .iter()
217 .enumerate()
218 .map(|(i, s)| {
219 if i == 0 || i == sentences.len() - 1 {
220 s.trim().to_string()
222 } else {
223 compress_sentence(s.trim(), &self.hardcode_config)
225 }
226 })
227 .collect();
228
229 let mut result = format!("[{}] ", role_label(role));
231 result.push_str(&compressed_sentences.join(" → "));
232
233 if content.contains("```") {
235 let code_blocks = extract_code_blocks(content);
236 if !code_blocks.is_empty() {
237 result.push_str("\n[代码: ");
238 result.push_str(&code_blocks.len().to_string());
239 result.push_str(" 个代码块]");
240 }
241 }
242
243 result
244 }
245
246 pub fn determine_batch_level(&self, messages: &[Message], priorities: &[PriorityScore]) -> SummaryLevel {
248 if messages.is_empty() || priorities.is_empty() {
249 return SummaryLevel::Standard;
250 }
251
252 let priority_scores: Vec<f32> = priorities
254 .iter()
255 .map(|p| p.value())
256 .collect();
257
258 let avg_score: f32 = priority_scores.iter().sum::<f32>() / priority_scores.len() as f32;
259
260 let count_factor = if messages.len() > self.hardcode_config.large_conversation_threshold {
262 0.8 } else if messages.len() > self.hardcode_config.medium_conversation_threshold {
264 0.9
265 } else {
266 1.0
267 };
268
269 let adjusted_score = avg_score * count_factor;
270
271 if adjusted_score >= 0.75 {
272 SummaryLevel::Detailed
273 } else if adjusted_score >= 0.45 {
274 SummaryLevel::Standard
275 } else {
276 SummaryLevel::Brief
277 }
278 }
279
280 pub fn progressive_summarize(&self, messages: &[Message], priorities: &[PriorityScore]) -> Vec<String> {
282 if messages.is_empty() {
283 return Vec::new();
284 }
285
286 let mut summaries = Vec::with_capacity(messages.len());
287 let total = messages.len();
288
289 for (i, (msg, priority)) in messages.iter().zip(priorities.iter()).enumerate() {
290 let base_level = SummaryLevel::from_priority(*priority);
292
293 let level = if self.config.progressive {
294 let age_factor = (total - i) as f32 / total as f32;
295
296 if age_factor > 0.7 {
298 base_level
300 } else if age_factor > 0.4 {
301 compress_level(base_level)
303 } else {
304 compress_level(compress_level(base_level))
306 }
307 } else {
308 base_level
309 };
310
311 summaries.push(self.summarize_message(msg, level));
312 }
313
314 summaries
315 }
316}
317
318fn role_label(role: &Role) -> &'static str {
322 match role {
323 Role::User => "U",
324 Role::Assistant => "A",
325 Role::System => "S",
326 Role::Tool => "T",
327 }
328}
329
330fn truncate_to_chars(s: &str, max_chars: usize) -> String {
332 if s.chars().count() <= max_chars {
333 s.to_string()
334 } else {
335 s.chars().take(max_chars).collect::<String>() + "..."
336 }
337}
338
339fn extract_key_actions(content: &str) -> Vec<String> {
341 let action_keywords = [
342 "创建", "删除", "修改", "更新", "查询", "搜索", "分析", "优化",
343 "create", "delete", "update", "query", "search", "analyze", "optimize",
344 "fix", "add", "remove", "refactor", "test"
345 ];
346
347 let mut actions = Vec::new();
348 let lower = content.to_lowercase();
349
350 for keyword in &action_keywords {
351 if lower.contains(keyword) {
352 actions.push(keyword.to_string());
353 if actions.len() >= 3 {
354 break;
355 }
356 }
357 }
358
359 actions
360}
361
362fn extract_entities(content: &str, config: &HardcodeConfig) -> Vec<String> {
364 let mut entities = Vec::new();
366
367 let in_quotes: Vec<&str> = content
369 .split('"')
370 .enumerate()
371 .filter(|(i, _)| i % 2 == 1)
372 .map(|(_, s)| s)
373 .take(3)
374 .collect();
375
376 for q in in_quotes {
377 if config.is_valid_question_length(q.len()) {
378 entities.push(format!("\"{}\"", truncate_to_chars(q, config.max_question_extract_length)));
379 }
380 }
381
382 entities
383}
384
385fn find_key_sentence(sentences: &[&str], config: &HardcodeConfig) -> Option<String> {
387 let key_terms = ["error", "问题", "result", "结果", "success", "成功", "fail", "失败"];
389
390 sentences
391 .iter()
392 .filter(|s| s.len() > config.min_sentence_length)
393 .max_by(|a, b| {
394 let a_score = key_terms.iter().filter(|t| a.contains(*t)).count();
395 let b_score = key_terms.iter().filter(|t| b.contains(*t)).count();
396 a_score.cmp(&b_score)
397 })
398 .map(|s| s.to_string())
399}
400
401fn compress_sentence(sentence: &str, config: &HardcodeConfig) -> String {
403 let fillers = ["的", "了", "然后", "接着", "因此", "所以", "that", "the", "then", "therefore"];
405
406 let mut compressed = sentence.to_string();
407 for filler in &fillers {
408 if compressed.len() > config.max_compressed_output_length {
409 compressed = compressed.replace(filler, "");
410 }
411 }
412
413 truncate_to_chars(&compressed, config.short_summary_word_count * 5)
414}
415
416fn extract_code_blocks(content: &str) -> Vec<&str> {
418 let mut blocks = Vec::new();
419 let mut in_block = false;
420
421 for line in content.lines() {
422 if line.contains("```") {
423 in_block = !in_block;
424 } else if in_block {
425 blocks.push(line);
426 }
427 }
428
429 blocks
430}
431
432fn compress_level(level: SummaryLevel) -> SummaryLevel {
434 match level {
435 SummaryLevel::Detailed => SummaryLevel::Standard,
436 SummaryLevel::Standard => SummaryLevel::Brief,
437 SummaryLevel::Brief => SummaryLevel::Brief,
438 }
439}
440
441#[cfg(test)]
442mod tests {
443 use super::*;
444
445 #[test]
446 fn test_summary_level_from_priority() {
447 assert_eq!(SummaryLevel::from_priority(PriorityScore::new(0.9)), SummaryLevel::Detailed);
448 assert_eq!(SummaryLevel::from_priority(PriorityScore::new(0.75)), SummaryLevel::Detailed);
449 assert_eq!(SummaryLevel::from_priority(PriorityScore::new(0.5)), SummaryLevel::Standard);
450 assert_eq!(SummaryLevel::from_priority(PriorityScore::new(0.3)), SummaryLevel::Brief);
451 }
452
453 #[test]
454 fn test_retention_ratio() {
455 assert!((SummaryLevel::Brief.retention_ratio() - 0.25).abs() < 0.01);
456 assert!((SummaryLevel::Standard.retention_ratio() - 0.45).abs() < 0.01);
457 assert!((SummaryLevel::Detailed.retention_ratio() - 0.65).abs() < 0.01);
458 }
459
460 #[test]
461 fn test_brief_summary() {
462 let summarizer = HierarchicalSummarizer::default();
463 let msg = Message {
464 role: Role::User,
465 content: MessageContent::Text("我需要创建一个新的API接口来处理用户认证。请帮我实现这个功能。".to_string()),
466 };
467
468 let summary = summarizer.summarize_message(&msg, SummaryLevel::Brief);
469 assert!(summary.contains("[U]"));
470 assert!(summary.len() < 100);
471 }
472
473 #[test]
474 fn test_standard_summary() {
475 let summarizer = HierarchicalSummarizer::default();
476 let msg = Message {
477 role: Role::Assistant,
478 content: MessageContent::Text("好的,我来创建API接口。首先需要设计数据结构。然后实现认证逻辑。最后添加测试用例。".to_string()),
479 };
480
481 let summary = summarizer.summarize_message(&msg, SummaryLevel::Standard);
482 assert!(summary.contains("[A]"));
483 assert!(summary.len() < 200);
484 }
485
486 #[test]
487 fn test_detailed_summary() {
488 let summarizer = HierarchicalSummarizer::default();
489 let msg = Message {
490 role: Role::Assistant,
491 content: MessageContent::Text("这是一个详细的实现方案。首先,我们需要考虑性能问题。其次,安全性也���重要。最后,要确保代码可维护性。".to_string()),
492 };
493
494 let summary = summarizer.summarize_message(&msg, SummaryLevel::Detailed);
495 assert!(summary.contains("[A]"));
496 assert!(summary.contains("→")); }
498
499 #[test]
500 fn test_progressive_summarize() {
501 let summarizer = HierarchicalSummarizer::default();
502 let messages = vec![
503 Message {
504 role: Role::User,
505 content: MessageContent::Text("第一条消息".to_string()),
506 },
507 Message {
508 role: Role::Assistant,
509 content: MessageContent::Text("第二条消息".to_string()),
510 },
511 Message {
512 role: Role::User,
513 content: MessageContent::Text("第三条消息".to_string()),
514 },
515 ];
516 let priorities = vec![
517 PriorityScore::new(0.3),
518 PriorityScore::new(0.5),
519 PriorityScore::new(0.8),
520 ];
521
522 let summaries = summarizer.progressive_summarize(&messages, &priorities);
523 assert_eq!(summaries.len(), 3);
524 }
525}