grafeo_core/execution/operators/
leapfrog_join.rs1use grafeo_common::types::{EdgeId, LogicalType, NodeId, Value};
10
11use super::{Operator, OperatorError, OperatorResult};
12use crate::execution::DataChunk;
13use crate::execution::chunk::DataChunkBuilder;
14use crate::index::trie::{LeapfrogJoin, TrieIndex};
15
16type RowId = (usize, usize, usize);
18
19struct JoinResult {
21 #[allow(dead_code)]
24 key: NodeId,
25 row_ids: Vec<Vec<RowId>>,
27}
28
29pub struct LeapfrogJoinOperator {
34 inputs: Vec<Box<dyn Operator>>,
36
37 join_key_indices: Vec<Vec<usize>>,
40
41 output_schema: Vec<LogicalType>,
43
44 output_column_mapping: Vec<(usize, usize)>,
46
47 materialized_inputs: Vec<Vec<DataChunk>>,
50
51 tries: Vec<TrieIndex>,
53
54 materialized: bool,
56
57 results: Vec<JoinResult>,
60
61 result_position: usize,
63
64 expansion_indices: Vec<usize>,
66
67 exhausted: bool,
69}
70
71impl LeapfrogJoinOperator {
72 #[must_use]
80 pub fn new(
81 inputs: Vec<Box<dyn Operator>>,
82 join_key_indices: Vec<Vec<usize>>,
83 output_schema: Vec<LogicalType>,
84 output_column_mapping: Vec<(usize, usize)>,
85 ) -> Self {
86 Self {
87 inputs,
88 join_key_indices,
89 output_schema,
90 output_column_mapping,
91 materialized_inputs: Vec::new(),
92 tries: Vec::new(),
93 materialized: false,
94 results: Vec::new(),
95 result_position: 0,
96 expansion_indices: Vec::new(),
97 exhausted: false,
98 }
99 }
100
101 fn materialize_inputs(&mut self) -> Result<(), OperatorError> {
103 for input in &mut self.inputs {
105 let mut chunks = Vec::new();
106 while let Some(chunk) = input.next()? {
107 chunks.push(chunk);
108 }
109 self.materialized_inputs.push(chunks);
110 }
111
112 for (input_idx, chunks) in self.materialized_inputs.iter().enumerate() {
114 let mut trie = TrieIndex::new();
115 let key_indices = &self.join_key_indices[input_idx];
116
117 for (chunk_idx, chunk) in chunks.iter().enumerate() {
118 for row in 0..chunk.row_count() {
119 if let Some(path) = self.extract_join_keys(chunk, row, key_indices) {
121 let row_id = Self::encode_row_id(input_idx, chunk_idx, row);
123 trie.insert(&path, row_id);
124 }
125 }
126 }
127 self.tries.push(trie);
128 }
129
130 self.materialized = true;
131 Ok(())
132 }
133
134 fn extract_join_keys(
136 &self,
137 chunk: &DataChunk,
138 row: usize,
139 key_indices: &[usize],
140 ) -> Option<Vec<NodeId>> {
141 let mut path = Vec::with_capacity(key_indices.len());
142
143 for &col_idx in key_indices {
144 let col = chunk.column(col_idx)?;
145 let node_id = match col.data_type() {
146 LogicalType::Node => col.get_node_id(row),
147 LogicalType::Edge => col.get_edge_id(row).map(|e| NodeId::new(e.as_u64())),
148 LogicalType::Int64 => col.get_int64(row).map(|i| NodeId::new(i as u64)),
149 _ => return None, }?;
151 path.push(node_id);
152 }
153
154 Some(path)
155 }
156
157 fn encode_row_id(input_idx: usize, chunk_idx: usize, row: usize) -> EdgeId {
159 let encoded = ((input_idx as u64) << 56)
161 | ((chunk_idx as u64 & 0xFFFFFF) << 32)
162 | (row as u64 & 0xFFFFFFFF);
163 EdgeId::new(encoded)
164 }
165
166 fn decode_row_id(edge_id: EdgeId) -> RowId {
168 let encoded = edge_id.as_u64();
169 let input_idx = (encoded >> 56) as usize;
170 let chunk_idx = ((encoded >> 32) & 0xFFFFFF) as usize;
171 let row = (encoded & 0xFFFFFFFF) as usize;
172 (input_idx, chunk_idx, row)
173 }
174
175 fn execute_leapfrog(&mut self) -> Result<(), OperatorError> {
177 if self.tries.is_empty() {
178 return Ok(());
179 }
180
181 let iters: Vec<_> = self.tries.iter().map(|t| t.iter()).collect();
183
184 let mut join = LeapfrogJoin::new(iters);
186
187 while let Some(key) = join.key() {
189 let mut row_ids_per_input: Vec<Vec<RowId>> = vec![Vec::new(); self.tries.len()];
191
192 if let Some(child_iters) = join.open() {
194 for (input_idx, _child_iter) in child_iters.into_iter().enumerate() {
195 self.collect_row_ids_at_key(
198 &self.tries[input_idx],
199 key,
200 input_idx,
201 &mut row_ids_per_input[input_idx],
202 );
203 }
204 }
205
206 if row_ids_per_input.iter().all(|ids| !ids.is_empty()) {
208 self.results.push(JoinResult {
209 key,
210 row_ids: row_ids_per_input,
211 });
212 }
213
214 if !join.next() {
215 break;
216 }
217 }
218
219 if !self.results.is_empty() {
221 self.expansion_indices = vec![0; self.inputs.len()];
222 }
223
224 Ok(())
225 }
226
227 fn collect_row_ids_at_key(
229 &self,
230 trie: &TrieIndex,
231 key: NodeId,
232 input_idx: usize,
233 row_ids: &mut Vec<RowId>,
234 ) {
235 if let Some(edges) = trie.get(&[key]) {
237 for &edge_id in edges {
238 let decoded = Self::decode_row_id(edge_id);
239 if decoded.0 == input_idx {
241 row_ids.push(decoded);
242 }
243 }
244 }
245
246 if let Some(iter) = trie.iter_at(&[key]) {
248 let mut iter = iter;
249 loop {
250 if let Some(child_key) = iter.key() {
251 if let Some(edges) = trie.get(&[key, child_key]) {
252 for &edge_id in edges {
253 row_ids.push(Self::decode_row_id(edge_id));
254 }
255 }
256 }
257 if !iter.next() {
258 break;
259 }
260 }
261 }
262 }
263
264 fn advance_expansion(&mut self) -> bool {
266 if self.result_position >= self.results.len() {
267 return false;
268 }
269
270 let result = &self.results[self.result_position];
271
272 for i in (0..self.expansion_indices.len()).rev() {
274 self.expansion_indices[i] += 1;
275 if self.expansion_indices[i] < result.row_ids[i].len() {
276 return true;
277 }
278 self.expansion_indices[i] = 0;
279 }
280
281 self.result_position += 1;
283 if self.result_position < self.results.len() {
284 self.expansion_indices = vec![0; self.inputs.len()];
285 true
286 } else {
287 false
288 }
289 }
290
291 fn build_output_row(&self, builder: &mut DataChunkBuilder) -> Result<(), OperatorError> {
293 let result = &self.results[self.result_position];
294
295 for (out_col, &(input_idx, in_col)) in self.output_column_mapping.iter().enumerate() {
296 let expansion_idx = self.expansion_indices[input_idx];
297 let (_, chunk_idx, row) = result.row_ids[input_idx][expansion_idx];
298
299 let chunk = &self.materialized_inputs[input_idx][chunk_idx];
300 let col = chunk
301 .column(in_col)
302 .ok_or_else(|| OperatorError::ColumnNotFound(in_col.to_string()))?;
303
304 let out_col_vec = builder
305 .column_mut(out_col)
306 .ok_or_else(|| OperatorError::ColumnNotFound(out_col.to_string()))?;
307
308 if let Some(value) = col.get_value(row) {
310 out_col_vec.push_value(value);
311 } else {
312 out_col_vec.push_value(Value::Null);
313 }
314 }
315
316 builder.advance_row();
317 Ok(())
318 }
319}
320
321impl Operator for LeapfrogJoinOperator {
322 fn next(&mut self) -> OperatorResult {
323 if !self.materialized {
325 self.materialize_inputs()?;
326 self.execute_leapfrog()?;
327 }
328
329 if self.exhausted || self.results.is_empty() {
330 return Ok(None);
331 }
332
333 if self.result_position >= self.results.len() {
335 self.exhausted = true;
336 return Ok(None);
337 }
338
339 let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 2048);
340
341 while !builder.is_full() {
342 self.build_output_row(&mut builder)?;
343
344 if !self.advance_expansion() {
345 self.exhausted = true;
346 break;
347 }
348 }
349
350 if builder.row_count() > 0 {
351 Ok(Some(builder.finish()))
352 } else {
353 Ok(None)
354 }
355 }
356
357 fn reset(&mut self) {
358 for input in &mut self.inputs {
359 input.reset();
360 }
361 self.materialized_inputs.clear();
362 self.tries.clear();
363 self.materialized = false;
364 self.results.clear();
365 self.result_position = 0;
366 self.expansion_indices.clear();
367 self.exhausted = false;
368 }
369
370 fn name(&self) -> &'static str {
371 "LeapfrogJoin"
372 }
373}
374
375#[cfg(test)]
376mod tests {
377 use super::*;
378 use crate::execution::vector::ValueVector;
379
380 struct MockScanOperator {
382 chunk: Option<DataChunk>,
383 returned: bool,
384 }
385
386 impl MockScanOperator {
387 fn new(chunk: DataChunk) -> Self {
388 Self {
389 chunk: Some(chunk),
390 returned: false,
391 }
392 }
393 }
394
395 impl Operator for MockScanOperator {
396 fn next(&mut self) -> OperatorResult {
397 if self.returned {
398 Ok(None)
399 } else {
400 self.returned = true;
401 Ok(self.chunk.take())
402 }
403 }
404
405 fn reset(&mut self) {
406 self.returned = false;
407 }
408
409 fn name(&self) -> &'static str {
410 "MockScan"
411 }
412 }
413
414 fn create_node_chunk(node_ids: &[i64]) -> DataChunk {
415 let mut col = ValueVector::with_type(LogicalType::Int64);
416 for &id in node_ids {
417 col.push_int64(id);
418 }
419 DataChunk::new(vec![col])
420 }
421
422 #[test]
423 fn test_leapfrog_binary_intersection() {
424 let chunk1 = create_node_chunk(&[1, 2, 3, 5]);
429 let chunk2 = create_node_chunk(&[2, 3, 4, 5]);
430
431 let op1: Box<dyn Operator> = Box::new(MockScanOperator::new(chunk1));
432 let op2: Box<dyn Operator> = Box::new(MockScanOperator::new(chunk2));
433
434 let mut leapfrog = LeapfrogJoinOperator::new(
435 vec![op1, op2],
436 vec![vec![0], vec![0]], vec![LogicalType::Int64, LogicalType::Int64],
438 vec![(0, 0), (1, 0)], );
440
441 let mut all_results = Vec::new();
442 while let Some(chunk) = leapfrog.next().unwrap() {
443 for row in 0..chunk.row_count() {
444 let val1 = chunk.column(0).unwrap().get_int64(row).unwrap();
445 let val2 = chunk.column(1).unwrap().get_int64(row).unwrap();
446 all_results.push((val1, val2));
447 }
448 }
449
450 assert_eq!(all_results.len(), 3);
452 assert!(all_results.contains(&(2, 2)));
453 assert!(all_results.contains(&(3, 3)));
454 assert!(all_results.contains(&(5, 5)));
455 }
456
457 #[test]
458 fn test_leapfrog_empty_intersection() {
459 let chunk1 = create_node_chunk(&[1, 2, 3]);
464 let chunk2 = create_node_chunk(&[4, 5, 6]);
465
466 let op1: Box<dyn Operator> = Box::new(MockScanOperator::new(chunk1));
467 let op2: Box<dyn Operator> = Box::new(MockScanOperator::new(chunk2));
468
469 let mut leapfrog = LeapfrogJoinOperator::new(
470 vec![op1, op2],
471 vec![vec![0], vec![0]],
472 vec![LogicalType::Int64, LogicalType::Int64],
473 vec![(0, 0), (1, 0)],
474 );
475
476 let result = leapfrog.next().unwrap();
477 assert!(result.is_none());
478 }
479
480 #[test]
481 fn test_leapfrog_reset() {
482 let chunk1 = create_node_chunk(&[1, 2, 3]);
483 let chunk2 = create_node_chunk(&[2, 3, 4]);
484
485 let op1: Box<dyn Operator> = Box::new(MockScanOperator::new(chunk1.clone()));
486 let op2: Box<dyn Operator> = Box::new(MockScanOperator::new(chunk2.clone()));
487
488 let mut leapfrog = LeapfrogJoinOperator::new(
489 vec![op1, op2],
490 vec![vec![0], vec![0]],
491 vec![LogicalType::Int64, LogicalType::Int64],
492 vec![(0, 0), (1, 0)],
493 );
494
495 let mut _count = 0;
497 while leapfrog.next().unwrap().is_some() {
498 _count += 1;
499 }
500
501 leapfrog.reset();
504 assert!(!leapfrog.materialized);
505 assert!(leapfrog.results.is_empty());
506 }
507
508 #[test]
509 fn test_encode_decode_row_id() {
510 let test_cases = [
511 (0, 0, 0),
512 (1, 2, 3),
513 (255, 16777215, 4294967295), ];
515
516 for (input_idx, chunk_idx, row) in test_cases {
517 let encoded = LeapfrogJoinOperator::encode_row_id(input_idx, chunk_idx, row);
518 let decoded = LeapfrogJoinOperator::decode_row_id(encoded);
519 assert_eq!(decoded, (input_idx, chunk_idx, row));
520 }
521 }
522}