use crate::dml::CopyOptions;
use crate::expr::{Alias, Exists, InSubquery, Placeholder};
use crate::expr_rewriter::create_col_from_scalar_expr;
use crate::logical_plan::display::{GraphvizVisitor, IndentVisitor};
use crate::logical_plan::extension::UserDefinedLogicalNode;
use crate::logical_plan::{DmlStatement, Statement};
use crate::utils::{
enumerate_grouping_sets, exprlist_to_fields, find_out_reference_exprs,
grouping_set_expr_count, grouping_set_to_exprlist, inspect_expr_pre,
};
use crate::{
build_join_schema, Expr, ExprSchemable, TableProviderFilterPushDown, TableSource,
};
use crate::{
expr_vec_fmt, BinaryExpr, CreateMemoryTable, CreateView, LogicalPlanBuilder, Operator,
};
use super::dml::CopyTo;
use super::DdlStatement;
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use datafusion_common::tree_node::{
RewriteRecursion, Transformed, TreeNode, TreeNodeRewriter, TreeNodeVisitor,
VisitRecursion,
};
use datafusion_common::{
aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints,
DFField, DFSchema, DFSchemaRef, DataFusionError, FunctionalDependencies,
OwnedTableReference, Result, ScalarValue, UnnestOptions,
};
pub use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan};
pub use datafusion_common::{JoinConstraint, JoinType};
use std::collections::{HashMap, HashSet};
use std::fmt::{self, Debug, Display, Formatter};
use std::hash::{Hash, Hasher};
use std::sync::Arc;
#[derive(Clone, PartialEq, Eq, Hash)]
pub enum LogicalPlan {
Projection(Projection),
Filter(Filter),
Window(Window),
Aggregate(Aggregate),
Sort(Sort),
Join(Join),
CrossJoin(CrossJoin),
Repartition(Repartition),
Union(Union),
TableScan(TableScan),
EmptyRelation(EmptyRelation),
Subquery(Subquery),
SubqueryAlias(SubqueryAlias),
Limit(Limit),
Statement(Statement),
Values(Values),
Explain(Explain),
Analyze(Analyze),
Extension(Extension),
Distinct(Distinct),
Prepare(Prepare),
Dml(DmlStatement),
Ddl(DdlStatement),
Copy(CopyTo),
DescribeTable(DescribeTable),
Unnest(Unnest),
}
impl LogicalPlan {
pub fn schema(&self) -> &DFSchemaRef {
match self {
LogicalPlan::EmptyRelation(EmptyRelation { schema, .. }) => schema,
LogicalPlan::Values(Values { schema, .. }) => schema,
LogicalPlan::TableScan(TableScan {
projected_schema, ..
}) => projected_schema,
LogicalPlan::Projection(Projection { schema, .. }) => schema,
LogicalPlan::Filter(Filter { input, .. }) => input.schema(),
LogicalPlan::Distinct(Distinct { input }) => input.schema(),
LogicalPlan::Window(Window { schema, .. }) => schema,
LogicalPlan::Aggregate(Aggregate { schema, .. }) => schema,
LogicalPlan::Sort(Sort { input, .. }) => input.schema(),
LogicalPlan::Join(Join { schema, .. }) => schema,
LogicalPlan::CrossJoin(CrossJoin { schema, .. }) => schema,
LogicalPlan::Repartition(Repartition { input, .. }) => input.schema(),
LogicalPlan::Limit(Limit { input, .. }) => input.schema(),
LogicalPlan::Statement(statement) => statement.schema(),
LogicalPlan::Subquery(Subquery { subquery, .. }) => subquery.schema(),
LogicalPlan::SubqueryAlias(SubqueryAlias { schema, .. }) => schema,
LogicalPlan::Prepare(Prepare { input, .. }) => input.schema(),
LogicalPlan::Explain(explain) => &explain.schema,
LogicalPlan::Analyze(analyze) => &analyze.schema,
LogicalPlan::Extension(extension) => extension.node.schema(),
LogicalPlan::Union(Union { schema, .. }) => schema,
LogicalPlan::DescribeTable(DescribeTable { output_schema, .. }) => {
output_schema
}
LogicalPlan::Dml(DmlStatement { table_schema, .. }) => table_schema,
LogicalPlan::Copy(CopyTo { input, .. }) => input.schema(),
LogicalPlan::Ddl(ddl) => ddl.schema(),
LogicalPlan::Unnest(Unnest { schema, .. }) => schema,
}
}
pub fn fallback_normalize_schemas(&self) -> Vec<&DFSchema> {
match self {
LogicalPlan::Window(_)
| LogicalPlan::Projection(_)
| LogicalPlan::Aggregate(_)
| LogicalPlan::Unnest(_)
| LogicalPlan::Join(_)
| LogicalPlan::CrossJoin(_) => self
.inputs()
.iter()
.map(|input| input.schema().as_ref())
.collect(),
_ => vec![],
}
}
#[deprecated(since = "20.0.0")]
pub fn all_schemas(&self) -> Vec<&DFSchemaRef> {
match self {
LogicalPlan::Window(_)
| LogicalPlan::Projection(_)
| LogicalPlan::Aggregate(_)
| LogicalPlan::Unnest(_)
| LogicalPlan::Join(_)
| LogicalPlan::CrossJoin(_) => {
let mut schemas = vec![self.schema()];
self.inputs().iter().for_each(|input| {
schemas.push(input.schema());
});
schemas
}
LogicalPlan::Explain(_)
| LogicalPlan::Analyze(_)
| LogicalPlan::EmptyRelation(_)
| LogicalPlan::Ddl(_)
| LogicalPlan::Dml(_)
| LogicalPlan::Copy(_)
| LogicalPlan::Values(_)
| LogicalPlan::SubqueryAlias(_)
| LogicalPlan::Union(_)
| LogicalPlan::Extension(_)
| LogicalPlan::TableScan(_) => {
vec![self.schema()]
}
LogicalPlan::Limit(_)
| LogicalPlan::Subquery(_)
| LogicalPlan::Repartition(_)
| LogicalPlan::Sort(_)
| LogicalPlan::Filter(_)
| LogicalPlan::Distinct(_)
| LogicalPlan::Prepare(_) => {
self.inputs().iter().map(|p| p.schema()).collect()
}
LogicalPlan::Statement(_) | LogicalPlan::DescribeTable(_) => vec![],
}
}
pub fn explain_schema() -> SchemaRef {
SchemaRef::new(Schema::new(vec![
Field::new("plan_type", DataType::Utf8, false),
Field::new("plan", DataType::Utf8, false),
]))
}
pub fn describe_schema() -> Schema {
Schema::new(vec![
Field::new("column_name", DataType::Utf8, false),
Field::new("data_type", DataType::Utf8, false),
Field::new("is_nullable", DataType::Utf8, false),
])
}
pub fn expressions(self: &LogicalPlan) -> Vec<Expr> {
let mut exprs = vec![];
self.inspect_expressions(|e| {
exprs.push(e.clone());
Ok(()) as Result<()>
})
.unwrap();
exprs
}
pub fn all_out_ref_exprs(self: &LogicalPlan) -> Vec<Expr> {
let mut exprs = vec![];
self.inspect_expressions(|e| {
find_out_reference_exprs(e).into_iter().for_each(|e| {
if !exprs.contains(&e) {
exprs.push(e)
}
});
Ok(()) as Result<(), DataFusionError>
})
.unwrap();
self.inputs()
.into_iter()
.flat_map(|child| child.all_out_ref_exprs())
.for_each(|e| {
if !exprs.contains(&e) {
exprs.push(e)
}
});
exprs
}
pub fn inspect_expressions<F, E>(self: &LogicalPlan, mut f: F) -> Result<(), E>
where
F: FnMut(&Expr) -> Result<(), E>,
{
match self {
LogicalPlan::Projection(Projection { expr, .. }) => {
expr.iter().try_for_each(f)
}
LogicalPlan::Values(Values { values, .. }) => {
values.iter().flatten().try_for_each(f)
}
LogicalPlan::Filter(Filter { predicate, .. }) => f(predicate),
LogicalPlan::Repartition(Repartition {
partitioning_scheme,
..
}) => match partitioning_scheme {
Partitioning::Hash(expr, _) => expr.iter().try_for_each(f),
Partitioning::DistributeBy(expr) => expr.iter().try_for_each(f),
Partitioning::RoundRobinBatch(_) => Ok(()),
},
LogicalPlan::Window(Window { window_expr, .. }) => {
window_expr.iter().try_for_each(f)
}
LogicalPlan::Aggregate(Aggregate {
group_expr,
aggr_expr,
..
}) => group_expr.iter().chain(aggr_expr.iter()).try_for_each(f),
LogicalPlan::Join(Join { on, filter, .. }) => {
on.iter()
.map(|(l, r)| Expr::eq(l.clone(), r.clone()))
.try_for_each(|e| f(&e))?;
if let Some(filter) = filter.as_ref() {
f(filter)
} else {
Ok(())
}
}
LogicalPlan::Sort(Sort { expr, .. }) => expr.iter().try_for_each(f),
LogicalPlan::Extension(extension) => {
extension.node.expressions().iter().try_for_each(f)
}
LogicalPlan::TableScan(TableScan { filters, .. }) => {
filters.iter().try_for_each(f)
}
LogicalPlan::Unnest(Unnest { column, .. }) => {
f(&Expr::Column(column.clone()))
}
LogicalPlan::EmptyRelation(_)
| LogicalPlan::Subquery(_)
| LogicalPlan::SubqueryAlias(_)
| LogicalPlan::Limit(_)
| LogicalPlan::Statement(_)
| LogicalPlan::CrossJoin(_)
| LogicalPlan::Analyze(_)
| LogicalPlan::Explain(_)
| LogicalPlan::Union(_)
| LogicalPlan::Distinct(_)
| LogicalPlan::Dml(_)
| LogicalPlan::Ddl(_)
| LogicalPlan::Copy(_)
| LogicalPlan::DescribeTable(_)
| LogicalPlan::Prepare(_) => Ok(()),
}
}
pub fn inputs(&self) -> Vec<&LogicalPlan> {
match self {
LogicalPlan::Projection(Projection { input, .. }) => vec![input],
LogicalPlan::Filter(Filter { input, .. }) => vec![input],
LogicalPlan::Repartition(Repartition { input, .. }) => vec![input],
LogicalPlan::Window(Window { input, .. }) => vec![input],
LogicalPlan::Aggregate(Aggregate { input, .. }) => vec![input],
LogicalPlan::Sort(Sort { input, .. }) => vec![input],
LogicalPlan::Join(Join { left, right, .. }) => vec![left, right],
LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => vec![left, right],
LogicalPlan::Limit(Limit { input, .. }) => vec![input],
LogicalPlan::Subquery(Subquery { subquery, .. }) => vec![subquery],
LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => vec![input],
LogicalPlan::Extension(extension) => extension.node.inputs(),
LogicalPlan::Union(Union { inputs, .. }) => {
inputs.iter().map(|arc| arc.as_ref()).collect()
}
LogicalPlan::Distinct(Distinct { input }) => vec![input],
LogicalPlan::Explain(explain) => vec![&explain.plan],
LogicalPlan::Analyze(analyze) => vec![&analyze.input],
LogicalPlan::Dml(write) => vec![&write.input],
LogicalPlan::Copy(copy) => vec![©.input],
LogicalPlan::Ddl(ddl) => ddl.inputs(),
LogicalPlan::Unnest(Unnest { input, .. }) => vec![input],
LogicalPlan::Prepare(Prepare { input, .. }) => vec![input],
LogicalPlan::TableScan { .. }
| LogicalPlan::Statement { .. }
| LogicalPlan::EmptyRelation { .. }
| LogicalPlan::Values { .. }
| LogicalPlan::DescribeTable(_) => vec![],
}
}
pub fn using_columns(&self) -> Result<Vec<HashSet<Column>>, DataFusionError> {
let mut using_columns: Vec<HashSet<Column>> = vec![];
self.apply(&mut |plan| {
if let LogicalPlan::Join(Join {
join_constraint: JoinConstraint::Using,
on,
..
}) = plan
{
let columns =
on.iter().try_fold(HashSet::new(), |mut accumu, (l, r)| {
accumu.insert(l.try_into_col()?);
accumu.insert(r.try_into_col()?);
Result::<_, DataFusionError>::Ok(accumu)
})?;
using_columns.push(columns);
}
Ok(VisitRecursion::Continue)
})?;
Ok(using_columns)
}
pub fn head_output_expr(&self) -> Result<Option<Expr>> {
match self {
LogicalPlan::Projection(projection) => {
Ok(Some(projection.expr.as_slice()[0].clone()))
}
LogicalPlan::Aggregate(agg) => {
if agg.group_expr.is_empty() {
Ok(Some(agg.aggr_expr.as_slice()[0].clone()))
} else {
Ok(Some(agg.group_expr.as_slice()[0].clone()))
}
}
LogicalPlan::Filter(Filter { input, .. })
| LogicalPlan::Distinct(Distinct { input, .. })
| LogicalPlan::Sort(Sort { input, .. })
| LogicalPlan::Limit(Limit { input, .. })
| LogicalPlan::Repartition(Repartition { input, .. })
| LogicalPlan::Window(Window { input, .. }) => input.head_output_expr(),
LogicalPlan::Join(Join {
left,
right,
join_type,
..
}) => match join_type {
JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => {
if left.schema().fields().is_empty() {
right.head_output_expr()
} else {
left.head_output_expr()
}
}
JoinType::LeftSemi | JoinType::LeftAnti => left.head_output_expr(),
JoinType::RightSemi | JoinType::RightAnti => right.head_output_expr(),
},
LogicalPlan::CrossJoin(cross) => {
if cross.left.schema().fields().is_empty() {
cross.right.head_output_expr()
} else {
cross.left.head_output_expr()
}
}
LogicalPlan::Union(union) => Ok(Some(Expr::Column(
union.schema.fields()[0].qualified_column(),
))),
LogicalPlan::TableScan(table) => Ok(Some(Expr::Column(
table.projected_schema.fields()[0].qualified_column(),
))),
LogicalPlan::SubqueryAlias(subquery_alias) => {
let expr_opt = subquery_alias.input.head_output_expr()?;
expr_opt
.map(|expr| {
Ok(Expr::Column(create_col_from_scalar_expr(
&expr,
subquery_alias.alias.to_string(),
)?))
})
.map_or(Ok(None), |v| v.map(Some))
}
LogicalPlan::Subquery(_) => Ok(None),
LogicalPlan::EmptyRelation(_)
| LogicalPlan::Prepare(_)
| LogicalPlan::Statement(_)
| LogicalPlan::Values(_)
| LogicalPlan::Explain(_)
| LogicalPlan::Analyze(_)
| LogicalPlan::Extension(_)
| LogicalPlan::Dml(_)
| LogicalPlan::Copy(_)
| LogicalPlan::Ddl(_)
| LogicalPlan::DescribeTable(_)
| LogicalPlan::Unnest(_) => Ok(None),
}
}
pub fn with_new_inputs(&self, inputs: &[LogicalPlan]) -> Result<LogicalPlan> {
match &self {
LogicalPlan::Projection(projection) => {
Ok(LogicalPlan::Projection(Projection::try_new_with_schema(
projection.expr.to_vec(),
Arc::new(inputs[0].clone()),
projection.schema.clone(),
)?))
}
LogicalPlan::Window(Window {
window_expr,
schema,
..
}) => Ok(LogicalPlan::Window(Window {
input: Arc::new(inputs[0].clone()),
window_expr: window_expr.to_vec(),
schema: schema.clone(),
})),
LogicalPlan::Aggregate(Aggregate {
group_expr,
aggr_expr,
schema,
..
}) => Ok(LogicalPlan::Aggregate(Aggregate::try_new_with_schema(
Arc::new(inputs[0].clone()),
group_expr.to_vec(),
aggr_expr.to_vec(),
schema.clone(),
)?)),
_ => self.with_new_exprs(self.expressions(), inputs),
}
}
pub fn with_new_exprs(
&self,
mut expr: Vec<Expr>,
inputs: &[LogicalPlan],
) -> Result<LogicalPlan> {
match self {
LogicalPlan::Projection(Projection { schema, .. }) => {
Ok(LogicalPlan::Projection(Projection::try_new_with_schema(
expr,
Arc::new(inputs[0].clone()),
schema.clone(),
)?))
}
LogicalPlan::Dml(DmlStatement {
table_name,
table_schema,
op,
..
}) => Ok(LogicalPlan::Dml(DmlStatement {
table_name: table_name.clone(),
table_schema: table_schema.clone(),
op: op.clone(),
input: Arc::new(inputs[0].clone()),
})),
LogicalPlan::Copy(CopyTo {
input: _,
output_url,
file_format,
copy_options,
single_file_output,
}) => Ok(LogicalPlan::Copy(CopyTo {
input: Arc::new(inputs[0].clone()),
output_url: output_url.clone(),
file_format: file_format.clone(),
single_file_output: *single_file_output,
copy_options: copy_options.clone(),
})),
LogicalPlan::Values(Values { schema, .. }) => {
Ok(LogicalPlan::Values(Values {
schema: schema.clone(),
values: expr
.chunks_exact(schema.fields().len())
.map(|s| s.to_vec())
.collect::<Vec<_>>(),
}))
}
LogicalPlan::Filter { .. } => {
assert_eq!(1, expr.len());
let predicate = expr.pop().unwrap();
struct RemoveAliases {}
impl TreeNodeRewriter for RemoveAliases {
type N = Expr;
fn pre_visit(&mut self, expr: &Expr) -> Result<RewriteRecursion> {
match expr {
Expr::Exists { .. }
| Expr::ScalarSubquery(_)
| Expr::InSubquery(_) => {
Ok(RewriteRecursion::Stop)
}
Expr::Alias(_) => Ok(RewriteRecursion::Mutate),
_ => Ok(RewriteRecursion::Continue),
}
}
fn mutate(&mut self, expr: Expr) -> Result<Expr> {
Ok(expr.unalias())
}
}
let mut remove_aliases = RemoveAliases {};
let predicate = predicate.rewrite(&mut remove_aliases)?;
Ok(LogicalPlan::Filter(Filter::try_new(
predicate,
Arc::new(inputs[0].clone()),
)?))
}
LogicalPlan::Repartition(Repartition {
partitioning_scheme,
..
}) => match partitioning_scheme {
Partitioning::RoundRobinBatch(n) => {
Ok(LogicalPlan::Repartition(Repartition {
partitioning_scheme: Partitioning::RoundRobinBatch(*n),
input: Arc::new(inputs[0].clone()),
}))
}
Partitioning::Hash(_, n) => Ok(LogicalPlan::Repartition(Repartition {
partitioning_scheme: Partitioning::Hash(expr, *n),
input: Arc::new(inputs[0].clone()),
})),
Partitioning::DistributeBy(_) => {
Ok(LogicalPlan::Repartition(Repartition {
partitioning_scheme: Partitioning::DistributeBy(expr),
input: Arc::new(inputs[0].clone()),
}))
}
},
LogicalPlan::Window(Window {
window_expr,
schema,
..
}) => {
assert_eq!(window_expr.len(), expr.len());
Ok(LogicalPlan::Window(Window {
input: Arc::new(inputs[0].clone()),
window_expr: expr,
schema: schema.clone(),
}))
}
LogicalPlan::Aggregate(Aggregate {
group_expr, schema, ..
}) => {
let agg_expr = expr.split_off(group_expr.len());
Ok(LogicalPlan::Aggregate(Aggregate::try_new_with_schema(
Arc::new(inputs[0].clone()),
expr,
agg_expr,
schema.clone(),
)?))
}
LogicalPlan::Sort(Sort { fetch, .. }) => Ok(LogicalPlan::Sort(Sort {
expr,
input: Arc::new(inputs[0].clone()),
fetch: *fetch,
})),
LogicalPlan::Join(Join {
join_type,
join_constraint,
on,
null_equals_null,
..
}) => {
let schema =
build_join_schema(inputs[0].schema(), inputs[1].schema(), join_type)?;
let equi_expr_count = on.len();
assert!(expr.len() >= equi_expr_count);
let filter_expr = if expr.len() > equi_expr_count {
expr.pop()
} else {
None
};
assert_eq!(expr.len(), equi_expr_count);
let new_on:Vec<(Expr,Expr)> = expr.into_iter().map(|equi_expr| {
let unalias_expr = equi_expr.clone().unalias();
if let Expr::BinaryExpr(BinaryExpr { left, op: Operator::Eq, right }) = unalias_expr {
Ok((*left, *right))
} else {
internal_err!(
"The front part expressions should be an binary equality expression, actual:{equi_expr}"
)
}
}).collect::<Result<Vec<(Expr, Expr)>>>()?;
Ok(LogicalPlan::Join(Join {
left: Arc::new(inputs[0].clone()),
right: Arc::new(inputs[1].clone()),
join_type: *join_type,
join_constraint: *join_constraint,
on: new_on,
filter: filter_expr,
schema: DFSchemaRef::new(schema),
null_equals_null: *null_equals_null,
}))
}
LogicalPlan::CrossJoin(_) => {
let left = inputs[0].clone();
let right = inputs[1].clone();
LogicalPlanBuilder::from(left).cross_join(right)?.build()
}
LogicalPlan::Subquery(Subquery {
outer_ref_columns, ..
}) => {
let subquery = LogicalPlanBuilder::from(inputs[0].clone()).build()?;
Ok(LogicalPlan::Subquery(Subquery {
subquery: Arc::new(subquery),
outer_ref_columns: outer_ref_columns.clone(),
}))
}
LogicalPlan::SubqueryAlias(SubqueryAlias { alias, .. }) => {
Ok(LogicalPlan::SubqueryAlias(SubqueryAlias::try_new(
inputs[0].clone(),
alias.clone(),
)?))
}
LogicalPlan::Limit(Limit { skip, fetch, .. }) => {
Ok(LogicalPlan::Limit(Limit {
skip: *skip,
fetch: *fetch,
input: Arc::new(inputs[0].clone()),
}))
}
LogicalPlan::Ddl(DdlStatement::CreateMemoryTable(CreateMemoryTable {
name,
if_not_exists,
or_replace,
..
})) => Ok(LogicalPlan::Ddl(DdlStatement::CreateMemoryTable(
CreateMemoryTable {
input: Arc::new(inputs[0].clone()),
constraints: Constraints::empty(),
name: name.clone(),
if_not_exists: *if_not_exists,
or_replace: *or_replace,
},
))),
LogicalPlan::Ddl(DdlStatement::CreateView(CreateView {
name,
or_replace,
definition,
..
})) => Ok(LogicalPlan::Ddl(DdlStatement::CreateView(CreateView {
input: Arc::new(inputs[0].clone()),
name: name.clone(),
or_replace: *or_replace,
definition: definition.clone(),
}))),
LogicalPlan::Extension(e) => Ok(LogicalPlan::Extension(Extension {
node: e.node.from_template(&expr, inputs),
})),
LogicalPlan::Union(Union { schema, .. }) => Ok(LogicalPlan::Union(Union {
inputs: inputs.iter().cloned().map(Arc::new).collect(),
schema: schema.clone(),
})),
LogicalPlan::Distinct(Distinct { .. }) => {
Ok(LogicalPlan::Distinct(Distinct {
input: Arc::new(inputs[0].clone()),
}))
}
LogicalPlan::Analyze(a) => {
assert!(expr.is_empty());
assert_eq!(inputs.len(), 1);
Ok(LogicalPlan::Analyze(Analyze {
verbose: a.verbose,
schema: a.schema.clone(),
input: Arc::new(inputs[0].clone()),
}))
}
LogicalPlan::Explain(_) => {
if expr.is_empty() {
return plan_err!("Invalid EXPLAIN command. Expression is empty");
}
if inputs.is_empty() {
return plan_err!("Invalid EXPLAIN command. Inputs are empty");
}
Ok(self.clone())
}
LogicalPlan::Prepare(Prepare {
name, data_types, ..
}) => Ok(LogicalPlan::Prepare(Prepare {
name: name.clone(),
data_types: data_types.clone(),
input: Arc::new(inputs[0].clone()),
})),
LogicalPlan::TableScan(ts) => {
assert!(inputs.is_empty(), "{self:?} should have no inputs");
Ok(LogicalPlan::TableScan(TableScan {
filters: expr,
..ts.clone()
}))
}
LogicalPlan::EmptyRelation(_)
| LogicalPlan::Ddl(_)
| LogicalPlan::Statement(_) => {
assert!(expr.is_empty(), "{self:?} should have no exprs");
assert!(inputs.is_empty(), "{self:?} should have no inputs");
Ok(self.clone())
}
LogicalPlan::DescribeTable(_) => Ok(self.clone()),
LogicalPlan::Unnest(Unnest {
column,
schema,
options,
..
}) => {
let input = Arc::new(inputs[0].clone());
let nested_field = input.schema().field_from_column(column)?;
let unnested_field = schema.field_from_column(column)?;
let fields = input
.schema()
.fields()
.iter()
.map(|f| {
if f == nested_field {
unnested_field.clone()
} else {
f.clone()
}
})
.collect::<Vec<_>>();
let schema = Arc::new(
DFSchema::new_with_metadata(
fields,
input.schema().metadata().clone(),
)?
.with_functional_dependencies(
input.schema().functional_dependencies().clone(),
),
);
Ok(LogicalPlan::Unnest(Unnest {
input,
column: column.clone(),
schema,
options: options.clone(),
}))
}
}
}
pub fn with_param_values(
self,
param_values: Vec<ScalarValue>,
) -> Result<LogicalPlan> {
match self {
LogicalPlan::Prepare(prepare_lp) => {
if prepare_lp.data_types.len() != param_values.len() {
return plan_err!(
"Expected {} parameters, got {}",
prepare_lp.data_types.len(),
param_values.len()
);
}
let iter = prepare_lp.data_types.iter().zip(param_values.iter());
for (i, (param_type, value)) in iter.enumerate() {
if *param_type != value.data_type() {
return plan_err!(
"Expected parameter of type {:?}, got {:?} at index {}",
param_type,
value.data_type(),
i
);
}
}
let input_plan = prepare_lp.input;
input_plan.replace_params_with_values(¶m_values)
}
_ => Ok(self),
}
}
pub fn max_rows(self: &LogicalPlan) -> Option<usize> {
match self {
LogicalPlan::Projection(Projection { input, .. }) => input.max_rows(),
LogicalPlan::Filter(Filter { input, .. }) => input.max_rows(),
LogicalPlan::Window(Window { input, .. }) => input.max_rows(),
LogicalPlan::Aggregate(Aggregate {
input, group_expr, ..
}) => {
if group_expr
.iter()
.all(|expr| matches!(expr, Expr::Literal(_)))
{
Some(1)
} else {
input.max_rows()
}
}
LogicalPlan::Sort(Sort { input, fetch, .. }) => {
match (fetch, input.max_rows()) {
(Some(fetch_limit), Some(input_max)) => {
Some(input_max.min(*fetch_limit))
}
(Some(fetch_limit), None) => Some(*fetch_limit),
(None, Some(input_max)) => Some(input_max),
(None, None) => None,
}
}
LogicalPlan::Join(Join {
left,
right,
join_type,
..
}) => match join_type {
JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => {
match (left.max_rows(), right.max_rows()) {
(Some(left_max), Some(right_max)) => {
let min_rows = match join_type {
JoinType::Left => left_max,
JoinType::Right => right_max,
JoinType::Full => left_max + right_max,
_ => 0,
};
Some((left_max * right_max).max(min_rows))
}
_ => None,
}
}
JoinType::LeftSemi | JoinType::LeftAnti => left.max_rows(),
JoinType::RightSemi | JoinType::RightAnti => right.max_rows(),
},
LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => {
match (left.max_rows(), right.max_rows()) {
(Some(left_max), Some(right_max)) => Some(left_max * right_max),
_ => None,
}
}
LogicalPlan::Repartition(Repartition { input, .. }) => input.max_rows(),
LogicalPlan::Union(Union { inputs, .. }) => inputs
.iter()
.map(|plan| plan.max_rows())
.try_fold(0usize, |mut acc, input_max| {
if let Some(i_max) = input_max {
acc += i_max;
Some(acc)
} else {
None
}
}),
LogicalPlan::TableScan(TableScan { fetch, .. }) => *fetch,
LogicalPlan::EmptyRelation(_) => Some(0),
LogicalPlan::Subquery(_) => None,
LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => input.max_rows(),
LogicalPlan::Limit(Limit { fetch, .. }) => *fetch,
LogicalPlan::Distinct(Distinct { input }) => input.max_rows(),
LogicalPlan::Values(v) => Some(v.values.len()),
LogicalPlan::Unnest(_) => None,
LogicalPlan::Ddl(_)
| LogicalPlan::Explain(_)
| LogicalPlan::Analyze(_)
| LogicalPlan::Dml(_)
| LogicalPlan::Copy(_)
| LogicalPlan::DescribeTable(_)
| LogicalPlan::Prepare(_)
| LogicalPlan::Statement(_)
| LogicalPlan::Extension(_) => None,
}
}
}
impl LogicalPlan {
pub(crate) fn apply_subqueries<F>(&self, op: &mut F) -> datafusion_common::Result<()>
where
F: FnMut(&Self) -> datafusion_common::Result<VisitRecursion>,
{
self.inspect_expressions(|expr| {
inspect_expr_pre(expr, |expr| {
match expr {
Expr::Exists(Exists { subquery, .. })
| Expr::InSubquery(InSubquery { subquery, .. })
| Expr::ScalarSubquery(subquery) => {
let synthetic_plan = LogicalPlan::Subquery(subquery.clone());
synthetic_plan.apply(op)?;
}
_ => {}
}
Ok::<(), DataFusionError>(())
})
})?;
Ok(())
}
pub(crate) fn visit_subqueries<V>(&self, v: &mut V) -> datafusion_common::Result<()>
where
V: TreeNodeVisitor<N = LogicalPlan>,
{
self.inspect_expressions(|expr| {
inspect_expr_pre(expr, |expr| {
match expr {
Expr::Exists(Exists { subquery, .. })
| Expr::InSubquery(InSubquery { subquery, .. })
| Expr::ScalarSubquery(subquery) => {
let synthetic_plan = LogicalPlan::Subquery(subquery.clone());
synthetic_plan.visit(v)?;
}
_ => {}
}
Ok::<(), DataFusionError>(())
})
})?;
Ok(())
}
pub fn replace_params_with_values(
&self,
param_values: &[ScalarValue],
) -> Result<LogicalPlan> {
let new_exprs = self
.expressions()
.into_iter()
.map(|e| Self::replace_placeholders_with_values(e, param_values))
.collect::<Result<Vec<_>>>()?;
let new_inputs_with_values = self
.inputs()
.into_iter()
.map(|inp| inp.replace_params_with_values(param_values))
.collect::<Result<Vec<_>>>()?;
self.with_new_exprs(new_exprs, &new_inputs_with_values)
}
pub fn get_parameter_types(
&self,
) -> Result<HashMap<String, Option<DataType>>, DataFusionError> {
let mut param_types: HashMap<String, Option<DataType>> = HashMap::new();
self.apply(&mut |plan| {
plan.inspect_expressions(|expr| {
expr.apply(&mut |expr| {
if let Expr::Placeholder(Placeholder { id, data_type }) = expr {
let prev = param_types.get(id);
match (prev, data_type) {
(Some(Some(prev)), Some(dt)) => {
if prev != dt {
plan_err!("Conflicting types for {id}")?;
}
}
(_, Some(dt)) => {
param_types.insert(id.clone(), Some(dt.clone()));
}
_ => {}
}
}
Ok(VisitRecursion::Continue)
})?;
Ok::<(), DataFusionError>(())
})?;
Ok(VisitRecursion::Continue)
})?;
Ok(param_types)
}
fn replace_placeholders_with_values(
expr: Expr,
param_values: &[ScalarValue],
) -> Result<Expr> {
expr.transform(&|expr| {
match &expr {
Expr::Placeholder(Placeholder { id, data_type }) => {
if id.is_empty() || id == "$0" {
return plan_err!("Empty placeholder id");
}
let idx = id[1..].parse::<usize>().map_err(|e| {
DataFusionError::Internal(format!(
"Failed to parse placeholder id: {e}"
))
})? - 1;
let value = param_values.get(idx).ok_or_else(|| {
DataFusionError::Internal(format!(
"No value found for placeholder with id {id}"
))
})?;
if Some(value.data_type()) != *data_type {
return internal_err!(
"Placeholder value type mismatch: expected {:?}, got {:?}",
data_type,
value.data_type()
);
}
Ok(Transformed::Yes(Expr::Literal(value.clone())))
}
Expr::ScalarSubquery(qry) => {
let subquery =
Arc::new(qry.subquery.replace_params_with_values(param_values)?);
Ok(Transformed::Yes(Expr::ScalarSubquery(Subquery {
subquery,
outer_ref_columns: qry.outer_ref_columns.clone(),
})))
}
_ => Ok(Transformed::No(expr)),
}
})
}
}
impl LogicalPlan {
pub fn display_indent(&self) -> impl Display + '_ {
struct Wrapper<'a>(&'a LogicalPlan);
impl<'a> Display for Wrapper<'a> {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
let with_schema = false;
let mut visitor = IndentVisitor::new(f, with_schema);
match self.0.visit(&mut visitor) {
Ok(_) => Ok(()),
Err(_) => Err(fmt::Error),
}
}
}
Wrapper(self)
}
pub fn display_indent_schema(&self) -> impl Display + '_ {
struct Wrapper<'a>(&'a LogicalPlan);
impl<'a> Display for Wrapper<'a> {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
let with_schema = true;
let mut visitor = IndentVisitor::new(f, with_schema);
match self.0.visit(&mut visitor) {
Ok(_) => Ok(()),
Err(_) => Err(fmt::Error),
}
}
}
Wrapper(self)
}
pub fn display_graphviz(&self) -> impl Display + '_ {
struct Wrapper<'a>(&'a LogicalPlan);
impl<'a> Display for Wrapper<'a> {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
let mut visitor = GraphvizVisitor::new(f);
visitor.start_graph()?;
visitor.pre_visit_plan("LogicalPlan")?;
self.0.visit(&mut visitor).map_err(|_| fmt::Error)?;
visitor.post_visit_plan()?;
visitor.set_with_schema(true);
visitor.pre_visit_plan("Detailed LogicalPlan")?;
self.0.visit(&mut visitor).map_err(|_| fmt::Error)?;
visitor.post_visit_plan()?;
visitor.end_graph()?;
Ok(())
}
}
Wrapper(self)
}
pub fn display(&self) -> impl Display + '_ {
struct Wrapper<'a>(&'a LogicalPlan);
impl<'a> Display for Wrapper<'a> {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
match self.0 {
LogicalPlan::EmptyRelation(_) => write!(f, "EmptyRelation"),
LogicalPlan::Values(Values { ref values, .. }) => {
let str_values: Vec<_> = values
.iter()
.take(5)
.map(|row| {
let item = row
.iter()
.map(|expr| expr.to_string())
.collect::<Vec<_>>()
.join(", ");
format!("({item})")
})
.collect();
let elipse = if values.len() > 5 { "..." } else { "" };
write!(f, "Values: {}{}", str_values.join(", "), elipse)
}
LogicalPlan::TableScan(TableScan {
ref source,
ref table_name,
ref projection,
ref filters,
ref fetch,
..
}) => {
let projected_fields = match projection {
Some(indices) => {
let schema = source.schema();
let names: Vec<&str> = indices
.iter()
.map(|i| schema.field(*i).name().as_str())
.collect();
format!(" projection=[{}]", names.join(", "))
}
_ => "".to_string(),
};
write!(f, "TableScan: {table_name}{projected_fields}")?;
if !filters.is_empty() {
let mut full_filter = vec![];
let mut partial_filter = vec![];
let mut unsupported_filters = vec![];
let filters: Vec<&Expr> = filters.iter().collect();
if let Ok(results) =
source.supports_filters_pushdown(&filters)
{
filters.iter().zip(results.iter()).for_each(
|(x, res)| match res {
TableProviderFilterPushDown::Exact => {
full_filter.push(x)
}
TableProviderFilterPushDown::Inexact => {
partial_filter.push(x)
}
TableProviderFilterPushDown::Unsupported => {
unsupported_filters.push(x)
}
},
);
}
if !full_filter.is_empty() {
write!(
f,
", full_filters=[{}]",
expr_vec_fmt!(full_filter)
)?;
};
if !partial_filter.is_empty() {
write!(
f,
", partial_filters=[{}]",
expr_vec_fmt!(partial_filter)
)?;
}
if !unsupported_filters.is_empty() {
write!(
f,
", unsupported_filters=[{}]",
expr_vec_fmt!(unsupported_filters)
)?;
}
}
if let Some(n) = fetch {
write!(f, ", fetch={n}")?;
}
Ok(())
}
LogicalPlan::Projection(Projection { ref expr, .. }) => {
write!(f, "Projection: ")?;
for (i, expr_item) in expr.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{expr_item}")?;
}
Ok(())
}
LogicalPlan::Dml(DmlStatement { table_name, op, .. }) => {
write!(f, "Dml: op=[{op}] table=[{table_name}]")
}
LogicalPlan::Copy(CopyTo {
input: _,
output_url,
file_format,
single_file_output,
copy_options,
}) => {
let op_str = match copy_options {
CopyOptions::SQLOptions(statement) => statement
.clone()
.into_inner()
.iter()
.map(|(k, v)| format!("{k} {v}"))
.collect::<Vec<String>>()
.join(", "),
CopyOptions::WriterOptions(_) => "".into(),
};
write!(f, "CopyTo: format={file_format} output_url={output_url} single_file_output={single_file_output} options: ({op_str})")
}
LogicalPlan::Ddl(ddl) => {
write!(f, "{}", ddl.display())
}
LogicalPlan::Filter(Filter {
predicate: ref expr,
..
}) => write!(f, "Filter: {expr}"),
LogicalPlan::Window(Window {
ref window_expr, ..
}) => {
write!(
f,
"WindowAggr: windowExpr=[[{}]]",
expr_vec_fmt!(window_expr)
)
}
LogicalPlan::Aggregate(Aggregate {
ref group_expr,
ref aggr_expr,
..
}) => write!(
f,
"Aggregate: groupBy=[[{}]], aggr=[[{}]]",
expr_vec_fmt!(group_expr),
expr_vec_fmt!(aggr_expr)
),
LogicalPlan::Sort(Sort { expr, fetch, .. }) => {
write!(f, "Sort: ")?;
for (i, expr_item) in expr.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{expr_item}")?;
}
if let Some(a) = fetch {
write!(f, ", fetch={a}")?;
}
Ok(())
}
LogicalPlan::Join(Join {
on: ref keys,
filter,
join_constraint,
join_type,
..
}) => {
let join_expr: Vec<String> =
keys.iter().map(|(l, r)| format!("{l} = {r}")).collect();
let filter_expr = filter
.as_ref()
.map(|expr| format!(" Filter: {expr}"))
.unwrap_or_else(|| "".to_string());
match join_constraint {
JoinConstraint::On => {
write!(
f,
"{} Join: {}{}",
join_type,
join_expr.join(", "),
filter_expr
)
}
JoinConstraint::Using => {
write!(
f,
"{} Join: Using {}{}",
join_type,
join_expr.join(", "),
filter_expr,
)
}
}
}
LogicalPlan::CrossJoin(_) => {
write!(f, "CrossJoin:")
}
LogicalPlan::Repartition(Repartition {
partitioning_scheme,
..
}) => match partitioning_scheme {
Partitioning::RoundRobinBatch(n) => {
write!(f, "Repartition: RoundRobinBatch partition_count={n}")
}
Partitioning::Hash(expr, n) => {
let hash_expr: Vec<String> =
expr.iter().map(|e| format!("{e}")).collect();
write!(
f,
"Repartition: Hash({}) partition_count={}",
hash_expr.join(", "),
n
)
}
Partitioning::DistributeBy(expr) => {
let dist_by_expr: Vec<String> =
expr.iter().map(|e| format!("{e}")).collect();
write!(
f,
"Repartition: DistributeBy({})",
dist_by_expr.join(", "),
)
}
},
LogicalPlan::Limit(Limit {
ref skip,
ref fetch,
..
}) => {
write!(
f,
"Limit: skip={}, fetch={}",
skip,
fetch.map_or_else(|| "None".to_string(), |x| x.to_string())
)
}
LogicalPlan::Subquery(Subquery { .. }) => {
write!(f, "Subquery:")
}
LogicalPlan::SubqueryAlias(SubqueryAlias { ref alias, .. }) => {
write!(f, "SubqueryAlias: {alias}")
}
LogicalPlan::Statement(statement) => {
write!(f, "{}", statement.display())
}
LogicalPlan::Distinct(Distinct { .. }) => {
write!(f, "Distinct:")
}
LogicalPlan::Explain { .. } => write!(f, "Explain"),
LogicalPlan::Analyze { .. } => write!(f, "Analyze"),
LogicalPlan::Union(_) => write!(f, "Union"),
LogicalPlan::Extension(e) => e.node.fmt_for_explain(f),
LogicalPlan::Prepare(Prepare {
name, data_types, ..
}) => {
write!(f, "Prepare: {name:?} {data_types:?} ")
}
LogicalPlan::DescribeTable(DescribeTable { .. }) => {
write!(f, "DescribeTable")
}
LogicalPlan::Unnest(Unnest { column, .. }) => {
write!(f, "Unnest: {column}")
}
}
}
}
Wrapper(self)
}
}
impl Debug for LogicalPlan {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
self.display_indent().fmt(f)
}
}
impl ToStringifiedPlan for LogicalPlan {
fn to_stringified(&self, plan_type: PlanType) -> StringifiedPlan {
StringifiedPlan::new(plan_type, self.display_indent().to_string())
}
}
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct EmptyRelation {
pub produce_one_row: bool,
pub schema: DFSchemaRef,
}
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct Values {
pub schema: DFSchemaRef,
pub values: Vec<Vec<Expr>>,
}
#[derive(Clone, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub struct Projection {
pub expr: Vec<Expr>,
pub input: Arc<LogicalPlan>,
pub schema: DFSchemaRef,
}
impl Projection {
pub fn try_new(expr: Vec<Expr>, input: Arc<LogicalPlan>) -> Result<Self> {
let schema = Arc::new(DFSchema::new_with_metadata(
exprlist_to_fields(&expr, &input)?,
input.schema().metadata().clone(),
)?);
Self::try_new_with_schema(expr, input, schema)
}
pub fn try_new_with_schema(
expr: Vec<Expr>,
input: Arc<LogicalPlan>,
schema: DFSchemaRef,
) -> Result<Self> {
if expr.len() != schema.fields().len() {
return plan_err!("Projection has mismatch between number of expressions ({}) and number of fields in schema ({})", expr.len(), schema.fields().len());
}
let id_key_groups = calc_func_dependencies_for_project(&expr, &input)?;
let schema = schema.as_ref().clone();
let schema = Arc::new(schema.with_functional_dependencies(id_key_groups));
Ok(Self {
expr,
input,
schema,
})
}
pub fn new_from_schema(input: Arc<LogicalPlan>, schema: DFSchemaRef) -> Self {
let expr: Vec<Expr> = schema
.fields()
.iter()
.map(|field| field.qualified_column())
.map(Expr::Column)
.collect();
Self {
expr,
input,
schema,
}
}
}
#[derive(Clone, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub struct SubqueryAlias {
pub input: Arc<LogicalPlan>,
pub alias: OwnedTableReference,
pub schema: DFSchemaRef,
}
impl SubqueryAlias {
pub fn try_new(
plan: LogicalPlan,
alias: impl Into<OwnedTableReference>,
) -> Result<Self> {
let alias = alias.into();
let schema: Schema = plan.schema().as_ref().clone().into();
let func_dependencies = plan.schema().functional_dependencies().clone();
let schema = DFSchemaRef::new(
DFSchema::try_from_qualified_schema(&alias, &schema)?
.with_functional_dependencies(func_dependencies),
);
Ok(SubqueryAlias {
input: Arc::new(plan),
alias,
schema,
})
}
}
#[derive(Clone, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub struct Filter {
pub predicate: Expr,
pub input: Arc<LogicalPlan>,
}
impl Filter {
pub fn try_new(predicate: Expr, input: Arc<LogicalPlan>) -> Result<Self> {
if let Ok(predicate_type) = predicate.get_type(input.schema()) {
if predicate_type != DataType::Boolean {
return plan_err!(
"Cannot create filter with non-boolean predicate '{predicate}' returning {predicate_type}"
);
}
}
if let Expr::Alias(Alias { expr, name, .. }) = predicate {
return plan_err!(
"Attempted to create Filter predicate with \
expression `{expr}` aliased as '{name}'. Filter predicates should not be \
aliased."
);
}
Ok(Self { predicate, input })
}
}
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct Window {
pub input: Arc<LogicalPlan>,
pub window_expr: Vec<Expr>,
pub schema: DFSchemaRef,
}
impl Window {
pub fn try_new(window_expr: Vec<Expr>, input: Arc<LogicalPlan>) -> Result<Self> {
let mut window_fields: Vec<DFField> = input.schema().fields().clone();
window_fields
.extend_from_slice(&exprlist_to_fields(window_expr.iter(), input.as_ref())?);
let metadata = input.schema().metadata().clone();
let mut window_func_dependencies =
input.schema().functional_dependencies().clone();
window_func_dependencies.extend_target_indices(window_fields.len());
Ok(Window {
input,
window_expr,
schema: Arc::new(
DFSchema::new_with_metadata(window_fields, metadata)?
.with_functional_dependencies(window_func_dependencies),
),
})
}
}
#[derive(Clone)]
pub struct TableScan {
pub table_name: OwnedTableReference,
pub source: Arc<dyn TableSource>,
pub projection: Option<Vec<usize>>,
pub projected_schema: DFSchemaRef,
pub filters: Vec<Expr>,
pub fetch: Option<usize>,
}
impl PartialEq for TableScan {
fn eq(&self, other: &Self) -> bool {
self.table_name == other.table_name
&& self.projection == other.projection
&& self.projected_schema == other.projected_schema
&& self.filters == other.filters
&& self.fetch == other.fetch
}
}
impl Eq for TableScan {}
impl Hash for TableScan {
fn hash<H: Hasher>(&self, state: &mut H) {
self.table_name.hash(state);
self.projection.hash(state);
self.projected_schema.hash(state);
self.filters.hash(state);
self.fetch.hash(state);
}
}
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct CrossJoin {
pub left: Arc<LogicalPlan>,
pub right: Arc<LogicalPlan>,
pub schema: DFSchemaRef,
}
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct Repartition {
pub input: Arc<LogicalPlan>,
pub partitioning_scheme: Partitioning,
}
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct Union {
pub inputs: Vec<Arc<LogicalPlan>>,
pub schema: DFSchemaRef,
}
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct Prepare {
pub name: String,
pub data_types: Vec<DataType>,
pub input: Arc<LogicalPlan>,
}
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct DescribeTable {
pub schema: Arc<Schema>,
pub output_schema: DFSchemaRef,
}
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct Explain {
pub verbose: bool,
pub plan: Arc<LogicalPlan>,
pub stringified_plans: Vec<StringifiedPlan>,
pub schema: DFSchemaRef,
pub logical_optimization_succeeded: bool,
}
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct Analyze {
pub verbose: bool,
pub input: Arc<LogicalPlan>,
pub schema: DFSchemaRef,
}
#[allow(clippy::derived_hash_with_manual_eq)]
#[derive(Clone, Eq, Hash)]
pub struct Extension {
pub node: Arc<dyn UserDefinedLogicalNode>,
}
impl PartialEq for Extension {
fn eq(&self, other: &Self) -> bool {
self.node.eq(&other.node)
}
}
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct Limit {
pub skip: usize,
pub fetch: Option<usize>,
pub input: Arc<LogicalPlan>,
}
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct Distinct {
pub input: Arc<LogicalPlan>,
}
#[derive(Clone, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub struct Aggregate {
pub input: Arc<LogicalPlan>,
pub group_expr: Vec<Expr>,
pub aggr_expr: Vec<Expr>,
pub schema: DFSchemaRef,
}
impl Aggregate {
pub fn try_new(
input: Arc<LogicalPlan>,
group_expr: Vec<Expr>,
aggr_expr: Vec<Expr>,
) -> Result<Self> {
let group_expr = enumerate_grouping_sets(group_expr)?;
let grouping_expr: Vec<Expr> = grouping_set_to_exprlist(group_expr.as_slice())?;
let all_expr = grouping_expr.iter().chain(aggr_expr.iter());
let schema = DFSchema::new_with_metadata(
exprlist_to_fields(all_expr, &input)?,
input.schema().metadata().clone(),
)?;
Self::try_new_with_schema(input, group_expr, aggr_expr, Arc::new(schema))
}
pub fn try_new_with_schema(
input: Arc<LogicalPlan>,
group_expr: Vec<Expr>,
aggr_expr: Vec<Expr>,
schema: DFSchemaRef,
) -> Result<Self> {
if group_expr.is_empty() && aggr_expr.is_empty() {
return plan_err!(
"Aggregate requires at least one grouping or aggregate expression"
);
}
let group_expr_count = grouping_set_expr_count(&group_expr)?;
if schema.fields().len() != group_expr_count + aggr_expr.len() {
return plan_err!(
"Aggregate schema has wrong number of fields. Expected {} got {}",
group_expr_count + aggr_expr.len(),
schema.fields().len()
);
}
let aggregate_func_dependencies =
calc_func_dependencies_for_aggregate(&group_expr, &input, &schema)?;
let new_schema = schema.as_ref().clone();
let schema = Arc::new(
new_schema.with_functional_dependencies(aggregate_func_dependencies),
);
Ok(Self {
input,
group_expr,
aggr_expr,
schema,
})
}
}
fn contains_grouping_set(group_expr: &[Expr]) -> bool {
group_expr
.iter()
.any(|expr| matches!(expr, Expr::GroupingSet(_)))
}
fn calc_func_dependencies_for_aggregate(
group_expr: &[Expr],
input: &LogicalPlan,
aggr_schema: &DFSchema,
) -> Result<FunctionalDependencies> {
if !contains_grouping_set(group_expr) {
let group_by_expr_names = group_expr
.iter()
.map(|item| item.display_name())
.collect::<Result<Vec<_>>>()?;
let aggregate_func_dependencies = aggregate_functional_dependencies(
input.schema(),
&group_by_expr_names,
aggr_schema,
);
Ok(aggregate_func_dependencies)
} else {
Ok(FunctionalDependencies::empty())
}
}
fn calc_func_dependencies_for_project(
exprs: &[Expr],
input: &LogicalPlan,
) -> Result<FunctionalDependencies> {
let input_fields = input.schema().fields();
let proj_indices = exprs
.iter()
.filter_map(|expr| {
let expr_name = match expr {
Expr::Alias(alias) => {
format!("{}", alias.expr)
}
_ => format!("{}", expr),
};
input_fields
.iter()
.position(|item| item.qualified_name() == expr_name)
})
.collect::<Vec<_>>();
Ok(input
.schema()
.functional_dependencies()
.project_functional_dependencies(&proj_indices, exprs.len()))
}
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct Sort {
pub expr: Vec<Expr>,
pub input: Arc<LogicalPlan>,
pub fetch: Option<usize>,
}
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct Join {
pub left: Arc<LogicalPlan>,
pub right: Arc<LogicalPlan>,
pub on: Vec<(Expr, Expr)>,
pub filter: Option<Expr>,
pub join_type: JoinType,
pub join_constraint: JoinConstraint,
pub schema: DFSchemaRef,
pub null_equals_null: bool,
}
impl Join {
pub fn try_new_with_project_input(
original: &LogicalPlan,
left: Arc<LogicalPlan>,
right: Arc<LogicalPlan>,
column_on: (Vec<Column>, Vec<Column>),
) -> Result<Self> {
let original_join = match original {
LogicalPlan::Join(join) => join,
_ => return plan_err!("Could not create join with project input"),
};
let on: Vec<(Expr, Expr)> = column_on
.0
.into_iter()
.zip(column_on.1)
.map(|(l, r)| (Expr::Column(l), Expr::Column(r)))
.collect();
let join_schema =
build_join_schema(left.schema(), right.schema(), &original_join.join_type)?;
Ok(Join {
left,
right,
on,
filter: original_join.filter.clone(),
join_type: original_join.join_type,
join_constraint: original_join.join_constraint,
schema: Arc::new(join_schema),
null_equals_null: original_join.null_equals_null,
})
}
}
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct Subquery {
pub subquery: Arc<LogicalPlan>,
pub outer_ref_columns: Vec<Expr>,
}
impl Subquery {
pub fn try_from_expr(plan: &Expr) -> Result<&Subquery> {
match plan {
Expr::ScalarSubquery(it) => Ok(it),
Expr::Cast(cast) => Subquery::try_from_expr(cast.expr.as_ref()),
_ => plan_err!("Could not coerce into ScalarSubquery!"),
}
}
pub fn with_plan(&self, plan: Arc<LogicalPlan>) -> Subquery {
Subquery {
subquery: plan,
outer_ref_columns: self.outer_ref_columns.clone(),
}
}
}
impl Debug for Subquery {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "<subquery>")
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum Partitioning {
RoundRobinBatch(usize),
Hash(Vec<Expr>, usize),
DistributeBy(Vec<Expr>),
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Unnest {
pub input: Arc<LogicalPlan>,
pub column: Column,
pub schema: DFSchemaRef,
pub options: UnnestOptions,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::logical_plan::table_scan;
use crate::{col, exists, in_subquery, lit};
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::tree_node::TreeNodeVisitor;
use datafusion_common::{not_impl_err, DFSchema, TableReference};
use std::collections::HashMap;
fn employee_schema() -> Schema {
Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("first_name", DataType::Utf8, false),
Field::new("last_name", DataType::Utf8, false),
Field::new("state", DataType::Utf8, false),
Field::new("salary", DataType::Int32, false),
])
}
fn display_plan() -> Result<LogicalPlan> {
let plan1 = table_scan(Some("employee_csv"), &employee_schema(), Some(vec![3]))?
.build()?;
table_scan(Some("employee_csv"), &employee_schema(), Some(vec![0, 3]))?
.filter(in_subquery(col("state"), Arc::new(plan1)))?
.project(vec![col("id")])?
.build()
}
#[test]
fn test_display_indent() -> Result<()> {
let plan = display_plan()?;
let expected = "Projection: employee_csv.id\
\n Filter: employee_csv.state IN (<subquery>)\
\n Subquery:\
\n TableScan: employee_csv projection=[state]\
\n TableScan: employee_csv projection=[id, state]";
assert_eq!(expected, format!("{}", plan.display_indent()));
Ok(())
}
#[test]
fn test_display_indent_schema() -> Result<()> {
let plan = display_plan()?;
let expected = "Projection: employee_csv.id [id:Int32]\
\n Filter: employee_csv.state IN (<subquery>) [id:Int32, state:Utf8]\
\n Subquery: [state:Utf8]\
\n TableScan: employee_csv projection=[state] [state:Utf8]\
\n TableScan: employee_csv projection=[id, state] [id:Int32, state:Utf8]";
assert_eq!(expected, format!("{}", plan.display_indent_schema()));
Ok(())
}
#[test]
fn test_display_subquery_alias() -> Result<()> {
let plan1 = table_scan(Some("employee_csv"), &employee_schema(), Some(vec![3]))?
.build()?;
let plan1 = Arc::new(plan1);
let plan =
table_scan(Some("employee_csv"), &employee_schema(), Some(vec![0, 3]))?
.project(vec![col("id"), exists(plan1).alias("exists")])?
.build();
let expected = "Projection: employee_csv.id, EXISTS (<subquery>) AS exists\
\n Subquery:\
\n TableScan: employee_csv projection=[state]\
\n TableScan: employee_csv projection=[id, state]";
assert_eq!(expected, format!("{}", plan?.display_indent()));
Ok(())
}
#[test]
fn test_display_graphviz() -> Result<()> {
let plan = display_plan()?;
let expected_graphviz = r#"
// Begin DataFusion GraphViz Plan,
// display it online here: https://dreampuf.github.io/GraphvizOnline
digraph {
subgraph cluster_1
{
graph[label="LogicalPlan"]
2[shape=box label="Projection: employee_csv.id"]
3[shape=box label="Filter: employee_csv.state IN (<subquery>)"]
2 -> 3 [arrowhead=none, arrowtail=normal, dir=back]
4[shape=box label="Subquery:"]
3 -> 4 [arrowhead=none, arrowtail=normal, dir=back]
5[shape=box label="TableScan: employee_csv projection=[state]"]
4 -> 5 [arrowhead=none, arrowtail=normal, dir=back]
6[shape=box label="TableScan: employee_csv projection=[id, state]"]
3 -> 6 [arrowhead=none, arrowtail=normal, dir=back]
}
subgraph cluster_7
{
graph[label="Detailed LogicalPlan"]
8[shape=box label="Projection: employee_csv.id\nSchema: [id:Int32]"]
9[shape=box label="Filter: employee_csv.state IN (<subquery>)\nSchema: [id:Int32, state:Utf8]"]
8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]
10[shape=box label="Subquery:\nSchema: [state:Utf8]"]
9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]
11[shape=box label="TableScan: employee_csv projection=[state]\nSchema: [state:Utf8]"]
10 -> 11 [arrowhead=none, arrowtail=normal, dir=back]
12[shape=box label="TableScan: employee_csv projection=[id, state]\nSchema: [id:Int32, state:Utf8]"]
9 -> 12 [arrowhead=none, arrowtail=normal, dir=back]
}
}
// End DataFusion GraphViz Plan
"#;
let graphviz = format!("{}", plan.display_graphviz());
assert_eq!(expected_graphviz, graphviz);
Ok(())
}
#[derive(Debug, Default)]
struct OkVisitor {
strings: Vec<String>,
}
impl TreeNodeVisitor for OkVisitor {
type N = LogicalPlan;
fn pre_visit(&mut self, plan: &LogicalPlan) -> Result<VisitRecursion> {
let s = match plan {
LogicalPlan::Projection { .. } => "pre_visit Projection",
LogicalPlan::Filter { .. } => "pre_visit Filter",
LogicalPlan::TableScan { .. } => "pre_visit TableScan",
_ => {
return not_impl_err!("unknown plan type");
}
};
self.strings.push(s.into());
Ok(VisitRecursion::Continue)
}
fn post_visit(&mut self, plan: &LogicalPlan) -> Result<VisitRecursion> {
let s = match plan {
LogicalPlan::Projection { .. } => "post_visit Projection",
LogicalPlan::Filter { .. } => "post_visit Filter",
LogicalPlan::TableScan { .. } => "post_visit TableScan",
_ => {
return not_impl_err!("unknown plan type");
}
};
self.strings.push(s.into());
Ok(VisitRecursion::Continue)
}
}
#[test]
fn visit_order() {
let mut visitor = OkVisitor::default();
let plan = test_plan();
let res = plan.visit(&mut visitor);
assert!(res.is_ok());
assert_eq!(
visitor.strings,
vec![
"pre_visit Projection",
"pre_visit Filter",
"pre_visit TableScan",
"post_visit TableScan",
"post_visit Filter",
"post_visit Projection",
]
);
}
#[derive(Debug, Default)]
struct OptionalCounter {
val: Option<usize>,
}
impl OptionalCounter {
fn new(val: usize) -> Self {
Self { val: Some(val) }
}
fn dec(&mut self) -> bool {
if Some(0) == self.val {
true
} else {
self.val = self.val.take().map(|i| i - 1);
false
}
}
}
#[derive(Debug, Default)]
struct StoppingVisitor {
inner: OkVisitor,
return_false_from_pre_in: OptionalCounter,
return_false_from_post_in: OptionalCounter,
}
impl TreeNodeVisitor for StoppingVisitor {
type N = LogicalPlan;
fn pre_visit(&mut self, plan: &LogicalPlan) -> Result<VisitRecursion> {
if self.return_false_from_pre_in.dec() {
return Ok(VisitRecursion::Stop);
}
self.inner.pre_visit(plan)?;
Ok(VisitRecursion::Continue)
}
fn post_visit(&mut self, plan: &LogicalPlan) -> Result<VisitRecursion> {
if self.return_false_from_post_in.dec() {
return Ok(VisitRecursion::Stop);
}
self.inner.post_visit(plan)
}
}
#[test]
fn early_stopping_pre_visit() {
let mut visitor = StoppingVisitor {
return_false_from_pre_in: OptionalCounter::new(2),
..Default::default()
};
let plan = test_plan();
let res = plan.visit(&mut visitor);
assert!(res.is_ok());
assert_eq!(
visitor.inner.strings,
vec!["pre_visit Projection", "pre_visit Filter"]
);
}
#[test]
fn early_stopping_post_visit() {
let mut visitor = StoppingVisitor {
return_false_from_post_in: OptionalCounter::new(1),
..Default::default()
};
let plan = test_plan();
let res = plan.visit(&mut visitor);
assert!(res.is_ok());
assert_eq!(
visitor.inner.strings,
vec![
"pre_visit Projection",
"pre_visit Filter",
"pre_visit TableScan",
"post_visit TableScan",
]
);
}
#[derive(Debug, Default)]
struct ErrorVisitor {
inner: OkVisitor,
return_error_from_pre_in: OptionalCounter,
return_error_from_post_in: OptionalCounter,
}
impl TreeNodeVisitor for ErrorVisitor {
type N = LogicalPlan;
fn pre_visit(&mut self, plan: &LogicalPlan) -> Result<VisitRecursion> {
if self.return_error_from_pre_in.dec() {
return not_impl_err!("Error in pre_visit");
}
self.inner.pre_visit(plan)
}
fn post_visit(&mut self, plan: &LogicalPlan) -> Result<VisitRecursion> {
if self.return_error_from_post_in.dec() {
return not_impl_err!("Error in post_visit");
}
self.inner.post_visit(plan)
}
}
#[test]
fn error_pre_visit() {
let mut visitor = ErrorVisitor {
return_error_from_pre_in: OptionalCounter::new(2),
..Default::default()
};
let plan = test_plan();
let res = plan.visit(&mut visitor).unwrap_err();
assert_eq!(
"This feature is not implemented: Error in pre_visit",
res.strip_backtrace()
);
assert_eq!(
visitor.inner.strings,
vec!["pre_visit Projection", "pre_visit Filter"]
);
}
#[test]
fn error_post_visit() {
let mut visitor = ErrorVisitor {
return_error_from_post_in: OptionalCounter::new(1),
..Default::default()
};
let plan = test_plan();
let res = plan.visit(&mut visitor).unwrap_err();
assert_eq!(
"This feature is not implemented: Error in post_visit",
res.strip_backtrace()
);
assert_eq!(
visitor.inner.strings,
vec![
"pre_visit Projection",
"pre_visit Filter",
"pre_visit TableScan",
"post_visit TableScan",
]
);
}
#[test]
fn projection_expr_schema_mismatch() -> Result<()> {
let empty_schema = Arc::new(DFSchema::new_with_metadata(vec![], HashMap::new())?);
let p = Projection::try_new_with_schema(
vec![col("a")],
Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema: empty_schema.clone(),
})),
empty_schema,
);
assert_eq!(p.err().unwrap().strip_backtrace(), "Error during planning: Projection has mismatch between number of expressions (1) and number of fields in schema (0)");
Ok(())
}
fn test_plan() -> LogicalPlan {
let schema = Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("state", DataType::Utf8, false),
]);
table_scan(TableReference::none(), &schema, Some(vec![0, 1]))
.unwrap()
.filter(col("state").eq(lit("CO")))
.unwrap()
.project(vec![col("id")])
.unwrap()
.build()
.unwrap()
}
#[derive(Debug)]
struct NoChildExtension {
empty_schema: DFSchemaRef,
}
impl NoChildExtension {
fn empty() -> Self {
Self {
empty_schema: Arc::new(DFSchema::empty()),
}
}
}
impl UserDefinedLogicalNode for NoChildExtension {
fn as_any(&self) -> &dyn std::any::Any {
unimplemented!()
}
fn name(&self) -> &str {
unimplemented!()
}
fn inputs(&self) -> Vec<&LogicalPlan> {
panic!("Should not be called")
}
fn schema(&self) -> &DFSchemaRef {
&self.empty_schema
}
fn expressions(&self) -> Vec<Expr> {
unimplemented!()
}
fn fmt_for_explain(&self, _: &mut fmt::Formatter) -> fmt::Result {
unimplemented!()
}
fn from_template(
&self,
_: &[Expr],
_: &[LogicalPlan],
) -> Arc<dyn UserDefinedLogicalNode> {
unimplemented!()
}
fn dyn_hash(&self, _: &mut dyn Hasher) {
unimplemented!()
}
fn dyn_eq(&self, _: &dyn UserDefinedLogicalNode) -> bool {
unimplemented!()
}
}
#[test]
#[allow(deprecated)]
fn test_extension_all_schemas() {
let plan = LogicalPlan::Extension(Extension {
node: Arc::new(NoChildExtension::empty()),
});
let schemas = plan.all_schemas();
assert_eq!(1, schemas.len());
assert_eq!(0, schemas[0].fields().len());
}
#[test]
fn test_replace_invalid_placeholder() {
let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]);
let plan = table_scan(TableReference::none(), &schema, None)
.unwrap()
.filter(col("id").eq(Expr::Placeholder(Placeholder::new(
"".into(),
Some(DataType::Int32),
))))
.unwrap()
.build()
.unwrap();
plan.replace_params_with_values(&[42i32.into()])
.expect_err("unexpectedly succeeded to replace an invalid placeholder");
let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]);
let plan = table_scan(TableReference::none(), &schema, None)
.unwrap()
.filter(col("id").eq(Expr::Placeholder(Placeholder::new(
"$0".into(),
Some(DataType::Int32),
))))
.unwrap()
.build()
.unwrap();
plan.replace_params_with_values(&[42i32.into()])
.expect_err("unexpectedly succeeded to replace an invalid placeholder");
}
}