grafeo_core/execution/operators/
leapfrog_join.rs1use grafeo_common::types::{EdgeId, LogicalType, NodeId, Value};
10
11use super::{Operator, OperatorError, OperatorResult};
12use crate::execution::DataChunk;
13use crate::execution::chunk::DataChunkBuilder;
14use crate::index::trie::{LeapfrogJoin, TrieIndex};
15
16type RowId = (usize, usize, usize);
18
19struct JoinResult {
21 row_ids: Vec<Vec<RowId>>,
23}
24
25pub struct LeapfrogJoinOperator {
30 inputs: Vec<Box<dyn Operator>>,
32
33 join_key_indices: Vec<Vec<usize>>,
36
37 output_schema: Vec<LogicalType>,
39
40 output_column_mapping: Vec<(usize, usize)>,
42
43 materialized_inputs: Vec<Vec<DataChunk>>,
46
47 tries: Vec<TrieIndex>,
49
50 materialized: bool,
52
53 results: Vec<JoinResult>,
56
57 result_position: usize,
59
60 expansion_indices: Vec<usize>,
62
63 exhausted: bool,
65}
66
67impl LeapfrogJoinOperator {
68 #[must_use]
76 pub fn new(
77 inputs: Vec<Box<dyn Operator>>,
78 join_key_indices: Vec<Vec<usize>>,
79 output_schema: Vec<LogicalType>,
80 output_column_mapping: Vec<(usize, usize)>,
81 ) -> Self {
82 Self {
83 inputs,
84 join_key_indices,
85 output_schema,
86 output_column_mapping,
87 materialized_inputs: Vec::new(),
88 tries: Vec::new(),
89 materialized: false,
90 results: Vec::new(),
91 result_position: 0,
92 expansion_indices: Vec::new(),
93 exhausted: false,
94 }
95 }
96
97 fn materialize_inputs(&mut self) -> Result<(), OperatorError> {
99 for input in &mut self.inputs {
101 let mut chunks = Vec::new();
102 while let Some(chunk) = input.next()? {
103 chunks.push(chunk);
104 }
105 self.materialized_inputs.push(chunks);
106 }
107
108 for (input_idx, chunks) in self.materialized_inputs.iter().enumerate() {
110 let mut trie = TrieIndex::new();
111 let key_indices = &self.join_key_indices[input_idx];
112
113 for (chunk_idx, chunk) in chunks.iter().enumerate() {
114 for row in 0..chunk.row_count() {
115 if let Some(path) = self.extract_join_keys(chunk, row, key_indices) {
117 let row_id = Self::encode_row_id(input_idx, chunk_idx, row);
119 trie.insert(&path, row_id);
120 }
121 }
122 }
123 self.tries.push(trie);
124 }
125
126 self.materialized = true;
127 Ok(())
128 }
129
130 fn extract_join_keys(
132 &self,
133 chunk: &DataChunk,
134 row: usize,
135 key_indices: &[usize],
136 ) -> Option<Vec<NodeId>> {
137 let mut path = Vec::with_capacity(key_indices.len());
138
139 for &col_idx in key_indices {
140 let col = chunk.column(col_idx)?;
141 let node_id = match col.data_type() {
142 LogicalType::Node => col.get_node_id(row),
143 LogicalType::Edge => col.get_edge_id(row).map(|e| NodeId::new(e.as_u64())),
144 LogicalType::Int64 => col.get_int64(row).map(|i| NodeId::new(i as u64)),
145 _ => return None, }?;
147 path.push(node_id);
148 }
149
150 Some(path)
151 }
152
153 fn encode_row_id(input_idx: usize, chunk_idx: usize, row: usize) -> EdgeId {
155 let encoded = ((input_idx as u64) << 56)
157 | ((chunk_idx as u64 & 0xFFFFFF) << 32)
158 | (row as u64 & 0xFFFFFFFF);
159 EdgeId::new(encoded)
160 }
161
162 fn decode_row_id(edge_id: EdgeId) -> RowId {
164 let encoded = edge_id.as_u64();
165 let input_idx = (encoded >> 56) as usize;
166 let chunk_idx = ((encoded >> 32) & 0xFFFFFF) as usize;
167 let row = (encoded & 0xFFFFFFFF) as usize;
168 (input_idx, chunk_idx, row)
169 }
170
171 fn execute_leapfrog(&mut self) -> Result<(), OperatorError> {
173 if self.tries.is_empty() {
174 return Ok(());
175 }
176
177 let iters: Vec<_> = self.tries.iter().map(|t| t.iter()).collect();
179
180 let mut join = LeapfrogJoin::new(iters);
182
183 while let Some(key) = join.key() {
185 let mut row_ids_per_input: Vec<Vec<RowId>> = vec![Vec::new(); self.tries.len()];
187
188 if let Some(child_iters) = join.open() {
190 for (input_idx, _child_iter) in child_iters.into_iter().enumerate() {
191 self.collect_row_ids_at_key(
194 &self.tries[input_idx],
195 key,
196 input_idx,
197 &mut row_ids_per_input[input_idx],
198 );
199 }
200 }
201
202 if row_ids_per_input.iter().all(|ids| !ids.is_empty()) {
204 self.results.push(JoinResult {
205 row_ids: row_ids_per_input,
206 });
207 }
208
209 if !join.next() {
210 break;
211 }
212 }
213
214 if !self.results.is_empty() {
216 self.expansion_indices = vec![0; self.inputs.len()];
217 }
218
219 Ok(())
220 }
221
222 fn collect_row_ids_at_key(
224 &self,
225 trie: &TrieIndex,
226 key: NodeId,
227 input_idx: usize,
228 row_ids: &mut Vec<RowId>,
229 ) {
230 if let Some(edges) = trie.get(&[key]) {
232 for &edge_id in edges {
233 let decoded = Self::decode_row_id(edge_id);
234 if decoded.0 == input_idx {
236 row_ids.push(decoded);
237 }
238 }
239 }
240
241 if let Some(iter) = trie.iter_at(&[key]) {
243 let mut iter = iter;
244 loop {
245 if let Some(child_key) = iter.key()
246 && let Some(edges) = trie.get(&[key, child_key])
247 {
248 for &edge_id in edges {
249 row_ids.push(Self::decode_row_id(edge_id));
250 }
251 }
252 if !iter.next() {
253 break;
254 }
255 }
256 }
257 }
258
259 fn advance_expansion(&mut self) -> bool {
261 if self.result_position >= self.results.len() {
262 return false;
263 }
264
265 let result = &self.results[self.result_position];
266
267 for i in (0..self.expansion_indices.len()).rev() {
269 self.expansion_indices[i] += 1;
270 if self.expansion_indices[i] < result.row_ids[i].len() {
271 return true;
272 }
273 self.expansion_indices[i] = 0;
274 }
275
276 self.result_position += 1;
278 if self.result_position < self.results.len() {
279 self.expansion_indices = vec![0; self.inputs.len()];
280 true
281 } else {
282 false
283 }
284 }
285
286 fn build_output_row(&self, builder: &mut DataChunkBuilder) -> Result<(), OperatorError> {
288 let result = &self.results[self.result_position];
289
290 for (out_col, &(input_idx, in_col)) in self.output_column_mapping.iter().enumerate() {
291 let expansion_idx = self.expansion_indices[input_idx];
292 let (_, chunk_idx, row) = result.row_ids[input_idx][expansion_idx];
293
294 let chunk = &self.materialized_inputs[input_idx][chunk_idx];
295 let col = chunk
296 .column(in_col)
297 .ok_or_else(|| OperatorError::ColumnNotFound(in_col.to_string()))?;
298
299 let out_col_vec = builder
300 .column_mut(out_col)
301 .ok_or_else(|| OperatorError::ColumnNotFound(out_col.to_string()))?;
302
303 if let Some(value) = col.get_value(row) {
305 out_col_vec.push_value(value);
306 } else {
307 out_col_vec.push_value(Value::Null);
308 }
309 }
310
311 builder.advance_row();
312 Ok(())
313 }
314}
315
316impl Operator for LeapfrogJoinOperator {
317 fn next(&mut self) -> OperatorResult {
318 if !self.materialized {
320 self.materialize_inputs()?;
321 self.execute_leapfrog()?;
322 }
323
324 if self.exhausted || self.results.is_empty() {
325 return Ok(None);
326 }
327
328 if self.result_position >= self.results.len() {
330 self.exhausted = true;
331 return Ok(None);
332 }
333
334 let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 2048);
335
336 while !builder.is_full() {
337 self.build_output_row(&mut builder)?;
338
339 if !self.advance_expansion() {
340 self.exhausted = true;
341 break;
342 }
343 }
344
345 if builder.row_count() > 0 {
346 Ok(Some(builder.finish()))
347 } else {
348 Ok(None)
349 }
350 }
351
352 fn reset(&mut self) {
353 for input in &mut self.inputs {
354 input.reset();
355 }
356 self.materialized_inputs.clear();
357 self.tries.clear();
358 self.materialized = false;
359 self.results.clear();
360 self.result_position = 0;
361 self.expansion_indices.clear();
362 self.exhausted = false;
363 }
364
365 fn name(&self) -> &'static str {
366 "LeapfrogJoin"
367 }
368}
369
370#[cfg(test)]
371mod tests {
372 use super::*;
373 use crate::execution::vector::ValueVector;
374
375 struct MockScanOperator {
377 chunk: Option<DataChunk>,
378 returned: bool,
379 }
380
381 impl MockScanOperator {
382 fn new(chunk: DataChunk) -> Self {
383 Self {
384 chunk: Some(chunk),
385 returned: false,
386 }
387 }
388 }
389
390 impl Operator for MockScanOperator {
391 fn next(&mut self) -> OperatorResult {
392 if self.returned {
393 Ok(None)
394 } else {
395 self.returned = true;
396 Ok(self.chunk.take())
397 }
398 }
399
400 fn reset(&mut self) {
401 self.returned = false;
402 }
403
404 fn name(&self) -> &'static str {
405 "MockScan"
406 }
407 }
408
409 fn create_node_chunk(node_ids: &[i64]) -> DataChunk {
410 let mut col = ValueVector::with_type(LogicalType::Int64);
411 for &id in node_ids {
412 col.push_int64(id);
413 }
414 DataChunk::new(vec![col])
415 }
416
417 #[test]
418 fn test_leapfrog_binary_intersection() {
419 let chunk1 = create_node_chunk(&[1, 2, 3, 5]);
424 let chunk2 = create_node_chunk(&[2, 3, 4, 5]);
425
426 let op1: Box<dyn Operator> = Box::new(MockScanOperator::new(chunk1));
427 let op2: Box<dyn Operator> = Box::new(MockScanOperator::new(chunk2));
428
429 let mut leapfrog = LeapfrogJoinOperator::new(
430 vec![op1, op2],
431 vec![vec![0], vec![0]], vec![LogicalType::Int64, LogicalType::Int64],
433 vec![(0, 0), (1, 0)], );
435
436 let mut all_results = Vec::new();
437 while let Some(chunk) = leapfrog.next().unwrap() {
438 for row in 0..chunk.row_count() {
439 let val1 = chunk.column(0).unwrap().get_int64(row).unwrap();
440 let val2 = chunk.column(1).unwrap().get_int64(row).unwrap();
441 all_results.push((val1, val2));
442 }
443 }
444
445 assert_eq!(all_results.len(), 3);
447 assert!(all_results.contains(&(2, 2)));
448 assert!(all_results.contains(&(3, 3)));
449 assert!(all_results.contains(&(5, 5)));
450 }
451
452 #[test]
453 fn test_leapfrog_empty_intersection() {
454 let chunk1 = create_node_chunk(&[1, 2, 3]);
459 let chunk2 = create_node_chunk(&[4, 5, 6]);
460
461 let op1: Box<dyn Operator> = Box::new(MockScanOperator::new(chunk1));
462 let op2: Box<dyn Operator> = Box::new(MockScanOperator::new(chunk2));
463
464 let mut leapfrog = LeapfrogJoinOperator::new(
465 vec![op1, op2],
466 vec![vec![0], vec![0]],
467 vec![LogicalType::Int64, LogicalType::Int64],
468 vec![(0, 0), (1, 0)],
469 );
470
471 let result = leapfrog.next().unwrap();
472 assert!(result.is_none());
473 }
474
475 #[test]
476 fn test_leapfrog_reset() {
477 let chunk1 = create_node_chunk(&[1, 2, 3]);
478 let chunk2 = create_node_chunk(&[2, 3, 4]);
479
480 let op1: Box<dyn Operator> = Box::new(MockScanOperator::new(chunk1.clone()));
481 let op2: Box<dyn Operator> = Box::new(MockScanOperator::new(chunk2.clone()));
482
483 let mut leapfrog = LeapfrogJoinOperator::new(
484 vec![op1, op2],
485 vec![vec![0], vec![0]],
486 vec![LogicalType::Int64, LogicalType::Int64],
487 vec![(0, 0), (1, 0)],
488 );
489
490 let mut _count = 0;
492 while leapfrog.next().unwrap().is_some() {
493 _count += 1;
494 }
495
496 leapfrog.reset();
499 assert!(!leapfrog.materialized);
500 assert!(leapfrog.results.is_empty());
501 }
502
503 #[test]
504 fn test_encode_decode_row_id() {
505 let test_cases = [
506 (0, 0, 0),
507 (1, 2, 3),
508 (255, 16777215, 4294967295), ];
510
511 for (input_idx, chunk_idx, row) in test_cases {
512 let encoded = LeapfrogJoinOperator::encode_row_id(input_idx, chunk_idx, row);
513 let decoded = LeapfrogJoinOperator::decode_row_id(encoded);
514 assert_eq!(decoded, (input_idx, chunk_idx, row));
515 }
516 }
517}