1use crate::optimizer::ApplyOrder;
20use crate::{OptimizerConfig, OptimizerRule};
21use datafusion_common::tree_node::Transformed;
22use datafusion_common::DFSchema;
23use datafusion_common::Result;
24use datafusion_expr::utils::split_conjunction_owned;
25use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair};
26use datafusion_expr::{BinaryExpr, Expr, ExprSchemable, Join, LogicalPlan, Operator};
27type EquijoinPredicate = (Expr, Expr);
29
30#[derive(Default, Debug)]
42pub struct ExtractEquijoinPredicate;
43
44impl ExtractEquijoinPredicate {
45 #[allow(missing_docs)]
46 pub fn new() -> Self {
47 Self {}
48 }
49}
50
51impl OptimizerRule for ExtractEquijoinPredicate {
52 fn supports_rewrite(&self) -> bool {
53 true
54 }
55
56 fn name(&self) -> &str {
57 "extract_equijoin_predicate"
58 }
59
60 fn apply_order(&self) -> Option<ApplyOrder> {
61 Some(ApplyOrder::BottomUp)
62 }
63
64 fn rewrite(
65 &self,
66 plan: LogicalPlan,
67 _config: &dyn OptimizerConfig,
68 ) -> Result<Transformed<LogicalPlan>> {
69 match plan {
70 LogicalPlan::Join(Join {
71 left,
72 right,
73 mut on,
74 filter: Some(expr),
75 join_type,
76 join_constraint,
77 schema,
78 null_equality,
79 }) => {
80 let left_schema = left.schema();
81 let right_schema = right.schema();
82 let (equijoin_predicates, non_equijoin_expr) =
83 split_eq_and_noneq_join_predicate(expr, left_schema, right_schema)?;
84
85 if !equijoin_predicates.is_empty() {
86 on.extend(equijoin_predicates);
87 Ok(Transformed::yes(LogicalPlan::Join(Join {
88 left,
89 right,
90 on,
91 filter: non_equijoin_expr,
92 join_type,
93 join_constraint,
94 schema,
95 null_equality,
96 })))
97 } else {
98 Ok(Transformed::no(LogicalPlan::Join(Join {
99 left,
100 right,
101 on,
102 filter: non_equijoin_expr,
103 join_type,
104 join_constraint,
105 schema,
106 null_equality,
107 })))
108 }
109 }
110 _ => Ok(Transformed::no(plan)),
111 }
112 }
113}
114
115fn split_eq_and_noneq_join_predicate(
116 filter: Expr,
117 left_schema: &DFSchema,
118 right_schema: &DFSchema,
119) -> Result<(Vec<EquijoinPredicate>, Option<Expr>)> {
120 let exprs = split_conjunction_owned(filter);
121
122 let mut accum_join_keys: Vec<(Expr, Expr)> = vec![];
123 let mut accum_filters: Vec<Expr> = vec![];
124 for expr in exprs {
125 match expr {
126 Expr::BinaryExpr(BinaryExpr {
127 ref left,
128 op: Operator::Eq,
129 ref right,
130 }) => {
131 let join_key_pair =
132 find_valid_equijoin_key_pair(left, right, left_schema, right_schema)?;
133
134 if let Some((left_expr, right_expr)) = join_key_pair {
135 let left_expr_type = left_expr.get_type(left_schema)?;
136 let right_expr_type = right_expr.get_type(right_schema)?;
137
138 if can_hash(&left_expr_type) && can_hash(&right_expr_type) {
139 accum_join_keys.push((left_expr, right_expr));
140 } else {
141 accum_filters.push(expr);
142 }
143 } else {
144 accum_filters.push(expr);
145 }
146 }
147 _ => accum_filters.push(expr),
148 }
149 }
150
151 let result_filter = accum_filters.into_iter().reduce(Expr::and);
152 Ok((accum_join_keys, result_filter))
153}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158 use crate::assert_optimized_plan_eq_display_indent_snapshot;
159 use crate::test::*;
160 use arrow::datatypes::DataType;
161 use datafusion_expr::{
162 col, lit, logical_plan::builder::LogicalPlanBuilder, JoinType,
163 };
164 use std::sync::Arc;
165
166 macro_rules! assert_optimized_plan_equal {
167 (
168 $plan:expr,
169 @ $expected:literal $(,)?
170 ) => {{
171 let rule: Arc<dyn crate::OptimizerRule + Send + Sync> = Arc::new(ExtractEquijoinPredicate {});
172 assert_optimized_plan_eq_display_indent_snapshot!(
173 rule,
174 $plan,
175 @ $expected,
176 )
177 }};
178 }
179
180 #[test]
181 fn join_with_only_column_equi_predicate() -> Result<()> {
182 let t1 = test_table_scan_with_name("t1")?;
183 let t2 = test_table_scan_with_name("t2")?;
184
185 let plan = LogicalPlanBuilder::from(t1)
186 .join_on(t2, JoinType::Left, Some(col("t1.a").eq(col("t2.a"))))?
187 .build()?;
188
189 assert_optimized_plan_equal!(
190 plan,
191 @r"
192 Left Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
193 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
194 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
195 "
196 )
197 }
198
199 #[test]
200 fn join_with_only_equi_expr_predicate() -> Result<()> {
201 let t1 = test_table_scan_with_name("t1")?;
202 let t2 = test_table_scan_with_name("t2")?;
203
204 let plan = LogicalPlanBuilder::from(t1)
205 .join_on(
206 t2,
207 JoinType::Left,
208 Some((col("t1.a") + lit(10i64)).eq(col("t2.a") * lit(2u32))),
209 )?
210 .build()?;
211
212 assert_optimized_plan_equal!(
213 plan,
214 @r"
215 Left Join: t1.a + Int64(10) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
216 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
217 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
218 "
219 )
220 }
221
222 #[test]
223 fn join_with_only_none_equi_predicate() -> Result<()> {
224 let t1 = test_table_scan_with_name("t1")?;
225 let t2 = test_table_scan_with_name("t2")?;
226
227 let plan = LogicalPlanBuilder::from(t1)
228 .join_on(
229 t2,
230 JoinType::Left,
231 Some(
232 (col("t1.a") + lit(10i64))
233 .gt_eq(col("t2.a") * lit(2u32))
234 .and(col("t1.b").lt(lit(100i32))),
235 ),
236 )?
237 .build()?;
238
239 assert_optimized_plan_equal!(
240 plan,
241 @r"
242 Left Join: Filter: t1.a + Int64(10) >= t2.a * UInt32(2) AND t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
243 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
244 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
245 "
246 )
247 }
248
249 #[test]
250 fn join_with_expr_both_from_filter_and_keys() -> Result<()> {
251 let t1 = test_table_scan_with_name("t1")?;
252 let t2 = test_table_scan_with_name("t2")?;
253
254 let plan = LogicalPlanBuilder::from(t1)
255 .join_with_expr_keys(
256 t2,
257 JoinType::Left,
258 (
259 vec![col("t1.a") + lit(11u32)],
260 vec![col("t2.a") * lit(2u32)],
261 ),
262 Some(
263 (col("t1.a") + lit(10i64))
264 .eq(col("t2.a") * lit(2u32))
265 .and(col("t1.b").lt(lit(100i32))),
266 ),
267 )?
268 .build()?;
269
270 assert_optimized_plan_equal!(
271 plan,
272 @r"
273 Left Join: t1.a + UInt32(11) = t2.a * UInt32(2), t1.a + Int64(10) = t2.a * UInt32(2) Filter: t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
274 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
275 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
276 "
277 )
278 }
279
280 #[test]
281 fn join_with_and_or_filter() -> Result<()> {
282 let t1 = test_table_scan_with_name("t1")?;
283 let t2 = test_table_scan_with_name("t2")?;
284
285 let plan = LogicalPlanBuilder::from(t1)
286 .join_on(
287 t2,
288 JoinType::Left,
289 Some(
290 col("t1.c")
291 .eq(col("t2.c"))
292 .or((col("t1.a") + col("t1.b")).gt(col("t2.b") + col("t2.c")))
293 .and(
294 col("t1.a").eq(col("t2.a")).and(col("t1.b").eq(col("t2.b"))),
295 ),
296 ),
297 )?
298 .build()?;
299
300 assert_optimized_plan_equal!(
301 plan,
302 @r"
303 Left Join: t1.a = t2.a, t1.b = t2.b Filter: t1.c = t2.c OR t1.a + t1.b > t2.b + t2.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
304 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
305 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
306 "
307 )
308 }
309
310 #[test]
311 fn join_with_multiple_table() -> Result<()> {
312 let t1 = test_table_scan_with_name("t1")?;
313 let t2 = test_table_scan_with_name("t2")?;
314 let t3 = test_table_scan_with_name("t3")?;
315
316 let input = LogicalPlanBuilder::from(t2)
317 .join_on(
318 t3,
319 JoinType::Left,
320 Some(
321 col("t2.a")
322 .eq(col("t3.a"))
323 .and((col("t2.a") + col("t3.b")).gt(lit(100u32))),
324 ),
325 )?
326 .build()?;
327 let plan = LogicalPlanBuilder::from(t1)
328 .join_on(
329 input,
330 JoinType::Left,
331 Some(
332 col("t1.a")
333 .eq(col("t2.a"))
334 .and((col("t1.c") + col("t2.c") + col("t3.c")).lt(lit(100u32))),
335 ),
336 )?
337 .build()?;
338
339 assert_optimized_plan_equal!(
340 plan,
341 @r"
342 Left Join: t1.a = t2.a Filter: t1.c + t2.c + t3.c < UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, c:UInt32;N]
343 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
344 Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
345 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
346 TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
347 "
348 )
349 }
350
351 #[test]
352 fn join_with_multiple_table_and_eq_filter() -> Result<()> {
353 let t1 = test_table_scan_with_name("t1")?;
354 let t2 = test_table_scan_with_name("t2")?;
355 let t3 = test_table_scan_with_name("t3")?;
356
357 let input = LogicalPlanBuilder::from(t2)
358 .join_on(
359 t3,
360 JoinType::Left,
361 Some(
362 col("t2.a")
363 .eq(col("t3.a"))
364 .and((col("t2.a") + col("t3.b")).gt(lit(100u32))),
365 ),
366 )?
367 .build()?;
368 let plan = LogicalPlanBuilder::from(t1)
369 .join_on(
370 input,
371 JoinType::Left,
372 Some(col("t1.a").eq(col("t2.a")).and(col("t2.c").eq(col("t3.c")))),
373 )?
374 .build()?;
375
376 assert_optimized_plan_equal!(
377 plan,
378 @r"
379 Left Join: t1.a = t2.a Filter: t2.c = t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, c:UInt32;N]
380 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
381 Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
382 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
383 TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
384 "
385 )
386 }
387
388 #[test]
389 fn join_with_alias_filter() -> Result<()> {
390 let t1 = test_table_scan_with_name("t1")?;
391 let t2 = test_table_scan_with_name("t2")?;
392
393 let t1_schema = Arc::clone(t1.schema());
394 let t2_schema = Arc::clone(t2.schema());
395
396 let filter = Expr::eq(
398 col("t1.a") + lit(1i64).cast_to(&DataType::UInt32, &t1_schema)?,
399 col("t2.a") + lit(2i32).cast_to(&DataType::UInt32, &t2_schema)?,
400 )
401 .alias("t1.a + 1 = t2.a + 2");
402 let plan = LogicalPlanBuilder::from(t1)
403 .join_on(t2, JoinType::Left, Some(filter))?
404 .build()?;
405
406 assert_optimized_plan_equal!(
407 plan,
408 @r"
409 Left Join: t1.a + CAST(Int64(1) AS UInt32) = t2.a + CAST(Int32(2) AS UInt32) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
410 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
411 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
412 "
413 )
414 }
415}