1use crate::ast::{BinaryOp, Expr, JoinType, UnaryOp};
27use crate::optimizer::OptimizerPass;
28use crate::planner::LogicalPlan;
29use alloc::boxed::Box;
30use alloc::string::String;
31use hashbrown::HashSet;
32
33pub struct OuterJoinSimplification;
37
38impl OptimizerPass for OuterJoinSimplification {
39 fn optimize(&self, plan: LogicalPlan) -> LogicalPlan {
40 self.simplify(plan)
41 }
42
43 fn name(&self) -> &'static str {
44 "outer_join_simplification"
45 }
46}
47
48impl OuterJoinSimplification {
49 fn simplify(&self, plan: LogicalPlan) -> LogicalPlan {
50 match plan {
51 LogicalPlan::Filter { input, predicate } => {
53 let optimized_input = self.simplify(*input);
54
55 if let LogicalPlan::Join {
57 left,
58 right,
59 condition,
60 join_type,
61 } = optimized_input
62 {
63 if let Some(new_join_type) =
64 self.try_simplify_join(&predicate, &left, &right, join_type)
65 {
66 return LogicalPlan::Filter {
67 input: Box::new(LogicalPlan::Join {
68 left,
69 right,
70 condition,
71 join_type: new_join_type,
72 }),
73 predicate,
74 };
75 }
76
77 return LogicalPlan::Filter {
79 input: Box::new(LogicalPlan::Join {
80 left,
81 right,
82 condition,
83 join_type,
84 }),
85 predicate,
86 };
87 }
88
89 LogicalPlan::Filter {
90 input: Box::new(optimized_input),
91 predicate,
92 }
93 }
94
95 LogicalPlan::Project { input, columns } => LogicalPlan::Project {
96 input: Box::new(self.simplify(*input)),
97 columns,
98 },
99
100 LogicalPlan::Join {
101 left,
102 right,
103 condition,
104 join_type,
105 } => LogicalPlan::Join {
106 left: Box::new(self.simplify(*left)),
107 right: Box::new(self.simplify(*right)),
108 condition,
109 join_type,
110 },
111
112 LogicalPlan::Aggregate {
113 input,
114 group_by,
115 aggregates,
116 } => LogicalPlan::Aggregate {
117 input: Box::new(self.simplify(*input)),
118 group_by,
119 aggregates,
120 },
121
122 LogicalPlan::Sort { input, order_by } => LogicalPlan::Sort {
123 input: Box::new(self.simplify(*input)),
124 order_by,
125 },
126
127 LogicalPlan::Limit {
128 input,
129 limit,
130 offset,
131 } => LogicalPlan::Limit {
132 input: Box::new(self.simplify(*input)),
133 limit,
134 offset,
135 },
136
137 LogicalPlan::CrossProduct { left, right } => LogicalPlan::CrossProduct {
138 left: Box::new(self.simplify(*left)),
139 right: Box::new(self.simplify(*right)),
140 },
141
142 LogicalPlan::Union { left, right, all } => LogicalPlan::Union {
143 left: Box::new(self.simplify(*left)),
144 right: Box::new(self.simplify(*right)),
145 all,
146 },
147
148 LogicalPlan::Scan { .. }
150 | LogicalPlan::IndexScan { .. }
151 | LogicalPlan::IndexGet { .. }
152 | LogicalPlan::IndexInGet { .. }
153 | LogicalPlan::GinIndexScan { .. }
154 | LogicalPlan::GinIndexScanMulti { .. }
155 | LogicalPlan::Empty => plan,
156 }
157 }
158
159 fn try_simplify_join(
161 &self,
162 predicate: &Expr,
163 left: &LogicalPlan,
164 right: &LogicalPlan,
165 join_type: JoinType,
166 ) -> Option<JoinType> {
167 match join_type {
168 JoinType::LeftOuter => {
169 let right_tables = self.extract_tables(right);
171 if self.predicate_rejects_null(predicate, &right_tables) {
172 return Some(JoinType::Inner);
173 }
174 None
175 }
176
177 JoinType::RightOuter => {
178 let left_tables = self.extract_tables(left);
180 if self.predicate_rejects_null(predicate, &left_tables) {
181 return Some(JoinType::Inner);
182 }
183 None
184 }
185
186 JoinType::FullOuter => {
187 let left_tables = self.extract_tables(left);
189 let right_tables = self.extract_tables(right);
190
191 let rejects_left_null = self.predicate_rejects_null(predicate, &left_tables);
192 let rejects_right_null = self.predicate_rejects_null(predicate, &right_tables);
193
194 if rejects_left_null && rejects_right_null {
195 return Some(JoinType::Inner);
196 } else if rejects_right_null {
197 return Some(JoinType::LeftOuter);
198 } else if rejects_left_null {
199 return Some(JoinType::RightOuter);
200 }
201 None
202 }
203
204 JoinType::Inner | JoinType::Cross => None,
206 }
207 }
208
209 fn predicate_rejects_null(&self, predicate: &Expr, tables: &HashSet<String>) -> bool {
211 match predicate {
212 Expr::UnaryOp {
214 op: UnaryOp::IsNotNull,
215 expr,
216 } => self.expr_references_tables(expr, tables),
217
218 Expr::BinaryOp { left, op, right } => {
220 match op {
221 BinaryOp::Eq
223 | BinaryOp::Ne
224 | BinaryOp::Lt
225 | BinaryOp::Le
226 | BinaryOp::Gt
227 | BinaryOp::Ge => {
228 let left_refs_tables = self.expr_references_tables(left, tables);
230 let right_refs_tables = self.expr_references_tables(right, tables);
231 let left_is_literal = matches!(left.as_ref(), Expr::Literal(_));
232 let right_is_literal = matches!(right.as_ref(), Expr::Literal(_));
233
234 (left_refs_tables && right_is_literal)
236 || (right_refs_tables && left_is_literal)
237 || (left_refs_tables && right_refs_tables)
239 }
240
241 BinaryOp::And => {
244 self.predicate_rejects_null(left, tables)
245 || self.predicate_rejects_null(right, tables)
246 }
247
248 BinaryOp::Or => {
250 self.predicate_rejects_null(left, tables)
251 && self.predicate_rejects_null(right, tables)
252 }
253
254 BinaryOp::Like => self.expr_references_tables(left, tables),
256
257 BinaryOp::In => self.expr_references_tables(left, tables),
259
260 BinaryOp::Between => self.expr_references_tables(left, tables),
262
263 _ => false,
264 }
265 }
266
267 Expr::In { expr, .. } => self.expr_references_tables(expr, tables),
269
270 Expr::Between { expr, .. } => self.expr_references_tables(expr, tables),
272
273 Expr::Like { expr, .. } => self.expr_references_tables(expr, tables),
275
276 Expr::UnaryOp {
278 op: UnaryOp::IsNull,
279 ..
280 } => false,
281
282 Expr::UnaryOp {
284 op: UnaryOp::Not,
285 expr,
286 } => {
287 if let Expr::UnaryOp {
289 op: UnaryOp::IsNull,
290 expr: inner,
291 } = expr.as_ref()
292 {
293 return self.expr_references_tables(inner, tables);
294 }
295 false
296 }
297
298 _ => false,
299 }
300 }
301
302 fn expr_references_tables(&self, expr: &Expr, tables: &HashSet<String>) -> bool {
304 match expr {
305 Expr::Column(col) => tables.contains(&col.table),
306 Expr::BinaryOp { left, right, .. } => {
307 self.expr_references_tables(left, tables)
308 || self.expr_references_tables(right, tables)
309 }
310 Expr::UnaryOp { expr, .. } => self.expr_references_tables(expr, tables),
311 Expr::Function { args, .. } => {
312 args.iter().any(|arg| self.expr_references_tables(arg, tables))
313 }
314 Expr::Aggregate { expr, .. } => {
315 expr.as_ref()
316 .map(|e| self.expr_references_tables(e, tables))
317 .unwrap_or(false)
318 }
319 Expr::Between { expr, low, high } => {
320 self.expr_references_tables(expr, tables)
321 || self.expr_references_tables(low, tables)
322 || self.expr_references_tables(high, tables)
323 }
324 Expr::In { expr, list } => {
325 self.expr_references_tables(expr, tables)
326 || list.iter().any(|e| self.expr_references_tables(e, tables))
327 }
328 Expr::Like { expr, .. } => self.expr_references_tables(expr, tables),
329 Expr::NotBetween { expr, low, high } => {
330 self.expr_references_tables(expr, tables)
331 || self.expr_references_tables(low, tables)
332 || self.expr_references_tables(high, tables)
333 }
334 Expr::NotIn { expr, list } => {
335 self.expr_references_tables(expr, tables)
336 || list.iter().any(|e| self.expr_references_tables(e, tables))
337 }
338 Expr::NotLike { expr, .. } => self.expr_references_tables(expr, tables),
339 Expr::Match { expr, .. } => self.expr_references_tables(expr, tables),
340 Expr::NotMatch { expr, .. } => self.expr_references_tables(expr, tables),
341 Expr::Literal(_) => false,
342 }
343 }
344
345 fn extract_tables(&self, plan: &LogicalPlan) -> HashSet<String> {
347 let mut tables = HashSet::new();
348 self.collect_tables(plan, &mut tables);
349 tables
350 }
351
352 fn collect_tables(&self, plan: &LogicalPlan, tables: &mut HashSet<String>) {
353 match plan {
354 LogicalPlan::Scan { table } => {
355 tables.insert(table.clone());
356 }
357 LogicalPlan::IndexScan { table, .. }
358 | LogicalPlan::IndexGet { table, .. }
359 | LogicalPlan::IndexInGet { table, .. }
360 | LogicalPlan::GinIndexScan { table, .. }
361 | LogicalPlan::GinIndexScanMulti { table, .. } => {
362 tables.insert(table.clone());
363 }
364 LogicalPlan::Filter { input, .. }
365 | LogicalPlan::Project { input, .. }
366 | LogicalPlan::Aggregate { input, .. }
367 | LogicalPlan::Sort { input, .. }
368 | LogicalPlan::Limit { input, .. } => {
369 self.collect_tables(input, tables);
370 }
371 LogicalPlan::Join { left, right, .. }
372 | LogicalPlan::CrossProduct { left, right }
373 | LogicalPlan::Union { left, right, .. } => {
374 self.collect_tables(left, tables);
375 self.collect_tables(right, tables);
376 }
377 LogicalPlan::Empty => {}
378 }
379 }
380}
381
382impl Default for OuterJoinSimplification {
383 fn default() -> Self {
384 Self
385 }
386}
387
388#[cfg(test)]
389mod tests {
390 use super::*;
391
392 #[test]
393 fn test_left_join_to_inner_with_equality() {
394 let pass = OuterJoinSimplification;
395
396 let plan = LogicalPlan::filter(
399 LogicalPlan::left_join(
400 LogicalPlan::scan("users"),
401 LogicalPlan::scan("orders"),
402 Expr::eq(
403 Expr::column("users", "id", 0),
404 Expr::column("orders", "user_id", 0),
405 ),
406 ),
407 Expr::eq(Expr::column("orders", "amount", 1), Expr::literal(100i64)),
408 );
409
410 let optimized = pass.optimize(plan);
411
412 if let LogicalPlan::Filter { input, .. } = optimized {
413 if let LogicalPlan::Join { join_type, .. } = *input {
414 assert_eq!(join_type, JoinType::Inner);
415 } else {
416 panic!("Expected Join");
417 }
418 } else {
419 panic!("Expected Filter");
420 }
421 }
422
423 #[test]
424 fn test_left_join_to_inner_with_is_not_null() {
425 let pass = OuterJoinSimplification;
426
427 let plan = LogicalPlan::filter(
429 LogicalPlan::left_join(
430 LogicalPlan::scan("users"),
431 LogicalPlan::scan("orders"),
432 Expr::eq(
433 Expr::column("users", "id", 0),
434 Expr::column("orders", "user_id", 0),
435 ),
436 ),
437 Expr::is_not_null(Expr::column("orders", "id", 0)),
438 );
439
440 let optimized = pass.optimize(plan);
441
442 if let LogicalPlan::Filter { input, .. } = optimized {
443 if let LogicalPlan::Join { join_type, .. } = *input {
444 assert_eq!(join_type, JoinType::Inner);
445 } else {
446 panic!("Expected Join");
447 }
448 } else {
449 panic!("Expected Filter");
450 }
451 }
452
453 #[test]
454 fn test_left_join_to_inner_with_comparison() {
455 let pass = OuterJoinSimplification;
456
457 let plan = LogicalPlan::filter(
459 LogicalPlan::left_join(
460 LogicalPlan::scan("users"),
461 LogicalPlan::scan("orders"),
462 Expr::eq(
463 Expr::column("users", "id", 0),
464 Expr::column("orders", "user_id", 0),
465 ),
466 ),
467 Expr::gt(Expr::column("orders", "amount", 1), Expr::literal(100i64)),
468 );
469
470 let optimized = pass.optimize(plan);
471
472 if let LogicalPlan::Filter { input, .. } = optimized {
473 if let LogicalPlan::Join { join_type, .. } = *input {
474 assert_eq!(join_type, JoinType::Inner);
475 } else {
476 panic!("Expected Join");
477 }
478 } else {
479 panic!("Expected Filter");
480 }
481 }
482
483 #[test]
484 fn test_left_join_unchanged_with_left_predicate() {
485 let pass = OuterJoinSimplification;
486
487 let plan = LogicalPlan::filter(
489 LogicalPlan::left_join(
490 LogicalPlan::scan("users"),
491 LogicalPlan::scan("orders"),
492 Expr::eq(
493 Expr::column("users", "id", 0),
494 Expr::column("orders", "user_id", 0),
495 ),
496 ),
497 Expr::eq(Expr::column("users", "active", 1), Expr::literal(true)),
498 );
499
500 let optimized = pass.optimize(plan);
501
502 if let LogicalPlan::Filter { input, .. } = optimized {
503 if let LogicalPlan::Join { join_type, .. } = *input {
504 assert_eq!(join_type, JoinType::LeftOuter);
505 } else {
506 panic!("Expected Join");
507 }
508 } else {
509 panic!("Expected Filter");
510 }
511 }
512
513 #[test]
514 fn test_left_join_unchanged_with_is_null() {
515 let pass = OuterJoinSimplification;
516
517 let plan = LogicalPlan::filter(
519 LogicalPlan::left_join(
520 LogicalPlan::scan("users"),
521 LogicalPlan::scan("orders"),
522 Expr::eq(
523 Expr::column("users", "id", 0),
524 Expr::column("orders", "user_id", 0),
525 ),
526 ),
527 Expr::is_null(Expr::column("orders", "id", 0)),
528 );
529
530 let optimized = pass.optimize(plan);
531
532 if let LogicalPlan::Filter { input, .. } = optimized {
533 if let LogicalPlan::Join { join_type, .. } = *input {
534 assert_eq!(join_type, JoinType::LeftOuter);
535 } else {
536 panic!("Expected Join");
537 }
538 } else {
539 panic!("Expected Filter");
540 }
541 }
542
543 #[test]
544 fn test_right_join_to_inner() {
545 let pass = OuterJoinSimplification;
546
547 let plan = LogicalPlan::filter(
549 LogicalPlan::Join {
550 left: Box::new(LogicalPlan::scan("users")),
551 right: Box::new(LogicalPlan::scan("orders")),
552 condition: Expr::eq(
553 Expr::column("users", "id", 0),
554 Expr::column("orders", "user_id", 0),
555 ),
556 join_type: JoinType::RightOuter,
557 },
558 Expr::eq(Expr::column("users", "active", 1), Expr::literal(true)),
559 );
560
561 let optimized = pass.optimize(plan);
562
563 if let LogicalPlan::Filter { input, .. } = optimized {
564 if let LogicalPlan::Join { join_type, .. } = *input {
565 assert_eq!(join_type, JoinType::Inner);
566 } else {
567 panic!("Expected Join");
568 }
569 } else {
570 panic!("Expected Filter");
571 }
572 }
573
574 #[test]
575 fn test_and_predicate_rejects_null() {
576 let pass = OuterJoinSimplification;
577
578 let plan = LogicalPlan::filter(
581 LogicalPlan::left_join(
582 LogicalPlan::scan("users"),
583 LogicalPlan::scan("orders"),
584 Expr::eq(
585 Expr::column("users", "id", 0),
586 Expr::column("orders", "user_id", 0),
587 ),
588 ),
589 Expr::and(
590 Expr::gt(Expr::column("orders", "amount", 1), Expr::literal(100i64)),
591 Expr::eq(
592 Expr::column("orders", "status", 2),
593 Expr::literal("active"),
594 ),
595 ),
596 );
597
598 let optimized = pass.optimize(plan);
599
600 if let LogicalPlan::Filter { input, .. } = optimized {
601 if let LogicalPlan::Join { join_type, .. } = *input {
602 assert_eq!(join_type, JoinType::Inner);
603 } else {
604 panic!("Expected Join");
605 }
606 } else {
607 panic!("Expected Filter");
608 }
609 }
610
611 #[test]
612 fn test_inner_join_unchanged() {
613 let pass = OuterJoinSimplification;
614
615 let plan = LogicalPlan::filter(
617 LogicalPlan::inner_join(
618 LogicalPlan::scan("users"),
619 LogicalPlan::scan("orders"),
620 Expr::eq(
621 Expr::column("users", "id", 0),
622 Expr::column("orders", "user_id", 0),
623 ),
624 ),
625 Expr::gt(Expr::column("orders", "amount", 1), Expr::literal(100i64)),
626 );
627
628 let optimized = pass.optimize(plan);
629
630 if let LogicalPlan::Filter { input, .. } = optimized {
631 if let LogicalPlan::Join { join_type, .. } = *input {
632 assert_eq!(join_type, JoinType::Inner);
633 } else {
634 panic!("Expected Join");
635 }
636 } else {
637 panic!("Expected Filter");
638 }
639 }
640
641 #[test]
642 fn test_nested_joins() {
643 let pass = OuterJoinSimplification;
644
645 let inner_join = LogicalPlan::left_join(
647 LogicalPlan::scan("orders"),
648 LogicalPlan::scan("items"),
649 Expr::eq(
650 Expr::column("orders", "id", 0),
651 Expr::column("items", "order_id", 0),
652 ),
653 );
654
655 let plan = LogicalPlan::filter(
656 LogicalPlan::left_join(
657 LogicalPlan::scan("users"),
658 inner_join,
659 Expr::eq(
660 Expr::column("users", "id", 0),
661 Expr::column("orders", "user_id", 0),
662 ),
663 ),
664 Expr::gt(Expr::column("orders", "amount", 1), Expr::literal(100i64)),
665 );
666
667 let optimized = pass.optimize(plan);
668
669 if let LogicalPlan::Filter { input, .. } = optimized {
671 if let LogicalPlan::Join { join_type, .. } = *input {
672 assert_eq!(join_type, JoinType::Inner);
673 } else {
674 panic!("Expected Join");
675 }
676 } else {
677 panic!("Expected Filter");
678 }
679 }
680}