use matrixcode_core::memory::{
ConversationPattern, PatternRegistry, PatternRegistryConfig, PatternSource, PatternType,
};
use matrixcode_core::compress::{CoherenceDetector, HardcodeConfig};
use matrixcode_core::providers::{Message, MessageContent, Role};
use tempfile::tempdir;
use std::fs;
fn create_text_message(text: &str) -> Message {
Message {
role: Role::User,
content: MessageContent::Text(text.to_string()),
}
}
fn create_text_message_with_role(text: &str, role: Role) -> Message {
Message {
role,
content: MessageContent::Text(text.to_string()),
}
}
fn find_pattern_by_string<'a>(registry: &'a PatternRegistry, pattern_str: &str) -> Option<&'a ConversationPattern> {
registry.all_patterns().iter().find(|p| p.pattern == pattern_str)
}
#[test]
fn test_registry_to_detector_data_flow() {
let mut registry = PatternRegistry::new();
let custom_ref = ConversationPattern::manual(PatternType::Reference, "custom-ref-pattern")
.with_description("Custom reference pattern for testing");
registry.add_pattern(custom_ref);
let custom_code = ConversationPattern::manual(PatternType::Code, "custom_code_keyword")
.with_description("Custom code pattern for testing");
registry.add_pattern(custom_code);
let initial_count = registry.len();
let detector = CoherenceDetector::new_with_registry(0.7, registry);
let detector_registry = detector.pattern_registry();
assert_eq!(detector_registry.len(), initial_count);
let ref_patterns = detector_registry.get_active_reference_patterns();
assert!(ref_patterns.iter().any(|p| p.contains("custom-ref-pattern")));
let code_patterns = detector_registry.get_active_code_patterns();
assert!(code_patterns.iter().any(|p| p.contains("custom_code_keyword")));
}
#[test]
fn test_detector_uses_registry_patterns_for_detection() {
let mut registry = PatternRegistry::new();
let specific_ref = ConversationPattern::manual(PatternType::Reference, "INTEGRATION_TEST_REF");
registry.add_pattern(specific_ref);
let detector = CoherenceDetector::new_with_registry(0.5, registry);
let messages = vec![
create_text_message("First message about topic"),
create_text_message("INTEGRATION_TEST_REF as we discussed"),
];
let coherence = detector.calculate_coherence(&messages);
assert!(coherence >= 0.0 && coherence <= 1.0);
}
#[test]
fn test_registry_changes_affect_detector() {
let mut detector = CoherenceDetector::new(0.7);
let initial_count = detector.pattern_registry().len();
let new_pattern = ConversationPattern::manual(PatternType::Reference, "dynamic-added-pattern");
detector.pattern_registry_mut().add_pattern(new_pattern);
assert!(detector.pattern_registry().len() > initial_count);
let ref_patterns = detector.pattern_registry().get_active_reference_patterns();
assert!(ref_patterns.iter().any(|p| p.contains("dynamic-added-pattern")));
}
#[test]
fn test_detector_with_empty_active_patterns() {
let mut registry = PatternRegistry::new();
let ids: Vec<String> = registry.all_patterns().iter().map(|p| p.id.clone()).collect();
for id in ids {
registry.deactivate_pattern(&id);
}
let detector = CoherenceDetector::new_with_registry(0.7, registry);
let messages = vec![
create_text_message("Message one"),
create_text_message("Message two"),
create_text_message("Message three"),
];
let coherence = detector.calculate_coherence(&messages);
assert!(coherence >= 0.0 && coherence <= 1.0);
let ref_patterns = detector.pattern_registry().get_active_reference_patterns();
let code_patterns = detector.pattern_registry().get_active_code_patterns();
assert!(ref_patterns.is_empty());
assert!(code_patterns.is_empty());
}
#[test]
fn test_learn_patterns_from_extraction_result() {
let mut registry = PatternRegistry::new();
let initial_count = registry.len();
let extracted_patterns = vec![
ConversationPattern::new(
PatternType::Reference,
"learn-ref-pattern-unique",
PatternSource::user_conversation("User said: learn-ref-pattern-unique"),
),
ConversationPattern::new(
PatternType::Code,
"learn-code-pattern-unique",
PatternSource::user_conversation("User code: learn-code-pattern-unique"),
),
];
registry.learn_patterns(&extracted_patterns);
assert!(registry.len() >= initial_count);
registry.learn_patterns(&extracted_patterns);
let ref_pattern = find_pattern_by_string(®istry, "learn-ref-pattern-unique");
if let Some(p) = ref_pattern {
assert!(p.frequency >= 2);
}
}
#[test]
fn test_auto_learn_disabled() {
let config = PatternRegistryConfig {
max_patterns_per_type: 100,
min_confidence_threshold: 0.3,
min_frequency: 2,
auto_learn: false, inactive_after_days: 90,
};
let mut registry = PatternRegistry::with_config(config);
let initial_count = registry.len();
let new_patterns = vec![
ConversationPattern::new(
PatternType::Reference,
"brand-new-pattern-auto-learn-off",
PatternSource::user_conversation("test"),
),
];
registry.learn_patterns(&new_patterns);
assert_eq!(registry.len(), initial_count);
}
#[test]
fn test_learning_preserves_existing_patterns() {
let mut registry = PatternRegistry::new();
let manual_pattern = ConversationPattern::manual(PatternType::Code, "preserve-test-pattern")
.with_description("This should be preserved");
let manual_id = manual_pattern.id.clone();
registry.add_pattern(manual_pattern);
let learned_patterns = vec![
ConversationPattern::new(
PatternType::Code,
"preserve-test-pattern", PatternSource::user_conversation("learned context"),
),
ConversationPattern::new(
PatternType::Reference,
"new-learned-pattern",
PatternSource::user_conversation("another learned context"),
),
];
registry.learn_patterns(&learned_patterns);
let found = registry.get_pattern(&manual_id);
assert!(found.is_some());
let found_pattern = found.unwrap();
assert!(found_pattern.source.is_manual());
assert!(found_pattern.frequency >= 2);
}
#[test]
fn test_full_chain_save_load_detect() {
let temp_dir = tempdir().unwrap();
let file_path = temp_dir.path().join("chain_test_patterns.json");
let mut original_registry = PatternRegistry::new();
let custom_patterns = vec![
ConversationPattern::manual(PatternType::Reference, "chain-ref-pattern")
.with_description("Chain test reference pattern")
.with_tag("chain-test"),
ConversationPattern::manual(PatternType::Code, "chain_code_pattern")
.with_description("Chain test code pattern")
.with_tag("chain-test"),
];
original_registry.add_patterns(custom_patterns);
let original_count = original_registry.len();
original_registry.save_to_file(&file_path).unwrap();
assert!(file_path.exists());
let loaded_registry = PatternRegistry::from_file(&file_path).unwrap();
assert_eq!(loaded_registry.len(), original_count);
let detector = CoherenceDetector::new_with_registry(0.7, loaded_registry);
let messages = vec![
create_text_message("Let's discuss chain-ref-pattern"),
create_text_message("As chain-ref-pattern mentioned earlier"),
create_text_message("chain_code_pattern implementation"),
];
let coherence = detector.calculate_coherence(&messages);
assert!(coherence >= 0.0 && coherence <= 1.0);
let loaded = detector.pattern_registry();
assert!(loaded.all_patterns().iter().any(|p| p.pattern == "chain-ref-pattern"));
assert!(loaded.all_patterns().iter().any(|p| p.pattern == "chain_code_pattern"));
}
#[test]
fn test_chain_with_corrupted_file_fallback() {
let temp_dir = tempdir().unwrap();
let file_path = temp_dir.path().join("corrupted_patterns.json");
fs::write(&file_path, "{ not valid json }").unwrap();
let registry = PatternRegistry::from_file(&file_path).unwrap();
assert!(!registry.is_empty());
assert!(registry.count_by_type(PatternType::Reference) > 0);
assert!(registry.count_by_type(PatternType::Code) > 0);
let detector = CoherenceDetector::new_with_registry(0.7, registry);
let messages = vec![
create_text_message("Test message"),
create_text_message("Another test message"),
];
let coherence = detector.calculate_coherence(&messages);
assert!(coherence >= 0.0 && coherence <= 1.0);
}
#[test]
fn test_chain_with_nonexistent_file() {
let temp_dir = tempdir().unwrap();
let file_path = temp_dir.path().join("nonexistent_patterns.json");
assert!(!file_path.exists());
let registry = PatternRegistry::from_file(&file_path).unwrap();
assert!(!registry.is_empty());
let detector = CoherenceDetector::new_with_registry(0.7, registry);
let messages = vec![
create_text_message("As I mentioned before"),
create_text_message("Following up on previous discussion"),
];
let coherence = detector.calculate_coherence(&messages);
assert!(coherence >= 0.0 && coherence <= 1.0, "Coherence should be valid");
}
#[test]
fn test_prune_removes_low_frequency_old_patterns() {
let config = PatternRegistryConfig {
max_patterns_per_type: 100,
min_confidence_threshold: 0.5,
min_frequency: 3, auto_learn: true,
inactive_after_days: 30, };
let mut registry = PatternRegistry::with_config(config);
let mut low_freq_old = ConversationPattern::new(
PatternType::Reference,
"low-freq-old",
PatternSource::user_conversation("test"),
);
low_freq_old.frequency = 1; low_freq_old.confidence = 0.4; low_freq_old.is_active = false; low_freq_old.last_used = chrono::Utc::now() - chrono::Duration::days(40);
let mut low_freq_recent = ConversationPattern::new(
PatternType::Reference,
"low-freq-recent",
PatternSource::user_conversation("test"),
);
low_freq_recent.frequency = 2; low_freq_recent.confidence = 0.4; low_freq_recent.is_active = false;
low_freq_recent.last_used = chrono::Utc::now() - chrono::Duration::days(5);
let mut high_freq_old = ConversationPattern::new(
PatternType::Reference,
"high-freq-old",
PatternSource::user_conversation("test"),
);
high_freq_old.frequency = 5; high_freq_old.confidence = 0.6; high_freq_old.is_active = false;
high_freq_old.last_used = chrono::Utc::now() - chrono::Duration::days(60);
registry.add_pattern(low_freq_old);
registry.add_pattern(low_freq_recent);
registry.add_pattern(high_freq_old);
registry.prune();
assert!(find_pattern_by_string(®istry, "low-freq-old").is_none(),
"Low frequency old pattern should be pruned");
assert!(find_pattern_by_string(®istry, "low-freq-recent").is_some(),
"Recent pattern should be kept even with low frequency");
assert!(find_pattern_by_string(®istry, "high-freq-old").is_some(),
"High frequency pattern should be kept");
}
#[test]
fn test_prune_presets_not_cleaned() {
let mut registry = PatternRegistry::new();
let presets_before = registry.all_patterns().iter()
.filter(|p| p.source.is_preset())
.count();
let preset_ids: Vec<String> = registry.all_patterns().iter()
.filter(|p| p.source.is_preset())
.take(3)
.map(|p| p.id.clone())
.collect();
for id in preset_ids {
registry.deactivate_pattern(&id);
}
registry.prune();
let presets_after = registry.all_patterns().iter()
.filter(|p| p.source.is_preset())
.count();
assert_eq!(presets_before, presets_after, "Presets should never be cleaned");
}
#[test]
fn test_prune_manual_patterns_not_cleaned() {
let mut registry = PatternRegistry::new();
let mut old_manual = ConversationPattern::manual(PatternType::Code, "old-manual-pattern");
old_manual.frequency = 1;
old_manual.last_used = chrono::Utc::now() - chrono::Duration::days(100);
old_manual.is_active = false;
registry.add_pattern(old_manual);
let manual_count_before = registry.all_patterns().iter()
.filter(|p| p.source.is_manual())
.count();
registry.prune();
let manual_count_after = registry.all_patterns().iter()
.filter(|p| p.source.is_manual())
.count();
assert_eq!(manual_count_before, manual_count_after,
"Manual patterns should never be cleaned regardless of frequency/age");
}
#[test]
fn test_prune_keeps_active_patterns() {
let mut registry = PatternRegistry::new();
let mut active_old_low_freq = ConversationPattern::manual(PatternType::Reference, "active-old-low");
active_old_low_freq.frequency = 1;
active_old_low_freq.last_used = chrono::Utc::now() - chrono::Duration::days(100);
active_old_low_freq.is_active = true;
registry.add_pattern(active_old_low_freq);
registry.prune();
assert!(find_pattern_by_string(®istry, "active-old-low").is_some(),
"Active patterns should be kept regardless of frequency/age");
}
#[test]
fn test_prune_frequency_threshold_boundary() {
let config = PatternRegistryConfig {
min_frequency: 3, ..PatternRegistryConfig::default()
};
let mut registry = PatternRegistry::with_config(config);
let mut threshold_pattern = ConversationPattern::manual(PatternType::Code, "threshold-exact");
threshold_pattern.frequency = 3; threshold_pattern.is_active = false;
threshold_pattern.last_used = chrono::Utc::now() - chrono::Duration::days(40);
registry.add_pattern(threshold_pattern);
registry.prune();
assert!(find_pattern_by_string(®istry, "threshold-exact").is_some(),
"Pattern at exact threshold frequency should be kept");
}
#[test]
fn test_chinese_reference_patterns_detection() {
let detector = CoherenceDetector::default();
let messages = vec![
create_text_message("我们决定使用 PostgreSQL 作为数据库"),
create_text_message("正如我所说,PostgreSQL 是最佳选择"),
create_text_message("之前提到过这个问题"),
];
let coherence = detector.calculate_coherence(&messages);
assert!(coherence > 0.5,
"Expected high coherence for Chinese reference patterns, got {}", coherence);
}
#[test]
fn test_chinese_code_context() {
let detector = CoherenceDetector::default();
let messages = vec![
create_text_message("这是一个 Rust 函数示例:\n```rust\nfn process() {}\n```"),
create_text_message("这个函数需要优化"),
];
let coherence = detector.calculate_coherence(&messages);
assert!(coherence >= 0.0 && coherence <= 1.0,
"Coherence should be valid for Chinese with code context, got {}", coherence);
}
#[test]
fn test_chinese_topic_continuity() {
let detector = CoherenceDetector::default();
let messages = vec![
create_text_message("我们需要优化数据库性能"),
create_text_message("数据库查询速度太慢了"),
create_text_message("数据库索引可以解决这个问题"),
];
let coherence = detector.calculate_coherence(&messages);
assert!(coherence >= 0.0 && coherence <= 1.0,
"Coherence should be valid for Chinese topic continuity, got {}", coherence);
}
#[test]
fn test_chinese_entity_consistency() {
let detector = CoherenceDetector::default();
let messages = vec![
create_text_message("在 process.rs 文件中有一个 bug"),
create_text_message("process.rs 需要修复"),
create_text_message("检查 process.rs 的逻辑"),
];
let coherence = detector.calculate_coherence(&messages);
assert!(coherence > 0.3,
"Expected reasonable coherence for Chinese entity consistency, got {}", coherence);
}
#[test]
fn test_mixed_chinese_english_dialogue() {
let detector = CoherenceDetector::default();
let messages = vec![
create_text_message("We decided to use PostgreSQL for the database"),
create_text_message("正如前面所说,PostgreSQL 是我们的选择"),
create_text_message("The PostgreSQL configuration needs adjustment"),
];
let coherence = detector.calculate_coherence(&messages);
assert!(coherence >= 0.0 && coherence <= 1.0);
assert!(coherence > 0.3,
"Expected some coherence from shared entities like PostgreSQL, got {}", coherence);
}
#[test]
fn test_mixed_with_code_blocks() {
let detector = CoherenceDetector::default();
let messages = vec![
create_text_message("这是一个 async function:\n```typescript\nasync function fetch() {}\n```"),
create_text_message("The async function needs error handling"),
create_text_message("给这个 async 函数添加错误处理"),
];
let coherence = detector.calculate_coherence(&messages);
assert!(coherence > 0.4,
"Expected coherence from code patterns in mixed dialogue, got {}", coherence);
}
#[test]
fn test_multilingual_reference_patterns() {
let detector = CoherenceDetector::default();
let messages = vec![
create_text_message("We discussed the architecture earlier"),
create_text_message("正如我所说,架构设计很重要"),
create_text_message("As mentioned above, we need to finalize"),
];
let coherence = detector.calculate_coherence(&messages);
assert!(coherence >= 0.0 && coherence <= 1.0,
"Coherence should be valid for multilingual reference patterns, got {}", coherence);
}
#[test]
fn test_segment_long_coherent_conversation() {
let detector = CoherenceDetector::default();
let messages: Vec<Message> = (0..20)
.map(|i| create_text_message(&format!("Message {} about database optimization topic", i)))
.collect();
let segments = detector.segment_messages(&messages);
assert_eq!(segments.len(), 1, "Coherent conversation should stay as one segment");
assert_eq!(segments[0].len(), 20);
}
#[test]
fn test_segment_conversation_with_topic_shift() {
let detector = CoherenceDetector::default();
let messages = vec![
create_text_message("Message 1 about database optimization"),
create_text_message("Message 2 about database queries"),
create_text_message("Message 3 about database indexes"),
create_text_message("Message 4 about frontend React components"),
create_text_message("Message 5 about frontend styling"),
create_text_message("Message 6 about frontend animations"),
];
let segments = detector.segment_messages(&messages);
assert!(segments.len() >= 1, "Should have at least one segment");
let total = segments.iter().map(|s| s.len()).sum::<usize>();
assert_eq!(total, 6, "All messages should be in segments");
}
#[test]
fn test_segment_empty_conversation() {
let detector = CoherenceDetector::default();
let messages: Vec<Message> = vec![];
let segments = detector.segment_messages(&messages);
assert!(segments.is_empty(), "Empty conversation should have no segments");
}
#[test]
fn test_segment_single_message() {
let detector = CoherenceDetector::default();
let messages = vec![create_text_message("Single message")];
let segments = detector.segment_messages(&messages);
assert_eq!(segments.len(), 1);
assert_eq!(segments[0].len(), 1);
}
#[test]
fn test_segment_with_code_blocks() {
let detector = CoherenceDetector::default();
let messages = vec![
create_text_message("Here's the first function:\n```rust\nfn one() {}\n```"),
create_text_message("And the second function:\n```rust\nfn two() {}\n```"),
create_text_message("The third function:\n```rust\nfn three() {}\n```"),
create_text_message("Now let's discuss the deployment strategy"),
create_text_message("Deployment to Kubernetes cluster"),
];
let segments = detector.segment_messages(&messages);
assert!(segments.len() >= 1);
let total = segments.iter().map(|s| s.len()).sum::<usize>();
assert_eq!(total, 5);
}
#[test]
fn test_find_segmentation_points_long_conversation() {
let detector = CoherenceDetector::default();
let messages: Vec<Message> = vec![
create_text_message("Discussion about API design"),
create_text_message("API endpoints need authentication"),
create_text_message("API rate limiting is important"),
create_text_message("Now moving to database schema"),
create_text_message("Database tables design"),
create_text_message("Database migrations strategy"),
create_text_message("Frontend component structure"),
create_text_message("React component hierarchy"),
create_text_message("State management approach"),
];
let points = detector.find_segmentation_points(&messages);
assert!(points.len() >= 0, "Should return valid segmentation points");
for point in &points {
assert!(*point > 0 && *point < messages.len());
}
}
#[test]
fn test_detector_with_custom_hardcode_config() {
let config = HardcodeConfig::complex_technical();
let detector = CoherenceDetector::new(0.7).with_hardcode_config(config);
let messages = vec![
create_text_message("Implementing async fn process() in Rust"),
create_text_message("The async fn needs proper error handling"),
];
let coherence = detector.calculate_coherence(&messages);
assert!(coherence >= 0.0 && coherence <= 1.0);
}
#[test]
fn test_hardcode_config_affects_keyword_extraction() {
let simple_config = HardcodeConfig::simple_conversation();
let simple_detector = CoherenceDetector::new(0.7)
.with_hardcode_config(simple_config);
let complex_config = HardcodeConfig::complex_technical();
let complex_detector = CoherenceDetector::new(0.7)
.with_hardcode_config(complex_config);
let messages = vec![
create_text_message("Short message about database"),
create_text_message("Another short database message"),
];
let simple_coherence = simple_detector.calculate_coherence(&messages);
let complex_coherence = complex_detector.calculate_coherence(&messages);
assert!(simple_coherence >= 0.0 && simple_coherence <= 1.0);
assert!(complex_coherence >= 0.0 && complex_coherence <= 1.0);
}
#[test]
fn test_detector_with_alternating_roles() {
let detector = CoherenceDetector::default();
let messages = vec![
create_text_message_with_role("User question about database", Role::User),
create_text_message_with_role("Assistant response about database", Role::Assistant),
create_text_message_with_role("User follow-up about database", Role::User),
create_text_message_with_role("Assistant clarification about database", Role::Assistant),
];
let coherence = detector.calculate_coherence(&messages);
assert!(coherence > 0.4, "Topic continuity should work with alternating roles");
}
#[test]
fn test_registry_add_remove_cycles() {
let mut registry = PatternRegistry::new();
for i in 0..5 {
let pattern = ConversationPattern::manual(PatternType::Code, &format!("cycle-{}", i));
let id = pattern.id.clone();
registry.add_pattern(pattern);
assert!(registry.remove_pattern(&id));
}
assert!(!registry.is_empty());
assert!(registry.count_by_type(PatternType::Reference) > 0);
}
#[test]
fn test_large_pattern_count_handling() {
let mut registry = PatternRegistry::with_config(
PatternRegistryConfig::with_max_patterns(10)
);
for i in 0..20 {
let pattern = ConversationPattern::manual(PatternType::Reference, &format!("large-{}", i));
registry.add_pattern(pattern);
}
assert!(registry.count_by_type(PatternType::Reference) <= 10);
}
#[test]
fn test_conversation_pattern_frequency_overflow() {
let mut pattern = ConversationPattern::manual(PatternType::Code, "overflow-test");
pattern.frequency = u32::MAX - 1;
pattern.mark_used();
assert_eq!(pattern.frequency, u32::MAX);
pattern.mark_used();
assert_eq!(pattern.frequency, u32::MAX);
}
#[test]
fn test_registry_stats_after_learning() {
let mut registry = PatternRegistry::new();
let initial_stats = registry.stats();
let patterns = vec![
ConversationPattern::new(PatternType::Reference, "stats-ref-1", PatternSource::user_conversation("test")),
ConversationPattern::new(PatternType::Code, "stats-code-1", PatternSource::user_conversation("test")),
];
registry.learn_patterns(&patterns);
let new_stats = registry.stats();
assert!(new_stats.total >= initial_stats.total);
assert!(new_stats.learned >= 0);
}
#[test]
fn test_registry_stats_with_deactivation() {
let mut registry = PatternRegistry::new();
let p1 = ConversationPattern::manual(PatternType::Code, "stats-deact-1");
let p1_id = p1.id.clone();
registry.add_pattern(p1);
let stats_before = registry.stats();
let active_before = stats_before.active;
registry.deactivate_pattern(&p1_id);
let stats_after = registry.stats();
assert!(stats_after.active < active_before || stats_after.active == active_before - 1);
assert!(stats_after.inactive > stats_before.inactive || stats_after.inactive == stats_before.inactive + 1);
}