1use std::collections::{HashMap, HashSet, VecDeque};
8
9use kyu_common::KyuResult;
10use kyu_common::id::TableId;
11use kyu_parser::ast::Direction;
12use kyu_types::TypedValue;
13
14use crate::context::ExecutionContext;
15use crate::data_chunk::DataChunk;
16use crate::physical_plan::PhysicalOperator;
17
18pub struct RecursiveJoinConfig {
20 pub rel_table_id: TableId,
21 pub dest_table_id: TableId,
22 pub direction: Direction,
23 pub min_hops: u32,
24 pub max_hops: u32,
25 pub src_key_col: usize,
27 pub dest_key_col: usize,
29 pub dest_ncols: usize,
31}
32
33pub struct RecursiveJoinOp {
34 pub child: Box<PhysicalOperator>,
35 pub cfg: RecursiveJoinConfig,
36 results: Option<VecDeque<DataChunk>>,
38}
39
40impl RecursiveJoinOp {
41 pub fn new(child: PhysicalOperator, cfg: RecursiveJoinConfig) -> Self {
42 Self {
43 child: Box::new(child),
44 cfg,
45 results: None,
46 }
47 }
48
49 pub fn next(&mut self, ctx: &ExecutionContext<'_>) -> KyuResult<Option<DataChunk>> {
50 if self.results.is_none() {
51 self.results = Some(self.execute(ctx)?);
52 }
53 Ok(self.results.as_mut().unwrap().pop_front())
54 }
55
56 fn execute(&mut self, ctx: &ExecutionContext<'_>) -> KyuResult<VecDeque<DataChunk>> {
57 let mut source_rows: Vec<Vec<TypedValue>> = Vec::new();
59 while let Some(chunk) = self.child.next(ctx)? {
60 for row_idx in 0..chunk.num_rows() {
61 source_rows.push(chunk.get_row(row_idx));
62 }
63 }
64
65 let adj = build_adjacency_map(ctx, self.cfg.rel_table_id, self.cfg.direction);
67
68 let dest_lookup = build_node_lookup(ctx, self.cfg.dest_table_id, self.cfg.dest_key_col);
70
71 let src_ncols = source_rows.first().map_or(0, |r| r.len());
73 let total_cols = src_ncols + self.cfg.dest_ncols;
74 let mut result_rows: Vec<Vec<TypedValue>> = Vec::new();
75
76 for src_row in &source_rows {
77 let src_key = &src_row[self.cfg.src_key_col];
78 let reachable = bfs_expand(src_key, &adj, self.cfg.min_hops, self.cfg.max_hops);
79
80 for dest_key in reachable {
81 let mut combined = src_row.clone();
82 if let Some(dest_row) = dest_lookup.get(&dest_key) {
83 combined.extend_from_slice(dest_row);
84 } else {
85 combined.extend(std::iter::repeat_n(TypedValue::Null, self.cfg.dest_ncols));
86 }
87 result_rows.push(combined);
88 }
89 }
90
91 let mut chunks = VecDeque::new();
93 let chunk_size = 2048;
94 for batch in result_rows.chunks(chunk_size) {
95 chunks.push_back(DataChunk::from_rows(batch, total_cols));
96 }
97
98 Ok(chunks)
99 }
100}
101
102pub fn build_adjacency_map(
108 ctx: &ExecutionContext<'_>,
109 rel_table_id: TableId,
110 direction: Direction,
111) -> HashMap<TypedValue, Vec<TypedValue>> {
112 let mut adj: HashMap<TypedValue, Vec<TypedValue>> = HashMap::new();
113
114 for chunk in ctx.storage.scan_table(rel_table_id) {
115 for row_idx in 0..chunk.num_rows() {
116 let col0 = chunk.get_value(row_idx, 0);
117 let col1 = chunk.get_value(row_idx, 1);
118
119 match direction {
120 Direction::Right => {
121 adj.entry(col0).or_default().push(col1);
122 }
123 Direction::Left => {
124 adj.entry(col1).or_default().push(col0);
125 }
126 Direction::Both => {
127 adj.entry(col0.clone()).or_default().push(col1.clone());
128 adj.entry(col1).or_default().push(col0);
129 }
130 }
131 }
132 }
133
134 adj
135}
136
137fn build_node_lookup(
139 ctx: &ExecutionContext<'_>,
140 table_id: TableId,
141 key_col: usize,
142) -> HashMap<TypedValue, Vec<TypedValue>> {
143 let mut lookup = HashMap::new();
144
145 for chunk in ctx.storage.scan_table(table_id) {
146 for row_idx in 0..chunk.num_rows() {
147 let key = chunk.get_value(row_idx, key_col);
148 let row = chunk.get_row(row_idx);
149 lookup.insert(key, row);
150 }
151 }
152
153 lookup
154}
155
156fn bfs_expand(
159 src: &TypedValue,
160 adj: &HashMap<TypedValue, Vec<TypedValue>>,
161 min_hops: u32,
162 max_hops: u32,
163) -> Vec<TypedValue> {
164 let mut visited: HashSet<TypedValue> = HashSet::new();
165 visited.insert(src.clone());
166
167 let mut queue: VecDeque<(TypedValue, u32)> = VecDeque::new();
169 queue.push_back((src.clone(), 0));
170
171 let mut results = Vec::new();
172
173 while let Some((node, depth)) = queue.pop_front() {
174 if depth >= max_hops {
175 continue;
176 }
177 if let Some(neighbors) = adj.get(&node) {
178 for neighbor in neighbors {
179 if visited.insert(neighbor.clone()) {
180 let next_depth = depth + 1;
181 if next_depth >= min_hops {
182 results.push(neighbor.clone());
183 }
184 queue.push_back((neighbor.clone(), next_depth));
185 }
186 }
187 }
188 }
189
190 results
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196 use crate::context::MockStorage;
197 use crate::operators::scan::ScanNodeOp;
198 use kyu_catalog::{CatalogContent, NodeTableEntry, Property, RelTableEntry};
199 use kyu_common::id::PropertyId;
200 use kyu_types::LogicalType;
201 use smol_str::SmolStr;
202
203 fn make_catalog() -> CatalogContent {
204 let mut catalog = CatalogContent::new();
205 catalog
206 .add_node_table(NodeTableEntry {
207 table_id: TableId(0),
208 name: SmolStr::new("Person"),
209 properties: vec![
210 Property::new(PropertyId(0), "name", LogicalType::String, true),
211 Property::new(PropertyId(1), "age", LogicalType::Int64, false),
212 ],
213 primary_key_idx: 0,
214 num_rows: 0,
215 comment: None,
216 })
217 .unwrap();
218 catalog
219 .add_rel_table(RelTableEntry {
220 table_id: TableId(1),
221 name: SmolStr::new("KNOWS"),
222 from_table_id: TableId(0),
223 to_table_id: TableId(0),
224 properties: vec![Property::new(
225 PropertyId(2),
226 "since",
227 LogicalType::Int64,
228 false,
229 )],
230 num_rows: 0,
231 comment: None,
232 })
233 .unwrap();
234 catalog
235 }
236
237 fn make_storage() -> MockStorage {
238 let mut storage = MockStorage::new();
239 storage.insert_table(
241 TableId(0),
242 vec![
243 vec![
244 TypedValue::String(SmolStr::new("Alice")),
245 TypedValue::Int64(25),
246 ],
247 vec![
248 TypedValue::String(SmolStr::new("Bob")),
249 TypedValue::Int64(30),
250 ],
251 vec![
252 TypedValue::String(SmolStr::new("Charlie")),
253 TypedValue::Int64(35),
254 ],
255 vec![
256 TypedValue::String(SmolStr::new("Diana")),
257 TypedValue::Int64(28),
258 ],
259 ],
260 );
261 storage.insert_table(
264 TableId(1),
265 vec![
266 vec![
267 TypedValue::String(SmolStr::new("Alice")),
268 TypedValue::String(SmolStr::new("Bob")),
269 TypedValue::Int64(2020),
270 ],
271 vec![
272 TypedValue::String(SmolStr::new("Bob")),
273 TypedValue::String(SmolStr::new("Charlie")),
274 TypedValue::Int64(2021),
275 ],
276 vec![
277 TypedValue::String(SmolStr::new("Charlie")),
278 TypedValue::String(SmolStr::new("Diana")),
279 TypedValue::Int64(2022),
280 ],
281 ],
282 );
283 storage
284 }
285
286 #[test]
287 fn recursive_join_1_hop() {
288 let storage = make_storage();
289 let ctx = ExecutionContext::new(make_catalog(), &storage);
290
291 let adj = build_adjacency_map(&ctx, TableId(1), Direction::Right);
293 assert!(adj.contains_key(&TypedValue::String(SmolStr::new("Alice"))));
294
295 let reachable = bfs_expand(&TypedValue::String(SmolStr::new("Alice")), &adj, 1, 1);
297 assert_eq!(reachable.len(), 1);
298 assert_eq!(reachable[0], TypedValue::String(SmolStr::new("Bob")));
299 }
300
301 #[test]
302 fn recursive_join_2_hops() {
303 let storage = make_storage();
304 let ctx = ExecutionContext::new(make_catalog(), &storage);
305
306 let adj = build_adjacency_map(&ctx, TableId(1), Direction::Right);
307 let reachable = bfs_expand(&TypedValue::String(SmolStr::new("Alice")), &adj, 1, 2);
308 assert_eq!(reachable.len(), 2);
310 }
311
312 #[test]
313 fn recursive_join_3_hops() {
314 let storage = make_storage();
315 let ctx = ExecutionContext::new(make_catalog(), &storage);
316
317 let adj = build_adjacency_map(&ctx, TableId(1), Direction::Right);
318 let reachable = bfs_expand(&TypedValue::String(SmolStr::new("Alice")), &adj, 1, 3);
319 assert_eq!(reachable.len(), 3);
321 }
322
323 #[test]
324 fn recursive_join_min_2() {
325 let storage = make_storage();
326 let ctx = ExecutionContext::new(make_catalog(), &storage);
327
328 let adj = build_adjacency_map(&ctx, TableId(1), Direction::Right);
329 let reachable = bfs_expand(&TypedValue::String(SmolStr::new("Alice")), &adj, 2, 3);
330 assert_eq!(reachable.len(), 2);
332 }
333
334 #[test]
335 fn recursive_join_operator() {
336 let storage = make_storage();
337 let ctx = ExecutionContext::new(make_catalog(), &storage);
338
339 let scan = PhysicalOperator::ScanNode(ScanNodeOp::new(TableId(0)));
341 let mut rj = RecursiveJoinOp::new(
342 scan,
343 RecursiveJoinConfig {
344 rel_table_id: TableId(1),
345 dest_table_id: TableId(0),
346 direction: Direction::Right,
347 min_hops: 1,
348 max_hops: 1,
349 src_key_col: 0,
350 dest_key_col: 0,
351 dest_ncols: 2,
352 },
353 );
354
355 let chunk = rj.next(&ctx).unwrap().unwrap();
356 assert_eq!(chunk.num_rows(), 3);
359 assert_eq!(chunk.num_columns(), 4);
361
362 assert_eq!(
364 chunk.get_value(0, 0),
365 TypedValue::String(SmolStr::new("Alice"))
366 );
367 assert_eq!(
368 chunk.get_value(0, 2),
369 TypedValue::String(SmolStr::new("Bob"))
370 );
371
372 assert!(rj.next(&ctx).unwrap().is_none());
374 }
375
376 #[test]
377 fn recursive_join_both_direction() {
378 let storage = make_storage();
379 let ctx = ExecutionContext::new(make_catalog(), &storage);
380
381 let adj = build_adjacency_map(&ctx, TableId(1), Direction::Both);
382 let reachable = bfs_expand(&TypedValue::String(SmolStr::new("Bob")), &adj, 1, 1);
384 assert_eq!(reachable.len(), 2);
385 }
386}