Skip to main content

grafeo_core/execution/operators/
leapfrog_join.rs

1//! Leapfrog TrieJoin operator for worst-case optimal joins.
2//!
3//! This operator wraps the `LeapfrogJoin` algorithm from the trie index module
4//! to provide efficient multi-way joins for cyclic patterns like triangles.
5//!
6//! Traditional binary hash joins cascade O(N²) for triangle patterns; leapfrog
7//! achieves O(N^1.5) by processing all relations simultaneously.
8
9use grafeo_common::types::{EdgeId, LogicalType, NodeId, Value};
10
11use super::{Operator, OperatorError, OperatorResult};
12use crate::execution::DataChunk;
13use crate::execution::chunk::DataChunkBuilder;
14use crate::index::trie::{LeapfrogJoin, TrieIndex};
15
16/// Row identifier for reconstructing output: (input_index, chunk_index, row_index).
17type RowId = (usize, usize, usize);
18
19/// A multi-way join intersection result.
20struct JoinResult {
21    /// Row identifiers from each input that participated in this match.
22    row_ids: Vec<Vec<RowId>>,
23}
24
25/// Leapfrog TrieJoin operator for worst-case optimal multi-way joins.
26///
27/// Uses the leapfrog algorithm to efficiently find intersections across
28/// multiple sorted inputs without materializing intermediate Cartesian products.
29pub struct LeapfrogJoinOperator {
30    /// Input operators (one per relation in the join).
31    inputs: Vec<Box<dyn Operator>>,
32
33    /// Column indices for join keys in each input.
34    /// Each inner Vec maps to one join variable.
35    join_key_indices: Vec<Vec<usize>>,
36
37    /// Output schema (combined columns from all inputs).
38    output_schema: Vec<LogicalType>,
39
40    /// Mapping from output column index to (input_idx, column_idx).
41    output_column_mapping: Vec<(usize, usize)>,
42
43    // === Materialization state ===
44    /// Materialized input chunks (built once during first next() call).
45    materialized_inputs: Vec<Vec<DataChunk>>,
46
47    /// TrieIndex structures built from materialized inputs.
48    tries: Vec<TrieIndex>,
49
50    /// Whether materialization is complete.
51    materialized: bool,
52
53    // === Iteration state ===
54    /// Pre-computed join results.
55    results: Vec<JoinResult>,
56
57    /// Current position in results.
58    result_position: usize,
59
60    /// Current expansion position within current result's cross product.
61    expansion_indices: Vec<usize>,
62
63    /// Whether iteration is exhausted.
64    exhausted: bool,
65}
66
67impl LeapfrogJoinOperator {
68    /// Creates a new leapfrog join operator.
69    ///
70    /// # Arguments
71    /// * `inputs` - Input operators (one per relation).
72    /// * `join_key_indices` - Column indices for join keys in each input.
73    /// * `output_schema` - Schema of the output columns.
74    /// * `output_column_mapping` - Maps output columns to (input_idx, column_idx).
75    #[must_use]
76    pub fn new(
77        inputs: Vec<Box<dyn Operator>>,
78        join_key_indices: Vec<Vec<usize>>,
79        output_schema: Vec<LogicalType>,
80        output_column_mapping: Vec<(usize, usize)>,
81    ) -> Self {
82        Self {
83            inputs,
84            join_key_indices,
85            output_schema,
86            output_column_mapping,
87            materialized_inputs: Vec::new(),
88            tries: Vec::new(),
89            materialized: false,
90            results: Vec::new(),
91            result_position: 0,
92            expansion_indices: Vec::new(),
93            exhausted: false,
94        }
95    }
96
97    /// Materializes all inputs and builds trie indexes.
98    fn materialize_inputs(&mut self) -> Result<(), OperatorError> {
99        // Phase 1: Collect all chunks from each input
100        for input in &mut self.inputs {
101            let mut chunks = Vec::new();
102            while let Some(chunk) = input.next()? {
103                chunks.push(chunk);
104            }
105            self.materialized_inputs.push(chunks);
106        }
107
108        // Phase 2: Build TrieIndex for each input
109        for (input_idx, chunks) in self.materialized_inputs.iter().enumerate() {
110            let mut trie = TrieIndex::new();
111            let key_indices = &self.join_key_indices[input_idx];
112
113            for (chunk_idx, chunk) in chunks.iter().enumerate() {
114                for row in 0..chunk.row_count() {
115                    // Extract join key values and convert to path
116                    if let Some(path) = self.extract_join_keys(chunk, row, key_indices) {
117                        // Encode row location as EdgeId for trie storage
118                        let row_id = Self::encode_row_id(input_idx, chunk_idx, row);
119                        trie.insert(&path, row_id);
120                    }
121                }
122            }
123            self.tries.push(trie);
124        }
125
126        self.materialized = true;
127        Ok(())
128    }
129
130    /// Extracts join key values from a row and converts to NodeId path.
131    fn extract_join_keys(
132        &self,
133        chunk: &DataChunk,
134        row: usize,
135        key_indices: &[usize],
136    ) -> Option<Vec<NodeId>> {
137        let mut path = Vec::with_capacity(key_indices.len());
138
139        for &col_idx in key_indices {
140            let col = chunk.column(col_idx)?;
141            let node_id = match col.data_type() {
142                LogicalType::Node => col.get_node_id(row),
143                LogicalType::Edge => col.get_edge_id(row).map(|e| NodeId::new(e.as_u64())),
144                LogicalType::Int64 => col.get_int64(row).map(|i| NodeId::new(i as u64)),
145                _ => return None, // Unsupported join key type
146            }?;
147            path.push(node_id);
148        }
149
150        Some(path)
151    }
152
153    /// Encodes a row location as an EdgeId for trie storage.
154    fn encode_row_id(input_idx: usize, chunk_idx: usize, row: usize) -> EdgeId {
155        // Pack: input (8 bits) | chunk (24 bits) | row (32 bits)
156        let encoded = ((input_idx as u64) << 56)
157            | ((chunk_idx as u64 & 0xFFFFFF) << 32)
158            | (row as u64 & 0xFFFFFFFF);
159        EdgeId::new(encoded)
160    }
161
162    /// Decodes a row location from an EdgeId.
163    fn decode_row_id(edge_id: EdgeId) -> RowId {
164        let encoded = edge_id.as_u64();
165        let input_idx = (encoded >> 56) as usize;
166        let chunk_idx = ((encoded >> 32) & 0xFFFFFF) as usize;
167        let row = (encoded & 0xFFFFFFFF) as usize;
168        (input_idx, chunk_idx, row)
169    }
170
171    /// Executes the leapfrog join to find all intersections.
172    fn execute_leapfrog(&mut self) -> Result<(), OperatorError> {
173        if self.tries.is_empty() {
174            return Ok(());
175        }
176
177        // Create iterators for each trie at the first level
178        let iters: Vec<_> = self.tries.iter().map(|t| t.iter()).collect();
179
180        // Create leapfrog join
181        let mut join = LeapfrogJoin::new(iters);
182
183        // Find all intersections at the first level
184        while let Some(key) = join.key() {
185            // Collect all row IDs from each input that match this key
186            let mut row_ids_per_input: Vec<Vec<RowId>> = vec![Vec::new(); self.tries.len()];
187
188            // For each trie, collect all row IDs at this key
189            if let Some(child_iters) = join.open() {
190                for (input_idx, _child_iter) in child_iters.into_iter().enumerate() {
191                    // The child iterator points to the second level of the trie
192                    // We need to collect the edge IDs (our encoded row IDs) at this position
193                    self.collect_row_ids_at_key(
194                        &self.tries[input_idx],
195                        key,
196                        input_idx,
197                        &mut row_ids_per_input[input_idx],
198                    );
199                }
200            }
201
202            // Only add result if all inputs have matching rows
203            if row_ids_per_input.iter().all(|ids| !ids.is_empty()) {
204                self.results.push(JoinResult {
205                    row_ids: row_ids_per_input,
206                });
207            }
208
209            if !join.next() {
210                break;
211            }
212        }
213
214        // Initialize expansion indices if we have results
215        if !self.results.is_empty() {
216            self.expansion_indices = vec![0; self.inputs.len()];
217        }
218
219        Ok(())
220    }
221
222    /// Collects all row IDs from a trie at a specific key.
223    fn collect_row_ids_at_key(
224        &self,
225        trie: &TrieIndex,
226        key: NodeId,
227        input_idx: usize,
228        row_ids: &mut Vec<RowId>,
229    ) {
230        // Get iterator at the key's path
231        if let Some(edges) = trie.get(&[key]) {
232            for &edge_id in edges {
233                let decoded = Self::decode_row_id(edge_id);
234                // Verify input index matches (should always match)
235                if decoded.0 == input_idx {
236                    row_ids.push(decoded);
237                }
238            }
239        }
240
241        // Also check children (for multi-level tries)
242        if let Some(iter) = trie.iter_at(&[key]) {
243            let mut iter = iter;
244            loop {
245                if let Some(child_key) = iter.key()
246                    && let Some(edges) = trie.get(&[key, child_key])
247                {
248                    for &edge_id in edges {
249                        row_ids.push(Self::decode_row_id(edge_id));
250                    }
251                }
252                if !iter.next() {
253                    break;
254                }
255            }
256        }
257    }
258
259    /// Advances to the next combination in the current result's cross product.
260    fn advance_expansion(&mut self) -> bool {
261        if self.result_position >= self.results.len() {
262            return false;
263        }
264
265        let result = &self.results[self.result_position];
266
267        // Try to advance from the rightmost input
268        for i in (0..self.expansion_indices.len()).rev() {
269            self.expansion_indices[i] += 1;
270            if self.expansion_indices[i] < result.row_ids[i].len() {
271                return true;
272            }
273            self.expansion_indices[i] = 0;
274        }
275
276        // All combinations exhausted for this result, move to next
277        self.result_position += 1;
278        if self.result_position < self.results.len() {
279            self.expansion_indices = vec![0; self.inputs.len()];
280            true
281        } else {
282            false
283        }
284    }
285
286    /// Builds an output row from the current expansion position.
287    fn build_output_row(&self, builder: &mut DataChunkBuilder) -> Result<(), OperatorError> {
288        let result = &self.results[self.result_position];
289
290        for (out_col, &(input_idx, in_col)) in self.output_column_mapping.iter().enumerate() {
291            let expansion_idx = self.expansion_indices[input_idx];
292            let (_, chunk_idx, row) = result.row_ids[input_idx][expansion_idx];
293
294            let chunk = &self.materialized_inputs[input_idx][chunk_idx];
295            let col = chunk
296                .column(in_col)
297                .ok_or_else(|| OperatorError::ColumnNotFound(in_col.to_string()))?;
298
299            let out_col_vec = builder
300                .column_mut(out_col)
301                .ok_or_else(|| OperatorError::ColumnNotFound(out_col.to_string()))?;
302
303            // Copy value from input to output
304            if let Some(value) = col.get_value(row) {
305                out_col_vec.push_value(value);
306            } else {
307                out_col_vec.push_value(Value::Null);
308            }
309        }
310
311        builder.advance_row();
312        Ok(())
313    }
314}
315
316impl Operator for LeapfrogJoinOperator {
317    fn next(&mut self) -> OperatorResult {
318        // First call: materialize inputs and execute leapfrog
319        if !self.materialized {
320            self.materialize_inputs()?;
321            self.execute_leapfrog()?;
322        }
323
324        if self.exhausted || self.results.is_empty() {
325            return Ok(None);
326        }
327
328        // Check if we've exhausted all results
329        if self.result_position >= self.results.len() {
330            self.exhausted = true;
331            return Ok(None);
332        }
333
334        let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 2048);
335
336        while !builder.is_full() {
337            self.build_output_row(&mut builder)?;
338
339            if !self.advance_expansion() {
340                self.exhausted = true;
341                break;
342            }
343        }
344
345        if builder.row_count() > 0 {
346            Ok(Some(builder.finish()))
347        } else {
348            Ok(None)
349        }
350    }
351
352    fn reset(&mut self) {
353        for input in &mut self.inputs {
354            input.reset();
355        }
356        self.materialized_inputs.clear();
357        self.tries.clear();
358        self.materialized = false;
359        self.results.clear();
360        self.result_position = 0;
361        self.expansion_indices.clear();
362        self.exhausted = false;
363    }
364
365    fn name(&self) -> &'static str {
366        "LeapfrogJoin"
367    }
368}
369
370#[cfg(test)]
371mod tests {
372    use super::*;
373    use crate::execution::vector::ValueVector;
374
375    /// Creates a simple scan operator that returns a single chunk.
376    struct MockScanOperator {
377        chunk: Option<DataChunk>,
378        returned: bool,
379    }
380
381    impl MockScanOperator {
382        fn new(chunk: DataChunk) -> Self {
383            Self {
384                chunk: Some(chunk),
385                returned: false,
386            }
387        }
388    }
389
390    impl Operator for MockScanOperator {
391        fn next(&mut self) -> OperatorResult {
392            if self.returned {
393                Ok(None)
394            } else {
395                self.returned = true;
396                Ok(self.chunk.take())
397            }
398        }
399
400        fn reset(&mut self) {
401            self.returned = false;
402        }
403
404        fn name(&self) -> &'static str {
405            "MockScan"
406        }
407    }
408
409    fn create_node_chunk(node_ids: &[i64]) -> DataChunk {
410        let mut col = ValueVector::with_type(LogicalType::Int64);
411        for &id in node_ids {
412            col.push_int64(id);
413        }
414        DataChunk::new(vec![col])
415    }
416
417    #[test]
418    fn test_leapfrog_binary_intersection() {
419        // Input 1: nodes [1, 2, 3, 5]
420        // Input 2: nodes [2, 3, 4, 5]
421        // Expected intersection: [2, 3, 5]
422
423        let chunk1 = create_node_chunk(&[1, 2, 3, 5]);
424        let chunk2 = create_node_chunk(&[2, 3, 4, 5]);
425
426        let op1: Box<dyn Operator> = Box::new(MockScanOperator::new(chunk1));
427        let op2: Box<dyn Operator> = Box::new(MockScanOperator::new(chunk2));
428
429        let mut leapfrog = LeapfrogJoinOperator::new(
430            vec![op1, op2],
431            vec![vec![0], vec![0]], // Join on first column of each
432            vec![LogicalType::Int64, LogicalType::Int64],
433            vec![(0, 0), (1, 0)], // Output both columns
434        );
435
436        let mut all_results = Vec::new();
437        while let Some(chunk) = leapfrog.next().unwrap() {
438            for row in 0..chunk.row_count() {
439                let val1 = chunk.column(0).unwrap().get_int64(row).unwrap();
440                let val2 = chunk.column(1).unwrap().get_int64(row).unwrap();
441                all_results.push((val1, val2));
442            }
443        }
444
445        // Should find 3 matches: (2,2), (3,3), (5,5)
446        assert_eq!(all_results.len(), 3);
447        assert!(all_results.contains(&(2, 2)));
448        assert!(all_results.contains(&(3, 3)));
449        assert!(all_results.contains(&(5, 5)));
450    }
451
452    #[test]
453    fn test_leapfrog_empty_intersection() {
454        // Input 1: nodes [1, 2, 3]
455        // Input 2: nodes [4, 5, 6]
456        // Expected: empty
457
458        let chunk1 = create_node_chunk(&[1, 2, 3]);
459        let chunk2 = create_node_chunk(&[4, 5, 6]);
460
461        let op1: Box<dyn Operator> = Box::new(MockScanOperator::new(chunk1));
462        let op2: Box<dyn Operator> = Box::new(MockScanOperator::new(chunk2));
463
464        let mut leapfrog = LeapfrogJoinOperator::new(
465            vec![op1, op2],
466            vec![vec![0], vec![0]],
467            vec![LogicalType::Int64, LogicalType::Int64],
468            vec![(0, 0), (1, 0)],
469        );
470
471        let result = leapfrog.next().unwrap();
472        assert!(result.is_none());
473    }
474
475    #[test]
476    fn test_leapfrog_reset() {
477        let chunk1 = create_node_chunk(&[1, 2, 3]);
478        let chunk2 = create_node_chunk(&[2, 3, 4]);
479
480        let op1: Box<dyn Operator> = Box::new(MockScanOperator::new(chunk1.clone()));
481        let op2: Box<dyn Operator> = Box::new(MockScanOperator::new(chunk2.clone()));
482
483        let mut leapfrog = LeapfrogJoinOperator::new(
484            vec![op1, op2],
485            vec![vec![0], vec![0]],
486            vec![LogicalType::Int64, LogicalType::Int64],
487            vec![(0, 0), (1, 0)],
488        );
489
490        // First iteration - consume all results
491        let mut _count = 0;
492        while leapfrog.next().unwrap().is_some() {
493            _count += 1;
494        }
495
496        // Reset won't work with MockScanOperator since the chunk is taken
497        // but the reset logic itself should work
498        leapfrog.reset();
499        assert!(!leapfrog.materialized);
500        assert!(leapfrog.results.is_empty());
501    }
502
503    #[test]
504    fn test_encode_decode_row_id() {
505        let test_cases = [
506            (0, 0, 0),
507            (1, 2, 3),
508            (255, 16777215, 4294967295), // Max values for each field
509        ];
510
511        for (input_idx, chunk_idx, row) in test_cases {
512            let encoded = LeapfrogJoinOperator::encode_row_id(input_idx, chunk_idx, row);
513            let decoded = LeapfrogJoinOperator::decode_row_id(encoded);
514            assert_eq!(decoded, (input_idx, chunk_idx, row));
515        }
516    }
517}