Skip to main content

kyu_executor/operators/
recursive_join.rs

1//! Recursive join operator — variable-length path traversal via BFS.
2//!
3//! Scans relationship table to build an adjacency map, then BFS-expands from
4//! each source node for `min_hops..=max_hops` levels. Joins reachable
5//! destination nodes with the dest node table to produce combined rows.
6
7use std::collections::{HashMap, HashSet, VecDeque};
8
9use kyu_common::KyuResult;
10use kyu_common::id::TableId;
11use kyu_parser::ast::Direction;
12use kyu_types::TypedValue;
13
14use crate::context::ExecutionContext;
15use crate::data_chunk::DataChunk;
16use crate::physical_plan::PhysicalOperator;
17
18/// Configuration for a recursive join (avoids too-many-arguments).
19pub struct RecursiveJoinConfig {
20    pub rel_table_id: TableId,
21    pub dest_table_id: TableId,
22    pub direction: Direction,
23    pub min_hops: u32,
24    pub max_hops: u32,
25    /// Column in child rows that holds the source node's primary key.
26    pub src_key_col: usize,
27    /// Column in dest node table that holds the primary key.
28    pub dest_key_col: usize,
29    /// Number of columns in destination node table.
30    pub dest_ncols: usize,
31}
32
33pub struct RecursiveJoinOp {
34    pub child: Box<PhysicalOperator>,
35    pub cfg: RecursiveJoinConfig,
36    /// Buffered result chunks.
37    results: Option<VecDeque<DataChunk>>,
38}
39
40impl RecursiveJoinOp {
41    pub fn new(child: PhysicalOperator, cfg: RecursiveJoinConfig) -> Self {
42        Self {
43            child: Box::new(child),
44            cfg,
45            results: None,
46        }
47    }
48
49    pub fn next(&mut self, ctx: &ExecutionContext<'_>) -> KyuResult<Option<DataChunk>> {
50        if self.results.is_none() {
51            self.results = Some(self.execute(ctx)?);
52        }
53        Ok(self.results.as_mut().unwrap().pop_front())
54    }
55
56    fn execute(&mut self, ctx: &ExecutionContext<'_>) -> KyuResult<VecDeque<DataChunk>> {
57        // 1. Drain child to collect source rows.
58        let mut source_rows: Vec<Vec<TypedValue>> = Vec::new();
59        while let Some(chunk) = self.child.next(ctx)? {
60            for row_idx in 0..chunk.num_rows() {
61                source_rows.push(chunk.get_row(row_idx));
62            }
63        }
64
65        // 2. Build adjacency map from relationship table.
66        let adj = build_adjacency_map(ctx, self.cfg.rel_table_id, self.cfg.direction);
67
68        // 3. Build dest node lookup: primary_key -> full row.
69        let dest_lookup = build_node_lookup(ctx, self.cfg.dest_table_id, self.cfg.dest_key_col);
70
71        // 4. BFS from each source, collect result rows.
72        let src_ncols = source_rows.first().map_or(0, |r| r.len());
73        let total_cols = src_ncols + self.cfg.dest_ncols;
74        let mut result_rows: Vec<Vec<TypedValue>> = Vec::new();
75
76        for src_row in &source_rows {
77            let src_key = &src_row[self.cfg.src_key_col];
78            let reachable = bfs_expand(src_key, &adj, self.cfg.min_hops, self.cfg.max_hops);
79
80            for dest_key in reachable {
81                let mut combined = src_row.clone();
82                if let Some(dest_row) = dest_lookup.get(&dest_key) {
83                    combined.extend_from_slice(dest_row);
84                } else {
85                    combined.extend(std::iter::repeat_n(TypedValue::Null, self.cfg.dest_ncols));
86                }
87                result_rows.push(combined);
88            }
89        }
90
91        // 5. Convert to DataChunks (batch of up to 2048 rows).
92        let mut chunks = VecDeque::new();
93        let chunk_size = 2048;
94        for batch in result_rows.chunks(chunk_size) {
95            chunks.push_back(DataChunk::from_rows(batch, total_cols));
96        }
97
98        Ok(chunks)
99    }
100}
101
102/// Build adjacency map: src -> [dst] from a relationship table.
103///
104/// For Right direction: column 0 = src, column 1 = dst.
105/// For Left direction: column 1 = src, column 0 = dst (reversed).
106/// For Both: both directions.
107pub fn build_adjacency_map(
108    ctx: &ExecutionContext<'_>,
109    rel_table_id: TableId,
110    direction: Direction,
111) -> HashMap<TypedValue, Vec<TypedValue>> {
112    let mut adj: HashMap<TypedValue, Vec<TypedValue>> = HashMap::new();
113
114    for chunk in ctx.storage.scan_table(rel_table_id) {
115        for row_idx in 0..chunk.num_rows() {
116            let col0 = chunk.get_value(row_idx, 0);
117            let col1 = chunk.get_value(row_idx, 1);
118
119            match direction {
120                Direction::Right => {
121                    adj.entry(col0).or_default().push(col1);
122                }
123                Direction::Left => {
124                    adj.entry(col1).or_default().push(col0);
125                }
126                Direction::Both => {
127                    adj.entry(col0.clone()).or_default().push(col1.clone());
128                    adj.entry(col1).or_default().push(col0);
129                }
130            }
131        }
132    }
133
134    adj
135}
136
137/// Build a lookup table: node primary key -> full row.
138fn build_node_lookup(
139    ctx: &ExecutionContext<'_>,
140    table_id: TableId,
141    key_col: usize,
142) -> HashMap<TypedValue, Vec<TypedValue>> {
143    let mut lookup = HashMap::new();
144
145    for chunk in ctx.storage.scan_table(table_id) {
146        for row_idx in 0..chunk.num_rows() {
147            let key = chunk.get_value(row_idx, key_col);
148            let row = chunk.get_row(row_idx);
149            lookup.insert(key, row);
150        }
151    }
152
153    lookup
154}
155
156/// BFS expansion from a source node through the adjacency map.
157/// Returns all distinct nodes reachable in min_hops..=max_hops steps.
158fn bfs_expand(
159    src: &TypedValue,
160    adj: &HashMap<TypedValue, Vec<TypedValue>>,
161    min_hops: u32,
162    max_hops: u32,
163) -> Vec<TypedValue> {
164    let mut visited: HashSet<TypedValue> = HashSet::new();
165    visited.insert(src.clone());
166
167    // BFS frontier: (node, depth)
168    let mut queue: VecDeque<(TypedValue, u32)> = VecDeque::new();
169    queue.push_back((src.clone(), 0));
170
171    let mut results = Vec::new();
172
173    while let Some((node, depth)) = queue.pop_front() {
174        if depth >= max_hops {
175            continue;
176        }
177        if let Some(neighbors) = adj.get(&node) {
178            for neighbor in neighbors {
179                if visited.insert(neighbor.clone()) {
180                    let next_depth = depth + 1;
181                    if next_depth >= min_hops {
182                        results.push(neighbor.clone());
183                    }
184                    queue.push_back((neighbor.clone(), next_depth));
185                }
186            }
187        }
188    }
189
190    results
191}
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196    use crate::context::MockStorage;
197    use crate::operators::scan::ScanNodeOp;
198    use kyu_catalog::{CatalogContent, NodeTableEntry, Property, RelTableEntry};
199    use kyu_common::id::PropertyId;
200    use kyu_types::LogicalType;
201    use smol_str::SmolStr;
202
203    fn make_catalog() -> CatalogContent {
204        let mut catalog = CatalogContent::new();
205        catalog
206            .add_node_table(NodeTableEntry {
207                table_id: TableId(0),
208                name: SmolStr::new("Person"),
209                properties: vec![
210                    Property::new(PropertyId(0), "name", LogicalType::String, true),
211                    Property::new(PropertyId(1), "age", LogicalType::Int64, false),
212                ],
213                primary_key_idx: 0,
214                num_rows: 0,
215                comment: None,
216            })
217            .unwrap();
218        catalog
219            .add_rel_table(RelTableEntry {
220                table_id: TableId(1),
221                name: SmolStr::new("KNOWS"),
222                from_table_id: TableId(0),
223                to_table_id: TableId(0),
224                properties: vec![Property::new(
225                    PropertyId(2),
226                    "since",
227                    LogicalType::Int64,
228                    false,
229                )],
230                num_rows: 0,
231                comment: None,
232            })
233            .unwrap();
234        catalog
235    }
236
237    fn make_storage() -> MockStorage {
238        let mut storage = MockStorage::new();
239        // Person: name, age
240        storage.insert_table(
241            TableId(0),
242            vec![
243                vec![
244                    TypedValue::String(SmolStr::new("Alice")),
245                    TypedValue::Int64(25),
246                ],
247                vec![
248                    TypedValue::String(SmolStr::new("Bob")),
249                    TypedValue::Int64(30),
250                ],
251                vec![
252                    TypedValue::String(SmolStr::new("Charlie")),
253                    TypedValue::Int64(35),
254                ],
255                vec![
256                    TypedValue::String(SmolStr::new("Diana")),
257                    TypedValue::Int64(28),
258                ],
259            ],
260        );
261        // KNOWS: src, dst, since
262        // Alice -> Bob, Bob -> Charlie, Charlie -> Diana
263        storage.insert_table(
264            TableId(1),
265            vec![
266                vec![
267                    TypedValue::String(SmolStr::new("Alice")),
268                    TypedValue::String(SmolStr::new("Bob")),
269                    TypedValue::Int64(2020),
270                ],
271                vec![
272                    TypedValue::String(SmolStr::new("Bob")),
273                    TypedValue::String(SmolStr::new("Charlie")),
274                    TypedValue::Int64(2021),
275                ],
276                vec![
277                    TypedValue::String(SmolStr::new("Charlie")),
278                    TypedValue::String(SmolStr::new("Diana")),
279                    TypedValue::Int64(2022),
280                ],
281            ],
282        );
283        storage
284    }
285
286    #[test]
287    fn recursive_join_1_hop() {
288        let storage = make_storage();
289        let ctx = ExecutionContext::new(make_catalog(), &storage);
290
291        // Build adjacency from real storage.
292        let adj = build_adjacency_map(&ctx, TableId(1), Direction::Right);
293        assert!(adj.contains_key(&TypedValue::String(SmolStr::new("Alice"))));
294
295        // BFS from Alice, 1..1 hop.
296        let reachable = bfs_expand(&TypedValue::String(SmolStr::new("Alice")), &adj, 1, 1);
297        assert_eq!(reachable.len(), 1);
298        assert_eq!(reachable[0], TypedValue::String(SmolStr::new("Bob")));
299    }
300
301    #[test]
302    fn recursive_join_2_hops() {
303        let storage = make_storage();
304        let ctx = ExecutionContext::new(make_catalog(), &storage);
305
306        let adj = build_adjacency_map(&ctx, TableId(1), Direction::Right);
307        let reachable = bfs_expand(&TypedValue::String(SmolStr::new("Alice")), &adj, 1, 2);
308        // 1 hop: Bob, 2 hops: Charlie → 2 results.
309        assert_eq!(reachable.len(), 2);
310    }
311
312    #[test]
313    fn recursive_join_3_hops() {
314        let storage = make_storage();
315        let ctx = ExecutionContext::new(make_catalog(), &storage);
316
317        let adj = build_adjacency_map(&ctx, TableId(1), Direction::Right);
318        let reachable = bfs_expand(&TypedValue::String(SmolStr::new("Alice")), &adj, 1, 3);
319        // 1: Bob, 2: Charlie, 3: Diana → 3 results.
320        assert_eq!(reachable.len(), 3);
321    }
322
323    #[test]
324    fn recursive_join_min_2() {
325        let storage = make_storage();
326        let ctx = ExecutionContext::new(make_catalog(), &storage);
327
328        let adj = build_adjacency_map(&ctx, TableId(1), Direction::Right);
329        let reachable = bfs_expand(&TypedValue::String(SmolStr::new("Alice")), &adj, 2, 3);
330        // min=2: skip Bob, get Charlie (2 hops) and Diana (3 hops).
331        assert_eq!(reachable.len(), 2);
332    }
333
334    #[test]
335    fn recursive_join_operator() {
336        let storage = make_storage();
337        let ctx = ExecutionContext::new(make_catalog(), &storage);
338
339        // Scan all persons as source.
340        let scan = PhysicalOperator::ScanNode(ScanNodeOp::new(TableId(0)));
341        let mut rj = RecursiveJoinOp::new(
342            scan,
343            RecursiveJoinConfig {
344                rel_table_id: TableId(1),
345                dest_table_id: TableId(0),
346                direction: Direction::Right,
347                min_hops: 1,
348                max_hops: 1,
349                src_key_col: 0,
350                dest_key_col: 0,
351                dest_ncols: 2,
352            },
353        );
354
355        let chunk = rj.next(&ctx).unwrap().unwrap();
356        // Alice->Bob, Bob->Charlie, Charlie->Diana = 3 result rows.
357        // Diana has no outgoing edges, so no results.
358        assert_eq!(chunk.num_rows(), 3);
359        // 4 columns: src.name, src.age, dest.name, dest.age
360        assert_eq!(chunk.num_columns(), 4);
361
362        // Verify Alice -> Bob
363        assert_eq!(
364            chunk.get_value(0, 0),
365            TypedValue::String(SmolStr::new("Alice"))
366        );
367        assert_eq!(
368            chunk.get_value(0, 2),
369            TypedValue::String(SmolStr::new("Bob"))
370        );
371
372        // No more chunks.
373        assert!(rj.next(&ctx).unwrap().is_none());
374    }
375
376    #[test]
377    fn recursive_join_both_direction() {
378        let storage = make_storage();
379        let ctx = ExecutionContext::new(make_catalog(), &storage);
380
381        let adj = build_adjacency_map(&ctx, TableId(1), Direction::Both);
382        // Bob with Both direction, 1 hop: Alice + Charlie.
383        let reachable = bfs_expand(&TypedValue::String(SmolStr::new("Bob")), &adj, 1, 1);
384        assert_eq!(reachable.len(), 2);
385    }
386}