datafusion_expr/expr_rewriter/
order_by.rs1use crate::expr::Alias;
21use crate::expr_rewriter::normalize_col;
22use crate::{Cast, Expr, LogicalPlan, TryCast, expr::Sort};
23
24use datafusion_common::tree_node::{
25 Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
26};
27use datafusion_common::{Column, Result};
28
29pub fn rewrite_sort_cols_by_aggs(
32 sorts: impl IntoIterator<Item = impl Into<Sort>>,
33 plan: &LogicalPlan,
34) -> Result<Vec<Sort>> {
35 sorts
36 .into_iter()
37 .map(|e| {
38 let sort = e.into();
39 Ok(Sort::new(
40 rewrite_sort_col_by_aggs(sort.expr, plan)?,
41 sort.asc,
42 sort.nulls_first,
43 ))
44 })
45 .collect()
46}
47
48fn rewrite_sort_col_by_aggs(expr: Expr, plan: &LogicalPlan) -> Result<Expr> {
49 let plan_inputs = plan.inputs();
50
51 if plan_inputs.len() == 1 {
54 let proj_exprs = plan.expressions();
55 rewrite_in_terms_of_projection(expr, &proj_exprs, plan_inputs[0])
56 } else {
57 Ok(expr)
58 }
59}
60
61fn rewrite_in_terms_of_projection(
73 expr: Expr,
74 proj_exprs: &[Expr],
75 input: &LogicalPlan,
76) -> Result<Expr> {
77 expr.transform(|expr| {
80 if let Some(found) = proj_exprs.iter().find(|a| (**a) == expr) {
82 let (qualifier, field_name) = found.qualified_name();
83 let col = Expr::Column(Column::new(qualifier, field_name));
84 return Ok(Transformed::yes(col));
85 }
86
87 let normalized_expr = if let Ok(e) = normalize_col(expr.clone(), input) {
93 e
94 } else {
95 return Ok(Transformed::no(expr));
97 };
98
99 let name = normalized_expr.schema_name().to_string();
102
103 let search_col = Expr::Column(Column::new_unqualified(name));
104
105 let mut found = None;
107 for proj_expr in proj_exprs {
108 proj_expr.apply(|e| {
109 if expr_match(&search_col, e) {
110 found = Some(e.clone());
111 return Ok(TreeNodeRecursion::Stop);
112 }
113 Ok(TreeNodeRecursion::Continue)
114 })?;
115 }
116
117 if let Some(found) = found {
118 return Ok(Transformed::yes(match normalized_expr {
119 Expr::Cast(Cast { expr: _, data_type }) => Expr::Cast(Cast {
120 expr: Box::new(found),
121 data_type,
122 }),
123 Expr::TryCast(TryCast { expr: _, data_type }) => Expr::TryCast(TryCast {
124 expr: Box::new(found),
125 data_type,
126 }),
127 _ => found,
128 }));
129 }
130
131 Ok(Transformed::no(expr))
132 })
133 .data()
134}
135
136fn expr_match(needle: &Expr, expr: &Expr) -> bool {
139 if let Expr::Alias(Alias { expr, .. }) = &expr {
141 expr.as_ref() == needle
142 } else {
143 expr == needle
144 }
145}
146
147#[cfg(test)]
148mod test {
149 use std::ops::Add;
150 use std::sync::Arc;
151
152 use arrow::datatypes::{DataType, Field, Schema};
153
154 use crate::{
155 LogicalPlanBuilder, cast, col, lit, logical_plan::builder::LogicalTableSource,
156 try_cast,
157 };
158
159 use super::*;
160 use crate::test::function_stub::avg;
161 use crate::test::function_stub::min;
162
163 #[test]
164 fn rewrite_sort_cols_by_agg() {
165 let agg = make_input()
167 .aggregate(
168 vec![col("c1")],
170 vec![min(col("c2"))],
172 )
173 .unwrap()
174 .build()
175 .unwrap();
176
177 let cases = vec![
178 TestCase {
179 desc: "c1 --> c1",
180 input: sort(col("c1")),
181 expected: sort(col("c1")),
182 },
183 TestCase {
184 desc: "c1 + c2 --> c1 + c2",
185 input: sort(col("c1") + col("c1")),
186 expected: sort(col("c1") + col("c1")),
187 },
188 TestCase {
189 desc: r#"min(c2) --> "min(c2)"#,
190 input: sort(min(col("c2"))),
191 expected: sort(min(col("c2"))),
192 },
193 TestCase {
194 desc: r#"c1 + min(c2) --> "c1 + min(c2)"#,
195 input: sort(col("c1") + min(col("c2"))),
196 expected: sort(col("c1") + min(col("c2"))),
197 },
198 ];
199
200 for case in cases {
201 case.run(&agg)
202 }
203 }
204
205 #[test]
206 fn rewrite_sort_cols_by_agg_alias() {
207 let agg = make_input()
208 .aggregate(
209 vec![col("c1")],
211 vec![min(col("c2")), avg(col("c3"))],
213 )
214 .unwrap()
215 .project(vec![
217 col("c1").add(lit(1)).alias("c1"),
219 min(col("c2")),
221 avg(col("c3")).alias("average"),
223 ])
224 .unwrap()
225 .build()
226 .unwrap();
227
228 let cases = vec![
229 TestCase {
230 desc: "c1 --> c1 -- column *named* c1 that came out of the projection, (not t.c1)",
231 input: sort(col("c1")),
232 expected: sort(col("c1")),
234 },
235 TestCase {
236 desc: r#"min(c2) --> "min(c2)" -- (column *named* "min(t.c2)"!)"#,
237 input: sort(min(col("c2"))),
238 expected: sort(col("min(t.c2)")),
239 },
240 TestCase {
241 desc: r#"c1 + min(c2) --> "c1 + min(c2)" -- (column *named* "min(t.c2)"!)"#,
242 input: sort(col("c1") + min(col("c2"))),
243 expected: sort(col("c1") + col("min(t.c2)")),
245 },
246 TestCase {
247 desc: r#"avg(c3) --> "avg(t.c3)" as average (column *named* "avg(t.c3)", aliased)"#,
248 input: sort(avg(col("c3"))),
249 expected: sort(col("avg(t.c3)").alias("average")),
250 },
251 ];
252
253 for case in cases {
254 case.run(&agg)
255 }
256 }
257
258 #[test]
259 fn preserve_cast() {
260 let plan = make_input()
261 .project(vec![col("c2").alias("c2")])
262 .unwrap()
263 .project(vec![col("c2").alias("c2")])
264 .unwrap()
265 .build()
266 .unwrap();
267
268 let cases = vec![
269 TestCase {
270 desc: "Cast is preserved by rewrite_sort_cols_by_aggs",
271 input: sort(cast(col("c2"), DataType::Int64)),
272 expected: sort(cast(col("c2").alias("c2"), DataType::Int64)),
273 },
274 TestCase {
275 desc: "TryCast is preserved by rewrite_sort_cols_by_aggs",
276 input: sort(try_cast(col("c2"), DataType::Int64)),
277 expected: sort(try_cast(col("c2").alias("c2"), DataType::Int64)),
278 },
279 ];
280
281 for case in cases {
282 case.run(&plan)
283 }
284 }
285
286 struct TestCase {
287 desc: &'static str,
288 input: Sort,
289 expected: Sort,
290 }
291
292 impl TestCase {
293 fn run(self, input_plan: &LogicalPlan) {
295 let Self {
296 desc,
297 input,
298 expected,
299 } = self;
300
301 println!("running: '{desc}'");
302 let mut exprs =
303 rewrite_sort_cols_by_aggs(vec![input.clone()], input_plan).unwrap();
304
305 assert_eq!(exprs.len(), 1);
306 let rewritten = exprs.pop().unwrap();
307
308 assert_eq!(
309 rewritten, expected,
310 "\n\ninput:{input:?}\nrewritten:{rewritten:?}\nexpected:{expected:?}\n"
311 );
312 }
313 }
314
315 fn make_input() -> LogicalPlanBuilder {
317 let schema = Arc::new(Schema::new(vec![
318 Field::new("c1", DataType::Int32, true),
319 Field::new("c2", DataType::Utf8, true),
320 Field::new("c3", DataType::Float64, true),
321 ]));
322 let projection = None;
323 LogicalPlanBuilder::scan(
324 "t",
325 Arc::new(LogicalTableSource::new(schema)),
326 projection,
327 )
328 .unwrap()
329 }
330
331 fn sort(expr: Expr) -> Sort {
332 let asc = true;
333 let nulls_first = true;
334 expr.sort(asc, nulls_first)
335 }
336}