1use crate::utils::{
2 collect_identifiers, collect_tables, combine_conjuncts, split_conjuncts, table_prefix,
3};
4use chryso_core::ast::{BinaryOperator, Expr, Literal};
5use chryso_planner::LogicalPlan;
6use std::collections::BTreeSet;
7
8pub trait Rule {
9 fn name(&self) -> &str;
10 fn apply(&self, plan: &LogicalPlan, _ctx: &mut RuleContext) -> Vec<LogicalPlan>;
11}
12
13#[derive(Default, Debug)]
14pub struct RuleContext {
15 literal_conflicts: BTreeSet<(String, String)>,
16}
17
18impl RuleContext {
19 pub fn record_literal_conflicts<I>(&mut self, pairs: I)
20 where
21 I: IntoIterator<Item = (String, String)>,
22 {
23 self.literal_conflicts.extend(pairs);
24 }
25
26 pub fn take_literal_conflicts(&mut self) -> Vec<(String, String)> {
27 std::mem::take(&mut self.literal_conflicts)
29 .into_iter()
30 .collect()
31 }
32}
33
34pub struct RuleSet {
35 rules: Vec<Box<dyn Rule + Send + Sync>>,
36}
37
38impl RuleSet {
39 pub fn new() -> Self {
40 Self { rules: Vec::new() }
41 }
42
43 pub fn with_rule(mut self, rule: impl Rule + Send + Sync + 'static) -> Self {
44 self.rules.push(Box::new(rule));
45 self
46 }
47
48 pub fn apply_all(&self, plan: &LogicalPlan, ctx: &mut RuleContext) -> Vec<LogicalPlan> {
49 let mut results = Vec::new();
50 for rule in &self.rules {
51 results.extend(rule.apply(plan, ctx));
52 }
53 results
54 }
55
56 pub fn iter(&self) -> impl Iterator<Item = &Box<dyn Rule + Send + Sync>> {
57 self.rules.iter()
58 }
59}
60
61impl Default for RuleSet {
62 fn default() -> Self {
63 RuleSet::new()
64 .with_rule(MergeFilters)
65 .with_rule(PruneProjection)
66 .with_rule(MergeProjections)
67 .with_rule(RemoveTrueFilter)
68 .with_rule(FilterPushdown)
69 .with_rule(FilterJoinPushdown)
70 .with_rule(PredicateInference)
71 .with_rule(JoinPredicatePushdown)
72 .with_rule(FilterOrDedup)
73 .with_rule(NormalizePredicates)
74 .with_rule(JoinCommute)
75 .with_rule(AggregatePredicatePushdown)
76 .with_rule(LimitPushdown)
77 .with_rule(TopNRule)
78 }
79}
80
81impl RuleSet {
82 pub fn detect_conflicts(&self) -> Vec<String> {
83 let mut seen = std::collections::HashMap::new();
84 for rule in &self.rules {
85 *seen.entry(rule.name().to_string()).or_insert(0usize) += 1;
86 }
87 seen.into_iter()
88 .filter_map(|(name, count)| if count > 1 { Some(name) } else { None })
89 .collect()
90 }
91}
92
93impl std::fmt::Debug for RuleSet {
94 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95 f.debug_struct("RuleSet")
96 .field("rule_count", &self.rules.len())
97 .finish()
98 }
99}
100
101pub struct NoopRule;
102
103impl Rule for NoopRule {
104 fn name(&self) -> &str {
105 "noop"
106 }
107
108 fn apply(&self, _plan: &LogicalPlan, _ctx: &mut RuleContext) -> Vec<LogicalPlan> {
109 Vec::new()
110 }
111}
112
113pub struct MergeFilters;
114
115impl Rule for MergeFilters {
116 fn name(&self) -> &str {
117 "merge_filters"
118 }
119
120 fn apply(&self, plan: &LogicalPlan, _ctx: &mut RuleContext) -> Vec<LogicalPlan> {
121 let LogicalPlan::Filter { predicate, input } = plan else {
122 return Vec::new();
123 };
124 let LogicalPlan::Filter {
125 predicate: inner_predicate,
126 input: inner_input,
127 } = input.as_ref()
128 else {
129 return Vec::new();
130 };
131 let merged = LogicalPlan::Filter {
132 predicate: chryso_core::ast::Expr::BinaryOp {
133 left: Box::new(inner_predicate.clone()),
134 op: BinaryOperator::And,
135 right: Box::new(predicate.clone()),
136 },
137 input: inner_input.clone(),
138 };
139 vec![merged]
140 }
141}
142
143pub struct PruneProjection;
144
145impl Rule for PruneProjection {
146 fn name(&self) -> &str {
147 "prune_projection"
148 }
149
150 fn apply(&self, plan: &LogicalPlan, _ctx: &mut RuleContext) -> Vec<LogicalPlan> {
151 let LogicalPlan::Projection { exprs, input } = plan else {
152 return Vec::new();
153 };
154 if exprs.len() == 1 && matches!(exprs[0], Expr::Wildcard) {
155 return vec![(*input.clone())];
156 }
157 Vec::new()
158 }
159}
160
161pub struct MergeProjections;
162
163impl Rule for MergeProjections {
164 fn name(&self) -> &str {
165 "merge_projections"
166 }
167
168 fn apply(&self, plan: &LogicalPlan, _ctx: &mut RuleContext) -> Vec<LogicalPlan> {
169 let LogicalPlan::Projection { exprs, input } = plan else {
170 return Vec::new();
171 };
172 let LogicalPlan::Projection {
173 exprs: inner_exprs,
174 input: inner_input,
175 } = input.as_ref()
176 else {
177 return Vec::new();
178 };
179 if projection_subset(exprs, inner_exprs) {
180 return vec![LogicalPlan::Projection {
181 exprs: exprs.clone(),
182 input: inner_input.clone(),
183 }];
184 }
185 Vec::new()
186 }
187}
188
189fn projection_subset(outer: &[Expr], inner: &[Expr]) -> bool {
190 let inner_names = inner
191 .iter()
192 .filter_map(|expr| match expr {
193 Expr::Identifier(name) => Some(name),
194 _ => None,
195 })
196 .collect::<std::collections::HashSet<_>>();
197 if inner_names.is_empty() {
198 return false;
199 }
200 outer.iter().all(|expr| match expr {
201 Expr::Identifier(name) => inner_names.contains(name),
202 Expr::Wildcard => true,
203 _ => false,
204 })
205}
206
207pub struct RemoveTrueFilter;
208
209impl Rule for RemoveTrueFilter {
210 fn name(&self) -> &str {
211 "remove_true_filter"
212 }
213
214 fn apply(&self, plan: &LogicalPlan, _ctx: &mut RuleContext) -> Vec<LogicalPlan> {
215 let LogicalPlan::Filter { predicate, input } = plan else {
216 return Vec::new();
217 };
218 match predicate {
219 Expr::Literal(chryso_core::ast::Literal::Bool(true)) => vec![*input.clone()],
220 _ => Vec::new(),
221 }
222 }
223}
224
225pub struct FilterPushdown;
226
227impl Rule for FilterPushdown {
228 fn name(&self) -> &str {
229 "filter_pushdown"
230 }
231
232 fn apply(&self, plan: &LogicalPlan, _ctx: &mut RuleContext) -> Vec<LogicalPlan> {
233 let LogicalPlan::Filter { predicate, input } = plan else {
234 return Vec::new();
235 };
236 let LogicalPlan::Projection { exprs, input } = input.as_ref() else {
237 if let LogicalPlan::Sort { order_by, input } = input.as_ref() {
238 return vec![LogicalPlan::Sort {
239 order_by: order_by.clone(),
240 input: Box::new(LogicalPlan::Filter {
241 predicate: predicate.clone(),
242 input: input.clone(),
243 }),
244 }];
245 }
246 return Vec::new();
247 };
248 if !projection_is_passthrough(exprs) {
249 return Vec::new();
250 }
251 let pushed = LogicalPlan::Projection {
252 exprs: exprs.clone(),
253 input: Box::new(LogicalPlan::Filter {
254 predicate: predicate.clone(),
255 input: input.clone(),
256 }),
257 };
258 vec![pushed]
259 }
260}
261
262pub struct FilterJoinPushdown;
263
264impl Rule for FilterJoinPushdown {
265 fn name(&self) -> &str {
266 "filter_join_pushdown"
267 }
268
269 fn apply(&self, plan: &LogicalPlan, _ctx: &mut RuleContext) -> Vec<LogicalPlan> {
270 let LogicalPlan::Filter { predicate, input } = plan else {
271 return Vec::new();
272 };
273 let LogicalPlan::Join {
274 join_type,
275 left,
276 right,
277 on,
278 } = input.as_ref()
279 else {
280 return Vec::new();
281 };
282 if !matches!(join_type, chryso_core::ast::JoinType::Inner) {
283 return Vec::new();
284 }
285
286 let left_tables = collect_tables(left.as_ref());
287 let right_tables = collect_tables(right.as_ref());
288 let mut left_preds = Vec::new();
289 let mut right_preds = Vec::new();
290 let mut remaining = Vec::new();
291
292 for conjunct in split_conjuncts(predicate) {
293 let idents = collect_identifiers(&conjunct);
294 if idents.is_empty() {
295 remaining.push(conjunct);
296 continue;
297 }
298 let mut side = None;
299 let mut ambiguous = false;
300 for ident in &idents {
301 let Some(prefix) = table_prefix(ident) else {
302 ambiguous = true;
303 break;
304 };
305 let in_left = left_tables.contains(prefix);
306 let in_right = right_tables.contains(prefix);
307 if in_left && !in_right {
308 side = match side {
309 None => Some(Side::Left),
310 Some(Side::Left) => Some(Side::Left),
311 Some(Side::Right) => {
312 ambiguous = true;
313 break;
314 }
315 };
316 } else if in_right && !in_left {
317 side = match side {
318 None => Some(Side::Right),
319 Some(Side::Right) => Some(Side::Right),
320 Some(Side::Left) => {
321 ambiguous = true;
322 break;
323 }
324 };
325 } else {
326 ambiguous = true;
327 break;
328 }
329 }
330 if ambiguous {
331 remaining.push(conjunct);
332 continue;
333 }
334 match side {
335 Some(Side::Left) => left_preds.push(conjunct),
336 Some(Side::Right) => right_preds.push(conjunct),
337 None => remaining.push(conjunct),
338 }
339 }
340
341 if left_preds.is_empty() && right_preds.is_empty() && remaining.is_empty() {
342 return Vec::new();
343 }
344
345 let new_left = if let Some(expr) = combine_conjuncts(left_preds) {
346 LogicalPlan::Filter {
347 predicate: expr,
348 input: left.clone(),
349 }
350 } else {
351 *left.clone()
352 };
353 let new_right = if let Some(expr) = combine_conjuncts(right_preds) {
354 LogicalPlan::Filter {
355 predicate: expr,
356 input: right.clone(),
357 }
358 } else {
359 *right.clone()
360 };
361 let new_on = if let Some(expr) = combine_conjuncts(remaining) {
362 Expr::BinaryOp {
363 left: Box::new(on.clone()),
364 op: BinaryOperator::And,
365 right: Box::new(expr),
366 }
367 } else {
368 on.clone()
369 };
370 let joined = LogicalPlan::Join {
371 join_type: *join_type,
372 left: Box::new(new_left),
373 right: Box::new(new_right),
374 on: new_on,
375 };
376 vec![joined]
377 }
378}
379
380pub struct JoinPredicatePushdown;
381
382impl Rule for JoinPredicatePushdown {
383 fn name(&self) -> &str {
384 "join_predicate_pushdown"
385 }
386
387 fn apply(&self, plan: &LogicalPlan, _ctx: &mut RuleContext) -> Vec<LogicalPlan> {
388 let LogicalPlan::Join {
389 join_type,
390 left,
391 right,
392 on,
393 } = plan
394 else {
395 return Vec::new();
396 };
397 if !matches!(join_type, chryso_core::ast::JoinType::Inner) {
398 return Vec::new();
399 }
400
401 let left_tables = collect_tables(left.as_ref());
402 let right_tables = collect_tables(right.as_ref());
403 let mut left_preds = Vec::new();
404 let mut right_preds = Vec::new();
405 let mut remaining = Vec::new();
406
407 for conjunct in split_conjuncts(on) {
408 let idents = collect_identifiers(&conjunct);
409 if idents.is_empty() {
410 remaining.push(conjunct);
411 continue;
412 }
413 let mut side = None;
414 let mut ambiguous = false;
415 for ident in &idents {
416 let Some(prefix) = table_prefix(ident) else {
417 ambiguous = true;
418 break;
419 };
420 let in_left = left_tables.contains(prefix);
421 let in_right = right_tables.contains(prefix);
422 if in_left && !in_right {
423 side = match side {
424 None => Some(Side::Left),
425 Some(Side::Left) => Some(Side::Left),
426 Some(Side::Right) => {
427 ambiguous = true;
428 break;
429 }
430 };
431 } else if in_right && !in_left {
432 side = match side {
433 None => Some(Side::Right),
434 Some(Side::Right) => Some(Side::Right),
435 Some(Side::Left) => {
436 ambiguous = true;
437 break;
438 }
439 };
440 } else {
441 ambiguous = true;
442 break;
443 }
444 }
445 if ambiguous {
446 remaining.push(conjunct);
447 continue;
448 }
449 match side {
450 Some(Side::Left) => left_preds.push(conjunct),
451 Some(Side::Right) => right_preds.push(conjunct),
452 None => remaining.push(conjunct),
453 }
454 }
455
456 if left_preds.is_empty() && right_preds.is_empty() {
457 return Vec::new();
458 }
459
460 let new_left = if let Some(expr) = combine_conjuncts(left_preds) {
461 LogicalPlan::Filter {
462 predicate: expr,
463 input: left.clone(),
464 }
465 } else {
466 *left.clone()
467 };
468 let new_right = if let Some(expr) = combine_conjuncts(right_preds) {
469 LogicalPlan::Filter {
470 predicate: expr,
471 input: right.clone(),
472 }
473 } else {
474 *right.clone()
475 };
476 let new_on =
477 combine_conjuncts(remaining).unwrap_or_else(|| Expr::Literal(Literal::Bool(true)));
478 vec![LogicalPlan::Join {
479 join_type: *join_type,
480 left: Box::new(new_left),
481 right: Box::new(new_right),
482 on: new_on,
483 }]
484 }
485}
486
487fn projection_is_passthrough(exprs: &[Expr]) -> bool {
488 exprs.iter().all(|expr| match expr {
489 Expr::Identifier(_) | Expr::Wildcard => true,
490 _ => false,
491 })
492}
493
494pub struct NormalizePredicates;
495
496impl Rule for NormalizePredicates {
497 fn name(&self) -> &str {
498 "normalize_predicates"
499 }
500
501 fn apply(&self, plan: &LogicalPlan, _ctx: &mut RuleContext) -> Vec<LogicalPlan> {
502 let LogicalPlan::Filter { predicate, input } = plan else {
503 return Vec::new();
504 };
505 let normalized = predicate.normalize();
506 if normalized.structural_eq(predicate) {
507 return Vec::new();
508 }
509 vec![LogicalPlan::Filter {
510 predicate: normalized,
511 input: input.clone(),
512 }]
513 }
514}
515
516pub struct PredicateInference;
517
518impl Rule for PredicateInference {
519 fn name(&self) -> &str {
520 "predicate_inference"
521 }
522
523 fn apply(&self, plan: &LogicalPlan, ctx: &mut RuleContext) -> Vec<LogicalPlan> {
524 match plan {
525 LogicalPlan::Filter { predicate, input } => match input.as_ref() {
526 LogicalPlan::Join {
527 join_type,
528 left,
529 right,
530 on,
531 } if matches!(join_type, chryso_core::ast::JoinType::Inner) => {
532 let combined = Expr::BinaryOp {
533 left: Box::new(predicate.clone()),
534 op: BinaryOperator::And,
535 right: Box::new(on.clone()),
536 };
537 let (inferred, changed) = infer_predicates(&combined, ctx);
538 if !changed {
539 return Vec::new();
540 }
541 let (filter_predicates, join_predicates) =
542 split_predicates_by_source(&inferred, predicate, on);
543 let join_expr = combine_conjuncts(join_predicates)
544 .unwrap_or_else(|| Expr::Literal(Literal::Bool(true)));
545 let join_plan = LogicalPlan::Join {
546 join_type: *join_type,
547 left: left.clone(),
548 right: right.clone(),
549 on: join_expr,
550 };
551 let filter_expr = combine_conjuncts(filter_predicates)
552 .unwrap_or_else(|| Expr::Literal(Literal::Bool(true)));
553 vec![LogicalPlan::Filter {
554 predicate: filter_expr,
555 input: Box::new(join_plan),
556 }]
557 }
558 _ => {
559 let (predicate, changed) = infer_predicates(predicate, ctx);
560 if !changed {
561 return Vec::new();
562 }
563 vec![LogicalPlan::Filter {
564 predicate,
565 input: input.clone(),
566 }]
567 }
568 },
569 LogicalPlan::Join {
570 join_type,
571 left,
572 right,
573 on,
574 } if matches!(join_type, chryso_core::ast::JoinType::Inner) => {
575 let (on, changed) = infer_predicates(on, ctx);
576 if !changed {
577 return Vec::new();
578 }
579 vec![LogicalPlan::Join {
580 join_type: *join_type,
581 left: left.clone(),
582 right: right.clone(),
583 on,
584 }]
585 }
586 _ => Vec::new(),
587 }
588 }
589}
590
591pub struct JoinCommute;
592
593impl Rule for JoinCommute {
594 fn name(&self) -> &str {
595 "join_commute"
596 }
597
598 fn apply(&self, plan: &LogicalPlan, _ctx: &mut RuleContext) -> Vec<LogicalPlan> {
599 let LogicalPlan::Join {
600 join_type,
601 left,
602 right,
603 on,
604 } = plan
605 else {
606 return Vec::new();
607 };
608 if !matches!(join_type, chryso_core::ast::JoinType::Inner) {
609 return Vec::new();
610 }
611 vec![LogicalPlan::Join {
612 join_type: *join_type,
613 left: right.clone(),
614 right: left.clone(),
615 on: on.clone(),
616 }]
617 }
618}
619
620pub struct FilterOrDedup;
621
622impl Rule for FilterOrDedup {
623 fn name(&self) -> &str {
624 "filter_or_dedup"
625 }
626
627 fn apply(&self, plan: &LogicalPlan, _ctx: &mut RuleContext) -> Vec<LogicalPlan> {
628 let LogicalPlan::Filter { predicate, input } = plan else {
629 return Vec::new();
630 };
631 let Expr::BinaryOp { left, op, right } = predicate else {
632 return Vec::new();
633 };
634 if !matches!(op, BinaryOperator::Or) {
635 return Vec::new();
636 }
637 if left.structural_eq(right) {
638 return vec![LogicalPlan::Filter {
639 predicate: (*left.clone()),
640 input: input.clone(),
641 }];
642 }
643 Vec::new()
644 }
645}
646
647pub struct LimitPushdown;
648
649impl Rule for LimitPushdown {
650 fn name(&self) -> &str {
651 "limit_pushdown"
652 }
653
654 fn apply(&self, plan: &LogicalPlan, _ctx: &mut RuleContext) -> Vec<LogicalPlan> {
655 let LogicalPlan::Limit {
656 limit,
657 offset,
658 input,
659 } = plan
660 else {
661 return Vec::new();
662 };
663 if offset.is_some() {
664 return Vec::new();
665 }
666 let inner = input.as_ref();
667 match inner {
668 LogicalPlan::Filter { predicate, input } => vec![LogicalPlan::Filter {
669 predicate: predicate.clone(),
670 input: Box::new(LogicalPlan::Limit {
671 limit: *limit,
672 offset: *offset,
673 input: input.clone(),
674 }),
675 }],
676 LogicalPlan::Projection { exprs, input } => vec![LogicalPlan::Projection {
677 exprs: exprs.clone(),
678 input: Box::new(LogicalPlan::Limit {
679 limit: *limit,
680 offset: *offset,
681 input: input.clone(),
682 }),
683 }],
684 _ => Vec::new(),
685 }
686 }
687}
688
689pub struct TopNRule;
690
691impl Rule for TopNRule {
692 fn name(&self) -> &str {
693 "topn_rule"
694 }
695
696 fn apply(&self, plan: &LogicalPlan, _ctx: &mut RuleContext) -> Vec<LogicalPlan> {
697 let LogicalPlan::Limit {
698 limit: Some(limit),
699 offset: None,
700 input,
701 } = plan
702 else {
703 return Vec::new();
704 };
705 let LogicalPlan::Sort { order_by, input } = input.as_ref() else {
706 return Vec::new();
707 };
708 vec![LogicalPlan::TopN {
709 order_by: order_by.clone(),
710 limit: *limit,
711 input: input.clone(),
712 }]
713 }
714}
715
716#[cfg(test)]
717mod tests {
718 use super::{
719 FilterJoinPushdown, FilterOrDedup, FilterPushdown, JoinPredicatePushdown, LimitPushdown,
720 MergeFilters, MergeProjections, NormalizePredicates, PredicateInference, PruneProjection,
721 RemoveTrueFilter, Rule, TopNRule,
722 };
723 use crate::utils::split_conjuncts;
724 use chryso_core::ast::{BinaryOperator, Expr};
725 use chryso_planner::LogicalPlan;
726
727 fn apply(rule: &impl Rule, plan: &LogicalPlan) -> Vec<LogicalPlan> {
728 let mut ctx = super::RuleContext::default();
729 rule.apply(plan, &mut ctx)
730 }
731
732 #[test]
733 fn merge_filters_combines_predicates() {
734 let plan = LogicalPlan::Filter {
735 predicate: Expr::Identifier("a".to_string()),
736 input: Box::new(LogicalPlan::Filter {
737 predicate: Expr::Identifier("b".to_string()),
738 input: Box::new(LogicalPlan::Scan {
739 table: "t".to_string(),
740 }),
741 }),
742 };
743 let rule = MergeFilters;
744 let results = apply(&rule, &plan);
745 assert_eq!(results.len(), 1);
746 let LogicalPlan::Filter { predicate, .. } = &results[0] else {
747 panic!("expected filter");
748 };
749 let Expr::BinaryOp { op, .. } = predicate else {
750 panic!("expected binary op");
751 };
752 assert!(matches!(op, BinaryOperator::And));
753 }
754
755 #[test]
756 fn prune_projection_removes_star() {
757 let plan = LogicalPlan::Projection {
758 exprs: vec![Expr::Wildcard],
759 input: Box::new(LogicalPlan::Scan {
760 table: "t".to_string(),
761 }),
762 };
763 let rule = PruneProjection;
764 let results = apply(&rule, &plan);
765 assert_eq!(results.len(), 1);
766 }
767
768 #[test]
769 fn filter_pushdown_under_projection() {
770 let plan = LogicalPlan::Filter {
771 predicate: Expr::Identifier("x".to_string()),
772 input: Box::new(LogicalPlan::Projection {
773 exprs: vec![Expr::Identifier("x".to_string())],
774 input: Box::new(LogicalPlan::Scan {
775 table: "t".to_string(),
776 }),
777 }),
778 };
779 let rule = FilterPushdown;
780 let results = apply(&rule, &plan);
781 assert_eq!(results.len(), 1);
782 }
783
784 #[test]
785 fn filter_pushdown_under_sort() {
786 let plan = LogicalPlan::Filter {
787 predicate: Expr::Identifier("x".to_string()),
788 input: Box::new(LogicalPlan::Sort {
789 order_by: vec![chryso_core::ast::OrderByExpr {
790 expr: Expr::Identifier("id".to_string()),
791 asc: true,
792 nulls_first: None,
793 }],
794 input: Box::new(LogicalPlan::Scan {
795 table: "t".to_string(),
796 }),
797 }),
798 };
799 let rule = FilterPushdown;
800 let results = apply(&rule, &plan);
801 assert_eq!(results.len(), 1);
802 let LogicalPlan::Sort { input, .. } = &results[0] else {
803 panic!("expected sort");
804 };
805 assert!(matches!(input.as_ref(), LogicalPlan::Filter { .. }));
806 }
807
808 #[test]
809 fn normalize_predicates_orders_and() {
810 let plan = LogicalPlan::Filter {
811 predicate: Expr::BinaryOp {
812 left: Box::new(Expr::Identifier("b".to_string())),
813 op: BinaryOperator::And,
814 right: Box::new(Expr::Identifier("a".to_string())),
815 },
816 input: Box::new(LogicalPlan::Scan {
817 table: "t".to_string(),
818 }),
819 };
820 let rule = NormalizePredicates;
821 let results = apply(&rule, &plan);
822 assert_eq!(results.len(), 1);
823 }
824
825 #[test]
826 fn normalize_predicates_removes_true_and() {
827 let plan = LogicalPlan::Filter {
828 predicate: Expr::BinaryOp {
829 left: Box::new(Expr::Literal(chryso_core::ast::Literal::Bool(true))),
830 op: BinaryOperator::And,
831 right: Box::new(Expr::Identifier("x".to_string())),
832 },
833 input: Box::new(LogicalPlan::Scan {
834 table: "t".to_string(),
835 }),
836 };
837 let rule = NormalizePredicates;
838 let results = apply(&rule, &plan);
839 assert_eq!(results.len(), 1);
840 let LogicalPlan::Filter { predicate, .. } = &results[0] else {
841 panic!("expected filter");
842 };
843 assert_eq!(predicate.to_sql(), "x");
844 }
845
846 #[test]
847 fn normalize_predicates_removes_nested_true_and() {
848 let plan = LogicalPlan::Filter {
849 predicate: Expr::BinaryOp {
850 left: Box::new(Expr::Literal(chryso_core::ast::Literal::Bool(true))),
851 op: BinaryOperator::And,
852 right: Box::new(Expr::BinaryOp {
853 left: Box::new(Expr::Literal(chryso_core::ast::Literal::Bool(true))),
854 op: BinaryOperator::And,
855 right: Box::new(Expr::Identifier("x".to_string())),
856 }),
857 },
858 input: Box::new(LogicalPlan::Scan {
859 table: "t".to_string(),
860 }),
861 };
862 let rule = NormalizePredicates;
863 let results = apply(&rule, &plan);
864 assert_eq!(results.len(), 1);
865 let LogicalPlan::Filter { predicate, .. } = &results[0] else {
866 panic!("expected filter");
867 };
868 assert_eq!(predicate.to_sql(), "x");
869 }
870
871 #[test]
872 fn detect_rule_conflicts() {
873 let rules = crate::rules::RuleSet::new()
874 .with_rule(MergeFilters)
875 .with_rule(MergeFilters);
876 let conflicts = rules.detect_conflicts();
877 assert_eq!(conflicts, vec!["merge_filters".to_string()]);
878 }
879
880 #[test]
881 fn topn_rule_rewrites_sort_limit() {
882 let plan = LogicalPlan::Limit {
883 limit: Some(10),
884 offset: None,
885 input: Box::new(LogicalPlan::Sort {
886 order_by: vec![chryso_core::ast::OrderByExpr {
887 expr: Expr::Identifier("id".to_string()),
888 asc: true,
889 nulls_first: None,
890 }],
891 input: Box::new(LogicalPlan::Scan {
892 table: "t".to_string(),
893 }),
894 }),
895 };
896 let rule = TopNRule;
897 let results = apply(&rule, &plan);
898 assert_eq!(results.len(), 1);
899 }
900
901 #[test]
902 fn limit_pushdown_under_projection() {
903 let plan = LogicalPlan::Limit {
904 limit: Some(5),
905 offset: None,
906 input: Box::new(LogicalPlan::Projection {
907 exprs: vec![Expr::Identifier("id".to_string())],
908 input: Box::new(LogicalPlan::Scan {
909 table: "t".to_string(),
910 }),
911 }),
912 };
913 let rule = LimitPushdown;
914 let results = apply(&rule, &plan);
915 assert_eq!(results.len(), 1);
916 }
917
918 #[test]
919 fn merge_projections_keeps_outer() {
920 let plan = LogicalPlan::Projection {
921 exprs: vec![Expr::Identifier("id".to_string())],
922 input: Box::new(LogicalPlan::Projection {
923 exprs: vec![
924 Expr::Identifier("id".to_string()),
925 Expr::Identifier("name".to_string()),
926 ],
927 input: Box::new(LogicalPlan::Scan {
928 table: "t".to_string(),
929 }),
930 }),
931 };
932 let rule = MergeProjections;
933 let results = apply(&rule, &plan);
934 assert_eq!(results.len(), 1);
935 let LogicalPlan::Projection { exprs, .. } = &results[0] else {
936 panic!("expected projection");
937 };
938 assert_eq!(exprs.len(), 1);
939 }
940
941 #[test]
942 fn remove_true_filter() {
943 let plan = LogicalPlan::Filter {
944 predicate: Expr::Literal(chryso_core::ast::Literal::Bool(true)),
945 input: Box::new(LogicalPlan::Scan {
946 table: "t".to_string(),
947 }),
948 };
949 let rule = RemoveTrueFilter;
950 let results = apply(&rule, &plan);
951 assert_eq!(results.len(), 1);
952 assert!(matches!(results[0], LogicalPlan::Scan { .. }));
953 }
954
955 #[test]
956 fn remove_true_filter_ignores_binary_predicate() {
957 let plan = LogicalPlan::Filter {
958 predicate: Expr::BinaryOp {
959 left: Box::new(Expr::Literal(chryso_core::ast::Literal::Bool(true))),
960 op: BinaryOperator::And,
961 right: Box::new(Expr::Identifier("x".to_string())),
962 },
963 input: Box::new(LogicalPlan::Scan {
964 table: "t".to_string(),
965 }),
966 };
967 let rule = RemoveTrueFilter;
968 let results = apply(&rule, &plan);
969 assert!(results.is_empty());
970 }
971
972 #[test]
973 fn filter_join_pushdown_left() {
974 let plan = LogicalPlan::Filter {
975 predicate: Expr::BinaryOp {
976 left: Box::new(Expr::Identifier("t1.id".to_string())),
977 op: BinaryOperator::Eq,
978 right: Box::new(Expr::Literal(chryso_core::ast::Literal::Number(1.0))),
979 },
980 input: Box::new(LogicalPlan::Join {
981 join_type: chryso_core::ast::JoinType::Inner,
982 left: Box::new(LogicalPlan::Scan {
983 table: "t1".to_string(),
984 }),
985 right: Box::new(LogicalPlan::Scan {
986 table: "t2".to_string(),
987 }),
988 on: Expr::Identifier("t1.id = t2.id".to_string()),
989 }),
990 };
991 let rule = FilterJoinPushdown;
992 let results = apply(&rule, &plan);
993 assert_eq!(results.len(), 1);
994 let LogicalPlan::Join { left, .. } = &results[0] else {
995 panic!("expected join");
996 };
997 assert!(matches!(left.as_ref(), LogicalPlan::Filter { .. }));
998 }
999
1000 #[test]
1001 fn filter_join_pushdown_keeps_cross_predicate() {
1002 let plan = LogicalPlan::Filter {
1003 predicate: Expr::BinaryOp {
1004 left: Box::new(Expr::Identifier("t1.id".to_string())),
1005 op: BinaryOperator::Eq,
1006 right: Box::new(Expr::Identifier("t2.id".to_string())),
1007 },
1008 input: Box::new(LogicalPlan::Join {
1009 join_type: chryso_core::ast::JoinType::Inner,
1010 left: Box::new(LogicalPlan::Scan {
1011 table: "t1".to_string(),
1012 }),
1013 right: Box::new(LogicalPlan::Scan {
1014 table: "t2".to_string(),
1015 }),
1016 on: Expr::Identifier("t1.id = t2.id".to_string()),
1017 }),
1018 };
1019 let rule = FilterJoinPushdown;
1020 let results = apply(&rule, &plan);
1021 assert_eq!(results.len(), 1);
1022 let LogicalPlan::Join { on, .. } = &results[0] else {
1023 panic!("expected join");
1024 };
1025 let Expr::BinaryOp { op, .. } = on else {
1026 panic!("expected binary op");
1027 };
1028 assert!(matches!(op, BinaryOperator::And));
1029 }
1030
1031 #[test]
1032 fn filter_or_dedup_removes_duplicate_predicates() {
1033 let plan = LogicalPlan::Filter {
1034 predicate: Expr::BinaryOp {
1035 left: Box::new(Expr::Identifier("a".to_string())),
1036 op: BinaryOperator::Or,
1037 right: Box::new(Expr::Identifier("a".to_string())),
1038 },
1039 input: Box::new(LogicalPlan::Scan {
1040 table: "t".to_string(),
1041 }),
1042 };
1043 let rule = FilterOrDedup;
1044 let results = apply(&rule, &plan);
1045 assert_eq!(results.len(), 1);
1046 let LogicalPlan::Filter { predicate, .. } = &results[0] else {
1047 panic!("expected filter");
1048 };
1049 assert!(matches!(predicate, Expr::Identifier(name) if name == "a"));
1050 }
1051
1052 #[test]
1053 fn join_predicate_pushdown_splits_single_side() {
1054 let plan = LogicalPlan::Join {
1055 join_type: chryso_core::ast::JoinType::Inner,
1056 left: Box::new(LogicalPlan::Scan {
1057 table: "t1".to_string(),
1058 }),
1059 right: Box::new(LogicalPlan::Scan {
1060 table: "t2".to_string(),
1061 }),
1062 on: Expr::BinaryOp {
1063 left: Box::new(Expr::BinaryOp {
1064 left: Box::new(Expr::Identifier("t1.flag".to_string())),
1065 op: BinaryOperator::Eq,
1066 right: Box::new(Expr::Literal(chryso_core::ast::Literal::Bool(true))),
1067 }),
1068 op: BinaryOperator::And,
1069 right: Box::new(Expr::BinaryOp {
1070 left: Box::new(Expr::Identifier("t1.id".to_string())),
1071 op: BinaryOperator::Eq,
1072 right: Box::new(Expr::Identifier("t2.id".to_string())),
1073 }),
1074 },
1075 };
1076 let rule = JoinPredicatePushdown;
1077 let results = apply(&rule, &plan);
1078 assert_eq!(results.len(), 1);
1079 let LogicalPlan::Join { left, on, .. } = &results[0] else {
1080 panic!("expected join");
1081 };
1082 assert!(matches!(left.as_ref(), LogicalPlan::Filter { .. }));
1083 assert_eq!(on.to_sql(), "t1.id = t2.id");
1084 }
1085
1086 #[test]
1087 fn predicate_inference_adds_literal_equivalence() {
1088 let plan = LogicalPlan::Filter {
1089 predicate: Expr::BinaryOp {
1090 left: Box::new(Expr::BinaryOp {
1091 left: Box::new(Expr::Identifier("a".to_string())),
1092 op: BinaryOperator::Eq,
1093 right: Box::new(Expr::Identifier("b".to_string())),
1094 }),
1095 op: BinaryOperator::And,
1096 right: Box::new(Expr::BinaryOp {
1097 left: Box::new(Expr::Identifier("a".to_string())),
1098 op: BinaryOperator::Eq,
1099 right: Box::new(Expr::Literal(chryso_core::ast::Literal::Number(1.0))),
1100 }),
1101 },
1102 input: Box::new(LogicalPlan::Scan {
1103 table: "t".to_string(),
1104 }),
1105 };
1106 let rule = PredicateInference;
1107 let results = apply(&rule, &plan);
1108 assert_eq!(results.len(), 1);
1109 let LogicalPlan::Filter { predicate, .. } = &results[0] else {
1110 panic!("expected filter");
1111 };
1112 let conjuncts = split_conjuncts(predicate)
1113 .into_iter()
1114 .map(|expr| expr.to_sql())
1115 .collect::<std::collections::HashSet<_>>();
1116 assert!(conjuncts.contains("b = 1"));
1117 }
1118
1119 #[test]
1120 fn predicate_inference_on_join() {
1121 let plan = LogicalPlan::Join {
1122 join_type: chryso_core::ast::JoinType::Inner,
1123 left: Box::new(LogicalPlan::Scan {
1124 table: "t1".to_string(),
1125 }),
1126 right: Box::new(LogicalPlan::Scan {
1127 table: "t2".to_string(),
1128 }),
1129 on: Expr::BinaryOp {
1130 left: Box::new(Expr::BinaryOp {
1131 left: Box::new(Expr::Identifier("t1.id".to_string())),
1132 op: BinaryOperator::Eq,
1133 right: Box::new(Expr::Identifier("t2.id".to_string())),
1134 }),
1135 op: BinaryOperator::And,
1136 right: Box::new(Expr::BinaryOp {
1137 left: Box::new(Expr::Identifier("t2.id".to_string())),
1138 op: BinaryOperator::Eq,
1139 right: Box::new(Expr::Literal(chryso_core::ast::Literal::Number(5.0))),
1140 }),
1141 },
1142 };
1143 let rule = PredicateInference;
1144 let results = apply(&rule, &plan);
1145 assert_eq!(results.len(), 1);
1146 let LogicalPlan::Join { on, .. } = &results[0] else {
1147 panic!("expected join");
1148 };
1149 let conjuncts = split_conjuncts(on)
1150 .into_iter()
1151 .map(|expr| expr.to_sql())
1152 .collect::<std::collections::HashSet<_>>();
1153 assert!(conjuncts.contains("t1.id = 5"));
1154 }
1155
1156 #[test]
1157 fn predicate_inference_enables_join_pushdown() {
1158 let plan = LogicalPlan::Join {
1159 join_type: chryso_core::ast::JoinType::Inner,
1160 left: Box::new(LogicalPlan::Scan {
1161 table: "t1".to_string(),
1162 }),
1163 right: Box::new(LogicalPlan::Scan {
1164 table: "t2".to_string(),
1165 }),
1166 on: Expr::BinaryOp {
1167 left: Box::new(Expr::BinaryOp {
1168 left: Box::new(Expr::Identifier("t1.id".to_string())),
1169 op: BinaryOperator::Eq,
1170 right: Box::new(Expr::Identifier("t2.id".to_string())),
1171 }),
1172 op: BinaryOperator::And,
1173 right: Box::new(Expr::BinaryOp {
1174 left: Box::new(Expr::Identifier("t1.id".to_string())),
1175 op: BinaryOperator::Eq,
1176 right: Box::new(Expr::Literal(chryso_core::ast::Literal::Number(42.0))),
1177 }),
1178 },
1179 };
1180 let inferred = apply(&PredicateInference, &plan);
1181 assert_eq!(inferred.len(), 1);
1182 let pushed = apply(&JoinPredicatePushdown, &inferred[0]);
1183 assert_eq!(pushed.len(), 1);
1184 let LogicalPlan::Join { right, .. } = &pushed[0] else {
1185 panic!("expected join");
1186 };
1187 assert!(matches!(right.as_ref(), LogicalPlan::Filter { .. }));
1188 }
1189
1190 #[test]
1191 fn predicate_inference_propagates_across_filter_join() {
1192 let plan = LogicalPlan::Filter {
1193 predicate: Expr::BinaryOp {
1194 left: Box::new(Expr::Identifier("t1.flag".to_string())),
1195 op: BinaryOperator::Eq,
1196 right: Box::new(Expr::Literal(chryso_core::ast::Literal::Bool(true))),
1197 },
1198 input: Box::new(LogicalPlan::Join {
1199 join_type: chryso_core::ast::JoinType::Inner,
1200 left: Box::new(LogicalPlan::Scan {
1201 table: "t1".to_string(),
1202 }),
1203 right: Box::new(LogicalPlan::Scan {
1204 table: "t2".to_string(),
1205 }),
1206 on: Expr::BinaryOp {
1207 left: Box::new(Expr::BinaryOp {
1208 left: Box::new(Expr::Identifier("t1.id".to_string())),
1209 op: BinaryOperator::Eq,
1210 right: Box::new(Expr::Identifier("t2.id".to_string())),
1211 }),
1212 op: BinaryOperator::And,
1213 right: Box::new(Expr::BinaryOp {
1214 left: Box::new(Expr::Identifier("t1.id".to_string())),
1215 op: BinaryOperator::Eq,
1216 right: Box::new(Expr::Literal(chryso_core::ast::Literal::Number(7.0))),
1217 }),
1218 },
1219 }),
1220 };
1221 let inferred = apply(&PredicateInference, &plan);
1222 assert_eq!(inferred.len(), 1);
1223 let LogicalPlan::Filter { input, predicate } = &inferred[0] else {
1224 panic!("expected filter");
1225 };
1226 let LogicalPlan::Join { on, .. } = input.as_ref() else {
1227 panic!("expected join");
1228 };
1229 let filter_conjuncts = split_conjuncts(predicate)
1230 .into_iter()
1231 .map(|expr| expr.to_sql())
1232 .collect::<std::collections::HashSet<_>>();
1233 let join_conjuncts = split_conjuncts(on)
1234 .into_iter()
1235 .map(|expr| expr.to_sql())
1236 .collect::<std::collections::HashSet<_>>();
1237 assert!(filter_conjuncts.contains("t1.flag = true"));
1238 assert!(join_conjuncts.contains("t2.id = 7"));
1239 }
1240
1241 #[test]
1242 fn predicate_inference_adds_transitive_equivalence() {
1243 let plan = LogicalPlan::Filter {
1244 predicate: Expr::BinaryOp {
1245 left: Box::new(Expr::BinaryOp {
1246 left: Box::new(Expr::Identifier("a".to_string())),
1247 op: BinaryOperator::Eq,
1248 right: Box::new(Expr::Identifier("b".to_string())),
1249 }),
1250 op: BinaryOperator::And,
1251 right: Box::new(Expr::BinaryOp {
1252 left: Box::new(Expr::Identifier("b".to_string())),
1253 op: BinaryOperator::Eq,
1254 right: Box::new(Expr::Identifier("c".to_string())),
1255 }),
1256 },
1257 input: Box::new(LogicalPlan::Scan {
1258 table: "t".to_string(),
1259 }),
1260 };
1261 let rule = PredicateInference;
1262 let results = apply(&rule, &plan);
1263 assert_eq!(results.len(), 1);
1264 let LogicalPlan::Filter { predicate, .. } = &results[0] else {
1265 panic!("expected filter");
1266 };
1267 let conjuncts = split_conjuncts(predicate)
1268 .into_iter()
1269 .map(|expr| expr.to_sql())
1270 .collect::<std::collections::HashSet<_>>();
1271 assert!(conjuncts.contains("a = c") || conjuncts.contains("c = a"));
1272 }
1273}
1274
1275pub struct AggregatePredicatePushdown;
1276
1277impl Rule for AggregatePredicatePushdown {
1278 fn name(&self) -> &str {
1279 "aggregate_predicate_pushdown"
1280 }
1281
1282 fn apply(&self, plan: &LogicalPlan, _ctx: &mut RuleContext) -> Vec<LogicalPlan> {
1283 let LogicalPlan::Filter { predicate, input } = plan else {
1284 return Vec::new();
1285 };
1286 let LogicalPlan::Aggregate {
1287 group_exprs,
1288 aggr_exprs,
1289 input,
1290 } = input.as_ref()
1291 else {
1292 return Vec::new();
1293 };
1294 let predicate_idents = collect_identifiers(predicate);
1295 let group_idents = group_exprs
1296 .iter()
1297 .flat_map(collect_identifiers)
1298 .collect::<std::collections::HashSet<_>>();
1299 if predicate_idents
1300 .iter()
1301 .all(|ident| group_idents.contains(ident))
1302 {
1303 return vec![LogicalPlan::Aggregate {
1304 group_exprs: group_exprs.clone(),
1305 aggr_exprs: aggr_exprs.clone(),
1306 input: Box::new(LogicalPlan::Filter {
1307 predicate: predicate.clone(),
1308 input: input.clone(),
1309 }),
1310 }];
1311 }
1312 Vec::new()
1313 }
1314}
1315
1316#[derive(Clone, Copy)]
1317enum Side {
1318 Left,
1319 Right,
1320}
1321
1322fn infer_predicates(predicate: &Expr, ctx: &mut RuleContext) -> (Expr, bool) {
1325 let conjuncts = split_conjuncts(predicate);
1327 let mut existing = std::collections::HashSet::new();
1328 for expr in &conjuncts {
1329 existing.insert(expr.to_sql());
1330 }
1331
1332 let mut uf = UnionFind::new();
1333 let mut literal_bindings = std::collections::HashMap::<String, Literal>::new();
1334 let mut literal_conflicts = std::collections::HashSet::<String>::new();
1335 let mut conflict_pairs = std::collections::BTreeSet::<(String, String)>::new();
1336
1337 for conjunct in &conjuncts {
1338 let Expr::BinaryOp { left, op, right } = conjunct else {
1339 continue;
1340 };
1341 if !matches!(op, BinaryOperator::Eq) {
1342 continue;
1343 }
1344 match (left.as_ref(), right.as_ref()) {
1345 (Expr::Identifier(left), Expr::Identifier(right)) => {
1346 uf.union(left, right);
1347 }
1348 (Expr::Identifier(ident), Expr::Literal(literal))
1349 | (Expr::Literal(literal), Expr::Identifier(ident)) => {
1350 uf.add(ident);
1351 if let Some(existing) = literal_bindings.get(ident) {
1352 if !literal_eq(existing, literal) {
1353 literal_conflicts.insert(ident.clone());
1354 let left =
1355 format!("{} = {}", ident, Expr::Literal(existing.clone()).to_sql());
1356 let right =
1357 format!("{} = {}", ident, Expr::Literal(literal.clone()).to_sql());
1358 conflict_pairs.insert((left, right));
1359 }
1360 } else {
1361 literal_bindings.insert(ident.clone(), literal.clone());
1362 }
1363 }
1364 _ => {}
1365 }
1366 }
1367
1368 let mut class_literals = std::collections::HashMap::<String, Option<Literal>>::new();
1369 if !literal_conflicts.is_empty() {
1370 ctx.record_literal_conflicts(conflict_pairs);
1371 }
1372 for ident in &literal_conflicts {
1373 let root = uf.find(ident);
1374 class_literals.insert(root, None);
1375 }
1376 for (ident, literal) in &literal_bindings {
1377 if literal_conflicts.contains(ident) {
1378 continue;
1379 }
1380 let root = uf.find(ident);
1381 match class_literals.get(&root) {
1382 None => {
1383 class_literals.insert(root, Some(literal.clone()));
1384 }
1385 Some(Some(existing_literal)) => {
1386 if !literal_eq(existing_literal, literal) {
1387 class_literals.insert(root, None);
1388 }
1389 }
1390 Some(None) => {}
1391 }
1392 }
1393
1394 let mut groups = std::collections::HashMap::<String, Vec<String>>::new();
1395 for key in uf.keys() {
1396 let root = uf.find(&key);
1397 groups.entry(root).or_default().push(key);
1398 }
1399
1400 let mut inferred = Vec::new();
1401 for members in groups.values() {
1402 if members.len() <= 1 {
1403 continue;
1404 }
1405 let mut members = members.clone();
1406 members.sort();
1407 let canonical = members[0].clone();
1408 for ident in members.into_iter().skip(1) {
1409 let forward = format!("{canonical} = {ident}");
1410 let backward = format!("{ident} = {canonical}");
1411 if existing.contains(&forward) || existing.contains(&backward) {
1412 continue;
1413 }
1414 existing.insert(forward);
1415 inferred.push(Expr::BinaryOp {
1416 left: Box::new(Expr::Identifier(canonical.clone())),
1417 op: BinaryOperator::Eq,
1418 right: Box::new(Expr::Identifier(ident)),
1419 });
1420 }
1421 }
1422
1423 for (root, literal) in class_literals {
1424 let Some(literal) = literal else {
1425 continue;
1426 };
1427 let Some(members) = groups.get(&root) else {
1428 continue;
1429 };
1430 for ident in members {
1431 let expr = Expr::BinaryOp {
1432 left: Box::new(Expr::Identifier(ident.clone())),
1433 op: BinaryOperator::Eq,
1434 right: Box::new(Expr::Literal(literal.clone())),
1435 };
1436 if existing.insert(expr.to_sql()) {
1437 inferred.push(expr);
1438 }
1439 }
1440 }
1441
1442 if inferred.is_empty() {
1443 return (predicate.clone(), false);
1444 }
1445 let mut all = conjuncts;
1446 all.extend(inferred);
1447 (
1448 combine_conjuncts(all).unwrap_or_else(|| predicate.clone()),
1449 true,
1450 )
1451}
1452
1453fn split_predicates_by_source(
1454 combined: &Expr,
1455 filter_predicate: &Expr,
1456 join_predicate: &Expr,
1457) -> (Vec<Expr>, Vec<Expr>) {
1458 let filter_set = split_conjuncts(filter_predicate)
1459 .into_iter()
1460 .map(|expr| expr.to_sql())
1461 .collect::<std::collections::HashSet<_>>();
1462 let join_set = split_conjuncts(join_predicate)
1463 .into_iter()
1464 .map(|expr| expr.to_sql())
1465 .collect::<std::collections::HashSet<_>>();
1466 let filter_prefixes = collect_table_prefixes(filter_predicate);
1467 let mut filter_preds = Vec::new();
1468 let mut join_preds = Vec::new();
1469 for expr in split_conjuncts(combined) {
1470 let sql = expr.to_sql();
1471 if join_set.contains(&sql) {
1472 join_preds.push(expr);
1473 } else if filter_set.contains(&sql) {
1474 filter_preds.push(expr);
1475 } else if is_join_compatible(&expr) {
1476 join_preds.push(expr);
1477 } else if is_same_table_predicate(&expr, &filter_prefixes) {
1478 filter_preds.push(expr);
1479 } else {
1480 join_preds.push(expr);
1481 }
1482 }
1483 (filter_preds, join_preds)
1484}
1485
1486fn is_join_compatible(expr: &Expr) -> bool {
1487 let idents = collect_identifiers(expr);
1488 if idents.len() < 2 {
1489 return false;
1490 }
1491 let mut tables = std::collections::HashSet::new();
1492 for ident in idents {
1493 if let Some(prefix) = table_prefix(&ident) {
1494 tables.insert(prefix.to_string());
1495 }
1496 }
1497 tables.len() >= 2
1498}
1499
1500fn collect_table_prefixes(expr: &Expr) -> std::collections::HashSet<String> {
1501 let idents = collect_identifiers(expr);
1502 let mut tables = std::collections::HashSet::new();
1503 for ident in idents {
1504 if let Some(prefix) = table_prefix(&ident) {
1505 tables.insert(prefix.to_string());
1506 }
1507 }
1508 tables
1509}
1510
1511fn is_same_table_predicate(expr: &Expr, known_tables: &std::collections::HashSet<String>) -> bool {
1512 let idents = collect_identifiers(expr);
1513 if idents.is_empty() {
1514 return false;
1515 }
1516 let mut tables = std::collections::HashSet::new();
1517 for ident in idents {
1518 let Some(prefix) = table_prefix(&ident) else {
1519 return false;
1520 };
1521 tables.insert(prefix.to_string());
1522 }
1523 if tables.len() != 1 {
1524 return false;
1525 }
1526 if known_tables.is_empty() {
1527 return false;
1528 }
1529 tables.is_subset(known_tables)
1530}
1531
1532fn literal_eq(left: &Literal, right: &Literal) -> bool {
1533 match (left, right) {
1534 (Literal::String(left), Literal::String(right)) => left == right,
1535 (Literal::Number(left), Literal::Number(right)) => left == right,
1536 (Literal::Bool(left), Literal::Bool(right)) => left == right,
1537 _ => false,
1538 }
1539}
1540
1541struct UnionFind {
1542 parent: std::collections::HashMap<String, String>,
1543}
1544
1545impl UnionFind {
1546 fn new() -> Self {
1547 Self {
1548 parent: std::collections::HashMap::new(),
1549 }
1550 }
1551
1552 fn add(&mut self, key: &str) {
1553 self.parent
1554 .entry(key.to_string())
1555 .or_insert_with(|| key.to_string());
1556 }
1557
1558 fn find(&mut self, key: &str) -> String {
1559 self.add(key);
1560 let mut current = key.to_string();
1561 let mut path = Vec::new();
1562 loop {
1563 let parent = self
1564 .parent
1565 .get(¤t)
1566 .cloned()
1567 .unwrap_or_else(|| current.clone());
1568 if parent == current {
1569 break;
1570 }
1571 path.push(current);
1572 current = parent;
1573 }
1574 for node in path {
1575 self.parent.insert(node, current.clone());
1576 }
1577 current
1578 }
1579
1580 fn union(&mut self, left: &str, right: &str) {
1581 let left_root = self.find(left);
1582 let right_root = self.find(right);
1583 if left_root != right_root {
1584 self.parent.insert(left_root, right_root);
1585 }
1586 }
1587
1588 fn keys(&self) -> Vec<String> {
1589 self.parent.keys().cloned().collect()
1590 }
1591}