use std::cmp::Ordering;
use std::collections::{HashMap, HashSet};
use std::fmt::{self, Debug, Display, Formatter};
use std::hash::{Hash, Hasher};
use std::sync::{Arc, LazyLock};
use super::DdlStatement;
use super::dml::CopyTo;
use super::invariants::{
InvariantLevel, assert_always_invariants_at_current_node,
assert_executable_invariants,
};
use crate::builder::{unique_field_aliases, unnest_with_options};
use crate::expr::{
Alias, Placeholder, Sort as SortExpr, WindowFunction, WindowFunctionParams,
intersect_metadata_for_union,
};
use crate::expr_rewriter::{
NamePreserver, create_col_from_scalar_expr, normalize_cols, normalize_sorts,
};
use crate::logical_plan::display::{GraphvizVisitor, IndentVisitor};
use crate::logical_plan::extension::UserDefinedLogicalNode;
use crate::logical_plan::{DmlStatement, Statement};
use crate::utils::{
enumerate_grouping_sets, exprlist_to_fields, find_out_reference_exprs,
grouping_set_expr_count, grouping_set_to_exprlist, split_conjunction,
};
use crate::{
BinaryExpr, CreateMemoryTable, CreateView, Execute, Expr, ExprSchemable,
LogicalPlanBuilder, Operator, Prepare, TableProviderFilterPushDown, TableSource,
WindowFunctionDefinition, build_join_schema, expr_vec_fmt, requalify_sides_if_needed,
};
use arrow::datatypes::{DataType, Field, FieldRef, Schema, SchemaRef};
use datafusion_common::cse::{NormalizeEq, Normalizeable};
use datafusion_common::format::ExplainFormat;
use datafusion_common::metadata::check_metadata_with_storage_equal;
use datafusion_common::tree_node::{
Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion,
};
use datafusion_common::{
Column, Constraints, DFSchema, DFSchemaRef, DataFusionError, Dependency,
FunctionalDependence, FunctionalDependencies, NullEquality, ParamValues, Result,
ScalarValue, Spans, TableReference, UnnestOptions, aggregate_functional_dependencies,
assert_eq_or_internal_err, assert_or_internal_err, internal_err, plan_err,
};
use indexmap::IndexSet;
use crate::display::PgJsonVisitor;
pub use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan};
pub use datafusion_common::{JoinConstraint, JoinType};
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
pub enum LogicalPlan {
Projection(Projection),
Filter(Filter),
Window(Window),
Aggregate(Aggregate),
Sort(Sort),
Join(Join),
Repartition(Repartition),
Union(Union),
TableScan(TableScan),
EmptyRelation(EmptyRelation),
Subquery(Subquery),
SubqueryAlias(SubqueryAlias),
Limit(Limit),
Statement(Statement),
Values(Values),
Explain(Explain),
Analyze(Analyze),
Extension(Extension),
Distinct(Distinct),
Dml(DmlStatement),
Ddl(DdlStatement),
Copy(CopyTo),
DescribeTable(DescribeTable),
Unnest(Unnest),
RecursiveQuery(RecursiveQuery),
}
impl Default for LogicalPlan {
fn default() -> Self {
LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema: Arc::new(DFSchema::empty()),
})
}
}
impl<'a> TreeNodeContainer<'a, Self> for LogicalPlan {
fn apply_elements<F: FnMut(&'a Self) -> Result<TreeNodeRecursion>>(
&'a self,
mut f: F,
) -> Result<TreeNodeRecursion> {
f(self)
}
fn map_elements<F: FnMut(Self) -> Result<Transformed<Self>>>(
self,
mut f: F,
) -> Result<Transformed<Self>> {
f(self)
}
}
impl LogicalPlan {
pub fn schema(&self) -> &DFSchemaRef {
match self {
LogicalPlan::EmptyRelation(EmptyRelation { schema, .. }) => schema,
LogicalPlan::Values(Values { schema, .. }) => schema,
LogicalPlan::TableScan(TableScan {
projected_schema, ..
}) => projected_schema,
LogicalPlan::Projection(Projection { schema, .. }) => schema,
LogicalPlan::Filter(Filter { input, .. }) => input.schema(),
LogicalPlan::Distinct(Distinct::All(input)) => input.schema(),
LogicalPlan::Distinct(Distinct::On(DistinctOn { schema, .. })) => schema,
LogicalPlan::Window(Window { schema, .. }) => schema,
LogicalPlan::Aggregate(Aggregate { schema, .. }) => schema,
LogicalPlan::Sort(Sort { input, .. }) => input.schema(),
LogicalPlan::Join(Join { schema, .. }) => schema,
LogicalPlan::Repartition(Repartition { input, .. }) => input.schema(),
LogicalPlan::Limit(Limit { input, .. }) => input.schema(),
LogicalPlan::Statement(statement) => statement.schema(),
LogicalPlan::Subquery(Subquery { subquery, .. }) => subquery.schema(),
LogicalPlan::SubqueryAlias(SubqueryAlias { schema, .. }) => schema,
LogicalPlan::Explain(explain) => &explain.schema,
LogicalPlan::Analyze(analyze) => &analyze.schema,
LogicalPlan::Extension(extension) => extension.node.schema(),
LogicalPlan::Union(Union { schema, .. }) => schema,
LogicalPlan::DescribeTable(DescribeTable { output_schema, .. }) => {
output_schema
}
LogicalPlan::Dml(DmlStatement { output_schema, .. }) => output_schema,
LogicalPlan::Copy(CopyTo { output_schema, .. }) => output_schema,
LogicalPlan::Ddl(ddl) => ddl.schema(),
LogicalPlan::Unnest(Unnest { schema, .. }) => schema,
LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, .. }) => {
static_term.schema()
}
}
}
pub fn fallback_normalize_schemas(&self) -> Vec<&DFSchema> {
match self {
LogicalPlan::Window(_)
| LogicalPlan::Projection(_)
| LogicalPlan::Aggregate(_)
| LogicalPlan::Unnest(_)
| LogicalPlan::Join(_) => self
.inputs()
.iter()
.map(|input| input.schema().as_ref())
.collect(),
_ => vec![],
}
}
pub fn explain_schema() -> SchemaRef {
SchemaRef::new(Schema::new(vec![
Field::new("plan_type", DataType::Utf8, false),
Field::new("plan", DataType::Utf8, false),
]))
}
pub fn describe_schema() -> Schema {
Schema::new(vec![
Field::new("column_name", DataType::Utf8, false),
Field::new("data_type", DataType::Utf8, false),
Field::new("is_nullable", DataType::Utf8, false),
])
}
pub fn expressions(self: &LogicalPlan) -> Vec<Expr> {
let mut exprs = vec![];
self.apply_expressions(|e| {
exprs.push(e.clone());
Ok(TreeNodeRecursion::Continue)
})
.unwrap();
exprs
}
pub fn all_out_ref_exprs(self: &LogicalPlan) -> Vec<Expr> {
let mut exprs = vec![];
self.apply_expressions(|e| {
find_out_reference_exprs(e).into_iter().for_each(|e| {
if !exprs.contains(&e) {
exprs.push(e)
}
});
Ok(TreeNodeRecursion::Continue)
})
.unwrap();
self.inputs()
.into_iter()
.flat_map(|child| child.all_out_ref_exprs())
.for_each(|e| {
if !exprs.contains(&e) {
exprs.push(e)
}
});
exprs
}
pub fn inputs(&self) -> Vec<&LogicalPlan> {
match self {
LogicalPlan::Projection(Projection { input, .. }) => vec![input],
LogicalPlan::Filter(Filter { input, .. }) => vec![input],
LogicalPlan::Repartition(Repartition { input, .. }) => vec![input],
LogicalPlan::Window(Window { input, .. }) => vec![input],
LogicalPlan::Aggregate(Aggregate { input, .. }) => vec![input],
LogicalPlan::Sort(Sort { input, .. }) => vec![input],
LogicalPlan::Join(Join { left, right, .. }) => vec![left, right],
LogicalPlan::Limit(Limit { input, .. }) => vec![input],
LogicalPlan::Subquery(Subquery { subquery, .. }) => vec![subquery],
LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => vec![input],
LogicalPlan::Extension(extension) => extension.node.inputs(),
LogicalPlan::Union(Union { inputs, .. }) => {
inputs.iter().map(|arc| arc.as_ref()).collect()
}
LogicalPlan::Distinct(
Distinct::All(input) | Distinct::On(DistinctOn { input, .. }),
) => vec![input],
LogicalPlan::Explain(explain) => vec![&explain.plan],
LogicalPlan::Analyze(analyze) => vec![&analyze.input],
LogicalPlan::Dml(write) => vec![&write.input],
LogicalPlan::Copy(copy) => vec![©.input],
LogicalPlan::Ddl(ddl) => ddl.inputs(),
LogicalPlan::Unnest(Unnest { input, .. }) => vec![input],
LogicalPlan::RecursiveQuery(RecursiveQuery {
static_term,
recursive_term,
..
}) => vec![static_term, recursive_term],
LogicalPlan::Statement(stmt) => stmt.inputs(),
LogicalPlan::TableScan { .. }
| LogicalPlan::EmptyRelation { .. }
| LogicalPlan::Values { .. }
| LogicalPlan::DescribeTable(_) => vec![],
}
}
pub fn using_columns(&self) -> Result<Vec<HashSet<Column>>, DataFusionError> {
let mut using_columns: Vec<HashSet<Column>> = vec![];
self.apply_with_subqueries(|plan| {
if let LogicalPlan::Join(Join {
join_constraint: JoinConstraint::Using,
on,
..
}) = plan
{
let columns =
on.iter().try_fold(HashSet::new(), |mut accumu, (l, r)| {
let Some(l) = l.get_as_join_column() else {
return internal_err!(
"Invalid join key. Expected column, found {l:?}"
);
};
let Some(r) = r.get_as_join_column() else {
return internal_err!(
"Invalid join key. Expected column, found {r:?}"
);
};
accumu.insert(l.to_owned());
accumu.insert(r.to_owned());
Result::<_, DataFusionError>::Ok(accumu)
})?;
using_columns.push(columns);
}
Ok(TreeNodeRecursion::Continue)
})?;
Ok(using_columns)
}
pub fn head_output_expr(&self) -> Result<Option<Expr>> {
match self {
LogicalPlan::Projection(projection) => {
Ok(Some(projection.expr.as_slice()[0].clone()))
}
LogicalPlan::Aggregate(agg) => {
if agg.group_expr.is_empty() {
Ok(Some(agg.aggr_expr.as_slice()[0].clone()))
} else {
Ok(Some(agg.group_expr.as_slice()[0].clone()))
}
}
LogicalPlan::Distinct(Distinct::On(DistinctOn { select_expr, .. })) => {
Ok(Some(select_expr[0].clone()))
}
LogicalPlan::Filter(Filter { input, .. })
| LogicalPlan::Distinct(Distinct::All(input))
| LogicalPlan::Sort(Sort { input, .. })
| LogicalPlan::Limit(Limit { input, .. })
| LogicalPlan::Repartition(Repartition { input, .. })
| LogicalPlan::Window(Window { input, .. }) => input.head_output_expr(),
LogicalPlan::Join(Join {
left,
right,
join_type,
..
}) => match join_type {
JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => {
if left.schema().fields().is_empty() {
right.head_output_expr()
} else {
left.head_output_expr()
}
}
JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => {
left.head_output_expr()
}
JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => {
right.head_output_expr()
}
},
LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, .. }) => {
static_term.head_output_expr()
}
LogicalPlan::Union(union) => Ok(Some(Expr::Column(Column::from(
union.schema.qualified_field(0),
)))),
LogicalPlan::TableScan(table) => Ok(Some(Expr::Column(Column::from(
table.projected_schema.qualified_field(0),
)))),
LogicalPlan::SubqueryAlias(subquery_alias) => {
let expr_opt = subquery_alias.input.head_output_expr()?;
expr_opt
.map(|expr| {
Ok(Expr::Column(create_col_from_scalar_expr(
&expr,
subquery_alias.alias.to_string(),
)?))
})
.map_or(Ok(None), |v| v.map(Some))
}
LogicalPlan::Subquery(_) => Ok(None),
LogicalPlan::EmptyRelation(_)
| LogicalPlan::Statement(_)
| LogicalPlan::Values(_)
| LogicalPlan::Explain(_)
| LogicalPlan::Analyze(_)
| LogicalPlan::Extension(_)
| LogicalPlan::Dml(_)
| LogicalPlan::Copy(_)
| LogicalPlan::Ddl(_)
| LogicalPlan::DescribeTable(_)
| LogicalPlan::Unnest(_) => Ok(None),
}
}
pub fn recompute_schema(self) -> Result<Self> {
match self {
LogicalPlan::Projection(Projection {
expr,
input,
schema: _,
}) => Projection::try_new(expr, input).map(LogicalPlan::Projection),
LogicalPlan::Dml(_) => Ok(self),
LogicalPlan::Copy(_) => Ok(self),
LogicalPlan::Values(Values { schema, values }) => {
Ok(LogicalPlan::Values(Values { schema, values }))
}
LogicalPlan::Filter(Filter { predicate, input }) => {
Filter::try_new(predicate, input).map(LogicalPlan::Filter)
}
LogicalPlan::Repartition(_) => Ok(self),
LogicalPlan::Window(Window {
input,
window_expr,
schema: _,
}) => Window::try_new(window_expr, input).map(LogicalPlan::Window),
LogicalPlan::Aggregate(Aggregate {
input,
group_expr,
aggr_expr,
schema: _,
}) => Aggregate::try_new(input, group_expr, aggr_expr)
.map(LogicalPlan::Aggregate),
LogicalPlan::Sort(_) => Ok(self),
LogicalPlan::Join(Join {
left,
right,
filter,
join_type,
join_constraint,
on,
schema: _,
null_equality,
null_aware,
}) => {
let schema =
build_join_schema(left.schema(), right.schema(), &join_type)?;
let new_on: Vec<_> = on
.into_iter()
.map(|equi_expr| {
(equi_expr.0.unalias(), equi_expr.1.unalias())
})
.collect();
Ok(LogicalPlan::Join(Join {
left,
right,
join_type,
join_constraint,
on: new_on,
filter,
schema: DFSchemaRef::new(schema),
null_equality,
null_aware,
}))
}
LogicalPlan::Subquery(_) => Ok(self),
LogicalPlan::SubqueryAlias(SubqueryAlias {
input,
alias,
schema: _,
}) => SubqueryAlias::try_new(input, alias).map(LogicalPlan::SubqueryAlias),
LogicalPlan::Limit(_) => Ok(self),
LogicalPlan::Ddl(_) => Ok(self),
LogicalPlan::Extension(Extension { node }) => {
let expr = node.expressions();
let inputs: Vec<_> = node.inputs().into_iter().cloned().collect();
Ok(LogicalPlan::Extension(Extension {
node: node.with_exprs_and_inputs(expr, inputs)?,
}))
}
LogicalPlan::Union(Union { inputs, schema }) => {
let first_input_schema = inputs[0].schema();
if schema.fields().len() == first_input_schema.fields().len() {
Ok(LogicalPlan::Union(Union { inputs, schema }))
} else {
Ok(LogicalPlan::Union(Union::try_new(inputs)?))
}
}
LogicalPlan::Distinct(distinct) => {
let distinct = match distinct {
Distinct::All(input) => Distinct::All(input),
Distinct::On(DistinctOn {
on_expr,
select_expr,
sort_expr,
input,
schema: _,
}) => Distinct::On(DistinctOn::try_new(
on_expr,
select_expr,
sort_expr,
input,
)?),
};
Ok(LogicalPlan::Distinct(distinct))
}
LogicalPlan::RecursiveQuery(_) => Ok(self),
LogicalPlan::Analyze(_) => Ok(self),
LogicalPlan::Explain(_) => Ok(self),
LogicalPlan::TableScan(_) => Ok(self),
LogicalPlan::EmptyRelation(_) => Ok(self),
LogicalPlan::Statement(_) => Ok(self),
LogicalPlan::DescribeTable(_) => Ok(self),
LogicalPlan::Unnest(Unnest {
input,
exec_columns,
options,
..
}) => {
unnest_with_options(Arc::unwrap_or_clone(input), exec_columns, options)
}
}
}
pub fn with_new_exprs(
&self,
mut expr: Vec<Expr>,
inputs: Vec<LogicalPlan>,
) -> Result<LogicalPlan> {
match self {
LogicalPlan::Projection(Projection { .. }) => {
let input = self.only_input(inputs)?;
Projection::try_new(expr, Arc::new(input)).map(LogicalPlan::Projection)
}
LogicalPlan::Dml(DmlStatement {
table_name,
target,
op,
..
}) => {
self.assert_no_expressions(expr)?;
let input = self.only_input(inputs)?;
Ok(LogicalPlan::Dml(DmlStatement::new(
table_name.clone(),
Arc::clone(target),
op.clone(),
Arc::new(input),
)))
}
LogicalPlan::Copy(CopyTo {
input: _,
output_url,
file_type,
options,
partition_by,
output_schema: _,
}) => {
self.assert_no_expressions(expr)?;
let input = self.only_input(inputs)?;
Ok(LogicalPlan::Copy(CopyTo::new(
Arc::new(input),
output_url.clone(),
partition_by.clone(),
Arc::clone(file_type),
options.clone(),
)))
}
LogicalPlan::Values(Values { schema, .. }) => {
self.assert_no_inputs(inputs)?;
Ok(LogicalPlan::Values(Values {
schema: Arc::clone(schema),
values: expr
.chunks_exact(schema.fields().len())
.map(|s| s.to_vec())
.collect(),
}))
}
LogicalPlan::Filter { .. } => {
let predicate = self.only_expr(expr)?;
let input = self.only_input(inputs)?;
Filter::try_new(predicate, Arc::new(input)).map(LogicalPlan::Filter)
}
LogicalPlan::Repartition(Repartition {
partitioning_scheme,
..
}) => match partitioning_scheme {
Partitioning::RoundRobinBatch(n) => {
self.assert_no_expressions(expr)?;
let input = self.only_input(inputs)?;
Ok(LogicalPlan::Repartition(Repartition {
partitioning_scheme: Partitioning::RoundRobinBatch(*n),
input: Arc::new(input),
}))
}
Partitioning::Hash(_, n) => {
let input = self.only_input(inputs)?;
Ok(LogicalPlan::Repartition(Repartition {
partitioning_scheme: Partitioning::Hash(expr, *n),
input: Arc::new(input),
}))
}
Partitioning::DistributeBy(_) => {
let input = self.only_input(inputs)?;
Ok(LogicalPlan::Repartition(Repartition {
partitioning_scheme: Partitioning::DistributeBy(expr),
input: Arc::new(input),
}))
}
},
LogicalPlan::Window(Window { window_expr, .. }) => {
assert_eq!(window_expr.len(), expr.len());
let input = self.only_input(inputs)?;
Window::try_new(expr, Arc::new(input)).map(LogicalPlan::Window)
}
LogicalPlan::Aggregate(Aggregate { group_expr, .. }) => {
let input = self.only_input(inputs)?;
let agg_expr = expr.split_off(group_expr.len());
Aggregate::try_new(Arc::new(input), expr, agg_expr)
.map(LogicalPlan::Aggregate)
}
LogicalPlan::Sort(Sort {
expr: sort_expr,
fetch,
..
}) => {
let input = self.only_input(inputs)?;
Ok(LogicalPlan::Sort(Sort {
expr: expr
.into_iter()
.zip(sort_expr.iter())
.map(|(expr, sort)| sort.with_expr(expr))
.collect(),
input: Arc::new(input),
fetch: *fetch,
}))
}
LogicalPlan::Join(Join {
join_type,
join_constraint,
on,
null_equality,
null_aware,
..
}) => {
let (left, right) = self.only_two_inputs(inputs)?;
let schema = build_join_schema(left.schema(), right.schema(), join_type)?;
let equi_expr_count = on.len() * 2;
assert!(expr.len() >= equi_expr_count);
let filter_expr = if expr.len() > equi_expr_count {
expr.pop()
} else {
None
};
assert_eq!(expr.len(), equi_expr_count);
let mut new_on = Vec::with_capacity(on.len());
let mut iter = expr.into_iter();
while let Some(left) = iter.next() {
let Some(right) = iter.next() else {
internal_err!(
"Expected a pair of expressions to construct the join on expression"
)?
};
new_on.push((left.unalias(), right.unalias()));
}
Ok(LogicalPlan::Join(Join {
left: Arc::new(left),
right: Arc::new(right),
join_type: *join_type,
join_constraint: *join_constraint,
on: new_on,
filter: filter_expr,
schema: DFSchemaRef::new(schema),
null_equality: *null_equality,
null_aware: *null_aware,
}))
}
LogicalPlan::Subquery(Subquery {
outer_ref_columns,
spans,
..
}) => {
self.assert_no_expressions(expr)?;
let input = self.only_input(inputs)?;
let subquery = LogicalPlanBuilder::from(input).build()?;
Ok(LogicalPlan::Subquery(Subquery {
subquery: Arc::new(subquery),
outer_ref_columns: outer_ref_columns.clone(),
spans: spans.clone(),
}))
}
LogicalPlan::SubqueryAlias(SubqueryAlias { alias, .. }) => {
self.assert_no_expressions(expr)?;
let input = self.only_input(inputs)?;
SubqueryAlias::try_new(Arc::new(input), alias.clone())
.map(LogicalPlan::SubqueryAlias)
}
LogicalPlan::Limit(Limit { skip, fetch, .. }) => {
let old_expr_len = skip.iter().chain(fetch.iter()).count();
assert_eq_or_internal_err!(
old_expr_len,
expr.len(),
"Invalid number of new Limit expressions: expected {}, got {}",
old_expr_len,
expr.len()
);
let new_fetch = fetch.as_ref().and_then(|_| expr.pop());
let new_skip = skip.as_ref().and_then(|_| expr.pop());
let input = self.only_input(inputs)?;
Ok(LogicalPlan::Limit(Limit {
skip: new_skip.map(Box::new),
fetch: new_fetch.map(Box::new),
input: Arc::new(input),
}))
}
LogicalPlan::Ddl(DdlStatement::CreateMemoryTable(CreateMemoryTable {
name,
if_not_exists,
or_replace,
column_defaults,
temporary,
..
})) => {
self.assert_no_expressions(expr)?;
let input = self.only_input(inputs)?;
Ok(LogicalPlan::Ddl(DdlStatement::CreateMemoryTable(
CreateMemoryTable {
input: Arc::new(input),
constraints: Constraints::default(),
name: name.clone(),
if_not_exists: *if_not_exists,
or_replace: *or_replace,
column_defaults: column_defaults.clone(),
temporary: *temporary,
},
)))
}
LogicalPlan::Ddl(DdlStatement::CreateView(CreateView {
name,
or_replace,
definition,
temporary,
..
})) => {
self.assert_no_expressions(expr)?;
let input = self.only_input(inputs)?;
Ok(LogicalPlan::Ddl(DdlStatement::CreateView(CreateView {
input: Arc::new(input),
name: name.clone(),
or_replace: *or_replace,
temporary: *temporary,
definition: definition.clone(),
})))
}
LogicalPlan::Extension(e) => Ok(LogicalPlan::Extension(Extension {
node: e.node.with_exprs_and_inputs(expr, inputs)?,
})),
LogicalPlan::Union(Union { schema, .. }) => {
self.assert_no_expressions(expr)?;
let input_schema = inputs[0].schema();
let schema = if schema.fields().len() == input_schema.fields().len() {
Arc::clone(schema)
} else {
Arc::clone(input_schema)
};
Ok(LogicalPlan::Union(Union {
inputs: inputs.into_iter().map(Arc::new).collect(),
schema,
}))
}
LogicalPlan::Distinct(distinct) => {
let distinct = match distinct {
Distinct::All(_) => {
self.assert_no_expressions(expr)?;
let input = self.only_input(inputs)?;
Distinct::All(Arc::new(input))
}
Distinct::On(DistinctOn {
on_expr,
select_expr,
..
}) => {
let input = self.only_input(inputs)?;
let sort_expr = expr.split_off(on_expr.len() + select_expr.len());
let select_expr = expr.split_off(on_expr.len());
assert!(
sort_expr.is_empty(),
"with_new_exprs for Distinct does not support sort expressions"
);
Distinct::On(DistinctOn::try_new(
expr,
select_expr,
None, Arc::new(input),
)?)
}
};
Ok(LogicalPlan::Distinct(distinct))
}
LogicalPlan::RecursiveQuery(RecursiveQuery {
name, is_distinct, ..
}) => {
self.assert_no_expressions(expr)?;
let (static_term, recursive_term) = self.only_two_inputs(inputs)?;
Ok(LogicalPlan::RecursiveQuery(RecursiveQuery {
name: name.clone(),
static_term: Arc::new(static_term),
recursive_term: Arc::new(recursive_term),
is_distinct: *is_distinct,
}))
}
LogicalPlan::Analyze(a) => {
self.assert_no_expressions(expr)?;
let input = self.only_input(inputs)?;
Ok(LogicalPlan::Analyze(Analyze {
verbose: a.verbose,
schema: Arc::clone(&a.schema),
input: Arc::new(input),
}))
}
LogicalPlan::Explain(e) => {
self.assert_no_expressions(expr)?;
let input = self.only_input(inputs)?;
Ok(LogicalPlan::Explain(Explain {
verbose: e.verbose,
plan: Arc::new(input),
explain_format: e.explain_format.clone(),
stringified_plans: e.stringified_plans.clone(),
schema: Arc::clone(&e.schema),
logical_optimization_succeeded: e.logical_optimization_succeeded,
}))
}
LogicalPlan::Statement(Statement::Prepare(Prepare {
name, fields, ..
})) => {
self.assert_no_expressions(expr)?;
let input = self.only_input(inputs)?;
Ok(LogicalPlan::Statement(Statement::Prepare(Prepare {
name: name.clone(),
fields: fields.clone(),
input: Arc::new(input),
})))
}
LogicalPlan::Statement(Statement::Execute(Execute { name, .. })) => {
self.assert_no_inputs(inputs)?;
Ok(LogicalPlan::Statement(Statement::Execute(Execute {
name: name.clone(),
parameters: expr,
})))
}
LogicalPlan::TableScan(ts) => {
self.assert_no_inputs(inputs)?;
Ok(LogicalPlan::TableScan(TableScan {
filters: expr,
..ts.clone()
}))
}
LogicalPlan::EmptyRelation(_)
| LogicalPlan::Ddl(_)
| LogicalPlan::Statement(_)
| LogicalPlan::DescribeTable(_) => {
self.assert_no_expressions(expr)?;
self.assert_no_inputs(inputs)?;
Ok(self.clone())
}
LogicalPlan::Unnest(Unnest {
exec_columns: columns,
options,
..
}) => {
self.assert_no_expressions(expr)?;
let input = self.only_input(inputs)?;
let new_plan =
unnest_with_options(input, columns.clone(), options.clone())?;
Ok(new_plan)
}
}
}
pub fn check_invariants(&self, check: InvariantLevel) -> Result<()> {
match check {
InvariantLevel::Always => assert_always_invariants_at_current_node(self),
InvariantLevel::Executable => assert_executable_invariants(self),
}
}
#[inline]
#[expect(clippy::needless_pass_by_value)] fn assert_no_expressions(&self, expr: Vec<Expr>) -> Result<()> {
assert_or_internal_err!(
expr.is_empty(),
"{self:?} should have no exprs, got {:?}",
expr
);
Ok(())
}
#[inline]
#[expect(clippy::needless_pass_by_value)] fn assert_no_inputs(&self, inputs: Vec<LogicalPlan>) -> Result<()> {
assert_or_internal_err!(
inputs.is_empty(),
"{self:?} should have no inputs, got: {:?}",
inputs
);
Ok(())
}
#[inline]
fn only_expr(&self, mut expr: Vec<Expr>) -> Result<Expr> {
assert_eq_or_internal_err!(
expr.len(),
1,
"{self:?} should have exactly one expr, got {:?}",
&expr
);
Ok(expr.remove(0))
}
#[inline]
fn only_input(&self, mut inputs: Vec<LogicalPlan>) -> Result<LogicalPlan> {
assert_eq_or_internal_err!(
inputs.len(),
1,
"{self:?} should have exactly one input, got {:?}",
&inputs
);
Ok(inputs.remove(0))
}
#[inline]
fn only_two_inputs(
&self,
mut inputs: Vec<LogicalPlan>,
) -> Result<(LogicalPlan, LogicalPlan)> {
assert_eq_or_internal_err!(
inputs.len(),
2,
"{self:?} should have exactly two inputs, got {:?}",
&inputs
);
let right = inputs.remove(1);
let left = inputs.remove(0);
Ok((left, right))
}
pub fn with_param_values(
self,
param_values: impl Into<ParamValues>,
) -> Result<LogicalPlan> {
let param_values = param_values.into();
let plan_with_values = self.replace_params_with_values(¶m_values)?;
Ok(
if let LogicalPlan::Statement(Statement::Prepare(prepare_lp)) =
plan_with_values
{
param_values.verify_fields(&prepare_lp.fields)?;
Arc::unwrap_or_clone(prepare_lp.input)
} else {
plan_with_values
},
)
}
pub fn max_rows(self: &LogicalPlan) -> Option<usize> {
match self {
LogicalPlan::Projection(Projection { input, .. }) => input.max_rows(),
LogicalPlan::Filter(filter) => {
if filter.is_scalar() {
Some(1)
} else {
filter.input.max_rows()
}
}
LogicalPlan::Window(Window { input, .. }) => input.max_rows(),
LogicalPlan::Aggregate(Aggregate {
input, group_expr, ..
}) => {
if group_expr
.iter()
.all(|expr| matches!(expr, Expr::Literal(_, _)))
{
Some(1)
} else {
input.max_rows()
}
}
LogicalPlan::Sort(Sort { input, fetch, .. }) => {
match (fetch, input.max_rows()) {
(Some(fetch_limit), Some(input_max)) => {
Some(input_max.min(*fetch_limit))
}
(Some(fetch_limit), None) => Some(*fetch_limit),
(None, Some(input_max)) => Some(input_max),
(None, None) => None,
}
}
LogicalPlan::Join(Join {
left,
right,
join_type,
..
}) => match join_type {
JoinType::Inner => Some(left.max_rows()? * right.max_rows()?),
JoinType::Left | JoinType::Right | JoinType::Full => {
match (left.max_rows()?, right.max_rows()?, join_type) {
(0, 0, _) => Some(0),
(max_rows, 0, JoinType::Left | JoinType::Full) => Some(max_rows),
(0, max_rows, JoinType::Right | JoinType::Full) => Some(max_rows),
(left_max, right_max, _) => Some(left_max * right_max),
}
}
JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => {
left.max_rows()
}
JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => {
right.max_rows()
}
},
LogicalPlan::Repartition(Repartition { input, .. }) => input.max_rows(),
LogicalPlan::Union(Union { inputs, .. }) => {
inputs.iter().try_fold(0usize, |mut acc, plan| {
acc += plan.max_rows()?;
Some(acc)
})
}
LogicalPlan::TableScan(TableScan { fetch, .. }) => *fetch,
LogicalPlan::EmptyRelation(_) => Some(0),
LogicalPlan::RecursiveQuery(_) => None,
LogicalPlan::Subquery(_) => None,
LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => input.max_rows(),
LogicalPlan::Limit(limit) => match limit.get_fetch_type() {
Ok(FetchType::Literal(s)) => s,
_ => None,
},
LogicalPlan::Distinct(
Distinct::All(input) | Distinct::On(DistinctOn { input, .. }),
) => input.max_rows(),
LogicalPlan::Values(v) => Some(v.values.len()),
LogicalPlan::Unnest(_) => None,
LogicalPlan::Ddl(_)
| LogicalPlan::Explain(_)
| LogicalPlan::Analyze(_)
| LogicalPlan::Dml(_)
| LogicalPlan::Copy(_)
| LogicalPlan::DescribeTable(_)
| LogicalPlan::Statement(_)
| LogicalPlan::Extension(_) => None,
}
}
pub fn skip(&self) -> Result<Option<usize>> {
match self {
LogicalPlan::Limit(limit) => match limit.get_skip_type()? {
SkipType::Literal(0) => Ok(None),
SkipType::Literal(n) => Ok(Some(n)),
SkipType::UnsupportedExpr => Ok(None),
},
LogicalPlan::Sort(_) => Ok(None),
LogicalPlan::TableScan(_) => Ok(None),
LogicalPlan::Projection(_) => Ok(None),
LogicalPlan::Filter(_) => Ok(None),
LogicalPlan::Window(_) => Ok(None),
LogicalPlan::Aggregate(_) => Ok(None),
LogicalPlan::Join(_) => Ok(None),
LogicalPlan::Repartition(_) => Ok(None),
LogicalPlan::Union(_) => Ok(None),
LogicalPlan::EmptyRelation(_) => Ok(None),
LogicalPlan::Subquery(_) => Ok(None),
LogicalPlan::SubqueryAlias(_) => Ok(None),
LogicalPlan::Statement(_) => Ok(None),
LogicalPlan::Values(_) => Ok(None),
LogicalPlan::Explain(_) => Ok(None),
LogicalPlan::Analyze(_) => Ok(None),
LogicalPlan::Extension(_) => Ok(None),
LogicalPlan::Distinct(_) => Ok(None),
LogicalPlan::Dml(_) => Ok(None),
LogicalPlan::Ddl(_) => Ok(None),
LogicalPlan::Copy(_) => Ok(None),
LogicalPlan::DescribeTable(_) => Ok(None),
LogicalPlan::Unnest(_) => Ok(None),
LogicalPlan::RecursiveQuery(_) => Ok(None),
}
}
pub fn fetch(&self) -> Result<Option<usize>> {
match self {
LogicalPlan::Sort(Sort { fetch, .. }) => Ok(*fetch),
LogicalPlan::TableScan(TableScan { fetch, .. }) => Ok(*fetch),
LogicalPlan::Limit(limit) => match limit.get_fetch_type()? {
FetchType::Literal(s) => Ok(s),
FetchType::UnsupportedExpr => Ok(None),
},
LogicalPlan::Projection(_) => Ok(None),
LogicalPlan::Filter(_) => Ok(None),
LogicalPlan::Window(_) => Ok(None),
LogicalPlan::Aggregate(_) => Ok(None),
LogicalPlan::Join(_) => Ok(None),
LogicalPlan::Repartition(_) => Ok(None),
LogicalPlan::Union(_) => Ok(None),
LogicalPlan::EmptyRelation(_) => Ok(None),
LogicalPlan::Subquery(_) => Ok(None),
LogicalPlan::SubqueryAlias(_) => Ok(None),
LogicalPlan::Statement(_) => Ok(None),
LogicalPlan::Values(_) => Ok(None),
LogicalPlan::Explain(_) => Ok(None),
LogicalPlan::Analyze(_) => Ok(None),
LogicalPlan::Extension(_) => Ok(None),
LogicalPlan::Distinct(_) => Ok(None),
LogicalPlan::Dml(_) => Ok(None),
LogicalPlan::Ddl(_) => Ok(None),
LogicalPlan::Copy(_) => Ok(None),
LogicalPlan::DescribeTable(_) => Ok(None),
LogicalPlan::Unnest(_) => Ok(None),
LogicalPlan::RecursiveQuery(_) => Ok(None),
}
}
pub fn contains_outer_reference(&self) -> bool {
let mut contains = false;
self.apply_expressions(|expr| {
Ok(if expr.contains_outer() {
contains = true;
TreeNodeRecursion::Stop
} else {
TreeNodeRecursion::Continue
})
})
.unwrap();
contains
}
pub fn columnized_output_exprs(&self) -> Result<Vec<(&Expr, Column)>> {
match self {
LogicalPlan::Aggregate(aggregate) => Ok(aggregate
.output_expressions()?
.into_iter()
.zip(self.schema().columns())
.collect()),
LogicalPlan::Window(Window {
window_expr,
input,
schema,
}) => {
let mut output_exprs = input.columnized_output_exprs()?;
let input_len = input.schema().fields().len();
output_exprs.extend(
window_expr
.iter()
.zip(schema.columns().into_iter().skip(input_len)),
);
Ok(output_exprs)
}
_ => Ok(vec![]),
}
}
}
impl LogicalPlan {
pub fn replace_params_with_values(
self,
param_values: &ParamValues,
) -> Result<LogicalPlan> {
self.transform_up_with_subqueries(|plan| {
let schema = Arc::clone(plan.schema());
let name_preserver = NamePreserver::new(&plan);
plan.map_expressions(|e| {
let (e, has_placeholder) = e.infer_placeholder_types(&schema)?;
if !has_placeholder {
Ok(Transformed::no(e))
} else {
let original_name = name_preserver.save(&e);
let transformed_expr = e.transform_up(|e| {
if let Expr::Placeholder(Placeholder { id, .. }) = e {
let (value, metadata) = param_values
.get_placeholders_with_values(&id)?
.into_inner();
Ok(Transformed::yes(Expr::Literal(value, metadata)))
} else {
Ok(Transformed::no(e))
}
})?;
Ok(transformed_expr.update_data(|expr| original_name.restore(expr)))
}
})?
.map_data(|plan| plan.update_schema_data_type())
})
.map(|res| res.data)
}
fn update_schema_data_type(self) -> Result<LogicalPlan> {
match self {
LogicalPlan::Values(Values { values, schema: _ }) => {
LogicalPlanBuilder::values(values)?.build()
}
plan => plan.recompute_schema(),
}
}
pub fn get_parameter_names(&self) -> Result<HashSet<String>> {
let mut param_names = HashSet::new();
self.apply_with_subqueries(|plan| {
plan.apply_expressions(|expr| {
expr.apply(|expr| {
if let Expr::Placeholder(Placeholder { id, .. }) = expr {
param_names.insert(id.clone());
}
Ok(TreeNodeRecursion::Continue)
})
})
})
.map(|_| param_names)
}
pub fn get_parameter_types(
&self,
) -> Result<HashMap<String, Option<DataType>>, DataFusionError> {
let mut parameter_fields = self.get_parameter_fields()?;
Ok(parameter_fields
.drain()
.map(|(name, maybe_field)| {
(name, maybe_field.map(|field| field.data_type().clone()))
})
.collect())
}
pub fn get_parameter_fields(
&self,
) -> Result<HashMap<String, Option<FieldRef>>, DataFusionError> {
let mut param_types: HashMap<String, Option<FieldRef>> = HashMap::new();
self.apply_with_subqueries(|plan| {
plan.apply_expressions(|expr| {
expr.apply(|expr| {
if let Expr::Placeholder(Placeholder { id, field }) = expr {
let prev = param_types.get(id);
match (prev, field) {
(Some(Some(prev)), Some(field)) => {
check_metadata_with_storage_equal(
(field.data_type(), Some(field.metadata())),
(prev.data_type(), Some(prev.metadata())),
"parameter",
&format!(": Conflicting types for id {id}"),
)?;
}
(_, Some(field)) => {
param_types.insert(id.clone(), Some(Arc::clone(field)));
}
_ => {
param_types.insert(id.clone(), None);
}
}
}
Ok(TreeNodeRecursion::Continue)
})
})
})
.map(|_| param_types)
}
pub fn display_indent(&self) -> impl Display + '_ {
struct Wrapper<'a>(&'a LogicalPlan);
impl Display for Wrapper<'_> {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
let with_schema = false;
let mut visitor = IndentVisitor::new(f, with_schema);
match self.0.visit_with_subqueries(&mut visitor) {
Ok(_) => Ok(()),
Err(_) => Err(fmt::Error),
}
}
}
Wrapper(self)
}
pub fn display_indent_schema(&self) -> impl Display + '_ {
struct Wrapper<'a>(&'a LogicalPlan);
impl Display for Wrapper<'_> {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
let with_schema = true;
let mut visitor = IndentVisitor::new(f, with_schema);
match self.0.visit_with_subqueries(&mut visitor) {
Ok(_) => Ok(()),
Err(_) => Err(fmt::Error),
}
}
}
Wrapper(self)
}
pub fn display_pg_json(&self) -> impl Display + '_ {
struct Wrapper<'a>(&'a LogicalPlan);
impl Display for Wrapper<'_> {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
let mut visitor = PgJsonVisitor::new(f);
visitor.with_schema(true);
match self.0.visit_with_subqueries(&mut visitor) {
Ok(_) => Ok(()),
Err(_) => Err(fmt::Error),
}
}
}
Wrapper(self)
}
pub fn display_graphviz(&self) -> impl Display + '_ {
struct Wrapper<'a>(&'a LogicalPlan);
impl Display for Wrapper<'_> {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
let mut visitor = GraphvizVisitor::new(f);
visitor.start_graph()?;
visitor.pre_visit_plan("LogicalPlan")?;
self.0
.visit_with_subqueries(&mut visitor)
.map_err(|_| fmt::Error)?;
visitor.post_visit_plan()?;
visitor.set_with_schema(true);
visitor.pre_visit_plan("Detailed LogicalPlan")?;
self.0
.visit_with_subqueries(&mut visitor)
.map_err(|_| fmt::Error)?;
visitor.post_visit_plan()?;
visitor.end_graph()?;
Ok(())
}
}
Wrapper(self)
}
pub fn display(&self) -> impl Display + '_ {
struct Wrapper<'a>(&'a LogicalPlan);
impl Display for Wrapper<'_> {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
match self.0 {
LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row,
schema: _,
}) => {
let rows = if *produce_one_row { 1 } else { 0 };
write!(f, "EmptyRelation: rows={rows}")
}
LogicalPlan::RecursiveQuery(RecursiveQuery {
is_distinct, ..
}) => {
write!(f, "RecursiveQuery: is_distinct={is_distinct}")
}
LogicalPlan::Values(Values { values, .. }) => {
let str_values: Vec<_> = values
.iter()
.take(5)
.map(|row| {
let item = row
.iter()
.map(|expr| expr.to_string())
.collect::<Vec<_>>()
.join(", ");
format!("({item})")
})
.collect();
let eclipse = if values.len() > 5 { "..." } else { "" };
write!(f, "Values: {}{}", str_values.join(", "), eclipse)
}
LogicalPlan::TableScan(TableScan {
source,
table_name,
projection,
filters,
fetch,
..
}) => {
let projected_fields = match projection {
Some(indices) => {
let schema = source.schema();
let names: Vec<&str> = indices
.iter()
.map(|i| schema.field(*i).name().as_str())
.collect();
format!(" projection=[{}]", names.join(", "))
}
_ => "".to_string(),
};
write!(f, "TableScan: {table_name}{projected_fields}")?;
if !filters.is_empty() {
let mut full_filter = vec![];
let mut partial_filter = vec![];
let mut unsupported_filters = vec![];
let filters: Vec<&Expr> = filters.iter().collect();
if let Ok(results) =
source.supports_filters_pushdown(&filters)
{
filters.iter().zip(results.iter()).for_each(
|(x, res)| match res {
TableProviderFilterPushDown::Exact => {
full_filter.push(x)
}
TableProviderFilterPushDown::Inexact => {
partial_filter.push(x)
}
TableProviderFilterPushDown::Unsupported => {
unsupported_filters.push(x)
}
},
);
}
if !full_filter.is_empty() {
write!(
f,
", full_filters=[{}]",
expr_vec_fmt!(full_filter)
)?;
};
if !partial_filter.is_empty() {
write!(
f,
", partial_filters=[{}]",
expr_vec_fmt!(partial_filter)
)?;
}
if !unsupported_filters.is_empty() {
write!(
f,
", unsupported_filters=[{}]",
expr_vec_fmt!(unsupported_filters)
)?;
}
}
if let Some(n) = fetch {
write!(f, ", fetch={n}")?;
}
Ok(())
}
LogicalPlan::Projection(Projection { expr, .. }) => {
write!(f, "Projection:")?;
for (i, expr_item) in expr.iter().enumerate() {
if i > 0 {
write!(f, ",")?;
}
write!(f, " {expr_item}")?;
}
Ok(())
}
LogicalPlan::Dml(DmlStatement { table_name, op, .. }) => {
write!(f, "Dml: op=[{op}] table=[{table_name}]")
}
LogicalPlan::Copy(CopyTo {
input: _,
output_url,
file_type,
options,
..
}) => {
let op_str = options
.iter()
.map(|(k, v)| format!("{k} {v}"))
.collect::<Vec<String>>()
.join(", ");
write!(
f,
"CopyTo: format={} output_url={output_url} options: ({op_str})",
file_type.get_ext()
)
}
LogicalPlan::Ddl(ddl) => {
write!(f, "{}", ddl.display())
}
LogicalPlan::Filter(Filter {
predicate: expr, ..
}) => write!(f, "Filter: {expr}"),
LogicalPlan::Window(Window { window_expr, .. }) => {
write!(
f,
"WindowAggr: windowExpr=[[{}]]",
expr_vec_fmt!(window_expr)
)
}
LogicalPlan::Aggregate(Aggregate {
group_expr,
aggr_expr,
..
}) => write!(
f,
"Aggregate: groupBy=[[{}]], aggr=[[{}]]",
expr_vec_fmt!(group_expr),
expr_vec_fmt!(aggr_expr)
),
LogicalPlan::Sort(Sort { expr, fetch, .. }) => {
write!(f, "Sort: ")?;
for (i, expr_item) in expr.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{expr_item}")?;
}
if let Some(a) = fetch {
write!(f, ", fetch={a}")?;
}
Ok(())
}
LogicalPlan::Join(Join {
on: keys,
filter,
join_constraint,
join_type,
..
}) => {
let join_expr: Vec<String> =
keys.iter().map(|(l, r)| format!("{l} = {r}")).collect();
let filter_expr = filter
.as_ref()
.map(|expr| format!(" Filter: {expr}"))
.unwrap_or_else(|| "".to_string());
let join_type = if filter.is_none()
&& keys.is_empty()
&& *join_type == JoinType::Inner
{
"Cross".to_string()
} else {
join_type.to_string()
};
match join_constraint {
JoinConstraint::On => {
write!(f, "{join_type} Join:",)?;
if !join_expr.is_empty() || !filter_expr.is_empty() {
write!(
f,
" {}{}",
join_expr.join(", "),
filter_expr
)?;
}
Ok(())
}
JoinConstraint::Using => {
write!(
f,
"{} Join: Using {}{}",
join_type,
join_expr.join(", "),
filter_expr,
)
}
}
}
LogicalPlan::Repartition(Repartition {
partitioning_scheme,
..
}) => match partitioning_scheme {
Partitioning::RoundRobinBatch(n) => {
write!(f, "Repartition: RoundRobinBatch partition_count={n}")
}
Partitioning::Hash(expr, n) => {
let hash_expr: Vec<String> =
expr.iter().map(|e| format!("{e}")).collect();
write!(
f,
"Repartition: Hash({}) partition_count={}",
hash_expr.join(", "),
n
)
}
Partitioning::DistributeBy(expr) => {
let dist_by_expr: Vec<String> =
expr.iter().map(|e| format!("{e}")).collect();
write!(
f,
"Repartition: DistributeBy({})",
dist_by_expr.join(", "),
)
}
},
LogicalPlan::Limit(limit) => {
let skip_str = match limit.get_skip_type() {
Ok(SkipType::Literal(n)) => n.to_string(),
_ => limit
.skip
.as_ref()
.map_or_else(|| "None".to_string(), |x| x.to_string()),
};
let fetch_str = match limit.get_fetch_type() {
Ok(FetchType::Literal(Some(n))) => n.to_string(),
Ok(FetchType::Literal(None)) => "None".to_string(),
_ => limit
.fetch
.as_ref()
.map_or_else(|| "None".to_string(), |x| x.to_string()),
};
write!(f, "Limit: skip={skip_str}, fetch={fetch_str}",)
}
LogicalPlan::Subquery(Subquery { .. }) => {
write!(f, "Subquery:")
}
LogicalPlan::SubqueryAlias(SubqueryAlias { alias, .. }) => {
write!(f, "SubqueryAlias: {alias}")
}
LogicalPlan::Statement(statement) => {
write!(f, "{}", statement.display())
}
LogicalPlan::Distinct(distinct) => match distinct {
Distinct::All(_) => write!(f, "Distinct:"),
Distinct::On(DistinctOn {
on_expr,
select_expr,
sort_expr,
..
}) => write!(
f,
"DistinctOn: on_expr=[[{}]], select_expr=[[{}]], sort_expr=[[{}]]",
expr_vec_fmt!(on_expr),
expr_vec_fmt!(select_expr),
if let Some(sort_expr) = sort_expr {
expr_vec_fmt!(sort_expr)
} else {
"".to_string()
},
),
},
LogicalPlan::Explain { .. } => write!(f, "Explain"),
LogicalPlan::Analyze { .. } => write!(f, "Analyze"),
LogicalPlan::Union(_) => write!(f, "Union"),
LogicalPlan::Extension(e) => e.node.fmt_for_explain(f),
LogicalPlan::DescribeTable(DescribeTable { .. }) => {
write!(f, "DescribeTable")
}
LogicalPlan::Unnest(Unnest {
input: plan,
list_type_columns: list_col_indices,
struct_type_columns: struct_col_indices,
..
}) => {
let input_columns = plan.schema().columns();
let list_type_columns = list_col_indices
.iter()
.map(|(i, unnest_info)| {
format!(
"{}|depth={}",
&input_columns[*i].to_string(),
unnest_info.depth
)
})
.collect::<Vec<String>>();
let struct_type_columns = struct_col_indices
.iter()
.map(|i| &input_columns[*i])
.collect::<Vec<&Column>>();
write!(
f,
"Unnest: lists[{}] structs[{}]",
expr_vec_fmt!(list_type_columns),
expr_vec_fmt!(struct_type_columns)
)
}
}
}
}
Wrapper(self)
}
}
impl Display for LogicalPlan {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
self.display_indent().fmt(f)
}
}
impl ToStringifiedPlan for LogicalPlan {
fn to_stringified(&self, plan_type: PlanType) -> StringifiedPlan {
StringifiedPlan::new(plan_type, self.display_indent().to_string())
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct EmptyRelation {
pub produce_one_row: bool,
pub schema: DFSchemaRef,
}
impl PartialOrd for EmptyRelation {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.produce_one_row
.partial_cmp(&other.produce_one_row)
.filter(|cmp| *cmp != Ordering::Equal || self == other)
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
pub struct RecursiveQuery {
pub name: String,
pub static_term: Arc<LogicalPlan>,
pub recursive_term: Arc<LogicalPlan>,
pub is_distinct: bool,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Values {
pub schema: DFSchemaRef,
pub values: Vec<Vec<Expr>>,
}
impl PartialOrd for Values {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.values
.partial_cmp(&other.values)
.filter(|cmp| *cmp != Ordering::Equal || self == other)
}
}
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
#[non_exhaustive]
pub struct Projection {
pub expr: Vec<Expr>,
pub input: Arc<LogicalPlan>,
pub schema: DFSchemaRef,
}
impl PartialOrd for Projection {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
match self.expr.partial_cmp(&other.expr) {
Some(Ordering::Equal) => self.input.partial_cmp(&other.input),
cmp => cmp,
}
.filter(|cmp| *cmp != Ordering::Equal || self == other)
}
}
impl Projection {
pub fn try_new(expr: Vec<Expr>, input: Arc<LogicalPlan>) -> Result<Self> {
let projection_schema = projection_schema(&input, &expr)?;
Self::try_new_with_schema(expr, input, projection_schema)
}
pub fn try_new_with_schema(
expr: Vec<Expr>,
input: Arc<LogicalPlan>,
schema: DFSchemaRef,
) -> Result<Self> {
#[expect(deprecated)]
if !expr.iter().any(|e| matches!(e, Expr::Wildcard { .. }))
&& expr.len() != schema.fields().len()
{
return plan_err!(
"Projection has mismatch between number of expressions ({}) and number of fields in schema ({})",
expr.len(),
schema.fields().len()
);
}
Ok(Self {
expr,
input,
schema,
})
}
pub fn new_from_schema(input: Arc<LogicalPlan>, schema: DFSchemaRef) -> Self {
let expr: Vec<Expr> = schema.columns().into_iter().map(Expr::Column).collect();
Self {
expr,
input,
schema,
}
}
}
pub fn projection_schema(input: &LogicalPlan, exprs: &[Expr]) -> Result<Arc<DFSchema>> {
let metadata = input.schema().metadata().clone();
let schema =
DFSchema::new_with_metadata(exprlist_to_fields(exprs, input)?, metadata)?
.with_functional_dependencies(calc_func_dependencies_for_project(
exprs, input,
)?)?;
Ok(Arc::new(schema))
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub struct SubqueryAlias {
pub input: Arc<LogicalPlan>,
pub alias: TableReference,
pub schema: DFSchemaRef,
}
impl SubqueryAlias {
pub fn try_new(
plan: Arc<LogicalPlan>,
alias: impl Into<TableReference>,
) -> Result<Self> {
let alias = alias.into();
let aliases = unique_field_aliases(plan.schema().fields());
let is_projection_needed = aliases.iter().any(Option::is_some);
let plan = if is_projection_needed {
let projection_expressions = aliases
.iter()
.zip(plan.schema().iter())
.map(|(alias, (qualifier, field))| {
let column =
Expr::Column(Column::new(qualifier.cloned(), field.name()));
match alias {
None => column,
Some(alias) => {
Expr::Alias(Alias::new(column, qualifier.cloned(), alias))
}
}
})
.collect();
let projection = Projection::try_new(projection_expressions, plan)?;
Arc::new(LogicalPlan::Projection(projection))
} else {
plan
};
let fields = plan.schema().fields().clone();
let meta_data = plan.schema().metadata().clone();
let func_dependencies = plan.schema().functional_dependencies().clone();
let schema = DFSchema::from_unqualified_fields(fields, meta_data)?;
let schema = schema.as_arrow();
let schema = DFSchemaRef::new(
DFSchema::try_from_qualified_schema(alias.clone(), schema)?
.with_functional_dependencies(func_dependencies)?,
);
Ok(SubqueryAlias {
input: plan,
alias,
schema,
})
}
}
impl PartialOrd for SubqueryAlias {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
match self.input.partial_cmp(&other.input) {
Some(Ordering::Equal) => self.alias.partial_cmp(&other.alias),
cmp => cmp,
}
.filter(|cmp| *cmp != Ordering::Equal || self == other)
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
#[non_exhaustive]
pub struct Filter {
pub predicate: Expr,
pub input: Arc<LogicalPlan>,
}
impl Filter {
pub fn try_new(predicate: Expr, input: Arc<LogicalPlan>) -> Result<Self> {
Self::try_new_internal(predicate, input)
}
#[deprecated(since = "48.0.0", note = "Use `try_new` instead")]
pub fn try_new_with_having(predicate: Expr, input: Arc<LogicalPlan>) -> Result<Self> {
Self::try_new_internal(predicate, input)
}
fn is_allowed_filter_type(data_type: &DataType) -> bool {
match data_type {
DataType::Boolean | DataType::Null => true,
DataType::Dictionary(_, value_type) => {
Filter::is_allowed_filter_type(value_type.as_ref())
}
_ => false,
}
}
fn try_new_internal(predicate: Expr, input: Arc<LogicalPlan>) -> Result<Self> {
if let Ok(predicate_type) = predicate.get_type(input.schema())
&& !Filter::is_allowed_filter_type(&predicate_type)
{
return plan_err!(
"Cannot create filter with non-boolean predicate '{predicate}' returning {predicate_type}"
);
}
Ok(Self {
predicate: predicate.unalias_nested().data,
input,
})
}
fn is_scalar(&self) -> bool {
let schema = self.input.schema();
let functional_dependencies = self.input.schema().functional_dependencies();
let unique_keys = functional_dependencies.iter().filter(|dep| {
let nullable = dep.nullable
&& dep
.source_indices
.iter()
.any(|&source| schema.field(source).is_nullable());
!nullable
&& dep.mode == Dependency::Single
&& dep.target_indices.len() == schema.fields().len()
});
let exprs = split_conjunction(&self.predicate);
let eq_pred_cols: HashSet<_> = exprs
.iter()
.filter_map(|expr| {
let Expr::BinaryExpr(BinaryExpr {
left,
op: Operator::Eq,
right,
}) = expr
else {
return None;
};
if left == right {
return None;
}
match (left.as_ref(), right.as_ref()) {
(Expr::Column(_), Expr::Column(_)) => None,
(Expr::Column(c), _) | (_, Expr::Column(c)) => {
Some(schema.index_of_column(c).unwrap())
}
_ => None,
}
})
.collect();
for key in unique_keys {
if key.source_indices.iter().all(|c| eq_pred_cols.contains(c)) {
return true;
}
}
false
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Window {
pub input: Arc<LogicalPlan>,
pub window_expr: Vec<Expr>,
pub schema: DFSchemaRef,
}
impl Window {
pub fn try_new(window_expr: Vec<Expr>, input: Arc<LogicalPlan>) -> Result<Self> {
let fields: Vec<(Option<TableReference>, Arc<Field>)> = input
.schema()
.iter()
.map(|(q, f)| (q.cloned(), Arc::clone(f)))
.collect();
let input_len = fields.len();
let mut window_fields = fields;
let expr_fields = exprlist_to_fields(window_expr.as_slice(), &input)?;
window_fields.extend_from_slice(expr_fields.as_slice());
let metadata = input.schema().metadata().clone();
let mut window_func_dependencies =
input.schema().functional_dependencies().clone();
window_func_dependencies.extend_target_indices(window_fields.len());
let mut new_dependencies = window_expr
.iter()
.enumerate()
.filter_map(|(idx, expr)| {
let Expr::WindowFunction(window_fun) = expr else {
return None;
};
let WindowFunction {
fun: WindowFunctionDefinition::WindowUDF(udwf),
params: WindowFunctionParams { partition_by, .. },
} = window_fun.as_ref()
else {
return None;
};
if udwf.name() == "row_number" && partition_by.is_empty() {
Some(idx + input_len)
} else {
None
}
})
.map(|idx| {
FunctionalDependence::new(vec![idx], vec![], false)
.with_mode(Dependency::Single)
})
.collect::<Vec<_>>();
if !new_dependencies.is_empty() {
for dependence in new_dependencies.iter_mut() {
dependence.target_indices = (0..window_fields.len()).collect();
}
let new_deps = FunctionalDependencies::new(new_dependencies);
window_func_dependencies.extend(new_deps);
}
if let Some(e) = window_expr.iter().find(|e| {
matches!(
e,
Expr::WindowFunction(wf)
if !matches!(wf.fun, WindowFunctionDefinition::AggregateUDF(_))
&& wf.params.filter.is_some()
)
}) {
return plan_err!(
"FILTER clause can only be used with aggregate window functions. Found in '{e}'"
);
}
Self::try_new_with_schema(
window_expr,
input,
Arc::new(
DFSchema::new_with_metadata(window_fields, metadata)?
.with_functional_dependencies(window_func_dependencies)?,
),
)
}
pub fn try_new_with_schema(
window_expr: Vec<Expr>,
input: Arc<LogicalPlan>,
schema: DFSchemaRef,
) -> Result<Self> {
let input_fields_count = input.schema().fields().len();
if schema.fields().len() != input_fields_count + window_expr.len() {
return plan_err!(
"Window schema has wrong number of fields. Expected {} got {}",
input_fields_count + window_expr.len(),
schema.fields().len()
);
}
Ok(Window {
input,
window_expr,
schema,
})
}
}
impl PartialOrd for Window {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
match self.input.partial_cmp(&other.input)? {
Ordering::Equal => {} not_equal => return Some(not_equal),
}
match self.window_expr.partial_cmp(&other.window_expr)? {
Ordering::Equal => {} not_equal => return Some(not_equal),
}
if self == other {
Some(Ordering::Equal)
} else {
None
}
}
}
#[derive(Clone)]
pub struct TableScan {
pub table_name: TableReference,
pub source: Arc<dyn TableSource>,
pub projection: Option<Vec<usize>>,
pub projected_schema: DFSchemaRef,
pub filters: Vec<Expr>,
pub fetch: Option<usize>,
}
impl Debug for TableScan {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
f.debug_struct("TableScan")
.field("table_name", &self.table_name)
.field("source", &"...")
.field("projection", &self.projection)
.field("projected_schema", &self.projected_schema)
.field("filters", &self.filters)
.field("fetch", &self.fetch)
.finish_non_exhaustive()
}
}
impl PartialEq for TableScan {
fn eq(&self, other: &Self) -> bool {
self.table_name == other.table_name
&& self.projection == other.projection
&& self.projected_schema == other.projected_schema
&& self.filters == other.filters
&& self.fetch == other.fetch
}
}
impl Eq for TableScan {}
impl PartialOrd for TableScan {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
#[derive(PartialEq, PartialOrd)]
struct ComparableTableScan<'a> {
pub table_name: &'a TableReference,
pub projection: &'a Option<Vec<usize>>,
pub filters: &'a Vec<Expr>,
pub fetch: &'a Option<usize>,
}
let comparable_self = ComparableTableScan {
table_name: &self.table_name,
projection: &self.projection,
filters: &self.filters,
fetch: &self.fetch,
};
let comparable_other = ComparableTableScan {
table_name: &other.table_name,
projection: &other.projection,
filters: &other.filters,
fetch: &other.fetch,
};
comparable_self
.partial_cmp(&comparable_other)
.filter(|cmp| *cmp != Ordering::Equal || self == other)
}
}
impl Hash for TableScan {
fn hash<H: Hasher>(&self, state: &mut H) {
self.table_name.hash(state);
self.projection.hash(state);
self.projected_schema.hash(state);
self.filters.hash(state);
self.fetch.hash(state);
}
}
impl TableScan {
pub fn try_new(
table_name: impl Into<TableReference>,
table_source: Arc<dyn TableSource>,
projection: Option<Vec<usize>>,
filters: Vec<Expr>,
fetch: Option<usize>,
) -> Result<Self> {
let table_name = table_name.into();
if table_name.table().is_empty() {
return plan_err!("table_name cannot be empty");
}
let schema = table_source.schema();
let func_dependencies = FunctionalDependencies::new_from_constraints(
table_source.constraints(),
schema.fields.len(),
);
let projected_schema = projection
.as_ref()
.map(|p| {
let projected_func_dependencies =
func_dependencies.project_functional_dependencies(p, p.len());
let df_schema = DFSchema::new_with_metadata(
p.iter()
.map(|i| {
(Some(table_name.clone()), Arc::clone(&schema.fields()[*i]))
})
.collect(),
schema.metadata.clone(),
)?;
df_schema.with_functional_dependencies(projected_func_dependencies)
})
.unwrap_or_else(|| {
let df_schema =
DFSchema::try_from_qualified_schema(table_name.clone(), &schema)?;
df_schema.with_functional_dependencies(func_dependencies)
})?;
let projected_schema = Arc::new(projected_schema);
Ok(Self {
table_name,
source: table_source,
projection,
projected_schema,
filters,
fetch,
})
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
pub struct Repartition {
pub input: Arc<LogicalPlan>,
pub partitioning_scheme: Partitioning,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Union {
pub inputs: Vec<Arc<LogicalPlan>>,
pub schema: DFSchemaRef,
}
impl Union {
pub fn try_new(inputs: Vec<Arc<LogicalPlan>>) -> Result<Self> {
let schema = Self::derive_schema_from_inputs(&inputs, false, false)?;
Ok(Union { inputs, schema })
}
pub fn try_new_with_loose_types(inputs: Vec<Arc<LogicalPlan>>) -> Result<Self> {
let schema = Self::derive_schema_from_inputs(&inputs, true, false)?;
Ok(Union { inputs, schema })
}
pub fn try_new_by_name(inputs: Vec<Arc<LogicalPlan>>) -> Result<Self> {
let schema = Self::derive_schema_from_inputs(&inputs, true, true)?;
let inputs = Self::rewrite_inputs_from_schema(&schema, inputs)?;
Ok(Union { inputs, schema })
}
fn rewrite_inputs_from_schema(
schema: &Arc<DFSchema>,
inputs: Vec<Arc<LogicalPlan>>,
) -> Result<Vec<Arc<LogicalPlan>>> {
let schema_width = schema.iter().count();
let mut wrapped_inputs = Vec::with_capacity(inputs.len());
for input in inputs {
let mut expr = Vec::with_capacity(schema_width);
for column in schema.columns() {
if input
.schema()
.has_column_with_unqualified_name(column.name())
{
expr.push(Expr::Column(column));
} else {
expr.push(
Expr::Literal(ScalarValue::Null, None).alias(column.name()),
);
}
}
wrapped_inputs.push(Arc::new(LogicalPlan::Projection(
Projection::try_new_with_schema(expr, input, Arc::clone(schema))?,
)));
}
Ok(wrapped_inputs)
}
fn derive_schema_from_inputs(
inputs: &[Arc<LogicalPlan>],
loose_types: bool,
by_name: bool,
) -> Result<DFSchemaRef> {
if inputs.len() < 2 {
return plan_err!("UNION requires at least two inputs");
}
if by_name {
Self::derive_schema_from_inputs_by_name(inputs, loose_types)
} else {
Self::derive_schema_from_inputs_by_position(inputs, loose_types)
}
}
fn derive_schema_from_inputs_by_name(
inputs: &[Arc<LogicalPlan>],
loose_types: bool,
) -> Result<DFSchemaRef> {
type FieldData<'a> =
(&'a DataType, bool, Vec<&'a HashMap<String, String>>, usize);
let mut cols: Vec<(&str, FieldData)> = Vec::new();
for input in inputs.iter() {
for field in input.schema().fields() {
if let Some((_, (data_type, is_nullable, metadata, occurrences))) =
cols.iter_mut().find(|(name, _)| name == field.name())
{
if !loose_types && *data_type != field.data_type() {
return plan_err!(
"Found different types for field {}",
field.name()
);
}
metadata.push(field.metadata());
*is_nullable |= field.is_nullable();
*occurrences += 1;
} else {
cols.push((
field.name(),
(
field.data_type(),
field.is_nullable(),
vec![field.metadata()],
1,
),
));
}
}
}
let union_fields = cols
.into_iter()
.map(
|(name, (data_type, is_nullable, unmerged_metadata, occurrences))| {
let final_is_nullable = if occurrences == inputs.len() {
is_nullable
} else {
true
};
let mut field =
Field::new(name, data_type.clone(), final_is_nullable);
field.set_metadata(intersect_metadata_for_union(unmerged_metadata));
(None, Arc::new(field))
},
)
.collect::<Vec<(Option<TableReference>, _)>>();
let union_schema_metadata = intersect_metadata_for_union(
inputs.iter().map(|input| input.schema().metadata()),
);
let schema = DFSchema::new_with_metadata(union_fields, union_schema_metadata)?;
let schema = Arc::new(schema);
Ok(schema)
}
fn derive_schema_from_inputs_by_position(
inputs: &[Arc<LogicalPlan>],
loose_types: bool,
) -> Result<DFSchemaRef> {
let first_schema = inputs[0].schema();
let fields_count = first_schema.fields().len();
for input in inputs.iter().skip(1) {
if fields_count != input.schema().fields().len() {
return plan_err!(
"UNION queries have different number of columns: \
left has {} columns whereas right has {} columns",
fields_count,
input.schema().fields().len()
);
}
}
let mut name_counts: HashMap<String, usize> = HashMap::new();
let union_fields = (0..fields_count)
.map(|i| {
let fields = inputs
.iter()
.map(|input| input.schema().field(i))
.collect::<Vec<_>>();
let first_field = fields[0];
let base_name = first_field.name().to_string();
let data_type = if loose_types {
first_field.data_type()
} else {
fields.iter().skip(1).try_fold(
first_field.data_type(),
|acc, field| {
if acc != field.data_type() {
return plan_err!(
"UNION field {i} have different type in inputs: \
left has {} whereas right has {}",
first_field.data_type(),
field.data_type()
);
}
Ok(acc)
},
)?
};
let nullable = fields.iter().any(|field| field.is_nullable());
let name = if let Some(count) = name_counts.get_mut(&base_name) {
*count += 1;
format!("{base_name}_{count}")
} else {
name_counts.insert(base_name.clone(), 0);
base_name
};
let mut field = Field::new(&name, data_type.clone(), nullable);
let field_metadata = intersect_metadata_for_union(
fields.iter().map(|field| field.metadata()),
);
field.set_metadata(field_metadata);
Ok((None, Arc::new(field)))
})
.collect::<Result<_>>()?;
let union_schema_metadata = intersect_metadata_for_union(
inputs.iter().map(|input| input.schema().metadata()),
);
let schema = DFSchema::new_with_metadata(union_fields, union_schema_metadata)?;
let schema = Arc::new(schema);
Ok(schema)
}
}
impl PartialOrd for Union {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.inputs
.partial_cmp(&other.inputs)
.filter(|cmp| *cmp != Ordering::Equal || self == other)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct DescribeTable {
pub schema: Arc<Schema>,
pub output_schema: DFSchemaRef,
}
impl PartialOrd for DescribeTable {
fn partial_cmp(&self, _other: &Self) -> Option<Ordering> {
None
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ExplainOption {
pub verbose: bool,
pub analyze: bool,
pub format: ExplainFormat,
}
impl Default for ExplainOption {
fn default() -> Self {
ExplainOption {
verbose: false,
analyze: false,
format: ExplainFormat::Indent,
}
}
}
impl ExplainOption {
pub fn with_verbose(mut self, verbose: bool) -> Self {
self.verbose = verbose;
self
}
pub fn with_analyze(mut self, analyze: bool) -> Self {
self.analyze = analyze;
self
}
pub fn with_format(mut self, format: ExplainFormat) -> Self {
self.format = format;
self
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Explain {
pub verbose: bool,
pub explain_format: ExplainFormat,
pub plan: Arc<LogicalPlan>,
pub stringified_plans: Vec<StringifiedPlan>,
pub schema: DFSchemaRef,
pub logical_optimization_succeeded: bool,
}
impl PartialOrd for Explain {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
#[derive(PartialEq, PartialOrd)]
struct ComparableExplain<'a> {
pub verbose: &'a bool,
pub plan: &'a Arc<LogicalPlan>,
pub stringified_plans: &'a Vec<StringifiedPlan>,
pub logical_optimization_succeeded: &'a bool,
}
let comparable_self = ComparableExplain {
verbose: &self.verbose,
plan: &self.plan,
stringified_plans: &self.stringified_plans,
logical_optimization_succeeded: &self.logical_optimization_succeeded,
};
let comparable_other = ComparableExplain {
verbose: &other.verbose,
plan: &other.plan,
stringified_plans: &other.stringified_plans,
logical_optimization_succeeded: &other.logical_optimization_succeeded,
};
comparable_self
.partial_cmp(&comparable_other)
.filter(|cmp| *cmp != Ordering::Equal || self == other)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Analyze {
pub verbose: bool,
pub input: Arc<LogicalPlan>,
pub schema: DFSchemaRef,
}
impl PartialOrd for Analyze {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
match self.verbose.partial_cmp(&other.verbose) {
Some(Ordering::Equal) => self.input.partial_cmp(&other.input),
cmp => cmp,
}
.filter(|cmp| *cmp != Ordering::Equal || self == other)
}
}
#[allow(clippy::allow_attributes)]
#[allow(clippy::derived_hash_with_manual_eq)]
#[derive(Debug, Clone, Eq, Hash)]
pub struct Extension {
pub node: Arc<dyn UserDefinedLogicalNode>,
}
impl PartialEq for Extension {
fn eq(&self, other: &Self) -> bool {
self.node.eq(&other.node)
}
}
impl PartialOrd for Extension {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.node.partial_cmp(&other.node)
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
pub struct Limit {
pub skip: Option<Box<Expr>>,
pub fetch: Option<Box<Expr>>,
pub input: Arc<LogicalPlan>,
}
pub enum SkipType {
Literal(usize),
UnsupportedExpr,
}
pub enum FetchType {
Literal(Option<usize>),
UnsupportedExpr,
}
impl Limit {
pub fn get_skip_type(&self) -> Result<SkipType> {
match self.skip.as_deref() {
Some(expr) => match *expr {
Expr::Literal(ScalarValue::Int64(s), _) => {
let s = s.unwrap_or(0);
if s >= 0 {
Ok(SkipType::Literal(s as usize))
} else {
plan_err!("OFFSET must be >=0, '{}' was provided", s)
}
}
_ => Ok(SkipType::UnsupportedExpr),
},
None => Ok(SkipType::Literal(0)),
}
}
pub fn get_fetch_type(&self) -> Result<FetchType> {
match self.fetch.as_deref() {
Some(expr) => match *expr {
Expr::Literal(ScalarValue::Int64(Some(s)), _) => {
if s >= 0 {
Ok(FetchType::Literal(Some(s as usize)))
} else {
plan_err!("LIMIT must be >= 0, '{}' was provided", s)
}
}
Expr::Literal(ScalarValue::Int64(None), _) => {
Ok(FetchType::Literal(None))
}
_ => Ok(FetchType::UnsupportedExpr),
},
None => Ok(FetchType::Literal(None)),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
pub enum Distinct {
All(Arc<LogicalPlan>),
On(DistinctOn),
}
impl Distinct {
pub fn input(&self) -> &Arc<LogicalPlan> {
match self {
Distinct::All(input) => input,
Distinct::On(DistinctOn { input, .. }) => input,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct DistinctOn {
pub on_expr: Vec<Expr>,
pub select_expr: Vec<Expr>,
pub sort_expr: Option<Vec<SortExpr>>,
pub input: Arc<LogicalPlan>,
pub schema: DFSchemaRef,
}
impl DistinctOn {
pub fn try_new(
on_expr: Vec<Expr>,
select_expr: Vec<Expr>,
sort_expr: Option<Vec<SortExpr>>,
input: Arc<LogicalPlan>,
) -> Result<Self> {
if on_expr.is_empty() {
return plan_err!("No `ON` expressions provided");
}
let on_expr = normalize_cols(on_expr, input.as_ref())?;
let qualified_fields = exprlist_to_fields(select_expr.as_slice(), &input)?
.into_iter()
.collect();
let dfschema = DFSchema::new_with_metadata(
qualified_fields,
input.schema().metadata().clone(),
)?;
let mut distinct_on = DistinctOn {
on_expr,
select_expr,
sort_expr: None,
input,
schema: Arc::new(dfschema),
};
if let Some(sort_expr) = sort_expr {
distinct_on = distinct_on.with_sort_expr(sort_expr)?;
}
Ok(distinct_on)
}
pub fn with_sort_expr(mut self, sort_expr: Vec<SortExpr>) -> Result<Self> {
let sort_expr = normalize_sorts(sort_expr, self.input.as_ref())?;
let mut matched = true;
for (on, sort) in self.on_expr.iter().zip(sort_expr.iter()) {
if on != &sort.expr {
matched = false;
break;
}
}
if self.on_expr.len() > sort_expr.len() || !matched {
return plan_err!(
"SELECT DISTINCT ON expressions must match initial ORDER BY expressions"
);
}
self.sort_expr = Some(sort_expr);
Ok(self)
}
}
impl PartialOrd for DistinctOn {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
#[derive(PartialEq, PartialOrd)]
struct ComparableDistinctOn<'a> {
pub on_expr: &'a Vec<Expr>,
pub select_expr: &'a Vec<Expr>,
pub sort_expr: &'a Option<Vec<SortExpr>>,
pub input: &'a Arc<LogicalPlan>,
}
let comparable_self = ComparableDistinctOn {
on_expr: &self.on_expr,
select_expr: &self.select_expr,
sort_expr: &self.sort_expr,
input: &self.input,
};
let comparable_other = ComparableDistinctOn {
on_expr: &other.on_expr,
select_expr: &other.select_expr,
sort_expr: &other.sort_expr,
input: &other.input,
};
comparable_self
.partial_cmp(&comparable_other)
.filter(|cmp| *cmp != Ordering::Equal || self == other)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub struct Aggregate {
pub input: Arc<LogicalPlan>,
pub group_expr: Vec<Expr>,
pub aggr_expr: Vec<Expr>,
pub schema: DFSchemaRef,
}
impl Aggregate {
pub fn try_new(
input: Arc<LogicalPlan>,
group_expr: Vec<Expr>,
aggr_expr: Vec<Expr>,
) -> Result<Self> {
let group_expr = enumerate_grouping_sets(group_expr)?;
let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_)]);
let grouping_expr: Vec<&Expr> = grouping_set_to_exprlist(group_expr.as_slice())?;
let mut qualified_fields = exprlist_to_fields(grouping_expr, &input)?;
if is_grouping_set {
qualified_fields = qualified_fields
.into_iter()
.map(|(q, f)| (q, f.as_ref().clone().with_nullable(true).into()))
.collect::<Vec<_>>();
qualified_fields.push((
None,
Field::new(
Self::INTERNAL_GROUPING_ID,
Self::grouping_id_type(qualified_fields.len()),
false,
)
.into(),
));
}
qualified_fields.extend(exprlist_to_fields(aggr_expr.as_slice(), &input)?);
let schema = DFSchema::new_with_metadata(
qualified_fields,
input.schema().metadata().clone(),
)?;
Self::try_new_with_schema(input, group_expr, aggr_expr, Arc::new(schema))
}
#[expect(clippy::needless_pass_by_value)]
pub fn try_new_with_schema(
input: Arc<LogicalPlan>,
group_expr: Vec<Expr>,
aggr_expr: Vec<Expr>,
schema: DFSchemaRef,
) -> Result<Self> {
if group_expr.is_empty() && aggr_expr.is_empty() {
return plan_err!(
"Aggregate requires at least one grouping or aggregate expression. \
Aggregate without grouping expressions nor aggregate expressions is \
logically equivalent to, but less efficient than, VALUES producing \
single row. Please use VALUES instead."
);
}
let group_expr_count = grouping_set_expr_count(&group_expr)?;
if schema.fields().len() != group_expr_count + aggr_expr.len() {
return plan_err!(
"Aggregate schema has wrong number of fields. Expected {} got {}",
group_expr_count + aggr_expr.len(),
schema.fields().len()
);
}
let aggregate_func_dependencies =
calc_func_dependencies_for_aggregate(&group_expr, &input, &schema)?;
let new_schema = schema.as_ref().clone();
let schema = Arc::new(
new_schema.with_functional_dependencies(aggregate_func_dependencies)?,
);
Ok(Self {
input,
group_expr,
aggr_expr,
schema,
})
}
fn is_grouping_set(&self) -> bool {
matches!(self.group_expr.as_slice(), [Expr::GroupingSet(_)])
}
fn output_expressions(&self) -> Result<Vec<&Expr>> {
static INTERNAL_ID_EXPR: LazyLock<Expr> = LazyLock::new(|| {
Expr::Column(Column::from_name(Aggregate::INTERNAL_GROUPING_ID))
});
let mut exprs = grouping_set_to_exprlist(self.group_expr.as_slice())?;
if self.is_grouping_set() {
exprs.push(&INTERNAL_ID_EXPR);
}
exprs.extend(self.aggr_expr.iter());
debug_assert!(exprs.len() == self.schema.fields().len());
Ok(exprs)
}
pub fn group_expr_len(&self) -> Result<usize> {
grouping_set_expr_count(&self.group_expr)
}
pub fn grouping_id_type(group_exprs: usize) -> DataType {
if group_exprs <= 8 {
DataType::UInt8
} else if group_exprs <= 16 {
DataType::UInt16
} else if group_exprs <= 32 {
DataType::UInt32
} else {
DataType::UInt64
}
}
pub const INTERNAL_GROUPING_ID: &'static str = "__grouping_id";
}
impl PartialOrd for Aggregate {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
match self.input.partial_cmp(&other.input) {
Some(Ordering::Equal) => {
match self.group_expr.partial_cmp(&other.group_expr) {
Some(Ordering::Equal) => self.aggr_expr.partial_cmp(&other.aggr_expr),
cmp => cmp,
}
}
cmp => cmp,
}
.filter(|cmp| *cmp != Ordering::Equal || self == other)
}
}
fn contains_grouping_set(group_expr: &[Expr]) -> bool {
group_expr
.iter()
.any(|expr| matches!(expr, Expr::GroupingSet(_)))
}
fn calc_func_dependencies_for_aggregate(
group_expr: &[Expr],
input: &LogicalPlan,
aggr_schema: &DFSchema,
) -> Result<FunctionalDependencies> {
if !contains_grouping_set(group_expr) {
let group_by_expr_names = group_expr
.iter()
.map(|item| item.schema_name().to_string())
.collect::<IndexSet<_>>()
.into_iter()
.collect::<Vec<_>>();
let aggregate_func_dependencies = aggregate_functional_dependencies(
input.schema(),
&group_by_expr_names,
aggr_schema,
);
Ok(aggregate_func_dependencies)
} else {
Ok(FunctionalDependencies::empty())
}
}
fn calc_func_dependencies_for_project(
exprs: &[Expr],
input: &LogicalPlan,
) -> Result<FunctionalDependencies> {
let input_fields = input.schema().field_names();
let proj_indices = exprs
.iter()
.map(|expr| match expr {
#[expect(deprecated)]
Expr::Wildcard { qualifier, options } => {
let wildcard_fields = exprlist_to_fields(
vec![&Expr::Wildcard {
qualifier: qualifier.clone(),
options: options.clone(),
}],
input,
)?;
Ok::<_, DataFusionError>(
wildcard_fields
.into_iter()
.filter_map(|(qualifier, f)| {
let flat_name = qualifier
.map(|t| format!("{}.{}", t, f.name()))
.unwrap_or_else(|| f.name().clone());
input_fields.iter().position(|item| *item == flat_name)
})
.collect::<Vec<_>>(),
)
}
Expr::Alias(alias) => {
let name = format!("{}", alias.expr);
Ok(input_fields
.iter()
.position(|item| *item == name)
.map(|i| vec![i])
.unwrap_or(vec![]))
}
_ => {
let name = format!("{expr}");
Ok(input_fields
.iter()
.position(|item| *item == name)
.map(|i| vec![i])
.unwrap_or(vec![]))
}
})
.collect::<Result<Vec<_>>>()?
.into_iter()
.flatten()
.collect::<Vec<_>>();
Ok(input
.schema()
.functional_dependencies()
.project_functional_dependencies(&proj_indices, exprs.len()))
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
pub struct Sort {
pub expr: Vec<SortExpr>,
pub input: Arc<LogicalPlan>,
pub fetch: Option<usize>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Join {
pub left: Arc<LogicalPlan>,
pub right: Arc<LogicalPlan>,
pub on: Vec<(Expr, Expr)>,
pub filter: Option<Expr>,
pub join_type: JoinType,
pub join_constraint: JoinConstraint,
pub schema: DFSchemaRef,
pub null_equality: NullEquality,
pub null_aware: bool,
}
impl Join {
#[expect(clippy::too_many_arguments)]
pub fn try_new(
left: Arc<LogicalPlan>,
right: Arc<LogicalPlan>,
on: Vec<(Expr, Expr)>,
filter: Option<Expr>,
join_type: JoinType,
join_constraint: JoinConstraint,
null_equality: NullEquality,
null_aware: bool,
) -> Result<Self> {
let join_schema = build_join_schema(left.schema(), right.schema(), &join_type)?;
Ok(Join {
left,
right,
on,
filter,
join_type,
join_constraint,
schema: Arc::new(join_schema),
null_equality,
null_aware,
})
}
pub fn try_new_with_project_input(
original: &LogicalPlan,
left: Arc<LogicalPlan>,
right: Arc<LogicalPlan>,
column_on: (Vec<Column>, Vec<Column>),
) -> Result<(Self, bool)> {
let original_join = match original {
LogicalPlan::Join(join) => join,
_ => return plan_err!("Could not create join with project input"),
};
let mut left_sch = LogicalPlanBuilder::from(Arc::clone(&left));
let mut right_sch = LogicalPlanBuilder::from(Arc::clone(&right));
let mut requalified = false;
if original_join.join_type == JoinType::Inner
|| original_join.join_type == JoinType::Left
|| original_join.join_type == JoinType::Right
|| original_join.join_type == JoinType::Full
{
(left_sch, right_sch, requalified) =
requalify_sides_if_needed(left_sch.clone(), right_sch.clone())?;
}
let on: Vec<(Expr, Expr)> = column_on
.0
.into_iter()
.zip(column_on.1)
.map(|(l, r)| (Expr::Column(l), Expr::Column(r)))
.collect();
let join_schema = build_join_schema(
left_sch.schema(),
right_sch.schema(),
&original_join.join_type,
)?;
Ok((
Join {
left,
right,
on,
filter: original_join.filter.clone(),
join_type: original_join.join_type,
join_constraint: original_join.join_constraint,
schema: Arc::new(join_schema),
null_equality: original_join.null_equality,
null_aware: original_join.null_aware,
},
requalified,
))
}
}
impl PartialOrd for Join {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
#[derive(PartialEq, PartialOrd)]
struct ComparableJoin<'a> {
pub left: &'a Arc<LogicalPlan>,
pub right: &'a Arc<LogicalPlan>,
pub on: &'a Vec<(Expr, Expr)>,
pub filter: &'a Option<Expr>,
pub join_type: &'a JoinType,
pub join_constraint: &'a JoinConstraint,
pub null_equality: &'a NullEquality,
}
let comparable_self = ComparableJoin {
left: &self.left,
right: &self.right,
on: &self.on,
filter: &self.filter,
join_type: &self.join_type,
join_constraint: &self.join_constraint,
null_equality: &self.null_equality,
};
let comparable_other = ComparableJoin {
left: &other.left,
right: &other.right,
on: &other.on,
filter: &other.filter,
join_type: &other.join_type,
join_constraint: &other.join_constraint,
null_equality: &other.null_equality,
};
comparable_self
.partial_cmp(&comparable_other)
.filter(|cmp| *cmp != Ordering::Equal || self == other)
}
}
#[derive(Clone, PartialEq, Eq, PartialOrd, Hash)]
pub struct Subquery {
pub subquery: Arc<LogicalPlan>,
pub outer_ref_columns: Vec<Expr>,
pub spans: Spans,
}
impl Normalizeable for Subquery {
fn can_normalize(&self) -> bool {
false
}
}
impl NormalizeEq for Subquery {
fn normalize_eq(&self, other: &Self) -> bool {
*self.subquery == *other.subquery
&& self.outer_ref_columns.len() == other.outer_ref_columns.len()
&& self
.outer_ref_columns
.iter()
.zip(other.outer_ref_columns.iter())
.all(|(a, b)| a.normalize_eq(b))
}
}
impl Subquery {
pub fn try_from_expr(plan: &Expr) -> Result<&Subquery> {
match plan {
Expr::ScalarSubquery(it) => Ok(it),
Expr::Cast(cast) => Subquery::try_from_expr(cast.expr.as_ref()),
_ => plan_err!("Could not coerce into ScalarSubquery!"),
}
}
pub fn with_plan(&self, plan: Arc<LogicalPlan>) -> Subquery {
Subquery {
subquery: plan,
outer_ref_columns: self.outer_ref_columns.clone(),
spans: Spans::new(),
}
}
}
impl Debug for Subquery {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "<subquery>")
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
pub enum Partitioning {
RoundRobinBatch(usize),
Hash(Vec<Expr>, usize),
DistributeBy(Vec<Expr>),
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd)]
pub struct ColumnUnnestList {
pub output_column: Column,
pub depth: usize,
}
impl Display for ColumnUnnestList {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "{}|depth={}", self.output_column, self.depth)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Unnest {
pub input: Arc<LogicalPlan>,
pub exec_columns: Vec<Column>,
pub list_type_columns: Vec<(usize, ColumnUnnestList)>,
pub struct_type_columns: Vec<usize>,
pub dependency_indices: Vec<usize>,
pub schema: DFSchemaRef,
pub options: UnnestOptions,
}
impl PartialOrd for Unnest {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
#[derive(PartialEq, PartialOrd)]
struct ComparableUnnest<'a> {
pub input: &'a Arc<LogicalPlan>,
pub exec_columns: &'a Vec<Column>,
pub list_type_columns: &'a Vec<(usize, ColumnUnnestList)>,
pub struct_type_columns: &'a Vec<usize>,
pub dependency_indices: &'a Vec<usize>,
pub options: &'a UnnestOptions,
}
let comparable_self = ComparableUnnest {
input: &self.input,
exec_columns: &self.exec_columns,
list_type_columns: &self.list_type_columns,
struct_type_columns: &self.struct_type_columns,
dependency_indices: &self.dependency_indices,
options: &self.options,
};
let comparable_other = ComparableUnnest {
input: &other.input,
exec_columns: &other.exec_columns,
list_type_columns: &other.list_type_columns,
struct_type_columns: &other.struct_type_columns,
dependency_indices: &other.dependency_indices,
options: &other.options,
};
comparable_self
.partial_cmp(&comparable_other)
.filter(|cmp| *cmp != Ordering::Equal || self == other)
}
}
impl Unnest {
pub fn try_new(
input: Arc<LogicalPlan>,
exec_columns: Vec<Column>,
options: UnnestOptions,
) -> Result<Self> {
if exec_columns.is_empty() {
return plan_err!("unnest plan requires at least 1 column to unnest");
}
let mut list_columns: Vec<(usize, ColumnUnnestList)> = vec![];
let mut struct_columns = vec![];
let indices_to_unnest = exec_columns
.iter()
.map(|c| Ok((input.schema().index_of_column(c)?, c)))
.collect::<Result<HashMap<usize, &Column>>>()?;
let input_schema = input.schema();
let mut dependency_indices = vec![];
let fields = input_schema
.iter()
.enumerate()
.map(|(index, (original_qualifier, original_field))| {
match indices_to_unnest.get(&index) {
Some(column_to_unnest) => {
let recursions_on_column = options
.recursions
.iter()
.filter(|p| -> bool { &p.input_column == *column_to_unnest })
.collect::<Vec<_>>();
let mut transformed_columns = recursions_on_column
.iter()
.map(|r| {
list_columns.push((
index,
ColumnUnnestList {
output_column: r.output_column.clone(),
depth: r.depth,
},
));
Ok(get_unnested_columns(
&r.output_column.name,
original_field.data_type(),
r.depth,
)?
.into_iter()
.next()
.unwrap()) })
.collect::<Result<Vec<(Column, Arc<Field>)>>>()?;
if transformed_columns.is_empty() {
transformed_columns = get_unnested_columns(
&column_to_unnest.name,
original_field.data_type(),
1,
)?;
match original_field.data_type() {
DataType::Struct(_) => {
struct_columns.push(index);
}
DataType::List(_)
| DataType::FixedSizeList(_, _)
| DataType::LargeList(_) => {
list_columns.push((
index,
ColumnUnnestList {
output_column: Column::from_name(
&column_to_unnest.name,
),
depth: 1,
},
));
}
_ => {}
};
}
dependency_indices.extend(std::iter::repeat_n(
index,
transformed_columns.len(),
));
Ok(transformed_columns
.iter()
.map(|(col, field)| {
(col.relation.to_owned(), field.to_owned())
})
.collect())
}
None => {
dependency_indices.push(index);
Ok(vec![(
original_qualifier.cloned(),
Arc::clone(original_field),
)])
}
}
})
.collect::<Result<Vec<_>>>()?
.into_iter()
.flatten()
.collect::<Vec<_>>();
let metadata = input_schema.metadata().clone();
let df_schema = DFSchema::new_with_metadata(fields, metadata)?;
let deps = input_schema.functional_dependencies().clone();
let schema = Arc::new(df_schema.with_functional_dependencies(deps)?);
Ok(Unnest {
input,
exec_columns,
list_type_columns: list_columns,
struct_type_columns: struct_columns,
dependency_indices,
schema,
options,
})
}
}
fn get_unnested_columns(
col_name: &String,
data_type: &DataType,
depth: usize,
) -> Result<Vec<(Column, Arc<Field>)>> {
let mut qualified_columns = Vec::with_capacity(1);
match data_type {
DataType::List(_) | DataType::FixedSizeList(_, _) | DataType::LargeList(_) => {
let data_type = get_unnested_list_datatype_recursive(data_type, depth)?;
let new_field = Arc::new(Field::new(
col_name, data_type,
true,
));
let column = Column::from_name(col_name);
qualified_columns.push((column, new_field));
}
DataType::Struct(fields) => {
qualified_columns.extend(fields.iter().map(|f| {
let new_name = format!("{}.{}", col_name, f.name());
let column = Column::from_name(&new_name);
let new_field = f.as_ref().clone().with_name(new_name);
(column, Arc::new(new_field))
}))
}
_ => {
return internal_err!("trying to unnest on invalid data type {data_type}");
}
};
Ok(qualified_columns)
}
fn get_unnested_list_datatype_recursive(
data_type: &DataType,
depth: usize,
) -> Result<DataType> {
match data_type {
DataType::List(field)
| DataType::FixedSizeList(field, _)
| DataType::LargeList(field) => {
if depth == 1 {
return Ok(field.data_type().clone());
}
return get_unnested_list_datatype_recursive(field.data_type(), depth - 1);
}
_ => {}
};
internal_err!("trying to unnest on invalid data type {data_type}")
}
#[cfg(test)]
mod tests {
use super::*;
use crate::builder::LogicalTableSource;
use crate::logical_plan::table_scan;
use crate::select_expr::SelectExpr;
use crate::test::function_stub::{count, count_udaf};
use crate::{
GroupingSet, binary_expr, col, exists, in_subquery, lit, placeholder,
scalar_subquery,
};
use datafusion_common::metadata::ScalarAndMetadata;
use datafusion_common::tree_node::{
TransformedResult, TreeNodeRewriter, TreeNodeVisitor,
};
use datafusion_common::{Constraint, ScalarValue, not_impl_err};
use insta::{assert_debug_snapshot, assert_snapshot};
use std::hash::DefaultHasher;
fn employee_schema() -> Schema {
Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("first_name", DataType::Utf8, false),
Field::new("last_name", DataType::Utf8, false),
Field::new("state", DataType::Utf8, false),
Field::new("salary", DataType::Int32, false),
])
}
fn display_plan() -> Result<LogicalPlan> {
let plan1 = table_scan(Some("employee_csv"), &employee_schema(), Some(vec![3]))?
.build()?;
table_scan(Some("employee_csv"), &employee_schema(), Some(vec![0, 3]))?
.filter(in_subquery(col("state"), Arc::new(plan1)))?
.project(vec![col("id")])?
.build()
}
#[test]
fn test_display_indent() -> Result<()> {
let plan = display_plan()?;
assert_snapshot!(plan.display_indent(), @r"
Projection: employee_csv.id
Filter: employee_csv.state IN (<subquery>)
Subquery:
TableScan: employee_csv projection=[state]
TableScan: employee_csv projection=[id, state]
");
Ok(())
}
#[test]
fn test_display_indent_schema() -> Result<()> {
let plan = display_plan()?;
assert_snapshot!(plan.display_indent_schema(), @r"
Projection: employee_csv.id [id:Int32]
Filter: employee_csv.state IN (<subquery>) [id:Int32, state:Utf8]
Subquery: [state:Utf8]
TableScan: employee_csv projection=[state] [state:Utf8]
TableScan: employee_csv projection=[id, state] [id:Int32, state:Utf8]
");
Ok(())
}
#[test]
fn test_display_subquery_alias() -> Result<()> {
let plan1 = table_scan(Some("employee_csv"), &employee_schema(), Some(vec![3]))?
.build()?;
let plan1 = Arc::new(plan1);
let plan =
table_scan(Some("employee_csv"), &employee_schema(), Some(vec![0, 3]))?
.project(vec![col("id"), exists(plan1).alias("exists")])?
.build();
assert_snapshot!(plan?.display_indent(), @r"
Projection: employee_csv.id, EXISTS (<subquery>) AS exists
Subquery:
TableScan: employee_csv projection=[state]
TableScan: employee_csv projection=[id, state]
");
Ok(())
}
#[test]
fn test_display_graphviz() -> Result<()> {
let plan = display_plan()?;
assert_snapshot!(plan.display_graphviz(), @r#"
// Begin DataFusion GraphViz Plan,
// display it online here: https://dreampuf.github.io/GraphvizOnline
digraph {
subgraph cluster_1
{
graph[label="LogicalPlan"]
2[shape=box label="Projection: employee_csv.id"]
3[shape=box label="Filter: employee_csv.state IN (<subquery>)"]
2 -> 3 [arrowhead=none, arrowtail=normal, dir=back]
4[shape=box label="Subquery:"]
3 -> 4 [arrowhead=none, arrowtail=normal, dir=back]
5[shape=box label="TableScan: employee_csv projection=[state]"]
4 -> 5 [arrowhead=none, arrowtail=normal, dir=back]
6[shape=box label="TableScan: employee_csv projection=[id, state]"]
3 -> 6 [arrowhead=none, arrowtail=normal, dir=back]
}
subgraph cluster_7
{
graph[label="Detailed LogicalPlan"]
8[shape=box label="Projection: employee_csv.id\nSchema: [id:Int32]"]
9[shape=box label="Filter: employee_csv.state IN (<subquery>)\nSchema: [id:Int32, state:Utf8]"]
8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]
10[shape=box label="Subquery:\nSchema: [state:Utf8]"]
9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]
11[shape=box label="TableScan: employee_csv projection=[state]\nSchema: [state:Utf8]"]
10 -> 11 [arrowhead=none, arrowtail=normal, dir=back]
12[shape=box label="TableScan: employee_csv projection=[id, state]\nSchema: [id:Int32, state:Utf8]"]
9 -> 12 [arrowhead=none, arrowtail=normal, dir=back]
}
}
// End DataFusion GraphViz Plan
"#);
Ok(())
}
#[test]
fn test_display_pg_json() -> Result<()> {
let plan = display_plan()?;
assert_snapshot!(plan.display_pg_json(), @r#"
[
{
"Plan": {
"Expressions": [
"employee_csv.id"
],
"Node Type": "Projection",
"Output": [
"id"
],
"Plans": [
{
"Condition": "employee_csv.state IN (<subquery>)",
"Node Type": "Filter",
"Output": [
"id",
"state"
],
"Plans": [
{
"Node Type": "Subquery",
"Output": [
"state"
],
"Plans": [
{
"Node Type": "TableScan",
"Output": [
"state"
],
"Plans": [],
"Relation Name": "employee_csv"
}
]
},
{
"Node Type": "TableScan",
"Output": [
"id",
"state"
],
"Plans": [],
"Relation Name": "employee_csv"
}
]
}
]
}
}
]
"#);
Ok(())
}
#[derive(Debug, Default)]
struct OkVisitor {
strings: Vec<String>,
}
impl<'n> TreeNodeVisitor<'n> for OkVisitor {
type Node = LogicalPlan;
fn f_down(&mut self, plan: &'n LogicalPlan) -> Result<TreeNodeRecursion> {
let s = match plan {
LogicalPlan::Projection { .. } => "pre_visit Projection",
LogicalPlan::Filter { .. } => "pre_visit Filter",
LogicalPlan::TableScan { .. } => "pre_visit TableScan",
_ => {
return not_impl_err!("unknown plan type");
}
};
self.strings.push(s.into());
Ok(TreeNodeRecursion::Continue)
}
fn f_up(&mut self, plan: &'n LogicalPlan) -> Result<TreeNodeRecursion> {
let s = match plan {
LogicalPlan::Projection { .. } => "post_visit Projection",
LogicalPlan::Filter { .. } => "post_visit Filter",
LogicalPlan::TableScan { .. } => "post_visit TableScan",
_ => {
return not_impl_err!("unknown plan type");
}
};
self.strings.push(s.into());
Ok(TreeNodeRecursion::Continue)
}
}
#[test]
fn visit_order() {
let mut visitor = OkVisitor::default();
let plan = test_plan();
let res = plan.visit_with_subqueries(&mut visitor);
assert!(res.is_ok());
assert_debug_snapshot!(visitor.strings, @r#"
[
"pre_visit Projection",
"pre_visit Filter",
"pre_visit TableScan",
"post_visit TableScan",
"post_visit Filter",
"post_visit Projection",
]
"#);
}
#[derive(Debug, Default)]
struct OptionalCounter {
val: Option<usize>,
}
impl OptionalCounter {
fn new(val: usize) -> Self {
Self { val: Some(val) }
}
fn dec(&mut self) -> bool {
if Some(0) == self.val {
true
} else {
self.val = self.val.take().map(|i| i - 1);
false
}
}
}
#[derive(Debug, Default)]
struct StoppingVisitor {
inner: OkVisitor,
return_false_from_pre_in: OptionalCounter,
return_false_from_post_in: OptionalCounter,
}
impl<'n> TreeNodeVisitor<'n> for StoppingVisitor {
type Node = LogicalPlan;
fn f_down(&mut self, plan: &'n LogicalPlan) -> Result<TreeNodeRecursion> {
if self.return_false_from_pre_in.dec() {
return Ok(TreeNodeRecursion::Stop);
}
self.inner.f_down(plan)?;
Ok(TreeNodeRecursion::Continue)
}
fn f_up(&mut self, plan: &'n LogicalPlan) -> Result<TreeNodeRecursion> {
if self.return_false_from_post_in.dec() {
return Ok(TreeNodeRecursion::Stop);
}
self.inner.f_up(plan)
}
}
#[test]
fn early_stopping_pre_visit() {
let mut visitor = StoppingVisitor {
return_false_from_pre_in: OptionalCounter::new(2),
..Default::default()
};
let plan = test_plan();
let res = plan.visit_with_subqueries(&mut visitor);
assert!(res.is_ok());
assert_debug_snapshot!(
visitor.inner.strings,
@r#"
[
"pre_visit Projection",
"pre_visit Filter",
]
"#
);
}
#[test]
fn early_stopping_post_visit() {
let mut visitor = StoppingVisitor {
return_false_from_post_in: OptionalCounter::new(1),
..Default::default()
};
let plan = test_plan();
let res = plan.visit_with_subqueries(&mut visitor);
assert!(res.is_ok());
assert_debug_snapshot!(
visitor.inner.strings,
@r#"
[
"pre_visit Projection",
"pre_visit Filter",
"pre_visit TableScan",
"post_visit TableScan",
]
"#
);
}
#[derive(Debug, Default)]
struct ErrorVisitor {
inner: OkVisitor,
return_error_from_pre_in: OptionalCounter,
return_error_from_post_in: OptionalCounter,
}
impl<'n> TreeNodeVisitor<'n> for ErrorVisitor {
type Node = LogicalPlan;
fn f_down(&mut self, plan: &'n LogicalPlan) -> Result<TreeNodeRecursion> {
if self.return_error_from_pre_in.dec() {
return not_impl_err!("Error in pre_visit");
}
self.inner.f_down(plan)
}
fn f_up(&mut self, plan: &'n LogicalPlan) -> Result<TreeNodeRecursion> {
if self.return_error_from_post_in.dec() {
return not_impl_err!("Error in post_visit");
}
self.inner.f_up(plan)
}
}
#[test]
fn error_pre_visit() {
let mut visitor = ErrorVisitor {
return_error_from_pre_in: OptionalCounter::new(2),
..Default::default()
};
let plan = test_plan();
let res = plan.visit_with_subqueries(&mut visitor).unwrap_err();
assert_snapshot!(
res.strip_backtrace(),
@"This feature is not implemented: Error in pre_visit"
);
assert_debug_snapshot!(
visitor.inner.strings,
@r#"
[
"pre_visit Projection",
"pre_visit Filter",
]
"#
);
}
#[test]
fn error_post_visit() {
let mut visitor = ErrorVisitor {
return_error_from_post_in: OptionalCounter::new(1),
..Default::default()
};
let plan = test_plan();
let res = plan.visit_with_subqueries(&mut visitor).unwrap_err();
assert_snapshot!(
res.strip_backtrace(),
@"This feature is not implemented: Error in post_visit"
);
assert_debug_snapshot!(
visitor.inner.strings,
@r#"
[
"pre_visit Projection",
"pre_visit Filter",
"pre_visit TableScan",
"post_visit TableScan",
]
"#
);
}
#[test]
fn test_partial_eq_hash_and_partial_ord() {
let empty_values = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: true,
schema: Arc::new(DFSchema::empty()),
}));
let count_window_function = |schema| {
Window::try_new_with_schema(
vec![Expr::WindowFunction(Box::new(WindowFunction::new(
WindowFunctionDefinition::AggregateUDF(count_udaf()),
vec![],
)))],
Arc::clone(&empty_values),
Arc::new(schema),
)
.unwrap()
};
let schema_without_metadata = || {
DFSchema::from_unqualified_fields(
vec![Field::new("count", DataType::Int64, false)].into(),
HashMap::new(),
)
.unwrap()
};
let schema_with_metadata = || {
DFSchema::from_unqualified_fields(
vec![Field::new("count", DataType::Int64, false)].into(),
[("key".to_string(), "value".to_string())].into(),
)
.unwrap()
};
let f = count_window_function(schema_without_metadata());
let f2 = count_window_function(schema_without_metadata());
assert_eq!(f, f2);
assert_eq!(hash(&f), hash(&f2));
assert_eq!(f.partial_cmp(&f2), Some(Ordering::Equal));
let o = count_window_function(schema_with_metadata());
assert_ne!(f, o);
assert_ne!(hash(&f), hash(&o)); assert_eq!(f.partial_cmp(&o), None);
}
fn hash<T: Hash>(value: &T) -> u64 {
let hasher = &mut DefaultHasher::new();
value.hash(hasher);
hasher.finish()
}
#[test]
fn projection_expr_schema_mismatch() -> Result<()> {
let empty_schema = Arc::new(DFSchema::empty());
let p = Projection::try_new_with_schema(
vec![col("a")],
Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema: Arc::clone(&empty_schema),
})),
empty_schema,
);
assert_snapshot!(p.unwrap_err().strip_backtrace(), @"Error during planning: Projection has mismatch between number of expressions (1) and number of fields in schema (0)");
Ok(())
}
fn test_plan() -> LogicalPlan {
let schema = Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("state", DataType::Utf8, false),
]);
table_scan(TableReference::none(), &schema, Some(vec![0, 1]))
.unwrap()
.filter(col("state").eq(lit("CO")))
.unwrap()
.project(vec![col("id")])
.unwrap()
.build()
.unwrap()
}
#[test]
fn test_replace_invalid_placeholder() {
let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]);
let plan = table_scan(TableReference::none(), &schema, None)
.unwrap()
.filter(col("id").eq(placeholder("")))
.unwrap()
.build()
.unwrap();
let param_values = vec![ScalarValue::Int32(Some(42))];
plan.replace_params_with_values(¶m_values.clone().into())
.expect_err("unexpectedly succeeded to replace an invalid placeholder");
let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]);
let plan = table_scan(TableReference::none(), &schema, None)
.unwrap()
.filter(col("id").eq(placeholder("$0")))
.unwrap()
.build()
.unwrap();
plan.replace_params_with_values(¶m_values.clone().into())
.expect_err("unexpectedly succeeded to replace an invalid placeholder");
let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]);
let plan = table_scan(TableReference::none(), &schema, None)
.unwrap()
.filter(col("id").eq(placeholder("$00")))
.unwrap()
.build()
.unwrap();
plan.replace_params_with_values(¶m_values.into())
.expect_err("unexpectedly succeeded to replace an invalid placeholder");
}
#[test]
fn test_replace_placeholder_mismatched_metadata() {
let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]);
let plan = table_scan(TableReference::none(), &schema, None)
.unwrap()
.filter(col("id").eq(placeholder("$1")))
.unwrap()
.build()
.unwrap();
let prepared_builder = LogicalPlanBuilder::new(plan)
.prepare(
"".to_string(),
vec![Field::new("", DataType::Int32, true).into()],
)
.unwrap();
let mut scalar_meta = HashMap::new();
scalar_meta.insert("some_key".to_string(), "some_value".to_string());
let param_values = ParamValues::List(vec![ScalarAndMetadata::new(
ScalarValue::Int32(Some(42)),
Some(scalar_meta.into()),
)]);
prepared_builder
.plan()
.clone()
.with_param_values(param_values)
.expect_err("prepared field metadata mismatch unexpectedly succeeded");
}
#[test]
fn test_replace_placeholder_empty_relation_valid_schema() {
let plan = LogicalPlanBuilder::empty(false)
.project(vec![
SelectExpr::from(placeholder("$1")),
SelectExpr::from(placeholder("$2")),
])
.unwrap()
.build()
.unwrap();
assert_snapshot!(plan.display_indent_schema(), @r"
Projection: $1, $2 [$1:Null;N, $2:Null;N]
EmptyRelation: rows=0 []
");
let plan = plan
.with_param_values(vec![ScalarValue::from(1i32), ScalarValue::from("s")])
.unwrap();
assert_snapshot!(plan.display_indent_schema(), @r#"
Projection: Int32(1) AS $1, Utf8("s") AS $2 [$1:Int32, $2:Utf8]
EmptyRelation: rows=0 []
"#);
}
#[test]
fn test_nullable_schema_after_grouping_set() {
let schema = Schema::new(vec![
Field::new("foo", DataType::Int32, false),
Field::new("bar", DataType::Int32, false),
]);
let plan = table_scan(TableReference::none(), &schema, None)
.unwrap()
.aggregate(
vec![Expr::GroupingSet(GroupingSet::GroupingSets(vec![
vec![col("foo")],
vec![col("bar")],
]))],
vec![count(lit(true))],
)
.unwrap()
.build()
.unwrap();
let output_schema = plan.schema();
assert!(
output_schema
.field_with_name(None, "foo")
.unwrap()
.is_nullable(),
);
assert!(
output_schema
.field_with_name(None, "bar")
.unwrap()
.is_nullable()
);
}
#[test]
fn test_filter_is_scalar() {
let schema =
Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
let source = Arc::new(LogicalTableSource::new(schema));
let schema = Arc::new(
DFSchema::try_from_qualified_schema(
TableReference::bare("tab"),
&source.schema(),
)
.unwrap(),
);
let scan = Arc::new(LogicalPlan::TableScan(TableScan {
table_name: TableReference::bare("tab"),
source: Arc::clone(&source) as Arc<dyn TableSource>,
projection: None,
projected_schema: Arc::clone(&schema),
filters: vec![],
fetch: None,
}));
let col = schema.field_names()[0].clone();
let filter = Filter::try_new(
Expr::Column(col.into()).eq(Expr::Literal(ScalarValue::Int32(Some(1)), None)),
scan,
)
.unwrap();
assert!(!filter.is_scalar());
let unique_schema = Arc::new(
schema
.as_ref()
.clone()
.with_functional_dependencies(
FunctionalDependencies::new_from_constraints(
Some(&Constraints::new_unverified(vec![Constraint::Unique(
vec![0],
)])),
1,
),
)
.unwrap(),
);
let scan = Arc::new(LogicalPlan::TableScan(TableScan {
table_name: TableReference::bare("tab"),
source,
projection: None,
projected_schema: Arc::clone(&unique_schema),
filters: vec![],
fetch: None,
}));
let col = schema.field_names()[0].clone();
let filter =
Filter::try_new(Expr::Column(col.into()).eq(lit(1i32)), scan).unwrap();
assert!(filter.is_scalar());
}
#[test]
fn test_transform_explain() {
let schema = Schema::new(vec![
Field::new("foo", DataType::Int32, false),
Field::new("bar", DataType::Int32, false),
]);
let plan = table_scan(TableReference::none(), &schema, None)
.unwrap()
.explain(false, false)
.unwrap()
.build()
.unwrap();
let external_filter = col("foo").eq(lit(true));
let plan = plan
.transform(|plan| match plan {
LogicalPlan::TableScan(table) => {
let filter = Filter::try_new(
external_filter.clone(),
Arc::new(LogicalPlan::TableScan(table)),
)
.unwrap();
Ok(Transformed::yes(LogicalPlan::Filter(filter)))
}
x => Ok(Transformed::no(x)),
})
.data()
.unwrap();
let actual = format!("{}", plan.display_indent());
assert_snapshot!(actual, @r"
Explain
Filter: foo = Boolean(true)
TableScan: ?table?
")
}
#[test]
fn test_plan_partial_ord() {
let empty_relation = LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema: Arc::new(DFSchema::empty()),
});
let describe_table = LogicalPlan::DescribeTable(DescribeTable {
schema: Arc::new(Schema::new(vec![Field::new(
"foo",
DataType::Int32,
false,
)])),
output_schema: DFSchemaRef::new(DFSchema::empty()),
});
let describe_table_clone = LogicalPlan::DescribeTable(DescribeTable {
schema: Arc::new(Schema::new(vec![Field::new(
"foo",
DataType::Int32,
false,
)])),
output_schema: DFSchemaRef::new(DFSchema::empty()),
});
assert_eq!(
empty_relation.partial_cmp(&describe_table),
Some(Ordering::Less)
);
assert_eq!(
describe_table.partial_cmp(&empty_relation),
Some(Ordering::Greater)
);
assert_eq!(describe_table.partial_cmp(&describe_table_clone), None);
}
#[test]
fn test_limit_with_new_children() {
let input = Arc::new(LogicalPlan::Values(Values {
schema: Arc::new(DFSchema::empty()),
values: vec![vec![]],
}));
let cases = [
LogicalPlan::Limit(Limit {
skip: None,
fetch: None,
input: Arc::clone(&input),
}),
LogicalPlan::Limit(Limit {
skip: None,
fetch: Some(Box::new(Expr::Literal(
ScalarValue::new_ten(&DataType::UInt32).unwrap(),
None,
))),
input: Arc::clone(&input),
}),
LogicalPlan::Limit(Limit {
skip: Some(Box::new(Expr::Literal(
ScalarValue::new_ten(&DataType::UInt32).unwrap(),
None,
))),
fetch: None,
input: Arc::clone(&input),
}),
LogicalPlan::Limit(Limit {
skip: Some(Box::new(Expr::Literal(
ScalarValue::new_one(&DataType::UInt32).unwrap(),
None,
))),
fetch: Some(Box::new(Expr::Literal(
ScalarValue::new_ten(&DataType::UInt32).unwrap(),
None,
))),
input,
}),
];
for limit in cases {
let new_limit = limit
.with_new_exprs(
limit.expressions(),
limit.inputs().into_iter().cloned().collect(),
)
.unwrap();
assert_eq!(limit, new_limit);
}
}
#[test]
fn test_with_subqueries_jump() {
let subquery_schema =
Schema::new(vec![Field::new("sub_id", DataType::Int32, false)]);
let subquery_plan =
table_scan(TableReference::none(), &subquery_schema, Some(vec![0]))
.unwrap()
.filter(col("sub_id").eq(lit(0)))
.unwrap()
.build()
.unwrap();
let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]);
let plan = table_scan(TableReference::none(), &schema, Some(vec![0]))
.unwrap()
.filter(col("id").eq(lit(0)))
.unwrap()
.project(vec![col("id"), scalar_subquery(Arc::new(subquery_plan))])
.unwrap()
.build()
.unwrap();
let mut filter_found = false;
plan.apply_with_subqueries(|plan| {
match plan {
LogicalPlan::Projection(..) => return Ok(TreeNodeRecursion::Jump),
LogicalPlan::Filter(..) => filter_found = true,
_ => {}
}
Ok(TreeNodeRecursion::Continue)
})
.unwrap();
assert!(!filter_found);
struct ProjectJumpVisitor {
filter_found: bool,
}
impl ProjectJumpVisitor {
fn new() -> Self {
Self {
filter_found: false,
}
}
}
impl<'n> TreeNodeVisitor<'n> for ProjectJumpVisitor {
type Node = LogicalPlan;
fn f_down(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
match node {
LogicalPlan::Projection(..) => return Ok(TreeNodeRecursion::Jump),
LogicalPlan::Filter(..) => self.filter_found = true,
_ => {}
}
Ok(TreeNodeRecursion::Continue)
}
}
let mut visitor = ProjectJumpVisitor::new();
plan.visit_with_subqueries(&mut visitor).unwrap();
assert!(!visitor.filter_found);
let mut filter_found = false;
plan.clone()
.transform_down_with_subqueries(|plan| {
match plan {
LogicalPlan::Projection(..) => {
return Ok(Transformed::new(
plan,
false,
TreeNodeRecursion::Jump,
));
}
LogicalPlan::Filter(..) => filter_found = true,
_ => {}
}
Ok(Transformed::no(plan))
})
.unwrap();
assert!(!filter_found);
let mut filter_found = false;
plan.clone()
.transform_down_up_with_subqueries(
|plan| {
match plan {
LogicalPlan::Projection(..) => {
return Ok(Transformed::new(
plan,
false,
TreeNodeRecursion::Jump,
));
}
LogicalPlan::Filter(..) => filter_found = true,
_ => {}
}
Ok(Transformed::no(plan))
},
|plan| Ok(Transformed::no(plan)),
)
.unwrap();
assert!(!filter_found);
struct ProjectJumpRewriter {
filter_found: bool,
}
impl ProjectJumpRewriter {
fn new() -> Self {
Self {
filter_found: false,
}
}
}
impl TreeNodeRewriter for ProjectJumpRewriter {
type Node = LogicalPlan;
fn f_down(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
match node {
LogicalPlan::Projection(..) => {
return Ok(Transformed::new(
node,
false,
TreeNodeRecursion::Jump,
));
}
LogicalPlan::Filter(..) => self.filter_found = true,
_ => {}
}
Ok(Transformed::no(node))
}
}
let mut rewriter = ProjectJumpRewriter::new();
plan.rewrite_with_subqueries(&mut rewriter).unwrap();
assert!(!rewriter.filter_found);
}
#[test]
fn test_with_unresolved_placeholders() {
let field_name = "id";
let placeholder_value = "$1";
let schema = Schema::new(vec![Field::new(field_name, DataType::Int32, false)]);
let plan = table_scan(TableReference::none(), &schema, None)
.unwrap()
.filter(col(field_name).eq(placeholder(placeholder_value)))
.unwrap()
.build()
.unwrap();
let params = plan.get_parameter_fields().unwrap();
assert_eq!(params.len(), 1);
let parameter_type = params.clone().get(placeholder_value).unwrap().clone();
assert_eq!(parameter_type, None);
}
#[test]
fn test_join_with_new_exprs() -> Result<()> {
fn create_test_join(
on: Vec<(Expr, Expr)>,
filter: Option<Expr>,
) -> Result<LogicalPlan> {
let schema = Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
]);
let left_schema = DFSchema::try_from_qualified_schema("t1", &schema)?;
let right_schema = DFSchema::try_from_qualified_schema("t2", &schema)?;
Ok(LogicalPlan::Join(Join {
left: Arc::new(
table_scan(Some("t1"), left_schema.as_arrow(), None)?.build()?,
),
right: Arc::new(
table_scan(Some("t2"), right_schema.as_arrow(), None)?.build()?,
),
on,
filter,
join_type: JoinType::Inner,
join_constraint: JoinConstraint::On,
schema: Arc::new(left_schema.join(&right_schema)?),
null_equality: NullEquality::NullEqualsNothing,
null_aware: false,
}))
}
{
let join = create_test_join(vec![(col("t1.a"), (col("t2.a")))], None)?;
let LogicalPlan::Join(join) = join.with_new_exprs(
join.expressions(),
join.inputs().into_iter().cloned().collect(),
)?
else {
unreachable!()
};
assert_eq!(join.on, vec![(col("t1.a"), (col("t2.a")))]);
assert_eq!(join.filter, None);
}
{
let join = create_test_join(vec![], Some(col("t1.a").gt(col("t2.a"))))?;
let LogicalPlan::Join(join) = join.with_new_exprs(
join.expressions(),
join.inputs().into_iter().cloned().collect(),
)?
else {
unreachable!()
};
assert_eq!(join.on, vec![]);
assert_eq!(join.filter, Some(col("t1.a").gt(col("t2.a"))));
}
{
let join = create_test_join(
vec![(col("t1.a"), (col("t2.a")))],
Some(col("t1.b").gt(col("t2.b"))),
)?;
let LogicalPlan::Join(join) = join.with_new_exprs(
join.expressions(),
join.inputs().into_iter().cloned().collect(),
)?
else {
unreachable!()
};
assert_eq!(join.on, vec![(col("t1.a"), (col("t2.a")))]);
assert_eq!(join.filter, Some(col("t1.b").gt(col("t2.b"))));
}
{
let join = create_test_join(
vec![(col("t1.a"), (col("t2.a"))), (col("t1.b"), (col("t2.b")))],
None,
)?;
let LogicalPlan::Join(join) = join.with_new_exprs(
vec![
binary_expr(col("t1.a"), Operator::Plus, lit(1)),
binary_expr(col("t2.a"), Operator::Plus, lit(2)),
col("t1.b"),
col("t2.b"),
lit(true),
],
join.inputs().into_iter().cloned().collect(),
)?
else {
unreachable!()
};
assert_eq!(
join.on,
vec![
(
binary_expr(col("t1.a"), Operator::Plus, lit(1)),
binary_expr(col("t2.a"), Operator::Plus, lit(2))
),
(col("t1.b"), (col("t2.b")))
]
);
assert_eq!(join.filter, Some(lit(true)));
}
Ok(())
}
#[test]
fn test_join_try_new() -> Result<()> {
let schema = Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
]);
let left_scan = table_scan(Some("t1"), &schema, None)?.build()?;
let right_scan = table_scan(Some("t2"), &schema, None)?.build()?;
let join_types = vec![
JoinType::Inner,
JoinType::Left,
JoinType::Right,
JoinType::Full,
JoinType::LeftSemi,
JoinType::LeftAnti,
JoinType::RightSemi,
JoinType::RightAnti,
JoinType::LeftMark,
];
for join_type in join_types {
let join = Join::try_new(
Arc::new(left_scan.clone()),
Arc::new(right_scan.clone()),
vec![(col("t1.a"), col("t2.a"))],
Some(col("t1.b").gt(col("t2.b"))),
join_type,
JoinConstraint::On,
NullEquality::NullEqualsNothing,
false,
)?;
match join_type {
JoinType::LeftSemi | JoinType::LeftAnti => {
assert_eq!(join.schema.fields().len(), 2);
let fields = join.schema.fields();
assert_eq!(
fields[0].name(),
"a",
"First field should be 'a' from left table"
);
assert_eq!(
fields[1].name(),
"b",
"Second field should be 'b' from left table"
);
}
JoinType::RightSemi | JoinType::RightAnti => {
assert_eq!(join.schema.fields().len(), 2);
let fields = join.schema.fields();
assert_eq!(
fields[0].name(),
"a",
"First field should be 'a' from right table"
);
assert_eq!(
fields[1].name(),
"b",
"Second field should be 'b' from right table"
);
}
JoinType::LeftMark => {
assert_eq!(join.schema.fields().len(), 3);
let fields = join.schema.fields();
assert_eq!(
fields[0].name(),
"a",
"First field should be 'a' from left table"
);
assert_eq!(
fields[1].name(),
"b",
"Second field should be 'b' from left table"
);
assert_eq!(
fields[2].name(),
"mark",
"Third field should be the mark column"
);
assert!(!fields[0].is_nullable());
assert!(!fields[1].is_nullable());
assert!(!fields[2].is_nullable());
}
_ => {
assert_eq!(join.schema.fields().len(), 4);
let fields = join.schema.fields();
assert_eq!(
fields[0].name(),
"a",
"First field should be 'a' from left table"
);
assert_eq!(
fields[1].name(),
"b",
"Second field should be 'b' from left table"
);
assert_eq!(
fields[2].name(),
"a",
"Third field should be 'a' from right table"
);
assert_eq!(
fields[3].name(),
"b",
"Fourth field should be 'b' from right table"
);
if join_type == JoinType::Left {
assert!(!fields[0].is_nullable());
assert!(!fields[1].is_nullable());
assert!(fields[2].is_nullable());
assert!(fields[3].is_nullable());
} else if join_type == JoinType::Right {
assert!(fields[0].is_nullable());
assert!(fields[1].is_nullable());
assert!(!fields[2].is_nullable());
assert!(!fields[3].is_nullable());
} else if join_type == JoinType::Full {
assert!(fields[0].is_nullable());
assert!(fields[1].is_nullable());
assert!(fields[2].is_nullable());
assert!(fields[3].is_nullable());
}
}
}
assert_eq!(join.on, vec![(col("t1.a"), col("t2.a"))]);
assert_eq!(join.filter, Some(col("t1.b").gt(col("t2.b"))));
assert_eq!(join.join_type, join_type);
assert_eq!(join.join_constraint, JoinConstraint::On);
assert_eq!(join.null_equality, NullEquality::NullEqualsNothing);
}
Ok(())
}
#[test]
fn test_join_try_new_with_using_constraint_and_overlapping_columns() -> Result<()> {
let left_schema = Schema::new(vec![
Field::new("id", DataType::Int32, false), Field::new("name", DataType::Utf8, false), Field::new("value", DataType::Int32, false), ]);
let right_schema = Schema::new(vec![
Field::new("id", DataType::Int32, false), Field::new("category", DataType::Utf8, false), Field::new("value", DataType::Float64, true), ]);
let left_plan = table_scan(Some("t1"), &left_schema, None)?.build()?;
let right_plan = table_scan(Some("t2"), &right_schema, None)?.build()?;
{
let join = Join::try_new(
Arc::new(left_plan.clone()),
Arc::new(right_plan.clone()),
vec![(col("t1.id"), col("t2.id"))],
None,
JoinType::Inner,
JoinConstraint::Using,
NullEquality::NullEqualsNothing,
false,
)?;
let fields = join.schema.fields();
assert_eq!(fields.len(), 6);
assert_eq!(
fields[0].name(),
"id",
"First field should be 'id' from left table"
);
assert_eq!(
fields[1].name(),
"name",
"Second field should be 'name' from left table"
);
assert_eq!(
fields[2].name(),
"value",
"Third field should be 'value' from left table"
);
assert_eq!(
fields[3].name(),
"id",
"Fourth field should be 'id' from right table"
);
assert_eq!(
fields[4].name(),
"category",
"Fifth field should be 'category' from right table"
);
assert_eq!(
fields[5].name(),
"value",
"Sixth field should be 'value' from right table"
);
assert_eq!(join.join_constraint, JoinConstraint::Using);
}
{
let join = Join::try_new(
Arc::new(left_plan.clone()),
Arc::new(right_plan.clone()),
vec![(col("t1.id"), col("t2.id"))], Some(col("t1.value").lt(col("t2.value"))), JoinType::Inner,
JoinConstraint::On,
NullEquality::NullEqualsNothing,
false,
)?;
let fields = join.schema.fields();
assert_eq!(fields.len(), 6);
assert_eq!(
fields[0].name(),
"id",
"First field should be 'id' from left table"
);
assert_eq!(
fields[1].name(),
"name",
"Second field should be 'name' from left table"
);
assert_eq!(
fields[2].name(),
"value",
"Third field should be 'value' from left table"
);
assert_eq!(
fields[3].name(),
"id",
"Fourth field should be 'id' from right table"
);
assert_eq!(
fields[4].name(),
"category",
"Fifth field should be 'category' from right table"
);
assert_eq!(
fields[5].name(),
"value",
"Sixth field should be 'value' from right table"
);
assert_eq!(join.filter, Some(col("t1.value").lt(col("t2.value"))));
}
{
let join = Join::try_new(
Arc::new(left_plan.clone()),
Arc::new(right_plan.clone()),
vec![(col("t1.id"), col("t2.id"))],
None,
JoinType::Inner,
JoinConstraint::On,
NullEquality::NullEqualsNull,
false,
)?;
assert_eq!(join.null_equality, NullEquality::NullEqualsNull);
}
Ok(())
}
#[test]
fn test_join_try_new_schema_validation() -> Result<()> {
let left_schema = Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, false),
Field::new("value", DataType::Float64, true),
]);
let right_schema = Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("category", DataType::Utf8, true),
Field::new("code", DataType::Int16, false),
]);
let left_plan = table_scan(Some("t1"), &left_schema, None)?.build()?;
let right_plan = table_scan(Some("t2"), &right_schema, None)?.build()?;
let join_types = vec![
JoinType::Inner,
JoinType::Left,
JoinType::Right,
JoinType::Full,
];
for join_type in join_types {
let join = Join::try_new(
Arc::new(left_plan.clone()),
Arc::new(right_plan.clone()),
vec![(col("t1.id"), col("t2.id"))],
Some(col("t1.value").gt(lit(5.0))),
join_type,
JoinConstraint::On,
NullEquality::NullEqualsNothing,
false,
)?;
let fields = join.schema.fields();
assert_eq!(fields.len(), 6, "Expected 6 fields for {join_type} join");
for (i, field) in fields.iter().enumerate() {
let expected_nullable = match (i, &join_type) {
(0, JoinType::Right | JoinType::Full) => true, (1, JoinType::Right | JoinType::Full) => true, (2, _) => true,
(3, JoinType::Left | JoinType::Full) => true, (4, _) => true, (5, JoinType::Left | JoinType::Full) => true,
_ => false,
};
assert_eq!(
field.is_nullable(),
expected_nullable,
"Field {} ({}) nullability incorrect for {:?} join",
i,
field.name(),
join_type
);
}
}
let using_join = Join::try_new(
Arc::new(left_plan.clone()),
Arc::new(right_plan.clone()),
vec![(col("t1.id"), col("t2.id"))],
None,
JoinType::Inner,
JoinConstraint::Using,
NullEquality::NullEqualsNothing,
false,
)?;
assert_eq!(
using_join.schema.fields().len(),
6,
"USING join should have all fields"
);
assert_eq!(using_join.join_constraint, JoinConstraint::Using);
Ok(())
}
}