1use std::collections::{HashMap, HashSet, VecDeque};
4
5use crate::engine::tokenizer::Tokenizer;
6use crate::graph::MemoryGraph;
7use crate::index::cosine_similarity;
8use crate::types::{AmemResult, CognitiveEvent, Edge, EdgeType, EventType};
9
10#[derive(Debug, Clone)]
16pub struct BeliefRevisionParams {
17 pub hypothesis: String,
19 pub hypothesis_vec: Option<Vec<f32>>,
21 pub contradiction_threshold: f32,
23 pub max_depth: u32,
25 pub hypothesis_confidence: f32,
27}
28
29#[derive(Debug, Clone)]
31pub struct ContradictedNode {
32 pub node_id: u64,
34 pub contradiction_strength: f32,
36 pub reason: String,
38}
39
40#[derive(Debug, Clone)]
42pub struct WeakenedNode {
43 pub node_id: u64,
45 pub original_confidence: f32,
47 pub revised_confidence: f32,
49 pub depth: u32,
51}
52
53#[derive(Debug, Clone)]
55pub struct CascadeStep {
56 pub node_id: u64,
58 pub via_edge: EdgeType,
60 pub from_node: u64,
62 pub depth: u32,
64}
65
66#[derive(Debug, Clone)]
68pub struct CascadeEffect {
69 pub node_id: u64,
71 pub weakening: f32,
73}
74
75#[derive(Debug, Clone)]
77pub struct RevisionReport {
78 pub contradicted: Vec<ContradictedNode>,
80 pub weakened: Vec<WeakenedNode>,
82 pub invalidated_decisions: Vec<u64>,
84 pub total_affected: usize,
86 pub cascade: Vec<CascadeStep>,
88}
89
90#[derive(Debug, Clone, Copy, PartialEq, Eq)]
96pub enum GapSeverity {
97 HighestImpact,
99 LowestConfidence,
101 MostRecent,
103}
104
105#[derive(Debug, Clone, Copy, PartialEq, Eq)]
107pub enum GapType {
108 UnjustifiedDecision,
110 SingleSourceInference,
112 LowConfidenceFoundation,
114 UnstableKnowledge,
116 StaleEvidence,
118}
119
120#[derive(Debug, Clone)]
122pub struct GapDetectionParams {
123 pub confidence_threshold: f32,
125 pub min_support_count: u32,
127 pub max_results: usize,
129 pub session_range: Option<(u32, u32)>,
131 pub sort_by: GapSeverity,
133}
134
135#[derive(Debug, Clone)]
137pub struct Gap {
138 pub node_id: u64,
140 pub gap_type: GapType,
142 pub severity: f32,
144 pub description: String,
146 pub downstream_count: usize,
148}
149
150#[derive(Debug, Clone)]
152pub struct GapSummary {
153 pub total_gaps: usize,
155 pub unjustified_decisions: usize,
157 pub single_source_inferences: usize,
159 pub low_confidence_foundations: usize,
161 pub unstable_knowledge: usize,
163 pub stale_evidence: usize,
165 pub health_score: f32,
167}
168
169#[derive(Debug, Clone)]
171pub struct GapReport {
172 pub gaps: Vec<Gap>,
174 pub summary: GapSummary,
176}
177
178#[derive(Debug, Clone)]
184pub enum AnalogicalAnchor {
185 Node(u64),
187 Vector(Vec<f32>),
189}
190
191#[derive(Debug, Clone)]
193pub struct PatternMatch {
194 pub event_type_counts: HashMap<u8, usize>,
196 pub edge_type_counts: HashMap<u8, usize>,
198 pub causal_chain_depth: u32,
200 pub branching_factor: f32,
202}
203
204#[derive(Debug, Clone)]
206pub struct Analogy {
207 pub center_id: u64,
209 pub structural_similarity: f32,
211 pub content_similarity: f32,
213 pub combined_score: f32,
215 pub pattern: PatternMatch,
217 pub subgraph_nodes: Vec<u64>,
219}
220
221#[derive(Debug, Clone)]
223pub struct AnalogicalParams {
224 pub anchor: AnalogicalAnchor,
226 pub context_depth: u32,
228 pub max_results: usize,
230 pub min_similarity: f32,
232 pub exclude_sessions: Vec<u32>,
234}
235
236#[derive(Debug, Clone, Copy, PartialEq, Eq)]
242pub enum ChangeType {
243 Initial,
245 Refined,
247 Corrected,
249 Contradicted,
251 Reinforced,
253}
254
255#[derive(Debug, Clone)]
257pub struct BeliefSnapshot {
258 pub node_id: u64,
260 pub session_id: u32,
262 pub created_at: u64,
264 pub confidence: f32,
266 pub content_preview: String,
268 pub change_type: ChangeType,
270}
271
272#[derive(Debug, Clone)]
274pub struct BeliefTimeline {
275 pub snapshots: Vec<BeliefSnapshot>,
277 pub change_count: usize,
279 pub correction_count: usize,
281 pub contradiction_count: usize,
283}
284
285#[derive(Debug, Clone)]
287pub struct DriftParams {
288 pub topic: String,
290 pub topic_vec: Option<Vec<f32>>,
292 pub max_results: usize,
294 pub min_relevance: f32,
296}
297
298#[derive(Debug, Clone)]
300pub struct DriftReport {
301 pub timelines: Vec<BeliefTimeline>,
303 pub stability: f32,
305 pub likely_to_change: bool,
307}
308
309const NEGATION_WORDS: &[&str] = &[
314 "not",
315 "no",
316 "never",
317 "neither",
318 "nor",
319 "none",
320 "nothing",
321 "nowhere",
322 "nobody",
323 "cannot",
324 "can't",
325 "don't",
326 "doesn't",
327 "didn't",
328 "won't",
329 "wouldn't",
330 "shouldn't",
331 "couldn't",
332 "isn't",
333 "aren't",
334 "wasn't",
335 "weren't",
336 "hasn't",
337 "haven't",
338 "hadn't",
339 "false",
340 "incorrect",
341 "wrong",
342 "invalid",
343 "untrue",
344 "deny",
345 "denied",
346 "disagree",
347 "unlike",
348 "opposite",
349 "contrary",
350 "instead",
351 "rather",
352];
353
354impl super::query::QueryEngine {
359 pub fn belief_revision(
370 &self,
371 graph: &MemoryGraph,
372 params: BeliefRevisionParams,
373 ) -> AmemResult<RevisionReport> {
374 let tokenizer = Tokenizer::new();
375 let hypothesis_terms: HashSet<String> =
376 tokenizer.tokenize(¶ms.hypothesis).into_iter().collect();
377
378 if hypothesis_terms.is_empty() && params.hypothesis_vec.is_none() {
379 return Ok(RevisionReport {
380 contradicted: Vec::new(),
381 weakened: Vec::new(),
382 invalidated_decisions: Vec::new(),
383 total_affected: 0,
384 cascade: Vec::new(),
385 });
386 }
387
388 let negation_set: HashSet<&str> = NEGATION_WORDS.iter().copied().collect();
390
391 let mut contradicted: Vec<ContradictedNode> = Vec::new();
398 let mut contradicted_ids: HashSet<u64> = HashSet::new();
399
400 for node in graph.nodes() {
401 let node_terms: HashSet<String> =
402 tokenizer.tokenize(&node.content).into_iter().collect();
403
404 let overlap_count = hypothesis_terms.intersection(&node_terms).count();
406 let text_sim = if hypothesis_terms.is_empty() {
407 0.0
408 } else {
409 overlap_count as f32 / hypothesis_terms.len() as f32
410 };
411
412 let vec_sim = if let Some(ref hvec) = params.hypothesis_vec {
414 if !node.feature_vec.iter().all(|&x| x == 0.0) {
415 cosine_similarity(hvec, &node.feature_vec)
416 } else {
417 0.0
418 }
419 } else {
420 0.0
421 };
422
423 let relevance = if params.hypothesis_vec.is_some() {
425 0.5 * text_sim + 0.5 * vec_sim
426 } else {
427 text_sim
428 };
429
430 if relevance < params.contradiction_threshold {
431 continue;
432 }
433
434 let node_content_lower = node.content.to_lowercase();
437 let has_negation = negation_set
438 .iter()
439 .any(|neg| node_content_lower.contains(neg));
440
441 let has_contradicts_edge = graph
443 .edges_from(node.id)
444 .iter()
445 .any(|e| e.edge_type == EdgeType::Contradicts)
446 || graph
447 .edges_to(node.id)
448 .iter()
449 .any(|e| e.edge_type == EdgeType::Contradicts);
450
451 let is_correction = node.event_type == EventType::Correction;
453
454 if has_negation || has_contradicts_edge || is_correction {
455 let strength = relevance
456 * if has_contradicts_edge { 1.0 } else { 0.8 }
457 * if has_negation { 1.0 } else { 0.7 };
458
459 let reason = if has_contradicts_edge {
460 "explicit Contradicts edge in graph".to_string()
461 } else if has_negation {
462 "negation detected in content".to_string()
463 } else {
464 "correction event with high similarity".to_string()
465 };
466
467 contradicted_ids.insert(node.id);
468 contradicted.push(ContradictedNode {
469 node_id: node.id,
470 contradiction_strength: strength.clamp(0.0, 1.0),
471 reason,
472 });
473 }
474 }
475
476 contradicted.sort_by(|a, b| {
478 b.contradiction_strength
479 .partial_cmp(&a.contradiction_strength)
480 .unwrap_or(std::cmp::Ordering::Equal)
481 });
482
483 let mut weakened: Vec<WeakenedNode> = Vec::new();
485 let mut cascade: Vec<CascadeStep> = Vec::new();
486 let mut visited: HashSet<u64> = contradicted_ids.clone();
487 let mut queue: VecDeque<(u64, u32, f32)> = VecDeque::new();
488
489 for cn in &contradicted {
491 queue.push_back((cn.node_id, 0, cn.contradiction_strength));
492 }
493
494 while let Some((current_id, depth, weakening_factor)) = queue.pop_front() {
495 if depth >= params.max_depth {
496 continue;
497 }
498
499 for edge in graph.edges_to(current_id) {
504 if edge.edge_type != EdgeType::CausedBy && edge.edge_type != EdgeType::Supports {
505 continue;
506 }
507 let dependent_id = edge.source_id;
508 if visited.contains(&dependent_id) {
509 continue;
510 }
511 visited.insert(dependent_id);
512
513 if let Some(dep_node) = graph.get_node(dependent_id) {
514 let decay = 0.7f32.powi(depth as i32 + 1);
516 let effective_weakening = weakening_factor * edge.weight * decay;
517 let revised = (dep_node.confidence - effective_weakening).clamp(0.0, 1.0);
518
519 weakened.push(WeakenedNode {
520 node_id: dependent_id,
521 original_confidence: dep_node.confidence,
522 revised_confidence: revised,
523 depth: depth + 1,
524 });
525
526 cascade.push(CascadeStep {
527 node_id: dependent_id,
528 via_edge: edge.edge_type,
529 from_node: current_id,
530 depth: depth + 1,
531 });
532
533 queue.push_back((dependent_id, depth + 1, effective_weakening));
534 }
535 }
536 }
537
538 let mut invalidated_decisions: Vec<u64> = Vec::new();
540 let affected_ids: HashSet<u64> = contradicted_ids
541 .iter()
542 .chain(weakened.iter().map(|w| &w.node_id))
543 .copied()
544 .collect();
545
546 for &node_id in &affected_ids {
547 if let Some(node) = graph.get_node(node_id) {
548 if node.event_type == EventType::Decision {
549 invalidated_decisions.push(node_id);
550 }
551 }
552 }
553 invalidated_decisions.sort_unstable();
554 invalidated_decisions.dedup();
555
556 let total_affected = affected_ids.len();
557
558 Ok(RevisionReport {
559 contradicted,
560 weakened,
561 invalidated_decisions,
562 total_affected,
563 cascade,
564 })
565 }
566
567 pub fn gap_detection(
576 &self,
577 graph: &MemoryGraph,
578 params: GapDetectionParams,
579 ) -> AmemResult<GapReport> {
580 let session_filter: Option<(u32, u32)> = params.session_range;
581 let mut gaps: Vec<Gap> = Vec::new();
582
583 for node in graph.nodes() {
584 if let Some((start, end)) = session_filter {
586 if node.session_id < start || node.session_id > end {
587 continue;
588 }
589 }
590
591 if node.event_type == EventType::Decision {
593 let incoming = graph.edges_to(node.id);
594 let has_justification = incoming.iter().any(|e| {
595 e.edge_type == EdgeType::CausedBy || e.edge_type == EdgeType::Supports
596 });
597 if !has_justification {
598 let downstream = self.count_downstream(graph, node.id);
599 gaps.push(Gap {
600 node_id: node.id,
601 gap_type: GapType::UnjustifiedDecision,
602 severity: 0.9, description: format!(
604 "Decision node {} has no CausedBy or Supports edges",
605 node.id
606 ),
607 downstream_count: downstream,
608 });
609 }
610 }
611
612 if node.event_type == EventType::Inference {
614 let incoming = graph.edges_to(node.id);
615 let support_count = incoming
616 .iter()
617 .filter(|e| e.edge_type == EdgeType::Supports)
618 .count();
619 if (support_count as u32) < params.min_support_count {
620 let downstream = self.count_downstream(graph, node.id);
621 gaps.push(Gap {
622 node_id: node.id,
623 gap_type: GapType::SingleSourceInference,
624 severity: 0.7,
625 description: format!(
626 "Inference node {} has only {} Supports edge(s), needs at least {}",
627 node.id, support_count, params.min_support_count
628 ),
629 downstream_count: downstream,
630 });
631 }
632 }
633
634 if (node.event_type == EventType::Fact || node.event_type == EventType::Inference)
636 && node.confidence < params.confidence_threshold
637 {
638 let dependents = graph.edges_to(node.id);
640 let has_dependents = dependents.iter().any(|e| {
641 e.edge_type == EdgeType::CausedBy || e.edge_type == EdgeType::Supports
642 });
643 if has_dependents {
644 let downstream = self.count_downstream(graph, node.id);
645 gaps.push(Gap {
646 node_id: node.id,
647 gap_type: GapType::LowConfidenceFoundation,
648 severity: 1.0 - node.confidence, description: format!(
650 "Node {} has confidence {:.2} (below {:.2}) and is depended upon",
651 node.id, node.confidence, params.confidence_threshold
652 ),
653 downstream_count: downstream,
654 });
655 }
656 }
657
658 {
660 let supersedes_count = self.count_supersedes_chain(graph, node.id);
661 if supersedes_count >= 3 {
662 let downstream = self.count_downstream(graph, node.id);
663 gaps.push(Gap {
664 node_id: node.id,
665 gap_type: GapType::UnstableKnowledge,
666 severity: (supersedes_count as f32 / 5.0).clamp(0.0, 1.0),
667 description: format!(
668 "Node {} has been superseded {} times (unstable)",
669 node.id, supersedes_count
670 ),
671 downstream_count: downstream,
672 });
673 }
674 }
675
676 if node.decay_score < 0.2 && node.event_type == EventType::Fact {
678 let has_dependents = graph.edges_to(node.id).iter().any(|e| {
679 e.edge_type == EdgeType::CausedBy || e.edge_type == EdgeType::Supports
680 });
681 if has_dependents {
682 let downstream = self.count_downstream(graph, node.id);
683 gaps.push(Gap {
684 node_id: node.id,
685 gap_type: GapType::StaleEvidence,
686 severity: 1.0 - node.decay_score,
687 description: format!(
688 "Fact node {} has decay score {:.2} and is depended upon",
689 node.id, node.decay_score
690 ),
691 downstream_count: downstream,
692 });
693 }
694 }
695 }
696
697 match params.sort_by {
699 GapSeverity::HighestImpact => {
700 gaps.sort_by(|a, b| b.downstream_count.cmp(&a.downstream_count));
701 }
702 GapSeverity::LowestConfidence => {
703 gaps.sort_by(|a, b| {
704 b.severity
705 .partial_cmp(&a.severity)
706 .unwrap_or(std::cmp::Ordering::Equal)
707 });
708 }
709 GapSeverity::MostRecent => {
710 gaps.sort_by(|a, b| {
711 let ts_a = graph.get_node(a.node_id).map(|n| n.created_at).unwrap_or(0);
712 let ts_b = graph.get_node(b.node_id).map(|n| n.created_at).unwrap_or(0);
713 ts_b.cmp(&ts_a)
714 });
715 }
716 }
717
718 let total_gaps = gaps.len();
720 let unjustified_decisions = gaps
721 .iter()
722 .filter(|g| g.gap_type == GapType::UnjustifiedDecision)
723 .count();
724 let single_source_inferences = gaps
725 .iter()
726 .filter(|g| g.gap_type == GapType::SingleSourceInference)
727 .count();
728 let low_confidence_foundations = gaps
729 .iter()
730 .filter(|g| g.gap_type == GapType::LowConfidenceFoundation)
731 .count();
732 let unstable_knowledge = gaps
733 .iter()
734 .filter(|g| g.gap_type == GapType::UnstableKnowledge)
735 .count();
736 let stale_evidence = gaps
737 .iter()
738 .filter(|g| g.gap_type == GapType::StaleEvidence)
739 .count();
740
741 let total_nodes = graph.node_count();
742 let health_score = if total_nodes > 0 {
743 1.0 - (total_gaps as f32 / total_nodes as f32).clamp(0.0, 1.0)
744 } else {
745 1.0
746 };
747
748 let summary = GapSummary {
749 total_gaps,
750 unjustified_decisions,
751 single_source_inferences,
752 low_confidence_foundations,
753 unstable_knowledge,
754 stale_evidence,
755 health_score,
756 };
757
758 gaps.truncate(params.max_results);
759
760 Ok(GapReport { gaps, summary })
761 }
762
763 fn count_downstream(&self, graph: &MemoryGraph, node_id: u64) -> usize {
765 let mut visited: HashSet<u64> = HashSet::new();
766 let mut queue: VecDeque<u64> = VecDeque::new();
767 visited.insert(node_id);
768 queue.push_back(node_id);
769
770 while let Some(current) = queue.pop_front() {
771 for edge in graph.edges_to(current) {
772 if (edge.edge_type == EdgeType::CausedBy || edge.edge_type == EdgeType::Supports)
773 && !visited.contains(&edge.source_id)
774 {
775 visited.insert(edge.source_id);
776 queue.push_back(edge.source_id);
777 }
778 }
779 }
780
781 visited.len().saturating_sub(1)
783 }
784
785 fn count_supersedes_chain(&self, graph: &MemoryGraph, node_id: u64) -> usize {
789 let mut count = 0usize;
790 let mut current = node_id;
791 let mut visited: HashSet<u64> = HashSet::new();
792 visited.insert(current);
793
794 loop {
796 let mut found = false;
797 for edge in graph.edges_to(current) {
798 if edge.edge_type == EdgeType::Supersedes && !visited.contains(&edge.source_id) {
799 visited.insert(edge.source_id);
800 current = edge.source_id;
801 count += 1;
802 found = true;
803 break;
804 }
805 }
806 if !found {
807 break;
808 }
809 }
810
811 current = node_id;
813 loop {
814 let mut found = false;
815 for edge in graph.edges_from(current) {
816 if edge.edge_type == EdgeType::Supersedes && !visited.contains(&edge.target_id) {
817 visited.insert(edge.target_id);
818 current = edge.target_id;
819 count += 1;
820 found = true;
821 break;
822 }
823 }
824 if !found {
825 break;
826 }
827 }
828
829 count
830 }
831
832 pub fn analogical(
847 &self,
848 graph: &MemoryGraph,
849 params: AnalogicalParams,
850 ) -> AmemResult<Vec<Analogy>> {
851 let exclude_sessions: HashSet<u32> = params.exclude_sessions.iter().copied().collect();
852
853 let (anchor_center, anchor_vec) = match ¶ms.anchor {
855 AnalogicalAnchor::Node(id) => {
856 let node = graph
857 .get_node(*id)
858 .ok_or(crate::types::AmemError::NodeNotFound(*id))?;
859 (*id, node.feature_vec.clone())
860 }
861 AnalogicalAnchor::Vector(v) => {
862 let mut best_id = 0u64;
864 let mut best_sim = -1.0f32;
865 for node in graph.nodes() {
866 if node.feature_vec.iter().all(|&x| x == 0.0) {
867 continue;
868 }
869 let sim = cosine_similarity(v, &node.feature_vec);
870 if sim > best_sim {
871 best_sim = sim;
872 best_id = node.id;
873 }
874 }
875 if best_sim < 0.0 {
876 return Ok(Vec::new());
877 }
878 (best_id, v.clone())
879 }
880 };
881
882 let anchor_subgraph = self.context(graph, anchor_center, params.context_depth)?;
883 let anchor_fp =
884 self.structural_fingerprint(graph, &anchor_subgraph.nodes, &anchor_subgraph.edges);
885 let anchor_session = graph
886 .get_node(anchor_center)
887 .map(|n| n.session_id)
888 .unwrap_or(0);
889
890 let mut analogies: Vec<Analogy> = Vec::new();
892 let anchor_node_set: HashSet<u64> = anchor_subgraph.nodes.iter().map(|n| n.id).collect();
893
894 for node in graph.nodes() {
895 if anchor_node_set.contains(&node.id) {
897 continue;
898 }
899 if exclude_sessions.contains(&node.session_id) {
901 continue;
902 }
903 if node.session_id == anchor_session
905 && graph.nodes().len() > anchor_subgraph.nodes.len()
906 {
907 continue;
908 }
909
910 let candidate_subgraph = match self.context(graph, node.id, params.context_depth) {
912 Ok(sg) => sg,
913 Err(_) => continue,
914 };
915
916 let candidate_fp = self.structural_fingerprint(
918 graph,
919 &candidate_subgraph.nodes,
920 &candidate_subgraph.edges,
921 );
922 let structural_sim = self.compare_fingerprints(&anchor_fp, &candidate_fp);
923
924 let content_sim = if !anchor_vec.iter().all(|&x| x == 0.0)
926 && !node.feature_vec.iter().all(|&x| x == 0.0)
927 {
928 cosine_similarity(&anchor_vec, &node.feature_vec).max(0.0)
929 } else {
930 0.0
931 };
932
933 let combined = 0.6 * structural_sim + 0.4 * content_sim;
934
935 if combined >= params.min_similarity {
936 analogies.push(Analogy {
937 center_id: node.id,
938 structural_similarity: structural_sim,
939 content_similarity: content_sim,
940 combined_score: combined,
941 pattern: candidate_fp,
942 subgraph_nodes: candidate_subgraph.nodes.iter().map(|n| n.id).collect(),
943 });
944 }
945 }
946
947 analogies.sort_by(|a, b| {
949 b.combined_score
950 .partial_cmp(&a.combined_score)
951 .unwrap_or(std::cmp::Ordering::Equal)
952 });
953 analogies.truncate(params.max_results);
954
955 Ok(analogies)
956 }
957
958 fn structural_fingerprint(
960 &self,
961 _graph: &MemoryGraph,
962 nodes: &[CognitiveEvent],
963 edges: &[Edge],
964 ) -> PatternMatch {
965 let mut event_type_counts: HashMap<u8, usize> = HashMap::new();
967 for node in nodes {
968 *event_type_counts.entry(node.event_type as u8).or_insert(0) += 1;
969 }
970
971 let mut edge_type_counts: HashMap<u8, usize> = HashMap::new();
973 for edge in edges {
974 *edge_type_counts.entry(edge.edge_type as u8).or_insert(0) += 1;
975 }
976
977 let node_set: HashSet<u64> = nodes.iter().map(|n| n.id).collect();
979 let causal_edges: Vec<&Edge> = edges
980 .iter()
981 .filter(|e| e.edge_type == EdgeType::CausedBy)
982 .collect();
983
984 let causal_chain_depth = if causal_edges.is_empty() {
985 0
986 } else {
987 let mut causal_adj: HashMap<u64, Vec<u64>> = HashMap::new();
989 for e in &causal_edges {
990 if node_set.contains(&e.source_id) && node_set.contains(&e.target_id) {
991 causal_adj.entry(e.source_id).or_default().push(e.target_id);
992 }
993 }
994 let mut max_depth = 0u32;
996 for &start_id in node_set.iter() {
997 let mut visited_local: HashSet<u64> = HashSet::new();
998 let mut q: VecDeque<(u64, u32)> = VecDeque::new();
999 visited_local.insert(start_id);
1000 q.push_back((start_id, 0));
1001 while let Some((cur, d)) = q.pop_front() {
1002 max_depth = max_depth.max(d);
1003 if let Some(neighbors) = causal_adj.get(&cur) {
1004 for &nb in neighbors {
1005 if !visited_local.contains(&nb) {
1006 visited_local.insert(nb);
1007 q.push_back((nb, d + 1));
1008 }
1009 }
1010 }
1011 }
1012 }
1013 max_depth
1014 };
1015
1016 let branching_factor = if nodes.is_empty() {
1018 0.0
1019 } else {
1020 let mut out_counts: HashMap<u64, usize> = HashMap::new();
1021 for n in nodes {
1022 out_counts.insert(n.id, 0);
1023 }
1024 for e in edges {
1025 if let Some(c) = out_counts.get_mut(&e.source_id) {
1026 *c += 1;
1027 }
1028 }
1029 let total: usize = out_counts.values().sum();
1030 total as f32 / nodes.len() as f32
1031 };
1032
1033 PatternMatch {
1034 event_type_counts,
1035 edge_type_counts,
1036 causal_chain_depth,
1037 branching_factor,
1038 }
1039 }
1040
1041 fn compare_fingerprints(&self, a: &PatternMatch, b: &PatternMatch) -> f32 {
1043 let et_sim = self.map_cosine_similarity(&a.event_type_counts, &b.event_type_counts);
1045
1046 let edge_sim = self.map_cosine_similarity(&a.edge_type_counts, &b.edge_type_counts);
1048
1049 let max_chain = a.causal_chain_depth.max(b.causal_chain_depth).max(1) as f32;
1051 let chain_sim =
1052 1.0 - (a.causal_chain_depth as f32 - b.causal_chain_depth as f32).abs() / max_chain;
1053
1054 let max_bf = a.branching_factor.max(b.branching_factor).max(0.01);
1056 let bf_sim = 1.0 - (a.branching_factor - b.branching_factor).abs() / max_bf;
1057
1058 0.3 * et_sim + 0.3 * edge_sim + 0.2 * chain_sim + 0.2 * bf_sim
1060 }
1061
1062 fn map_cosine_similarity(&self, a: &HashMap<u8, usize>, b: &HashMap<u8, usize>) -> f32 {
1064 let all_keys: HashSet<u8> = a.keys().chain(b.keys()).copied().collect();
1065 if all_keys.is_empty() {
1066 return 1.0; }
1068
1069 let mut dot = 0.0f64;
1070 let mut norm_a = 0.0f64;
1071 let mut norm_b = 0.0f64;
1072
1073 for &key in &all_keys {
1074 let va = *a.get(&key).unwrap_or(&0) as f64;
1075 let vb = *b.get(&key).unwrap_or(&0) as f64;
1076 dot += va * vb;
1077 norm_a += va * va;
1078 norm_b += vb * vb;
1079 }
1080
1081 let denom = norm_a.sqrt() * norm_b.sqrt();
1082 if denom < 1e-12 {
1083 0.0
1084 } else {
1085 (dot / denom) as f32
1086 }
1087 }
1088
1089 pub fn drift_detection(
1100 &self,
1101 graph: &MemoryGraph,
1102 params: DriftParams,
1103 ) -> AmemResult<DriftReport> {
1104 let tokenizer = Tokenizer::new();
1105 let topic_terms: HashSet<String> = tokenizer.tokenize(¶ms.topic).into_iter().collect();
1106
1107 if topic_terms.is_empty() && params.topic_vec.is_none() {
1108 return Ok(DriftReport {
1109 timelines: Vec::new(),
1110 stability: 1.0,
1111 likely_to_change: false,
1112 });
1113 }
1114
1115 let mut relevant: Vec<(u64, f32)> = Vec::new(); for node in graph.nodes() {
1119 let node_terms: HashSet<String> =
1120 tokenizer.tokenize(&node.content).into_iter().collect();
1121
1122 let overlap = topic_terms.intersection(&node_terms).count();
1124 let text_sim = if topic_terms.is_empty() {
1125 0.0
1126 } else {
1127 overlap as f32 / topic_terms.len() as f32
1128 };
1129
1130 let vec_sim = if let Some(ref tvec) = params.topic_vec {
1132 if !node.feature_vec.iter().all(|&x| x == 0.0) {
1133 cosine_similarity(tvec, &node.feature_vec).max(0.0)
1134 } else {
1135 0.0
1136 }
1137 } else {
1138 0.0
1139 };
1140
1141 let relevance = if params.topic_vec.is_some() {
1142 0.5 * text_sim + 0.5 * vec_sim
1143 } else {
1144 text_sim
1145 };
1146
1147 if relevance >= params.min_relevance {
1148 relevant.push((node.id, relevance));
1149 }
1150 }
1151
1152 if relevant.is_empty() {
1153 return Ok(DriftReport {
1154 timelines: Vec::new(),
1155 stability: 1.0,
1156 likely_to_change: false,
1157 });
1158 }
1159
1160 relevant.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1162
1163 let relevant_ids: HashSet<u64> = relevant.iter().map(|(id, _)| *id).collect();
1164
1165 let mut chain_roots: Vec<u64> = Vec::new();
1168 let mut assigned: HashSet<u64> = HashSet::new();
1169
1170 for &(node_id, _) in &relevant {
1172 let _supersedes_another = graph.edges_from(node_id).iter().any(|e| {
1174 e.edge_type == EdgeType::Supersedes && relevant_ids.contains(&e.target_id)
1175 });
1176
1177 let is_superseded = graph.edges_to(node_id).iter().any(|e| {
1179 e.edge_type == EdgeType::Supersedes && relevant_ids.contains(&e.source_id)
1180 });
1181
1182 if !is_superseded {
1185 chain_roots.push(node_id);
1186 }
1187 }
1188
1189 if chain_roots.is_empty() {
1191 for &(node_id, _) in relevant.iter().take(params.max_results) {
1193 chain_roots.push(node_id);
1194 }
1195 }
1196
1197 let mut timelines: Vec<BeliefTimeline> = Vec::new();
1198
1199 for &root_id in &chain_roots {
1200 if assigned.contains(&root_id) {
1201 continue;
1202 }
1203
1204 let mut chain: Vec<u64> = Vec::new();
1205 let mut current = root_id;
1206 let mut chain_visited: HashSet<u64> = HashSet::new();
1207
1208 chain_visited.insert(current);
1211 chain.push(current);
1212 assigned.insert(current);
1213
1214 loop {
1215 let mut next = None;
1216 for edge in graph.edges_from(current) {
1217 if edge.edge_type == EdgeType::Supersedes
1218 && !chain_visited.contains(&edge.target_id)
1219 {
1220 next = Some(edge.target_id);
1221 break;
1222 }
1223 }
1224 match next {
1225 Some(next_id) => {
1226 chain_visited.insert(next_id);
1227 chain.push(next_id);
1228 assigned.insert(next_id);
1229 current = next_id;
1230 }
1231 None => break,
1232 }
1233 }
1234
1235 current = root_id;
1237 loop {
1238 let mut prev = None;
1239 for edge in graph.edges_to(current) {
1240 if edge.edge_type == EdgeType::Supersedes
1241 && !chain_visited.contains(&edge.source_id)
1242 {
1243 prev = Some(edge.source_id);
1244 break;
1245 }
1246 }
1247 match prev {
1248 Some(prev_id) => {
1249 chain_visited.insert(prev_id);
1250 chain.insert(0, prev_id);
1251 assigned.insert(prev_id);
1252 current = prev_id;
1253 }
1254 None => break,
1255 }
1256 }
1257
1258 chain.sort_by_key(|&id| graph.get_node(id).map(|n| n.created_at).unwrap_or(0));
1262
1263 let mut snapshots: Vec<BeliefSnapshot> = Vec::new();
1265 let mut correction_count = 0usize;
1266 let mut contradiction_count = 0usize;
1267
1268 for (i, &nid) in chain.iter().enumerate() {
1269 if let Some(node) = graph.get_node(nid) {
1270 let change_type =
1271 if i == 0 {
1272 ChangeType::Initial
1273 } else {
1274 let prev_id = chain[i - 1];
1275 let has_supersedes = graph.edges_from(nid).iter().any(|e| {
1277 e.edge_type == EdgeType::Supersedes && e.target_id == prev_id
1278 });
1279 let has_contradicts = graph.edges_from(nid).iter().any(|e| {
1280 e.edge_type == EdgeType::Contradicts && e.target_id == prev_id
1281 }) || graph.edges_to(nid).iter().any(|e| {
1282 e.edge_type == EdgeType::Contradicts && e.source_id == prev_id
1283 });
1284 let has_supports = graph.edges_from(nid).iter().any(|e| {
1285 e.edge_type == EdgeType::Supports && e.target_id == prev_id
1286 }) || graph.edges_to(nid).iter().any(|e| {
1287 e.edge_type == EdgeType::Supports && e.source_id == prev_id
1288 });
1289
1290 if has_contradicts {
1291 ChangeType::Contradicted
1292 } else if node.event_type == EventType::Correction || has_supersedes {
1293 ChangeType::Corrected
1294 } else if has_supports {
1295 ChangeType::Reinforced
1296 } else {
1297 let prev_conf =
1299 graph.get_node(prev_id).map(|n| n.confidence).unwrap_or(0.0);
1300 if node.confidence >= prev_conf {
1301 ChangeType::Refined
1302 } else {
1303 ChangeType::Corrected
1304 }
1305 }
1306 };
1307
1308 match change_type {
1309 ChangeType::Corrected => correction_count += 1,
1310 ChangeType::Contradicted => contradiction_count += 1,
1311 _ => {}
1312 }
1313
1314 let content_preview = if node.content.len() > 120 {
1315 format!("{}...", &node.content[..120])
1316 } else {
1317 node.content.clone()
1318 };
1319
1320 snapshots.push(BeliefSnapshot {
1321 node_id: nid,
1322 session_id: node.session_id,
1323 created_at: node.created_at,
1324 confidence: node.confidence,
1325 content_preview,
1326 change_type,
1327 });
1328 }
1329 }
1330
1331 if !snapshots.is_empty() {
1332 let change_count = snapshots.len().saturating_sub(1);
1333 timelines.push(BeliefTimeline {
1334 snapshots,
1335 change_count,
1336 correction_count,
1337 contradiction_count,
1338 });
1339 }
1340 }
1341
1342 for &(node_id, _) in &relevant {
1344 if assigned.contains(&node_id) {
1345 continue;
1346 }
1347 assigned.insert(node_id);
1348
1349 if let Some(node) = graph.get_node(node_id) {
1350 let content_preview = if node.content.len() > 120 {
1351 format!("{}...", &node.content[..120])
1352 } else {
1353 node.content.clone()
1354 };
1355
1356 timelines.push(BeliefTimeline {
1357 snapshots: vec![BeliefSnapshot {
1358 node_id,
1359 session_id: node.session_id,
1360 created_at: node.created_at,
1361 confidence: node.confidence,
1362 content_preview,
1363 change_type: ChangeType::Initial,
1364 }],
1365 change_count: 0,
1366 correction_count: 0,
1367 contradiction_count: 0,
1368 });
1369 }
1370 }
1371
1372 timelines.sort_by(|a, b| b.change_count.cmp(&a.change_count));
1374 timelines.truncate(params.max_results);
1375
1376 let total_changes: usize = timelines.iter().map(|t| t.change_count).sum();
1378 let total_corrections: usize = timelines.iter().map(|t| t.correction_count).sum();
1379 let total_contradictions: usize = timelines.iter().map(|t| t.contradiction_count).sum();
1380 let total_snapshots: usize = timelines.iter().map(|t| t.snapshots.len()).sum();
1381
1382 let stability = if total_snapshots <= 1 {
1383 1.0
1384 } else {
1385 let volatility =
1386 (total_corrections + total_contradictions) as f32 / total_snapshots as f32;
1387 (1.0 - volatility).clamp(0.0, 1.0)
1388 };
1389
1390 let likely_to_change = if total_changes == 0 {
1393 false
1394 } else {
1395 let instability_ratio =
1396 (total_corrections + total_contradictions) as f32 / total_changes as f32;
1397 instability_ratio > 0.3
1398 };
1399
1400 Ok(DriftReport {
1401 timelines,
1402 stability,
1403 likely_to_change,
1404 })
1405 }
1406}