1use crate::{JoinKey, JoinOptions, JoinType};
35use arrow::array::{Array, ArrayRef, RecordBatch};
36use arrow::datatypes::{DataType, Schema};
37use llkv_column_map::gather::{
38 gather_indices, gather_indices_from_batches, gather_optional_indices_from_batches,
39};
40use llkv_column_map::store::Projection;
41use llkv_expr::{Expr, Filter, Operator};
42use llkv_result::{Error, Result as LlkvResult};
43use llkv_storage::pager::Pager;
44use llkv_table::schema_ext::CachedSchema;
45use llkv_table::table::{ScanProjection, ScanStreamOptions, Table};
46use llkv_table::types::FieldId;
47use llkv_threading::with_thread_pool;
48use llkv_types::LogicalFieldId;
49use rayon::prelude::*;
50use rustc_hash::FxHashMap;
51use simd_r_drive_entry_handle::EntryHandle;
52use std::hash::{Hash, Hasher};
53use std::ops::Bound;
54use std::sync::Arc;
55
56#[derive(Debug, Clone, Eq)]
58struct HashKey {
59 values: Vec<KeyValue>,
60}
61
62#[derive(Debug, Clone)]
64enum KeyValue {
65 Null,
66 Int8(i8),
67 Int16(i16),
68 Int32(i32),
69 Int64(i64),
70 UInt8(u8),
71 UInt16(u16),
72 UInt32(u32),
73 UInt64(u64),
74 Float32(u32), Float64(u64), Utf8(String),
77 Binary(Vec<u8>),
78}
79
80impl PartialEq for KeyValue {
81 fn eq(&self, other: &Self) -> bool {
82 match (self, other) {
83 (KeyValue::Null, KeyValue::Null) => false, (KeyValue::Int8(a), KeyValue::Int8(b)) => a == b,
85 (KeyValue::Int16(a), KeyValue::Int16(b)) => a == b,
86 (KeyValue::Int32(a), KeyValue::Int32(b)) => a == b,
87 (KeyValue::Int64(a), KeyValue::Int64(b)) => a == b,
88 (KeyValue::UInt8(a), KeyValue::UInt8(b)) => a == b,
89 (KeyValue::UInt16(a), KeyValue::UInt16(b)) => a == b,
90 (KeyValue::UInt32(a), KeyValue::UInt32(b)) => a == b,
91 (KeyValue::UInt64(a), KeyValue::UInt64(b)) => a == b,
92 (KeyValue::Float32(a), KeyValue::Float32(b)) => a == b,
93 (KeyValue::Float64(a), KeyValue::Float64(b)) => a == b,
94 (KeyValue::Utf8(a), KeyValue::Utf8(b)) => a == b,
95 (KeyValue::Binary(a), KeyValue::Binary(b)) => a == b,
96 _ => false,
97 }
98 }
99}
100
101impl Eq for KeyValue {}
102
103impl Hash for KeyValue {
104 fn hash<H: Hasher>(&self, state: &mut H) {
105 match self {
106 KeyValue::Null => 0u8.hash(state),
107 KeyValue::Int8(v) => v.hash(state),
108 KeyValue::Int16(v) => v.hash(state),
109 KeyValue::Int32(v) => v.hash(state),
110 KeyValue::Int64(v) => v.hash(state),
111 KeyValue::UInt8(v) => v.hash(state),
112 KeyValue::UInt16(v) => v.hash(state),
113 KeyValue::UInt32(v) => v.hash(state),
114 KeyValue::UInt64(v) => v.hash(state),
115 KeyValue::Float32(v) => v.hash(state),
116 KeyValue::Float64(v) => v.hash(state),
117 KeyValue::Utf8(v) => v.hash(state),
118 KeyValue::Binary(v) => v.hash(state),
119 }
120 }
121}
122
123impl PartialEq for HashKey {
124 fn eq(&self, other: &Self) -> bool {
125 if self.values.len() != other.values.len() {
126 return false;
127 }
128 self.values.iter().zip(&other.values).all(|(a, b)| a == b)
129 }
130}
131
132impl Hash for HashKey {
133 fn hash<H: Hasher>(&self, state: &mut H) {
134 for value in &self.values {
135 value.hash(state);
136 }
137 }
138}
139
140type RowRef = (usize, usize);
142
143type HashTable = FxHashMap<HashKey, Vec<RowRef>>;
145
146pub fn hash_join_stream<P, F>(
152 left: &Table<P>,
153 right: &Table<P>,
154 keys: &[JoinKey],
155 options: &JoinOptions,
156 mut on_batch: F,
157) -> LlkvResult<()>
158where
159 P: Pager<Blob = EntryHandle> + Send + Sync,
160 F: FnMut(RecordBatch),
161{
162 if keys.is_empty() {
164 return cross_product_stream(left, right, options, on_batch);
165 }
166
167 let left_schema = left.schema()?;
169 let right_schema = right.schema()?;
170
171 if keys.len() == 1 {
175 if let (Ok(left_dtype), Ok(right_dtype)) = (
177 get_key_datatype(&left_schema, keys[0].left_field),
178 get_key_datatype(&right_schema, keys[0].right_field),
179 ) && left_dtype == right_dtype
180 {
181 match left_dtype {
182 DataType::Int32 => {
183 return hash_join_i32_fast_path(left, right, keys, options, on_batch);
184 }
185 DataType::Int64 => {
186 return hash_join_i64_fast_path(left, right, keys, options, on_batch);
187 }
188 DataType::UInt32 => {
189 return hash_join_u32_fast_path(left, right, keys, options, on_batch);
190 }
191 DataType::UInt64 => {
192 return hash_join_u64_fast_path(left, right, keys, options, on_batch);
193 }
194 _ => {
195 }
197 }
198 }
199 }
201
202 let left_projections = build_user_projections(left, &left_schema)?;
204 let right_projections = build_user_projections(right, &right_schema)?;
205
206 let output_schema = build_output_schema(&left_schema, &right_schema, options.join_type)?;
208
209 let (hash_table, build_batches) = if right_projections.is_empty() {
212 (HashTable::default(), Vec::new())
213 } else {
214 build_hash_table(right, &right_projections, keys, &right_schema)?
215 };
216
217 let probe_key_indices = if left_projections.is_empty() || right_projections.is_empty() {
219 Vec::new()
220 } else {
221 extract_left_key_indices(keys, &left_schema)?
222 };
223
224 let batch_size = options.batch_size;
226
227 if !left_projections.is_empty() {
228 let filter_expr = build_all_rows_filter(&left_projections)?;
229
230 let mut probe_batches = Vec::new();
235 left.scan_stream(
236 &left_projections,
237 &filter_expr,
238 ScanStreamOptions::default(),
239 |probe_batch| probe_batches.push(probe_batch.clone()),
240 )?;
241
242 let mut probe_tasks = Vec::new();
243 for (batch_idx, probe_batch) in probe_batches.into_iter().enumerate() {
244 let rows = probe_batch.num_rows();
245 if rows == 0 {
246 continue;
247 }
248 let mut start = 0;
249 while start < rows {
250 let len = (start + batch_size).min(rows) - start;
251 let slice = probe_batch.slice(start, len);
252 probe_tasks.push(((batch_idx, start), slice));
253 start += len;
254 }
255 }
256
257 let mut parallel_results: Vec<((usize, usize), Vec<RecordBatch>)> = with_thread_pool(
258 || {
259 probe_tasks
260 .into_par_iter()
261 .map(|(key, probe_batch)| -> LlkvResult<_> {
262 let mut local_batches = Vec::new();
263 let result = match options.join_type {
264 JoinType::Inner => process_inner_probe(
265 &probe_batch,
266 &probe_key_indices,
267 &hash_table,
268 &build_batches,
269 &output_schema,
270 keys,
271 batch_size,
272 &mut |batch| local_batches.push(batch),
273 ),
274 JoinType::Left => process_left_probe(
275 &probe_batch,
276 &probe_key_indices,
277 &hash_table,
278 &build_batches,
279 &output_schema,
280 keys,
281 batch_size,
282 &mut |batch| local_batches.push(batch),
283 ),
284 JoinType::Semi => process_semi_probe(
285 &probe_batch,
286 &probe_key_indices,
287 &hash_table,
288 &output_schema,
289 keys,
290 batch_size,
291 &mut |batch| local_batches.push(batch),
292 ),
293 JoinType::Anti => process_anti_probe(
294 &probe_batch,
295 &probe_key_indices,
296 &hash_table,
297 &output_schema,
298 keys,
299 batch_size,
300 &mut |batch| local_batches.push(batch),
301 ),
302 _ => {
303 tracing::debug!(
304 join_type = ?options.join_type,
305 "Hash join does not yet support this join type; skipping batch processing"
306 );
307 Ok(())
308 }
309 };
310
311 result?;
312 Ok((key, local_batches))
313 })
314 .collect::<LlkvResult<Vec<_>>>()
315 },
316 )?;
317
318 parallel_results.sort_by_key(|(key, _)| *key);
320 for (_, batches) in parallel_results {
321 for batch in batches {
322 on_batch(batch);
323 }
324 }
325 }
326
327 if matches!(options.join_type, JoinType::Right | JoinType::Full) {
329 return Err(Error::Internal(
330 "Right and Full outer joins not yet implemented for hash join".to_string(),
331 ));
332 }
333
334 Ok(())
335}
336
337fn build_hash_table<P>(
339 table: &Table<P>,
340 projections: &[ScanProjection],
341 join_keys: &[JoinKey],
342 schema: &Arc<Schema>,
343) -> LlkvResult<(HashTable, Vec<RecordBatch>)>
344where
345 P: Pager<Blob = EntryHandle> + Send + Sync,
346{
347 let mut hash_table = HashTable::default();
348 let mut batches = Vec::new();
349 let key_indices = extract_right_key_indices(join_keys, schema)?;
350 let filter_expr = build_all_rows_filter(projections)?;
351
352 table.scan_stream(
353 projections,
354 &filter_expr,
355 ScanStreamOptions::default(),
356 |batch| {
357 let batch_idx = batches.len();
358
359 for row_idx in 0..batch.num_rows() {
361 if let Ok(key) = extract_hash_key(&batch, &key_indices, row_idx, join_keys) {
362 hash_table
363 .entry(key)
364 .or_default()
365 .push((batch_idx, row_idx));
366 }
367 }
368
369 batches.push(batch.clone());
370 },
371 )?;
372
373 Ok((hash_table, batches))
374}
375
376fn extract_hash_key(
378 batch: &RecordBatch,
379 key_indices: &[usize],
380 row_idx: usize,
381 join_keys: &[JoinKey],
382) -> LlkvResult<HashKey> {
383 let mut values = Vec::with_capacity(key_indices.len());
384
385 for (&col_idx, join_key) in key_indices.iter().zip(join_keys) {
386 let column = batch.column(col_idx);
387
388 if column.is_null(row_idx) {
390 if join_key.null_equals_null {
391 values.push(KeyValue::Utf8("<NULL>".to_string())); } else {
393 values.push(KeyValue::Null);
394 }
395 continue;
396 }
397
398 let value = extract_key_value(column, row_idx)?;
399 values.push(value);
400 }
401
402 Ok(HashKey { values })
403}
404
405fn extract_key_value(column: &ArrayRef, row_idx: usize) -> LlkvResult<KeyValue> {
407 use arrow::array::*;
408
409 let value = match column.data_type() {
410 DataType::Int8 => KeyValue::Int8(
411 column
412 .as_any()
413 .downcast_ref::<Int8Array>()
414 .unwrap()
415 .value(row_idx),
416 ),
417 DataType::Int16 => KeyValue::Int16(
418 column
419 .as_any()
420 .downcast_ref::<Int16Array>()
421 .unwrap()
422 .value(row_idx),
423 ),
424 DataType::Int32 => KeyValue::Int32(
425 column
426 .as_any()
427 .downcast_ref::<Int32Array>()
428 .unwrap()
429 .value(row_idx),
430 ),
431 DataType::Int64 => KeyValue::Int64(
432 column
433 .as_any()
434 .downcast_ref::<Int64Array>()
435 .unwrap()
436 .value(row_idx),
437 ),
438 DataType::UInt8 => KeyValue::UInt8(
439 column
440 .as_any()
441 .downcast_ref::<UInt8Array>()
442 .unwrap()
443 .value(row_idx),
444 ),
445 DataType::UInt16 => KeyValue::UInt16(
446 column
447 .as_any()
448 .downcast_ref::<UInt16Array>()
449 .unwrap()
450 .value(row_idx),
451 ),
452 DataType::UInt32 => KeyValue::UInt32(
453 column
454 .as_any()
455 .downcast_ref::<UInt32Array>()
456 .unwrap()
457 .value(row_idx),
458 ),
459 DataType::UInt64 => KeyValue::UInt64(
460 column
461 .as_any()
462 .downcast_ref::<UInt64Array>()
463 .unwrap()
464 .value(row_idx),
465 ),
466 DataType::Float32 => {
467 let val = column
468 .as_any()
469 .downcast_ref::<Float32Array>()
470 .unwrap()
471 .value(row_idx);
472 KeyValue::Float32(val.to_bits())
473 }
474 DataType::Float64 => {
475 let val = column
476 .as_any()
477 .downcast_ref::<Float64Array>()
478 .unwrap()
479 .value(row_idx);
480 KeyValue::Float64(val.to_bits())
481 }
482 DataType::Utf8 => KeyValue::Utf8(
483 column
484 .as_any()
485 .downcast_ref::<StringArray>()
486 .unwrap()
487 .value(row_idx)
488 .to_string(),
489 ),
490 DataType::Binary => KeyValue::Binary(
491 column
492 .as_any()
493 .downcast_ref::<BinaryArray>()
494 .unwrap()
495 .value(row_idx)
496 .to_vec(),
497 ),
498 dt => {
499 return Err(Error::Internal(format!(
500 "Unsupported join key type: {:?}",
501 dt
502 )));
503 }
504 };
505
506 Ok(value)
507}
508
509#[allow(clippy::too_many_arguments)]
511fn process_inner_probe<F>(
512 probe_batch: &RecordBatch,
513 probe_key_indices: &[usize],
514 hash_table: &HashTable,
515 build_batches: &[RecordBatch],
516 output_schema: &Arc<Schema>,
517 join_keys: &[JoinKey],
518 batch_size: usize,
519 on_batch: &mut F,
520) -> LlkvResult<()>
521where
522 F: FnMut(RecordBatch),
523{
524 let mut probe_indices = Vec::new();
525 let mut build_indices = Vec::new();
526
527 for probe_row_idx in 0..probe_batch.num_rows() {
528 if let Ok(key) = extract_hash_key(probe_batch, probe_key_indices, probe_row_idx, join_keys)
529 && let Some(build_rows) = hash_table.get(&key)
530 {
531 for &(batch_idx, row_idx) in build_rows {
532 probe_indices.push(probe_row_idx);
533 build_indices.push((batch_idx, row_idx));
534 }
535 }
536
537 if probe_indices.len() >= batch_size {
539 emit_joined_batch(
540 probe_batch,
541 &probe_indices,
542 build_batches,
543 &build_indices,
544 output_schema,
545 on_batch,
546 )?;
547 probe_indices.clear();
548 build_indices.clear();
549 }
550 }
551
552 if !probe_indices.is_empty() {
554 emit_joined_batch(
555 probe_batch,
556 &probe_indices,
557 build_batches,
558 &build_indices,
559 output_schema,
560 on_batch,
561 )?;
562 }
563
564 Ok(())
565}
566
567#[allow(clippy::too_many_arguments)]
569fn process_left_probe<F>(
570 probe_batch: &RecordBatch,
571 probe_key_indices: &[usize],
572 hash_table: &HashTable,
573 build_batches: &[RecordBatch],
574 output_schema: &Arc<Schema>,
575 join_keys: &[JoinKey],
576 batch_size: usize,
577 on_batch: &mut F,
578) -> LlkvResult<()>
579where
580 F: FnMut(RecordBatch),
581{
582 let mut probe_indices = Vec::new();
583 let mut build_indices = Vec::new();
584
585 for probe_row_idx in 0..probe_batch.num_rows() {
586 let mut found_match = false;
587
588 if let Ok(key) = extract_hash_key(probe_batch, probe_key_indices, probe_row_idx, join_keys)
589 && let Some(build_rows) = hash_table.get(&key)
590 {
591 for &(batch_idx, row_idx) in build_rows {
592 probe_indices.push(probe_row_idx);
593 build_indices.push(Some((batch_idx, row_idx)));
594 found_match = true;
595 }
596 }
597
598 if !found_match {
599 probe_indices.push(probe_row_idx);
601 build_indices.push(None);
602 }
603
604 if probe_indices.len() >= batch_size {
606 emit_left_joined_batch(
607 probe_batch,
608 &probe_indices,
609 build_batches,
610 &build_indices,
611 output_schema,
612 on_batch,
613 )?;
614 probe_indices.clear();
615 build_indices.clear();
616 }
617 }
618
619 if !probe_indices.is_empty() {
621 emit_left_joined_batch(
622 probe_batch,
623 &probe_indices,
624 build_batches,
625 &build_indices,
626 output_schema,
627 on_batch,
628 )?;
629 }
630
631 Ok(())
632}
633
634#[allow(clippy::too_many_arguments)]
636fn process_semi_probe<F>(
637 probe_batch: &RecordBatch,
638 probe_key_indices: &[usize],
639 hash_table: &HashTable,
640 output_schema: &Arc<Schema>,
641 join_keys: &[JoinKey],
642 batch_size: usize,
643 on_batch: &mut F,
644) -> LlkvResult<()>
645where
646 F: FnMut(RecordBatch),
647{
648 let mut probe_indices = Vec::new();
649
650 for probe_row_idx in 0..probe_batch.num_rows() {
651 if let Ok(key) = extract_hash_key(probe_batch, probe_key_indices, probe_row_idx, join_keys)
652 && hash_table.contains_key(&key)
653 {
654 probe_indices.push(probe_row_idx);
655 }
656
657 if probe_indices.len() >= batch_size {
659 emit_semi_batch(probe_batch, &probe_indices, output_schema, on_batch)?;
660 probe_indices.clear();
661 }
662 }
663
664 if !probe_indices.is_empty() {
666 emit_semi_batch(probe_batch, &probe_indices, output_schema, on_batch)?;
667 }
668
669 Ok(())
670}
671
672#[allow(clippy::too_many_arguments)]
674fn process_anti_probe<F>(
675 probe_batch: &RecordBatch,
676 probe_key_indices: &[usize],
677 hash_table: &HashTable,
678 output_schema: &Arc<Schema>,
679 join_keys: &[JoinKey],
680 batch_size: usize,
681 on_batch: &mut F,
682) -> LlkvResult<()>
683where
684 F: FnMut(RecordBatch),
685{
686 let mut probe_indices = Vec::new();
687
688 for probe_row_idx in 0..probe_batch.num_rows() {
689 let mut found = false;
690 if let Ok(key) = extract_hash_key(probe_batch, probe_key_indices, probe_row_idx, join_keys)
691 {
692 found = hash_table.contains_key(&key);
693 }
694
695 if !found {
696 probe_indices.push(probe_row_idx);
697 }
698
699 if probe_indices.len() >= batch_size {
701 emit_semi_batch(probe_batch, &probe_indices, output_schema, on_batch)?;
702 probe_indices.clear();
703 }
704 }
705
706 if !probe_indices.is_empty() {
708 emit_semi_batch(probe_batch, &probe_indices, output_schema, on_batch)?;
709 }
710
711 Ok(())
712}
713
714fn emit_joined_batch<F>(
716 probe_batch: &RecordBatch,
717 probe_indices: &[usize],
718 build_batches: &[RecordBatch],
719 build_indices: &[(usize, usize)],
720 output_schema: &Arc<Schema>,
721 on_batch: &mut F,
722) -> LlkvResult<()>
723where
724 F: FnMut(RecordBatch),
725{
726 let probe_arrays = gather_indices(probe_batch, probe_indices)?;
727 let build_arrays = gather_indices_from_batches(build_batches, build_indices)?;
728
729 let output_arrays: Vec<ArrayRef> = probe_arrays.into_iter().chain(build_arrays).collect();
730
731 let output_batch = RecordBatch::try_new(output_schema.clone(), output_arrays)?;
732 on_batch(output_batch);
733 Ok(())
734}
735
736fn emit_left_joined_batch<F>(
738 probe_batch: &RecordBatch,
739 probe_indices: &[usize],
740 build_batches: &[RecordBatch],
741 build_indices: &[Option<(usize, usize)>],
742 output_schema: &Arc<Schema>,
743 on_batch: &mut F,
744) -> LlkvResult<()>
745where
746 F: FnMut(RecordBatch),
747{
748 let probe_arrays = gather_indices(probe_batch, probe_indices)?;
749 let build_arrays = gather_optional_indices_from_batches(build_batches, build_indices)?;
750
751 let output_arrays: Vec<ArrayRef> = probe_arrays.into_iter().chain(build_arrays).collect();
752
753 let output_batch = RecordBatch::try_new(output_schema.clone(), output_arrays)?;
754 on_batch(output_batch);
755 Ok(())
756}
757
758fn emit_semi_batch<F>(
760 probe_batch: &RecordBatch,
761 probe_indices: &[usize],
762 output_schema: &Arc<Schema>,
763 on_batch: &mut F,
764) -> LlkvResult<()>
765where
766 F: FnMut(RecordBatch),
767{
768 let probe_arrays = gather_indices(probe_batch, probe_indices)?;
769 let output_batch = RecordBatch::try_new(output_schema.clone(), probe_arrays)?;
770 on_batch(output_batch);
771 Ok(())
772}
773
774fn build_user_projections<P>(
776 table: &Table<P>,
777 schema: &Arc<Schema>,
778) -> LlkvResult<Vec<ScanProjection>>
779where
780 P: Pager<Blob = EntryHandle> + Send + Sync,
781{
782 let cached = CachedSchema::new(Arc::clone(schema));
783 let mut projections = Vec::new();
784
785 for (idx, field) in schema.fields().iter().enumerate() {
786 let Some(field_id) = cached.field_id(idx) else {
788 continue;
789 };
790
791 let lfid = LogicalFieldId::for_user(table.table_id(), field_id);
792 projections.push(ScanProjection::Column(Projection::with_alias(
793 lfid,
794 field.name().to_string(),
795 )));
796 }
797
798 Ok(projections)
799}
800
801fn build_all_rows_filter(projections: &[ScanProjection]) -> LlkvResult<Expr<'static, FieldId>> {
802 if projections.is_empty() {
803 return Ok(Expr::Pred(Filter {
804 field_id: 0,
805 op: Operator::Range {
806 lower: Bound::Unbounded,
807 upper: Bound::Unbounded,
808 },
809 }));
810 }
811
812 let first_field = match &projections[0] {
813 ScanProjection::Column(proj) => proj.logical_field_id.field_id(),
814 ScanProjection::Computed { .. } => {
815 return Err(Error::InvalidArgumentError(
816 "join projections cannot include computed columns yet".to_string(),
817 ));
818 }
819 };
820
821 Ok(Expr::Pred(Filter {
822 field_id: first_field,
823 op: Operator::Range {
824 lower: Bound::Unbounded,
825 upper: Bound::Unbounded,
826 },
827 }))
828}
829
830fn extract_left_key_indices(keys: &[JoinKey], schema: &Arc<Schema>) -> LlkvResult<Vec<usize>> {
831 keys.iter()
832 .map(|key| find_field_index(schema, key.left_field))
833 .collect()
834}
835
836fn extract_right_key_indices(keys: &[JoinKey], schema: &Arc<Schema>) -> LlkvResult<Vec<usize>> {
837 keys.iter()
838 .map(|key| find_field_index(schema, key.right_field))
839 .collect()
840}
841
842fn find_field_index(schema: &Schema, target_field_id: FieldId) -> LlkvResult<usize> {
843 let cached = CachedSchema::new(Arc::new(schema.clone()));
847
848 let schema_index = cached.index_of_field_id(target_field_id).ok_or_else(|| {
850 Error::Internal(format!("field_id {} not found in schema", target_field_id))
851 })?;
852
853 let mut user_col_idx = 0;
856 for idx in 0..schema_index {
857 if cached.field_id(idx).is_some() {
858 user_col_idx += 1;
859 }
860 }
861
862 Ok(user_col_idx)
863}
864
865fn get_key_datatype(schema: &Schema, field_id: FieldId) -> LlkvResult<DataType> {
867 let cached = CachedSchema::new(Arc::new(schema.clone()));
869
870 let index = cached
871 .index_of_field_id(field_id)
872 .ok_or_else(|| Error::Internal(format!("field_id {} not found in schema", field_id)))?;
873
874 Ok(schema.field(index).data_type().clone())
875}
876
877fn build_output_schema(
878 left_schema: &Schema,
879 right_schema: &Schema,
880 join_type: JoinType,
881) -> LlkvResult<Arc<Schema>> {
882 let mut fields = Vec::new();
883 let mut field_names: std::collections::HashSet<String> = std::collections::HashSet::new();
884
885 if matches!(join_type, JoinType::Semi | JoinType::Anti) {
887 for field in left_schema.fields() {
888 if field
889 .metadata()
890 .get(llkv_column_map::store::FIELD_ID_META_KEY)
891 .is_some()
892 {
893 fields.push(field.clone());
894 field_names.insert(field.name().clone());
895 }
896 }
897 return Ok(Arc::new(Schema::new(fields)));
898 }
899
900 for field in left_schema.fields() {
903 if field
904 .metadata()
905 .get(llkv_column_map::store::FIELD_ID_META_KEY)
906 .is_some()
907 {
908 fields.push(field.clone());
909 field_names.insert(field.name().clone());
910 }
911 }
912
913 for field in right_schema.fields() {
915 if field
916 .metadata()
917 .get(llkv_column_map::store::FIELD_ID_META_KEY)
918 .is_some()
919 {
920 let field_name = field.name();
921 let new_name = if field_names.contains(field_name) {
923 format!("{}_1", field_name)
924 } else {
925 field_name.clone()
926 };
927
928 let new_field = Arc::new(
929 arrow::datatypes::Field::new(
930 new_name.clone(),
931 field.data_type().clone(),
932 field.is_nullable(),
933 )
934 .with_metadata(field.metadata().clone()),
935 );
936
937 fields.push(new_field);
938 field_names.insert(new_name);
939 }
940 }
941
942 Ok(Arc::new(Schema::new(fields)))
943}
944
945macro_rules! impl_integer_fast_path {
956 (
957 fast_path_fn: $fast_path_fn:ident,
958 build_fn: $build_fn:ident,
959 inner_probe_fn: $inner_probe_fn:ident,
960 left_probe_fn: $left_probe_fn:ident,
961 semi_probe_fn: $semi_probe_fn:ident,
962 anti_probe_fn: $anti_probe_fn:ident,
963 rust_type: $rust_type:ty,
964 arrow_array: $arrow_array:ty,
965 null_sentinel: $null_sentinel:expr
966 ) => {
967 #[allow(clippy::too_many_arguments)]
972 fn $fast_path_fn<P, F>(
973 left: &Table<P>,
974 right: &Table<P>,
975 keys: &[JoinKey],
976 options: &JoinOptions,
977 mut on_batch: F,
978 ) -> LlkvResult<()>
979 where
980 P: Pager<Blob = EntryHandle> + Send + Sync,
981 F: FnMut(RecordBatch),
982 {
983 let left_schema = left.schema()?;
984 let right_schema = right.schema()?;
985
986 let left_projections = build_user_projections(left, &left_schema)?;
987 let right_projections = build_user_projections(right, &right_schema)?;
988
989 let output_schema =
990 build_output_schema(&left_schema, &right_schema, options.join_type)?;
991
992 let (hash_table, build_batches) = if right_projections.is_empty() {
993 (FxHashMap::default(), Vec::new())
994 } else {
995 $build_fn(right, &right_projections, keys, &right_schema)?
996 };
997
998 let probe_key_idx = if left_projections.is_empty() || right_projections.is_empty() {
999 0
1000 } else {
1001 find_field_index(&left_schema, keys[0].left_field)?
1002 };
1003
1004 let batch_size = options.batch_size;
1005
1006 if !left_projections.is_empty() {
1007 let filter_expr = build_all_rows_filter(&left_projections)?;
1008 let null_equals_null = keys[0].null_equals_null;
1009
1010 left.scan_stream(
1011 &left_projections,
1012 &filter_expr,
1013 ScanStreamOptions::default(),
1014 |probe_batch| {
1015 let result = match options.join_type {
1016 JoinType::Inner => $inner_probe_fn(
1017 &probe_batch,
1018 probe_key_idx,
1019 &hash_table,
1020 &build_batches,
1021 &output_schema,
1022 null_equals_null,
1023 batch_size,
1024 &mut on_batch,
1025 ),
1026 JoinType::Left => $left_probe_fn(
1027 &probe_batch,
1028 probe_key_idx,
1029 &hash_table,
1030 &build_batches,
1031 &output_schema,
1032 null_equals_null,
1033 batch_size,
1034 &mut on_batch,
1035 ),
1036 JoinType::Semi => $semi_probe_fn(
1037 &probe_batch,
1038 probe_key_idx,
1039 &hash_table,
1040 &output_schema,
1041 null_equals_null,
1042 batch_size,
1043 &mut on_batch,
1044 ),
1045 JoinType::Anti => $anti_probe_fn(
1046 &probe_batch,
1047 probe_key_idx,
1048 &hash_table,
1049 &output_schema,
1050 null_equals_null,
1051 batch_size,
1052 &mut on_batch,
1053 ),
1054 _ => {
1055 tracing::debug!(
1056 join_type = ?options.join_type,
1057 "Hash join does not yet support this join type; skipping batch processing"
1058 );
1059 Ok(())
1060 }
1061 };
1062
1063 if let Err(err) = result {
1064 tracing::debug!(error = %err, "Hash join batch processing failed");
1065 }
1066 },
1067 )?;
1068 }
1069
1070 if matches!(options.join_type, JoinType::Right | JoinType::Full) {
1071 return Err(Error::Internal(
1072 "Right and Full outer joins not yet implemented for hash join".to_string(),
1073 ));
1074 }
1075
1076 Ok(())
1077 }
1078
1079 fn $build_fn<P>(
1081 table: &Table<P>,
1082 projections: &[ScanProjection],
1083 join_keys: &[JoinKey],
1084 schema: &Arc<Schema>,
1085 ) -> LlkvResult<(FxHashMap<$rust_type, Vec<RowRef>>, Vec<RecordBatch>)>
1086 where
1087 P: Pager<Blob = EntryHandle> + Send + Sync,
1088 {
1089 let mut hash_table: FxHashMap<$rust_type, Vec<RowRef>> = FxHashMap::default();
1090 let mut batches = Vec::new();
1091 let key_idx = find_field_index(schema, join_keys[0].right_field)?;
1092 let filter_expr = build_all_rows_filter(projections)?;
1093 let null_equals_null = join_keys[0].null_equals_null;
1094
1095 table.scan_stream(
1096 projections,
1097 &filter_expr,
1098 ScanStreamOptions::default(),
1099 |batch| {
1100 let batch_idx = batches.len();
1101 let key_column = batch.column(key_idx);
1102 let key_array = match key_column.as_any().downcast_ref::<$arrow_array>() {
1103 Some(arr) => arr,
1104 None => {
1105 tracing::debug!(
1106 expected_array = stringify!($arrow_array),
1107 actual_type = ?key_column.data_type(),
1108 "Fast-path expected array type mismatch; falling back to generic path"
1109 );
1110 batches.push(batch.clone());
1111 return;
1112 }
1113 };
1114
1115 for row_idx in 0..batch.num_rows() {
1116 if key_array.is_null(row_idx) {
1117 if null_equals_null {
1118 hash_table
1119 .entry($null_sentinel)
1120 .or_default()
1121 .push((batch_idx, row_idx));
1122 }
1123 } else {
1124 let key = key_array.value(row_idx);
1125 hash_table
1126 .entry(key)
1127 .or_default()
1128 .push((batch_idx, row_idx));
1129 }
1130 }
1131
1132 batches.push(batch.clone());
1133 },
1134 )?;
1135
1136 Ok((hash_table, batches))
1137 }
1138
1139 #[allow(clippy::too_many_arguments)]
1141 fn $inner_probe_fn<F>(
1142 probe_batch: &RecordBatch,
1143 probe_key_idx: usize,
1144 hash_table: &FxHashMap<$rust_type, Vec<RowRef>>,
1145 build_batches: &[RecordBatch],
1146 output_schema: &Arc<Schema>,
1147 null_equals_null: bool,
1148 batch_size: usize,
1149 on_batch: &mut F,
1150 ) -> LlkvResult<()>
1151 where
1152 F: FnMut(RecordBatch),
1153 {
1154 let probe_keys = match probe_batch
1155 .column(probe_key_idx)
1156 .as_any()
1157 .downcast_ref::<$arrow_array>()
1158 {
1159 Some(arr) => arr,
1160 None => {
1161 return Err(Error::Internal(format!(
1162 "Fast-path: Expected array type at column {} but got {:?}",
1163 probe_key_idx,
1164 probe_batch.column(probe_key_idx).data_type()
1165 )));
1166 }
1167 };
1168 let mut probe_indices = Vec::with_capacity(batch_size);
1169 let mut build_indices = Vec::with_capacity(batch_size);
1170
1171 for probe_row_idx in 0..probe_batch.num_rows() {
1172 let key = if probe_keys.is_null(probe_row_idx) {
1173 if null_equals_null {
1174 $null_sentinel
1175 } else {
1176 continue;
1177 }
1178 } else {
1179 probe_keys.value(probe_row_idx)
1180 };
1181
1182 if let Some(build_rows) = hash_table.get(&key) {
1183 for &row_ref in build_rows {
1184 probe_indices.push(probe_row_idx);
1185 build_indices.push(row_ref);
1186 }
1187 }
1188
1189 if probe_indices.len() >= batch_size {
1190 emit_joined_batch(
1191 probe_batch,
1192 &probe_indices,
1193 build_batches,
1194 &build_indices,
1195 output_schema,
1196 on_batch,
1197 )?;
1198 probe_indices.clear();
1199 build_indices.clear();
1200 }
1201 }
1202
1203 if !probe_indices.is_empty() {
1204 emit_joined_batch(
1205 probe_batch,
1206 &probe_indices,
1207 build_batches,
1208 &build_indices,
1209 output_schema,
1210 on_batch,
1211 )?;
1212 }
1213
1214 Ok(())
1215 }
1216
1217 #[allow(clippy::too_many_arguments)]
1219 fn $left_probe_fn<F>(
1220 probe_batch: &RecordBatch,
1221 probe_key_idx: usize,
1222 hash_table: &FxHashMap<$rust_type, Vec<RowRef>>,
1223 build_batches: &[RecordBatch],
1224 output_schema: &Arc<Schema>,
1225 null_equals_null: bool,
1226 batch_size: usize,
1227 on_batch: &mut F,
1228 ) -> LlkvResult<()>
1229 where
1230 F: FnMut(RecordBatch),
1231 {
1232 let probe_keys = match probe_batch
1233 .column(probe_key_idx)
1234 .as_any()
1235 .downcast_ref::<$arrow_array>()
1236 {
1237 Some(arr) => arr,
1238 None => {
1239 return Err(Error::Internal(format!(
1240 "Fast-path: Expected array type at column {} but got {:?}",
1241 probe_key_idx,
1242 probe_batch.column(probe_key_idx).data_type()
1243 )));
1244 }
1245 };
1246 let mut probe_indices = Vec::with_capacity(batch_size);
1247 let mut build_indices = Vec::with_capacity(batch_size);
1248
1249 for probe_row_idx in 0..probe_batch.num_rows() {
1250 let key = if probe_keys.is_null(probe_row_idx) {
1251 if null_equals_null {
1252 $null_sentinel
1253 } else {
1254 probe_indices.push(probe_row_idx);
1255 build_indices.push(None);
1256 continue;
1257 }
1258 } else {
1259 probe_keys.value(probe_row_idx)
1260 };
1261
1262 if let Some(build_rows) = hash_table.get(&key) {
1263 for &row_ref in build_rows {
1264 probe_indices.push(probe_row_idx);
1265 build_indices.push(Some(row_ref));
1266 }
1267 } else {
1268 probe_indices.push(probe_row_idx);
1269 build_indices.push(None);
1270 }
1271
1272 if probe_indices.len() >= batch_size {
1273 emit_left_joined_batch(
1274 probe_batch,
1275 &probe_indices,
1276 build_batches,
1277 &build_indices,
1278 output_schema,
1279 on_batch,
1280 )?;
1281 probe_indices.clear();
1282 build_indices.clear();
1283 }
1284 }
1285
1286 if !probe_indices.is_empty() {
1287 emit_left_joined_batch(
1288 probe_batch,
1289 &probe_indices,
1290 build_batches,
1291 &build_indices,
1292 output_schema,
1293 on_batch,
1294 )?;
1295 }
1296
1297 Ok(())
1298 }
1299
1300 #[allow(clippy::too_many_arguments)]
1302 fn $semi_probe_fn<F>(
1303 probe_batch: &RecordBatch,
1304 probe_key_idx: usize,
1305 hash_table: &FxHashMap<$rust_type, Vec<RowRef>>,
1306 output_schema: &Arc<Schema>,
1307 null_equals_null: bool,
1308 batch_size: usize,
1309 on_batch: &mut F,
1310 ) -> LlkvResult<()>
1311 where
1312 F: FnMut(RecordBatch),
1313 {
1314 let probe_keys = match probe_batch
1315 .column(probe_key_idx)
1316 .as_any()
1317 .downcast_ref::<$arrow_array>()
1318 {
1319 Some(arr) => arr,
1320 None => {
1321 return Err(Error::Internal(format!(
1322 "Fast-path: Expected array type at column {} but got {:?}",
1323 probe_key_idx,
1324 probe_batch.column(probe_key_idx).data_type()
1325 )));
1326 }
1327 };
1328 let mut probe_indices = Vec::with_capacity(batch_size);
1329
1330 for probe_row_idx in 0..probe_batch.num_rows() {
1331 let key = if probe_keys.is_null(probe_row_idx) {
1332 if null_equals_null {
1333 $null_sentinel
1334 } else {
1335 continue;
1336 }
1337 } else {
1338 probe_keys.value(probe_row_idx)
1339 };
1340
1341 if hash_table.contains_key(&key) {
1342 probe_indices.push(probe_row_idx);
1343 }
1344
1345 if probe_indices.len() >= batch_size {
1346 emit_semi_batch(probe_batch, &probe_indices, output_schema, on_batch)?;
1347 probe_indices.clear();
1348 }
1349 }
1350
1351 if !probe_indices.is_empty() {
1352 emit_semi_batch(probe_batch, &probe_indices, output_schema, on_batch)?;
1353 }
1354
1355 Ok(())
1356 }
1357
1358 #[allow(clippy::too_many_arguments)]
1360 fn $anti_probe_fn<F>(
1361 probe_batch: &RecordBatch,
1362 probe_key_idx: usize,
1363 hash_table: &FxHashMap<$rust_type, Vec<RowRef>>,
1364 output_schema: &Arc<Schema>,
1365 null_equals_null: bool,
1366 batch_size: usize,
1367 on_batch: &mut F,
1368 ) -> LlkvResult<()>
1369 where
1370 F: FnMut(RecordBatch),
1371 {
1372 let probe_keys = match probe_batch
1373 .column(probe_key_idx)
1374 .as_any()
1375 .downcast_ref::<$arrow_array>()
1376 {
1377 Some(arr) => arr,
1378 None => {
1379 return Err(Error::Internal(format!(
1380 "Fast-path: Expected array type at column {} but got {:?}",
1381 probe_key_idx,
1382 probe_batch.column(probe_key_idx).data_type()
1383 )));
1384 }
1385 };
1386 let mut probe_indices = Vec::with_capacity(batch_size);
1387
1388 for probe_row_idx in 0..probe_batch.num_rows() {
1389 let key = if probe_keys.is_null(probe_row_idx) {
1390 if null_equals_null {
1391 $null_sentinel
1392 } else {
1393 probe_indices.push(probe_row_idx);
1394 continue;
1395 }
1396 } else {
1397 probe_keys.value(probe_row_idx)
1398 };
1399
1400 if !hash_table.contains_key(&key) {
1401 probe_indices.push(probe_row_idx);
1402 }
1403
1404 if probe_indices.len() >= batch_size {
1405 emit_semi_batch(probe_batch, &probe_indices, output_schema, on_batch)?;
1406 probe_indices.clear();
1407 }
1408 }
1409
1410 if !probe_indices.is_empty() {
1411 emit_semi_batch(probe_batch, &probe_indices, output_schema, on_batch)?;
1412 }
1413
1414 Ok(())
1415 }
1416 };
1417}
1418
1419impl_integer_fast_path!(
1421 fast_path_fn: hash_join_i32_fast_path,
1422 build_fn: build_i32_hash_table,
1423 inner_probe_fn: process_i32_inner_probe,
1424 left_probe_fn: process_i32_left_probe,
1425 semi_probe_fn: process_i32_semi_probe,
1426 anti_probe_fn: process_i32_anti_probe,
1427 rust_type: i32,
1428 arrow_array: arrow::array::Int32Array,
1429 null_sentinel: i32::MIN
1430);
1431
1432impl_integer_fast_path!(
1433 fast_path_fn: hash_join_i64_fast_path,
1434 build_fn: build_i64_hash_table,
1435 inner_probe_fn: process_i64_inner_probe,
1436 left_probe_fn: process_i64_left_probe,
1437 semi_probe_fn: process_i64_semi_probe,
1438 anti_probe_fn: process_i64_anti_probe,
1439 rust_type: i64,
1440 arrow_array: arrow::array::Int64Array,
1441 null_sentinel: i64::MIN
1442);
1443
1444impl_integer_fast_path!(
1445 fast_path_fn: hash_join_u32_fast_path,
1446 build_fn: build_u32_hash_table,
1447 inner_probe_fn: process_u32_inner_probe,
1448 left_probe_fn: process_u32_left_probe,
1449 semi_probe_fn: process_u32_semi_probe,
1450 anti_probe_fn: process_u32_anti_probe,
1451 rust_type: u32,
1452 arrow_array: arrow::array::UInt32Array,
1453 null_sentinel: u32::MAX
1454);
1455
1456impl_integer_fast_path!(
1457 fast_path_fn: hash_join_u64_fast_path,
1458 build_fn: build_u64_hash_table,
1459 inner_probe_fn: process_u64_inner_probe,
1460 left_probe_fn: process_u64_left_probe,
1461 semi_probe_fn: process_u64_semi_probe,
1462 anti_probe_fn: process_u64_anti_probe,
1463 rust_type: u64,
1464 arrow_array: arrow::array::UInt64Array,
1465 null_sentinel: u64::MAX
1466);
1467
1468fn synthesize_left_join_nulls(
1471 left_batch: &RecordBatch,
1472 output_schema: &Arc<Schema>,
1473) -> LlkvResult<RecordBatch> {
1474 use arrow::array::new_null_array;
1475
1476 let left_col_count = left_batch.num_columns();
1477 let right_col_count = output_schema.fields().len() - left_col_count;
1478 let row_count = left_batch.num_rows();
1479
1480 let mut columns: Vec<ArrayRef> = Vec::with_capacity(output_schema.fields().len());
1481
1482 for col in left_batch.columns() {
1484 columns.push(Arc::clone(col));
1485 }
1486
1487 for field_idx in left_col_count..(left_col_count + right_col_count) {
1489 let field = output_schema.field(field_idx);
1490 let null_array = new_null_array(field.data_type(), row_count);
1491 columns.push(null_array);
1492 }
1493
1494 RecordBatch::try_new(Arc::clone(output_schema), columns).map_err(|err| {
1495 Error::InvalidArgumentError(format!("Failed to create LEFT JOIN null batch: {}", err))
1496 })
1497}
1498
1499fn cross_product_stream<P, F>(
1501 left: &Table<P>,
1502 right: &Table<P>,
1503 options: &JoinOptions,
1504 mut on_batch: F,
1505) -> LlkvResult<()>
1506where
1507 P: Pager<Blob = EntryHandle> + Send + Sync,
1508 F: FnMut(RecordBatch),
1509{
1510 let left_schema = left.schema()?;
1511 let right_schema = right.schema()?;
1512
1513 let left_projections = build_user_projections(left, &left_schema)?;
1515 let right_projections = build_user_projections(right, &right_schema)?;
1516
1517 let output_schema = build_output_schema(&left_schema, &right_schema, options.join_type)?;
1519
1520 let mut right_batches = Vec::new();
1521 if !right_projections.is_empty() {
1522 let filter_expr = build_all_rows_filter(&right_projections)?;
1523 right.scan_stream(
1524 &right_projections,
1525 &filter_expr,
1526 ScanStreamOptions::default(),
1527 |batch| {
1528 right_batches.push(batch);
1529 },
1530 )?;
1531 }
1532
1533 let right_is_empty =
1536 right_batches.is_empty() || right_batches.iter().all(|b| b.num_rows() == 0);
1537
1538 if right_is_empty && options.join_type == JoinType::Inner {
1539 return Ok(());
1540 }
1541
1542 if left_projections.is_empty() {
1543 return Ok(());
1544 }
1545
1546 let filter_expr = build_all_rows_filter(&left_projections)?;
1547 let mut error: Option<Error> = None;
1548
1549 left.scan_stream(
1550 &left_projections,
1551 &filter_expr,
1552 ScanStreamOptions::default(),
1553 |left_batch| {
1554 if error.is_some() || left_batch.num_rows() == 0 {
1555 return;
1556 }
1557
1558 if right_is_empty && options.join_type == JoinType::Left {
1560 match synthesize_left_join_nulls(&left_batch, &output_schema) {
1561 Ok(result) => on_batch(result),
1562 Err(err) => {
1563 error = Some(err);
1564 }
1565 }
1566 return;
1567 }
1568
1569 for right_batch in &right_batches {
1570 if right_batch.num_rows() == 0 {
1571 continue;
1572 }
1573
1574 match crate::cartesian::cross_join_pair(&left_batch, right_batch, &output_schema) {
1575 Ok(result) => on_batch(result),
1576 Err(err) => {
1577 error = Some(err);
1578 break;
1579 }
1580 }
1581 }
1582 },
1583 )?;
1584
1585 if let Some(err) = error {
1586 return Err(err);
1587 }
1588
1589 Ok(())
1590}