use crate::error::{DbxError, DbxResult};
use crate::sql::executor::operators::PhysicalOperator;
use crate::sql::planner::JoinType;
use ahash::AHashMap;
use arrow::array::*;
use arrow::compute;
use arrow::datatypes::{DataType, Schema};
use arrow::record_batch::RecordBatch;
use rayon::prelude::*;
use smallvec::{SmallVec, smallvec};
use std::sync::Arc;
pub struct HashJoinOperator {
left: Box<dyn PhysicalOperator>,
right: Box<dyn PhysicalOperator>,
schema: Arc<Schema>,
on: Vec<(usize, usize)>,
#[allow(dead_code)]
join_type: JoinType,
build_table: Option<AHashMap<Vec<u8>, Vec<usize>>>,
left_batch: Option<RecordBatch>,
right_batches: Option<Vec<RecordBatch>>,
right_batch_idx: usize,
done: bool,
}
impl HashJoinOperator {
pub fn new(
left: Box<dyn PhysicalOperator>,
right: Box<dyn PhysicalOperator>,
schema: Arc<Schema>,
on: Vec<(usize, usize)>,
join_type: JoinType,
) -> Self {
Self {
left,
right,
schema,
on,
join_type,
build_table: None,
left_batch: None,
right_batches: None,
right_batch_idx: 0,
done: false,
}
}
fn build_phase(&mut self) -> DbxResult<()> {
let mut left_batches: SmallVec<[RecordBatch; 8]> = smallvec![];
while let Some(batch) = self.left.next()? {
if batch.num_rows() > 0 {
left_batches.push(batch);
}
}
let mut right_batches: SmallVec<[RecordBatch; 8]> = smallvec![];
while let Some(batch) = self.right.next()? {
if batch.num_rows() > 0 {
right_batches.push(batch);
}
}
if left_batches.is_empty() || right_batches.is_empty() {
self.build_table = Some(AHashMap::new());
self.left_batch = None;
self.right_batches = Some(Vec::new());
return Ok(());
}
let left_rows: usize = left_batches.iter().map(|b| b.num_rows()).sum();
let right_rows: usize = right_batches.iter().map(|b| b.num_rows()).sum();
let (build_batches, probe_batches, build_is_left) =
if right_rows < left_rows && matches!(self.join_type, JoinType::Inner) {
(right_batches, left_batches, false)
} else {
(left_batches, right_batches, true)
};
let schema = build_batches[0].schema();
let merged = super::super::concat_batches(&schema, build_batches.as_slice())?;
const PARALLEL_THRESHOLD: usize = 1000;
let key_columns: Vec<usize> = if build_is_left {
self.on.iter().map(|(left_col, _)| *left_col).collect()
} else {
self.on.iter().map(|(_, right_col)| *right_col).collect()
};
let hash_table: AHashMap<Vec<u8>, Vec<usize>> = if merged.num_rows() >= PARALLEL_THRESHOLD {
use dashmap::DashMap;
let parallel_table: DashMap<Vec<u8>, Vec<usize>> = DashMap::new();
(0..merged.num_rows()).into_par_iter().for_each(|row_idx| {
let key = extract_join_key(&merged, &key_columns, row_idx);
parallel_table.entry(key).or_default().push(row_idx);
});
parallel_table.into_iter().collect()
} else {
let mut hash_table: AHashMap<Vec<u8>, Vec<usize>> = AHashMap::new();
for row_idx in 0..merged.num_rows() {
let key = extract_join_key(&merged, &key_columns, row_idx);
hash_table.entry(key).or_default().push(row_idx);
}
hash_table
};
if build_is_left {
self.left_batch = Some(merged);
self.right_batches = Some(probe_batches.into_vec());
} else {
self.left_batch = Some(merged);
self.right_batches = Some(probe_batches.into_vec());
}
self.build_table = Some(hash_table);
Ok(())
}
}
fn extract_join_key(batch: &RecordBatch, key_columns: &[usize], row_idx: usize) -> Vec<u8> {
let mut key = Vec::new();
for &col_idx in key_columns {
append_value_to_key(&mut key, batch.column(col_idx), row_idx);
}
key
}
fn append_value_to_key(key: &mut Vec<u8>, col: &ArrayRef, row_idx: usize) {
if col.is_null(row_idx) {
key.push(0); return;
}
key.push(1); match col.data_type() {
DataType::Int32 => {
let arr = col.as_any().downcast_ref::<Int32Array>().unwrap();
key.extend_from_slice(&arr.value(row_idx).to_le_bytes());
}
DataType::Int64 => {
let arr = col.as_any().downcast_ref::<Int64Array>().unwrap();
key.extend_from_slice(&arr.value(row_idx).to_le_bytes());
}
DataType::Float64 => {
let arr = col.as_any().downcast_ref::<Float64Array>().unwrap();
key.extend_from_slice(&arr.value(row_idx).to_le_bytes());
}
DataType::Utf8 => {
let arr = col.as_any().downcast_ref::<StringArray>().unwrap();
let s = arr.value(row_idx);
key.extend_from_slice(&(s.len() as u32).to_le_bytes());
key.extend_from_slice(s.as_bytes());
}
_ => {
key.extend_from_slice(format!("{:?}", col).as_bytes());
}
}
}
fn create_column_with_nulls(
source_col: &ArrayRef,
indices: &[u32],
null_sentinel: u32,
) -> DbxResult<ArrayRef> {
let num_rows = indices.len();
match source_col.data_type() {
DataType::Int32 => {
let source = source_col.as_any().downcast_ref::<Int32Array>().unwrap();
let mut builder = Int32Builder::with_capacity(num_rows);
for &idx in indices {
if idx == null_sentinel {
builder.append_null();
} else {
builder.append_value(source.value(idx as usize));
}
}
Ok(Arc::new(builder.finish()))
}
DataType::Int64 => {
let source = source_col.as_any().downcast_ref::<Int64Array>().unwrap();
let mut builder = Int64Builder::with_capacity(num_rows);
for &idx in indices {
if idx == null_sentinel {
builder.append_null();
} else {
builder.append_value(source.value(idx as usize));
}
}
Ok(Arc::new(builder.finish()))
}
DataType::Float64 => {
let source = source_col.as_any().downcast_ref::<Float64Array>().unwrap();
let mut builder = Float64Builder::with_capacity(num_rows);
for &idx in indices {
if idx == null_sentinel {
builder.append_null();
} else {
builder.append_value(source.value(idx as usize));
}
}
Ok(Arc::new(builder.finish()))
}
DataType::Utf8 => {
let source = source_col.as_any().downcast_ref::<StringArray>().unwrap();
let mut builder = StringBuilder::with_capacity(num_rows, num_rows * 10);
for &idx in indices {
if idx == null_sentinel {
builder.append_null();
} else {
builder.append_value(source.value(idx as usize));
}
}
Ok(Arc::new(builder.finish()))
}
DataType::Boolean => {
let source = source_col.as_any().downcast_ref::<BooleanArray>().unwrap();
let mut builder = BooleanBuilder::with_capacity(num_rows);
for &idx in indices {
if idx == null_sentinel {
builder.append_null();
} else {
builder.append_value(source.value(idx as usize));
}
}
Ok(Arc::new(builder.finish()))
}
_ => Err(DbxError::SqlExecution {
message: format!(
"Unsupported data type for NULL handling: {:?}",
source_col.data_type()
),
context: "create_column_with_nulls".to_string(),
}),
}
}
impl PhysicalOperator for HashJoinOperator {
fn schema(&self) -> &Schema {
&self.schema
}
fn next(&mut self) -> DbxResult<Option<RecordBatch>> {
if self.done {
return Ok(None);
}
if self.build_table.is_none() {
self.build_phase()?;
}
let build_table = self.build_table.as_ref().unwrap();
let left_batch = match &self.left_batch {
Some(b) => b.clone(),
None => {
self.done = true;
return Ok(None);
}
};
let right_batches = self.right_batches.as_ref().unwrap();
while self.right_batch_idx < right_batches.len() {
let right_batch = &right_batches[self.right_batch_idx];
self.right_batch_idx += 1;
if right_batch.num_rows() == 0 {
continue;
}
let mut left_indices = Vec::new();
let mut right_indices = Vec::new();
let mut matched_left_rows = if matches!(self.join_type, JoinType::Left) {
Some(std::collections::HashSet::new())
} else {
None
};
let mut matched_right_rows = if matches!(self.join_type, JoinType::Right) {
Some(vec![false; right_batch.num_rows()])
} else {
None
};
let right_key_columns: Vec<usize> =
self.on.iter().map(|(_, right_col)| *right_col).collect();
const PARALLEL_THRESHOLD: usize = 1000;
if right_batch.num_rows() >= PARALLEL_THRESHOLD {
use dashmap::DashMap;
let parallel_matches: DashMap<usize, Vec<usize>> = DashMap::new();
(0..right_batch.num_rows())
.into_par_iter()
.for_each(|right_row| {
let key = extract_join_key(right_batch, &right_key_columns, right_row);
if let Some(left_rows) = build_table.get(&key) {
parallel_matches.insert(right_row, left_rows.clone());
}
});
for (right_row, left_rows) in parallel_matches.into_iter() {
for &left_row in &left_rows {
left_indices.push(left_row as u32);
right_indices.push(right_row as u32);
if let Some(ref mut matched) = matched_left_rows {
matched.insert(left_row);
}
if let Some(ref mut matched) = matched_right_rows {
matched[right_row] = true;
}
}
}
} else {
for right_row in 0..right_batch.num_rows() {
let key = extract_join_key(right_batch, &right_key_columns, right_row);
if let Some(left_rows) = build_table.get(&key) {
for &left_row in left_rows {
left_indices.push(left_row as u32);
right_indices.push(right_row as u32);
if let Some(ref mut matched) = matched_left_rows {
matched.insert(left_row);
}
if let Some(ref mut matched) = matched_right_rows {
matched[right_row] = true;
}
}
} else if matches!(self.join_type, JoinType::Right) {
}
}
}
if let Some(matched) = matched_left_rows {
for left_row in 0..left_batch.num_rows() {
if !matched.contains(&left_row) {
left_indices.push(left_row as u32);
right_indices.push(u32::MAX);
}
}
}
if let Some(matched) = matched_right_rows {
for (right_row, &was_matched) in matched.iter().enumerate() {
if !was_matched {
left_indices.push(u32::MAX);
right_indices.push(right_row as u32);
}
}
}
if left_indices.is_empty() {
continue;
}
let mut output_columns: Vec<ArrayRef> = Vec::new();
for col in left_batch.columns() {
let filtered_indices: Vec<u32> = left_indices
.iter()
.filter(|&&idx| idx != u32::MAX)
.copied()
.collect();
if filtered_indices.len() == left_indices.len() {
let left_idx_arr = UInt32Array::from(left_indices.clone());
output_columns.push(compute::take(col.as_ref(), &left_idx_arr, None)?);
} else {
output_columns.push(create_column_with_nulls(col, &left_indices, u32::MAX)?);
}
}
for col in right_batch.columns() {
let filtered_indices: Vec<u32> = right_indices
.iter()
.filter(|&&idx| idx != u32::MAX)
.copied()
.collect();
if filtered_indices.len() == right_indices.len() {
let right_idx_arr = UInt32Array::from(right_indices.clone());
output_columns.push(compute::take(col.as_ref(), &right_idx_arr, None)?);
} else {
output_columns.push(create_column_with_nulls(col, &right_indices, u32::MAX)?);
}
}
let result = RecordBatch::try_new(Arc::clone(&self.schema), output_columns)?;
if result.num_rows() > 0 {
return Ok(Some(result));
}
}
self.done = true;
Ok(None)
}
fn reset(&mut self) -> DbxResult<()> {
self.build_table = None;
self.left_batch = None;
self.done = false;
self.left.reset()?;
self.right.reset()
}
}