Skip to main content

datafusion_expr/expr_rewriter/
order_by.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Rewrite for order by expressions
19
20use crate::expr::Alias;
21use crate::expr_rewriter::normalize_col;
22use crate::{Cast, Expr, LogicalPlan, TryCast, expr::Sort};
23
24use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
25use datafusion_common::{Column, Result};
26
27/// Rewrite sort on aggregate expressions to sort on the column of aggregate output
28/// For example, `max(x)` is written to `col("max(x)")`
29pub fn rewrite_sort_cols_by_aggs(
30    sorts: impl IntoIterator<Item = impl Into<Sort>>,
31    plan: &LogicalPlan,
32) -> Result<Vec<Sort>> {
33    sorts
34        .into_iter()
35        .map(|e| {
36            let sort = e.into();
37            Ok(Sort::new(
38                rewrite_sort_col_by_aggs(sort.expr, plan)?,
39                sort.asc,
40                sort.nulls_first,
41            ))
42        })
43        .collect()
44}
45
46fn rewrite_sort_col_by_aggs(expr: Expr, plan: &LogicalPlan) -> Result<Expr> {
47    let plan_inputs = plan.inputs();
48
49    // Joins, and Unions are not yet handled (should have a projection
50    // on top of them)
51    if plan_inputs.len() == 1 {
52        let proj_exprs = plan.expressions();
53        rewrite_in_terms_of_projection(expr, &proj_exprs, plan_inputs[0])
54    } else {
55        Ok(expr)
56    }
57}
58
59/// Rewrites a sort expression in terms of the output of the previous [`LogicalPlan`]
60///
61/// Example:
62///
63/// Given an input expression such as `col(a) + col(b) + col(c)`
64///
65/// into `col(a) + col("b + c")`
66///
67/// Remember that:
68/// 1. given a projection with exprs: [a, b + c]
69/// 2. t produces an output schema with two columns "a", "b + c"
70fn rewrite_in_terms_of_projection(
71    expr: Expr,
72    proj_exprs: &[Expr],
73    input: &LogicalPlan,
74) -> Result<Expr> {
75    // assumption is that each item in exprs, such as "b + c" is
76    // available as an output column named "b + c"
77    expr.transform(|expr| {
78        // search for unnormalized names first such as "c1" (such as aliases).
79        // Also look inside aliases so e.g. `count(Int64(1))` matches
80        // `count(Int64(1)) AS count(*)`.
81        if let Some(found) = proj_exprs.iter().find(|a| expr_match(&expr, a)) {
82            let (qualifier, field_name) = found.qualified_name();
83            let col = Expr::Column(Column::new(qualifier, field_name));
84            return Ok(Transformed::yes(col));
85        }
86
87        // if that doesn't work, try to match the expression as an
88        // output column -- however first it must be "normalized"
89        // (e.g. "c1" --> "t.c1") because that normalization is done
90        // at the input of the aggregate.
91
92        let normalized_expr = if let Ok(e) = normalize_col(expr.clone(), input) {
93            e
94        } else {
95            // The expr is not based on Aggregate plan output. Skip it.
96            return Ok(Transformed::no(expr));
97        };
98
99        // expr is an actual expr like min(t.c2), but we are looking
100        // for a column with the same "MIN(C2)", so translate there
101        let name = normalized_expr.schema_name().to_string();
102
103        let search_col = Expr::Column(Column::new_unqualified(name));
104
105        // Search only top-level projection expressions for a match.
106        // We intentionally avoid a recursive search (e.g. `apply`) to
107        // prevent matching sub-expressions of composites like
108        // `min(c2) + max(c3)` when the ORDER BY is just `min(c2)`.
109        let found = proj_exprs
110            .iter()
111            .find(|proj_expr| expr_match(&search_col, proj_expr));
112
113        if let Some(found) = found {
114            let (qualifier, field_name) = found.qualified_name();
115            let col = Expr::Column(Column::new(qualifier, field_name));
116            return Ok(Transformed::yes(match normalized_expr {
117                Expr::Cast(Cast { expr: _, field }) => Expr::Cast(Cast {
118                    expr: Box::new(col),
119                    field,
120                }),
121                Expr::TryCast(TryCast { expr: _, field }) => Expr::TryCast(TryCast {
122                    expr: Box::new(col),
123                    field,
124                }),
125                _ => col,
126            }));
127        }
128
129        Ok(Transformed::no(expr))
130    })
131    .data()
132}
133
134/// Does the underlying expr match e?
135/// so avg(c) as average will match avgc
136fn expr_match(needle: &Expr, expr: &Expr) -> bool {
137    // check inside aliases
138    if let Expr::Alias(Alias { expr, .. }) = &expr {
139        expr.as_ref() == needle
140    } else {
141        expr == needle
142    }
143}
144
145#[cfg(test)]
146mod test {
147    use std::ops::Add;
148    use std::sync::Arc;
149
150    use arrow::datatypes::{DataType, Field, Schema};
151
152    use crate::{
153        LogicalPlanBuilder, cast, col, lit, logical_plan::builder::LogicalTableSource,
154        try_cast,
155    };
156
157    use super::*;
158    use crate::test::function_stub::avg;
159    use crate::test::function_stub::count;
160    use crate::test::function_stub::max;
161    use crate::test::function_stub::min;
162    use crate::test::function_stub::sum;
163
164    #[test]
165    fn rewrite_sort_cols_by_agg() {
166        //  gby c1, agg: min(c2)
167        let agg = make_input()
168            .aggregate(
169                // gby: c1
170                vec![col("c1")],
171                // agg: min(c2)
172                vec![min(col("c2"))],
173            )
174            .unwrap()
175            .build()
176            .unwrap();
177
178        let cases = vec![
179            TestCase {
180                desc: "c1 --> c1",
181                input: sort(col("c1")),
182                expected: sort(col("c1")),
183            },
184            TestCase {
185                desc: "c1 + c2 --> c1 + c2",
186                input: sort(col("c1") + col("c1")),
187                expected: sort(col("c1") + col("c1")),
188            },
189            TestCase {
190                desc: r#"min(c2) --> "min(c2)"#,
191                input: sort(min(col("c2"))),
192                expected: sort(min(col("c2"))),
193            },
194            TestCase {
195                desc: r#"c1 + min(c2) --> "c1 + min(c2)"#,
196                input: sort(col("c1") + min(col("c2"))),
197                expected: sort(col("c1") + min(col("c2"))),
198            },
199        ];
200
201        for case in cases {
202            case.run(&agg)
203        }
204    }
205
206    #[test]
207    fn rewrite_sort_cols_by_agg_alias() {
208        let agg = make_input()
209            .aggregate(
210                // gby c1
211                vec![col("c1")],
212                // agg: min(c2), avg(c3)
213                vec![min(col("c2")), avg(col("c3"))],
214            )
215            .unwrap()
216            //  projects out an expression "c1" that is different than the column "c1"
217            .project(vec![
218                // c1 + 1 as c1,
219                col("c1").add(lit(1)).alias("c1"),
220                // min(c2)
221                min(col("c2")),
222                // avg("c3") as average
223                avg(col("c3")).alias("average"),
224            ])
225            .unwrap()
226            .build()
227            .unwrap();
228
229        let cases = vec![
230            TestCase {
231                desc: "c1 --> c1  -- column *named* c1 that came out of the projection, (not t.c1)",
232                input: sort(col("c1")),
233                // should be "c1" not t.c1
234                expected: sort(col("c1")),
235            },
236            TestCase {
237                desc: r#"min(c2) --> "min(c2)" -- (column *named* "min(t.c2)"!)"#,
238                input: sort(min(col("c2"))),
239                expected: sort(Expr::Column(Column::new_unqualified("min(t.c2)"))),
240            },
241            TestCase {
242                desc: r#"c1 + min(c2) --> "c1 + min(c2)" -- (column *named* "min(t.c2)"!)"#,
243                input: sort(col("c1") + min(col("c2"))),
244                expected: sort(
245                    col("c1") + Expr::Column(Column::new_unqualified("min(t.c2)")),
246                ),
247            },
248            TestCase {
249                desc: r#"avg(c3) --> "average" (column *named* "average", from alias)"#,
250                input: sort(avg(col("c3"))),
251                expected: sort(col("average")),
252            },
253        ];
254
255        for case in cases {
256            case.run(&agg)
257        }
258    }
259
260    /// When an aggregate is aliased in the projection,
261    /// ORDER BY on the original aggregate expression should resolve to
262    /// a Column reference using the alias name — not leak the inner
263    /// Alias expression node or resolve to a descendant subtree.
264    #[test]
265    fn rewrite_sort_resolves_alias_to_column_ref() {
266        let plan = make_input()
267            .aggregate(vec![col("c1")], vec![min(col("c2")), max(col("c3"))])
268            .unwrap()
269            .project(vec![
270                col("c1"),
271                min(col("c2")).alias("min_val"),
272                max(col("c3")).alias("max_val"),
273            ])
274            .unwrap()
275            .build()
276            .unwrap();
277
278        let cases = vec![
279            TestCase {
280                desc: "min(c2) with alias 'min_val' should resolve to col(min_val)",
281                input: sort(min(col("c2"))),
282                expected: sort(col("min_val")),
283            },
284            TestCase {
285                desc: "max(c3) with alias 'max_val' should resolve to col(max_val)",
286                input: sort(max(col("c3"))),
287                expected: sort(col("max_val")),
288            },
289        ];
290
291        for case in cases {
292            case.run(&plan)
293        }
294    }
295
296    #[test]
297    fn composite_proj_expr_containing_sort_col_as_subexpr() {
298        let plan = make_input()
299            .aggregate(vec![col("c1")], vec![min(col("c2")), max(col("c3"))])
300            .unwrap()
301            .project(vec![
302                col("c1"),
303                (min(col("c2")) + max(col("c3"))).alias("range"),
304                min(col("c2")).alias("min_val"),
305                max(col("c3")).alias("max_val"),
306            ])
307            .unwrap()
308            .build()
309            .unwrap();
310
311        let cases = vec![
312            TestCase {
313                desc: "sort by min(c2) should resolve to col(min_val), not col(range)",
314                input: sort(min(col("c2"))),
315                expected: sort(col("min_val")),
316            },
317            TestCase {
318                desc: "sort by max(c3) should resolve to col(max_val), not col(range)",
319                input: sort(max(col("c3"))),
320                expected: sort(col("max_val")),
321            },
322        ];
323
324        for case in cases {
325            case.run(&plan)
326        }
327    }
328
329    #[test]
330    fn composite_before_standalone_should_not_shadow() {
331        let plan = make_input()
332            .aggregate(vec![col("c1")], vec![min(col("c2")), max(col("c2"))])
333            .unwrap()
334            .project(vec![
335                col("c1"),
336                (min(col("c2")) + max(col("c2"))).alias("combined"),
337                min(col("c2")),
338            ])
339            .unwrap()
340            .build()
341            .unwrap();
342
343        let cases = vec![TestCase {
344            desc: "sort by min(c2) should resolve to col(min(t.c2)), not col(combined)",
345            input: sort(min(col("c2"))),
346            expected: sort(Expr::Column(Column::new_unqualified("min(t.c2)"))),
347        }];
348
349        for case in cases {
350            case.run(&plan)
351        }
352    }
353
354    #[test]
355    fn duplicate_aggregate_in_multiple_proj_exprs() {
356        let plan = make_input()
357            .aggregate(vec![col("c1")], vec![min(col("c2"))])
358            .unwrap()
359            .project(vec![
360                col("c1"),
361                min(col("c2")).alias("first_alias"),
362                min(col("c2")).alias("second_alias"),
363            ])
364            .unwrap()
365            .build()
366            .unwrap();
367
368        let cases = vec![TestCase {
369            desc: "sort by min(c2) with two aliases picks first_alias",
370            input: sort(min(col("c2"))),
371            expected: sort(col("first_alias")),
372        }];
373
374        for case in cases {
375            case.run(&plan)
376        }
377    }
378
379    #[test]
380    fn sort_agg_not_in_select_with_aliased_aggs() {
381        let plan = make_input()
382            .aggregate(
383                vec![col("c1")],
384                vec![min(col("c2")), max(col("c3")), sum(col("c3"))],
385            )
386            .unwrap()
387            .project(vec![
388                col("c1"),
389                min(col("c2")).alias("min_val"),
390                max(col("c3")).alias("max_val"),
391            ])
392            .unwrap()
393            .build()
394            .unwrap();
395
396        let cases = vec![TestCase {
397            desc: "sort by sum(c3) not in projection should not be rewritten",
398            input: sort(sum(col("c3"))),
399            expected: sort(sum(col("c3"))),
400        }];
401
402        for case in cases {
403            case.run(&plan)
404        }
405    }
406
407    #[test]
408    fn cast_on_aliased_aggregate() {
409        let plan = make_input()
410            .aggregate(vec![col("c1")], vec![min(col("c2"))])
411            .unwrap()
412            .project(vec![col("c1"), min(col("c2")).alias("min_val")])
413            .unwrap()
414            .build()
415            .unwrap();
416
417        let cases = vec![
418            TestCase {
419                desc: "CAST on aliased aggregate should preserve cast and resolve alias",
420                input: sort(cast(min(col("c2")), DataType::Int64)),
421                expected: sort(cast(col("min_val"), DataType::Int64)),
422            },
423            TestCase {
424                desc: "TryCast on aliased aggregate should preserve try_cast and resolve alias",
425                input: sort(try_cast(min(col("c2")), DataType::Int64)),
426                expected: sort(try_cast(col("min_val"), DataType::Int64)),
427            },
428        ];
429
430        for case in cases {
431            case.run(&plan)
432        }
433    }
434
435    #[test]
436    fn count_star_with_alias() {
437        let plan = make_input()
438            .aggregate(vec![col("c1")], vec![count(lit(1))])
439            .unwrap()
440            .project(vec![col("c1"), count(lit(1)).alias("cnt")])
441            .unwrap()
442            .build()
443            .unwrap();
444
445        let cases = vec![TestCase {
446            desc: "sort by count(1) should resolve to cnt alias",
447            input: sort(count(lit(1))),
448            expected: sort(col("cnt")),
449        }];
450
451        for case in cases {
452            case.run(&plan)
453        }
454    }
455
456    #[test]
457    fn preserve_cast() {
458        let plan = make_input()
459            .project(vec![col("c2").alias("c2")])
460            .unwrap()
461            .project(vec![col("c2").alias("c2")])
462            .unwrap()
463            .build()
464            .unwrap();
465
466        let cases = vec![
467            TestCase {
468                desc: "Cast is preserved by rewrite_sort_cols_by_aggs",
469                input: sort(cast(col("c2"), DataType::Int64)),
470                expected: sort(cast(col("c2"), DataType::Int64)),
471            },
472            TestCase {
473                desc: "TryCast is preserved by rewrite_sort_cols_by_aggs",
474                input: sort(try_cast(col("c2"), DataType::Int64)),
475                expected: sort(try_cast(col("c2"), DataType::Int64)),
476            },
477        ];
478
479        for case in cases {
480            case.run(&plan)
481        }
482    }
483
484    struct TestCase {
485        desc: &'static str,
486        input: Sort,
487        expected: Sort,
488    }
489
490    impl TestCase {
491        /// calls rewrite_sort_cols_by_aggs for expr and compares it to expected_expr
492        fn run(self, input_plan: &LogicalPlan) {
493            let Self {
494                desc,
495                input,
496                expected,
497            } = self;
498
499            println!("running: '{desc}'");
500            let mut exprs =
501                rewrite_sort_cols_by_aggs(vec![input.clone()], input_plan).unwrap();
502
503            assert_eq!(exprs.len(), 1);
504            let rewritten = exprs.pop().unwrap();
505
506            assert_eq!(
507                rewritten, expected,
508                "\n\ninput:{input:?}\nrewritten:{rewritten:?}\nexpected:{expected:?}\n"
509            );
510        }
511    }
512
513    /// Scan of a table: t(c1 int, c2 varchar, c3 float)
514    fn make_input() -> LogicalPlanBuilder {
515        let schema = Arc::new(Schema::new(vec![
516            Field::new("c1", DataType::Int32, true),
517            Field::new("c2", DataType::Utf8, true),
518            Field::new("c3", DataType::Float64, true),
519        ]));
520        let projection = None;
521        LogicalPlanBuilder::scan(
522            "t",
523            Arc::new(LogicalTableSource::new(schema)),
524            projection,
525        )
526        .unwrap()
527    }
528
529    fn sort(expr: Expr) -> Sort {
530        let asc = true;
531        let nulls_first = true;
532        expr.sort(asc, nulls_first)
533    }
534}