Skip to main content

mentedb_graph/
traversal.rs

1//! Graph traversal algorithms: BFS, DFS, shortest path, subgraph extraction.
2
3use std::collections::VecDeque;
4
5use ahash::{HashMap, HashSet};
6use mentedb_core::edge::{EdgeType, MemoryEdge};
7use mentedb_core::types::MemoryId;
8
9use crate::csr::CsrGraph;
10
11/// Breadth-first search returning (node, depth) pairs.
12pub fn bfs(graph: &CsrGraph, start: MemoryId, max_depth: usize) -> Vec<(MemoryId, usize)> {
13    let Some(_) = graph.get_idx(start) else {
14        return Vec::new();
15    };
16
17    let mut visited = HashSet::default();
18    let mut queue = VecDeque::new();
19    let mut result = Vec::new();
20
21    visited.insert(start);
22    queue.push_back((start, 0usize));
23
24    while let Some((node, depth)) = queue.pop_front() {
25        result.push((node, depth));
26        if depth >= max_depth {
27            continue;
28        }
29        for (neighbor, _edge) in graph.outgoing(node) {
30            if visited.insert(neighbor) {
31                queue.push_back((neighbor, depth + 1));
32            }
33        }
34    }
35
36    result
37}
38
39/// Depth-first search returning (node, depth) pairs.
40pub fn dfs(graph: &CsrGraph, start: MemoryId, max_depth: usize) -> Vec<(MemoryId, usize)> {
41    let Some(_) = graph.get_idx(start) else {
42        return Vec::new();
43    };
44
45    let mut visited = HashSet::default();
46    let mut stack = vec![(start, 0usize)];
47    let mut result = Vec::new();
48
49    while let Some((node, depth)) = stack.pop() {
50        if !visited.insert(node) {
51            continue;
52        }
53        result.push((node, depth));
54        if depth >= max_depth {
55            continue;
56        }
57        for (neighbor, _edge) in graph.outgoing(node) {
58            if !visited.contains(&neighbor) {
59                stack.push((neighbor, depth + 1));
60            }
61        }
62    }
63
64    result
65}
66
67/// BFS that only follows edges matching the given edge types.
68pub fn bfs_filtered(
69    graph: &CsrGraph,
70    start: MemoryId,
71    max_depth: usize,
72    edge_filter: &[EdgeType],
73) -> Vec<(MemoryId, usize)> {
74    let Some(_) = graph.get_idx(start) else {
75        return Vec::new();
76    };
77
78    let filter_set: HashSet<EdgeType> = edge_filter.iter().copied().collect();
79    let mut visited = HashSet::default();
80    let mut queue = VecDeque::new();
81    let mut result = Vec::new();
82
83    visited.insert(start);
84    queue.push_back((start, 0usize));
85
86    while let Some((node, depth)) = queue.pop_front() {
87        result.push((node, depth));
88        if depth >= max_depth {
89            continue;
90        }
91        for (neighbor, edge) in graph.outgoing(node) {
92            if filter_set.contains(&edge.edge_type) && visited.insert(neighbor) {
93                queue.push_back((neighbor, depth + 1));
94            }
95        }
96    }
97
98    result
99}
100
101/// Extract all nodes and edges within `radius` hops of `center`.
102pub fn extract_subgraph(
103    graph: &CsrGraph,
104    center: MemoryId,
105    radius: usize,
106) -> (Vec<MemoryId>, Vec<MemoryEdge>) {
107    let nodes_with_depth = bfs(graph, center, radius);
108    let node_set: HashSet<MemoryId> = nodes_with_depth.iter().map(|&(id, _)| id).collect();
109
110    let nodes: Vec<MemoryId> = nodes_with_depth.into_iter().map(|(id, _)| id).collect();
111    let mut edges = Vec::new();
112
113    for &node in &nodes {
114        for (neighbor, stored) in graph.outgoing(node) {
115            if node_set.contains(&neighbor) {
116                edges.push(MemoryEdge {
117                    source: node,
118                    target: neighbor,
119                    edge_type: stored.edge_type,
120                    weight: stored.weight,
121                    created_at: stored.created_at,
122                    valid_from: stored.valid_from,
123                    valid_until: stored.valid_until,
124                });
125            }
126        }
127    }
128
129    (nodes, edges)
130}
131
132/// Find shortest path using BFS. Returns None if no path exists.
133pub fn shortest_path(graph: &CsrGraph, from: MemoryId, to: MemoryId) -> Option<Vec<MemoryId>> {
134    if from == to {
135        return Some(vec![from]);
136    }
137
138    let _ = graph.get_idx(from)?;
139    let _ = graph.get_idx(to)?;
140
141    let mut visited = HashSet::default();
142    let mut parent: HashMap<MemoryId, MemoryId> = HashMap::default();
143    let mut queue = VecDeque::new();
144
145    visited.insert(from);
146    queue.push_back(from);
147
148    while let Some(node) = queue.pop_front() {
149        for (neighbor, _) in graph.outgoing(node) {
150            if visited.insert(neighbor) {
151                parent.insert(neighbor, node);
152                if neighbor == to {
153                    // Reconstruct path
154                    let mut path = vec![to];
155                    let mut cur = to;
156                    while let Some(&prev) = parent.get(&cur) {
157                        path.push(prev);
158                        cur = prev;
159                        if cur == from {
160                            break;
161                        }
162                    }
163                    path.reverse();
164                    return Some(path);
165                }
166                queue.push_back(neighbor);
167            }
168        }
169    }
170
171    None
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177
178    fn make_edge(src: MemoryId, tgt: MemoryId, etype: EdgeType) -> MemoryEdge {
179        MemoryEdge {
180            source: src,
181            target: tgt,
182            edge_type: etype,
183            weight: 1.0,
184            created_at: 1000,
185            valid_from: None,
186            valid_until: None,
187        }
188    }
189
190    fn build_chain() -> (CsrGraph, Vec<MemoryId>) {
191        // a -> b -> c -> d
192        let mut g = CsrGraph::new();
193        let ids: Vec<MemoryId> = (0..4).map(|_| MemoryId::new()).collect();
194        g.add_edge(&make_edge(ids[0], ids[1], EdgeType::Caused));
195        g.add_edge(&make_edge(ids[1], ids[2], EdgeType::Caused));
196        g.add_edge(&make_edge(ids[2], ids[3], EdgeType::Related));
197        (g, ids)
198    }
199
200    #[test]
201    fn test_bfs_chain() {
202        let (g, ids) = build_chain();
203        let result = bfs(&g, ids[0], 10);
204        assert_eq!(result.len(), 4);
205        assert_eq!(result[0], (ids[0], 0));
206        assert_eq!(result[1], (ids[1], 1));
207    }
208
209    #[test]
210    fn test_bfs_max_depth() {
211        let (g, ids) = build_chain();
212        let result = bfs(&g, ids[0], 1);
213        assert_eq!(result.len(), 2);
214    }
215
216    #[test]
217    fn test_dfs_chain() {
218        let (g, ids) = build_chain();
219        let result = dfs(&g, ids[0], 10);
220        assert_eq!(result.len(), 4);
221        assert_eq!(result[0].0, ids[0]);
222    }
223
224    #[test]
225    fn test_bfs_filtered() {
226        let (g, ids) = build_chain();
227        // Only follow Caused edges, so we stop before the Related edge
228        let result = bfs_filtered(&g, ids[0], 10, &[EdgeType::Caused]);
229        assert_eq!(result.len(), 3); // a, b, c but not d
230    }
231
232    #[test]
233    fn test_shortest_path() {
234        let (g, ids) = build_chain();
235        let path = shortest_path(&g, ids[0], ids[3]);
236        assert!(path.is_some());
237        let path = path.unwrap();
238        assert_eq!(path.len(), 4);
239        assert_eq!(path[0], ids[0]);
240        assert_eq!(path[3], ids[3]);
241    }
242
243    #[test]
244    fn test_shortest_path_no_path() {
245        let (g, ids) = build_chain();
246        // No reverse path from d to a in a directed graph
247        let path = shortest_path(&g, ids[3], ids[0]);
248        assert!(path.is_none());
249    }
250
251    #[test]
252    fn test_extract_subgraph() {
253        let (g, ids) = build_chain();
254        let (nodes, edges) = extract_subgraph(&g, ids[0], 2);
255        assert_eq!(nodes.len(), 3); // a, b, c
256        assert_eq!(edges.len(), 2); // a->b, b->c
257    }
258}