datafusion_optimizer/
eliminate_duplicated_expr.rs1use crate::optimizer::ApplyOrder;
21use crate::{OptimizerConfig, OptimizerRule};
22use datafusion_common::tree_node::Transformed;
23use datafusion_common::Result;
24use datafusion_expr::logical_plan::LogicalPlan;
25use datafusion_expr::{Aggregate, Expr, Sort, SortExpr};
26use std::hash::{Hash, Hasher};
27
28use indexmap::IndexSet;
29
30#[derive(Default, Debug)]
32pub struct EliminateDuplicatedExpr;
33
34impl EliminateDuplicatedExpr {
35 #[allow(missing_docs)]
36 pub fn new() -> Self {
37 Self {}
38 }
39}
40#[derive(Eq, Clone, Debug)]
42struct SortExprWrapper(SortExpr);
43impl PartialEq for SortExprWrapper {
44 fn eq(&self, other: &Self) -> bool {
45 self.0.expr == other.0.expr
46 }
47}
48impl Hash for SortExprWrapper {
49 fn hash<H: Hasher>(&self, state: &mut H) {
50 self.0.expr.hash(state);
51 }
52}
53impl OptimizerRule for EliminateDuplicatedExpr {
54 fn apply_order(&self) -> Option<ApplyOrder> {
55 Some(ApplyOrder::TopDown)
56 }
57
58 fn supports_rewrite(&self) -> bool {
59 true
60 }
61
62 fn rewrite(
63 &self,
64 plan: LogicalPlan,
65 _config: &dyn OptimizerConfig,
66 ) -> Result<Transformed<LogicalPlan>> {
67 match plan {
68 LogicalPlan::Sort(sort) => {
69 let len = sort.expr.len();
70 let unique_exprs: Vec<_> = sort
71 .expr
72 .into_iter()
73 .map(SortExprWrapper)
74 .collect::<IndexSet<_>>()
75 .into_iter()
76 .map(|wrapper| wrapper.0)
77 .collect();
78
79 let transformed = if len != unique_exprs.len() {
80 Transformed::yes
81 } else {
82 Transformed::no
83 };
84
85 Ok(transformed(LogicalPlan::Sort(Sort {
86 expr: unique_exprs,
87 input: sort.input,
88 fetch: sort.fetch,
89 })))
90 }
91 LogicalPlan::Aggregate(agg) => {
92 let len = agg.group_expr.len();
93
94 let unique_exprs: Vec<Expr> = agg
95 .group_expr
96 .into_iter()
97 .collect::<IndexSet<_>>()
98 .into_iter()
99 .collect();
100
101 let transformed = if len != unique_exprs.len() {
102 Transformed::yes
103 } else {
104 Transformed::no
105 };
106
107 Aggregate::try_new(agg.input, unique_exprs, agg.aggr_expr)
108 .map(|f| transformed(LogicalPlan::Aggregate(f)))
109 }
110 _ => Ok(Transformed::no(plan)),
111 }
112 }
113 fn name(&self) -> &str {
114 "eliminate_duplicated_expr"
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121 use crate::test::*;
122 use datafusion_expr::{col, logical_plan::builder::LogicalPlanBuilder};
123 use std::sync::Arc;
124
125 fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> {
126 crate::test::assert_optimized_plan_eq(
127 Arc::new(EliminateDuplicatedExpr::new()),
128 plan,
129 expected,
130 )
131 }
132
133 #[test]
134 fn eliminate_sort_expr() -> Result<()> {
135 let table_scan = test_table_scan().unwrap();
136 let plan = LogicalPlanBuilder::from(table_scan)
137 .sort_by(vec![col("a"), col("a"), col("b"), col("c")])?
138 .limit(5, Some(10))?
139 .build()?;
140 let expected = "Limit: skip=5, fetch=10\
141 \n Sort: test.a ASC NULLS LAST, test.b ASC NULLS LAST, test.c ASC NULLS LAST\
142 \n TableScan: test";
143 assert_optimized_plan_eq(plan, expected)
144 }
145
146 #[test]
147 fn eliminate_sort_exprs_with_options() -> Result<()> {
148 let table_scan = test_table_scan().unwrap();
149 let sort_exprs = vec![
150 col("a").sort(true, true),
151 col("b").sort(true, false),
152 col("a").sort(false, false),
153 col("b").sort(false, true),
154 ];
155 let plan = LogicalPlanBuilder::from(table_scan)
156 .sort(sort_exprs)?
157 .limit(5, Some(10))?
158 .build()?;
159 let expected = "Limit: skip=5, fetch=10\
160 \n Sort: test.a ASC NULLS FIRST, test.b ASC NULLS LAST\
161 \n TableScan: test";
162 assert_optimized_plan_eq(plan, expected)
163 }
164}