Skip to main content

mentedb_graph/
belief.rs

1//! Belief propagation through causal and supporting edges.
2
3use std::collections::VecDeque;
4
5use ahash::{HashMap, HashSet};
6use mentedb_core::edge::EdgeType;
7use mentedb_core::types::MemoryId;
8
9use crate::csr::CsrGraph;
10
11/// Configuration for belief propagation.
12#[derive(Debug, Clone)]
13pub struct PropagationConfig {
14    /// Maximum BFS depth for propagation (default: 5).
15    pub max_depth: usize,
16    /// Dampening factor for Caused edges (default: 0.9).
17    pub caused_dampening: f32,
18    /// Factor for Supports edges (default: 0.5).
19    pub supports_factor: f32,
20    /// Factor for Contradicts edges (default: 0.7).
21    pub contradicts_factor: f32,
22    /// Floor multiplier for Supersedes edges (default: 0.1).
23    pub supersedes_floor: f32,
24}
25
26impl Default for PropagationConfig {
27    fn default() -> Self {
28        Self {
29            max_depth: 5,
30            caused_dampening: 0.9,
31            supports_factor: 0.5,
32            contradicts_factor: 0.7,
33            supersedes_floor: 0.1,
34        }
35    }
36}
37
38/// Propagate a confidence change through the graph.
39///
40/// Returns a list of (affected_node, new_confidence) pairs. The initial node
41/// is included with `new_confidence` as its updated value.
42pub fn propagate_update(
43    graph: &CsrGraph,
44    changed_id: MemoryId,
45    new_confidence: f32,
46) -> Vec<(MemoryId, f32)> {
47    propagate_update_with_config(
48        graph,
49        changed_id,
50        new_confidence,
51        &PropagationConfig::default(),
52    )
53}
54
55/// Propagate a confidence change through the graph with custom configuration.
56pub fn propagate_update_with_config(
57    graph: &CsrGraph,
58    changed_id: MemoryId,
59    new_confidence: f32,
60    config: &PropagationConfig,
61) -> Vec<(MemoryId, f32)> {
62    let Some(_) = graph.get_idx(changed_id) else {
63        return Vec::new();
64    };
65
66    // Track computed confidences for each affected node
67    let mut confidences: HashMap<MemoryId, f32> = HashMap::default();
68    confidences.insert(changed_id, new_confidence);
69
70    // BFS queue: (node_id, confidence_at_node, current_depth)
71    let mut queue: VecDeque<(MemoryId, f32, usize)> = VecDeque::new();
72    queue.push_back((changed_id, new_confidence, 0));
73
74    let mut visited = HashSet::default();
75    visited.insert(changed_id);
76
77    while let Some((node, node_confidence, depth)) = queue.pop_front() {
78        if depth >= config.max_depth {
79            continue;
80        }
81
82        for (neighbor, edge) in graph.outgoing(node) {
83            let new_conf = match edge.edge_type {
84                EdgeType::Caused => {
85                    // child confidence = parent confidence * edge weight * dampening
86                    node_confidence * edge.weight * config.caused_dampening
87                }
88                EdgeType::Supports => {
89                    // supported node confidence += delta * edge weight * factor
90                    let current = confidences.get(&neighbor).copied().unwrap_or(1.0);
91                    let delta = node_confidence - 1.0; // change from baseline
92                    (current + delta * edge.weight * config.supports_factor).clamp(0.0, 1.0)
93                }
94                EdgeType::Contradicts => {
95                    // contradicted node confidence -= delta * edge weight * factor
96                    let current = confidences.get(&neighbor).copied().unwrap_or(1.0);
97                    (current - node_confidence * edge.weight * config.contradicts_factor)
98                        .clamp(0.0, 1.0)
99                }
100                EdgeType::Supersedes => {
101                    // superseded node confidence = min(current, new * floor)
102                    let current = confidences.get(&neighbor).copied().unwrap_or(1.0);
103                    current.min(node_confidence * config.supersedes_floor)
104                }
105                // Other edge types don't propagate belief
106                _ => continue,
107            };
108
109            confidences.insert(neighbor, new_conf);
110            if visited.insert(neighbor) {
111                queue.push_back((neighbor, new_conf, depth + 1));
112            }
113        }
114    }
115
116    confidences.into_iter().collect()
117}
118
119#[cfg(test)]
120mod tests {
121    use super::*;
122    use mentedb_core::edge::MemoryEdge;
123
124    fn make_edge(src: MemoryId, tgt: MemoryId, etype: EdgeType, weight: f32) -> MemoryEdge {
125        MemoryEdge {
126            source: src,
127            target: tgt,
128            edge_type: etype,
129            weight,
130            created_at: 1000,
131            valid_from: None,
132            valid_until: None,
133        }
134    }
135
136    #[test]
137    fn test_caused_propagation() {
138        let mut g = CsrGraph::new();
139        let a = MemoryId::new();
140        let b = MemoryId::new();
141        g.add_edge(&make_edge(a, b, EdgeType::Caused, 1.0));
142
143        let result = propagate_update(&g, a, 0.5);
144        let map: HashMap<MemoryId, f32> = result.into_iter().collect();
145
146        assert!((map[&a] - 0.5).abs() < 0.001);
147        // b = 0.5 * 1.0 * 0.9 = 0.45
148        assert!((map[&b] - 0.45).abs() < 0.001);
149    }
150
151    #[test]
152    fn test_contradicts_propagation() {
153        let mut g = CsrGraph::new();
154        let a = MemoryId::new();
155        let b = MemoryId::new();
156        g.add_edge(&make_edge(a, b, EdgeType::Contradicts, 1.0));
157
158        let result = propagate_update(&g, a, 0.8);
159        let map: HashMap<MemoryId, f32> = result.into_iter().collect();
160
161        // b = max(0, 1.0 - 0.8 * 1.0 * 0.7) = 0.44
162        assert!((map[&b] - 0.44).abs() < 0.001);
163    }
164
165    #[test]
166    fn test_supersedes_propagation() {
167        let mut g = CsrGraph::new();
168        let a = MemoryId::new();
169        let b = MemoryId::new();
170        g.add_edge(&make_edge(a, b, EdgeType::Supersedes, 1.0));
171
172        let result = propagate_update(&g, a, 0.9);
173        let map: HashMap<MemoryId, f32> = result.into_iter().collect();
174
175        // b = min(1.0, 0.9 * 0.1) = 0.09
176        assert!((map[&b] - 0.09).abs() < 0.001);
177    }
178
179    #[test]
180    fn test_max_depth_limit() {
181        // Build a chain of 10 Caused edges; only first 5 should propagate
182        let mut g = CsrGraph::new();
183        let ids: Vec<MemoryId> = (0..10).map(|_| MemoryId::new()).collect();
184        for i in 0..9 {
185            g.add_edge(&make_edge(ids[i], ids[i + 1], EdgeType::Caused, 1.0));
186        }
187
188        let result = propagate_update(&g, ids[0], 1.0);
189        let map: HashMap<MemoryId, f32> = result.into_iter().collect();
190
191        // Should have at most 6 nodes (start + 5 depth levels)
192        assert!(map.len() <= 6);
193    }
194
195    #[test]
196    fn test_unrelated_edges_dont_propagate() {
197        let mut g = CsrGraph::new();
198        let a = MemoryId::new();
199        let b = MemoryId::new();
200        g.add_edge(&make_edge(a, b, EdgeType::Related, 1.0));
201
202        let result = propagate_update(&g, a, 0.5);
203        // Only the changed node itself
204        assert_eq!(result.len(), 1);
205    }
206}