use std::any::Any;
use std::fmt::Formatter;
use std::sync::Arc;
use crate::execution_plan::{EmissionType, boundedness_from_children};
use crate::expressions::PhysicalSortExpr;
use crate::joins::sort_merge_join::metrics::SortMergeJoinMetrics;
use crate::joins::sort_merge_join::stream::SortMergeJoinStream;
use crate::joins::utils::{
JoinFilter, JoinOn, JoinOnRef, build_join_schema, check_join_is_valid,
estimate_join_statistics, reorder_output_after_swap,
symmetric_join_output_partitioning,
};
use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
use crate::projection::{
ProjectionExec, join_allows_pushdown, join_table_borders, new_join_children,
physical_to_column_exprs, update_join_on,
};
use crate::{
DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
PlanProperties, SendableRecordBatchStream, Statistics, check_if_same_properties,
};
use arrow::compute::SortOptions;
use arrow::datatypes::SchemaRef;
use datafusion_common::{
JoinSide, JoinType, NullEquality, Result, assert_eq_or_internal_err, internal_err,
plan_err,
};
use datafusion_execution::TaskContext;
use datafusion_execution::memory_pool::MemoryConsumer;
use datafusion_physical_expr::equivalence::join_equivalence_properties;
use datafusion_physical_expr_common::physical_expr::{PhysicalExprRef, fmt_sql};
use datafusion_physical_expr_common::sort_expr::{LexOrdering, OrderingRequirements};
#[derive(Debug, Clone)]
pub struct SortMergeJoinExec {
pub left: Arc<dyn ExecutionPlan>,
pub right: Arc<dyn ExecutionPlan>,
pub on: JoinOn,
pub filter: Option<JoinFilter>,
pub join_type: JoinType,
schema: SchemaRef,
metrics: ExecutionPlanMetricsSet,
left_sort_exprs: LexOrdering,
right_sort_exprs: LexOrdering,
pub sort_options: Vec<SortOptions>,
pub null_equality: NullEquality,
cache: Arc<PlanProperties>,
}
impl SortMergeJoinExec {
pub fn try_new(
left: Arc<dyn ExecutionPlan>,
right: Arc<dyn ExecutionPlan>,
on: JoinOn,
filter: Option<JoinFilter>,
join_type: JoinType,
sort_options: Vec<SortOptions>,
null_equality: NullEquality,
) -> Result<Self> {
let left_schema = left.schema();
let right_schema = right.schema();
check_join_is_valid(&left_schema, &right_schema, &on)?;
if sort_options.len() != on.len() {
return plan_err!(
"Expected number of sort options: {}, actual: {}",
on.len(),
sort_options.len()
);
}
let (left_sort_exprs, right_sort_exprs): (Vec<_>, Vec<_>) = on
.iter()
.zip(sort_options.iter())
.map(|((l, r), sort_op)| {
let left = PhysicalSortExpr {
expr: Arc::clone(l),
options: *sort_op,
};
let right = PhysicalSortExpr {
expr: Arc::clone(r),
options: *sort_op,
};
(left, right)
})
.unzip();
let Some(left_sort_exprs) = LexOrdering::new(left_sort_exprs) else {
return plan_err!(
"SortMergeJoinExec requires valid sort expressions for its left side"
);
};
let Some(right_sort_exprs) = LexOrdering::new(right_sort_exprs) else {
return plan_err!(
"SortMergeJoinExec requires valid sort expressions for its right side"
);
};
let schema =
Arc::new(build_join_schema(&left_schema, &right_schema, &join_type).0);
let cache =
Self::compute_properties(&left, &right, Arc::clone(&schema), join_type, &on)?;
Ok(Self {
left,
right,
on,
filter,
join_type,
schema,
metrics: ExecutionPlanMetricsSet::new(),
left_sort_exprs,
right_sort_exprs,
sort_options,
null_equality,
cache: Arc::new(cache),
})
}
pub fn probe_side(join_type: &JoinType) -> JoinSide {
match join_type {
JoinType::Right
| JoinType::RightSemi
| JoinType::RightAnti
| JoinType::RightMark => JoinSide::Right,
JoinType::Inner
| JoinType::Left
| JoinType::Full
| JoinType::LeftAnti
| JoinType::LeftSemi
| JoinType::LeftMark => JoinSide::Left,
}
}
fn maintains_input_order(join_type: JoinType) -> Vec<bool> {
match join_type {
JoinType::Inner => vec![true, false],
JoinType::Left
| JoinType::LeftSemi
| JoinType::LeftAnti
| JoinType::LeftMark => vec![true, false],
JoinType::Right
| JoinType::RightSemi
| JoinType::RightAnti
| JoinType::RightMark => {
vec![false, true]
}
_ => vec![false, false],
}
}
pub fn on(&self) -> &[(PhysicalExprRef, PhysicalExprRef)] {
&self.on
}
pub fn right(&self) -> &Arc<dyn ExecutionPlan> {
&self.right
}
pub fn join_type(&self) -> JoinType {
self.join_type
}
pub fn left(&self) -> &Arc<dyn ExecutionPlan> {
&self.left
}
pub fn filter(&self) -> &Option<JoinFilter> {
&self.filter
}
pub fn sort_options(&self) -> &[SortOptions] {
&self.sort_options
}
pub fn null_equality(&self) -> NullEquality {
self.null_equality
}
fn compute_properties(
left: &Arc<dyn ExecutionPlan>,
right: &Arc<dyn ExecutionPlan>,
schema: SchemaRef,
join_type: JoinType,
join_on: JoinOnRef,
) -> Result<PlanProperties> {
let eq_properties = join_equivalence_properties(
left.equivalence_properties().clone(),
right.equivalence_properties().clone(),
&join_type,
schema,
&Self::maintains_input_order(join_type),
Some(Self::probe_side(&join_type)),
join_on,
)?;
let output_partitioning =
symmetric_join_output_partitioning(left, right, &join_type)?;
Ok(PlanProperties::new(
eq_properties,
output_partitioning,
EmissionType::Incremental,
boundedness_from_children([left, right]),
))
}
pub fn swap_inputs(&self) -> Result<Arc<dyn ExecutionPlan>> {
let left = self.left();
let right = self.right();
let new_join = SortMergeJoinExec::try_new(
Arc::clone(right),
Arc::clone(left),
self.on()
.iter()
.map(|(l, r)| (Arc::clone(r), Arc::clone(l)))
.collect::<Vec<_>>(),
self.filter().as_ref().map(JoinFilter::swap),
self.join_type().swap(),
self.sort_options.clone(),
self.null_equality,
)?;
if matches!(
self.join_type(),
JoinType::LeftSemi
| JoinType::RightSemi
| JoinType::LeftAnti
| JoinType::RightAnti
) {
Ok(Arc::new(new_join))
} else {
reorder_output_after_swap(Arc::new(new_join), &left.schema(), &right.schema())
}
}
fn with_new_children_and_same_properties(
&self,
mut children: Vec<Arc<dyn ExecutionPlan>>,
) -> Self {
let left = children.swap_remove(0);
let right = children.swap_remove(0);
Self {
left,
right,
metrics: ExecutionPlanMetricsSet::new(),
..Self::clone(self)
}
}
}
impl DisplayAs for SortMergeJoinExec {
fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
let on = self
.on
.iter()
.map(|(c1, c2)| format!("({c1}, {c2})"))
.collect::<Vec<String>>()
.join(", ");
let display_null_equality =
if self.null_equality() == NullEquality::NullEqualsNull {
", NullsEqual: true"
} else {
""
};
write!(
f,
"{}: join_type={:?}, on=[{}]{}{}",
Self::static_name(),
self.join_type,
on,
self.filter.as_ref().map_or_else(
|| "".to_string(),
|f| format!(", filter={}", f.expression())
),
display_null_equality,
)
}
DisplayFormatType::TreeRender => {
let on = self
.on
.iter()
.map(|(c1, c2)| {
format!("({} = {})", fmt_sql(c1.as_ref()), fmt_sql(c2.as_ref()))
})
.collect::<Vec<String>>()
.join(", ");
if self.join_type() != JoinType::Inner {
writeln!(f, "join_type={:?}", self.join_type)?;
}
writeln!(f, "on={on}")?;
if self.null_equality() == NullEquality::NullEqualsNull {
writeln!(f, "NullsEqual: true")?;
}
Ok(())
}
}
}
}
impl ExecutionPlan for SortMergeJoinExec {
fn name(&self) -> &'static str {
"SortMergeJoinExec"
}
fn as_any(&self) -> &dyn Any {
self
}
fn properties(&self) -> &Arc<PlanProperties> {
&self.cache
}
fn required_input_distribution(&self) -> Vec<Distribution> {
let (left_expr, right_expr) = self
.on
.iter()
.map(|(l, r)| (Arc::clone(l), Arc::clone(r)))
.unzip();
vec![
Distribution::HashPartitioned(left_expr),
Distribution::HashPartitioned(right_expr),
]
}
fn required_input_ordering(&self) -> Vec<Option<OrderingRequirements>> {
vec![
Some(OrderingRequirements::from(self.left_sort_exprs.clone())),
Some(OrderingRequirements::from(self.right_sort_exprs.clone())),
]
}
fn maintains_input_order(&self) -> Vec<bool> {
Self::maintains_input_order(self.join_type)
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![&self.left, &self.right]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
check_if_same_properties!(self, children);
match &children[..] {
[left, right] => Ok(Arc::new(SortMergeJoinExec::try_new(
Arc::clone(left),
Arc::clone(right),
self.on.clone(),
self.filter.clone(),
self.join_type,
self.sort_options.clone(),
self.null_equality,
)?)),
_ => internal_err!("SortMergeJoin wrong number of children"),
}
}
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
let left_partitions = self.left.output_partitioning().partition_count();
let right_partitions = self.right.output_partitioning().partition_count();
assert_eq_or_internal_err!(
left_partitions,
right_partitions,
"Invalid SortMergeJoinExec, partition count mismatch {left_partitions}!={right_partitions},\
consider using RepartitionExec"
);
let (on_left, on_right) = self.on.iter().cloned().unzip();
let (streamed, buffered, on_streamed, on_buffered) =
if SortMergeJoinExec::probe_side(&self.join_type) == JoinSide::Left {
(
Arc::clone(&self.left),
Arc::clone(&self.right),
on_left,
on_right,
)
} else {
(
Arc::clone(&self.right),
Arc::clone(&self.left),
on_right,
on_left,
)
};
let streamed = streamed.execute(partition, Arc::clone(&context))?;
let buffered = buffered.execute(partition, Arc::clone(&context))?;
let batch_size = context.session_config().batch_size();
let reservation = MemoryConsumer::new(format!("SMJStream[{partition}]"))
.register(context.memory_pool());
Ok(Box::pin(SortMergeJoinStream::try_new(
context.session_config().spill_compression(),
Arc::clone(&self.schema),
self.sort_options.clone(),
self.null_equality,
streamed,
buffered,
on_streamed,
on_buffered,
self.filter.clone(),
self.join_type,
batch_size,
SortMergeJoinMetrics::new(partition, &self.metrics),
reservation,
context.runtime_env(),
)?))
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.metrics.clone_inner())
}
fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
estimate_join_statistics(
self.left.partition_statistics(partition)?,
self.right.partition_statistics(partition)?,
&self.on,
&self.join_type,
&self.schema,
)
}
fn try_swapping_with_projection(
&self,
projection: &ProjectionExec,
) -> Result<Option<Arc<dyn ExecutionPlan>>> {
let Some(projection_as_columns) = physical_to_column_exprs(projection.expr())
else {
return Ok(None);
};
let (far_right_left_col_ind, far_left_right_col_ind) = join_table_borders(
self.left().schema().fields().len(),
&projection_as_columns,
);
if !join_allows_pushdown(
&projection_as_columns,
&self.schema(),
far_right_left_col_ind,
far_left_right_col_ind,
) {
return Ok(None);
}
let Some(new_on) = update_join_on(
&projection_as_columns[0..=far_right_left_col_ind as _],
&projection_as_columns[far_left_right_col_ind as _..],
self.on(),
self.left().schema().fields().len(),
) else {
return Ok(None);
};
let (new_left, new_right) = new_join_children(
&projection_as_columns,
far_right_left_col_ind,
far_left_right_col_ind,
self.children()[0],
self.children()[1],
)?;
Ok(Some(Arc::new(SortMergeJoinExec::try_new(
Arc::new(new_left),
Arc::new(new_right),
new_on,
self.filter.clone(),
self.join_type,
self.sort_options.clone(),
self.null_equality,
)?)))
}
}