1use crate::optimizer::ApplyOrder;
20use crate::{OptimizerConfig, OptimizerRule};
21use datafusion_common::tree_node::Transformed;
22use datafusion_common::DFSchema;
23use datafusion_common::Result;
24use datafusion_expr::utils::split_conjunction_owned;
25use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair};
26use datafusion_expr::{BinaryExpr, Expr, ExprSchemable, Join, LogicalPlan, Operator};
27type EquijoinPredicate = (Expr, Expr);
29
30#[derive(Default, Debug)]
42pub struct ExtractEquijoinPredicate;
43
44impl ExtractEquijoinPredicate {
45 #[allow(missing_docs)]
46 pub fn new() -> Self {
47 Self {}
48 }
49}
50
51impl OptimizerRule for ExtractEquijoinPredicate {
52 fn supports_rewrite(&self) -> bool {
53 true
54 }
55
56 fn name(&self) -> &str {
57 "extract_equijoin_predicate"
58 }
59
60 fn apply_order(&self) -> Option<ApplyOrder> {
61 Some(ApplyOrder::BottomUp)
62 }
63
64 fn rewrite(
65 &self,
66 plan: LogicalPlan,
67 _config: &dyn OptimizerConfig,
68 ) -> Result<Transformed<LogicalPlan>> {
69 match plan {
70 LogicalPlan::Join(Join {
71 left,
72 right,
73 mut on,
74 filter: Some(expr),
75 join_type,
76 join_constraint,
77 schema,
78 null_equals_null,
79 }) => {
80 let left_schema = left.schema();
81 let right_schema = right.schema();
82 let (equijoin_predicates, non_equijoin_expr) =
83 split_eq_and_noneq_join_predicate(expr, left_schema, right_schema)?;
84
85 if !equijoin_predicates.is_empty() {
86 on.extend(equijoin_predicates);
87 Ok(Transformed::yes(LogicalPlan::Join(Join {
88 left,
89 right,
90 on,
91 filter: non_equijoin_expr,
92 join_type,
93 join_constraint,
94 schema,
95 null_equals_null,
96 })))
97 } else {
98 Ok(Transformed::no(LogicalPlan::Join(Join {
99 left,
100 right,
101 on,
102 filter: non_equijoin_expr,
103 join_type,
104 join_constraint,
105 schema,
106 null_equals_null,
107 })))
108 }
109 }
110 _ => Ok(Transformed::no(plan)),
111 }
112 }
113}
114
115fn split_eq_and_noneq_join_predicate(
116 filter: Expr,
117 left_schema: &DFSchema,
118 right_schema: &DFSchema,
119) -> Result<(Vec<EquijoinPredicate>, Option<Expr>)> {
120 let exprs = split_conjunction_owned(filter);
121
122 let mut accum_join_keys: Vec<(Expr, Expr)> = vec![];
123 let mut accum_filters: Vec<Expr> = vec![];
124 for expr in exprs {
125 match expr {
126 Expr::BinaryExpr(BinaryExpr {
127 ref left,
128 op: Operator::Eq,
129 ref right,
130 }) => {
131 let join_key_pair =
132 find_valid_equijoin_key_pair(left, right, left_schema, right_schema)?;
133
134 if let Some((left_expr, right_expr)) = join_key_pair {
135 let left_expr_type = left_expr.get_type(left_schema)?;
136 let right_expr_type = right_expr.get_type(right_schema)?;
137
138 if can_hash(&left_expr_type) && can_hash(&right_expr_type) {
139 accum_join_keys.push((left_expr, right_expr));
140 } else {
141 accum_filters.push(expr);
142 }
143 } else {
144 accum_filters.push(expr);
145 }
146 }
147 _ => accum_filters.push(expr),
148 }
149 }
150
151 let result_filter = accum_filters.into_iter().reduce(Expr::and);
152 Ok((accum_join_keys, result_filter))
153}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158 use crate::test::*;
159 use arrow::datatypes::DataType;
160 use datafusion_expr::{
161 col, lit, logical_plan::builder::LogicalPlanBuilder, JoinType,
162 };
163 use std::sync::Arc;
164
165 fn assert_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> {
166 assert_optimized_plan_eq_display_indent(
167 Arc::new(ExtractEquijoinPredicate {}),
168 plan,
169 expected,
170 );
171
172 Ok(())
173 }
174
175 #[test]
176 fn join_with_only_column_equi_predicate() -> Result<()> {
177 let t1 = test_table_scan_with_name("t1")?;
178 let t2 = test_table_scan_with_name("t2")?;
179
180 let plan = LogicalPlanBuilder::from(t1)
181 .join_on(t2, JoinType::Left, Some(col("t1.a").eq(col("t2.a"))))?
182 .build()?;
183 let expected = "Left Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\
184 \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
185 \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]";
186
187 assert_plan_eq(plan, expected)
188 }
189
190 #[test]
191 fn join_with_only_equi_expr_predicate() -> Result<()> {
192 let t1 = test_table_scan_with_name("t1")?;
193 let t2 = test_table_scan_with_name("t2")?;
194
195 let plan = LogicalPlanBuilder::from(t1)
196 .join_on(
197 t2,
198 JoinType::Left,
199 Some((col("t1.a") + lit(10i64)).eq(col("t2.a") * lit(2u32))),
200 )?
201 .build()?;
202 let expected = "Left Join: t1.a + Int64(10) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\
203 \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
204 \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]";
205
206 assert_plan_eq(plan, expected)
207 }
208
209 #[test]
210 fn join_with_only_none_equi_predicate() -> Result<()> {
211 let t1 = test_table_scan_with_name("t1")?;
212 let t2 = test_table_scan_with_name("t2")?;
213
214 let plan = LogicalPlanBuilder::from(t1)
215 .join_on(
216 t2,
217 JoinType::Left,
218 Some(
219 (col("t1.a") + lit(10i64))
220 .gt_eq(col("t2.a") * lit(2u32))
221 .and(col("t1.b").lt(lit(100i32))),
222 ),
223 )?
224 .build()?;
225 let expected = "Left Join: Filter: t1.a + Int64(10) >= t2.a * UInt32(2) AND t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\
226 \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
227 \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]";
228
229 assert_plan_eq(plan, expected)
230 }
231
232 #[test]
233 fn join_with_expr_both_from_filter_and_keys() -> Result<()> {
234 let t1 = test_table_scan_with_name("t1")?;
235 let t2 = test_table_scan_with_name("t2")?;
236
237 let plan = LogicalPlanBuilder::from(t1)
238 .join_with_expr_keys(
239 t2,
240 JoinType::Left,
241 (
242 vec![col("t1.a") + lit(11u32)],
243 vec![col("t2.a") * lit(2u32)],
244 ),
245 Some(
246 (col("t1.a") + lit(10i64))
247 .eq(col("t2.a") * lit(2u32))
248 .and(col("t1.b").lt(lit(100i32))),
249 ),
250 )?
251 .build()?;
252 let expected = "Left Join: t1.a + UInt32(11) = t2.a * UInt32(2), t1.a + Int64(10) = t2.a * UInt32(2) Filter: t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\
253 \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
254 \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]";
255
256 assert_plan_eq(plan, expected)
257 }
258
259 #[test]
260 fn join_with_and_or_filter() -> Result<()> {
261 let t1 = test_table_scan_with_name("t1")?;
262 let t2 = test_table_scan_with_name("t2")?;
263
264 let plan = LogicalPlanBuilder::from(t1)
265 .join_on(
266 t2,
267 JoinType::Left,
268 Some(
269 col("t1.c")
270 .eq(col("t2.c"))
271 .or((col("t1.a") + col("t1.b")).gt(col("t2.b") + col("t2.c")))
272 .and(
273 col("t1.a").eq(col("t2.a")).and(col("t1.b").eq(col("t2.b"))),
274 ),
275 ),
276 )?
277 .build()?;
278 let expected = "Left Join: t1.a = t2.a, t1.b = t2.b Filter: t1.c = t2.c OR t1.a + t1.b > t2.b + t2.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\
279 \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
280 \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]";
281
282 assert_plan_eq(plan, expected)
283 }
284
285 #[test]
286 fn join_with_multiple_table() -> Result<()> {
287 let t1 = test_table_scan_with_name("t1")?;
288 let t2 = test_table_scan_with_name("t2")?;
289 let t3 = test_table_scan_with_name("t3")?;
290
291 let input = LogicalPlanBuilder::from(t2)
292 .join_on(
293 t3,
294 JoinType::Left,
295 Some(
296 col("t2.a")
297 .eq(col("t3.a"))
298 .and((col("t2.a") + col("t3.b")).gt(lit(100u32))),
299 ),
300 )?
301 .build()?;
302 let plan = LogicalPlanBuilder::from(t1)
303 .join_on(
304 input,
305 JoinType::Left,
306 Some(
307 col("t1.a")
308 .eq(col("t2.a"))
309 .and((col("t1.c") + col("t2.c") + col("t3.c")).lt(lit(100u32))),
310 ),
311 )?
312 .build()?;
313 let expected = "Left Join: t1.a = t2.a Filter: t1.c + t2.c + t3.c < UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, c:UInt32;N]\
314 \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
315 \n Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\
316 \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\
317 \n TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]";
318
319 assert_plan_eq(plan, expected)
320 }
321
322 #[test]
323 fn join_with_multiple_table_and_eq_filter() -> Result<()> {
324 let t1 = test_table_scan_with_name("t1")?;
325 let t2 = test_table_scan_with_name("t2")?;
326 let t3 = test_table_scan_with_name("t3")?;
327
328 let input = LogicalPlanBuilder::from(t2)
329 .join_on(
330 t3,
331 JoinType::Left,
332 Some(
333 col("t2.a")
334 .eq(col("t3.a"))
335 .and((col("t2.a") + col("t3.b")).gt(lit(100u32))),
336 ),
337 )?
338 .build()?;
339 let plan = LogicalPlanBuilder::from(t1)
340 .join_on(
341 input,
342 JoinType::Left,
343 Some(col("t1.a").eq(col("t2.a")).and(col("t2.c").eq(col("t3.c")))),
344 )?
345 .build()?;
346 let expected = "Left Join: t1.a = t2.a Filter: t2.c = t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, c:UInt32;N]\
347 \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
348 \n Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\
349 \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\
350 \n TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]";
351
352 assert_plan_eq(plan, expected)
353 }
354
355 #[test]
356 fn join_with_alias_filter() -> Result<()> {
357 let t1 = test_table_scan_with_name("t1")?;
358 let t2 = test_table_scan_with_name("t2")?;
359
360 let t1_schema = Arc::clone(t1.schema());
361 let t2_schema = Arc::clone(t2.schema());
362
363 let filter = Expr::eq(
365 col("t1.a") + lit(1i64).cast_to(&DataType::UInt32, &t1_schema)?,
366 col("t2.a") + lit(2i32).cast_to(&DataType::UInt32, &t2_schema)?,
367 )
368 .alias("t1.a + 1 = t2.a + 2");
369 let plan = LogicalPlanBuilder::from(t1)
370 .join_on(t2, JoinType::Left, Some(filter))?
371 .build()?;
372 let expected = "Left Join: t1.a + CAST(Int64(1) AS UInt32) = t2.a + CAST(Int32(2) AS UInt32) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\
373 \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
374 \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]";
375
376 assert_plan_eq(plan, expected)
377 }
378}