use crate::expr::Alias;
use crate::expr_rewriter::normalize_col;
use crate::{Cast, Expr, LogicalPlan, TryCast, expr::Sort};
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::{Column, Result};
pub fn rewrite_sort_cols_by_aggs(
sorts: impl IntoIterator<Item = impl Into<Sort>>,
plan: &LogicalPlan,
) -> Result<Vec<Sort>> {
sorts
.into_iter()
.map(|e| {
let sort = e.into();
Ok(Sort::new(
rewrite_sort_col_by_aggs(sort.expr, plan)?,
sort.asc,
sort.nulls_first,
))
})
.collect()
}
fn rewrite_sort_col_by_aggs(expr: Expr, plan: &LogicalPlan) -> Result<Expr> {
let plan_inputs = plan.inputs();
if plan_inputs.len() == 1 {
let proj_exprs = plan.expressions();
rewrite_in_terms_of_projection(expr, &proj_exprs, plan_inputs[0])
} else {
Ok(expr)
}
}
fn rewrite_in_terms_of_projection(
expr: Expr,
proj_exprs: &[Expr],
input: &LogicalPlan,
) -> Result<Expr> {
expr.transform(|expr| {
if let Some(found) = proj_exprs.iter().find(|a| expr_match(&expr, a)) {
let (qualifier, field_name) = found.qualified_name();
let col = Expr::Column(Column::new(qualifier, field_name));
return Ok(Transformed::yes(col));
}
let normalized_expr = if let Ok(e) = normalize_col(expr.clone(), input) {
e
} else {
return Ok(Transformed::no(expr));
};
let name = normalized_expr.schema_name().to_string();
let search_col = Expr::Column(Column::new_unqualified(name));
let found = proj_exprs
.iter()
.find(|proj_expr| expr_match(&search_col, proj_expr));
if let Some(found) = found {
let (qualifier, field_name) = found.qualified_name();
let col = Expr::Column(Column::new(qualifier, field_name));
return Ok(Transformed::yes(match normalized_expr {
Expr::Cast(Cast { expr: _, field }) => Expr::Cast(Cast {
expr: Box::new(col),
field,
}),
Expr::TryCast(TryCast { expr: _, field }) => Expr::TryCast(TryCast {
expr: Box::new(col),
field,
}),
_ => col,
}));
}
Ok(Transformed::no(expr))
})
.data()
}
fn expr_match(needle: &Expr, expr: &Expr) -> bool {
if let Expr::Alias(Alias { expr, .. }) = &expr {
expr.as_ref() == needle
} else {
expr == needle
}
}
#[cfg(test)]
mod test {
use std::ops::Add;
use std::sync::Arc;
use arrow::datatypes::{DataType, Field, Schema};
use crate::{
LogicalPlanBuilder, cast, col, lit, logical_plan::builder::LogicalTableSource,
try_cast,
};
use super::*;
use crate::test::function_stub::avg;
use crate::test::function_stub::count;
use crate::test::function_stub::max;
use crate::test::function_stub::min;
use crate::test::function_stub::sum;
#[test]
fn rewrite_sort_cols_by_agg() {
let agg = make_input()
.aggregate(
vec![col("c1")],
vec![min(col("c2"))],
)
.unwrap()
.build()
.unwrap();
let cases = vec![
TestCase {
desc: "c1 --> c1",
input: sort(col("c1")),
expected: sort(col("c1")),
},
TestCase {
desc: "c1 + c2 --> c1 + c2",
input: sort(col("c1") + col("c1")),
expected: sort(col("c1") + col("c1")),
},
TestCase {
desc: r#"min(c2) --> "min(c2)"#,
input: sort(min(col("c2"))),
expected: sort(min(col("c2"))),
},
TestCase {
desc: r#"c1 + min(c2) --> "c1 + min(c2)"#,
input: sort(col("c1") + min(col("c2"))),
expected: sort(col("c1") + min(col("c2"))),
},
];
for case in cases {
case.run(&agg)
}
}
#[test]
fn rewrite_sort_cols_by_agg_alias() {
let agg = make_input()
.aggregate(
vec![col("c1")],
vec![min(col("c2")), avg(col("c3"))],
)
.unwrap()
.project(vec![
col("c1").add(lit(1)).alias("c1"),
min(col("c2")),
avg(col("c3")).alias("average"),
])
.unwrap()
.build()
.unwrap();
let cases = vec![
TestCase {
desc: "c1 --> c1 -- column *named* c1 that came out of the projection, (not t.c1)",
input: sort(col("c1")),
expected: sort(col("c1")),
},
TestCase {
desc: r#"min(c2) --> "min(c2)" -- (column *named* "min(t.c2)"!)"#,
input: sort(min(col("c2"))),
expected: sort(Expr::Column(Column::new_unqualified("min(t.c2)"))),
},
TestCase {
desc: r#"c1 + min(c2) --> "c1 + min(c2)" -- (column *named* "min(t.c2)"!)"#,
input: sort(col("c1") + min(col("c2"))),
expected: sort(
col("c1") + Expr::Column(Column::new_unqualified("min(t.c2)")),
),
},
TestCase {
desc: r#"avg(c3) --> "average" (column *named* "average", from alias)"#,
input: sort(avg(col("c3"))),
expected: sort(col("average")),
},
];
for case in cases {
case.run(&agg)
}
}
#[test]
fn rewrite_sort_resolves_alias_to_column_ref() {
let plan = make_input()
.aggregate(vec![col("c1")], vec![min(col("c2")), max(col("c3"))])
.unwrap()
.project(vec![
col("c1"),
min(col("c2")).alias("min_val"),
max(col("c3")).alias("max_val"),
])
.unwrap()
.build()
.unwrap();
let cases = vec![
TestCase {
desc: "min(c2) with alias 'min_val' should resolve to col(min_val)",
input: sort(min(col("c2"))),
expected: sort(col("min_val")),
},
TestCase {
desc: "max(c3) with alias 'max_val' should resolve to col(max_val)",
input: sort(max(col("c3"))),
expected: sort(col("max_val")),
},
];
for case in cases {
case.run(&plan)
}
}
#[test]
fn composite_proj_expr_containing_sort_col_as_subexpr() {
let plan = make_input()
.aggregate(vec![col("c1")], vec![min(col("c2")), max(col("c3"))])
.unwrap()
.project(vec![
col("c1"),
(min(col("c2")) + max(col("c3"))).alias("range"),
min(col("c2")).alias("min_val"),
max(col("c3")).alias("max_val"),
])
.unwrap()
.build()
.unwrap();
let cases = vec![
TestCase {
desc: "sort by min(c2) should resolve to col(min_val), not col(range)",
input: sort(min(col("c2"))),
expected: sort(col("min_val")),
},
TestCase {
desc: "sort by max(c3) should resolve to col(max_val), not col(range)",
input: sort(max(col("c3"))),
expected: sort(col("max_val")),
},
];
for case in cases {
case.run(&plan)
}
}
#[test]
fn composite_before_standalone_should_not_shadow() {
let plan = make_input()
.aggregate(vec![col("c1")], vec![min(col("c2")), max(col("c2"))])
.unwrap()
.project(vec![
col("c1"),
(min(col("c2")) + max(col("c2"))).alias("combined"),
min(col("c2")),
])
.unwrap()
.build()
.unwrap();
let cases = vec![TestCase {
desc: "sort by min(c2) should resolve to col(min(t.c2)), not col(combined)",
input: sort(min(col("c2"))),
expected: sort(Expr::Column(Column::new_unqualified("min(t.c2)"))),
}];
for case in cases {
case.run(&plan)
}
}
#[test]
fn duplicate_aggregate_in_multiple_proj_exprs() {
let plan = make_input()
.aggregate(vec![col("c1")], vec![min(col("c2"))])
.unwrap()
.project(vec![
col("c1"),
min(col("c2")).alias("first_alias"),
min(col("c2")).alias("second_alias"),
])
.unwrap()
.build()
.unwrap();
let cases = vec![TestCase {
desc: "sort by min(c2) with two aliases picks first_alias",
input: sort(min(col("c2"))),
expected: sort(col("first_alias")),
}];
for case in cases {
case.run(&plan)
}
}
#[test]
fn sort_agg_not_in_select_with_aliased_aggs() {
let plan = make_input()
.aggregate(
vec![col("c1")],
vec![min(col("c2")), max(col("c3")), sum(col("c3"))],
)
.unwrap()
.project(vec![
col("c1"),
min(col("c2")).alias("min_val"),
max(col("c3")).alias("max_val"),
])
.unwrap()
.build()
.unwrap();
let cases = vec![TestCase {
desc: "sort by sum(c3) not in projection should not be rewritten",
input: sort(sum(col("c3"))),
expected: sort(sum(col("c3"))),
}];
for case in cases {
case.run(&plan)
}
}
#[test]
fn cast_on_aliased_aggregate() {
let plan = make_input()
.aggregate(vec![col("c1")], vec![min(col("c2"))])
.unwrap()
.project(vec![col("c1"), min(col("c2")).alias("min_val")])
.unwrap()
.build()
.unwrap();
let cases = vec![
TestCase {
desc: "CAST on aliased aggregate should preserve cast and resolve alias",
input: sort(cast(min(col("c2")), DataType::Int64)),
expected: sort(cast(col("min_val"), DataType::Int64)),
},
TestCase {
desc: "TryCast on aliased aggregate should preserve try_cast and resolve alias",
input: sort(try_cast(min(col("c2")), DataType::Int64)),
expected: sort(try_cast(col("min_val"), DataType::Int64)),
},
];
for case in cases {
case.run(&plan)
}
}
#[test]
fn count_star_with_alias() {
let plan = make_input()
.aggregate(vec![col("c1")], vec![count(lit(1))])
.unwrap()
.project(vec![col("c1"), count(lit(1)).alias("cnt")])
.unwrap()
.build()
.unwrap();
let cases = vec![TestCase {
desc: "sort by count(1) should resolve to cnt alias",
input: sort(count(lit(1))),
expected: sort(col("cnt")),
}];
for case in cases {
case.run(&plan)
}
}
#[test]
fn preserve_cast() {
let plan = make_input()
.project(vec![col("c2").alias("c2")])
.unwrap()
.project(vec![col("c2").alias("c2")])
.unwrap()
.build()
.unwrap();
let cases = vec![
TestCase {
desc: "Cast is preserved by rewrite_sort_cols_by_aggs",
input: sort(cast(col("c2"), DataType::Int64)),
expected: sort(cast(col("c2"), DataType::Int64)),
},
TestCase {
desc: "TryCast is preserved by rewrite_sort_cols_by_aggs",
input: sort(try_cast(col("c2"), DataType::Int64)),
expected: sort(try_cast(col("c2"), DataType::Int64)),
},
];
for case in cases {
case.run(&plan)
}
}
struct TestCase {
desc: &'static str,
input: Sort,
expected: Sort,
}
impl TestCase {
fn run(self, input_plan: &LogicalPlan) {
let Self {
desc,
input,
expected,
} = self;
println!("running: '{desc}'");
let mut exprs =
rewrite_sort_cols_by_aggs(vec![input.clone()], input_plan).unwrap();
assert_eq!(exprs.len(), 1);
let rewritten = exprs.pop().unwrap();
assert_eq!(
rewritten, expected,
"\n\ninput:{input:?}\nrewritten:{rewritten:?}\nexpected:{expected:?}\n"
);
}
}
fn make_input() -> LogicalPlanBuilder {
let schema = Arc::new(Schema::new(vec![
Field::new("c1", DataType::Int32, true),
Field::new("c2", DataType::Utf8, true),
Field::new("c3", DataType::Float64, true),
]));
let projection = None;
LogicalPlanBuilder::scan(
"t",
Arc::new(LogicalTableSource::new(schema)),
projection,
)
.unwrap()
}
fn sort(expr: Expr) -> Sort {
let asc = true;
let nulls_first = true;
expr.sort(asc, nulls_first)
}
}