1use std::collections::HashMap;
33use std::sync::Arc;
34use xlog_core::{RelId, Schema};
35use xlog_ir::{CompareOp, Expr, JoinType, RirNode};
36use xlog_stats::StatsManager;
37
38#[derive(Debug, Clone)]
43#[non_exhaustive]
44pub struct OptimizerConfig {
45 pub dp_threshold: usize,
51
52 pub index_heat_threshold: f32,
57
58 pub enable_pushdown: bool,
63
64 pub default_filter_selectivity: f64,
69
70 pub transfer_cost_multiplier: f64,
75
76 pub default_bytes_per_row: u64,
80}
81
82impl Default for OptimizerConfig {
83 fn default() -> Self {
84 Self {
85 dp_threshold: 10,
86 index_heat_threshold: 0.7,
87 enable_pushdown: true,
88 default_filter_selectivity: 0.1,
89 transfer_cost_multiplier: 100.0,
90 default_bytes_per_row: 32,
91 }
92 }
93}
94
95#[derive(Debug, Clone, Default, PartialEq)]
100pub struct PlanCost {
101 pub rows: u64,
103
104 pub cpu_cost: f64,
109
110 pub gpu_mem: u64,
115
116 pub transfers: u32,
121}
122
123impl PlanCost {
124 pub fn with_rows(rows: u64) -> Self {
126 Self {
127 rows,
128 ..Default::default()
129 }
130 }
131
132 pub fn total_cost(&self, transfer_weight: f64) -> f64 {
143 self.cpu_cost + (self.gpu_mem as f64 * 0.001) + (self.transfers as f64 * transfer_weight)
144 }
145
146 pub fn then(self, other: PlanCost) -> PlanCost {
150 PlanCost {
151 rows: other.rows,
152 cpu_cost: self.cpu_cost + other.cpu_cost,
153 gpu_mem: self.gpu_mem.max(other.gpu_mem), transfers: self.transfers + other.transfers,
155 }
156 }
157}
158
159pub struct Optimizer {
165 stats: Arc<StatsManager>,
166 config: OptimizerConfig,
167 schemas: HashMap<RelId, Schema>,
169}
170
171impl Optimizer {
172 pub fn new(stats: Arc<StatsManager>) -> Self {
178 Self {
179 stats,
180 config: OptimizerConfig::default(),
181 schemas: HashMap::new(),
182 }
183 }
184
185 pub fn with_config(stats: Arc<StatsManager>, config: OptimizerConfig) -> Self {
192 Self {
193 stats,
194 config,
195 schemas: HashMap::new(),
196 }
197 }
198
199 pub fn set_schemas(&mut self, schemas: HashMap<RelId, Schema>) {
204 self.schemas = schemas;
205 }
206
207 pub fn config(&self) -> &OptimizerConfig {
209 &self.config
210 }
211
212 pub fn stats(&self) -> &Arc<StatsManager> {
214 &self.stats
215 }
216
217 pub fn optimize(&self, node: RirNode) -> RirNode {
235 if self.config.enable_pushdown {
236 self.predicate_pushdown(node)
237 } else {
238 node
239 }
240 }
241
242 fn predicate_pushdown(&self, node: RirNode) -> RirNode {
260 match node {
261 RirNode::Unit => RirNode::Unit,
263 RirNode::Scan { rel } => RirNode::Scan { rel },
264
265 RirNode::Filter { input, predicate } => {
267 let optimized_input = self.predicate_pushdown(*input);
269
270 match optimized_input {
271 RirNode::Filter {
273 input: inner_input,
274 predicate: inner_pred,
275 } => {
276 let merged = Expr::And(vec![inner_pred, predicate]);
277 RirNode::Filter {
278 input: inner_input,
279 predicate: merged,
280 }
281 }
282
283 RirNode::Project {
285 input: proj_input,
286 columns,
287 } => {
288 if let Some(remapped) =
290 self.remap_predicate_through_project(&predicate, &columns)
291 {
292 RirNode::Project {
294 input: Box::new(RirNode::Filter {
295 input: proj_input,
296 predicate: remapped,
297 }),
298 columns,
299 }
300 } else {
301 RirNode::Filter {
303 input: Box::new(RirNode::Project {
304 input: proj_input,
305 columns,
306 }),
307 predicate,
308 }
309 }
310 }
311
312 RirNode::Join {
314 left,
315 right,
316 left_keys,
317 right_keys,
318 join_type,
319 } => {
320 let left_width = self.estimate_width(&left);
321 let (left_preds, right_preds, remaining) =
322 self.split_predicate_for_join(&predicate, left_width);
323
324 let new_left = if !left_preds.is_empty() {
326 Box::new(RirNode::Filter {
327 input: left,
328 predicate: Self::conjoin(left_preds),
329 })
330 } else {
331 left
332 };
333
334 let new_right = if !right_preds.is_empty() {
335 Box::new(RirNode::Filter {
336 input: right,
337 predicate: Self::conjoin(right_preds),
338 })
339 } else {
340 right
341 };
342
343 let join_node = RirNode::Join {
344 left: new_left,
345 right: new_right,
346 left_keys,
347 right_keys,
348 join_type,
349 };
350
351 if !remaining.is_empty() {
353 RirNode::Filter {
354 input: Box::new(join_node),
355 predicate: Self::conjoin(remaining),
356 }
357 } else {
358 join_node
359 }
360 }
361
362 other => RirNode::Filter {
364 input: Box::new(other),
365 predicate,
366 },
367 }
368 }
369
370 RirNode::Project { input, columns } => RirNode::Project {
372 input: Box::new(self.predicate_pushdown(*input)),
373 columns,
374 },
375
376 RirNode::Join {
378 left,
379 right,
380 left_keys,
381 right_keys,
382 join_type,
383 } => RirNode::Join {
384 left: Box::new(self.predicate_pushdown(*left)),
385 right: Box::new(self.predicate_pushdown(*right)),
386 left_keys,
387 right_keys,
388 join_type,
389 },
390
391 RirNode::GroupBy {
393 input,
394 key_cols,
395 aggs,
396 } => RirNode::GroupBy {
397 input: Box::new(self.predicate_pushdown(*input)),
398 key_cols,
399 aggs,
400 },
401
402 RirNode::Union { inputs } => RirNode::Union {
404 inputs: inputs
405 .into_iter()
406 .map(|i| self.predicate_pushdown(i))
407 .collect(),
408 },
409
410 RirNode::Distinct { input, key_cols } => RirNode::Distinct {
412 input: Box::new(self.predicate_pushdown(*input)),
413 key_cols,
414 },
415
416 RirNode::Diff { left, right } => RirNode::Diff {
418 left: Box::new(self.predicate_pushdown(*left)),
419 right: Box::new(self.predicate_pushdown(*right)),
420 },
421
422 RirNode::Fixpoint {
424 scc_id,
425 base,
426 recursive,
427 delta_rel,
428 full_rel,
429 } => RirNode::Fixpoint {
430 scc_id,
431 base: Box::new(self.predicate_pushdown(*base)),
432 recursive: Box::new(self.predicate_pushdown(*recursive)),
433 delta_rel,
434 full_rel,
435 },
436
437 RirNode::TensorMaskedJoin { .. } => node, RirNode::MultiWayJoin { .. } | RirNode::ChainJoin { .. } => node,
443 }
444 }
445
446 fn remap_predicate_through_project(
452 &self,
453 predicate: &Expr,
454 columns: &[xlog_ir::ProjectExpr],
455 ) -> Option<Expr> {
456 let mut output_to_input: std::collections::HashMap<usize, usize> =
459 std::collections::HashMap::new();
460
461 for (out_idx, proj_expr) in columns.iter().enumerate() {
462 if let xlog_ir::ProjectExpr::Column(in_idx) = proj_expr {
463 output_to_input.insert(out_idx, *in_idx);
464 }
465 }
466
467 self.remap_expr(predicate, &output_to_input)
468 }
469
470 fn remap_expr(
472 &self,
473 expr: &Expr,
474 mapping: &std::collections::HashMap<usize, usize>,
475 ) -> Option<Expr> {
476 match expr {
477 Expr::Column(idx) => mapping.get(idx).map(|&new_idx| Expr::Column(new_idx)),
478
479 Expr::Const(val) => Some(Expr::Const(val.clone())),
480
481 Expr::Compare { left, op, right } => {
482 let new_left = self.remap_expr(left, mapping)?;
483 let new_right = self.remap_expr(right, mapping)?;
484 Some(Expr::Compare {
485 left: Box::new(new_left),
486 op: *op,
487 right: Box::new(new_right),
488 })
489 }
490
491 Expr::And(exprs) => {
492 let remapped: Option<Vec<_>> =
493 exprs.iter().map(|e| self.remap_expr(e, mapping)).collect();
494 remapped.map(Expr::And)
495 }
496
497 Expr::Or(exprs) => {
498 let remapped: Option<Vec<_>> =
499 exprs.iter().map(|e| self.remap_expr(e, mapping)).collect();
500 remapped.map(Expr::Or)
501 }
502
503 Expr::Not(inner) => {
504 let remapped = self.remap_expr(inner, mapping)?;
505 Some(Expr::Not(Box::new(remapped)))
506 }
507
508 Expr::Add(l, r) => {
510 let new_l = self.remap_expr(l, mapping)?;
511 let new_r = self.remap_expr(r, mapping)?;
512 Some(Expr::Add(Box::new(new_l), Box::new(new_r)))
513 }
514 Expr::Sub(l, r) => {
515 let new_l = self.remap_expr(l, mapping)?;
516 let new_r = self.remap_expr(r, mapping)?;
517 Some(Expr::Sub(Box::new(new_l), Box::new(new_r)))
518 }
519 Expr::Mul(l, r) => {
520 let new_l = self.remap_expr(l, mapping)?;
521 let new_r = self.remap_expr(r, mapping)?;
522 Some(Expr::Mul(Box::new(new_l), Box::new(new_r)))
523 }
524 Expr::Div(l, r) => {
525 let new_l = self.remap_expr(l, mapping)?;
526 let new_r = self.remap_expr(r, mapping)?;
527 Some(Expr::Div(Box::new(new_l), Box::new(new_r)))
528 }
529 Expr::Mod(l, r) => {
530 let new_l = self.remap_expr(l, mapping)?;
531 let new_r = self.remap_expr(r, mapping)?;
532 Some(Expr::Mod(Box::new(new_l), Box::new(new_r)))
533 }
534
535 Expr::Abs(inner) => {
537 let remapped = self.remap_expr(inner, mapping)?;
538 Some(Expr::Abs(Box::new(remapped)))
539 }
540 Expr::Min(l, r) => {
541 let new_l = self.remap_expr(l, mapping)?;
542 let new_r = self.remap_expr(r, mapping)?;
543 Some(Expr::Min(Box::new(new_l), Box::new(new_r)))
544 }
545 Expr::Max(l, r) => {
546 let new_l = self.remap_expr(l, mapping)?;
547 let new_r = self.remap_expr(r, mapping)?;
548 Some(Expr::Max(Box::new(new_l), Box::new(new_r)))
549 }
550 Expr::Pow(l, r) => {
551 let new_l = self.remap_expr(l, mapping)?;
552 let new_r = self.remap_expr(r, mapping)?;
553 Some(Expr::Pow(Box::new(new_l), Box::new(new_r)))
554 }
555 Expr::Cast(inner, scalar_type) => {
556 let remapped = self.remap_expr(inner, mapping)?;
557 Some(Expr::Cast(Box::new(remapped), *scalar_type))
558 }
559 Expr::Conditional {
560 condition,
561 then_expr,
562 else_expr,
563 } => {
564 let new_condition = self.remap_expr(condition, mapping)?;
565 let new_then = self.remap_expr(then_expr, mapping)?;
566 let new_else = self.remap_expr(else_expr, mapping)?;
567 Some(Expr::Conditional {
568 condition: Box::new(new_condition),
569 then_expr: Box::new(new_then),
570 else_expr: Box::new(new_else),
571 })
572 }
573 }
574 }
575
576 fn estimate_width(&self, node: &RirNode) -> usize {
578 match node {
579 RirNode::Unit => 0,
580 RirNode::Scan { rel } => {
581 if let Some(schema) = self.schemas.get(rel) {
583 schema.arity()
584 } else if let Some(stats) = self.stats.get_relation_stats(*rel) {
585 stats.column_stats.len().max(1)
586 } else {
587 4 }
589 }
590 RirNode::Filter { input, .. } => self.estimate_width(input),
591 RirNode::Project { columns, .. } => columns.len(),
592 RirNode::Join { left, right, .. } => {
593 self.estimate_width(left) + self.estimate_width(right)
594 }
595 RirNode::ChainJoin { output_columns, .. } => output_columns.len(),
596 RirNode::GroupBy { key_cols, aggs, .. } => key_cols.len() + aggs.len(),
597 RirNode::Union { inputs } => {
598 inputs.first().map(|i| self.estimate_width(i)).unwrap_or(0)
599 }
600 RirNode::Distinct { input, .. } => self.estimate_width(input),
601 RirNode::Diff { left, .. } => self.estimate_width(left),
602 RirNode::Fixpoint { base, .. } => self.estimate_width(base),
603 RirNode::TensorMaskedJoin { head_rel_id, .. } => self
606 .schemas
607 .get(head_rel_id)
608 .map(|s| s.arity())
609 .unwrap_or(2),
610 RirNode::MultiWayJoin { output_columns, .. } => output_columns.len(),
613 }
614 }
615
616 fn split_predicate_for_join(
620 &self,
621 predicate: &Expr,
622 left_width: usize,
623 ) -> (Vec<Expr>, Vec<Expr>, Vec<Expr>) {
624 let mut left_preds = Vec::new();
625 let mut right_preds = Vec::new();
626 let mut remaining = Vec::new();
627
628 let conjuncts = Self::flatten_and(predicate);
630
631 for conj in conjuncts {
632 let cols = Self::collect_columns(&conj);
633 let max_col = cols.iter().copied().max().unwrap_or(0);
634 let min_col = cols.iter().copied().min().unwrap_or(0);
635
636 if cols.is_empty() {
637 left_preds.push(conj);
639 } else if max_col < left_width {
640 left_preds.push(conj);
642 } else if min_col >= left_width {
643 let remapped = Self::remap_columns(&conj, |c| c - left_width);
645 right_preds.push(remapped);
646 } else {
647 remaining.push(conj);
649 }
650 }
651
652 (left_preds, right_preds, remaining)
653 }
654
655 fn flatten_and(expr: &Expr) -> Vec<Expr> {
657 match expr {
658 Expr::And(exprs) => exprs.iter().flat_map(Self::flatten_and).collect(),
659 other => vec![other.clone()],
660 }
661 }
662
663 fn collect_columns(expr: &Expr) -> Vec<usize> {
665 match expr {
666 Expr::Column(idx) => vec![*idx],
667 Expr::Const(_) => vec![],
668 Expr::Compare { left, right, .. } => {
669 let mut cols = Self::collect_columns(left);
670 cols.extend(Self::collect_columns(right));
671 cols
672 }
673 Expr::And(exprs) | Expr::Or(exprs) => {
674 exprs.iter().flat_map(Self::collect_columns).collect()
675 }
676 Expr::Not(inner) | Expr::Abs(inner) | Expr::Cast(inner, _) => {
677 Self::collect_columns(inner)
678 }
679 Expr::Add(l, r)
680 | Expr::Sub(l, r)
681 | Expr::Mul(l, r)
682 | Expr::Div(l, r)
683 | Expr::Mod(l, r)
684 | Expr::Min(l, r)
685 | Expr::Max(l, r)
686 | Expr::Pow(l, r) => {
687 let mut cols = Self::collect_columns(l);
688 cols.extend(Self::collect_columns(r));
689 cols
690 }
691 Expr::Conditional {
692 condition,
693 then_expr,
694 else_expr,
695 } => {
696 let mut cols = Self::collect_columns(condition);
697 cols.extend(Self::collect_columns(then_expr));
698 cols.extend(Self::collect_columns(else_expr));
699 cols
700 }
701 }
702 }
703
704 fn remap_columns<F: Fn(usize) -> usize + Copy>(expr: &Expr, f: F) -> Expr {
706 match expr {
707 Expr::Column(idx) => Expr::Column(f(*idx)),
708 Expr::Const(v) => Expr::Const(v.clone()),
709 Expr::Compare { left, op, right } => Expr::Compare {
710 left: Box::new(Self::remap_columns(left, f)),
711 op: *op,
712 right: Box::new(Self::remap_columns(right, f)),
713 },
714 Expr::And(exprs) => {
715 Expr::And(exprs.iter().map(|e| Self::remap_columns(e, f)).collect())
716 }
717 Expr::Or(exprs) => Expr::Or(exprs.iter().map(|e| Self::remap_columns(e, f)).collect()),
718 Expr::Not(inner) => Expr::Not(Box::new(Self::remap_columns(inner, f))),
719 Expr::Add(l, r) => Expr::Add(
720 Box::new(Self::remap_columns(l, f)),
721 Box::new(Self::remap_columns(r, f)),
722 ),
723 Expr::Sub(l, r) => Expr::Sub(
724 Box::new(Self::remap_columns(l, f)),
725 Box::new(Self::remap_columns(r, f)),
726 ),
727 Expr::Mul(l, r) => Expr::Mul(
728 Box::new(Self::remap_columns(l, f)),
729 Box::new(Self::remap_columns(r, f)),
730 ),
731 Expr::Div(l, r) => Expr::Div(
732 Box::new(Self::remap_columns(l, f)),
733 Box::new(Self::remap_columns(r, f)),
734 ),
735 Expr::Mod(l, r) => Expr::Mod(
736 Box::new(Self::remap_columns(l, f)),
737 Box::new(Self::remap_columns(r, f)),
738 ),
739 Expr::Abs(inner) => Expr::Abs(Box::new(Self::remap_columns(inner, f))),
740 Expr::Min(l, r) => Expr::Min(
741 Box::new(Self::remap_columns(l, f)),
742 Box::new(Self::remap_columns(r, f)),
743 ),
744 Expr::Max(l, r) => Expr::Max(
745 Box::new(Self::remap_columns(l, f)),
746 Box::new(Self::remap_columns(r, f)),
747 ),
748 Expr::Pow(l, r) => Expr::Pow(
749 Box::new(Self::remap_columns(l, f)),
750 Box::new(Self::remap_columns(r, f)),
751 ),
752 Expr::Cast(inner, t) => Expr::Cast(Box::new(Self::remap_columns(inner, f)), *t),
753 Expr::Conditional {
754 condition,
755 then_expr,
756 else_expr,
757 } => Expr::Conditional {
758 condition: Box::new(Self::remap_columns(condition, f)),
759 then_expr: Box::new(Self::remap_columns(then_expr, f)),
760 else_expr: Box::new(Self::remap_columns(else_expr, f)),
761 },
762 }
763 }
764
765 fn conjoin(predicates: Vec<Expr>) -> Expr {
767 debug_assert!(!predicates.is_empty());
768 if predicates.len() == 1 {
769 predicates.into_iter().next().unwrap()
770 } else {
771 Expr::And(predicates)
772 }
773 }
774
775 pub fn estimate_cost(&self, node: &RirNode) -> PlanCost {
788 match node {
789 RirNode::Unit => PlanCost {
790 rows: 1,
791 cpu_cost: 0.0,
792 gpu_mem: 0,
793 transfers: 0,
794 },
795 RirNode::Scan { rel } => self.estimate_scan_cost(*rel),
796
797 RirNode::Filter { input, predicate } => {
798 let input_cost = self.estimate_cost(input);
799 self.estimate_filter_cost(input_cost, predicate, input)
800 }
801
802 RirNode::Project { input, columns } => {
803 let input_cost = self.estimate_cost(input);
804 self.estimate_project_cost(input_cost, columns)
805 }
806
807 RirNode::Join {
808 left,
809 right,
810 left_keys,
811 right_keys,
812 join_type,
813 } => {
814 let left_cost = self.estimate_cost(left);
815 let right_cost = self.estimate_cost(right);
816 self.estimate_join_cost(
817 left_cost, right_cost, left, right, left_keys, right_keys, *join_type,
818 )
819 }
820
821 RirNode::ChainJoin {
822 left,
823 right,
824 left_key,
825 right_key,
826 output_columns,
827 ..
828 } => {
829 let left_cost = self.estimate_cost(left);
830 let right_cost = self.estimate_cost(right);
831 let join_cost = self.estimate_join_cost(
832 left_cost,
833 right_cost,
834 left,
835 right,
836 &[*left_key],
837 &[*right_key],
838 JoinType::Inner,
839 );
840 self.estimate_project_cost(join_cost, output_columns)
841 }
842
843 RirNode::GroupBy {
844 input,
845 key_cols,
846 aggs,
847 } => {
848 let input_cost = self.estimate_cost(input);
849 self.estimate_groupby_cost(input_cost, key_cols, aggs)
850 }
851
852 RirNode::Union { inputs } => {
853 let costs: Vec<_> = inputs.iter().map(|i| self.estimate_cost(i)).collect();
854 self.estimate_union_cost(costs)
855 }
856
857 RirNode::Distinct { input, key_cols } => {
858 let input_cost = self.estimate_cost(input);
859 self.estimate_distinct_cost(input_cost, key_cols)
860 }
861
862 RirNode::Diff { left, right } => {
863 let left_cost = self.estimate_cost(left);
864 let right_cost = self.estimate_cost(right);
865 self.estimate_diff_cost(left_cost, right_cost)
866 }
867
868 RirNode::Fixpoint {
869 base, recursive, ..
870 } => {
871 let base_cost = self.estimate_cost(base);
872 let recursive_cost = self.estimate_cost(recursive);
873 self.estimate_fixpoint_cost(base_cost, recursive_cost)
874 }
875
876 RirNode::TensorMaskedJoin {
877 max_active_rules, ..
878 } => PlanCost {
879 rows: *max_active_rules as u64,
880 cpu_cost: *max_active_rules as f64 * 100.0,
881 gpu_mem: *max_active_rules as u64 * 1024,
882 transfers: 1,
883 },
884 RirNode::MultiWayJoin { inputs, .. } => {
889 let mut total = PlanCost::default();
890 for inp in inputs {
891 let c = self.estimate_cost(inp);
892 total.rows = total.rows.saturating_add(c.rows);
893 total.cpu_cost += c.cpu_cost;
894 total.gpu_mem = total.gpu_mem.saturating_add(c.gpu_mem);
895 total.transfers = total.transfers.saturating_add(c.transfers);
896 }
897 total
898 }
899 }
900 }
901
902 fn estimate_scan_cost(&self, rel: RelId) -> PlanCost {
904 if let Some(stats) = self.stats.get_relation_stats(rel) {
905 PlanCost {
906 rows: stats.cardinality,
907 cpu_cost: stats.cardinality as f64 * 0.01, gpu_mem: stats
909 .byte_size
910 .max(stats.cardinality * self.config.default_bytes_per_row),
911 transfers: 0, }
913 } else {
914 let default_rows = 1000;
916 PlanCost {
917 rows: default_rows,
918 cpu_cost: default_rows as f64 * 0.01,
919 gpu_mem: default_rows * self.config.default_bytes_per_row,
920 transfers: 0,
921 }
922 }
923 }
924
925 fn estimate_filter_cost(
927 &self,
928 input_cost: PlanCost,
929 predicate: &Expr,
930 input: &RirNode,
931 ) -> PlanCost {
932 let selectivity = self.estimate_predicate_selectivity(predicate, input);
933 let output_rows = ((input_cost.rows as f64 * selectivity) as u64).max(1);
934
935 PlanCost {
936 rows: output_rows,
937 cpu_cost: input_cost.cpu_cost + input_cost.rows as f64 * 0.02, gpu_mem: input_cost.gpu_mem, transfers: input_cost.transfers,
940 }
941 }
942
943 fn estimate_project_cost(
945 &self,
946 input_cost: PlanCost,
947 columns: &[xlog_ir::ProjectExpr],
948 ) -> PlanCost {
949 let computed_count = columns
951 .iter()
952 .filter(|c| matches!(c, xlog_ir::ProjectExpr::Computed(_, _)))
953 .count();
954
955 let compute_cost = computed_count as f64 * input_cost.rows as f64 * 0.05;
957
958 let output_width_ratio = columns.len() as f64 / (columns.len() + 2) as f64; PlanCost {
962 rows: input_cost.rows,
963 cpu_cost: input_cost.cpu_cost + compute_cost,
964 gpu_mem: (input_cost.gpu_mem as f64 * output_width_ratio) as u64,
965 transfers: input_cost.transfers,
966 }
967 }
968
969 #[allow(clippy::too_many_arguments)]
971 fn estimate_join_cost(
972 &self,
973 left_cost: PlanCost,
974 right_cost: PlanCost,
975 left: &RirNode,
976 right: &RirNode,
977 left_keys: &[usize],
978 right_keys: &[usize],
979 join_type: JoinType,
980 ) -> PlanCost {
981 let output_rows = match join_type {
984 JoinType::Semi => {
985 ((left_cost.rows as f64 * 0.5) as u64).max(1)
987 }
988 JoinType::Anti => {
989 ((left_cost.rows as f64 * 0.5) as u64).max(1)
991 }
992 JoinType::Inner | JoinType::LeftOuter => {
993 let left_rels = left.referenced_relations();
995 let right_rels = right.referenced_relations();
996
997 if left_rels.len() == 1 && right_rels.len() == 1 {
998 let estimated = self.stats.estimate_join_cardinality(
1000 left_rels[0],
1001 right_rels[0],
1002 left_keys,
1003 right_keys,
1004 );
1005
1006 match join_type {
1007 JoinType::LeftOuter => estimated.max(left_cost.rows),
1008 _ => estimated,
1009 }
1010 } else {
1011 match join_type {
1013 JoinType::Inner => {
1014 ((left_cost.rows as f64 * right_cost.rows as f64 * 0.1) as u64).max(1)
1016 }
1017 JoinType::LeftOuter => {
1018 left_cost.rows.max(
1020 ((left_cost.rows as f64 * right_cost.rows as f64 * 0.1) as u64)
1021 .max(1),
1022 )
1023 }
1024 _ => unreachable!(),
1025 }
1026 }
1027 }
1028 };
1029
1030 let build_cost = right_cost.rows as f64 * 1.0; let probe_cost = left_cost.rows as f64 * 0.5; let cpu_cost = left_cost.cpu_cost + right_cost.cpu_cost + build_cost + probe_cost;
1034
1035 let hash_table_overhead = right_cost.gpu_mem / 2; let gpu_mem = left_cost.gpu_mem + right_cost.gpu_mem + hash_table_overhead;
1038
1039 PlanCost {
1040 rows: output_rows,
1041 cpu_cost,
1042 gpu_mem,
1043 transfers: left_cost.transfers + right_cost.transfers,
1044 }
1045 }
1046
1047 fn estimate_groupby_cost(
1049 &self,
1050 input_cost: PlanCost,
1051 key_cols: &[usize],
1052 _aggs: &[(usize, xlog_core::AggOp)],
1053 ) -> PlanCost {
1054 let estimated_groups = if key_cols.is_empty() {
1057 1 } else {
1059 ((input_cost.rows as f64).sqrt() as u64).max(1)
1061 };
1062
1063 PlanCost {
1064 rows: estimated_groups,
1065 cpu_cost: input_cost.cpu_cost + input_cost.rows as f64 * 0.5, gpu_mem: input_cost.gpu_mem + estimated_groups * self.config.default_bytes_per_row,
1067 transfers: input_cost.transfers,
1068 }
1069 }
1070
1071 fn estimate_union_cost(&self, input_costs: Vec<PlanCost>) -> PlanCost {
1073 let total_rows: u64 = input_costs.iter().map(|c| c.rows).sum();
1074 let total_cpu: f64 = input_costs.iter().map(|c| c.cpu_cost).sum();
1075 let max_gpu: u64 = input_costs.iter().map(|c| c.gpu_mem).max().unwrap_or(0);
1076 let total_transfers: u32 = input_costs.iter().map(|c| c.transfers).sum();
1077
1078 PlanCost {
1079 rows: total_rows,
1080 cpu_cost: total_cpu + total_rows as f64 * 0.01, gpu_mem: max_gpu, transfers: total_transfers,
1083 }
1084 }
1085
1086 fn estimate_distinct_cost(&self, input_cost: PlanCost, _key_cols: &[usize]) -> PlanCost {
1088 let estimated_distinct = (input_cost.rows as f64 * 0.7) as u64;
1090
1091 PlanCost {
1092 rows: estimated_distinct.max(1),
1093 cpu_cost: input_cost.cpu_cost + input_cost.rows as f64 * 0.3, gpu_mem: input_cost.gpu_mem + input_cost.rows * 8, transfers: input_cost.transfers,
1096 }
1097 }
1098
1099 fn estimate_diff_cost(&self, left_cost: PlanCost, right_cost: PlanCost) -> PlanCost {
1101 let estimated_remaining = (left_cost.rows as f64 * 0.5) as u64;
1103
1104 PlanCost {
1105 rows: estimated_remaining.max(1),
1106 cpu_cost: left_cost.cpu_cost + right_cost.cpu_cost + right_cost.rows as f64 * 0.5,
1107 gpu_mem: left_cost.gpu_mem + right_cost.gpu_mem,
1108 transfers: left_cost.transfers + right_cost.transfers,
1109 }
1110 }
1111
1112 fn estimate_fixpoint_cost(&self, base_cost: PlanCost, recursive_cost: PlanCost) -> PlanCost {
1114 let estimated_iterations = ((base_cost.rows as f64).log2().ceil() as u64).max(1);
1117
1118 PlanCost {
1119 rows: base_cost.rows * estimated_iterations, cpu_cost: base_cost.cpu_cost + recursive_cost.cpu_cost * estimated_iterations as f64,
1121 gpu_mem: (base_cost.gpu_mem + recursive_cost.gpu_mem) * 2, transfers: base_cost.transfers + recursive_cost.transfers * estimated_iterations as u32,
1123 }
1124 }
1125
1126 fn estimate_predicate_selectivity(&self, predicate: &Expr, input: &RirNode) -> f64 {
1128 match predicate {
1129 Expr::Compare { left, op, right } => {
1130 self.estimate_compare_selectivity(left, *op, right, input)
1131 }
1132 Expr::And(exprs) => {
1133 exprs
1135 .iter()
1136 .map(|e| self.estimate_predicate_selectivity(e, input))
1137 .product()
1138 }
1139 Expr::Or(exprs) => {
1140 exprs
1143 .iter()
1144 .map(|e| self.estimate_predicate_selectivity(e, input))
1145 .fold(0.0, f64::max)
1146 }
1147 Expr::Not(inner) => 1.0 - self.estimate_predicate_selectivity(inner, input),
1148 _ => self.config.default_filter_selectivity,
1149 }
1150 }
1151
1152 fn estimate_compare_selectivity(
1154 &self,
1155 left: &Expr,
1156 op: CompareOp,
1157 right: &Expr,
1158 input: &RirNode,
1159 ) -> f64 {
1160 if let (Expr::Column(col_idx), Expr::Const(_)) | (Expr::Const(_), Expr::Column(col_idx)) =
1162 (left, right)
1163 {
1164 if let Some(rel_id) = self.find_column_relation(input, *col_idx) {
1166 if let Some(stats) = self.stats.get_relation_stats(rel_id) {
1167 if let Some(col_stats) = stats.get_column(*col_idx) {
1168 return match op {
1169 CompareOp::Eq => col_stats.equality_selectivity(stats.cardinality),
1170 CompareOp::Ne => {
1171 1.0 - col_stats.equality_selectivity(stats.cardinality)
1172 }
1173 CompareOp::Lt | CompareOp::Le | CompareOp::Gt | CompareOp::Ge => {
1174 0.33
1176 }
1177 };
1178 }
1179 }
1180 }
1181 }
1182
1183 match op {
1185 CompareOp::Eq => 0.1, CompareOp::Ne => 0.9, CompareOp::Lt | CompareOp::Le | CompareOp::Gt | CompareOp::Ge => 0.33, }
1189 }
1190
1191 fn find_column_relation(&self, node: &RirNode, col_idx: usize) -> Option<RelId> {
1193 match node {
1194 RirNode::Scan { rel } => Some(*rel),
1195 RirNode::Filter { input, .. } => self.find_column_relation(input, col_idx),
1196 RirNode::Project { input, columns } => {
1197 if col_idx < columns.len() {
1199 if let xlog_ir::ProjectExpr::Column(src_idx) = &columns[col_idx] {
1200 return self.find_column_relation(input, *src_idx);
1201 }
1202 }
1203 None
1204 }
1205 RirNode::Join { left, right, .. } => {
1206 let left_width = self.estimate_width(left);
1207 if col_idx < left_width {
1208 self.find_column_relation(left, col_idx)
1209 } else {
1210 self.find_column_relation(right, col_idx - left_width)
1211 }
1212 }
1213 RirNode::MultiWayJoin { .. } => None,
1219 _ => None, }
1221 }
1222
1223 pub fn recommend_indexes(&self) -> Vec<RelId> {
1228 self.stats.hot_relations(self.config.index_heat_threshold)
1229 }
1230
1231 pub fn should_use_greedy(&self, node: &RirNode) -> bool {
1235 let rels = node.referenced_relations();
1236 let unique_rels: std::collections::HashSet<_> = rels.iter().collect();
1237 unique_rels.len() > self.config.dp_threshold
1238 }
1239}
1240
1241pub mod selectivity_pass {
1253 use std::collections::HashMap;
1301 use xlog_core::RelId;
1302 use xlog_ir::ExecutionPlan;
1303 use xlog_stats::StatsManager;
1304
1305 pub fn run(plan: &mut ExecutionPlan, stats: &StatsManager, rel_ids: &HashMap<String, RelId>) {
1315 let _ = rel_ids;
1321 for rules in plan.rules_by_scc.iter_mut() {
1322 for rule in rules.iter_mut() {
1323 if let Some(rewritten) = super::reorder::try_reorder_triangle(&rule.body, stats) {
1324 rule.body = rewritten;
1325 continue;
1326 }
1327 if let Some(rewritten) = super::reorder::try_reorder_4cycle(&rule.body, stats) {
1328 rule.body = rewritten;
1329 }
1330 }
1331 }
1332 }
1333}
1334
1335pub mod helper_split_pass {
1337 use std::collections::{HashMap, HashSet};
1338
1339 use xlog_core::{RelId, ScalarType, Schema};
1340 use xlog_ir::rir::{HelperSplitSpec, KCliqueVariableOrder};
1341 use xlog_ir::{CompiledRule, ExecutionPlan, JoinType, ProjectExpr, RirMeta, RirNode, Scc};
1342 use xlog_stats::StatsManager;
1343
1344 const HEAVY_SKEW_RATIO: f64 = 10.0;
1345
1346 #[derive(Debug, Clone, PartialEq, Eq)]
1348 pub struct HelperRelationSpec {
1349 pub name: String,
1351 pub rel_id: RelId,
1353 pub schema: Schema,
1355 pub source_rels: [RelId; 2],
1357 }
1358
1359 struct JoinStep {
1360 left_keys: Vec<usize>,
1361 right_keys: Vec<usize>,
1362 }
1363
1364 struct LinearBody {
1365 leaves: Vec<RelId>,
1366 leaf_classes: Vec<Vec<u32>>,
1367 joins: Vec<JoinStep>,
1368 project: Vec<ProjectExpr>,
1369 final_classes: Vec<u32>,
1370 }
1371
1372 struct FlatJoin {
1373 leaves: Vec<RelId>,
1374 output_cols: Vec<usize>,
1375 equalities: Vec<(usize, usize)>,
1376 }
1377
1378 struct Candidate {
1379 pair_start: usize,
1380 helper_schema: Schema,
1381 helper_project: Vec<ProjectExpr>,
1382 helper_join_left_keys: Vec<usize>,
1383 helper_join_right_keys: Vec<usize>,
1384 exposed_classes: Vec<u32>,
1385 }
1386
1387 struct Rewrite {
1388 helper_body: RirNode,
1389 outer_body: RirNode,
1390 spec: HelperRelationSpec,
1391 }
1392
1393 #[derive(Clone, Copy)]
1394 struct KCliqueHelperEdge {
1395 slot: usize,
1396 rel: RelId,
1397 left: usize,
1398 right: usize,
1399 }
1400
1401 pub fn run<F>(
1403 plan: &mut ExecutionPlan,
1404 schemas: &HashMap<RelId, Schema>,
1405 stats: &StatsManager,
1406 mut allocate: F,
1407 ) -> Vec<HelperRelationSpec>
1408 where
1409 F: FnMut(Schema) -> (String, RelId),
1410 {
1411 let mut specs = Vec::new();
1412 for scc_idx in 0..plan.rules_by_scc.len() {
1413 let mut rule_idx = 0;
1414 while rule_idx < plan.rules_by_scc[scc_idx].len() {
1415 let rewrite = {
1416 let rule = &plan.rules_by_scc[scc_idx][rule_idx];
1417 try_rewrite_rule(rule, schemas, stats, &mut allocate)
1418 };
1419 if let Some(rewrite) = rewrite {
1420 let helper_rule = CompiledRule {
1421 head: rewrite.spec.name.clone(),
1422 body: rewrite.helper_body,
1423 meta: RirMeta::with_schema(rewrite.spec.schema.clone()),
1424 };
1425 plan.rules_by_scc[scc_idx].insert(rule_idx, helper_rule);
1426 rule_idx += 1;
1427 plan.rules_by_scc[scc_idx][rule_idx].body = rewrite.outer_body;
1428 add_helper_to_scc(&mut plan.sccs, scc_idx, &rewrite.spec.name);
1429 specs.push(rewrite.spec);
1430 }
1431 rule_idx += 1;
1432 }
1433 }
1434 specs
1435 }
1436
1437 pub fn run_kclique_specs<F>(
1443 plan: &mut ExecutionPlan,
1444 schemas: &HashMap<RelId, Schema>,
1445 mut allocate: F,
1446 ) -> Vec<HelperRelationSpec>
1447 where
1448 F: FnMut(Schema) -> (String, RelId),
1449 {
1450 let mut specs = Vec::new();
1451 for scc_idx in 0..plan.rules_by_scc.len() {
1452 let mut rule_idx = 0;
1453 while rule_idx < plan.rules_by_scc[scc_idx].len() {
1454 let rewrite = {
1455 let rule = &plan.rules_by_scc[scc_idx][rule_idx];
1456 try_rewrite_kclique_rule(rule, schemas, &mut allocate)
1457 };
1458 if let Some(rewrite) = rewrite {
1459 let helper_rule = CompiledRule {
1460 head: rewrite.spec.name.clone(),
1461 body: rewrite.helper_body,
1462 meta: RirMeta::with_schema(rewrite.spec.schema.clone()),
1463 };
1464 plan.rules_by_scc[scc_idx].insert(rule_idx, helper_rule);
1465 rule_idx += 1;
1466 plan.rules_by_scc[scc_idx][rule_idx].body = rewrite.outer_body;
1467 add_helper_to_scc(&mut plan.sccs, scc_idx, &rewrite.spec.name);
1468 specs.push(rewrite.spec);
1469 }
1470 rule_idx += 1;
1471 }
1472 }
1473 specs
1474 }
1475
1476 fn add_helper_to_scc(sccs: &mut [Scc], scc_idx: usize, helper: &str) {
1477 if let Some(scc) = sccs.get_mut(scc_idx) {
1478 if !scc.predicates.iter().any(|p| p == helper) {
1479 scc.predicates.push(helper.to_string());
1480 }
1481 }
1482 }
1483
1484 fn try_rewrite_rule<F>(
1485 rule: &CompiledRule,
1486 schemas: &HashMap<RelId, Schema>,
1487 stats: &StatsManager,
1488 allocate: &mut F,
1489 ) -> Option<Rewrite>
1490 where
1491 F: FnMut(Schema) -> (String, RelId),
1492 {
1493 let linear = linearize_project_body(&rule.body, schemas)?;
1494 let candidate = choose_candidate(&linear, schemas, stats)?;
1495 let (helper_name, helper_rel) = allocate(candidate.helper_schema.clone());
1496 let helper_body = build_helper_body(&linear, &candidate);
1497 let outer_body = build_outer_body(&linear, &candidate, helper_rel)?;
1498 Some(Rewrite {
1499 helper_body,
1500 outer_body,
1501 spec: HelperRelationSpec {
1502 name: helper_name,
1503 rel_id: helper_rel,
1504 schema: candidate.helper_schema,
1505 source_rels: [
1506 linear.leaves[candidate.pair_start],
1507 linear.leaves[candidate.pair_start + 1],
1508 ],
1509 },
1510 })
1511 }
1512
1513 fn try_rewrite_kclique_rule<F>(
1514 rule: &CompiledRule,
1515 schemas: &HashMap<RelId, Schema>,
1516 allocate: &mut F,
1517 ) -> Option<Rewrite>
1518 where
1519 F: FnMut(Schema) -> (String, RelId),
1520 {
1521 let mut outer_body = rule.body.clone();
1522 let RirNode::MultiWayJoin {
1523 inputs, var_order, ..
1524 } = &mut outer_body
1525 else {
1526 return None;
1527 };
1528 let kclique = var_order.as_ref()?.kclique.as_ref()?;
1529 let spec = kclique.helper_split_specs.first()?;
1530 let (hot_left, hot_right, target) = kclique_helper_edges(inputs, kclique, spec)?;
1531 let helper_schema = schemas.get(&target.rel)?.clone();
1532 let (helper_name, helper_rel) = allocate(helper_schema.clone());
1533 let helper_body = build_kclique_helper_body(spec, hot_left, hot_right, target)?;
1534 *inputs.get_mut(target.slot)? = RirNode::Scan { rel: helper_rel };
1535 Some(Rewrite {
1536 helper_body,
1537 outer_body,
1538 spec: HelperRelationSpec {
1539 name: helper_name,
1540 rel_id: helper_rel,
1541 schema: helper_schema,
1542 source_rels: [hot_left.rel, hot_right.rel],
1543 },
1544 })
1545 }
1546
1547 fn kclique_helper_edges(
1548 inputs: &[RirNode],
1549 kclique: &KCliqueVariableOrder,
1550 spec: &HelperSplitSpec,
1551 ) -> Option<(KCliqueHelperEdge, KCliqueHelperEdge, KCliqueHelperEdge)> {
1552 let k = usize::from(kclique.k);
1553 let hot = usize::from(spec.variable);
1554 let mut hot_edges = Vec::new();
1555 let mut target = None;
1556 for &slot in &spec.edge_slots {
1557 let slot = usize::from(slot);
1558 let (left, right) = kclique_edge_pair(slot, k)?;
1559 let RirNode::Scan { rel } = inputs.get(slot)? else {
1560 return None;
1561 };
1562 let edge = KCliqueHelperEdge {
1563 slot,
1564 rel: *rel,
1565 left,
1566 right,
1567 };
1568 if left == hot || right == hot {
1569 hot_edges.push(edge);
1570 } else {
1571 target = Some(edge);
1572 }
1573 }
1574 if hot_edges.len() != 2 {
1575 return None;
1576 }
1577 Some((hot_edges[0], hot_edges[1], target?))
1578 }
1579
1580 fn build_kclique_helper_body(
1581 spec: &HelperSplitSpec,
1582 hot_left: KCliqueHelperEdge,
1583 hot_right: KCliqueHelperEdge,
1584 target: KCliqueHelperEdge,
1585 ) -> Option<RirNode> {
1586 let hot = usize::from(spec.variable);
1587 let target_left = target.left;
1588 let target_right = target.right;
1589 let first_other = kclique_other_endpoint(hot_left, hot)?;
1590 let second_other = kclique_other_endpoint(hot_right, hot)?;
1591 if ![first_other, second_other].contains(&target_left)
1592 || ![first_other, second_other].contains(&target_right)
1593 {
1594 return None;
1595 }
1596
1597 let first_scan = RirNode::Scan { rel: hot_left.rel };
1598 let second_scan = RirNode::Scan { rel: hot_right.rel };
1599 let target_scan = RirNode::Scan { rel: target.rel };
1600 let first_hot_col = kclique_endpoint_col(hot_left, hot)?;
1601 let second_hot_col = kclique_endpoint_col(hot_right, hot)?;
1602 let first_other_col = kclique_endpoint_col(hot_left, first_other)?;
1603 let second_other_col = 2 + kclique_endpoint_col(hot_right, second_other)?;
1604
1605 let target_left_in_join = if first_other == target_left {
1606 first_other_col
1607 } else {
1608 second_other_col
1609 };
1610 let target_right_in_join = if first_other == target_right {
1611 first_other_col
1612 } else {
1613 second_other_col
1614 };
1615 let target_left_col = kclique_endpoint_col(target, target_left)?;
1616 let target_right_col = kclique_endpoint_col(target, target_right)?;
1617
1618 let hot_join = RirNode::Join {
1619 left: Box::new(first_scan),
1620 right: Box::new(second_scan),
1621 left_keys: vec![first_hot_col],
1622 right_keys: vec![second_hot_col],
1623 join_type: JoinType::Inner,
1624 };
1625 let helper_join = RirNode::Join {
1626 left: Box::new(hot_join),
1627 right: Box::new(target_scan),
1628 left_keys: vec![target_left_in_join, target_right_in_join],
1629 right_keys: vec![target_left_col, target_right_col],
1630 join_type: JoinType::Inner,
1631 };
1632 Some(RirNode::Project {
1633 input: Box::new(helper_join),
1634 columns: vec![ProjectExpr::Column(4), ProjectExpr::Column(5)],
1635 })
1636 }
1637
1638 fn kclique_edge_pair(edge_idx: usize, k: usize) -> Option<(usize, usize)> {
1639 let mut idx = 0usize;
1640 for left in 0..k {
1641 for right in (left + 1)..k {
1642 if idx == edge_idx {
1643 return Some((left, right));
1644 }
1645 idx += 1;
1646 }
1647 }
1648 None
1649 }
1650
1651 fn kclique_endpoint_col(edge: KCliqueHelperEdge, variable: usize) -> Option<usize> {
1652 if edge.left == variable {
1653 Some(0)
1654 } else if edge.right == variable {
1655 Some(1)
1656 } else {
1657 None
1658 }
1659 }
1660
1661 fn kclique_other_endpoint(edge: KCliqueHelperEdge, variable: usize) -> Option<usize> {
1662 if edge.left == variable {
1663 Some(edge.right)
1664 } else if edge.right == variable {
1665 Some(edge.left)
1666 } else {
1667 None
1668 }
1669 }
1670
1671 fn linearize_project_body(
1672 body: &RirNode,
1673 schemas: &HashMap<RelId, Schema>,
1674 ) -> Option<LinearBody> {
1675 let RirNode::Project { input, columns } = body else {
1676 return None;
1677 };
1678 let flat = collect_join_graph(input, schemas)?;
1679 if flat.leaves.len() < 6 {
1680 return None;
1681 }
1682 let mut offsets = Vec::with_capacity(flat.leaves.len());
1683 let mut total_cols = 0usize;
1684 for rel in &flat.leaves {
1685 offsets.push(total_cols);
1686 total_cols += schemas.get(rel)?.arity();
1687 }
1688 let mut uf = UnionFind::new(total_cols);
1689 for (left, right) in flat.equalities {
1690 if left >= total_cols || right >= total_cols {
1691 return None;
1692 }
1693 uf.union(left, right);
1694 }
1695 let mut leaf_classes: Vec<Vec<u32>> = Vec::with_capacity(flat.leaves.len());
1696 for (leaf_idx, rel) in flat.leaves.iter().enumerate() {
1697 let arity = schemas.get(rel)?.arity();
1698 let offset = offsets[leaf_idx];
1699 leaf_classes.push((0..arity).map(|col| uf.find(offset + col) as u32).collect());
1700 }
1701 let final_classes = flat
1702 .output_cols
1703 .iter()
1704 .map(|col| uf.find(*col) as u32)
1705 .collect();
1706 let joins = derive_left_deep_steps(&leaf_classes)?;
1707 Some(LinearBody {
1708 leaves: flat.leaves,
1709 leaf_classes,
1710 joins,
1711 project: columns.clone(),
1712 final_classes,
1713 })
1714 }
1715
1716 fn collect_join_graph(node: &RirNode, schemas: &HashMap<RelId, Schema>) -> Option<FlatJoin> {
1717 match node {
1718 RirNode::Scan { rel } => Some(FlatJoin {
1719 leaves: vec![*rel],
1720 output_cols: (0..schemas.get(rel)?.arity()).collect(),
1721 equalities: Vec::new(),
1722 }),
1723 RirNode::Join {
1724 left,
1725 right,
1726 left_keys,
1727 right_keys,
1728 join_type,
1729 } if *join_type == JoinType::Inner => {
1730 let left_flat = collect_join_graph(left, schemas)?;
1731 let right_flat = collect_join_graph(right, schemas)?;
1732 if left_keys.len() != right_keys.len() {
1733 return None;
1734 }
1735 let right_shift = total_width(&left_flat.leaves, schemas)?;
1736 let mut leaves = left_flat.leaves;
1737 leaves.extend(right_flat.leaves);
1738 let right_output_cols: Vec<usize> = right_flat
1739 .output_cols
1740 .iter()
1741 .map(|col| col + right_shift)
1742 .collect();
1743 let mut equalities = left_flat.equalities;
1744 equalities.extend(
1745 right_flat
1746 .equalities
1747 .iter()
1748 .map(|(left, right)| (left + right_shift, right + right_shift)),
1749 );
1750 for (&left_key, &right_key) in left_keys.iter().zip(right_keys.iter()) {
1751 equalities.push((
1752 *left_flat.output_cols.get(left_key)?,
1753 *right_output_cols.get(right_key)?,
1754 ));
1755 }
1756 let mut output_cols = left_flat.output_cols;
1757 output_cols.extend(right_output_cols);
1758 Some(FlatJoin {
1759 leaves,
1760 output_cols,
1761 equalities,
1762 })
1763 }
1764 _ => None,
1765 }
1766 }
1767
1768 fn total_width(leaves: &[RelId], schemas: &HashMap<RelId, Schema>) -> Option<usize> {
1769 leaves
1770 .iter()
1771 .map(|rel| schemas.get(rel).map(Schema::arity))
1772 .try_fold(0usize, |acc, width| width.map(|width| acc + width))
1773 }
1774
1775 fn derive_left_deep_steps(leaf_classes: &[Vec<u32>]) -> Option<Vec<JoinStep>> {
1776 let mut joins = Vec::with_capacity(leaf_classes.len().saturating_sub(1));
1777 let mut current = leaf_classes.first()?.clone();
1778 for classes in leaf_classes.iter().skip(1) {
1779 let mut left_keys = Vec::new();
1780 let mut right_keys = Vec::new();
1781 for (right_col, class) in classes.iter().enumerate() {
1782 if let Some(left_col) = current
1783 .iter()
1784 .position(|current_class| current_class == class)
1785 {
1786 left_keys.push(left_col);
1787 right_keys.push(right_col);
1788 }
1789 }
1790 if left_keys.is_empty() {
1791 return None;
1792 }
1793 joins.push(JoinStep {
1794 left_keys,
1795 right_keys,
1796 });
1797 current.extend(classes.iter().copied());
1798 }
1799 Some(joins)
1800 }
1801
1802 fn choose_candidate(
1803 linear: &LinearBody,
1804 schemas: &HashMap<RelId, Schema>,
1805 stats: &StatsManager,
1806 ) -> Option<Candidate> {
1807 for pair_start in 3..linear.leaves.len().saturating_sub(1) {
1808 let candidate = build_candidate(linear, schemas, pair_start)?;
1809 if skew_ratio_for_candidate(linear, stats, &candidate) >= HEAVY_SKEW_RATIO {
1810 return Some(candidate);
1811 }
1812 }
1813 None
1814 }
1815
1816 fn build_candidate(
1817 linear: &LinearBody,
1818 schemas: &HashMap<RelId, Schema>,
1819 pair_start: usize,
1820 ) -> Option<Candidate> {
1821 let left_rel = linear.leaves[pair_start];
1822 let right_rel = linear.leaves[pair_start + 1];
1823 let left_schema = schemas.get(&left_rel)?;
1824 let right_schema = schemas.get(&right_rel)?;
1825 let internal_step = linear.joins.get(pair_start)?;
1826 let mut helper_left_keys = Vec::new();
1827 let mut helper_right_keys = Vec::new();
1828 for (&left_key, &right_key) in internal_step
1829 .left_keys
1830 .iter()
1831 .zip(internal_step.right_keys.iter())
1832 {
1833 let class = class_at_state(linear, pair_start + 1, left_key)?;
1834 let left_col = linear.leaf_classes[pair_start]
1835 .iter()
1836 .position(|c| *c == class)?;
1837 helper_left_keys.push(left_col);
1838 helper_right_keys.push(right_key);
1839 }
1840 let internal: HashSet<u32> = helper_left_keys
1841 .iter()
1842 .map(|col| linear.leaf_classes[pair_start][*col])
1843 .collect();
1844 let outside = outside_classes(linear, pair_start);
1845 let output = projected_classes(linear)?;
1846 let mut exposed_classes = Vec::new();
1847 let mut helper_project = Vec::new();
1848 let mut helper_columns = Vec::new();
1849 for (col, class) in linear.leaf_classes[pair_start].iter().copied().enumerate() {
1850 if !internal.contains(&class)
1851 && (outside.contains(&class) || output.contains(&class))
1852 && !exposed_classes.contains(&class)
1853 {
1854 exposed_classes.push(class);
1855 helper_project.push(ProjectExpr::Column(col));
1856 let ty = left_schema.column_type(col).unwrap_or(ScalarType::U32);
1857 helper_columns.push((format!("c{}", helper_columns.len()), ty));
1858 }
1859 }
1860 let right_offset = left_schema.arity();
1861 for (col, class) in linear.leaf_classes[pair_start + 1]
1862 .iter()
1863 .copied()
1864 .enumerate()
1865 {
1866 if !internal.contains(&class)
1867 && (outside.contains(&class) || output.contains(&class))
1868 && !exposed_classes.contains(&class)
1869 {
1870 exposed_classes.push(class);
1871 helper_project.push(ProjectExpr::Column(right_offset + col));
1872 let ty = right_schema.column_type(col).unwrap_or(ScalarType::U32);
1873 helper_columns.push((format!("c{}", helper_columns.len()), ty));
1874 }
1875 }
1876 if exposed_classes.len() != 2 {
1877 return None;
1878 }
1879 Some(Candidate {
1880 pair_start,
1881 helper_schema: Schema::new(helper_columns),
1882 helper_project,
1883 helper_join_left_keys: helper_left_keys,
1884 helper_join_right_keys: helper_right_keys,
1885 exposed_classes,
1886 })
1887 }
1888
1889 fn class_at_state(linear: &LinearBody, leaf_count: usize, col: usize) -> Option<u32> {
1890 let mut idx = col;
1891 for leaf_idx in 0..leaf_count {
1892 let classes = &linear.leaf_classes[leaf_idx];
1893 if idx < classes.len() {
1894 return Some(classes[idx]);
1895 }
1896 idx -= classes.len();
1897 }
1898 None
1899 }
1900
1901 fn outside_classes(linear: &LinearBody, pair_start: usize) -> HashSet<u32> {
1902 linear
1903 .leaf_classes
1904 .iter()
1905 .enumerate()
1906 .filter(|(idx, _)| *idx != pair_start && *idx != pair_start + 1)
1907 .flat_map(|(_, classes)| classes.iter().copied())
1908 .collect()
1909 }
1910
1911 fn projected_classes(linear: &LinearBody) -> Option<HashSet<u32>> {
1912 let mut out = HashSet::new();
1913 for expr in &linear.project {
1914 let ProjectExpr::Column(col) = expr else {
1915 return None;
1916 };
1917 out.insert(*linear.final_classes.get(*col)?);
1918 }
1919 Some(out)
1920 }
1921
1922 fn skew_ratio_for_candidate(
1923 linear: &LinearBody,
1924 stats: &StatsManager,
1925 candidate: &Candidate,
1926 ) -> f64 {
1927 let rel = linear.leaves[candidate.pair_start];
1928 let Some(rel_stats) = stats.get_relation_stats(rel) else {
1929 return 0.0;
1930 };
1931 let mut ratio: f64 = 0.0;
1932 for (col, class) in linear.leaf_classes[candidate.pair_start]
1933 .iter()
1934 .copied()
1935 .enumerate()
1936 {
1937 if !candidate.exposed_classes.contains(&class) {
1938 continue;
1939 }
1940 let Some(col_stats) = rel_stats.get_column(col) else {
1941 continue;
1942 };
1943 if col_stats.distinct_estimate == 0 {
1944 continue;
1945 }
1946 ratio = ratio.max(rel_stats.cardinality as f64 / col_stats.distinct_estimate as f64);
1947 }
1948 ratio
1949 }
1950
1951 fn build_helper_body(linear: &LinearBody, candidate: &Candidate) -> RirNode {
1952 let left = RirNode::Scan {
1953 rel: linear.leaves[candidate.pair_start],
1954 };
1955 let right = RirNode::Scan {
1956 rel: linear.leaves[candidate.pair_start + 1],
1957 };
1958 RirNode::Project {
1959 input: Box::new(RirNode::Join {
1960 left: Box::new(left),
1961 right: Box::new(right),
1962 left_keys: candidate.helper_join_left_keys.clone(),
1963 right_keys: candidate.helper_join_right_keys.clone(),
1964 join_type: JoinType::Inner,
1965 }),
1966 columns: candidate.helper_project.clone(),
1967 }
1968 }
1969
1970 fn build_outer_body(
1971 linear: &LinearBody,
1972 candidate: &Candidate,
1973 helper_rel: RelId,
1974 ) -> Option<RirNode> {
1975 let mut node = RirNode::Scan {
1976 rel: linear.leaves[0],
1977 };
1978 let mut classes = linear.leaf_classes[0].clone();
1979 for leaf_idx in 1..candidate.pair_start {
1980 let step = &linear.joins[leaf_idx - 1];
1981 node = RirNode::Join {
1982 left: Box::new(node),
1983 right: Box::new(RirNode::Scan {
1984 rel: linear.leaves[leaf_idx],
1985 }),
1986 left_keys: step.left_keys.clone(),
1987 right_keys: step.right_keys.clone(),
1988 join_type: JoinType::Inner,
1989 };
1990 classes.extend(linear.leaf_classes[leaf_idx].iter().copied());
1991 }
1992 let prefix_step = &linear.joins[candidate.pair_start - 1];
1993 let mut helper_right_keys = Vec::new();
1994 for &rk in &prefix_step.right_keys {
1995 let class = linear.leaf_classes[candidate.pair_start][rk];
1996 helper_right_keys.push(candidate.exposed_classes.iter().position(|c| *c == class)?);
1997 }
1998 node = RirNode::Join {
1999 left: Box::new(node),
2000 right: Box::new(RirNode::Scan { rel: helper_rel }),
2001 left_keys: prefix_step.left_keys.clone(),
2002 right_keys: helper_right_keys,
2003 join_type: JoinType::Inner,
2004 };
2005 classes.extend(candidate.exposed_classes.iter().copied());
2006 for leaf_idx in candidate.pair_start + 2..linear.leaves.len() {
2007 let step = &linear.joins[leaf_idx - 1];
2008 let mut left_keys = Vec::new();
2009 for &lk in &step.left_keys {
2010 let class = class_at_state(linear, leaf_idx, lk)?;
2011 left_keys.push(classes.iter().position(|c| *c == class)?);
2012 }
2013 node = RirNode::Join {
2014 left: Box::new(node),
2015 right: Box::new(RirNode::Scan {
2016 rel: linear.leaves[leaf_idx],
2017 }),
2018 left_keys,
2019 right_keys: step.right_keys.clone(),
2020 join_type: JoinType::Inner,
2021 };
2022 classes.extend(linear.leaf_classes[leaf_idx].iter().copied());
2023 }
2024 let mut project = Vec::with_capacity(linear.project.len());
2025 for expr in &linear.project {
2026 let ProjectExpr::Column(col) = expr else {
2027 return None;
2028 };
2029 let class = *linear.final_classes.get(*col)?;
2030 let mapped = classes.iter().position(|c| *c == class)?;
2031 project.push(ProjectExpr::Column(mapped));
2032 }
2033 Some(RirNode::Project {
2034 input: Box::new(node),
2035 columns: project,
2036 })
2037 }
2038
2039 struct UnionFind {
2040 parent: Vec<usize>,
2041 }
2042
2043 impl UnionFind {
2044 fn new(len: usize) -> Self {
2045 Self {
2046 parent: (0..len).collect(),
2047 }
2048 }
2049
2050 fn find(&mut self, x: usize) -> usize {
2051 let p = self.parent[x];
2052 if p == x {
2053 x
2054 } else {
2055 let root = self.find(p);
2056 self.parent[x] = root;
2057 root
2058 }
2059 }
2060
2061 fn union(&mut self, a: usize, b: usize) {
2062 let ra = self.find(a);
2063 let rb = self.find(b);
2064 if ra != rb {
2065 self.parent[rb] = ra;
2066 }
2067 }
2068 }
2069}
2070
2071#[path = "optimizer/stream_schedule_pass.rs"]
2072pub mod stream_schedule_pass;
2073
2074#[cfg(test)]
2075mod helper_split_pass_tests {
2076 use std::collections::HashMap;
2077
2078 use super::helper_split_pass;
2079 use xlog_core::{RelId, ScalarType, Schema};
2080 use xlog_ir::{CompiledRule, ExecutionPlan, JoinType, ProjectExpr, RirMeta, RirNode, Scc};
2081 use xlog_stats::{ColumnStats, StatsManager};
2082
2083 fn edge_schema() -> Schema {
2084 Schema::new(vec![
2085 ("c0".to_string(), ScalarType::U32),
2086 ("c1".to_string(), ScalarType::U32),
2087 ])
2088 }
2089
2090 fn helper_schema() -> Schema {
2091 Schema::new(vec![
2092 ("c0".to_string(), ScalarType::U32),
2093 ("c1".to_string(), ScalarType::U32),
2094 ])
2095 }
2096
2097 fn schemas() -> HashMap<RelId, Schema> {
2098 (0..6)
2099 .map(|idx| (RelId(idx), edge_schema()))
2100 .collect::<HashMap<_, _>>()
2101 }
2102
2103 fn left_deep_fixture_body() -> RirNode {
2104 let ab_bc = RirNode::Join {
2105 left: Box::new(RirNode::Scan { rel: RelId(0) }),
2106 right: Box::new(RirNode::Scan { rel: RelId(1) }),
2107 left_keys: vec![1],
2108 right_keys: vec![0],
2109 join_type: JoinType::Inner,
2110 };
2111 let with_cd = RirNode::Join {
2112 left: Box::new(ab_bc),
2113 right: Box::new(RirNode::Scan { rel: RelId(2) }),
2114 left_keys: vec![3],
2115 right_keys: vec![0],
2116 join_type: JoinType::Inner,
2117 };
2118 let with_de = RirNode::Join {
2119 left: Box::new(with_cd),
2120 right: Box::new(RirNode::Scan { rel: RelId(3) }),
2121 left_keys: vec![5],
2122 right_keys: vec![0],
2123 join_type: JoinType::Inner,
2124 };
2125 let with_ef = RirNode::Join {
2126 left: Box::new(with_de),
2127 right: Box::new(RirNode::Scan { rel: RelId(4) }),
2128 left_keys: vec![7],
2129 right_keys: vec![0],
2130 join_type: JoinType::Inner,
2131 };
2132 let with_af = RirNode::Join {
2133 left: Box::new(with_ef),
2134 right: Box::new(RirNode::Scan { rel: RelId(5) }),
2135 left_keys: vec![0, 9],
2136 right_keys: vec![0, 1],
2137 join_type: JoinType::Inner,
2138 };
2139 RirNode::Project {
2140 input: Box::new(with_af),
2141 columns: vec![
2142 ProjectExpr::Column(0),
2143 ProjectExpr::Column(1),
2144 ProjectExpr::Column(3),
2145 ProjectExpr::Column(5),
2146 ProjectExpr::Column(9),
2147 ],
2148 }
2149 }
2150
2151 fn plan() -> ExecutionPlan {
2152 ExecutionPlan {
2153 sccs: vec![Scc {
2154 id: 0,
2155 predicates: vec!["out".to_string()],
2156 is_recursive: false,
2157 }],
2158 strata: vec![],
2159 rules_by_scc: vec![vec![CompiledRule {
2160 head: "out".to_string(),
2161 body: left_deep_fixture_body(),
2162 meta: RirMeta::with_schema(Schema::new(vec![
2163 ("a".to_string(), ScalarType::U32),
2164 ("b".to_string(), ScalarType::U32),
2165 ("c".to_string(), ScalarType::U32),
2166 ("d".to_string(), ScalarType::U32),
2167 ("f".to_string(), ScalarType::U32),
2168 ])),
2169 }]],
2170 est_memory_peak: 0,
2171 }
2172 }
2173
2174 fn stats_for_de(distinct_d: u64) -> StatsManager {
2175 let mut stats = StatsManager::new();
2176 for idx in 0..6 {
2177 stats.register_relation(RelId(idx));
2178 stats.update_cardinality(RelId(idx), 8192);
2179 }
2180 let mut d_col = ColumnStats::new(0, ScalarType::U32);
2181 d_col.update_distinct(distinct_d);
2182 stats.add_column_stats(RelId(3), d_col);
2183 stats
2184 }
2185
2186 fn contains_scan(node: &RirNode, rel: RelId) -> bool {
2187 match node {
2188 RirNode::Scan { rel: scan_rel } => *scan_rel == rel,
2189 RirNode::Join { left, right, .. } | RirNode::ChainJoin { left, right, .. } => {
2190 contains_scan(left, rel) || contains_scan(right, rel)
2191 }
2192 RirNode::Project { input, .. }
2193 | RirNode::Filter { input, .. }
2194 | RirNode::Distinct { input, .. }
2195 | RirNode::GroupBy { input, .. } => contains_scan(input, rel),
2196 RirNode::Union { inputs } => inputs.iter().any(|input| contains_scan(input, rel)),
2197 RirNode::Diff { left, right } => contains_scan(left, rel) || contains_scan(right, rel),
2198 RirNode::Fixpoint {
2199 base, recursive, ..
2200 } => contains_scan(base, rel) || contains_scan(recursive, rel),
2201 RirNode::MultiWayJoin { inputs, .. } => {
2202 inputs.iter().any(|input| contains_scan(input, rel))
2203 }
2204 RirNode::TensorMaskedJoin { rel_index, .. } => {
2205 rel_index.iter().any(|(input_rel, _)| *input_rel == rel)
2206 }
2207 RirNode::Unit => false,
2208 }
2209 }
2210
2211 #[test]
2212 fn helper_split_extracts_buried_pair() {
2213 let mut plan = plan();
2214 let schemas = schemas();
2215 let stats = stats_for_de(1);
2216 let specs = helper_split_pass::run(&mut plan, &schemas, &stats, |_| {
2217 ("__w37_helper_6".to_string(), RelId(6))
2218 });
2219
2220 assert_eq!(specs.len(), 1);
2221 assert_eq!(specs[0].name, "__w37_helper_6");
2222 assert_eq!(specs[0].rel_id, RelId(6));
2223 assert_eq!(specs[0].schema, helper_schema());
2224 assert_eq!(specs[0].source_rels, [RelId(3), RelId(4)]);
2225 assert_eq!(plan.rules_by_scc[0].len(), 2);
2226 assert_eq!(plan.rules_by_scc[0][0].head, "__w37_helper_6");
2227 assert_eq!(plan.rules_by_scc[0][1].head, "out");
2228 assert!(contains_scan(&plan.rules_by_scc[0][1].body, RelId(6)));
2229 assert!(plan.sccs[0]
2230 .predicates
2231 .iter()
2232 .any(|predicate| predicate == "__w37_helper_6"));
2233 }
2234
2235 #[test]
2236 fn helper_split_ignores_flat_distribution() {
2237 let mut plan = plan();
2238 let schemas = schemas();
2239 let stats = stats_for_de(8192);
2240 let specs = helper_split_pass::run(&mut plan, &schemas, &stats, |_| {
2241 ("__w37_helper_6".to_string(), RelId(6))
2242 });
2243
2244 assert!(specs.is_empty());
2245 assert_eq!(plan.rules_by_scc[0].len(), 1);
2246 assert!(!contains_scan(&plan.rules_by_scc[0][0].body, RelId(6)));
2247 }
2248}
2249
2250mod reorder {
2254 use std::collections::HashMap;
2255 use xlog_core::RelId;
2256 use xlog_ir::rir::ProjectExpr;
2257 use xlog_ir::{JoinType, RirNode};
2258 use xlog_stats::StatsManager;
2259
2260 fn ac3(atom: u8, col: u8) -> u8 {
2261 atom * 2 + col
2262 }
2263 fn ac4(atom: u8, col: u8) -> u8 {
2264 atom * 2 + col
2265 }
2266 fn uf_find_n<const N: usize>(parent: &mut [u8; N], x: u8) -> u8 {
2267 let mut root = x;
2268 while parent[root as usize] != root {
2269 root = parent[root as usize];
2270 }
2271 let mut cur = x;
2272 while parent[cur as usize] != root {
2273 let next = parent[cur as usize];
2274 parent[cur as usize] = root;
2275 cur = next;
2276 }
2277 root
2278 }
2279 fn uf_union_n<const N: usize>(parent: &mut [u8; N], a: u8, b: u8) {
2280 let ra = uf_find_n(parent, a);
2281 let rb = uf_find_n(parent, b);
2282 if ra != rb {
2283 parent[rb as usize] = ra;
2284 }
2285 }
2286
2287 fn populated_card(stats: &StatsManager, rel: RelId) -> Option<u64> {
2288 stats
2289 .get_relation_stats(rel)
2290 .map(|s| s.cardinality)
2291 .filter(|c| *c > 0)
2292 }
2293
2294 struct TriangleSemantics {
2299 rel_xy: RelId,
2300 rel_yz: RelId,
2301 rel_xz: RelId,
2302 }
2303
2304 fn match_and_infer_triangle(body: &RirNode) -> Option<TriangleSemantics> {
2305 let RirNode::Project {
2306 input: outer_input,
2307 columns,
2308 } = body
2309 else {
2310 return None;
2311 };
2312 let RirNode::Join {
2313 left: l1,
2314 right: r1,
2315 left_keys: lk1,
2316 right_keys: rk1,
2317 join_type: jt1,
2318 } = outer_input.as_ref()
2319 else {
2320 return None;
2321 };
2322 if !matches!(jt1, JoinType::Inner) {
2323 return None;
2324 }
2325 let RirNode::Scan { rel: rel_third } = r1.as_ref() else {
2326 return None;
2327 };
2328 let RirNode::Join {
2329 left: l2,
2330 right: r2,
2331 left_keys: lk2,
2332 right_keys: rk2,
2333 join_type: jt2,
2334 } = l1.as_ref()
2335 else {
2336 return None;
2337 };
2338 if !matches!(jt2, JoinType::Inner) {
2339 return None;
2340 }
2341 let RirNode::Scan { rel: rel_inner_l } = l2.as_ref() else {
2342 return None;
2343 };
2344 let RirNode::Scan { rel: rel_inner_r } = r2.as_ref() else {
2345 return None;
2346 };
2347 if lk2.len() != 1 || rk2.len() != 1 || lk1.len() != 2 || rk1.len() != 2 {
2348 return None;
2349 }
2350 if columns.len() != 3 {
2351 return None;
2352 }
2353 if lk2[0] >= 2 || rk2[0] >= 2 {
2354 return None;
2355 }
2356 if lk1.iter().any(|k| *k >= 4) || rk1.iter().any(|k| *k >= 2) {
2357 return None;
2358 }
2359
2360 let mut parent = [0u8, 1, 2, 3, 4, 5];
2361 uf_union_n::<6>(&mut parent, ac3(0, lk2[0] as u8), ac3(1, rk2[0] as u8));
2362 for i in 0..2 {
2363 let inner_ac = match lk1[i] {
2364 0 => (0u8, 0u8),
2365 1 => (0, 1),
2366 2 => (1, 0),
2367 3 => (1, 1),
2368 _ => return None,
2369 };
2370 uf_union_n::<6>(
2371 &mut parent,
2372 ac3(inner_ac.0, inner_ac.1),
2373 ac3(2, rk1[i] as u8),
2374 );
2375 }
2376 let roots: [u8; 6] = std::array::from_fn(|i| uf_find_n::<6>(&mut parent, i as u8));
2377 let mut counts: HashMap<u8, u8> = HashMap::new();
2378 for r in &roots {
2379 *counts.entry(*r).or_insert(0) += 1;
2380 }
2381 if counts.len() != 3 || counts.values().any(|c| *c != 2) {
2382 return None;
2383 }
2384 let mut head_classes: [u8; 3] = [0; 3];
2385 for (i, pc) in columns.iter().enumerate() {
2386 let ProjectExpr::Column(k) = pc else {
2387 return None;
2388 };
2389 let outer_ac = match *k {
2390 0 => (0u8, 0u8),
2391 1 => (0, 1),
2392 2 => (1, 0),
2393 3 => (1, 1),
2394 4 => (2, 0),
2395 5 => (2, 1),
2396 _ => return None,
2397 };
2398 head_classes[i] = uf_find_n::<6>(&mut parent, ac3(outer_ac.0, outer_ac.1));
2399 }
2400 if head_classes[0] == head_classes[1]
2401 || head_classes[0] == head_classes[2]
2402 || head_classes[1] == head_classes[2]
2403 {
2404 return None;
2405 }
2406 let x_class = head_classes[0];
2407 let y_class = head_classes[1];
2408 let z_class = head_classes[2];
2409 let atom_classes = |a: u8| (roots[ac3(a, 0) as usize], roots[ac3(a, 1) as usize]);
2410 let atom_rels = [*rel_inner_l, *rel_inner_r, *rel_third];
2411 let mut rel_xy = None;
2412 let mut rel_yz = None;
2413 let mut rel_xz = None;
2414 for atom_idx in 0..3u8 {
2415 let (c0, c1) = atom_classes(atom_idx);
2416 let bx = c0 == x_class || c1 == x_class;
2417 let by = c0 == y_class || c1 == y_class;
2418 let bz = c0 == z_class || c1 == z_class;
2419 match (bx, by, bz) {
2420 (true, true, false) => rel_xy = Some(atom_rels[atom_idx as usize]),
2421 (false, true, true) => rel_yz = Some(atom_rels[atom_idx as usize]),
2422 (true, false, true) => rel_xz = Some(atom_rels[atom_idx as usize]),
2423 _ => return None,
2424 }
2425 }
2426 Some(TriangleSemantics {
2427 rel_xy: rel_xy?,
2428 rel_yz: rel_yz?,
2429 rel_xz: rel_xz?,
2430 })
2431 }
2432
2433 #[derive(Clone, Copy, Debug, PartialEq, Eq)]
2434 #[allow(clippy::enum_variant_names)]
2435 enum TriangleInnerPair {
2436 YShared,
2437 XShared,
2438 ZShared,
2439 }
2440
2441 fn build_triangle_body(s: &TriangleSemantics, inner_pair: TriangleInnerPair) -> RirNode {
2442 let mk_scan = |r: RelId| RirNode::Scan { rel: r };
2443 match inner_pair {
2444 TriangleInnerPair::YShared => {
2445 let inner = RirNode::Join {
2446 left: Box::new(mk_scan(s.rel_xy)),
2447 right: Box::new(mk_scan(s.rel_yz)),
2448 left_keys: vec![1],
2449 right_keys: vec![0],
2450 join_type: JoinType::Inner,
2451 };
2452 let outer = RirNode::Join {
2453 left: Box::new(inner),
2454 right: Box::new(mk_scan(s.rel_xz)),
2455 left_keys: vec![0, 3],
2456 right_keys: vec![0, 1],
2457 join_type: JoinType::Inner,
2458 };
2459 RirNode::Project {
2460 input: Box::new(outer),
2461 columns: vec![
2462 ProjectExpr::Column(0),
2463 ProjectExpr::Column(1),
2464 ProjectExpr::Column(3),
2465 ],
2466 }
2467 }
2468 TriangleInnerPair::XShared => {
2469 let inner = RirNode::Join {
2470 left: Box::new(mk_scan(s.rel_xy)),
2471 right: Box::new(mk_scan(s.rel_xz)),
2472 left_keys: vec![0],
2473 right_keys: vec![0],
2474 join_type: JoinType::Inner,
2475 };
2476 let outer = RirNode::Join {
2477 left: Box::new(inner),
2478 right: Box::new(mk_scan(s.rel_yz)),
2479 left_keys: vec![1, 3],
2480 right_keys: vec![0, 1],
2481 join_type: JoinType::Inner,
2482 };
2483 RirNode::Project {
2484 input: Box::new(outer),
2485 columns: vec![
2486 ProjectExpr::Column(0),
2487 ProjectExpr::Column(1),
2488 ProjectExpr::Column(3),
2489 ],
2490 }
2491 }
2492 TriangleInnerPair::ZShared => {
2493 let inner = RirNode::Join {
2494 left: Box::new(mk_scan(s.rel_xz)),
2495 right: Box::new(mk_scan(s.rel_yz)),
2496 left_keys: vec![1],
2497 right_keys: vec![1],
2498 join_type: JoinType::Inner,
2499 };
2500 let outer = RirNode::Join {
2501 left: Box::new(inner),
2502 right: Box::new(mk_scan(s.rel_xy)),
2503 left_keys: vec![0, 2],
2504 right_keys: vec![0, 1],
2505 join_type: JoinType::Inner,
2506 };
2507 RirNode::Project {
2508 input: Box::new(outer),
2509 columns: vec![
2510 ProjectExpr::Column(0),
2511 ProjectExpr::Column(2),
2512 ProjectExpr::Column(3),
2513 ],
2514 }
2515 }
2516 }
2517 }
2518
2519 pub fn try_reorder_triangle(body: &RirNode, stats: &StatsManager) -> Option<RirNode> {
2520 let s = match_and_infer_triangle(body)?;
2521 let _ = (
2522 populated_card(stats, s.rel_xy)?,
2523 populated_card(stats, s.rel_yz)?,
2524 populated_card(stats, s.rel_xz)?,
2525 );
2526 let est_y = stats.estimate_join_cardinality(s.rel_xy, s.rel_yz, &[1], &[0]);
2527 let est_x = stats.estimate_join_cardinality(s.rel_xy, s.rel_xz, &[0], &[0]);
2528 let est_z = stats.estimate_join_cardinality(s.rel_yz, s.rel_xz, &[1], &[1]);
2529 let mut best = (TriangleInnerPair::YShared, est_y);
2530 if est_x < best.1 {
2531 best = (TriangleInnerPair::XShared, est_x);
2532 }
2533 if est_z < best.1 {
2534 best = (TriangleInnerPair::ZShared, est_z);
2535 }
2536 let candidate = build_triangle_body(&s, best.0);
2537 if format!("{:?}", candidate) == format!("{:?}", body) {
2543 return None;
2544 }
2545 Some(candidate)
2546 }
2547
2548 struct Cycle4Semantics {
2553 rel_wx: RelId,
2554 rel_xy: RelId,
2555 rel_yz: RelId,
2556 rel_zw: RelId,
2557 }
2558
2559 fn match_and_infer_4cycle(body: &RirNode) -> Option<Cycle4Semantics> {
2560 let RirNode::Project {
2561 input: outer_input,
2562 columns,
2563 } = body
2564 else {
2565 return None;
2566 };
2567 let RirNode::Join {
2568 left: outer_l,
2569 right: outer_r,
2570 left_keys: olk,
2571 right_keys: ork,
2572 join_type: ojt,
2573 } = outer_input.as_ref()
2574 else {
2575 return None;
2576 };
2577 if !matches!(ojt, JoinType::Inner) {
2578 return None;
2579 }
2580 let RirNode::Join {
2581 left: ll,
2582 right: lr,
2583 left_keys: ilk_l,
2584 right_keys: irk_l,
2585 join_type: ijt_l,
2586 } = outer_l.as_ref()
2587 else {
2588 return None;
2589 };
2590 if !matches!(ijt_l, JoinType::Inner) {
2591 return None;
2592 }
2593 let RirNode::Scan { rel: rel_ll } = ll.as_ref() else {
2594 return None;
2595 };
2596 let RirNode::Scan { rel: rel_lr } = lr.as_ref() else {
2597 return None;
2598 };
2599 let RirNode::Join {
2600 left: rl,
2601 right: rr,
2602 left_keys: ilk_r,
2603 right_keys: irk_r,
2604 join_type: ijt_r,
2605 } = outer_r.as_ref()
2606 else {
2607 return None;
2608 };
2609 if !matches!(ijt_r, JoinType::Inner) {
2610 return None;
2611 }
2612 let RirNode::Scan { rel: rel_rl } = rl.as_ref() else {
2613 return None;
2614 };
2615 let RirNode::Scan { rel: rel_rr } = rr.as_ref() else {
2616 return None;
2617 };
2618 if ilk_l.len() != 1 || irk_l.len() != 1 || ilk_r.len() != 1 || irk_r.len() != 1 {
2619 return None;
2620 }
2621 if olk.len() != 2 || ork.len() != 2 || columns.len() != 4 {
2622 return None;
2623 }
2624 if ilk_l[0] >= 2 || irk_l[0] >= 2 || ilk_r[0] >= 2 || irk_r[0] >= 2 {
2625 return None;
2626 }
2627 if olk.iter().any(|k| *k >= 4) || ork.iter().any(|k| *k >= 4) {
2628 return None;
2629 }
2630
2631 let mut parent = [0u8, 1, 2, 3, 4, 5, 6, 7];
2632 uf_union_n::<8>(&mut parent, ac4(0, ilk_l[0] as u8), ac4(1, irk_l[0] as u8));
2633 uf_union_n::<8>(&mut parent, ac4(2, ilk_r[0] as u8), ac4(3, irk_r[0] as u8));
2634 for i in 0..2 {
2635 let l_ac = match olk[i] {
2636 0 => (0u8, 0u8),
2637 1 => (0, 1),
2638 2 => (1, 0),
2639 3 => (1, 1),
2640 _ => return None,
2641 };
2642 let r_ac = match ork[i] {
2643 0 => (2u8, 0u8),
2644 1 => (2, 1),
2645 2 => (3, 0),
2646 3 => (3, 1),
2647 _ => return None,
2648 };
2649 uf_union_n::<8>(&mut parent, ac4(l_ac.0, l_ac.1), ac4(r_ac.0, r_ac.1));
2650 }
2651 let roots: [u8; 8] = std::array::from_fn(|i| uf_find_n::<8>(&mut parent, i as u8));
2652 let mut counts: HashMap<u8, u8> = HashMap::new();
2653 for r in &roots {
2654 *counts.entry(*r).or_insert(0) += 1;
2655 }
2656 if counts.len() != 4 || counts.values().any(|c| *c != 2) {
2657 return None;
2658 }
2659
2660 let mut head_classes: [u8; 4] = [0; 4];
2661 for (i, pc) in columns.iter().enumerate() {
2662 let ProjectExpr::Column(k) = pc else {
2663 return None;
2664 };
2665 let ac = match *k {
2666 0 => (0u8, 0u8),
2667 1 => (0, 1),
2668 2 => (1, 0),
2669 3 => (1, 1),
2670 4 => (2, 0),
2671 5 => (2, 1),
2672 6 => (3, 0),
2673 7 => (3, 1),
2674 _ => return None,
2675 };
2676 head_classes[i] = uf_find_n::<8>(&mut parent, ac4(ac.0, ac.1));
2677 }
2678 for i in 0..4 {
2679 for j in (i + 1)..4 {
2680 if head_classes[i] == head_classes[j] {
2681 return None;
2682 }
2683 }
2684 }
2685 let w_class = head_classes[0];
2686 let x_class = head_classes[1];
2687 let y_class = head_classes[2];
2688 let z_class = head_classes[3];
2689 let atom_classes = |a: u8| (roots[ac4(a, 0) as usize], roots[ac4(a, 1) as usize]);
2690 let atom_rels = [*rel_ll, *rel_lr, *rel_rl, *rel_rr];
2691 let mut rel_wx = None;
2692 let mut rel_xy = None;
2693 let mut rel_yz = None;
2694 let mut rel_zw = None;
2695 for atom_idx in 0..4u8 {
2696 let (c0, c1) = atom_classes(atom_idx);
2697 let bw = c0 == w_class || c1 == w_class;
2698 let bx = c0 == x_class || c1 == x_class;
2699 let by = c0 == y_class || c1 == y_class;
2700 let bz = c0 == z_class || c1 == z_class;
2701 match (bw, bx, by, bz) {
2702 (true, true, false, false) => rel_wx = Some(atom_rels[atom_idx as usize]),
2703 (false, true, true, false) => rel_xy = Some(atom_rels[atom_idx as usize]),
2704 (false, false, true, true) => rel_yz = Some(atom_rels[atom_idx as usize]),
2705 (true, false, false, true) => rel_zw = Some(atom_rels[atom_idx as usize]),
2706 _ => return None,
2707 }
2708 }
2709 Some(Cycle4Semantics {
2710 rel_wx: rel_wx?,
2711 rel_xy: rel_xy?,
2712 rel_yz: rel_yz?,
2713 rel_zw: rel_zw?,
2714 })
2715 }
2716
2717 #[derive(Clone, Copy, Debug, PartialEq, Eq)]
2718 enum Cycle4Grouping {
2719 Default,
2720 Alt,
2721 }
2722
2723 fn build_4cycle_body(s: &Cycle4Semantics, g: Cycle4Grouping) -> RirNode {
2724 let mk_scan = |r: RelId| RirNode::Scan { rel: r };
2725 match g {
2726 Cycle4Grouping::Default => {
2727 let il = RirNode::Join {
2728 left: Box::new(mk_scan(s.rel_wx)),
2729 right: Box::new(mk_scan(s.rel_xy)),
2730 left_keys: vec![1],
2731 right_keys: vec![0],
2732 join_type: JoinType::Inner,
2733 };
2734 let ir = RirNode::Join {
2735 left: Box::new(mk_scan(s.rel_yz)),
2736 right: Box::new(mk_scan(s.rel_zw)),
2737 left_keys: vec![1],
2738 right_keys: vec![0],
2739 join_type: JoinType::Inner,
2740 };
2741 let outer = RirNode::Join {
2742 left: Box::new(il),
2743 right: Box::new(ir),
2744 left_keys: vec![0, 3],
2745 right_keys: vec![3, 0],
2746 join_type: JoinType::Inner,
2747 };
2748 RirNode::Project {
2749 input: Box::new(outer),
2750 columns: vec![
2751 ProjectExpr::Column(0),
2752 ProjectExpr::Column(1),
2753 ProjectExpr::Column(3),
2754 ProjectExpr::Column(5),
2755 ],
2756 }
2757 }
2758 Cycle4Grouping::Alt => {
2759 let il = RirNode::Join {
2760 left: Box::new(mk_scan(s.rel_xy)),
2761 right: Box::new(mk_scan(s.rel_yz)),
2762 left_keys: vec![1],
2763 right_keys: vec![0],
2764 join_type: JoinType::Inner,
2765 };
2766 let ir = RirNode::Join {
2767 left: Box::new(mk_scan(s.rel_zw)),
2768 right: Box::new(mk_scan(s.rel_wx)),
2769 left_keys: vec![1],
2770 right_keys: vec![0],
2771 join_type: JoinType::Inner,
2772 };
2773 let outer = RirNode::Join {
2774 left: Box::new(il),
2775 right: Box::new(ir),
2776 left_keys: vec![0, 3],
2777 right_keys: vec![3, 0],
2778 join_type: JoinType::Inner,
2779 };
2780 RirNode::Project {
2781 input: Box::new(outer),
2782 columns: vec![
2783 ProjectExpr::Column(5),
2784 ProjectExpr::Column(0),
2785 ProjectExpr::Column(1),
2786 ProjectExpr::Column(3),
2787 ],
2788 }
2789 }
2790 }
2791 }
2792
2793 pub fn try_reorder_4cycle(body: &RirNode, stats: &StatsManager) -> Option<RirNode> {
2794 let s = match_and_infer_4cycle(body)?;
2795 let _ = (
2796 populated_card(stats, s.rel_wx)?,
2797 populated_card(stats, s.rel_xy)?,
2798 populated_card(stats, s.rel_yz)?,
2799 populated_card(stats, s.rel_zw)?,
2800 );
2801 let est_default = stats
2802 .estimate_join_cardinality(s.rel_wx, s.rel_xy, &[1], &[0])
2803 .saturating_add(stats.estimate_join_cardinality(s.rel_yz, s.rel_zw, &[1], &[0]));
2804 let est_alt = stats
2805 .estimate_join_cardinality(s.rel_xy, s.rel_yz, &[1], &[0])
2806 .saturating_add(stats.estimate_join_cardinality(s.rel_zw, s.rel_wx, &[1], &[0]));
2807 let chosen = if est_alt < est_default {
2808 Cycle4Grouping::Alt
2809 } else {
2810 Cycle4Grouping::Default
2811 };
2812 let candidate = build_4cycle_body(&s, chosen);
2813 if format!("{:?}", candidate) == format!("{:?}", body) {
2814 return None;
2815 }
2816 Some(candidate)
2817 }
2818}
2819
2820#[cfg(test)]
2821mod selectivity_pass_tests {
2822 use super::selectivity_pass;
2823 use crate::Compiler;
2824 use xlog_stats::StatsManager;
2825
2826 fn body_snapshots(plan: &xlog_ir::ExecutionPlan) -> Vec<String> {
2827 plan.rules_by_scc
2828 .iter()
2829 .flatten()
2830 .map(|r| format!("{:?}", r.body))
2831 .collect()
2832 }
2833
2834 #[test]
2835 fn selectivity_pass_is_noop_for_triangle_plan() {
2836 let mut compiler = Compiler::new();
2837 let plan = compiler
2838 .compile("tri(X, Y, Z) :- e1(X, Y), e2(Y, Z), e3(X, Z).")
2839 .expect("compile");
2840 let before = body_snapshots(&plan);
2841 let stats = StatsManager::new();
2842 let mut plan2 = plan.clone();
2843 selectivity_pass::run(&mut plan2, &stats, &std::collections::HashMap::new());
2844 let after = body_snapshots(&plan2);
2845 assert_eq!(
2846 before, after,
2847 "selectivity_pass must preserve every triangle rule body byte-for-byte"
2848 );
2849 }
2850
2851 #[test]
2852 fn selectivity_pass_is_noop_for_4cycle_plan() {
2853 let mut compiler = Compiler::new();
2854 let plan = compiler
2855 .compile("cycle4(W, X, Y, Z) :- e1(W, X), e2(X, Y), e3(Y, Z), e4(Z, W).")
2856 .expect("compile");
2857 let before = body_snapshots(&plan);
2858 let stats = StatsManager::new();
2859 let mut plan2 = plan.clone();
2860 selectivity_pass::run(&mut plan2, &stats, &std::collections::HashMap::new());
2861 let after = body_snapshots(&plan2);
2862 assert_eq!(
2863 before, after,
2864 "selectivity_pass must preserve every 4-cycle rule body byte-for-byte"
2865 );
2866 }
2867
2868 #[test]
2869 fn selectivity_pass_is_noop_for_recursive_scc() {
2870 let mut compiler = Compiler::new();
2871 let plan = compiler
2872 .compile(
2873 "edge(1, 2). edge(2, 3). \
2874 reach(X, Y) :- edge(X, Y). \
2875 reach(X, Z) :- reach(X, Y), edge(Y, Z).",
2876 )
2877 .expect("compile");
2878 let before = body_snapshots(&plan);
2879 let stats = StatsManager::new();
2880 let mut plan2 = plan.clone();
2881 selectivity_pass::run(&mut plan2, &stats, &std::collections::HashMap::new());
2882 let after = body_snapshots(&plan2);
2883 assert_eq!(
2884 before, after,
2885 "selectivity_pass must preserve recursive SCC bodies byte-for-byte"
2886 );
2887 }
2888
2889 use xlog_core::RelId;
2894 use xlog_ir::plan::{CompiledRule, PlanBuilder, Scc};
2895 use xlog_ir::rir::ProjectExpr;
2896 use xlog_ir::{ExecutionPlan, JoinType, RirNode};
2897
2898 fn synth_triangle_plan() -> ExecutionPlan {
2907 let inner = RirNode::Join {
2908 left: Box::new(RirNode::Scan { rel: RelId(1) }),
2909 right: Box::new(RirNode::Scan { rel: RelId(2) }),
2910 left_keys: vec![1],
2911 right_keys: vec![0],
2912 join_type: JoinType::Inner,
2913 };
2914 let outer = RirNode::Join {
2915 left: Box::new(inner),
2916 right: Box::new(RirNode::Scan { rel: RelId(3) }),
2917 left_keys: vec![0, 3],
2918 right_keys: vec![0, 1],
2919 join_type: JoinType::Inner,
2920 };
2921 let body = RirNode::Project {
2922 input: Box::new(outer),
2923 columns: vec![
2924 ProjectExpr::Column(0),
2925 ProjectExpr::Column(1),
2926 ProjectExpr::Column(3),
2927 ],
2928 };
2929 let mut builder = PlanBuilder::new();
2930 builder.add_scc(Scc {
2931 id: 0,
2932 predicates: vec!["tri".to_string()],
2933 is_recursive: false,
2934 });
2935 builder.add_rule(
2936 0,
2937 CompiledRule {
2938 head: "tri".to_string(),
2939 body,
2940 meta: Default::default(),
2941 },
2942 );
2943 builder.build()
2944 }
2945
2946 fn seed_triangle_stats(c1: u64, c2: u64, c3: u64) -> StatsManager {
2950 let mut stats = StatsManager::new();
2951 for (rid, card) in [(RelId(1), c1), (RelId(2), c2), (RelId(3), c3)] {
2952 stats.register_relation(rid);
2953 stats.update_cardinality(rid, card);
2954 }
2955 stats
2956 }
2957
2958 fn inspect_triangle_inner_pair(plan: &xlog_ir::ExecutionPlan) -> Option<(RelId, RelId)> {
2969 let body = &plan.rules_by_scc.iter().flatten().next()?.body;
2970 let body = match body {
2971 xlog_ir::RirNode::MultiWayJoin { fallback, .. } => fallback.as_ref(),
2972 other => other,
2973 };
2974 let xlog_ir::RirNode::Project { input, .. } = body else {
2975 return None;
2976 };
2977 let xlog_ir::RirNode::Join { left, .. } = input.as_ref() else {
2978 return None;
2979 };
2980 let xlog_ir::RirNode::Join {
2981 left: l2,
2982 right: r2,
2983 ..
2984 } = left.as_ref()
2985 else {
2986 return None;
2987 };
2988 let xlog_ir::RirNode::Scan { rel: rel_l } = l2.as_ref() else {
2989 return None;
2990 };
2991 let xlog_ir::RirNode::Scan { rel: rel_r } = r2.as_ref() else {
2992 return None;
2993 };
2994 Some((*rel_l, *rel_r))
2995 }
2996
2997 #[test]
3004 fn selectivity_pass_picks_y_shared_inner_when_e1_e2_smallest() {
3005 let mut plan = synth_triangle_plan();
3006 let stats = seed_triangle_stats(10, 10, 100_000);
3008 selectivity_pass::run(&mut plan, &stats, &std::collections::HashMap::new());
3009 let pair = inspect_triangle_inner_pair(&plan).expect("inner pair");
3010 assert!(
3012 pair == (RelId(1), RelId(2)) || pair == (RelId(2), RelId(1)),
3013 "expected (RelId(1), RelId(2)) for Y-shared; got {:?}",
3014 pair
3015 );
3016 }
3017
3018 #[test]
3021 fn selectivity_pass_picks_x_shared_inner_when_e1_e3_smallest() {
3022 let mut plan = synth_triangle_plan();
3023 let stats = seed_triangle_stats(10, 100_000, 10);
3025 selectivity_pass::run(&mut plan, &stats, &std::collections::HashMap::new());
3026 let pair = inspect_triangle_inner_pair(&plan).expect("inner pair");
3027 assert!(
3029 pair == (RelId(1), RelId(3)) || pair == (RelId(3), RelId(1)),
3030 "expected (RelId(1), RelId(3)) for X-shared; got {:?}",
3031 pair
3032 );
3033 }
3034
3035 #[test]
3038 fn selectivity_pass_picks_z_shared_inner_when_e2_e3_smallest() {
3039 let mut plan = synth_triangle_plan();
3040 let stats = seed_triangle_stats(100_000, 10, 10);
3042 selectivity_pass::run(&mut plan, &stats, &std::collections::HashMap::new());
3043 let pair = inspect_triangle_inner_pair(&plan).expect("inner pair");
3044 assert!(
3046 pair == (RelId(2), RelId(3)) || pair == (RelId(3), RelId(2)),
3047 "expected (RelId(2), RelId(3)) for Z-shared; got {:?}",
3048 pair
3049 );
3050 }
3051
3052 #[test]
3057 fn selectivity_pass_two_snapshots_produce_different_inner_pairs() {
3058 let mut plan_a = synth_triangle_plan();
3059 let stats_a = seed_triangle_stats(10, 10, 100_000); selectivity_pass::run(&mut plan_a, &stats_a, &std::collections::HashMap::new());
3061 let pair_a = inspect_triangle_inner_pair(&plan_a).expect("snapshot A pair");
3062
3063 let mut plan_b = synth_triangle_plan();
3064 let stats_b = seed_triangle_stats(100_000, 10, 10); selectivity_pass::run(&mut plan_b, &stats_b, &std::collections::HashMap::new());
3066 let pair_b = inspect_triangle_inner_pair(&plan_b).expect("snapshot B pair");
3067
3068 let normalize = |(a, b): (RelId, RelId)| -> (RelId, RelId) {
3069 if a.0 <= b.0 {
3070 (a, b)
3071 } else {
3072 (b, a)
3073 }
3074 };
3075 assert_ne!(
3076 normalize(pair_a),
3077 normalize(pair_b),
3078 "two different stats snapshots must produce different inner pairs; \
3079 got A = {:?}, B = {:?}",
3080 pair_a,
3081 pair_b
3082 );
3083 }
3084
3085 #[test]
3093 fn selectivity_pass_with_only_relation_cards_may_pick_arbitrary_pair() {
3094 let mut plan = synth_triangle_plan();
3095 let stats = seed_triangle_stats(100, 100, 100);
3097 selectivity_pass::run(&mut plan, &stats, &std::collections::HashMap::new());
3098 let _ = inspect_triangle_inner_pair(&plan);
3101 }
3102
3103 fn synth_4cycle_plan() -> ExecutionPlan {
3114 let inner_left = RirNode::Join {
3115 left: Box::new(RirNode::Scan { rel: RelId(1) }),
3116 right: Box::new(RirNode::Scan { rel: RelId(2) }),
3117 left_keys: vec![1],
3118 right_keys: vec![0],
3119 join_type: JoinType::Inner,
3120 };
3121 let inner_right = RirNode::Join {
3122 left: Box::new(RirNode::Scan { rel: RelId(3) }),
3123 right: Box::new(RirNode::Scan { rel: RelId(4) }),
3124 left_keys: vec![1],
3125 right_keys: vec![0],
3126 join_type: JoinType::Inner,
3127 };
3128 let outer = RirNode::Join {
3129 left: Box::new(inner_left),
3130 right: Box::new(inner_right),
3131 left_keys: vec![0, 3],
3132 right_keys: vec![3, 0],
3133 join_type: JoinType::Inner,
3134 };
3135 let body = RirNode::Project {
3136 input: Box::new(outer),
3137 columns: vec![
3138 ProjectExpr::Column(0),
3139 ProjectExpr::Column(1),
3140 ProjectExpr::Column(3),
3141 ProjectExpr::Column(5),
3142 ],
3143 };
3144 let mut builder = PlanBuilder::new();
3145 builder.add_scc(Scc {
3146 id: 0,
3147 predicates: vec!["cyc".to_string()],
3148 is_recursive: false,
3149 });
3150 builder.add_rule(
3151 0,
3152 CompiledRule {
3153 head: "cyc".to_string(),
3154 body,
3155 meta: Default::default(),
3156 },
3157 );
3158 builder.build()
3159 }
3160
3161 fn seed_4cycle_stats(c1: u64, c2: u64, c3: u64, c4: u64) -> StatsManager {
3162 let mut stats = StatsManager::new();
3163 for (rid, card) in [
3164 (RelId(1), c1),
3165 (RelId(2), c2),
3166 (RelId(3), c3),
3167 (RelId(4), c4),
3168 ] {
3169 stats.register_relation(rid);
3170 stats.update_cardinality(rid, card);
3171 }
3172 stats
3173 }
3174
3175 fn inspect_4cycle_grouping(
3179 plan: &xlog_ir::ExecutionPlan,
3180 ) -> Option<(RelId, RelId, RelId, RelId)> {
3181 let body = &plan.rules_by_scc.iter().flatten().next()?.body;
3182 let body = match body {
3183 xlog_ir::RirNode::MultiWayJoin { fallback, .. } => fallback.as_ref(),
3184 other => other,
3185 };
3186 let xlog_ir::RirNode::Project { input, .. } = body else {
3187 return None;
3188 };
3189 let xlog_ir::RirNode::Join { left, right, .. } = input.as_ref() else {
3190 return None;
3191 };
3192 let xlog_ir::RirNode::Join {
3193 left: ll,
3194 right: lr,
3195 ..
3196 } = left.as_ref()
3197 else {
3198 return None;
3199 };
3200 let xlog_ir::RirNode::Join {
3201 left: rl,
3202 right: rr,
3203 ..
3204 } = right.as_ref()
3205 else {
3206 return None;
3207 };
3208 let xlog_ir::RirNode::Scan { rel: r_ll } = ll.as_ref() else {
3209 return None;
3210 };
3211 let xlog_ir::RirNode::Scan { rel: r_lr } = lr.as_ref() else {
3212 return None;
3213 };
3214 let xlog_ir::RirNode::Scan { rel: r_rl } = rl.as_ref() else {
3215 return None;
3216 };
3217 let xlog_ir::RirNode::Scan { rel: r_rr } = rr.as_ref() else {
3218 return None;
3219 };
3220 Some((*r_ll, *r_lr, *r_rl, *r_rr))
3221 }
3222
3223 #[test]
3245 fn selectivity_pass_4cycle_picks_default_grouping_when_corners_smallest() {
3246 let mut plan = synth_4cycle_plan();
3247 let stats = seed_4cycle_stats(10, 10_000, 10_000, 10);
3248 selectivity_pass::run(&mut plan, &stats, &std::collections::HashMap::new());
3249 let (ll, lr, rl, rr) = inspect_4cycle_grouping(&plan).expect("grouping");
3250 assert_eq!(
3252 (ll, lr, rl, rr),
3253 (RelId(1), RelId(2), RelId(3), RelId(4)),
3254 "expected Default grouping"
3255 );
3256 }
3257
3258 #[test]
3271 fn selectivity_pass_4cycle_picks_alt_grouping_when_diagonals_smallest() {
3272 let mut plan = synth_4cycle_plan();
3273 let stats = seed_4cycle_stats(10_000, 10_000, 10, 10);
3274 selectivity_pass::run(&mut plan, &stats, &std::collections::HashMap::new());
3275 let (ll, lr, rl, rr) = inspect_4cycle_grouping(&plan).expect("grouping");
3276 assert_eq!(
3278 (ll, lr, rl, rr),
3279 (RelId(2), RelId(3), RelId(4), RelId(1)),
3280 "expected Alt grouping"
3281 );
3282 }
3283
3284 #[test]
3288 fn selectivity_pass_4cycle_two_snapshots_produce_different_groupings() {
3289 let mut plan_a = synth_4cycle_plan();
3290 let stats_a = seed_4cycle_stats(10, 10_000, 10_000, 10); selectivity_pass::run(&mut plan_a, &stats_a, &std::collections::HashMap::new());
3292 let g_a = inspect_4cycle_grouping(&plan_a).expect("grouping a");
3293
3294 let mut plan_b = synth_4cycle_plan();
3295 let stats_b = seed_4cycle_stats(10_000, 10_000, 10, 10); selectivity_pass::run(&mut plan_b, &stats_b, &std::collections::HashMap::new());
3297 let g_b = inspect_4cycle_grouping(&plan_b).expect("grouping b");
3298
3299 assert_ne!(
3300 g_a, g_b,
3301 "two different stats snapshots must produce different 4-cycle groupings; \
3302 got A = {:?}, B = {:?}",
3303 g_a, g_b
3304 );
3305 }
3306
3307 #[test]
3310 fn selectivity_pass_4cycle_skips_when_card_missing() {
3311 let mut plan = synth_4cycle_plan();
3312 let mut stats = StatsManager::new();
3314 for rid in [RelId(1), RelId(2), RelId(3)] {
3315 stats.register_relation(rid);
3316 stats.update_cardinality(rid, 100);
3317 }
3318 let before = format!("{:?}", plan.rules_by_scc[0][0].body);
3319 selectivity_pass::run(&mut plan, &stats, &std::collections::HashMap::new());
3320 let after = format!("{:?}", plan.rules_by_scc[0][0].body);
3321 assert_eq!(
3322 before, after,
3323 "missing-stats safety floor must leave body unchanged"
3324 );
3325 }
3326}
3327
3328#[cfg(test)]
3329mod tests {
3330 use super::*;
3331 use xlog_core::ScalarType;
3332 use xlog_ir::{ConstValue, ProjectExpr};
3333 use xlog_stats::ColumnStats;
3334
3335 fn make_stats_manager() -> Arc<StatsManager> {
3336 let mut mgr = StatsManager::new();
3337
3338 mgr.register_relation(RelId(1));
3340 mgr.update_cardinality(RelId(1), 10_000);
3341 mgr.update_byte_size(RelId(1), 320_000); mgr.register_relation(RelId(2));
3344 mgr.update_cardinality(RelId(2), 5_000);
3345 mgr.update_byte_size(RelId(2), 160_000);
3346
3347 mgr.register_relation(RelId(3));
3348 mgr.update_cardinality(RelId(3), 1_000);
3349 mgr.update_byte_size(RelId(3), 32_000);
3350
3351 let mut col0 = ColumnStats::new(0, ScalarType::I64);
3353 col0.update_distinct(1000);
3354 col0.update_range(0, 10000);
3355 mgr.add_column_stats(RelId(1), col0);
3356
3357 let mut col1 = ColumnStats::new(1, ScalarType::I64);
3358 col1.update_distinct(100);
3359 mgr.add_column_stats(RelId(1), col1);
3360
3361 Arc::new(mgr)
3362 }
3363
3364 #[test]
3365 fn test_optimizer_new() {
3366 let stats = make_stats_manager();
3367 let optimizer = Optimizer::new(stats);
3368
3369 assert_eq!(optimizer.config().dp_threshold, 10);
3370 assert!(optimizer.config().enable_pushdown);
3371 }
3372
3373 #[test]
3374 fn test_optimizer_with_config() {
3375 let stats = make_stats_manager();
3376 let config = OptimizerConfig {
3377 dp_threshold: 5,
3378 enable_pushdown: false,
3379 ..Default::default()
3380 };
3381 let optimizer = Optimizer::with_config(stats, config);
3382
3383 assert_eq!(optimizer.config().dp_threshold, 5);
3384 assert!(!optimizer.config().enable_pushdown);
3385 }
3386
3387 #[test]
3388 fn test_estimate_scan_cost() {
3389 let stats = make_stats_manager();
3390 let optimizer = Optimizer::new(stats);
3391
3392 let scan = RirNode::Scan { rel: RelId(1) };
3393 let cost = optimizer.estimate_cost(&scan);
3394
3395 assert_eq!(cost.rows, 10_000);
3396 assert!(cost.gpu_mem > 0);
3397 assert_eq!(cost.transfers, 0); }
3399
3400 #[test]
3401 fn test_estimate_scan_cost_unknown_relation() {
3402 let stats = Arc::new(StatsManager::new());
3403 let optimizer = Optimizer::new(stats);
3404
3405 let scan = RirNode::Scan { rel: RelId(999) };
3406 let cost = optimizer.estimate_cost(&scan);
3407
3408 assert_eq!(cost.rows, 1000);
3410 }
3411
3412 #[test]
3413 fn test_estimate_filter_cost() {
3414 let stats = make_stats_manager();
3415 let optimizer = Optimizer::new(stats);
3416
3417 let filter = RirNode::Filter {
3418 input: Box::new(RirNode::Scan { rel: RelId(1) }),
3419 predicate: Expr::Compare {
3420 left: Box::new(Expr::Column(0)),
3421 op: CompareOp::Eq,
3422 right: Box::new(Expr::Const(ConstValue::I64(42))),
3423 },
3424 };
3425
3426 let cost = optimizer.estimate_cost(&filter);
3427
3428 assert!(cost.rows < 10_000);
3430 assert!(cost.rows >= 1);
3431 }
3432
3433 #[test]
3434 fn test_estimate_join_cost() {
3435 let stats = make_stats_manager();
3436 let optimizer = Optimizer::new(stats);
3437
3438 let join = RirNode::Join {
3439 left: Box::new(RirNode::Scan { rel: RelId(1) }),
3440 right: Box::new(RirNode::Scan { rel: RelId(2) }),
3441 left_keys: vec![0],
3442 right_keys: vec![0],
3443 join_type: JoinType::Inner,
3444 };
3445
3446 let cost = optimizer.estimate_cost(&join);
3447
3448 assert!(cost.rows > 0);
3450 assert!(cost.cpu_cost > 0.0);
3451 assert!(cost.gpu_mem > 0);
3452 }
3453
3454 #[test]
3455 fn test_estimate_join_cost_with_selectivity() {
3456 let mut mgr = StatsManager::new();
3457 mgr.register_relation(RelId(1));
3458 mgr.register_relation(RelId(2));
3459 mgr.update_cardinality(RelId(1), 1000);
3460 mgr.update_cardinality(RelId(2), 500);
3461
3462 mgr.record_join_result(RelId(1), RelId(2), vec![0], vec![0], 500_000, 2500);
3464
3465 let optimizer = Optimizer::new(Arc::new(mgr));
3466
3467 let join = RirNode::Join {
3468 left: Box::new(RirNode::Scan { rel: RelId(1) }),
3469 right: Box::new(RirNode::Scan { rel: RelId(2) }),
3470 left_keys: vec![0],
3471 right_keys: vec![0],
3472 join_type: JoinType::Inner,
3473 };
3474
3475 let cost = optimizer.estimate_cost(&join);
3476
3477 assert!(cost.rows > 0);
3479 }
3480
3481 #[test]
3482 fn test_predicate_pushdown_simple_scan() {
3483 let stats = make_stats_manager();
3484 let optimizer = Optimizer::new(stats);
3485
3486 let scan = RirNode::Scan { rel: RelId(1) };
3487 let optimized = optimizer.optimize(scan);
3488
3489 assert!(matches!(optimized, RirNode::Scan { rel: RelId(1) }));
3491 }
3492
3493 #[test]
3494 fn test_predicate_pushdown_filter_on_scan() {
3495 let stats = make_stats_manager();
3496 let optimizer = Optimizer::new(stats);
3497
3498 let filter = RirNode::Filter {
3499 input: Box::new(RirNode::Scan { rel: RelId(1) }),
3500 predicate: Expr::Compare {
3501 left: Box::new(Expr::Column(0)),
3502 op: CompareOp::Eq,
3503 right: Box::new(Expr::Const(ConstValue::I64(42))),
3504 },
3505 };
3506
3507 let optimized = optimizer.optimize(filter);
3508
3509 assert!(matches!(optimized, RirNode::Filter { .. }));
3511 }
3512
3513 #[test]
3514 fn test_predicate_pushdown_merges_filters() {
3515 let stats = make_stats_manager();
3516 let optimizer = Optimizer::new(stats);
3517
3518 let nested_filter = RirNode::Filter {
3519 input: Box::new(RirNode::Filter {
3520 input: Box::new(RirNode::Scan { rel: RelId(1) }),
3521 predicate: Expr::Compare {
3522 left: Box::new(Expr::Column(0)),
3523 op: CompareOp::Gt,
3524 right: Box::new(Expr::Const(ConstValue::I64(0))),
3525 },
3526 }),
3527 predicate: Expr::Compare {
3528 left: Box::new(Expr::Column(0)),
3529 op: CompareOp::Lt,
3530 right: Box::new(Expr::Const(ConstValue::I64(100))),
3531 },
3532 };
3533
3534 let optimized = optimizer.optimize(nested_filter);
3535
3536 if let RirNode::Filter { predicate, .. } = optimized {
3538 assert!(matches!(predicate, Expr::And(_)));
3539 } else {
3540 panic!("Expected Filter node");
3541 }
3542 }
3543
3544 #[test]
3545 fn test_predicate_pushdown_through_project() {
3546 let stats = make_stats_manager();
3547 let optimizer = Optimizer::new(stats);
3548
3549 let plan = RirNode::Filter {
3551 input: Box::new(RirNode::Project {
3552 input: Box::new(RirNode::Scan { rel: RelId(1) }),
3553 columns: vec![ProjectExpr::Column(0), ProjectExpr::Column(1)],
3554 }),
3555 predicate: Expr::Compare {
3556 left: Box::new(Expr::Column(0)),
3557 op: CompareOp::Eq,
3558 right: Box::new(Expr::Const(ConstValue::I64(42))),
3559 },
3560 };
3561
3562 let optimized = optimizer.optimize(plan);
3563
3564 assert!(matches!(optimized, RirNode::Project { .. }));
3566 if let RirNode::Project { input, .. } = optimized {
3567 assert!(matches!(*input, RirNode::Filter { .. }));
3568 }
3569 }
3570
3571 #[test]
3572 fn test_predicate_pushdown_into_join() {
3573 let stats = make_stats_manager();
3574 let optimizer = Optimizer::new(stats);
3575
3576 let plan = RirNode::Filter {
3578 input: Box::new(RirNode::Join {
3579 left: Box::new(RirNode::Scan { rel: RelId(1) }),
3580 right: Box::new(RirNode::Scan { rel: RelId(2) }),
3581 left_keys: vec![0],
3582 right_keys: vec![0],
3583 join_type: JoinType::Inner,
3584 }),
3585 predicate: Expr::Compare {
3586 left: Box::new(Expr::Column(0)), op: CompareOp::Eq,
3588 right: Box::new(Expr::Const(ConstValue::I64(42))),
3589 },
3590 };
3591
3592 let optimized = optimizer.optimize(plan);
3593
3594 if let RirNode::Join { left, .. } = optimized {
3596 assert!(matches!(*left, RirNode::Filter { .. }));
3597 } else {
3598 panic!("Expected Join node");
3599 }
3600 }
3601
3602 #[test]
3603 fn test_plan_cost_total() {
3604 let cost = PlanCost {
3605 rows: 1000,
3606 cpu_cost: 100.0,
3607 gpu_mem: 1_000_000,
3608 transfers: 2,
3609 };
3610
3611 let total = cost.total_cost(100.0);
3612
3613 assert!((total - 1300.0).abs() < 0.001);
3616 }
3617
3618 #[test]
3619 fn test_plan_cost_then() {
3620 let cost1 = PlanCost {
3621 rows: 1000,
3622 cpu_cost: 50.0,
3623 gpu_mem: 500,
3624 transfers: 1,
3625 };
3626
3627 let cost2 = PlanCost {
3628 rows: 500,
3629 cpu_cost: 25.0,
3630 gpu_mem: 800,
3631 transfers: 1,
3632 };
3633
3634 let combined = cost1.then(cost2);
3635
3636 assert_eq!(combined.rows, 500); assert_eq!(combined.cpu_cost, 75.0);
3638 assert_eq!(combined.gpu_mem, 800); assert_eq!(combined.transfers, 2);
3640 }
3641
3642 #[test]
3643 fn test_optimizer_config_default() {
3644 let config = OptimizerConfig::default();
3645
3646 assert_eq!(config.dp_threshold, 10);
3647 assert!((config.index_heat_threshold - 0.7).abs() < 0.001);
3648 assert!(config.enable_pushdown);
3649 assert!((config.default_filter_selectivity - 0.1).abs() < 0.001);
3650 }
3651
3652 #[test]
3653 fn test_should_use_greedy() {
3654 let stats = make_stats_manager();
3655 let config = OptimizerConfig {
3656 dp_threshold: 2,
3657 ..Default::default()
3658 };
3659 let optimizer = Optimizer::with_config(stats, config);
3660
3661 let single = RirNode::Scan { rel: RelId(1) };
3663 assert!(!optimizer.should_use_greedy(&single));
3664
3665 let multi = RirNode::Join {
3667 left: Box::new(RirNode::Join {
3668 left: Box::new(RirNode::Scan { rel: RelId(1) }),
3669 right: Box::new(RirNode::Scan { rel: RelId(2) }),
3670 left_keys: vec![0],
3671 right_keys: vec![0],
3672 join_type: JoinType::Inner,
3673 }),
3674 right: Box::new(RirNode::Scan { rel: RelId(3) }),
3675 left_keys: vec![0],
3676 right_keys: vec![0],
3677 join_type: JoinType::Inner,
3678 };
3679 assert!(optimizer.should_use_greedy(&multi));
3680 }
3681
3682 #[test]
3683 fn test_recommend_indexes() {
3684 let mut mgr = StatsManager::new();
3685 mgr.register_relation(RelId(1));
3686 mgr.register_relation(RelId(2));
3687
3688 for _ in 0..50 {
3690 mgr.record_access(RelId(1));
3691 }
3692
3693 let optimizer = Optimizer::new(Arc::new(mgr));
3694 let recommendations = optimizer.recommend_indexes();
3695
3696 assert!(recommendations.contains(&RelId(1)));
3697 assert!(!recommendations.contains(&RelId(2)));
3698 }
3699
3700 #[test]
3701 fn test_estimate_groupby_cost() {
3702 let stats = make_stats_manager();
3703 let optimizer = Optimizer::new(stats);
3704
3705 let groupby = RirNode::GroupBy {
3706 input: Box::new(RirNode::Scan { rel: RelId(1) }),
3707 key_cols: vec![0],
3708 aggs: vec![(1, xlog_core::AggOp::Sum)],
3709 };
3710
3711 let cost = optimizer.estimate_cost(&groupby);
3712
3713 assert!(cost.rows < 10_000);
3715 assert!(cost.rows >= 1);
3716 }
3717
3718 #[test]
3719 fn test_estimate_union_cost() {
3720 let stats = make_stats_manager();
3721 let optimizer = Optimizer::new(stats);
3722
3723 let union = RirNode::Union {
3724 inputs: vec![
3725 RirNode::Scan { rel: RelId(1) },
3726 RirNode::Scan { rel: RelId(2) },
3727 ],
3728 };
3729
3730 let cost = optimizer.estimate_cost(&union);
3731
3732 assert_eq!(cost.rows, 15_000); }
3735
3736 #[test]
3737 fn test_estimate_distinct_cost() {
3738 let stats = make_stats_manager();
3739 let optimizer = Optimizer::new(stats);
3740
3741 let distinct = RirNode::Distinct {
3742 input: Box::new(RirNode::Scan { rel: RelId(1) }),
3743 key_cols: vec![0],
3744 };
3745
3746 let cost = optimizer.estimate_cost(&distinct);
3747
3748 assert!(cost.rows <= 10_000);
3750 assert!(cost.rows >= 1);
3751 }
3752
3753 #[test]
3754 fn test_estimate_diff_cost() {
3755 let stats = make_stats_manager();
3756 let optimizer = Optimizer::new(stats);
3757
3758 let diff = RirNode::Diff {
3759 left: Box::new(RirNode::Scan { rel: RelId(1) }),
3760 right: Box::new(RirNode::Scan { rel: RelId(2) }),
3761 };
3762
3763 let cost = optimizer.estimate_cost(&diff);
3764
3765 assert!(cost.rows <= 10_000);
3767 assert!(cost.rows >= 1);
3768 }
3769
3770 #[test]
3771 fn test_estimate_fixpoint_cost() {
3772 let stats = make_stats_manager();
3773 let optimizer = Optimizer::new(stats);
3774
3775 let fixpoint = RirNode::Fixpoint {
3776 scc_id: 0,
3777 base: Box::new(RirNode::Scan { rel: RelId(1) }),
3778 recursive: Box::new(RirNode::Scan { rel: RelId(1) }),
3779 delta_rel: RelId(10),
3780 full_rel: RelId(11),
3781 };
3782
3783 let cost = optimizer.estimate_cost(&fixpoint);
3784
3785 assert!(cost.rows >= 10_000);
3787 }
3788
3789 #[test]
3790 fn test_predicate_selectivity_equality() {
3791 let stats = make_stats_manager();
3792 let optimizer = Optimizer::new(stats);
3793
3794 let scan = RirNode::Scan { rel: RelId(1) };
3795
3796 let eq_pred = Expr::Compare {
3798 left: Box::new(Expr::Column(0)),
3799 op: CompareOp::Eq,
3800 right: Box::new(Expr::Const(ConstValue::I64(42))),
3801 };
3802
3803 let selectivity = optimizer.estimate_predicate_selectivity(&eq_pred, &scan);
3804
3805 assert!(selectivity < 0.01);
3807 assert!(selectivity > 0.0);
3808 }
3809
3810 #[test]
3811 fn test_predicate_selectivity_and() {
3812 let stats = make_stats_manager();
3813 let optimizer = Optimizer::new(stats);
3814
3815 let scan = RirNode::Scan { rel: RelId(1) };
3816
3817 let and_pred = Expr::And(vec![
3819 Expr::Compare {
3820 left: Box::new(Expr::Column(0)),
3821 op: CompareOp::Gt,
3822 right: Box::new(Expr::Const(ConstValue::I64(0))),
3823 },
3824 Expr::Compare {
3825 left: Box::new(Expr::Column(0)),
3826 op: CompareOp::Lt,
3827 right: Box::new(Expr::Const(ConstValue::I64(100))),
3828 },
3829 ]);
3830
3831 let selectivity = optimizer.estimate_predicate_selectivity(&and_pred, &scan);
3832
3833 assert!(selectivity < 0.5);
3835 assert!(selectivity > 0.0);
3836 }
3837
3838 #[test]
3839 fn test_predicate_selectivity_not() {
3840 let stats = make_stats_manager();
3841 let optimizer = Optimizer::new(stats);
3842
3843 let scan = RirNode::Scan { rel: RelId(1) };
3844
3845 let not_pred = Expr::Not(Box::new(Expr::Compare {
3847 left: Box::new(Expr::Column(0)),
3848 op: CompareOp::Eq,
3849 right: Box::new(Expr::Const(ConstValue::I64(42))),
3850 }));
3851
3852 let selectivity = optimizer.estimate_predicate_selectivity(¬_pred, &scan);
3853
3854 assert!(selectivity > 0.9);
3856 }
3857
3858 #[test]
3859 fn test_join_type_semi() {
3860 let stats = make_stats_manager();
3861 let optimizer = Optimizer::new(stats);
3862
3863 let semi_join = RirNode::Join {
3864 left: Box::new(RirNode::Scan { rel: RelId(1) }),
3865 right: Box::new(RirNode::Scan { rel: RelId(2) }),
3866 left_keys: vec![0],
3867 right_keys: vec![0],
3868 join_type: JoinType::Semi,
3869 };
3870
3871 let cost = optimizer.estimate_cost(&semi_join);
3872
3873 assert!(cost.rows <= 10_000);
3875 }
3876
3877 #[test]
3878 fn test_join_type_anti() {
3879 let stats = make_stats_manager();
3880 let optimizer = Optimizer::new(stats);
3881
3882 let anti_join = RirNode::Join {
3883 left: Box::new(RirNode::Scan { rel: RelId(1) }),
3884 right: Box::new(RirNode::Scan { rel: RelId(2) }),
3885 left_keys: vec![0],
3886 right_keys: vec![0],
3887 join_type: JoinType::Anti,
3888 };
3889
3890 let cost = optimizer.estimate_cost(&anti_join);
3891
3892 assert!(cost.rows <= 10_000);
3894 }
3895
3896 #[test]
3897 fn test_pushdown_disabled() {
3898 let stats = make_stats_manager();
3899 let config = OptimizerConfig {
3900 enable_pushdown: false,
3901 ..Default::default()
3902 };
3903 let optimizer = Optimizer::with_config(stats, config);
3904
3905 let plan = RirNode::Filter {
3907 input: Box::new(RirNode::Filter {
3908 input: Box::new(RirNode::Scan { rel: RelId(1) }),
3909 predicate: Expr::Compare {
3910 left: Box::new(Expr::Column(0)),
3911 op: CompareOp::Gt,
3912 right: Box::new(Expr::Const(ConstValue::I64(0))),
3913 },
3914 }),
3915 predicate: Expr::Compare {
3916 left: Box::new(Expr::Column(0)),
3917 op: CompareOp::Lt,
3918 right: Box::new(Expr::Const(ConstValue::I64(100))),
3919 },
3920 };
3921
3922 let optimized = optimizer.optimize(plan.clone());
3923
3924 if let RirNode::Filter { input, .. } = optimized {
3927 assert!(matches!(*input, RirNode::Filter { .. }));
3928 } else {
3929 panic!("Expected Filter node");
3930 }
3931 }
3932
3933 #[test]
3934 fn test_collect_columns() {
3935 let expr = Expr::And(vec![
3936 Expr::Compare {
3937 left: Box::new(Expr::Column(0)),
3938 op: CompareOp::Eq,
3939 right: Box::new(Expr::Column(2)),
3940 },
3941 Expr::Compare {
3942 left: Box::new(Expr::Column(1)),
3943 op: CompareOp::Gt,
3944 right: Box::new(Expr::Const(ConstValue::I64(0))),
3945 },
3946 ]);
3947
3948 let cols = Optimizer::collect_columns(&expr);
3949
3950 assert!(cols.contains(&0));
3951 assert!(cols.contains(&1));
3952 assert!(cols.contains(&2));
3953 }
3954
3955 #[test]
3956 fn test_flatten_and() {
3957 let nested = Expr::And(vec![
3958 Expr::And(vec![
3959 Expr::Compare {
3960 left: Box::new(Expr::Column(0)),
3961 op: CompareOp::Eq,
3962 right: Box::new(Expr::Const(ConstValue::I64(1))),
3963 },
3964 Expr::Compare {
3965 left: Box::new(Expr::Column(1)),
3966 op: CompareOp::Eq,
3967 right: Box::new(Expr::Const(ConstValue::I64(2))),
3968 },
3969 ]),
3970 Expr::Compare {
3971 left: Box::new(Expr::Column(2)),
3972 op: CompareOp::Eq,
3973 right: Box::new(Expr::Const(ConstValue::I64(3))),
3974 },
3975 ]);
3976
3977 let flattened = Optimizer::flatten_and(&nested);
3978
3979 assert_eq!(flattened.len(), 3);
3980 }
3981
3982 #[test]
3983 fn test_conjoin_single() {
3984 let single = vec![Expr::Compare {
3985 left: Box::new(Expr::Column(0)),
3986 op: CompareOp::Eq,
3987 right: Box::new(Expr::Const(ConstValue::I64(42))),
3988 }];
3989
3990 let result = Optimizer::conjoin(single);
3991
3992 assert!(matches!(result, Expr::Compare { .. }));
3993 }
3994
3995 #[test]
3996 fn test_conjoin_multiple() {
3997 let multiple = vec![
3998 Expr::Compare {
3999 left: Box::new(Expr::Column(0)),
4000 op: CompareOp::Eq,
4001 right: Box::new(Expr::Const(ConstValue::I64(1))),
4002 },
4003 Expr::Compare {
4004 left: Box::new(Expr::Column(1)),
4005 op: CompareOp::Eq,
4006 right: Box::new(Expr::Const(ConstValue::I64(2))),
4007 },
4008 ];
4009
4010 let result = Optimizer::conjoin(multiple);
4011
4012 assert!(matches!(result, Expr::And(_)));
4013 }
4014
4015 #[test]
4016 fn test_predicate_pushdown_with_schemas() {
4017 let stats = make_stats_manager();
4020 let mut optimizer = Optimizer::new(stats);
4021
4022 let left_schema = Schema::new(vec![
4024 ("c0".to_string(), xlog_core::ScalarType::Symbol),
4025 ("c1".to_string(), xlog_core::ScalarType::Symbol),
4026 ("c2".to_string(), xlog_core::ScalarType::Symbol),
4027 ]);
4028 let right_schema = Schema::new(vec![
4029 ("c0".to_string(), xlog_core::ScalarType::Symbol),
4030 ("c1".to_string(), xlog_core::ScalarType::Symbol),
4031 ("c2".to_string(), xlog_core::ScalarType::U32),
4032 ]);
4033
4034 let mut schemas = HashMap::new();
4035 schemas.insert(RelId(1), left_schema);
4036 schemas.insert(RelId(2), right_schema);
4037 optimizer.set_schemas(schemas);
4038
4039 let plan = RirNode::Filter {
4041 input: Box::new(RirNode::Join {
4042 left: Box::new(RirNode::Scan { rel: RelId(1) }),
4043 right: Box::new(RirNode::Scan { rel: RelId(2) }),
4044 left_keys: vec![0],
4045 right_keys: vec![0],
4046 join_type: JoinType::Inner,
4047 }),
4048 predicate: Expr::Compare {
4049 left: Box::new(Expr::Column(5)), op: CompareOp::Ge,
4051 right: Box::new(Expr::Const(ConstValue::U32(4))),
4052 },
4053 };
4054
4055 let optimized = optimizer.optimize(plan);
4056
4057 if let RirNode::Join { right, .. } = optimized {
4059 if let RirNode::Filter { predicate, .. } = *right {
4060 if let Expr::Compare { left, .. } = predicate {
4061 if let Expr::Column(idx) = *left {
4062 assert_eq!(
4063 idx, 2,
4064 "Column should be remapped to 2 (5 - left_width(3) = 2)"
4065 );
4066 } else {
4067 panic!("Expected Column expression");
4068 }
4069 } else {
4070 panic!("Expected Compare predicate");
4071 }
4072 } else {
4073 panic!("Expected Filter on right side of join");
4074 }
4075 } else {
4076 panic!("Expected Join node");
4077 }
4078 }
4079
4080 fn build_canonical_triangle_multiway() -> RirNode {
4096 let scan_xy = RirNode::Scan { rel: RelId(1) };
4097 let scan_yz = RirNode::Scan { rel: RelId(2) };
4098 let scan_xz = RirNode::Scan { rel: RelId(3) };
4099 let inner_join = RirNode::Join {
4100 left: Box::new(scan_xy.clone()),
4101 right: Box::new(scan_yz.clone()),
4102 left_keys: vec![1],
4103 right_keys: vec![0],
4104 join_type: JoinType::Inner,
4105 };
4106 let outer_join = RirNode::Join {
4107 left: Box::new(inner_join),
4108 right: Box::new(scan_xz.clone()),
4109 left_keys: vec![0, 3],
4110 right_keys: vec![0, 1],
4111 join_type: JoinType::Inner,
4112 };
4113 let fallback = RirNode::Project {
4114 input: Box::new(outer_join),
4115 columns: vec![
4116 ProjectExpr::Column(0),
4117 ProjectExpr::Column(1),
4118 ProjectExpr::Column(3),
4119 ],
4120 };
4121 RirNode::MultiWayJoin {
4122 inputs: vec![scan_xy, scan_yz, scan_xz],
4123 slot_vars: vec![
4124 vec![Some(0), Some(1)],
4125 vec![Some(1), Some(2)],
4126 vec![Some(0), Some(2)],
4127 ],
4128 output_columns: vec![
4129 ProjectExpr::Column(0),
4130 ProjectExpr::Column(1),
4131 ProjectExpr::Column(3),
4132 ],
4133 fallback: Box::new(fallback),
4134 plan: None,
4135 var_order: None,
4136 }
4137 }
4138
4139 fn build_4input_multiway() -> RirNode {
4148 let scans = [RelId(1), RelId(2), RelId(3), RelId(1)]
4149 .map(|rel| RirNode::Scan { rel })
4150 .to_vec();
4151 let slot_vars = vec![
4153 vec![Some(0u32), Some(1)],
4154 vec![Some(1u32), Some(2)],
4155 vec![Some(2u32), Some(3)],
4156 vec![Some(0u32), Some(3)],
4157 ];
4158 let output_columns = vec![
4161 ProjectExpr::Column(0),
4162 ProjectExpr::Column(1),
4163 ProjectExpr::Column(2),
4164 ProjectExpr::Column(3),
4165 ];
4166 let fallback = RirNode::Unit;
4169 RirNode::MultiWayJoin {
4170 inputs: scans,
4171 slot_vars,
4172 output_columns,
4173 fallback: Box::new(fallback),
4174 plan: None,
4175 var_order: None,
4176 }
4177 }
4178
4179 #[test]
4180 fn optimize_returns_multiway_unchanged() {
4181 let optimizer = Optimizer::new(make_stats_manager());
4182 for node in [build_canonical_triangle_multiway(), build_4input_multiway()] {
4183 let optimized = optimizer.optimize(node.clone());
4184 match (&node, &optimized) {
4185 (
4186 RirNode::MultiWayJoin {
4187 inputs: a_in,
4188 output_columns: a_out,
4189 ..
4190 },
4191 RirNode::MultiWayJoin {
4192 inputs: b_in,
4193 output_columns: b_out,
4194 ..
4195 },
4196 ) => {
4197 assert_eq!(a_in.len(), b_in.len());
4198 assert_eq!(a_out.len(), b_out.len());
4199 }
4200 _ => panic!("optimize() must return a MultiWayJoin"),
4201 }
4202 }
4203 }
4204
4205 #[test]
4206 fn estimate_width_uses_output_columns_arity() {
4207 let optimizer = Optimizer::new(make_stats_manager());
4208 assert_eq!(
4210 optimizer.estimate_width(&build_canonical_triangle_multiway()),
4211 3
4212 );
4213 assert_eq!(optimizer.estimate_width(&build_4input_multiway()), 4);
4217 }
4218
4219 #[test]
4220 fn estimate_cost_sums_input_costs() {
4221 let optimizer = Optimizer::new(make_stats_manager());
4222
4223 let cost_tri = optimizer.estimate_cost(&build_canonical_triangle_multiway());
4226 assert!(
4227 cost_tri.rows >= 16_000,
4228 "expected cost.rows >= 16000, got {}",
4229 cost_tri.rows
4230 );
4231
4232 let cost_4 = optimizer.estimate_cost(&build_4input_multiway());
4237 assert!(
4238 cost_4.rows >= 26_000,
4239 "expected 4-input cost.rows >= 26000, got {}",
4240 cost_4.rows
4241 );
4242 assert!(
4243 cost_4.rows > cost_tri.rows,
4244 "4-input cost ({}) must exceed triangle cost ({})",
4245 cost_4.rows,
4246 cost_tri.rows
4247 );
4248 }
4249
4250 #[test]
4251 fn find_column_relation_returns_none_for_multiway() {
4252 let optimizer = Optimizer::new(make_stats_manager());
4253 for node in [build_canonical_triangle_multiway(), build_4input_multiway()] {
4260 for col in 0..node.referenced_relations().len() {
4261 assert!(
4262 optimizer.find_column_relation(&node, col).is_none(),
4263 "find_column_relation must return None for any \
4264 MultiWayJoin column (col={})",
4265 col,
4266 );
4267 }
4268 }
4269 }
4270}