use crate::expr::{Sort, WindowFunction};
use crate::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion};
use crate::expr_visitor::{
inspect_expr_pre, ExprVisitable, ExpressionVisitor, Recursion,
};
use crate::logical_plan::builder::build_join_schema;
use crate::logical_plan::{
Aggregate, Analyze, CreateMemoryTable, CreateView, Distinct, Extension, Filter, Join,
Limit, Partitioning, Prepare, Projection, Repartition, Sort as SortPlan, Subquery,
SubqueryAlias, Union, Values, Window,
};
use crate::{
BinaryExpr, Cast, DmlStatement, Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder,
Operator, TableScan, TryCast,
};
use arrow::datatypes::{DataType, TimeUnit};
use datafusion_common::{
Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue,
};
use std::cmp::Ordering;
use std::collections::HashSet;
use std::sync::Arc;
pub const COUNT_STAR_EXPANSION: ScalarValue = ScalarValue::UInt8(Some(1));
pub fn exprlist_to_columns(expr: &[Expr], accum: &mut HashSet<Column>) -> Result<()> {
for e in expr {
expr_to_columns(e, accum)?;
}
Ok(())
}
pub fn grouping_set_expr_count(group_expr: &[Expr]) -> Result<usize> {
if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() {
if group_expr.len() > 1 {
return Err(DataFusionError::Plan(
"Invalid group by expressions, GroupingSet must be the only expression"
.to_string(),
));
}
Ok(grouping_set.distinct_expr().len())
} else {
Ok(group_expr.len())
}
}
pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result<Vec<Expr>> {
if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() {
if group_expr.len() > 1 {
return Err(DataFusionError::Plan(
"Invalid group by expressions, GroupingSet must be the only expression"
.to_string(),
));
}
Ok(grouping_set.distinct_expr())
} else {
Ok(group_expr.to_vec())
}
}
pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet<Column>) -> Result<()> {
inspect_expr_pre(expr, |expr| {
match expr {
Expr::Column(qc) => {
accum.insert(qc.clone());
}
Expr::ScalarVariable(_, var_names) => {
accum.insert(Column::from_name(var_names.join(".")));
}
Expr::Alias(_, _)
| Expr::Literal(_)
| Expr::BinaryExpr { .. }
| Expr::Like { .. }
| Expr::ILike { .. }
| Expr::SimilarTo { .. }
| Expr::Not(_)
| Expr::IsNotNull(_)
| Expr::IsNull(_)
| Expr::IsTrue(_)
| Expr::IsFalse(_)
| Expr::IsUnknown(_)
| Expr::IsNotTrue(_)
| Expr::IsNotFalse(_)
| Expr::IsNotUnknown(_)
| Expr::Negative(_)
| Expr::Between { .. }
| Expr::Case { .. }
| Expr::Cast { .. }
| Expr::TryCast { .. }
| Expr::Sort { .. }
| Expr::ScalarFunction { .. }
| Expr::ScalarUDF { .. }
| Expr::WindowFunction { .. }
| Expr::AggregateFunction { .. }
| Expr::GroupingSet(_)
| Expr::AggregateUDF { .. }
| Expr::InList { .. }
| Expr::Exists { .. }
| Expr::InSubquery { .. }
| Expr::ScalarSubquery(_)
| Expr::Wildcard
| Expr::QualifiedWildcard { .. }
| Expr::GetIndexedField { .. }
| Expr::Placeholder { .. } => {}
}
Ok(())
})
}
pub fn expand_wildcard(schema: &DFSchema, plan: &LogicalPlan) -> Result<Vec<Expr>> {
let using_columns = plan.using_columns()?;
let columns_to_skip = using_columns
.into_iter()
.flat_map(|cols| {
let mut cols = cols.into_iter().collect::<Vec<_>>();
cols.sort();
let mut out_column_names: HashSet<String> = HashSet::new();
cols.into_iter()
.filter_map(|c| {
if out_column_names.contains(&c.name) {
Some(c)
} else {
out_column_names.insert(c.name);
None
}
})
.collect::<Vec<_>>()
})
.collect::<HashSet<_>>();
if columns_to_skip.is_empty() {
Ok(schema
.fields()
.iter()
.map(|f| Expr::Column(f.qualified_column()))
.collect::<Vec<Expr>>())
} else {
Ok(schema
.fields()
.iter()
.filter_map(|f| {
let col = f.qualified_column();
if !columns_to_skip.contains(&col) {
Some(Expr::Column(col))
} else {
None
}
})
.collect::<Vec<Expr>>())
}
}
pub fn expand_qualified_wildcard(
qualifier: &str,
schema: &DFSchema,
) -> Result<Vec<Expr>> {
let qualified_fields: Vec<DFField> = schema
.fields_with_qualified(qualifier)
.into_iter()
.cloned()
.collect();
if qualified_fields.is_empty() {
return Err(DataFusionError::Plan(format!(
"Invalid qualifier {qualifier}"
)));
}
let qualified_schema =
DFSchema::new_with_metadata(qualified_fields, schema.metadata().clone())?;
Ok(qualified_schema
.fields()
.iter()
.map(|f| Expr::Column(f.qualified_column()))
.collect::<Vec<Expr>>())
}
type WindowSortKey = Vec<(Expr, bool)>;
pub fn generate_sort_key(
partition_by: &[Expr],
order_by: &[Expr],
) -> Result<WindowSortKey> {
let normalized_order_by_keys = order_by
.iter()
.map(|e| match e {
Expr::Sort(Sort { expr, .. }) => {
Ok(Expr::Sort(Sort::new(expr.clone(), true, false)))
}
_ => Err(DataFusionError::Plan(
"Order by only accepts sort expressions".to_string(),
)),
})
.collect::<Result<Vec<_>>>()?;
let mut final_sort_keys = vec![];
let mut is_partition_flag = vec![];
partition_by.iter().for_each(|e| {
let e = e.clone().sort(true, false);
if let Some(pos) = normalized_order_by_keys.iter().position(|key| key.eq(&e)) {
let order_by_key = &order_by[pos];
if !final_sort_keys.contains(order_by_key) {
final_sort_keys.push(order_by_key.clone());
is_partition_flag.push(true);
}
} else if !final_sort_keys.contains(&e) {
final_sort_keys.push(e);
is_partition_flag.push(true);
}
});
order_by.iter().for_each(|e| {
if !final_sort_keys.contains(e) {
final_sort_keys.push(e.clone());
is_partition_flag.push(false);
}
});
let res = final_sort_keys
.into_iter()
.zip(is_partition_flag)
.map(|(lhs, rhs)| (lhs, rhs))
.collect::<Vec<_>>();
Ok(res)
}
pub fn compare_sort_expr(
sort_expr_a: &Expr,
sort_expr_b: &Expr,
schema: &DFSchemaRef,
) -> Ordering {
match (sort_expr_a, sort_expr_b) {
(
Expr::Sort(Sort {
expr: expr_a,
asc: asc_a,
nulls_first: nulls_first_a,
}),
Expr::Sort(Sort {
expr: expr_b,
asc: asc_b,
nulls_first: nulls_first_b,
}),
) => {
let ref_indexes_a = find_column_indexes_referenced_by_expr(expr_a, schema);
let ref_indexes_b = find_column_indexes_referenced_by_expr(expr_b, schema);
for (idx_a, idx_b) in ref_indexes_a.iter().zip(ref_indexes_b.iter()) {
match idx_a.cmp(idx_b) {
Ordering::Less => {
return Ordering::Less;
}
Ordering::Greater => {
return Ordering::Greater;
}
Ordering::Equal => {}
}
}
match ref_indexes_a.len().cmp(&ref_indexes_b.len()) {
Ordering::Less => return Ordering::Greater,
Ordering::Greater => {
return Ordering::Less;
}
Ordering::Equal => {}
}
match (asc_a, asc_b) {
(true, false) => {
return Ordering::Greater;
}
(false, true) => {
return Ordering::Less;
}
_ => {}
}
match (nulls_first_a, nulls_first_b) {
(true, false) => {
return Ordering::Less;
}
(false, true) => {
return Ordering::Greater;
}
_ => {}
}
Ordering::Equal
}
_ => Ordering::Equal,
}
}
pub fn group_window_expr_by_sort_keys(
window_expr: &[Expr],
) -> Result<Vec<(WindowSortKey, Vec<&Expr>)>> {
let mut result = vec![];
window_expr.iter().try_for_each(|expr| match expr {
Expr::WindowFunction(WindowFunction{ partition_by, order_by, .. }) => {
let sort_key = generate_sort_key(partition_by, order_by)?;
if let Some((_, values)) = result.iter_mut().find(
|group: &&mut (WindowSortKey, Vec<&Expr>)| matches!(group, (key, _) if *key == sort_key),
) {
values.push(expr);
} else {
result.push((sort_key, vec![expr]))
}
Ok(())
}
other => Err(DataFusionError::Internal(format!(
"Impossibly got non-window expr {other:?}",
))),
})?;
Ok(result)
}
pub fn find_aggregate_exprs(exprs: &[Expr]) -> Vec<Expr> {
find_exprs_in_exprs(exprs, &|nested_expr| {
matches!(
nested_expr,
Expr::AggregateFunction { .. } | Expr::AggregateUDF { .. }
)
})
}
pub fn find_sort_exprs(exprs: &[Expr]) -> Vec<Expr> {
find_exprs_in_exprs(exprs, &|nested_expr| {
matches!(nested_expr, Expr::Sort { .. })
})
}
pub fn find_window_exprs(exprs: &[Expr]) -> Vec<Expr> {
find_exprs_in_exprs(exprs, &|nested_expr| {
matches!(nested_expr, Expr::WindowFunction { .. })
})
}
fn find_exprs_in_exprs<F>(exprs: &[Expr], test_fn: &F) -> Vec<Expr>
where
F: Fn(&Expr) -> bool,
{
exprs
.iter()
.flat_map(|expr| find_exprs_in_expr(expr, test_fn))
.fold(vec![], |mut acc, expr| {
if !acc.contains(&expr) {
acc.push(expr)
}
acc
})
}
fn find_exprs_in_expr<F>(expr: &Expr, test_fn: &F) -> Vec<Expr>
where
F: Fn(&Expr) -> bool,
{
let Finder { exprs, .. } = expr
.accept(Finder::new(test_fn))
.expect("no way to return error during recursion");
exprs
}
struct Finder<'a, F>
where
F: Fn(&Expr) -> bool,
{
test_fn: &'a F,
exprs: Vec<Expr>,
}
impl<'a, F> Finder<'a, F>
where
F: Fn(&Expr) -> bool,
{
fn new(test_fn: &'a F) -> Self {
Self {
test_fn,
exprs: Vec::new(),
}
}
}
impl<'a, F> ExpressionVisitor for Finder<'a, F>
where
F: Fn(&Expr) -> bool,
{
fn pre_visit(mut self, expr: &Expr) -> Result<Recursion<Self>> {
if (self.test_fn)(expr) {
if !(self.exprs.contains(expr)) {
self.exprs.push(expr.clone())
}
return Ok(Recursion::Stop(self));
}
Ok(Recursion::Continue(self))
}
}
pub fn from_plan(
plan: &LogicalPlan,
expr: &[Expr],
inputs: &[LogicalPlan],
) -> Result<LogicalPlan> {
match plan {
LogicalPlan::Projection(Projection { schema, .. }) => {
Ok(LogicalPlan::Projection(Projection::try_new_with_schema(
expr.to_vec(),
Arc::new(inputs[0].clone()),
schema.clone(),
)?))
}
LogicalPlan::Dml(DmlStatement {
table_name,
table_schema,
op,
..
}) => Ok(LogicalPlan::Dml(DmlStatement {
table_name: table_name.clone(),
table_schema: table_schema.clone(),
op: op.clone(),
input: Arc::new(inputs[0].clone()),
})),
LogicalPlan::Values(Values { schema, .. }) => Ok(LogicalPlan::Values(Values {
schema: schema.clone(),
values: expr
.chunks_exact(schema.fields().len())
.map(|s| s.to_vec())
.collect::<Vec<_>>(),
})),
LogicalPlan::Filter { .. } => {
assert_eq!(1, expr.len());
let predicate = expr[0].clone();
struct RemoveAliases {}
impl ExprRewriter for RemoveAliases {
fn pre_visit(&mut self, expr: &Expr) -> Result<RewriteRecursion> {
match expr {
Expr::Exists { .. }
| Expr::ScalarSubquery(_)
| Expr::InSubquery { .. } => {
Ok(RewriteRecursion::Stop)
}
Expr::Alias(_, _) => Ok(RewriteRecursion::Mutate),
_ => Ok(RewriteRecursion::Continue),
}
}
fn mutate(&mut self, expr: Expr) -> Result<Expr> {
Ok(expr.unalias())
}
}
let mut remove_aliases = RemoveAliases {};
let predicate = predicate.rewrite(&mut remove_aliases)?;
Ok(LogicalPlan::Filter(Filter::try_new(
predicate,
Arc::new(inputs[0].clone()),
)?))
}
LogicalPlan::Repartition(Repartition {
partitioning_scheme,
..
}) => match partitioning_scheme {
Partitioning::RoundRobinBatch(n) => {
Ok(LogicalPlan::Repartition(Repartition {
partitioning_scheme: Partitioning::RoundRobinBatch(*n),
input: Arc::new(inputs[0].clone()),
}))
}
Partitioning::Hash(_, n) => Ok(LogicalPlan::Repartition(Repartition {
partitioning_scheme: Partitioning::Hash(expr.to_owned(), *n),
input: Arc::new(inputs[0].clone()),
})),
Partitioning::DistributeBy(_) => Ok(LogicalPlan::Repartition(Repartition {
partitioning_scheme: Partitioning::DistributeBy(expr.to_owned()),
input: Arc::new(inputs[0].clone()),
})),
},
LogicalPlan::Window(Window {
window_expr,
schema,
..
}) => Ok(LogicalPlan::Window(Window {
input: Arc::new(inputs[0].clone()),
window_expr: expr[0..window_expr.len()].to_vec(),
schema: schema.clone(),
})),
LogicalPlan::Aggregate(Aggregate {
group_expr, schema, ..
}) => Ok(LogicalPlan::Aggregate(Aggregate::try_new_with_schema(
Arc::new(inputs[0].clone()),
expr[0..group_expr.len()].to_vec(),
expr[group_expr.len()..].to_vec(),
schema.clone(),
)?)),
LogicalPlan::Sort(SortPlan { fetch, .. }) => Ok(LogicalPlan::Sort(SortPlan {
expr: expr.to_vec(),
input: Arc::new(inputs[0].clone()),
fetch: *fetch,
})),
LogicalPlan::Join(Join {
join_type,
join_constraint,
on,
null_equals_null,
..
}) => {
let schema =
build_join_schema(inputs[0].schema(), inputs[1].schema(), join_type)?;
let equi_expr_count = on.len();
assert!(expr.len() >= equi_expr_count);
let new_on:Vec<(Expr,Expr)> = expr.iter().take(equi_expr_count).map(|equi_expr| {
let unalias_expr = equi_expr.clone().unalias();
if let Expr::BinaryExpr(BinaryExpr { left, op:Operator::Eq, right }) = unalias_expr {
Ok((*left, *right))
} else {
Err(DataFusionError::Internal(format!(
"The front part expressions should be an binary equiality expression, actual:{equi_expr}"
)))
}
}).collect::<Result<Vec<(Expr, Expr)>>>()?;
let filter_expr =
(expr.len() > equi_expr_count).then(|| expr[expr.len() - 1].clone());
Ok(LogicalPlan::Join(Join {
left: Arc::new(inputs[0].clone()),
right: Arc::new(inputs[1].clone()),
join_type: *join_type,
join_constraint: *join_constraint,
on: new_on,
filter: filter_expr,
schema: DFSchemaRef::new(schema),
null_equals_null: *null_equals_null,
}))
}
LogicalPlan::CrossJoin(_) => {
let left = inputs[0].clone();
let right = inputs[1].clone();
LogicalPlanBuilder::from(left).cross_join(right)?.build()
}
LogicalPlan::Subquery(_) => {
let subquery = LogicalPlanBuilder::from(inputs[0].clone()).build()?;
Ok(LogicalPlan::Subquery(Subquery {
subquery: Arc::new(subquery),
}))
}
LogicalPlan::SubqueryAlias(SubqueryAlias { alias, .. }) => {
let schema = inputs[0].schema().as_ref().clone().into();
let schema =
DFSchemaRef::new(DFSchema::try_from_qualified_schema(alias, &schema)?);
Ok(LogicalPlan::SubqueryAlias(SubqueryAlias {
alias: alias.clone(),
input: Arc::new(inputs[0].clone()),
schema,
}))
}
LogicalPlan::Limit(Limit { skip, fetch, .. }) => Ok(LogicalPlan::Limit(Limit {
skip: *skip,
fetch: *fetch,
input: Arc::new(inputs[0].clone()),
})),
LogicalPlan::CreateMemoryTable(CreateMemoryTable {
name,
if_not_exists,
or_replace,
..
}) => Ok(LogicalPlan::CreateMemoryTable(CreateMemoryTable {
input: Arc::new(inputs[0].clone()),
name: name.clone(),
if_not_exists: *if_not_exists,
or_replace: *or_replace,
})),
LogicalPlan::CreateView(CreateView {
name,
or_replace,
definition,
..
}) => Ok(LogicalPlan::CreateView(CreateView {
input: Arc::new(inputs[0].clone()),
name: name.clone(),
or_replace: *or_replace,
definition: definition.clone(),
})),
LogicalPlan::Extension(e) => Ok(LogicalPlan::Extension(Extension {
node: e.node.from_template(expr, inputs),
})),
LogicalPlan::Union(Union { schema, .. }) => Ok(LogicalPlan::Union(Union {
inputs: inputs.iter().cloned().map(Arc::new).collect(),
schema: schema.clone(),
})),
LogicalPlan::Distinct(Distinct { .. }) => Ok(LogicalPlan::Distinct(Distinct {
input: Arc::new(inputs[0].clone()),
})),
LogicalPlan::Analyze(a) => {
assert!(expr.is_empty());
assert_eq!(inputs.len(), 1);
Ok(LogicalPlan::Analyze(Analyze {
verbose: a.verbose,
schema: a.schema.clone(),
input: Arc::new(inputs[0].clone()),
}))
}
LogicalPlan::Explain(_) => {
if expr.is_empty() {
return Err(DataFusionError::Plan(
"Invalid EXPLAIN command. Expression is empty".to_string(),
));
}
if inputs.is_empty() {
return Err(DataFusionError::Plan(
"Invalid EXPLAIN command. Inputs are empty".to_string(),
));
}
Ok(plan.clone())
}
LogicalPlan::Prepare(Prepare {
name, data_types, ..
}) => Ok(LogicalPlan::Prepare(Prepare {
name: name.clone(),
data_types: data_types.clone(),
input: Arc::new(inputs[0].clone()),
})),
LogicalPlan::TableScan(ts) => {
assert!(inputs.is_empty(), "{plan:?} should have no inputs");
Ok(LogicalPlan::TableScan(TableScan {
filters: expr.to_vec(),
..ts.clone()
}))
}
LogicalPlan::EmptyRelation(_)
| LogicalPlan::CreateExternalTable(_)
| LogicalPlan::DropTable(_)
| LogicalPlan::DropView(_)
| LogicalPlan::SetVariable(_)
| LogicalPlan::CreateCatalogSchema(_)
| LogicalPlan::CreateCatalog(_) => {
assert!(expr.is_empty(), "{plan:?} should have no exprs");
assert!(inputs.is_empty(), "{plan:?} should have no inputs");
Ok(plan.clone())
}
LogicalPlan::DescribeTable(_) => Ok(plan.clone()),
}
}
fn agg_cols(agg: &Aggregate) -> Result<Vec<Column>> {
Ok(agg
.aggr_expr
.iter()
.chain(&agg.group_expr)
.flat_map(find_columns_referenced_by_expr)
.collect())
}
fn exprlist_to_fields_aggregate(
exprs: &[Expr],
plan: &LogicalPlan,
agg: &Aggregate,
) -> Result<Vec<DFField>> {
let agg_cols = agg_cols(agg)?;
let mut fields = vec![];
for expr in exprs {
match expr {
Expr::Column(c) if agg_cols.iter().any(|x| x == c) => {
fields.push(expr.to_field(agg.input.schema())?);
}
_ => fields.push(expr.to_field(plan.schema())?),
}
}
Ok(fields)
}
pub fn exprlist_to_fields<'a>(
expr: impl IntoIterator<Item = &'a Expr>,
plan: &LogicalPlan,
) -> Result<Vec<DFField>> {
let exprs: Vec<Expr> = expr.into_iter().cloned().collect();
let fields = match plan {
LogicalPlan::Aggregate(agg) => {
Some(exprlist_to_fields_aggregate(&exprs, plan, agg))
}
LogicalPlan::Window(window) => match window.input.as_ref() {
LogicalPlan::Aggregate(agg) => {
Some(exprlist_to_fields_aggregate(&exprs, plan, agg))
}
_ => None,
},
_ => None,
};
if let Some(fields) = fields {
fields
} else {
let input_schema = &plan.schema();
exprs.iter().map(|e| e.to_field(input_schema)).collect()
}
}
pub fn columnize_expr(e: Expr, input_schema: &DFSchema) -> Expr {
match e {
Expr::Column(_) => e,
Expr::Alias(inner_expr, name) => {
columnize_expr(*inner_expr, input_schema).alias(name)
}
Expr::Cast(Cast { expr, data_type }) => Expr::Cast(Cast {
expr: Box::new(columnize_expr(*expr, input_schema)),
data_type,
}),
Expr::TryCast(TryCast { expr, data_type }) => Expr::TryCast(TryCast::new(
Box::new(columnize_expr(*expr, input_schema)),
data_type,
)),
Expr::ScalarSubquery(_) => e.clone(),
_ => match e.display_name() {
Ok(name) => match input_schema.field_with_unqualified_name(&name) {
Ok(field) => Expr::Column(field.qualified_column()),
Err(_) => e,
},
Err(_) => e,
},
}
}
pub fn find_column_exprs(exprs: &[Expr]) -> Vec<Expr> {
exprs
.iter()
.flat_map(find_columns_referenced_by_expr)
.map(Expr::Column)
.collect()
}
pub(crate) fn find_columns_referenced_by_expr(e: &Expr) -> Vec<Column> {
let mut exprs = vec![];
inspect_expr_pre(e, |expr| {
if let Expr::Column(c) = expr {
exprs.push(c.clone())
}
Ok(()) as Result<()>
})
.expect("Unexpected error");
exprs
}
pub fn expr_as_column_expr(expr: &Expr, plan: &LogicalPlan) -> Result<Expr> {
match expr {
Expr::Column(col) => {
let field = plan.schema().field_from_column(col)?;
Ok(Expr::Column(field.qualified_column()))
}
_ => Ok(Expr::Column(Column::from_name(expr.display_name()?))),
}
}
pub(crate) fn find_column_indexes_referenced_by_expr(
e: &Expr,
schema: &DFSchemaRef,
) -> Vec<usize> {
let mut indexes = vec![];
inspect_expr_pre(e, |expr| {
match expr {
Expr::Column(qc) => {
if let Ok(idx) = schema.index_of_column(qc) {
indexes.push(idx);
}
}
Expr::Literal(_) => {
indexes.push(std::usize::MAX);
}
_ => {}
}
Ok(()) as Result<()>
})
.unwrap();
indexes
}
pub fn can_hash(data_type: &DataType) -> bool {
match data_type {
DataType::Null => true,
DataType::Boolean => true,
DataType::Int8 => true,
DataType::Int16 => true,
DataType::Int32 => true,
DataType::Int64 => true,
DataType::UInt8 => true,
DataType::UInt16 => true,
DataType::UInt32 => true,
DataType::UInt64 => true,
DataType::Float32 => true,
DataType::Float64 => true,
DataType::Timestamp(time_unit, None) => match time_unit {
TimeUnit::Second => true,
TimeUnit::Millisecond => true,
TimeUnit::Microsecond => true,
TimeUnit::Nanosecond => true,
},
DataType::Utf8 => true,
DataType::LargeUtf8 => true,
DataType::Decimal128(_, _) => true,
DataType::Date32 => true,
DataType::Date64 => true,
DataType::Dictionary(key_type, value_type)
if *value_type.as_ref() == DataType::Utf8 =>
{
DataType::is_dictionary_key_type(key_type)
}
_ => false,
}
}
pub fn check_all_column_from_schema(
columns: &HashSet<Column>,
schema: DFSchemaRef,
) -> bool {
columns
.iter()
.all(|column| schema.index_of_column(column).is_ok())
}
pub fn find_valid_equijoin_key_pair(
left_key: &Expr,
right_key: &Expr,
left_schema: DFSchemaRef,
right_schema: DFSchemaRef,
) -> Result<Option<(Expr, Expr)>> {
let left_using_columns = left_key.to_columns()?;
let right_using_columns = right_key.to_columns()?;
if left_using_columns.is_empty() || right_using_columns.is_empty() {
return Ok(None);
}
let l_is_left =
check_all_column_from_schema(&left_using_columns, left_schema.clone());
let r_is_right =
check_all_column_from_schema(&right_using_columns, right_schema.clone());
let r_is_left_and_l_is_right = || {
check_all_column_from_schema(&right_using_columns, left_schema.clone())
&& check_all_column_from_schema(&left_using_columns, right_schema.clone())
};
let join_key_pair = match (l_is_left, r_is_right) {
(true, true) => Some((left_key.clone(), right_key.clone())),
(_, _) if r_is_left_and_l_is_right() => {
Some((right_key.clone(), left_key.clone()))
}
_ => None,
};
Ok(join_key_pair)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{col, expr, AggregateFunction, WindowFrame, WindowFunction};
#[test]
fn test_group_window_expr_by_sort_keys_empty_case() -> Result<()> {
let result = group_window_expr_by_sort_keys(&[])?;
let expected: Vec<(WindowSortKey, Vec<&Expr>)> = vec![];
assert_eq!(expected, result);
Ok(())
}
#[test]
fn test_group_window_expr_by_sort_keys_empty_window() -> Result<()> {
let max1 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunction::AggregateFunction(AggregateFunction::Max),
vec![col("name")],
vec![],
vec![],
WindowFrame::new(false),
));
let max2 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunction::AggregateFunction(AggregateFunction::Max),
vec![col("name")],
vec![],
vec![],
WindowFrame::new(false),
));
let min3 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunction::AggregateFunction(AggregateFunction::Min),
vec![col("name")],
vec![],
vec![],
WindowFrame::new(false),
));
let sum4 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunction::AggregateFunction(AggregateFunction::Sum),
vec![col("age")],
vec![],
vec![],
WindowFrame::new(false),
));
let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
let result = group_window_expr_by_sort_keys(exprs)?;
let key = vec![];
let expected: Vec<(WindowSortKey, Vec<&Expr>)> =
vec![(key, vec![&max1, &max2, &min3, &sum4])];
assert_eq!(expected, result);
Ok(())
}
#[test]
fn test_group_window_expr_by_sort_keys() -> Result<()> {
let age_asc = Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true));
let name_desc = Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true));
let created_at_desc =
Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true));
let max1 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunction::AggregateFunction(AggregateFunction::Max),
vec![col("name")],
vec![],
vec![age_asc.clone(), name_desc.clone()],
WindowFrame::new(true),
));
let max2 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunction::AggregateFunction(AggregateFunction::Max),
vec![col("name")],
vec![],
vec![],
WindowFrame::new(false),
));
let min3 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunction::AggregateFunction(AggregateFunction::Min),
vec![col("name")],
vec![],
vec![age_asc.clone(), name_desc.clone()],
WindowFrame::new(true),
));
let sum4 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunction::AggregateFunction(AggregateFunction::Sum),
vec![col("age")],
vec![],
vec![name_desc.clone(), age_asc.clone(), created_at_desc.clone()],
WindowFrame::new(true),
));
let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
let result = group_window_expr_by_sort_keys(exprs)?;
let key1 = vec![(age_asc.clone(), false), (name_desc.clone(), false)];
let key2 = vec![];
let key3 = vec![
(name_desc, false),
(age_asc, false),
(created_at_desc, false),
];
let expected: Vec<(WindowSortKey, Vec<&Expr>)> = vec![
(key1, vec![&max1, &min3]),
(key2, vec![&max2]),
(key3, vec![&sum4]),
];
assert_eq!(expected, result);
Ok(())
}
#[test]
fn test_find_sort_exprs() -> Result<()> {
let exprs = &[
Expr::WindowFunction(expr::WindowFunction::new(
WindowFunction::AggregateFunction(AggregateFunction::Max),
vec![col("name")],
vec![],
vec![
Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)),
Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)),
],
WindowFrame::new(true),
)),
Expr::WindowFunction(expr::WindowFunction::new(
WindowFunction::AggregateFunction(AggregateFunction::Sum),
vec![col("age")],
vec![],
vec![
Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)),
Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)),
Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)),
],
WindowFrame::new(true),
)),
];
let expected = vec![
Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)),
Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)),
Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)),
];
let result = find_sort_exprs(exprs);
assert_eq!(expected, result);
Ok(())
}
#[test]
fn avoid_generate_duplicate_sort_keys() -> Result<()> {
let asc_or_desc = [true, false];
let nulls_first_or_last = [true, false];
let partition_by = &[col("age"), col("name"), col("created_at")];
for asc_ in asc_or_desc {
for nulls_first_ in nulls_first_or_last {
let order_by = &[
Expr::Sort(Sort {
expr: Box::new(col("age")),
asc: asc_,
nulls_first: nulls_first_,
}),
Expr::Sort(Sort {
expr: Box::new(col("name")),
asc: asc_,
nulls_first: nulls_first_,
}),
];
let expected = vec![
(
Expr::Sort(Sort {
expr: Box::new(col("age")),
asc: asc_,
nulls_first: nulls_first_,
}),
true,
),
(
Expr::Sort(Sort {
expr: Box::new(col("name")),
asc: asc_,
nulls_first: nulls_first_,
}),
true,
),
(
Expr::Sort(Sort {
expr: Box::new(col("created_at")),
asc: true,
nulls_first: false,
}),
true,
),
];
let result = generate_sort_key(partition_by, order_by)?;
assert_eq!(expected, result);
}
}
Ok(())
}
}