Skip to main content

matrixcode_core/compress/
coherence.rs

1//! Semantic Coherence Detection: Preserve conversation continuity.
2//!
3//! This module analyzes conversation segments to determine if messages
4//! should be compressed together or kept separate to maintain semantic
5//! coherence.
6
7use crate::providers::Message;
8use crate::compress::hardcode_config::HardcodeConfig;
9use crate::memory::PatternRegistry;
10use std::collections::HashSet;
11
12/// Coherence detector for conversation segments.
13#[derive(Debug, Clone)]
14pub struct CoherenceDetector {
15    /// Minimum coherence score to keep messages together
16    threshold: f32,
17    /// Pattern registry for dynamic pattern loading (replaces hardcoded patterns)
18    pattern_registry: PatternRegistry,
19    /// Hardcode configuration
20    hardcode_config: HardcodeConfig,
21}
22
23impl Default for CoherenceDetector {
24    fn default() -> Self {
25        Self {
26            threshold: 0.7,
27            pattern_registry: PatternRegistry::new(),
28            hardcode_config: HardcodeConfig::default(),
29        }
30    }
31}
32
33impl CoherenceDetector {
34    /// Create a new coherence detector with default settings.
35    ///
36    /// This is backward compatible - automatically loads preset patterns.
37    pub fn new(threshold: f32) -> Self {
38        Self {
39            threshold,
40            ..Default::default()
41        }
42    }
43
44    /// Create a coherence detector with a custom pattern registry.
45    ///
46    /// Use this when you need to customize patterns or load from a file.
47    pub fn new_with_registry(threshold: f32, registry: PatternRegistry) -> Self {
48        Self {
49            threshold,
50            pattern_registry: registry,
51            hardcode_config: HardcodeConfig::default(),
52        }
53    }
54
55    /// Create a coherence detector with custom hardcode config.
56    pub fn with_hardcode_config(mut self, config: HardcodeConfig) -> Self {
57        self.hardcode_config = config;
58        self
59    }
60
61    /// Get a reference to the pattern registry.
62    pub fn pattern_registry(&self) -> &PatternRegistry {
63        &self.pattern_registry
64    }
65
66    /// Get a mutable reference to the pattern registry.
67    pub fn pattern_registry_mut(&mut self) -> &mut PatternRegistry {
68        &mut self.pattern_registry
69    }
70
71    /// Check if a group of messages should be kept together.
72    pub fn should_keep_together(&self, messages: &[Message]) -> bool {
73        if messages.len() < 2 {
74            return true;
75        }
76
77        let coherence_score = self.calculate_coherence(messages);
78        coherence_score >= self.threshold
79    }
80
81    /// Calculate coherence score for a message group.
82    pub fn calculate_coherence(&self, messages: &[Message]) -> f32 {
83        if messages.len() < 2 {
84            return 1.0;
85        }
86
87        // Weighted sum of scores (weights sum to 1.0, no need to divide)
88        let topic_score = self.check_topic_continuity(messages);
89        let reference_score = self.check_reference_patterns(messages);
90        let code_score = self.check_code_context(messages);
91        let entity_score = self.check_entity_consistency(messages);
92
93        // Weighted average: weights sum to 1.0
94        topic_score * 0.3 + reference_score * 0.25 + code_score * 0.25 + entity_score * 0.2
95    }
96
97    /// Check topic continuity across messages.
98    fn check_topic_continuity(&self, messages: &[Message]) -> f32 {
99        let topics: Vec<HashSet<String>> = messages
100            .iter()
101            .map(|m| self.extract_topic_keywords(&self.get_message_content(m)))
102            .collect();
103
104        if topics.len() < 2 {
105            return 1.0;
106        }
107
108        // Calculate overlap between consecutive messages
109        let mut overlap_scores = Vec::new();
110        for i in 1..topics.len() {
111            let overlap = self.calculate_set_overlap(&topics[i - 1], &topics[i]);
112            overlap_scores.push(overlap);
113        }
114
115        // Average overlap
116        if !overlap_scores.is_empty() {
117            overlap_scores.iter().sum::<f32>() / overlap_scores.len() as f32
118        } else {
119            0.5
120        }
121    }
122
123    /// Check if messages reference each other.
124    fn check_reference_patterns(&self, messages: &[Message]) -> f32 {
125        let patterns = self.pattern_registry.get_active_reference_patterns();
126        if patterns.is_empty() {
127            return 0.5; // Neutral if no patterns available
128        }
129
130        let mut has_references = false;
131
132        for i in 1..messages.len() {
133            let content_lower = self.get_message_content_lower(&messages[i]);
134
135            for pattern in &patterns {
136                // Try regex match first (case-insensitive), fallback to simple contains for non-regex patterns
137                if let Ok(re) = regex::Regex::new(&format!("(?i){}", pattern)) {
138                    if re.is_match(&content_lower) {
139                        has_references = true;
140                        break;
141                    }
142                } else {
143                    // Fallback to simple contains for invalid regex patterns
144                    if content_lower.contains(pattern.to_lowercase().as_str()) {
145                        has_references = true;
146                        break;
147                    }
148                }
149            }
150        }
151
152        if has_references {
153            1.0 // High coherence if messages reference each other
154        } else {
155            0.5 // Neutral
156        }
157    }
158
159    /// Check if messages contain code that should stay together.
160    fn check_code_context(&self, messages: &[Message]) -> f32 {
161        let patterns = self.pattern_registry.get_active_code_patterns();
162        if patterns.is_empty() {
163            return 0.5; // Neutral if no patterns available
164        }
165
166        let mut code_messages = Vec::new();
167
168        for (i, msg) in messages.iter().enumerate() {
169            let content_lower = self.get_message_content_lower(msg);
170
171            for pattern in &patterns {
172                // Try regex match first (case-insensitive), fallback to simple contains for non-regex patterns
173                if let Ok(re) = regex::Regex::new(&format!("(?i){}", pattern)) {
174                    if re.is_match(&content_lower) {
175                        code_messages.push(i);
176                        break;
177                    }
178                } else {
179                    // Fallback to simple contains for invalid regex patterns
180                    if content_lower.contains(pattern.to_lowercase().as_str()) {
181                        code_messages.push(i);
182                        break;
183                    }
184                }
185            }
186        }
187
188        if code_messages.is_empty() {
189            return 0.5; // Neutral if no code
190        }
191
192        // Check if code messages are consecutive or close
193        let mut consecutive_score = 0.0;
194        for i in 1..code_messages.len() {
195            let distance = code_messages[i] - code_messages[i - 1];
196            if distance <= 2 {
197                consecutive_score += 1.0;
198            } else if distance <= 4 {
199                consecutive_score += 0.5;
200            }
201        }
202
203        if !code_messages.is_empty() && consecutive_score > 0.0 {
204            consecutive_score / (code_messages.len() - 1).max(1) as f32
205        } else {
206            0.5
207        }
208    }
209
210    /// Check entity consistency (files, functions, modules).
211    fn check_entity_consistency(&self, messages: &[Message]) -> f32 {
212        let entities: Vec<HashSet<String>> = messages
213            .iter()
214            .map(|m| self.extract_entities(&self.get_message_content(m)))
215            .collect();
216
217        if entities.len() < 2 {
218            return 1.0;
219        }
220
221        // Find entities that appear in multiple messages
222        let mut common_entities = HashSet::new();
223        let all_entities: HashSet<String> = entities
224            .iter()
225            .flat_map(|e| e.iter().cloned())
226            .collect();
227
228        for entity in &all_entities {
229            let count = entities.iter().filter(|e| e.contains(entity)).count();
230            if count >= 2 {
231                common_entities.insert(entity.clone());
232            }
233        }
234
235        // Score based on number of common entities
236        if common_entities.is_empty() {
237            0.3 // Low coherence if no common entities
238        } else if common_entities.len() >= 3 {
239            1.0 // High coherence
240        } else {
241            0.7 // Medium coherence
242        }
243    }
244
245    /// Extract topic keywords from message content.
246    fn extract_topic_keywords(&self, content: &str) -> HashSet<String> {
247        let content_lower = content.to_lowercase();
248        let words = content_lower
249            .split_whitespace()
250            .filter(|w| w.len() > self.hardcode_config.min_word_length) // Skip short words
251            .take(20) // Limit to 20 keywords
252            .map(|w| w.to_string())
253            .collect();
254        words
255    }
256
257    /// Extract entities (file names, function names) from content.
258    fn extract_entities(&self, content: &str) -> HashSet<String> {
259        let mut entities = HashSet::new();
260
261        // Pattern: file.rs, module.ts, etc.
262        let file_pattern = regex::Regex::new(r"\b[\w]+\.[\w]{2,4}\b").unwrap();
263        for cap in file_pattern.find_iter(content) {
264            entities.insert(cap.as_str().to_string());
265        }
266
267        // Pattern: function_name, ClassName, etc.
268        let name_pattern = regex::Regex::new(r"\b[A-Z][a-zA-Z]+\b|\b[a-z_][a-z0-9_]{3,}\b").unwrap();
269        for cap in name_pattern.find_iter(content) {
270            let name = cap.as_str();
271            // Skip common words
272            if !["true", "false", "null", "some", "none", "this", "that", "here", "there"].contains(&name.to_lowercase().as_str()) {
273                entities.insert(name.to_string());
274            }
275        }
276
277        entities
278    }
279
280    /// Calculate overlap between two sets.
281    fn calculate_set_overlap<T: std::hash::Hash + Eq>(&self, set1: &HashSet<T>, set2: &HashSet<T>) -> f32 {
282        if set1.is_empty() || set2.is_empty() {
283            return 0.0;
284        }
285
286        let intersection = set1.intersection(set2).count();
287        let union = set1.union(set2).count();
288
289        if union > 0 {
290            intersection as f32 / union as f32
291        } else {
292            0.0
293        }
294    }
295
296    /// Get message content as string.
297    fn get_message_content(&self, message: &Message) -> String {
298        match &message.content {
299            crate::providers::MessageContent::Text(text) => text.clone(),
300            crate::providers::MessageContent::Blocks(blocks) => {
301                blocks
302                    .iter()
303                    .filter_map(|b| {
304                        if let crate::providers::ContentBlock::Text { text } = b {
305                            Some(text.clone())
306                        } else {
307                            None
308                        }
309                    })
310                    .collect::<Vec<_>>()
311                    .join(" ")
312            }
313        }
314    }
315
316    /// Get message content as lowercase string.
317    fn get_message_content_lower(&self, message: &Message) -> String {
318        match &message.content {
319            crate::providers::MessageContent::Text(text) => text.to_lowercase(),
320            crate::providers::MessageContent::Blocks(blocks) => {
321                blocks
322                    .iter()
323                    .filter_map(|b| {
324                        if let crate::providers::ContentBlock::Text { text } = b {
325                            Some(text.clone())
326                        } else {
327                            None
328                        }
329                    })
330                    .collect::<Vec<_>>()
331                    .join(" ")
332                    .to_lowercase()
333            }
334        }
335    }
336
337    /// Find optimal segmentation points in messages.
338    pub fn find_segmentation_points(&self, messages: &[Message]) -> Vec<usize> {
339        if messages.len() < 3 {
340            return vec![];
341        }
342
343        let mut points = Vec::new();
344
345        // Check coherence between consecutive pairs
346        for i in 1..messages.len() - 1 {
347            let before = &messages[0..i];
348            let after = &messages[i..messages.len()];
349
350            // Calculate coherence drop at this point
351            let coherence_before = self.calculate_coherence(before);
352            let coherence_after = self.calculate_coherence(after);
353            let coherence_cross = self.calculate_coherence(&[messages[i - 1].clone(), messages[i].clone()]);
354
355            // If cross coherence is significantly lower, mark as segmentation point
356            if coherence_cross < coherence_before * 0.7 && coherence_cross < coherence_after * 0.7 {
357                points.push(i);
358            }
359        }
360
361        points
362    }
363
364    /// Segment messages into coherent groups.
365    pub fn segment_messages(&self, messages: &[Message]) -> Vec<Vec<Message>> {
366        if messages.is_empty() {
367            return vec![];
368        }
369
370        let points = self.find_segmentation_points(messages);
371
372        if points.is_empty() {
373            return vec![messages.to_vec()];
374        }
375
376        let mut segments = Vec::new();
377        let mut start = 0;
378
379        for point in points {
380            if point > start {
381                segments.push(messages[start..point].to_vec());
382                start = point;
383            }
384        }
385
386        if start < messages.len() {
387            segments.push(messages[start..messages.len()].to_vec());
388        }
389
390        segments
391    }
392}
393
394#[cfg(test)]
395mod tests {
396    use super::*;
397    use crate::providers::{Message, MessageContent, Role};
398
399    fn create_text_message(text: &str) -> Message {
400        Message {
401            role: Role::User,
402            content: MessageContent::Text(text.to_string()),
403        }
404    }
405
406    #[test]
407    fn test_coherence_detector_creation() {
408        let detector = CoherenceDetector::default();
409        assert_eq!(detector.threshold, 0.7);
410    }
411
412    #[test]
413    fn test_single_message_coherence() {
414        let detector = CoherenceDetector::default();
415        let messages = vec![create_text_message("test")];
416        assert!(detector.should_keep_together(&messages));
417    }
418
419    #[test]
420    fn test_topic_continuity() {
421        let detector = CoherenceDetector::default();
422        let messages = vec![
423            create_text_message("We need to optimize database performance"),
424            create_text_message("The database queries are slow"),
425            create_text_message("Let's add database indexes"),
426        ];
427
428        let score = detector.calculate_coherence(&messages);
429        // Topic continuity with database keywords should give reasonable score
430        assert!(score > 0.4, "Expected coherence > 0.4 for topic continuity, got {}", score);
431    }
432
433    #[test]
434    fn test_reference_patterns() {
435        // Create registry with custom reference pattern
436        let mut registry = PatternRegistry::new();
437        let pattern = crate::memory::ConversationPattern::manual(
438            crate::memory::PatternType::Reference,
439            "as i mentioned",
440        );
441        registry.add_pattern(pattern);
442
443        let detector = CoherenceDetector::new_with_registry(0.7, registry);
444        let messages = vec![
445            create_text_message("We decided to use PostgreSQL"),
446            create_text_message("As I mentioned, PostgreSQL is the choice"),
447        ];
448
449        let score = detector.calculate_coherence(&messages);
450        // With reference pattern match (score=1.0 for reference), overall should be reasonable
451        assert!(score > 0.5, "Expected coherence > 0.5 for reference patterns, got {}", score);
452    }
453
454    #[test]
455    fn test_code_context() {
456        // Create registry with custom code pattern
457        let mut registry = PatternRegistry::new();
458        let pattern = crate::memory::ConversationPattern::manual(
459            crate::memory::PatternType::Code,
460            "fn ",
461        );
462        registry.add_pattern(pattern);
463
464        let detector = CoherenceDetector::new_with_registry(0.7, registry);
465        let messages = vec![
466            create_text_message("Here's the function:\n```rust\nfn test() {}\n```"),
467            create_text_message("This function needs optimization"),
468        ];
469
470        let score = detector.calculate_coherence(&messages);
471        // Code patterns should boost coherence (reference score = 0.5 neutral when patterns match)
472        // Without presets, coherence depends on topic overlap and entity consistency
473        assert!(score >= 0.35, "Expected coherence >= 0.35 for code context, got {}", score);
474    }
475
476    #[test]
477    fn test_segmentation() {
478        let detector = CoherenceDetector::default();
479        let messages = vec![
480            create_text_message("Topic A: database optimization"),
481            create_text_message("More about database"),
482            create_text_message("Topic B: frontend design"),
483            create_text_message("More about frontend"),
484        ];
485
486        let segments = detector.segment_messages(&messages);
487        assert!(segments.len() >= 1);
488    }
489
490    // =========================================================================
491    // Pattern Registry Integration Tests
492    // =========================================================================
493
494    #[test]
495    fn test_new_with_registry() {
496        let registry = PatternRegistry::new();
497        let detector = CoherenceDetector::new_with_registry(0.8, registry);
498
499        assert_eq!(detector.threshold, 0.8);
500        assert!(detector.pattern_registry().is_empty());
501    }
502
503    #[test]
504    fn test_backward_compatible_new() {
505        // new() should work with empty registry
506        let detector = CoherenceDetector::new(0.7);
507
508        assert_eq!(detector.threshold, 0.7);
509        // PatternRegistry::new() is now empty, no presets loaded
510        assert!(detector.pattern_registry().is_empty());
511
512        // Should have no patterns from presets
513        let ref_patterns = detector.pattern_registry().get_active_reference_patterns();
514        let code_patterns = detector.pattern_registry().get_active_code_patterns();
515
516        assert!(ref_patterns.is_empty(), "Reference patterns should be empty");
517        assert!(code_patterns.is_empty(), "Code patterns should be empty");
518    }
519
520    #[test]
521    fn test_default_uses_pattern_registry() {
522        let detector = CoherenceDetector::default();
523
524        // Default should have empty registry (no presets)
525        let ref_patterns = detector.pattern_registry().get_active_reference_patterns();
526        let code_patterns = detector.pattern_registry().get_active_code_patterns();
527
528        assert!(ref_patterns.is_empty());
529        assert!(code_patterns.is_empty());
530    }
531
532    #[test]
533    fn test_pattern_registry_accessor() {
534        let detector = CoherenceDetector::default();
535
536        // Should be able to access the registry
537        let registry = detector.pattern_registry();
538        assert!(registry.is_empty());
539
540        // No patterns loaded
541        assert_eq!(registry.count_by_type(crate::memory::PatternType::Reference), 0);
542        assert_eq!(registry.count_by_type(crate::memory::PatternType::Code), 0);
543    }
544
545    #[test]
546    fn test_pattern_registry_mut_accessor() {
547        let mut detector = CoherenceDetector::default();
548        assert!(detector.pattern_registry().is_empty());
549
550        // Should be able to mutate the registry
551        let pattern = crate::memory::ConversationPattern::manual(
552            crate::memory::PatternType::Code,
553            "test-pattern-mut",
554        );
555        detector.pattern_registry_mut().add_pattern(pattern);
556
557        assert_eq!(detector.pattern_registry().len(), 1);
558    }
559
560    #[test]
561    fn test_with_hardcode_config() {
562        let config = HardcodeConfig::complex_technical();
563        let detector = CoherenceDetector::new(0.7).with_hardcode_config(config.clone());
564
565        assert_eq!(detector.hardcode_config.min_word_length, config.min_word_length);
566    }
567
568    #[test]
569    fn test_reference_patterns_from_registry() {
570        // Create registry with custom reference pattern
571        let mut registry = PatternRegistry::new();
572        let pattern = crate::memory::ConversationPattern::manual(
573            crate::memory::PatternType::Reference,
574            "as i mentioned",
575        );
576        registry.add_pattern(pattern);
577
578        let detector = CoherenceDetector::new_with_registry(0.7, registry);
579
580        // Test with a message that should match custom reference pattern
581        let messages = vec![
582            create_text_message("Let's implement feature X"),
583            create_text_message("As I mentioned earlier, feature X is important"),
584        ];
585
586        let score = detector.calculate_coherence(&messages);
587        // Should have high coherence due to reference pattern match
588        assert!(score > 0.5, "Expected coherence > 0.5 for reference patterns, got {}", score);
589    }
590
591    #[test]
592    fn test_code_patterns_from_registry() {
593        // Create registry with custom code pattern
594        let mut registry = PatternRegistry::new();
595        let pattern = crate::memory::ConversationPattern::manual(
596            crate::memory::PatternType::Code,
597            "fn ",
598        );
599        registry.add_pattern(pattern);
600
601        let detector = CoherenceDetector::new_with_registry(0.7, registry);
602
603        // Test with messages containing code patterns
604        let messages = vec![
605            create_text_message("Here is a function:\n```rust\nfn example() {}\n```"),
606            create_text_message("This function does something"),
607        ];
608
609        let score = detector.calculate_coherence(&messages);
610        // Should have reasonable coherence due to code pattern match
611        assert!(score >= 0.35, "Expected coherence >= 0.35 for code patterns, got {}", score);
612    }
613
614    #[test]
615    fn test_empty_registry_graceful_handling() {
616        // Create an empty registry (no presets)
617        let registry = PatternRegistry::new();
618        assert!(registry.is_empty());
619
620        let detector = CoherenceDetector::new_with_registry(0.7, registry);
621
622        // Should still work without crashing, returning neutral scores
623        let messages = vec![
624            create_text_message("Message one"),
625            create_text_message("Message two"),
626        ];
627
628        // Should not panic
629        let score = detector.calculate_coherence(&messages);
630        // Should get a valid score (neutral 0.5 for empty pattern registry)
631        assert!(score >= 0.0 && score <= 1.0);
632    }
633
634    #[test]
635    fn test_chinese_reference_patterns() {
636        // Create registry with Chinese reference pattern
637        let mut registry = PatternRegistry::new();
638        let pattern = crate::memory::ConversationPattern::manual(
639            crate::memory::PatternType::Reference,
640            "正如我所说",
641        );
642        registry.add_pattern(pattern);
643
644        let detector = CoherenceDetector::new_with_registry(0.7, registry);
645
646        // Test Chinese reference patterns
647        let messages = vec![
648            create_text_message("我们决定使用 PostgreSQL"),
649            create_text_message("正如我所说,PostgreSQL 是最佳选择"),
650        ];
651
652        let score = detector.calculate_coherence(&messages);
653        // Should detect Chinese reference pattern "正如我所说"
654        assert!(score > 0.5, "Expected coherence > 0.5 for Chinese reference patterns, got {}", score);
655    }
656
657    // =========================================================================
658    // Additional Coverage Tests
659    // =========================================================================
660
661    #[test]
662    fn test_empty_messages() {
663        let detector = CoherenceDetector::default();
664
665        // Empty messages should return true (keep together)
666        let messages: Vec<Message> = vec![];
667        assert!(detector.should_keep_together(&messages));
668
669        // Empty messages should return 1.0 coherence
670        let score = detector.calculate_coherence(&messages);
671        assert!((score - 1.0).abs() < 0.001, "Expected coherence 1.0 for empty messages, got {}", score);
672    }
673
674    #[test]
675    fn test_find_segmentation_points_empty() {
676        let detector = CoherenceDetector::default();
677
678        // Empty messages should return empty segmentation points
679        let messages: Vec<Message> = vec![];
680        let points = detector.find_segmentation_points(&messages);
681        assert!(points.is_empty());
682
683        // Single message should return empty segmentation points
684        let single = vec![create_text_message("single message")];
685        let points = detector.find_segmentation_points(&single);
686        assert!(points.is_empty());
687
688        // Two messages should return empty segmentation points
689        let two = vec![
690            create_text_message("first"),
691            create_text_message("second"),
692        ];
693        let points = detector.find_segmentation_points(&two);
694        assert!(points.is_empty());
695    }
696
697    #[test]
698    fn test_segment_messages_empty() {
699        let detector = CoherenceDetector::default();
700
701        // Empty messages should return empty segments
702        let messages: Vec<Message> = vec![];
703        let segments = detector.segment_messages(&messages);
704        assert!(segments.is_empty());
705    }
706
707    #[test]
708    fn test_regex_pattern_matching() {
709        // Test that regex patterns are properly matched
710        let detector = CoherenceDetector::default();
711
712        // Test with regex-style patterns (e.g., "之前的?讨论")
713        let messages = vec![
714            create_text_message("We discussed the architecture"),
715            create_text_message("之前的讨论很有价值"),  // Chinese "previous discussion"
716        ];
717
718        let score = detector.calculate_coherence(&messages);
719        // Should detect reference patterns
720        assert!(score >= 0.0 && score <= 1.0, "Score should be valid, got {}", score);
721    }
722
723    #[test]
724    fn test_simple_pattern_matching() {
725        // Create registry with custom reference pattern
726        let mut registry = PatternRegistry::new();
727        let pattern = crate::memory::ConversationPattern::manual(
728            crate::memory::PatternType::Reference,
729            "as mentioned",
730        );
731        registry.add_pattern(pattern);
732
733        let detector = CoherenceDetector::new_with_registry(0.7, registry);
734
735        // Test with simple text patterns
736        let messages = vec![
737            create_text_message("Let's implement feature A"),
738            create_text_message("As mentioned above, feature A is important"),
739        ];
740
741        let score = detector.calculate_coherence(&messages);
742        // Should detect "as mentioned" reference pattern
743        assert!(score > 0.5, "Expected high coherence for reference pattern match, got {}", score);
744    }
745
746    #[test]
747    fn test_low_coherence_messages() {
748        let detector = CoherenceDetector::new(0.7);
749
750        // Messages with no semantic connection
751        let messages = vec![
752            create_text_message("The quick brown fox jumps"),
753            create_text_message("Database optimization strategies"),
754            create_text_message("Weather forecast tomorrow"),
755        ];
756
757        let score = detector.calculate_coherence(&messages);
758        // Should have lower coherence due to no common entities or references
759        // But still a valid score
760        assert!(score >= 0.0 && score <= 1.0, "Score should be valid, got {}", score);
761    }
762
763    #[test]
764    fn test_entity_consistency() {
765        let detector = CoherenceDetector::default();
766
767        // Messages sharing entities (file names, function names)
768        let messages = vec![
769            create_text_message("In process.rs we have a bug"),
770            create_text_message("The process.rs file needs fixing"),
771            create_text_message("Let me check process.rs again"),
772        ];
773
774        let score = detector.calculate_coherence(&messages);
775        // Entity consistency contributes to coherence, but entity_score is only 20% weight
776        // So the overall score depends on topic, reference, and code scores too
777        // Verify that entity extraction finds "process.rs"
778        assert!(score >= 0.3 && score <= 1.0, "Expected valid coherence score, got {}", score);
779
780        // Compare with messages that have no shared entities
781        let no_entity_messages = vec![
782            create_text_message("First topic discussion"),
783            create_text_message("Second unrelated topic"),
784            create_text_message("Third different subject"),
785        ];
786        let no_entity_score = detector.calculate_coherence(&no_entity_messages);
787
788        // Messages with shared entities should generally have higher or equal coherence
789        // (entity_score contribution is 0.7 for 1-2 common entities vs 0.3 for none)
790        // Note: This depends on other factors too, so we just verify valid scores
791        assert!(no_entity_score >= 0.0 && no_entity_score <= 1.0);
792    }
793
794    #[test]
795    fn test_custom_registry_affects_detection() {
796        // Test that custom registry patterns affect detection
797        use crate::memory::{ConversationPattern, PatternType};
798
799        let mut registry = PatternRegistry::new();
800        let initial_count = registry.len();
801
802        // Add custom pattern
803        let custom_pattern = ConversationPattern::manual(
804            PatternType::Reference,
805            "custom_reference_pattern_xyz",
806        );
807        registry.add_pattern(custom_pattern);
808
809        assert!(registry.len() > initial_count);
810
811        // Create detector with custom registry
812        let detector = CoherenceDetector::new_with_registry(0.7, registry);
813
814        // Test that custom pattern is in the registry
815        let ref_patterns = detector.pattern_registry().get_active_reference_patterns();
816        assert!(ref_patterns.iter().any(|p| p.contains("custom_reference_pattern_xyz")));
817    }
818
819    #[test]
820    fn test_multiple_code_blocks() {
821        let detector = CoherenceDetector::default();
822
823        // Multiple code blocks in sequence
824        let messages = vec![
825            create_text_message("Here is the first function:\n```rust\nfn one() {}\n```"),
826            create_text_message("And the second:\n```rust\nfn two() {}\n```"),
827            create_text_message("The third function:\n```rust\nfn three() {}\n```"),
828        ];
829
830        let score = detector.calculate_coherence(&messages);
831        // Should have good coherence due to consecutive code patterns
832        assert!(score >= 0.0 && score <= 1.0);
833    }
834
835    #[test]
836    fn test_threshold_affects_should_keep_together() {
837        // Low threshold should keep more messages together
838        let low_threshold = CoherenceDetector::new(0.3);
839        // High threshold should separate more messages
840        let high_threshold = CoherenceDetector::new(0.9);
841
842        let messages = vec![
843            create_text_message("Topic A discussion"),
844            create_text_message("Related to topic A"),
845        ];
846
847        // Both should keep these related messages together
848        assert!(low_threshold.should_keep_together(&messages));
849        // High threshold may or may not keep together depending on coherence
850        let _score = high_threshold.calculate_coherence(&messages);
851        // Just verify it doesn't crash
852    }
853}