Skip to main content

hirn_engine/graph/
causal.rs

1//! Causal reasoning, trust scoring, and contradiction detection.
2//!
3//! This module provides:
4//! - Causal chain extraction via directed graph traversal
5//! - Causal influence scoring (ε weight in composite formula)
6//! - Counterfactual constraint detection
7//! - Trustworthiness scoring engine
8//! - Automatic contradiction detection on insertion
9
10use std::collections::HashSet;
11
12use hirn_core::id::MemoryId;
13use hirn_core::provenance::Provenance;
14use hirn_core::record::MemoryRecord;
15use hirn_core::types::{EdgeRelation, Namespace, Origin};
16
17use crate::graph_store::GraphStore;
18
19use hirn_core::error::HirnResult;
20
21// ── Causal Chain Types ─────────────────────────────────────────────────
22
23/// A single link in a causal chain: source → (edge) → target.
24#[derive(Debug, Clone)]
25pub struct CausalLink {
26    pub source: MemoryId,
27    pub target: MemoryId,
28    pub weight: f32,
29    pub edge_id: MemoryId,
30    /// Causal strength in `[0, 1]`.  `None` for legacy edges.
31    pub strength: Option<f32>,
32    /// Confidence in `[0, 1]`.  `None` for legacy edges.
33    pub confidence: Option<f32>,
34    /// Evidence count supporting this causal link.
35    pub evidence_count: Option<u32>,
36    /// Free-text provenance tag (e.g. "RCT", "observational").
37    pub provenance: Option<String>,
38    /// Mechanism description (e.g. "dopamine release").
39    pub mechanism: Option<String>,
40}
41
42/// A complete causal chain: an ordered sequence of links forming a path.
43#[derive(Debug, Clone)]
44pub struct CausalChain {
45    pub links: Vec<CausalLink>,
46}
47
48impl CausalChain {
49    /// The starting node of this chain.
50    pub fn start(&self) -> Option<MemoryId> {
51        self.links.first().map(|l| l.source)
52    }
53
54    /// The ending node of this chain.
55    pub fn end(&self) -> Option<MemoryId> {
56        self.links.last().map(|l| l.target)
57    }
58
59    /// All node IDs in this chain (in order).
60    pub fn node_ids(&self) -> Vec<MemoryId> {
61        if self.links.is_empty() {
62            return vec![];
63        }
64        let mut ids = vec![self.links[0].source];
65        for link in &self.links {
66            ids.push(link.target);
67        }
68        ids
69    }
70
71    /// Number of hops in this chain.
72    pub fn depth(&self) -> usize {
73        self.links.len()
74    }
75
76    /// Average edge weight across the chain.
77    pub fn avg_weight(&self) -> f32 {
78        if self.links.is_empty() {
79            return 0.0;
80        }
81        let sum: f32 = self.links.iter().map(|l| l.weight).sum();
82        sum / self.links.len() as f32
83    }
84}
85
86/// Result of causal chain extraction.
87#[derive(Debug, Clone)]
88pub struct CausalChainResult {
89    /// All causal chains found from the starting node.
90    pub chains: Vec<CausalChain>,
91    /// Whether any cycles were detected (and broken).
92    pub cycles_detected: bool,
93}
94
95// ── Causal Relevance Scoring ────────────────────────────────────────────
96
97/// Compute the causal-relevance score ε for a memory that participates in one
98/// or more causal chains.  Returns a value in `[0, 1]`.
99///
100/// The score is the **maximum** across all chains that touch the memory,
101/// computed as:
102///
103///   `chain_score = avg(link_score)` over every link in the chain
104///   `link_score  = strength * confidence` when both are present,
105///                  or `weight` as a fallback for legacy edges.
106///
107/// Using the max (rather than mean) ensures that a single strong causal
108/// connection is not diluted by unrelated weak chains.
109pub fn causal_relevance(result: &CausalChainResult) -> f32 {
110    if result.chains.is_empty() {
111        return 0.0;
112    }
113    let mut max_score: f32 = 0.0;
114    for chain in &result.chains {
115        if chain.links.is_empty() {
116            continue;
117        }
118        let sum: f32 = chain.links.iter().map(|l| link_score(l)).sum();
119        let avg = sum / chain.links.len() as f32;
120        max_score = max_score.max(avg);
121    }
122    max_score.clamp(0.0, 1.0)
123}
124
125/// Per-link score: `strength × confidence × ln(1 + evidence_count)` when
126/// rich fields are available, falls back to `weight` for legacy edges.
127fn link_score(link: &CausalLink) -> f32 {
128    match (link.strength, link.confidence) {
129        (Some(s), Some(c)) => {
130            let ev = link.evidence_count.unwrap_or(1).max(1) as f32;
131            s * c * (1.0 + ev).ln()
132        }
133        (Some(s), None) => s,
134        (None, Some(c)) => c,
135        (None, None) => link.weight,
136    }
137}
138
139// ── Causal Chain Extraction ─────────────────────────────────────────────
140
141/// Extract causal chains forward from a starting node via any [`GraphStore`].
142pub async fn causal_chain_forward(
143    store: &dyn GraphStore,
144    start: MemoryId,
145    max_depth: usize,
146    confidence_threshold: f32,
147    allowed_namespaces: Option<&[Namespace]>,
148) -> HirnResult<CausalChainResult> {
149    extract_causal_chains(
150        store,
151        start,
152        max_depth,
153        EdgeRelation::Causes,
154        confidence_threshold,
155        allowed_namespaces,
156    )
157    .await
158}
159
160/// Extract causal chains backward from a starting node via any [`GraphStore`].
161pub async fn causal_chain_backward(
162    store: &dyn GraphStore,
163    start: MemoryId,
164    max_depth: usize,
165    confidence_threshold: f32,
166    allowed_namespaces: Option<&[Namespace]>,
167) -> HirnResult<CausalChainResult> {
168    extract_causal_chains(
169        store,
170        start,
171        max_depth,
172        EdgeRelation::CausedBy,
173        confidence_threshold,
174        allowed_namespaces,
175    )
176    .await
177}
178
179/// Async core causal chain extraction: iterative DFS with cycle detection.
180async fn extract_causal_chains(
181    store: &dyn GraphStore,
182    start: MemoryId,
183    max_depth: usize,
184    relation: EdgeRelation,
185    confidence_threshold: f32,
186    allowed_namespaces: Option<&[Namespace]>,
187) -> HirnResult<CausalChainResult> {
188    if max_depth == 0 || !store.has_node(start).await? {
189        return Ok(CausalChainResult {
190            chains: vec![],
191            cycles_detected: false,
192        });
193    }
194
195    if let Some(allowed) = allowed_namespaces {
196        let Some(namespace) = store.node_namespace(start).await? else {
197            return Ok(CausalChainResult {
198                chains: vec![],
199                cycles_detected: false,
200            });
201        };
202        if !allowed.contains(&namespace) {
203            return Ok(CausalChainResult {
204                chains: vec![],
205                cycles_detected: false,
206            });
207        }
208    }
209
210    let mut chains = Vec::new();
211    let mut cycles_detected = false;
212
213    let mut stack: Vec<(MemoryId, Vec<CausalLink>, HashSet<MemoryId>)> = Vec::new();
214    let mut initial_visited = HashSet::new();
215    initial_visited.insert(start);
216    stack.push((start, Vec::new(), initial_visited));
217
218    while let Some((current, path, visited)) = stack.pop() {
219        if path.len() >= max_depth {
220            if !path.is_empty() {
221                chains.push(CausalChain { links: path });
222            }
223            continue;
224        }
225
226        let edges = store.get_edges_of_type(current, relation).await?;
227        let mut outgoing = Vec::new();
228        for edge in &edges {
229            if edge.source != current {
230                continue;
231            }
232
233            let confidence = edge.confidence().unwrap_or(0.5);
234            if confidence < confidence_threshold {
235                continue;
236            }
237
238            if let Some(allowed) = allowed_namespaces {
239                let Some(namespace) = store.node_namespace(edge.target).await? else {
240                    continue;
241                };
242                if !allowed.contains(&namespace) {
243                    continue;
244                }
245            }
246
247            outgoing.push(edge);
248        }
249
250        if outgoing.is_empty() {
251            if !path.is_empty() {
252                chains.push(CausalChain { links: path });
253            }
254            continue;
255        }
256
257        let mut any_extended = false;
258        for edge in &outgoing {
259            let target = edge.target;
260            if visited.contains(&target) {
261                cycles_detected = true;
262                if !path.is_empty() {
263                    chains.push(CausalChain {
264                        links: path.clone(),
265                    });
266                }
267                continue;
268            }
269
270            any_extended = true;
271            let link = CausalLink {
272                source: current,
273                target,
274                weight: edge.weight,
275                edge_id: edge.id,
276                strength: edge.strength(),
277                confidence: edge.confidence(),
278                evidence_count: edge.evidence_count(),
279                provenance: edge.provenance().map(str::to_owned),
280                mechanism: edge.mechanism().map(str::to_owned),
281            };
282            let mut new_path = path.clone();
283            new_path.push(link);
284            let mut new_visited = visited.clone();
285            new_visited.insert(target);
286            stack.push((target, new_path, new_visited));
287        }
288
289        if !any_extended && !path.is_empty() {
290            // Already recorded in cycle detection above.
291        }
292    }
293
294    chains.sort_by_key(|c| std::cmp::Reverse(c.depth()));
295    chains.dedup_by(|a, b| a.node_ids() == b.node_ids());
296
297    Ok(CausalChainResult {
298        chains,
299        cycles_detected,
300    })
301}
302
303// ── Counterfactual Detection ───────────────────────────────────────────
304
305/// A counterfactual constraint: if memory A is true, memory B is under tension.
306#[derive(Debug, Clone)]
307pub struct Counterfactual {
308    pub memory_a: MemoryId,
309    pub memory_b: MemoryId,
310    pub constraint: CounterfactualConstraint,
311    pub explanation: String,
312}
313
314/// The type of counterfactual constraint detected.
315#[derive(Debug, Clone, PartialEq, Eq)]
316pub enum CounterfactualConstraint {
317    /// Direct contradiction via `contradicts` edge.
318    DirectContradiction,
319    /// Temporal impossibility: events are temporally inconsistent.
320    TemporalImpossibility,
321    /// F-44: Temporal supersession: a newer record updates/replaces an older one.
322    /// The newer record (memory_a) supersedes the older (memory_b).
323    TemporalSupersession,
324}
325
326// ── Trust Scoring Engine ───────────────────────────────────────────────
327
328/// Compute the trust score for a memory record.
329///
330/// Trust factors:
331/// 1. Origin type: direct observation = 1.0, user = 0.9, LLM = 0.7, etc.
332/// 2. Evidence diversity: more diverse sources = higher trust
333/// 3. Reconsolidation penalty: each mutation slightly reduces trust
334///    (unless supported by new evidence)
335///
336/// Returns a score in [0.0, 1.0].
337pub fn compute_trust_score(provenance: &Provenance, contradiction_count: usize) -> f32 {
338    // Base trust from origin type.
339    let origin_trust = match *provenance.origin() {
340        Origin::DirectObservation => 1.0,
341        Origin::UserProvided => 0.9,
342        Origin::LlmExtraction => 0.7,
343        Origin::Consolidation => {
344            // Consolidation trust depends on evidence diversity.
345            let evidence_count = provenance.confidence_basis.len();
346            if evidence_count >= 5 {
347                0.9
348            } else if evidence_count >= 3 {
349                0.8
350            } else if evidence_count >= 1 {
351                0.6
352            } else {
353                0.4
354            }
355        }
356        Origin::CrossAgent => 0.6,
357        Origin::DreamReplay => 0.3, // low trust until validated
358    };
359
360    // Evidence diversity bonus: unique source IDs.
361    let unique_sources: HashSet<MemoryId> = provenance
362        .confidence_basis
363        .iter()
364        .map(|e| e.source_id)
365        .collect();
366    let diversity_bonus = if unique_sources.len() >= 5 {
367        0.1
368    } else if unique_sources.len() >= 3 {
369        0.05
370    } else {
371        0.0
372    };
373
374    // Mutation penalty: each mutation without new evidence slightly reduces trust.
375    let mutations_without_evidence = count_mutations_without_evidence(provenance);
376    let mutation_penalty = (mutations_without_evidence as f32 * 0.05).min(0.3);
377
378    // Contradiction penalty.
379    let contradiction_penalty = (contradiction_count as f32 * 0.1).min(0.3);
380
381    let score = origin_trust + diversity_bonus - mutation_penalty - contradiction_penalty;
382    score.clamp(0.0, 1.0)
383}
384
385/// Count mutations that are not supported by new evidence.
386fn count_mutations_without_evidence(provenance: &Provenance) -> usize {
387    // Each evidence ref added AFTER a mutation counts as "supported".
388    // Simple heuristic: mutations beyond the evidence count are unsupported.
389    let evidence_count = provenance.confidence_basis.len();
390    let mutation_count = provenance.mutation_log.len();
391    mutation_count.saturating_sub(evidence_count)
392}
393
394// ── Contradiction Detection (for insertion) ────────────────────────────
395
396/// Result of contradiction detection during insertion.
397#[derive(Debug, Clone)]
398pub struct ContradictionDetection {
399    /// IDs of records that contradict the new record.
400    pub contradicting_ids: Vec<MemoryId>,
401    /// Whether any contradictions were found.
402    pub has_contradictions: bool,
403}
404
405#[derive(Clone, Copy)]
406pub struct InsertionCandidateRecord<'a> {
407    pub id: MemoryId,
408    pub content_lower: &'a str,
409    pub has_negation: bool,
410    pub entities: &'a [String],
411    pub similarity: f32,
412}
413
414/// Check if a new record (represented by its content and embedding)
415/// contradicts any existing records.
416///
417/// Detection signals:
418/// 1. High embedding similarity (same topic) with conflicting content
419/// 2. Entity-value conflicts (same entity, different claims)
420/// 3. Negation patterns
421pub fn detect_contradictions_on_insert(
422    content: &str,
423    entities: &[String],
424    similar_records: &[InsertionCandidateRecord<'_>],
425    similarity_threshold: f32,
426) -> ContradictionDetection {
427    let mut contradicting_ids = Vec::new();
428    let content_lower = content.to_lowercase();
429    let new_has_negation = contains_negation(&content_lower);
430
431    for candidate in similar_records {
432        if candidate.similarity < similarity_threshold {
433            continue;
434        }
435
436        let existing_content = candidate.content_lower;
437        let existing_has_negation = candidate.has_negation;
438
439        // Signal 1: Same topic (high cosine sim) + one has negation, the other doesn't.
440        let negation_conflict = new_has_negation != existing_has_negation
441            && content_similarity_simple(&content_lower, existing_content) > 0.3;
442
443        // Signal 2: Entity-value conflicts — same entities but different values.
444        let entity_conflict = if !entities.is_empty() {
445            let shared_entities: Vec<&String> = entities
446                .iter()
447                .filter(|e| candidate.entities.iter().any(|ee| ee == *e))
448                .collect();
449            // If they share entities but have different content, likely a conflict.
450            !shared_entities.is_empty()
451                && content_similarity_simple(&content_lower, existing_content) < 0.8
452                && (new_has_negation
453                    || existing_has_negation
454                    || value_conflict(&content_lower, existing_content))
455        } else {
456            false
457        };
458
459        if negation_conflict || entity_conflict {
460            contradicting_ids.push(candidate.id);
461        }
462    }
463
464    ContradictionDetection {
465        has_contradictions: !contradicting_ids.is_empty(),
466        contradicting_ids,
467    }
468}
469
470// ── TRACE / Provenance Lineage ─────────────────────────────────────────
471
472/// Complete trace result for a memory record.
473#[derive(Debug, Clone)]
474pub struct TraceReport {
475    /// The traced record.
476    pub record: MemoryRecord,
477    /// Full provenance chain.
478    pub provenance: Provenance,
479    /// Source episodes (for semantic/consolidated records).
480    pub source_episodes: Vec<MemoryId>,
481    /// Records derived FROM this record (via DerivedFrom edges).
482    pub derived_records: Vec<MemoryId>,
483    /// Mutation history summary.
484    pub mutation_count: usize,
485    /// Trust score.
486    pub trust_score: f32,
487    /// Textual lineage tree.
488    pub lineage_tree: String,
489}
490
491/// Build a trace report for a memory record via any [`GraphStore`].
492pub async fn build_trace_report(
493    store: &dyn GraphStore,
494    record: MemoryRecord,
495    provenance: Provenance,
496    source_episodes: Vec<MemoryId>,
497) -> HirnResult<TraceReport> {
498    let record_id = record.id();
499
500    let derived_edges = store
501        .get_edges_of_type(record_id, EdgeRelation::DerivedFrom)
502        .await?;
503    let derived_records: Vec<MemoryId> = derived_edges
504        .iter()
505        .filter(|e| e.target == record_id)
506        .map(|e| e.source)
507        .collect();
508
509    let contra_edges = store
510        .get_edges_of_type(record_id, EdgeRelation::Contradicts)
511        .await?;
512    let contradiction_count = contra_edges.len();
513
514    let trust_score = compute_trust_score(&provenance, contradiction_count);
515    let mutation_count = provenance.mutation_log.len();
516
517    let lineage_tree =
518        format_lineage_tree(record_id, &provenance, &source_episodes, &derived_records);
519
520    Ok(TraceReport {
521        record,
522        provenance,
523        source_episodes,
524        derived_records,
525        mutation_count,
526        trust_score,
527        lineage_tree,
528    })
529}
530
531/// Format a textual lineage tree for display.
532fn format_lineage_tree(
533    record_id: MemoryId,
534    provenance: &Provenance,
535    source_episodes: &[MemoryId],
536    derived_records: &[MemoryId],
537) -> String {
538    let mut out = String::new();
539    out.push_str(&format!("Lineage for {record_id}:\n"));
540    out.push_str(&format!("  Origin: {:?}\n", provenance.origin()));
541    out.push_str(&format!("  Created by: {}\n", provenance.created_by));
542
543    if !source_episodes.is_empty() {
544        out.push_str("  Source episodes:\n");
545        for ep in source_episodes {
546            out.push_str(&format!("    <- {ep}\n"));
547        }
548    }
549
550    if let Some(ref model) = provenance.extraction_model {
551        out.push_str(&format!("  Extraction model: {model}\n"));
552    }
553
554    if !provenance.mutation_log.is_empty() {
555        out.push_str(&format!(
556            "  Mutations ({}):\n",
557            provenance.mutation_log.len()
558        ));
559        for m in &provenance.mutation_log {
560            out.push_str(&format!(
561                "    [{:?}] {}: {} -> {} ({})\n",
562                m.trigger, m.field, m.old_value, m.new_value, m.reason
563            ));
564        }
565    }
566
567    if !derived_records.is_empty() {
568        out.push_str("  Derived records:\n");
569        for d in derived_records {
570            out.push_str(&format!("    -> {d}\n"));
571        }
572    }
573
574    out
575}
576
577// ── Helpers ────────────────────────────────────────────────────────────
578
579/// Extract the primary text content from a memory record (any layer).
580pub fn record_content_str(record: &MemoryRecord) -> &str {
581    match record {
582        MemoryRecord::Episodic(e) => &e.content,
583        MemoryRecord::Semantic(s) => &s.description,
584        MemoryRecord::Working(w) => &w.content,
585        MemoryRecord::Procedural(p) => &p.description,
586    }
587}
588
589/// Simple negation detection: checks for common negation patterns.
590///
591/// **Limitation (F-48):** This is surface-level pattern matching, not semantic
592/// entailment. For example, "the project succeeded" vs "the project failed" won't
593/// be caught unless negation markers are present. Full semantic contradiction
594/// detection would require an LLM or NLI (Natural Language Inference) model.
595/// The current approach works well for explicit negations and numerical conflicts
596/// but misses implicit contradictions via paraphrase.
597pub(crate) fn contains_negation(text: &str) -> bool {
598    let negation_patterns = [
599        "not ",
600        "n't ",
601        "never ",
602        "no ",
603        "doesn't ",
604        "didn't ",
605        "isn't ",
606        "wasn't ",
607        "aren't ",
608        "won't ",
609        "cannot ",
610        "can't ",
611        "shouldn't ",
612        "wouldn't ",
613        "hasn't ",
614        "haven't ",
615        "weren't ",
616        "couldn't ",
617        "needn't ",
618        "shan't ",
619        "nor ",
620        "neither ",
621        "nowhere ",
622        "nothing ",
623        "nobody ",
624        "hardly ",
625        "barely ",
626        "scarcely ",
627        "seldom ",
628        "rarely ",
629        "however ",
630        "actually ",
631        "instead ",
632        "contrary ",
633        "incorrect ",
634        "false ",
635        "wrong ",
636        "failed ",
637        "impossible ",
638        "unlike ",
639        "rather than ",
640        "on the contrary",
641        "slower ", // contextual negation for performance claims
642    ];
643    negation_patterns.iter().any(|pat| text.contains(pat))
644}
645
646/// Simple content similarity based on word overlap (Jaccard on word sets).
647fn content_similarity_simple(a: &str, b: &str) -> f32 {
648    let words_a: HashSet<&str> = a.split_whitespace().collect();
649    let words_b: HashSet<&str> = b.split_whitespace().collect();
650    let intersection = words_a.intersection(&words_b).count();
651    let union = words_a.union(&words_b).count();
652    if union == 0 {
653        return 0.0;
654    }
655    intersection as f32 / union as f32
656}
657
658/// Detect value conflicts between two content strings.
659///
660/// Looks for numeric values in the context of the same entity/topic.
661fn value_conflict(a: &str, b: &str) -> bool {
662    // Extract numbers from both strings.
663    let nums_a = extract_numbers(a);
664    let nums_b = extract_numbers(b);
665
666    // If both have numbers and they differ, likely a value conflict.
667    if !nums_a.is_empty() && !nums_b.is_empty() {
668        // Check if any numbers differ significantly.
669        for na in &nums_a {
670            for nb in &nums_b {
671                if (na - nb).abs() > f64::EPSILON {
672                    return true;
673                }
674            }
675        }
676    }
677
678    false
679}
680
681/// Extract numeric values from text.
682fn extract_numbers(text: &str) -> Vec<f64> {
683    let mut numbers = Vec::new();
684    for word in text.split_whitespace() {
685        // Strip common suffixes like "GB", "MB", etc.
686        let cleaned = word
687            .trim_end_matches(|c: char| c.is_alphabetic())
688            .trim_end_matches('%');
689        if let Ok(n) = cleaned.parse::<f64>() {
690            numbers.push(n);
691        }
692    }
693    numbers
694}
695
696// ── Tests ──────────────────────────────────────────────────────────────
697
698#[cfg(test)]
699mod tests {
700    use super::*;
701    use hirn_core::provenance::{EvidenceRef, Provenance};
702    use hirn_core::timestamp::Timestamp;
703    use hirn_core::types::MutationTrigger;
704
705    // ── Trust Scoring Tests ────────────────────────────────────────────
706
707    #[test]
708    fn direct_observation_high_trust() {
709        let p = Provenance::direct(hirn_core::types::AgentId::new("test").unwrap());
710        let score = compute_trust_score(&p, 0);
711        assert!(score >= 0.95, "score={score}");
712    }
713
714    #[test]
715    fn llm_extraction_lower_trust() {
716        let p = Provenance::with_origin(
717            Origin::LlmExtraction,
718            hirn_core::types::AgentId::new("test").unwrap(),
719        );
720        let score = compute_trust_score(&p, 0);
721        assert!((score - 0.7).abs() < 0.05, "score={score}");
722    }
723
724    #[test]
725    fn consolidation_with_diverse_sources_high_trust() {
726        let agent = hirn_core::types::AgentId::new("test").unwrap();
727        let mut p = Provenance::with_origin(Origin::Consolidation, agent);
728        for i in 0..5 {
729            p.confidence_basis.push(EvidenceRef {
730                source_id: MemoryId::new(),
731                description: format!("source {i}"),
732            });
733        }
734        let score = compute_trust_score(&p, 0);
735        assert!(score > 0.8, "score={score}");
736    }
737
738    #[test]
739    fn consolidation_with_single_source_low_trust() {
740        let agent = hirn_core::types::AgentId::new("test").unwrap();
741        let mut p = Provenance::with_origin(Origin::Consolidation, agent);
742        p.confidence_basis.push(EvidenceRef {
743            source_id: MemoryId::new(),
744            description: "only source".to_string(),
745        });
746        let score = compute_trust_score(&p, 0);
747        assert!(score < 0.7, "score={score}");
748    }
749
750    #[test]
751    fn mutations_without_evidence_reduce_trust() {
752        let agent = hirn_core::types::AgentId::new("test").unwrap();
753        let mut p = Provenance::direct(agent);
754        // Add 3 mutations without evidence.
755        for i in 0..3 {
756            p.record_mutation(hirn_core::provenance::Mutation {
757                timestamp: Timestamp::now(),
758                trigger: MutationTrigger::Reconsolidation,
759                field: "description".to_string(),
760                old_value: format!("old {i}"),
761                new_value: format!("new {i}"),
762                reason: "test".to_string(),
763            });
764        }
765        let score = compute_trust_score(&p, 0);
766        // 1.0 - 3*0.05 = 0.85
767        assert!(score < 1.0, "score={score}");
768        assert!(score > 0.7, "score={score}");
769    }
770
771    #[test]
772    fn mutations_with_evidence_maintain_trust() {
773        let agent = hirn_core::types::AgentId::new("test").unwrap();
774        let mut p = Provenance::direct(agent);
775        // Add evidence for each mutation.
776        for i in 0..3 {
777            p.confidence_basis.push(EvidenceRef {
778                source_id: MemoryId::new(),
779                description: format!("evidence {i}"),
780            });
781            p.record_mutation(hirn_core::provenance::Mutation {
782                timestamp: Timestamp::now(),
783                trigger: MutationTrigger::Reconsolidation,
784                field: "description".to_string(),
785                old_value: format!("old {i}"),
786                new_value: format!("new {i}"),
787                reason: "supported update".to_string(),
788            });
789        }
790        let score = compute_trust_score(&p, 0);
791        // No unsupported mutations → trust maintained.
792        assert!(score >= 0.95, "score={score}");
793    }
794
795    // ── Contradiction Detection Tests ──────────────────────────────────
796
797    #[test]
798    fn negation_detection() {
799        assert!(contains_negation("hnsw is not faster"));
800        assert!(contains_negation("it doesn't work"));
801        assert!(contains_negation("system never recovered"));
802        assert!(!contains_negation("system is fast"));
803    }
804
805    #[test]
806    fn value_conflict_detection() {
807        assert!(value_conflict(
808            "system uses 10gb ram",
809            "system uses 5gb ram"
810        ));
811        assert!(!value_conflict(
812            "system uses 10gb ram",
813            "system uses 10gb ram"
814        ));
815    }
816
817    #[test]
818    fn content_similarity_identical() {
819        let sim = content_similarity_simple("hello world test", "hello world test");
820        assert!((sim - 1.0).abs() < f64::EPSILON as f32);
821    }
822
823    #[test]
824    fn content_similarity_different() {
825        let sim = content_similarity_simple("hello world", "foo bar baz");
826        assert!(sim < 0.1);
827    }
828
829    // ── Causal Relevance Scoring Tests ─────────────────────────────────
830
831    fn make_link(weight: f32, strength: Option<f32>, confidence: Option<f32>) -> CausalLink {
832        CausalLink {
833            source: MemoryId::new(),
834            target: MemoryId::new(),
835            weight,
836            edge_id: MemoryId::new(),
837            strength,
838            confidence,
839            evidence_count: None,
840            provenance: None,
841            mechanism: None,
842        }
843    }
844
845    #[test]
846    fn causal_relevance_empty_chains() {
847        let result = CausalChainResult {
848            chains: vec![],
849            cycles_detected: false,
850        };
851        assert!((causal_relevance(&result)).abs() < f32::EPSILON);
852    }
853
854    #[test]
855    fn causal_relevance_uses_strength_and_confidence() {
856        let link = make_link(0.5, Some(0.9), Some(0.8));
857        let result = CausalChainResult {
858            chains: vec![CausalChain { links: vec![link] }],
859            cycles_detected: false,
860        };
861        let score = causal_relevance(&result);
862        // 0.9 * 0.8 * ln(2) ≈ 0.72 * 0.693 ≈ 0.499
863        let expected = 0.9 * 0.8 * (2.0_f32).ln();
864        assert!(
865            (score - expected).abs() < 0.01,
866            "score={score}, expected={expected}"
867        );
868    }
869
870    #[test]
871    fn causal_relevance_falls_back_to_weight() {
872        let link = make_link(0.6, None, None);
873        let result = CausalChainResult {
874            chains: vec![CausalChain { links: vec![link] }],
875            cycles_detected: false,
876        };
877        let score = causal_relevance(&result);
878        assert!((score - 0.6).abs() < 0.01, "score={score}");
879    }
880
881    #[test]
882    fn causal_relevance_takes_max_across_chains() {
883        let weak = make_link(0.2, None, None);
884        let strong = make_link(0.0, Some(0.95), Some(0.95));
885        let result = CausalChainResult {
886            chains: vec![
887                CausalChain { links: vec![weak] },
888                CausalChain {
889                    links: vec![strong],
890                },
891            ],
892            cycles_detected: false,
893        };
894        let score = causal_relevance(&result);
895        // max(0.2, 0.95*0.95*ln(2)) = max(0.2, 0.625) = 0.625
896        assert!(score > 0.5, "score={score}");
897    }
898
899    #[test]
900    fn causal_relevance_averages_links_in_chain() {
901        let l1 = make_link(0.0, Some(1.0), Some(1.0)); // 1.0 * ln(2) ≈ 0.693
902        let l2 = make_link(0.0, Some(0.5), Some(0.5)); // 0.25 * ln(2) ≈ 0.173
903        let result = CausalChainResult {
904            chains: vec![CausalChain {
905                links: vec![l1, l2],
906            }],
907            cycles_detected: false,
908        };
909        let score = causal_relevance(&result);
910        // avg(0.693, 0.173) ≈ 0.433
911        let expected = f32::midpoint(1.0 * 1.0 * (2.0_f32).ln(), 0.5 * 0.5 * (2.0_f32).ln());
912        assert!(
913            (score - expected).abs() < 0.01,
914            "score={score}, expected={expected}"
915        );
916    }
917
918    #[test]
919    fn link_score_strength_only() {
920        let link = make_link(0.3, Some(0.8), None);
921        assert!((link_score(&link) - 0.8).abs() < f32::EPSILON);
922    }
923
924    #[test]
925    fn link_score_confidence_only() {
926        let link = make_link(0.3, None, Some(0.7));
927        assert!((link_score(&link) - 0.7).abs() < f32::EPSILON);
928    }
929}