use std::sync::Arc;
use crate::joins::utils::{JoinFilter, JoinOn};
use crate::joins::{
HashJoinExec, PartitionMode, StreamJoinPartitionMode, SymmetricHashJoinExec,
};
use crate::repartition::RepartitionExec;
use crate::test::TestMemoryExec;
use crate::{ExecutionPlan, ExecutionPlanProperties, Partitioning, common};
use arrow::array::{
ArrayRef, Float64Array, Int32Array, IntervalDayTimeArray, RecordBatch,
TimestampMillisecondArray, types::IntervalDayTime,
};
use arrow::datatypes::{DataType, Schema};
use arrow::util::pretty::pretty_format_batches;
use datafusion_common::{NullEquality, Result, ScalarValue};
use datafusion_execution::TaskContext;
use datafusion_expr::{JoinType, Operator};
use datafusion_physical_expr::expressions::{binary, cast, col, lit};
use datafusion_physical_expr::intervals::test_utils::{
gen_conjunctive_numerical_expr, gen_conjunctive_temporal_expr,
};
use datafusion_physical_expr::{LexOrdering, PhysicalExpr};
use rand::prelude::StdRng;
use rand::{Rng, SeedableRng};
pub fn compare_batches(collected_1: &[RecordBatch], collected_2: &[RecordBatch]) {
let left_row_num: usize = collected_1.iter().map(|batch| batch.num_rows()).sum();
let right_row_num: usize = collected_2.iter().map(|batch| batch.num_rows()).sum();
if left_row_num == 0 && right_row_num == 0 {
return;
}
let first_formatted = pretty_format_batches(collected_1).unwrap().to_string();
let second_formatted = pretty_format_batches(collected_2).unwrap().to_string();
let mut first_lines: Vec<&str> = first_formatted.trim().lines().collect();
first_lines.sort_unstable();
let mut second_lines: Vec<&str> = second_formatted.trim().lines().collect();
second_lines.sort_unstable();
for (i, (first_line, second_line)) in
first_lines.iter().zip(&second_lines).enumerate()
{
assert_eq!((i, first_line), (i, second_line));
}
}
pub async fn partitioned_sym_join_with_filter(
left: Arc<dyn ExecutionPlan>,
right: Arc<dyn ExecutionPlan>,
on: JoinOn,
filter: Option<JoinFilter>,
join_type: &JoinType,
null_equality: NullEquality,
context: Arc<TaskContext>,
) -> Result<Vec<RecordBatch>> {
let partition_count = 4;
let left_expr = on
.iter()
.map(|(l, _)| Arc::clone(l) as _)
.collect::<Vec<_>>();
let right_expr = on
.iter()
.map(|(_, r)| Arc::clone(r) as _)
.collect::<Vec<_>>();
let join = SymmetricHashJoinExec::try_new(
Arc::new(RepartitionExec::try_new(
Arc::clone(&left),
Partitioning::Hash(left_expr, partition_count),
)?),
Arc::new(RepartitionExec::try_new(
Arc::clone(&right),
Partitioning::Hash(right_expr, partition_count),
)?),
on,
filter,
join_type,
null_equality,
left.output_ordering().cloned(),
right.output_ordering().cloned(),
StreamJoinPartitionMode::Partitioned,
)?;
let mut batches = vec![];
for i in 0..partition_count {
let stream = join.execute(i, Arc::clone(&context))?;
let more_batches = common::collect(stream).await?;
batches.extend(
more_batches
.into_iter()
.filter(|b| b.num_rows() > 0)
.collect::<Vec<_>>(),
);
}
Ok(batches)
}
pub async fn partitioned_hash_join_with_filter(
left: Arc<dyn ExecutionPlan>,
right: Arc<dyn ExecutionPlan>,
on: JoinOn,
filter: Option<JoinFilter>,
join_type: &JoinType,
null_equality: NullEquality,
context: Arc<TaskContext>,
) -> Result<Vec<RecordBatch>> {
let partition_count = 4;
let (left_expr, right_expr) = on
.iter()
.map(|(l, r)| (Arc::clone(l) as _, Arc::clone(r) as _))
.unzip();
let join = Arc::new(HashJoinExec::try_new(
Arc::new(RepartitionExec::try_new(
left,
Partitioning::Hash(left_expr, partition_count),
)?),
Arc::new(RepartitionExec::try_new(
right,
Partitioning::Hash(right_expr, partition_count),
)?),
on,
filter,
join_type,
None,
PartitionMode::Partitioned,
null_equality,
false, )?);
let mut batches = vec![];
for i in 0..partition_count {
let stream = join.execute(i, Arc::clone(&context))?;
let more_batches = common::collect(stream).await?;
batches.extend(
more_batches
.into_iter()
.filter(|b| b.num_rows() > 0)
.collect::<Vec<_>>(),
);
}
Ok(batches)
}
pub fn split_record_batches(
batch: &RecordBatch,
batch_size: usize,
) -> Result<Vec<RecordBatch>> {
let row_num = batch.num_rows();
let number_of_batch = row_num / batch_size;
let mut sizes = vec![batch_size; number_of_batch];
sizes.push(row_num - (batch_size * number_of_batch));
let mut result = vec![];
for (i, size) in sizes.iter().enumerate() {
result.push(batch.slice(i * batch_size, *size));
}
Ok(result)
}
struct AscendingRandomFloatIterator {
prev: f64,
max: f64,
rng: StdRng,
}
impl AscendingRandomFloatIterator {
fn new(min: f64, max: f64) -> Self {
let mut rng = StdRng::seed_from_u64(42);
let initial = rng.random_range(min..max);
AscendingRandomFloatIterator {
prev: initial,
max,
rng,
}
}
}
impl Iterator for AscendingRandomFloatIterator {
type Item = f64;
fn next(&mut self) -> Option<Self::Item> {
let value = self.rng.random_range(self.prev..self.max);
self.prev = value;
Some(value)
}
}
pub fn join_expr_tests_fixture_temporal(
expr_id: usize,
left_col: Arc<dyn PhysicalExpr>,
right_col: Arc<dyn PhysicalExpr>,
schema: &Schema,
) -> Result<Arc<dyn PhysicalExpr>> {
match expr_id {
0 => gen_conjunctive_temporal_expr(
left_col,
right_col,
Operator::Minus,
Operator::Minus,
Operator::Minus,
Operator::Minus,
ScalarValue::new_interval_dt(0, 100), ScalarValue::new_interval_dt(0, 200), ScalarValue::new_interval_dt(0, 450), ScalarValue::new_interval_dt(0, 300), schema,
),
1 => gen_conjunctive_temporal_expr(
left_col,
right_col,
Operator::Minus,
Operator::Minus,
Operator::Minus,
Operator::Minus,
ScalarValue::TimestampMillisecond(Some(1672574403000), None), ScalarValue::TimestampMillisecond(Some(1672574401000), None), ScalarValue::TimestampMillisecond(Some(1672574400000), None), ScalarValue::TimestampMillisecond(Some(1672574402000), None), schema,
),
2 => gen_conjunctive_temporal_expr(
left_col,
right_col,
Operator::Minus,
Operator::Minus,
Operator::Minus,
Operator::Minus,
ScalarValue::DurationMillisecond(Some(3000)), ScalarValue::DurationMillisecond(Some(2000)), ScalarValue::DurationMillisecond(Some(5000)), ScalarValue::DurationMillisecond(Some(4000)), schema,
),
_ => unreachable!(),
}
}
macro_rules! join_expr_tests {
($func_name:ident, $type:ty, $SCALAR:ident) => {
pub fn $func_name(
expr_id: usize,
left_col: Arc<dyn PhysicalExpr>,
right_col: Arc<dyn PhysicalExpr>,
) -> Arc<dyn PhysicalExpr> {
match expr_id {
0 => gen_conjunctive_numerical_expr(
left_col,
right_col,
(
Operator::Plus,
Operator::Plus,
Operator::Plus,
Operator::Plus,
),
ScalarValue::$SCALAR(Some(1 as $type)),
ScalarValue::$SCALAR(Some(5 as $type)),
ScalarValue::$SCALAR(Some(3 as $type)),
ScalarValue::$SCALAR(Some(10 as $type)),
(Operator::Gt, Operator::Lt),
),
1 => gen_conjunctive_numerical_expr(
left_col,
right_col,
(
Operator::Minus,
Operator::Plus,
Operator::Plus,
Operator::Plus,
),
ScalarValue::$SCALAR(Some(1 as $type)),
ScalarValue::$SCALAR(Some(3 as $type)),
ScalarValue::$SCALAR(Some(3 as $type)),
ScalarValue::$SCALAR(Some(15 as $type)),
(Operator::Gt, Operator::Lt),
),
2 => gen_conjunctive_numerical_expr(
left_col,
right_col,
(
Operator::Minus,
Operator::Plus,
Operator::Minus,
Operator::Plus,
),
ScalarValue::$SCALAR(Some(1 as $type)),
ScalarValue::$SCALAR(Some(5 as $type)),
ScalarValue::$SCALAR(Some(3 as $type)),
ScalarValue::$SCALAR(Some(10 as $type)),
(Operator::Gt, Operator::Lt),
),
3 => gen_conjunctive_numerical_expr(
left_col,
right_col,
(
Operator::Minus,
Operator::Minus,
Operator::Minus,
Operator::Plus,
),
ScalarValue::$SCALAR(Some(10 as $type)),
ScalarValue::$SCALAR(Some(5 as $type)),
ScalarValue::$SCALAR(Some(3 as $type)),
ScalarValue::$SCALAR(Some(10 as $type)),
(Operator::Gt, Operator::Lt),
),
4 => gen_conjunctive_numerical_expr(
left_col,
right_col,
(
Operator::Minus,
Operator::Minus,
Operator::Minus,
Operator::Minus,
),
ScalarValue::$SCALAR(Some(10 as $type)),
ScalarValue::$SCALAR(Some(5 as $type)),
ScalarValue::$SCALAR(Some(30 as $type)),
ScalarValue::$SCALAR(Some(3 as $type)),
(Operator::Gt, Operator::Lt),
),
5 => gen_conjunctive_numerical_expr(
left_col,
right_col,
(
Operator::Minus,
Operator::Plus,
Operator::Plus,
Operator::Minus,
),
ScalarValue::$SCALAR(Some(2 as $type)),
ScalarValue::$SCALAR(Some(5 as $type)),
ScalarValue::$SCALAR(Some(7 as $type)),
ScalarValue::$SCALAR(Some(3 as $type)),
(Operator::GtEq, Operator::LtEq),
),
6 => gen_conjunctive_numerical_expr(
left_col,
right_col,
(
Operator::Plus,
Operator::Minus,
Operator::Plus,
Operator::Plus,
),
ScalarValue::$SCALAR(Some(28 as $type)),
ScalarValue::$SCALAR(Some(11 as $type)),
ScalarValue::$SCALAR(Some(21 as $type)),
ScalarValue::$SCALAR(Some(39 as $type)),
(Operator::Gt, Operator::LtEq),
),
7 => gen_conjunctive_numerical_expr(
left_col,
right_col,
(
Operator::Plus,
Operator::Minus,
Operator::Minus,
Operator::Plus,
),
ScalarValue::$SCALAR(Some(28 as $type)),
ScalarValue::$SCALAR(Some(11 as $type)),
ScalarValue::$SCALAR(Some(21 as $type)),
ScalarValue::$SCALAR(Some(39 as $type)),
(Operator::GtEq, Operator::Lt),
),
_ => panic!("No case"),
}
}
};
}
join_expr_tests!(join_expr_tests_fixture_i32, i32, Int32);
join_expr_tests!(join_expr_tests_fixture_f64, f64, Float64);
pub fn build_sides_record_batches(
table_size: i32,
key_cardinality: (i32, i32),
) -> Result<(RecordBatch, RecordBatch)> {
let null_ratio: f64 = 0.4;
let duplicate_ratio = 0.4;
let initial_range = 0..table_size;
let index = (table_size as f64 * null_ratio).round() as i32;
let rest_of = index..table_size;
let ordered: ArrayRef = Arc::new(Int32Array::from_iter(
initial_range.clone().collect::<Vec<i32>>(),
));
let random_ordered = generate_ordered_array(table_size, duplicate_ratio);
let ordered_des = Arc::new(Int32Array::from_iter(
initial_range.clone().rev().collect::<Vec<i32>>(),
));
let cardinality = Arc::new(Int32Array::from_iter(
initial_range.clone().map(|x| x % 4).collect::<Vec<i32>>(),
));
let cardinality_key_left = Arc::new(Int32Array::from_iter(
initial_range
.clone()
.map(|x| x % key_cardinality.0)
.collect::<Vec<i32>>(),
));
let cardinality_key_right = Arc::new(Int32Array::from_iter(
initial_range
.clone()
.map(|x| x % key_cardinality.1)
.collect::<Vec<i32>>(),
));
let ordered_asc_null_first = Arc::new(Int32Array::from_iter({
std::iter::repeat_n(None, index as usize)
.chain(rest_of.clone().map(Some))
.collect::<Vec<Option<i32>>>()
}));
let ordered_asc_null_last = Arc::new(Int32Array::from_iter({
rest_of
.clone()
.map(Some)
.chain(std::iter::repeat_n(None, index as usize))
.collect::<Vec<Option<i32>>>()
}));
let ordered_desc_null_first = Arc::new(Int32Array::from_iter({
std::iter::repeat_n(None, index as usize)
.chain(rest_of.rev().map(Some))
.collect::<Vec<Option<i32>>>()
}));
let time = Arc::new(TimestampMillisecondArray::from(
initial_range
.clone()
.map(|x| x as i64 + 1672531200000) .collect::<Vec<i64>>(),
));
let interval_time: ArrayRef = Arc::new(IntervalDayTimeArray::from(
initial_range
.map(|x| IntervalDayTime {
days: 0,
milliseconds: x * 100,
}) .collect::<Vec<_>>(),
));
let float_asc = Arc::new(Float64Array::from_iter_values(
AscendingRandomFloatIterator::new(0., table_size as f64)
.take(table_size as usize),
));
let left = RecordBatch::try_from_iter(vec![
("la1", Arc::clone(&ordered)),
("lb1", Arc::clone(&cardinality) as ArrayRef),
("lc1", cardinality_key_left),
("lt1", Arc::clone(&time) as ArrayRef),
("la2", Arc::clone(&ordered)),
("la1_des", Arc::clone(&ordered_des) as ArrayRef),
(
"l_asc_null_first",
Arc::clone(&ordered_asc_null_first) as ArrayRef,
),
(
"l_asc_null_last",
Arc::clone(&ordered_asc_null_last) as ArrayRef,
),
(
"l_desc_null_first",
Arc::clone(&ordered_desc_null_first) as ArrayRef,
),
("li1", Arc::clone(&interval_time)),
("l_float", Arc::clone(&float_asc) as ArrayRef),
("l_random_ordered", Arc::clone(&random_ordered) as ArrayRef),
])?;
let right = RecordBatch::try_from_iter(vec![
("ra1", Arc::clone(&ordered)),
("rb1", cardinality),
("rc1", cardinality_key_right),
("rt1", time),
("ra2", ordered),
("ra1_des", ordered_des),
("r_asc_null_first", ordered_asc_null_first),
("r_asc_null_last", ordered_asc_null_last),
("r_desc_null_first", ordered_desc_null_first),
("ri1", interval_time),
("r_float", float_asc),
("r_random_ordered", random_ordered),
])?;
Ok((left, right))
}
pub fn create_memory_table(
left_partition: Vec<RecordBatch>,
right_partition: Vec<RecordBatch>,
left_sorted: Vec<LexOrdering>,
right_sorted: Vec<LexOrdering>,
) -> Result<(Arc<dyn ExecutionPlan>, Arc<dyn ExecutionPlan>)> {
let left_schema = left_partition[0].schema();
let left = TestMemoryExec::try_new(&[left_partition], left_schema, None)?
.try_with_sort_information(left_sorted)?;
let right_schema = right_partition[0].schema();
let right = TestMemoryExec::try_new(&[right_partition], right_schema, None)?
.try_with_sort_information(right_sorted)?;
let left = Arc::new(left);
let right = Arc::new(right);
Ok((
Arc::new(TestMemoryExec::update_cache(&left)),
Arc::new(TestMemoryExec::update_cache(&right)),
))
}
pub(crate) fn complicated_filter(
filter_schema: &Schema,
) -> Result<Arc<dyn PhysicalExpr>> {
let left_expr = binary(
cast(
binary(
col("0", filter_schema)?,
Operator::Plus,
col("1", filter_schema)?,
filter_schema,
)?,
filter_schema,
DataType::Int64,
)?,
Operator::Gt,
binary(
cast(col("2", filter_schema)?, filter_schema, DataType::Int64)?,
Operator::Plus,
lit(ScalarValue::Int64(Some(10))),
filter_schema,
)?,
filter_schema,
)?;
let right_expr = binary(
cast(
binary(
col("0", filter_schema)?,
Operator::Plus,
col("1", filter_schema)?,
filter_schema,
)?,
filter_schema,
DataType::Int64,
)?,
Operator::Lt,
binary(
cast(col("2", filter_schema)?, filter_schema, DataType::Int64)?,
Operator::Plus,
lit(ScalarValue::Int64(Some(100))),
filter_schema,
)?,
filter_schema,
)?;
binary(left_expr, Operator::And, right_expr, filter_schema)
}
fn generate_ordered_array(size: i32, duplicate_ratio: f32) -> Arc<Int32Array> {
let mut rng = StdRng::seed_from_u64(42);
let unique_count = (size as f32 * (1.0 - duplicate_ratio)) as i32;
let mut values: Vec<i32> = (0..unique_count)
.map(|_| rng.random_range(1..500)) .collect();
for _ in 0..(size - unique_count) {
let index = rng.random_range(0..unique_count);
values.push(values[index as usize]);
}
values.sort();
Arc::new(Int32Array::from_iter(values))
}