use std::any::Any;
use std::borrow::Cow;
use std::cmp::Ordering;
use std::collections::{HashMap, HashSet};
use std::iter::once;
use std::sync::Arc;
use crate::dml::CopyTo;
use crate::expr::{Alias, PlannedReplaceSelectItem, Sort as SortExpr};
use crate::expr_rewriter::{
coerce_plan_expr_for_schema, normalize_col,
normalize_col_with_schemas_and_ambiguity_check, normalize_cols, normalize_sorts,
rewrite_sort_cols_by_aggs,
};
use crate::logical_plan::{
Aggregate, Analyze, Distinct, DistinctOn, EmptyRelation, Explain, Filter, Join,
JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare,
Projection, Repartition, Sort, SubqueryAlias, TableScan, Union, Unnest, Values,
Window,
};
use crate::select_expr::SelectExpr;
use crate::utils::{
can_hash, columnize_expr, compare_sort_expr, expand_qualified_wildcard,
expand_wildcard, expr_to_columns, find_valid_equijoin_key_pair,
group_window_expr_by_sort_keys,
};
use crate::{
DmlStatement, ExplainOption, Expr, ExprSchemable, Operator, RecursiveQuery,
Statement, TableProviderFilterPushDown, TableSource, WriteOp, and, binary_expr, lit,
};
use super::dml::InsertOp;
use arrow::compute::can_cast_types;
use arrow::datatypes::{DataType, Field, FieldRef, Fields, Schema, SchemaRef};
use datafusion_common::display::ToStringifiedPlan;
use datafusion_common::file_options::file_type::FileType;
use datafusion_common::metadata::FieldMetadata;
use datafusion_common::{
Column, Constraints, DFSchema, DFSchemaRef, NullEquality, Result, ScalarValue,
TableReference, ToDFSchema, UnnestOptions, exec_err,
get_target_functional_dependencies, internal_datafusion_err, plan_datafusion_err,
plan_err,
};
use datafusion_expr_common::type_coercion::binary::type_union_resolution;
use indexmap::IndexSet;
pub const UNNAMED_TABLE: &str = "?table?";
#[derive(Default, Debug, Clone)]
pub struct LogicalPlanBuilderOptions {
add_implicit_group_by_exprs: bool,
}
impl LogicalPlanBuilderOptions {
pub fn new() -> Self {
Default::default()
}
pub fn with_add_implicit_group_by_exprs(mut self, add: bool) -> Self {
self.add_implicit_group_by_exprs = add;
self
}
}
#[derive(Debug, Clone)]
pub struct LogicalPlanBuilder {
plan: Arc<LogicalPlan>,
options: LogicalPlanBuilderOptions,
}
impl LogicalPlanBuilder {
pub fn new(plan: LogicalPlan) -> Self {
Self {
plan: Arc::new(plan),
options: LogicalPlanBuilderOptions::default(),
}
}
pub fn new_from_arc(plan: Arc<LogicalPlan>) -> Self {
Self {
plan,
options: LogicalPlanBuilderOptions::default(),
}
}
pub fn with_options(mut self, options: LogicalPlanBuilderOptions) -> Self {
self.options = options;
self
}
pub fn schema(&self) -> &DFSchemaRef {
self.plan.schema()
}
pub fn plan(&self) -> &LogicalPlan {
&self.plan
}
pub fn empty(produce_one_row: bool) -> Self {
Self::new(LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row,
schema: DFSchemaRef::new(DFSchema::empty()),
}))
}
pub fn to_recursive_query(
self,
name: String,
recursive_term: LogicalPlan,
is_distinct: bool,
) -> Result<Self> {
let static_fields_len = self.plan.schema().fields().len();
let recursive_fields_len = recursive_term.schema().fields().len();
if static_fields_len != recursive_fields_len {
return plan_err!(
"Non-recursive term and recursive term must have the same number of columns ({} != {})",
static_fields_len,
recursive_fields_len
);
}
let coerced_recursive_term =
coerce_plan_expr_for_schema(recursive_term, self.plan.schema())?;
Ok(Self::from(LogicalPlan::RecursiveQuery(RecursiveQuery {
name,
static_term: self.plan,
recursive_term: Arc::new(coerced_recursive_term),
is_distinct,
})))
}
pub fn values(values: Vec<Vec<Expr>>) -> Result<Self> {
if values.is_empty() {
return plan_err!("Values list cannot be empty");
}
let n_cols = values[0].len();
if n_cols == 0 {
return plan_err!("Values list cannot be zero length");
}
for (i, row) in values.iter().enumerate() {
if row.len() != n_cols {
return plan_err!(
"Inconsistent data length across values list: got {} values in row {} but expected {}",
row.len(),
i,
n_cols
);
}
}
Self::infer_data(values)
}
pub fn values_with_schema(
values: Vec<Vec<Expr>>,
schema: &DFSchemaRef,
) -> Result<Self> {
if values.is_empty() {
return plan_err!("Values list cannot be empty");
}
let n_cols = schema.fields().len();
if n_cols == 0 {
return plan_err!("Values list cannot be zero length");
}
for (i, row) in values.iter().enumerate() {
if row.len() != n_cols {
return plan_err!(
"Inconsistent data length across values list: got {} values in row {} but expected {}",
row.len(),
i,
n_cols
);
}
}
Self::infer_values_from_schema(values, schema)
}
fn infer_values_from_schema(
values: Vec<Vec<Expr>>,
schema: &DFSchema,
) -> Result<Self> {
let n_cols = values[0].len();
let mut fields = ValuesFields::new();
for j in 0..n_cols {
let field_type = schema.field(j).data_type();
let field_nullable = schema.field(j).is_nullable();
for row in values.iter() {
let value = &row[j];
let data_type = value.get_type(schema)?;
if !data_type.equals_datatype(field_type)
&& !can_cast_types(&data_type, field_type)
{
return exec_err!(
"type mismatch and can't cast to got {} and {}",
data_type,
field_type
);
}
}
fields.push(field_type.to_owned(), field_nullable);
}
Self::infer_inner(values, fields, schema)
}
fn infer_data(values: Vec<Vec<Expr>>) -> Result<Self> {
let n_cols = values[0].len();
let schema = DFSchema::empty();
let mut fields = ValuesFields::new();
for j in 0..n_cols {
let mut common_type: Option<DataType> = None;
let mut common_metadata: Option<FieldMetadata> = None;
for (i, row) in values.iter().enumerate() {
let value = &row[j];
let metadata = value.metadata(&schema)?;
if let Some(ref cm) = common_metadata {
if &metadata != cm {
return plan_err!(
"Inconsistent metadata across values list at row {i} column {j}. Was {:?} but found {:?}",
cm,
metadata
);
}
} else {
common_metadata = Some(metadata.clone());
}
let data_type = value.get_type(&schema)?;
if data_type == DataType::Null {
continue;
}
if let Some(prev_type) = common_type {
let data_types = vec![prev_type.clone(), data_type.clone()];
let Some(new_type) = type_union_resolution(&data_types) else {
return plan_err!(
"Inconsistent data type across values list at row {i} column {j}. Was {prev_type} but found {data_type}"
);
};
common_type = Some(new_type);
} else {
common_type = Some(data_type);
}
}
fields.push_with_metadata(
common_type.unwrap_or(DataType::Null),
true,
common_metadata,
);
}
Self::infer_inner(values, fields, &schema)
}
fn infer_inner(
mut values: Vec<Vec<Expr>>,
fields: ValuesFields,
schema: &DFSchema,
) -> Result<Self> {
let fields = fields.into_fields();
for row in &mut values {
for (j, field_type) in fields.iter().map(|f| f.data_type()).enumerate() {
if let Expr::Literal(ScalarValue::Null, metadata) = &row[j] {
row[j] = Expr::Literal(
ScalarValue::try_from(field_type)?,
metadata.clone(),
);
} else {
row[j] = std::mem::take(&mut row[j]).cast_to(field_type, schema)?;
}
}
}
let dfschema = DFSchema::from_unqualified_fields(fields, HashMap::new())?;
let schema = DFSchemaRef::new(dfschema);
Ok(Self::new(LogicalPlan::Values(Values { schema, values })))
}
pub fn scan(
table_name: impl Into<TableReference>,
table_source: Arc<dyn TableSource>,
projection: Option<Vec<usize>>,
) -> Result<Self> {
Self::scan_with_filters(table_name, table_source, projection, vec![])
}
pub fn copy_to(
input: LogicalPlan,
output_url: String,
file_type: Arc<dyn FileType>,
options: HashMap<String, String>,
partition_by: Vec<String>,
) -> Result<Self> {
Ok(Self::new(LogicalPlan::Copy(CopyTo::new(
Arc::new(input),
output_url,
partition_by,
file_type,
options,
))))
}
pub fn insert_into(
input: LogicalPlan,
table_name: impl Into<TableReference>,
target: Arc<dyn TableSource>,
insert_op: InsertOp,
) -> Result<Self> {
Ok(Self::new(LogicalPlan::Dml(DmlStatement::new(
table_name.into(),
target,
WriteOp::Insert(insert_op),
Arc::new(input),
))))
}
pub fn scan_with_filters(
table_name: impl Into<TableReference>,
table_source: Arc<dyn TableSource>,
projection: Option<Vec<usize>>,
filters: Vec<Expr>,
) -> Result<Self> {
Self::scan_with_filters_inner(table_name, table_source, projection, filters, None)
}
pub fn scan_with_filters_fetch(
table_name: impl Into<TableReference>,
table_source: Arc<dyn TableSource>,
projection: Option<Vec<usize>>,
filters: Vec<Expr>,
fetch: Option<usize>,
) -> Result<Self> {
Self::scan_with_filters_inner(
table_name,
table_source,
projection,
filters,
fetch,
)
}
fn scan_with_filters_inner(
table_name: impl Into<TableReference>,
table_source: Arc<dyn TableSource>,
projection: Option<Vec<usize>>,
filters: Vec<Expr>,
fetch: Option<usize>,
) -> Result<Self> {
let table_scan =
TableScan::try_new(table_name, table_source, projection, filters, fetch)?;
if table_scan.filters.is_empty()
&& let Some(p) = table_scan.source.get_logical_plan()
{
let sub_plan = p.into_owned();
if let Some(proj) = table_scan.projection {
let projection_exprs = proj
.into_iter()
.map(|i| {
Expr::Column(Column::from(sub_plan.schema().qualified_field(i)))
})
.collect::<Vec<_>>();
return Self::new(sub_plan)
.project(projection_exprs)?
.alias(table_scan.table_name);
}
return Self::new(sub_plan).alias(table_scan.table_name);
}
Ok(Self::new(LogicalPlan::TableScan(table_scan)))
}
pub fn window_plan(
input: LogicalPlan,
window_exprs: impl IntoIterator<Item = Expr>,
) -> Result<LogicalPlan> {
let mut plan = input;
let mut groups = group_window_expr_by_sort_keys(window_exprs)?;
groups.sort_by(|(key_a, _), (key_b, _)| {
for ((first, _), (second, _)) in key_a.iter().zip(key_b.iter()) {
let key_ordering = compare_sort_expr(first, second, plan.schema());
match key_ordering {
Ordering::Less => {
return Ordering::Less;
}
Ordering::Greater => {
return Ordering::Greater;
}
Ordering::Equal => {}
}
}
key_b.len().cmp(&key_a.len())
});
for (_, exprs) in groups {
let window_exprs = exprs.into_iter().collect::<Vec<_>>();
plan = LogicalPlanBuilder::from(plan)
.window(window_exprs)?
.build()?;
}
Ok(plan)
}
pub fn project(
self,
expr: impl IntoIterator<Item = impl Into<SelectExpr>>,
) -> Result<Self> {
project(Arc::unwrap_or_clone(self.plan), expr).map(Self::new)
}
pub fn project_with_validation(
self,
expr: Vec<(impl Into<SelectExpr>, bool)>,
) -> Result<Self> {
project_with_validation(Arc::unwrap_or_clone(self.plan), expr).map(Self::new)
}
pub fn select(self, indices: impl IntoIterator<Item = usize>) -> Result<Self> {
let exprs: Vec<_> = indices
.into_iter()
.map(|x| Expr::Column(Column::from(self.plan.schema().qualified_field(x))))
.collect();
self.project(exprs)
}
pub fn filter(self, expr: impl Into<Expr>) -> Result<Self> {
let expr = normalize_col(expr.into(), &self.plan)?;
Filter::try_new(expr, self.plan)
.map(LogicalPlan::Filter)
.map(Self::new)
}
pub fn having(self, expr: impl Into<Expr>) -> Result<Self> {
let expr = normalize_col(expr.into(), &self.plan)?;
Filter::try_new(expr, self.plan)
.map(LogicalPlan::Filter)
.map(Self::from)
}
pub fn prepare(self, name: String, fields: Vec<FieldRef>) -> Result<Self> {
Ok(Self::new(LogicalPlan::Statement(Statement::Prepare(
Prepare {
name,
fields,
input: self.plan,
},
))))
}
pub fn limit(self, skip: usize, fetch: Option<usize>) -> Result<Self> {
let skip_expr = if skip == 0 {
None
} else {
Some(lit(skip as i64))
};
let fetch_expr = fetch.map(|f| lit(f as i64));
self.limit_by_expr(skip_expr, fetch_expr)
}
pub fn limit_by_expr(self, skip: Option<Expr>, fetch: Option<Expr>) -> Result<Self> {
Ok(Self::new(LogicalPlan::Limit(Limit {
skip: skip.map(Box::new),
fetch: fetch.map(Box::new),
input: self.plan,
})))
}
pub fn alias(self, alias: impl Into<TableReference>) -> Result<Self> {
subquery_alias(Arc::unwrap_or_clone(self.plan), alias).map(Self::new)
}
fn add_missing_columns(
curr_plan: LogicalPlan,
missing_cols: &IndexSet<Column>,
is_distinct: bool,
) -> Result<LogicalPlan> {
match curr_plan {
LogicalPlan::Projection(Projection {
input,
mut expr,
schema: _,
}) if missing_cols.iter().all(|c| input.schema().has_column(c)) => {
let mut missing_exprs = missing_cols
.iter()
.map(|c| normalize_col(Expr::Column(c.clone()), &input))
.collect::<Result<Vec<_>>>()?;
missing_exprs.retain(|e| !expr.contains(e));
if is_distinct {
Self::ambiguous_distinct_check(&missing_exprs, missing_cols, &expr)?;
}
expr.extend(missing_exprs);
project(Arc::unwrap_or_clone(input), expr)
}
_ => {
let is_distinct =
is_distinct || matches!(curr_plan, LogicalPlan::Distinct(_));
let new_inputs = curr_plan
.inputs()
.into_iter()
.map(|input_plan| {
Self::add_missing_columns(
(*input_plan).clone(),
missing_cols,
is_distinct,
)
})
.collect::<Result<Vec<_>>>()?;
curr_plan.with_new_exprs(curr_plan.expressions(), new_inputs)
}
}
}
fn ambiguous_distinct_check(
missing_exprs: &[Expr],
missing_cols: &IndexSet<Column>,
projection_exprs: &[Expr],
) -> Result<()> {
if missing_exprs.is_empty() {
return Ok(());
}
let all_aliases = missing_exprs.iter().all(|e| {
projection_exprs.iter().any(|proj_expr| {
if let Expr::Alias(Alias { expr, .. }) = proj_expr {
e == expr.as_ref()
} else {
false
}
})
});
if all_aliases {
return Ok(());
}
let missing_col_names = missing_cols
.iter()
.map(|col| col.flat_name())
.collect::<String>();
plan_err!(
"For SELECT DISTINCT, ORDER BY expressions {missing_col_names} must appear in select list"
)
}
pub fn sort_by(
self,
expr: impl IntoIterator<Item = impl Into<Expr>> + Clone,
) -> Result<Self> {
self.sort(
expr.into_iter()
.map(|e| e.into().sort(true, false))
.collect::<Vec<SortExpr>>(),
)
}
pub fn sort(
self,
sorts: impl IntoIterator<Item = impl Into<SortExpr>> + Clone,
) -> Result<Self> {
self.sort_with_limit(sorts, None)
}
pub fn sort_with_limit(
self,
sorts: impl IntoIterator<Item = impl Into<SortExpr>> + Clone,
fetch: Option<usize>,
) -> Result<Self> {
let sorts = rewrite_sort_cols_by_aggs(sorts, &self.plan)?;
let schema = self.plan.schema();
let mut missing_cols: IndexSet<Column> = IndexSet::new();
sorts.iter().try_for_each::<_, Result<()>>(|sort| {
let columns = sort.expr.column_refs();
missing_cols.extend(
columns
.into_iter()
.filter(|c| !schema.has_column(c))
.cloned(),
);
Ok(())
})?;
if missing_cols.is_empty() {
return Ok(Self::new(LogicalPlan::Sort(Sort {
expr: normalize_sorts(sorts, &self.plan)?,
input: self.plan,
fetch,
})));
}
let new_expr = schema.columns().into_iter().map(Expr::Column).collect();
let is_distinct = false;
let plan = Self::add_missing_columns(
Arc::unwrap_or_clone(self.plan),
&missing_cols,
is_distinct,
)?;
let sort_plan = LogicalPlan::Sort(Sort {
expr: normalize_sorts(sorts, &plan)?,
input: Arc::new(plan),
fetch,
});
Projection::try_new(new_expr, Arc::new(sort_plan))
.map(LogicalPlan::Projection)
.map(Self::new)
}
pub fn union(self, plan: LogicalPlan) -> Result<Self> {
union(Arc::unwrap_or_clone(self.plan), plan).map(Self::new)
}
pub fn union_by_name(self, plan: LogicalPlan) -> Result<Self> {
union_by_name(Arc::unwrap_or_clone(self.plan), plan).map(Self::new)
}
pub fn union_by_name_distinct(self, plan: LogicalPlan) -> Result<Self> {
let left_plan: LogicalPlan = Arc::unwrap_or_clone(self.plan);
let right_plan: LogicalPlan = plan;
Ok(Self::new(LogicalPlan::Distinct(Distinct::All(Arc::new(
union_by_name(left_plan, right_plan)?,
)))))
}
pub fn union_distinct(self, plan: LogicalPlan) -> Result<Self> {
let left_plan: LogicalPlan = Arc::unwrap_or_clone(self.plan);
let right_plan: LogicalPlan = plan;
Ok(Self::new(LogicalPlan::Distinct(Distinct::All(Arc::new(
union(left_plan, right_plan)?,
)))))
}
pub fn distinct(self) -> Result<Self> {
Ok(Self::new(LogicalPlan::Distinct(Distinct::All(self.plan))))
}
pub fn distinct_on(
self,
on_expr: Vec<Expr>,
select_expr: Vec<Expr>,
sort_expr: Option<Vec<SortExpr>>,
) -> Result<Self> {
Ok(Self::new(LogicalPlan::Distinct(Distinct::On(
DistinctOn::try_new(on_expr, select_expr, sort_expr, self.plan)?,
))))
}
pub fn join(
self,
right: LogicalPlan,
join_type: JoinType,
join_keys: (Vec<impl Into<Column>>, Vec<impl Into<Column>>),
filter: Option<Expr>,
) -> Result<Self> {
self.join_detailed(
right,
join_type,
join_keys,
filter,
NullEquality::NullEqualsNothing,
)
}
pub fn join_on(
self,
right: LogicalPlan,
join_type: JoinType,
on_exprs: impl IntoIterator<Item = Expr>,
) -> Result<Self> {
let filter = on_exprs.into_iter().reduce(Expr::and);
self.join_detailed(
right,
join_type,
(Vec::<Column>::new(), Vec::<Column>::new()),
filter,
NullEquality::NullEqualsNothing,
)
}
pub(crate) fn normalize(plan: &LogicalPlan, column: Column) -> Result<Column> {
if column.relation.is_some() {
return Ok(column);
}
let schema = plan.schema();
let fallback_schemas = plan.fallback_normalize_schemas();
let using_columns = plan.using_columns()?;
column.normalize_with_schemas_and_ambiguity_check(
&[&[schema], &fallback_schemas],
&using_columns,
)
}
pub fn join_detailed(
self,
right: LogicalPlan,
join_type: JoinType,
join_keys: (Vec<impl Into<Column>>, Vec<impl Into<Column>>),
filter: Option<Expr>,
null_equality: NullEquality,
) -> Result<Self> {
self.join_detailed_with_options(
right,
join_type,
join_keys,
filter,
null_equality,
false,
)
}
pub fn join_detailed_with_options(
self,
right: LogicalPlan,
join_type: JoinType,
join_keys: (Vec<impl Into<Column>>, Vec<impl Into<Column>>),
filter: Option<Expr>,
null_equality: NullEquality,
null_aware: bool,
) -> Result<Self> {
if join_keys.0.len() != join_keys.1.len() {
return plan_err!("left_keys and right_keys were not the same length");
}
let filter = if let Some(expr) = filter {
let filter = normalize_col_with_schemas_and_ambiguity_check(
expr,
&[&[self.schema(), right.schema()]],
&[],
)?;
Some(filter)
} else {
None
};
let (left_keys, right_keys): (Vec<Result<Column>>, Vec<Result<Column>>) =
join_keys
.0
.into_iter()
.zip(join_keys.1)
.map(|(l, r)| {
let l = l.into();
let r = r.into();
match (&l.relation, &r.relation) {
(Some(lr), Some(rr)) => {
let l_is_left =
self.plan.schema().field_with_qualified_name(lr, &l.name);
let l_is_right =
right.schema().field_with_qualified_name(lr, &l.name);
let r_is_left =
self.plan.schema().field_with_qualified_name(rr, &r.name);
let r_is_right =
right.schema().field_with_qualified_name(rr, &r.name);
match (l_is_left, l_is_right, r_is_left, r_is_right) {
(_, Ok(_), Ok(_), _) => (Ok(r), Ok(l)),
(Ok(_), _, _, Ok(_)) => (Ok(l), Ok(r)),
_ => (
Self::normalize(&self.plan, l),
Self::normalize(&right, r),
),
}
}
(Some(lr), None) => {
let l_is_left =
self.plan.schema().field_with_qualified_name(lr, &l.name);
let l_is_right =
right.schema().field_with_qualified_name(lr, &l.name);
match (l_is_left, l_is_right) {
(Ok(_), _) => (Ok(l), Self::normalize(&right, r)),
(_, Ok(_)) => (Self::normalize(&self.plan, r), Ok(l)),
_ => (
Self::normalize(&self.plan, l),
Self::normalize(&right, r),
),
}
}
(None, Some(rr)) => {
let r_is_left =
self.plan.schema().field_with_qualified_name(rr, &r.name);
let r_is_right =
right.schema().field_with_qualified_name(rr, &r.name);
match (r_is_left, r_is_right) {
(Ok(_), _) => (Ok(r), Self::normalize(&right, l)),
(_, Ok(_)) => (Self::normalize(&self.plan, l), Ok(r)),
_ => (
Self::normalize(&self.plan, l),
Self::normalize(&right, r),
),
}
}
(None, None) => {
let mut swap = false;
let left_key = Self::normalize(&self.plan, l.clone())
.or_else(|_| {
swap = true;
Self::normalize(&right, l)
});
if swap {
(Self::normalize(&self.plan, r), left_key)
} else {
(left_key, Self::normalize(&right, r))
}
}
}
})
.unzip();
let left_keys = left_keys.into_iter().collect::<Result<Vec<Column>>>()?;
let right_keys = right_keys.into_iter().collect::<Result<Vec<Column>>>()?;
let on: Vec<_> = left_keys
.into_iter()
.zip(right_keys)
.map(|(l, r)| (Expr::Column(l), Expr::Column(r)))
.collect();
let join_schema =
build_join_schema(self.plan.schema(), right.schema(), &join_type)?;
if join_type != JoinType::Inner && on.is_empty() && filter.is_none() {
return plan_err!("join condition should not be empty");
}
Ok(Self::new(LogicalPlan::Join(Join {
left: self.plan,
right: Arc::new(right),
on,
filter,
join_type,
join_constraint: JoinConstraint::On,
schema: DFSchemaRef::new(join_schema),
null_equality,
null_aware,
})))
}
pub fn join_using(
self,
right: LogicalPlan,
join_type: JoinType,
using_keys: Vec<Column>,
) -> Result<Self> {
let left_keys: Vec<Column> = using_keys
.clone()
.into_iter()
.map(|c| Self::normalize(&self.plan, c))
.collect::<Result<_>>()?;
let right_keys: Vec<Column> = using_keys
.into_iter()
.map(|c| Self::normalize(&right, c))
.collect::<Result<_>>()?;
let on: Vec<(_, _)> = left_keys.into_iter().zip(right_keys).collect();
let mut join_on: Vec<(Expr, Expr)> = vec![];
let mut filters: Option<Expr> = None;
for (l, r) in &on {
if self.plan.schema().has_column(l)
&& right.schema().has_column(r)
&& can_hash(
datafusion_common::ExprSchema::field_from_column(
self.plan.schema(),
l,
)?
.data_type(),
)
{
join_on.push((Expr::Column(l.clone()), Expr::Column(r.clone())));
} else if self.plan.schema().has_column(l)
&& right.schema().has_column(r)
&& can_hash(
datafusion_common::ExprSchema::field_from_column(
self.plan.schema(),
r,
)?
.data_type(),
)
{
join_on.push((Expr::Column(r.clone()), Expr::Column(l.clone())));
} else {
let expr = binary_expr(
Expr::Column(l.clone()),
Operator::Eq,
Expr::Column(r.clone()),
);
match filters {
None => filters = Some(expr),
Some(filter_expr) => filters = Some(and(expr, filter_expr)),
}
}
}
if join_on.is_empty() {
let join = Self::from(self.plan).cross_join(right)?;
join.filter(filters.ok_or_else(|| {
internal_datafusion_err!("filters should not be None here")
})?)
} else {
let join = Join::try_new(
self.plan,
Arc::new(right),
join_on,
filters,
join_type,
JoinConstraint::Using,
NullEquality::NullEqualsNothing,
false, )?;
Ok(Self::new(LogicalPlan::Join(join)))
}
}
pub fn cross_join(self, right: LogicalPlan) -> Result<Self> {
let join = Join::try_new(
self.plan,
Arc::new(right),
vec![],
None,
JoinType::Inner,
JoinConstraint::On,
NullEquality::NullEqualsNothing,
false, )?;
Ok(Self::new(LogicalPlan::Join(join)))
}
pub fn repartition(self, partitioning_scheme: Partitioning) -> Result<Self> {
Ok(Self::new(LogicalPlan::Repartition(Repartition {
input: self.plan,
partitioning_scheme,
})))
}
pub fn window(
self,
window_expr: impl IntoIterator<Item = impl Into<Expr>>,
) -> Result<Self> {
let window_expr = normalize_cols(window_expr, &self.plan)?;
validate_unique_names("Windows", &window_expr)?;
Ok(Self::new(LogicalPlan::Window(Window::try_new(
window_expr,
self.plan,
)?)))
}
pub fn aggregate(
self,
group_expr: impl IntoIterator<Item = impl Into<Expr>>,
aggr_expr: impl IntoIterator<Item = impl Into<Expr>>,
) -> Result<Self> {
let group_expr = normalize_cols(group_expr, &self.plan)?;
let aggr_expr = normalize_cols(aggr_expr, &self.plan)?;
let group_expr = if self.options.add_implicit_group_by_exprs {
add_group_by_exprs_from_dependencies(group_expr, self.plan.schema())?
} else {
group_expr
};
Aggregate::try_new(self.plan, group_expr, aggr_expr)
.map(LogicalPlan::Aggregate)
.map(Self::new)
}
pub fn explain(self, verbose: bool, analyze: bool) -> Result<Self> {
self.explain_option_format(
ExplainOption::default()
.with_verbose(verbose)
.with_analyze(analyze),
)
}
pub fn explain_option_format(self, explain_option: ExplainOption) -> Result<Self> {
let schema = LogicalPlan::explain_schema();
let schema = schema.to_dfschema_ref()?;
if explain_option.analyze {
Ok(Self::new(LogicalPlan::Analyze(Analyze {
verbose: explain_option.verbose,
input: self.plan,
schema,
})))
} else {
let stringified_plans =
vec![self.plan.to_stringified(PlanType::InitialLogicalPlan)];
Ok(Self::new(LogicalPlan::Explain(Explain {
verbose: explain_option.verbose,
plan: self.plan,
explain_format: explain_option.format,
stringified_plans,
schema,
logical_optimization_succeeded: false,
})))
}
}
pub fn intersect(
left_plan: LogicalPlan,
right_plan: LogicalPlan,
is_all: bool,
) -> Result<LogicalPlan> {
LogicalPlanBuilder::intersect_or_except(
left_plan,
right_plan,
JoinType::LeftSemi,
is_all,
)
}
pub fn except(
left_plan: LogicalPlan,
right_plan: LogicalPlan,
is_all: bool,
) -> Result<LogicalPlan> {
LogicalPlanBuilder::intersect_or_except(
left_plan,
right_plan,
JoinType::LeftAnti,
is_all,
)
}
fn intersect_or_except(
left_plan: LogicalPlan,
right_plan: LogicalPlan,
join_type: JoinType,
is_all: bool,
) -> Result<LogicalPlan> {
let left_len = left_plan.schema().fields().len();
let right_len = right_plan.schema().fields().len();
if left_len != right_len {
return plan_err!(
"INTERSECT/EXCEPT query must have the same number of columns. Left is {left_len} and right is {right_len}."
);
}
let left_builder = LogicalPlanBuilder::from(left_plan);
let right_builder = LogicalPlanBuilder::from(right_plan);
let (left_builder, right_builder, _requalified) =
requalify_sides_if_needed(left_builder, right_builder)?;
let left_plan = left_builder.build()?;
let right_plan = right_builder.build()?;
let join_keys = left_plan
.schema()
.fields()
.iter()
.zip(right_plan.schema().fields().iter())
.map(|(left_field, right_field)| {
(
(Column::from_name(left_field.name())),
(Column::from_name(right_field.name())),
)
})
.unzip();
if is_all {
LogicalPlanBuilder::from(left_plan)
.join_detailed(
right_plan,
join_type,
join_keys,
None,
NullEquality::NullEqualsNull,
)?
.build()
} else {
LogicalPlanBuilder::from(left_plan)
.distinct()?
.join_detailed(
right_plan,
join_type,
join_keys,
None,
NullEquality::NullEqualsNull,
)?
.build()
}
}
pub fn build(self) -> Result<LogicalPlan> {
Ok(Arc::unwrap_or_clone(self.plan))
}
pub fn join_with_expr_keys(
self,
right: LogicalPlan,
join_type: JoinType,
equi_exprs: (Vec<impl Into<Expr>>, Vec<impl Into<Expr>>),
filter: Option<Expr>,
) -> Result<Self> {
if equi_exprs.0.len() != equi_exprs.1.len() {
return plan_err!("left_keys and right_keys were not the same length");
}
let join_key_pairs = equi_exprs
.0
.into_iter()
.zip(equi_exprs.1)
.map(|(l, r)| {
let left_key = l.into();
let right_key = r.into();
let mut left_using_columns = HashSet::new();
expr_to_columns(&left_key, &mut left_using_columns)?;
let normalized_left_key = normalize_col_with_schemas_and_ambiguity_check(
left_key,
&[&[self.plan.schema()]],
&[],
)?;
let mut right_using_columns = HashSet::new();
expr_to_columns(&right_key, &mut right_using_columns)?;
let normalized_right_key = normalize_col_with_schemas_and_ambiguity_check(
right_key,
&[&[right.schema()]],
&[],
)?;
find_valid_equijoin_key_pair(
&normalized_left_key,
&normalized_right_key,
self.plan.schema(),
right.schema(),
)?.ok_or_else(||
plan_datafusion_err!(
"can't create join plan, join key should belong to one input, error key: ({normalized_left_key},{normalized_right_key})"
))
})
.collect::<Result<Vec<_>>>()?;
let join = Join::try_new(
self.plan,
Arc::new(right),
join_key_pairs,
filter,
join_type,
JoinConstraint::On,
NullEquality::NullEqualsNothing,
false, )?;
Ok(Self::new(LogicalPlan::Join(join)))
}
pub fn unnest_column(self, column: impl Into<Column>) -> Result<Self> {
unnest(Arc::unwrap_or_clone(self.plan), vec![column.into()]).map(Self::new)
}
pub fn unnest_column_with_options(
self,
column: impl Into<Column>,
options: UnnestOptions,
) -> Result<Self> {
unnest_with_options(
Arc::unwrap_or_clone(self.plan),
vec![column.into()],
options,
)
.map(Self::new)
}
pub fn unnest_columns_with_options(
self,
columns: Vec<Column>,
options: UnnestOptions,
) -> Result<Self> {
unnest_with_options(Arc::unwrap_or_clone(self.plan), columns, options)
.map(Self::new)
}
}
impl From<LogicalPlan> for LogicalPlanBuilder {
fn from(plan: LogicalPlan) -> Self {
LogicalPlanBuilder::new(plan)
}
}
impl From<Arc<LogicalPlan>> for LogicalPlanBuilder {
fn from(plan: Arc<LogicalPlan>) -> Self {
LogicalPlanBuilder::new_from_arc(plan)
}
}
#[derive(Default)]
struct ValuesFields {
inner: Vec<Field>,
}
impl ValuesFields {
pub fn new() -> Self {
Self::default()
}
pub fn push(&mut self, data_type: DataType, nullable: bool) {
self.push_with_metadata(data_type, nullable, None);
}
pub fn push_with_metadata(
&mut self,
data_type: DataType,
nullable: bool,
metadata: Option<FieldMetadata>,
) {
let name = format!("column{}", self.inner.len() + 1);
let mut field = Field::new(name, data_type, nullable);
if let Some(metadata) = metadata {
field.set_metadata(metadata.to_hashmap());
}
self.inner.push(field);
}
pub fn into_fields(self) -> Fields {
self.inner.into()
}
}
pub fn unique_field_aliases(fields: &Fields) -> Vec<Option<String>> {
let mut name_map = HashMap::<&str, usize>::new();
let mut seen = HashSet::<Cow<String>>::new();
fields
.iter()
.map(|field| {
let original_name = field.name();
let mut name = Cow::Borrowed(original_name);
let count = name_map.entry(original_name).or_insert(0);
while seen.contains(&name) {
*count += 1;
name = Cow::Owned(format!("{original_name}:{count}"));
}
seen.insert(name.clone());
match name {
Cow::Borrowed(_) => None,
Cow::Owned(alias) => Some(alias),
}
})
.collect()
}
fn mark_field(schema: &DFSchema) -> (Option<TableReference>, Arc<Field>) {
let mut table_references = schema
.iter()
.filter_map(|(qualifier, _)| qualifier)
.collect::<Vec<_>>();
table_references.dedup();
let table_reference = if table_references.len() == 1 {
table_references.pop().cloned()
} else {
None
};
(
table_reference,
Arc::new(Field::new("mark", DataType::Boolean, false)),
)
}
pub fn build_join_schema(
left: &DFSchema,
right: &DFSchema,
join_type: &JoinType,
) -> Result<DFSchema> {
fn nullify_fields<'a>(
fields: impl Iterator<Item = (Option<&'a TableReference>, &'a Arc<Field>)>,
) -> Vec<(Option<TableReference>, Arc<Field>)> {
fields
.map(|(q, f)| {
let field = f.as_ref().clone().with_nullable(true);
(q.cloned(), Arc::new(field))
})
.collect()
}
let right_fields = right.iter();
let left_fields = left.iter();
let qualified_fields: Vec<(Option<TableReference>, Arc<Field>)> = match join_type {
JoinType::Inner => {
let left_fields = left_fields
.map(|(q, f)| (q.cloned(), Arc::clone(f)))
.collect::<Vec<_>>();
let right_fields = right_fields
.map(|(q, f)| (q.cloned(), Arc::clone(f)))
.collect::<Vec<_>>();
left_fields.into_iter().chain(right_fields).collect()
}
JoinType::Left => {
let left_fields = left_fields
.map(|(q, f)| (q.cloned(), Arc::clone(f)))
.collect::<Vec<_>>();
left_fields
.into_iter()
.chain(nullify_fields(right_fields))
.collect()
}
JoinType::Right => {
let right_fields = right_fields
.map(|(q, f)| (q.cloned(), Arc::clone(f)))
.collect::<Vec<_>>();
nullify_fields(left_fields)
.into_iter()
.chain(right_fields)
.collect()
}
JoinType::Full => {
nullify_fields(left_fields)
.into_iter()
.chain(nullify_fields(right_fields))
.collect()
}
JoinType::LeftSemi | JoinType::LeftAnti => {
left_fields
.map(|(q, f)| (q.cloned(), Arc::clone(f)))
.collect()
}
JoinType::LeftMark => left_fields
.map(|(q, f)| (q.cloned(), Arc::clone(f)))
.chain(once(mark_field(right)))
.collect(),
JoinType::RightSemi | JoinType::RightAnti => {
right_fields
.map(|(q, f)| (q.cloned(), Arc::clone(f)))
.collect()
}
JoinType::RightMark => right_fields
.map(|(q, f)| (q.cloned(), Arc::clone(f)))
.chain(once(mark_field(left)))
.collect(),
};
let func_dependencies = left.functional_dependencies().join(
right.functional_dependencies(),
join_type,
left.fields().len(),
);
let (schema1, schema2) = match join_type {
JoinType::Right
| JoinType::RightSemi
| JoinType::RightAnti
| JoinType::RightMark => (left, right),
_ => (right, left),
};
let metadata = schema1
.metadata()
.clone()
.into_iter()
.chain(schema2.metadata().clone())
.collect();
let dfschema = DFSchema::new_with_metadata(qualified_fields, metadata)?;
dfschema.with_functional_dependencies(func_dependencies)
}
pub fn requalify_sides_if_needed(
left: LogicalPlanBuilder,
right: LogicalPlanBuilder,
) -> Result<(LogicalPlanBuilder, LogicalPlanBuilder, bool)> {
let left_cols = left.schema().columns();
let right_cols = right.schema().columns();
for l in &left_cols {
for r in &right_cols {
if l.name != r.name {
continue;
}
match (&l.relation, &r.relation) {
(Some(l_rel), Some(r_rel)) if l_rel == r_rel => {
return Ok((
left.alias(TableReference::bare("left"))?,
right.alias(TableReference::bare("right"))?,
true,
));
}
(None, None) => {
return Ok((
left.alias(TableReference::bare("left"))?,
right.alias(TableReference::bare("right"))?,
true,
));
}
(Some(_), None) | (None, Some(_)) => {
return Ok((
left.alias(TableReference::bare("left"))?,
right.alias(TableReference::bare("right"))?,
true,
));
}
_ => {}
}
}
}
Ok((left, right, false))
}
pub fn add_group_by_exprs_from_dependencies(
mut group_expr: Vec<Expr>,
schema: &DFSchemaRef,
) -> Result<Vec<Expr>> {
let mut group_by_field_names = group_expr
.iter()
.map(|e| e.schema_name().to_string())
.collect::<Vec<_>>();
if let Some(target_indices) =
get_target_functional_dependencies(schema, &group_by_field_names)
{
for idx in target_indices {
let expr = Expr::Column(Column::from(schema.qualified_field(idx)));
let expr_name = expr.schema_name().to_string();
if !group_by_field_names.contains(&expr_name) {
group_by_field_names.push(expr_name);
group_expr.push(expr);
}
}
}
Ok(group_expr)
}
pub fn validate_unique_names<'a>(
node_name: &str,
expressions: impl IntoIterator<Item = &'a Expr>,
) -> Result<()> {
let mut unique_names = HashMap::new();
expressions.into_iter().enumerate().try_for_each(|(position, expr)| {
let name = expr.schema_name().to_string();
match unique_names.get(&name) {
None => {
unique_names.insert(name, (position, expr));
Ok(())
},
Some((existing_position, existing_expr)) => {
plan_err!("{node_name} require unique expression names \
but the expression \"{existing_expr}\" at position {existing_position} and \"{expr}\" \
at position {position} have the same name. Consider aliasing (\"AS\") one of them."
)
}
}
})
}
pub fn union(left_plan: LogicalPlan, right_plan: LogicalPlan) -> Result<LogicalPlan> {
Ok(LogicalPlan::Union(Union::try_new_with_loose_types(vec![
Arc::new(left_plan),
Arc::new(right_plan),
])?))
}
pub fn union_by_name(
left_plan: LogicalPlan,
right_plan: LogicalPlan,
) -> Result<LogicalPlan> {
Ok(LogicalPlan::Union(Union::try_new_by_name(vec![
Arc::new(left_plan),
Arc::new(right_plan),
])?))
}
pub fn project(
plan: LogicalPlan,
expr: impl IntoIterator<Item = impl Into<SelectExpr>>,
) -> Result<LogicalPlan> {
project_with_validation(plan, expr.into_iter().map(|e| (e, true)))
}
fn project_with_validation(
plan: LogicalPlan,
expr: impl IntoIterator<Item = (impl Into<SelectExpr>, bool)>,
) -> Result<LogicalPlan> {
let mut projected_expr = vec![];
for (e, validate) in expr {
let e = e.into();
match e {
SelectExpr::Wildcard(opt) => {
let expanded = expand_wildcard(plan.schema(), &plan, Some(&opt))?;
let expanded = if let Some(replace) = opt.replace {
replace_columns(expanded, &replace)?
} else {
expanded
};
for e in expanded {
if validate {
projected_expr
.push(columnize_expr(normalize_col(e, &plan)?, &plan)?)
} else {
projected_expr.push(e)
}
}
}
SelectExpr::QualifiedWildcard(table_ref, opt) => {
let expanded =
expand_qualified_wildcard(&table_ref, plan.schema(), Some(&opt))?;
let expanded = if let Some(replace) = opt.replace {
replace_columns(expanded, &replace)?
} else {
expanded
};
for e in expanded {
if validate {
projected_expr
.push(columnize_expr(normalize_col(e, &plan)?, &plan)?)
} else {
projected_expr.push(e)
}
}
}
SelectExpr::Expression(e) => {
if validate {
projected_expr.push(columnize_expr(normalize_col(e, &plan)?, &plan)?)
} else {
projected_expr.push(e)
}
}
}
}
validate_unique_names("Projections", projected_expr.iter())?;
Projection::try_new(projected_expr, Arc::new(plan)).map(LogicalPlan::Projection)
}
fn replace_columns(
mut exprs: Vec<Expr>,
replace: &PlannedReplaceSelectItem,
) -> Result<Vec<Expr>> {
for expr in exprs.iter_mut() {
if let Expr::Column(Column { name, .. }) = expr
&& let Some((_, new_expr)) = replace
.items()
.iter()
.zip(replace.expressions().iter())
.find(|(item, _)| item.column_name.value == *name)
{
*expr = new_expr.clone().alias(name.clone())
}
}
Ok(exprs)
}
pub fn subquery_alias(
plan: LogicalPlan,
alias: impl Into<TableReference>,
) -> Result<LogicalPlan> {
SubqueryAlias::try_new(Arc::new(plan), alias).map(LogicalPlan::SubqueryAlias)
}
pub fn table_scan(
name: Option<impl Into<TableReference>>,
table_schema: &Schema,
projection: Option<Vec<usize>>,
) -> Result<LogicalPlanBuilder> {
table_scan_with_filters(name, table_schema, projection, vec![])
}
pub fn table_scan_with_filters(
name: Option<impl Into<TableReference>>,
table_schema: &Schema,
projection: Option<Vec<usize>>,
filters: Vec<Expr>,
) -> Result<LogicalPlanBuilder> {
let table_source = table_source(table_schema);
let name = name
.map(|n| n.into())
.unwrap_or_else(|| TableReference::bare(UNNAMED_TABLE));
LogicalPlanBuilder::scan_with_filters(name, table_source, projection, filters)
}
pub fn table_scan_with_filter_and_fetch(
name: Option<impl Into<TableReference>>,
table_schema: &Schema,
projection: Option<Vec<usize>>,
filters: Vec<Expr>,
fetch: Option<usize>,
) -> Result<LogicalPlanBuilder> {
let table_source = table_source(table_schema);
let name = name
.map(|n| n.into())
.unwrap_or_else(|| TableReference::bare(UNNAMED_TABLE));
LogicalPlanBuilder::scan_with_filters_fetch(
name,
table_source,
projection,
filters,
fetch,
)
}
pub fn table_source(table_schema: &Schema) -> Arc<dyn TableSource> {
let table_schema = Arc::new(table_schema.clone());
Arc::new(LogicalTableSource {
table_schema,
constraints: Default::default(),
})
}
pub fn table_source_with_constraints(
table_schema: &Schema,
constraints: Constraints,
) -> Arc<dyn TableSource> {
let table_schema = Arc::new(table_schema.clone());
Arc::new(LogicalTableSource {
table_schema,
constraints,
})
}
pub fn wrap_projection_for_join_if_necessary(
join_keys: &[Expr],
input: LogicalPlan,
) -> Result<(LogicalPlan, Vec<Column>, bool)> {
let input_schema = input.schema();
let alias_join_keys: Vec<Expr> = join_keys
.iter()
.map(|key| {
if matches!(key, Expr::Cast(_)) || matches!(key, Expr::TryCast(_)) {
let alias = format!("{key}");
key.clone().alias(alias)
} else {
key.clone()
}
})
.collect::<Vec<_>>();
let need_project = join_keys.iter().any(|key| !matches!(key, Expr::Column(_)));
let plan = if need_project {
let mut projection = input_schema
.columns()
.into_iter()
.map(Expr::Column)
.collect::<Vec<_>>();
let join_key_items = alias_join_keys
.iter()
.flat_map(|expr| expr.try_as_col().is_none().then_some(expr))
.cloned()
.collect::<HashSet<Expr>>();
projection.extend(join_key_items);
LogicalPlanBuilder::from(input)
.project(projection.into_iter().map(SelectExpr::from))?
.build()?
} else {
input
};
let join_on = alias_join_keys
.into_iter()
.map(|key| {
if let Some(col) = key.try_as_col() {
Ok(col.clone())
} else {
let name = key.schema_name().to_string();
Ok(Column::from_name(name))
}
})
.collect::<Result<Vec<_>>>()?;
Ok((plan, join_on, need_project))
}
pub struct LogicalTableSource {
table_schema: SchemaRef,
constraints: Constraints,
}
impl LogicalTableSource {
pub fn new(table_schema: SchemaRef) -> Self {
Self {
table_schema,
constraints: Constraints::default(),
}
}
pub fn with_constraints(mut self, constraints: Constraints) -> Self {
self.constraints = constraints;
self
}
}
impl TableSource for LogicalTableSource {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
Arc::clone(&self.table_schema)
}
fn constraints(&self) -> Option<&Constraints> {
Some(&self.constraints)
}
fn supports_filters_pushdown(
&self,
filters: &[&Expr],
) -> Result<Vec<TableProviderFilterPushDown>> {
Ok(vec![TableProviderFilterPushDown::Exact; filters.len()])
}
}
pub fn unnest(input: LogicalPlan, columns: Vec<Column>) -> Result<LogicalPlan> {
unnest_with_options(input, columns, UnnestOptions::default())
}
pub fn get_struct_unnested_columns(
col_name: &String,
inner_fields: &Fields,
) -> Vec<Column> {
inner_fields
.iter()
.map(|f| Column::from_name(format!("{}.{}", col_name, f.name())))
.collect()
}
pub fn unnest_with_options(
input: LogicalPlan,
columns_to_unnest: Vec<Column>,
options: UnnestOptions,
) -> Result<LogicalPlan> {
Ok(LogicalPlan::Unnest(Unnest::try_new(
Arc::new(input),
columns_to_unnest,
options,
)?))
}
#[cfg(test)]
mod tests {
use std::vec;
use super::*;
use crate::lit_with_metadata;
use crate::logical_plan::StringifiedPlan;
use crate::{col, expr, expr_fn::exists, in_subquery, lit, scalar_subquery};
use crate::test::function_stub::sum;
use datafusion_common::{
Constraint, DataFusionError, RecursionUnnestOption, SchemaError,
};
use insta::assert_snapshot;
#[test]
fn plan_builder_simple() -> Result<()> {
let plan =
table_scan(Some("employee_csv"), &employee_schema(), Some(vec![0, 3]))?
.filter(col("state").eq(lit("CO")))?
.project(vec![col("id")])?
.build()?;
assert_snapshot!(plan, @r#"
Projection: employee_csv.id
Filter: employee_csv.state = Utf8("CO")
TableScan: employee_csv projection=[id, state]
"#);
Ok(())
}
#[test]
fn plan_builder_schema() {
let schema = employee_schema();
let projection = None;
let plan =
LogicalPlanBuilder::scan("employee_csv", table_source(&schema), projection)
.unwrap();
assert_snapshot!(plan.schema().as_ref(), @"fields:[employee_csv.id, employee_csv.first_name, employee_csv.last_name, employee_csv.state, employee_csv.salary], metadata:{}");
let projection = None;
let plan =
LogicalPlanBuilder::scan("EMPLOYEE_CSV", table_source(&schema), projection)
.unwrap();
assert_snapshot!(plan.schema().as_ref(), @"fields:[employee_csv.id, employee_csv.first_name, employee_csv.last_name, employee_csv.state, employee_csv.salary], metadata:{}");
}
#[test]
fn plan_builder_empty_name() {
let schema = employee_schema();
let projection = None;
let err =
LogicalPlanBuilder::scan("", table_source(&schema), projection).unwrap_err();
assert_snapshot!(
err.strip_backtrace(),
@"Error during planning: table_name cannot be empty"
);
}
#[test]
fn plan_builder_sort() -> Result<()> {
let plan =
table_scan(Some("employee_csv"), &employee_schema(), Some(vec![3, 4]))?
.sort(vec![
expr::Sort::new(col("state"), true, true),
expr::Sort::new(col("salary"), false, false),
])?
.build()?;
assert_snapshot!(plan, @r"
Sort: employee_csv.state ASC NULLS FIRST, employee_csv.salary DESC NULLS LAST
TableScan: employee_csv projection=[state, salary]
");
Ok(())
}
#[test]
fn plan_builder_union() -> Result<()> {
let plan =
table_scan(Some("employee_csv"), &employee_schema(), Some(vec![3, 4]))?;
let plan = plan
.clone()
.union(plan.clone().build()?)?
.union(plan.clone().build()?)?
.union(plan.build()?)?
.build()?;
assert_snapshot!(plan, @r"
Union
Union
Union
TableScan: employee_csv projection=[state, salary]
TableScan: employee_csv projection=[state, salary]
TableScan: employee_csv projection=[state, salary]
TableScan: employee_csv projection=[state, salary]
");
Ok(())
}
#[test]
fn plan_builder_union_distinct() -> Result<()> {
let plan =
table_scan(Some("employee_csv"), &employee_schema(), Some(vec![3, 4]))?;
let plan = plan
.clone()
.union_distinct(plan.clone().build()?)?
.union_distinct(plan.clone().build()?)?
.union_distinct(plan.build()?)?
.build()?;
assert_snapshot!(plan, @r"
Distinct:
Union
Distinct:
Union
Distinct:
Union
TableScan: employee_csv projection=[state, salary]
TableScan: employee_csv projection=[state, salary]
TableScan: employee_csv projection=[state, salary]
TableScan: employee_csv projection=[state, salary]
");
Ok(())
}
#[test]
fn plan_builder_simple_distinct() -> Result<()> {
let plan =
table_scan(Some("employee_csv"), &employee_schema(), Some(vec![0, 3]))?
.filter(col("state").eq(lit("CO")))?
.project(vec![col("id")])?
.distinct()?
.build()?;
assert_snapshot!(plan, @r#"
Distinct:
Projection: employee_csv.id
Filter: employee_csv.state = Utf8("CO")
TableScan: employee_csv projection=[id, state]
"#);
Ok(())
}
#[test]
fn exists_subquery() -> Result<()> {
let foo = test_table_scan_with_name("foo")?;
let bar = test_table_scan_with_name("bar")?;
let subquery = LogicalPlanBuilder::from(foo)
.project(vec![col("a")])?
.filter(col("a").eq(col("bar.a")))?
.build()?;
let outer_query = LogicalPlanBuilder::from(bar)
.project(vec![col("a")])?
.filter(exists(Arc::new(subquery)))?
.build()?;
assert_snapshot!(outer_query, @r"
Filter: EXISTS (<subquery>)
Subquery:
Filter: foo.a = bar.a
Projection: foo.a
TableScan: foo
Projection: bar.a
TableScan: bar
");
Ok(())
}
#[test]
fn filter_in_subquery() -> Result<()> {
let foo = test_table_scan_with_name("foo")?;
let bar = test_table_scan_with_name("bar")?;
let subquery = LogicalPlanBuilder::from(foo)
.project(vec![col("a")])?
.filter(col("a").eq(col("bar.a")))?
.build()?;
let outer_query = LogicalPlanBuilder::from(bar)
.project(vec![col("a")])?
.filter(in_subquery(col("a"), Arc::new(subquery)))?
.build()?;
assert_snapshot!(outer_query, @r"
Filter: bar.a IN (<subquery>)
Subquery:
Filter: foo.a = bar.a
Projection: foo.a
TableScan: foo
Projection: bar.a
TableScan: bar
");
Ok(())
}
#[test]
fn select_scalar_subquery() -> Result<()> {
let foo = test_table_scan_with_name("foo")?;
let bar = test_table_scan_with_name("bar")?;
let subquery = LogicalPlanBuilder::from(foo)
.project(vec![col("b")])?
.filter(col("a").eq(col("bar.a")))?
.build()?;
let outer_query = LogicalPlanBuilder::from(bar)
.project(vec![scalar_subquery(Arc::new(subquery))])?
.build()?;
assert_snapshot!(outer_query, @r"
Projection: (<subquery>)
Subquery:
Filter: foo.a = bar.a
Projection: foo.b
TableScan: foo
TableScan: bar
");
Ok(())
}
#[test]
fn projection_non_unique_names() -> Result<()> {
let plan = table_scan(
Some("employee_csv"),
&employee_schema(),
Some(vec![0, 1]),
)?
.project(vec![col("id"), col("first_name").alias("id")]);
match plan {
Err(DataFusionError::SchemaError(err, _)) => {
if let SchemaError::AmbiguousReference { field } = *err {
let Column {
relation,
name,
spans: _,
} = *field;
let Some(TableReference::Bare { table }) = relation else {
return plan_err!(
"wrong relation: {relation:?}, expected table name"
);
};
assert_eq!(*"employee_csv", *table);
assert_eq!("id", &name);
Ok(())
} else {
plan_err!("Plan should have returned an DataFusionError::SchemaError")
}
}
_ => plan_err!("Plan should have returned an DataFusionError::SchemaError"),
}
}
fn employee_schema() -> Schema {
Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("first_name", DataType::Utf8, false),
Field::new("last_name", DataType::Utf8, false),
Field::new("state", DataType::Utf8, false),
Field::new("salary", DataType::Int32, false),
])
}
#[test]
fn stringified_plan() {
let stringified_plan =
StringifiedPlan::new(PlanType::InitialLogicalPlan, "...the plan...");
assert!(stringified_plan.should_display(true));
assert!(!stringified_plan.should_display(false));
let stringified_plan =
StringifiedPlan::new(PlanType::FinalLogicalPlan, "...the plan...");
assert!(stringified_plan.should_display(true));
assert!(stringified_plan.should_display(false));
let stringified_plan =
StringifiedPlan::new(PlanType::InitialPhysicalPlan, "...the plan...");
assert!(stringified_plan.should_display(true));
assert!(!stringified_plan.should_display(false));
let stringified_plan =
StringifiedPlan::new(PlanType::FinalPhysicalPlan, "...the plan...");
assert!(stringified_plan.should_display(true));
assert!(stringified_plan.should_display(false));
let stringified_plan = StringifiedPlan::new(
PlanType::OptimizedLogicalPlan {
optimizer_name: "random opt pass".into(),
},
"...the plan...",
);
assert!(stringified_plan.should_display(true));
assert!(!stringified_plan.should_display(false));
}
fn test_table_scan_with_name(name: &str) -> Result<LogicalPlan> {
let schema = Schema::new(vec![
Field::new("a", DataType::UInt32, false),
Field::new("b", DataType::UInt32, false),
Field::new("c", DataType::UInt32, false),
]);
table_scan(Some(name), &schema, None)?.build()
}
#[test]
fn plan_builder_intersect_different_num_columns_error() -> Result<()> {
let plan1 =
table_scan(TableReference::none(), &employee_schema(), Some(vec![3]))?;
let plan2 =
table_scan(TableReference::none(), &employee_schema(), Some(vec![3, 4]))?;
let err_msg1 =
LogicalPlanBuilder::intersect(plan1.build()?, plan2.build()?, true)
.unwrap_err();
assert_snapshot!(err_msg1.strip_backtrace(), @"Error during planning: INTERSECT/EXCEPT query must have the same number of columns. Left is 1 and right is 2.");
Ok(())
}
#[test]
fn plan_builder_unnest() -> Result<()> {
let err = nested_table_scan("test_table")?
.unnest_column("scalar")
.unwrap_err();
let DataFusionError::Internal(desc) = err else {
return plan_err!("Plan should have returned an DataFusionError::Internal");
};
let desc = (*desc
.split(DataFusionError::BACK_TRACE_SEP)
.collect::<Vec<&str>>()
.first()
.unwrap_or(&""))
.to_string();
assert_snapshot!(desc, @"trying to unnest on invalid data type UInt32");
let plan = nested_table_scan("test_table")?
.unnest_column("strings")?
.build()?;
assert_snapshot!(plan, @r"
Unnest: lists[test_table.strings|depth=1] structs[]
TableScan: test_table
");
let field = plan.schema().field_with_name(None, "strings").unwrap();
assert_eq!(&DataType::Utf8, field.data_type());
let plan = nested_table_scan("test_table")?
.unnest_column("struct_singular")?
.build()?;
assert_snapshot!(plan, @r"
Unnest: lists[] structs[test_table.struct_singular]
TableScan: test_table
");
for field_name in &["a", "b"] {
let field = plan
.schema()
.field_with_name(None, &format!("struct_singular.{field_name}"))
.unwrap();
assert_eq!(&DataType::UInt32, field.data_type());
}
let plan = nested_table_scan("test_table")?
.unnest_column("strings")?
.unnest_column("structs")?
.unnest_column("struct_singular")?
.build()?;
assert_snapshot!(plan, @r"
Unnest: lists[] structs[test_table.struct_singular]
Unnest: lists[test_table.structs|depth=1] structs[]
Unnest: lists[test_table.strings|depth=1] structs[]
TableScan: test_table
");
let field = plan.schema().field_with_name(None, "structs").unwrap();
assert!(matches!(field.data_type(), DataType::Struct(_)));
let cols = vec!["strings", "structs", "struct_singular"]
.into_iter()
.map(|c| c.into())
.collect();
let plan = nested_table_scan("test_table")?
.unnest_columns_with_options(cols, UnnestOptions::default())?
.build()?;
assert_snapshot!(plan, @r"
Unnest: lists[test_table.strings|depth=1, test_table.structs|depth=1] structs[test_table.struct_singular]
TableScan: test_table
");
let plan = nested_table_scan("test_table")?.unnest_column("missing");
assert!(plan.is_err());
let plan = nested_table_scan("test_table")?
.unnest_columns_with_options(
vec!["stringss".into(), "struct_singular".into()],
UnnestOptions::default()
.with_recursions(RecursionUnnestOption {
input_column: "stringss".into(),
output_column: "stringss_depth_1".into(),
depth: 1,
})
.with_recursions(RecursionUnnestOption {
input_column: "stringss".into(),
output_column: "stringss_depth_2".into(),
depth: 2,
}),
)?
.build()?;
assert_snapshot!(plan, @r"
Unnest: lists[test_table.stringss|depth=1, test_table.stringss|depth=2] structs[test_table.struct_singular]
TableScan: test_table
");
let field = plan
.schema()
.field_with_name(None, "stringss_depth_1")
.unwrap();
assert_eq!(
&DataType::new_list(DataType::Utf8, false),
field.data_type()
);
let field = plan
.schema()
.field_with_name(None, "stringss_depth_2")
.unwrap();
assert_eq!(&DataType::Utf8, field.data_type());
for field_name in &["a", "b"] {
let field = plan
.schema()
.field_with_name(None, &format!("struct_singular.{field_name}"))
.unwrap();
assert_eq!(&DataType::UInt32, field.data_type());
}
Ok(())
}
fn nested_table_scan(table_name: &str) -> Result<LogicalPlanBuilder> {
let struct_field_in_list = Field::new_struct(
"item",
vec![
Field::new("a", DataType::UInt32, false),
Field::new("b", DataType::UInt32, false),
],
false,
);
let string_field = Field::new_list_field(DataType::Utf8, false);
let strings_field = Field::new_list("item", string_field.clone(), false);
let schema = Schema::new(vec![
Field::new("scalar", DataType::UInt32, false),
Field::new_list("strings", string_field, false),
Field::new_list("structs", struct_field_in_list, false),
Field::new(
"struct_singular",
DataType::Struct(Fields::from(vec![
Field::new("a", DataType::UInt32, false),
Field::new("b", DataType::UInt32, false),
])),
false,
),
Field::new_list("stringss", strings_field, false),
]);
table_scan(Some(table_name), &schema, None)
}
#[test]
fn test_union_after_join() -> Result<()> {
let values = vec![vec![lit(1)]];
let left = LogicalPlanBuilder::values(values.clone())?
.alias("left")?
.build()?;
let right = LogicalPlanBuilder::values(values)?
.alias("right")?
.build()?;
let join = LogicalPlanBuilder::from(left).cross_join(right)?.build()?;
let plan = LogicalPlanBuilder::from(join.clone())
.union(join)?
.build()?;
assert_snapshot!(plan, @r"
Union
Cross Join:
SubqueryAlias: left
Values: (Int32(1))
SubqueryAlias: right
Values: (Int32(1))
Cross Join:
SubqueryAlias: left
Values: (Int32(1))
SubqueryAlias: right
Values: (Int32(1))
");
Ok(())
}
#[test]
fn plan_builder_from_logical_plan() -> Result<()> {
let plan =
table_scan(Some("employee_csv"), &employee_schema(), Some(vec![3, 4]))?
.sort(vec![
expr::Sort::new(col("state"), true, true),
expr::Sort::new(col("salary"), false, false),
])?
.build()?;
let plan_expected = format!("{plan}");
let plan_builder: LogicalPlanBuilder = Arc::new(plan).into();
assert_eq!(plan_expected, format!("{}", plan_builder.plan));
Ok(())
}
#[test]
fn plan_builder_aggregate_without_implicit_group_by_exprs() -> Result<()> {
let constraints =
Constraints::new_unverified(vec![Constraint::PrimaryKey(vec![0])]);
let table_source = table_source_with_constraints(&employee_schema(), constraints);
let plan =
LogicalPlanBuilder::scan("employee_csv", table_source, Some(vec![0, 3, 4]))?
.aggregate(vec![col("id")], vec![sum(col("salary"))])?
.build()?;
assert_snapshot!(plan, @r"
Aggregate: groupBy=[[employee_csv.id]], aggr=[[sum(employee_csv.salary)]]
TableScan: employee_csv projection=[id, state, salary]
");
Ok(())
}
#[test]
fn plan_builder_aggregate_with_implicit_group_by_exprs() -> Result<()> {
let constraints =
Constraints::new_unverified(vec![Constraint::PrimaryKey(vec![0])]);
let table_source = table_source_with_constraints(&employee_schema(), constraints);
let options =
LogicalPlanBuilderOptions::new().with_add_implicit_group_by_exprs(true);
let plan =
LogicalPlanBuilder::scan("employee_csv", table_source, Some(vec![0, 3, 4]))?
.with_options(options)
.aggregate(vec![col("id")], vec![sum(col("salary"))])?
.build()?;
assert_snapshot!(plan, @r"
Aggregate: groupBy=[[employee_csv.id, employee_csv.state, employee_csv.salary]], aggr=[[sum(employee_csv.salary)]]
TableScan: employee_csv projection=[id, state, salary]
");
Ok(())
}
#[test]
fn test_join_metadata() -> Result<()> {
let left_schema = DFSchema::new_with_metadata(
vec![(None, Arc::new(Field::new("a", DataType::Int32, false)))],
HashMap::from([("key".to_string(), "left".to_string())]),
)?;
let right_schema = DFSchema::new_with_metadata(
vec![(None, Arc::new(Field::new("b", DataType::Int32, false)))],
HashMap::from([("key".to_string(), "right".to_string())]),
)?;
let join_schema =
build_join_schema(&left_schema, &right_schema, &JoinType::Left)?;
assert_eq!(
join_schema.metadata(),
&HashMap::from([("key".to_string(), "left".to_string())])
);
let join_schema =
build_join_schema(&left_schema, &right_schema, &JoinType::Right)?;
assert_eq!(
join_schema.metadata(),
&HashMap::from([("key".to_string(), "right".to_string())])
);
Ok(())
}
#[test]
fn test_values_metadata() -> Result<()> {
let metadata: HashMap<String, String> =
[("ARROW:extension:metadata".to_string(), "test".to_string())]
.into_iter()
.collect();
let metadata = FieldMetadata::from(metadata);
let values = LogicalPlanBuilder::values(vec![
vec![lit_with_metadata(1, Some(metadata.clone()))],
vec![lit_with_metadata(2, Some(metadata.clone()))],
])?
.build()?;
assert_eq!(*values.schema().field(0).metadata(), metadata.to_hashmap());
let metadata2: HashMap<String, String> =
[("ARROW:extension:metadata".to_string(), "test2".to_string())]
.into_iter()
.collect();
let metadata2 = FieldMetadata::from(metadata2);
assert!(
LogicalPlanBuilder::values(vec![
vec![lit_with_metadata(1, Some(metadata.clone()))],
vec![lit_with_metadata(2, Some(metadata2.clone()))],
])
.is_err()
);
Ok(())
}
#[test]
fn test_unique_field_aliases() {
let t1_field_1 = Field::new("a", DataType::Int32, false);
let t2_field_1 = Field::new("a", DataType::Int32, false);
let t2_field_3 = Field::new("a", DataType::Int32, false);
let t2_field_4 = Field::new("a:1", DataType::Int32, false);
let t1_field_2 = Field::new("b", DataType::Int32, false);
let t2_field_2 = Field::new("b", DataType::Int32, false);
let fields = vec![
t1_field_1, t2_field_1, t1_field_2, t2_field_2, t2_field_3, t2_field_4,
];
let fields = Fields::from(fields);
let remove_redundant = unique_field_aliases(&fields);
assert_eq!(
remove_redundant,
vec![
None,
Some("a:1".to_string()),
None,
Some("b:1".to_string()),
Some("a:2".to_string()),
Some("a:1:1".to_string()),
]
);
}
}