use std::collections::{HashMap, VecDeque};
use std::mem::size_of;
use std::sync::Arc;
use crate::joins::MapOffset;
use crate::joins::join_hash_map::{
contain_hashes, get_matched_indices, get_matched_indices_with_limit_offset,
update_from_iter,
};
use crate::joins::utils::{JoinFilter, JoinHashMapType};
use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricBuilder};
use crate::{ExecutionPlan, metrics};
use arrow::array::{
ArrowPrimitiveType, BooleanArray, BooleanBufferBuilder, NativeAdapter,
PrimitiveArray, RecordBatch,
};
use arrow::compute::concat_batches;
use arrow::datatypes::{ArrowNativeType, Schema, SchemaRef};
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::utils::memory::estimate_memory_size;
use datafusion_common::{HashSet, JoinSide, Result, ScalarValue, arrow_datafusion_err};
use datafusion_expr::interval_arithmetic::Interval;
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph;
use datafusion_physical_expr::utils::collect_columns;
use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
use datafusion_physical_expr_common::sort_expr::LexOrdering;
use hashbrown::HashTable;
impl JoinHashMapType for PruningJoinHashMap {
fn extend_zero(&mut self, len: usize) {
self.next.resize(self.next.len() + len, 0)
}
fn update_from_iter<'a>(
&mut self,
iter: Box<dyn Iterator<Item = (usize, &'a u64)> + Send + 'a>,
deleted_offset: usize,
) {
let slice: &mut [u64] = self.next.make_contiguous();
update_from_iter::<u64>(&mut self.map, slice, iter, deleted_offset);
}
fn get_matched_indices<'a>(
&self,
iter: Box<dyn Iterator<Item = (usize, &'a u64)> + 'a>,
deleted_offset: Option<usize>,
) -> (Vec<u32>, Vec<u64>) {
let next: Vec<u64> = self.next.iter().copied().collect();
get_matched_indices::<u64>(&self.map, &next, iter, deleted_offset)
}
fn get_matched_indices_with_limit_offset(
&self,
hash_values: &[u64],
limit: usize,
offset: MapOffset,
input_indices: &mut Vec<u32>,
match_indices: &mut Vec<u64>,
) -> Option<MapOffset> {
let next: Vec<u64> = self.next.iter().copied().collect();
get_matched_indices_with_limit_offset::<u64>(
&self.map,
&next,
hash_values,
limit,
offset,
input_indices,
match_indices,
)
}
fn contain_hashes(&self, hash_values: &[u64]) -> BooleanArray {
contain_hashes(&self.map, hash_values)
}
fn is_empty(&self) -> bool {
self.map.is_empty()
}
fn len(&self) -> usize {
self.map.len()
}
}
pub struct PruningJoinHashMap {
pub map: HashTable<(u64, u64)>,
pub next: VecDeque<u64>,
}
impl PruningJoinHashMap {
pub(crate) fn with_capacity(capacity: usize) -> Self {
PruningJoinHashMap {
map: HashTable::with_capacity(capacity),
next: VecDeque::with_capacity(capacity),
}
}
pub(crate) fn shrink_if_necessary(&mut self, scale_factor: usize) {
let capacity = self.map.capacity();
if capacity > scale_factor * self.map.len() {
let new_capacity = (capacity * (scale_factor - 1)) / scale_factor;
self.map.shrink_to(new_capacity, |(hash, _)| *hash)
}
}
pub(crate) fn size(&self) -> usize {
let fixed_size = size_of::<PruningJoinHashMap>();
estimate_memory_size::<(u64, u64)>(self.map.capacity(), fixed_size).unwrap()
+ self.next.capacity() * size_of::<u64>()
}
pub(crate) fn prune_hash_values(
&mut self,
prune_length: usize,
deleting_offset: u64,
shrink_factor: usize,
) {
self.next.drain(0..prune_length);
let removable_keys = self
.map
.iter()
.filter_map(|(hash, tail_index)| {
(*tail_index < prune_length as u64 + deleting_offset).then_some(*hash)
})
.collect::<Vec<_>>();
removable_keys.into_iter().for_each(|hash_value| {
self.map
.find_entry(hash_value, |(hash, _)| hash_value == *hash)
.unwrap()
.remove();
});
self.shrink_if_necessary(shrink_factor);
}
}
fn check_filter_expr_contains_sort_information(
expr: &Arc<dyn PhysicalExpr>,
reference: &Arc<dyn PhysicalExpr>,
) -> bool {
expr.eq(reference)
|| expr
.children()
.iter()
.any(|e| check_filter_expr_contains_sort_information(e, reference))
}
pub fn map_origin_col_to_filter_col(
filter: &JoinFilter,
schema: &SchemaRef,
side: &JoinSide,
) -> Result<HashMap<Column, Column>> {
let filter_schema = filter.schema();
let mut col_to_col_map = HashMap::<Column, Column>::new();
for (filter_schema_index, index) in filter.column_indices().iter().enumerate() {
if index.side.eq(side) {
let main_field = schema.field(index.index);
let main_col = Column::new_with_schema(main_field.name(), schema.as_ref())?;
let filter_field = filter_schema.field(filter_schema_index);
let filter_col = Column::new(filter_field.name(), filter_schema_index);
col_to_col_map.insert(main_col, filter_col);
}
}
Ok(col_to_col_map)
}
pub fn convert_sort_expr_with_filter_schema(
side: &JoinSide,
filter: &JoinFilter,
schema: &SchemaRef,
sort_expr: &PhysicalSortExpr,
) -> Result<Option<Arc<dyn PhysicalExpr>>> {
let column_map = map_origin_col_to_filter_col(filter, schema, side)?;
let expr = Arc::clone(&sort_expr.expr);
let expr_columns = collect_columns(&expr);
let all_columns_are_included =
expr_columns.iter().all(|col| column_map.contains_key(col));
if all_columns_are_included {
let converted_filter_expr = expr
.transform_up(|p| {
convert_filter_columns(p.as_ref(), &column_map).map(|transformed| {
match transformed {
Some(transformed) => Transformed::yes(transformed),
None => Transformed::no(p),
}
})
})
.data()?;
if check_filter_expr_contains_sort_information(
filter.expression(),
&converted_filter_expr,
) {
return Ok(Some(converted_filter_expr));
}
}
Ok(None)
}
pub fn build_filter_input_order(
side: JoinSide,
filter: &JoinFilter,
schema: &SchemaRef,
order: &PhysicalSortExpr,
) -> Result<Option<SortedFilterExpr>> {
let opt_expr = convert_sort_expr_with_filter_schema(&side, filter, schema, order)?;
opt_expr
.map(|filter_expr| {
SortedFilterExpr::try_new(order.clone(), filter_expr, filter.schema())
})
.transpose()
}
fn convert_filter_columns(
input: &dyn PhysicalExpr,
column_map: &HashMap<Column, Column>,
) -> Result<Option<Arc<dyn PhysicalExpr>>> {
Ok(if let Some(col) = input.as_any().downcast_ref::<Column>() {
column_map.get(col).map(|c| Arc::new(c.clone()) as _)
} else {
None
})
}
#[derive(Debug, Clone)]
pub struct SortedFilterExpr {
origin_sorted_expr: PhysicalSortExpr,
filter_expr: Arc<dyn PhysicalExpr>,
interval: Interval,
node_index: usize,
}
impl SortedFilterExpr {
pub fn try_new(
origin_sorted_expr: PhysicalSortExpr,
filter_expr: Arc<dyn PhysicalExpr>,
filter_schema: &Schema,
) -> Result<Self> {
let dt = filter_expr.data_type(filter_schema)?;
Ok(Self {
origin_sorted_expr,
filter_expr,
interval: Interval::make_unbounded(&dt)?,
node_index: 0,
})
}
pub fn origin_sorted_expr(&self) -> &PhysicalSortExpr {
&self.origin_sorted_expr
}
pub fn filter_expr(&self) -> &Arc<dyn PhysicalExpr> {
&self.filter_expr
}
pub fn interval(&self) -> &Interval {
&self.interval
}
pub fn set_interval(&mut self, interval: Interval) {
self.interval = interval;
}
pub fn node_index(&self) -> usize {
self.node_index
}
pub fn set_node_index(&mut self, node_index: usize) {
self.node_index = node_index;
}
}
pub fn calculate_filter_expr_intervals(
build_input_buffer: &RecordBatch,
build_sorted_filter_expr: &mut SortedFilterExpr,
probe_batch: &RecordBatch,
probe_sorted_filter_expr: &mut SortedFilterExpr,
) -> Result<()> {
if build_input_buffer.num_rows() == 0 || probe_batch.num_rows() == 0 {
return Ok(());
}
update_filter_expr_interval(
&build_input_buffer.slice(0, 1),
build_sorted_filter_expr,
)?;
update_filter_expr_interval(
&probe_batch.slice(probe_batch.num_rows() - 1, 1),
probe_sorted_filter_expr,
)
}
pub fn update_filter_expr_interval(
batch: &RecordBatch,
sorted_expr: &mut SortedFilterExpr,
) -> Result<()> {
let array = sorted_expr
.origin_sorted_expr()
.expr
.evaluate(batch)?
.into_array(1)?;
let value = ScalarValue::try_from_array(&array, 0)?;
let inf = ScalarValue::try_from(value.data_type())?;
let interval = if sorted_expr.origin_sorted_expr().options.descending {
Interval::try_new(inf, value)?
} else {
Interval::try_new(value, inf)?
};
sorted_expr.set_interval(interval);
Ok(())
}
pub fn get_pruning_anti_indices<T: ArrowPrimitiveType>(
prune_length: usize,
deleted_offset: usize,
visited_rows: &HashSet<usize>,
) -> PrimitiveArray<T>
where
NativeAdapter<T>: From<<T as ArrowPrimitiveType>::Native>,
{
let mut bitmap = BooleanBufferBuilder::new(prune_length);
bitmap.append_n(prune_length, false);
for v in 0..prune_length {
let row = v + deleted_offset;
bitmap.set_bit(v, visited_rows.contains(&row));
}
(0..prune_length)
.filter_map(|idx| (!bitmap.get_bit(idx)).then_some(T::Native::from_usize(idx)))
.collect()
}
pub fn get_pruning_semi_indices<T: ArrowPrimitiveType>(
prune_length: usize,
deleted_offset: usize,
visited_rows: &HashSet<usize>,
) -> PrimitiveArray<T>
where
NativeAdapter<T>: From<<T as ArrowPrimitiveType>::Native>,
{
let mut bitmap = BooleanBufferBuilder::new(prune_length);
bitmap.append_n(prune_length, false);
(0..prune_length).for_each(|v| {
let row = &(v + deleted_offset);
bitmap.set_bit(v, visited_rows.contains(row));
});
(0..prune_length)
.filter_map(|idx| (bitmap.get_bit(idx)).then_some(T::Native::from_usize(idx)))
.collect()
}
pub fn combine_two_batches(
output_schema: &SchemaRef,
left_batch: Option<RecordBatch>,
right_batch: Option<RecordBatch>,
) -> Result<Option<RecordBatch>> {
match (left_batch, right_batch) {
(Some(batch), None) | (None, Some(batch)) => {
Ok(Some(batch))
}
(Some(left_batch), Some(right_batch)) => {
concat_batches(output_schema, &[left_batch, right_batch])
.map_err(|e| arrow_datafusion_err!(e))
.map(Some)
}
(None, None) => {
Ok(None)
}
}
}
pub fn record_visited_indices<T: ArrowPrimitiveType>(
visited: &mut HashSet<usize>,
offset: usize,
indices: &PrimitiveArray<T>,
) {
for i in indices.values() {
visited.insert(i.as_usize() + offset);
}
}
#[derive(Debug)]
pub struct StreamJoinSideMetrics {
pub(crate) input_batches: metrics::Count,
pub(crate) input_rows: metrics::Count,
}
#[derive(Debug)]
pub struct StreamJoinMetrics {
pub(crate) left: StreamJoinSideMetrics,
pub(crate) right: StreamJoinSideMetrics,
pub(crate) stream_memory_usage: metrics::Gauge,
pub(crate) baseline_metrics: BaselineMetrics,
}
impl StreamJoinMetrics {
pub fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self {
let input_batches =
MetricBuilder::new(metrics).counter("left_input_batches", partition);
let input_rows =
MetricBuilder::new(metrics).counter("left_input_rows", partition);
let left = StreamJoinSideMetrics {
input_batches,
input_rows,
};
let input_batches =
MetricBuilder::new(metrics).counter("right_input_batches", partition);
let input_rows =
MetricBuilder::new(metrics).counter("right_input_rows", partition);
let right = StreamJoinSideMetrics {
input_batches,
input_rows,
};
let stream_memory_usage =
MetricBuilder::new(metrics).gauge("stream_memory_usage", partition);
Self {
left,
right,
stream_memory_usage,
baseline_metrics: BaselineMetrics::new(metrics, partition),
}
}
}
fn update_sorted_exprs_with_node_indices(
graph: &mut ExprIntervalGraph,
sorted_exprs: &mut [SortedFilterExpr],
) {
let filter_exprs = sorted_exprs
.iter()
.map(|expr| Arc::clone(expr.filter_expr()))
.collect::<Vec<_>>();
let child_node_indices = graph.gather_node_indices(&filter_exprs);
for (sorted_expr, (_, index)) in sorted_exprs.iter_mut().zip(child_node_indices) {
sorted_expr.set_node_index(index);
}
}
pub fn prepare_sorted_exprs(
filter: &JoinFilter,
left: &Arc<dyn ExecutionPlan>,
right: &Arc<dyn ExecutionPlan>,
left_sort_exprs: &LexOrdering,
right_sort_exprs: &LexOrdering,
) -> Result<(SortedFilterExpr, SortedFilterExpr, ExprIntervalGraph)> {
let err = || {
datafusion_common::plan_datafusion_err!("Filter does not include the child order")
};
let left_temp_sorted_filter_expr = build_filter_input_order(
JoinSide::Left,
filter,
&left.schema(),
&left_sort_exprs[0],
)?
.ok_or_else(err)?;
let right_temp_sorted_filter_expr = build_filter_input_order(
JoinSide::Right,
filter,
&right.schema(),
&right_sort_exprs[0],
)?
.ok_or_else(err)?;
let mut sorted_exprs =
vec![left_temp_sorted_filter_expr, right_temp_sorted_filter_expr];
let mut graph =
ExprIntervalGraph::try_new(Arc::clone(filter.expression()), filter.schema())?;
update_sorted_exprs_with_node_indices(&mut graph, &mut sorted_exprs);
let right_sorted_filter_expr = sorted_exprs.swap_remove(1);
let left_sorted_filter_expr = sorted_exprs.swap_remove(0);
Ok((left_sorted_filter_expr, right_sorted_filter_expr, graph))
}
#[cfg(test)]
pub mod tests {
use super::*;
use crate::{joins::test_utils::complicated_filter, joins::utils::ColumnIndex};
use arrow::compute::SortOptions;
use arrow::datatypes::{DataType, Field};
use datafusion_expr::Operator;
use datafusion_physical_expr::expressions::{binary, cast, col};
#[test]
fn test_column_exchange() -> Result<()> {
let left_child_schema =
Schema::new(vec![Field::new("left_1", DataType::Int32, true)]);
let left_child_sort_expr = PhysicalSortExpr {
expr: col("left_1", &left_child_schema)?,
options: SortOptions::default(),
};
let right_child_schema = Schema::new(vec![
Field::new("right_1", DataType::Int32, true),
Field::new("right_2", DataType::Int32, true),
]);
let right_child_sort_expr = PhysicalSortExpr {
expr: binary(
col("right_1", &right_child_schema)?,
Operator::Plus,
col("right_2", &right_child_schema)?,
&right_child_schema,
)?,
options: SortOptions::default(),
};
let intermediate_schema = Schema::new(vec![
Field::new("filter_1", DataType::Int32, true),
Field::new("filter_2", DataType::Int32, true),
Field::new("filter_3", DataType::Int32, true),
]);
let filter_left = col("filter_1", &intermediate_schema)?;
let filter_right = binary(
col("filter_2", &intermediate_schema)?,
Operator::Plus,
col("filter_3", &intermediate_schema)?,
&intermediate_schema,
)?;
let filter_expr = binary(
Arc::clone(&filter_left),
Operator::Gt,
Arc::clone(&filter_right),
&intermediate_schema,
)?;
let column_indices = vec![
ColumnIndex {
index: 0,
side: JoinSide::Left,
},
ColumnIndex {
index: 0,
side: JoinSide::Right,
},
ColumnIndex {
index: 1,
side: JoinSide::Right,
},
];
let filter =
JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
let left_sort_filter_expr = build_filter_input_order(
JoinSide::Left,
&filter,
&Arc::new(left_child_schema),
&left_child_sort_expr,
)?
.unwrap();
assert!(left_child_sort_expr.eq(left_sort_filter_expr.origin_sorted_expr()));
let right_sort_filter_expr = build_filter_input_order(
JoinSide::Right,
&filter,
&Arc::new(right_child_schema),
&right_child_sort_expr,
)?
.unwrap();
assert!(right_child_sort_expr.eq(right_sort_filter_expr.origin_sorted_expr()));
assert!(filter_left.eq(left_sort_filter_expr.filter_expr()));
assert!(filter_right.eq(right_sort_filter_expr.filter_expr()));
Ok(())
}
#[test]
fn test_column_collector() -> Result<()> {
let schema = Schema::new(vec![
Field::new("0", DataType::Int32, true),
Field::new("1", DataType::Int32, true),
Field::new("2", DataType::Int32, true),
]);
let filter_expr = complicated_filter(&schema)?;
let columns = collect_columns(&filter_expr);
assert_eq!(columns.len(), 3);
Ok(())
}
#[test]
fn find_expr_inside_expr() -> Result<()> {
let schema = Schema::new(vec![
Field::new("0", DataType::Int32, true),
Field::new("1", DataType::Int32, true),
Field::new("2", DataType::Int32, true),
]);
let filter_expr = complicated_filter(&schema)?;
let expr_1 = Arc::new(Column::new("gnz", 0)) as _;
assert!(!check_filter_expr_contains_sort_information(
&filter_expr,
&expr_1
));
let expr_2 = col("1", &schema)? as _;
assert!(check_filter_expr_contains_sort_information(
&filter_expr,
&expr_2
));
let expr_3 = cast(
binary(
col("0", &schema)?,
Operator::Plus,
col("1", &schema)?,
&schema,
)?,
&schema,
DataType::Int64,
)?;
assert!(check_filter_expr_contains_sort_information(
&filter_expr,
&expr_3
));
let expr_4 = Arc::new(Column::new("1", 42)) as _;
assert!(!check_filter_expr_contains_sort_information(
&filter_expr,
&expr_4,
));
Ok(())
}
#[test]
fn build_sorted_expr() -> Result<()> {
let left_schema = Schema::new(vec![
Field::new("la1", DataType::Int32, false),
Field::new("lb1", DataType::Int32, false),
Field::new("lc1", DataType::Int32, false),
Field::new("lt1", DataType::Int32, false),
Field::new("la2", DataType::Int32, false),
Field::new("la1_des", DataType::Int32, false),
]);
let right_schema = Schema::new(vec![
Field::new("ra1", DataType::Int32, false),
Field::new("rb1", DataType::Int32, false),
Field::new("rc1", DataType::Int32, false),
Field::new("rt1", DataType::Int32, false),
Field::new("ra2", DataType::Int32, false),
Field::new("ra1_des", DataType::Int32, false),
]);
let intermediate_schema = Schema::new(vec![
Field::new("0", DataType::Int32, true),
Field::new("1", DataType::Int32, true),
Field::new("2", DataType::Int32, true),
]);
let filter_expr = complicated_filter(&intermediate_schema)?;
let column_indices = vec![
ColumnIndex {
index: left_schema.index_of("la1")?,
side: JoinSide::Left,
},
ColumnIndex {
index: left_schema.index_of("la2")?,
side: JoinSide::Left,
},
ColumnIndex {
index: right_schema.index_of("ra1")?,
side: JoinSide::Right,
},
];
let filter =
JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
let left_schema = Arc::new(left_schema);
let right_schema = Arc::new(right_schema);
assert!(
build_filter_input_order(
JoinSide::Left,
&filter,
&left_schema,
&PhysicalSortExpr {
expr: col("la1", left_schema.as_ref())?,
options: SortOptions::default(),
}
)?
.is_some()
);
assert!(
build_filter_input_order(
JoinSide::Left,
&filter,
&left_schema,
&PhysicalSortExpr {
expr: col("lt1", left_schema.as_ref())?,
options: SortOptions::default(),
}
)?
.is_none()
);
assert!(
build_filter_input_order(
JoinSide::Right,
&filter,
&right_schema,
&PhysicalSortExpr {
expr: col("ra1", right_schema.as_ref())?,
options: SortOptions::default(),
}
)?
.is_some()
);
assert!(
build_filter_input_order(
JoinSide::Right,
&filter,
&right_schema,
&PhysicalSortExpr {
expr: col("rb1", right_schema.as_ref())?,
options: SortOptions::default(),
}
)?
.is_none()
);
Ok(())
}
#[test]
fn sorted_filter_expr_build() -> Result<()> {
let intermediate_schema = Schema::new(vec![
Field::new("0", DataType::Int32, true),
Field::new("1", DataType::Int32, true),
]);
let filter_expr = binary(
col("0", &intermediate_schema)?,
Operator::Minus,
col("1", &intermediate_schema)?,
&intermediate_schema,
)?;
let column_indices = vec![
ColumnIndex {
index: 0,
side: JoinSide::Left,
},
ColumnIndex {
index: 1,
side: JoinSide::Left,
},
];
let filter =
JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
let schema = Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
]);
let sorted = PhysicalSortExpr {
expr: binary(
col("a", &schema)?,
Operator::Plus,
col("b", &schema)?,
&schema,
)?,
options: SortOptions::default(),
};
let res = convert_sort_expr_with_filter_schema(
&JoinSide::Left,
&filter,
&Arc::new(schema),
&sorted,
)?;
assert!(res.is_none());
Ok(())
}
#[test]
fn test_shrink_if_necessary() {
let scale_factor = 4;
let mut join_hash_map = PruningJoinHashMap::with_capacity(100);
let data_size = 2000;
let deleted_part = 3 * data_size / 4;
for hash_value in 0..data_size {
join_hash_map.map.insert_unique(
hash_value,
(hash_value, hash_value),
|(hash, _)| *hash,
);
}
assert_eq!(join_hash_map.map.len(), data_size as usize);
assert!(join_hash_map.map.capacity() >= data_size as usize);
for hash_value in 0..deleted_part {
join_hash_map
.map
.find_entry(hash_value, |(hash, _)| hash_value == *hash)
.unwrap()
.remove();
}
assert_eq!(join_hash_map.map.len(), (data_size - deleted_part) as usize);
let old_capacity = join_hash_map.map.capacity();
join_hash_map.shrink_if_necessary(scale_factor);
let new_expected_capacity =
join_hash_map.map.capacity() * (scale_factor - 1) / scale_factor;
assert!(join_hash_map.map.capacity() >= new_expected_capacity);
assert!(join_hash_map.map.capacity() <= old_capacity);
}
}