use std::ops::Deref;
use std::sync::Arc;
use crate::expressions::Column;
use crate::utils::collect_columns;
use crate::PhysicalExpr;
use arrow::datatypes::{Field, Schema, SchemaRef};
use datafusion_common::stats::{ColumnStatistics, Precision};
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::{internal_datafusion_err, internal_err, plan_err, Result};
use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
use indexmap::IndexMap;
use itertools::Itertools;
#[derive(Debug, Clone)]
pub struct ProjectionExpr {
pub expr: Arc<dyn PhysicalExpr>,
pub alias: String,
}
impl std::fmt::Display for ProjectionExpr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.expr.to_string() == self.alias {
write!(f, "{}", self.alias)
} else {
write!(f, "{} AS {}", self.expr, self.alias)
}
}
}
impl ProjectionExpr {
pub fn new(expr: Arc<dyn PhysicalExpr>, alias: String) -> Self {
Self { expr, alias }
}
pub fn new_from_expression(
expr: Arc<dyn PhysicalExpr>,
schema: &Schema,
) -> Result<Self> {
let field = expr.return_field(schema)?;
Ok(Self {
expr,
alias: field.name().to_string(),
})
}
}
impl From<(Arc<dyn PhysicalExpr>, String)> for ProjectionExpr {
fn from(value: (Arc<dyn PhysicalExpr>, String)) -> Self {
Self::new(value.0, value.1)
}
}
impl From<&(Arc<dyn PhysicalExpr>, String)> for ProjectionExpr {
fn from(value: &(Arc<dyn PhysicalExpr>, String)) -> Self {
Self::new(Arc::clone(&value.0), value.1.clone())
}
}
impl From<ProjectionExpr> for (Arc<dyn PhysicalExpr>, String) {
fn from(value: ProjectionExpr) -> Self {
(value.expr, value.alias)
}
}
#[derive(Debug, Clone)]
pub struct ProjectionExprs {
exprs: Vec<ProjectionExpr>,
}
impl std::fmt::Display for ProjectionExprs {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let exprs: Vec<String> = self.exprs.iter().map(|e| e.to_string()).collect();
write!(f, "Projection[{}]", exprs.join(", "))
}
}
impl From<Vec<ProjectionExpr>> for ProjectionExprs {
fn from(value: Vec<ProjectionExpr>) -> Self {
Self { exprs: value }
}
}
impl From<&[ProjectionExpr]> for ProjectionExprs {
fn from(value: &[ProjectionExpr]) -> Self {
Self {
exprs: value.to_vec(),
}
}
}
impl FromIterator<ProjectionExpr> for ProjectionExprs {
fn from_iter<T: IntoIterator<Item = ProjectionExpr>>(exprs: T) -> Self {
Self {
exprs: exprs.into_iter().collect::<Vec<_>>(),
}
}
}
impl AsRef<[ProjectionExpr]> for ProjectionExprs {
fn as_ref(&self) -> &[ProjectionExpr] {
&self.exprs
}
}
impl ProjectionExprs {
pub fn new<I>(exprs: I) -> Self
where
I: IntoIterator<Item = ProjectionExpr>,
{
Self {
exprs: exprs.into_iter().collect::<Vec<_>>(),
}
}
pub fn from_indices(indices: &[usize], schema: &SchemaRef) -> Self {
let projection_exprs = indices.iter().map(|&i| {
let field = schema.field(i);
ProjectionExpr {
expr: Arc::new(Column::new(field.name(), i)),
alias: field.name().clone(),
}
});
Self::from_iter(projection_exprs)
}
pub fn iter(&self) -> impl Iterator<Item = &ProjectionExpr> {
self.exprs.iter()
}
pub fn projection_mapping(
&self,
input_schema: &SchemaRef,
) -> Result<ProjectionMapping> {
ProjectionMapping::try_new(
self.exprs
.iter()
.map(|p| (Arc::clone(&p.expr), p.alias.clone())),
input_schema,
)
}
pub fn expr_iter(&self) -> impl Iterator<Item = Arc<dyn PhysicalExpr>> + '_ {
self.exprs.iter().map(|e| Arc::clone(&e.expr))
}
pub fn try_merge(&self, other: &ProjectionExprs) -> Result<ProjectionExprs> {
let mut new_exprs = Vec::with_capacity(other.exprs.len());
for proj_expr in &other.exprs {
let new_expr = update_expr(&proj_expr.expr, &self.exprs, true)?
.ok_or_else(|| {
internal_datafusion_err!(
"Failed to combine projections: expression {} could not be applied on top of existing projections {}",
proj_expr.expr,
self.exprs.iter().map(|e| format!("{e}")).join(", ")
)
})?;
new_exprs.push(ProjectionExpr {
expr: new_expr,
alias: proj_expr.alias.clone(),
});
}
Ok(ProjectionExprs::new(new_exprs))
}
pub fn column_indices(&self) -> Vec<usize> {
self.exprs
.iter()
.flat_map(|e| collect_columns(&e.expr).into_iter().map(|col| col.index()))
.sorted_unstable()
.dedup()
.collect_vec()
}
pub fn ordered_column_indices(&self) -> Vec<usize> {
self.exprs
.iter()
.map(|e| {
e.expr
.as_any()
.downcast_ref::<Column>()
.expect("Expected column reference in projection")
.index()
})
.collect()
}
pub fn project_schema(&self, input_schema: &Schema) -> Result<Schema> {
let fields: Result<Vec<Field>> = self
.exprs
.iter()
.map(|proj_expr| {
let metadata = proj_expr
.expr
.return_field(input_schema)?
.metadata()
.clone();
let field = Field::new(
&proj_expr.alias,
proj_expr.expr.data_type(input_schema)?,
proj_expr.expr.nullable(input_schema)?,
)
.with_metadata(metadata);
Ok(field)
})
.collect();
Ok(Schema::new_with_metadata(
fields?,
input_schema.metadata().clone(),
))
}
pub fn project_statistics(
&self,
mut stats: datafusion_common::Statistics,
input_schema: &Schema,
) -> Result<datafusion_common::Statistics> {
let mut primitive_row_size = 0;
let mut primitive_row_size_possible = true;
let mut column_statistics = vec![];
for proj_expr in &self.exprs {
let expr = &proj_expr.expr;
let col_stats = if let Some(col) = expr.as_any().downcast_ref::<Column>() {
stats.column_statistics[col.index()].clone()
} else {
ColumnStatistics::new_unknown()
};
column_statistics.push(col_stats);
let data_type = expr.data_type(input_schema)?;
if let Some(value) = data_type.primitive_width() {
primitive_row_size += value;
continue;
}
primitive_row_size_possible = false;
}
if primitive_row_size_possible {
stats.total_byte_size =
Precision::Exact(primitive_row_size).multiply(&stats.num_rows);
}
stats.column_statistics = column_statistics;
Ok(stats)
}
}
impl<'a> IntoIterator for &'a ProjectionExprs {
type Item = &'a ProjectionExpr;
type IntoIter = std::slice::Iter<'a, ProjectionExpr>;
fn into_iter(self) -> Self::IntoIter {
self.exprs.iter()
}
}
impl IntoIterator for ProjectionExprs {
type Item = ProjectionExpr;
type IntoIter = std::vec::IntoIter<ProjectionExpr>;
fn into_iter(self) -> Self::IntoIter {
self.exprs.into_iter()
}
}
pub fn update_expr(
expr: &Arc<dyn PhysicalExpr>,
projected_exprs: &[ProjectionExpr],
sync_with_child: bool,
) -> Result<Option<Arc<dyn PhysicalExpr>>> {
#[derive(Debug, PartialEq)]
enum RewriteState {
Unchanged,
RewrittenValid,
RewrittenInvalid,
}
let mut state = RewriteState::Unchanged;
let new_expr = Arc::clone(expr)
.transform_up(|expr| {
if state == RewriteState::RewrittenInvalid {
return Ok(Transformed::no(expr));
}
let Some(column) = expr.as_any().downcast_ref::<Column>() else {
return Ok(Transformed::no(expr));
};
if sync_with_child {
state = RewriteState::RewrittenValid;
let projected_expr = projected_exprs.get(column.index()).ok_or_else(|| {
internal_datafusion_err!(
"Column index {} out of bounds for projected expressions of length {}",
column.index(),
projected_exprs.len()
)
})?;
Ok(Transformed::yes(Arc::clone(&projected_expr.expr)))
} else {
state = RewriteState::RewrittenInvalid;
projected_exprs
.iter()
.enumerate()
.find_map(|(index, proj_expr)| {
proj_expr.expr.as_any().downcast_ref::<Column>().and_then(
|projected_column| {
(column.name().eq(projected_column.name())
&& column.index() == projected_column.index())
.then(|| {
state = RewriteState::RewrittenValid;
Arc::new(Column::new(&proj_expr.alias, index)) as _
})
},
)
})
.map_or_else(
|| Ok(Transformed::no(expr)),
|c| Ok(Transformed::yes(c)),
)
}
})
.data()?;
Ok((state == RewriteState::RewrittenValid).then_some(new_expr))
}
#[derive(Clone, Debug, Default)]
pub struct ProjectionTargets {
exprs_indices: Vec<(Arc<dyn PhysicalExpr>, usize)>,
}
impl ProjectionTargets {
pub fn first(&self) -> &(Arc<dyn PhysicalExpr>, usize) {
self.exprs_indices.first().unwrap()
}
pub fn push(&mut self, target: (Arc<dyn PhysicalExpr>, usize)) {
self.exprs_indices.push(target);
}
}
impl Deref for ProjectionTargets {
type Target = [(Arc<dyn PhysicalExpr>, usize)];
fn deref(&self) -> &Self::Target {
&self.exprs_indices
}
}
impl From<Vec<(Arc<dyn PhysicalExpr>, usize)>> for ProjectionTargets {
fn from(exprs_indices: Vec<(Arc<dyn PhysicalExpr>, usize)>) -> Self {
Self { exprs_indices }
}
}
#[derive(Clone, Debug)]
pub struct ProjectionMapping {
map: IndexMap<Arc<dyn PhysicalExpr>, ProjectionTargets>,
}
impl ProjectionMapping {
pub fn try_new(
expr: impl IntoIterator<Item = (Arc<dyn PhysicalExpr>, String)>,
input_schema: &SchemaRef,
) -> Result<Self> {
let mut map = IndexMap::<_, ProjectionTargets>::new();
for (expr_idx, (expr, name)) in expr.into_iter().enumerate() {
let target_expr = Arc::new(Column::new(&name, expr_idx)) as _;
let source_expr = expr.transform_down(|e| match e.as_any().downcast_ref::<Column>() {
Some(col) => {
let idx = col.index();
let matching_field = input_schema.field(idx);
let matching_name = matching_field.name();
if col.name() != matching_name {
return internal_err!(
"Input field name {} does not match with the projection expression {}",
matching_name,
col.name()
);
}
let matching_column = Column::new(matching_name, idx);
Ok(Transformed::yes(Arc::new(matching_column)))
}
None => Ok(Transformed::no(e)),
})
.data()?;
map.entry(source_expr)
.or_default()
.push((target_expr, expr_idx));
}
Ok(Self { map })
}
pub fn from_indices(indices: &[usize], schema: &SchemaRef) -> Result<Self> {
let projection_exprs = indices.iter().map(|index| {
let field = schema.field(*index);
let column = Arc::new(Column::new(field.name(), *index));
(column as _, field.name().clone())
});
ProjectionMapping::try_new(projection_exprs, schema)
}
}
impl Deref for ProjectionMapping {
type Target = IndexMap<Arc<dyn PhysicalExpr>, ProjectionTargets>;
fn deref(&self) -> &Self::Target {
&self.map
}
}
impl FromIterator<(Arc<dyn PhysicalExpr>, ProjectionTargets)> for ProjectionMapping {
fn from_iter<T: IntoIterator<Item = (Arc<dyn PhysicalExpr>, ProjectionTargets)>>(
iter: T,
) -> Self {
Self {
map: IndexMap::from_iter(iter),
}
}
}
pub fn project_orderings(
orderings: &[LexOrdering],
schema: &SchemaRef,
) -> Vec<LexOrdering> {
let mut projected_orderings = vec![];
for ordering in orderings {
projected_orderings.extend(project_ordering(ordering, schema));
}
projected_orderings
}
pub fn project_ordering(
ordering: &LexOrdering,
schema: &SchemaRef,
) -> Option<LexOrdering> {
let mut projected_exprs = vec![];
for PhysicalSortExpr { expr, options } in ordering.iter() {
let transformed = Arc::clone(expr).transform_up(|expr| {
let Some(col) = expr.as_any().downcast_ref::<Column>() else {
return Ok(Transformed::no(expr));
};
let name = col.name();
if let Some((idx, _)) = schema.column_with_name(name) {
Ok(Transformed::yes(Arc::new(Column::new(name, idx))))
} else {
plan_err!("")
}
});
match transformed {
Ok(transformed) => {
projected_exprs.push(PhysicalSortExpr::new(transformed.data, *options));
}
Err(_) => {
break;
}
}
}
LexOrdering::new(projected_exprs)
}
#[cfg(test)]
pub(crate) mod tests {
use std::collections::HashMap;
use super::*;
use crate::equivalence::{convert_to_orderings, EquivalenceProperties};
use crate::expressions::{col, BinaryExpr, Literal};
use crate::utils::tests::TestScalarUDF;
use crate::{PhysicalExprRef, ScalarFunctionExpr};
use arrow::compute::SortOptions;
use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
use datafusion_common::config::ConfigOptions;
use datafusion_common::{ScalarValue, Statistics};
use datafusion_expr::{Operator, ScalarUDF};
use insta::assert_snapshot;
pub(crate) fn output_schema(
mapping: &ProjectionMapping,
input_schema: &Arc<Schema>,
) -> Result<SchemaRef> {
let mut fields = vec![];
for (source, targets) in mapping.iter() {
let data_type = source.data_type(input_schema)?;
let nullable = source.nullable(input_schema)?;
for (target, _) in targets.iter() {
let Some(column) = target.as_any().downcast_ref::<Column>() else {
return plan_err!("Expects to have column");
};
fields.push(Field::new(column.name(), data_type.clone(), nullable));
}
}
let output_schema = Arc::new(Schema::new_with_metadata(
fields,
input_schema.metadata().clone(),
));
Ok(output_schema)
}
#[test]
fn project_orderings() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Int32, true),
Field::new("c", DataType::Int32, true),
Field::new("d", DataType::Int32, true),
Field::new("e", DataType::Int32, true),
Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), true),
]));
let col_a = &col("a", &schema)?;
let col_b = &col("b", &schema)?;
let col_c = &col("c", &schema)?;
let col_d = &col("d", &schema)?;
let col_e = &col("e", &schema)?;
let col_ts = &col("ts", &schema)?;
let a_plus_b = Arc::new(BinaryExpr::new(
Arc::clone(col_a),
Operator::Plus,
Arc::clone(col_b),
)) as Arc<dyn PhysicalExpr>;
let b_plus_d = Arc::new(BinaryExpr::new(
Arc::clone(col_b),
Operator::Plus,
Arc::clone(col_d),
)) as Arc<dyn PhysicalExpr>;
let b_plus_e = Arc::new(BinaryExpr::new(
Arc::clone(col_b),
Operator::Plus,
Arc::clone(col_e),
)) as Arc<dyn PhysicalExpr>;
let c_plus_d = Arc::new(BinaryExpr::new(
Arc::clone(col_c),
Operator::Plus,
Arc::clone(col_d),
)) as Arc<dyn PhysicalExpr>;
let option_asc = SortOptions {
descending: false,
nulls_first: false,
};
let option_desc = SortOptions {
descending: true,
nulls_first: true,
};
let test_cases = vec![
(
vec![
vec![(col_b, option_asc)],
],
vec![(col_b, "b_new".to_string()), (col_a, "a_new".to_string())],
vec![
vec![("b_new", option_asc)],
],
),
(
vec![
],
vec![(col_c, "c_new".to_string()), (col_b, "b_new".to_string())],
vec![
],
),
(
vec![
vec![(col_ts, option_asc)],
],
vec![
(col_b, "b_new".to_string()),
(col_a, "a_new".to_string()),
(col_ts, "ts_new".to_string()),
],
vec![
vec![("ts_new", option_asc)],
],
),
(
vec![
vec![(col_a, option_asc), (col_ts, option_asc)],
vec![(col_b, option_asc), (col_ts, option_asc)],
],
vec![
(col_b, "b_new".to_string()),
(col_a, "a_new".to_string()),
(col_ts, "ts_new".to_string()),
],
vec![
vec![("a_new", option_asc), ("ts_new", option_asc)],
vec![("b_new", option_asc), ("ts_new", option_asc)],
],
),
(
vec![
vec![(&a_plus_b, option_asc)],
],
vec![
(col_b, "b_new".to_string()),
(col_a, "a_new".to_string()),
(&a_plus_b, "a+b".to_string()),
],
vec![
vec![("a+b", option_asc)],
],
),
(
vec![
vec![(&a_plus_b, option_asc), (col_c, option_asc)],
],
vec![
(col_b, "b_new".to_string()),
(col_a, "a_new".to_string()),
(col_c, "c_new".to_string()),
(&a_plus_b, "a+b".to_string()),
],
vec![
vec![("a+b", option_asc), ("c_new", option_asc)],
],
),
(
vec![
vec![(col_a, option_asc), (col_b, option_asc)],
vec![(col_a, option_asc), (col_d, option_asc)],
],
vec![
(col_b, "b_new".to_string()),
(col_a, "a_new".to_string()),
(col_d, "d_new".to_string()),
(&b_plus_d, "b+d".to_string()),
],
vec![
vec![("a_new", option_asc), ("b_new", option_asc)],
vec![("a_new", option_asc), ("d_new", option_asc)],
vec![("a_new", option_asc), ("b+d", option_asc)],
],
),
(
vec![
vec![(&b_plus_d, option_asc)],
],
vec![
(col_b, "b_new".to_string()),
(col_a, "a_new".to_string()),
(col_d, "d_new".to_string()),
(&b_plus_d, "b+d".to_string()),
],
vec![
vec![("b+d", option_asc)],
],
),
(
vec![
vec![
(col_a, option_asc),
(col_d, option_asc),
(col_b, option_asc),
],
vec![(col_c, option_asc)],
],
vec![
(col_b, "b_new".to_string()),
(col_a, "a_new".to_string()),
(col_d, "d_new".to_string()),
(col_c, "c_new".to_string()),
],
vec![
vec![
("a_new", option_asc),
("d_new", option_asc),
("b_new", option_asc),
],
vec![("c_new", option_asc)],
],
),
(
vec![
vec![
(col_a, option_asc),
(col_b, option_asc),
(col_c, option_asc),
],
vec![(col_a, option_asc), (col_d, option_asc)],
],
vec![
(col_b, "b_new".to_string()),
(col_a, "a_new".to_string()),
(col_c, "c_new".to_string()),
(&c_plus_d, "c+d".to_string()),
],
vec![
vec![
("a_new", option_asc),
("b_new", option_asc),
("c_new", option_asc),
],
vec![
("a_new", option_asc),
("b_new", option_asc),
("c+d", option_asc),
],
],
),
(
vec![
vec![(col_a, option_asc), (col_b, option_asc)],
vec![(col_a, option_asc), (col_d, option_asc)],
],
vec![
(col_b, "b_new".to_string()),
(col_a, "a_new".to_string()),
(&b_plus_d, "b+d".to_string()),
],
vec![
vec![("a_new", option_asc), ("b_new", option_asc)],
vec![("a_new", option_asc), ("b+d", option_asc)],
],
),
(
vec![
vec![
(col_a, option_asc),
(col_b, option_asc),
(col_c, option_asc),
],
],
vec![(col_c, "c_new".to_string()), (col_a, "a_new".to_string())],
vec![
vec![("a_new", option_asc)],
],
),
(
vec![
vec![
(col_a, option_asc),
(col_b, option_asc),
(col_c, option_asc),
],
vec![
(col_a, option_asc),
(&a_plus_b, option_asc),
(col_c, option_asc),
],
],
vec![
(col_c, "c_new".to_string()),
(col_b, "b_new".to_string()),
(col_a, "a_new".to_string()),
(&a_plus_b, "a+b".to_string()),
],
vec![
vec![
("a_new", option_asc),
("b_new", option_asc),
("c_new", option_asc),
],
vec![
("a_new", option_asc),
("a+b", option_asc),
("c_new", option_asc),
],
],
),
(
vec![
vec![(col_a, option_asc), (col_b, option_asc)],
vec![(col_c, option_asc), (col_b, option_asc)],
vec![(col_d, option_asc), (col_e, option_asc)],
],
vec![
(col_c, "c_new".to_string()),
(col_d, "d_new".to_string()),
(col_a, "a_new".to_string()),
(&b_plus_e, "b+e".to_string()),
],
vec![
vec![
("a_new", option_asc),
("d_new", option_asc),
("b+e", option_asc),
],
vec![
("d_new", option_asc),
("a_new", option_asc),
("b+e", option_asc),
],
vec![
("c_new", option_asc),
("d_new", option_asc),
("b+e", option_asc),
],
vec![
("d_new", option_asc),
("c_new", option_asc),
("b+e", option_asc),
],
],
),
(
vec![
vec![
(col_a, option_asc),
(col_c, option_asc),
(col_b, option_asc),
],
],
vec![
(col_c, "c_new".to_string()),
(col_a, "a_new".to_string()),
(&a_plus_b, "a+b".to_string()),
],
vec![
vec![
("a_new", option_asc),
("c_new", option_asc),
("a+b", option_asc),
],
],
),
(
vec![
vec![(col_a, option_asc), (col_b, option_asc)],
vec![(col_c, option_asc), (col_b, option_desc)],
vec![(col_e, option_asc)],
],
vec![
(col_c, "c_new".to_string()),
(col_a, "a_new".to_string()),
(col_b, "b_new".to_string()),
(&b_plus_e, "b+e".to_string()),
],
vec![
vec![("a_new", option_asc), ("b_new", option_asc)],
vec![("a_new", option_asc), ("b+e", option_asc)],
vec![("c_new", option_asc), ("b_new", option_desc)],
],
),
];
for (idx, (orderings, proj_exprs, expected)) in test_cases.into_iter().enumerate()
{
let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema));
let orderings = convert_to_orderings(&orderings);
eq_properties.add_orderings(orderings);
let proj_exprs = proj_exprs
.into_iter()
.map(|(expr, name)| (Arc::clone(expr), name));
let projection_mapping = ProjectionMapping::try_new(proj_exprs, &schema)?;
let output_schema = output_schema(&projection_mapping, &schema)?;
let expected = expected
.into_iter()
.map(|ordering| {
ordering
.into_iter()
.map(|(name, options)| {
(col(name, &output_schema).unwrap(), options)
})
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
let expected = convert_to_orderings(&expected);
let projected_eq = eq_properties.project(&projection_mapping, output_schema);
let orderings = projected_eq.oeq_class();
let err_msg = format!(
"test_idx: {idx:?}, actual: {orderings:?}, expected: {expected:?}, projection_mapping: {projection_mapping:?}"
);
assert_eq!(orderings.len(), expected.len(), "{err_msg}");
for expected_ordering in &expected {
assert!(orderings.contains(expected_ordering), "{}", err_msg)
}
}
Ok(())
}
#[test]
fn project_orderings2() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Int32, true),
Field::new("c", DataType::Int32, true),
Field::new("d", DataType::Int32, true),
Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), true),
]));
let col_a = &col("a", &schema)?;
let col_b = &col("b", &schema)?;
let col_c = &col("c", &schema)?;
let col_ts = &col("ts", &schema)?;
let a_plus_b = Arc::new(BinaryExpr::new(
Arc::clone(col_a),
Operator::Plus,
Arc::clone(col_b),
)) as Arc<dyn PhysicalExpr>;
let test_fun = Arc::new(ScalarUDF::new_from_impl(TestScalarUDF::new()));
let round_c = Arc::new(ScalarFunctionExpr::try_new(
test_fun,
vec![Arc::clone(col_c)],
&schema,
Arc::new(ConfigOptions::default()),
)?) as PhysicalExprRef;
let option_asc = SortOptions {
descending: false,
nulls_first: false,
};
let proj_exprs = vec![
(col_b, "b_new".to_string()),
(col_a, "a_new".to_string()),
(col_c, "c_new".to_string()),
(&round_c, "round_c_res".to_string()),
];
let proj_exprs = proj_exprs
.into_iter()
.map(|(expr, name)| (Arc::clone(expr), name));
let projection_mapping = ProjectionMapping::try_new(proj_exprs, &schema)?;
let output_schema = output_schema(&projection_mapping, &schema)?;
let col_a_new = &col("a_new", &output_schema)?;
let col_b_new = &col("b_new", &output_schema)?;
let col_c_new = &col("c_new", &output_schema)?;
let col_round_c_res = &col("round_c_res", &output_schema)?;
let a_new_plus_b_new = Arc::new(BinaryExpr::new(
Arc::clone(col_a_new),
Operator::Plus,
Arc::clone(col_b_new),
)) as Arc<dyn PhysicalExpr>;
let test_cases = [
(
vec![
vec![(col_a, option_asc)],
],
vec![
vec![(col_a_new, option_asc)],
],
),
(
vec![
vec![(&a_plus_b, option_asc)],
],
vec![
vec![(&a_new_plus_b_new, option_asc)],
],
),
(
vec![
vec![(col_a, option_asc), (col_ts, option_asc)],
],
vec![
vec![(col_a_new, option_asc)],
],
),
(
vec![
vec![
(col_a, option_asc),
(col_ts, option_asc),
(col_b, option_asc),
],
],
vec![
vec![(col_a_new, option_asc)],
],
),
(
vec![
vec![(col_a, option_asc), (col_c, option_asc)],
],
vec![
vec![(col_a_new, option_asc), (col_round_c_res, option_asc)],
vec![(col_a_new, option_asc), (col_c_new, option_asc)],
],
),
(
vec![
vec![(col_c, option_asc), (col_b, option_asc)],
],
vec![
vec![(col_round_c_res, option_asc)],
vec![(col_c_new, option_asc), (col_b_new, option_asc)],
],
),
(
vec![
vec![(&a_plus_b, option_asc), (col_c, option_asc)],
],
vec![
vec![
(&a_new_plus_b_new, option_asc),
(col_round_c_res, option_asc),
],
vec![(&a_new_plus_b_new, option_asc), (col_c_new, option_asc)],
],
),
];
for (idx, (orderings, expected)) in test_cases.iter().enumerate() {
let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema));
let orderings = convert_to_orderings(orderings);
eq_properties.add_orderings(orderings);
let expected = convert_to_orderings(expected);
let projected_eq =
eq_properties.project(&projection_mapping, Arc::clone(&output_schema));
let orderings = projected_eq.oeq_class();
let err_msg = format!(
"test idx: {idx:?}, actual: {orderings:?}, expected: {expected:?}, projection_mapping: {projection_mapping:?}"
);
assert_eq!(orderings.len(), expected.len(), "{err_msg}");
for expected_ordering in &expected {
assert!(orderings.contains(expected_ordering), "{}", err_msg)
}
}
Ok(())
}
#[test]
fn project_orderings3() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Int32, true),
Field::new("c", DataType::Int32, true),
Field::new("d", DataType::Int32, true),
Field::new("e", DataType::Int32, true),
Field::new("f", DataType::Int32, true),
]));
let col_a = &col("a", &schema)?;
let col_b = &col("b", &schema)?;
let col_c = &col("c", &schema)?;
let col_d = &col("d", &schema)?;
let col_e = &col("e", &schema)?;
let col_f = &col("f", &schema)?;
let a_plus_b = Arc::new(BinaryExpr::new(
Arc::clone(col_a),
Operator::Plus,
Arc::clone(col_b),
)) as Arc<dyn PhysicalExpr>;
let option_asc = SortOptions {
descending: false,
nulls_first: false,
};
let proj_exprs = vec![
(col_c, "c_new".to_string()),
(col_d, "d_new".to_string()),
(&a_plus_b, "a+b".to_string()),
];
let proj_exprs = proj_exprs
.into_iter()
.map(|(expr, name)| (Arc::clone(expr), name));
let projection_mapping = ProjectionMapping::try_new(proj_exprs, &schema)?;
let output_schema = output_schema(&projection_mapping, &schema)?;
let col_a_plus_b_new = &col("a+b", &output_schema)?;
let col_c_new = &col("c_new", &output_schema)?;
let col_d_new = &col("d_new", &output_schema)?;
let test_cases = vec![
(
vec![
vec![(col_d, option_asc), (col_b, option_asc)],
vec![(col_c, option_asc), (col_a, option_asc)],
],
vec![],
vec![
vec![
(col_d_new, option_asc),
(col_c_new, option_asc),
(col_a_plus_b_new, option_asc),
],
vec![
(col_c_new, option_asc),
(col_d_new, option_asc),
(col_a_plus_b_new, option_asc),
],
],
),
(
vec![
vec![(col_d, option_asc), (col_b, option_asc)],
vec![(col_c, option_asc), (col_e, option_asc)],
],
vec![(col_e, col_a)],
vec![
vec![
(col_d_new, option_asc),
(col_c_new, option_asc),
(col_a_plus_b_new, option_asc),
],
vec![
(col_c_new, option_asc),
(col_d_new, option_asc),
(col_a_plus_b_new, option_asc),
],
],
),
(
vec![
vec![(col_d, option_asc), (col_b, option_asc)],
vec![(col_c, option_asc), (col_e, option_asc)],
],
vec![(col_a, col_f)],
vec![
vec![(col_d_new, option_asc)],
vec![(col_c_new, option_asc)],
],
),
];
for (orderings, equal_columns, expected) in test_cases {
let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema));
for (lhs, rhs) in equal_columns {
eq_properties.add_equal_conditions(Arc::clone(lhs), Arc::clone(rhs))?;
}
let orderings = convert_to_orderings(&orderings);
eq_properties.add_orderings(orderings);
let expected = convert_to_orderings(&expected);
let projected_eq =
eq_properties.project(&projection_mapping, Arc::clone(&output_schema));
let orderings = projected_eq.oeq_class();
let err_msg = format!(
"actual: {orderings:?}, expected: {expected:?}, projection_mapping: {projection_mapping:?}"
);
assert_eq!(orderings.len(), expected.len(), "{err_msg}");
for expected_ordering in &expected {
assert!(orderings.contains(expected_ordering), "{}", err_msg)
}
}
Ok(())
}
fn get_stats() -> Statistics {
Statistics {
num_rows: Precision::Exact(5),
total_byte_size: Precision::Exact(23),
column_statistics: vec![
ColumnStatistics {
distinct_count: Precision::Exact(5),
max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
sum_value: Precision::Exact(ScalarValue::Int64(Some(42))),
null_count: Precision::Exact(0),
},
ColumnStatistics {
distinct_count: Precision::Exact(1),
max_value: Precision::Exact(ScalarValue::from("x")),
min_value: Precision::Exact(ScalarValue::from("a")),
sum_value: Precision::Absent,
null_count: Precision::Exact(3),
},
ColumnStatistics {
distinct_count: Precision::Absent,
max_value: Precision::Exact(ScalarValue::Float32(Some(1.1))),
min_value: Precision::Exact(ScalarValue::Float32(Some(0.1))),
sum_value: Precision::Exact(ScalarValue::Float32(Some(5.5))),
null_count: Precision::Absent,
},
],
}
}
fn get_schema() -> Schema {
let field_0 = Field::new("col0", DataType::Int64, false);
let field_1 = Field::new("col1", DataType::Utf8, false);
let field_2 = Field::new("col2", DataType::Float32, false);
Schema::new(vec![field_0, field_1, field_2])
}
#[test]
fn test_stats_projection_columns_only() {
let source = get_stats();
let schema = get_schema();
let projection = ProjectionExprs::new(vec![
ProjectionExpr {
expr: Arc::new(Column::new("col1", 1)),
alias: "col1".to_string(),
},
ProjectionExpr {
expr: Arc::new(Column::new("col0", 0)),
alias: "col0".to_string(),
},
]);
let result = projection.project_statistics(source, &schema).unwrap();
let expected = Statistics {
num_rows: Precision::Exact(5),
total_byte_size: Precision::Exact(23),
column_statistics: vec![
ColumnStatistics {
distinct_count: Precision::Exact(1),
max_value: Precision::Exact(ScalarValue::from("x")),
min_value: Precision::Exact(ScalarValue::from("a")),
sum_value: Precision::Absent,
null_count: Precision::Exact(3),
},
ColumnStatistics {
distinct_count: Precision::Exact(5),
max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
sum_value: Precision::Exact(ScalarValue::Int64(Some(42))),
null_count: Precision::Exact(0),
},
],
};
assert_eq!(result, expected);
}
#[test]
fn test_stats_projection_column_with_primitive_width_only() {
let source = get_stats();
let schema = get_schema();
let projection = ProjectionExprs::new(vec![
ProjectionExpr {
expr: Arc::new(Column::new("col2", 2)),
alias: "col2".to_string(),
},
ProjectionExpr {
expr: Arc::new(Column::new("col0", 0)),
alias: "col0".to_string(),
},
]);
let result = projection.project_statistics(source, &schema).unwrap();
let expected = Statistics {
num_rows: Precision::Exact(5),
total_byte_size: Precision::Exact(60),
column_statistics: vec![
ColumnStatistics {
distinct_count: Precision::Absent,
max_value: Precision::Exact(ScalarValue::Float32(Some(1.1))),
min_value: Precision::Exact(ScalarValue::Float32(Some(0.1))),
sum_value: Precision::Exact(ScalarValue::Float32(Some(5.5))),
null_count: Precision::Absent,
},
ColumnStatistics {
distinct_count: Precision::Exact(5),
max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
sum_value: Precision::Exact(ScalarValue::Int64(Some(42))),
null_count: Precision::Exact(0),
},
],
};
assert_eq!(result, expected);
}
#[test]
fn test_projection_new() -> Result<()> {
let exprs = vec![
ProjectionExpr {
expr: Arc::new(Column::new("a", 0)),
alias: "a".to_string(),
},
ProjectionExpr {
expr: Arc::new(Column::new("b", 1)),
alias: "b".to_string(),
},
];
let projection = ProjectionExprs::new(exprs.clone());
assert_eq!(projection.as_ref().len(), 2);
Ok(())
}
#[test]
fn test_projection_from_vec() -> Result<()> {
let exprs = vec![ProjectionExpr {
expr: Arc::new(Column::new("x", 0)),
alias: "x".to_string(),
}];
let projection: ProjectionExprs = exprs.clone().into();
assert_eq!(projection.as_ref().len(), 1);
Ok(())
}
#[test]
fn test_projection_as_ref() -> Result<()> {
let exprs = vec![
ProjectionExpr {
expr: Arc::new(Column::new("col1", 0)),
alias: "col1".to_string(),
},
ProjectionExpr {
expr: Arc::new(Column::new("col2", 1)),
alias: "col2".to_string(),
},
];
let projection = ProjectionExprs::new(exprs);
let as_ref: &[ProjectionExpr] = projection.as_ref();
assert_eq!(as_ref.len(), 2);
Ok(())
}
#[test]
fn test_column_indices_multiple_columns() -> Result<()> {
let projection = ProjectionExprs::new(vec![
ProjectionExpr {
expr: Arc::new(Column::new("c", 5)),
alias: "c".to_string(),
},
ProjectionExpr {
expr: Arc::new(Column::new("b", 2)),
alias: "b".to_string(),
},
ProjectionExpr {
expr: Arc::new(Column::new("a", 0)),
alias: "a".to_string(),
},
]);
assert_eq!(projection.column_indices(), vec![0, 2, 5]);
Ok(())
}
#[test]
fn test_column_indices_duplicates() -> Result<()> {
let projection = ProjectionExprs::new(vec![
ProjectionExpr {
expr: Arc::new(Column::new("a", 1)),
alias: "a".to_string(),
},
ProjectionExpr {
expr: Arc::new(Column::new("b", 3)),
alias: "b".to_string(),
},
ProjectionExpr {
expr: Arc::new(Column::new("a2", 1)), alias: "a2".to_string(),
},
]);
assert_eq!(projection.column_indices(), vec![1, 3]);
Ok(())
}
#[test]
fn test_column_indices_unsorted() -> Result<()> {
let projection = ProjectionExprs::new(vec![
ProjectionExpr {
expr: Arc::new(Column::new("c", 5)),
alias: "c".to_string(),
},
ProjectionExpr {
expr: Arc::new(Column::new("a", 1)),
alias: "a".to_string(),
},
ProjectionExpr {
expr: Arc::new(Column::new("b", 3)),
alias: "b".to_string(),
},
]);
assert_eq!(projection.column_indices(), vec![1, 3, 5]);
Ok(())
}
#[test]
fn test_column_indices_complex_expr() -> Result<()> {
let expr = Arc::new(BinaryExpr::new(
Arc::new(Column::new("a", 1)),
Operator::Plus,
Arc::new(Column::new("b", 4)),
));
let projection = ProjectionExprs::new(vec![
ProjectionExpr {
expr,
alias: "sum".to_string(),
},
ProjectionExpr {
expr: Arc::new(Column::new("c", 2)),
alias: "c".to_string(),
},
]);
assert_eq!(projection.column_indices(), vec![1, 2, 4]);
Ok(())
}
#[test]
fn test_column_indices_empty() -> Result<()> {
let projection = ProjectionExprs::new(vec![]);
assert_eq!(projection.column_indices(), Vec::<usize>::new());
Ok(())
}
#[test]
fn test_merge_simple_columns() -> Result<()> {
let base_projection = ProjectionExprs::new(vec![
ProjectionExpr {
expr: Arc::new(Column::new("c", 2)),
alias: "x".to_string(),
},
ProjectionExpr {
expr: Arc::new(Column::new("b", 1)),
alias: "y".to_string(),
},
ProjectionExpr {
expr: Arc::new(Column::new("a", 0)),
alias: "z".to_string(),
},
]);
let top_projection = ProjectionExprs::new(vec![
ProjectionExpr {
expr: Arc::new(Column::new("y", 1)),
alias: "col2".to_string(),
},
ProjectionExpr {
expr: Arc::new(Column::new("x", 0)),
alias: "col1".to_string(),
},
]);
let merged = base_projection.try_merge(&top_projection)?;
assert_snapshot!(format!("{merged}"), @"Projection[b@1 AS col2, c@2 AS col1]");
Ok(())
}
#[test]
fn test_merge_with_expressions() -> Result<()> {
let base_projection = ProjectionExprs::new(vec![
ProjectionExpr {
expr: Arc::new(Column::new("c", 2)),
alias: "x".to_string(),
},
ProjectionExpr {
expr: Arc::new(Column::new("b", 1)),
alias: "y".to_string(),
},
ProjectionExpr {
expr: Arc::new(Column::new("a", 0)),
alias: "z".to_string(),
},
]);
let top_projection = ProjectionExprs::new(vec![
ProjectionExpr {
expr: Arc::new(BinaryExpr::new(
Arc::new(Column::new("y", 1)),
Operator::Plus,
Arc::new(Column::new("z", 2)),
)),
alias: "c2".to_string(),
},
ProjectionExpr {
expr: Arc::new(BinaryExpr::new(
Arc::new(Column::new("x", 0)),
Operator::Plus,
Arc::new(Literal::new(ScalarValue::Int32(Some(1)))),
)),
alias: "c1".to_string(),
},
]);
let merged = base_projection.try_merge(&top_projection)?;
assert_snapshot!(format!("{merged}"), @"Projection[b@1 + a@0 AS c2, c@2 + 1 AS c1]");
Ok(())
}
#[test]
fn try_merge_error() {
let base = ProjectionExprs::new(vec![
ProjectionExpr {
expr: Arc::new(Column::new("a", 0)),
alias: "x".to_string(),
},
ProjectionExpr {
expr: Arc::new(Column::new("b", 1)),
alias: "y".to_string(),
},
]);
let top = ProjectionExprs::new(vec![ProjectionExpr {
expr: Arc::new(Column::new("z", 5)), alias: "result".to_string(),
}]);
let err_msg = base.try_merge(&top).unwrap_err().to_string();
assert!(
err_msg.contains("Internal error: Column index 5 out of bounds for projected expressions of length 2"),
"Unexpected error message: {err_msg}",
);
}
#[test]
fn test_project_schema_simple_columns() -> Result<()> {
let input_schema = get_schema();
let projection = ProjectionExprs::new(vec![
ProjectionExpr {
expr: Arc::new(Column::new("col2", 2)),
alias: "c".to_string(),
},
ProjectionExpr {
expr: Arc::new(Column::new("col0", 0)),
alias: "a".to_string(),
},
]);
let output_schema = projection.project_schema(&input_schema)?;
assert_eq!(output_schema.fields().len(), 2);
assert_eq!(output_schema.field(0).name(), "c");
assert_eq!(output_schema.field(0).data_type(), &DataType::Float32);
assert_eq!(output_schema.field(1).name(), "a");
assert_eq!(output_schema.field(1).data_type(), &DataType::Int64);
Ok(())
}
#[test]
fn test_project_schema_with_expressions() -> Result<()> {
let input_schema = get_schema();
let projection = ProjectionExprs::new(vec![ProjectionExpr {
expr: Arc::new(BinaryExpr::new(
Arc::new(Column::new("col0", 0)),
Operator::Plus,
Arc::new(Literal::new(ScalarValue::Int64(Some(1)))),
)),
alias: "incremented".to_string(),
}]);
let output_schema = projection.project_schema(&input_schema)?;
assert_eq!(output_schema.fields().len(), 1);
assert_eq!(output_schema.field(0).name(), "incremented");
assert_eq!(output_schema.field(0).data_type(), &DataType::Int64);
Ok(())
}
#[test]
fn test_project_schema_preserves_metadata() -> Result<()> {
let mut metadata = HashMap::new();
metadata.insert("key".to_string(), "value".to_string());
let field_with_metadata =
Field::new("col0", DataType::Int64, false).with_metadata(metadata.clone());
let input_schema = Schema::new(vec![
field_with_metadata,
Field::new("col1", DataType::Utf8, false),
]);
let projection = ProjectionExprs::new(vec![ProjectionExpr {
expr: Arc::new(Column::new("col0", 0)),
alias: "renamed".to_string(),
}]);
let output_schema = projection.project_schema(&input_schema)?;
assert_eq!(output_schema.fields().len(), 1);
assert_eq!(output_schema.field(0).name(), "renamed");
assert_eq!(output_schema.field(0).metadata(), &metadata);
Ok(())
}
#[test]
fn test_project_schema_empty() -> Result<()> {
let input_schema = get_schema();
let projection = ProjectionExprs::new(vec![]);
let output_schema = projection.project_schema(&input_schema)?;
assert_eq!(output_schema.fields().len(), 0);
Ok(())
}
#[test]
fn test_project_statistics_columns_only() -> Result<()> {
let input_stats = get_stats();
let input_schema = get_schema();
let projection = ProjectionExprs::new(vec![
ProjectionExpr {
expr: Arc::new(Column::new("col1", 1)),
alias: "text".to_string(),
},
ProjectionExpr {
expr: Arc::new(Column::new("col0", 0)),
alias: "num".to_string(),
},
]);
let output_stats = projection.project_statistics(input_stats, &input_schema)?;
assert_eq!(output_stats.num_rows, Precision::Exact(5));
assert_eq!(output_stats.column_statistics.len(), 2);
assert_eq!(
output_stats.column_statistics[0].distinct_count,
Precision::Exact(1)
);
assert_eq!(
output_stats.column_statistics[0].max_value,
Precision::Exact(ScalarValue::from("x"))
);
assert_eq!(
output_stats.column_statistics[1].distinct_count,
Precision::Exact(5)
);
assert_eq!(
output_stats.column_statistics[1].max_value,
Precision::Exact(ScalarValue::Int64(Some(21)))
);
Ok(())
}
#[test]
fn test_project_statistics_with_expressions() -> Result<()> {
let input_stats = get_stats();
let input_schema = get_schema();
let projection = ProjectionExprs::new(vec![
ProjectionExpr {
expr: Arc::new(BinaryExpr::new(
Arc::new(Column::new("col0", 0)),
Operator::Plus,
Arc::new(Literal::new(ScalarValue::Int64(Some(1)))),
)),
alias: "incremented".to_string(),
},
ProjectionExpr {
expr: Arc::new(Column::new("col1", 1)),
alias: "text".to_string(),
},
]);
let output_stats = projection.project_statistics(input_stats, &input_schema)?;
assert_eq!(output_stats.num_rows, Precision::Exact(5));
assert_eq!(output_stats.column_statistics.len(), 2);
assert_eq!(
output_stats.column_statistics[0].distinct_count,
Precision::Absent
);
assert_eq!(
output_stats.column_statistics[0].max_value,
Precision::Absent
);
assert_eq!(
output_stats.column_statistics[1].distinct_count,
Precision::Exact(1)
);
Ok(())
}
#[test]
fn test_project_statistics_primitive_width_only() -> Result<()> {
let input_stats = get_stats();
let input_schema = get_schema();
let projection = ProjectionExprs::new(vec![
ProjectionExpr {
expr: Arc::new(Column::new("col2", 2)),
alias: "f".to_string(),
},
ProjectionExpr {
expr: Arc::new(Column::new("col0", 0)),
alias: "i".to_string(),
},
]);
let output_stats = projection.project_statistics(input_stats, &input_schema)?;
assert_eq!(output_stats.num_rows, Precision::Exact(5));
assert_eq!(output_stats.total_byte_size, Precision::Exact(60));
assert_eq!(output_stats.column_statistics.len(), 2);
Ok(())
}
#[test]
fn test_project_statistics_empty() -> Result<()> {
let input_stats = get_stats();
let input_schema = get_schema();
let projection = ProjectionExprs::new(vec![]);
let output_stats = projection.project_statistics(input_stats, &input_schema)?;
assert_eq!(output_stats.num_rows, Precision::Exact(5));
assert_eq!(output_stats.column_statistics.len(), 0);
assert_eq!(output_stats.total_byte_size, Precision::Exact(0));
Ok(())
}
}