datafusion_expr/expr_rewriter/
order_by.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Rewrite for order by expressions
19
20use crate::expr::Alias;
21use crate::expr_rewriter::normalize_col;
22use crate::{Cast, Expr, LogicalPlan, TryCast, expr::Sort};
23
24use datafusion_common::tree_node::{
25    Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
26};
27use datafusion_common::{Column, Result};
28
29/// Rewrite sort on aggregate expressions to sort on the column of aggregate output
30/// For example, `max(x)` is written to `col("max(x)")`
31pub fn rewrite_sort_cols_by_aggs(
32    sorts: impl IntoIterator<Item = impl Into<Sort>>,
33    plan: &LogicalPlan,
34) -> Result<Vec<Sort>> {
35    sorts
36        .into_iter()
37        .map(|e| {
38            let sort = e.into();
39            Ok(Sort::new(
40                rewrite_sort_col_by_aggs(sort.expr, plan)?,
41                sort.asc,
42                sort.nulls_first,
43            ))
44        })
45        .collect()
46}
47
48fn rewrite_sort_col_by_aggs(expr: Expr, plan: &LogicalPlan) -> Result<Expr> {
49    let plan_inputs = plan.inputs();
50
51    // Joins, and Unions are not yet handled (should have a projection
52    // on top of them)
53    if plan_inputs.len() == 1 {
54        let proj_exprs = plan.expressions();
55        rewrite_in_terms_of_projection(expr, &proj_exprs, plan_inputs[0])
56    } else {
57        Ok(expr)
58    }
59}
60
61/// Rewrites a sort expression in terms of the output of the previous [`LogicalPlan`]
62///
63/// Example:
64///
65/// Given an input expression such as `col(a) + col(b) + col(c)`
66///
67/// into `col(a) + col("b + c")`
68///
69/// Remember that:
70/// 1. given a projection with exprs: [a, b + c]
71/// 2. t produces an output schema with two columns "a", "b + c"
72fn rewrite_in_terms_of_projection(
73    expr: Expr,
74    proj_exprs: &[Expr],
75    input: &LogicalPlan,
76) -> Result<Expr> {
77    // assumption is that each item in exprs, such as "b + c" is
78    // available as an output column named "b + c"
79    expr.transform(|expr| {
80        // search for unnormalized names first such as "c1" (such as aliases)
81        if let Some(found) = proj_exprs.iter().find(|a| (**a) == expr) {
82            let (qualifier, field_name) = found.qualified_name();
83            let col = Expr::Column(Column::new(qualifier, field_name));
84            return Ok(Transformed::yes(col));
85        }
86
87        // if that doesn't work, try to match the expression as an
88        // output column -- however first it must be "normalized"
89        // (e.g. "c1" --> "t.c1") because that normalization is done
90        // at the input of the aggregate.
91
92        let normalized_expr = if let Ok(e) = normalize_col(expr.clone(), input) {
93            e
94        } else {
95            // The expr is not based on Aggregate plan output. Skip it.
96            return Ok(Transformed::no(expr));
97        };
98
99        // expr is an actual expr like min(t.c2), but we are looking
100        // for a column with the same "MIN(C2)", so translate there
101        let name = normalized_expr.schema_name().to_string();
102
103        let search_col = Expr::Column(Column::new_unqualified(name));
104
105        // look for the column named the same as this expr
106        let mut found = None;
107        for proj_expr in proj_exprs {
108            proj_expr.apply(|e| {
109                if expr_match(&search_col, e) {
110                    found = Some(e.clone());
111                    return Ok(TreeNodeRecursion::Stop);
112                }
113                Ok(TreeNodeRecursion::Continue)
114            })?;
115        }
116
117        if let Some(found) = found {
118            return Ok(Transformed::yes(match normalized_expr {
119                Expr::Cast(Cast { expr: _, data_type }) => Expr::Cast(Cast {
120                    expr: Box::new(found),
121                    data_type,
122                }),
123                Expr::TryCast(TryCast { expr: _, data_type }) => Expr::TryCast(TryCast {
124                    expr: Box::new(found),
125                    data_type,
126                }),
127                _ => found,
128            }));
129        }
130
131        Ok(Transformed::no(expr))
132    })
133    .data()
134}
135
136/// Does the underlying expr match e?
137/// so avg(c) as average will match avgc
138fn expr_match(needle: &Expr, expr: &Expr) -> bool {
139    // check inside aliases
140    if let Expr::Alias(Alias { expr, .. }) = &expr {
141        expr.as_ref() == needle
142    } else {
143        expr == needle
144    }
145}
146
147#[cfg(test)]
148mod test {
149    use std::ops::Add;
150    use std::sync::Arc;
151
152    use arrow::datatypes::{DataType, Field, Schema};
153
154    use crate::{
155        LogicalPlanBuilder, cast, col, lit, logical_plan::builder::LogicalTableSource,
156        try_cast,
157    };
158
159    use super::*;
160    use crate::test::function_stub::avg;
161    use crate::test::function_stub::min;
162
163    #[test]
164    fn rewrite_sort_cols_by_agg() {
165        //  gby c1, agg: min(c2)
166        let agg = make_input()
167            .aggregate(
168                // gby: c1
169                vec![col("c1")],
170                // agg: min(c2)
171                vec![min(col("c2"))],
172            )
173            .unwrap()
174            .build()
175            .unwrap();
176
177        let cases = vec![
178            TestCase {
179                desc: "c1 --> c1",
180                input: sort(col("c1")),
181                expected: sort(col("c1")),
182            },
183            TestCase {
184                desc: "c1 + c2 --> c1 + c2",
185                input: sort(col("c1") + col("c1")),
186                expected: sort(col("c1") + col("c1")),
187            },
188            TestCase {
189                desc: r#"min(c2) --> "min(c2)"#,
190                input: sort(min(col("c2"))),
191                expected: sort(min(col("c2"))),
192            },
193            TestCase {
194                desc: r#"c1 + min(c2) --> "c1 + min(c2)"#,
195                input: sort(col("c1") + min(col("c2"))),
196                expected: sort(col("c1") + min(col("c2"))),
197            },
198        ];
199
200        for case in cases {
201            case.run(&agg)
202        }
203    }
204
205    #[test]
206    fn rewrite_sort_cols_by_agg_alias() {
207        let agg = make_input()
208            .aggregate(
209                // gby c1
210                vec![col("c1")],
211                // agg: min(c2), avg(c3)
212                vec![min(col("c2")), avg(col("c3"))],
213            )
214            .unwrap()
215            //  projects out an expression "c1" that is different than the column "c1"
216            .project(vec![
217                // c1 + 1 as c1,
218                col("c1").add(lit(1)).alias("c1"),
219                // min(c2)
220                min(col("c2")),
221                // avg("c3") as average
222                avg(col("c3")).alias("average"),
223            ])
224            .unwrap()
225            .build()
226            .unwrap();
227
228        let cases = vec![
229            TestCase {
230                desc: "c1 --> c1  -- column *named* c1 that came out of the projection, (not t.c1)",
231                input: sort(col("c1")),
232                // should be "c1" not t.c1
233                expected: sort(col("c1")),
234            },
235            TestCase {
236                desc: r#"min(c2) --> "min(c2)" -- (column *named* "min(t.c2)"!)"#,
237                input: sort(min(col("c2"))),
238                expected: sort(col("min(t.c2)")),
239            },
240            TestCase {
241                desc: r#"c1 + min(c2) --> "c1 + min(c2)" -- (column *named* "min(t.c2)"!)"#,
242                input: sort(col("c1") + min(col("c2"))),
243                // should be "c1" not t.c1
244                expected: sort(col("c1") + col("min(t.c2)")),
245            },
246            TestCase {
247                desc: r#"avg(c3) --> "avg(t.c3)" as average (column *named* "avg(t.c3)", aliased)"#,
248                input: sort(avg(col("c3"))),
249                expected: sort(col("avg(t.c3)").alias("average")),
250            },
251        ];
252
253        for case in cases {
254            case.run(&agg)
255        }
256    }
257
258    #[test]
259    fn preserve_cast() {
260        let plan = make_input()
261            .project(vec![col("c2").alias("c2")])
262            .unwrap()
263            .project(vec![col("c2").alias("c2")])
264            .unwrap()
265            .build()
266            .unwrap();
267
268        let cases = vec![
269            TestCase {
270                desc: "Cast is preserved by rewrite_sort_cols_by_aggs",
271                input: sort(cast(col("c2"), DataType::Int64)),
272                expected: sort(cast(col("c2").alias("c2"), DataType::Int64)),
273            },
274            TestCase {
275                desc: "TryCast is preserved by rewrite_sort_cols_by_aggs",
276                input: sort(try_cast(col("c2"), DataType::Int64)),
277                expected: sort(try_cast(col("c2").alias("c2"), DataType::Int64)),
278            },
279        ];
280
281        for case in cases {
282            case.run(&plan)
283        }
284    }
285
286    struct TestCase {
287        desc: &'static str,
288        input: Sort,
289        expected: Sort,
290    }
291
292    impl TestCase {
293        /// calls rewrite_sort_cols_by_aggs for expr and compares it to expected_expr
294        fn run(self, input_plan: &LogicalPlan) {
295            let Self {
296                desc,
297                input,
298                expected,
299            } = self;
300
301            println!("running: '{desc}'");
302            let mut exprs =
303                rewrite_sort_cols_by_aggs(vec![input.clone()], input_plan).unwrap();
304
305            assert_eq!(exprs.len(), 1);
306            let rewritten = exprs.pop().unwrap();
307
308            assert_eq!(
309                rewritten, expected,
310                "\n\ninput:{input:?}\nrewritten:{rewritten:?}\nexpected:{expected:?}\n"
311            );
312        }
313    }
314
315    /// Scan of a table: t(c1 int, c2 varchar, c3 float)
316    fn make_input() -> LogicalPlanBuilder {
317        let schema = Arc::new(Schema::new(vec![
318            Field::new("c1", DataType::Int32, true),
319            Field::new("c2", DataType::Utf8, true),
320            Field::new("c3", DataType::Float64, true),
321        ]));
322        let projection = None;
323        LogicalPlanBuilder::scan(
324            "t",
325            Arc::new(LogicalTableSource::new(schema)),
326            projection,
327        )
328        .unwrap()
329    }
330
331    fn sort(expr: Expr) -> Sort {
332        let asc = true;
333        let nulls_first = true;
334        expr.sort(asc, nulls_first)
335    }
336}