Skip to main content

datafusion_optimizer/
eliminate_duplicated_expr.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`EliminateDuplicatedExpr`] Removes redundant expressions
19
20use crate::optimizer::ApplyOrder;
21use crate::{OptimizerConfig, OptimizerRule};
22use datafusion_common::tree_node::Transformed;
23use datafusion_common::{Result, get_required_sort_exprs_indices, internal_err};
24use datafusion_expr::logical_plan::LogicalPlan;
25use datafusion_expr::{Aggregate, Expr, Sort, SortExpr};
26use std::hash::{Hash, Hasher};
27
28use indexmap::IndexSet;
29
30/// Optimization rule that eliminate duplicated expr.
31#[derive(Default, Debug)]
32pub struct EliminateDuplicatedExpr;
33
34impl EliminateDuplicatedExpr {
35    #[expect(missing_docs)]
36    pub fn new() -> Self {
37        Self {}
38    }
39}
40// use this structure to avoid initial clone
41#[derive(Eq, Clone, Debug)]
42struct SortExprWrapper(SortExpr);
43impl PartialEq for SortExprWrapper {
44    fn eq(&self, other: &Self) -> bool {
45        self.0.expr == other.0.expr
46    }
47}
48impl Hash for SortExprWrapper {
49    fn hash<H: Hasher>(&self, state: &mut H) {
50        self.0.expr.hash(state);
51    }
52}
53impl OptimizerRule for EliminateDuplicatedExpr {
54    fn apply_order(&self) -> Option<ApplyOrder> {
55        Some(ApplyOrder::TopDown)
56    }
57
58    fn supports_rewrite(&self) -> bool {
59        true
60    }
61
62    fn rewrite(
63        &self,
64        plan: LogicalPlan,
65        _config: &dyn OptimizerConfig,
66    ) -> Result<Transformed<LogicalPlan>> {
67        match plan {
68            LogicalPlan::Sort(sort) => {
69                let len = sort.expr.len();
70                let unique_exprs: Vec<_> = sort
71                    .expr
72                    .into_iter()
73                    .map(SortExprWrapper)
74                    .collect::<IndexSet<_>>()
75                    .into_iter()
76                    .map(|wrapper| wrapper.0)
77                    .collect();
78
79                let sort_expr_names = unique_exprs
80                    .iter()
81                    .map(|sort_expr| sort_expr.expr.schema_name().to_string())
82                    .collect::<Vec<_>>();
83                let required_indices = get_required_sort_exprs_indices(
84                    sort.input.schema().as_ref(),
85                    &sort_expr_names,
86                );
87
88                let unique_exprs = if required_indices.len() < unique_exprs.len() {
89                    required_indices
90                        .into_iter()
91                        .map(|idx| unique_exprs[idx].clone())
92                        .collect()
93                } else {
94                    unique_exprs
95                };
96
97                let transformed = if len != unique_exprs.len() {
98                    Transformed::yes
99                } else {
100                    Transformed::no
101                };
102
103                if unique_exprs.is_empty() {
104                    return internal_err!(
105                        "FD pruning unexpectedly removed all ORDER BY expressions"
106                    );
107                }
108
109                Ok(transformed(LogicalPlan::Sort(Sort {
110                    expr: unique_exprs,
111                    input: sort.input,
112                    fetch: sort.fetch,
113                })))
114            }
115            LogicalPlan::Aggregate(agg) => {
116                let len = agg.group_expr.len();
117
118                let unique_exprs: Vec<Expr> = agg
119                    .group_expr
120                    .into_iter()
121                    .collect::<IndexSet<_>>()
122                    .into_iter()
123                    .collect();
124
125                let transformed = if len != unique_exprs.len() {
126                    Transformed::yes
127                } else {
128                    Transformed::no
129                };
130
131                Aggregate::try_new(agg.input, unique_exprs, agg.aggr_expr)
132                    .map(|f| transformed(LogicalPlan::Aggregate(f)))
133            }
134            _ => Ok(Transformed::no(plan)),
135        }
136    }
137    fn name(&self) -> &str {
138        "eliminate_duplicated_expr"
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145    use crate::OptimizerContext;
146    use crate::assert_optimized_plan_eq_snapshot;
147    use crate::test::*;
148    use datafusion_expr::{col, logical_plan::builder::LogicalPlanBuilder};
149    use std::sync::Arc;
150
151    macro_rules! assert_optimized_plan_equal {
152        (
153            $plan:expr,
154            @ $expected:literal $(,)?
155        ) => {{
156            let optimizer_ctx = OptimizerContext::new().with_max_passes(1);
157            let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> =
158                vec![Arc::new(EliminateDuplicatedExpr::new())];
159            assert_optimized_plan_eq_snapshot!(
160                optimizer_ctx,
161                rules,
162                $plan,
163                @ $expected,
164            )
165        }};
166    }
167
168    #[test]
169    fn eliminate_sort_expr() -> Result<()> {
170        let table_scan = test_table_scan().unwrap();
171        let plan = LogicalPlanBuilder::from(table_scan)
172            .sort_by(vec![col("a"), col("a"), col("b"), col("c")])?
173            .limit(5, Some(10))?
174            .build()?;
175
176        assert_optimized_plan_equal!(plan, @r"
177        Limit: skip=5, fetch=10
178          Sort: test.a ASC NULLS LAST, test.b ASC NULLS LAST, test.c ASC NULLS LAST
179            TableScan: test
180        ")
181    }
182
183    #[test]
184    fn eliminate_sort_exprs_with_options() -> Result<()> {
185        let table_scan = test_table_scan().unwrap();
186        let sort_exprs = vec![
187            col("a").sort(true, true),
188            col("b").sort(true, false),
189            col("a").sort(false, false),
190            col("b").sort(false, true),
191        ];
192        let plan = LogicalPlanBuilder::from(table_scan)
193            .sort(sort_exprs)?
194            .limit(5, Some(10))?
195            .build()?;
196
197        assert_optimized_plan_equal!(plan, @r"
198        Limit: skip=5, fetch=10
199          Sort: test.a ASC NULLS FIRST, test.b ASC NULLS LAST
200            TableScan: test
201        ")
202    }
203}