1use std::collections::{HashMap, VecDeque};
7
8use kyu_common::id::TableId;
9use kyu_common::KyuResult;
10use kyu_parser::ast::Direction;
11use kyu_types::TypedValue;
12
13use crate::context::ExecutionContext;
14use crate::data_chunk::DataChunk;
15use crate::operators::recursive_join::build_adjacency_map;
16use crate::physical_plan::PhysicalOperator;
17
18pub struct ShortestPathConfig {
20 pub rel_table_id: TableId,
21 pub direction: Direction,
22 pub src_key_col: usize,
24 pub dst_key_col: usize,
26}
27
28pub struct ShortestPathOp {
29 pub child: Box<PhysicalOperator>,
30 pub cfg: ShortestPathConfig,
31 results: Option<VecDeque<DataChunk>>,
32}
33
34impl ShortestPathOp {
35 pub fn new(child: PhysicalOperator, cfg: ShortestPathConfig) -> Self {
36 Self {
37 child: Box::new(child),
38 cfg,
39 results: None,
40 }
41 }
42
43 pub fn next(&mut self, ctx: &ExecutionContext<'_>) -> KyuResult<Option<DataChunk>> {
44 if self.results.is_none() {
45 self.results = Some(self.execute(ctx)?);
46 }
47 Ok(self.results.as_mut().unwrap().pop_front())
48 }
49
50 fn execute(&mut self, ctx: &ExecutionContext<'_>) -> KyuResult<VecDeque<DataChunk>> {
51 let mut source_rows: Vec<Vec<TypedValue>> = Vec::new();
53 while let Some(chunk) = self.child.next(ctx)? {
54 for row_idx in 0..chunk.num_rows() {
55 source_rows.push(chunk.get_row(row_idx));
56 }
57 }
58
59 let adj = build_adjacency_map(ctx, self.cfg.rel_table_id, self.cfg.direction);
61
62 let mut result_rows: Vec<Vec<TypedValue>> = Vec::new();
64
65 for row in &source_rows {
66 let src = &row[self.cfg.src_key_col];
67 let dst = &row[self.cfg.dst_key_col];
68 let path = bfs_shortest_path(src, dst, &adj);
69
70 let mut out = row.clone();
72 out.push(TypedValue::List(path));
73 result_rows.push(out);
74 }
75
76 let ncols = source_rows.first().map_or(1, |r| r.len()) + 1;
78 let mut chunks = VecDeque::new();
79 for batch in result_rows.chunks(2048) {
80 chunks.push_back(DataChunk::from_rows(batch, ncols));
81 }
82 Ok(chunks)
83 }
84}
85
86pub fn bfs_shortest_path(
90 src: &TypedValue,
91 dst: &TypedValue,
92 adj: &HashMap<TypedValue, Vec<TypedValue>>,
93) -> Vec<TypedValue> {
94 if src == dst {
95 return vec![src.clone()];
96 }
97
98 let mut visited: HashMap<TypedValue, TypedValue> = HashMap::new(); visited.insert(src.clone(), src.clone()); let mut queue = VecDeque::new();
102 queue.push_back(src.clone());
103
104 while let Some(node) = queue.pop_front() {
105 if let Some(neighbors) = adj.get(&node) {
106 for neighbor in neighbors {
107 if !visited.contains_key(neighbor) {
108 visited.insert(neighbor.clone(), node.clone());
109 if neighbor == dst {
110 return reconstruct_path(&visited, src, dst);
112 }
113 queue.push_back(neighbor.clone());
114 }
115 }
116 }
117 }
118
119 Vec::new()
121}
122
123fn reconstruct_path(
125 parents: &HashMap<TypedValue, TypedValue>,
126 src: &TypedValue,
127 dst: &TypedValue,
128) -> Vec<TypedValue> {
129 let mut path = Vec::new();
130 let mut current = dst.clone();
131 loop {
132 path.push(current.clone());
133 if ¤t == src {
134 break;
135 }
136 match parents.get(¤t) {
137 Some(parent) => current = parent.clone(),
138 None => break, }
140 }
141 path.reverse();
142 path
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148 use smol_str::SmolStr;
149
150 fn tv(s: &str) -> TypedValue {
151 TypedValue::String(SmolStr::new(s))
152 }
153
154 #[test]
155 fn shortest_path_direct() {
156 let mut adj: HashMap<TypedValue, Vec<TypedValue>> = HashMap::new();
157 adj.insert(tv("A"), vec![tv("B"), tv("C")]);
158 adj.insert(tv("B"), vec![tv("D")]);
159 adj.insert(tv("C"), vec![tv("D")]);
160
161 let path = bfs_shortest_path(&tv("A"), &tv("B"), &adj);
162 assert_eq!(path, vec![tv("A"), tv("B")]);
163 }
164
165 #[test]
166 fn shortest_path_two_hops() {
167 let mut adj: HashMap<TypedValue, Vec<TypedValue>> = HashMap::new();
168 adj.insert(tv("A"), vec![tv("B")]);
169 adj.insert(tv("B"), vec![tv("C")]);
170 adj.insert(tv("C"), vec![tv("D")]);
171
172 let path = bfs_shortest_path(&tv("A"), &tv("C"), &adj);
173 assert_eq!(path, vec![tv("A"), tv("B"), tv("C")]);
174 }
175
176 #[test]
177 fn shortest_path_prefers_direct() {
178 let mut adj: HashMap<TypedValue, Vec<TypedValue>> = HashMap::new();
180 adj.insert(tv("A"), vec![tv("B"), tv("C")]);
181 adj.insert(tv("B"), vec![tv("C")]);
182
183 let path = bfs_shortest_path(&tv("A"), &tv("C"), &adj);
184 assert_eq!(path, vec![tv("A"), tv("C")]);
186 }
187
188 #[test]
189 fn shortest_path_no_path() {
190 let mut adj: HashMap<TypedValue, Vec<TypedValue>> = HashMap::new();
191 adj.insert(tv("A"), vec![tv("B")]);
192 adj.insert(tv("C"), vec![tv("D")]);
193
194 let path = bfs_shortest_path(&tv("A"), &tv("D"), &adj);
195 assert!(path.is_empty());
196 }
197
198 #[test]
199 fn shortest_path_same_node() {
200 let adj: HashMap<TypedValue, Vec<TypedValue>> = HashMap::new();
201 let path = bfs_shortest_path(&tv("A"), &tv("A"), &adj);
202 assert_eq!(path, vec![tv("A")]);
203 }
204
205 #[test]
206 fn shortest_path_cycle() {
207 let mut adj: HashMap<TypedValue, Vec<TypedValue>> = HashMap::new();
208 adj.insert(tv("A"), vec![tv("B")]);
209 adj.insert(tv("B"), vec![tv("C")]);
210 adj.insert(tv("C"), vec![tv("A")]); let path = bfs_shortest_path(&tv("A"), &tv("C"), &adj);
213 assert_eq!(path, vec![tv("A"), tv("B"), tv("C")]);
214 }
215
216 #[test]
217 fn shortest_path_diamond() {
218 let mut adj: HashMap<TypedValue, Vec<TypedValue>> = HashMap::new();
220 adj.insert(tv("A"), vec![tv("B"), tv("C")]);
221 adj.insert(tv("B"), vec![tv("D")]);
222 adj.insert(tv("C"), vec![tv("D")]);
223
224 let path = bfs_shortest_path(&tv("A"), &tv("D"), &adj);
225 assert_eq!(path.len(), 3);
227 assert_eq!(path[0], tv("A"));
228 assert_eq!(path[2], tv("D"));
229 }
230
231 #[test]
232 fn shortest_path_long_chain() {
233 let mut adj: HashMap<TypedValue, Vec<TypedValue>> = HashMap::new();
234 adj.insert(tv("A"), vec![tv("B")]);
235 adj.insert(tv("B"), vec![tv("C")]);
236 adj.insert(tv("C"), vec![tv("D")]);
237 adj.insert(tv("D"), vec![tv("E")]);
238 adj.insert(tv("E"), vec![tv("F")]);
239
240 let path = bfs_shortest_path(&tv("A"), &tv("F"), &adj);
241 assert_eq!(path.len(), 6);
242 assert_eq!(path, vec![tv("A"), tv("B"), tv("C"), tv("D"), tv("E"), tv("F")]);
243 }
244}