forge_reasoning/belief/
graph.rs1use petgraph::graph::{DiGraph, NodeIndex};
4use petgraph::algo::tarjan_scc;
5use petgraph::visit::Dfs;
6use std::collections::{HashMap, HashSet};
7use indexmap::IndexSet; use crate::hypothesis::types::HypothesisId;
10use crate::errors::Result;
11
12pub struct BeliefGraph {
17 graph: DiGraph<HypothesisId, ()>,
18 node_indices: HashMap<HypothesisId, NodeIndex>,
19}
20
21impl BeliefGraph {
22 pub fn new() -> Self {
23 Self {
24 graph: DiGraph::new(),
25 node_indices: HashMap::new(),
26 }
27 }
28
29 pub fn add_dependency(
33 &mut self,
34 hypothesis_id: HypothesisId,
35 depends_on: HypothesisId,
36 ) -> Result<()> {
37 if self.would_create_cycle(hypothesis_id, depends_on) {
40 return Err(crate::errors::ReasoningError::InvalidState(
41 format!("Adding dependency {} -> {} would create a cycle",
42 hypothesis_id, depends_on)
43 ));
44 }
45
46 let from_idx = self.get_or_create_node(hypothesis_id);
47 let to_idx = self.get_or_create_node(depends_on);
48
49 if self.graph.find_edge(from_idx, to_idx).is_some() {
51 return Ok(()); }
53
54 self.graph.add_edge(from_idx, to_idx, ());
55 Ok(())
56 }
57
58 pub fn remove_dependency(
60 &mut self,
61 hypothesis_id: HypothesisId,
62 depends_on: HypothesisId,
63 ) -> Result<bool> {
64 let from_idx = *self.node_indices.get(&hypothesis_id)
65 .ok_or_else(|| crate::errors::ReasoningError::NotFound(
66 format!("Hypothesis {} not found in graph", hypothesis_id)
67 ))?;
68 let to_idx = *self.node_indices.get(&depends_on)
69 .ok_or_else(|| crate::errors::ReasoningError::NotFound(
70 format!("Hypothesis {} not found in graph", depends_on)
71 ))?;
72
73 if let Some(edge) = self.graph.find_edge(from_idx, to_idx) {
74 self.graph.remove_edge(edge);
75 Ok(true)
76 } else {
77 Ok(false)
78 }
79 }
80
81 pub fn dependents(&self, hypothesis_id: HypothesisId) -> Result<IndexSet<HypothesisId>> {
85 let node_idx = *self.node_indices.get(&hypothesis_id)
86 .ok_or_else(|| crate::errors::ReasoningError::NotFound(
87 format!("Hypothesis {} not found in graph", hypothesis_id)
88 ))?;
89
90 let mut result = IndexSet::new();
91 for neighbor in self.graph.neighbors_directed(node_idx, petgraph::Direction::Incoming) {
92 result.insert(self.graph[neighbor]);
93 }
94 Ok(result)
95 }
96
97 pub fn dependees(&self, hypothesis_id: HypothesisId) -> Result<IndexSet<HypothesisId>> {
101 let node_idx = *self.node_indices.get(&hypothesis_id)
102 .ok_or_else(|| crate::errors::ReasoningError::NotFound(
103 format!("Hypothesis {} not found in graph", hypothesis_id)
104 ))?;
105
106 let mut result = IndexSet::new();
107 for neighbor in self.graph.neighbors_directed(node_idx, petgraph::Direction::Outgoing) {
108 result.insert(self.graph[neighbor]);
109 }
110 Ok(result)
111 }
112
113 pub fn dependency_chain(&self, hypothesis_id: HypothesisId) -> Result<IndexSet<HypothesisId>> {
117 let node_idx = *self.node_indices.get(&hypothesis_id)
118 .ok_or_else(|| crate::errors::ReasoningError::NotFound(
119 format!("Hypothesis {} not found in graph", hypothesis_id)
120 ))?;
121
122 let mut result = IndexSet::new();
123 let mut dfs = Dfs::new(&self.graph, node_idx);
124 while let Some(reached) = dfs.next(&self.graph) {
125 if reached != node_idx { result.insert(self.graph[reached]);
127 }
128 }
129 Ok(result)
130 }
131
132 pub fn reverse_dependency_chain(&self, hypothesis_id: HypothesisId) -> Result<IndexSet<HypothesisId>> {
134 let _reversed = self.graph.clone();
136 let mut result = IndexSet::new();
138 let mut visited = HashSet::new();
139 self.collect_reverse_dependencies(hypothesis_id, &mut visited, &mut result);
140 result.shift_remove(&hypothesis_id); Ok(result)
142 }
143
144 fn collect_reverse_dependencies(
145 &self,
146 hypothesis_id: HypothesisId,
147 visited: &mut HashSet<HypothesisId>,
148 result: &mut IndexSet<HypothesisId>,
149 ) {
150 if !visited.insert(hypothesis_id) {
151 return; }
153
154 result.insert(hypothesis_id);
155
156 if let Ok(direct_dependents) = self.dependents(hypothesis_id) {
157 for dependent in direct_dependents {
158 self.collect_reverse_dependencies(dependent, visited, result);
159 }
160 }
161 }
162
163 pub fn detect_cycles(&self) -> Vec<Vec<HypothesisId>> {
165 let sccs = tarjan_scc(&self.graph);
166
167 sccs.into_iter()
168 .filter(|scc| scc.len() > 1)
169 .map(|scc| {
170 scc.into_iter()
171 .map(|idx| self.graph[idx])
172 .collect()
173 })
174 .collect()
175 }
176
177 pub fn would_create_cycle(
182 &self,
183 hypothesis_id: HypothesisId,
184 depends_on: HypothesisId,
185 ) -> bool {
186 let mut temp_graph = self.graph.clone();
188
189 let from_idx = if let Some(&idx) = self.node_indices.get(&hypothesis_id) {
191 idx
192 } else {
193 temp_graph.add_node(hypothesis_id)
194 };
195
196 let to_idx = if let Some(&idx) = self.node_indices.get(&depends_on) {
197 idx
198 } else {
199 temp_graph.add_node(depends_on)
200 };
201
202 temp_graph.add_edge(from_idx, to_idx, ());
204
205 let mut dfs = Dfs::new(&temp_graph, to_idx);
207 while let Some(reached) = dfs.next(&temp_graph) {
208 if reached == from_idx {
209 return true; }
211 }
212 false
213 }
214
215 pub fn nodes(&self) -> IndexSet<HypothesisId> {
217 self.node_indices.keys().copied().collect()
218 }
219
220 pub fn remove_hypothesis(&mut self, hypothesis_id: HypothesisId) -> Result<bool> {
222 if let Some(node_idx) = self.node_indices.remove(&hypothesis_id) {
223 self.graph.remove_node(node_idx);
224 Ok(true)
225 } else {
226 Ok(false)
227 }
228 }
229
230 pub fn all_edges(&self) -> Vec<(HypothesisId, HypothesisId)> {
234 self.graph.raw_edges()
235 .iter()
236 .map(|e| {
237 let dependent = self.graph[e.source()];
238 let dependee = self.graph[e.target()];
239 (dependent, dependee)
240 })
241 .collect()
242 }
243
244 fn get_or_create_node(&mut self, id: HypothesisId) -> NodeIndex {
245 if let Some(&idx) = self.node_indices.get(&id) {
246 return idx;
247 }
248 let idx = self.graph.add_node(id);
249 self.node_indices.insert(id, idx);
250 idx
251 }
252}
253
254impl Default for BeliefGraph {
255 fn default() -> Self {
256 Self::new()
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263
264 #[test]
265 fn test_add_dependency() {
266 let mut graph = BeliefGraph::new();
267 let a = HypothesisId::new();
268 let b = HypothesisId::new();
269
270 graph.add_dependency(a, b).unwrap();
271 assert_eq!(graph.dependees(a).unwrap().len(), 1);
272 assert!(graph.dependees(a).unwrap().contains(&b));
273 }
274
275 #[test]
276 fn test_dependents_and_dependees() {
277 let mut graph = BeliefGraph::new();
278 let a = HypothesisId::new();
279 let b = HypothesisId::new();
280 let c = HypothesisId::new();
281
282 graph.add_dependency(a, b).unwrap();
284 graph.add_dependency(b, c).unwrap();
285
286 assert_eq!(graph.dependees(a).unwrap().len(), 1);
288 assert!(graph.dependees(a).unwrap().contains(&b));
289
290 assert_eq!(graph.dependents(b).unwrap().len(), 1);
292 assert!(graph.dependents(b).unwrap().contains(&a));
293
294 assert_eq!(graph.dependents(c).unwrap().len(), 1);
296 assert!(graph.dependents(c).unwrap().contains(&b));
297 }
298
299 #[test]
300 fn test_dependency_chain() {
301 let mut graph = BeliefGraph::new();
302 let a = HypothesisId::new();
303 let b = HypothesisId::new();
304 let c = HypothesisId::new();
305 let d = HypothesisId::new();
306
307 graph.add_dependency(a, b).unwrap();
309 graph.add_dependency(b, c).unwrap();
310 graph.add_dependency(c, d).unwrap();
311
312 let chain = graph.dependency_chain(a).unwrap();
314 assert_eq!(chain.len(), 3);
315 assert!(chain.contains(&b));
316 assert!(chain.contains(&c));
317 assert!(chain.contains(&d));
318 }
319
320 #[test]
321 fn test_cycle_detection_prevents_cycle() {
322 let mut graph = BeliefGraph::new();
323 let a = HypothesisId::new();
324 let b = HypothesisId::new();
325 let c = HypothesisId::new();
326
327 graph.add_dependency(a, b).unwrap();
328 graph.add_dependency(b, c).unwrap();
329
330 let result = graph.add_dependency(c, a);
332 assert!(result.is_err());
333 }
334
335 #[test]
336 fn test_detect_existing_cycles() {
337 let mut graph = BeliefGraph::new();
338 let a = HypothesisId::new();
339 let b = HypothesisId::new();
340
341 graph.get_or_create_node(a);
343 graph.get_or_create_node(b);
344 let from_idx = *graph.node_indices.get(&a).unwrap();
345 let to_idx = *graph.node_indices.get(&b).unwrap();
346 graph.graph.add_edge(from_idx, to_idx, ());
347 graph.graph.add_edge(to_idx, from_idx, ());
348
349 let cycles = graph.detect_cycles();
350 assert_eq!(cycles.len(), 1);
351 }
352}