Skip to main content

mentedb_graph/
traversal.rs

1//! Graph traversal algorithms: BFS, DFS, shortest path, subgraph extraction.
2
3use std::collections::VecDeque;
4
5use ahash::{HashMap, HashSet};
6use mentedb_core::edge::{EdgeType, MemoryEdge};
7use mentedb_core::types::MemoryId;
8
9use crate::csr::CsrGraph;
10
11/// Breadth-first search returning (node, depth) pairs.
12pub fn bfs(graph: &CsrGraph, start: MemoryId, max_depth: usize) -> Vec<(MemoryId, usize)> {
13    let Some(_) = graph.get_idx(start) else {
14        return Vec::new();
15    };
16
17    let mut visited = HashSet::default();
18    let mut queue = VecDeque::new();
19    let mut result = Vec::new();
20
21    visited.insert(start);
22    queue.push_back((start, 0usize));
23
24    while let Some((node, depth)) = queue.pop_front() {
25        result.push((node, depth));
26        if depth >= max_depth {
27            continue;
28        }
29        for (neighbor, _edge) in graph.outgoing(node) {
30            if visited.insert(neighbor) {
31                queue.push_back((neighbor, depth + 1));
32            }
33        }
34    }
35
36    result
37}
38
39/// Depth-first search returning (node, depth) pairs.
40pub fn dfs(graph: &CsrGraph, start: MemoryId, max_depth: usize) -> Vec<(MemoryId, usize)> {
41    let Some(_) = graph.get_idx(start) else {
42        return Vec::new();
43    };
44
45    let mut visited = HashSet::default();
46    let mut stack = vec![(start, 0usize)];
47    let mut result = Vec::new();
48
49    while let Some((node, depth)) = stack.pop() {
50        if !visited.insert(node) {
51            continue;
52        }
53        result.push((node, depth));
54        if depth >= max_depth {
55            continue;
56        }
57        for (neighbor, _edge) in graph.outgoing(node) {
58            if !visited.contains(&neighbor) {
59                stack.push((neighbor, depth + 1));
60            }
61        }
62    }
63
64    result
65}
66
67/// BFS that only follows edges matching the given edge types.
68pub fn bfs_filtered(
69    graph: &CsrGraph,
70    start: MemoryId,
71    max_depth: usize,
72    edge_filter: &[EdgeType],
73) -> Vec<(MemoryId, usize)> {
74    let Some(_) = graph.get_idx(start) else {
75        return Vec::new();
76    };
77
78    let filter_set: HashSet<EdgeType> = edge_filter.iter().copied().collect();
79    let mut visited = HashSet::default();
80    let mut queue = VecDeque::new();
81    let mut result = Vec::new();
82
83    visited.insert(start);
84    queue.push_back((start, 0usize));
85
86    while let Some((node, depth)) = queue.pop_front() {
87        result.push((node, depth));
88        if depth >= max_depth {
89            continue;
90        }
91        for (neighbor, edge) in graph.outgoing(node) {
92            if filter_set.contains(&edge.edge_type) && visited.insert(neighbor) {
93                queue.push_back((neighbor, depth + 1));
94            }
95        }
96    }
97
98    result
99}
100
101/// Extract all nodes and edges within `radius` hops of `center`.
102pub fn extract_subgraph(
103    graph: &CsrGraph,
104    center: MemoryId,
105    radius: usize,
106) -> (Vec<MemoryId>, Vec<MemoryEdge>) {
107    let nodes_with_depth = bfs(graph, center, radius);
108    let node_set: HashSet<MemoryId> = nodes_with_depth.iter().map(|&(id, _)| id).collect();
109
110    let nodes: Vec<MemoryId> = nodes_with_depth.into_iter().map(|(id, _)| id).collect();
111    let mut edges = Vec::new();
112
113    for &node in &nodes {
114        for (neighbor, stored) in graph.outgoing(node) {
115            if node_set.contains(&neighbor) {
116                edges.push(MemoryEdge {
117                    source: node,
118                    target: neighbor,
119                    edge_type: stored.edge_type,
120                    weight: stored.weight,
121                    created_at: stored.created_at,
122                });
123            }
124        }
125    }
126
127    (nodes, edges)
128}
129
130/// Find shortest path using BFS. Returns None if no path exists.
131pub fn shortest_path(graph: &CsrGraph, from: MemoryId, to: MemoryId) -> Option<Vec<MemoryId>> {
132    if from == to {
133        return Some(vec![from]);
134    }
135
136    let _ = graph.get_idx(from)?;
137    let _ = graph.get_idx(to)?;
138
139    let mut visited = HashSet::default();
140    let mut parent: HashMap<MemoryId, MemoryId> = HashMap::default();
141    let mut queue = VecDeque::new();
142
143    visited.insert(from);
144    queue.push_back(from);
145
146    while let Some(node) = queue.pop_front() {
147        for (neighbor, _) in graph.outgoing(node) {
148            if visited.insert(neighbor) {
149                parent.insert(neighbor, node);
150                if neighbor == to {
151                    // Reconstruct path
152                    let mut path = vec![to];
153                    let mut cur = to;
154                    while let Some(&prev) = parent.get(&cur) {
155                        path.push(prev);
156                        cur = prev;
157                        if cur == from {
158                            break;
159                        }
160                    }
161                    path.reverse();
162                    return Some(path);
163                }
164                queue.push_back(neighbor);
165            }
166        }
167    }
168
169    None
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175
176    fn make_edge(src: MemoryId, tgt: MemoryId, etype: EdgeType) -> MemoryEdge {
177        MemoryEdge {
178            source: src,
179            target: tgt,
180            edge_type: etype,
181            weight: 1.0,
182            created_at: 1000,
183        }
184    }
185
186    fn build_chain() -> (CsrGraph, Vec<MemoryId>) {
187        // a -> b -> c -> d
188        let mut g = CsrGraph::new();
189        let ids: Vec<MemoryId> = (0..4).map(|_| MemoryId::new()).collect();
190        g.add_edge(&make_edge(ids[0], ids[1], EdgeType::Caused));
191        g.add_edge(&make_edge(ids[1], ids[2], EdgeType::Caused));
192        g.add_edge(&make_edge(ids[2], ids[3], EdgeType::Related));
193        (g, ids)
194    }
195
196    #[test]
197    fn test_bfs_chain() {
198        let (g, ids) = build_chain();
199        let result = bfs(&g, ids[0], 10);
200        assert_eq!(result.len(), 4);
201        assert_eq!(result[0], (ids[0], 0));
202        assert_eq!(result[1], (ids[1], 1));
203    }
204
205    #[test]
206    fn test_bfs_max_depth() {
207        let (g, ids) = build_chain();
208        let result = bfs(&g, ids[0], 1);
209        assert_eq!(result.len(), 2);
210    }
211
212    #[test]
213    fn test_dfs_chain() {
214        let (g, ids) = build_chain();
215        let result = dfs(&g, ids[0], 10);
216        assert_eq!(result.len(), 4);
217        assert_eq!(result[0].0, ids[0]);
218    }
219
220    #[test]
221    fn test_bfs_filtered() {
222        let (g, ids) = build_chain();
223        // Only follow Caused edges, so we stop before the Related edge
224        let result = bfs_filtered(&g, ids[0], 10, &[EdgeType::Caused]);
225        assert_eq!(result.len(), 3); // a, b, c but not d
226    }
227
228    #[test]
229    fn test_shortest_path() {
230        let (g, ids) = build_chain();
231        let path = shortest_path(&g, ids[0], ids[3]);
232        assert!(path.is_some());
233        let path = path.unwrap();
234        assert_eq!(path.len(), 4);
235        assert_eq!(path[0], ids[0]);
236        assert_eq!(path[3], ids[3]);
237    }
238
239    #[test]
240    fn test_shortest_path_no_path() {
241        let (g, ids) = build_chain();
242        // No reverse path from d to a in a directed graph
243        let path = shortest_path(&g, ids[3], ids[0]);
244        assert!(path.is_none());
245    }
246
247    #[test]
248    fn test_extract_subgraph() {
249        let (g, ids) = build_chain();
250        let (nodes, edges) = extract_subgraph(&g, ids[0], 2);
251        assert_eq!(nodes.len(), 3); // a, b, c
252        assert_eq!(edges.len(), 2); // a->b, b->c
253    }
254}