1use crate::utils::{collect_identifiers, collect_tables, combine_conjuncts, split_conjuncts, table_prefix};
2use chryso_core::ast::{BinaryOperator, Expr, Literal};
3use chryso_planner::LogicalPlan;
4
5pub trait Rule {
6 fn name(&self) -> &str;
7 fn apply(&self, plan: &LogicalPlan) -> Vec<LogicalPlan>;
8}
9
10pub struct RuleSet {
11 rules: Vec<Box<dyn Rule + Send + Sync>>,
12}
13
14impl RuleSet {
15 pub fn new() -> Self {
16 Self { rules: Vec::new() }
17 }
18
19 pub fn with_rule(mut self, rule: impl Rule + Send + Sync + 'static) -> Self {
20 self.rules.push(Box::new(rule));
21 self
22 }
23
24 pub fn apply_all(&self, plan: &LogicalPlan) -> Vec<LogicalPlan> {
25 let mut results = Vec::new();
26 for rule in &self.rules {
27 results.extend(rule.apply(plan));
28 }
29 results
30 }
31
32 pub fn iter(&self) -> impl Iterator<Item = &Box<dyn Rule + Send + Sync>> {
33 self.rules.iter()
34 }
35}
36
37impl Default for RuleSet {
38 fn default() -> Self {
39 RuleSet::new()
40 .with_rule(MergeFilters)
41 .with_rule(PruneProjection)
42 .with_rule(MergeProjections)
43 .with_rule(RemoveTrueFilter)
44 .with_rule(FilterPushdown)
45 .with_rule(FilterJoinPushdown)
46 .with_rule(PredicateInference)
47 .with_rule(JoinPredicatePushdown)
48 .with_rule(FilterOrDedup)
49 .with_rule(NormalizePredicates)
50 .with_rule(JoinCommute)
51 .with_rule(AggregatePredicatePushdown)
52 .with_rule(LimitPushdown)
53 .with_rule(TopNRule)
54 }
55}
56
57impl RuleSet {
58 pub fn detect_conflicts(&self) -> Vec<String> {
59 let mut seen = std::collections::HashMap::new();
60 for rule in &self.rules {
61 *seen.entry(rule.name().to_string()).or_insert(0usize) += 1;
62 }
63 seen.into_iter()
64 .filter_map(|(name, count)| if count > 1 { Some(name) } else { None })
65 .collect()
66 }
67}
68
69impl std::fmt::Debug for RuleSet {
70 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 f.debug_struct("RuleSet")
72 .field("rule_count", &self.rules.len())
73 .finish()
74 }
75}
76
77pub struct NoopRule;
78
79impl Rule for NoopRule {
80 fn name(&self) -> &str {
81 "noop"
82 }
83
84 fn apply(&self, _plan: &LogicalPlan) -> Vec<LogicalPlan> {
85 Vec::new()
86 }
87}
88
89pub struct MergeFilters;
90
91impl Rule for MergeFilters {
92 fn name(&self) -> &str {
93 "merge_filters"
94 }
95
96 fn apply(&self, plan: &LogicalPlan) -> Vec<LogicalPlan> {
97 let LogicalPlan::Filter { predicate, input } = plan else {
98 return Vec::new();
99 };
100 let LogicalPlan::Filter {
101 predicate: inner_predicate,
102 input: inner_input,
103 } = input.as_ref()
104 else {
105 return Vec::new();
106 };
107 let merged = LogicalPlan::Filter {
108 predicate: chryso_core::ast::Expr::BinaryOp {
109 left: Box::new(inner_predicate.clone()),
110 op: BinaryOperator::And,
111 right: Box::new(predicate.clone()),
112 },
113 input: inner_input.clone(),
114 };
115 vec![merged]
116 }
117}
118
119pub struct PruneProjection;
120
121impl Rule for PruneProjection {
122 fn name(&self) -> &str {
123 "prune_projection"
124 }
125
126 fn apply(&self, plan: &LogicalPlan) -> Vec<LogicalPlan> {
127 let LogicalPlan::Projection { exprs, input } = plan else {
128 return Vec::new();
129 };
130 if exprs.len() == 1 && matches!(exprs[0], Expr::Wildcard) {
131 return vec![(*input.clone())];
132 }
133 Vec::new()
134 }
135}
136
137pub struct MergeProjections;
138
139impl Rule for MergeProjections {
140 fn name(&self) -> &str {
141 "merge_projections"
142 }
143
144 fn apply(&self, plan: &LogicalPlan) -> Vec<LogicalPlan> {
145 let LogicalPlan::Projection { exprs, input } = plan else {
146 return Vec::new();
147 };
148 let LogicalPlan::Projection {
149 exprs: inner_exprs,
150 input: inner_input,
151 } = input.as_ref()
152 else {
153 return Vec::new();
154 };
155 if projection_subset(exprs, inner_exprs) {
156 return vec![LogicalPlan::Projection {
157 exprs: exprs.clone(),
158 input: inner_input.clone(),
159 }];
160 }
161 Vec::new()
162 }
163}
164
165fn projection_subset(outer: &[Expr], inner: &[Expr]) -> bool {
166 let inner_names = inner
167 .iter()
168 .filter_map(|expr| match expr {
169 Expr::Identifier(name) => Some(name),
170 _ => None,
171 })
172 .collect::<std::collections::HashSet<_>>();
173 if inner_names.is_empty() {
174 return false;
175 }
176 outer.iter().all(|expr| match expr {
177 Expr::Identifier(name) => inner_names.contains(name),
178 Expr::Wildcard => true,
179 _ => false,
180 })
181}
182
183pub struct RemoveTrueFilter;
184
185impl Rule for RemoveTrueFilter {
186 fn name(&self) -> &str {
187 "remove_true_filter"
188 }
189
190 fn apply(&self, plan: &LogicalPlan) -> Vec<LogicalPlan> {
191 let LogicalPlan::Filter { predicate, input } = plan else {
192 return Vec::new();
193 };
194 match predicate {
195 Expr::Literal(chryso_core::ast::Literal::Bool(true)) => vec![*input.clone()],
196 _ => Vec::new(),
197 }
198 }
199}
200
201pub struct FilterPushdown;
202
203impl Rule for FilterPushdown {
204 fn name(&self) -> &str {
205 "filter_pushdown"
206 }
207
208 fn apply(&self, plan: &LogicalPlan) -> Vec<LogicalPlan> {
209 let LogicalPlan::Filter { predicate, input } = plan else {
210 return Vec::new();
211 };
212 let LogicalPlan::Projection { exprs, input } = input.as_ref() else {
213 if let LogicalPlan::Sort { order_by, input } = input.as_ref() {
214 return vec![LogicalPlan::Sort {
215 order_by: order_by.clone(),
216 input: Box::new(LogicalPlan::Filter {
217 predicate: predicate.clone(),
218 input: input.clone(),
219 }),
220 }];
221 }
222 return Vec::new();
223 };
224 if !projection_is_passthrough(exprs) {
225 return Vec::new();
226 }
227 let pushed = LogicalPlan::Projection {
228 exprs: exprs.clone(),
229 input: Box::new(LogicalPlan::Filter {
230 predicate: predicate.clone(),
231 input: input.clone(),
232 }),
233 };
234 vec![pushed]
235 }
236}
237
238pub struct FilterJoinPushdown;
239
240impl Rule for FilterJoinPushdown {
241 fn name(&self) -> &str {
242 "filter_join_pushdown"
243 }
244
245 fn apply(&self, plan: &LogicalPlan) -> Vec<LogicalPlan> {
246 let LogicalPlan::Filter { predicate, input } = plan else {
247 return Vec::new();
248 };
249 let LogicalPlan::Join {
250 join_type,
251 left,
252 right,
253 on,
254 } = input.as_ref()
255 else {
256 return Vec::new();
257 };
258 if !matches!(join_type, chryso_core::ast::JoinType::Inner) {
259 return Vec::new();
260 }
261
262 let left_tables = collect_tables(left.as_ref());
263 let right_tables = collect_tables(right.as_ref());
264 let mut left_preds = Vec::new();
265 let mut right_preds = Vec::new();
266 let mut remaining = Vec::new();
267
268 for conjunct in split_conjuncts(predicate) {
269 let idents = collect_identifiers(&conjunct);
270 if idents.is_empty() {
271 remaining.push(conjunct);
272 continue;
273 }
274 let mut side = None;
275 let mut ambiguous = false;
276 for ident in &idents {
277 let Some(prefix) = table_prefix(ident) else {
278 ambiguous = true;
279 break;
280 };
281 let in_left = left_tables.contains(prefix);
282 let in_right = right_tables.contains(prefix);
283 if in_left && !in_right {
284 side = match side {
285 None => Some(Side::Left),
286 Some(Side::Left) => Some(Side::Left),
287 Some(Side::Right) => {
288 ambiguous = true;
289 break;
290 }
291 };
292 } else if in_right && !in_left {
293 side = match side {
294 None => Some(Side::Right),
295 Some(Side::Right) => Some(Side::Right),
296 Some(Side::Left) => {
297 ambiguous = true;
298 break;
299 }
300 };
301 } else {
302 ambiguous = true;
303 break;
304 }
305 }
306 if ambiguous {
307 remaining.push(conjunct);
308 continue;
309 }
310 match side {
311 Some(Side::Left) => left_preds.push(conjunct),
312 Some(Side::Right) => right_preds.push(conjunct),
313 None => remaining.push(conjunct),
314 }
315 }
316
317 if left_preds.is_empty() && right_preds.is_empty() && remaining.is_empty() {
318 return Vec::new();
319 }
320
321 let new_left = if let Some(expr) = combine_conjuncts(left_preds) {
322 LogicalPlan::Filter {
323 predicate: expr,
324 input: left.clone(),
325 }
326 } else {
327 *left.clone()
328 };
329 let new_right = if let Some(expr) = combine_conjuncts(right_preds) {
330 LogicalPlan::Filter {
331 predicate: expr,
332 input: right.clone(),
333 }
334 } else {
335 *right.clone()
336 };
337 let new_on = if let Some(expr) = combine_conjuncts(remaining) {
338 Expr::BinaryOp {
339 left: Box::new(on.clone()),
340 op: BinaryOperator::And,
341 right: Box::new(expr),
342 }
343 } else {
344 on.clone()
345 };
346 let joined = LogicalPlan::Join {
347 join_type: *join_type,
348 left: Box::new(new_left),
349 right: Box::new(new_right),
350 on: new_on,
351 };
352 vec![joined]
353 }
354}
355
356pub struct JoinPredicatePushdown;
357
358impl Rule for JoinPredicatePushdown {
359 fn name(&self) -> &str {
360 "join_predicate_pushdown"
361 }
362
363 fn apply(&self, plan: &LogicalPlan) -> Vec<LogicalPlan> {
364 let LogicalPlan::Join {
365 join_type,
366 left,
367 right,
368 on,
369 } = plan
370 else {
371 return Vec::new();
372 };
373 if !matches!(join_type, chryso_core::ast::JoinType::Inner) {
374 return Vec::new();
375 }
376
377 let left_tables = collect_tables(left.as_ref());
378 let right_tables = collect_tables(right.as_ref());
379 let mut left_preds = Vec::new();
380 let mut right_preds = Vec::new();
381 let mut remaining = Vec::new();
382
383 for conjunct in split_conjuncts(on) {
384 let idents = collect_identifiers(&conjunct);
385 if idents.is_empty() {
386 remaining.push(conjunct);
387 continue;
388 }
389 let mut side = None;
390 let mut ambiguous = false;
391 for ident in &idents {
392 let Some(prefix) = table_prefix(ident) else {
393 ambiguous = true;
394 break;
395 };
396 let in_left = left_tables.contains(prefix);
397 let in_right = right_tables.contains(prefix);
398 if in_left && !in_right {
399 side = match side {
400 None => Some(Side::Left),
401 Some(Side::Left) => Some(Side::Left),
402 Some(Side::Right) => {
403 ambiguous = true;
404 break;
405 }
406 };
407 } else if in_right && !in_left {
408 side = match side {
409 None => Some(Side::Right),
410 Some(Side::Right) => Some(Side::Right),
411 Some(Side::Left) => {
412 ambiguous = true;
413 break;
414 }
415 };
416 } else {
417 ambiguous = true;
418 break;
419 }
420 }
421 if ambiguous {
422 remaining.push(conjunct);
423 continue;
424 }
425 match side {
426 Some(Side::Left) => left_preds.push(conjunct),
427 Some(Side::Right) => right_preds.push(conjunct),
428 None => remaining.push(conjunct),
429 }
430 }
431
432 if left_preds.is_empty() && right_preds.is_empty() {
433 return Vec::new();
434 }
435
436 let new_left = if let Some(expr) = combine_conjuncts(left_preds) {
437 LogicalPlan::Filter {
438 predicate: expr,
439 input: left.clone(),
440 }
441 } else {
442 *left.clone()
443 };
444 let new_right = if let Some(expr) = combine_conjuncts(right_preds) {
445 LogicalPlan::Filter {
446 predicate: expr,
447 input: right.clone(),
448 }
449 } else {
450 *right.clone()
451 };
452 let new_on =
453 combine_conjuncts(remaining).unwrap_or_else(|| Expr::Literal(Literal::Bool(true)));
454 vec![LogicalPlan::Join {
455 join_type: *join_type,
456 left: Box::new(new_left),
457 right: Box::new(new_right),
458 on: new_on,
459 }]
460 }
461}
462
463fn projection_is_passthrough(exprs: &[Expr]) -> bool {
464 exprs.iter().all(|expr| match expr {
465 Expr::Identifier(_) | Expr::Wildcard => true,
466 _ => false,
467 })
468}
469
470pub struct NormalizePredicates;
471
472impl Rule for NormalizePredicates {
473 fn name(&self) -> &str {
474 "normalize_predicates"
475 }
476
477 fn apply(&self, plan: &LogicalPlan) -> Vec<LogicalPlan> {
478 let LogicalPlan::Filter { predicate, input } = plan else {
479 return Vec::new();
480 };
481 let normalized = predicate.normalize();
482 if normalized.structural_eq(predicate) {
483 return Vec::new();
484 }
485 vec![LogicalPlan::Filter {
486 predicate: normalized,
487 input: input.clone(),
488 }]
489 }
490}
491
492pub struct PredicateInference;
493
494impl Rule for PredicateInference {
495 fn name(&self) -> &str {
496 "predicate_inference"
497 }
498
499 fn apply(&self, plan: &LogicalPlan) -> Vec<LogicalPlan> {
500 match plan {
501 LogicalPlan::Filter {
502 predicate,
503 input,
504 } => match input.as_ref() {
505 LogicalPlan::Join {
506 join_type,
507 left,
508 right,
509 on,
510 } if matches!(join_type, chryso_core::ast::JoinType::Inner) => {
511 let combined = Expr::BinaryOp {
512 left: Box::new(predicate.clone()),
513 op: BinaryOperator::And,
514 right: Box::new(on.clone()),
515 };
516 let (inferred, changed) = infer_predicates(&combined);
517 if !changed {
518 return Vec::new();
519 }
520 let (filter_predicates, join_predicates) =
521 split_predicates_by_source(&inferred, predicate, on);
522 let join_expr = combine_conjuncts(join_predicates)
523 .unwrap_or_else(|| Expr::Literal(Literal::Bool(true)));
524 let join_plan = LogicalPlan::Join {
525 join_type: *join_type,
526 left: left.clone(),
527 right: right.clone(),
528 on: join_expr,
529 };
530 let filter_expr = combine_conjuncts(filter_predicates)
531 .unwrap_or_else(|| Expr::Literal(Literal::Bool(true)));
532 vec![LogicalPlan::Filter {
533 predicate: filter_expr,
534 input: Box::new(join_plan),
535 }]
536 }
537 _ => {
538 let (predicate, changed) = infer_predicates(predicate);
539 if !changed {
540 return Vec::new();
541 }
542 vec![LogicalPlan::Filter {
543 predicate,
544 input: input.clone(),
545 }]
546 }
547 },
548 LogicalPlan::Join {
549 join_type,
550 left,
551 right,
552 on,
553 } if matches!(join_type, chryso_core::ast::JoinType::Inner) => {
554 let (on, changed) = infer_predicates(on);
555 if !changed {
556 return Vec::new();
557 }
558 vec![LogicalPlan::Join {
559 join_type: *join_type,
560 left: left.clone(),
561 right: right.clone(),
562 on,
563 }]
564 }
565 _ => Vec::new(),
566 }
567 }
568}
569
570pub struct JoinCommute;
571
572impl Rule for JoinCommute {
573 fn name(&self) -> &str {
574 "join_commute"
575 }
576
577 fn apply(&self, plan: &LogicalPlan) -> Vec<LogicalPlan> {
578 let LogicalPlan::Join {
579 join_type,
580 left,
581 right,
582 on,
583 } = plan
584 else {
585 return Vec::new();
586 };
587 if !matches!(join_type, chryso_core::ast::JoinType::Inner) {
588 return Vec::new();
589 }
590 vec![LogicalPlan::Join {
591 join_type: *join_type,
592 left: right.clone(),
593 right: left.clone(),
594 on: on.clone(),
595 }]
596 }
597}
598
599pub struct FilterOrDedup;
600
601impl Rule for FilterOrDedup {
602 fn name(&self) -> &str {
603 "filter_or_dedup"
604 }
605
606 fn apply(&self, plan: &LogicalPlan) -> Vec<LogicalPlan> {
607 let LogicalPlan::Filter { predicate, input } = plan else {
608 return Vec::new();
609 };
610 let Expr::BinaryOp { left, op, right } = predicate else {
611 return Vec::new();
612 };
613 if !matches!(op, BinaryOperator::Or) {
614 return Vec::new();
615 }
616 if left.structural_eq(right) {
617 return vec![LogicalPlan::Filter {
618 predicate: (*left.clone()),
619 input: input.clone(),
620 }];
621 }
622 Vec::new()
623 }
624}
625
626pub struct LimitPushdown;
627
628impl Rule for LimitPushdown {
629 fn name(&self) -> &str {
630 "limit_pushdown"
631 }
632
633 fn apply(&self, plan: &LogicalPlan) -> Vec<LogicalPlan> {
634 let LogicalPlan::Limit {
635 limit,
636 offset,
637 input,
638 } = plan
639 else {
640 return Vec::new();
641 };
642 if offset.is_some() {
643 return Vec::new();
644 }
645 let inner = input.as_ref();
646 match inner {
647 LogicalPlan::Filter { predicate, input } => vec![LogicalPlan::Filter {
648 predicate: predicate.clone(),
649 input: Box::new(LogicalPlan::Limit {
650 limit: *limit,
651 offset: *offset,
652 input: input.clone(),
653 }),
654 }],
655 LogicalPlan::Projection { exprs, input } => vec![LogicalPlan::Projection {
656 exprs: exprs.clone(),
657 input: Box::new(LogicalPlan::Limit {
658 limit: *limit,
659 offset: *offset,
660 input: input.clone(),
661 }),
662 }],
663 _ => Vec::new(),
664 }
665 }
666}
667
668pub struct TopNRule;
669
670impl Rule for TopNRule {
671 fn name(&self) -> &str {
672 "topn_rule"
673 }
674
675 fn apply(&self, plan: &LogicalPlan) -> Vec<LogicalPlan> {
676 let LogicalPlan::Limit {
677 limit: Some(limit),
678 offset: None,
679 input,
680 } = plan
681 else {
682 return Vec::new();
683 };
684 let LogicalPlan::Sort { order_by, input } = input.as_ref() else {
685 return Vec::new();
686 };
687 vec![LogicalPlan::TopN {
688 order_by: order_by.clone(),
689 limit: *limit,
690 input: input.clone(),
691 }]
692 }
693}
694
695#[cfg(test)]
696mod tests {
697 use super::{
698 FilterJoinPushdown, FilterOrDedup, FilterPushdown, JoinPredicatePushdown, LimitPushdown,
699 MergeFilters, MergeProjections, NormalizePredicates, PredicateInference, PruneProjection,
700 RemoveTrueFilter, Rule, TopNRule,
701 };
702 use crate::utils::split_conjuncts;
703 use chryso_core::ast::{BinaryOperator, Expr};
704 use chryso_planner::LogicalPlan;
705
706 #[test]
707 fn merge_filters_combines_predicates() {
708 let plan = LogicalPlan::Filter {
709 predicate: Expr::Identifier("a".to_string()),
710 input: Box::new(LogicalPlan::Filter {
711 predicate: Expr::Identifier("b".to_string()),
712 input: Box::new(LogicalPlan::Scan {
713 table: "t".to_string(),
714 }),
715 }),
716 };
717 let rule = MergeFilters;
718 let results = rule.apply(&plan);
719 assert_eq!(results.len(), 1);
720 let LogicalPlan::Filter { predicate, .. } = &results[0] else {
721 panic!("expected filter");
722 };
723 let Expr::BinaryOp { op, .. } = predicate else {
724 panic!("expected binary op");
725 };
726 assert!(matches!(op, BinaryOperator::And));
727 }
728
729 #[test]
730 fn prune_projection_removes_star() {
731 let plan = LogicalPlan::Projection {
732 exprs: vec![Expr::Wildcard],
733 input: Box::new(LogicalPlan::Scan {
734 table: "t".to_string(),
735 }),
736 };
737 let rule = PruneProjection;
738 let results = rule.apply(&plan);
739 assert_eq!(results.len(), 1);
740 }
741
742 #[test]
743 fn filter_pushdown_under_projection() {
744 let plan = LogicalPlan::Filter {
745 predicate: Expr::Identifier("x".to_string()),
746 input: Box::new(LogicalPlan::Projection {
747 exprs: vec![Expr::Identifier("x".to_string())],
748 input: Box::new(LogicalPlan::Scan {
749 table: "t".to_string(),
750 }),
751 }),
752 };
753 let rule = FilterPushdown;
754 let results = rule.apply(&plan);
755 assert_eq!(results.len(), 1);
756 }
757
758 #[test]
759 fn filter_pushdown_under_sort() {
760 let plan = LogicalPlan::Filter {
761 predicate: Expr::Identifier("x".to_string()),
762 input: Box::new(LogicalPlan::Sort {
763 order_by: vec![chryso_core::ast::OrderByExpr {
764 expr: Expr::Identifier("id".to_string()),
765 asc: true,
766 nulls_first: None,
767 }],
768 input: Box::new(LogicalPlan::Scan {
769 table: "t".to_string(),
770 }),
771 }),
772 };
773 let rule = FilterPushdown;
774 let results = rule.apply(&plan);
775 assert_eq!(results.len(), 1);
776 let LogicalPlan::Sort { input, .. } = &results[0] else {
777 panic!("expected sort");
778 };
779 assert!(matches!(input.as_ref(), LogicalPlan::Filter { .. }));
780 }
781
782 #[test]
783 fn normalize_predicates_orders_and() {
784 let plan = LogicalPlan::Filter {
785 predicate: Expr::BinaryOp {
786 left: Box::new(Expr::Identifier("b".to_string())),
787 op: BinaryOperator::And,
788 right: Box::new(Expr::Identifier("a".to_string())),
789 },
790 input: Box::new(LogicalPlan::Scan {
791 table: "t".to_string(),
792 }),
793 };
794 let rule = NormalizePredicates;
795 let results = rule.apply(&plan);
796 assert_eq!(results.len(), 1);
797 }
798
799 #[test]
800 fn normalize_predicates_removes_true_and() {
801 let plan = LogicalPlan::Filter {
802 predicate: Expr::BinaryOp {
803 left: Box::new(Expr::Literal(chryso_core::ast::Literal::Bool(true))),
804 op: BinaryOperator::And,
805 right: Box::new(Expr::Identifier("x".to_string())),
806 },
807 input: Box::new(LogicalPlan::Scan {
808 table: "t".to_string(),
809 }),
810 };
811 let rule = NormalizePredicates;
812 let results = rule.apply(&plan);
813 assert_eq!(results.len(), 1);
814 let LogicalPlan::Filter { predicate, .. } = &results[0] else {
815 panic!("expected filter");
816 };
817 assert_eq!(predicate.to_sql(), "x");
818 }
819
820 #[test]
821 fn normalize_predicates_removes_nested_true_and() {
822 let plan = LogicalPlan::Filter {
823 predicate: Expr::BinaryOp {
824 left: Box::new(Expr::Literal(chryso_core::ast::Literal::Bool(true))),
825 op: BinaryOperator::And,
826 right: Box::new(Expr::BinaryOp {
827 left: Box::new(Expr::Literal(chryso_core::ast::Literal::Bool(true))),
828 op: BinaryOperator::And,
829 right: Box::new(Expr::Identifier("x".to_string())),
830 }),
831 },
832 input: Box::new(LogicalPlan::Scan {
833 table: "t".to_string(),
834 }),
835 };
836 let rule = NormalizePredicates;
837 let results = rule.apply(&plan);
838 assert_eq!(results.len(), 1);
839 let LogicalPlan::Filter { predicate, .. } = &results[0] else {
840 panic!("expected filter");
841 };
842 assert_eq!(predicate.to_sql(), "x");
843 }
844
845 #[test]
846 fn detect_rule_conflicts() {
847 let rules = crate::rules::RuleSet::new()
848 .with_rule(MergeFilters)
849 .with_rule(MergeFilters);
850 let conflicts = rules.detect_conflicts();
851 assert_eq!(conflicts, vec!["merge_filters".to_string()]);
852 }
853
854 #[test]
855 fn topn_rule_rewrites_sort_limit() {
856 let plan = LogicalPlan::Limit {
857 limit: Some(10),
858 offset: None,
859 input: Box::new(LogicalPlan::Sort {
860 order_by: vec![chryso_core::ast::OrderByExpr {
861 expr: Expr::Identifier("id".to_string()),
862 asc: true,
863 nulls_first: None,
864 }],
865 input: Box::new(LogicalPlan::Scan {
866 table: "t".to_string(),
867 }),
868 }),
869 };
870 let rule = TopNRule;
871 let results = rule.apply(&plan);
872 assert_eq!(results.len(), 1);
873 }
874
875 #[test]
876 fn limit_pushdown_under_projection() {
877 let plan = LogicalPlan::Limit {
878 limit: Some(5),
879 offset: None,
880 input: Box::new(LogicalPlan::Projection {
881 exprs: vec![Expr::Identifier("id".to_string())],
882 input: Box::new(LogicalPlan::Scan {
883 table: "t".to_string(),
884 }),
885 }),
886 };
887 let rule = LimitPushdown;
888 let results = rule.apply(&plan);
889 assert_eq!(results.len(), 1);
890 }
891
892 #[test]
893 fn merge_projections_keeps_outer() {
894 let plan = LogicalPlan::Projection {
895 exprs: vec![Expr::Identifier("id".to_string())],
896 input: Box::new(LogicalPlan::Projection {
897 exprs: vec![Expr::Identifier("id".to_string()), Expr::Identifier("name".to_string())],
898 input: Box::new(LogicalPlan::Scan {
899 table: "t".to_string(),
900 }),
901 }),
902 };
903 let rule = MergeProjections;
904 let results = rule.apply(&plan);
905 assert_eq!(results.len(), 1);
906 let LogicalPlan::Projection { exprs, .. } = &results[0] else {
907 panic!("expected projection");
908 };
909 assert_eq!(exprs.len(), 1);
910 }
911
912 #[test]
913 fn remove_true_filter() {
914 let plan = LogicalPlan::Filter {
915 predicate: Expr::Literal(chryso_core::ast::Literal::Bool(true)),
916 input: Box::new(LogicalPlan::Scan {
917 table: "t".to_string(),
918 }),
919 };
920 let rule = RemoveTrueFilter;
921 let results = rule.apply(&plan);
922 assert_eq!(results.len(), 1);
923 assert!(matches!(results[0], LogicalPlan::Scan { .. }));
924 }
925
926 #[test]
927 fn remove_true_filter_ignores_binary_predicate() {
928 let plan = LogicalPlan::Filter {
929 predicate: Expr::BinaryOp {
930 left: Box::new(Expr::Literal(chryso_core::ast::Literal::Bool(true))),
931 op: BinaryOperator::And,
932 right: Box::new(Expr::Identifier("x".to_string())),
933 },
934 input: Box::new(LogicalPlan::Scan {
935 table: "t".to_string(),
936 }),
937 };
938 let rule = RemoveTrueFilter;
939 let results = rule.apply(&plan);
940 assert!(results.is_empty());
941 }
942
943 #[test]
944 fn filter_join_pushdown_left() {
945 let plan = LogicalPlan::Filter {
946 predicate: Expr::BinaryOp {
947 left: Box::new(Expr::Identifier("t1.id".to_string())),
948 op: BinaryOperator::Eq,
949 right: Box::new(Expr::Literal(chryso_core::ast::Literal::Number(1.0))),
950 },
951 input: Box::new(LogicalPlan::Join {
952 join_type: chryso_core::ast::JoinType::Inner,
953 left: Box::new(LogicalPlan::Scan {
954 table: "t1".to_string(),
955 }),
956 right: Box::new(LogicalPlan::Scan {
957 table: "t2".to_string(),
958 }),
959 on: Expr::Identifier("t1.id = t2.id".to_string()),
960 }),
961 };
962 let rule = FilterJoinPushdown;
963 let results = rule.apply(&plan);
964 assert_eq!(results.len(), 1);
965 let LogicalPlan::Join { left, .. } = &results[0] else {
966 panic!("expected join");
967 };
968 assert!(matches!(left.as_ref(), LogicalPlan::Filter { .. }));
969 }
970
971 #[test]
972 fn filter_join_pushdown_keeps_cross_predicate() {
973 let plan = LogicalPlan::Filter {
974 predicate: Expr::BinaryOp {
975 left: Box::new(Expr::Identifier("t1.id".to_string())),
976 op: BinaryOperator::Eq,
977 right: Box::new(Expr::Identifier("t2.id".to_string())),
978 },
979 input: Box::new(LogicalPlan::Join {
980 join_type: chryso_core::ast::JoinType::Inner,
981 left: Box::new(LogicalPlan::Scan {
982 table: "t1".to_string(),
983 }),
984 right: Box::new(LogicalPlan::Scan {
985 table: "t2".to_string(),
986 }),
987 on: Expr::Identifier("t1.id = t2.id".to_string()),
988 }),
989 };
990 let rule = FilterJoinPushdown;
991 let results = rule.apply(&plan);
992 assert_eq!(results.len(), 1);
993 let LogicalPlan::Join { on, .. } = &results[0] else {
994 panic!("expected join");
995 };
996 let Expr::BinaryOp { op, .. } = on else {
997 panic!("expected binary op");
998 };
999 assert!(matches!(op, BinaryOperator::And));
1000 }
1001
1002 #[test]
1003 fn filter_or_dedup_removes_duplicate_predicates() {
1004 let plan = LogicalPlan::Filter {
1005 predicate: Expr::BinaryOp {
1006 left: Box::new(Expr::Identifier("a".to_string())),
1007 op: BinaryOperator::Or,
1008 right: Box::new(Expr::Identifier("a".to_string())),
1009 },
1010 input: Box::new(LogicalPlan::Scan {
1011 table: "t".to_string(),
1012 }),
1013 };
1014 let rule = FilterOrDedup;
1015 let results = rule.apply(&plan);
1016 assert_eq!(results.len(), 1);
1017 let LogicalPlan::Filter { predicate, .. } = &results[0] else {
1018 panic!("expected filter");
1019 };
1020 assert!(matches!(predicate, Expr::Identifier(name) if name == "a"));
1021 }
1022
1023 #[test]
1024 fn join_predicate_pushdown_splits_single_side() {
1025 let plan = LogicalPlan::Join {
1026 join_type: chryso_core::ast::JoinType::Inner,
1027 left: Box::new(LogicalPlan::Scan {
1028 table: "t1".to_string(),
1029 }),
1030 right: Box::new(LogicalPlan::Scan {
1031 table: "t2".to_string(),
1032 }),
1033 on: Expr::BinaryOp {
1034 left: Box::new(Expr::BinaryOp {
1035 left: Box::new(Expr::Identifier("t1.flag".to_string())),
1036 op: BinaryOperator::Eq,
1037 right: Box::new(Expr::Literal(chryso_core::ast::Literal::Bool(true))),
1038 }),
1039 op: BinaryOperator::And,
1040 right: Box::new(Expr::BinaryOp {
1041 left: Box::new(Expr::Identifier("t1.id".to_string())),
1042 op: BinaryOperator::Eq,
1043 right: Box::new(Expr::Identifier("t2.id".to_string())),
1044 }),
1045 },
1046 };
1047 let rule = JoinPredicatePushdown;
1048 let results = rule.apply(&plan);
1049 assert_eq!(results.len(), 1);
1050 let LogicalPlan::Join { left, on, .. } = &results[0] else {
1051 panic!("expected join");
1052 };
1053 assert!(matches!(left.as_ref(), LogicalPlan::Filter { .. }));
1054 assert_eq!(on.to_sql(), "t1.id = t2.id");
1055 }
1056
1057 #[test]
1058 fn predicate_inference_adds_literal_equivalence() {
1059 let plan = LogicalPlan::Filter {
1060 predicate: Expr::BinaryOp {
1061 left: Box::new(Expr::BinaryOp {
1062 left: Box::new(Expr::Identifier("a".to_string())),
1063 op: BinaryOperator::Eq,
1064 right: Box::new(Expr::Identifier("b".to_string())),
1065 }),
1066 op: BinaryOperator::And,
1067 right: Box::new(Expr::BinaryOp {
1068 left: Box::new(Expr::Identifier("a".to_string())),
1069 op: BinaryOperator::Eq,
1070 right: Box::new(Expr::Literal(chryso_core::ast::Literal::Number(1.0))),
1071 }),
1072 },
1073 input: Box::new(LogicalPlan::Scan {
1074 table: "t".to_string(),
1075 }),
1076 };
1077 let rule = PredicateInference;
1078 let results = rule.apply(&plan);
1079 assert_eq!(results.len(), 1);
1080 let LogicalPlan::Filter { predicate, .. } = &results[0] else {
1081 panic!("expected filter");
1082 };
1083 let conjuncts = split_conjuncts(predicate)
1084 .into_iter()
1085 .map(|expr| expr.to_sql())
1086 .collect::<std::collections::HashSet<_>>();
1087 assert!(conjuncts.contains("b = 1"));
1088 }
1089
1090 #[test]
1091 fn predicate_inference_on_join() {
1092 let plan = LogicalPlan::Join {
1093 join_type: chryso_core::ast::JoinType::Inner,
1094 left: Box::new(LogicalPlan::Scan {
1095 table: "t1".to_string(),
1096 }),
1097 right: Box::new(LogicalPlan::Scan {
1098 table: "t2".to_string(),
1099 }),
1100 on: Expr::BinaryOp {
1101 left: Box::new(Expr::BinaryOp {
1102 left: Box::new(Expr::Identifier("t1.id".to_string())),
1103 op: BinaryOperator::Eq,
1104 right: Box::new(Expr::Identifier("t2.id".to_string())),
1105 }),
1106 op: BinaryOperator::And,
1107 right: Box::new(Expr::BinaryOp {
1108 left: Box::new(Expr::Identifier("t2.id".to_string())),
1109 op: BinaryOperator::Eq,
1110 right: Box::new(Expr::Literal(chryso_core::ast::Literal::Number(5.0))),
1111 }),
1112 },
1113 };
1114 let rule = PredicateInference;
1115 let results = rule.apply(&plan);
1116 assert_eq!(results.len(), 1);
1117 let LogicalPlan::Join { on, .. } = &results[0] else {
1118 panic!("expected join");
1119 };
1120 let conjuncts = split_conjuncts(on)
1121 .into_iter()
1122 .map(|expr| expr.to_sql())
1123 .collect::<std::collections::HashSet<_>>();
1124 assert!(conjuncts.contains("t1.id = 5"));
1125 }
1126
1127 #[test]
1128 fn predicate_inference_enables_join_pushdown() {
1129 let plan = LogicalPlan::Join {
1130 join_type: chryso_core::ast::JoinType::Inner,
1131 left: Box::new(LogicalPlan::Scan {
1132 table: "t1".to_string(),
1133 }),
1134 right: Box::new(LogicalPlan::Scan {
1135 table: "t2".to_string(),
1136 }),
1137 on: Expr::BinaryOp {
1138 left: Box::new(Expr::BinaryOp {
1139 left: Box::new(Expr::Identifier("t1.id".to_string())),
1140 op: BinaryOperator::Eq,
1141 right: Box::new(Expr::Identifier("t2.id".to_string())),
1142 }),
1143 op: BinaryOperator::And,
1144 right: Box::new(Expr::BinaryOp {
1145 left: Box::new(Expr::Identifier("t1.id".to_string())),
1146 op: BinaryOperator::Eq,
1147 right: Box::new(Expr::Literal(chryso_core::ast::Literal::Number(42.0))),
1148 }),
1149 },
1150 };
1151 let inferred = PredicateInference.apply(&plan);
1152 assert_eq!(inferred.len(), 1);
1153 let pushed = JoinPredicatePushdown.apply(&inferred[0]);
1154 assert_eq!(pushed.len(), 1);
1155 let LogicalPlan::Join { right, .. } = &pushed[0] else {
1156 panic!("expected join");
1157 };
1158 assert!(matches!(right.as_ref(), LogicalPlan::Filter { .. }));
1159 }
1160
1161 #[test]
1162 fn predicate_inference_propagates_across_filter_join() {
1163 let plan = LogicalPlan::Filter {
1164 predicate: Expr::BinaryOp {
1165 left: Box::new(Expr::Identifier("t1.flag".to_string())),
1166 op: BinaryOperator::Eq,
1167 right: Box::new(Expr::Literal(chryso_core::ast::Literal::Bool(true))),
1168 },
1169 input: Box::new(LogicalPlan::Join {
1170 join_type: chryso_core::ast::JoinType::Inner,
1171 left: Box::new(LogicalPlan::Scan {
1172 table: "t1".to_string(),
1173 }),
1174 right: Box::new(LogicalPlan::Scan {
1175 table: "t2".to_string(),
1176 }),
1177 on: Expr::BinaryOp {
1178 left: Box::new(Expr::BinaryOp {
1179 left: Box::new(Expr::Identifier("t1.id".to_string())),
1180 op: BinaryOperator::Eq,
1181 right: Box::new(Expr::Identifier("t2.id".to_string())),
1182 }),
1183 op: BinaryOperator::And,
1184 right: Box::new(Expr::BinaryOp {
1185 left: Box::new(Expr::Identifier("t1.id".to_string())),
1186 op: BinaryOperator::Eq,
1187 right: Box::new(Expr::Literal(chryso_core::ast::Literal::Number(7.0))),
1188 }),
1189 },
1190 }),
1191 };
1192 let inferred = PredicateInference.apply(&plan);
1193 assert_eq!(inferred.len(), 1);
1194 let LogicalPlan::Filter { input, predicate } = &inferred[0] else {
1195 panic!("expected filter");
1196 };
1197 let LogicalPlan::Join { on, .. } = input.as_ref() else {
1198 panic!("expected join");
1199 };
1200 let filter_conjuncts = split_conjuncts(predicate)
1201 .into_iter()
1202 .map(|expr| expr.to_sql())
1203 .collect::<std::collections::HashSet<_>>();
1204 let join_conjuncts = split_conjuncts(on)
1205 .into_iter()
1206 .map(|expr| expr.to_sql())
1207 .collect::<std::collections::HashSet<_>>();
1208 assert!(filter_conjuncts.contains("t1.flag = true"));
1209 assert!(join_conjuncts.contains("t2.id = 7"));
1210 }
1211
1212 #[test]
1213 fn predicate_inference_adds_transitive_equivalence() {
1214 let plan = LogicalPlan::Filter {
1215 predicate: Expr::BinaryOp {
1216 left: Box::new(Expr::BinaryOp {
1217 left: Box::new(Expr::Identifier("a".to_string())),
1218 op: BinaryOperator::Eq,
1219 right: Box::new(Expr::Identifier("b".to_string())),
1220 }),
1221 op: BinaryOperator::And,
1222 right: Box::new(Expr::BinaryOp {
1223 left: Box::new(Expr::Identifier("b".to_string())),
1224 op: BinaryOperator::Eq,
1225 right: Box::new(Expr::Identifier("c".to_string())),
1226 }),
1227 },
1228 input: Box::new(LogicalPlan::Scan {
1229 table: "t".to_string(),
1230 }),
1231 };
1232 let rule = PredicateInference;
1233 let results = rule.apply(&plan);
1234 assert_eq!(results.len(), 1);
1235 let LogicalPlan::Filter { predicate, .. } = &results[0] else {
1236 panic!("expected filter");
1237 };
1238 let conjuncts = split_conjuncts(predicate)
1239 .into_iter()
1240 .map(|expr| expr.to_sql())
1241 .collect::<std::collections::HashSet<_>>();
1242 assert!(conjuncts.contains("a = c") || conjuncts.contains("c = a"));
1243 }
1244}
1245
1246pub struct AggregatePredicatePushdown;
1247
1248impl Rule for AggregatePredicatePushdown {
1249 fn name(&self) -> &str {
1250 "aggregate_predicate_pushdown"
1251 }
1252
1253 fn apply(&self, plan: &LogicalPlan) -> Vec<LogicalPlan> {
1254 let LogicalPlan::Filter { predicate, input } = plan else {
1255 return Vec::new();
1256 };
1257 let LogicalPlan::Aggregate {
1258 group_exprs,
1259 aggr_exprs,
1260 input,
1261 } = input.as_ref()
1262 else {
1263 return Vec::new();
1264 };
1265 let predicate_idents = collect_identifiers(predicate);
1266 let group_idents = group_exprs
1267 .iter()
1268 .flat_map(collect_identifiers)
1269 .collect::<std::collections::HashSet<_>>();
1270 if predicate_idents
1271 .iter()
1272 .all(|ident| group_idents.contains(ident))
1273 {
1274 return vec![LogicalPlan::Aggregate {
1275 group_exprs: group_exprs.clone(),
1276 aggr_exprs: aggr_exprs.clone(),
1277 input: Box::new(LogicalPlan::Filter {
1278 predicate: predicate.clone(),
1279 input: input.clone(),
1280 }),
1281 }];
1282 }
1283 Vec::new()
1284 }
1285}
1286
1287#[derive(Clone, Copy)]
1288enum Side {
1289 Left,
1290 Right,
1291}
1292
1293fn infer_predicates(predicate: &Expr) -> (Expr, bool) {
1296 let conjuncts = split_conjuncts(predicate);
1297 let mut existing = std::collections::HashSet::new();
1298 for expr in &conjuncts {
1299 existing.insert(expr.to_sql());
1300 }
1301
1302 let mut uf = UnionFind::new();
1303 let mut literal_bindings = std::collections::HashMap::<String, Literal>::new();
1304 let mut literal_conflicts = std::collections::HashSet::<String>::new();
1305
1306 for conjunct in &conjuncts {
1307 let Expr::BinaryOp { left, op, right } = conjunct else {
1308 continue;
1309 };
1310 if !matches!(op, BinaryOperator::Eq) {
1311 continue;
1312 }
1313 match (left.as_ref(), right.as_ref()) {
1314 (Expr::Identifier(left), Expr::Identifier(right)) => {
1315 uf.union(left, right);
1316 }
1317 (Expr::Identifier(ident), Expr::Literal(literal))
1318 | (Expr::Literal(literal), Expr::Identifier(ident)) => {
1319 uf.add(ident);
1320 if let Some(existing) = literal_bindings.get(ident) {
1321 if !literal_eq(existing, literal) {
1322 literal_conflicts.insert(ident.clone());
1323 }
1324 } else {
1325 literal_bindings.insert(ident.clone(), literal.clone());
1326 }
1327 }
1328 _ => {}
1329 }
1330 }
1331
1332 let mut class_literals = std::collections::HashMap::<String, Option<Literal>>::new();
1333 if !literal_conflicts.is_empty() {
1334 }
1336 for ident in &literal_conflicts {
1337 let root = uf.find(ident);
1338 class_literals.insert(root, None);
1339 }
1340 for (ident, literal) in &literal_bindings {
1341 if literal_conflicts.contains(ident) {
1342 continue;
1343 }
1344 let root = uf.find(ident);
1345 match class_literals.get(&root) {
1346 None => {
1347 class_literals.insert(root, Some(literal.clone()));
1348 }
1349 Some(Some(existing_literal)) => {
1350 if !literal_eq(existing_literal, literal) {
1351 class_literals.insert(root, None);
1352 }
1353 }
1354 Some(None) => {}
1355 }
1356 }
1357
1358 let mut groups = std::collections::HashMap::<String, Vec<String>>::new();
1359 for key in uf.keys() {
1360 let root = uf.find(&key);
1361 groups.entry(root).or_default().push(key);
1362 }
1363
1364 let mut inferred = Vec::new();
1365 for members in groups.values() {
1366 if members.len() <= 1 {
1367 continue;
1368 }
1369 let mut members = members.clone();
1370 members.sort();
1371 let canonical = members[0].clone();
1372 for ident in members.into_iter().skip(1) {
1373 let forward = format!("{canonical} = {ident}");
1374 let backward = format!("{ident} = {canonical}");
1375 if existing.contains(&forward) || existing.contains(&backward) {
1376 continue;
1377 }
1378 existing.insert(forward);
1379 inferred.push(Expr::BinaryOp {
1380 left: Box::new(Expr::Identifier(canonical.clone())),
1381 op: BinaryOperator::Eq,
1382 right: Box::new(Expr::Identifier(ident)),
1383 });
1384 }
1385 }
1386
1387 for (root, literal) in class_literals {
1388 let Some(literal) = literal else {
1389 continue;
1390 };
1391 let Some(members) = groups.get(&root) else {
1392 continue;
1393 };
1394 for ident in members {
1395 let expr = Expr::BinaryOp {
1396 left: Box::new(Expr::Identifier(ident.clone())),
1397 op: BinaryOperator::Eq,
1398 right: Box::new(Expr::Literal(literal.clone())),
1399 };
1400 if existing.insert(expr.to_sql()) {
1401 inferred.push(expr);
1402 }
1403 }
1404 }
1405
1406 if inferred.is_empty() {
1407 return (predicate.clone(), false);
1408 }
1409 let mut all = conjuncts;
1410 all.extend(inferred);
1411 (combine_conjuncts(all).unwrap_or_else(|| predicate.clone()), true)
1412}
1413
1414fn split_predicates_by_source(
1415 combined: &Expr,
1416 filter_predicate: &Expr,
1417 join_predicate: &Expr,
1418) -> (Vec<Expr>, Vec<Expr>) {
1419 let filter_set = split_conjuncts(filter_predicate)
1420 .into_iter()
1421 .map(|expr| expr.to_sql())
1422 .collect::<std::collections::HashSet<_>>();
1423 let join_set = split_conjuncts(join_predicate)
1424 .into_iter()
1425 .map(|expr| expr.to_sql())
1426 .collect::<std::collections::HashSet<_>>();
1427 let filter_prefixes = collect_table_prefixes(filter_predicate);
1428 let mut filter_preds = Vec::new();
1429 let mut join_preds = Vec::new();
1430 for expr in split_conjuncts(combined) {
1431 let sql = expr.to_sql();
1432 if join_set.contains(&sql) {
1433 join_preds.push(expr);
1434 } else if filter_set.contains(&sql) {
1435 filter_preds.push(expr);
1436 } else if is_join_compatible(&expr) {
1437 join_preds.push(expr);
1438 } else if is_same_table_predicate(&expr, &filter_prefixes) {
1439 filter_preds.push(expr);
1440 } else {
1441 join_preds.push(expr);
1442 }
1443 }
1444 (filter_preds, join_preds)
1445}
1446
1447fn is_join_compatible(expr: &Expr) -> bool {
1448 let idents = collect_identifiers(expr);
1449 if idents.len() < 2 {
1450 return false;
1451 }
1452 let mut tables = std::collections::HashSet::new();
1453 for ident in idents {
1454 if let Some(prefix) = table_prefix(&ident) {
1455 tables.insert(prefix.to_string());
1456 }
1457 }
1458 tables.len() >= 2
1459}
1460
1461fn collect_table_prefixes(expr: &Expr) -> std::collections::HashSet<String> {
1462 let idents = collect_identifiers(expr);
1463 let mut tables = std::collections::HashSet::new();
1464 for ident in idents {
1465 if let Some(prefix) = table_prefix(&ident) {
1466 tables.insert(prefix.to_string());
1467 }
1468 }
1469 tables
1470}
1471
1472fn is_same_table_predicate(expr: &Expr, known_tables: &std::collections::HashSet<String>) -> bool {
1473 let idents = collect_identifiers(expr);
1474 if idents.is_empty() {
1475 return false;
1476 }
1477 let mut tables = std::collections::HashSet::new();
1478 for ident in idents {
1479 let Some(prefix) = table_prefix(&ident) else {
1480 return false;
1481 };
1482 tables.insert(prefix.to_string());
1483 }
1484 if tables.len() != 1 {
1485 return false;
1486 }
1487 if known_tables.is_empty() {
1488 return false;
1489 }
1490 tables.is_subset(known_tables)
1491}
1492
1493fn literal_eq(left: &Literal, right: &Literal) -> bool {
1494 match (left, right) {
1495 (Literal::String(left), Literal::String(right)) => left == right,
1496 (Literal::Number(left), Literal::Number(right)) => left == right,
1497 (Literal::Bool(left), Literal::Bool(right)) => left == right,
1498 _ => false,
1499 }
1500}
1501
1502struct UnionFind {
1503 parent: std::collections::HashMap<String, String>,
1504}
1505
1506impl UnionFind {
1507 fn new() -> Self {
1508 Self {
1509 parent: std::collections::HashMap::new(),
1510 }
1511 }
1512
1513 fn add(&mut self, key: &str) {
1514 self.parent.entry(key.to_string()).or_insert_with(|| key.to_string());
1515 }
1516
1517 fn find(&mut self, key: &str) -> String {
1518 self.add(key);
1519 let mut current = key.to_string();
1520 let mut path = Vec::new();
1521 loop {
1522 let parent = self
1523 .parent
1524 .get(¤t)
1525 .cloned()
1526 .unwrap_or_else(|| current.clone());
1527 if parent == current {
1528 break;
1529 }
1530 path.push(current);
1531 current = parent;
1532 }
1533 for node in path {
1534 self.parent.insert(node, current.clone());
1535 }
1536 current
1537 }
1538
1539 fn union(&mut self, left: &str, right: &str) {
1540 let left_root = self.find(left);
1541 let right_root = self.find(right);
1542 if left_root != right_root {
1543 self.parent.insert(left_root, right_root);
1544 }
1545 }
1546
1547 fn keys(&self) -> Vec<String> {
1548 self.parent.keys().cloned().collect()
1549 }
1550}