Skip to main content

mentedb_graph/
belief.rs

1//! Belief propagation through causal and supporting edges.
2
3use std::collections::VecDeque;
4
5use ahash::{HashMap, HashSet};
6use mentedb_core::edge::EdgeType;
7use mentedb_core::types::MemoryId;
8
9use crate::csr::CsrGraph;
10
11/// Configuration for belief propagation.
12#[derive(Debug, Clone)]
13pub struct PropagationConfig {
14    /// Maximum BFS depth for propagation (default: 5).
15    pub max_depth: usize,
16    /// Dampening factor for Caused edges (default: 0.9).
17    pub caused_dampening: f32,
18    /// Factor for Supports edges (default: 0.5).
19    pub supports_factor: f32,
20    /// Factor for Contradicts edges (default: 0.7).
21    pub contradicts_factor: f32,
22    /// Floor multiplier for Supersedes edges (default: 0.1).
23    pub supersedes_floor: f32,
24}
25
26impl Default for PropagationConfig {
27    fn default() -> Self {
28        Self {
29            max_depth: 5,
30            caused_dampening: 0.9,
31            supports_factor: 0.5,
32            contradicts_factor: 0.7,
33            supersedes_floor: 0.1,
34        }
35    }
36}
37
38/// Propagate a confidence change through the graph.
39///
40/// Returns a list of (affected_node, new_confidence) pairs. The initial node
41/// is included with `new_confidence` as its updated value.
42pub fn propagate_update(
43    graph: &CsrGraph,
44    changed_id: MemoryId,
45    new_confidence: f32,
46) -> Vec<(MemoryId, f32)> {
47    propagate_update_with_config(
48        graph,
49        changed_id,
50        new_confidence,
51        &PropagationConfig::default(),
52    )
53}
54
55/// Propagate a confidence change through the graph with custom configuration.
56pub fn propagate_update_with_config(
57    graph: &CsrGraph,
58    changed_id: MemoryId,
59    new_confidence: f32,
60    config: &PropagationConfig,
61) -> Vec<(MemoryId, f32)> {
62    let Some(_) = graph.get_idx(changed_id) else {
63        return Vec::new();
64    };
65
66    // Track computed confidences for each affected node
67    let mut confidences: HashMap<MemoryId, f32> = HashMap::default();
68    confidences.insert(changed_id, new_confidence);
69
70    // BFS queue: (node_id, confidence_at_node, current_depth)
71    let mut queue: VecDeque<(MemoryId, f32, usize)> = VecDeque::new();
72    queue.push_back((changed_id, new_confidence, 0));
73
74    let mut visited = HashSet::default();
75    visited.insert(changed_id);
76
77    while let Some((node, node_confidence, depth)) = queue.pop_front() {
78        if depth >= config.max_depth {
79            continue;
80        }
81
82        for (neighbor, edge) in graph.outgoing(node) {
83            let new_conf = match edge.edge_type {
84                EdgeType::Caused => {
85                    // child confidence = parent confidence * edge weight * dampening
86                    node_confidence * edge.weight * config.caused_dampening
87                }
88                EdgeType::Supports => {
89                    // supported node confidence += delta * edge weight * factor
90                    let current = confidences.get(&neighbor).copied().unwrap_or(1.0);
91                    let delta = node_confidence - 1.0; // change from baseline
92                    (current + delta * edge.weight * config.supports_factor).clamp(0.0, 1.0)
93                }
94                EdgeType::Contradicts => {
95                    // contradicted node confidence -= delta * edge weight * factor
96                    let current = confidences.get(&neighbor).copied().unwrap_or(1.0);
97                    (current - node_confidence * edge.weight * config.contradicts_factor)
98                        .clamp(0.0, 1.0)
99                }
100                EdgeType::Supersedes => {
101                    // superseded node confidence = min(current, new * floor)
102                    let current = confidences.get(&neighbor).copied().unwrap_or(1.0);
103                    current.min(node_confidence * config.supersedes_floor)
104                }
105                // Other edge types don't propagate belief
106                _ => continue,
107            };
108
109            confidences.insert(neighbor, new_conf);
110            if visited.insert(neighbor) {
111                queue.push_back((neighbor, new_conf, depth + 1));
112            }
113        }
114    }
115
116    confidences.into_iter().collect()
117}
118
119#[cfg(test)]
120mod tests {
121    use super::*;
122    use mentedb_core::edge::MemoryEdge;
123    use uuid::Uuid;
124
125    fn make_edge(src: MemoryId, tgt: MemoryId, etype: EdgeType, weight: f32) -> MemoryEdge {
126        MemoryEdge {
127            source: src,
128            target: tgt,
129            edge_type: etype,
130            weight,
131            created_at: 1000,
132        }
133    }
134
135    #[test]
136    fn test_caused_propagation() {
137        let mut g = CsrGraph::new();
138        let a = Uuid::new_v4();
139        let b = Uuid::new_v4();
140        g.add_edge(&make_edge(a, b, EdgeType::Caused, 1.0));
141
142        let result = propagate_update(&g, a, 0.5);
143        let map: HashMap<MemoryId, f32> = result.into_iter().collect();
144
145        assert!((map[&a] - 0.5).abs() < 0.001);
146        // b = 0.5 * 1.0 * 0.9 = 0.45
147        assert!((map[&b] - 0.45).abs() < 0.001);
148    }
149
150    #[test]
151    fn test_contradicts_propagation() {
152        let mut g = CsrGraph::new();
153        let a = Uuid::new_v4();
154        let b = Uuid::new_v4();
155        g.add_edge(&make_edge(a, b, EdgeType::Contradicts, 1.0));
156
157        let result = propagate_update(&g, a, 0.8);
158        let map: HashMap<MemoryId, f32> = result.into_iter().collect();
159
160        // b = max(0, 1.0 - 0.8 * 1.0 * 0.7) = 0.44
161        assert!((map[&b] - 0.44).abs() < 0.001);
162    }
163
164    #[test]
165    fn test_supersedes_propagation() {
166        let mut g = CsrGraph::new();
167        let a = Uuid::new_v4();
168        let b = Uuid::new_v4();
169        g.add_edge(&make_edge(a, b, EdgeType::Supersedes, 1.0));
170
171        let result = propagate_update(&g, a, 0.9);
172        let map: HashMap<MemoryId, f32> = result.into_iter().collect();
173
174        // b = min(1.0, 0.9 * 0.1) = 0.09
175        assert!((map[&b] - 0.09).abs() < 0.001);
176    }
177
178    #[test]
179    fn test_max_depth_limit() {
180        // Build a chain of 10 Caused edges; only first 5 should propagate
181        let mut g = CsrGraph::new();
182        let ids: Vec<MemoryId> = (0..10).map(|_| Uuid::new_v4()).collect();
183        for i in 0..9 {
184            g.add_edge(&make_edge(ids[i], ids[i + 1], EdgeType::Caused, 1.0));
185        }
186
187        let result = propagate_update(&g, ids[0], 1.0);
188        let map: HashMap<MemoryId, f32> = result.into_iter().collect();
189
190        // Should have at most 6 nodes (start + 5 depth levels)
191        assert!(map.len() <= 6);
192    }
193
194    #[test]
195    fn test_unrelated_edges_dont_propagate() {
196        let mut g = CsrGraph::new();
197        let a = Uuid::new_v4();
198        let b = Uuid::new_v4();
199        g.add_edge(&make_edge(a, b, EdgeType::Related, 1.0));
200
201        let result = propagate_update(&g, a, 0.5);
202        // Only the changed node itself
203        assert_eq!(result.len(), 1);
204    }
205}