1use std::collections::VecDeque;
4
5use ahash::{HashMap, HashSet};
6use mentedb_core::edge::EdgeType;
7use mentedb_core::types::MemoryId;
8
9use crate::csr::CsrGraph;
10
11#[derive(Debug, Clone)]
13pub struct PropagationConfig {
14 pub max_depth: usize,
16 pub caused_dampening: f32,
18 pub supports_factor: f32,
20 pub contradicts_factor: f32,
22 pub supersedes_floor: f32,
24}
25
26impl Default for PropagationConfig {
27 fn default() -> Self {
28 Self {
29 max_depth: 5,
30 caused_dampening: 0.9,
31 supports_factor: 0.5,
32 contradicts_factor: 0.7,
33 supersedes_floor: 0.1,
34 }
35 }
36}
37
38pub fn propagate_update(
43 graph: &CsrGraph,
44 changed_id: MemoryId,
45 new_confidence: f32,
46) -> Vec<(MemoryId, f32)> {
47 propagate_update_with_config(
48 graph,
49 changed_id,
50 new_confidence,
51 &PropagationConfig::default(),
52 )
53}
54
55pub fn propagate_update_with_config(
57 graph: &CsrGraph,
58 changed_id: MemoryId,
59 new_confidence: f32,
60 config: &PropagationConfig,
61) -> Vec<(MemoryId, f32)> {
62 let Some(_) = graph.get_idx(changed_id) else {
63 return Vec::new();
64 };
65
66 let mut confidences: HashMap<MemoryId, f32> = HashMap::default();
68 confidences.insert(changed_id, new_confidence);
69
70 let mut queue: VecDeque<(MemoryId, f32, usize)> = VecDeque::new();
72 queue.push_back((changed_id, new_confidence, 0));
73
74 let mut visited = HashSet::default();
75 visited.insert(changed_id);
76
77 while let Some((node, node_confidence, depth)) = queue.pop_front() {
78 if depth >= config.max_depth {
79 continue;
80 }
81
82 for (neighbor, edge) in graph.outgoing(node) {
83 let new_conf = match edge.edge_type {
84 EdgeType::Caused => {
85 node_confidence * edge.weight * config.caused_dampening
87 }
88 EdgeType::Supports => {
89 let current = confidences.get(&neighbor).copied().unwrap_or(1.0);
91 let delta = node_confidence - 1.0; (current + delta * edge.weight * config.supports_factor).clamp(0.0, 1.0)
93 }
94 EdgeType::Contradicts => {
95 let current = confidences.get(&neighbor).copied().unwrap_or(1.0);
97 (current - node_confidence * edge.weight * config.contradicts_factor)
98 .clamp(0.0, 1.0)
99 }
100 EdgeType::Supersedes => {
101 let current = confidences.get(&neighbor).copied().unwrap_or(1.0);
103 current.min(node_confidence * config.supersedes_floor)
104 }
105 _ => continue,
107 };
108
109 confidences.insert(neighbor, new_conf);
110 if visited.insert(neighbor) {
111 queue.push_back((neighbor, new_conf, depth + 1));
112 }
113 }
114 }
115
116 confidences.into_iter().collect()
117}
118
119#[cfg(test)]
120mod tests {
121 use super::*;
122 use mentedb_core::edge::MemoryEdge;
123
124 fn make_edge(src: MemoryId, tgt: MemoryId, etype: EdgeType, weight: f32) -> MemoryEdge {
125 MemoryEdge {
126 source: src,
127 target: tgt,
128 edge_type: etype,
129 weight,
130 created_at: 1000,
131 valid_from: None,
132 valid_until: None,
133 }
134 }
135
136 #[test]
137 fn test_caused_propagation() {
138 let mut g = CsrGraph::new();
139 let a = MemoryId::new();
140 let b = MemoryId::new();
141 g.add_edge(&make_edge(a, b, EdgeType::Caused, 1.0));
142
143 let result = propagate_update(&g, a, 0.5);
144 let map: HashMap<MemoryId, f32> = result.into_iter().collect();
145
146 assert!((map[&a] - 0.5).abs() < 0.001);
147 assert!((map[&b] - 0.45).abs() < 0.001);
149 }
150
151 #[test]
152 fn test_contradicts_propagation() {
153 let mut g = CsrGraph::new();
154 let a = MemoryId::new();
155 let b = MemoryId::new();
156 g.add_edge(&make_edge(a, b, EdgeType::Contradicts, 1.0));
157
158 let result = propagate_update(&g, a, 0.8);
159 let map: HashMap<MemoryId, f32> = result.into_iter().collect();
160
161 assert!((map[&b] - 0.44).abs() < 0.001);
163 }
164
165 #[test]
166 fn test_supersedes_propagation() {
167 let mut g = CsrGraph::new();
168 let a = MemoryId::new();
169 let b = MemoryId::new();
170 g.add_edge(&make_edge(a, b, EdgeType::Supersedes, 1.0));
171
172 let result = propagate_update(&g, a, 0.9);
173 let map: HashMap<MemoryId, f32> = result.into_iter().collect();
174
175 assert!((map[&b] - 0.09).abs() < 0.001);
177 }
178
179 #[test]
180 fn test_max_depth_limit() {
181 let mut g = CsrGraph::new();
183 let ids: Vec<MemoryId> = (0..10).map(|_| MemoryId::new()).collect();
184 for i in 0..9 {
185 g.add_edge(&make_edge(ids[i], ids[i + 1], EdgeType::Caused, 1.0));
186 }
187
188 let result = propagate_update(&g, ids[0], 1.0);
189 let map: HashMap<MemoryId, f32> = result.into_iter().collect();
190
191 assert!(map.len() <= 6);
193 }
194
195 #[test]
196 fn test_unrelated_edges_dont_propagate() {
197 let mut g = CsrGraph::new();
198 let a = MemoryId::new();
199 let b = MemoryId::new();
200 g.add_edge(&make_edge(a, b, EdgeType::Related, 1.0));
201
202 let result = propagate_update(&g, a, 0.5);
203 assert_eq!(result.len(), 1);
205 }
206}