1use super::query_core::{QueryCore, QuestionType};
44use brainwires_core::confidence::ResponseConfidence;
45use brainwires_core::graph::EntityType;
46use brainwires_tool_runtime::{ToolErrorCategory, ToolOutcome};
47use chrono::Utc;
48use std::collections::HashMap;
49
50#[derive(Debug, Clone)]
52pub struct TrackedEntity {
53 pub name: String,
55 pub entity_type: EntityType,
57 pub mention_turns: Vec<u32>,
59 pub was_queried: bool,
61 pub was_modified: bool,
63 pub discovered_relations: Vec<(String, String)>, }
66
67impl TrackedEntity {
68 pub fn new(name: String, entity_type: EntityType, turn: u32) -> Self {
70 Self {
71 name,
72 entity_type,
73 mention_turns: vec![turn],
74 was_queried: false,
75 was_modified: false,
76 discovered_relations: Vec::new(),
77 }
78 }
79
80 pub fn record_mention(&mut self, turn: u32) {
82 if !self.mention_turns.contains(&turn) {
83 self.mention_turns.push(turn);
84 }
85 }
86
87 pub fn frequency(&self) -> usize {
89 self.mention_turns.len()
90 }
91}
92
93#[derive(Debug, Clone)]
95pub struct CoreferenceRecord {
96 pub reference: String,
98 pub resolved_to: String,
100 pub confidence: f32,
102 pub turn: u32,
104 pub confirmed: Option<bool>,
106}
107
108#[derive(Debug, Clone)]
110pub struct QueryRecord {
111 pub original: String,
113 pub resolved: String,
115 pub question_type: QuestionType,
117 pub query_sexp: Option<String>,
119 pub turn: u32,
121 pub success: bool,
123 pub result_count: usize,
125 pub execution_time_ms: u64,
127}
128
129#[derive(Debug)]
131pub struct LocalMemory {
132 pub conversation_id: String,
134 pub entities: HashMap<String, TrackedEntity>,
136 pub coreference_log: Vec<CoreferenceRecord>,
138 pub query_history: Vec<QueryRecord>,
140 pub focus_stack: Vec<String>,
142 pub current_turn: u32,
144}
145
146impl LocalMemory {
147 pub fn new(conversation_id: String) -> Self {
149 Self {
150 conversation_id,
151 entities: HashMap::new(),
152 coreference_log: Vec::new(),
153 query_history: Vec::new(),
154 focus_stack: Vec::new(),
155 current_turn: 0,
156 }
157 }
158
159 pub fn next_turn(&mut self) {
161 self.current_turn += 1;
162 }
163
164 pub fn track_entity(&mut self, name: &str, entity_type: EntityType) {
166 if let Some(entity) = self.entities.get_mut(name) {
167 entity.record_mention(self.current_turn);
168 } else {
169 self.entities.insert(
170 name.to_string(),
171 TrackedEntity::new(name.to_string(), entity_type, self.current_turn),
172 );
173 }
174
175 self.focus_stack.retain(|n| n != name);
177 self.focus_stack.insert(0, name.to_string());
178 if self.focus_stack.len() > 20 {
179 self.focus_stack.truncate(20);
180 }
181 }
182
183 pub fn record_coreference(&mut self, reference: &str, resolved_to: &str, confidence: f32) {
185 self.coreference_log.push(CoreferenceRecord {
186 reference: reference.to_string(),
187 resolved_to: resolved_to.to_string(),
188 confidence,
189 turn: self.current_turn,
190 confirmed: None,
191 });
192 }
193
194 #[allow(clippy::too_many_arguments)]
196 pub fn record_query(
197 &mut self,
198 original: &str,
199 resolved: &str,
200 question_type: QuestionType,
201 query_sexp: Option<String>,
202 success: bool,
203 result_count: usize,
204 execution_time_ms: u64,
205 ) {
206 self.query_history.push(QueryRecord {
207 original: original.to_string(),
208 resolved: resolved.to_string(),
209 question_type,
210 query_sexp,
211 turn: self.current_turn,
212 success,
213 result_count,
214 execution_time_ms,
215 });
216 }
217
218 pub fn get_frequent_entities(&self, limit: usize) -> Vec<&TrackedEntity> {
220 let mut entities: Vec<_> = self.entities.values().collect();
221 entities.sort_by_key(|e| std::cmp::Reverse(e.frequency()));
222 entities.into_iter().take(limit).collect()
223 }
224
225 pub fn get_recent_coreferences(&self, count: usize) -> Vec<&CoreferenceRecord> {
227 self.coreference_log.iter().rev().take(count).collect()
228 }
229
230 pub fn get_success_rate(&self, question_type: &QuestionType) -> f32 {
232 let relevant: Vec<_> = self
233 .query_history
234 .iter()
235 .filter(|q| &q.question_type == question_type)
236 .collect();
237
238 if relevant.is_empty() {
239 return 0.5; }
241
242 let successes = relevant.iter().filter(|q| q.success).count();
243 successes as f32 / relevant.len() as f32
244 }
245}
246
247#[derive(Debug, Clone)]
249pub struct QueryPattern {
250 pub id: String,
252 pub question_type: QuestionType,
254 pub template: String,
256 pub required_types: Vec<EntityType>,
258 pub success_count: u32,
260 pub failure_count: u32,
262 pub avg_results: f32,
264 pub created_at: i64,
266 pub last_used_at: i64,
268}
269
270impl QueryPattern {
271 pub fn new(
273 question_type: QuestionType,
274 template: String,
275 required_types: Vec<EntityType>,
276 ) -> Self {
277 let now = Utc::now().timestamp();
278 Self {
279 id: uuid::Uuid::new_v4().to_string(),
280 question_type,
281 template,
282 required_types,
283 success_count: 0,
284 failure_count: 0,
285 avg_results: 0.0,
286 created_at: now,
287 last_used_at: now,
288 }
289 }
290
291 pub fn reliability(&self) -> f32 {
293 let total = self.success_count + self.failure_count;
294 if total == 0 {
295 return 0.5; }
297 self.success_count as f32 / total as f32
298 }
299
300 pub fn record_success(&mut self, result_count: usize) {
302 self.success_count += 1;
303 self.last_used_at = Utc::now().timestamp();
304
305 let alpha = 0.3;
307 self.avg_results = alpha * result_count as f32 + (1.0 - alpha) * self.avg_results;
308 }
309
310 pub fn record_failure(&mut self) {
312 self.failure_count += 1;
313 self.last_used_at = Utc::now().timestamp();
314 }
315
316 pub fn matches_types(&self, types: &[EntityType]) -> bool {
318 self.required_types.iter().all(|rt| types.contains(rt))
319 }
320}
321
322#[derive(Debug, Clone)]
324pub struct ResolutionPattern {
325 pub reference_type: String,
327 pub entity_type: EntityType,
329 pub context_pattern: Option<String>,
331 pub success_count: u32,
333 pub failure_count: u32,
335}
336
337#[derive(Debug, Clone)]
339pub struct ToolErrorPattern {
340 pub tool_name: String,
342 pub error_category: String,
344 pub occurrence_count: u32,
346 pub last_occurred: i64,
348 pub suggested_fix: Option<String>,
350 pub input_patterns: Vec<String>,
352}
353
354impl ToolErrorPattern {
355 pub fn new(tool_name: &str, error_category: &ToolErrorCategory) -> Self {
357 Self {
358 tool_name: tool_name.to_string(),
359 error_category: error_category.category_name().to_string(),
360 occurrence_count: 1,
361 last_occurred: Utc::now().timestamp(),
362 suggested_fix: error_category.get_suggestion(),
363 input_patterns: Vec::new(),
364 }
365 }
366
367 pub fn record_occurrence(&mut self) {
369 self.occurrence_count += 1;
370 self.last_occurred = Utc::now().timestamp();
371 }
372
373 pub fn is_frequent(&self) -> bool {
375 self.occurrence_count >= 3
376 }
377}
378
379#[derive(Debug, Clone, Default)]
381pub struct ToolStats {
382 pub success_count: u32,
384 pub failure_count: u32,
386 pub total_retries: u32,
388 pub avg_execution_time_ms: f64,
390 pub last_used: i64,
392}
393
394impl ToolStats {
395 pub fn record_success(&mut self, retries: u32, execution_time_ms: u64) {
397 self.success_count += 1;
398 self.total_retries += retries;
399 self.last_used = Utc::now().timestamp();
400
401 let alpha = 0.3;
403 self.avg_execution_time_ms =
404 alpha * execution_time_ms as f64 + (1.0 - alpha) * self.avg_execution_time_ms;
405 }
406
407 pub fn record_failure(&mut self, retries: u32, execution_time_ms: u64) {
409 self.failure_count += 1;
410 self.total_retries += retries;
411 self.last_used = Utc::now().timestamp();
412
413 let alpha = 0.3;
414 self.avg_execution_time_ms =
415 alpha * execution_time_ms as f64 + (1.0 - alpha) * self.avg_execution_time_ms;
416 }
417
418 pub fn success_rate(&self) -> f64 {
420 let total = self.success_count + self.failure_count;
421 if total == 0 {
422 0.5 } else {
424 self.success_count as f64 / total as f64
425 }
426 }
427
428 pub fn avg_retries(&self) -> f64 {
430 let total = self.success_count + self.failure_count;
431 if total == 0 {
432 0.0
433 } else {
434 self.total_retries as f64 / total as f64
435 }
436 }
437}
438
439#[derive(Debug, Clone, Default)]
441pub struct ConfidenceStats {
442 pub sample_count: u32,
444 pub confidence_sum: f64,
446 pub low_confidence_count: u32,
448 pub high_confidence_count: u32,
450}
451
452impl ConfidenceStats {
453 pub fn record_sample(&mut self, confidence: &ResponseConfidence) {
455 self.sample_count += 1;
456 self.confidence_sum += confidence.score;
457
458 if confidence.is_low_confidence() {
459 self.low_confidence_count += 1;
460 } else if confidence.is_high_confidence() {
461 self.high_confidence_count += 1;
462 }
463 }
464
465 pub fn avg_confidence(&self) -> f64 {
467 if self.sample_count == 0 {
468 0.5
469 } else {
470 self.confidence_sum / self.sample_count as f64
471 }
472 }
473
474 pub fn low_confidence_ratio(&self) -> f64 {
476 if self.sample_count == 0 {
477 0.0
478 } else {
479 self.low_confidence_count as f64 / self.sample_count as f64
480 }
481 }
482}
483
484#[derive(Debug, Clone)]
486pub struct PatternHint {
487 pub context_pattern: String,
489 pub rule: String,
491 pub confidence: f64,
493 pub source: String,
495}
496
497#[derive(Debug, Default)]
499pub struct GlobalMemory {
500 pub query_patterns: HashMap<QuestionType, Vec<QueryPattern>>,
502 pub resolution_patterns: Vec<ResolutionPattern>,
504 pub tool_error_patterns: HashMap<String, ToolErrorPattern>,
506 pub tool_stats: HashMap<String, ToolStats>,
508 pub confidence_stats: ConfidenceStats,
510 pub pattern_hints: Vec<PatternHint>,
512}
513
514impl GlobalMemory {
515 pub fn new() -> Self {
517 Self::default()
518 }
519
520 pub fn add_pattern_hint(&mut self, hint: PatternHint) {
522 self.pattern_hints.push(hint);
523 }
524
525 pub fn get_pattern_hints(&self) -> &[PatternHint] {
527 &self.pattern_hints
528 }
529
530 pub fn add_pattern(&mut self, pattern: QueryPattern) {
532 self.query_patterns
533 .entry(pattern.question_type.clone())
534 .or_default()
535 .push(pattern);
536 }
537
538 pub fn get_patterns(&self, question_type: &QuestionType) -> Vec<&QueryPattern> {
540 if let Some(patterns) = self.query_patterns.get(question_type) {
541 let mut sorted: Vec<_> = patterns.iter().collect();
542 sorted.sort_by(|a, b| {
543 b.reliability()
544 .partial_cmp(&a.reliability())
545 .unwrap_or(std::cmp::Ordering::Equal)
546 });
547 sorted
548 } else {
549 Vec::new()
550 }
551 }
552
553 pub fn get_best_pattern(
555 &self,
556 question_type: &QuestionType,
557 entity_types: &[EntityType],
558 ) -> Option<&QueryPattern> {
559 self.get_patterns(question_type)
560 .into_iter()
561 .find(|p| p.matches_types(entity_types))
562 }
563
564 pub fn get_pattern_mut(&mut self, id: &str) -> Option<&mut QueryPattern> {
566 for patterns in self.query_patterns.values_mut() {
567 if let Some(pattern) = patterns.iter_mut().find(|p| p.id == id) {
568 return Some(pattern);
569 }
570 }
571 None
572 }
573
574 pub fn prune_patterns(&mut self, min_reliability: f32, min_uses: u32) {
576 for patterns in self.query_patterns.values_mut() {
577 patterns.retain(|p| {
578 let total_uses = p.success_count + p.failure_count;
579 total_uses < min_uses || p.reliability() >= min_reliability
580 });
581 }
582 }
583
584 pub fn record_tool_outcome(&mut self, outcome: &ToolOutcome) {
586 let stats = self
587 .tool_stats
588 .entry(outcome.tool_name.clone())
589 .or_default();
590
591 if outcome.success {
592 stats.record_success(outcome.retries, outcome.execution_time_ms);
593 } else {
594 stats.record_failure(outcome.retries, outcome.execution_time_ms);
595
596 if let Some(ref error_category) = outcome.error_category {
598 let key = format!("{}:{}", outcome.tool_name, error_category.category_name());
599
600 if let Some(pattern) = self.tool_error_patterns.get_mut(&key) {
601 pattern.record_occurrence();
602 } else {
603 self.tool_error_patterns.insert(
604 key,
605 ToolErrorPattern::new(&outcome.tool_name, error_category),
606 );
607 }
608 }
609 }
610 }
611
612 pub fn record_confidence(&mut self, confidence: &ResponseConfidence) {
614 self.confidence_stats.record_sample(confidence);
615 }
616
617 pub fn get_common_errors(&self, tool_name: &str) -> Vec<&ToolErrorPattern> {
619 self.tool_error_patterns
620 .values()
621 .filter(|p| p.tool_name == tool_name && p.is_frequent())
622 .collect()
623 }
624
625 pub fn get_error_prevention_hints(&self, tool_name: &str) -> Option<String> {
627 let common_errors = self.get_common_errors(tool_name);
628 if common_errors.is_empty() {
629 return None;
630 }
631
632 let hints: Vec<String> = common_errors
633 .iter()
634 .filter_map(|e| e.suggested_fix.clone())
635 .collect();
636
637 if hints.is_empty() {
638 None
639 } else {
640 Some(format!(
641 "Common pitfalls for {}: {}",
642 tool_name,
643 hints.join("; ")
644 ))
645 }
646 }
647
648 pub fn get_tool_reliability(&self, tool_name: &str) -> Option<f64> {
650 self.tool_stats.get(tool_name).map(|s| s.success_rate())
651 }
652}
653
654#[derive(Debug)]
656pub struct LearningCoordinator {
657 pub local: LocalMemory,
659 pub global: GlobalMemory,
661 _learning_rate: f32,
663 min_successes: u32,
665}
666
667impl LearningCoordinator {
668 pub fn new(conversation_id: String) -> Self {
670 Self {
671 local: LocalMemory::new(conversation_id),
672 global: GlobalMemory::new(),
673 _learning_rate: 0.3,
674 min_successes: 3,
675 }
676 }
677
678 pub fn process_query(
680 &mut self,
681 _original: &str,
682 _resolved: &str,
683 core: Option<QueryCore>,
684 turn: u32,
685 ) -> Option<&QueryPattern> {
686 self.local.current_turn = turn;
687
688 if let Some(ref c) = core {
689 let entity_types: Vec<_> = c.entities.iter().map(|(_, t)| t.clone()).collect();
691
692 if let Some(pattern) = self
694 .global
695 .get_best_pattern(&c.question_type, &entity_types)
696 {
697 return Some(pattern);
698 }
699 }
700
701 None
702 }
703
704 pub fn record_outcome(
706 &mut self,
707 pattern_id: Option<&str>,
708 success: bool,
709 result_count: usize,
710 query_core: Option<&QueryCore>,
711 execution_time_ms: u64,
712 ) {
713 if let Some(id) = pattern_id
715 && let Some(pattern) = self.global.get_pattern_mut(id)
716 {
717 if success {
718 pattern.record_success(result_count);
719 } else {
720 pattern.record_failure();
721 }
722 }
723
724 if let Some(core) = query_core {
726 self.local.record_query(
727 &core.original,
728 core.resolved.as_deref().unwrap_or(&core.original),
729 core.question_type.clone(),
730 Some(core.to_sexp()),
731 success,
732 result_count,
733 execution_time_ms,
734 );
735
736 if success && pattern_id.is_none() && result_count > 0 {
738 let _ = self.learn_pattern(core, result_count);
739 }
740 }
741 }
742
743 pub fn learn_pattern(&mut self, query: &QueryCore, result_count: usize) -> Option<String> {
745 if result_count == 0 || result_count > 100 {
747 return None;
748 }
749
750 let template = self.generalize_query(query);
752
753 let required_types: Vec<_> = query.entities.iter().map(|(_, t)| t.clone()).collect();
755
756 if let Some(existing) = self
758 .global
759 .get_best_pattern(&query.question_type, &required_types)
760 && existing.template == template
761 {
762 return None; }
764
765 let mut pattern = QueryPattern::new(query.question_type.clone(), template, required_types);
767 pattern.record_success(result_count);
768
769 let id = pattern.id.clone();
770 self.global.add_pattern(pattern);
771
772 Some(id)
773 }
774
775 fn generalize_query(&self, query: &QueryCore) -> String {
777 let mut template = query.to_sexp();
778
779 for (name, entity_type) in &query.entities {
781 let placeholder = format!("${{{}}}", entity_type.as_str().to_uppercase());
782 template = template.replace(&format!("\"{}\"", name), &placeholder);
783 }
784
785 template
786 }
787
788 pub fn get_context_for_prompt(&self) -> String {
790 let mut context = String::new();
791
792 let frequent = self.local.get_frequent_entities(5);
794 if !frequent.is_empty() {
795 context.push_str("Frequently referenced entities:\n");
796 for entity in frequent {
797 context.push_str(&format!(
798 "- {} ({}): {} mentions\n",
799 entity.name,
800 entity.entity_type.as_str(),
801 entity.frequency()
802 ));
803 }
804 context.push('\n');
805 }
806
807 for question_type in [
809 QuestionType::Definition,
810 QuestionType::Location,
811 QuestionType::Dependency,
812 ] {
813 let patterns = self.global.get_patterns(&question_type);
814 let good_patterns: Vec<_> = patterns
815 .iter()
816 .filter(|p| p.reliability() > 0.7 && p.success_count >= self.min_successes)
817 .take(2)
818 .collect();
819
820 if !good_patterns.is_empty() {
821 context.push_str(&format!("Effective {:?} patterns:\n", question_type));
822 for pattern in good_patterns {
823 context.push_str(&format!(
824 "- {} ({}% reliable)\n",
825 pattern.template,
826 (pattern.reliability() * 100.0) as u32
827 ));
828 }
829 context.push('\n');
830 }
831 }
832
833 context
834 }
835
836 pub fn get_promotable_patterns(
840 &self,
841 min_reliability: f32,
842 min_uses: u32,
843 ) -> Vec<&QueryPattern> {
844 let mut promotable = Vec::new();
845
846 for patterns in self.global.query_patterns.values() {
847 for pattern in patterns {
848 let total_uses = pattern.success_count + pattern.failure_count;
849 if pattern.reliability() >= min_reliability && total_uses >= min_uses {
850 promotable.push(pattern);
851 }
852 }
853 }
854
855 promotable.sort_by(|a, b| {
857 b.reliability()
858 .partial_cmp(&a.reliability())
859 .unwrap_or(std::cmp::Ordering::Equal)
860 });
861
862 promotable
863 }
864
865 pub fn get_stats(&self) -> LearningStats {
867 let total_patterns: usize = self.global.query_patterns.values().map(|v| v.len()).sum();
868
869 let mut total_successes = 0u32;
870 let mut total_failures = 0u32;
871 for patterns in self.global.query_patterns.values() {
872 for pattern in patterns {
873 total_successes += pattern.success_count;
874 total_failures += pattern.failure_count;
875 }
876 }
877
878 LearningStats {
879 session_queries: self.local.query_history.len(),
880 session_entities: self.local.entities.len(),
881 session_coreferences: self.local.coreference_log.len(),
882 global_patterns: total_patterns,
883 global_successes: total_successes,
884 global_failures: total_failures,
885 overall_reliability: if total_successes + total_failures > 0 {
886 total_successes as f32 / (total_successes + total_failures) as f32
887 } else {
888 0.5
889 },
890 }
891 }
892
893 pub fn record_tool_outcome(&mut self, outcome: &ToolOutcome) {
899 self.global.record_tool_outcome(outcome);
900 }
901
902 pub fn record_confidence(&mut self, confidence: &ResponseConfidence) {
904 self.global.record_confidence(confidence);
905 }
906
907 pub fn get_error_prevention_hints(&self, tool_name: &str) -> Option<String> {
909 self.global.get_error_prevention_hints(tool_name)
910 }
911
912 pub fn get_tool_reliability(&self, tool_name: &str) -> Option<f64> {
914 self.global.get_tool_reliability(tool_name)
915 }
916
917 pub fn get_common_errors(&self, tool_name: &str) -> Vec<&ToolErrorPattern> {
919 self.global.get_common_errors(tool_name)
920 }
921
922 pub fn get_avg_confidence(&self) -> f64 {
924 self.global.confidence_stats.avg_confidence()
925 }
926
927 pub fn has_confidence_issues(&self) -> bool {
929 self.global.confidence_stats.low_confidence_ratio() > 0.3
930 }
931}
932
933#[derive(Debug, Clone)]
935pub struct LearningStats {
936 pub session_queries: usize,
938 pub session_entities: usize,
940 pub session_coreferences: usize,
942 pub global_patterns: usize,
944 pub global_successes: u32,
946 pub global_failures: u32,
948 pub overall_reliability: f32,
950}
951
952#[cfg(test)]
953mod tests {
954 use super::*;
955
956 #[test]
957 fn test_tracked_entity() {
958 let mut entity = TrackedEntity::new("main.rs".to_string(), EntityType::File, 1);
959 assert_eq!(entity.frequency(), 1);
960
961 entity.record_mention(2);
962 entity.record_mention(3);
963 assert_eq!(entity.frequency(), 3);
964
965 entity.record_mention(2);
967 assert_eq!(entity.frequency(), 3);
968 }
969
970 #[test]
971 fn test_local_memory() {
972 let mut local = LocalMemory::new("test-conv".to_string());
973
974 local.track_entity("main.rs", EntityType::File);
975 local.next_turn();
976 local.track_entity("config.toml", EntityType::File);
977 local.track_entity("main.rs", EntityType::File); assert_eq!(local.entities.len(), 2);
980 assert_eq!(local.entities["main.rs"].frequency(), 2);
981
982 assert_eq!(local.focus_stack[0], "main.rs"); }
985
986 #[test]
987 fn test_query_pattern_reliability() {
988 let mut pattern =
989 QueryPattern::new(QuestionType::Definition, "template".to_string(), vec![]);
990
991 assert_eq!(pattern.reliability(), 0.5); pattern.record_success(5);
994 pattern.record_success(3);
995 pattern.record_failure();
996
997 assert!((pattern.reliability() - 0.666).abs() < 0.01);
999 }
1000
1001 #[test]
1002 fn test_global_memory_patterns() {
1003 let mut global = GlobalMemory::new();
1004
1005 let mut pattern1 =
1006 QueryPattern::new(QuestionType::Definition, "template1".to_string(), vec![]);
1007 pattern1.record_success(5);
1008 pattern1.record_success(5);
1009
1010 let mut pattern2 =
1011 QueryPattern::new(QuestionType::Definition, "template2".to_string(), vec![]);
1012 pattern2.record_failure();
1013
1014 global.add_pattern(pattern1);
1015 global.add_pattern(pattern2);
1016
1017 let patterns = global.get_patterns(&QuestionType::Definition);
1019 assert_eq!(patterns.len(), 2);
1020 assert!(patterns[0].reliability() > patterns[1].reliability());
1021 }
1022
1023 #[test]
1024 fn test_learning_coordinator() {
1025 let mut coordinator = LearningCoordinator::new("test-conv".to_string());
1026
1027 let core = QueryCore::new(
1029 QuestionType::Definition,
1030 crate::query_core::QueryExpr::var("x"),
1031 vec![("main.rs".to_string(), EntityType::File)],
1032 "What is main.rs?".to_string(),
1033 );
1034
1035 coordinator.record_outcome(None, true, 1, Some(&core), 0);
1036
1037 let stats = coordinator.get_stats();
1038 assert_eq!(stats.session_queries, 1);
1039 assert_eq!(stats.global_patterns, 1); }
1041
1042 #[test]
1043 fn test_pattern_matching() {
1044 let pattern = QueryPattern::new(
1045 QuestionType::Definition,
1046 "template".to_string(),
1047 vec![EntityType::File],
1048 );
1049
1050 assert!(pattern.matches_types(&[EntityType::File]));
1051 assert!(pattern.matches_types(&[EntityType::File, EntityType::Function]));
1052 assert!(!pattern.matches_types(&[EntityType::Function]));
1053 }
1054
1055 #[test]
1056 fn test_prune_patterns() {
1057 let mut global = GlobalMemory::new();
1058
1059 let mut good_pattern =
1060 QueryPattern::new(QuestionType::Definition, "good".to_string(), vec![]);
1061 for _ in 0..10 {
1062 good_pattern.record_success(5);
1063 }
1064
1065 let mut bad_pattern =
1066 QueryPattern::new(QuestionType::Definition, "bad".to_string(), vec![]);
1067 for _ in 0..10 {
1068 bad_pattern.record_failure();
1069 }
1070
1071 global.add_pattern(good_pattern);
1072 global.add_pattern(bad_pattern);
1073
1074 assert_eq!(global.get_patterns(&QuestionType::Definition).len(), 2);
1075
1076 global.prune_patterns(0.5, 5);
1077
1078 assert_eq!(global.get_patterns(&QuestionType::Definition).len(), 1);
1080 }
1081
1082 #[test]
1083 fn test_get_context_for_prompt() {
1084 let mut coordinator = LearningCoordinator::new("test".to_string());
1085
1086 coordinator.local.track_entity("main.rs", EntityType::File);
1087 coordinator.local.track_entity("main.rs", EntityType::File);
1088 coordinator
1089 .local
1090 .track_entity("config.toml", EntityType::File);
1091
1092 let context = coordinator.get_context_for_prompt();
1093
1094 assert!(context.contains("main.rs") || context.contains("Frequently"));
1096 }
1097}