use datafusion::common::Result;
use datafusion::logical_expr::JoinType;
use datafusion::physical_expr::expressions::Column as PhysicalColumn;
use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode, SortMergeJoinExec};
use datafusion::physical_plan::{ExecutionPlan, PhysicalExpr};
use datafusion_common::{DataFusionError, NullEquality};
use std::sync::Arc;
pub fn try_create_index_lookup_join(
left: Arc<dyn ExecutionPlan>,
right: Arc<dyn ExecutionPlan>,
) -> Result<Arc<dyn ExecutionPlan>> {
let left_schema = left.schema();
let right_schema = right.schema();
let join_on: Vec<(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)> = left_schema
.fields()
.iter()
.enumerate()
.map(|(i, field)| {
let right_idx = right_schema.index_of(field.name()).map_err(|_| {
DataFusionError::Plan(format!(
"PK column '{}' not found in right join schema: {:?}",
field.name(),
right_schema
))
})?;
Ok((
Arc::new(PhysicalColumn::new(field.name(), i)) as Arc<dyn PhysicalExpr>,
Arc::new(PhysicalColumn::new(field.name(), right_idx)) as Arc<dyn PhysicalExpr>,
))
})
.collect::<Result<Vec<_>>>()?;
let both_sorted = match (
left.properties().output_ordering(),
right.properties().output_ordering(),
) {
(Some(left_ord), Some(right_ord)) if left_ord.eq(right_ord) => {
Some(left_ord.iter().map(|e| e.options).collect())
}
_ => None,
};
match both_sorted {
Some(sort_options) => Ok(Arc::new(SortMergeJoinExec::try_new(
left,
right,
join_on,
None,
JoinType::Inner,
sort_options,
NullEquality::NullEqualsNull,
)?)),
None => Ok(Arc::new(HashJoinExec::try_new(
left,
right,
join_on,
None,
&JoinType::Inner,
None,
PartitionMode::CollectLeft,
NullEquality::NullEqualsNull,
)?)),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::physical_plan::create_plan_properties_for_pk_scan;
use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use datafusion::common::Statistics;
use datafusion::execution::context::TaskContext;
use datafusion::execution::SendableRecordBatchStream;
use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
use std::any::Any;
use std::fmt;
#[derive(Debug)]
struct MockExec {
plan_properties: PlanProperties,
schema: SchemaRef,
}
impl MockExec {
fn new(ordered: bool) -> Self {
let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::UInt64, false)]));
let plan_properties = create_plan_properties_for_pk_scan(schema.clone(), ordered);
Self {
plan_properties,
schema,
}
}
}
impl DisplayAs for MockExec {
fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "MockExec")
}
}
impl ExecutionPlan for MockExec {
fn name(&self) -> &str {
"MockExec"
}
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
fn properties(&self) -> &PlanProperties {
&self.plan_properties
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![]
}
fn with_new_children(
self: Arc<Self>,
_: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
unimplemented!()
}
fn execute(
&self,
_partition: usize,
_context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
unimplemented!()
}
fn statistics(&self) -> Result<Statistics> {
unimplemented!()
}
}
#[test]
fn test_uses_sort_merge_join_for_ordered_inputs() -> Result<()> {
let left = Arc::new(MockExec::new(true));
let right = Arc::new(MockExec::new(true));
let join_plan = try_create_index_lookup_join(left, right)?;
assert_eq!(join_plan.name(), "SortMergeJoinExec");
Ok(())
}
#[test]
fn test_uses_hash_join_for_unordered_inputs() -> Result<()> {
let left = Arc::new(MockExec::new(true));
let right = Arc::new(MockExec::new(false));
let join_plan = try_create_index_lookup_join(left, right)?;
assert_eq!(join_plan.name(), "HashJoinExec");
let left = Arc::new(MockExec::new(false));
let right = Arc::new(MockExec::new(false));
let join_plan = try_create_index_lookup_join(left, right)?;
assert_eq!(join_plan.name(), "HashJoinExec");
Ok(())
}
}