1use std::collections::HashSet;
11
12use hirn_core::id::MemoryId;
13use hirn_core::provenance::Provenance;
14use hirn_core::record::MemoryRecord;
15use hirn_core::types::{EdgeRelation, Namespace, Origin};
16
17use crate::graph_store::GraphStore;
18
19use hirn_core::error::HirnResult;
20
21#[derive(Debug, Clone)]
25pub struct CausalLink {
26 pub source: MemoryId,
27 pub target: MemoryId,
28 pub weight: f32,
29 pub edge_id: MemoryId,
30 pub strength: Option<f32>,
32 pub confidence: Option<f32>,
34 pub evidence_count: Option<u32>,
36 pub provenance: Option<String>,
38 pub mechanism: Option<String>,
40}
41
42#[derive(Debug, Clone)]
44pub struct CausalChain {
45 pub links: Vec<CausalLink>,
46}
47
48impl CausalChain {
49 pub fn start(&self) -> Option<MemoryId> {
51 self.links.first().map(|l| l.source)
52 }
53
54 pub fn end(&self) -> Option<MemoryId> {
56 self.links.last().map(|l| l.target)
57 }
58
59 pub fn node_ids(&self) -> Vec<MemoryId> {
61 if self.links.is_empty() {
62 return vec![];
63 }
64 let mut ids = vec![self.links[0].source];
65 for link in &self.links {
66 ids.push(link.target);
67 }
68 ids
69 }
70
71 pub fn depth(&self) -> usize {
73 self.links.len()
74 }
75
76 pub fn avg_weight(&self) -> f32 {
78 if self.links.is_empty() {
79 return 0.0;
80 }
81 let sum: f32 = self.links.iter().map(|l| l.weight).sum();
82 sum / self.links.len() as f32
83 }
84}
85
86#[derive(Debug, Clone)]
88pub struct CausalChainResult {
89 pub chains: Vec<CausalChain>,
91 pub cycles_detected: bool,
93}
94
95pub fn causal_relevance(result: &CausalChainResult) -> f32 {
110 if result.chains.is_empty() {
111 return 0.0;
112 }
113 let mut max_score: f32 = 0.0;
114 for chain in &result.chains {
115 if chain.links.is_empty() {
116 continue;
117 }
118 let sum: f32 = chain.links.iter().map(|l| link_score(l)).sum();
119 let avg = sum / chain.links.len() as f32;
120 max_score = max_score.max(avg);
121 }
122 max_score.clamp(0.0, 1.0)
123}
124
125fn link_score(link: &CausalLink) -> f32 {
128 match (link.strength, link.confidence) {
129 (Some(s), Some(c)) => {
130 let ev = link.evidence_count.unwrap_or(1).max(1) as f32;
131 s * c * (1.0 + ev).ln()
132 }
133 (Some(s), None) => s,
134 (None, Some(c)) => c,
135 (None, None) => link.weight,
136 }
137}
138
139pub async fn causal_chain_forward(
143 store: &dyn GraphStore,
144 start: MemoryId,
145 max_depth: usize,
146 confidence_threshold: f32,
147 allowed_namespaces: Option<&[Namespace]>,
148) -> HirnResult<CausalChainResult> {
149 extract_causal_chains(
150 store,
151 start,
152 max_depth,
153 EdgeRelation::Causes,
154 confidence_threshold,
155 allowed_namespaces,
156 )
157 .await
158}
159
160pub async fn causal_chain_backward(
162 store: &dyn GraphStore,
163 start: MemoryId,
164 max_depth: usize,
165 confidence_threshold: f32,
166 allowed_namespaces: Option<&[Namespace]>,
167) -> HirnResult<CausalChainResult> {
168 extract_causal_chains(
169 store,
170 start,
171 max_depth,
172 EdgeRelation::CausedBy,
173 confidence_threshold,
174 allowed_namespaces,
175 )
176 .await
177}
178
179async fn extract_causal_chains(
181 store: &dyn GraphStore,
182 start: MemoryId,
183 max_depth: usize,
184 relation: EdgeRelation,
185 confidence_threshold: f32,
186 allowed_namespaces: Option<&[Namespace]>,
187) -> HirnResult<CausalChainResult> {
188 if max_depth == 0 || !store.has_node(start).await? {
189 return Ok(CausalChainResult {
190 chains: vec![],
191 cycles_detected: false,
192 });
193 }
194
195 if let Some(allowed) = allowed_namespaces {
196 let Some(namespace) = store.node_namespace(start).await? else {
197 return Ok(CausalChainResult {
198 chains: vec![],
199 cycles_detected: false,
200 });
201 };
202 if !allowed.contains(&namespace) {
203 return Ok(CausalChainResult {
204 chains: vec![],
205 cycles_detected: false,
206 });
207 }
208 }
209
210 let mut chains = Vec::new();
211 let mut cycles_detected = false;
212
213 let mut stack: Vec<(MemoryId, Vec<CausalLink>, HashSet<MemoryId>)> = Vec::new();
214 let mut initial_visited = HashSet::new();
215 initial_visited.insert(start);
216 stack.push((start, Vec::new(), initial_visited));
217
218 while let Some((current, path, visited)) = stack.pop() {
219 if path.len() >= max_depth {
220 if !path.is_empty() {
221 chains.push(CausalChain { links: path });
222 }
223 continue;
224 }
225
226 let edges = store.get_edges_of_type(current, relation).await?;
227 let mut outgoing = Vec::new();
228 for edge in &edges {
229 if edge.source != current {
230 continue;
231 }
232
233 let confidence = edge.confidence().unwrap_or(0.5);
234 if confidence < confidence_threshold {
235 continue;
236 }
237
238 if let Some(allowed) = allowed_namespaces {
239 let Some(namespace) = store.node_namespace(edge.target).await? else {
240 continue;
241 };
242 if !allowed.contains(&namespace) {
243 continue;
244 }
245 }
246
247 outgoing.push(edge);
248 }
249
250 if outgoing.is_empty() {
251 if !path.is_empty() {
252 chains.push(CausalChain { links: path });
253 }
254 continue;
255 }
256
257 let mut any_extended = false;
258 for edge in &outgoing {
259 let target = edge.target;
260 if visited.contains(&target) {
261 cycles_detected = true;
262 if !path.is_empty() {
263 chains.push(CausalChain {
264 links: path.clone(),
265 });
266 }
267 continue;
268 }
269
270 any_extended = true;
271 let link = CausalLink {
272 source: current,
273 target,
274 weight: edge.weight,
275 edge_id: edge.id,
276 strength: edge.strength(),
277 confidence: edge.confidence(),
278 evidence_count: edge.evidence_count(),
279 provenance: edge.provenance().map(str::to_owned),
280 mechanism: edge.mechanism().map(str::to_owned),
281 };
282 let mut new_path = path.clone();
283 new_path.push(link);
284 let mut new_visited = visited.clone();
285 new_visited.insert(target);
286 stack.push((target, new_path, new_visited));
287 }
288
289 if !any_extended && !path.is_empty() {
290 }
292 }
293
294 chains.sort_by_key(|c| std::cmp::Reverse(c.depth()));
295 chains.dedup_by(|a, b| a.node_ids() == b.node_ids());
296
297 Ok(CausalChainResult {
298 chains,
299 cycles_detected,
300 })
301}
302
303#[derive(Debug, Clone)]
307pub struct Counterfactual {
308 pub memory_a: MemoryId,
309 pub memory_b: MemoryId,
310 pub constraint: CounterfactualConstraint,
311 pub explanation: String,
312}
313
314#[derive(Debug, Clone, PartialEq, Eq)]
316pub enum CounterfactualConstraint {
317 DirectContradiction,
319 TemporalImpossibility,
321 TemporalSupersession,
324}
325
326pub fn compute_trust_score(provenance: &Provenance, contradiction_count: usize) -> f32 {
338 let origin_trust = match *provenance.origin() {
340 Origin::DirectObservation => 1.0,
341 Origin::UserProvided => 0.9,
342 Origin::LlmExtraction => 0.7,
343 Origin::Consolidation => {
344 let evidence_count = provenance.confidence_basis.len();
346 if evidence_count >= 5 {
347 0.9
348 } else if evidence_count >= 3 {
349 0.8
350 } else if evidence_count >= 1 {
351 0.6
352 } else {
353 0.4
354 }
355 }
356 Origin::CrossAgent => 0.6,
357 Origin::DreamReplay => 0.3, };
359
360 let unique_sources: HashSet<MemoryId> = provenance
362 .confidence_basis
363 .iter()
364 .map(|e| e.source_id)
365 .collect();
366 let diversity_bonus = if unique_sources.len() >= 5 {
367 0.1
368 } else if unique_sources.len() >= 3 {
369 0.05
370 } else {
371 0.0
372 };
373
374 let mutations_without_evidence = count_mutations_without_evidence(provenance);
376 let mutation_penalty = (mutations_without_evidence as f32 * 0.05).min(0.3);
377
378 let contradiction_penalty = (contradiction_count as f32 * 0.1).min(0.3);
380
381 let score = origin_trust + diversity_bonus - mutation_penalty - contradiction_penalty;
382 score.clamp(0.0, 1.0)
383}
384
385fn count_mutations_without_evidence(provenance: &Provenance) -> usize {
387 let evidence_count = provenance.confidence_basis.len();
390 let mutation_count = provenance.mutation_log.len();
391 mutation_count.saturating_sub(evidence_count)
392}
393
394#[derive(Debug, Clone)]
398pub struct ContradictionDetection {
399 pub contradicting_ids: Vec<MemoryId>,
401 pub has_contradictions: bool,
403}
404
405#[derive(Clone, Copy)]
406pub struct InsertionCandidateRecord<'a> {
407 pub id: MemoryId,
408 pub content_lower: &'a str,
409 pub has_negation: bool,
410 pub entities: &'a [String],
411 pub similarity: f32,
412}
413
414pub fn detect_contradictions_on_insert(
422 content: &str,
423 entities: &[String],
424 similar_records: &[InsertionCandidateRecord<'_>],
425 similarity_threshold: f32,
426) -> ContradictionDetection {
427 let mut contradicting_ids = Vec::new();
428 let content_lower = content.to_lowercase();
429 let new_has_negation = contains_negation(&content_lower);
430
431 for candidate in similar_records {
432 if candidate.similarity < similarity_threshold {
433 continue;
434 }
435
436 let existing_content = candidate.content_lower;
437 let existing_has_negation = candidate.has_negation;
438
439 let negation_conflict = new_has_negation != existing_has_negation
441 && content_similarity_simple(&content_lower, existing_content) > 0.3;
442
443 let entity_conflict = if !entities.is_empty() {
445 let shared_entities: Vec<&String> = entities
446 .iter()
447 .filter(|e| candidate.entities.iter().any(|ee| ee == *e))
448 .collect();
449 !shared_entities.is_empty()
451 && content_similarity_simple(&content_lower, existing_content) < 0.8
452 && (new_has_negation
453 || existing_has_negation
454 || value_conflict(&content_lower, existing_content))
455 } else {
456 false
457 };
458
459 if negation_conflict || entity_conflict {
460 contradicting_ids.push(candidate.id);
461 }
462 }
463
464 ContradictionDetection {
465 has_contradictions: !contradicting_ids.is_empty(),
466 contradicting_ids,
467 }
468}
469
470#[derive(Debug, Clone)]
474pub struct TraceReport {
475 pub record: MemoryRecord,
477 pub provenance: Provenance,
479 pub source_episodes: Vec<MemoryId>,
481 pub derived_records: Vec<MemoryId>,
483 pub mutation_count: usize,
485 pub trust_score: f32,
487 pub lineage_tree: String,
489}
490
491pub async fn build_trace_report(
493 store: &dyn GraphStore,
494 record: MemoryRecord,
495 provenance: Provenance,
496 source_episodes: Vec<MemoryId>,
497) -> HirnResult<TraceReport> {
498 let record_id = record.id();
499
500 let derived_edges = store
501 .get_edges_of_type(record_id, EdgeRelation::DerivedFrom)
502 .await?;
503 let derived_records: Vec<MemoryId> = derived_edges
504 .iter()
505 .filter(|e| e.target == record_id)
506 .map(|e| e.source)
507 .collect();
508
509 let contra_edges = store
510 .get_edges_of_type(record_id, EdgeRelation::Contradicts)
511 .await?;
512 let contradiction_count = contra_edges.len();
513
514 let trust_score = compute_trust_score(&provenance, contradiction_count);
515 let mutation_count = provenance.mutation_log.len();
516
517 let lineage_tree =
518 format_lineage_tree(record_id, &provenance, &source_episodes, &derived_records);
519
520 Ok(TraceReport {
521 record,
522 provenance,
523 source_episodes,
524 derived_records,
525 mutation_count,
526 trust_score,
527 lineage_tree,
528 })
529}
530
531fn format_lineage_tree(
533 record_id: MemoryId,
534 provenance: &Provenance,
535 source_episodes: &[MemoryId],
536 derived_records: &[MemoryId],
537) -> String {
538 let mut out = String::new();
539 out.push_str(&format!("Lineage for {record_id}:\n"));
540 out.push_str(&format!(" Origin: {:?}\n", provenance.origin()));
541 out.push_str(&format!(" Created by: {}\n", provenance.created_by));
542
543 if !source_episodes.is_empty() {
544 out.push_str(" Source episodes:\n");
545 for ep in source_episodes {
546 out.push_str(&format!(" <- {ep}\n"));
547 }
548 }
549
550 if let Some(ref model) = provenance.extraction_model {
551 out.push_str(&format!(" Extraction model: {model}\n"));
552 }
553
554 if !provenance.mutation_log.is_empty() {
555 out.push_str(&format!(
556 " Mutations ({}):\n",
557 provenance.mutation_log.len()
558 ));
559 for m in &provenance.mutation_log {
560 out.push_str(&format!(
561 " [{:?}] {}: {} -> {} ({})\n",
562 m.trigger, m.field, m.old_value, m.new_value, m.reason
563 ));
564 }
565 }
566
567 if !derived_records.is_empty() {
568 out.push_str(" Derived records:\n");
569 for d in derived_records {
570 out.push_str(&format!(" -> {d}\n"));
571 }
572 }
573
574 out
575}
576
577pub fn record_content_str(record: &MemoryRecord) -> &str {
581 match record {
582 MemoryRecord::Episodic(e) => &e.content,
583 MemoryRecord::Semantic(s) => &s.description,
584 MemoryRecord::Working(w) => &w.content,
585 MemoryRecord::Procedural(p) => &p.description,
586 }
587}
588
589pub(crate) fn contains_negation(text: &str) -> bool {
598 let negation_patterns = [
599 "not ",
600 "n't ",
601 "never ",
602 "no ",
603 "doesn't ",
604 "didn't ",
605 "isn't ",
606 "wasn't ",
607 "aren't ",
608 "won't ",
609 "cannot ",
610 "can't ",
611 "shouldn't ",
612 "wouldn't ",
613 "hasn't ",
614 "haven't ",
615 "weren't ",
616 "couldn't ",
617 "needn't ",
618 "shan't ",
619 "nor ",
620 "neither ",
621 "nowhere ",
622 "nothing ",
623 "nobody ",
624 "hardly ",
625 "barely ",
626 "scarcely ",
627 "seldom ",
628 "rarely ",
629 "however ",
630 "actually ",
631 "instead ",
632 "contrary ",
633 "incorrect ",
634 "false ",
635 "wrong ",
636 "failed ",
637 "impossible ",
638 "unlike ",
639 "rather than ",
640 "on the contrary",
641 "slower ", ];
643 negation_patterns.iter().any(|pat| text.contains(pat))
644}
645
646fn content_similarity_simple(a: &str, b: &str) -> f32 {
648 let words_a: HashSet<&str> = a.split_whitespace().collect();
649 let words_b: HashSet<&str> = b.split_whitespace().collect();
650 let intersection = words_a.intersection(&words_b).count();
651 let union = words_a.union(&words_b).count();
652 if union == 0 {
653 return 0.0;
654 }
655 intersection as f32 / union as f32
656}
657
658fn value_conflict(a: &str, b: &str) -> bool {
662 let nums_a = extract_numbers(a);
664 let nums_b = extract_numbers(b);
665
666 if !nums_a.is_empty() && !nums_b.is_empty() {
668 for na in &nums_a {
670 for nb in &nums_b {
671 if (na - nb).abs() > f64::EPSILON {
672 return true;
673 }
674 }
675 }
676 }
677
678 false
679}
680
681fn extract_numbers(text: &str) -> Vec<f64> {
683 let mut numbers = Vec::new();
684 for word in text.split_whitespace() {
685 let cleaned = word
687 .trim_end_matches(|c: char| c.is_alphabetic())
688 .trim_end_matches('%');
689 if let Ok(n) = cleaned.parse::<f64>() {
690 numbers.push(n);
691 }
692 }
693 numbers
694}
695
696#[cfg(test)]
699mod tests {
700 use super::*;
701 use hirn_core::provenance::{EvidenceRef, Provenance};
702 use hirn_core::timestamp::Timestamp;
703 use hirn_core::types::MutationTrigger;
704
705 #[test]
708 fn direct_observation_high_trust() {
709 let p = Provenance::direct(hirn_core::types::AgentId::new("test").unwrap());
710 let score = compute_trust_score(&p, 0);
711 assert!(score >= 0.95, "score={score}");
712 }
713
714 #[test]
715 fn llm_extraction_lower_trust() {
716 let p = Provenance::with_origin(
717 Origin::LlmExtraction,
718 hirn_core::types::AgentId::new("test").unwrap(),
719 );
720 let score = compute_trust_score(&p, 0);
721 assert!((score - 0.7).abs() < 0.05, "score={score}");
722 }
723
724 #[test]
725 fn consolidation_with_diverse_sources_high_trust() {
726 let agent = hirn_core::types::AgentId::new("test").unwrap();
727 let mut p = Provenance::with_origin(Origin::Consolidation, agent);
728 for i in 0..5 {
729 p.confidence_basis.push(EvidenceRef {
730 source_id: MemoryId::new(),
731 description: format!("source {i}"),
732 });
733 }
734 let score = compute_trust_score(&p, 0);
735 assert!(score > 0.8, "score={score}");
736 }
737
738 #[test]
739 fn consolidation_with_single_source_low_trust() {
740 let agent = hirn_core::types::AgentId::new("test").unwrap();
741 let mut p = Provenance::with_origin(Origin::Consolidation, agent);
742 p.confidence_basis.push(EvidenceRef {
743 source_id: MemoryId::new(),
744 description: "only source".to_string(),
745 });
746 let score = compute_trust_score(&p, 0);
747 assert!(score < 0.7, "score={score}");
748 }
749
750 #[test]
751 fn mutations_without_evidence_reduce_trust() {
752 let agent = hirn_core::types::AgentId::new("test").unwrap();
753 let mut p = Provenance::direct(agent);
754 for i in 0..3 {
756 p.record_mutation(hirn_core::provenance::Mutation {
757 timestamp: Timestamp::now(),
758 trigger: MutationTrigger::Reconsolidation,
759 field: "description".to_string(),
760 old_value: format!("old {i}"),
761 new_value: format!("new {i}"),
762 reason: "test".to_string(),
763 });
764 }
765 let score = compute_trust_score(&p, 0);
766 assert!(score < 1.0, "score={score}");
768 assert!(score > 0.7, "score={score}");
769 }
770
771 #[test]
772 fn mutations_with_evidence_maintain_trust() {
773 let agent = hirn_core::types::AgentId::new("test").unwrap();
774 let mut p = Provenance::direct(agent);
775 for i in 0..3 {
777 p.confidence_basis.push(EvidenceRef {
778 source_id: MemoryId::new(),
779 description: format!("evidence {i}"),
780 });
781 p.record_mutation(hirn_core::provenance::Mutation {
782 timestamp: Timestamp::now(),
783 trigger: MutationTrigger::Reconsolidation,
784 field: "description".to_string(),
785 old_value: format!("old {i}"),
786 new_value: format!("new {i}"),
787 reason: "supported update".to_string(),
788 });
789 }
790 let score = compute_trust_score(&p, 0);
791 assert!(score >= 0.95, "score={score}");
793 }
794
795 #[test]
798 fn negation_detection() {
799 assert!(contains_negation("hnsw is not faster"));
800 assert!(contains_negation("it doesn't work"));
801 assert!(contains_negation("system never recovered"));
802 assert!(!contains_negation("system is fast"));
803 }
804
805 #[test]
806 fn value_conflict_detection() {
807 assert!(value_conflict(
808 "system uses 10gb ram",
809 "system uses 5gb ram"
810 ));
811 assert!(!value_conflict(
812 "system uses 10gb ram",
813 "system uses 10gb ram"
814 ));
815 }
816
817 #[test]
818 fn content_similarity_identical() {
819 let sim = content_similarity_simple("hello world test", "hello world test");
820 assert!((sim - 1.0).abs() < f64::EPSILON as f32);
821 }
822
823 #[test]
824 fn content_similarity_different() {
825 let sim = content_similarity_simple("hello world", "foo bar baz");
826 assert!(sim < 0.1);
827 }
828
829 fn make_link(weight: f32, strength: Option<f32>, confidence: Option<f32>) -> CausalLink {
832 CausalLink {
833 source: MemoryId::new(),
834 target: MemoryId::new(),
835 weight,
836 edge_id: MemoryId::new(),
837 strength,
838 confidence,
839 evidence_count: None,
840 provenance: None,
841 mechanism: None,
842 }
843 }
844
845 #[test]
846 fn causal_relevance_empty_chains() {
847 let result = CausalChainResult {
848 chains: vec![],
849 cycles_detected: false,
850 };
851 assert!((causal_relevance(&result)).abs() < f32::EPSILON);
852 }
853
854 #[test]
855 fn causal_relevance_uses_strength_and_confidence() {
856 let link = make_link(0.5, Some(0.9), Some(0.8));
857 let result = CausalChainResult {
858 chains: vec![CausalChain { links: vec![link] }],
859 cycles_detected: false,
860 };
861 let score = causal_relevance(&result);
862 let expected = 0.9 * 0.8 * (2.0_f32).ln();
864 assert!(
865 (score - expected).abs() < 0.01,
866 "score={score}, expected={expected}"
867 );
868 }
869
870 #[test]
871 fn causal_relevance_falls_back_to_weight() {
872 let link = make_link(0.6, None, None);
873 let result = CausalChainResult {
874 chains: vec![CausalChain { links: vec![link] }],
875 cycles_detected: false,
876 };
877 let score = causal_relevance(&result);
878 assert!((score - 0.6).abs() < 0.01, "score={score}");
879 }
880
881 #[test]
882 fn causal_relevance_takes_max_across_chains() {
883 let weak = make_link(0.2, None, None);
884 let strong = make_link(0.0, Some(0.95), Some(0.95));
885 let result = CausalChainResult {
886 chains: vec![
887 CausalChain { links: vec![weak] },
888 CausalChain {
889 links: vec![strong],
890 },
891 ],
892 cycles_detected: false,
893 };
894 let score = causal_relevance(&result);
895 assert!(score > 0.5, "score={score}");
897 }
898
899 #[test]
900 fn causal_relevance_averages_links_in_chain() {
901 let l1 = make_link(0.0, Some(1.0), Some(1.0)); let l2 = make_link(0.0, Some(0.5), Some(0.5)); let result = CausalChainResult {
904 chains: vec![CausalChain {
905 links: vec![l1, l2],
906 }],
907 cycles_detected: false,
908 };
909 let score = causal_relevance(&result);
910 let expected = f32::midpoint(1.0 * 1.0 * (2.0_f32).ln(), 0.5 * 0.5 * (2.0_f32).ln());
912 assert!(
913 (score - expected).abs() < 0.01,
914 "score={score}, expected={expected}"
915 );
916 }
917
918 #[test]
919 fn link_score_strength_only() {
920 let link = make_link(0.3, Some(0.8), None);
921 assert!((link_score(&link) - 0.8).abs() < f32::EPSILON);
922 }
923
924 #[test]
925 fn link_score_confidence_only() {
926 let link = make_link(0.3, None, Some(0.7));
927 assert!((link_score(&link) - 0.7).abs() < f32::EPSILON);
928 }
929}