1use std::collections::VecDeque;
4
5use ahash::{HashMap, HashSet};
6use mentedb_core::edge::EdgeType;
7use mentedb_core::types::MemoryId;
8
9use crate::csr::CsrGraph;
10
11#[derive(Debug, Clone)]
13pub struct PropagationConfig {
14 pub max_depth: usize,
16 pub caused_dampening: f32,
18 pub supports_factor: f32,
20 pub contradicts_factor: f32,
22 pub supersedes_floor: f32,
24}
25
26impl Default for PropagationConfig {
27 fn default() -> Self {
28 Self {
29 max_depth: 5,
30 caused_dampening: 0.9,
31 supports_factor: 0.5,
32 contradicts_factor: 0.7,
33 supersedes_floor: 0.1,
34 }
35 }
36}
37
38pub fn propagate_update(
43 graph: &CsrGraph,
44 changed_id: MemoryId,
45 new_confidence: f32,
46) -> Vec<(MemoryId, f32)> {
47 propagate_update_with_config(
48 graph,
49 changed_id,
50 new_confidence,
51 &PropagationConfig::default(),
52 )
53}
54
55pub fn propagate_update_with_config(
57 graph: &CsrGraph,
58 changed_id: MemoryId,
59 new_confidence: f32,
60 config: &PropagationConfig,
61) -> Vec<(MemoryId, f32)> {
62 let Some(_) = graph.get_idx(changed_id) else {
63 return Vec::new();
64 };
65
66 let mut confidences: HashMap<MemoryId, f32> = HashMap::default();
68 confidences.insert(changed_id, new_confidence);
69
70 let mut queue: VecDeque<(MemoryId, f32, usize)> = VecDeque::new();
72 queue.push_back((changed_id, new_confidence, 0));
73
74 let mut visited = HashSet::default();
75 visited.insert(changed_id);
76
77 while let Some((node, node_confidence, depth)) = queue.pop_front() {
78 if depth >= config.max_depth {
79 continue;
80 }
81
82 for (neighbor, edge) in graph.outgoing(node) {
83 let new_conf = match edge.edge_type {
84 EdgeType::Caused => {
85 node_confidence * edge.weight * config.caused_dampening
87 }
88 EdgeType::Supports => {
89 let current = confidences.get(&neighbor).copied().unwrap_or(1.0);
91 let delta = node_confidence - 1.0; (current + delta * edge.weight * config.supports_factor).clamp(0.0, 1.0)
93 }
94 EdgeType::Contradicts => {
95 let current = confidences.get(&neighbor).copied().unwrap_or(1.0);
97 (current - node_confidence * edge.weight * config.contradicts_factor)
98 .clamp(0.0, 1.0)
99 }
100 EdgeType::Supersedes => {
101 let current = confidences.get(&neighbor).copied().unwrap_or(1.0);
103 current.min(node_confidence * config.supersedes_floor)
104 }
105 _ => continue,
107 };
108
109 confidences.insert(neighbor, new_conf);
110 if visited.insert(neighbor) {
111 queue.push_back((neighbor, new_conf, depth + 1));
112 }
113 }
114 }
115
116 confidences.into_iter().collect()
117}
118
119#[cfg(test)]
120mod tests {
121 use super::*;
122 use mentedb_core::edge::MemoryEdge;
123 use uuid::Uuid;
124
125 fn make_edge(src: MemoryId, tgt: MemoryId, etype: EdgeType, weight: f32) -> MemoryEdge {
126 MemoryEdge {
127 source: src,
128 target: tgt,
129 edge_type: etype,
130 weight,
131 created_at: 1000,
132 }
133 }
134
135 #[test]
136 fn test_caused_propagation() {
137 let mut g = CsrGraph::new();
138 let a = Uuid::new_v4();
139 let b = Uuid::new_v4();
140 g.add_edge(&make_edge(a, b, EdgeType::Caused, 1.0));
141
142 let result = propagate_update(&g, a, 0.5);
143 let map: HashMap<MemoryId, f32> = result.into_iter().collect();
144
145 assert!((map[&a] - 0.5).abs() < 0.001);
146 assert!((map[&b] - 0.45).abs() < 0.001);
148 }
149
150 #[test]
151 fn test_contradicts_propagation() {
152 let mut g = CsrGraph::new();
153 let a = Uuid::new_v4();
154 let b = Uuid::new_v4();
155 g.add_edge(&make_edge(a, b, EdgeType::Contradicts, 1.0));
156
157 let result = propagate_update(&g, a, 0.8);
158 let map: HashMap<MemoryId, f32> = result.into_iter().collect();
159
160 assert!((map[&b] - 0.44).abs() < 0.001);
162 }
163
164 #[test]
165 fn test_supersedes_propagation() {
166 let mut g = CsrGraph::new();
167 let a = Uuid::new_v4();
168 let b = Uuid::new_v4();
169 g.add_edge(&make_edge(a, b, EdgeType::Supersedes, 1.0));
170
171 let result = propagate_update(&g, a, 0.9);
172 let map: HashMap<MemoryId, f32> = result.into_iter().collect();
173
174 assert!((map[&b] - 0.09).abs() < 0.001);
176 }
177
178 #[test]
179 fn test_max_depth_limit() {
180 let mut g = CsrGraph::new();
182 let ids: Vec<MemoryId> = (0..10).map(|_| Uuid::new_v4()).collect();
183 for i in 0..9 {
184 g.add_edge(&make_edge(ids[i], ids[i + 1], EdgeType::Caused, 1.0));
185 }
186
187 let result = propagate_update(&g, ids[0], 1.0);
188 let map: HashMap<MemoryId, f32> = result.into_iter().collect();
189
190 assert!(map.len() <= 6);
192 }
193
194 #[test]
195 fn test_unrelated_edges_dont_propagate() {
196 let mut g = CsrGraph::new();
197 let a = Uuid::new_v4();
198 let b = Uuid::new_v4();
199 g.add_edge(&make_edge(a, b, EdgeType::Related, 1.0));
200
201 let result = propagate_update(&g, a, 0.5);
202 assert_eq!(result.len(), 1);
204 }
205}