1use std::collections::{HashMap, HashSet, VecDeque};
8
9use kyu_common::id::TableId;
10use kyu_common::KyuResult;
11use kyu_parser::ast::Direction;
12use kyu_types::TypedValue;
13
14use crate::context::ExecutionContext;
15use crate::data_chunk::DataChunk;
16use crate::physical_plan::PhysicalOperator;
17
18pub struct RecursiveJoinConfig {
20 pub rel_table_id: TableId,
21 pub dest_table_id: TableId,
22 pub direction: Direction,
23 pub min_hops: u32,
24 pub max_hops: u32,
25 pub src_key_col: usize,
27 pub dest_key_col: usize,
29 pub dest_ncols: usize,
31}
32
33pub struct RecursiveJoinOp {
34 pub child: Box<PhysicalOperator>,
35 pub cfg: RecursiveJoinConfig,
36 results: Option<VecDeque<DataChunk>>,
38}
39
40impl RecursiveJoinOp {
41 pub fn new(child: PhysicalOperator, cfg: RecursiveJoinConfig) -> Self {
42 Self {
43 child: Box::new(child),
44 cfg,
45 results: None,
46 }
47 }
48
49 pub fn next(&mut self, ctx: &ExecutionContext<'_>) -> KyuResult<Option<DataChunk>> {
50 if self.results.is_none() {
51 self.results = Some(self.execute(ctx)?);
52 }
53 Ok(self.results.as_mut().unwrap().pop_front())
54 }
55
56 fn execute(&mut self, ctx: &ExecutionContext<'_>) -> KyuResult<VecDeque<DataChunk>> {
57 let mut source_rows: Vec<Vec<TypedValue>> = Vec::new();
59 while let Some(chunk) = self.child.next(ctx)? {
60 for row_idx in 0..chunk.num_rows() {
61 source_rows.push(chunk.get_row(row_idx));
62 }
63 }
64
65 let adj = build_adjacency_map(ctx, self.cfg.rel_table_id, self.cfg.direction);
67
68 let dest_lookup =
70 build_node_lookup(ctx, self.cfg.dest_table_id, self.cfg.dest_key_col);
71
72 let src_ncols = source_rows.first().map_or(0, |r| r.len());
74 let total_cols = src_ncols + self.cfg.dest_ncols;
75 let mut result_rows: Vec<Vec<TypedValue>> = Vec::new();
76
77 for src_row in &source_rows {
78 let src_key = &src_row[self.cfg.src_key_col];
79 let reachable =
80 bfs_expand(src_key, &adj, self.cfg.min_hops, self.cfg.max_hops);
81
82 for dest_key in reachable {
83 let mut combined = src_row.clone();
84 if let Some(dest_row) = dest_lookup.get(&dest_key) {
85 combined.extend_from_slice(dest_row);
86 } else {
87 combined.extend(
88 std::iter::repeat_n(TypedValue::Null, self.cfg.dest_ncols),
89 );
90 }
91 result_rows.push(combined);
92 }
93 }
94
95 let mut chunks = VecDeque::new();
97 let chunk_size = 2048;
98 for batch in result_rows.chunks(chunk_size) {
99 chunks.push_back(DataChunk::from_rows(batch, total_cols));
100 }
101
102 Ok(chunks)
103 }
104}
105
106pub fn build_adjacency_map(
112 ctx: &ExecutionContext<'_>,
113 rel_table_id: TableId,
114 direction: Direction,
115) -> HashMap<TypedValue, Vec<TypedValue>> {
116 let mut adj: HashMap<TypedValue, Vec<TypedValue>> = HashMap::new();
117
118 for chunk in ctx.storage.scan_table(rel_table_id) {
119 for row_idx in 0..chunk.num_rows() {
120 let col0 = chunk.get_value(row_idx, 0);
121 let col1 = chunk.get_value(row_idx, 1);
122
123 match direction {
124 Direction::Right => {
125 adj.entry(col0).or_default().push(col1);
126 }
127 Direction::Left => {
128 adj.entry(col1).or_default().push(col0);
129 }
130 Direction::Both => {
131 adj.entry(col0.clone()).or_default().push(col1.clone());
132 adj.entry(col1).or_default().push(col0);
133 }
134 }
135 }
136 }
137
138 adj
139}
140
141fn build_node_lookup(
143 ctx: &ExecutionContext<'_>,
144 table_id: TableId,
145 key_col: usize,
146) -> HashMap<TypedValue, Vec<TypedValue>> {
147 let mut lookup = HashMap::new();
148
149 for chunk in ctx.storage.scan_table(table_id) {
150 for row_idx in 0..chunk.num_rows() {
151 let key = chunk.get_value(row_idx, key_col);
152 let row = chunk.get_row(row_idx);
153 lookup.insert(key, row);
154 }
155 }
156
157 lookup
158}
159
160fn bfs_expand(
163 src: &TypedValue,
164 adj: &HashMap<TypedValue, Vec<TypedValue>>,
165 min_hops: u32,
166 max_hops: u32,
167) -> Vec<TypedValue> {
168 let mut visited: HashSet<TypedValue> = HashSet::new();
169 visited.insert(src.clone());
170
171 let mut queue: VecDeque<(TypedValue, u32)> = VecDeque::new();
173 queue.push_back((src.clone(), 0));
174
175 let mut results = Vec::new();
176
177 while let Some((node, depth)) = queue.pop_front() {
178 if depth >= max_hops {
179 continue;
180 }
181 if let Some(neighbors) = adj.get(&node) {
182 for neighbor in neighbors {
183 if visited.insert(neighbor.clone()) {
184 let next_depth = depth + 1;
185 if next_depth >= min_hops {
186 results.push(neighbor.clone());
187 }
188 queue.push_back((neighbor.clone(), next_depth));
189 }
190 }
191 }
192 }
193
194 results
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200 use crate::context::MockStorage;
201 use crate::operators::scan::ScanNodeOp;
202 use kyu_catalog::{CatalogContent, NodeTableEntry, Property, RelTableEntry};
203 use kyu_common::id::PropertyId;
204 use kyu_types::LogicalType;
205 use smol_str::SmolStr;
206
207 fn make_catalog() -> CatalogContent {
208 let mut catalog = CatalogContent::new();
209 catalog
210 .add_node_table(NodeTableEntry {
211 table_id: TableId(0),
212 name: SmolStr::new("Person"),
213 properties: vec![
214 Property::new(PropertyId(0), "name", LogicalType::String, true),
215 Property::new(PropertyId(1), "age", LogicalType::Int64, false),
216 ],
217 primary_key_idx: 0,
218 num_rows: 0,
219 comment: None,
220 })
221 .unwrap();
222 catalog
223 .add_rel_table(RelTableEntry {
224 table_id: TableId(1),
225 name: SmolStr::new("KNOWS"),
226 from_table_id: TableId(0),
227 to_table_id: TableId(0),
228 properties: vec![Property::new(
229 PropertyId(2),
230 "since",
231 LogicalType::Int64,
232 false,
233 )],
234 num_rows: 0,
235 comment: None,
236 })
237 .unwrap();
238 catalog
239 }
240
241 fn make_storage() -> MockStorage {
242 let mut storage = MockStorage::new();
243 storage.insert_table(
245 TableId(0),
246 vec![
247 vec![TypedValue::String(SmolStr::new("Alice")), TypedValue::Int64(25)],
248 vec![TypedValue::String(SmolStr::new("Bob")), TypedValue::Int64(30)],
249 vec![TypedValue::String(SmolStr::new("Charlie")), TypedValue::Int64(35)],
250 vec![TypedValue::String(SmolStr::new("Diana")), TypedValue::Int64(28)],
251 ],
252 );
253 storage.insert_table(
256 TableId(1),
257 vec![
258 vec![
259 TypedValue::String(SmolStr::new("Alice")),
260 TypedValue::String(SmolStr::new("Bob")),
261 TypedValue::Int64(2020),
262 ],
263 vec![
264 TypedValue::String(SmolStr::new("Bob")),
265 TypedValue::String(SmolStr::new("Charlie")),
266 TypedValue::Int64(2021),
267 ],
268 vec![
269 TypedValue::String(SmolStr::new("Charlie")),
270 TypedValue::String(SmolStr::new("Diana")),
271 TypedValue::Int64(2022),
272 ],
273 ],
274 );
275 storage
276 }
277
278 #[test]
279 fn recursive_join_1_hop() {
280 let storage = make_storage();
281 let ctx = ExecutionContext::new(make_catalog(), &storage);
282
283 let adj = build_adjacency_map(&ctx, TableId(1), Direction::Right);
285 assert!(adj.contains_key(&TypedValue::String(SmolStr::new("Alice"))));
286
287 let reachable = bfs_expand(
289 &TypedValue::String(SmolStr::new("Alice")),
290 &adj,
291 1,
292 1,
293 );
294 assert_eq!(reachable.len(), 1);
295 assert_eq!(reachable[0], TypedValue::String(SmolStr::new("Bob")));
296 }
297
298 #[test]
299 fn recursive_join_2_hops() {
300 let storage = make_storage();
301 let ctx = ExecutionContext::new(make_catalog(), &storage);
302
303 let adj = build_adjacency_map(&ctx, TableId(1), Direction::Right);
304 let reachable = bfs_expand(
305 &TypedValue::String(SmolStr::new("Alice")),
306 &adj,
307 1,
308 2,
309 );
310 assert_eq!(reachable.len(), 2);
312 }
313
314 #[test]
315 fn recursive_join_3_hops() {
316 let storage = make_storage();
317 let ctx = ExecutionContext::new(make_catalog(), &storage);
318
319 let adj = build_adjacency_map(&ctx, TableId(1), Direction::Right);
320 let reachable = bfs_expand(
321 &TypedValue::String(SmolStr::new("Alice")),
322 &adj,
323 1,
324 3,
325 );
326 assert_eq!(reachable.len(), 3);
328 }
329
330 #[test]
331 fn recursive_join_min_2() {
332 let storage = make_storage();
333 let ctx = ExecutionContext::new(make_catalog(), &storage);
334
335 let adj = build_adjacency_map(&ctx, TableId(1), Direction::Right);
336 let reachable = bfs_expand(
337 &TypedValue::String(SmolStr::new("Alice")),
338 &adj,
339 2,
340 3,
341 );
342 assert_eq!(reachable.len(), 2);
344 }
345
346 #[test]
347 fn recursive_join_operator() {
348 let storage = make_storage();
349 let ctx = ExecutionContext::new(make_catalog(), &storage);
350
351 let scan = PhysicalOperator::ScanNode(ScanNodeOp::new(TableId(0)));
353 let mut rj = RecursiveJoinOp::new(
354 scan,
355 RecursiveJoinConfig {
356 rel_table_id: TableId(1),
357 dest_table_id: TableId(0),
358 direction: Direction::Right,
359 min_hops: 1,
360 max_hops: 1,
361 src_key_col: 0,
362 dest_key_col: 0,
363 dest_ncols: 2,
364 },
365 );
366
367 let chunk = rj.next(&ctx).unwrap().unwrap();
368 assert_eq!(chunk.num_rows(), 3);
371 assert_eq!(chunk.num_columns(), 4);
373
374 assert_eq!(chunk.get_value(0, 0), TypedValue::String(SmolStr::new("Alice")));
376 assert_eq!(chunk.get_value(0, 2), TypedValue::String(SmolStr::new("Bob")));
377
378 assert!(rj.next(&ctx).unwrap().is_none());
380 }
381
382 #[test]
383 fn recursive_join_both_direction() {
384 let storage = make_storage();
385 let ctx = ExecutionContext::new(make_catalog(), &storage);
386
387 let adj = build_adjacency_map(&ctx, TableId(1), Direction::Both);
388 let reachable = bfs_expand(
390 &TypedValue::String(SmolStr::new("Bob")),
391 &adj,
392 1,
393 1,
394 );
395 assert_eq!(reachable.len(), 2);
396 }
397}