matrixcode_core/compress/
coherence.rs1use crate::providers::Message;
8use crate::compress::hardcode_config::HardcodeConfig;
9use crate::memory::PatternRegistry;
10use std::collections::HashSet;
11
12#[derive(Debug, Clone)]
14pub struct CoherenceDetector {
15 threshold: f32,
17 pattern_registry: PatternRegistry,
19 hardcode_config: HardcodeConfig,
21}
22
23impl Default for CoherenceDetector {
24 fn default() -> Self {
25 Self {
26 threshold: 0.7,
27 pattern_registry: PatternRegistry::new(),
28 hardcode_config: HardcodeConfig::default(),
29 }
30 }
31}
32
33impl CoherenceDetector {
34 pub fn new(threshold: f32) -> Self {
38 Self {
39 threshold,
40 ..Default::default()
41 }
42 }
43
44 pub fn new_with_registry(threshold: f32, registry: PatternRegistry) -> Self {
48 Self {
49 threshold,
50 pattern_registry: registry,
51 hardcode_config: HardcodeConfig::default(),
52 }
53 }
54
55 pub fn with_hardcode_config(mut self, config: HardcodeConfig) -> Self {
57 self.hardcode_config = config;
58 self
59 }
60
61 pub fn pattern_registry(&self) -> &PatternRegistry {
63 &self.pattern_registry
64 }
65
66 pub fn pattern_registry_mut(&mut self) -> &mut PatternRegistry {
68 &mut self.pattern_registry
69 }
70
71 pub fn should_keep_together(&self, messages: &[Message]) -> bool {
73 if messages.len() < 2 {
74 return true;
75 }
76
77 let coherence_score = self.calculate_coherence(messages);
78 coherence_score >= self.threshold
79 }
80
81 pub fn calculate_coherence(&self, messages: &[Message]) -> f32 {
83 if messages.len() < 2 {
84 return 1.0;
85 }
86
87 let topic_score = self.check_topic_continuity(messages);
89 let reference_score = self.check_reference_patterns(messages);
90 let code_score = self.check_code_context(messages);
91 let entity_score = self.check_entity_consistency(messages);
92
93 topic_score * 0.3 + reference_score * 0.25 + code_score * 0.25 + entity_score * 0.2
95 }
96
97 fn check_topic_continuity(&self, messages: &[Message]) -> f32 {
99 let topics: Vec<HashSet<String>> = messages
100 .iter()
101 .map(|m| self.extract_topic_keywords(&self.get_message_content(m)))
102 .collect();
103
104 if topics.len() < 2 {
105 return 1.0;
106 }
107
108 let mut overlap_scores = Vec::new();
110 for i in 1..topics.len() {
111 let overlap = self.calculate_set_overlap(&topics[i - 1], &topics[i]);
112 overlap_scores.push(overlap);
113 }
114
115 if !overlap_scores.is_empty() {
117 overlap_scores.iter().sum::<f32>() / overlap_scores.len() as f32
118 } else {
119 0.5
120 }
121 }
122
123 fn check_reference_patterns(&self, messages: &[Message]) -> f32 {
125 let patterns = self.pattern_registry.get_active_reference_patterns();
126 if patterns.is_empty() {
127 return 0.5; }
129
130 let mut has_references = false;
131
132 for i in 1..messages.len() {
133 let content_lower = self.get_message_content_lower(&messages[i]);
134
135 for pattern in &patterns {
136 if let Ok(re) = regex::Regex::new(&format!("(?i){}", pattern)) {
138 if re.is_match(&content_lower) {
139 has_references = true;
140 break;
141 }
142 } else {
143 if content_lower.contains(pattern.to_lowercase().as_str()) {
145 has_references = true;
146 break;
147 }
148 }
149 }
150 }
151
152 if has_references {
153 1.0 } else {
155 0.5 }
157 }
158
159 fn check_code_context(&self, messages: &[Message]) -> f32 {
161 let patterns = self.pattern_registry.get_active_code_patterns();
162 if patterns.is_empty() {
163 return 0.5; }
165
166 let mut code_messages = Vec::new();
167
168 for (i, msg) in messages.iter().enumerate() {
169 let content_lower = self.get_message_content_lower(msg);
170
171 for pattern in &patterns {
172 if let Ok(re) = regex::Regex::new(&format!("(?i){}", pattern)) {
174 if re.is_match(&content_lower) {
175 code_messages.push(i);
176 break;
177 }
178 } else {
179 if content_lower.contains(pattern.to_lowercase().as_str()) {
181 code_messages.push(i);
182 break;
183 }
184 }
185 }
186 }
187
188 if code_messages.is_empty() {
189 return 0.5; }
191
192 let mut consecutive_score = 0.0;
194 for i in 1..code_messages.len() {
195 let distance = code_messages[i] - code_messages[i - 1];
196 if distance <= 2 {
197 consecutive_score += 1.0;
198 } else if distance <= 4 {
199 consecutive_score += 0.5;
200 }
201 }
202
203 if !code_messages.is_empty() && consecutive_score > 0.0 {
204 consecutive_score / (code_messages.len() - 1).max(1) as f32
205 } else {
206 0.5
207 }
208 }
209
210 fn check_entity_consistency(&self, messages: &[Message]) -> f32 {
212 let entities: Vec<HashSet<String>> = messages
213 .iter()
214 .map(|m| self.extract_entities(&self.get_message_content(m)))
215 .collect();
216
217 if entities.len() < 2 {
218 return 1.0;
219 }
220
221 let mut common_entities = HashSet::new();
223 let all_entities: HashSet<String> = entities
224 .iter()
225 .flat_map(|e| e.iter().cloned())
226 .collect();
227
228 for entity in &all_entities {
229 let count = entities.iter().filter(|e| e.contains(entity)).count();
230 if count >= 2 {
231 common_entities.insert(entity.clone());
232 }
233 }
234
235 if common_entities.is_empty() {
237 0.3 } else if common_entities.len() >= 3 {
239 1.0 } else {
241 0.7 }
243 }
244
245 fn extract_topic_keywords(&self, content: &str) -> HashSet<String> {
247 let content_lower = content.to_lowercase();
248 let words = content_lower
249 .split_whitespace()
250 .filter(|w| w.len() > self.hardcode_config.min_word_length) .take(20) .map(|w| w.to_string())
253 .collect();
254 words
255 }
256
257 fn extract_entities(&self, content: &str) -> HashSet<String> {
259 let mut entities = HashSet::new();
260
261 let file_pattern = regex::Regex::new(r"\b[\w]+\.[\w]{2,4}\b").unwrap();
263 for cap in file_pattern.find_iter(content) {
264 entities.insert(cap.as_str().to_string());
265 }
266
267 let name_pattern = regex::Regex::new(r"\b[A-Z][a-zA-Z]+\b|\b[a-z_][a-z0-9_]{3,}\b").unwrap();
269 for cap in name_pattern.find_iter(content) {
270 let name = cap.as_str();
271 if !["true", "false", "null", "some", "none", "this", "that", "here", "there"].contains(&name.to_lowercase().as_str()) {
273 entities.insert(name.to_string());
274 }
275 }
276
277 entities
278 }
279
280 fn calculate_set_overlap<T: std::hash::Hash + Eq>(&self, set1: &HashSet<T>, set2: &HashSet<T>) -> f32 {
282 if set1.is_empty() || set2.is_empty() {
283 return 0.0;
284 }
285
286 let intersection = set1.intersection(set2).count();
287 let union = set1.union(set2).count();
288
289 if union > 0 {
290 intersection as f32 / union as f32
291 } else {
292 0.0
293 }
294 }
295
296 fn get_message_content(&self, message: &Message) -> String {
298 match &message.content {
299 crate::providers::MessageContent::Text(text) => text.clone(),
300 crate::providers::MessageContent::Blocks(blocks) => {
301 blocks
302 .iter()
303 .filter_map(|b| {
304 if let crate::providers::ContentBlock::Text { text } = b {
305 Some(text.clone())
306 } else {
307 None
308 }
309 })
310 .collect::<Vec<_>>()
311 .join(" ")
312 }
313 }
314 }
315
316 fn get_message_content_lower(&self, message: &Message) -> String {
318 match &message.content {
319 crate::providers::MessageContent::Text(text) => text.to_lowercase(),
320 crate::providers::MessageContent::Blocks(blocks) => {
321 blocks
322 .iter()
323 .filter_map(|b| {
324 if let crate::providers::ContentBlock::Text { text } = b {
325 Some(text.clone())
326 } else {
327 None
328 }
329 })
330 .collect::<Vec<_>>()
331 .join(" ")
332 .to_lowercase()
333 }
334 }
335 }
336
337 pub fn find_segmentation_points(&self, messages: &[Message]) -> Vec<usize> {
339 if messages.len() < 3 {
340 return vec![];
341 }
342
343 let mut points = Vec::new();
344
345 for i in 1..messages.len() - 1 {
347 let before = &messages[0..i];
348 let after = &messages[i..messages.len()];
349
350 let coherence_before = self.calculate_coherence(before);
352 let coherence_after = self.calculate_coherence(after);
353 let coherence_cross = self.calculate_coherence(&[messages[i - 1].clone(), messages[i].clone()]);
354
355 if coherence_cross < coherence_before * 0.7 && coherence_cross < coherence_after * 0.7 {
357 points.push(i);
358 }
359 }
360
361 points
362 }
363
364 pub fn segment_messages(&self, messages: &[Message]) -> Vec<Vec<Message>> {
366 if messages.is_empty() {
367 return vec![];
368 }
369
370 let points = self.find_segmentation_points(messages);
371
372 if points.is_empty() {
373 return vec![messages.to_vec()];
374 }
375
376 let mut segments = Vec::new();
377 let mut start = 0;
378
379 for point in points {
380 if point > start {
381 segments.push(messages[start..point].to_vec());
382 start = point;
383 }
384 }
385
386 if start < messages.len() {
387 segments.push(messages[start..messages.len()].to_vec());
388 }
389
390 segments
391 }
392}
393
394#[cfg(test)]
395mod tests {
396 use super::*;
397 use crate::providers::{Message, MessageContent, Role};
398
399 fn create_text_message(text: &str) -> Message {
400 Message {
401 role: Role::User,
402 content: MessageContent::Text(text.to_string()),
403 }
404 }
405
406 #[test]
407 fn test_coherence_detector_creation() {
408 let detector = CoherenceDetector::default();
409 assert_eq!(detector.threshold, 0.7);
410 }
411
412 #[test]
413 fn test_single_message_coherence() {
414 let detector = CoherenceDetector::default();
415 let messages = vec![create_text_message("test")];
416 assert!(detector.should_keep_together(&messages));
417 }
418
419 #[test]
420 fn test_topic_continuity() {
421 let detector = CoherenceDetector::default();
422 let messages = vec![
423 create_text_message("We need to optimize database performance"),
424 create_text_message("The database queries are slow"),
425 create_text_message("Let's add database indexes"),
426 ];
427
428 let score = detector.calculate_coherence(&messages);
429 assert!(score > 0.4, "Expected coherence > 0.4 for topic continuity, got {}", score);
431 }
432
433 #[test]
434 fn test_reference_patterns() {
435 let mut registry = PatternRegistry::new();
437 let pattern = crate::memory::ConversationPattern::manual(
438 crate::memory::PatternType::Reference,
439 "as i mentioned",
440 );
441 registry.add_pattern(pattern);
442
443 let detector = CoherenceDetector::new_with_registry(0.7, registry);
444 let messages = vec![
445 create_text_message("We decided to use PostgreSQL"),
446 create_text_message("As I mentioned, PostgreSQL is the choice"),
447 ];
448
449 let score = detector.calculate_coherence(&messages);
450 assert!(score > 0.5, "Expected coherence > 0.5 for reference patterns, got {}", score);
452 }
453
454 #[test]
455 fn test_code_context() {
456 let mut registry = PatternRegistry::new();
458 let pattern = crate::memory::ConversationPattern::manual(
459 crate::memory::PatternType::Code,
460 "fn ",
461 );
462 registry.add_pattern(pattern);
463
464 let detector = CoherenceDetector::new_with_registry(0.7, registry);
465 let messages = vec![
466 create_text_message("Here's the function:\n```rust\nfn test() {}\n```"),
467 create_text_message("This function needs optimization"),
468 ];
469
470 let score = detector.calculate_coherence(&messages);
471 assert!(score >= 0.35, "Expected coherence >= 0.35 for code context, got {}", score);
474 }
475
476 #[test]
477 fn test_segmentation() {
478 let detector = CoherenceDetector::default();
479 let messages = vec![
480 create_text_message("Topic A: database optimization"),
481 create_text_message("More about database"),
482 create_text_message("Topic B: frontend design"),
483 create_text_message("More about frontend"),
484 ];
485
486 let segments = detector.segment_messages(&messages);
487 assert!(segments.len() >= 1);
488 }
489
490 #[test]
495 fn test_new_with_registry() {
496 let registry = PatternRegistry::new();
497 let detector = CoherenceDetector::new_with_registry(0.8, registry);
498
499 assert_eq!(detector.threshold, 0.8);
500 assert!(detector.pattern_registry().is_empty());
501 }
502
503 #[test]
504 fn test_backward_compatible_new() {
505 let detector = CoherenceDetector::new(0.7);
507
508 assert_eq!(detector.threshold, 0.7);
509 assert!(detector.pattern_registry().is_empty());
511
512 let ref_patterns = detector.pattern_registry().get_active_reference_patterns();
514 let code_patterns = detector.pattern_registry().get_active_code_patterns();
515
516 assert!(ref_patterns.is_empty(), "Reference patterns should be empty");
517 assert!(code_patterns.is_empty(), "Code patterns should be empty");
518 }
519
520 #[test]
521 fn test_default_uses_pattern_registry() {
522 let detector = CoherenceDetector::default();
523
524 let ref_patterns = detector.pattern_registry().get_active_reference_patterns();
526 let code_patterns = detector.pattern_registry().get_active_code_patterns();
527
528 assert!(ref_patterns.is_empty());
529 assert!(code_patterns.is_empty());
530 }
531
532 #[test]
533 fn test_pattern_registry_accessor() {
534 let detector = CoherenceDetector::default();
535
536 let registry = detector.pattern_registry();
538 assert!(registry.is_empty());
539
540 assert_eq!(registry.count_by_type(crate::memory::PatternType::Reference), 0);
542 assert_eq!(registry.count_by_type(crate::memory::PatternType::Code), 0);
543 }
544
545 #[test]
546 fn test_pattern_registry_mut_accessor() {
547 let mut detector = CoherenceDetector::default();
548 assert!(detector.pattern_registry().is_empty());
549
550 let pattern = crate::memory::ConversationPattern::manual(
552 crate::memory::PatternType::Code,
553 "test-pattern-mut",
554 );
555 detector.pattern_registry_mut().add_pattern(pattern);
556
557 assert_eq!(detector.pattern_registry().len(), 1);
558 }
559
560 #[test]
561 fn test_with_hardcode_config() {
562 let config = HardcodeConfig::complex_technical();
563 let detector = CoherenceDetector::new(0.7).with_hardcode_config(config.clone());
564
565 assert_eq!(detector.hardcode_config.min_word_length, config.min_word_length);
566 }
567
568 #[test]
569 fn test_reference_patterns_from_registry() {
570 let mut registry = PatternRegistry::new();
572 let pattern = crate::memory::ConversationPattern::manual(
573 crate::memory::PatternType::Reference,
574 "as i mentioned",
575 );
576 registry.add_pattern(pattern);
577
578 let detector = CoherenceDetector::new_with_registry(0.7, registry);
579
580 let messages = vec![
582 create_text_message("Let's implement feature X"),
583 create_text_message("As I mentioned earlier, feature X is important"),
584 ];
585
586 let score = detector.calculate_coherence(&messages);
587 assert!(score > 0.5, "Expected coherence > 0.5 for reference patterns, got {}", score);
589 }
590
591 #[test]
592 fn test_code_patterns_from_registry() {
593 let mut registry = PatternRegistry::new();
595 let pattern = crate::memory::ConversationPattern::manual(
596 crate::memory::PatternType::Code,
597 "fn ",
598 );
599 registry.add_pattern(pattern);
600
601 let detector = CoherenceDetector::new_with_registry(0.7, registry);
602
603 let messages = vec![
605 create_text_message("Here is a function:\n```rust\nfn example() {}\n```"),
606 create_text_message("This function does something"),
607 ];
608
609 let score = detector.calculate_coherence(&messages);
610 assert!(score >= 0.35, "Expected coherence >= 0.35 for code patterns, got {}", score);
612 }
613
614 #[test]
615 fn test_empty_registry_graceful_handling() {
616 let registry = PatternRegistry::new();
618 assert!(registry.is_empty());
619
620 let detector = CoherenceDetector::new_with_registry(0.7, registry);
621
622 let messages = vec![
624 create_text_message("Message one"),
625 create_text_message("Message two"),
626 ];
627
628 let score = detector.calculate_coherence(&messages);
630 assert!(score >= 0.0 && score <= 1.0);
632 }
633
634 #[test]
635 fn test_chinese_reference_patterns() {
636 let mut registry = PatternRegistry::new();
638 let pattern = crate::memory::ConversationPattern::manual(
639 crate::memory::PatternType::Reference,
640 "正如我所说",
641 );
642 registry.add_pattern(pattern);
643
644 let detector = CoherenceDetector::new_with_registry(0.7, registry);
645
646 let messages = vec![
648 create_text_message("我们决定使用 PostgreSQL"),
649 create_text_message("正如我所说,PostgreSQL 是最佳选择"),
650 ];
651
652 let score = detector.calculate_coherence(&messages);
653 assert!(score > 0.5, "Expected coherence > 0.5 for Chinese reference patterns, got {}", score);
655 }
656
657 #[test]
662 fn test_empty_messages() {
663 let detector = CoherenceDetector::default();
664
665 let messages: Vec<Message> = vec![];
667 assert!(detector.should_keep_together(&messages));
668
669 let score = detector.calculate_coherence(&messages);
671 assert!((score - 1.0).abs() < 0.001, "Expected coherence 1.0 for empty messages, got {}", score);
672 }
673
674 #[test]
675 fn test_find_segmentation_points_empty() {
676 let detector = CoherenceDetector::default();
677
678 let messages: Vec<Message> = vec![];
680 let points = detector.find_segmentation_points(&messages);
681 assert!(points.is_empty());
682
683 let single = vec![create_text_message("single message")];
685 let points = detector.find_segmentation_points(&single);
686 assert!(points.is_empty());
687
688 let two = vec![
690 create_text_message("first"),
691 create_text_message("second"),
692 ];
693 let points = detector.find_segmentation_points(&two);
694 assert!(points.is_empty());
695 }
696
697 #[test]
698 fn test_segment_messages_empty() {
699 let detector = CoherenceDetector::default();
700
701 let messages: Vec<Message> = vec![];
703 let segments = detector.segment_messages(&messages);
704 assert!(segments.is_empty());
705 }
706
707 #[test]
708 fn test_regex_pattern_matching() {
709 let detector = CoherenceDetector::default();
711
712 let messages = vec![
714 create_text_message("We discussed the architecture"),
715 create_text_message("之前的讨论很有价值"), ];
717
718 let score = detector.calculate_coherence(&messages);
719 assert!(score >= 0.0 && score <= 1.0, "Score should be valid, got {}", score);
721 }
722
723 #[test]
724 fn test_simple_pattern_matching() {
725 let mut registry = PatternRegistry::new();
727 let pattern = crate::memory::ConversationPattern::manual(
728 crate::memory::PatternType::Reference,
729 "as mentioned",
730 );
731 registry.add_pattern(pattern);
732
733 let detector = CoherenceDetector::new_with_registry(0.7, registry);
734
735 let messages = vec![
737 create_text_message("Let's implement feature A"),
738 create_text_message("As mentioned above, feature A is important"),
739 ];
740
741 let score = detector.calculate_coherence(&messages);
742 assert!(score > 0.5, "Expected high coherence for reference pattern match, got {}", score);
744 }
745
746 #[test]
747 fn test_low_coherence_messages() {
748 let detector = CoherenceDetector::new(0.7);
749
750 let messages = vec![
752 create_text_message("The quick brown fox jumps"),
753 create_text_message("Database optimization strategies"),
754 create_text_message("Weather forecast tomorrow"),
755 ];
756
757 let score = detector.calculate_coherence(&messages);
758 assert!(score >= 0.0 && score <= 1.0, "Score should be valid, got {}", score);
761 }
762
763 #[test]
764 fn test_entity_consistency() {
765 let detector = CoherenceDetector::default();
766
767 let messages = vec![
769 create_text_message("In process.rs we have a bug"),
770 create_text_message("The process.rs file needs fixing"),
771 create_text_message("Let me check process.rs again"),
772 ];
773
774 let score = detector.calculate_coherence(&messages);
775 assert!(score >= 0.3 && score <= 1.0, "Expected valid coherence score, got {}", score);
779
780 let no_entity_messages = vec![
782 create_text_message("First topic discussion"),
783 create_text_message("Second unrelated topic"),
784 create_text_message("Third different subject"),
785 ];
786 let no_entity_score = detector.calculate_coherence(&no_entity_messages);
787
788 assert!(no_entity_score >= 0.0 && no_entity_score <= 1.0);
792 }
793
794 #[test]
795 fn test_custom_registry_affects_detection() {
796 use crate::memory::{ConversationPattern, PatternType};
798
799 let mut registry = PatternRegistry::new();
800 let initial_count = registry.len();
801
802 let custom_pattern = ConversationPattern::manual(
804 PatternType::Reference,
805 "custom_reference_pattern_xyz",
806 );
807 registry.add_pattern(custom_pattern);
808
809 assert!(registry.len() > initial_count);
810
811 let detector = CoherenceDetector::new_with_registry(0.7, registry);
813
814 let ref_patterns = detector.pattern_registry().get_active_reference_patterns();
816 assert!(ref_patterns.iter().any(|p| p.contains("custom_reference_pattern_xyz")));
817 }
818
819 #[test]
820 fn test_multiple_code_blocks() {
821 let detector = CoherenceDetector::default();
822
823 let messages = vec![
825 create_text_message("Here is the first function:\n```rust\nfn one() {}\n```"),
826 create_text_message("And the second:\n```rust\nfn two() {}\n```"),
827 create_text_message("The third function:\n```rust\nfn three() {}\n```"),
828 ];
829
830 let score = detector.calculate_coherence(&messages);
831 assert!(score >= 0.0 && score <= 1.0);
833 }
834
835 #[test]
836 fn test_threshold_affects_should_keep_together() {
837 let low_threshold = CoherenceDetector::new(0.3);
839 let high_threshold = CoherenceDetector::new(0.9);
841
842 let messages = vec![
843 create_text_message("Topic A discussion"),
844 create_text_message("Related to topic A"),
845 ];
846
847 assert!(low_threshold.should_keep_together(&messages));
849 let _score = high_threshold.calculate_coherence(&messages);
851 }
853}