use super::Join;
use crate::error::Result;
use crate::executor::scan::{ColumnData, RecordBatch, Schema};
use crate::parser::ast::{BinaryOperator, Expr};
use std::collections::HashMap;
use std::sync::Arc;
impl Join {
pub(super) fn try_hash_join(
&self,
left: &RecordBatch,
right: &RecordBatch,
) -> Option<Result<RecordBatch>> {
let condition = self.on_condition.as_ref()?;
if let Expr::BinaryOp {
left: left_expr,
op: BinaryOperator::Eq,
right: right_expr,
} = condition
{
if let (
Expr::Column {
table: left_table,
name: left_col,
},
Expr::Column {
table: right_table,
name: right_col,
},
) = (left_expr.as_ref(), right_expr.as_ref())
{
let (left_idx, right_idx) = self.resolve_join_columns(
left,
right,
left_table.as_deref(),
left_col,
right_table.as_deref(),
right_col,
)?;
if !self
.are_types_hash_compatible(&left.columns[left_idx], &right.columns[right_idx])
{
return None;
}
return Some(self.hash_join_impl(left, right, left_idx, right_idx));
}
}
None
}
pub(super) fn are_types_hash_compatible(&self, left: &ColumnData, right: &ColumnData) -> bool {
matches!(
(left, right),
(ColumnData::Boolean(_), ColumnData::Boolean(_))
| (ColumnData::Int32(_), ColumnData::Int32(_))
| (ColumnData::Int32(_), ColumnData::Int64(_))
| (ColumnData::Int64(_), ColumnData::Int32(_))
| (ColumnData::Int64(_), ColumnData::Int64(_))
| (ColumnData::Float32(_), ColumnData::Float32(_))
| (ColumnData::Float32(_), ColumnData::Float64(_))
| (ColumnData::Float64(_), ColumnData::Float32(_))
| (ColumnData::Float64(_), ColumnData::Float64(_))
| (ColumnData::String(_), ColumnData::String(_))
)
}
pub(super) fn resolve_join_columns(
&self,
left: &RecordBatch,
right: &RecordBatch,
table1: Option<&str>,
col1: &str,
table2: Option<&str>,
col2: &str,
) -> Option<(usize, usize)> {
let left_alias = self.left_alias.as_deref();
let right_alias = self.right_alias.as_deref();
let left_idx_1 = self.find_column_index(left, table1, col1, left_alias);
let right_idx_2 = self.find_column_index(right, table2, col2, right_alias);
if let (Some(l), Some(r)) = (left_idx_1, right_idx_2) {
return Some((l, r));
}
let right_idx_1 = self.find_column_index(right, table1, col1, right_alias);
let left_idx_2 = self.find_column_index(left, table2, col2, left_alias);
if let (Some(l), Some(r)) = (left_idx_2, right_idx_1) {
return Some((l, r));
}
let left_by_name_1 = left.schema.index_of(col1);
let right_by_name_2 = right.schema.index_of(col2);
if let (Some(l), Some(r)) = (left_by_name_1, right_by_name_2) {
return Some((l, r));
}
None
}
pub(super) fn find_column_index(
&self,
batch: &RecordBatch,
table: Option<&str>,
col_name: &str,
alias: Option<&str>,
) -> Option<usize> {
match (table, alias) {
(Some(t), Some(a)) if t == a => batch.schema.index_of(col_name),
(Some(_), Some(_)) => None, (None, _) => batch.schema.index_of(col_name),
(Some(t), None) => {
batch
.schema
.index_of(col_name)
.filter(|_| {
true
})
.or_else(|| {
batch.schema.index_of(&format!("{}.{}", t, col_name))
})
}
}
}
pub(super) fn hash_join_impl(
&self,
left: &RecordBatch,
right: &RecordBatch,
left_col_idx: usize,
right_col_idx: usize,
) -> Result<RecordBatch> {
let mut hash_table: HashMap<String, Vec<usize>> = HashMap::new();
for row in 0..left.num_rows {
let value = self.get_column_value(&left.columns[left_col_idx], row)?;
if value.is_null() {
continue;
}
let key = value.to_hash_key();
hash_table.entry(key).or_default().push(row);
}
let mut result_columns = Vec::new();
let mut result_fields = Vec::new();
for field in &left.schema.fields {
result_fields.push(field.clone());
}
for field in &right.schema.fields {
result_fields.push(field.clone());
}
for _ in 0..result_fields.len() {
result_columns.push(Vec::new());
}
let mut result_rows = 0;
for right_row in 0..right.num_rows {
let value = self.get_column_value(&right.columns[right_col_idx], right_row)?;
if value.is_null() {
continue;
}
let key = value.to_hash_key();
if let Some(left_rows) = hash_table.get(&key) {
for &left_row in left_rows {
self.append_row(&mut result_columns, left, right, left_row, right_row)?;
result_rows += 1;
}
}
}
let schema = Arc::new(Schema::new(result_fields));
let columns = self.convert_columns(result_columns);
RecordBatch::new(schema, columns, result_rows)
}
}