datafusion_optimizer/
eliminate_duplicated_expr.rs1use crate::optimizer::ApplyOrder;
21use crate::{OptimizerConfig, OptimizerRule};
22use datafusion_common::tree_node::Transformed;
23use datafusion_common::{Result, get_required_sort_exprs_indices, internal_err};
24use datafusion_expr::logical_plan::LogicalPlan;
25use datafusion_expr::{Aggregate, Expr, Sort, SortExpr};
26use std::hash::{Hash, Hasher};
27
28use indexmap::IndexSet;
29
30#[derive(Default, Debug)]
32pub struct EliminateDuplicatedExpr;
33
34impl EliminateDuplicatedExpr {
35 #[expect(missing_docs)]
36 pub fn new() -> Self {
37 Self {}
38 }
39}
40#[derive(Eq, Clone, Debug)]
42struct SortExprWrapper(SortExpr);
43impl PartialEq for SortExprWrapper {
44 fn eq(&self, other: &Self) -> bool {
45 self.0.expr == other.0.expr
46 }
47}
48impl Hash for SortExprWrapper {
49 fn hash<H: Hasher>(&self, state: &mut H) {
50 self.0.expr.hash(state);
51 }
52}
53impl OptimizerRule for EliminateDuplicatedExpr {
54 fn apply_order(&self) -> Option<ApplyOrder> {
55 Some(ApplyOrder::TopDown)
56 }
57
58 fn supports_rewrite(&self) -> bool {
59 true
60 }
61
62 fn rewrite(
63 &self,
64 plan: LogicalPlan,
65 _config: &dyn OptimizerConfig,
66 ) -> Result<Transformed<LogicalPlan>> {
67 match plan {
68 LogicalPlan::Sort(sort) => {
69 let len = sort.expr.len();
70 let unique_exprs: Vec<_> = sort
71 .expr
72 .into_iter()
73 .map(SortExprWrapper)
74 .collect::<IndexSet<_>>()
75 .into_iter()
76 .map(|wrapper| wrapper.0)
77 .collect();
78
79 let sort_expr_names = unique_exprs
80 .iter()
81 .map(|sort_expr| sort_expr.expr.schema_name().to_string())
82 .collect::<Vec<_>>();
83 let required_indices = get_required_sort_exprs_indices(
84 sort.input.schema().as_ref(),
85 &sort_expr_names,
86 );
87
88 let unique_exprs = if required_indices.len() < unique_exprs.len() {
89 required_indices
90 .into_iter()
91 .map(|idx| unique_exprs[idx].clone())
92 .collect()
93 } else {
94 unique_exprs
95 };
96
97 let transformed = if len != unique_exprs.len() {
98 Transformed::yes
99 } else {
100 Transformed::no
101 };
102
103 if unique_exprs.is_empty() {
104 return internal_err!(
105 "FD pruning unexpectedly removed all ORDER BY expressions"
106 );
107 }
108
109 Ok(transformed(LogicalPlan::Sort(Sort {
110 expr: unique_exprs,
111 input: sort.input,
112 fetch: sort.fetch,
113 })))
114 }
115 LogicalPlan::Aggregate(agg) => {
116 let len = agg.group_expr.len();
117
118 let unique_exprs: Vec<Expr> = agg
119 .group_expr
120 .into_iter()
121 .collect::<IndexSet<_>>()
122 .into_iter()
123 .collect();
124
125 let transformed = if len != unique_exprs.len() {
126 Transformed::yes
127 } else {
128 Transformed::no
129 };
130
131 Aggregate::try_new(agg.input, unique_exprs, agg.aggr_expr)
132 .map(|f| transformed(LogicalPlan::Aggregate(f)))
133 }
134 _ => Ok(Transformed::no(plan)),
135 }
136 }
137 fn name(&self) -> &str {
138 "eliminate_duplicated_expr"
139 }
140}
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145 use crate::OptimizerContext;
146 use crate::assert_optimized_plan_eq_snapshot;
147 use crate::test::*;
148 use datafusion_expr::{col, logical_plan::builder::LogicalPlanBuilder};
149 use std::sync::Arc;
150
151 macro_rules! assert_optimized_plan_equal {
152 (
153 $plan:expr,
154 @ $expected:literal $(,)?
155 ) => {{
156 let optimizer_ctx = OptimizerContext::new().with_max_passes(1);
157 let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> =
158 vec![Arc::new(EliminateDuplicatedExpr::new())];
159 assert_optimized_plan_eq_snapshot!(
160 optimizer_ctx,
161 rules,
162 $plan,
163 @ $expected,
164 )
165 }};
166 }
167
168 #[test]
169 fn eliminate_sort_expr() -> Result<()> {
170 let table_scan = test_table_scan().unwrap();
171 let plan = LogicalPlanBuilder::from(table_scan)
172 .sort_by(vec![col("a"), col("a"), col("b"), col("c")])?
173 .limit(5, Some(10))?
174 .build()?;
175
176 assert_optimized_plan_equal!(plan, @r"
177 Limit: skip=5, fetch=10
178 Sort: test.a ASC NULLS LAST, test.b ASC NULLS LAST, test.c ASC NULLS LAST
179 TableScan: test
180 ")
181 }
182
183 #[test]
184 fn eliminate_sort_exprs_with_options() -> Result<()> {
185 let table_scan = test_table_scan().unwrap();
186 let sort_exprs = vec![
187 col("a").sort(true, true),
188 col("b").sort(true, false),
189 col("a").sort(false, false),
190 col("b").sort(false, true),
191 ];
192 let plan = LogicalPlanBuilder::from(table_scan)
193 .sort(sort_exprs)?
194 .limit(5, Some(10))?
195 .build()?;
196
197 assert_optimized_plan_equal!(plan, @r"
198 Limit: skip=5, fetch=10
199 Sort: test.a ASC NULLS FIRST, test.b ASC NULLS LAST
200 TableScan: test
201 ")
202 }
203}