Skip to main content

mentedb_graph/
contradiction.rs

1//! Contradiction and cycle detection in the knowledge graph.
2
3use std::collections::VecDeque;
4
5use ahash::HashSet;
6use mentedb_core::edge::EdgeType;
7use mentedb_core::types::MemoryId;
8
9use crate::csr::CsrGraph;
10
11/// Find all nodes that contradict `id`, including transitive contradictions
12/// through Supports->Contradicts chains.
13pub fn find_contradictions(graph: &CsrGraph, id: MemoryId) -> Vec<MemoryId> {
14    let mut contradictions = HashSet::default();
15    let mut visited = HashSet::default();
16    let mut queue = VecDeque::new();
17
18    visited.insert(id);
19
20    // Direct contradictions
21    for (neighbor, edge) in graph.outgoing(id) {
22        if edge.edge_type == EdgeType::Contradicts {
23            contradictions.insert(neighbor);
24            visited.insert(neighbor);
25        }
26    }
27    // Also check incoming Contradicts edges
28    for (neighbor, edge) in graph.incoming(id) {
29        if edge.edge_type == EdgeType::Contradicts {
30            contradictions.insert(neighbor);
31            visited.insert(neighbor);
32        }
33    }
34
35    // Transitive: nodes that Support `id` may have Contradicts edges
36    // Follow Supports edges to `id` (incoming), then their Contradicts
37    queue.push_back(id);
38    while let Some(node) = queue.pop_front() {
39        // Find nodes that support this node
40        for (supporter, edge) in graph.incoming(node) {
41            if edge.edge_type == EdgeType::Supports && visited.insert(supporter) {
42                // Check if supporter contradicts anything
43                for (target, e2) in graph.outgoing(supporter) {
44                    if e2.edge_type == EdgeType::Contradicts && target != id {
45                        contradictions.insert(target);
46                    }
47                }
48            }
49        }
50        // Follow outgoing Supports to find Contradicts chains
51        for (supported, edge) in graph.outgoing(node) {
52            if edge.edge_type == EdgeType::Supports && visited.insert(supported) {
53                for (target, e2) in graph.outgoing(supported) {
54                    if e2.edge_type == EdgeType::Contradicts {
55                        contradictions.insert(target);
56                    }
57                }
58            }
59        }
60    }
61
62    contradictions.into_iter().collect()
63}
64
65/// Find all nodes superseded by `id` (directly or transitively).
66pub fn find_superseded(graph: &CsrGraph, id: MemoryId) -> Vec<MemoryId> {
67    let mut result = Vec::new();
68    let mut visited = HashSet::default();
69    let mut queue = VecDeque::new();
70
71    visited.insert(id);
72    queue.push_back(id);
73
74    while let Some(node) = queue.pop_front() {
75        for (neighbor, edge) in graph.outgoing(node) {
76            if edge.edge_type == EdgeType::Supersedes && visited.insert(neighbor) {
77                result.push(neighbor);
78                queue.push_back(neighbor);
79            }
80        }
81    }
82
83    result
84}
85
86/// Detect cycles in the graph considering only the specified edge types.
87/// Returns a list of cycles, each represented as a vector of node IDs.
88pub fn detect_cycles(graph: &CsrGraph, edge_types: &[EdgeType]) -> Vec<Vec<MemoryId>> {
89    let filter: HashSet<EdgeType> = edge_types.iter().copied().collect();
90    let mut cycles = Vec::new();
91    let mut globally_visited = HashSet::default();
92
93    for &start_id in graph.node_ids() {
94        if globally_visited.contains(&start_id) {
95            continue;
96        }
97
98        // DFS with path tracking for cycle detection
99        let mut stack: Vec<(MemoryId, Vec<MemoryId>)> = vec![(start_id, vec![start_id])];
100        let mut in_stack = HashSet::default();
101        in_stack.insert(start_id);
102        let mut local_visited = HashSet::default();
103        local_visited.insert(start_id);
104
105        while let Some((node, path)) = stack.pop() {
106            // Rebuild in_stack from the current path
107            in_stack.clear();
108            for &p in &path {
109                in_stack.insert(p);
110            }
111
112            for (neighbor, edge) in graph.outgoing(node) {
113                if !filter.contains(&edge.edge_type) {
114                    continue;
115                }
116
117                if in_stack.contains(&neighbor) {
118                    // Found a cycle - extract it
119                    if let Some(pos) = path.iter().position(|&n| n == neighbor) {
120                        let cycle: Vec<MemoryId> = path[pos..].to_vec();
121                        // Only add if we haven't found an equivalent cycle
122                        if !cycles.iter().any(|c: &Vec<MemoryId>| {
123                            c.len() == cycle.len() && cycle.iter().all(|n| c.contains(n))
124                        }) {
125                            cycles.push(cycle);
126                        }
127                    }
128                } else if local_visited.insert(neighbor) {
129                    let mut new_path = path.clone();
130                    new_path.push(neighbor);
131                    stack.push((neighbor, new_path));
132                }
133            }
134
135            globally_visited.insert(node);
136        }
137    }
138
139    cycles
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145    use mentedb_core::edge::MemoryEdge;
146
147    fn make_edge(src: MemoryId, tgt: MemoryId, etype: EdgeType, weight: f32) -> MemoryEdge {
148        MemoryEdge {
149            source: src,
150            target: tgt,
151            edge_type: etype,
152            weight,
153            created_at: 1000,
154            valid_from: None,
155            valid_until: None,
156        }
157    }
158
159    #[test]
160    fn test_direct_contradictions() {
161        let mut g = CsrGraph::new();
162        let a = MemoryId::new();
163        let b = MemoryId::new();
164        let c = MemoryId::new();
165
166        g.add_edge(&make_edge(a, b, EdgeType::Contradicts, 1.0));
167        g.add_edge(&make_edge(c, a, EdgeType::Contradicts, 1.0));
168
169        let contras = find_contradictions(&g, a);
170        assert!(contras.contains(&b));
171        assert!(contras.contains(&c));
172    }
173
174    #[test]
175    fn test_transitive_contradictions() {
176        let mut g = CsrGraph::new();
177        let a = MemoryId::new();
178        let b = MemoryId::new();
179        let c = MemoryId::new();
180
181        // a is supported by b, and b contradicts c
182        g.add_edge(&make_edge(b, a, EdgeType::Supports, 1.0));
183        g.add_edge(&make_edge(b, c, EdgeType::Contradicts, 1.0));
184
185        let contras = find_contradictions(&g, a);
186        assert!(contras.contains(&c));
187    }
188
189    #[test]
190    fn test_find_superseded() {
191        let mut g = CsrGraph::new();
192        let a = MemoryId::new();
193        let b = MemoryId::new();
194        let c = MemoryId::new();
195
196        g.add_edge(&make_edge(a, b, EdgeType::Supersedes, 1.0));
197        g.add_edge(&make_edge(b, c, EdgeType::Supersedes, 1.0));
198
199        let superseded = find_superseded(&g, a);
200        assert_eq!(superseded.len(), 2);
201        assert!(superseded.contains(&b));
202        assert!(superseded.contains(&c));
203    }
204
205    #[test]
206    fn test_detect_cycle() {
207        let mut g = CsrGraph::new();
208        let a = MemoryId::new();
209        let b = MemoryId::new();
210        let c = MemoryId::new();
211
212        g.add_edge(&make_edge(a, b, EdgeType::Caused, 1.0));
213        g.add_edge(&make_edge(b, c, EdgeType::Caused, 1.0));
214        g.add_edge(&make_edge(c, a, EdgeType::Caused, 1.0));
215
216        let cycles = detect_cycles(&g, &[EdgeType::Caused]);
217        assert!(!cycles.is_empty());
218        assert_eq!(cycles[0].len(), 3);
219    }
220
221    #[test]
222    fn test_no_cycle() {
223        let mut g = CsrGraph::new();
224        let a = MemoryId::new();
225        let b = MemoryId::new();
226        let c = MemoryId::new();
227
228        g.add_edge(&make_edge(a, b, EdgeType::Caused, 1.0));
229        g.add_edge(&make_edge(b, c, EdgeType::Caused, 1.0));
230
231        let cycles = detect_cycles(&g, &[EdgeType::Caused]);
232        assert!(cycles.is_empty());
233    }
234}