Skip to main content

kyu_executor/operators/
bfs.rs

1//! BFS shortest path — find shortest path between two nodes.
2//!
3//! Builds adjacency map from relationship table, then runs BFS from source
4//! to target, reconstructing the path via parent pointers.
5
6use std::collections::{HashMap, VecDeque};
7
8use kyu_common::id::TableId;
9use kyu_common::KyuResult;
10use kyu_parser::ast::Direction;
11use kyu_types::TypedValue;
12
13use crate::context::ExecutionContext;
14use crate::data_chunk::DataChunk;
15use crate::operators::recursive_join::build_adjacency_map;
16use crate::physical_plan::PhysicalOperator;
17
18/// Configuration for a shortest path operation.
19pub struct ShortestPathConfig {
20    pub rel_table_id: TableId,
21    pub direction: Direction,
22    /// Column in child rows holding source node primary key.
23    pub src_key_col: usize,
24    /// Column in child rows holding destination node primary key.
25    pub dst_key_col: usize,
26}
27
28pub struct ShortestPathOp {
29    pub child: Box<PhysicalOperator>,
30    pub cfg: ShortestPathConfig,
31    results: Option<VecDeque<DataChunk>>,
32}
33
34impl ShortestPathOp {
35    pub fn new(child: PhysicalOperator, cfg: ShortestPathConfig) -> Self {
36        Self {
37            child: Box::new(child),
38            cfg,
39            results: None,
40        }
41    }
42
43    pub fn next(&mut self, ctx: &ExecutionContext<'_>) -> KyuResult<Option<DataChunk>> {
44        if self.results.is_none() {
45            self.results = Some(self.execute(ctx)?);
46        }
47        Ok(self.results.as_mut().unwrap().pop_front())
48    }
49
50    fn execute(&mut self, ctx: &ExecutionContext<'_>) -> KyuResult<VecDeque<DataChunk>> {
51        // 1. Drain child to collect source rows (pairs of src, dst nodes).
52        let mut source_rows: Vec<Vec<TypedValue>> = Vec::new();
53        while let Some(chunk) = self.child.next(ctx)? {
54            for row_idx in 0..chunk.num_rows() {
55                source_rows.push(chunk.get_row(row_idx));
56            }
57        }
58
59        // 2. Build adjacency map.
60        let adj = build_adjacency_map(ctx, self.cfg.rel_table_id, self.cfg.direction);
61
62        // 3. For each (src, dst) pair, find shortest path.
63        let mut result_rows: Vec<Vec<TypedValue>> = Vec::new();
64
65        for row in &source_rows {
66            let src = &row[self.cfg.src_key_col];
67            let dst = &row[self.cfg.dst_key_col];
68            let path = bfs_shortest_path(src, dst, &adj);
69
70            // Output: src columns + path as list.
71            let mut out = row.clone();
72            out.push(TypedValue::List(path));
73            result_rows.push(out);
74        }
75
76        // 4. Convert to DataChunks.
77        let ncols = source_rows.first().map_or(1, |r| r.len()) + 1;
78        let mut chunks = VecDeque::new();
79        for batch in result_rows.chunks(2048) {
80            chunks.push_back(DataChunk::from_rows(batch, ncols));
81        }
82        Ok(chunks)
83    }
84}
85
86/// BFS from `src` to `dst` through the adjacency map.
87/// Returns the shortest path as a list of node keys (including src and dst).
88/// Returns empty list if no path exists.
89pub fn bfs_shortest_path(
90    src: &TypedValue,
91    dst: &TypedValue,
92    adj: &HashMap<TypedValue, Vec<TypedValue>>,
93) -> Vec<TypedValue> {
94    if src == dst {
95        return vec![src.clone()];
96    }
97
98    // BFS with parent tracking.
99    let mut visited: HashMap<TypedValue, TypedValue> = HashMap::new(); // child -> parent
100    visited.insert(src.clone(), src.clone()); // sentinel: src's parent is itself
101    let mut queue = VecDeque::new();
102    queue.push_back(src.clone());
103
104    while let Some(node) = queue.pop_front() {
105        if let Some(neighbors) = adj.get(&node) {
106            for neighbor in neighbors {
107                if !visited.contains_key(neighbor) {
108                    visited.insert(neighbor.clone(), node.clone());
109                    if neighbor == dst {
110                        // Reconstruct path.
111                        return reconstruct_path(&visited, src, dst);
112                    }
113                    queue.push_back(neighbor.clone());
114                }
115            }
116        }
117    }
118
119    // No path found.
120    Vec::new()
121}
122
123/// Reconstruct path from parent map.
124fn reconstruct_path(
125    parents: &HashMap<TypedValue, TypedValue>,
126    src: &TypedValue,
127    dst: &TypedValue,
128) -> Vec<TypedValue> {
129    let mut path = Vec::new();
130    let mut current = dst.clone();
131    loop {
132        path.push(current.clone());
133        if &current == src {
134            break;
135        }
136        match parents.get(&current) {
137            Some(parent) => current = parent.clone(),
138            None => break, // shouldn't happen if called correctly
139        }
140    }
141    path.reverse();
142    path
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148    use smol_str::SmolStr;
149
150    fn tv(s: &str) -> TypedValue {
151        TypedValue::String(SmolStr::new(s))
152    }
153
154    #[test]
155    fn shortest_path_direct() {
156        let mut adj: HashMap<TypedValue, Vec<TypedValue>> = HashMap::new();
157        adj.insert(tv("A"), vec![tv("B"), tv("C")]);
158        adj.insert(tv("B"), vec![tv("D")]);
159        adj.insert(tv("C"), vec![tv("D")]);
160
161        let path = bfs_shortest_path(&tv("A"), &tv("B"), &adj);
162        assert_eq!(path, vec![tv("A"), tv("B")]);
163    }
164
165    #[test]
166    fn shortest_path_two_hops() {
167        let mut adj: HashMap<TypedValue, Vec<TypedValue>> = HashMap::new();
168        adj.insert(tv("A"), vec![tv("B")]);
169        adj.insert(tv("B"), vec![tv("C")]);
170        adj.insert(tv("C"), vec![tv("D")]);
171
172        let path = bfs_shortest_path(&tv("A"), &tv("C"), &adj);
173        assert_eq!(path, vec![tv("A"), tv("B"), tv("C")]);
174    }
175
176    #[test]
177    fn shortest_path_prefers_direct() {
178        // A -> B -> C, A -> C (direct)
179        let mut adj: HashMap<TypedValue, Vec<TypedValue>> = HashMap::new();
180        adj.insert(tv("A"), vec![tv("B"), tv("C")]);
181        adj.insert(tv("B"), vec![tv("C")]);
182
183        let path = bfs_shortest_path(&tv("A"), &tv("C"), &adj);
184        // BFS finds direct A->C first.
185        assert_eq!(path, vec![tv("A"), tv("C")]);
186    }
187
188    #[test]
189    fn shortest_path_no_path() {
190        let mut adj: HashMap<TypedValue, Vec<TypedValue>> = HashMap::new();
191        adj.insert(tv("A"), vec![tv("B")]);
192        adj.insert(tv("C"), vec![tv("D")]);
193
194        let path = bfs_shortest_path(&tv("A"), &tv("D"), &adj);
195        assert!(path.is_empty());
196    }
197
198    #[test]
199    fn shortest_path_same_node() {
200        let adj: HashMap<TypedValue, Vec<TypedValue>> = HashMap::new();
201        let path = bfs_shortest_path(&tv("A"), &tv("A"), &adj);
202        assert_eq!(path, vec![tv("A")]);
203    }
204
205    #[test]
206    fn shortest_path_cycle() {
207        let mut adj: HashMap<TypedValue, Vec<TypedValue>> = HashMap::new();
208        adj.insert(tv("A"), vec![tv("B")]);
209        adj.insert(tv("B"), vec![tv("C")]);
210        adj.insert(tv("C"), vec![tv("A")]); // cycle back
211
212        let path = bfs_shortest_path(&tv("A"), &tv("C"), &adj);
213        assert_eq!(path, vec![tv("A"), tv("B"), tv("C")]);
214    }
215
216    #[test]
217    fn shortest_path_diamond() {
218        // A -> B -> D, A -> C -> D
219        let mut adj: HashMap<TypedValue, Vec<TypedValue>> = HashMap::new();
220        adj.insert(tv("A"), vec![tv("B"), tv("C")]);
221        adj.insert(tv("B"), vec![tv("D")]);
222        adj.insert(tv("C"), vec![tv("D")]);
223
224        let path = bfs_shortest_path(&tv("A"), &tv("D"), &adj);
225        // Both paths A->B->D and A->C->D are length 2; BFS finds one.
226        assert_eq!(path.len(), 3);
227        assert_eq!(path[0], tv("A"));
228        assert_eq!(path[2], tv("D"));
229    }
230
231    #[test]
232    fn shortest_path_long_chain() {
233        let mut adj: HashMap<TypedValue, Vec<TypedValue>> = HashMap::new();
234        adj.insert(tv("A"), vec![tv("B")]);
235        adj.insert(tv("B"), vec![tv("C")]);
236        adj.insert(tv("C"), vec![tv("D")]);
237        adj.insert(tv("D"), vec![tv("E")]);
238        adj.insert(tv("E"), vec![tv("F")]);
239
240        let path = bfs_shortest_path(&tv("A"), &tv("F"), &adj);
241        assert_eq!(path.len(), 6);
242        assert_eq!(path, vec![tv("A"), tv("B"), tv("C"), tv("D"), tv("E"), tv("F")]);
243    }
244}