Skip to main content

cynos_query/optimizer/
outer_join_simplification.rs

1//! Outer join simplification optimization pass.
2//!
3//! This pass converts outer joins to inner joins when the WHERE clause
4//! contains predicates that would filter out NULL values from the outer side.
5//!
6//! Example:
7//! ```text
8//! Filter(orders.amount > 100)     =>    InnerJoin(users, orders)
9//!        |                                     |
10//! LeftJoin(users, orders)              (same condition)
11//! ```
12//!
13//! This optimization is safe because:
14//! - LEFT JOIN produces NULL for right-side columns when there's no match
15//! - If the WHERE clause filters on right-side columns with non-NULL conditions,
16//!   those NULL rows would be filtered out anyway
17//! - Converting to INNER JOIN is more efficient (no NULL handling needed)
18//!
19//! Conditions that reject NULL:
20//! - `col = value` (equality with non-NULL value)
21//! - `col > value`, `col < value`, etc. (comparisons)
22//! - `col IS NOT NULL`
23//! - `col IN (...)` with non-NULL values
24//! - `col BETWEEN a AND b`
25
26use crate::ast::{BinaryOp, Expr, JoinType, UnaryOp};
27use crate::optimizer::OptimizerPass;
28use crate::planner::LogicalPlan;
29use alloc::boxed::Box;
30use alloc::string::String;
31use hashbrown::HashSet;
32
33/// Outer join simplification optimization.
34///
35/// Converts outer joins to inner joins when predicates reject NULL values.
36pub struct OuterJoinSimplification;
37
38impl OptimizerPass for OuterJoinSimplification {
39    fn optimize(&self, plan: LogicalPlan) -> LogicalPlan {
40        self.simplify(plan)
41    }
42
43    fn name(&self) -> &'static str {
44        "outer_join_simplification"
45    }
46}
47
48impl OuterJoinSimplification {
49    fn simplify(&self, plan: LogicalPlan) -> LogicalPlan {
50        match plan {
51            // Look for Filter above Join pattern
52            LogicalPlan::Filter { input, predicate } => {
53                let optimized_input = self.simplify(*input);
54
55                // Check if we can simplify an outer join
56                if let LogicalPlan::Join {
57                    left,
58                    right,
59                    condition,
60                    join_type,
61                } = optimized_input
62                {
63                    if let Some(new_join_type) =
64                        self.try_simplify_join(&predicate, &left, &right, join_type)
65                    {
66                        return LogicalPlan::Filter {
67                            input: Box::new(LogicalPlan::Join {
68                                left,
69                                right,
70                                condition,
71                                join_type: new_join_type,
72                            }),
73                            predicate,
74                        };
75                    }
76
77                    // No simplification possible
78                    return LogicalPlan::Filter {
79                        input: Box::new(LogicalPlan::Join {
80                            left,
81                            right,
82                            condition,
83                            join_type,
84                        }),
85                        predicate,
86                    };
87                }
88
89                LogicalPlan::Filter {
90                    input: Box::new(optimized_input),
91                    predicate,
92                }
93            }
94
95            LogicalPlan::Project { input, columns } => LogicalPlan::Project {
96                input: Box::new(self.simplify(*input)),
97                columns,
98            },
99
100            LogicalPlan::Join {
101                left,
102                right,
103                condition,
104                join_type,
105            } => LogicalPlan::Join {
106                left: Box::new(self.simplify(*left)),
107                right: Box::new(self.simplify(*right)),
108                condition,
109                join_type,
110            },
111
112            LogicalPlan::Aggregate {
113                input,
114                group_by,
115                aggregates,
116            } => LogicalPlan::Aggregate {
117                input: Box::new(self.simplify(*input)),
118                group_by,
119                aggregates,
120            },
121
122            LogicalPlan::Sort { input, order_by } => LogicalPlan::Sort {
123                input: Box::new(self.simplify(*input)),
124                order_by,
125            },
126
127            LogicalPlan::Limit {
128                input,
129                limit,
130                offset,
131            } => LogicalPlan::Limit {
132                input: Box::new(self.simplify(*input)),
133                limit,
134                offset,
135            },
136
137            LogicalPlan::CrossProduct { left, right } => LogicalPlan::CrossProduct {
138                left: Box::new(self.simplify(*left)),
139                right: Box::new(self.simplify(*right)),
140            },
141
142            LogicalPlan::Union { left, right, all } => LogicalPlan::Union {
143                left: Box::new(self.simplify(*left)),
144                right: Box::new(self.simplify(*right)),
145                all,
146            },
147
148            // Leaf nodes - no transformation
149            LogicalPlan::Scan { .. }
150            | LogicalPlan::IndexScan { .. }
151            | LogicalPlan::IndexGet { .. }
152            | LogicalPlan::IndexInGet { .. }
153            | LogicalPlan::GinIndexScan { .. }
154            | LogicalPlan::GinIndexScanMulti { .. }
155            | LogicalPlan::Empty => plan,
156        }
157    }
158
159    /// Try to simplify an outer join to inner join based on the predicate.
160    fn try_simplify_join(
161        &self,
162        predicate: &Expr,
163        left: &LogicalPlan,
164        right: &LogicalPlan,
165        join_type: JoinType,
166    ) -> Option<JoinType> {
167        match join_type {
168            JoinType::LeftOuter => {
169                // For LEFT JOIN, check if predicate rejects NULLs from right side
170                let right_tables = self.extract_tables(right);
171                if self.predicate_rejects_null(predicate, &right_tables) {
172                    return Some(JoinType::Inner);
173                }
174                None
175            }
176
177            JoinType::RightOuter => {
178                // For RIGHT JOIN, check if predicate rejects NULLs from left side
179                let left_tables = self.extract_tables(left);
180                if self.predicate_rejects_null(predicate, &left_tables) {
181                    return Some(JoinType::Inner);
182                }
183                None
184            }
185
186            JoinType::FullOuter => {
187                // For FULL OUTER JOIN, need predicates rejecting NULLs from both sides
188                let left_tables = self.extract_tables(left);
189                let right_tables = self.extract_tables(right);
190
191                let rejects_left_null = self.predicate_rejects_null(predicate, &left_tables);
192                let rejects_right_null = self.predicate_rejects_null(predicate, &right_tables);
193
194                if rejects_left_null && rejects_right_null {
195                    return Some(JoinType::Inner);
196                } else if rejects_right_null {
197                    return Some(JoinType::LeftOuter);
198                } else if rejects_left_null {
199                    return Some(JoinType::RightOuter);
200                }
201                None
202            }
203
204            // Inner and Cross joins don't need simplification
205            JoinType::Inner | JoinType::Cross => None,
206        }
207    }
208
209    /// Check if a predicate would reject NULL values from the given tables.
210    fn predicate_rejects_null(&self, predicate: &Expr, tables: &HashSet<String>) -> bool {
211        match predicate {
212            // IS NOT NULL explicitly rejects NULL
213            Expr::UnaryOp {
214                op: UnaryOp::IsNotNull,
215                expr,
216            } => self.expr_references_tables(expr, tables),
217
218            // Comparisons with literals reject NULL (NULL compared to anything is NULL/false)
219            Expr::BinaryOp { left, op, right } => {
220                match op {
221                    // Equality and comparison operators reject NULL
222                    BinaryOp::Eq
223                    | BinaryOp::Ne
224                    | BinaryOp::Lt
225                    | BinaryOp::Le
226                    | BinaryOp::Gt
227                    | BinaryOp::Ge => {
228                        // Check if one side is a column from target tables and other is literal
229                        let left_refs_tables = self.expr_references_tables(left, tables);
230                        let right_refs_tables = self.expr_references_tables(right, tables);
231                        let left_is_literal = matches!(left.as_ref(), Expr::Literal(_));
232                        let right_is_literal = matches!(right.as_ref(), Expr::Literal(_));
233
234                        // col = literal or literal = col rejects NULL
235                        (left_refs_tables && right_is_literal)
236                            || (right_refs_tables && left_is_literal)
237                            // col = col from same tables also rejects NULL
238                            || (left_refs_tables && right_refs_tables)
239                    }
240
241                    // AND: both sides must reject NULL for the whole predicate to reject NULL
242                    // But if either side rejects NULL, the row is filtered
243                    BinaryOp::And => {
244                        self.predicate_rejects_null(left, tables)
245                            || self.predicate_rejects_null(right, tables)
246                    }
247
248                    // OR: both sides must reject NULL
249                    BinaryOp::Or => {
250                        self.predicate_rejects_null(left, tables)
251                            && self.predicate_rejects_null(right, tables)
252                    }
253
254                    // LIKE rejects NULL
255                    BinaryOp::Like => self.expr_references_tables(left, tables),
256
257                    // IN rejects NULL
258                    BinaryOp::In => self.expr_references_tables(left, tables),
259
260                    // BETWEEN rejects NULL
261                    BinaryOp::Between => self.expr_references_tables(left, tables),
262
263                    _ => false,
264                }
265            }
266
267            // IN expression rejects NULL
268            Expr::In { expr, .. } => self.expr_references_tables(expr, tables),
269
270            // BETWEEN rejects NULL
271            Expr::Between { expr, .. } => self.expr_references_tables(expr, tables),
272
273            // LIKE rejects NULL
274            Expr::Like { expr, .. } => self.expr_references_tables(expr, tables),
275
276            // IS NULL does NOT reject NULL (it accepts NULL)
277            Expr::UnaryOp {
278                op: UnaryOp::IsNull,
279                ..
280            } => false,
281
282            // NOT of something that accepts NULL might reject NULL
283            Expr::UnaryOp {
284                op: UnaryOp::Not,
285                expr,
286            } => {
287                // NOT (IS NULL) = IS NOT NULL, which rejects NULL
288                if let Expr::UnaryOp {
289                    op: UnaryOp::IsNull,
290                    expr: inner,
291                } = expr.as_ref()
292                {
293                    return self.expr_references_tables(inner, tables);
294                }
295                false
296            }
297
298            _ => false,
299        }
300    }
301
302    /// Check if an expression references any of the given tables.
303    fn expr_references_tables(&self, expr: &Expr, tables: &HashSet<String>) -> bool {
304        match expr {
305            Expr::Column(col) => tables.contains(&col.table),
306            Expr::BinaryOp { left, right, .. } => {
307                self.expr_references_tables(left, tables)
308                    || self.expr_references_tables(right, tables)
309            }
310            Expr::UnaryOp { expr, .. } => self.expr_references_tables(expr, tables),
311            Expr::Function { args, .. } => {
312                args.iter().any(|arg| self.expr_references_tables(arg, tables))
313            }
314            Expr::Aggregate { expr, .. } => {
315                expr.as_ref()
316                    .map(|e| self.expr_references_tables(e, tables))
317                    .unwrap_or(false)
318            }
319            Expr::Between { expr, low, high } => {
320                self.expr_references_tables(expr, tables)
321                    || self.expr_references_tables(low, tables)
322                    || self.expr_references_tables(high, tables)
323            }
324            Expr::In { expr, list } => {
325                self.expr_references_tables(expr, tables)
326                    || list.iter().any(|e| self.expr_references_tables(e, tables))
327            }
328            Expr::Like { expr, .. } => self.expr_references_tables(expr, tables),
329            Expr::NotBetween { expr, low, high } => {
330                self.expr_references_tables(expr, tables)
331                    || self.expr_references_tables(low, tables)
332                    || self.expr_references_tables(high, tables)
333            }
334            Expr::NotIn { expr, list } => {
335                self.expr_references_tables(expr, tables)
336                    || list.iter().any(|e| self.expr_references_tables(e, tables))
337            }
338            Expr::NotLike { expr, .. } => self.expr_references_tables(expr, tables),
339            Expr::Match { expr, .. } => self.expr_references_tables(expr, tables),
340            Expr::NotMatch { expr, .. } => self.expr_references_tables(expr, tables),
341            Expr::Literal(_) => false,
342        }
343    }
344
345    /// Extract all table names referenced by a plan.
346    fn extract_tables(&self, plan: &LogicalPlan) -> HashSet<String> {
347        let mut tables = HashSet::new();
348        self.collect_tables(plan, &mut tables);
349        tables
350    }
351
352    fn collect_tables(&self, plan: &LogicalPlan, tables: &mut HashSet<String>) {
353        match plan {
354            LogicalPlan::Scan { table } => {
355                tables.insert(table.clone());
356            }
357            LogicalPlan::IndexScan { table, .. }
358            | LogicalPlan::IndexGet { table, .. }
359            | LogicalPlan::IndexInGet { table, .. }
360            | LogicalPlan::GinIndexScan { table, .. }
361            | LogicalPlan::GinIndexScanMulti { table, .. } => {
362                tables.insert(table.clone());
363            }
364            LogicalPlan::Filter { input, .. }
365            | LogicalPlan::Project { input, .. }
366            | LogicalPlan::Aggregate { input, .. }
367            | LogicalPlan::Sort { input, .. }
368            | LogicalPlan::Limit { input, .. } => {
369                self.collect_tables(input, tables);
370            }
371            LogicalPlan::Join { left, right, .. }
372            | LogicalPlan::CrossProduct { left, right }
373            | LogicalPlan::Union { left, right, .. } => {
374                self.collect_tables(left, tables);
375                self.collect_tables(right, tables);
376            }
377            LogicalPlan::Empty => {}
378        }
379    }
380}
381
382impl Default for OuterJoinSimplification {
383    fn default() -> Self {
384        Self
385    }
386}
387
388#[cfg(test)]
389mod tests {
390    use super::*;
391
392    #[test]
393    fn test_left_join_to_inner_with_equality() {
394        let pass = OuterJoinSimplification;
395
396        // Filter(orders.amount = 100) -> LeftJoin
397        // Should convert to InnerJoin because equality rejects NULL
398        let plan = LogicalPlan::filter(
399            LogicalPlan::left_join(
400                LogicalPlan::scan("users"),
401                LogicalPlan::scan("orders"),
402                Expr::eq(
403                    Expr::column("users", "id", 0),
404                    Expr::column("orders", "user_id", 0),
405                ),
406            ),
407            Expr::eq(Expr::column("orders", "amount", 1), Expr::literal(100i64)),
408        );
409
410        let optimized = pass.optimize(plan);
411
412        if let LogicalPlan::Filter { input, .. } = optimized {
413            if let LogicalPlan::Join { join_type, .. } = *input {
414                assert_eq!(join_type, JoinType::Inner);
415            } else {
416                panic!("Expected Join");
417            }
418        } else {
419            panic!("Expected Filter");
420        }
421    }
422
423    #[test]
424    fn test_left_join_to_inner_with_is_not_null() {
425        let pass = OuterJoinSimplification;
426
427        // Filter(orders.id IS NOT NULL) -> LeftJoin
428        let plan = LogicalPlan::filter(
429            LogicalPlan::left_join(
430                LogicalPlan::scan("users"),
431                LogicalPlan::scan("orders"),
432                Expr::eq(
433                    Expr::column("users", "id", 0),
434                    Expr::column("orders", "user_id", 0),
435                ),
436            ),
437            Expr::is_not_null(Expr::column("orders", "id", 0)),
438        );
439
440        let optimized = pass.optimize(plan);
441
442        if let LogicalPlan::Filter { input, .. } = optimized {
443            if let LogicalPlan::Join { join_type, .. } = *input {
444                assert_eq!(join_type, JoinType::Inner);
445            } else {
446                panic!("Expected Join");
447            }
448        } else {
449            panic!("Expected Filter");
450        }
451    }
452
453    #[test]
454    fn test_left_join_to_inner_with_comparison() {
455        let pass = OuterJoinSimplification;
456
457        // Filter(orders.amount > 100) -> LeftJoin
458        let plan = LogicalPlan::filter(
459            LogicalPlan::left_join(
460                LogicalPlan::scan("users"),
461                LogicalPlan::scan("orders"),
462                Expr::eq(
463                    Expr::column("users", "id", 0),
464                    Expr::column("orders", "user_id", 0),
465                ),
466            ),
467            Expr::gt(Expr::column("orders", "amount", 1), Expr::literal(100i64)),
468        );
469
470        let optimized = pass.optimize(plan);
471
472        if let LogicalPlan::Filter { input, .. } = optimized {
473            if let LogicalPlan::Join { join_type, .. } = *input {
474                assert_eq!(join_type, JoinType::Inner);
475            } else {
476                panic!("Expected Join");
477            }
478        } else {
479            panic!("Expected Filter");
480        }
481    }
482
483    #[test]
484    fn test_left_join_unchanged_with_left_predicate() {
485        let pass = OuterJoinSimplification;
486
487        // Filter on LEFT side should NOT convert to inner join
488        let plan = LogicalPlan::filter(
489            LogicalPlan::left_join(
490                LogicalPlan::scan("users"),
491                LogicalPlan::scan("orders"),
492                Expr::eq(
493                    Expr::column("users", "id", 0),
494                    Expr::column("orders", "user_id", 0),
495                ),
496            ),
497            Expr::eq(Expr::column("users", "active", 1), Expr::literal(true)),
498        );
499
500        let optimized = pass.optimize(plan);
501
502        if let LogicalPlan::Filter { input, .. } = optimized {
503            if let LogicalPlan::Join { join_type, .. } = *input {
504                assert_eq!(join_type, JoinType::LeftOuter);
505            } else {
506                panic!("Expected Join");
507            }
508        } else {
509            panic!("Expected Filter");
510        }
511    }
512
513    #[test]
514    fn test_left_join_unchanged_with_is_null() {
515        let pass = OuterJoinSimplification;
516
517        // IS NULL does NOT reject NULL, so should stay as LEFT JOIN
518        let plan = LogicalPlan::filter(
519            LogicalPlan::left_join(
520                LogicalPlan::scan("users"),
521                LogicalPlan::scan("orders"),
522                Expr::eq(
523                    Expr::column("users", "id", 0),
524                    Expr::column("orders", "user_id", 0),
525                ),
526            ),
527            Expr::is_null(Expr::column("orders", "id", 0)),
528        );
529
530        let optimized = pass.optimize(plan);
531
532        if let LogicalPlan::Filter { input, .. } = optimized {
533            if let LogicalPlan::Join { join_type, .. } = *input {
534                assert_eq!(join_type, JoinType::LeftOuter);
535            } else {
536                panic!("Expected Join");
537            }
538        } else {
539            panic!("Expected Filter");
540        }
541    }
542
543    #[test]
544    fn test_right_join_to_inner() {
545        let pass = OuterJoinSimplification;
546
547        // Filter on LEFT side of RIGHT JOIN should convert to INNER
548        let plan = LogicalPlan::filter(
549            LogicalPlan::Join {
550                left: Box::new(LogicalPlan::scan("users")),
551                right: Box::new(LogicalPlan::scan("orders")),
552                condition: Expr::eq(
553                    Expr::column("users", "id", 0),
554                    Expr::column("orders", "user_id", 0),
555                ),
556                join_type: JoinType::RightOuter,
557            },
558            Expr::eq(Expr::column("users", "active", 1), Expr::literal(true)),
559        );
560
561        let optimized = pass.optimize(plan);
562
563        if let LogicalPlan::Filter { input, .. } = optimized {
564            if let LogicalPlan::Join { join_type, .. } = *input {
565                assert_eq!(join_type, JoinType::Inner);
566            } else {
567                panic!("Expected Join");
568            }
569        } else {
570            panic!("Expected Filter");
571        }
572    }
573
574    #[test]
575    fn test_and_predicate_rejects_null() {
576        let pass = OuterJoinSimplification;
577
578        // Filter(orders.amount > 100 AND orders.status = 'active') -> LeftJoin
579        // Either condition rejects NULL, so should convert
580        let plan = LogicalPlan::filter(
581            LogicalPlan::left_join(
582                LogicalPlan::scan("users"),
583                LogicalPlan::scan("orders"),
584                Expr::eq(
585                    Expr::column("users", "id", 0),
586                    Expr::column("orders", "user_id", 0),
587                ),
588            ),
589            Expr::and(
590                Expr::gt(Expr::column("orders", "amount", 1), Expr::literal(100i64)),
591                Expr::eq(
592                    Expr::column("orders", "status", 2),
593                    Expr::literal("active"),
594                ),
595            ),
596        );
597
598        let optimized = pass.optimize(plan);
599
600        if let LogicalPlan::Filter { input, .. } = optimized {
601            if let LogicalPlan::Join { join_type, .. } = *input {
602                assert_eq!(join_type, JoinType::Inner);
603            } else {
604                panic!("Expected Join");
605            }
606        } else {
607            panic!("Expected Filter");
608        }
609    }
610
611    #[test]
612    fn test_inner_join_unchanged() {
613        let pass = OuterJoinSimplification;
614
615        // Inner join should remain unchanged
616        let plan = LogicalPlan::filter(
617            LogicalPlan::inner_join(
618                LogicalPlan::scan("users"),
619                LogicalPlan::scan("orders"),
620                Expr::eq(
621                    Expr::column("users", "id", 0),
622                    Expr::column("orders", "user_id", 0),
623                ),
624            ),
625            Expr::gt(Expr::column("orders", "amount", 1), Expr::literal(100i64)),
626        );
627
628        let optimized = pass.optimize(plan);
629
630        if let LogicalPlan::Filter { input, .. } = optimized {
631            if let LogicalPlan::Join { join_type, .. } = *input {
632                assert_eq!(join_type, JoinType::Inner);
633            } else {
634                panic!("Expected Join");
635            }
636        } else {
637            panic!("Expected Filter");
638        }
639    }
640
641    #[test]
642    fn test_nested_joins() {
643        let pass = OuterJoinSimplification;
644
645        // Nested joins should be processed recursively
646        let inner_join = LogicalPlan::left_join(
647            LogicalPlan::scan("orders"),
648            LogicalPlan::scan("items"),
649            Expr::eq(
650                Expr::column("orders", "id", 0),
651                Expr::column("items", "order_id", 0),
652            ),
653        );
654
655        let plan = LogicalPlan::filter(
656            LogicalPlan::left_join(
657                LogicalPlan::scan("users"),
658                inner_join,
659                Expr::eq(
660                    Expr::column("users", "id", 0),
661                    Expr::column("orders", "user_id", 0),
662                ),
663            ),
664            Expr::gt(Expr::column("orders", "amount", 1), Expr::literal(100i64)),
665        );
666
667        let optimized = pass.optimize(plan);
668
669        // The outer join should be converted because predicate references orders
670        if let LogicalPlan::Filter { input, .. } = optimized {
671            if let LogicalPlan::Join { join_type, .. } = *input {
672                assert_eq!(join_type, JoinType::Inner);
673            } else {
674                panic!("Expected Join");
675            }
676        } else {
677            panic!("Expected Filter");
678        }
679    }
680}