use std::fmt;
use std::sync::Arc;
use crate::ExecutionPlan;
use crate::ExecutionPlanProperties;
use crate::joins::Map;
use crate::joins::PartitionMode;
use crate::joins::hash_join::exec::HASH_JOIN_SEED;
use crate::joins::hash_join::inlist_builder::build_struct_fields;
use crate::joins::hash_join::partitioned_hash_eval::{
HashExpr, HashTableLookupExpr, SeededRandomState,
};
use arrow::array::ArrayRef;
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::config::ConfigOptions;
use datafusion_common::{DataFusionError, Result, ScalarValue, SharedResult};
use datafusion_expr::Operator;
use datafusion_functions::core::r#struct as struct_func;
use datafusion_physical_expr::expressions::{
BinaryExpr, CaseExpr, DynamicFilterPhysicalExpr, InListExpr, lit,
};
use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef, ScalarFunctionExpr};
use parking_lot::Mutex;
use tokio::sync::Notify;
#[derive(Debug, Clone, PartialEq)]
pub(crate) struct ColumnBounds {
pub(crate) min: ScalarValue,
pub(crate) max: ScalarValue,
}
impl ColumnBounds {
pub(crate) fn new(min: ScalarValue, max: ScalarValue) -> Self {
Self { min, max }
}
}
#[derive(Debug, Clone)]
pub(crate) struct PartitionBounds {
column_bounds: Vec<ColumnBounds>,
}
impl PartitionBounds {
pub(crate) fn new(column_bounds: Vec<ColumnBounds>) -> Self {
Self { column_bounds }
}
pub(crate) fn get_column_bounds(&self, index: usize) -> Option<&ColumnBounds> {
self.column_bounds.get(index)
}
}
fn create_membership_predicate(
on_right: &[PhysicalExprRef],
pushdown: PushdownStrategy,
random_state: &SeededRandomState,
schema: &Schema,
) -> Result<Option<Arc<dyn PhysicalExpr>>> {
match pushdown {
PushdownStrategy::InList(in_list_array) => {
let expr = if on_right.len() == 1 {
Arc::clone(&on_right[0])
} else {
let fields = build_struct_fields(
on_right
.iter()
.map(|r| r.data_type(schema))
.collect::<Result<Vec<_>>>()?
.as_ref(),
)?;
let return_field =
Arc::new(Field::new("struct", DataType::Struct(fields), true));
Arc::new(ScalarFunctionExpr::new(
"struct",
struct_func(),
on_right.to_vec(),
return_field,
Arc::new(ConfigOptions::default()),
)) as Arc<dyn PhysicalExpr>
};
Ok(Some(Arc::new(InListExpr::try_new_from_array(
expr,
in_list_array,
false,
schema,
)?)))
}
PushdownStrategy::Map(hash_map) => Ok(Some(Arc::new(HashTableLookupExpr::new(
on_right.to_vec(),
random_state.clone(),
hash_map,
"hash_lookup".to_string(),
)) as Arc<dyn PhysicalExpr>)),
PushdownStrategy::Empty => Ok(None),
}
}
fn create_bounds_predicate(
on_right: &[PhysicalExprRef],
bounds: &PartitionBounds,
) -> Option<Arc<dyn PhysicalExpr>> {
let mut column_predicates = Vec::new();
for (col_idx, right_expr) in on_right.iter().enumerate() {
if let Some(column_bounds) = bounds.get_column_bounds(col_idx) {
let min_expr = Arc::new(BinaryExpr::new(
Arc::clone(right_expr),
Operator::GtEq,
lit(column_bounds.min.clone()),
)) as Arc<dyn PhysicalExpr>;
let max_expr = Arc::new(BinaryExpr::new(
Arc::clone(right_expr),
Operator::LtEq,
lit(column_bounds.max.clone()),
)) as Arc<dyn PhysicalExpr>;
let range_expr = Arc::new(BinaryExpr::new(min_expr, Operator::And, max_expr))
as Arc<dyn PhysicalExpr>;
column_predicates.push(range_expr);
}
}
if column_predicates.is_empty() {
None
} else {
Some(
column_predicates
.into_iter()
.reduce(|acc, pred| {
Arc::new(BinaryExpr::new(acc, Operator::And, pred))
as Arc<dyn PhysicalExpr>
})
.unwrap(),
)
}
}
fn combine_membership_and_bounds(
membership_expr: Option<Arc<dyn PhysicalExpr>>,
bounds_expr: Option<Arc<dyn PhysicalExpr>>,
) -> Option<Arc<dyn PhysicalExpr>> {
match (membership_expr, bounds_expr) {
(Some(membership), Some(bounds)) => {
Some(Arc::new(BinaryExpr::new(bounds, Operator::And, membership))
as Arc<dyn PhysicalExpr>)
}
(Some(membership), None) => Some(membership),
(None, Some(bounds)) => Some(bounds),
(None, None) => None,
}
}
pub(crate) struct SharedBuildAccumulator {
inner: Mutex<AccumulatorState>,
completion_notify: Notify,
dynamic_filter: Arc<DynamicFilterPhysicalExpr>,
on_right: Vec<PhysicalExprRef>,
repartition_random_state: SeededRandomState,
probe_schema: Arc<Schema>,
}
#[derive(Clone)]
pub(crate) enum PushdownStrategy {
InList(ArrayRef),
Map(Arc<Map>),
Empty,
}
pub(crate) enum PartitionBuildData {
Partitioned {
partition_id: usize,
pushdown: PushdownStrategy,
bounds: PartitionBounds,
},
CollectLeft {
pushdown: PushdownStrategy,
bounds: PartitionBounds,
},
}
#[derive(Clone)]
struct PartitionData {
bounds: PartitionBounds,
pushdown: PushdownStrategy,
}
enum AccumulatedBuildData {
Partitioned {
partitions: Vec<PartitionStatus>,
completed_partitions: usize,
},
CollectLeft {
data: PartitionStatus,
reported_count: usize,
expected_reports: usize,
},
}
enum CompletionState {
Pending,
Finalizing,
Ready(SharedResult<()>),
}
struct AccumulatorState {
data: AccumulatedBuildData,
completion: CompletionState,
}
#[derive(Clone)]
enum PartitionStatus {
Pending,
Reported(PartitionData),
CanceledUnknown,
}
#[derive(Clone)]
enum FinalizeInput {
Partitioned(Vec<PartitionStatus>),
CollectLeft(PartitionStatus),
}
impl SharedBuildAccumulator {
pub(crate) fn new_from_partition_mode(
partition_mode: PartitionMode,
left_child: &dyn ExecutionPlan,
right_child: &dyn ExecutionPlan,
dynamic_filter: Arc<DynamicFilterPhysicalExpr>,
on_right: Vec<PhysicalExprRef>,
repartition_random_state: SeededRandomState,
) -> Self {
let expected_calls = match partition_mode {
PartitionMode::CollectLeft => {
right_child.output_partitioning().partition_count()
}
PartitionMode::Partitioned => {
left_child.output_partitioning().partition_count()
}
PartitionMode::Auto => unreachable!(
"PartitionMode::Auto should not be present at execution time. This is a bug in DataFusion, please report it!"
),
};
let mode_data = match partition_mode {
PartitionMode::Partitioned => AccumulatedBuildData::Partitioned {
partitions: vec![
PartitionStatus::Pending;
left_child.output_partitioning().partition_count()
],
completed_partitions: 0,
},
PartitionMode::CollectLeft => AccumulatedBuildData::CollectLeft {
data: PartitionStatus::Pending,
reported_count: 0,
expected_reports: expected_calls,
},
PartitionMode::Auto => unreachable!(
"PartitionMode::Auto should not be present at execution time. This is a bug in DataFusion, please report it!"
),
};
Self {
inner: Mutex::new(AccumulatorState {
data: mode_data,
completion: CompletionState::Pending,
}),
completion_notify: Notify::new(),
dynamic_filter,
on_right,
repartition_random_state,
probe_schema: right_child.schema(),
}
}
pub(crate) async fn report_build_data(&self, data: PartitionBuildData) -> Result<()> {
let finalize_input = {
let mut guard = self.inner.lock();
self.store_build_data(&mut guard, data)?;
self.take_finalize_input_if_ready(&mut guard)
};
if let Some(finalize_input) = finalize_input {
self.finish(finalize_input);
}
self.wait_for_completion().await
}
pub(crate) fn report_canceled_partition(&self, partition_id: usize) {
let finalize_input = {
let mut guard = self.inner.lock();
self.store_canceled_partition(&mut guard, partition_id);
self.take_finalize_input_if_ready(&mut guard)
};
if let Some(finalize_input) = finalize_input {
self.finish(finalize_input);
}
}
fn store_build_data(
&self,
guard: &mut AccumulatorState,
data: PartitionBuildData,
) -> Result<()> {
match (data, &mut guard.data) {
(
PartitionBuildData::Partitioned {
partition_id,
pushdown,
bounds,
},
AccumulatedBuildData::Partitioned {
partitions,
completed_partitions,
},
) => {
if matches!(partitions[partition_id], PartitionStatus::Pending) {
*completed_partitions += 1;
}
partitions[partition_id] =
PartitionStatus::Reported(PartitionData { pushdown, bounds });
}
(
PartitionBuildData::CollectLeft { pushdown, bounds },
AccumulatedBuildData::CollectLeft {
data,
reported_count,
..
},
) => {
if matches!(data, PartitionStatus::Pending) {
*data = PartitionStatus::Reported(PartitionData { pushdown, bounds });
}
*reported_count += 1;
}
_ => {
return datafusion_common::internal_err!(
"Build data mode mismatch in report_build_data"
);
}
}
Ok(())
}
fn store_canceled_partition(
&self,
guard: &mut AccumulatorState,
partition_id: usize,
) {
if let AccumulatedBuildData::Partitioned {
partitions,
completed_partitions,
} = &mut guard.data
&& matches!(partitions[partition_id], PartitionStatus::Pending)
{
partitions[partition_id] = PartitionStatus::CanceledUnknown;
*completed_partitions += 1;
}
}
fn take_finalize_input_if_ready(
&self,
guard: &mut AccumulatorState,
) -> Option<FinalizeInput> {
if !matches!(guard.completion, CompletionState::Pending) {
return None;
}
let finalize_input = match &guard.data {
AccumulatedBuildData::Partitioned {
partitions,
completed_partitions,
} if *completed_partitions == partitions.len() => {
Some(FinalizeInput::Partitioned(partitions.clone()))
}
AccumulatedBuildData::CollectLeft {
data,
reported_count,
expected_reports,
} if *reported_count == *expected_reports => {
Some(FinalizeInput::CollectLeft(data.clone()))
}
_ => None,
}?;
guard.completion = CompletionState::Finalizing;
Some(finalize_input)
}
fn finish(&self, finalize_input: FinalizeInput) {
let result = self.build_filter(finalize_input).map_err(Arc::new);
self.dynamic_filter.mark_complete();
let mut guard = self.inner.lock();
guard.completion = CompletionState::Ready(result);
drop(guard);
self.completion_notify.notify_waiters();
}
async fn wait_for_completion(&self) -> Result<()> {
loop {
let notified = {
let guard = self.inner.lock();
match &guard.completion {
CompletionState::Ready(Ok(())) => return Ok(()),
CompletionState::Ready(Err(err)) => {
return Err(DataFusionError::Shared(Arc::clone(err)));
}
CompletionState::Pending | CompletionState::Finalizing => {
self.completion_notify.notified()
}
}
};
notified.await;
}
}
fn build_filter(&self, finalize_input: FinalizeInput) -> Result<()> {
match finalize_input {
FinalizeInput::CollectLeft(partition) => match partition {
PartitionStatus::Reported(partition_data) => {
let membership_expr = create_membership_predicate(
&self.on_right,
partition_data.pushdown.clone(),
&HASH_JOIN_SEED,
self.probe_schema.as_ref(),
)?;
let bounds_expr =
create_bounds_predicate(&self.on_right, &partition_data.bounds);
if let Some(filter_expr) =
combine_membership_and_bounds(membership_expr, bounds_expr)
{
self.dynamic_filter.update(filter_expr)?;
}
}
PartitionStatus::Pending => {
return datafusion_common::internal_err!(
"attempted to finalize collect-left dynamic filter without reported build data"
);
}
PartitionStatus::CanceledUnknown => {
return datafusion_common::internal_err!(
"collect-left dynamic filter cannot finalize with canceled build data"
);
}
},
FinalizeInput::Partitioned(partitions) => {
let num_partitions = partitions.len();
let routing_hash_expr = Arc::new(HashExpr::new(
self.on_right.clone(),
self.repartition_random_state.clone(),
"hash_repartition".to_string(),
)) as Arc<dyn PhysicalExpr>;
let modulo_expr = Arc::new(BinaryExpr::new(
routing_hash_expr,
Operator::Modulo,
lit(ScalarValue::UInt64(Some(num_partitions as u64))),
)) as Arc<dyn PhysicalExpr>;
let mut real_branches = Vec::new();
let mut empty_partition_ids = Vec::new();
let mut has_canceled_unknown = false;
for (partition_id, partition) in partitions.iter().enumerate() {
match partition {
PartitionStatus::Reported(partition)
if matches!(partition.pushdown, PushdownStrategy::Empty) =>
{
empty_partition_ids.push(partition_id);
}
PartitionStatus::Reported(partition) => {
let membership_expr = create_membership_predicate(
&self.on_right,
partition.pushdown.clone(),
&HASH_JOIN_SEED,
self.probe_schema.as_ref(),
)?;
let bounds_expr = create_bounds_predicate(
&self.on_right,
&partition.bounds,
);
let then_expr = combine_membership_and_bounds(
membership_expr,
bounds_expr,
)
.unwrap_or_else(|| lit(true));
real_branches.push((
lit(ScalarValue::UInt64(Some(partition_id as u64))),
then_expr,
));
}
PartitionStatus::CanceledUnknown => {
has_canceled_unknown = true;
}
PartitionStatus::Pending => {
return datafusion_common::internal_err!(
"attempted to finalize dynamic filter with pending partition"
);
}
}
}
let filter_expr = if has_canceled_unknown {
let mut when_then_branches = empty_partition_ids
.into_iter()
.map(|partition_id| {
(
lit(ScalarValue::UInt64(Some(partition_id as u64))),
lit(false),
)
})
.collect::<Vec<_>>();
when_then_branches.extend(real_branches);
if when_then_branches.is_empty() {
lit(true)
} else {
Arc::new(CaseExpr::try_new(
Some(modulo_expr),
when_then_branches,
Some(lit(true)),
)?) as Arc<dyn PhysicalExpr>
}
} else if real_branches.is_empty() {
lit(false)
} else if real_branches.len() == 1
&& empty_partition_ids.len() + 1 == num_partitions
{
Arc::clone(&real_branches[0].1)
} else {
Arc::new(CaseExpr::try_new(
Some(modulo_expr),
real_branches,
Some(lit(false)),
)?) as Arc<dyn PhysicalExpr>
};
self.dynamic_filter.update(filter_expr)?;
}
}
Ok(())
}
}
impl fmt::Debug for SharedBuildAccumulator {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "SharedBuildAccumulator")
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_partitioned_accumulator(num_partitions: usize) -> SharedBuildAccumulator {
let probe_schema = Arc::new(Schema::new(vec![Field::new(
"probe_key",
DataType::Int32,
false,
)]));
let dynamic_filter = Arc::new(DynamicFilterPhysicalExpr::new(vec![], lit(true)));
SharedBuildAccumulator {
inner: Mutex::new(AccumulatorState {
data: AccumulatedBuildData::Partitioned {
partitions: vec![PartitionStatus::Pending; num_partitions],
completed_partitions: 0,
},
completion: CompletionState::Pending,
}),
completion_notify: Notify::new(),
dynamic_filter,
on_right: vec![],
repartition_random_state: SeededRandomState::with_seed(1),
probe_schema,
}
}
fn partitioned_state(acc: &SharedBuildAccumulator) -> (Vec<PartitionStatus>, usize) {
let guard = acc.inner.lock();
let AccumulatedBuildData::Partitioned {
partitions,
completed_partitions,
} = &guard.data
else {
panic!("expected partitioned accumulator");
};
(partitions.clone(), *completed_partitions)
}
#[test]
fn report_canceled_partition_is_noop_after_report() {
let acc = make_partitioned_accumulator(2);
{
let mut guard = acc.inner.lock();
acc.store_build_data(
&mut guard,
PartitionBuildData::Partitioned {
partition_id: 0,
pushdown: PushdownStrategy::Empty,
bounds: PartitionBounds::new(vec![]),
},
)
.unwrap();
}
let (partitions, completed) = partitioned_state(&acc);
assert!(matches!(partitions[0], PartitionStatus::Reported(_)));
assert_eq!(completed, 1);
acc.report_canceled_partition(0);
let (partitions, completed) = partitioned_state(&acc);
assert!(
matches!(partitions[0], PartitionStatus::Reported(_)),
"late cancel must not overwrite a prior Reported status"
);
assert_eq!(completed, 1, "late cancel must not double-count completion");
}
#[test]
fn report_canceled_partition_marks_pending_partition_canceled() {
let acc = make_partitioned_accumulator(2);
acc.report_canceled_partition(0);
let (partitions, completed) = partitioned_state(&acc);
assert!(matches!(partitions[0], PartitionStatus::CanceledUnknown));
assert_eq!(completed, 1);
acc.report_canceled_partition(0);
let (partitions, completed) = partitioned_state(&acc);
assert!(matches!(partitions[0], PartitionStatus::CanceledUnknown));
assert_eq!(completed, 1);
}
}