Skip to main content

chryso_optimizer/
rules.rs

1use crate::utils::{
2    collect_identifiers, collect_tables, combine_conjuncts, split_conjuncts, table_prefix,
3};
4use chryso_core::ast::{BinaryOperator, Expr, Literal};
5use chryso_planner::LogicalPlan;
6use std::collections::BTreeSet;
7
8pub trait Rule {
9    fn name(&self) -> &str;
10    fn apply(&self, plan: &LogicalPlan, _ctx: &mut RuleContext) -> Vec<LogicalPlan>;
11}
12
13#[derive(Default, Debug)]
14pub struct RuleContext {
15    literal_conflicts: BTreeSet<(String, String)>,
16}
17
18impl RuleContext {
19    pub fn record_literal_conflicts<I>(&mut self, pairs: I)
20    where
21        I: IntoIterator<Item = (String, String)>,
22    {
23        self.literal_conflicts.extend(pairs);
24    }
25
26    pub fn take_literal_conflicts(&mut self) -> Vec<(String, String)> {
27        // Preserve a deterministic order for trace output by draining the ordered set.
28        std::mem::take(&mut self.literal_conflicts)
29            .into_iter()
30            .collect()
31    }
32}
33
34pub struct RuleSet {
35    rules: Vec<Box<dyn Rule + Send + Sync>>,
36}
37
38impl RuleSet {
39    pub fn new() -> Self {
40        Self { rules: Vec::new() }
41    }
42
43    pub fn with_rule(mut self, rule: impl Rule + Send + Sync + 'static) -> Self {
44        self.rules.push(Box::new(rule));
45        self
46    }
47
48    pub fn apply_all(&self, plan: &LogicalPlan, ctx: &mut RuleContext) -> Vec<LogicalPlan> {
49        let mut results = Vec::new();
50        for rule in &self.rules {
51            results.extend(rule.apply(plan, ctx));
52        }
53        results
54    }
55
56    pub fn iter(&self) -> impl Iterator<Item = &Box<dyn Rule + Send + Sync>> {
57        self.rules.iter()
58    }
59}
60
61impl Default for RuleSet {
62    fn default() -> Self {
63        RuleSet::new()
64            .with_rule(MergeFilters)
65            .with_rule(PruneProjection)
66            .with_rule(MergeProjections)
67            .with_rule(RemoveTrueFilter)
68            .with_rule(FilterPushdown)
69            .with_rule(FilterJoinPushdown)
70            .with_rule(PredicateInference)
71            .with_rule(JoinPredicatePushdown)
72            .with_rule(FilterOrDedup)
73            .with_rule(NormalizePredicates)
74            .with_rule(JoinCommute)
75            .with_rule(AggregatePredicatePushdown)
76            .with_rule(LimitPushdown)
77            .with_rule(TopNRule)
78    }
79}
80
81impl RuleSet {
82    pub fn detect_conflicts(&self) -> Vec<String> {
83        let mut seen = std::collections::HashMap::new();
84        for rule in &self.rules {
85            *seen.entry(rule.name().to_string()).or_insert(0usize) += 1;
86        }
87        seen.into_iter()
88            .filter_map(|(name, count)| if count > 1 { Some(name) } else { None })
89            .collect()
90    }
91}
92
93impl std::fmt::Debug for RuleSet {
94    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95        f.debug_struct("RuleSet")
96            .field("rule_count", &self.rules.len())
97            .finish()
98    }
99}
100
101pub struct NoopRule;
102
103impl Rule for NoopRule {
104    fn name(&self) -> &str {
105        "noop"
106    }
107
108    fn apply(&self, _plan: &LogicalPlan, _ctx: &mut RuleContext) -> Vec<LogicalPlan> {
109        Vec::new()
110    }
111}
112
113pub struct MergeFilters;
114
115impl Rule for MergeFilters {
116    fn name(&self) -> &str {
117        "merge_filters"
118    }
119
120    fn apply(&self, plan: &LogicalPlan, _ctx: &mut RuleContext) -> Vec<LogicalPlan> {
121        let LogicalPlan::Filter { predicate, input } = plan else {
122            return Vec::new();
123        };
124        let LogicalPlan::Filter {
125            predicate: inner_predicate,
126            input: inner_input,
127        } = input.as_ref()
128        else {
129            return Vec::new();
130        };
131        let merged = LogicalPlan::Filter {
132            predicate: chryso_core::ast::Expr::BinaryOp {
133                left: Box::new(inner_predicate.clone()),
134                op: BinaryOperator::And,
135                right: Box::new(predicate.clone()),
136            },
137            input: inner_input.clone(),
138        };
139        vec![merged]
140    }
141}
142
143pub struct PruneProjection;
144
145impl Rule for PruneProjection {
146    fn name(&self) -> &str {
147        "prune_projection"
148    }
149
150    fn apply(&self, plan: &LogicalPlan, _ctx: &mut RuleContext) -> Vec<LogicalPlan> {
151        let LogicalPlan::Projection { exprs, input } = plan else {
152            return Vec::new();
153        };
154        if exprs.len() == 1 && matches!(exprs[0], Expr::Wildcard) {
155            return vec![(*input.clone())];
156        }
157        Vec::new()
158    }
159}
160
161pub struct MergeProjections;
162
163impl Rule for MergeProjections {
164    fn name(&self) -> &str {
165        "merge_projections"
166    }
167
168    fn apply(&self, plan: &LogicalPlan, _ctx: &mut RuleContext) -> Vec<LogicalPlan> {
169        let LogicalPlan::Projection { exprs, input } = plan else {
170            return Vec::new();
171        };
172        let LogicalPlan::Projection {
173            exprs: inner_exprs,
174            input: inner_input,
175        } = input.as_ref()
176        else {
177            return Vec::new();
178        };
179        if projection_subset(exprs, inner_exprs) {
180            return vec![LogicalPlan::Projection {
181                exprs: exprs.clone(),
182                input: inner_input.clone(),
183            }];
184        }
185        Vec::new()
186    }
187}
188
189fn projection_subset(outer: &[Expr], inner: &[Expr]) -> bool {
190    let inner_names = inner
191        .iter()
192        .filter_map(|expr| match expr {
193            Expr::Identifier(name) => Some(name),
194            _ => None,
195        })
196        .collect::<std::collections::HashSet<_>>();
197    if inner_names.is_empty() {
198        return false;
199    }
200    outer.iter().all(|expr| match expr {
201        Expr::Identifier(name) => inner_names.contains(name),
202        Expr::Wildcard => true,
203        _ => false,
204    })
205}
206
207pub struct RemoveTrueFilter;
208
209impl Rule for RemoveTrueFilter {
210    fn name(&self) -> &str {
211        "remove_true_filter"
212    }
213
214    fn apply(&self, plan: &LogicalPlan, _ctx: &mut RuleContext) -> Vec<LogicalPlan> {
215        let LogicalPlan::Filter { predicate, input } = plan else {
216            return Vec::new();
217        };
218        match predicate {
219            Expr::Literal(chryso_core::ast::Literal::Bool(true)) => vec![*input.clone()],
220            _ => Vec::new(),
221        }
222    }
223}
224
225pub struct FilterPushdown;
226
227impl Rule for FilterPushdown {
228    fn name(&self) -> &str {
229        "filter_pushdown"
230    }
231
232    fn apply(&self, plan: &LogicalPlan, _ctx: &mut RuleContext) -> Vec<LogicalPlan> {
233        let LogicalPlan::Filter { predicate, input } = plan else {
234            return Vec::new();
235        };
236        let LogicalPlan::Projection { exprs, input } = input.as_ref() else {
237            if let LogicalPlan::Sort { order_by, input } = input.as_ref() {
238                return vec![LogicalPlan::Sort {
239                    order_by: order_by.clone(),
240                    input: Box::new(LogicalPlan::Filter {
241                        predicate: predicate.clone(),
242                        input: input.clone(),
243                    }),
244                }];
245            }
246            return Vec::new();
247        };
248        if !projection_is_passthrough(exprs) {
249            return Vec::new();
250        }
251        let pushed = LogicalPlan::Projection {
252            exprs: exprs.clone(),
253            input: Box::new(LogicalPlan::Filter {
254                predicate: predicate.clone(),
255                input: input.clone(),
256            }),
257        };
258        vec![pushed]
259    }
260}
261
262pub struct FilterJoinPushdown;
263
264impl Rule for FilterJoinPushdown {
265    fn name(&self) -> &str {
266        "filter_join_pushdown"
267    }
268
269    fn apply(&self, plan: &LogicalPlan, _ctx: &mut RuleContext) -> Vec<LogicalPlan> {
270        let LogicalPlan::Filter { predicate, input } = plan else {
271            return Vec::new();
272        };
273        let LogicalPlan::Join {
274            join_type,
275            left,
276            right,
277            on,
278        } = input.as_ref()
279        else {
280            return Vec::new();
281        };
282        if !matches!(join_type, chryso_core::ast::JoinType::Inner) {
283            return Vec::new();
284        }
285
286        let left_tables = collect_tables(left.as_ref());
287        let right_tables = collect_tables(right.as_ref());
288        let mut left_preds = Vec::new();
289        let mut right_preds = Vec::new();
290        let mut remaining = Vec::new();
291
292        for conjunct in split_conjuncts(predicate) {
293            let idents = collect_identifiers(&conjunct);
294            if idents.is_empty() {
295                remaining.push(conjunct);
296                continue;
297            }
298            let mut side = None;
299            let mut ambiguous = false;
300            for ident in &idents {
301                let Some(prefix) = table_prefix(ident) else {
302                    ambiguous = true;
303                    break;
304                };
305                let in_left = left_tables.contains(prefix);
306                let in_right = right_tables.contains(prefix);
307                if in_left && !in_right {
308                    side = match side {
309                        None => Some(Side::Left),
310                        Some(Side::Left) => Some(Side::Left),
311                        Some(Side::Right) => {
312                            ambiguous = true;
313                            break;
314                        }
315                    };
316                } else if in_right && !in_left {
317                    side = match side {
318                        None => Some(Side::Right),
319                        Some(Side::Right) => Some(Side::Right),
320                        Some(Side::Left) => {
321                            ambiguous = true;
322                            break;
323                        }
324                    };
325                } else {
326                    ambiguous = true;
327                    break;
328                }
329            }
330            if ambiguous {
331                remaining.push(conjunct);
332                continue;
333            }
334            match side {
335                Some(Side::Left) => left_preds.push(conjunct),
336                Some(Side::Right) => right_preds.push(conjunct),
337                None => remaining.push(conjunct),
338            }
339        }
340
341        if left_preds.is_empty() && right_preds.is_empty() && remaining.is_empty() {
342            return Vec::new();
343        }
344
345        let new_left = if let Some(expr) = combine_conjuncts(left_preds) {
346            LogicalPlan::Filter {
347                predicate: expr,
348                input: left.clone(),
349            }
350        } else {
351            *left.clone()
352        };
353        let new_right = if let Some(expr) = combine_conjuncts(right_preds) {
354            LogicalPlan::Filter {
355                predicate: expr,
356                input: right.clone(),
357            }
358        } else {
359            *right.clone()
360        };
361        let new_on = if let Some(expr) = combine_conjuncts(remaining) {
362            Expr::BinaryOp {
363                left: Box::new(on.clone()),
364                op: BinaryOperator::And,
365                right: Box::new(expr),
366            }
367        } else {
368            on.clone()
369        };
370        let joined = LogicalPlan::Join {
371            join_type: *join_type,
372            left: Box::new(new_left),
373            right: Box::new(new_right),
374            on: new_on,
375        };
376        vec![joined]
377    }
378}
379
380pub struct JoinPredicatePushdown;
381
382impl Rule for JoinPredicatePushdown {
383    fn name(&self) -> &str {
384        "join_predicate_pushdown"
385    }
386
387    fn apply(&self, plan: &LogicalPlan, _ctx: &mut RuleContext) -> Vec<LogicalPlan> {
388        let LogicalPlan::Join {
389            join_type,
390            left,
391            right,
392            on,
393        } = plan
394        else {
395            return Vec::new();
396        };
397        if !matches!(join_type, chryso_core::ast::JoinType::Inner) {
398            return Vec::new();
399        }
400
401        let left_tables = collect_tables(left.as_ref());
402        let right_tables = collect_tables(right.as_ref());
403        let mut left_preds = Vec::new();
404        let mut right_preds = Vec::new();
405        let mut remaining = Vec::new();
406
407        for conjunct in split_conjuncts(on) {
408            let idents = collect_identifiers(&conjunct);
409            if idents.is_empty() {
410                remaining.push(conjunct);
411                continue;
412            }
413            let mut side = None;
414            let mut ambiguous = false;
415            for ident in &idents {
416                let Some(prefix) = table_prefix(ident) else {
417                    ambiguous = true;
418                    break;
419                };
420                let in_left = left_tables.contains(prefix);
421                let in_right = right_tables.contains(prefix);
422                if in_left && !in_right {
423                    side = match side {
424                        None => Some(Side::Left),
425                        Some(Side::Left) => Some(Side::Left),
426                        Some(Side::Right) => {
427                            ambiguous = true;
428                            break;
429                        }
430                    };
431                } else if in_right && !in_left {
432                    side = match side {
433                        None => Some(Side::Right),
434                        Some(Side::Right) => Some(Side::Right),
435                        Some(Side::Left) => {
436                            ambiguous = true;
437                            break;
438                        }
439                    };
440                } else {
441                    ambiguous = true;
442                    break;
443                }
444            }
445            if ambiguous {
446                remaining.push(conjunct);
447                continue;
448            }
449            match side {
450                Some(Side::Left) => left_preds.push(conjunct),
451                Some(Side::Right) => right_preds.push(conjunct),
452                None => remaining.push(conjunct),
453            }
454        }
455
456        if left_preds.is_empty() && right_preds.is_empty() {
457            return Vec::new();
458        }
459
460        let new_left = if let Some(expr) = combine_conjuncts(left_preds) {
461            LogicalPlan::Filter {
462                predicate: expr,
463                input: left.clone(),
464            }
465        } else {
466            *left.clone()
467        };
468        let new_right = if let Some(expr) = combine_conjuncts(right_preds) {
469            LogicalPlan::Filter {
470                predicate: expr,
471                input: right.clone(),
472            }
473        } else {
474            *right.clone()
475        };
476        let new_on =
477            combine_conjuncts(remaining).unwrap_or_else(|| Expr::Literal(Literal::Bool(true)));
478        vec![LogicalPlan::Join {
479            join_type: *join_type,
480            left: Box::new(new_left),
481            right: Box::new(new_right),
482            on: new_on,
483        }]
484    }
485}
486
487fn projection_is_passthrough(exprs: &[Expr]) -> bool {
488    exprs.iter().all(|expr| match expr {
489        Expr::Identifier(_) | Expr::Wildcard => true,
490        _ => false,
491    })
492}
493
494pub struct NormalizePredicates;
495
496impl Rule for NormalizePredicates {
497    fn name(&self) -> &str {
498        "normalize_predicates"
499    }
500
501    fn apply(&self, plan: &LogicalPlan, _ctx: &mut RuleContext) -> Vec<LogicalPlan> {
502        let LogicalPlan::Filter { predicate, input } = plan else {
503            return Vec::new();
504        };
505        let normalized = predicate.normalize();
506        if normalized.structural_eq(predicate) {
507            return Vec::new();
508        }
509        vec![LogicalPlan::Filter {
510            predicate: normalized,
511            input: input.clone(),
512        }]
513    }
514}
515
516pub struct PredicateInference;
517
518impl Rule for PredicateInference {
519    fn name(&self) -> &str {
520        "predicate_inference"
521    }
522
523    fn apply(&self, plan: &LogicalPlan, ctx: &mut RuleContext) -> Vec<LogicalPlan> {
524        match plan {
525            LogicalPlan::Filter { predicate, input } => match input.as_ref() {
526                LogicalPlan::Join {
527                    join_type,
528                    left,
529                    right,
530                    on,
531                } if matches!(join_type, chryso_core::ast::JoinType::Inner) => {
532                    let combined = Expr::BinaryOp {
533                        left: Box::new(predicate.clone()),
534                        op: BinaryOperator::And,
535                        right: Box::new(on.clone()),
536                    };
537                    let (inferred, changed) = infer_predicates(&combined, ctx);
538                    if !changed {
539                        return Vec::new();
540                    }
541                    let (filter_predicates, join_predicates) =
542                        split_predicates_by_source(&inferred, predicate, on);
543                    let join_expr = combine_conjuncts(join_predicates)
544                        .unwrap_or_else(|| Expr::Literal(Literal::Bool(true)));
545                    let join_plan = LogicalPlan::Join {
546                        join_type: *join_type,
547                        left: left.clone(),
548                        right: right.clone(),
549                        on: join_expr,
550                    };
551                    let filter_expr = combine_conjuncts(filter_predicates)
552                        .unwrap_or_else(|| Expr::Literal(Literal::Bool(true)));
553                    vec![LogicalPlan::Filter {
554                        predicate: filter_expr,
555                        input: Box::new(join_plan),
556                    }]
557                }
558                _ => {
559                    let (predicate, changed) = infer_predicates(predicate, ctx);
560                    if !changed {
561                        return Vec::new();
562                    }
563                    vec![LogicalPlan::Filter {
564                        predicate,
565                        input: input.clone(),
566                    }]
567                }
568            },
569            LogicalPlan::Join {
570                join_type,
571                left,
572                right,
573                on,
574            } if matches!(join_type, chryso_core::ast::JoinType::Inner) => {
575                let (on, changed) = infer_predicates(on, ctx);
576                if !changed {
577                    return Vec::new();
578                }
579                vec![LogicalPlan::Join {
580                    join_type: *join_type,
581                    left: left.clone(),
582                    right: right.clone(),
583                    on,
584                }]
585            }
586            _ => Vec::new(),
587        }
588    }
589}
590
591pub struct JoinCommute;
592
593impl Rule for JoinCommute {
594    fn name(&self) -> &str {
595        "join_commute"
596    }
597
598    fn apply(&self, plan: &LogicalPlan, _ctx: &mut RuleContext) -> Vec<LogicalPlan> {
599        let LogicalPlan::Join {
600            join_type,
601            left,
602            right,
603            on,
604        } = plan
605        else {
606            return Vec::new();
607        };
608        if !matches!(join_type, chryso_core::ast::JoinType::Inner) {
609            return Vec::new();
610        }
611        vec![LogicalPlan::Join {
612            join_type: *join_type,
613            left: right.clone(),
614            right: left.clone(),
615            on: on.clone(),
616        }]
617    }
618}
619
620pub struct FilterOrDedup;
621
622impl Rule for FilterOrDedup {
623    fn name(&self) -> &str {
624        "filter_or_dedup"
625    }
626
627    fn apply(&self, plan: &LogicalPlan, _ctx: &mut RuleContext) -> Vec<LogicalPlan> {
628        let LogicalPlan::Filter { predicate, input } = plan else {
629            return Vec::new();
630        };
631        let Expr::BinaryOp { left, op, right } = predicate else {
632            return Vec::new();
633        };
634        if !matches!(op, BinaryOperator::Or) {
635            return Vec::new();
636        }
637        if left.structural_eq(right) {
638            return vec![LogicalPlan::Filter {
639                predicate: (*left.clone()),
640                input: input.clone(),
641            }];
642        }
643        Vec::new()
644    }
645}
646
647pub struct LimitPushdown;
648
649impl Rule for LimitPushdown {
650    fn name(&self) -> &str {
651        "limit_pushdown"
652    }
653
654    fn apply(&self, plan: &LogicalPlan, _ctx: &mut RuleContext) -> Vec<LogicalPlan> {
655        let LogicalPlan::Limit {
656            limit,
657            offset,
658            input,
659        } = plan
660        else {
661            return Vec::new();
662        };
663        if offset.is_some() {
664            return Vec::new();
665        }
666        let inner = input.as_ref();
667        match inner {
668            LogicalPlan::Filter { predicate, input } => vec![LogicalPlan::Filter {
669                predicate: predicate.clone(),
670                input: Box::new(LogicalPlan::Limit {
671                    limit: *limit,
672                    offset: *offset,
673                    input: input.clone(),
674                }),
675            }],
676            LogicalPlan::Projection { exprs, input } => vec![LogicalPlan::Projection {
677                exprs: exprs.clone(),
678                input: Box::new(LogicalPlan::Limit {
679                    limit: *limit,
680                    offset: *offset,
681                    input: input.clone(),
682                }),
683            }],
684            _ => Vec::new(),
685        }
686    }
687}
688
689pub struct TopNRule;
690
691impl Rule for TopNRule {
692    fn name(&self) -> &str {
693        "topn_rule"
694    }
695
696    fn apply(&self, plan: &LogicalPlan, _ctx: &mut RuleContext) -> Vec<LogicalPlan> {
697        let LogicalPlan::Limit {
698            limit: Some(limit),
699            offset: None,
700            input,
701        } = plan
702        else {
703            return Vec::new();
704        };
705        let LogicalPlan::Sort { order_by, input } = input.as_ref() else {
706            return Vec::new();
707        };
708        vec![LogicalPlan::TopN {
709            order_by: order_by.clone(),
710            limit: *limit,
711            input: input.clone(),
712        }]
713    }
714}
715
716#[cfg(test)]
717mod tests {
718    use super::{
719        FilterJoinPushdown, FilterOrDedup, FilterPushdown, JoinPredicatePushdown, LimitPushdown,
720        MergeFilters, MergeProjections, NormalizePredicates, PredicateInference, PruneProjection,
721        RemoveTrueFilter, Rule, TopNRule,
722    };
723    use crate::utils::split_conjuncts;
724    use chryso_core::ast::{BinaryOperator, Expr};
725    use chryso_planner::LogicalPlan;
726
727    fn apply(rule: &impl Rule, plan: &LogicalPlan) -> Vec<LogicalPlan> {
728        let mut ctx = super::RuleContext::default();
729        rule.apply(plan, &mut ctx)
730    }
731
732    #[test]
733    fn merge_filters_combines_predicates() {
734        let plan = LogicalPlan::Filter {
735            predicate: Expr::Identifier("a".to_string()),
736            input: Box::new(LogicalPlan::Filter {
737                predicate: Expr::Identifier("b".to_string()),
738                input: Box::new(LogicalPlan::Scan {
739                    table: "t".to_string(),
740                }),
741            }),
742        };
743        let rule = MergeFilters;
744        let results = apply(&rule, &plan);
745        assert_eq!(results.len(), 1);
746        let LogicalPlan::Filter { predicate, .. } = &results[0] else {
747            panic!("expected filter");
748        };
749        let Expr::BinaryOp { op, .. } = predicate else {
750            panic!("expected binary op");
751        };
752        assert!(matches!(op, BinaryOperator::And));
753    }
754
755    #[test]
756    fn prune_projection_removes_star() {
757        let plan = LogicalPlan::Projection {
758            exprs: vec![Expr::Wildcard],
759            input: Box::new(LogicalPlan::Scan {
760                table: "t".to_string(),
761            }),
762        };
763        let rule = PruneProjection;
764        let results = apply(&rule, &plan);
765        assert_eq!(results.len(), 1);
766    }
767
768    #[test]
769    fn filter_pushdown_under_projection() {
770        let plan = LogicalPlan::Filter {
771            predicate: Expr::Identifier("x".to_string()),
772            input: Box::new(LogicalPlan::Projection {
773                exprs: vec![Expr::Identifier("x".to_string())],
774                input: Box::new(LogicalPlan::Scan {
775                    table: "t".to_string(),
776                }),
777            }),
778        };
779        let rule = FilterPushdown;
780        let results = apply(&rule, &plan);
781        assert_eq!(results.len(), 1);
782    }
783
784    #[test]
785    fn filter_pushdown_under_sort() {
786        let plan = LogicalPlan::Filter {
787            predicate: Expr::Identifier("x".to_string()),
788            input: Box::new(LogicalPlan::Sort {
789                order_by: vec![chryso_core::ast::OrderByExpr {
790                    expr: Expr::Identifier("id".to_string()),
791                    asc: true,
792                    nulls_first: None,
793                }],
794                input: Box::new(LogicalPlan::Scan {
795                    table: "t".to_string(),
796                }),
797            }),
798        };
799        let rule = FilterPushdown;
800        let results = apply(&rule, &plan);
801        assert_eq!(results.len(), 1);
802        let LogicalPlan::Sort { input, .. } = &results[0] else {
803            panic!("expected sort");
804        };
805        assert!(matches!(input.as_ref(), LogicalPlan::Filter { .. }));
806    }
807
808    #[test]
809    fn normalize_predicates_orders_and() {
810        let plan = LogicalPlan::Filter {
811            predicate: Expr::BinaryOp {
812                left: Box::new(Expr::Identifier("b".to_string())),
813                op: BinaryOperator::And,
814                right: Box::new(Expr::Identifier("a".to_string())),
815            },
816            input: Box::new(LogicalPlan::Scan {
817                table: "t".to_string(),
818            }),
819        };
820        let rule = NormalizePredicates;
821        let results = apply(&rule, &plan);
822        assert_eq!(results.len(), 1);
823    }
824
825    #[test]
826    fn normalize_predicates_removes_true_and() {
827        let plan = LogicalPlan::Filter {
828            predicate: Expr::BinaryOp {
829                left: Box::new(Expr::Literal(chryso_core::ast::Literal::Bool(true))),
830                op: BinaryOperator::And,
831                right: Box::new(Expr::Identifier("x".to_string())),
832            },
833            input: Box::new(LogicalPlan::Scan {
834                table: "t".to_string(),
835            }),
836        };
837        let rule = NormalizePredicates;
838        let results = apply(&rule, &plan);
839        assert_eq!(results.len(), 1);
840        let LogicalPlan::Filter { predicate, .. } = &results[0] else {
841            panic!("expected filter");
842        };
843        assert_eq!(predicate.to_sql(), "x");
844    }
845
846    #[test]
847    fn normalize_predicates_removes_nested_true_and() {
848        let plan = LogicalPlan::Filter {
849            predicate: Expr::BinaryOp {
850                left: Box::new(Expr::Literal(chryso_core::ast::Literal::Bool(true))),
851                op: BinaryOperator::And,
852                right: Box::new(Expr::BinaryOp {
853                    left: Box::new(Expr::Literal(chryso_core::ast::Literal::Bool(true))),
854                    op: BinaryOperator::And,
855                    right: Box::new(Expr::Identifier("x".to_string())),
856                }),
857            },
858            input: Box::new(LogicalPlan::Scan {
859                table: "t".to_string(),
860            }),
861        };
862        let rule = NormalizePredicates;
863        let results = apply(&rule, &plan);
864        assert_eq!(results.len(), 1);
865        let LogicalPlan::Filter { predicate, .. } = &results[0] else {
866            panic!("expected filter");
867        };
868        assert_eq!(predicate.to_sql(), "x");
869    }
870
871    #[test]
872    fn detect_rule_conflicts() {
873        let rules = crate::rules::RuleSet::new()
874            .with_rule(MergeFilters)
875            .with_rule(MergeFilters);
876        let conflicts = rules.detect_conflicts();
877        assert_eq!(conflicts, vec!["merge_filters".to_string()]);
878    }
879
880    #[test]
881    fn topn_rule_rewrites_sort_limit() {
882        let plan = LogicalPlan::Limit {
883            limit: Some(10),
884            offset: None,
885            input: Box::new(LogicalPlan::Sort {
886                order_by: vec![chryso_core::ast::OrderByExpr {
887                    expr: Expr::Identifier("id".to_string()),
888                    asc: true,
889                    nulls_first: None,
890                }],
891                input: Box::new(LogicalPlan::Scan {
892                    table: "t".to_string(),
893                }),
894            }),
895        };
896        let rule = TopNRule;
897        let results = apply(&rule, &plan);
898        assert_eq!(results.len(), 1);
899    }
900
901    #[test]
902    fn limit_pushdown_under_projection() {
903        let plan = LogicalPlan::Limit {
904            limit: Some(5),
905            offset: None,
906            input: Box::new(LogicalPlan::Projection {
907                exprs: vec![Expr::Identifier("id".to_string())],
908                input: Box::new(LogicalPlan::Scan {
909                    table: "t".to_string(),
910                }),
911            }),
912        };
913        let rule = LimitPushdown;
914        let results = apply(&rule, &plan);
915        assert_eq!(results.len(), 1);
916    }
917
918    #[test]
919    fn merge_projections_keeps_outer() {
920        let plan = LogicalPlan::Projection {
921            exprs: vec![Expr::Identifier("id".to_string())],
922            input: Box::new(LogicalPlan::Projection {
923                exprs: vec![
924                    Expr::Identifier("id".to_string()),
925                    Expr::Identifier("name".to_string()),
926                ],
927                input: Box::new(LogicalPlan::Scan {
928                    table: "t".to_string(),
929                }),
930            }),
931        };
932        let rule = MergeProjections;
933        let results = apply(&rule, &plan);
934        assert_eq!(results.len(), 1);
935        let LogicalPlan::Projection { exprs, .. } = &results[0] else {
936            panic!("expected projection");
937        };
938        assert_eq!(exprs.len(), 1);
939    }
940
941    #[test]
942    fn remove_true_filter() {
943        let plan = LogicalPlan::Filter {
944            predicate: Expr::Literal(chryso_core::ast::Literal::Bool(true)),
945            input: Box::new(LogicalPlan::Scan {
946                table: "t".to_string(),
947            }),
948        };
949        let rule = RemoveTrueFilter;
950        let results = apply(&rule, &plan);
951        assert_eq!(results.len(), 1);
952        assert!(matches!(results[0], LogicalPlan::Scan { .. }));
953    }
954
955    #[test]
956    fn remove_true_filter_ignores_binary_predicate() {
957        let plan = LogicalPlan::Filter {
958            predicate: Expr::BinaryOp {
959                left: Box::new(Expr::Literal(chryso_core::ast::Literal::Bool(true))),
960                op: BinaryOperator::And,
961                right: Box::new(Expr::Identifier("x".to_string())),
962            },
963            input: Box::new(LogicalPlan::Scan {
964                table: "t".to_string(),
965            }),
966        };
967        let rule = RemoveTrueFilter;
968        let results = apply(&rule, &plan);
969        assert!(results.is_empty());
970    }
971
972    #[test]
973    fn filter_join_pushdown_left() {
974        let plan = LogicalPlan::Filter {
975            predicate: Expr::BinaryOp {
976                left: Box::new(Expr::Identifier("t1.id".to_string())),
977                op: BinaryOperator::Eq,
978                right: Box::new(Expr::Literal(chryso_core::ast::Literal::Number(1.0))),
979            },
980            input: Box::new(LogicalPlan::Join {
981                join_type: chryso_core::ast::JoinType::Inner,
982                left: Box::new(LogicalPlan::Scan {
983                    table: "t1".to_string(),
984                }),
985                right: Box::new(LogicalPlan::Scan {
986                    table: "t2".to_string(),
987                }),
988                on: Expr::Identifier("t1.id = t2.id".to_string()),
989            }),
990        };
991        let rule = FilterJoinPushdown;
992        let results = apply(&rule, &plan);
993        assert_eq!(results.len(), 1);
994        let LogicalPlan::Join { left, .. } = &results[0] else {
995            panic!("expected join");
996        };
997        assert!(matches!(left.as_ref(), LogicalPlan::Filter { .. }));
998    }
999
1000    #[test]
1001    fn filter_join_pushdown_keeps_cross_predicate() {
1002        let plan = LogicalPlan::Filter {
1003            predicate: Expr::BinaryOp {
1004                left: Box::new(Expr::Identifier("t1.id".to_string())),
1005                op: BinaryOperator::Eq,
1006                right: Box::new(Expr::Identifier("t2.id".to_string())),
1007            },
1008            input: Box::new(LogicalPlan::Join {
1009                join_type: chryso_core::ast::JoinType::Inner,
1010                left: Box::new(LogicalPlan::Scan {
1011                    table: "t1".to_string(),
1012                }),
1013                right: Box::new(LogicalPlan::Scan {
1014                    table: "t2".to_string(),
1015                }),
1016                on: Expr::Identifier("t1.id = t2.id".to_string()),
1017            }),
1018        };
1019        let rule = FilterJoinPushdown;
1020        let results = apply(&rule, &plan);
1021        assert_eq!(results.len(), 1);
1022        let LogicalPlan::Join { on, .. } = &results[0] else {
1023            panic!("expected join");
1024        };
1025        let Expr::BinaryOp { op, .. } = on else {
1026            panic!("expected binary op");
1027        };
1028        assert!(matches!(op, BinaryOperator::And));
1029    }
1030
1031    #[test]
1032    fn filter_or_dedup_removes_duplicate_predicates() {
1033        let plan = LogicalPlan::Filter {
1034            predicate: Expr::BinaryOp {
1035                left: Box::new(Expr::Identifier("a".to_string())),
1036                op: BinaryOperator::Or,
1037                right: Box::new(Expr::Identifier("a".to_string())),
1038            },
1039            input: Box::new(LogicalPlan::Scan {
1040                table: "t".to_string(),
1041            }),
1042        };
1043        let rule = FilterOrDedup;
1044        let results = apply(&rule, &plan);
1045        assert_eq!(results.len(), 1);
1046        let LogicalPlan::Filter { predicate, .. } = &results[0] else {
1047            panic!("expected filter");
1048        };
1049        assert!(matches!(predicate, Expr::Identifier(name) if name == "a"));
1050    }
1051
1052    #[test]
1053    fn join_predicate_pushdown_splits_single_side() {
1054        let plan = LogicalPlan::Join {
1055            join_type: chryso_core::ast::JoinType::Inner,
1056            left: Box::new(LogicalPlan::Scan {
1057                table: "t1".to_string(),
1058            }),
1059            right: Box::new(LogicalPlan::Scan {
1060                table: "t2".to_string(),
1061            }),
1062            on: Expr::BinaryOp {
1063                left: Box::new(Expr::BinaryOp {
1064                    left: Box::new(Expr::Identifier("t1.flag".to_string())),
1065                    op: BinaryOperator::Eq,
1066                    right: Box::new(Expr::Literal(chryso_core::ast::Literal::Bool(true))),
1067                }),
1068                op: BinaryOperator::And,
1069                right: Box::new(Expr::BinaryOp {
1070                    left: Box::new(Expr::Identifier("t1.id".to_string())),
1071                    op: BinaryOperator::Eq,
1072                    right: Box::new(Expr::Identifier("t2.id".to_string())),
1073                }),
1074            },
1075        };
1076        let rule = JoinPredicatePushdown;
1077        let results = apply(&rule, &plan);
1078        assert_eq!(results.len(), 1);
1079        let LogicalPlan::Join { left, on, .. } = &results[0] else {
1080            panic!("expected join");
1081        };
1082        assert!(matches!(left.as_ref(), LogicalPlan::Filter { .. }));
1083        assert_eq!(on.to_sql(), "t1.id = t2.id");
1084    }
1085
1086    #[test]
1087    fn predicate_inference_adds_literal_equivalence() {
1088        let plan = LogicalPlan::Filter {
1089            predicate: Expr::BinaryOp {
1090                left: Box::new(Expr::BinaryOp {
1091                    left: Box::new(Expr::Identifier("a".to_string())),
1092                    op: BinaryOperator::Eq,
1093                    right: Box::new(Expr::Identifier("b".to_string())),
1094                }),
1095                op: BinaryOperator::And,
1096                right: Box::new(Expr::BinaryOp {
1097                    left: Box::new(Expr::Identifier("a".to_string())),
1098                    op: BinaryOperator::Eq,
1099                    right: Box::new(Expr::Literal(chryso_core::ast::Literal::Number(1.0))),
1100                }),
1101            },
1102            input: Box::new(LogicalPlan::Scan {
1103                table: "t".to_string(),
1104            }),
1105        };
1106        let rule = PredicateInference;
1107        let results = apply(&rule, &plan);
1108        assert_eq!(results.len(), 1);
1109        let LogicalPlan::Filter { predicate, .. } = &results[0] else {
1110            panic!("expected filter");
1111        };
1112        let conjuncts = split_conjuncts(predicate)
1113            .into_iter()
1114            .map(|expr| expr.to_sql())
1115            .collect::<std::collections::HashSet<_>>();
1116        assert!(conjuncts.contains("b = 1"));
1117    }
1118
1119    #[test]
1120    fn predicate_inference_on_join() {
1121        let plan = LogicalPlan::Join {
1122            join_type: chryso_core::ast::JoinType::Inner,
1123            left: Box::new(LogicalPlan::Scan {
1124                table: "t1".to_string(),
1125            }),
1126            right: Box::new(LogicalPlan::Scan {
1127                table: "t2".to_string(),
1128            }),
1129            on: Expr::BinaryOp {
1130                left: Box::new(Expr::BinaryOp {
1131                    left: Box::new(Expr::Identifier("t1.id".to_string())),
1132                    op: BinaryOperator::Eq,
1133                    right: Box::new(Expr::Identifier("t2.id".to_string())),
1134                }),
1135                op: BinaryOperator::And,
1136                right: Box::new(Expr::BinaryOp {
1137                    left: Box::new(Expr::Identifier("t2.id".to_string())),
1138                    op: BinaryOperator::Eq,
1139                    right: Box::new(Expr::Literal(chryso_core::ast::Literal::Number(5.0))),
1140                }),
1141            },
1142        };
1143        let rule = PredicateInference;
1144        let results = apply(&rule, &plan);
1145        assert_eq!(results.len(), 1);
1146        let LogicalPlan::Join { on, .. } = &results[0] else {
1147            panic!("expected join");
1148        };
1149        let conjuncts = split_conjuncts(on)
1150            .into_iter()
1151            .map(|expr| expr.to_sql())
1152            .collect::<std::collections::HashSet<_>>();
1153        assert!(conjuncts.contains("t1.id = 5"));
1154    }
1155
1156    #[test]
1157    fn predicate_inference_enables_join_pushdown() {
1158        let plan = LogicalPlan::Join {
1159            join_type: chryso_core::ast::JoinType::Inner,
1160            left: Box::new(LogicalPlan::Scan {
1161                table: "t1".to_string(),
1162            }),
1163            right: Box::new(LogicalPlan::Scan {
1164                table: "t2".to_string(),
1165            }),
1166            on: Expr::BinaryOp {
1167                left: Box::new(Expr::BinaryOp {
1168                    left: Box::new(Expr::Identifier("t1.id".to_string())),
1169                    op: BinaryOperator::Eq,
1170                    right: Box::new(Expr::Identifier("t2.id".to_string())),
1171                }),
1172                op: BinaryOperator::And,
1173                right: Box::new(Expr::BinaryOp {
1174                    left: Box::new(Expr::Identifier("t1.id".to_string())),
1175                    op: BinaryOperator::Eq,
1176                    right: Box::new(Expr::Literal(chryso_core::ast::Literal::Number(42.0))),
1177                }),
1178            },
1179        };
1180        let inferred = apply(&PredicateInference, &plan);
1181        assert_eq!(inferred.len(), 1);
1182        let pushed = apply(&JoinPredicatePushdown, &inferred[0]);
1183        assert_eq!(pushed.len(), 1);
1184        let LogicalPlan::Join { right, .. } = &pushed[0] else {
1185            panic!("expected join");
1186        };
1187        assert!(matches!(right.as_ref(), LogicalPlan::Filter { .. }));
1188    }
1189
1190    #[test]
1191    fn predicate_inference_propagates_across_filter_join() {
1192        let plan = LogicalPlan::Filter {
1193            predicate: Expr::BinaryOp {
1194                left: Box::new(Expr::Identifier("t1.flag".to_string())),
1195                op: BinaryOperator::Eq,
1196                right: Box::new(Expr::Literal(chryso_core::ast::Literal::Bool(true))),
1197            },
1198            input: Box::new(LogicalPlan::Join {
1199                join_type: chryso_core::ast::JoinType::Inner,
1200                left: Box::new(LogicalPlan::Scan {
1201                    table: "t1".to_string(),
1202                }),
1203                right: Box::new(LogicalPlan::Scan {
1204                    table: "t2".to_string(),
1205                }),
1206                on: Expr::BinaryOp {
1207                    left: Box::new(Expr::BinaryOp {
1208                        left: Box::new(Expr::Identifier("t1.id".to_string())),
1209                        op: BinaryOperator::Eq,
1210                        right: Box::new(Expr::Identifier("t2.id".to_string())),
1211                    }),
1212                    op: BinaryOperator::And,
1213                    right: Box::new(Expr::BinaryOp {
1214                        left: Box::new(Expr::Identifier("t1.id".to_string())),
1215                        op: BinaryOperator::Eq,
1216                        right: Box::new(Expr::Literal(chryso_core::ast::Literal::Number(7.0))),
1217                    }),
1218                },
1219            }),
1220        };
1221        let inferred = apply(&PredicateInference, &plan);
1222        assert_eq!(inferred.len(), 1);
1223        let LogicalPlan::Filter { input, predicate } = &inferred[0] else {
1224            panic!("expected filter");
1225        };
1226        let LogicalPlan::Join { on, .. } = input.as_ref() else {
1227            panic!("expected join");
1228        };
1229        let filter_conjuncts = split_conjuncts(predicate)
1230            .into_iter()
1231            .map(|expr| expr.to_sql())
1232            .collect::<std::collections::HashSet<_>>();
1233        let join_conjuncts = split_conjuncts(on)
1234            .into_iter()
1235            .map(|expr| expr.to_sql())
1236            .collect::<std::collections::HashSet<_>>();
1237        assert!(filter_conjuncts.contains("t1.flag = true"));
1238        assert!(join_conjuncts.contains("t2.id = 7"));
1239    }
1240
1241    #[test]
1242    fn predicate_inference_adds_transitive_equivalence() {
1243        let plan = LogicalPlan::Filter {
1244            predicate: Expr::BinaryOp {
1245                left: Box::new(Expr::BinaryOp {
1246                    left: Box::new(Expr::Identifier("a".to_string())),
1247                    op: BinaryOperator::Eq,
1248                    right: Box::new(Expr::Identifier("b".to_string())),
1249                }),
1250                op: BinaryOperator::And,
1251                right: Box::new(Expr::BinaryOp {
1252                    left: Box::new(Expr::Identifier("b".to_string())),
1253                    op: BinaryOperator::Eq,
1254                    right: Box::new(Expr::Identifier("c".to_string())),
1255                }),
1256            },
1257            input: Box::new(LogicalPlan::Scan {
1258                table: "t".to_string(),
1259            }),
1260        };
1261        let rule = PredicateInference;
1262        let results = apply(&rule, &plan);
1263        assert_eq!(results.len(), 1);
1264        let LogicalPlan::Filter { predicate, .. } = &results[0] else {
1265            panic!("expected filter");
1266        };
1267        let conjuncts = split_conjuncts(predicate)
1268            .into_iter()
1269            .map(|expr| expr.to_sql())
1270            .collect::<std::collections::HashSet<_>>();
1271        assert!(conjuncts.contains("a = c") || conjuncts.contains("c = a"));
1272    }
1273}
1274
1275pub struct AggregatePredicatePushdown;
1276
1277impl Rule for AggregatePredicatePushdown {
1278    fn name(&self) -> &str {
1279        "aggregate_predicate_pushdown"
1280    }
1281
1282    fn apply(&self, plan: &LogicalPlan, _ctx: &mut RuleContext) -> Vec<LogicalPlan> {
1283        let LogicalPlan::Filter { predicate, input } = plan else {
1284            return Vec::new();
1285        };
1286        let LogicalPlan::Aggregate {
1287            group_exprs,
1288            aggr_exprs,
1289            input,
1290        } = input.as_ref()
1291        else {
1292            return Vec::new();
1293        };
1294        let predicate_idents = collect_identifiers(predicate);
1295        let group_idents = group_exprs
1296            .iter()
1297            .flat_map(collect_identifiers)
1298            .collect::<std::collections::HashSet<_>>();
1299        if predicate_idents
1300            .iter()
1301            .all(|ident| group_idents.contains(ident))
1302        {
1303            return vec![LogicalPlan::Aggregate {
1304                group_exprs: group_exprs.clone(),
1305                aggr_exprs: aggr_exprs.clone(),
1306                input: Box::new(LogicalPlan::Filter {
1307                    predicate: predicate.clone(),
1308                    input: input.clone(),
1309                }),
1310            }];
1311        }
1312        Vec::new()
1313    }
1314}
1315
1316#[derive(Clone, Copy)]
1317enum Side {
1318    Left,
1319    Right,
1320}
1321
1322// shared helpers are in utils.rs
1323
1324fn infer_predicates(predicate: &Expr, ctx: &mut RuleContext) -> (Expr, bool) {
1325    // Infer equalities via union-find and surface literal conflicts for trace/debugging.
1326    let conjuncts = split_conjuncts(predicate);
1327    let mut existing = std::collections::HashSet::new();
1328    for expr in &conjuncts {
1329        existing.insert(expr.to_sql());
1330    }
1331
1332    let mut uf = UnionFind::new();
1333    let mut literal_bindings = std::collections::HashMap::<String, Literal>::new();
1334    let mut literal_conflicts = std::collections::HashSet::<String>::new();
1335    let mut conflict_pairs = std::collections::BTreeSet::<(String, String)>::new();
1336
1337    for conjunct in &conjuncts {
1338        let Expr::BinaryOp { left, op, right } = conjunct else {
1339            continue;
1340        };
1341        if !matches!(op, BinaryOperator::Eq) {
1342            continue;
1343        }
1344        match (left.as_ref(), right.as_ref()) {
1345            (Expr::Identifier(left), Expr::Identifier(right)) => {
1346                uf.union(left, right);
1347            }
1348            (Expr::Identifier(ident), Expr::Literal(literal))
1349            | (Expr::Literal(literal), Expr::Identifier(ident)) => {
1350                uf.add(ident);
1351                if let Some(existing) = literal_bindings.get(ident) {
1352                    if !literal_eq(existing, literal) {
1353                        literal_conflicts.insert(ident.clone());
1354                        let left =
1355                            format!("{} = {}", ident, Expr::Literal(existing.clone()).to_sql());
1356                        let right =
1357                            format!("{} = {}", ident, Expr::Literal(literal.clone()).to_sql());
1358                        conflict_pairs.insert((left, right));
1359                    }
1360                } else {
1361                    literal_bindings.insert(ident.clone(), literal.clone());
1362                }
1363            }
1364            _ => {}
1365        }
1366    }
1367
1368    let mut class_literals = std::collections::HashMap::<String, Option<Literal>>::new();
1369    if !literal_conflicts.is_empty() {
1370        ctx.record_literal_conflicts(conflict_pairs);
1371    }
1372    for ident in &literal_conflicts {
1373        let root = uf.find(ident);
1374        class_literals.insert(root, None);
1375    }
1376    for (ident, literal) in &literal_bindings {
1377        if literal_conflicts.contains(ident) {
1378            continue;
1379        }
1380        let root = uf.find(ident);
1381        match class_literals.get(&root) {
1382            None => {
1383                class_literals.insert(root, Some(literal.clone()));
1384            }
1385            Some(Some(existing_literal)) => {
1386                if !literal_eq(existing_literal, literal) {
1387                    class_literals.insert(root, None);
1388                }
1389            }
1390            Some(None) => {}
1391        }
1392    }
1393
1394    let mut groups = std::collections::HashMap::<String, Vec<String>>::new();
1395    for key in uf.keys() {
1396        let root = uf.find(&key);
1397        groups.entry(root).or_default().push(key);
1398    }
1399
1400    let mut inferred = Vec::new();
1401    for members in groups.values() {
1402        if members.len() <= 1 {
1403            continue;
1404        }
1405        let mut members = members.clone();
1406        members.sort();
1407        let canonical = members[0].clone();
1408        for ident in members.into_iter().skip(1) {
1409            let forward = format!("{canonical} = {ident}");
1410            let backward = format!("{ident} = {canonical}");
1411            if existing.contains(&forward) || existing.contains(&backward) {
1412                continue;
1413            }
1414            existing.insert(forward);
1415            inferred.push(Expr::BinaryOp {
1416                left: Box::new(Expr::Identifier(canonical.clone())),
1417                op: BinaryOperator::Eq,
1418                right: Box::new(Expr::Identifier(ident)),
1419            });
1420        }
1421    }
1422
1423    for (root, literal) in class_literals {
1424        let Some(literal) = literal else {
1425            continue;
1426        };
1427        let Some(members) = groups.get(&root) else {
1428            continue;
1429        };
1430        for ident in members {
1431            let expr = Expr::BinaryOp {
1432                left: Box::new(Expr::Identifier(ident.clone())),
1433                op: BinaryOperator::Eq,
1434                right: Box::new(Expr::Literal(literal.clone())),
1435            };
1436            if existing.insert(expr.to_sql()) {
1437                inferred.push(expr);
1438            }
1439        }
1440    }
1441
1442    if inferred.is_empty() {
1443        return (predicate.clone(), false);
1444    }
1445    let mut all = conjuncts;
1446    all.extend(inferred);
1447    (
1448        combine_conjuncts(all).unwrap_or_else(|| predicate.clone()),
1449        true,
1450    )
1451}
1452
1453fn split_predicates_by_source(
1454    combined: &Expr,
1455    filter_predicate: &Expr,
1456    join_predicate: &Expr,
1457) -> (Vec<Expr>, Vec<Expr>) {
1458    let filter_set = split_conjuncts(filter_predicate)
1459        .into_iter()
1460        .map(|expr| expr.to_sql())
1461        .collect::<std::collections::HashSet<_>>();
1462    let join_set = split_conjuncts(join_predicate)
1463        .into_iter()
1464        .map(|expr| expr.to_sql())
1465        .collect::<std::collections::HashSet<_>>();
1466    let filter_prefixes = collect_table_prefixes(filter_predicate);
1467    let mut filter_preds = Vec::new();
1468    let mut join_preds = Vec::new();
1469    for expr in split_conjuncts(combined) {
1470        let sql = expr.to_sql();
1471        if join_set.contains(&sql) {
1472            join_preds.push(expr);
1473        } else if filter_set.contains(&sql) {
1474            filter_preds.push(expr);
1475        } else if is_join_compatible(&expr) {
1476            join_preds.push(expr);
1477        } else if is_same_table_predicate(&expr, &filter_prefixes) {
1478            filter_preds.push(expr);
1479        } else {
1480            join_preds.push(expr);
1481        }
1482    }
1483    (filter_preds, join_preds)
1484}
1485
1486fn is_join_compatible(expr: &Expr) -> bool {
1487    let idents = collect_identifiers(expr);
1488    if idents.len() < 2 {
1489        return false;
1490    }
1491    let mut tables = std::collections::HashSet::new();
1492    for ident in idents {
1493        if let Some(prefix) = table_prefix(&ident) {
1494            tables.insert(prefix.to_string());
1495        }
1496    }
1497    tables.len() >= 2
1498}
1499
1500fn collect_table_prefixes(expr: &Expr) -> std::collections::HashSet<String> {
1501    let idents = collect_identifiers(expr);
1502    let mut tables = std::collections::HashSet::new();
1503    for ident in idents {
1504        if let Some(prefix) = table_prefix(&ident) {
1505            tables.insert(prefix.to_string());
1506        }
1507    }
1508    tables
1509}
1510
1511fn is_same_table_predicate(expr: &Expr, known_tables: &std::collections::HashSet<String>) -> bool {
1512    let idents = collect_identifiers(expr);
1513    if idents.is_empty() {
1514        return false;
1515    }
1516    let mut tables = std::collections::HashSet::new();
1517    for ident in idents {
1518        let Some(prefix) = table_prefix(&ident) else {
1519            return false;
1520        };
1521        tables.insert(prefix.to_string());
1522    }
1523    if tables.len() != 1 {
1524        return false;
1525    }
1526    if known_tables.is_empty() {
1527        return false;
1528    }
1529    tables.is_subset(known_tables)
1530}
1531
1532fn literal_eq(left: &Literal, right: &Literal) -> bool {
1533    match (left, right) {
1534        (Literal::String(left), Literal::String(right)) => left == right,
1535        (Literal::Number(left), Literal::Number(right)) => left == right,
1536        (Literal::Bool(left), Literal::Bool(right)) => left == right,
1537        _ => false,
1538    }
1539}
1540
1541struct UnionFind {
1542    parent: std::collections::HashMap<String, String>,
1543}
1544
1545impl UnionFind {
1546    fn new() -> Self {
1547        Self {
1548            parent: std::collections::HashMap::new(),
1549        }
1550    }
1551
1552    fn add(&mut self, key: &str) {
1553        self.parent
1554            .entry(key.to_string())
1555            .or_insert_with(|| key.to_string());
1556    }
1557
1558    fn find(&mut self, key: &str) -> String {
1559        self.add(key);
1560        let mut current = key.to_string();
1561        let mut path = Vec::new();
1562        loop {
1563            let parent = self
1564                .parent
1565                .get(&current)
1566                .cloned()
1567                .unwrap_or_else(|| current.clone());
1568            if parent == current {
1569                break;
1570            }
1571            path.push(current);
1572            current = parent;
1573        }
1574        for node in path {
1575            self.parent.insert(node, current.clone());
1576        }
1577        current
1578    }
1579
1580    fn union(&mut self, left: &str, right: &str) {
1581        let left_root = self.find(left);
1582        let right_root = self.find(right);
1583        if left_root != right_root {
1584            self.parent.insert(left_root, right_root);
1585        }
1586    }
1587
1588    fn keys(&self) -> Vec<String> {
1589        self.parent.keys().cloned().collect()
1590    }
1591}