Skip to main content

grafeo_core/execution/operators/
leapfrog_join.rs

1//! Leapfrog TrieJoin operator for worst-case optimal joins.
2//!
3//! This operator wraps the `LeapfrogJoin` algorithm from the trie index module
4//! to provide efficient multi-way joins for cyclic patterns like triangles.
5//!
6//! Traditional binary hash joins cascade O(N²) for triangle patterns; leapfrog
7//! achieves O(N^1.5) by processing all relations simultaneously.
8
9use grafeo_common::types::{EdgeId, LogicalType, NodeId, Value};
10
11use super::{Operator, OperatorError, OperatorResult};
12use crate::execution::DataChunk;
13use crate::execution::chunk::DataChunkBuilder;
14use crate::index::trie::{LeapfrogJoin, TrieIndex};
15
16/// Row identifier for reconstructing output: (input_index, chunk_index, row_index).
17type RowId = (usize, usize, usize);
18
19/// A multi-way join intersection result.
20struct JoinResult {
21    /// The matched key value (intersection point).
22    /// Kept for debugging and potential future use in result verification.
23    #[allow(dead_code)]
24    key: NodeId,
25    /// Row identifiers from each input that participated in this match.
26    row_ids: Vec<Vec<RowId>>,
27}
28
29/// Leapfrog TrieJoin operator for worst-case optimal multi-way joins.
30///
31/// Uses the leapfrog algorithm to efficiently find intersections across
32/// multiple sorted inputs without materializing intermediate Cartesian products.
33pub struct LeapfrogJoinOperator {
34    /// Input operators (one per relation in the join).
35    inputs: Vec<Box<dyn Operator>>,
36
37    /// Column indices for join keys in each input.
38    /// Each inner Vec maps to one join variable.
39    join_key_indices: Vec<Vec<usize>>,
40
41    /// Output schema (combined columns from all inputs).
42    output_schema: Vec<LogicalType>,
43
44    /// Mapping from output column index to (input_idx, column_idx).
45    output_column_mapping: Vec<(usize, usize)>,
46
47    // === Materialization state ===
48    /// Materialized input chunks (built once during first next() call).
49    materialized_inputs: Vec<Vec<DataChunk>>,
50
51    /// TrieIndex structures built from materialized inputs.
52    tries: Vec<TrieIndex>,
53
54    /// Whether materialization is complete.
55    materialized: bool,
56
57    // === Iteration state ===
58    /// Pre-computed join results.
59    results: Vec<JoinResult>,
60
61    /// Current position in results.
62    result_position: usize,
63
64    /// Current expansion position within current result's cross product.
65    expansion_indices: Vec<usize>,
66
67    /// Whether iteration is exhausted.
68    exhausted: bool,
69}
70
71impl LeapfrogJoinOperator {
72    /// Creates a new leapfrog join operator.
73    ///
74    /// # Arguments
75    /// * `inputs` - Input operators (one per relation).
76    /// * `join_key_indices` - Column indices for join keys in each input.
77    /// * `output_schema` - Schema of the output columns.
78    /// * `output_column_mapping` - Maps output columns to (input_idx, column_idx).
79    #[must_use]
80    pub fn new(
81        inputs: Vec<Box<dyn Operator>>,
82        join_key_indices: Vec<Vec<usize>>,
83        output_schema: Vec<LogicalType>,
84        output_column_mapping: Vec<(usize, usize)>,
85    ) -> Self {
86        Self {
87            inputs,
88            join_key_indices,
89            output_schema,
90            output_column_mapping,
91            materialized_inputs: Vec::new(),
92            tries: Vec::new(),
93            materialized: false,
94            results: Vec::new(),
95            result_position: 0,
96            expansion_indices: Vec::new(),
97            exhausted: false,
98        }
99    }
100
101    /// Materializes all inputs and builds trie indexes.
102    fn materialize_inputs(&mut self) -> Result<(), OperatorError> {
103        // Phase 1: Collect all chunks from each input
104        for input in &mut self.inputs {
105            let mut chunks = Vec::new();
106            while let Some(chunk) = input.next()? {
107                chunks.push(chunk);
108            }
109            self.materialized_inputs.push(chunks);
110        }
111
112        // Phase 2: Build TrieIndex for each input
113        for (input_idx, chunks) in self.materialized_inputs.iter().enumerate() {
114            let mut trie = TrieIndex::new();
115            let key_indices = &self.join_key_indices[input_idx];
116
117            for (chunk_idx, chunk) in chunks.iter().enumerate() {
118                for row in 0..chunk.row_count() {
119                    // Extract join key values and convert to path
120                    if let Some(path) = self.extract_join_keys(chunk, row, key_indices) {
121                        // Encode row location as EdgeId for trie storage
122                        let row_id = Self::encode_row_id(input_idx, chunk_idx, row);
123                        trie.insert(&path, row_id);
124                    }
125                }
126            }
127            self.tries.push(trie);
128        }
129
130        self.materialized = true;
131        Ok(())
132    }
133
134    /// Extracts join key values from a row and converts to NodeId path.
135    fn extract_join_keys(
136        &self,
137        chunk: &DataChunk,
138        row: usize,
139        key_indices: &[usize],
140    ) -> Option<Vec<NodeId>> {
141        let mut path = Vec::with_capacity(key_indices.len());
142
143        for &col_idx in key_indices {
144            let col = chunk.column(col_idx)?;
145            let node_id = match col.data_type() {
146                LogicalType::Node => col.get_node_id(row),
147                LogicalType::Edge => col.get_edge_id(row).map(|e| NodeId::new(e.as_u64())),
148                LogicalType::Int64 => col.get_int64(row).map(|i| NodeId::new(i as u64)),
149                _ => return None, // Unsupported join key type
150            }?;
151            path.push(node_id);
152        }
153
154        Some(path)
155    }
156
157    /// Encodes a row location as an EdgeId for trie storage.
158    fn encode_row_id(input_idx: usize, chunk_idx: usize, row: usize) -> EdgeId {
159        // Pack: input (8 bits) | chunk (24 bits) | row (32 bits)
160        let encoded = ((input_idx as u64) << 56)
161            | ((chunk_idx as u64 & 0xFFFFFF) << 32)
162            | (row as u64 & 0xFFFFFFFF);
163        EdgeId::new(encoded)
164    }
165
166    /// Decodes a row location from an EdgeId.
167    fn decode_row_id(edge_id: EdgeId) -> RowId {
168        let encoded = edge_id.as_u64();
169        let input_idx = (encoded >> 56) as usize;
170        let chunk_idx = ((encoded >> 32) & 0xFFFFFF) as usize;
171        let row = (encoded & 0xFFFFFFFF) as usize;
172        (input_idx, chunk_idx, row)
173    }
174
175    /// Executes the leapfrog join to find all intersections.
176    fn execute_leapfrog(&mut self) -> Result<(), OperatorError> {
177        if self.tries.is_empty() {
178            return Ok(());
179        }
180
181        // Create iterators for each trie at the first level
182        let iters: Vec<_> = self.tries.iter().map(|t| t.iter()).collect();
183
184        // Create leapfrog join
185        let mut join = LeapfrogJoin::new(iters);
186
187        // Find all intersections at the first level
188        while let Some(key) = join.key() {
189            // Collect all row IDs from each input that match this key
190            let mut row_ids_per_input: Vec<Vec<RowId>> = vec![Vec::new(); self.tries.len()];
191
192            // For each trie, collect all row IDs at this key
193            if let Some(child_iters) = join.open() {
194                for (input_idx, _child_iter) in child_iters.into_iter().enumerate() {
195                    // The child iterator points to the second level of the trie
196                    // We need to collect the edge IDs (our encoded row IDs) at this position
197                    self.collect_row_ids_at_key(
198                        &self.tries[input_idx],
199                        key,
200                        input_idx,
201                        &mut row_ids_per_input[input_idx],
202                    );
203                }
204            }
205
206            // Only add result if all inputs have matching rows
207            if row_ids_per_input.iter().all(|ids| !ids.is_empty()) {
208                self.results.push(JoinResult {
209                    key,
210                    row_ids: row_ids_per_input,
211                });
212            }
213
214            if !join.next() {
215                break;
216            }
217        }
218
219        // Initialize expansion indices if we have results
220        if !self.results.is_empty() {
221            self.expansion_indices = vec![0; self.inputs.len()];
222        }
223
224        Ok(())
225    }
226
227    /// Collects all row IDs from a trie at a specific key.
228    fn collect_row_ids_at_key(
229        &self,
230        trie: &TrieIndex,
231        key: NodeId,
232        input_idx: usize,
233        row_ids: &mut Vec<RowId>,
234    ) {
235        // Get iterator at the key's path
236        if let Some(edges) = trie.get(&[key]) {
237            for &edge_id in edges {
238                let decoded = Self::decode_row_id(edge_id);
239                // Verify input index matches (should always match)
240                if decoded.0 == input_idx {
241                    row_ids.push(decoded);
242                }
243            }
244        }
245
246        // Also check children (for multi-level tries)
247        if let Some(iter) = trie.iter_at(&[key]) {
248            let mut iter = iter;
249            loop {
250                if let Some(child_key) = iter.key() {
251                    if let Some(edges) = trie.get(&[key, child_key]) {
252                        for &edge_id in edges {
253                            row_ids.push(Self::decode_row_id(edge_id));
254                        }
255                    }
256                }
257                if !iter.next() {
258                    break;
259                }
260            }
261        }
262    }
263
264    /// Advances to the next combination in the current result's cross product.
265    fn advance_expansion(&mut self) -> bool {
266        if self.result_position >= self.results.len() {
267            return false;
268        }
269
270        let result = &self.results[self.result_position];
271
272        // Try to advance from the rightmost input
273        for i in (0..self.expansion_indices.len()).rev() {
274            self.expansion_indices[i] += 1;
275            if self.expansion_indices[i] < result.row_ids[i].len() {
276                return true;
277            }
278            self.expansion_indices[i] = 0;
279        }
280
281        // All combinations exhausted for this result, move to next
282        self.result_position += 1;
283        if self.result_position < self.results.len() {
284            self.expansion_indices = vec![0; self.inputs.len()];
285            true
286        } else {
287            false
288        }
289    }
290
291    /// Builds an output row from the current expansion position.
292    fn build_output_row(&self, builder: &mut DataChunkBuilder) -> Result<(), OperatorError> {
293        let result = &self.results[self.result_position];
294
295        for (out_col, &(input_idx, in_col)) in self.output_column_mapping.iter().enumerate() {
296            let expansion_idx = self.expansion_indices[input_idx];
297            let (_, chunk_idx, row) = result.row_ids[input_idx][expansion_idx];
298
299            let chunk = &self.materialized_inputs[input_idx][chunk_idx];
300            let col = chunk
301                .column(in_col)
302                .ok_or_else(|| OperatorError::ColumnNotFound(in_col.to_string()))?;
303
304            let out_col_vec = builder
305                .column_mut(out_col)
306                .ok_or_else(|| OperatorError::ColumnNotFound(out_col.to_string()))?;
307
308            // Copy value from input to output
309            if let Some(value) = col.get_value(row) {
310                out_col_vec.push_value(value);
311            } else {
312                out_col_vec.push_value(Value::Null);
313            }
314        }
315
316        builder.advance_row();
317        Ok(())
318    }
319}
320
321impl Operator for LeapfrogJoinOperator {
322    fn next(&mut self) -> OperatorResult {
323        // First call: materialize inputs and execute leapfrog
324        if !self.materialized {
325            self.materialize_inputs()?;
326            self.execute_leapfrog()?;
327        }
328
329        if self.exhausted || self.results.is_empty() {
330            return Ok(None);
331        }
332
333        // Check if we've exhausted all results
334        if self.result_position >= self.results.len() {
335            self.exhausted = true;
336            return Ok(None);
337        }
338
339        let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 2048);
340
341        while !builder.is_full() {
342            self.build_output_row(&mut builder)?;
343
344            if !self.advance_expansion() {
345                self.exhausted = true;
346                break;
347            }
348        }
349
350        if builder.row_count() > 0 {
351            Ok(Some(builder.finish()))
352        } else {
353            Ok(None)
354        }
355    }
356
357    fn reset(&mut self) {
358        for input in &mut self.inputs {
359            input.reset();
360        }
361        self.materialized_inputs.clear();
362        self.tries.clear();
363        self.materialized = false;
364        self.results.clear();
365        self.result_position = 0;
366        self.expansion_indices.clear();
367        self.exhausted = false;
368    }
369
370    fn name(&self) -> &'static str {
371        "LeapfrogJoin"
372    }
373}
374
375#[cfg(test)]
376mod tests {
377    use super::*;
378    use crate::execution::vector::ValueVector;
379
380    /// Creates a simple scan operator that returns a single chunk.
381    struct MockScanOperator {
382        chunk: Option<DataChunk>,
383        returned: bool,
384    }
385
386    impl MockScanOperator {
387        fn new(chunk: DataChunk) -> Self {
388            Self {
389                chunk: Some(chunk),
390                returned: false,
391            }
392        }
393    }
394
395    impl Operator for MockScanOperator {
396        fn next(&mut self) -> OperatorResult {
397            if self.returned {
398                Ok(None)
399            } else {
400                self.returned = true;
401                Ok(self.chunk.take())
402            }
403        }
404
405        fn reset(&mut self) {
406            self.returned = false;
407        }
408
409        fn name(&self) -> &'static str {
410            "MockScan"
411        }
412    }
413
414    fn create_node_chunk(node_ids: &[i64]) -> DataChunk {
415        let mut col = ValueVector::with_type(LogicalType::Int64);
416        for &id in node_ids {
417            col.push_int64(id);
418        }
419        DataChunk::new(vec![col])
420    }
421
422    #[test]
423    fn test_leapfrog_binary_intersection() {
424        // Input 1: nodes [1, 2, 3, 5]
425        // Input 2: nodes [2, 3, 4, 5]
426        // Expected intersection: [2, 3, 5]
427
428        let chunk1 = create_node_chunk(&[1, 2, 3, 5]);
429        let chunk2 = create_node_chunk(&[2, 3, 4, 5]);
430
431        let op1: Box<dyn Operator> = Box::new(MockScanOperator::new(chunk1));
432        let op2: Box<dyn Operator> = Box::new(MockScanOperator::new(chunk2));
433
434        let mut leapfrog = LeapfrogJoinOperator::new(
435            vec![op1, op2],
436            vec![vec![0], vec![0]], // Join on first column of each
437            vec![LogicalType::Int64, LogicalType::Int64],
438            vec![(0, 0), (1, 0)], // Output both columns
439        );
440
441        let mut all_results = Vec::new();
442        while let Some(chunk) = leapfrog.next().unwrap() {
443            for row in 0..chunk.row_count() {
444                let val1 = chunk.column(0).unwrap().get_int64(row).unwrap();
445                let val2 = chunk.column(1).unwrap().get_int64(row).unwrap();
446                all_results.push((val1, val2));
447            }
448        }
449
450        // Should find 3 matches: (2,2), (3,3), (5,5)
451        assert_eq!(all_results.len(), 3);
452        assert!(all_results.contains(&(2, 2)));
453        assert!(all_results.contains(&(3, 3)));
454        assert!(all_results.contains(&(5, 5)));
455    }
456
457    #[test]
458    fn test_leapfrog_empty_intersection() {
459        // Input 1: nodes [1, 2, 3]
460        // Input 2: nodes [4, 5, 6]
461        // Expected: empty
462
463        let chunk1 = create_node_chunk(&[1, 2, 3]);
464        let chunk2 = create_node_chunk(&[4, 5, 6]);
465
466        let op1: Box<dyn Operator> = Box::new(MockScanOperator::new(chunk1));
467        let op2: Box<dyn Operator> = Box::new(MockScanOperator::new(chunk2));
468
469        let mut leapfrog = LeapfrogJoinOperator::new(
470            vec![op1, op2],
471            vec![vec![0], vec![0]],
472            vec![LogicalType::Int64, LogicalType::Int64],
473            vec![(0, 0), (1, 0)],
474        );
475
476        let result = leapfrog.next().unwrap();
477        assert!(result.is_none());
478    }
479
480    #[test]
481    fn test_leapfrog_reset() {
482        let chunk1 = create_node_chunk(&[1, 2, 3]);
483        let chunk2 = create_node_chunk(&[2, 3, 4]);
484
485        let op1: Box<dyn Operator> = Box::new(MockScanOperator::new(chunk1.clone()));
486        let op2: Box<dyn Operator> = Box::new(MockScanOperator::new(chunk2.clone()));
487
488        let mut leapfrog = LeapfrogJoinOperator::new(
489            vec![op1, op2],
490            vec![vec![0], vec![0]],
491            vec![LogicalType::Int64, LogicalType::Int64],
492            vec![(0, 0), (1, 0)],
493        );
494
495        // First iteration - consume all results
496        let mut _count = 0;
497        while leapfrog.next().unwrap().is_some() {
498            _count += 1;
499        }
500
501        // Reset won't work with MockScanOperator since the chunk is taken
502        // but the reset logic itself should work
503        leapfrog.reset();
504        assert!(!leapfrog.materialized);
505        assert!(leapfrog.results.is_empty());
506    }
507
508    #[test]
509    fn test_encode_decode_row_id() {
510        let test_cases = [
511            (0, 0, 0),
512            (1, 2, 3),
513            (255, 16777215, 4294967295), // Max values for each field
514        ];
515
516        for (input_idx, chunk_idx, row) in test_cases {
517            let encoded = LeapfrogJoinOperator::encode_row_id(input_idx, chunk_idx, row);
518            let decoded = LeapfrogJoinOperator::decode_row_id(encoded);
519            assert_eq!(decoded, (input_idx, chunk_idx, row));
520        }
521    }
522}