datafusion_optimizer/
eliminate_duplicated_expr.rs1use crate::optimizer::ApplyOrder;
21use crate::{OptimizerConfig, OptimizerRule};
22use datafusion_common::tree_node::Transformed;
23use datafusion_common::Result;
24use datafusion_expr::logical_plan::LogicalPlan;
25use datafusion_expr::{Aggregate, Expr, Sort, SortExpr};
26use std::hash::{Hash, Hasher};
27
28use indexmap::IndexSet;
29
30#[derive(Default, Debug)]
32pub struct EliminateDuplicatedExpr;
33
34impl EliminateDuplicatedExpr {
35 #[allow(missing_docs)]
36 pub fn new() -> Self {
37 Self {}
38 }
39}
40#[derive(Eq, Clone, Debug)]
42struct SortExprWrapper(SortExpr);
43impl PartialEq for SortExprWrapper {
44 fn eq(&self, other: &Self) -> bool {
45 self.0.expr == other.0.expr
46 }
47}
48impl Hash for SortExprWrapper {
49 fn hash<H: Hasher>(&self, state: &mut H) {
50 self.0.expr.hash(state);
51 }
52}
53impl OptimizerRule for EliminateDuplicatedExpr {
54 fn apply_order(&self) -> Option<ApplyOrder> {
55 Some(ApplyOrder::TopDown)
56 }
57
58 fn supports_rewrite(&self) -> bool {
59 true
60 }
61
62 fn rewrite(
63 &self,
64 plan: LogicalPlan,
65 _config: &dyn OptimizerConfig,
66 ) -> Result<Transformed<LogicalPlan>> {
67 match plan {
68 LogicalPlan::Sort(sort) => {
69 let len = sort.expr.len();
70 let unique_exprs: Vec<_> = sort
71 .expr
72 .into_iter()
73 .map(SortExprWrapper)
74 .collect::<IndexSet<_>>()
75 .into_iter()
76 .map(|wrapper| wrapper.0)
77 .collect();
78
79 let transformed = if len != unique_exprs.len() {
80 Transformed::yes
81 } else {
82 Transformed::no
83 };
84
85 Ok(transformed(LogicalPlan::Sort(Sort {
86 expr: unique_exprs,
87 input: sort.input,
88 fetch: sort.fetch,
89 })))
90 }
91 LogicalPlan::Aggregate(agg) => {
92 let len = agg.group_expr.len();
93
94 let unique_exprs: Vec<Expr> = agg
95 .group_expr
96 .into_iter()
97 .collect::<IndexSet<_>>()
98 .into_iter()
99 .collect();
100
101 let transformed = if len != unique_exprs.len() {
102 Transformed::yes
103 } else {
104 Transformed::no
105 };
106
107 Aggregate::try_new(agg.input, unique_exprs, agg.aggr_expr)
108 .map(|f| transformed(LogicalPlan::Aggregate(f)))
109 }
110 _ => Ok(Transformed::no(plan)),
111 }
112 }
113 fn name(&self) -> &str {
114 "eliminate_duplicated_expr"
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121 use crate::assert_optimized_plan_eq_snapshot;
122 use crate::test::*;
123 use crate::OptimizerContext;
124 use datafusion_expr::{col, logical_plan::builder::LogicalPlanBuilder};
125 use std::sync::Arc;
126
127 macro_rules! assert_optimized_plan_equal {
128 (
129 $plan:expr,
130 @ $expected:literal $(,)?
131 ) => {{
132 let optimizer_ctx = OptimizerContext::new().with_max_passes(1);
133 let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(EliminateDuplicatedExpr::new())];
134 assert_optimized_plan_eq_snapshot!(
135 optimizer_ctx,
136 rules,
137 $plan,
138 @ $expected,
139 )
140 }};
141 }
142
143 #[test]
144 fn eliminate_sort_expr() -> Result<()> {
145 let table_scan = test_table_scan().unwrap();
146 let plan = LogicalPlanBuilder::from(table_scan)
147 .sort_by(vec![col("a"), col("a"), col("b"), col("c")])?
148 .limit(5, Some(10))?
149 .build()?;
150
151 assert_optimized_plan_equal!(plan, @r"
152 Limit: skip=5, fetch=10
153 Sort: test.a ASC NULLS LAST, test.b ASC NULLS LAST, test.c ASC NULLS LAST
154 TableScan: test
155 ")
156 }
157
158 #[test]
159 fn eliminate_sort_exprs_with_options() -> Result<()> {
160 let table_scan = test_table_scan().unwrap();
161 let sort_exprs = vec![
162 col("a").sort(true, true),
163 col("b").sort(true, false),
164 col("a").sort(false, false),
165 col("b").sort(false, true),
166 ];
167 let plan = LogicalPlanBuilder::from(table_scan)
168 .sort(sort_exprs)?
169 .limit(5, Some(10))?
170 .build()?;
171
172 assert_optimized_plan_equal!(plan, @r"
173 Limit: skip=5, fetch=10
174 Sort: test.a ASC NULLS FIRST, test.b ASC NULLS LAST
175 TableScan: test
176 ")
177 }
178}