chryso_optimizer/
rules.rs

1use crate::utils::{collect_identifiers, collect_tables, combine_conjuncts, split_conjuncts, table_prefix};
2use chryso_core::ast::{BinaryOperator, Expr, Literal};
3use chryso_planner::LogicalPlan;
4
5pub trait Rule {
6    fn name(&self) -> &str;
7    fn apply(&self, plan: &LogicalPlan) -> Vec<LogicalPlan>;
8}
9
10pub struct RuleSet {
11    rules: Vec<Box<dyn Rule + Send + Sync>>,
12}
13
14impl RuleSet {
15    pub fn new() -> Self {
16        Self { rules: Vec::new() }
17    }
18
19    pub fn with_rule(mut self, rule: impl Rule + Send + Sync + 'static) -> Self {
20        self.rules.push(Box::new(rule));
21        self
22    }
23
24    pub fn apply_all(&self, plan: &LogicalPlan) -> Vec<LogicalPlan> {
25        let mut results = Vec::new();
26        for rule in &self.rules {
27            results.extend(rule.apply(plan));
28        }
29        results
30    }
31
32    pub fn iter(&self) -> impl Iterator<Item = &Box<dyn Rule + Send + Sync>> {
33        self.rules.iter()
34    }
35}
36
37impl Default for RuleSet {
38    fn default() -> Self {
39        RuleSet::new()
40            .with_rule(MergeFilters)
41            .with_rule(PruneProjection)
42            .with_rule(MergeProjections)
43            .with_rule(RemoveTrueFilter)
44            .with_rule(FilterPushdown)
45            .with_rule(FilterJoinPushdown)
46            .with_rule(PredicateInference)
47            .with_rule(JoinPredicatePushdown)
48            .with_rule(FilterOrDedup)
49            .with_rule(NormalizePredicates)
50            .with_rule(JoinCommute)
51            .with_rule(AggregatePredicatePushdown)
52            .with_rule(LimitPushdown)
53            .with_rule(TopNRule)
54    }
55}
56
57impl RuleSet {
58    pub fn detect_conflicts(&self) -> Vec<String> {
59        let mut seen = std::collections::HashMap::new();
60        for rule in &self.rules {
61            *seen.entry(rule.name().to_string()).or_insert(0usize) += 1;
62        }
63        seen.into_iter()
64            .filter_map(|(name, count)| if count > 1 { Some(name) } else { None })
65            .collect()
66    }
67}
68
69impl std::fmt::Debug for RuleSet {
70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71        f.debug_struct("RuleSet")
72            .field("rule_count", &self.rules.len())
73            .finish()
74    }
75}
76
77pub struct NoopRule;
78
79impl Rule for NoopRule {
80    fn name(&self) -> &str {
81        "noop"
82    }
83
84    fn apply(&self, _plan: &LogicalPlan) -> Vec<LogicalPlan> {
85        Vec::new()
86    }
87}
88
89pub struct MergeFilters;
90
91impl Rule for MergeFilters {
92    fn name(&self) -> &str {
93        "merge_filters"
94    }
95
96    fn apply(&self, plan: &LogicalPlan) -> Vec<LogicalPlan> {
97        let LogicalPlan::Filter { predicate, input } = plan else {
98            return Vec::new();
99        };
100        let LogicalPlan::Filter {
101            predicate: inner_predicate,
102            input: inner_input,
103        } = input.as_ref()
104        else {
105            return Vec::new();
106        };
107        let merged = LogicalPlan::Filter {
108            predicate: chryso_core::ast::Expr::BinaryOp {
109                left: Box::new(inner_predicate.clone()),
110                op: BinaryOperator::And,
111                right: Box::new(predicate.clone()),
112            },
113            input: inner_input.clone(),
114        };
115        vec![merged]
116    }
117}
118
119pub struct PruneProjection;
120
121impl Rule for PruneProjection {
122    fn name(&self) -> &str {
123        "prune_projection"
124    }
125
126    fn apply(&self, plan: &LogicalPlan) -> Vec<LogicalPlan> {
127        let LogicalPlan::Projection { exprs, input } = plan else {
128            return Vec::new();
129        };
130        if exprs.len() == 1 && matches!(exprs[0], Expr::Wildcard) {
131            return vec![(*input.clone())];
132        }
133        Vec::new()
134    }
135}
136
137pub struct MergeProjections;
138
139impl Rule for MergeProjections {
140    fn name(&self) -> &str {
141        "merge_projections"
142    }
143
144    fn apply(&self, plan: &LogicalPlan) -> Vec<LogicalPlan> {
145        let LogicalPlan::Projection { exprs, input } = plan else {
146            return Vec::new();
147        };
148        let LogicalPlan::Projection {
149            exprs: inner_exprs,
150            input: inner_input,
151        } = input.as_ref()
152        else {
153            return Vec::new();
154        };
155        if projection_subset(exprs, inner_exprs) {
156            return vec![LogicalPlan::Projection {
157                exprs: exprs.clone(),
158                input: inner_input.clone(),
159            }];
160        }
161        Vec::new()
162    }
163}
164
165fn projection_subset(outer: &[Expr], inner: &[Expr]) -> bool {
166    let inner_names = inner
167        .iter()
168        .filter_map(|expr| match expr {
169            Expr::Identifier(name) => Some(name),
170            _ => None,
171        })
172        .collect::<std::collections::HashSet<_>>();
173    if inner_names.is_empty() {
174        return false;
175    }
176    outer.iter().all(|expr| match expr {
177        Expr::Identifier(name) => inner_names.contains(name),
178        Expr::Wildcard => true,
179        _ => false,
180    })
181}
182
183pub struct RemoveTrueFilter;
184
185impl Rule for RemoveTrueFilter {
186    fn name(&self) -> &str {
187        "remove_true_filter"
188    }
189
190    fn apply(&self, plan: &LogicalPlan) -> Vec<LogicalPlan> {
191        let LogicalPlan::Filter { predicate, input } = plan else {
192            return Vec::new();
193        };
194        match predicate {
195            Expr::Literal(chryso_core::ast::Literal::Bool(true)) => vec![*input.clone()],
196            _ => Vec::new(),
197        }
198    }
199}
200
201pub struct FilterPushdown;
202
203impl Rule for FilterPushdown {
204    fn name(&self) -> &str {
205        "filter_pushdown"
206    }
207
208    fn apply(&self, plan: &LogicalPlan) -> Vec<LogicalPlan> {
209        let LogicalPlan::Filter { predicate, input } = plan else {
210            return Vec::new();
211        };
212        let LogicalPlan::Projection { exprs, input } = input.as_ref() else {
213            if let LogicalPlan::Sort { order_by, input } = input.as_ref() {
214                return vec![LogicalPlan::Sort {
215                    order_by: order_by.clone(),
216                    input: Box::new(LogicalPlan::Filter {
217                        predicate: predicate.clone(),
218                        input: input.clone(),
219                    }),
220                }];
221            }
222            return Vec::new();
223        };
224        if !projection_is_passthrough(exprs) {
225            return Vec::new();
226        }
227        let pushed = LogicalPlan::Projection {
228            exprs: exprs.clone(),
229            input: Box::new(LogicalPlan::Filter {
230                predicate: predicate.clone(),
231                input: input.clone(),
232            }),
233        };
234        vec![pushed]
235    }
236}
237
238pub struct FilterJoinPushdown;
239
240impl Rule for FilterJoinPushdown {
241    fn name(&self) -> &str {
242        "filter_join_pushdown"
243    }
244
245    fn apply(&self, plan: &LogicalPlan) -> Vec<LogicalPlan> {
246        let LogicalPlan::Filter { predicate, input } = plan else {
247            return Vec::new();
248        };
249        let LogicalPlan::Join {
250            join_type,
251            left,
252            right,
253            on,
254        } = input.as_ref()
255        else {
256            return Vec::new();
257        };
258        if !matches!(join_type, chryso_core::ast::JoinType::Inner) {
259            return Vec::new();
260        }
261
262        let left_tables = collect_tables(left.as_ref());
263        let right_tables = collect_tables(right.as_ref());
264        let mut left_preds = Vec::new();
265        let mut right_preds = Vec::new();
266        let mut remaining = Vec::new();
267
268        for conjunct in split_conjuncts(predicate) {
269            let idents = collect_identifiers(&conjunct);
270            if idents.is_empty() {
271                remaining.push(conjunct);
272                continue;
273            }
274            let mut side = None;
275            let mut ambiguous = false;
276            for ident in &idents {
277                let Some(prefix) = table_prefix(ident) else {
278                    ambiguous = true;
279                    break;
280                };
281                let in_left = left_tables.contains(prefix);
282                let in_right = right_tables.contains(prefix);
283                if in_left && !in_right {
284                    side = match side {
285                        None => Some(Side::Left),
286                        Some(Side::Left) => Some(Side::Left),
287                        Some(Side::Right) => {
288                            ambiguous = true;
289                            break;
290                        }
291                    };
292                } else if in_right && !in_left {
293                    side = match side {
294                        None => Some(Side::Right),
295                        Some(Side::Right) => Some(Side::Right),
296                        Some(Side::Left) => {
297                            ambiguous = true;
298                            break;
299                        }
300                    };
301                } else {
302                    ambiguous = true;
303                    break;
304                }
305            }
306            if ambiguous {
307                remaining.push(conjunct);
308                continue;
309            }
310            match side {
311                Some(Side::Left) => left_preds.push(conjunct),
312                Some(Side::Right) => right_preds.push(conjunct),
313                None => remaining.push(conjunct),
314            }
315        }
316
317        if left_preds.is_empty() && right_preds.is_empty() && remaining.is_empty() {
318            return Vec::new();
319        }
320
321        let new_left = if let Some(expr) = combine_conjuncts(left_preds) {
322            LogicalPlan::Filter {
323                predicate: expr,
324                input: left.clone(),
325            }
326        } else {
327            *left.clone()
328        };
329        let new_right = if let Some(expr) = combine_conjuncts(right_preds) {
330            LogicalPlan::Filter {
331                predicate: expr,
332                input: right.clone(),
333            }
334        } else {
335            *right.clone()
336        };
337        let new_on = if let Some(expr) = combine_conjuncts(remaining) {
338            Expr::BinaryOp {
339                left: Box::new(on.clone()),
340                op: BinaryOperator::And,
341                right: Box::new(expr),
342            }
343        } else {
344            on.clone()
345        };
346        let joined = LogicalPlan::Join {
347            join_type: *join_type,
348            left: Box::new(new_left),
349            right: Box::new(new_right),
350            on: new_on,
351        };
352        vec![joined]
353    }
354}
355
356pub struct JoinPredicatePushdown;
357
358impl Rule for JoinPredicatePushdown {
359    fn name(&self) -> &str {
360        "join_predicate_pushdown"
361    }
362
363    fn apply(&self, plan: &LogicalPlan) -> Vec<LogicalPlan> {
364        let LogicalPlan::Join {
365            join_type,
366            left,
367            right,
368            on,
369        } = plan
370        else {
371            return Vec::new();
372        };
373        if !matches!(join_type, chryso_core::ast::JoinType::Inner) {
374            return Vec::new();
375        }
376
377        let left_tables = collect_tables(left.as_ref());
378        let right_tables = collect_tables(right.as_ref());
379        let mut left_preds = Vec::new();
380        let mut right_preds = Vec::new();
381        let mut remaining = Vec::new();
382
383        for conjunct in split_conjuncts(on) {
384            let idents = collect_identifiers(&conjunct);
385            if idents.is_empty() {
386                remaining.push(conjunct);
387                continue;
388            }
389            let mut side = None;
390            let mut ambiguous = false;
391            for ident in &idents {
392                let Some(prefix) = table_prefix(ident) else {
393                    ambiguous = true;
394                    break;
395                };
396                let in_left = left_tables.contains(prefix);
397                let in_right = right_tables.contains(prefix);
398                if in_left && !in_right {
399                    side = match side {
400                        None => Some(Side::Left),
401                        Some(Side::Left) => Some(Side::Left),
402                        Some(Side::Right) => {
403                            ambiguous = true;
404                            break;
405                        }
406                    };
407                } else if in_right && !in_left {
408                    side = match side {
409                        None => Some(Side::Right),
410                        Some(Side::Right) => Some(Side::Right),
411                        Some(Side::Left) => {
412                            ambiguous = true;
413                            break;
414                        }
415                    };
416                } else {
417                    ambiguous = true;
418                    break;
419                }
420            }
421            if ambiguous {
422                remaining.push(conjunct);
423                continue;
424            }
425            match side {
426                Some(Side::Left) => left_preds.push(conjunct),
427                Some(Side::Right) => right_preds.push(conjunct),
428                None => remaining.push(conjunct),
429            }
430        }
431
432        if left_preds.is_empty() && right_preds.is_empty() {
433            return Vec::new();
434        }
435
436        let new_left = if let Some(expr) = combine_conjuncts(left_preds) {
437            LogicalPlan::Filter {
438                predicate: expr,
439                input: left.clone(),
440            }
441        } else {
442            *left.clone()
443        };
444        let new_right = if let Some(expr) = combine_conjuncts(right_preds) {
445            LogicalPlan::Filter {
446                predicate: expr,
447                input: right.clone(),
448            }
449        } else {
450            *right.clone()
451        };
452        let new_on =
453            combine_conjuncts(remaining).unwrap_or_else(|| Expr::Literal(Literal::Bool(true)));
454        vec![LogicalPlan::Join {
455            join_type: *join_type,
456            left: Box::new(new_left),
457            right: Box::new(new_right),
458            on: new_on,
459        }]
460    }
461}
462
463fn projection_is_passthrough(exprs: &[Expr]) -> bool {
464    exprs.iter().all(|expr| match expr {
465        Expr::Identifier(_) | Expr::Wildcard => true,
466        _ => false,
467    })
468}
469
470pub struct NormalizePredicates;
471
472impl Rule for NormalizePredicates {
473    fn name(&self) -> &str {
474        "normalize_predicates"
475    }
476
477    fn apply(&self, plan: &LogicalPlan) -> Vec<LogicalPlan> {
478        let LogicalPlan::Filter { predicate, input } = plan else {
479            return Vec::new();
480        };
481        let normalized = predicate.normalize();
482        if normalized.structural_eq(predicate) {
483            return Vec::new();
484        }
485        vec![LogicalPlan::Filter {
486            predicate: normalized,
487            input: input.clone(),
488        }]
489    }
490}
491
492pub struct PredicateInference;
493
494impl Rule for PredicateInference {
495    fn name(&self) -> &str {
496        "predicate_inference"
497    }
498
499    fn apply(&self, plan: &LogicalPlan) -> Vec<LogicalPlan> {
500        match plan {
501            LogicalPlan::Filter {
502                predicate,
503                input,
504            } => match input.as_ref() {
505                LogicalPlan::Join {
506                    join_type,
507                    left,
508                    right,
509                    on,
510                } if matches!(join_type, chryso_core::ast::JoinType::Inner) => {
511                    let combined = Expr::BinaryOp {
512                        left: Box::new(predicate.clone()),
513                        op: BinaryOperator::And,
514                        right: Box::new(on.clone()),
515                    };
516                    let (inferred, changed) = infer_predicates(&combined);
517                    if !changed {
518                        return Vec::new();
519                    }
520                    let (filter_predicates, join_predicates) =
521                        split_predicates_by_source(&inferred, predicate, on);
522                    let join_expr = combine_conjuncts(join_predicates)
523                        .unwrap_or_else(|| Expr::Literal(Literal::Bool(true)));
524                    let join_plan = LogicalPlan::Join {
525                        join_type: *join_type,
526                        left: left.clone(),
527                        right: right.clone(),
528                        on: join_expr,
529                    };
530                    let filter_expr = combine_conjuncts(filter_predicates)
531                        .unwrap_or_else(|| Expr::Literal(Literal::Bool(true)));
532                    vec![LogicalPlan::Filter {
533                        predicate: filter_expr,
534                        input: Box::new(join_plan),
535                    }]
536                }
537                _ => {
538                    let (predicate, changed) = infer_predicates(predicate);
539                    if !changed {
540                        return Vec::new();
541                    }
542                    vec![LogicalPlan::Filter {
543                        predicate,
544                        input: input.clone(),
545                    }]
546                }
547            },
548            LogicalPlan::Join {
549                join_type,
550                left,
551                right,
552                on,
553            } if matches!(join_type, chryso_core::ast::JoinType::Inner) => {
554                let (on, changed) = infer_predicates(on);
555                if !changed {
556                    return Vec::new();
557                }
558                vec![LogicalPlan::Join {
559                    join_type: *join_type,
560                    left: left.clone(),
561                    right: right.clone(),
562                    on,
563                }]
564            }
565            _ => Vec::new(),
566        }
567    }
568}
569
570pub struct JoinCommute;
571
572impl Rule for JoinCommute {
573    fn name(&self) -> &str {
574        "join_commute"
575    }
576
577    fn apply(&self, plan: &LogicalPlan) -> Vec<LogicalPlan> {
578        let LogicalPlan::Join {
579            join_type,
580            left,
581            right,
582            on,
583        } = plan
584        else {
585            return Vec::new();
586        };
587        if !matches!(join_type, chryso_core::ast::JoinType::Inner) {
588            return Vec::new();
589        }
590        vec![LogicalPlan::Join {
591            join_type: *join_type,
592            left: right.clone(),
593            right: left.clone(),
594            on: on.clone(),
595        }]
596    }
597}
598
599pub struct FilterOrDedup;
600
601impl Rule for FilterOrDedup {
602    fn name(&self) -> &str {
603        "filter_or_dedup"
604    }
605
606    fn apply(&self, plan: &LogicalPlan) -> Vec<LogicalPlan> {
607        let LogicalPlan::Filter { predicate, input } = plan else {
608            return Vec::new();
609        };
610        let Expr::BinaryOp { left, op, right } = predicate else {
611            return Vec::new();
612        };
613        if !matches!(op, BinaryOperator::Or) {
614            return Vec::new();
615        }
616        if left.structural_eq(right) {
617            return vec![LogicalPlan::Filter {
618                predicate: (*left.clone()),
619                input: input.clone(),
620            }];
621        }
622        Vec::new()
623    }
624}
625
626pub struct LimitPushdown;
627
628impl Rule for LimitPushdown {
629    fn name(&self) -> &str {
630        "limit_pushdown"
631    }
632
633    fn apply(&self, plan: &LogicalPlan) -> Vec<LogicalPlan> {
634        let LogicalPlan::Limit {
635            limit,
636            offset,
637            input,
638        } = plan
639        else {
640            return Vec::new();
641        };
642        if offset.is_some() {
643            return Vec::new();
644        }
645        let inner = input.as_ref();
646        match inner {
647            LogicalPlan::Filter { predicate, input } => vec![LogicalPlan::Filter {
648                predicate: predicate.clone(),
649                input: Box::new(LogicalPlan::Limit {
650                    limit: *limit,
651                    offset: *offset,
652                    input: input.clone(),
653                }),
654            }],
655            LogicalPlan::Projection { exprs, input } => vec![LogicalPlan::Projection {
656                exprs: exprs.clone(),
657                input: Box::new(LogicalPlan::Limit {
658                    limit: *limit,
659                    offset: *offset,
660                    input: input.clone(),
661                }),
662            }],
663            _ => Vec::new(),
664        }
665    }
666}
667
668pub struct TopNRule;
669
670impl Rule for TopNRule {
671    fn name(&self) -> &str {
672        "topn_rule"
673    }
674
675    fn apply(&self, plan: &LogicalPlan) -> Vec<LogicalPlan> {
676        let LogicalPlan::Limit {
677            limit: Some(limit),
678            offset: None,
679            input,
680        } = plan
681        else {
682            return Vec::new();
683        };
684        let LogicalPlan::Sort { order_by, input } = input.as_ref() else {
685            return Vec::new();
686        };
687        vec![LogicalPlan::TopN {
688            order_by: order_by.clone(),
689            limit: *limit,
690            input: input.clone(),
691        }]
692    }
693}
694
695#[cfg(test)]
696mod tests {
697    use super::{
698        FilterJoinPushdown, FilterOrDedup, FilterPushdown, JoinPredicatePushdown, LimitPushdown,
699        MergeFilters, MergeProjections, NormalizePredicates, PredicateInference, PruneProjection,
700        RemoveTrueFilter, Rule, TopNRule,
701    };
702    use crate::utils::split_conjuncts;
703    use chryso_core::ast::{BinaryOperator, Expr};
704    use chryso_planner::LogicalPlan;
705
706    #[test]
707    fn merge_filters_combines_predicates() {
708        let plan = LogicalPlan::Filter {
709            predicate: Expr::Identifier("a".to_string()),
710            input: Box::new(LogicalPlan::Filter {
711                predicate: Expr::Identifier("b".to_string()),
712                input: Box::new(LogicalPlan::Scan {
713                    table: "t".to_string(),
714                }),
715            }),
716        };
717        let rule = MergeFilters;
718        let results = rule.apply(&plan);
719        assert_eq!(results.len(), 1);
720        let LogicalPlan::Filter { predicate, .. } = &results[0] else {
721            panic!("expected filter");
722        };
723        let Expr::BinaryOp { op, .. } = predicate else {
724            panic!("expected binary op");
725        };
726        assert!(matches!(op, BinaryOperator::And));
727    }
728
729    #[test]
730    fn prune_projection_removes_star() {
731        let plan = LogicalPlan::Projection {
732            exprs: vec![Expr::Wildcard],
733            input: Box::new(LogicalPlan::Scan {
734                table: "t".to_string(),
735            }),
736        };
737        let rule = PruneProjection;
738        let results = rule.apply(&plan);
739        assert_eq!(results.len(), 1);
740    }
741
742    #[test]
743    fn filter_pushdown_under_projection() {
744        let plan = LogicalPlan::Filter {
745            predicate: Expr::Identifier("x".to_string()),
746            input: Box::new(LogicalPlan::Projection {
747                exprs: vec![Expr::Identifier("x".to_string())],
748                input: Box::new(LogicalPlan::Scan {
749                    table: "t".to_string(),
750                }),
751            }),
752        };
753        let rule = FilterPushdown;
754        let results = rule.apply(&plan);
755        assert_eq!(results.len(), 1);
756    }
757
758    #[test]
759    fn filter_pushdown_under_sort() {
760        let plan = LogicalPlan::Filter {
761            predicate: Expr::Identifier("x".to_string()),
762            input: Box::new(LogicalPlan::Sort {
763                order_by: vec![chryso_core::ast::OrderByExpr {
764                    expr: Expr::Identifier("id".to_string()),
765                    asc: true,
766                    nulls_first: None,
767                }],
768                input: Box::new(LogicalPlan::Scan {
769                    table: "t".to_string(),
770                }),
771            }),
772        };
773        let rule = FilterPushdown;
774        let results = rule.apply(&plan);
775        assert_eq!(results.len(), 1);
776        let LogicalPlan::Sort { input, .. } = &results[0] else {
777            panic!("expected sort");
778        };
779        assert!(matches!(input.as_ref(), LogicalPlan::Filter { .. }));
780    }
781
782    #[test]
783    fn normalize_predicates_orders_and() {
784        let plan = LogicalPlan::Filter {
785            predicate: Expr::BinaryOp {
786                left: Box::new(Expr::Identifier("b".to_string())),
787                op: BinaryOperator::And,
788                right: Box::new(Expr::Identifier("a".to_string())),
789            },
790            input: Box::new(LogicalPlan::Scan {
791                table: "t".to_string(),
792            }),
793        };
794        let rule = NormalizePredicates;
795        let results = rule.apply(&plan);
796        assert_eq!(results.len(), 1);
797    }
798
799    #[test]
800    fn normalize_predicates_removes_true_and() {
801        let plan = LogicalPlan::Filter {
802            predicate: Expr::BinaryOp {
803                left: Box::new(Expr::Literal(chryso_core::ast::Literal::Bool(true))),
804                op: BinaryOperator::And,
805                right: Box::new(Expr::Identifier("x".to_string())),
806            },
807            input: Box::new(LogicalPlan::Scan {
808                table: "t".to_string(),
809            }),
810        };
811        let rule = NormalizePredicates;
812        let results = rule.apply(&plan);
813        assert_eq!(results.len(), 1);
814        let LogicalPlan::Filter { predicate, .. } = &results[0] else {
815            panic!("expected filter");
816        };
817        assert_eq!(predicate.to_sql(), "x");
818    }
819
820    #[test]
821    fn normalize_predicates_removes_nested_true_and() {
822        let plan = LogicalPlan::Filter {
823            predicate: Expr::BinaryOp {
824                left: Box::new(Expr::Literal(chryso_core::ast::Literal::Bool(true))),
825                op: BinaryOperator::And,
826                right: Box::new(Expr::BinaryOp {
827                    left: Box::new(Expr::Literal(chryso_core::ast::Literal::Bool(true))),
828                    op: BinaryOperator::And,
829                    right: Box::new(Expr::Identifier("x".to_string())),
830                }),
831            },
832            input: Box::new(LogicalPlan::Scan {
833                table: "t".to_string(),
834            }),
835        };
836        let rule = NormalizePredicates;
837        let results = rule.apply(&plan);
838        assert_eq!(results.len(), 1);
839        let LogicalPlan::Filter { predicate, .. } = &results[0] else {
840            panic!("expected filter");
841        };
842        assert_eq!(predicate.to_sql(), "x");
843    }
844
845    #[test]
846    fn detect_rule_conflicts() {
847        let rules = crate::rules::RuleSet::new()
848            .with_rule(MergeFilters)
849            .with_rule(MergeFilters);
850        let conflicts = rules.detect_conflicts();
851        assert_eq!(conflicts, vec!["merge_filters".to_string()]);
852    }
853
854    #[test]
855    fn topn_rule_rewrites_sort_limit() {
856        let plan = LogicalPlan::Limit {
857            limit: Some(10),
858            offset: None,
859            input: Box::new(LogicalPlan::Sort {
860                order_by: vec![chryso_core::ast::OrderByExpr {
861                    expr: Expr::Identifier("id".to_string()),
862                    asc: true,
863                    nulls_first: None,
864                }],
865                input: Box::new(LogicalPlan::Scan {
866                    table: "t".to_string(),
867                }),
868            }),
869        };
870        let rule = TopNRule;
871        let results = rule.apply(&plan);
872        assert_eq!(results.len(), 1);
873    }
874
875    #[test]
876    fn limit_pushdown_under_projection() {
877        let plan = LogicalPlan::Limit {
878            limit: Some(5),
879            offset: None,
880            input: Box::new(LogicalPlan::Projection {
881                exprs: vec![Expr::Identifier("id".to_string())],
882                input: Box::new(LogicalPlan::Scan {
883                    table: "t".to_string(),
884                }),
885            }),
886        };
887        let rule = LimitPushdown;
888        let results = rule.apply(&plan);
889        assert_eq!(results.len(), 1);
890    }
891
892    #[test]
893    fn merge_projections_keeps_outer() {
894        let plan = LogicalPlan::Projection {
895            exprs: vec![Expr::Identifier("id".to_string())],
896            input: Box::new(LogicalPlan::Projection {
897                exprs: vec![Expr::Identifier("id".to_string()), Expr::Identifier("name".to_string())],
898                input: Box::new(LogicalPlan::Scan {
899                    table: "t".to_string(),
900                }),
901            }),
902        };
903        let rule = MergeProjections;
904        let results = rule.apply(&plan);
905        assert_eq!(results.len(), 1);
906        let LogicalPlan::Projection { exprs, .. } = &results[0] else {
907            panic!("expected projection");
908        };
909        assert_eq!(exprs.len(), 1);
910    }
911
912    #[test]
913    fn remove_true_filter() {
914        let plan = LogicalPlan::Filter {
915            predicate: Expr::Literal(chryso_core::ast::Literal::Bool(true)),
916            input: Box::new(LogicalPlan::Scan {
917                table: "t".to_string(),
918            }),
919        };
920        let rule = RemoveTrueFilter;
921        let results = rule.apply(&plan);
922        assert_eq!(results.len(), 1);
923        assert!(matches!(results[0], LogicalPlan::Scan { .. }));
924    }
925
926    #[test]
927    fn remove_true_filter_ignores_binary_predicate() {
928        let plan = LogicalPlan::Filter {
929            predicate: Expr::BinaryOp {
930                left: Box::new(Expr::Literal(chryso_core::ast::Literal::Bool(true))),
931                op: BinaryOperator::And,
932                right: Box::new(Expr::Identifier("x".to_string())),
933            },
934            input: Box::new(LogicalPlan::Scan {
935                table: "t".to_string(),
936            }),
937        };
938        let rule = RemoveTrueFilter;
939        let results = rule.apply(&plan);
940        assert!(results.is_empty());
941    }
942
943    #[test]
944    fn filter_join_pushdown_left() {
945        let plan = LogicalPlan::Filter {
946            predicate: Expr::BinaryOp {
947                left: Box::new(Expr::Identifier("t1.id".to_string())),
948                op: BinaryOperator::Eq,
949                right: Box::new(Expr::Literal(chryso_core::ast::Literal::Number(1.0))),
950            },
951            input: Box::new(LogicalPlan::Join {
952                join_type: chryso_core::ast::JoinType::Inner,
953                left: Box::new(LogicalPlan::Scan {
954                    table: "t1".to_string(),
955                }),
956                right: Box::new(LogicalPlan::Scan {
957                    table: "t2".to_string(),
958                }),
959                on: Expr::Identifier("t1.id = t2.id".to_string()),
960            }),
961        };
962        let rule = FilterJoinPushdown;
963        let results = rule.apply(&plan);
964        assert_eq!(results.len(), 1);
965        let LogicalPlan::Join { left, .. } = &results[0] else {
966            panic!("expected join");
967        };
968        assert!(matches!(left.as_ref(), LogicalPlan::Filter { .. }));
969    }
970
971    #[test]
972    fn filter_join_pushdown_keeps_cross_predicate() {
973        let plan = LogicalPlan::Filter {
974            predicate: Expr::BinaryOp {
975                left: Box::new(Expr::Identifier("t1.id".to_string())),
976                op: BinaryOperator::Eq,
977                right: Box::new(Expr::Identifier("t2.id".to_string())),
978            },
979            input: Box::new(LogicalPlan::Join {
980                join_type: chryso_core::ast::JoinType::Inner,
981                left: Box::new(LogicalPlan::Scan {
982                    table: "t1".to_string(),
983                }),
984                right: Box::new(LogicalPlan::Scan {
985                    table: "t2".to_string(),
986                }),
987                on: Expr::Identifier("t1.id = t2.id".to_string()),
988            }),
989        };
990        let rule = FilterJoinPushdown;
991        let results = rule.apply(&plan);
992        assert_eq!(results.len(), 1);
993        let LogicalPlan::Join { on, .. } = &results[0] else {
994            panic!("expected join");
995        };
996        let Expr::BinaryOp { op, .. } = on else {
997            panic!("expected binary op");
998        };
999        assert!(matches!(op, BinaryOperator::And));
1000    }
1001
1002    #[test]
1003    fn filter_or_dedup_removes_duplicate_predicates() {
1004        let plan = LogicalPlan::Filter {
1005            predicate: Expr::BinaryOp {
1006                left: Box::new(Expr::Identifier("a".to_string())),
1007                op: BinaryOperator::Or,
1008                right: Box::new(Expr::Identifier("a".to_string())),
1009            },
1010            input: Box::new(LogicalPlan::Scan {
1011                table: "t".to_string(),
1012            }),
1013        };
1014        let rule = FilterOrDedup;
1015        let results = rule.apply(&plan);
1016        assert_eq!(results.len(), 1);
1017        let LogicalPlan::Filter { predicate, .. } = &results[0] else {
1018            panic!("expected filter");
1019        };
1020        assert!(matches!(predicate, Expr::Identifier(name) if name == "a"));
1021    }
1022
1023    #[test]
1024    fn join_predicate_pushdown_splits_single_side() {
1025        let plan = LogicalPlan::Join {
1026            join_type: chryso_core::ast::JoinType::Inner,
1027            left: Box::new(LogicalPlan::Scan {
1028                table: "t1".to_string(),
1029            }),
1030            right: Box::new(LogicalPlan::Scan {
1031                table: "t2".to_string(),
1032            }),
1033            on: Expr::BinaryOp {
1034                left: Box::new(Expr::BinaryOp {
1035                    left: Box::new(Expr::Identifier("t1.flag".to_string())),
1036                    op: BinaryOperator::Eq,
1037                    right: Box::new(Expr::Literal(chryso_core::ast::Literal::Bool(true))),
1038                }),
1039                op: BinaryOperator::And,
1040                right: Box::new(Expr::BinaryOp {
1041                    left: Box::new(Expr::Identifier("t1.id".to_string())),
1042                    op: BinaryOperator::Eq,
1043                    right: Box::new(Expr::Identifier("t2.id".to_string())),
1044                }),
1045            },
1046        };
1047        let rule = JoinPredicatePushdown;
1048        let results = rule.apply(&plan);
1049        assert_eq!(results.len(), 1);
1050        let LogicalPlan::Join { left, on, .. } = &results[0] else {
1051            panic!("expected join");
1052        };
1053        assert!(matches!(left.as_ref(), LogicalPlan::Filter { .. }));
1054        assert_eq!(on.to_sql(), "t1.id = t2.id");
1055    }
1056
1057    #[test]
1058    fn predicate_inference_adds_literal_equivalence() {
1059        let plan = LogicalPlan::Filter {
1060            predicate: Expr::BinaryOp {
1061                left: Box::new(Expr::BinaryOp {
1062                    left: Box::new(Expr::Identifier("a".to_string())),
1063                    op: BinaryOperator::Eq,
1064                    right: Box::new(Expr::Identifier("b".to_string())),
1065                }),
1066                op: BinaryOperator::And,
1067                right: Box::new(Expr::BinaryOp {
1068                    left: Box::new(Expr::Identifier("a".to_string())),
1069                    op: BinaryOperator::Eq,
1070                    right: Box::new(Expr::Literal(chryso_core::ast::Literal::Number(1.0))),
1071                }),
1072            },
1073            input: Box::new(LogicalPlan::Scan {
1074                table: "t".to_string(),
1075            }),
1076        };
1077        let rule = PredicateInference;
1078        let results = rule.apply(&plan);
1079        assert_eq!(results.len(), 1);
1080        let LogicalPlan::Filter { predicate, .. } = &results[0] else {
1081            panic!("expected filter");
1082        };
1083        let conjuncts = split_conjuncts(predicate)
1084            .into_iter()
1085            .map(|expr| expr.to_sql())
1086            .collect::<std::collections::HashSet<_>>();
1087        assert!(conjuncts.contains("b = 1"));
1088    }
1089
1090    #[test]
1091    fn predicate_inference_on_join() {
1092        let plan = LogicalPlan::Join {
1093            join_type: chryso_core::ast::JoinType::Inner,
1094            left: Box::new(LogicalPlan::Scan {
1095                table: "t1".to_string(),
1096            }),
1097            right: Box::new(LogicalPlan::Scan {
1098                table: "t2".to_string(),
1099            }),
1100            on: Expr::BinaryOp {
1101                left: Box::new(Expr::BinaryOp {
1102                    left: Box::new(Expr::Identifier("t1.id".to_string())),
1103                    op: BinaryOperator::Eq,
1104                    right: Box::new(Expr::Identifier("t2.id".to_string())),
1105                }),
1106                op: BinaryOperator::And,
1107                right: Box::new(Expr::BinaryOp {
1108                    left: Box::new(Expr::Identifier("t2.id".to_string())),
1109                    op: BinaryOperator::Eq,
1110                    right: Box::new(Expr::Literal(chryso_core::ast::Literal::Number(5.0))),
1111                }),
1112            },
1113        };
1114        let rule = PredicateInference;
1115        let results = rule.apply(&plan);
1116        assert_eq!(results.len(), 1);
1117        let LogicalPlan::Join { on, .. } = &results[0] else {
1118            panic!("expected join");
1119        };
1120        let conjuncts = split_conjuncts(on)
1121            .into_iter()
1122            .map(|expr| expr.to_sql())
1123            .collect::<std::collections::HashSet<_>>();
1124        assert!(conjuncts.contains("t1.id = 5"));
1125    }
1126
1127    #[test]
1128    fn predicate_inference_enables_join_pushdown() {
1129        let plan = LogicalPlan::Join {
1130            join_type: chryso_core::ast::JoinType::Inner,
1131            left: Box::new(LogicalPlan::Scan {
1132                table: "t1".to_string(),
1133            }),
1134            right: Box::new(LogicalPlan::Scan {
1135                table: "t2".to_string(),
1136            }),
1137            on: Expr::BinaryOp {
1138                left: Box::new(Expr::BinaryOp {
1139                    left: Box::new(Expr::Identifier("t1.id".to_string())),
1140                    op: BinaryOperator::Eq,
1141                    right: Box::new(Expr::Identifier("t2.id".to_string())),
1142                }),
1143                op: BinaryOperator::And,
1144                right: Box::new(Expr::BinaryOp {
1145                    left: Box::new(Expr::Identifier("t1.id".to_string())),
1146                    op: BinaryOperator::Eq,
1147                    right: Box::new(Expr::Literal(chryso_core::ast::Literal::Number(42.0))),
1148                }),
1149            },
1150        };
1151        let inferred = PredicateInference.apply(&plan);
1152        assert_eq!(inferred.len(), 1);
1153        let pushed = JoinPredicatePushdown.apply(&inferred[0]);
1154        assert_eq!(pushed.len(), 1);
1155        let LogicalPlan::Join { right, .. } = &pushed[0] else {
1156            panic!("expected join");
1157        };
1158        assert!(matches!(right.as_ref(), LogicalPlan::Filter { .. }));
1159    }
1160
1161    #[test]
1162    fn predicate_inference_propagates_across_filter_join() {
1163        let plan = LogicalPlan::Filter {
1164            predicate: Expr::BinaryOp {
1165                left: Box::new(Expr::Identifier("t1.flag".to_string())),
1166                op: BinaryOperator::Eq,
1167                right: Box::new(Expr::Literal(chryso_core::ast::Literal::Bool(true))),
1168            },
1169            input: Box::new(LogicalPlan::Join {
1170                join_type: chryso_core::ast::JoinType::Inner,
1171                left: Box::new(LogicalPlan::Scan {
1172                    table: "t1".to_string(),
1173                }),
1174                right: Box::new(LogicalPlan::Scan {
1175                    table: "t2".to_string(),
1176                }),
1177                on: Expr::BinaryOp {
1178                    left: Box::new(Expr::BinaryOp {
1179                        left: Box::new(Expr::Identifier("t1.id".to_string())),
1180                        op: BinaryOperator::Eq,
1181                        right: Box::new(Expr::Identifier("t2.id".to_string())),
1182                    }),
1183                    op: BinaryOperator::And,
1184                    right: Box::new(Expr::BinaryOp {
1185                        left: Box::new(Expr::Identifier("t1.id".to_string())),
1186                        op: BinaryOperator::Eq,
1187                        right: Box::new(Expr::Literal(chryso_core::ast::Literal::Number(7.0))),
1188                    }),
1189                },
1190            }),
1191        };
1192        let inferred = PredicateInference.apply(&plan);
1193        assert_eq!(inferred.len(), 1);
1194        let LogicalPlan::Filter { input, predicate } = &inferred[0] else {
1195            panic!("expected filter");
1196        };
1197        let LogicalPlan::Join { on, .. } = input.as_ref() else {
1198            panic!("expected join");
1199        };
1200        let filter_conjuncts = split_conjuncts(predicate)
1201            .into_iter()
1202            .map(|expr| expr.to_sql())
1203            .collect::<std::collections::HashSet<_>>();
1204        let join_conjuncts = split_conjuncts(on)
1205            .into_iter()
1206            .map(|expr| expr.to_sql())
1207            .collect::<std::collections::HashSet<_>>();
1208        assert!(filter_conjuncts.contains("t1.flag = true"));
1209        assert!(join_conjuncts.contains("t2.id = 7"));
1210    }
1211
1212    #[test]
1213    fn predicate_inference_adds_transitive_equivalence() {
1214        let plan = LogicalPlan::Filter {
1215            predicate: Expr::BinaryOp {
1216                left: Box::new(Expr::BinaryOp {
1217                    left: Box::new(Expr::Identifier("a".to_string())),
1218                    op: BinaryOperator::Eq,
1219                    right: Box::new(Expr::Identifier("b".to_string())),
1220                }),
1221                op: BinaryOperator::And,
1222                right: Box::new(Expr::BinaryOp {
1223                    left: Box::new(Expr::Identifier("b".to_string())),
1224                    op: BinaryOperator::Eq,
1225                    right: Box::new(Expr::Identifier("c".to_string())),
1226                }),
1227            },
1228            input: Box::new(LogicalPlan::Scan {
1229                table: "t".to_string(),
1230            }),
1231        };
1232        let rule = PredicateInference;
1233        let results = rule.apply(&plan);
1234        assert_eq!(results.len(), 1);
1235        let LogicalPlan::Filter { predicate, .. } = &results[0] else {
1236            panic!("expected filter");
1237        };
1238        let conjuncts = split_conjuncts(predicate)
1239            .into_iter()
1240            .map(|expr| expr.to_sql())
1241            .collect::<std::collections::HashSet<_>>();
1242        assert!(conjuncts.contains("a = c") || conjuncts.contains("c = a"));
1243    }
1244}
1245
1246pub struct AggregatePredicatePushdown;
1247
1248impl Rule for AggregatePredicatePushdown {
1249    fn name(&self) -> &str {
1250        "aggregate_predicate_pushdown"
1251    }
1252
1253    fn apply(&self, plan: &LogicalPlan) -> Vec<LogicalPlan> {
1254        let LogicalPlan::Filter { predicate, input } = plan else {
1255            return Vec::new();
1256        };
1257        let LogicalPlan::Aggregate {
1258            group_exprs,
1259            aggr_exprs,
1260            input,
1261        } = input.as_ref()
1262        else {
1263            return Vec::new();
1264        };
1265        let predicate_idents = collect_identifiers(predicate);
1266        let group_idents = group_exprs
1267            .iter()
1268            .flat_map(collect_identifiers)
1269            .collect::<std::collections::HashSet<_>>();
1270        if predicate_idents
1271            .iter()
1272            .all(|ident| group_idents.contains(ident))
1273        {
1274            return vec![LogicalPlan::Aggregate {
1275                group_exprs: group_exprs.clone(),
1276                aggr_exprs: aggr_exprs.clone(),
1277                input: Box::new(LogicalPlan::Filter {
1278                    predicate: predicate.clone(),
1279                    input: input.clone(),
1280                }),
1281            }];
1282        }
1283        Vec::new()
1284    }
1285}
1286
1287#[derive(Clone, Copy)]
1288enum Side {
1289    Left,
1290    Right,
1291}
1292
1293// shared helpers are in utils.rs
1294
1295fn infer_predicates(predicate: &Expr) -> (Expr, bool) {
1296    let conjuncts = split_conjuncts(predicate);
1297    let mut existing = std::collections::HashSet::new();
1298    for expr in &conjuncts {
1299        existing.insert(expr.to_sql());
1300    }
1301
1302    let mut uf = UnionFind::new();
1303    let mut literal_bindings = std::collections::HashMap::<String, Literal>::new();
1304    let mut literal_conflicts = std::collections::HashSet::<String>::new();
1305
1306    for conjunct in &conjuncts {
1307        let Expr::BinaryOp { left, op, right } = conjunct else {
1308            continue;
1309        };
1310        if !matches!(op, BinaryOperator::Eq) {
1311            continue;
1312        }
1313        match (left.as_ref(), right.as_ref()) {
1314            (Expr::Identifier(left), Expr::Identifier(right)) => {
1315                uf.union(left, right);
1316            }
1317            (Expr::Identifier(ident), Expr::Literal(literal))
1318            | (Expr::Literal(literal), Expr::Identifier(ident)) => {
1319                uf.add(ident);
1320                if let Some(existing) = literal_bindings.get(ident) {
1321                    if !literal_eq(existing, literal) {
1322                        literal_conflicts.insert(ident.clone());
1323                    }
1324                } else {
1325                    literal_bindings.insert(ident.clone(), literal.clone());
1326                }
1327            }
1328            _ => {}
1329        }
1330    }
1331
1332    let mut class_literals = std::collections::HashMap::<String, Option<Literal>>::new();
1333    if !literal_conflicts.is_empty() {
1334        // TODO: surface conflicting literal bindings in optimizer trace.
1335    }
1336    for ident in &literal_conflicts {
1337        let root = uf.find(ident);
1338        class_literals.insert(root, None);
1339    }
1340    for (ident, literal) in &literal_bindings {
1341        if literal_conflicts.contains(ident) {
1342            continue;
1343        }
1344        let root = uf.find(ident);
1345        match class_literals.get(&root) {
1346            None => {
1347                class_literals.insert(root, Some(literal.clone()));
1348            }
1349            Some(Some(existing_literal)) => {
1350                if !literal_eq(existing_literal, literal) {
1351                    class_literals.insert(root, None);
1352                }
1353            }
1354            Some(None) => {}
1355        }
1356    }
1357
1358    let mut groups = std::collections::HashMap::<String, Vec<String>>::new();
1359    for key in uf.keys() {
1360        let root = uf.find(&key);
1361        groups.entry(root).or_default().push(key);
1362    }
1363
1364    let mut inferred = Vec::new();
1365    for members in groups.values() {
1366        if members.len() <= 1 {
1367            continue;
1368        }
1369        let mut members = members.clone();
1370        members.sort();
1371        let canonical = members[0].clone();
1372        for ident in members.into_iter().skip(1) {
1373            let forward = format!("{canonical} = {ident}");
1374            let backward = format!("{ident} = {canonical}");
1375            if existing.contains(&forward) || existing.contains(&backward) {
1376                continue;
1377            }
1378            existing.insert(forward);
1379            inferred.push(Expr::BinaryOp {
1380                left: Box::new(Expr::Identifier(canonical.clone())),
1381                op: BinaryOperator::Eq,
1382                right: Box::new(Expr::Identifier(ident)),
1383            });
1384        }
1385    }
1386
1387    for (root, literal) in class_literals {
1388        let Some(literal) = literal else {
1389            continue;
1390        };
1391        let Some(members) = groups.get(&root) else {
1392            continue;
1393        };
1394        for ident in members {
1395            let expr = Expr::BinaryOp {
1396                left: Box::new(Expr::Identifier(ident.clone())),
1397                op: BinaryOperator::Eq,
1398                right: Box::new(Expr::Literal(literal.clone())),
1399            };
1400            if existing.insert(expr.to_sql()) {
1401                inferred.push(expr);
1402            }
1403        }
1404    }
1405
1406    if inferred.is_empty() {
1407        return (predicate.clone(), false);
1408    }
1409    let mut all = conjuncts;
1410    all.extend(inferred);
1411    (combine_conjuncts(all).unwrap_or_else(|| predicate.clone()), true)
1412}
1413
1414fn split_predicates_by_source(
1415    combined: &Expr,
1416    filter_predicate: &Expr,
1417    join_predicate: &Expr,
1418) -> (Vec<Expr>, Vec<Expr>) {
1419    let filter_set = split_conjuncts(filter_predicate)
1420        .into_iter()
1421        .map(|expr| expr.to_sql())
1422        .collect::<std::collections::HashSet<_>>();
1423    let join_set = split_conjuncts(join_predicate)
1424        .into_iter()
1425        .map(|expr| expr.to_sql())
1426        .collect::<std::collections::HashSet<_>>();
1427    let filter_prefixes = collect_table_prefixes(filter_predicate);
1428    let mut filter_preds = Vec::new();
1429    let mut join_preds = Vec::new();
1430    for expr in split_conjuncts(combined) {
1431        let sql = expr.to_sql();
1432        if join_set.contains(&sql) {
1433            join_preds.push(expr);
1434        } else if filter_set.contains(&sql) {
1435            filter_preds.push(expr);
1436        } else if is_join_compatible(&expr) {
1437            join_preds.push(expr);
1438        } else if is_same_table_predicate(&expr, &filter_prefixes) {
1439            filter_preds.push(expr);
1440        } else {
1441            join_preds.push(expr);
1442        }
1443    }
1444    (filter_preds, join_preds)
1445}
1446
1447fn is_join_compatible(expr: &Expr) -> bool {
1448    let idents = collect_identifiers(expr);
1449    if idents.len() < 2 {
1450        return false;
1451    }
1452    let mut tables = std::collections::HashSet::new();
1453    for ident in idents {
1454        if let Some(prefix) = table_prefix(&ident) {
1455            tables.insert(prefix.to_string());
1456        }
1457    }
1458    tables.len() >= 2
1459}
1460
1461fn collect_table_prefixes(expr: &Expr) -> std::collections::HashSet<String> {
1462    let idents = collect_identifiers(expr);
1463    let mut tables = std::collections::HashSet::new();
1464    for ident in idents {
1465        if let Some(prefix) = table_prefix(&ident) {
1466            tables.insert(prefix.to_string());
1467        }
1468    }
1469    tables
1470}
1471
1472fn is_same_table_predicate(expr: &Expr, known_tables: &std::collections::HashSet<String>) -> bool {
1473    let idents = collect_identifiers(expr);
1474    if idents.is_empty() {
1475        return false;
1476    }
1477    let mut tables = std::collections::HashSet::new();
1478    for ident in idents {
1479        let Some(prefix) = table_prefix(&ident) else {
1480            return false;
1481        };
1482        tables.insert(prefix.to_string());
1483    }
1484    if tables.len() != 1 {
1485        return false;
1486    }
1487    if known_tables.is_empty() {
1488        return false;
1489    }
1490    tables.is_subset(known_tables)
1491}
1492
1493fn literal_eq(left: &Literal, right: &Literal) -> bool {
1494    match (left, right) {
1495        (Literal::String(left), Literal::String(right)) => left == right,
1496        (Literal::Number(left), Literal::Number(right)) => left == right,
1497        (Literal::Bool(left), Literal::Bool(right)) => left == right,
1498        _ => false,
1499    }
1500}
1501
1502struct UnionFind {
1503    parent: std::collections::HashMap<String, String>,
1504}
1505
1506impl UnionFind {
1507    fn new() -> Self {
1508        Self {
1509            parent: std::collections::HashMap::new(),
1510        }
1511    }
1512
1513    fn add(&mut self, key: &str) {
1514        self.parent.entry(key.to_string()).or_insert_with(|| key.to_string());
1515    }
1516
1517    fn find(&mut self, key: &str) -> String {
1518        self.add(key);
1519        let mut current = key.to_string();
1520        let mut path = Vec::new();
1521        loop {
1522            let parent = self
1523                .parent
1524                .get(&current)
1525                .cloned()
1526                .unwrap_or_else(|| current.clone());
1527            if parent == current {
1528                break;
1529            }
1530            path.push(current);
1531            current = parent;
1532        }
1533        for node in path {
1534            self.parent.insert(node, current.clone());
1535        }
1536        current
1537    }
1538
1539    fn union(&mut self, left: &str, right: &str) {
1540        let left_root = self.find(left);
1541        let right_root = self.find(right);
1542        if left_root != right_root {
1543            self.parent.insert(left_root, right_root);
1544        }
1545    }
1546
1547    fn keys(&self) -> Vec<String> {
1548        self.parent.keys().cloned().collect()
1549    }
1550}