Skip to main content

mentedb_graph/
traversal.rs

1//! Graph traversal algorithms: BFS, DFS, shortest path, subgraph extraction.
2
3use std::collections::VecDeque;
4
5use ahash::{HashMap, HashSet};
6use mentedb_core::edge::{EdgeType, MemoryEdge};
7use mentedb_core::types::MemoryId;
8
9use crate::csr::CsrGraph;
10
11/// Breadth-first search returning (node, depth) pairs.
12pub fn bfs(graph: &CsrGraph, start: MemoryId, max_depth: usize) -> Vec<(MemoryId, usize)> {
13    let Some(_) = graph.get_idx(start) else {
14        return Vec::new();
15    };
16
17    let mut visited = HashSet::default();
18    let mut queue = VecDeque::new();
19    let mut result = Vec::new();
20
21    visited.insert(start);
22    queue.push_back((start, 0usize));
23
24    while let Some((node, depth)) = queue.pop_front() {
25        result.push((node, depth));
26        if depth >= max_depth {
27            continue;
28        }
29        for (neighbor, _edge) in graph.outgoing(node) {
30            if visited.insert(neighbor) {
31                queue.push_back((neighbor, depth + 1));
32            }
33        }
34    }
35
36    result
37}
38
39/// Depth-first search returning (node, depth) pairs.
40pub fn dfs(graph: &CsrGraph, start: MemoryId, max_depth: usize) -> Vec<(MemoryId, usize)> {
41    let Some(_) = graph.get_idx(start) else {
42        return Vec::new();
43    };
44
45    let mut visited = HashSet::default();
46    let mut stack = vec![(start, 0usize)];
47    let mut result = Vec::new();
48
49    while let Some((node, depth)) = stack.pop() {
50        if !visited.insert(node) {
51            continue;
52        }
53        result.push((node, depth));
54        if depth >= max_depth {
55            continue;
56        }
57        for (neighbor, _edge) in graph.outgoing(node) {
58            if !visited.contains(&neighbor) {
59                stack.push((neighbor, depth + 1));
60            }
61        }
62    }
63
64    result
65}
66
67/// BFS that only follows edges matching the given edge types.
68pub fn bfs_filtered(
69    graph: &CsrGraph,
70    start: MemoryId,
71    max_depth: usize,
72    edge_filter: &[EdgeType],
73) -> Vec<(MemoryId, usize)> {
74    let Some(_) = graph.get_idx(start) else {
75        return Vec::new();
76    };
77
78    let filter_set: HashSet<EdgeType> = edge_filter.iter().copied().collect();
79    let mut visited = HashSet::default();
80    let mut queue = VecDeque::new();
81    let mut result = Vec::new();
82
83    visited.insert(start);
84    queue.push_back((start, 0usize));
85
86    while let Some((node, depth)) = queue.pop_front() {
87        result.push((node, depth));
88        if depth >= max_depth {
89            continue;
90        }
91        for (neighbor, edge) in graph.outgoing(node) {
92            if filter_set.contains(&edge.edge_type) && visited.insert(neighbor) {
93                queue.push_back((neighbor, depth + 1));
94            }
95        }
96    }
97
98    result
99}
100
101/// Extract all nodes and edges within `radius` hops of `center`.
102pub fn extract_subgraph(
103    graph: &CsrGraph,
104    center: MemoryId,
105    radius: usize,
106) -> (Vec<MemoryId>, Vec<MemoryEdge>) {
107    let nodes_with_depth = bfs(graph, center, radius);
108    let node_set: HashSet<MemoryId> = nodes_with_depth.iter().map(|&(id, _)| id).collect();
109
110    let nodes: Vec<MemoryId> = nodes_with_depth.into_iter().map(|(id, _)| id).collect();
111    let mut edges = Vec::new();
112
113    for &node in &nodes {
114        for (neighbor, stored) in graph.outgoing(node) {
115            if node_set.contains(&neighbor) {
116                edges.push(MemoryEdge {
117                    source: node,
118                    target: neighbor,
119                    edge_type: stored.edge_type,
120                    weight: stored.weight,
121                    created_at: stored.created_at,
122                    valid_from: stored.valid_from,
123                    valid_until: stored.valid_until,
124                    label: stored.label.clone(),
125                });
126            }
127        }
128    }
129
130    (nodes, edges)
131}
132
133/// Find shortest path using BFS. Returns None if no path exists.
134pub fn shortest_path(graph: &CsrGraph, from: MemoryId, to: MemoryId) -> Option<Vec<MemoryId>> {
135    if from == to {
136        return Some(vec![from]);
137    }
138
139    let _ = graph.get_idx(from)?;
140    let _ = graph.get_idx(to)?;
141
142    let mut visited = HashSet::default();
143    let mut parent: HashMap<MemoryId, MemoryId> = HashMap::default();
144    let mut queue = VecDeque::new();
145
146    visited.insert(from);
147    queue.push_back(from);
148
149    while let Some(node) = queue.pop_front() {
150        for (neighbor, _) in graph.outgoing(node) {
151            if visited.insert(neighbor) {
152                parent.insert(neighbor, node);
153                if neighbor == to {
154                    // Reconstruct path
155                    let mut path = vec![to];
156                    let mut cur = to;
157                    while let Some(&prev) = parent.get(&cur) {
158                        path.push(prev);
159                        cur = prev;
160                        if cur == from {
161                            break;
162                        }
163                    }
164                    path.reverse();
165                    return Some(path);
166                }
167                queue.push_back(neighbor);
168            }
169        }
170    }
171
172    None
173}
174
175#[cfg(test)]
176mod tests {
177    use super::*;
178
179    fn make_edge(src: MemoryId, tgt: MemoryId, etype: EdgeType) -> MemoryEdge {
180        MemoryEdge {
181            source: src,
182            target: tgt,
183            edge_type: etype,
184            weight: 1.0,
185            created_at: 1000,
186            valid_from: None,
187            valid_until: None,
188            label: None,
189        }
190    }
191
192    fn build_chain() -> (CsrGraph, Vec<MemoryId>) {
193        // a -> b -> c -> d
194        let mut g = CsrGraph::new();
195        let ids: Vec<MemoryId> = (0..4).map(|_| MemoryId::new()).collect();
196        g.add_edge(&make_edge(ids[0], ids[1], EdgeType::Caused));
197        g.add_edge(&make_edge(ids[1], ids[2], EdgeType::Caused));
198        g.add_edge(&make_edge(ids[2], ids[3], EdgeType::Related));
199        (g, ids)
200    }
201
202    #[test]
203    fn test_bfs_chain() {
204        let (g, ids) = build_chain();
205        let result = bfs(&g, ids[0], 10);
206        assert_eq!(result.len(), 4);
207        assert_eq!(result[0], (ids[0], 0));
208        assert_eq!(result[1], (ids[1], 1));
209    }
210
211    #[test]
212    fn test_bfs_max_depth() {
213        let (g, ids) = build_chain();
214        let result = bfs(&g, ids[0], 1);
215        assert_eq!(result.len(), 2);
216    }
217
218    #[test]
219    fn test_dfs_chain() {
220        let (g, ids) = build_chain();
221        let result = dfs(&g, ids[0], 10);
222        assert_eq!(result.len(), 4);
223        assert_eq!(result[0].0, ids[0]);
224    }
225
226    #[test]
227    fn test_bfs_filtered() {
228        let (g, ids) = build_chain();
229        // Only follow Caused edges, so we stop before the Related edge
230        let result = bfs_filtered(&g, ids[0], 10, &[EdgeType::Caused]);
231        assert_eq!(result.len(), 3); // a, b, c but not d
232    }
233
234    #[test]
235    fn test_shortest_path() {
236        let (g, ids) = build_chain();
237        let path = shortest_path(&g, ids[0], ids[3]);
238        assert!(path.is_some());
239        let path = path.unwrap();
240        assert_eq!(path.len(), 4);
241        assert_eq!(path[0], ids[0]);
242        assert_eq!(path[3], ids[3]);
243    }
244
245    #[test]
246    fn test_shortest_path_no_path() {
247        let (g, ids) = build_chain();
248        // No reverse path from d to a in a directed graph
249        let path = shortest_path(&g, ids[3], ids[0]);
250        assert!(path.is_none());
251    }
252
253    #[test]
254    fn test_extract_subgraph() {
255        let (g, ids) = build_chain();
256        let (nodes, edges) = extract_subgraph(&g, ids[0], 2);
257        assert_eq!(nodes.len(), 3); // a, b, c
258        assert_eq!(edges.len(), 2); // a->b, b->c
259    }
260}