use std::collections::HashSet;
use std::sync::Arc;
use crate::optimizer::ApplyOrder;
use crate::{OptimizerConfig, OptimizerRule};
use arrow::datatypes::SchemaRef;
use datafusion_common::{
get_required_group_by_exprs_indices, internal_err, Column, DFSchema, DFSchemaRef,
JoinType, Result,
};
use datafusion_expr::expr::{Alias, ScalarFunction};
use datafusion_expr::{
logical_plan::LogicalPlan, projection_schema, Aggregate, BinaryExpr, Cast, Distinct,
Expr, Projection, TableScan, Window,
};
use datafusion_expr::utils::inspect_expr_pre;
use hashbrown::HashMap;
use itertools::{izip, Itertools};
#[derive(Default)]
pub struct OptimizeProjections {}
impl OptimizeProjections {
#[allow(missing_docs)]
pub fn new() -> Self {
Self {}
}
}
impl OptimizerRule for OptimizeProjections {
fn try_optimize(
&self,
plan: &LogicalPlan,
config: &dyn OptimizerConfig,
) -> Result<Option<LogicalPlan>> {
let indices = (0..plan.schema().fields().len()).collect::<Vec<_>>();
optimize_projections(plan, config, &indices)
}
fn name(&self) -> &str {
"optimize_projections"
}
fn apply_order(&self) -> Option<ApplyOrder> {
None
}
}
fn optimize_projections(
plan: &LogicalPlan,
config: &dyn OptimizerConfig,
indices: &[usize],
) -> Result<Option<LogicalPlan>> {
let child_required_indices: Vec<(Vec<usize>, bool)> = match plan {
LogicalPlan::Sort(_)
| LogicalPlan::Filter(_)
| LogicalPlan::Repartition(_)
| LogicalPlan::Unnest(_)
| LogicalPlan::Union(_)
| LogicalPlan::SubqueryAlias(_)
| LogicalPlan::Distinct(Distinct::On(_)) => {
let exprs = plan.expressions();
plan.inputs()
.into_iter()
.map(|input| {
get_all_required_indices(indices, input, exprs.iter())
.map(|idxs| (idxs, true))
})
.collect::<Result<_>>()?
}
LogicalPlan::Limit(_) | LogicalPlan::Prepare(_) => {
let exprs = plan.expressions();
plan.inputs()
.into_iter()
.map(|input| {
get_all_required_indices(indices, input, exprs.iter())
.map(|idxs| (idxs, false))
})
.collect::<Result<_>>()?
}
LogicalPlan::Copy(_)
| LogicalPlan::Ddl(_)
| LogicalPlan::Dml(_)
| LogicalPlan::Explain(_)
| LogicalPlan::Analyze(_)
| LogicalPlan::Subquery(_)
| LogicalPlan::Distinct(Distinct::All(_)) => {
plan.inputs()
.iter()
.map(|input| ((0..input.schema().fields().len()).collect_vec(), false))
.collect::<Vec<_>>()
}
LogicalPlan::Extension(extension) => {
let necessary_children_indices = if let Some(necessary_children_indices) =
extension.node.necessary_children_exprs(indices)
{
necessary_children_indices
} else {
return Ok(None);
};
let children = extension.node.inputs();
if children.len() != necessary_children_indices.len() {
return internal_err!("Inconsistent length between children and necessary children indices. \
Make sure `.necessary_children_exprs` implementation of the `UserDefinedLogicalNode` is \
consistent with actual children length for the node.");
}
let exprs = plan.expressions();
children
.into_iter()
.zip(necessary_children_indices)
.map(|(child, necessary_indices)| {
let child_schema = child.schema();
let child_req_indices =
indices_referred_by_exprs(child_schema, exprs.iter())?;
Ok((merge_slices(&necessary_indices, &child_req_indices), false))
})
.collect::<Result<Vec<_>>>()?
}
LogicalPlan::EmptyRelation(_)
| LogicalPlan::RecursiveQuery(_)
| LogicalPlan::Statement(_)
| LogicalPlan::Values(_)
| LogicalPlan::DescribeTable(_) => {
return Ok(None);
}
LogicalPlan::Projection(proj) => {
return if let Some(proj) = merge_consecutive_projections(proj)? {
Ok(Some(
rewrite_projection_given_requirements(&proj, config, indices)?
.unwrap_or_else(|| LogicalPlan::Projection(proj)),
))
} else {
rewrite_projection_given_requirements(proj, config, indices)
};
}
LogicalPlan::Aggregate(aggregate) => {
let n_group_exprs = aggregate.group_expr_len()?;
let (group_by_reqs, mut aggregate_reqs): (Vec<usize>, Vec<usize>) =
indices.iter().partition(|&&idx| idx < n_group_exprs);
for idx in aggregate_reqs.iter_mut() {
*idx -= n_group_exprs;
}
let group_by_expr_existing = aggregate
.group_expr
.iter()
.map(|group_by_expr| group_by_expr.display_name())
.collect::<Result<Vec<_>>>()?;
let new_group_bys = if let Some(simplest_groupby_indices) =
get_required_group_by_exprs_indices(
aggregate.input.schema(),
&group_by_expr_existing,
) {
let required_indices =
merge_slices(&simplest_groupby_indices, &group_by_reqs);
get_at_indices(&aggregate.group_expr, &required_indices)
} else {
aggregate.group_expr.clone()
};
let mut new_aggr_expr = get_at_indices(&aggregate.aggr_expr, &aggregate_reqs);
if new_aggr_expr.is_empty()
&& new_group_bys.is_empty()
&& !aggregate.aggr_expr.is_empty()
{
new_aggr_expr = vec![aggregate.aggr_expr[0].clone()];
}
let all_exprs_iter = new_group_bys.iter().chain(new_aggr_expr.iter());
let schema = aggregate.input.schema();
let necessary_indices = indices_referred_by_exprs(schema, all_exprs_iter)?;
let aggregate_input = if let Some(input) =
optimize_projections(&aggregate.input, config, &necessary_indices)?
{
input
} else {
aggregate.input.as_ref().clone()
};
let necessary_exprs = get_required_exprs(schema, &necessary_indices);
let (aggregate_input, _) =
add_projection_on_top_if_helpful(aggregate_input, necessary_exprs)?;
return Aggregate::try_new(
Arc::new(aggregate_input),
new_group_bys,
new_aggr_expr,
)
.map(|aggregate| Some(LogicalPlan::Aggregate(aggregate)));
}
LogicalPlan::Window(window) => {
let n_input_fields = window.input.schema().fields().len();
let (child_reqs, mut window_reqs): (Vec<usize>, Vec<usize>) =
indices.iter().partition(|&&idx| idx < n_input_fields);
for idx in window_reqs.iter_mut() {
*idx -= n_input_fields;
}
let new_window_expr = get_at_indices(&window.window_expr, &window_reqs);
let required_indices = get_all_required_indices(
&child_reqs,
&window.input,
new_window_expr.iter(),
)?;
let window_child = if let Some(new_window_child) =
optimize_projections(&window.input, config, &required_indices)?
{
new_window_child
} else {
window.input.as_ref().clone()
};
return if new_window_expr.is_empty() {
Ok(Some(window_child))
} else {
let required_exprs =
get_required_exprs(window.input.schema(), &required_indices);
let (window_child, _) =
add_projection_on_top_if_helpful(window_child, required_exprs)?;
Window::try_new(new_window_expr, Arc::new(window_child))
.map(|window| Some(LogicalPlan::Window(window)))
};
}
LogicalPlan::Join(join) => {
let left_len = join.left.schema().fields().len();
let (left_req_indices, right_req_indices) =
split_join_requirements(left_len, indices, &join.join_type);
let exprs = plan.expressions();
let left_indices =
get_all_required_indices(&left_req_indices, &join.left, exprs.iter())?;
let right_indices =
get_all_required_indices(&right_req_indices, &join.right, exprs.iter())?;
vec![(left_indices, true), (right_indices, true)]
}
LogicalPlan::CrossJoin(cross_join) => {
let left_len = cross_join.left.schema().fields().len();
let (left_child_indices, right_child_indices) =
split_join_requirements(left_len, indices, &JoinType::Inner);
vec![(left_child_indices, true), (right_child_indices, true)]
}
LogicalPlan::TableScan(table_scan) => {
let schema = table_scan.source.schema();
let projection = with_indices(&table_scan.projection, schema, |map| {
indices.iter().map(|&idx| map[idx]).collect()
});
return TableScan::try_new(
table_scan.table_name.clone(),
table_scan.source.clone(),
Some(projection),
table_scan.filters.clone(),
table_scan.fetch,
)
.map(|table| Some(LogicalPlan::TableScan(table)));
}
};
let new_inputs = izip!(child_required_indices, plan.inputs().into_iter())
.map(|((required_indices, projection_beneficial), child)| {
let (input, is_changed) = if let Some(new_input) =
optimize_projections(child, config, &required_indices)?
{
(new_input, true)
} else {
(child.clone(), false)
};
let project_exprs = get_required_exprs(child.schema(), &required_indices);
let (input, proj_added) = if projection_beneficial {
add_projection_on_top_if_helpful(input, project_exprs)?
} else {
(input, false)
};
Ok((is_changed || proj_added).then_some(input))
})
.collect::<Result<Vec<_>>>()?;
if new_inputs.iter().all(|child| child.is_none()) {
Ok(None)
} else {
let new_inputs = izip!(new_inputs, plan.inputs())
.map(|(new_input, old_child)| new_input.unwrap_or_else(|| old_child.clone()))
.collect();
let exprs = plan.expressions();
plan.with_new_exprs(exprs, new_inputs).map(Some)
}
}
fn with_indices<F>(
proj_indices: &Option<Vec<usize>>,
schema: SchemaRef,
mut f: F,
) -> Vec<usize>
where
F: FnMut(&[usize]) -> Vec<usize>,
{
match proj_indices {
Some(indices) => f(indices.as_slice()),
None => {
let range: Vec<usize> = (0..schema.fields.len()).collect();
f(range.as_slice())
}
}
}
fn merge_consecutive_projections(proj: &Projection) -> Result<Option<Projection>> {
let LogicalPlan::Projection(prev_projection) = proj.input.as_ref() else {
return Ok(None);
};
let mut column_referral_map = HashMap::<Column, usize>::new();
for columns in proj.expr.iter().flat_map(|expr| expr.to_columns()) {
for col in columns.into_iter() {
*column_referral_map.entry(col.clone()).or_default() += 1;
}
}
if column_referral_map.into_iter().any(|(col, usage)| {
usage > 1
&& !is_expr_trivial(
&prev_projection.expr
[prev_projection.schema.index_of_column(&col).unwrap()],
)
}) {
return Ok(None);
}
let new_exprs = proj
.expr
.iter()
.map(|expr| rewrite_expr(expr, prev_projection))
.collect::<Result<Option<Vec<_>>>>()?;
if let Some(new_exprs) = new_exprs {
let new_exprs = new_exprs
.into_iter()
.zip(proj.expr.iter())
.map(|(new_expr, old_expr)| {
new_expr.alias_if_changed(old_expr.name_for_alias()?)
})
.collect::<Result<Vec<_>>>()?;
Projection::try_new(new_exprs, prev_projection.input.clone()).map(Some)
} else {
Ok(None)
}
}
fn trim_expr(expr: Expr) -> Expr {
match expr {
Expr::Alias(alias) => trim_expr(*alias.expr),
_ => expr,
}
}
fn is_expr_trivial(expr: &Expr) -> bool {
matches!(expr, Expr::Column(_) | Expr::Literal(_))
}
macro_rules! rewrite_expr_with_check {
($expr:expr, $input:expr) => {
if let Some(value) = rewrite_expr($expr, $input)? {
value
} else {
return Ok(None);
}
};
}
fn rewrite_expr(expr: &Expr, input: &Projection) -> Result<Option<Expr>> {
let result = match expr {
Expr::Column(col) => {
let idx = input.schema.index_of_column(col)?;
input.expr[idx].clone()
}
Expr::BinaryExpr(binary) => Expr::BinaryExpr(BinaryExpr::new(
Box::new(trim_expr(rewrite_expr_with_check!(&binary.left, input))),
binary.op,
Box::new(trim_expr(rewrite_expr_with_check!(&binary.right, input))),
)),
Expr::Alias(alias) => Expr::Alias(Alias::new(
trim_expr(rewrite_expr_with_check!(&alias.expr, input)),
alias.relation.clone(),
alias.name.clone(),
)),
Expr::Literal(_) => expr.clone(),
Expr::Cast(cast) => {
let new_expr = rewrite_expr_with_check!(&cast.expr, input);
Expr::Cast(Cast::new(Box::new(new_expr), cast.data_type.clone()))
}
Expr::ScalarFunction(scalar_fn) => {
return Ok(scalar_fn
.args
.iter()
.map(|expr| rewrite_expr(expr, input))
.collect::<Result<Option<_>>>()?
.map(|new_args| {
Expr::ScalarFunction(ScalarFunction::new_func_def(
scalar_fn.func_def.clone(),
new_args,
))
}));
}
_ => return Ok(None),
};
Ok(Some(result))
}
fn outer_columns(expr: &Expr, columns: &mut HashSet<Column>) {
inspect_expr_pre(expr, |expr| {
match expr {
Expr::OuterReferenceColumn(_, col) => {
columns.insert(col.clone());
}
Expr::ScalarSubquery(subquery) => {
outer_columns_helper_multi(&subquery.outer_ref_columns, columns);
}
Expr::Exists(exists) => {
outer_columns_helper_multi(&exists.subquery.outer_ref_columns, columns);
}
Expr::InSubquery(insubquery) => {
outer_columns_helper_multi(
&insubquery.subquery.outer_ref_columns,
columns,
);
}
_ => {}
};
Ok(()) as Result<()>
})
.unwrap();
}
fn outer_columns_helper_multi<'a>(
exprs: impl IntoIterator<Item = &'a Expr>,
columns: &mut HashSet<Column>,
) {
exprs.into_iter().for_each(|e| outer_columns(e, columns));
}
fn get_required_exprs(input_schema: &Arc<DFSchema>, indices: &[usize]) -> Vec<Expr> {
let fields = input_schema.fields();
indices
.iter()
.map(|&idx| Expr::Column(fields[idx].qualified_column()))
.collect()
}
fn indices_referred_by_exprs<'a>(
input_schema: &DFSchemaRef,
exprs: impl Iterator<Item = &'a Expr>,
) -> Result<Vec<usize>> {
let indices = exprs
.map(|expr| indices_referred_by_expr(input_schema, expr))
.collect::<Result<Vec<_>>>()?;
Ok(indices
.into_iter()
.flatten()
.sorted()
.dedup()
.collect())
}
fn indices_referred_by_expr(
input_schema: &DFSchemaRef,
expr: &Expr,
) -> Result<Vec<usize>> {
let mut cols = expr.to_columns()?;
outer_columns(expr, &mut cols);
Ok(cols
.iter()
.flat_map(|col| input_schema.index_of_column(col))
.collect())
}
fn get_all_required_indices<'a>(
parent_required_indices: &[usize],
input: &LogicalPlan,
exprs: impl Iterator<Item = &'a Expr>,
) -> Result<Vec<usize>> {
indices_referred_by_exprs(input.schema(), exprs)
.map(|indices| merge_slices(parent_required_indices, &indices))
}
fn get_at_indices(exprs: &[Expr], indices: &[usize]) -> Vec<Expr> {
indices
.iter()
.filter_map(|&idx| exprs.get(idx).cloned())
.collect()
}
fn merge_slices<T: Clone + Ord>(left: &[T], right: &[T]) -> Vec<T> {
left.iter()
.cloned()
.chain(right.iter().cloned())
.sorted()
.dedup()
.collect()
}
fn split_join_requirements(
left_len: usize,
indices: &[usize],
join_type: &JoinType,
) -> (Vec<usize>, Vec<usize>) {
match join_type {
JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => {
let (left_reqs, mut right_reqs): (Vec<usize>, Vec<usize>) =
indices.iter().partition(|&&idx| idx < left_len);
for idx in right_reqs.iter_mut() {
*idx -= left_len;
}
(left_reqs, right_reqs)
}
JoinType::LeftAnti | JoinType::LeftSemi => (indices.to_vec(), vec![]),
JoinType::RightSemi | JoinType::RightAnti => (vec![], indices.to_vec()),
}
}
fn add_projection_on_top_if_helpful(
plan: LogicalPlan,
project_exprs: Vec<Expr>,
) -> Result<(LogicalPlan, bool)> {
if project_exprs.len() >= plan.schema().fields().len() {
Ok((plan, false))
} else {
Projection::try_new(project_exprs, Arc::new(plan))
.map(|proj| (LogicalPlan::Projection(proj), true))
}
}
fn rewrite_projection_given_requirements(
proj: &Projection,
config: &dyn OptimizerConfig,
indices: &[usize],
) -> Result<Option<LogicalPlan>> {
let exprs_used = get_at_indices(&proj.expr, indices);
let required_indices =
indices_referred_by_exprs(proj.input.schema(), exprs_used.iter())?;
return if let Some(input) =
optimize_projections(&proj.input, config, &required_indices)?
{
if is_projection_unnecessary(&input, &exprs_used)? {
Ok(Some(input))
} else {
Projection::try_new(exprs_used, Arc::new(input))
.map(|proj| Some(LogicalPlan::Projection(proj)))
}
} else if exprs_used.len() < proj.expr.len() {
if is_projection_unnecessary(&proj.input, &exprs_used)? {
Ok(Some(proj.input.as_ref().clone()))
} else {
Projection::try_new(exprs_used, proj.input.clone())
.map(|proj| Some(LogicalPlan::Projection(proj)))
}
} else {
Ok(None)
};
}
fn is_projection_unnecessary(input: &LogicalPlan, proj_exprs: &[Expr]) -> Result<bool> {
Ok(&projection_schema(input, proj_exprs)? == input.schema()
&& proj_exprs.iter().all(is_expr_trivial))
}
#[cfg(test)]
mod tests {
use std::fmt::Formatter;
use std::sync::Arc;
use crate::optimize_projections::OptimizeProjections;
use crate::test::{
assert_optimized_plan_eq, test_table_scan, test_table_scan_with_name,
};
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::{Column, DFSchemaRef, JoinType, Result, TableReference};
use datafusion_expr::{
binary_expr, build_join_schema, col, count, lit,
logical_plan::builder::LogicalPlanBuilder, not, table_scan, try_cast, when,
BinaryExpr, Expr, Extension, Like, LogicalPlan, Operator,
UserDefinedLogicalNodeCore,
};
fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> {
assert_optimized_plan_eq(Arc::new(OptimizeProjections::new()), plan, expected)
}
#[derive(Debug, Hash, PartialEq, Eq)]
struct NoOpUserDefined {
exprs: Vec<Expr>,
schema: DFSchemaRef,
input: Arc<LogicalPlan>,
}
impl NoOpUserDefined {
fn new(schema: DFSchemaRef, input: Arc<LogicalPlan>) -> Self {
Self {
exprs: vec![],
schema,
input,
}
}
fn with_exprs(mut self, exprs: Vec<Expr>) -> Self {
self.exprs = exprs;
self
}
}
impl UserDefinedLogicalNodeCore for NoOpUserDefined {
fn name(&self) -> &str {
"NoOpUserDefined"
}
fn inputs(&self) -> Vec<&LogicalPlan> {
vec![&self.input]
}
fn schema(&self) -> &DFSchemaRef {
&self.schema
}
fn expressions(&self) -> Vec<Expr> {
self.exprs.clone()
}
fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
write!(f, "NoOpUserDefined")
}
fn from_template(&self, exprs: &[Expr], inputs: &[LogicalPlan]) -> Self {
Self {
exprs: exprs.to_vec(),
input: Arc::new(inputs[0].clone()),
schema: self.schema.clone(),
}
}
fn necessary_children_exprs(
&self,
output_columns: &[usize],
) -> Option<Vec<Vec<usize>>> {
Some(vec![output_columns.to_vec()])
}
}
#[derive(Debug, Hash, PartialEq, Eq)]
struct UserDefinedCrossJoin {
exprs: Vec<Expr>,
schema: DFSchemaRef,
left_child: Arc<LogicalPlan>,
right_child: Arc<LogicalPlan>,
}
impl UserDefinedCrossJoin {
fn new(left_child: Arc<LogicalPlan>, right_child: Arc<LogicalPlan>) -> Self {
let left_schema = left_child.schema();
let right_schema = right_child.schema();
let schema = Arc::new(
build_join_schema(left_schema, right_schema, &JoinType::Inner).unwrap(),
);
Self {
exprs: vec![],
schema,
left_child,
right_child,
}
}
}
impl UserDefinedLogicalNodeCore for UserDefinedCrossJoin {
fn name(&self) -> &str {
"UserDefinedCrossJoin"
}
fn inputs(&self) -> Vec<&LogicalPlan> {
vec![&self.left_child, &self.right_child]
}
fn schema(&self) -> &DFSchemaRef {
&self.schema
}
fn expressions(&self) -> Vec<Expr> {
self.exprs.clone()
}
fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
write!(f, "UserDefinedCrossJoin")
}
fn from_template(&self, exprs: &[Expr], inputs: &[LogicalPlan]) -> Self {
assert_eq!(inputs.len(), 2);
Self {
exprs: exprs.to_vec(),
left_child: Arc::new(inputs[0].clone()),
right_child: Arc::new(inputs[1].clone()),
schema: self.schema.clone(),
}
}
fn necessary_children_exprs(
&self,
output_columns: &[usize],
) -> Option<Vec<Vec<usize>>> {
let left_child_len = self.left_child.schema().fields().len();
let mut left_reqs = vec![];
let mut right_reqs = vec![];
for &out_idx in output_columns {
if out_idx < left_child_len {
left_reqs.push(out_idx);
} else {
right_reqs.push(out_idx - left_child_len)
}
}
Some(vec![left_reqs, right_reqs])
}
}
#[test]
fn merge_two_projection() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a")])?
.project(vec![binary_expr(lit(1), Operator::Plus, col("a"))])?
.build()?;
let expected = "Projection: Int32(1) + test.a\
\n TableScan: test projection=[a]";
assert_optimized_plan_equal(&plan, expected)
}
#[test]
fn merge_three_projection() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b")])?
.project(vec![col("a")])?
.project(vec![binary_expr(lit(1), Operator::Plus, col("a"))])?
.build()?;
let expected = "Projection: Int32(1) + test.a\
\n TableScan: test projection=[a]";
assert_optimized_plan_equal(&plan, expected)
}
#[test]
fn merge_alias() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a")])?
.project(vec![col("a").alias("alias")])?
.build()?;
let expected = "Projection: test.a AS alias\
\n TableScan: test projection=[a]";
assert_optimized_plan_equal(&plan, expected)
}
#[test]
fn merge_nested_alias() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").alias("alias1").alias("alias2")])?
.project(vec![col("alias2").alias("alias")])?
.build()?;
let expected = "Projection: test.a AS alias\
\n TableScan: test projection=[a]";
assert_optimized_plan_equal(&plan, expected)
}
#[test]
fn test_nested_count() -> Result<()> {
let schema = Schema::new(vec![Field::new("foo", DataType::Int32, false)]);
let groups: Vec<Expr> = vec![];
let plan = table_scan(TableReference::none(), &schema, None)
.unwrap()
.aggregate(groups.clone(), vec![count(lit(1))])
.unwrap()
.aggregate(groups, vec![count(lit(1))])
.unwrap()
.build()
.unwrap();
let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(Int32(1))]]\
\n Projection: \
\n Aggregate: groupBy=[[]], aggr=[[COUNT(Int32(1))]]\
\n TableScan: ?table? projection=[]";
assert_optimized_plan_equal(&plan, expected)
}
#[test]
fn test_struct_field_push_down() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int64, false),
Field::new_struct(
"s",
vec![
Field::new("x", DataType::Int64, false),
Field::new("y", DataType::Int64, false),
],
false,
),
]));
let table_scan = table_scan(TableReference::none(), &schema, None)?.build()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("s").field("x")])?
.build()?;
let expected = "Projection: (?table?.s)[x]\
\n TableScan: ?table? projection=[s]";
assert_optimized_plan_equal(&plan, expected)
}
#[test]
fn test_neg_push_down() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![-col("a")])?
.build()?;
let expected = "Projection: (- test.a)\
\n TableScan: test projection=[a]";
assert_optimized_plan_equal(&plan, expected)
}
#[test]
fn test_is_null() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").is_null()])?
.build()?;
let expected = "Projection: test.a IS NULL\
\n TableScan: test projection=[a]";
assert_optimized_plan_equal(&plan, expected)
}
#[test]
fn test_is_not_null() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").is_not_null()])?
.build()?;
let expected = "Projection: test.a IS NOT NULL\
\n TableScan: test projection=[a]";
assert_optimized_plan_equal(&plan, expected)
}
#[test]
fn test_is_true() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").is_true()])?
.build()?;
let expected = "Projection: test.a IS TRUE\
\n TableScan: test projection=[a]";
assert_optimized_plan_equal(&plan, expected)
}
#[test]
fn test_is_not_true() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").is_not_true()])?
.build()?;
let expected = "Projection: test.a IS NOT TRUE\
\n TableScan: test projection=[a]";
assert_optimized_plan_equal(&plan, expected)
}
#[test]
fn test_is_false() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").is_false()])?
.build()?;
let expected = "Projection: test.a IS FALSE\
\n TableScan: test projection=[a]";
assert_optimized_plan_equal(&plan, expected)
}
#[test]
fn test_is_not_false() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").is_not_false()])?
.build()?;
let expected = "Projection: test.a IS NOT FALSE\
\n TableScan: test projection=[a]";
assert_optimized_plan_equal(&plan, expected)
}
#[test]
fn test_is_unknown() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").is_unknown()])?
.build()?;
let expected = "Projection: test.a IS UNKNOWN\
\n TableScan: test projection=[a]";
assert_optimized_plan_equal(&plan, expected)
}
#[test]
fn test_is_not_unknown() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").is_not_unknown()])?
.build()?;
let expected = "Projection: test.a IS NOT UNKNOWN\
\n TableScan: test projection=[a]";
assert_optimized_plan_equal(&plan, expected)
}
#[test]
fn test_not() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![not(col("a"))])?
.build()?;
let expected = "Projection: NOT test.a\
\n TableScan: test projection=[a]";
assert_optimized_plan_equal(&plan, expected)
}
#[test]
fn test_try_cast() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![try_cast(col("a"), DataType::Float64)])?
.build()?;
let expected = "Projection: TRY_CAST(test.a AS Float64)\
\n TableScan: test projection=[a]";
assert_optimized_plan_equal(&plan, expected)
}
#[test]
fn test_similar_to() -> Result<()> {
let table_scan = test_table_scan()?;
let expr = Box::new(col("a"));
let pattern = Box::new(lit("[0-9]"));
let similar_to_expr =
Expr::SimilarTo(Like::new(false, expr, pattern, None, false));
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![similar_to_expr])?
.build()?;
let expected = "Projection: test.a SIMILAR TO Utf8(\"[0-9]\")\
\n TableScan: test projection=[a]";
assert_optimized_plan_equal(&plan, expected)
}
#[test]
fn test_between() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").between(lit(1), lit(3))])?
.build()?;
let expected = "Projection: test.a BETWEEN Int32(1) AND Int32(3)\
\n TableScan: test projection=[a]";
assert_optimized_plan_equal(&plan, expected)
}
#[test]
fn test_derived_column() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), lit(0).alias("d")])?
.project(vec![
col("a"),
when(col("a").eq(lit(1)), lit(10))
.otherwise(col("d"))?
.alias("d"),
])?
.build()?;
let expected = "Projection: test.a, CASE WHEN test.a = Int32(1) THEN Int32(10) ELSE d END AS d\
\n Projection: test.a, Int32(0) AS d\
\n TableScan: test projection=[a]";
assert_optimized_plan_equal(&plan, expected)
}
#[test]
fn test_user_defined_logical_plan_node() -> Result<()> {
let table_scan = test_table_scan()?;
let custom_plan = LogicalPlan::Extension(Extension {
node: Arc::new(NoOpUserDefined::new(
table_scan.schema().clone(),
Arc::new(table_scan.clone()),
)),
});
let plan = LogicalPlanBuilder::from(custom_plan)
.project(vec![col("a"), lit(0).alias("d")])?
.build()?;
let expected = "Projection: test.a, Int32(0) AS d\
\n NoOpUserDefined\
\n TableScan: test projection=[a]";
assert_optimized_plan_equal(&plan, expected)
}
#[test]
fn test_user_defined_logical_plan_node2() -> Result<()> {
let table_scan = test_table_scan()?;
let exprs = vec![Expr::Column(Column::from_qualified_name("b"))];
let custom_plan = LogicalPlan::Extension(Extension {
node: Arc::new(
NoOpUserDefined::new(
table_scan.schema().clone(),
Arc::new(table_scan.clone()),
)
.with_exprs(exprs),
),
});
let plan = LogicalPlanBuilder::from(custom_plan)
.project(vec![col("a"), lit(0).alias("d")])?
.build()?;
let expected = "Projection: test.a, Int32(0) AS d\
\n NoOpUserDefined\
\n TableScan: test projection=[a, b]";
assert_optimized_plan_equal(&plan, expected)
}
#[test]
fn test_user_defined_logical_plan_node3() -> Result<()> {
let table_scan = test_table_scan()?;
let left_expr = Expr::Column(Column::from_qualified_name("b"));
let right_expr = Expr::Column(Column::from_qualified_name("c"));
let binary_expr = Expr::BinaryExpr(BinaryExpr::new(
Box::new(left_expr),
Operator::Plus,
Box::new(right_expr),
));
let exprs = vec![binary_expr];
let custom_plan = LogicalPlan::Extension(Extension {
node: Arc::new(
NoOpUserDefined::new(
table_scan.schema().clone(),
Arc::new(table_scan.clone()),
)
.with_exprs(exprs),
),
});
let plan = LogicalPlanBuilder::from(custom_plan)
.project(vec![col("a"), lit(0).alias("d")])?
.build()?;
let expected = "Projection: test.a, Int32(0) AS d\
\n NoOpUserDefined\
\n TableScan: test projection=[a, b, c]";
assert_optimized_plan_equal(&plan, expected)
}
#[test]
fn test_user_defined_logical_plan_node4() -> Result<()> {
let left_table = test_table_scan_with_name("l")?;
let right_table = test_table_scan_with_name("r")?;
let custom_plan = LogicalPlan::Extension(Extension {
node: Arc::new(UserDefinedCrossJoin::new(
Arc::new(left_table.clone()),
Arc::new(right_table.clone()),
)),
});
let plan = LogicalPlanBuilder::from(custom_plan)
.project(vec![col("l.a"), col("l.c"), col("r.a"), lit(0).alias("d")])?
.build()?;
let expected = "Projection: l.a, l.c, r.a, Int32(0) AS d\
\n UserDefinedCrossJoin\
\n TableScan: l projection=[a, c]\
\n TableScan: r projection=[a]";
assert_optimized_plan_equal(&plan, expected)
}
}