1use crate::{JoinKey, JoinOptions, JoinType};
40use arrow::array::{Array, ArrayRef, RecordBatch};
41use arrow::compute::take;
42use arrow::datatypes::{DataType, Schema};
43use llkv_column_map::store::Projection;
44use llkv_column_map::types::LogicalFieldId;
45use llkv_expr::{Expr, Filter, Operator};
46use llkv_result::{Error, Result as LlkvResult};
47use llkv_storage::pager::Pager;
48use llkv_table::table::{ScanProjection, ScanStreamOptions, Table};
49use llkv_table::types::FieldId;
50use rustc_hash::FxHashMap;
51use simd_r_drive_entry_handle::EntryHandle;
52use std::hash::{Hash, Hasher};
53use std::ops::Bound;
54use std::sync::Arc;
55
56#[derive(Debug, Clone, Eq)]
58struct HashKey {
59 values: Vec<KeyValue>,
60}
61
62#[derive(Debug, Clone)]
64enum KeyValue {
65 Null,
66 Int8(i8),
67 Int16(i16),
68 Int32(i32),
69 Int64(i64),
70 UInt8(u8),
71 UInt16(u16),
72 UInt32(u32),
73 UInt64(u64),
74 Float32(u32), Float64(u64), Utf8(String),
77 Binary(Vec<u8>),
78}
79
80impl PartialEq for KeyValue {
81 fn eq(&self, other: &Self) -> bool {
82 match (self, other) {
83 (KeyValue::Null, KeyValue::Null) => false, (KeyValue::Int8(a), KeyValue::Int8(b)) => a == b,
85 (KeyValue::Int16(a), KeyValue::Int16(b)) => a == b,
86 (KeyValue::Int32(a), KeyValue::Int32(b)) => a == b,
87 (KeyValue::Int64(a), KeyValue::Int64(b)) => a == b,
88 (KeyValue::UInt8(a), KeyValue::UInt8(b)) => a == b,
89 (KeyValue::UInt16(a), KeyValue::UInt16(b)) => a == b,
90 (KeyValue::UInt32(a), KeyValue::UInt32(b)) => a == b,
91 (KeyValue::UInt64(a), KeyValue::UInt64(b)) => a == b,
92 (KeyValue::Float32(a), KeyValue::Float32(b)) => a == b,
93 (KeyValue::Float64(a), KeyValue::Float64(b)) => a == b,
94 (KeyValue::Utf8(a), KeyValue::Utf8(b)) => a == b,
95 (KeyValue::Binary(a), KeyValue::Binary(b)) => a == b,
96 _ => false,
97 }
98 }
99}
100
101impl Eq for KeyValue {}
102
103impl Hash for KeyValue {
104 fn hash<H: Hasher>(&self, state: &mut H) {
105 match self {
106 KeyValue::Null => 0u8.hash(state),
107 KeyValue::Int8(v) => v.hash(state),
108 KeyValue::Int16(v) => v.hash(state),
109 KeyValue::Int32(v) => v.hash(state),
110 KeyValue::Int64(v) => v.hash(state),
111 KeyValue::UInt8(v) => v.hash(state),
112 KeyValue::UInt16(v) => v.hash(state),
113 KeyValue::UInt32(v) => v.hash(state),
114 KeyValue::UInt64(v) => v.hash(state),
115 KeyValue::Float32(v) => v.hash(state),
116 KeyValue::Float64(v) => v.hash(state),
117 KeyValue::Utf8(v) => v.hash(state),
118 KeyValue::Binary(v) => v.hash(state),
119 }
120 }
121}
122
123impl PartialEq for HashKey {
124 fn eq(&self, other: &Self) -> bool {
125 if self.values.len() != other.values.len() {
126 return false;
127 }
128 self.values.iter().zip(&other.values).all(|(a, b)| a == b)
129 }
130}
131
132impl Hash for HashKey {
133 fn hash<H: Hasher>(&self, state: &mut H) {
134 for value in &self.values {
135 value.hash(state);
136 }
137 }
138}
139
140type RowRef = (usize, usize);
142
143type HashTable = FxHashMap<HashKey, Vec<RowRef>>;
145
146pub fn hash_join_stream<P, F>(
148 left: &Table<P>,
149 right: &Table<P>,
150 keys: &[JoinKey],
151 options: &JoinOptions,
152 mut on_batch: F,
153) -> LlkvResult<()>
154where
155 P: Pager<Blob = EntryHandle> + Send + Sync,
156 F: FnMut(RecordBatch),
157{
158 let left_schema = left.schema()?;
160 let right_schema = right.schema()?;
161
162 if keys.len() == 1 {
166 if let (Ok(left_dtype), Ok(right_dtype)) = (
168 get_key_datatype(&left_schema, keys[0].left_field),
169 get_key_datatype(&right_schema, keys[0].right_field),
170 ) && left_dtype == right_dtype
171 {
172 match left_dtype {
173 DataType::Int32 => {
174 return hash_join_i32_fast_path(left, right, keys, options, on_batch);
175 }
176 DataType::Int64 => {
177 return hash_join_i64_fast_path(left, right, keys, options, on_batch);
178 }
179 DataType::UInt32 => {
180 return hash_join_u32_fast_path(left, right, keys, options, on_batch);
181 }
182 DataType::UInt64 => {
183 return hash_join_u64_fast_path(left, right, keys, options, on_batch);
184 }
185 _ => {
186 }
188 }
189 }
190 }
192
193 let left_projections = build_user_projections(left, &left_schema)?;
195 let right_projections = build_user_projections(right, &right_schema)?;
196
197 let output_schema = build_output_schema(&left_schema, &right_schema, options.join_type)?;
199
200 let (hash_table, build_batches) = if right_projections.is_empty() {
203 (HashTable::default(), Vec::new())
204 } else {
205 build_hash_table(right, &right_projections, keys, &right_schema)?
206 };
207
208 let probe_key_indices = if left_projections.is_empty() || right_projections.is_empty() {
210 Vec::new()
211 } else {
212 extract_left_key_indices(keys, &left_schema)?
213 };
214
215 let batch_size = options.batch_size;
217
218 if !left_projections.is_empty() {
219 let filter_expr = build_all_rows_filter(&left_projections)?;
220
221 left.scan_stream(
222 &left_projections,
223 &filter_expr,
224 ScanStreamOptions::default(),
225 |probe_batch| {
226 let result = match options.join_type {
227 JoinType::Inner => process_inner_probe(
228 &probe_batch,
229 &probe_key_indices,
230 &hash_table,
231 &build_batches,
232 &output_schema,
233 keys,
234 batch_size,
235 &mut on_batch,
236 ),
237 JoinType::Left => process_left_probe(
238 &probe_batch,
239 &probe_key_indices,
240 &hash_table,
241 &build_batches,
242 &output_schema,
243 keys,
244 batch_size,
245 &mut on_batch,
246 ),
247 JoinType::Semi => process_semi_probe(
248 &probe_batch,
249 &probe_key_indices,
250 &hash_table,
251 &output_schema,
252 keys,
253 batch_size,
254 &mut on_batch,
255 ),
256 JoinType::Anti => process_anti_probe(
257 &probe_batch,
258 &probe_key_indices,
259 &hash_table,
260 &output_schema,
261 keys,
262 batch_size,
263 &mut on_batch,
264 ),
265 _ => {
266 eprintln!(
267 "Hash join does not yet support {:?}, falling back would cause issues",
268 options.join_type
269 );
270 Ok(())
271 }
272 };
273
274 if let Err(e) = result {
275 eprintln!("Join error: {}", e);
276 }
277 },
278 )?;
279 }
280
281 if matches!(options.join_type, JoinType::Right | JoinType::Full) {
283 return Err(Error::Internal(
284 "Right and Full outer joins not yet implemented for hash join".to_string(),
285 ));
286 }
287
288 Ok(())
289}
290
291fn build_hash_table<P>(
293 table: &Table<P>,
294 projections: &[ScanProjection],
295 join_keys: &[JoinKey],
296 schema: &Arc<Schema>,
297) -> LlkvResult<(HashTable, Vec<RecordBatch>)>
298where
299 P: Pager<Blob = EntryHandle> + Send + Sync,
300{
301 let mut hash_table = HashTable::default();
302 let mut batches = Vec::new();
303 let key_indices = extract_right_key_indices(join_keys, schema)?;
304 let filter_expr = build_all_rows_filter(projections)?;
305
306 table.scan_stream(
307 projections,
308 &filter_expr,
309 ScanStreamOptions::default(),
310 |batch| {
311 let batch_idx = batches.len();
312
313 for row_idx in 0..batch.num_rows() {
315 if let Ok(key) = extract_hash_key(&batch, &key_indices, row_idx, join_keys) {
316 hash_table
317 .entry(key)
318 .or_default()
319 .push((batch_idx, row_idx));
320 }
321 }
322
323 batches.push(batch.clone());
324 },
325 )?;
326
327 Ok((hash_table, batches))
328}
329
330fn extract_hash_key(
332 batch: &RecordBatch,
333 key_indices: &[usize],
334 row_idx: usize,
335 join_keys: &[JoinKey],
336) -> LlkvResult<HashKey> {
337 let mut values = Vec::with_capacity(key_indices.len());
338
339 for (&col_idx, join_key) in key_indices.iter().zip(join_keys) {
340 let column = batch.column(col_idx);
341
342 if column.is_null(row_idx) {
344 if join_key.null_equals_null {
345 values.push(KeyValue::Utf8("<NULL>".to_string())); } else {
347 values.push(KeyValue::Null);
348 }
349 continue;
350 }
351
352 let value = extract_key_value(column, row_idx)?;
353 values.push(value);
354 }
355
356 Ok(HashKey { values })
357}
358
359fn extract_key_value(column: &ArrayRef, row_idx: usize) -> LlkvResult<KeyValue> {
361 use arrow::array::*;
362
363 let value = match column.data_type() {
364 DataType::Int8 => KeyValue::Int8(
365 column
366 .as_any()
367 .downcast_ref::<Int8Array>()
368 .unwrap()
369 .value(row_idx),
370 ),
371 DataType::Int16 => KeyValue::Int16(
372 column
373 .as_any()
374 .downcast_ref::<Int16Array>()
375 .unwrap()
376 .value(row_idx),
377 ),
378 DataType::Int32 => KeyValue::Int32(
379 column
380 .as_any()
381 .downcast_ref::<Int32Array>()
382 .unwrap()
383 .value(row_idx),
384 ),
385 DataType::Int64 => KeyValue::Int64(
386 column
387 .as_any()
388 .downcast_ref::<Int64Array>()
389 .unwrap()
390 .value(row_idx),
391 ),
392 DataType::UInt8 => KeyValue::UInt8(
393 column
394 .as_any()
395 .downcast_ref::<UInt8Array>()
396 .unwrap()
397 .value(row_idx),
398 ),
399 DataType::UInt16 => KeyValue::UInt16(
400 column
401 .as_any()
402 .downcast_ref::<UInt16Array>()
403 .unwrap()
404 .value(row_idx),
405 ),
406 DataType::UInt32 => KeyValue::UInt32(
407 column
408 .as_any()
409 .downcast_ref::<UInt32Array>()
410 .unwrap()
411 .value(row_idx),
412 ),
413 DataType::UInt64 => KeyValue::UInt64(
414 column
415 .as_any()
416 .downcast_ref::<UInt64Array>()
417 .unwrap()
418 .value(row_idx),
419 ),
420 DataType::Float32 => {
421 let val = column
422 .as_any()
423 .downcast_ref::<Float32Array>()
424 .unwrap()
425 .value(row_idx);
426 KeyValue::Float32(val.to_bits())
427 }
428 DataType::Float64 => {
429 let val = column
430 .as_any()
431 .downcast_ref::<Float64Array>()
432 .unwrap()
433 .value(row_idx);
434 KeyValue::Float64(val.to_bits())
435 }
436 DataType::Utf8 => KeyValue::Utf8(
437 column
438 .as_any()
439 .downcast_ref::<StringArray>()
440 .unwrap()
441 .value(row_idx)
442 .to_string(),
443 ),
444 DataType::Binary => KeyValue::Binary(
445 column
446 .as_any()
447 .downcast_ref::<BinaryArray>()
448 .unwrap()
449 .value(row_idx)
450 .to_vec(),
451 ),
452 dt => {
453 return Err(Error::Internal(format!(
454 "Unsupported join key type: {:?}",
455 dt
456 )));
457 }
458 };
459
460 Ok(value)
461}
462
463#[allow(clippy::too_many_arguments)]
465fn process_inner_probe<F>(
466 probe_batch: &RecordBatch,
467 probe_key_indices: &[usize],
468 hash_table: &HashTable,
469 build_batches: &[RecordBatch],
470 output_schema: &Arc<Schema>,
471 join_keys: &[JoinKey],
472 batch_size: usize,
473 on_batch: &mut F,
474) -> LlkvResult<()>
475where
476 F: FnMut(RecordBatch),
477{
478 let mut probe_indices = Vec::new();
479 let mut build_indices = Vec::new();
480
481 for probe_row_idx in 0..probe_batch.num_rows() {
482 if let Ok(key) = extract_hash_key(probe_batch, probe_key_indices, probe_row_idx, join_keys)
483 && let Some(build_rows) = hash_table.get(&key)
484 {
485 for &(batch_idx, row_idx) in build_rows {
486 probe_indices.push(probe_row_idx);
487 build_indices.push((batch_idx, row_idx));
488 }
489 }
490
491 if probe_indices.len() >= batch_size {
493 emit_joined_batch(
494 probe_batch,
495 &probe_indices,
496 build_batches,
497 &build_indices,
498 output_schema,
499 on_batch,
500 )?;
501 probe_indices.clear();
502 build_indices.clear();
503 }
504 }
505
506 if !probe_indices.is_empty() {
508 emit_joined_batch(
509 probe_batch,
510 &probe_indices,
511 build_batches,
512 &build_indices,
513 output_schema,
514 on_batch,
515 )?;
516 }
517
518 Ok(())
519}
520
521#[allow(clippy::too_many_arguments)]
523fn process_left_probe<F>(
524 probe_batch: &RecordBatch,
525 probe_key_indices: &[usize],
526 hash_table: &HashTable,
527 build_batches: &[RecordBatch],
528 output_schema: &Arc<Schema>,
529 join_keys: &[JoinKey],
530 batch_size: usize,
531 on_batch: &mut F,
532) -> LlkvResult<()>
533where
534 F: FnMut(RecordBatch),
535{
536 let mut probe_indices = Vec::new();
537 let mut build_indices = Vec::new();
538
539 for probe_row_idx in 0..probe_batch.num_rows() {
540 let mut found_match = false;
541
542 if let Ok(key) = extract_hash_key(probe_batch, probe_key_indices, probe_row_idx, join_keys)
543 && let Some(build_rows) = hash_table.get(&key)
544 {
545 for &(batch_idx, row_idx) in build_rows {
546 probe_indices.push(probe_row_idx);
547 build_indices.push(Some((batch_idx, row_idx)));
548 found_match = true;
549 }
550 }
551
552 if !found_match {
553 probe_indices.push(probe_row_idx);
555 build_indices.push(None);
556 }
557
558 if probe_indices.len() >= batch_size {
560 emit_left_joined_batch(
561 probe_batch,
562 &probe_indices,
563 build_batches,
564 &build_indices,
565 output_schema,
566 on_batch,
567 )?;
568 probe_indices.clear();
569 build_indices.clear();
570 }
571 }
572
573 if !probe_indices.is_empty() {
575 emit_left_joined_batch(
576 probe_batch,
577 &probe_indices,
578 build_batches,
579 &build_indices,
580 output_schema,
581 on_batch,
582 )?;
583 }
584
585 Ok(())
586}
587
588#[allow(clippy::too_many_arguments)]
590fn process_semi_probe<F>(
591 probe_batch: &RecordBatch,
592 probe_key_indices: &[usize],
593 hash_table: &HashTable,
594 output_schema: &Arc<Schema>,
595 join_keys: &[JoinKey],
596 batch_size: usize,
597 on_batch: &mut F,
598) -> LlkvResult<()>
599where
600 F: FnMut(RecordBatch),
601{
602 let mut probe_indices = Vec::new();
603
604 for probe_row_idx in 0..probe_batch.num_rows() {
605 if let Ok(key) = extract_hash_key(probe_batch, probe_key_indices, probe_row_idx, join_keys)
606 && hash_table.contains_key(&key)
607 {
608 probe_indices.push(probe_row_idx);
609 }
610
611 if probe_indices.len() >= batch_size {
613 emit_semi_batch(probe_batch, &probe_indices, output_schema, on_batch)?;
614 probe_indices.clear();
615 }
616 }
617
618 if !probe_indices.is_empty() {
620 emit_semi_batch(probe_batch, &probe_indices, output_schema, on_batch)?;
621 }
622
623 Ok(())
624}
625
626#[allow(clippy::too_many_arguments)]
628fn process_anti_probe<F>(
629 probe_batch: &RecordBatch,
630 probe_key_indices: &[usize],
631 hash_table: &HashTable,
632 output_schema: &Arc<Schema>,
633 join_keys: &[JoinKey],
634 batch_size: usize,
635 on_batch: &mut F,
636) -> LlkvResult<()>
637where
638 F: FnMut(RecordBatch),
639{
640 let mut probe_indices = Vec::new();
641
642 for probe_row_idx in 0..probe_batch.num_rows() {
643 let mut found = false;
644 if let Ok(key) = extract_hash_key(probe_batch, probe_key_indices, probe_row_idx, join_keys)
645 {
646 found = hash_table.contains_key(&key);
647 }
648
649 if !found {
650 probe_indices.push(probe_row_idx);
651 }
652
653 if probe_indices.len() >= batch_size {
655 emit_semi_batch(probe_batch, &probe_indices, output_schema, on_batch)?;
656 probe_indices.clear();
657 }
658 }
659
660 if !probe_indices.is_empty() {
662 emit_semi_batch(probe_batch, &probe_indices, output_schema, on_batch)?;
663 }
664
665 Ok(())
666}
667
668fn emit_joined_batch<F>(
670 probe_batch: &RecordBatch,
671 probe_indices: &[usize],
672 build_batches: &[RecordBatch],
673 build_indices: &[(usize, usize)],
674 output_schema: &Arc<Schema>,
675 on_batch: &mut F,
676) -> LlkvResult<()>
677where
678 F: FnMut(RecordBatch),
679{
680 let probe_arrays = gather_indices(probe_batch, probe_indices)?;
681 let build_arrays = gather_indices_from_batches(build_batches, build_indices)?;
682
683 let output_arrays: Vec<ArrayRef> = probe_arrays.into_iter().chain(build_arrays).collect();
684
685 let output_batch = RecordBatch::try_new(output_schema.clone(), output_arrays)?;
686 on_batch(output_batch);
687 Ok(())
688}
689
690fn emit_left_joined_batch<F>(
692 probe_batch: &RecordBatch,
693 probe_indices: &[usize],
694 build_batches: &[RecordBatch],
695 build_indices: &[Option<(usize, usize)>],
696 output_schema: &Arc<Schema>,
697 on_batch: &mut F,
698) -> LlkvResult<()>
699where
700 F: FnMut(RecordBatch),
701{
702 let probe_arrays = gather_indices(probe_batch, probe_indices)?;
703 let build_arrays = gather_optional_indices_from_batches(build_batches, build_indices)?;
704
705 let output_arrays: Vec<ArrayRef> = probe_arrays.into_iter().chain(build_arrays).collect();
706
707 let output_batch = RecordBatch::try_new(output_schema.clone(), output_arrays)?;
708 on_batch(output_batch);
709 Ok(())
710}
711
712fn emit_semi_batch<F>(
714 probe_batch: &RecordBatch,
715 probe_indices: &[usize],
716 output_schema: &Arc<Schema>,
717 on_batch: &mut F,
718) -> LlkvResult<()>
719where
720 F: FnMut(RecordBatch),
721{
722 let probe_arrays = gather_indices(probe_batch, probe_indices)?;
723 let output_batch = RecordBatch::try_new(output_schema.clone(), probe_arrays)?;
724 on_batch(output_batch);
725 Ok(())
726}
727
728fn build_user_projections<P>(
730 table: &Table<P>,
731 schema: &Arc<Schema>,
732) -> LlkvResult<Vec<ScanProjection>>
733where
734 P: Pager<Blob = EntryHandle> + Send + Sync,
735{
736 let mut projections = Vec::new();
737
738 for field in schema.fields() {
739 if field.name() == "row_id" {
740 continue;
741 }
742
743 if let Some(field_id_str) = field.metadata().get("field_id") {
744 let field_id: u32 = field_id_str.parse().map_err(|_| {
745 Error::Internal(format!("Invalid field_id in schema: {}", field_id_str))
746 })?;
747 let lfid = LogicalFieldId::for_user(table.table_id(), field_id);
748 projections.push(ScanProjection::Column(Projection::with_alias(
749 lfid,
750 field.name().to_string(),
751 )));
752 }
753 }
754
755 Ok(projections)
756}
757
758fn build_all_rows_filter(projections: &[ScanProjection]) -> LlkvResult<Expr<'static, FieldId>> {
759 if projections.is_empty() {
760 return Ok(Expr::Pred(Filter {
761 field_id: 0,
762 op: Operator::Range {
763 lower: Bound::Unbounded,
764 upper: Bound::Unbounded,
765 },
766 }));
767 }
768
769 let first_field = match &projections[0] {
770 ScanProjection::Column(proj) => proj.logical_field_id.field_id(),
771 ScanProjection::Computed { .. } => {
772 return Err(Error::InvalidArgumentError(
773 "join projections cannot include computed columns yet".to_string(),
774 ));
775 }
776 };
777
778 Ok(Expr::Pred(Filter {
779 field_id: first_field,
780 op: Operator::Range {
781 lower: Bound::Unbounded,
782 upper: Bound::Unbounded,
783 },
784 }))
785}
786
787fn extract_left_key_indices(keys: &[JoinKey], schema: &Arc<Schema>) -> LlkvResult<Vec<usize>> {
788 keys.iter()
789 .map(|key| find_field_index(schema, key.left_field))
790 .collect()
791}
792
793fn extract_right_key_indices(keys: &[JoinKey], schema: &Arc<Schema>) -> LlkvResult<Vec<usize>> {
794 keys.iter()
795 .map(|key| find_field_index(schema, key.right_field))
796 .collect()
797}
798
799fn find_field_index(schema: &Schema, target_field_id: FieldId) -> LlkvResult<usize> {
800 let mut user_col_idx = 0;
801
802 for field in schema.fields() {
803 if field.name() == "row_id" {
804 continue;
805 }
806
807 if let Some(field_id_str) = field.metadata().get("field_id") {
808 let field_id: u32 = field_id_str.parse().map_err(|_| {
809 Error::Internal(format!("Invalid field_id in schema: {}", field_id_str))
810 })?;
811
812 if field_id == target_field_id {
813 return Ok(user_col_idx);
814 }
815 }
816
817 user_col_idx += 1;
818 }
819
820 Err(Error::Internal(format!(
821 "field_id {} not found in schema",
822 target_field_id
823 )))
824}
825
826fn get_key_datatype(schema: &Schema, field_id: FieldId) -> LlkvResult<DataType> {
828 for field in schema.fields() {
829 if field.name() == "row_id" {
830 continue;
831 }
832
833 if let Some(field_id_str) = field.metadata().get("field_id") {
834 let fid: u32 = field_id_str.parse().map_err(|_| {
835 Error::Internal(format!("Invalid field_id in schema: {}", field_id_str))
836 })?;
837
838 if fid == field_id {
839 return Ok(field.data_type().clone());
840 }
841 }
842 }
843
844 Err(Error::Internal(format!(
845 "field_id {} not found in schema",
846 field_id
847 )))
848}
849
850fn build_output_schema(
851 left_schema: &Schema,
852 right_schema: &Schema,
853 join_type: JoinType,
854) -> LlkvResult<Arc<Schema>> {
855 let mut fields = Vec::new();
856
857 if matches!(join_type, JoinType::Semi | JoinType::Anti) {
859 for field in left_schema.fields() {
860 if field.name() != "row_id" {
861 fields.push(field.clone());
862 }
863 }
864 return Ok(Arc::new(Schema::new(fields)));
865 }
866
867 for field in left_schema.fields() {
869 if field.name() != "row_id" {
870 fields.push(field.clone());
871 }
872 }
873
874 for field in right_schema.fields() {
875 if field.name() != "row_id" {
876 fields.push(field.clone());
877 }
878 }
879
880 Ok(Arc::new(Schema::new(fields)))
881}
882
883fn gather_indices(batch: &RecordBatch, indices: &[usize]) -> LlkvResult<Vec<ArrayRef>> {
884 let indices_array =
885 arrow::array::UInt32Array::from(indices.iter().map(|&i| i as u32).collect::<Vec<_>>());
886
887 let mut result = Vec::new();
888 for column in batch.columns() {
889 let gathered = take(column.as_ref(), &indices_array, None)?;
890 result.push(gathered);
891 }
892
893 Ok(result)
894}
895
896fn gather_indices_from_batches(
897 batches: &[RecordBatch],
898 indices: &[(usize, usize)],
899) -> LlkvResult<Vec<ArrayRef>> {
900 if batches.is_empty() || indices.is_empty() {
901 return Ok(Vec::new());
902 }
903
904 let num_columns = batches[0].num_columns();
905 let mut result = Vec::with_capacity(num_columns);
906
907 for col_idx in 0..num_columns {
908 let mut column_data: Vec<ArrayRef> = Vec::new();
909
910 for &(batch_idx, row_idx) in indices {
911 let batch = &batches[batch_idx];
912 let column = batch.column(col_idx);
913 let single_row = take(
914 column.as_ref(),
915 &arrow::array::UInt32Array::from(vec![row_idx as u32]),
916 None,
917 )?;
918 column_data.push(single_row);
919 }
920
921 let concatenated =
922 arrow::compute::concat(&column_data.iter().map(|a| a.as_ref()).collect::<Vec<_>>())?;
923 result.push(concatenated);
924 }
925
926 Ok(result)
927}
928
929fn gather_optional_indices_from_batches(
930 batches: &[RecordBatch],
931 indices: &[Option<(usize, usize)>],
932) -> LlkvResult<Vec<ArrayRef>> {
933 if batches.is_empty() {
934 return Ok(Vec::new());
935 }
936
937 let num_columns = batches[0].num_columns();
938 let mut result = Vec::with_capacity(num_columns);
939
940 for col_idx in 0..num_columns {
941 let mut column_data: Vec<ArrayRef> = Vec::new();
942
943 for opt_idx in indices {
944 if let Some((batch_idx, row_idx)) = opt_idx {
945 let batch = &batches[*batch_idx];
946 let column = batch.column(col_idx);
947 let single_row = take(
948 column.as_ref(),
949 &arrow::array::UInt32Array::from(vec![*row_idx as u32]),
950 None,
951 )?;
952 column_data.push(single_row);
953 } else {
954 let column = batches[0].column(col_idx);
956 let null_array = arrow::array::new_null_array(column.data_type(), 1);
957 column_data.push(null_array);
958 }
959 }
960
961 let concatenated =
962 arrow::compute::concat(&column_data.iter().map(|a| a.as_ref()).collect::<Vec<_>>())?;
963 result.push(concatenated);
964 }
965
966 Ok(result)
967}
968
969macro_rules! impl_integer_fast_path {
978 (
979 fast_path_fn: $fast_path_fn:ident,
980 build_fn: $build_fn:ident,
981 inner_probe_fn: $inner_probe_fn:ident,
982 left_probe_fn: $left_probe_fn:ident,
983 semi_probe_fn: $semi_probe_fn:ident,
984 anti_probe_fn: $anti_probe_fn:ident,
985 rust_type: $rust_type:ty,
986 arrow_array: $arrow_array:ty,
987 null_sentinel: $null_sentinel:expr
988 ) => {
989 #[allow(clippy::too_many_arguments)]
994 fn $fast_path_fn<P, F>(
995 left: &Table<P>,
996 right: &Table<P>,
997 keys: &[JoinKey],
998 options: &JoinOptions,
999 mut on_batch: F,
1000 ) -> LlkvResult<()>
1001 where
1002 P: Pager<Blob = EntryHandle> + Send + Sync,
1003 F: FnMut(RecordBatch),
1004 {
1005 let left_schema = left.schema()?;
1006 let right_schema = right.schema()?;
1007
1008 let left_projections = build_user_projections(left, &left_schema)?;
1009 let right_projections = build_user_projections(right, &right_schema)?;
1010
1011 let output_schema =
1012 build_output_schema(&left_schema, &right_schema, options.join_type)?;
1013
1014 let (hash_table, build_batches) = if right_projections.is_empty() {
1015 (FxHashMap::default(), Vec::new())
1016 } else {
1017 $build_fn(right, &right_projections, keys, &right_schema)?
1018 };
1019
1020 let probe_key_idx = if left_projections.is_empty() || right_projections.is_empty() {
1021 0
1022 } else {
1023 find_field_index(&left_schema, keys[0].left_field)?
1024 };
1025
1026 let batch_size = options.batch_size;
1027
1028 if !left_projections.is_empty() {
1029 let filter_expr = build_all_rows_filter(&left_projections)?;
1030 let null_equals_null = keys[0].null_equals_null;
1031
1032 left.scan_stream(
1033 &left_projections,
1034 &filter_expr,
1035 ScanStreamOptions::default(),
1036 |probe_batch| {
1037 let result = match options.join_type {
1038 JoinType::Inner => $inner_probe_fn(
1039 &probe_batch,
1040 probe_key_idx,
1041 &hash_table,
1042 &build_batches,
1043 &output_schema,
1044 null_equals_null,
1045 batch_size,
1046 &mut on_batch,
1047 ),
1048 JoinType::Left => $left_probe_fn(
1049 &probe_batch,
1050 probe_key_idx,
1051 &hash_table,
1052 &build_batches,
1053 &output_schema,
1054 null_equals_null,
1055 batch_size,
1056 &mut on_batch,
1057 ),
1058 JoinType::Semi => $semi_probe_fn(
1059 &probe_batch,
1060 probe_key_idx,
1061 &hash_table,
1062 &output_schema,
1063 null_equals_null,
1064 batch_size,
1065 &mut on_batch,
1066 ),
1067 JoinType::Anti => $anti_probe_fn(
1068 &probe_batch,
1069 probe_key_idx,
1070 &hash_table,
1071 &output_schema,
1072 null_equals_null,
1073 batch_size,
1074 &mut on_batch,
1075 ),
1076 _ => {
1077 eprintln!("Hash join does not yet support {:?}", options.join_type);
1078 Ok(())
1079 }
1080 };
1081
1082 if let Err(e) = result {
1083 eprintln!("Join error: {}", e);
1084 }
1085 },
1086 )?;
1087 }
1088
1089 if matches!(options.join_type, JoinType::Right | JoinType::Full) {
1090 return Err(Error::Internal(
1091 "Right and Full outer joins not yet implemented for hash join".to_string(),
1092 ));
1093 }
1094
1095 Ok(())
1096 }
1097
1098 fn $build_fn<P>(
1100 table: &Table<P>,
1101 projections: &[ScanProjection],
1102 join_keys: &[JoinKey],
1103 schema: &Arc<Schema>,
1104 ) -> LlkvResult<(FxHashMap<$rust_type, Vec<RowRef>>, Vec<RecordBatch>)>
1105 where
1106 P: Pager<Blob = EntryHandle> + Send + Sync,
1107 {
1108 let mut hash_table: FxHashMap<$rust_type, Vec<RowRef>> = FxHashMap::default();
1109 let mut batches = Vec::new();
1110 let key_idx = find_field_index(schema, join_keys[0].right_field)?;
1111 let filter_expr = build_all_rows_filter(projections)?;
1112 let null_equals_null = join_keys[0].null_equals_null;
1113
1114 table.scan_stream(
1115 projections,
1116 &filter_expr,
1117 ScanStreamOptions::default(),
1118 |batch| {
1119 let batch_idx = batches.len();
1120 let key_column = batch.column(key_idx);
1121 let key_array = match key_column.as_any().downcast_ref::<$arrow_array>() {
1122 Some(arr) => arr,
1123 None => {
1124 eprintln!(
1125 "Fast-path: Expected array type but got {:?}",
1126 key_column.data_type()
1127 );
1128 batches.push(batch.clone());
1129 return;
1130 }
1131 };
1132
1133 for row_idx in 0..batch.num_rows() {
1134 if key_array.is_null(row_idx) {
1135 if null_equals_null {
1136 hash_table
1137 .entry($null_sentinel)
1138 .or_default()
1139 .push((batch_idx, row_idx));
1140 }
1141 } else {
1142 let key = key_array.value(row_idx);
1143 hash_table
1144 .entry(key)
1145 .or_default()
1146 .push((batch_idx, row_idx));
1147 }
1148 }
1149
1150 batches.push(batch.clone());
1151 },
1152 )?;
1153
1154 Ok((hash_table, batches))
1155 }
1156
1157 #[allow(clippy::too_many_arguments)]
1159 fn $inner_probe_fn<F>(
1160 probe_batch: &RecordBatch,
1161 probe_key_idx: usize,
1162 hash_table: &FxHashMap<$rust_type, Vec<RowRef>>,
1163 build_batches: &[RecordBatch],
1164 output_schema: &Arc<Schema>,
1165 null_equals_null: bool,
1166 batch_size: usize,
1167 on_batch: &mut F,
1168 ) -> LlkvResult<()>
1169 where
1170 F: FnMut(RecordBatch),
1171 {
1172 let probe_keys = match probe_batch
1173 .column(probe_key_idx)
1174 .as_any()
1175 .downcast_ref::<$arrow_array>()
1176 {
1177 Some(arr) => arr,
1178 None => {
1179 return Err(Error::Internal(format!(
1180 "Fast-path: Expected array type at column {} but got {:?}",
1181 probe_key_idx,
1182 probe_batch.column(probe_key_idx).data_type()
1183 )));
1184 }
1185 };
1186 let mut probe_indices = Vec::with_capacity(batch_size);
1187 let mut build_indices = Vec::with_capacity(batch_size);
1188
1189 for probe_row_idx in 0..probe_batch.num_rows() {
1190 let key = if probe_keys.is_null(probe_row_idx) {
1191 if null_equals_null {
1192 $null_sentinel
1193 } else {
1194 continue;
1195 }
1196 } else {
1197 probe_keys.value(probe_row_idx)
1198 };
1199
1200 if let Some(build_rows) = hash_table.get(&key) {
1201 for &row_ref in build_rows {
1202 probe_indices.push(probe_row_idx);
1203 build_indices.push(row_ref);
1204 }
1205 }
1206
1207 if probe_indices.len() >= batch_size {
1208 emit_joined_batch(
1209 probe_batch,
1210 &probe_indices,
1211 build_batches,
1212 &build_indices,
1213 output_schema,
1214 on_batch,
1215 )?;
1216 probe_indices.clear();
1217 build_indices.clear();
1218 }
1219 }
1220
1221 if !probe_indices.is_empty() {
1222 emit_joined_batch(
1223 probe_batch,
1224 &probe_indices,
1225 build_batches,
1226 &build_indices,
1227 output_schema,
1228 on_batch,
1229 )?;
1230 }
1231
1232 Ok(())
1233 }
1234
1235 #[allow(clippy::too_many_arguments)]
1237 fn $left_probe_fn<F>(
1238 probe_batch: &RecordBatch,
1239 probe_key_idx: usize,
1240 hash_table: &FxHashMap<$rust_type, Vec<RowRef>>,
1241 build_batches: &[RecordBatch],
1242 output_schema: &Arc<Schema>,
1243 null_equals_null: bool,
1244 batch_size: usize,
1245 on_batch: &mut F,
1246 ) -> LlkvResult<()>
1247 where
1248 F: FnMut(RecordBatch),
1249 {
1250 let probe_keys = match probe_batch
1251 .column(probe_key_idx)
1252 .as_any()
1253 .downcast_ref::<$arrow_array>()
1254 {
1255 Some(arr) => arr,
1256 None => {
1257 return Err(Error::Internal(format!(
1258 "Fast-path: Expected array type at column {} but got {:?}",
1259 probe_key_idx,
1260 probe_batch.column(probe_key_idx).data_type()
1261 )));
1262 }
1263 };
1264 let mut probe_indices = Vec::with_capacity(batch_size);
1265 let mut build_indices = Vec::with_capacity(batch_size);
1266
1267 for probe_row_idx in 0..probe_batch.num_rows() {
1268 let key = if probe_keys.is_null(probe_row_idx) {
1269 if null_equals_null {
1270 $null_sentinel
1271 } else {
1272 probe_indices.push(probe_row_idx);
1273 build_indices.push(None);
1274 continue;
1275 }
1276 } else {
1277 probe_keys.value(probe_row_idx)
1278 };
1279
1280 if let Some(build_rows) = hash_table.get(&key) {
1281 for &row_ref in build_rows {
1282 probe_indices.push(probe_row_idx);
1283 build_indices.push(Some(row_ref));
1284 }
1285 } else {
1286 probe_indices.push(probe_row_idx);
1287 build_indices.push(None);
1288 }
1289
1290 if probe_indices.len() >= batch_size {
1291 emit_left_joined_batch(
1292 probe_batch,
1293 &probe_indices,
1294 build_batches,
1295 &build_indices,
1296 output_schema,
1297 on_batch,
1298 )?;
1299 probe_indices.clear();
1300 build_indices.clear();
1301 }
1302 }
1303
1304 if !probe_indices.is_empty() {
1305 emit_left_joined_batch(
1306 probe_batch,
1307 &probe_indices,
1308 build_batches,
1309 &build_indices,
1310 output_schema,
1311 on_batch,
1312 )?;
1313 }
1314
1315 Ok(())
1316 }
1317
1318 #[allow(clippy::too_many_arguments)]
1320 fn $semi_probe_fn<F>(
1321 probe_batch: &RecordBatch,
1322 probe_key_idx: usize,
1323 hash_table: &FxHashMap<$rust_type, Vec<RowRef>>,
1324 output_schema: &Arc<Schema>,
1325 null_equals_null: bool,
1326 batch_size: usize,
1327 on_batch: &mut F,
1328 ) -> LlkvResult<()>
1329 where
1330 F: FnMut(RecordBatch),
1331 {
1332 let probe_keys = match probe_batch
1333 .column(probe_key_idx)
1334 .as_any()
1335 .downcast_ref::<$arrow_array>()
1336 {
1337 Some(arr) => arr,
1338 None => {
1339 return Err(Error::Internal(format!(
1340 "Fast-path: Expected array type at column {} but got {:?}",
1341 probe_key_idx,
1342 probe_batch.column(probe_key_idx).data_type()
1343 )));
1344 }
1345 };
1346 let mut probe_indices = Vec::with_capacity(batch_size);
1347
1348 for probe_row_idx in 0..probe_batch.num_rows() {
1349 let key = if probe_keys.is_null(probe_row_idx) {
1350 if null_equals_null {
1351 $null_sentinel
1352 } else {
1353 continue;
1354 }
1355 } else {
1356 probe_keys.value(probe_row_idx)
1357 };
1358
1359 if hash_table.contains_key(&key) {
1360 probe_indices.push(probe_row_idx);
1361 }
1362
1363 if probe_indices.len() >= batch_size {
1364 emit_semi_batch(probe_batch, &probe_indices, output_schema, on_batch)?;
1365 probe_indices.clear();
1366 }
1367 }
1368
1369 if !probe_indices.is_empty() {
1370 emit_semi_batch(probe_batch, &probe_indices, output_schema, on_batch)?;
1371 }
1372
1373 Ok(())
1374 }
1375
1376 #[allow(clippy::too_many_arguments)]
1378 fn $anti_probe_fn<F>(
1379 probe_batch: &RecordBatch,
1380 probe_key_idx: usize,
1381 hash_table: &FxHashMap<$rust_type, Vec<RowRef>>,
1382 output_schema: &Arc<Schema>,
1383 null_equals_null: bool,
1384 batch_size: usize,
1385 on_batch: &mut F,
1386 ) -> LlkvResult<()>
1387 where
1388 F: FnMut(RecordBatch),
1389 {
1390 let probe_keys = match probe_batch
1391 .column(probe_key_idx)
1392 .as_any()
1393 .downcast_ref::<$arrow_array>()
1394 {
1395 Some(arr) => arr,
1396 None => {
1397 return Err(Error::Internal(format!(
1398 "Fast-path: Expected array type at column {} but got {:?}",
1399 probe_key_idx,
1400 probe_batch.column(probe_key_idx).data_type()
1401 )));
1402 }
1403 };
1404 let mut probe_indices = Vec::with_capacity(batch_size);
1405
1406 for probe_row_idx in 0..probe_batch.num_rows() {
1407 let key = if probe_keys.is_null(probe_row_idx) {
1408 if null_equals_null {
1409 $null_sentinel
1410 } else {
1411 probe_indices.push(probe_row_idx);
1412 continue;
1413 }
1414 } else {
1415 probe_keys.value(probe_row_idx)
1416 };
1417
1418 if !hash_table.contains_key(&key) {
1419 probe_indices.push(probe_row_idx);
1420 }
1421
1422 if probe_indices.len() >= batch_size {
1423 emit_semi_batch(probe_batch, &probe_indices, output_schema, on_batch)?;
1424 probe_indices.clear();
1425 }
1426 }
1427
1428 if !probe_indices.is_empty() {
1429 emit_semi_batch(probe_batch, &probe_indices, output_schema, on_batch)?;
1430 }
1431
1432 Ok(())
1433 }
1434 };
1435}
1436
1437impl_integer_fast_path!(
1439 fast_path_fn: hash_join_i32_fast_path,
1440 build_fn: build_i32_hash_table,
1441 inner_probe_fn: process_i32_inner_probe,
1442 left_probe_fn: process_i32_left_probe,
1443 semi_probe_fn: process_i32_semi_probe,
1444 anti_probe_fn: process_i32_anti_probe,
1445 rust_type: i32,
1446 arrow_array: arrow::array::Int32Array,
1447 null_sentinel: i32::MIN
1448);
1449
1450impl_integer_fast_path!(
1451 fast_path_fn: hash_join_i64_fast_path,
1452 build_fn: build_i64_hash_table,
1453 inner_probe_fn: process_i64_inner_probe,
1454 left_probe_fn: process_i64_left_probe,
1455 semi_probe_fn: process_i64_semi_probe,
1456 anti_probe_fn: process_i64_anti_probe,
1457 rust_type: i64,
1458 arrow_array: arrow::array::Int64Array,
1459 null_sentinel: i64::MIN
1460);
1461
1462impl_integer_fast_path!(
1463 fast_path_fn: hash_join_u32_fast_path,
1464 build_fn: build_u32_hash_table,
1465 inner_probe_fn: process_u32_inner_probe,
1466 left_probe_fn: process_u32_left_probe,
1467 semi_probe_fn: process_u32_semi_probe,
1468 anti_probe_fn: process_u32_anti_probe,
1469 rust_type: u32,
1470 arrow_array: arrow::array::UInt32Array,
1471 null_sentinel: u32::MAX
1472);
1473
1474impl_integer_fast_path!(
1475 fast_path_fn: hash_join_u64_fast_path,
1476 build_fn: build_u64_hash_table,
1477 inner_probe_fn: process_u64_inner_probe,
1478 left_probe_fn: process_u64_left_probe,
1479 semi_probe_fn: process_u64_semi_probe,
1480 anti_probe_fn: process_u64_anti_probe,
1481 rust_type: u64,
1482 arrow_array: arrow::array::UInt64Array,
1483 null_sentinel: u64::MAX
1484);