Skip to main content

kyu_executor/operators/
recursive_join.rs

1//! Recursive join operator — variable-length path traversal via BFS.
2//!
3//! Scans relationship table to build an adjacency map, then BFS-expands from
4//! each source node for `min_hops..=max_hops` levels. Joins reachable
5//! destination nodes with the dest node table to produce combined rows.
6
7use std::collections::{HashMap, HashSet, VecDeque};
8
9use kyu_common::id::TableId;
10use kyu_common::KyuResult;
11use kyu_parser::ast::Direction;
12use kyu_types::TypedValue;
13
14use crate::context::ExecutionContext;
15use crate::data_chunk::DataChunk;
16use crate::physical_plan::PhysicalOperator;
17
18/// Configuration for a recursive join (avoids too-many-arguments).
19pub struct RecursiveJoinConfig {
20    pub rel_table_id: TableId,
21    pub dest_table_id: TableId,
22    pub direction: Direction,
23    pub min_hops: u32,
24    pub max_hops: u32,
25    /// Column in child rows that holds the source node's primary key.
26    pub src_key_col: usize,
27    /// Column in dest node table that holds the primary key.
28    pub dest_key_col: usize,
29    /// Number of columns in destination node table.
30    pub dest_ncols: usize,
31}
32
33pub struct RecursiveJoinOp {
34    pub child: Box<PhysicalOperator>,
35    pub cfg: RecursiveJoinConfig,
36    /// Buffered result chunks.
37    results: Option<VecDeque<DataChunk>>,
38}
39
40impl RecursiveJoinOp {
41    pub fn new(child: PhysicalOperator, cfg: RecursiveJoinConfig) -> Self {
42        Self {
43            child: Box::new(child),
44            cfg,
45            results: None,
46        }
47    }
48
49    pub fn next(&mut self, ctx: &ExecutionContext<'_>) -> KyuResult<Option<DataChunk>> {
50        if self.results.is_none() {
51            self.results = Some(self.execute(ctx)?);
52        }
53        Ok(self.results.as_mut().unwrap().pop_front())
54    }
55
56    fn execute(&mut self, ctx: &ExecutionContext<'_>) -> KyuResult<VecDeque<DataChunk>> {
57        // 1. Drain child to collect source rows.
58        let mut source_rows: Vec<Vec<TypedValue>> = Vec::new();
59        while let Some(chunk) = self.child.next(ctx)? {
60            for row_idx in 0..chunk.num_rows() {
61                source_rows.push(chunk.get_row(row_idx));
62            }
63        }
64
65        // 2. Build adjacency map from relationship table.
66        let adj = build_adjacency_map(ctx, self.cfg.rel_table_id, self.cfg.direction);
67
68        // 3. Build dest node lookup: primary_key -> full row.
69        let dest_lookup =
70            build_node_lookup(ctx, self.cfg.dest_table_id, self.cfg.dest_key_col);
71
72        // 4. BFS from each source, collect result rows.
73        let src_ncols = source_rows.first().map_or(0, |r| r.len());
74        let total_cols = src_ncols + self.cfg.dest_ncols;
75        let mut result_rows: Vec<Vec<TypedValue>> = Vec::new();
76
77        for src_row in &source_rows {
78            let src_key = &src_row[self.cfg.src_key_col];
79            let reachable =
80                bfs_expand(src_key, &adj, self.cfg.min_hops, self.cfg.max_hops);
81
82            for dest_key in reachable {
83                let mut combined = src_row.clone();
84                if let Some(dest_row) = dest_lookup.get(&dest_key) {
85                    combined.extend_from_slice(dest_row);
86                } else {
87                    combined.extend(
88                        std::iter::repeat_n(TypedValue::Null, self.cfg.dest_ncols),
89                    );
90                }
91                result_rows.push(combined);
92            }
93        }
94
95        // 5. Convert to DataChunks (batch of up to 2048 rows).
96        let mut chunks = VecDeque::new();
97        let chunk_size = 2048;
98        for batch in result_rows.chunks(chunk_size) {
99            chunks.push_back(DataChunk::from_rows(batch, total_cols));
100        }
101
102        Ok(chunks)
103    }
104}
105
106/// Build adjacency map: src -> [dst] from a relationship table.
107///
108/// For Right direction: column 0 = src, column 1 = dst.
109/// For Left direction: column 1 = src, column 0 = dst (reversed).
110/// For Both: both directions.
111pub fn build_adjacency_map(
112    ctx: &ExecutionContext<'_>,
113    rel_table_id: TableId,
114    direction: Direction,
115) -> HashMap<TypedValue, Vec<TypedValue>> {
116    let mut adj: HashMap<TypedValue, Vec<TypedValue>> = HashMap::new();
117
118    for chunk in ctx.storage.scan_table(rel_table_id) {
119        for row_idx in 0..chunk.num_rows() {
120            let col0 = chunk.get_value(row_idx, 0);
121            let col1 = chunk.get_value(row_idx, 1);
122
123            match direction {
124                Direction::Right => {
125                    adj.entry(col0).or_default().push(col1);
126                }
127                Direction::Left => {
128                    adj.entry(col1).or_default().push(col0);
129                }
130                Direction::Both => {
131                    adj.entry(col0.clone()).or_default().push(col1.clone());
132                    adj.entry(col1).or_default().push(col0);
133                }
134            }
135        }
136    }
137
138    adj
139}
140
141/// Build a lookup table: node primary key -> full row.
142fn build_node_lookup(
143    ctx: &ExecutionContext<'_>,
144    table_id: TableId,
145    key_col: usize,
146) -> HashMap<TypedValue, Vec<TypedValue>> {
147    let mut lookup = HashMap::new();
148
149    for chunk in ctx.storage.scan_table(table_id) {
150        for row_idx in 0..chunk.num_rows() {
151            let key = chunk.get_value(row_idx, key_col);
152            let row = chunk.get_row(row_idx);
153            lookup.insert(key, row);
154        }
155    }
156
157    lookup
158}
159
160/// BFS expansion from a source node through the adjacency map.
161/// Returns all distinct nodes reachable in min_hops..=max_hops steps.
162fn bfs_expand(
163    src: &TypedValue,
164    adj: &HashMap<TypedValue, Vec<TypedValue>>,
165    min_hops: u32,
166    max_hops: u32,
167) -> Vec<TypedValue> {
168    let mut visited: HashSet<TypedValue> = HashSet::new();
169    visited.insert(src.clone());
170
171    // BFS frontier: (node, depth)
172    let mut queue: VecDeque<(TypedValue, u32)> = VecDeque::new();
173    queue.push_back((src.clone(), 0));
174
175    let mut results = Vec::new();
176
177    while let Some((node, depth)) = queue.pop_front() {
178        if depth >= max_hops {
179            continue;
180        }
181        if let Some(neighbors) = adj.get(&node) {
182            for neighbor in neighbors {
183                if visited.insert(neighbor.clone()) {
184                    let next_depth = depth + 1;
185                    if next_depth >= min_hops {
186                        results.push(neighbor.clone());
187                    }
188                    queue.push_back((neighbor.clone(), next_depth));
189                }
190            }
191        }
192    }
193
194    results
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200    use crate::context::MockStorage;
201    use crate::operators::scan::ScanNodeOp;
202    use kyu_catalog::{CatalogContent, NodeTableEntry, Property, RelTableEntry};
203    use kyu_common::id::PropertyId;
204    use kyu_types::LogicalType;
205    use smol_str::SmolStr;
206
207    fn make_catalog() -> CatalogContent {
208        let mut catalog = CatalogContent::new();
209        catalog
210            .add_node_table(NodeTableEntry {
211                table_id: TableId(0),
212                name: SmolStr::new("Person"),
213                properties: vec![
214                    Property::new(PropertyId(0), "name", LogicalType::String, true),
215                    Property::new(PropertyId(1), "age", LogicalType::Int64, false),
216                ],
217                primary_key_idx: 0,
218                num_rows: 0,
219                comment: None,
220            })
221            .unwrap();
222        catalog
223            .add_rel_table(RelTableEntry {
224                table_id: TableId(1),
225                name: SmolStr::new("KNOWS"),
226                from_table_id: TableId(0),
227                to_table_id: TableId(0),
228                properties: vec![Property::new(
229                    PropertyId(2),
230                    "since",
231                    LogicalType::Int64,
232                    false,
233                )],
234                num_rows: 0,
235                comment: None,
236            })
237            .unwrap();
238        catalog
239    }
240
241    fn make_storage() -> MockStorage {
242        let mut storage = MockStorage::new();
243        // Person: name, age
244        storage.insert_table(
245            TableId(0),
246            vec![
247                vec![TypedValue::String(SmolStr::new("Alice")), TypedValue::Int64(25)],
248                vec![TypedValue::String(SmolStr::new("Bob")), TypedValue::Int64(30)],
249                vec![TypedValue::String(SmolStr::new("Charlie")), TypedValue::Int64(35)],
250                vec![TypedValue::String(SmolStr::new("Diana")), TypedValue::Int64(28)],
251            ],
252        );
253        // KNOWS: src, dst, since
254        // Alice -> Bob, Bob -> Charlie, Charlie -> Diana
255        storage.insert_table(
256            TableId(1),
257            vec![
258                vec![
259                    TypedValue::String(SmolStr::new("Alice")),
260                    TypedValue::String(SmolStr::new("Bob")),
261                    TypedValue::Int64(2020),
262                ],
263                vec![
264                    TypedValue::String(SmolStr::new("Bob")),
265                    TypedValue::String(SmolStr::new("Charlie")),
266                    TypedValue::Int64(2021),
267                ],
268                vec![
269                    TypedValue::String(SmolStr::new("Charlie")),
270                    TypedValue::String(SmolStr::new("Diana")),
271                    TypedValue::Int64(2022),
272                ],
273            ],
274        );
275        storage
276    }
277
278    #[test]
279    fn recursive_join_1_hop() {
280        let storage = make_storage();
281        let ctx = ExecutionContext::new(make_catalog(), &storage);
282
283        // Build adjacency from real storage.
284        let adj = build_adjacency_map(&ctx, TableId(1), Direction::Right);
285        assert!(adj.contains_key(&TypedValue::String(SmolStr::new("Alice"))));
286
287        // BFS from Alice, 1..1 hop.
288        let reachable = bfs_expand(
289            &TypedValue::String(SmolStr::new("Alice")),
290            &adj,
291            1,
292            1,
293        );
294        assert_eq!(reachable.len(), 1);
295        assert_eq!(reachable[0], TypedValue::String(SmolStr::new("Bob")));
296    }
297
298    #[test]
299    fn recursive_join_2_hops() {
300        let storage = make_storage();
301        let ctx = ExecutionContext::new(make_catalog(), &storage);
302
303        let adj = build_adjacency_map(&ctx, TableId(1), Direction::Right);
304        let reachable = bfs_expand(
305            &TypedValue::String(SmolStr::new("Alice")),
306            &adj,
307            1,
308            2,
309        );
310        // 1 hop: Bob, 2 hops: Charlie → 2 results.
311        assert_eq!(reachable.len(), 2);
312    }
313
314    #[test]
315    fn recursive_join_3_hops() {
316        let storage = make_storage();
317        let ctx = ExecutionContext::new(make_catalog(), &storage);
318
319        let adj = build_adjacency_map(&ctx, TableId(1), Direction::Right);
320        let reachable = bfs_expand(
321            &TypedValue::String(SmolStr::new("Alice")),
322            &adj,
323            1,
324            3,
325        );
326        // 1: Bob, 2: Charlie, 3: Diana → 3 results.
327        assert_eq!(reachable.len(), 3);
328    }
329
330    #[test]
331    fn recursive_join_min_2() {
332        let storage = make_storage();
333        let ctx = ExecutionContext::new(make_catalog(), &storage);
334
335        let adj = build_adjacency_map(&ctx, TableId(1), Direction::Right);
336        let reachable = bfs_expand(
337            &TypedValue::String(SmolStr::new("Alice")),
338            &adj,
339            2,
340            3,
341        );
342        // min=2: skip Bob, get Charlie (2 hops) and Diana (3 hops).
343        assert_eq!(reachable.len(), 2);
344    }
345
346    #[test]
347    fn recursive_join_operator() {
348        let storage = make_storage();
349        let ctx = ExecutionContext::new(make_catalog(), &storage);
350
351        // Scan all persons as source.
352        let scan = PhysicalOperator::ScanNode(ScanNodeOp::new(TableId(0)));
353        let mut rj = RecursiveJoinOp::new(
354            scan,
355            RecursiveJoinConfig {
356                rel_table_id: TableId(1),
357                dest_table_id: TableId(0),
358                direction: Direction::Right,
359                min_hops: 1,
360                max_hops: 1,
361                src_key_col: 0,
362                dest_key_col: 0,
363                dest_ncols: 2,
364            },
365        );
366
367        let chunk = rj.next(&ctx).unwrap().unwrap();
368        // Alice->Bob, Bob->Charlie, Charlie->Diana = 3 result rows.
369        // Diana has no outgoing edges, so no results.
370        assert_eq!(chunk.num_rows(), 3);
371        // 4 columns: src.name, src.age, dest.name, dest.age
372        assert_eq!(chunk.num_columns(), 4);
373
374        // Verify Alice -> Bob
375        assert_eq!(chunk.get_value(0, 0), TypedValue::String(SmolStr::new("Alice")));
376        assert_eq!(chunk.get_value(0, 2), TypedValue::String(SmolStr::new("Bob")));
377
378        // No more chunks.
379        assert!(rj.next(&ctx).unwrap().is_none());
380    }
381
382    #[test]
383    fn recursive_join_both_direction() {
384        let storage = make_storage();
385        let ctx = ExecutionContext::new(make_catalog(), &storage);
386
387        let adj = build_adjacency_map(&ctx, TableId(1), Direction::Both);
388        // Bob with Both direction, 1 hop: Alice + Charlie.
389        let reachable = bfs_expand(
390            &TypedValue::String(SmolStr::new("Bob")),
391            &adj,
392            1,
393            1,
394        );
395        assert_eq!(reachable.len(), 2);
396    }
397}