1pub mod cardinality;
15pub mod cost;
16pub mod join_order;
17
18pub use cardinality::{
19 CardinalityEstimator, ColumnStats, EstimationLog, SelectivityConfig, TableStats,
20};
21pub use cost::{Cost, CostModel};
22pub use join_order::{BitSet, DPccp, JoinGraph, JoinGraphBuilder, JoinPlan};
23
24use crate::query::plan::{
25 FilterOp, JoinCondition, LogicalExpression, LogicalOperator, LogicalPlan, MultiWayJoinOp,
26};
27use grafeo_common::grafeo_debug_span;
28use grafeo_common::utils::error::Result;
29use std::collections::HashSet;
30
31#[derive(Debug, Clone)]
33struct JoinInfo {
34 left_var: String,
35 right_var: String,
36 left_expr: LogicalExpression,
37 right_expr: LogicalExpression,
38}
39
40#[derive(Debug, Clone, PartialEq, Eq, Hash)]
42enum RequiredColumn {
43 Variable(String),
45 Property(String, String),
47}
48
49pub struct Optimizer {
54 enable_filter_pushdown: bool,
56 enable_join_reorder: bool,
58 enable_projection_pushdown: bool,
60 cost_model: CostModel,
62 card_estimator: CardinalityEstimator,
64}
65
66impl Optimizer {
67 #[must_use]
69 pub fn new() -> Self {
70 Self {
71 enable_filter_pushdown: true,
72 enable_join_reorder: true,
73 enable_projection_pushdown: true,
74 cost_model: CostModel::new(),
75 card_estimator: CardinalityEstimator::new(),
76 }
77 }
78
79 #[cfg(feature = "lpg")]
85 #[must_use]
86 pub fn from_store(store: &grafeo_core::graph::lpg::LpgStore) -> Self {
87 store.ensure_statistics_fresh();
88 let stats = store.statistics();
89 Self::from_statistics(&stats)
90 }
91
92 #[must_use]
99 pub fn from_graph_store(store: &dyn grafeo_core::graph::GraphStore) -> Self {
100 let stats = store.statistics();
101 Self::from_statistics(&stats)
102 }
103
104 #[cfg(feature = "triple-store")]
109 #[must_use]
110 pub fn from_rdf_statistics(rdf_stats: grafeo_core::statistics::RdfStatistics) -> Self {
111 let total = rdf_stats.total_triples;
112 let estimator = CardinalityEstimator::from_rdf_statistics(rdf_stats);
113 Self {
114 enable_filter_pushdown: true,
115 enable_join_reorder: true,
116 enable_projection_pushdown: true,
117 cost_model: CostModel::new().with_graph_totals(total, total),
118 card_estimator: estimator,
119 }
120 }
121
122 #[must_use]
127 fn from_statistics(stats: &grafeo_core::statistics::Statistics) -> Self {
128 let estimator = CardinalityEstimator::from_statistics(stats);
129
130 let avg_fanout = if stats.total_nodes > 0 {
131 (stats.total_edges as f64 / stats.total_nodes as f64).max(1.0)
132 } else {
133 10.0
134 };
135
136 let edge_type_degrees: std::collections::HashMap<String, (f64, f64)> = stats
137 .edge_types
138 .iter()
139 .map(|(name, et)| (name.clone(), (et.avg_out_degree, et.avg_in_degree)))
140 .collect();
141
142 let label_cardinalities: std::collections::HashMap<String, u64> = stats
143 .labels
144 .iter()
145 .map(|(name, ls)| (name.clone(), ls.node_count))
146 .collect();
147
148 Self {
149 enable_filter_pushdown: true,
150 enable_join_reorder: true,
151 enable_projection_pushdown: true,
152 cost_model: CostModel::new()
153 .with_avg_fanout(avg_fanout)
154 .with_edge_type_degrees(edge_type_degrees)
155 .with_label_cardinalities(label_cardinalities)
156 .with_graph_totals(stats.total_nodes, stats.total_edges),
157 card_estimator: estimator,
158 }
159 }
160
161 pub fn with_filter_pushdown(mut self, enabled: bool) -> Self {
163 self.enable_filter_pushdown = enabled;
164 self
165 }
166
167 pub fn with_join_reorder(mut self, enabled: bool) -> Self {
169 self.enable_join_reorder = enabled;
170 self
171 }
172
173 pub fn with_projection_pushdown(mut self, enabled: bool) -> Self {
175 self.enable_projection_pushdown = enabled;
176 self
177 }
178
179 pub fn with_cost_model(mut self, cost_model: CostModel) -> Self {
181 self.cost_model = cost_model;
182 self
183 }
184
185 pub fn with_cardinality_estimator(mut self, estimator: CardinalityEstimator) -> Self {
187 self.card_estimator = estimator;
188 self
189 }
190
191 pub fn with_selectivity_config(mut self, config: SelectivityConfig) -> Self {
193 self.card_estimator = CardinalityEstimator::with_selectivity_config(config);
194 self
195 }
196
197 pub fn cost_model(&self) -> &CostModel {
199 &self.cost_model
200 }
201
202 pub fn cardinality_estimator(&self) -> &CardinalityEstimator {
204 &self.card_estimator
205 }
206
207 pub fn estimate_cost(&self, plan: &LogicalPlan) -> Cost {
213 self.cost_model
214 .estimate_tree(&plan.root, &self.card_estimator)
215 }
216
217 pub fn estimate_cardinality(&self, plan: &LogicalPlan) -> f64 {
219 self.card_estimator.estimate(&plan.root)
220 }
221
222 pub fn optimize(&self, plan: LogicalPlan) -> Result<LogicalPlan> {
228 let _span = grafeo_debug_span!("grafeo::query::optimize");
229 let mut root = plan.root;
230
231 if self.enable_filter_pushdown {
233 root = self.propagate_join_predicates(root);
239 root = self.push_filters_down(root);
240 }
241
242 if self.enable_join_reorder {
243 root = self.reorder_joins(root);
244 }
245
246 if self.enable_projection_pushdown {
247 root = self.push_projections_down(root);
248 }
249
250 Ok(LogicalPlan {
251 root,
252 explain: plan.explain,
253 profile: plan.profile,
254 default_params: plan.default_params,
255 })
256 }
257
258 fn push_projections_down(&self, op: LogicalOperator) -> LogicalOperator {
265 let required = self.collect_required_columns(&op);
267
268 self.push_projections_recursive(op, &required)
270 }
271
272 fn collect_required_columns(&self, op: &LogicalOperator) -> HashSet<RequiredColumn> {
274 let mut required = HashSet::new();
275 Self::collect_required_recursive(op, &mut required);
276 required
277 }
278
279 fn collect_required_recursive(op: &LogicalOperator, required: &mut HashSet<RequiredColumn>) {
281 match op {
282 LogicalOperator::Return(ret) => {
283 for item in &ret.items {
284 Self::collect_from_expression(&item.expression, required);
285 }
286 Self::collect_required_recursive(&ret.input, required);
287 }
288 LogicalOperator::Project(proj) => {
289 for p in &proj.projections {
290 Self::collect_from_expression(&p.expression, required);
291 }
292 Self::collect_required_recursive(&proj.input, required);
293 }
294 LogicalOperator::Filter(filter) => {
295 Self::collect_from_expression(&filter.predicate, required);
296 Self::collect_required_recursive(&filter.input, required);
297 }
298 LogicalOperator::Sort(sort) => {
299 for key in &sort.keys {
300 Self::collect_from_expression(&key.expression, required);
301 }
302 Self::collect_required_recursive(&sort.input, required);
303 }
304 LogicalOperator::Aggregate(agg) => {
305 for expr in &agg.group_by {
306 Self::collect_from_expression(expr, required);
307 }
308 for agg_expr in &agg.aggregates {
309 if let Some(ref expr) = agg_expr.expression {
310 Self::collect_from_expression(expr, required);
311 }
312 }
313 if let Some(ref having) = agg.having {
314 Self::collect_from_expression(having, required);
315 }
316 Self::collect_required_recursive(&agg.input, required);
317 }
318 LogicalOperator::Join(join) => {
319 for cond in &join.conditions {
320 Self::collect_from_expression(&cond.left, required);
321 Self::collect_from_expression(&cond.right, required);
322 }
323 Self::collect_required_recursive(&join.left, required);
324 Self::collect_required_recursive(&join.right, required);
325 }
326 LogicalOperator::Expand(expand) => {
327 required.insert(RequiredColumn::Variable(expand.from_variable.clone()));
329 required.insert(RequiredColumn::Variable(expand.to_variable.clone()));
330 if let Some(ref edge_var) = expand.edge_variable {
331 required.insert(RequiredColumn::Variable(edge_var.clone()));
332 }
333 Self::collect_required_recursive(&expand.input, required);
334 }
335 LogicalOperator::Limit(limit) => {
336 Self::collect_required_recursive(&limit.input, required);
337 }
338 LogicalOperator::Skip(skip) => {
339 Self::collect_required_recursive(&skip.input, required);
340 }
341 LogicalOperator::Distinct(distinct) => {
342 Self::collect_required_recursive(&distinct.input, required);
343 }
344 LogicalOperator::NodeScan(scan) => {
345 required.insert(RequiredColumn::Variable(scan.variable.clone()));
346 }
347 LogicalOperator::EdgeScan(scan) => {
348 required.insert(RequiredColumn::Variable(scan.variable.clone()));
349 }
350 LogicalOperator::MultiWayJoin(mwj) => {
351 for cond in &mwj.conditions {
352 Self::collect_from_expression(&cond.left, required);
353 Self::collect_from_expression(&cond.right, required);
354 }
355 for input in &mwj.inputs {
356 Self::collect_required_recursive(input, required);
357 }
358 }
359 _ => {}
360 }
361 }
362
363 fn collect_from_expression(expr: &LogicalExpression, required: &mut HashSet<RequiredColumn>) {
365 match expr {
366 LogicalExpression::Variable(var) => {
367 required.insert(RequiredColumn::Variable(var.clone()));
368 }
369 LogicalExpression::Property { variable, property } => {
370 required.insert(RequiredColumn::Property(variable.clone(), property.clone()));
371 required.insert(RequiredColumn::Variable(variable.clone()));
372 }
373 LogicalExpression::Binary { left, right, .. } => {
374 Self::collect_from_expression(left, required);
375 Self::collect_from_expression(right, required);
376 }
377 LogicalExpression::Unary { operand, .. } => {
378 Self::collect_from_expression(operand, required);
379 }
380 LogicalExpression::FunctionCall { args, .. } => {
381 for arg in args {
382 Self::collect_from_expression(arg, required);
383 }
384 }
385 LogicalExpression::List(items) => {
386 for item in items {
387 Self::collect_from_expression(item, required);
388 }
389 }
390 LogicalExpression::Map(pairs) => {
391 for (_, value) in pairs {
392 Self::collect_from_expression(value, required);
393 }
394 }
395 LogicalExpression::IndexAccess { base, index } => {
396 Self::collect_from_expression(base, required);
397 Self::collect_from_expression(index, required);
398 }
399 LogicalExpression::SliceAccess { base, start, end } => {
400 Self::collect_from_expression(base, required);
401 if let Some(s) = start {
402 Self::collect_from_expression(s, required);
403 }
404 if let Some(e) = end {
405 Self::collect_from_expression(e, required);
406 }
407 }
408 LogicalExpression::Case {
409 operand,
410 when_clauses,
411 else_clause,
412 } => {
413 if let Some(op) = operand {
414 Self::collect_from_expression(op, required);
415 }
416 for (cond, result) in when_clauses {
417 Self::collect_from_expression(cond, required);
418 Self::collect_from_expression(result, required);
419 }
420 if let Some(else_expr) = else_clause {
421 Self::collect_from_expression(else_expr, required);
422 }
423 }
424 LogicalExpression::Labels(var)
425 | LogicalExpression::Type(var)
426 | LogicalExpression::Id(var) => {
427 required.insert(RequiredColumn::Variable(var.clone()));
428 }
429 LogicalExpression::ListComprehension {
430 list_expr,
431 filter_expr,
432 map_expr,
433 ..
434 } => {
435 Self::collect_from_expression(list_expr, required);
436 if let Some(filter) = filter_expr {
437 Self::collect_from_expression(filter, required);
438 }
439 Self::collect_from_expression(map_expr, required);
440 }
441 _ => {}
442 }
443 }
444
445 fn push_projections_recursive(
447 &self,
448 op: LogicalOperator,
449 required: &HashSet<RequiredColumn>,
450 ) -> LogicalOperator {
451 match op {
452 LogicalOperator::Return(mut ret) => {
453 ret.input = Box::new(self.push_projections_recursive(*ret.input, required));
454 LogicalOperator::Return(ret)
455 }
456 LogicalOperator::Project(mut proj) => {
457 proj.input = Box::new(self.push_projections_recursive(*proj.input, required));
458 LogicalOperator::Project(proj)
459 }
460 LogicalOperator::Filter(mut filter) => {
461 filter.input = Box::new(self.push_projections_recursive(*filter.input, required));
462 LogicalOperator::Filter(filter)
463 }
464 LogicalOperator::Sort(mut sort) => {
465 sort.input = Box::new(self.push_projections_recursive(*sort.input, required));
468 LogicalOperator::Sort(sort)
469 }
470 LogicalOperator::Aggregate(mut agg) => {
471 agg.input = Box::new(self.push_projections_recursive(*agg.input, required));
472 LogicalOperator::Aggregate(agg)
473 }
474 LogicalOperator::Join(mut join) => {
475 let left_vars = self.collect_output_variables(&join.left);
478 let right_vars = self.collect_output_variables(&join.right);
479
480 let left_required: HashSet<_> = required
482 .iter()
483 .filter(|c| match c {
484 RequiredColumn::Variable(v) => left_vars.contains(v),
485 RequiredColumn::Property(v, _) => left_vars.contains(v),
486 })
487 .cloned()
488 .collect();
489
490 let right_required: HashSet<_> = required
491 .iter()
492 .filter(|c| match c {
493 RequiredColumn::Variable(v) => right_vars.contains(v),
494 RequiredColumn::Property(v, _) => right_vars.contains(v),
495 })
496 .cloned()
497 .collect();
498
499 join.left = Box::new(self.push_projections_recursive(*join.left, &left_required));
500 join.right =
501 Box::new(self.push_projections_recursive(*join.right, &right_required));
502 LogicalOperator::Join(join)
503 }
504 LogicalOperator::Expand(mut expand) => {
505 expand.input = Box::new(self.push_projections_recursive(*expand.input, required));
506 LogicalOperator::Expand(expand)
507 }
508 LogicalOperator::Limit(mut limit) => {
509 limit.input = Box::new(self.push_projections_recursive(*limit.input, required));
510 LogicalOperator::Limit(limit)
511 }
512 LogicalOperator::Skip(mut skip) => {
513 skip.input = Box::new(self.push_projections_recursive(*skip.input, required));
514 LogicalOperator::Skip(skip)
515 }
516 LogicalOperator::Distinct(mut distinct) => {
517 distinct.input =
518 Box::new(self.push_projections_recursive(*distinct.input, required));
519 LogicalOperator::Distinct(distinct)
520 }
521 LogicalOperator::MapCollect(mut mc) => {
522 mc.input = Box::new(self.push_projections_recursive(*mc.input, required));
523 LogicalOperator::MapCollect(mc)
524 }
525 LogicalOperator::MultiWayJoin(mut mwj) => {
526 mwj.inputs = mwj
527 .inputs
528 .into_iter()
529 .map(|input| self.push_projections_recursive(input, required))
530 .collect();
531 LogicalOperator::MultiWayJoin(mwj)
532 }
533 other => other,
534 }
535 }
536
537 fn reorder_joins(&self, op: LogicalOperator) -> LogicalOperator {
544 let op = self.reorder_joins_recursive(op);
546
547 if let Some((relations, conditions)) = self.extract_join_tree(&op)
549 && relations.len() >= 2
550 && let Some(optimized) = self.optimize_join_order(&relations, &conditions)
551 {
552 return optimized;
553 }
554
555 op
556 }
557
558 fn reorder_joins_recursive(&self, op: LogicalOperator) -> LogicalOperator {
560 match op {
561 LogicalOperator::Return(mut ret) => {
562 ret.input = Box::new(self.reorder_joins(*ret.input));
563 LogicalOperator::Return(ret)
564 }
565 LogicalOperator::Project(mut proj) => {
566 proj.input = Box::new(self.reorder_joins(*proj.input));
567 LogicalOperator::Project(proj)
568 }
569 LogicalOperator::Filter(mut filter) => {
570 filter.input = Box::new(self.reorder_joins(*filter.input));
571 LogicalOperator::Filter(filter)
572 }
573 LogicalOperator::Limit(mut limit) => {
574 limit.input = Box::new(self.reorder_joins(*limit.input));
575 LogicalOperator::Limit(limit)
576 }
577 LogicalOperator::Skip(mut skip) => {
578 skip.input = Box::new(self.reorder_joins(*skip.input));
579 LogicalOperator::Skip(skip)
580 }
581 LogicalOperator::Sort(mut sort) => {
582 sort.input = Box::new(self.reorder_joins(*sort.input));
583 LogicalOperator::Sort(sort)
584 }
585 LogicalOperator::Distinct(mut distinct) => {
586 distinct.input = Box::new(self.reorder_joins(*distinct.input));
587 LogicalOperator::Distinct(distinct)
588 }
589 LogicalOperator::Aggregate(mut agg) => {
590 agg.input = Box::new(self.reorder_joins(*agg.input));
591 LogicalOperator::Aggregate(agg)
592 }
593 LogicalOperator::Expand(mut expand) => {
594 expand.input = Box::new(self.reorder_joins(*expand.input));
595 LogicalOperator::Expand(expand)
596 }
597 LogicalOperator::MapCollect(mut mc) => {
598 mc.input = Box::new(self.reorder_joins(*mc.input));
599 LogicalOperator::MapCollect(mc)
600 }
601 LogicalOperator::MultiWayJoin(mut mwj) => {
602 mwj.inputs = mwj
603 .inputs
604 .into_iter()
605 .map(|input| self.reorder_joins(input))
606 .collect();
607 LogicalOperator::MultiWayJoin(mwj)
608 }
609 other => other,
611 }
612 }
613
614 fn extract_join_tree(
618 &self,
619 op: &LogicalOperator,
620 ) -> Option<(Vec<(String, LogicalOperator)>, Vec<JoinInfo>)> {
621 let mut relations = Vec::new();
622 let mut join_conditions = Vec::new();
623
624 if !self.collect_join_tree(op, &mut relations, &mut join_conditions) {
625 return None;
626 }
627
628 if relations.len() < 2 {
629 return None;
630 }
631
632 Some((relations, join_conditions))
633 }
634
635 fn collect_join_tree(
639 &self,
640 op: &LogicalOperator,
641 relations: &mut Vec<(String, LogicalOperator)>,
642 conditions: &mut Vec<JoinInfo>,
643 ) -> bool {
644 match op {
645 LogicalOperator::Join(join) => {
646 let left_ok = self.collect_join_tree(&join.left, relations, conditions);
648 let right_ok = self.collect_join_tree(&join.right, relations, conditions);
649
650 for cond in &join.conditions {
652 if let (Some(left_var), Some(right_var)) = (
653 self.extract_variable_from_expr(&cond.left),
654 self.extract_variable_from_expr(&cond.right),
655 ) {
656 conditions.push(JoinInfo {
657 left_var,
658 right_var,
659 left_expr: cond.left.clone(),
660 right_expr: cond.right.clone(),
661 });
662 }
663 }
664
665 left_ok && right_ok
666 }
667 LogicalOperator::NodeScan(scan) => {
668 relations.push((scan.variable.clone(), op.clone()));
669 true
670 }
671 LogicalOperator::EdgeScan(scan) => {
672 relations.push((scan.variable.clone(), op.clone()));
673 true
674 }
675 LogicalOperator::Filter(filter) => {
676 self.collect_join_tree(&filter.input, relations, conditions)
678 }
679 LogicalOperator::Expand(expand) => {
680 relations.push((expand.to_variable.clone(), op.clone()));
683 true
684 }
685 #[cfg(feature = "triple-store")]
686 LogicalOperator::TripleScan(scan) => {
687 let name = scan
691 .subject
692 .as_variable()
693 .or_else(|| scan.predicate.as_variable())
694 .or_else(|| scan.object.as_variable())
695 .map_or_else(|| format!("__tp_{}", relations.len()), String::from);
696 relations.push((name, op.clone()));
697 true
698 }
699 _ => false,
700 }
701 }
702
703 fn extract_variable_from_expr(&self, expr: &LogicalExpression) -> Option<String> {
705 match expr {
706 LogicalExpression::Variable(v) => Some(v.clone()),
707 LogicalExpression::Property { variable, .. } => Some(variable.clone()),
708 _ => None,
709 }
710 }
711
712 fn optimize_join_order(
715 &self,
716 relations: &[(String, LogicalOperator)],
717 conditions: &[JoinInfo],
718 ) -> Option<LogicalOperator> {
719 use join_order::{DPccp, JoinGraphBuilder};
720
721 let mut builder = JoinGraphBuilder::new();
723
724 for (var, relation) in relations {
725 builder.add_relation(var, relation.clone());
726 }
727
728 for cond in conditions {
729 builder.add_join_condition(
730 &cond.left_var,
731 &cond.right_var,
732 cond.left_expr.clone(),
733 cond.right_expr.clone(),
734 );
735 }
736
737 let graph = builder.build();
738
739 if graph.is_cyclic() && relations.len() >= 3 {
744 let mut var_counts: std::collections::HashMap<&str, usize> =
746 std::collections::HashMap::new();
747 for cond in conditions {
748 *var_counts.entry(&cond.left_var).or_default() += 1;
749 *var_counts.entry(&cond.right_var).or_default() += 1;
750 }
751 let shared_variables: Vec<String> = var_counts
752 .into_iter()
753 .filter(|(_, count)| *count >= 2)
754 .map(|(var, _)| var.to_string())
755 .collect();
756
757 let join_conditions: Vec<JoinCondition> = conditions
758 .iter()
759 .map(|c| JoinCondition {
760 left: c.left_expr.clone(),
761 right: c.right_expr.clone(),
762 })
763 .collect();
764
765 return Some(LogicalOperator::MultiWayJoin(MultiWayJoinOp {
766 inputs: relations.iter().map(|(_, rel)| rel.clone()).collect(),
767 conditions: join_conditions,
768 shared_variables,
769 }));
770 }
771
772 let mut dpccp = DPccp::new(&graph, &self.cost_model, &self.card_estimator);
774 let plan = dpccp.optimize()?;
775
776 Some(plan.operator)
777 }
778
779 fn propagate_join_predicates(&self, op: LogicalOperator) -> LogicalOperator {
798 let op = op.map_children(|child| self.propagate_join_predicates(child));
801
802 let LogicalOperator::LeftJoin(mut left_join) = op else {
803 return op;
804 };
805
806 let left_vars = self.collect_output_variables(&left_join.left);
810 let right_vars = self.collect_output_variables(&left_join.right);
811 let shared_vars: HashSet<String> = left_vars.intersection(&right_vars).cloned().collect();
812
813 if shared_vars.is_empty() {
814 return LogicalOperator::LeftJoin(left_join);
815 }
816
817 let mut shared_filters = Vec::new();
818 self.collect_shared_var_filters(&left_join.left, &shared_vars, &mut shared_filters);
819 for predicate in shared_filters {
820 left_join.right = Box::new(LogicalOperator::Filter(FilterOp {
821 predicate,
822 pushdown_hint: None,
823 input: left_join.right,
824 }));
825 }
826
827 LogicalOperator::LeftJoin(left_join)
828 }
829
830 fn collect_shared_var_filters(
834 &self,
835 op: &LogicalOperator,
836 shared_vars: &HashSet<String>,
837 out: &mut Vec<LogicalExpression>,
838 ) {
839 match op {
840 LogicalOperator::Filter(f) => {
841 let predicate_vars = self.extract_variables(&f.predicate);
842 if !predicate_vars.is_empty()
843 && predicate_vars.iter().all(|v| shared_vars.contains(v))
844 {
845 out.push(f.predicate.clone());
846 }
847 self.collect_shared_var_filters(&f.input, shared_vars, out);
848 }
849 LogicalOperator::Project(p) => {
855 self.collect_shared_var_filters(&p.input, shared_vars, out);
856 }
857 LogicalOperator::Return(r) => {
858 self.collect_shared_var_filters(&r.input, shared_vars, out);
859 }
860 LogicalOperator::Expand(e) => {
861 self.collect_shared_var_filters(&e.input, shared_vars, out);
862 }
863 LogicalOperator::LeftJoin(j) => {
871 self.collect_shared_var_filters(&j.left, shared_vars, out);
872 }
873 LogicalOperator::Join(j) => {
874 self.collect_shared_var_filters(&j.left, shared_vars, out);
875 self.collect_shared_var_filters(&j.right, shared_vars, out);
876 }
877 _ => {}
880 }
881 }
882
883 fn push_filters_down(&self, op: LogicalOperator) -> LogicalOperator {
888 match op {
889 LogicalOperator::Filter(filter) => {
891 let optimized_input = self.push_filters_down(*filter.input);
892 self.try_push_filter_into(filter.predicate, optimized_input)
893 }
894 LogicalOperator::Return(mut ret) => {
896 ret.input = Box::new(self.push_filters_down(*ret.input));
897 LogicalOperator::Return(ret)
898 }
899 LogicalOperator::Project(mut proj) => {
900 proj.input = Box::new(self.push_filters_down(*proj.input));
901 LogicalOperator::Project(proj)
902 }
903 LogicalOperator::Limit(mut limit) => {
904 limit.input = Box::new(self.push_filters_down(*limit.input));
905 LogicalOperator::Limit(limit)
906 }
907 LogicalOperator::Skip(mut skip) => {
908 skip.input = Box::new(self.push_filters_down(*skip.input));
909 LogicalOperator::Skip(skip)
910 }
911 LogicalOperator::Sort(mut sort) => {
912 sort.input = Box::new(self.push_filters_down(*sort.input));
913 LogicalOperator::Sort(sort)
914 }
915 LogicalOperator::Distinct(mut distinct) => {
916 distinct.input = Box::new(self.push_filters_down(*distinct.input));
917 LogicalOperator::Distinct(distinct)
918 }
919 LogicalOperator::Expand(mut expand) => {
920 expand.input = Box::new(self.push_filters_down(*expand.input));
921 LogicalOperator::Expand(expand)
922 }
923 LogicalOperator::Join(mut join) => {
924 join.left = Box::new(self.push_filters_down(*join.left));
925 join.right = Box::new(self.push_filters_down(*join.right));
926 LogicalOperator::Join(join)
927 }
928 LogicalOperator::LeftJoin(mut left_join) => {
929 left_join.left = Box::new(self.push_filters_down(*left_join.left));
930 left_join.right = Box::new(self.push_filters_down(*left_join.right));
931 LogicalOperator::LeftJoin(left_join)
932 }
933 LogicalOperator::AntiJoin(mut anti_join) => {
934 anti_join.left = Box::new(self.push_filters_down(*anti_join.left));
935 anti_join.right = Box::new(self.push_filters_down(*anti_join.right));
936 LogicalOperator::AntiJoin(anti_join)
937 }
938 LogicalOperator::Apply(mut apply) => {
939 apply.input = Box::new(self.push_filters_down(*apply.input));
940 apply.subplan = Box::new(self.push_filters_down(*apply.subplan));
941 LogicalOperator::Apply(apply)
942 }
943 LogicalOperator::Union(mut union) => {
944 union.inputs = union
945 .inputs
946 .into_iter()
947 .map(|input| self.push_filters_down(input))
948 .collect();
949 LogicalOperator::Union(union)
950 }
951 LogicalOperator::Unwind(mut unwind) => {
952 unwind.input = Box::new(self.push_filters_down(*unwind.input));
953 LogicalOperator::Unwind(unwind)
954 }
955 LogicalOperator::Aggregate(mut agg) => {
956 agg.input = Box::new(self.push_filters_down(*agg.input));
957 LogicalOperator::Aggregate(agg)
958 }
959 LogicalOperator::MapCollect(mut mc) => {
960 mc.input = Box::new(self.push_filters_down(*mc.input));
961 LogicalOperator::MapCollect(mc)
962 }
963 LogicalOperator::MultiWayJoin(mut mwj) => {
964 mwj.inputs = mwj
965 .inputs
966 .into_iter()
967 .map(|input| self.push_filters_down(input))
968 .collect();
969 LogicalOperator::MultiWayJoin(mwj)
970 }
971 other => other,
973 }
974 }
975
976 fn try_push_filter_into(
981 &self,
982 predicate: LogicalExpression,
983 op: LogicalOperator,
984 ) -> LogicalOperator {
985 match op {
986 LogicalOperator::Project(mut proj) => {
988 let predicate_vars = self.extract_variables(&predicate);
989 let computed_vars = self.extract_projection_aliases(&proj.projections);
990
991 if predicate_vars.is_disjoint(&computed_vars) {
993 proj.input = Box::new(self.try_push_filter_into(predicate, *proj.input));
994 LogicalOperator::Project(proj)
995 } else {
996 LogicalOperator::Filter(FilterOp {
998 predicate,
999 pushdown_hint: None,
1000 input: Box::new(LogicalOperator::Project(proj)),
1001 })
1002 }
1003 }
1004
1005 LogicalOperator::Return(mut ret) => {
1007 ret.input = Box::new(self.try_push_filter_into(predicate, *ret.input));
1008 LogicalOperator::Return(ret)
1009 }
1010
1011 LogicalOperator::Expand(mut expand) => {
1013 let predicate_vars = self.extract_variables(&predicate);
1014
1015 let mut introduced_vars = vec![&expand.to_variable];
1020 if let Some(ref edge_var) = expand.edge_variable {
1021 introduced_vars.push(edge_var);
1022 }
1023 if let Some(ref path_alias) = expand.path_alias {
1024 introduced_vars.push(path_alias);
1025 }
1026
1027 let uses_introduced_vars =
1029 predicate_vars.iter().any(|v| introduced_vars.contains(&v));
1030
1031 if !uses_introduced_vars {
1032 expand.input = Box::new(self.try_push_filter_into(predicate, *expand.input));
1034 LogicalOperator::Expand(expand)
1035 } else {
1036 LogicalOperator::Filter(FilterOp {
1038 predicate,
1039 pushdown_hint: None,
1040 input: Box::new(LogicalOperator::Expand(expand)),
1041 })
1042 }
1043 }
1044
1045 LogicalOperator::Join(mut join) => {
1047 let predicate_vars = self.extract_variables(&predicate);
1048 let left_vars = self.collect_output_variables(&join.left);
1049 let right_vars = self.collect_output_variables(&join.right);
1050
1051 let uses_left = predicate_vars.iter().any(|v| left_vars.contains(v));
1052 let uses_right = predicate_vars.iter().any(|v| right_vars.contains(v));
1053
1054 if uses_left && !uses_right {
1055 join.left = Box::new(self.try_push_filter_into(predicate, *join.left));
1057 LogicalOperator::Join(join)
1058 } else if uses_right && !uses_left {
1059 join.right = Box::new(self.try_push_filter_into(predicate, *join.right));
1061 LogicalOperator::Join(join)
1062 } else {
1063 LogicalOperator::Filter(FilterOp {
1065 predicate,
1066 pushdown_hint: None,
1067 input: Box::new(LogicalOperator::Join(join)),
1068 })
1069 }
1070 }
1071
1072 LogicalOperator::LeftJoin(mut left_join) => {
1078 let predicate_vars = self.extract_variables(&predicate);
1079 let left_vars = self.collect_output_variables(&left_join.left);
1080 let right_vars = self.collect_output_variables(&left_join.right);
1081
1082 let uses_left = predicate_vars.iter().any(|v| left_vars.contains(v));
1083 let uses_right = predicate_vars.iter().any(|v| right_vars.contains(v));
1084
1085 if uses_left && !uses_right {
1086 left_join.left =
1087 Box::new(self.try_push_filter_into(predicate, *left_join.left));
1088 LogicalOperator::LeftJoin(left_join)
1089 } else if uses_left
1090 && uses_right
1091 && predicate_vars
1092 .iter()
1093 .all(|v| left_vars.contains(v) && right_vars.contains(v))
1094 {
1095 left_join.left =
1109 Box::new(self.try_push_filter_into(predicate.clone(), *left_join.left));
1110 left_join.right =
1111 Box::new(self.try_push_filter_into(predicate, *left_join.right));
1112 LogicalOperator::LeftJoin(left_join)
1113 } else {
1114 LogicalOperator::Filter(FilterOp {
1115 predicate,
1116 pushdown_hint: None,
1117 input: Box::new(LogicalOperator::LeftJoin(left_join)),
1118 })
1119 }
1120 }
1121
1122 LogicalOperator::Apply(mut apply) => {
1126 let predicate_vars = self.extract_variables(&predicate);
1127 let input_vars = self.collect_output_variables(&apply.input);
1128 let subplan_vars = self.collect_output_variables(&apply.subplan);
1129
1130 let uses_input = predicate_vars.iter().any(|v| input_vars.contains(v));
1131 let uses_subplan = predicate_vars.iter().any(|v| subplan_vars.contains(v));
1132
1133 if uses_input && !uses_subplan {
1134 apply.input = Box::new(self.try_push_filter_into(predicate, *apply.input));
1135 LogicalOperator::Apply(apply)
1136 } else {
1137 LogicalOperator::Filter(FilterOp {
1138 predicate,
1139 pushdown_hint: None,
1140 input: Box::new(LogicalOperator::Apply(apply)),
1141 })
1142 }
1143 }
1144
1145 LogicalOperator::Aggregate(agg) => LogicalOperator::Filter(FilterOp {
1147 predicate,
1148 pushdown_hint: None,
1149 input: Box::new(LogicalOperator::Aggregate(agg)),
1150 }),
1151
1152 LogicalOperator::NodeScan(scan) => LogicalOperator::Filter(FilterOp {
1154 predicate,
1155 pushdown_hint: None,
1156 input: Box::new(LogicalOperator::NodeScan(scan)),
1157 }),
1158
1159 LogicalOperator::Filter(inner_filter) => {
1178 let predicate_vars = self.extract_variables(&predicate);
1179 let inner_input_vars = self.collect_output_variables(&inner_filter.input);
1180 let safe_to_commute = predicate_vars.iter().all(|v| inner_input_vars.contains(v));
1181 if safe_to_commute {
1182 let mut inner_filter = inner_filter;
1183 inner_filter.input =
1184 Box::new(self.try_push_filter_into(predicate, *inner_filter.input));
1185 LogicalOperator::Filter(inner_filter)
1186 } else {
1187 LogicalOperator::Filter(FilterOp {
1188 predicate,
1189 pushdown_hint: None,
1190 input: Box::new(LogicalOperator::Filter(inner_filter)),
1191 })
1192 }
1193 }
1194
1195 other => LogicalOperator::Filter(FilterOp {
1197 predicate,
1198 pushdown_hint: None,
1199 input: Box::new(other),
1200 }),
1201 }
1202 }
1203
1204 fn collect_output_variables(&self, op: &LogicalOperator) -> HashSet<String> {
1212 let mut vars = HashSet::new();
1213 Self::collect_output_variables_recursive(op, &mut vars);
1214 vars
1215 }
1216
1217 fn collect_output_variables_recursive(op: &LogicalOperator, vars: &mut HashSet<String>) {
1219 match op {
1220 LogicalOperator::NodeScan(scan) => {
1221 vars.insert(scan.variable.clone());
1222 }
1223 LogicalOperator::EdgeScan(scan) => {
1224 vars.insert(scan.variable.clone());
1225 }
1226 LogicalOperator::Expand(expand) => {
1227 vars.insert(expand.to_variable.clone());
1228 if let Some(edge_var) = &expand.edge_variable {
1229 vars.insert(edge_var.clone());
1230 }
1231 Self::collect_output_variables_recursive(&expand.input, vars);
1232 }
1233 LogicalOperator::Filter(filter) => {
1234 Self::collect_output_variables_recursive(&filter.input, vars);
1235 }
1236 LogicalOperator::Project(proj) => {
1237 for p in &proj.projections {
1238 if let Some(alias) = &p.alias {
1239 vars.insert(alias.clone());
1240 }
1241 }
1242 Self::collect_output_variables_recursive(&proj.input, vars);
1243 }
1244 LogicalOperator::Join(join) => {
1245 Self::collect_output_variables_recursive(&join.left, vars);
1246 Self::collect_output_variables_recursive(&join.right, vars);
1247 }
1248 LogicalOperator::Aggregate(agg) => {
1249 for expr in &agg.group_by {
1250 Self::collect_variables(expr, vars);
1251 }
1252 for agg_expr in &agg.aggregates {
1253 if let Some(alias) = &agg_expr.alias {
1254 vars.insert(alias.clone());
1255 }
1256 }
1257 }
1258 LogicalOperator::Return(ret) => {
1259 Self::collect_output_variables_recursive(&ret.input, vars);
1260 }
1261 LogicalOperator::Limit(limit) => {
1262 Self::collect_output_variables_recursive(&limit.input, vars);
1263 }
1264 LogicalOperator::Skip(skip) => {
1265 Self::collect_output_variables_recursive(&skip.input, vars);
1266 }
1267 LogicalOperator::Sort(sort) => {
1268 Self::collect_output_variables_recursive(&sort.input, vars);
1269 }
1270 LogicalOperator::Distinct(distinct) => {
1271 Self::collect_output_variables_recursive(&distinct.input, vars);
1272 }
1273 #[cfg(feature = "triple-store")]
1274 LogicalOperator::TripleScan(scan) => {
1275 if let Some(v) = scan.subject.as_variable() {
1276 vars.insert(v.to_string());
1277 }
1278 if let Some(v) = scan.predicate.as_variable() {
1279 vars.insert(v.to_string());
1280 }
1281 if let Some(v) = scan.object.as_variable() {
1282 vars.insert(v.to_string());
1283 }
1284 if let Some(ref g) = scan.graph
1285 && let Some(v) = g.as_variable()
1286 {
1287 vars.insert(v.to_string());
1288 }
1289 }
1290 _ => {}
1291 }
1292 }
1293
1294 fn extract_variables(&self, expr: &LogicalExpression) -> HashSet<String> {
1296 let mut vars = HashSet::new();
1297 Self::collect_variables(expr, &mut vars);
1298 vars
1299 }
1300
1301 fn collect_variables(expr: &LogicalExpression, vars: &mut HashSet<String>) {
1303 match expr {
1304 LogicalExpression::Variable(name) => {
1305 vars.insert(name.clone());
1306 }
1307 LogicalExpression::Property { variable, .. } => {
1308 vars.insert(variable.clone());
1309 }
1310 LogicalExpression::Binary { left, right, .. } => {
1311 Self::collect_variables(left, vars);
1312 Self::collect_variables(right, vars);
1313 }
1314 LogicalExpression::Unary { operand, .. } => {
1315 Self::collect_variables(operand, vars);
1316 }
1317 LogicalExpression::FunctionCall { args, .. } => {
1318 for arg in args {
1319 Self::collect_variables(arg, vars);
1320 }
1321 }
1322 LogicalExpression::List(items) => {
1323 for item in items {
1324 Self::collect_variables(item, vars);
1325 }
1326 }
1327 LogicalExpression::Map(pairs) => {
1328 for (_, value) in pairs {
1329 Self::collect_variables(value, vars);
1330 }
1331 }
1332 LogicalExpression::IndexAccess { base, index } => {
1333 Self::collect_variables(base, vars);
1334 Self::collect_variables(index, vars);
1335 }
1336 LogicalExpression::SliceAccess { base, start, end } => {
1337 Self::collect_variables(base, vars);
1338 if let Some(s) = start {
1339 Self::collect_variables(s, vars);
1340 }
1341 if let Some(e) = end {
1342 Self::collect_variables(e, vars);
1343 }
1344 }
1345 LogicalExpression::Case {
1346 operand,
1347 when_clauses,
1348 else_clause,
1349 } => {
1350 if let Some(op) = operand {
1351 Self::collect_variables(op, vars);
1352 }
1353 for (cond, result) in when_clauses {
1354 Self::collect_variables(cond, vars);
1355 Self::collect_variables(result, vars);
1356 }
1357 if let Some(else_expr) = else_clause {
1358 Self::collect_variables(else_expr, vars);
1359 }
1360 }
1361 LogicalExpression::Labels(var)
1362 | LogicalExpression::Type(var)
1363 | LogicalExpression::Id(var) => {
1364 vars.insert(var.clone());
1365 }
1366 LogicalExpression::Literal(_) | LogicalExpression::Parameter(_) => {}
1367 LogicalExpression::ListComprehension {
1368 list_expr,
1369 filter_expr,
1370 map_expr,
1371 ..
1372 } => {
1373 Self::collect_variables(list_expr, vars);
1374 if let Some(filter) = filter_expr {
1375 Self::collect_variables(filter, vars);
1376 }
1377 Self::collect_variables(map_expr, vars);
1378 }
1379 LogicalExpression::ListPredicate {
1380 list_expr,
1381 predicate,
1382 ..
1383 } => {
1384 Self::collect_variables(list_expr, vars);
1385 Self::collect_variables(predicate, vars);
1386 }
1387 LogicalExpression::ExistsSubquery(_)
1388 | LogicalExpression::CountSubquery(_)
1389 | LogicalExpression::ValueSubquery(_) => {
1390 }
1392 LogicalExpression::PatternComprehension { projection, .. } => {
1393 Self::collect_variables(projection, vars);
1394 }
1395 LogicalExpression::MapProjection { base, entries } => {
1396 vars.insert(base.clone());
1397 for entry in entries {
1398 if let crate::query::plan::MapProjectionEntry::LiteralEntry(_, expr) = entry {
1399 Self::collect_variables(expr, vars);
1400 }
1401 }
1402 }
1403 LogicalExpression::Reduce {
1404 initial,
1405 list,
1406 expression,
1407 ..
1408 } => {
1409 Self::collect_variables(initial, vars);
1410 Self::collect_variables(list, vars);
1411 Self::collect_variables(expression, vars);
1412 }
1413 }
1414 }
1415
1416 fn extract_projection_aliases(
1418 &self,
1419 projections: &[crate::query::plan::Projection],
1420 ) -> HashSet<String> {
1421 projections.iter().filter_map(|p| p.alias.clone()).collect()
1422 }
1423}
1424
1425impl Default for Optimizer {
1426 fn default() -> Self {
1427 Self::new()
1428 }
1429}
1430
1431#[cfg(test)]
1432mod tests {
1433 use super::*;
1434 use crate::query::plan::{
1435 AggregateExpr, AggregateFunction, AggregateOp, BinaryOp, DistinctOp, ExpandDirection,
1436 ExpandOp, JoinOp, JoinType, LimitOp, NodeScanOp, PathMode, ProjectOp, Projection,
1437 ReturnItem, ReturnOp, SkipOp, SortKey, SortOp, SortOrder, UnaryOp,
1438 };
1439 use grafeo_common::types::Value;
1440
1441 #[test]
1442 fn test_optimizer_filter_pushdown_simple() {
1443 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1448 items: vec![ReturnItem {
1449 expression: LogicalExpression::Variable("n".to_string()),
1450 alias: None,
1451 }],
1452 distinct: false,
1453 input: Box::new(LogicalOperator::Filter(FilterOp {
1454 predicate: LogicalExpression::Binary {
1455 left: Box::new(LogicalExpression::Property {
1456 variable: "n".to_string(),
1457 property: "age".to_string(),
1458 }),
1459 op: BinaryOp::Gt,
1460 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1461 },
1462 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1463 variable: "n".to_string(),
1464 label: Some("Person".to_string()),
1465 input: None,
1466 })),
1467 pushdown_hint: None,
1468 })),
1469 }));
1470
1471 let optimizer = Optimizer::new();
1472 let optimized = optimizer.optimize(plan).unwrap();
1473
1474 if let LogicalOperator::Return(ret) = &optimized.root
1476 && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1477 && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1478 {
1479 assert_eq!(scan.variable, "n");
1480 return;
1481 }
1482 panic!("Expected Return -> Filter -> NodeScan structure");
1483 }
1484
1485 #[test]
1486 fn test_optimizer_filter_pushdown_through_expand() {
1487 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1491 items: vec![ReturnItem {
1492 expression: LogicalExpression::Variable("b".to_string()),
1493 alias: None,
1494 }],
1495 distinct: false,
1496 input: Box::new(LogicalOperator::Filter(FilterOp {
1497 predicate: LogicalExpression::Binary {
1498 left: Box::new(LogicalExpression::Property {
1499 variable: "a".to_string(),
1500 property: "age".to_string(),
1501 }),
1502 op: BinaryOp::Gt,
1503 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1504 },
1505 pushdown_hint: None,
1506 input: Box::new(LogicalOperator::Expand(ExpandOp {
1507 from_variable: "a".to_string(),
1508 to_variable: "b".to_string(),
1509 edge_variable: None,
1510 direction: ExpandDirection::Outgoing,
1511 edge_types: vec!["KNOWS".to_string()],
1512 min_hops: 1,
1513 max_hops: Some(1),
1514 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1515 variable: "a".to_string(),
1516 label: Some("Person".to_string()),
1517 input: None,
1518 })),
1519 path_alias: None,
1520 path_mode: PathMode::Walk,
1521 })),
1522 })),
1523 }));
1524
1525 let optimizer = Optimizer::new();
1526 let optimized = optimizer.optimize(plan).unwrap();
1527
1528 if let LogicalOperator::Return(ret) = &optimized.root
1531 && let LogicalOperator::Expand(expand) = ret.input.as_ref()
1532 && let LogicalOperator::Filter(filter) = expand.input.as_ref()
1533 && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1534 {
1535 assert_eq!(scan.variable, "a");
1536 assert_eq!(expand.from_variable, "a");
1537 assert_eq!(expand.to_variable, "b");
1538 return;
1539 }
1540 panic!("Expected Return -> Expand -> Filter -> NodeScan structure");
1541 }
1542
1543 #[test]
1544 fn test_optimizer_filter_not_pushed_through_expand_for_target_var() {
1545 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1549 items: vec![ReturnItem {
1550 expression: LogicalExpression::Variable("a".to_string()),
1551 alias: None,
1552 }],
1553 distinct: false,
1554 input: Box::new(LogicalOperator::Filter(FilterOp {
1555 predicate: LogicalExpression::Binary {
1556 left: Box::new(LogicalExpression::Property {
1557 variable: "b".to_string(),
1558 property: "age".to_string(),
1559 }),
1560 op: BinaryOp::Gt,
1561 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1562 },
1563 pushdown_hint: None,
1564 input: Box::new(LogicalOperator::Expand(ExpandOp {
1565 from_variable: "a".to_string(),
1566 to_variable: "b".to_string(),
1567 edge_variable: None,
1568 direction: ExpandDirection::Outgoing,
1569 edge_types: vec!["KNOWS".to_string()],
1570 min_hops: 1,
1571 max_hops: Some(1),
1572 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1573 variable: "a".to_string(),
1574 label: Some("Person".to_string()),
1575 input: None,
1576 })),
1577 path_alias: None,
1578 path_mode: PathMode::Walk,
1579 })),
1580 })),
1581 }));
1582
1583 let optimizer = Optimizer::new();
1584 let optimized = optimizer.optimize(plan).unwrap();
1585
1586 if let LogicalOperator::Return(ret) = &optimized.root
1589 && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1590 {
1591 if let LogicalExpression::Binary { left, .. } = &filter.predicate
1593 && let LogicalExpression::Property { variable, .. } = left.as_ref()
1594 {
1595 assert_eq!(variable, "b");
1596 }
1597
1598 if let LogicalOperator::Expand(expand) = filter.input.as_ref()
1599 && let LogicalOperator::NodeScan(_) = expand.input.as_ref()
1600 {
1601 return;
1602 }
1603 }
1604 panic!("Expected Return -> Filter -> Expand -> NodeScan structure");
1605 }
1606
1607 #[test]
1608 fn test_optimizer_extract_variables() {
1609 let optimizer = Optimizer::new();
1610
1611 let expr = LogicalExpression::Binary {
1612 left: Box::new(LogicalExpression::Property {
1613 variable: "n".to_string(),
1614 property: "age".to_string(),
1615 }),
1616 op: BinaryOp::Gt,
1617 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1618 };
1619
1620 let vars = optimizer.extract_variables(&expr);
1621 assert_eq!(vars.len(), 1);
1622 assert!(vars.contains("n"));
1623 }
1624
1625 #[test]
1628 fn test_optimizer_default() {
1629 let optimizer = Optimizer::default();
1630 let plan = LogicalPlan::new(LogicalOperator::Empty);
1632 let result = optimizer.optimize(plan);
1633 assert!(result.is_ok());
1634 }
1635
1636 #[test]
1637 fn test_optimizer_with_filter_pushdown_disabled() {
1638 let optimizer = Optimizer::new().with_filter_pushdown(false);
1639
1640 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1641 items: vec![ReturnItem {
1642 expression: LogicalExpression::Variable("n".to_string()),
1643 alias: None,
1644 }],
1645 distinct: false,
1646 input: Box::new(LogicalOperator::Filter(FilterOp {
1647 predicate: LogicalExpression::Literal(Value::Bool(true)),
1648 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1649 variable: "n".to_string(),
1650 label: None,
1651 input: None,
1652 })),
1653 pushdown_hint: None,
1654 })),
1655 }));
1656
1657 let optimized = optimizer.optimize(plan).unwrap();
1658 if let LogicalOperator::Return(ret) = &optimized.root
1660 && let LogicalOperator::Filter(_) = ret.input.as_ref()
1661 {
1662 return;
1663 }
1664 panic!("Expected unchanged structure");
1665 }
1666
1667 #[test]
1668 fn test_optimizer_with_join_reorder_disabled() {
1669 let optimizer = Optimizer::new().with_join_reorder(false);
1670 assert!(
1671 optimizer
1672 .optimize(LogicalPlan::new(LogicalOperator::Empty))
1673 .is_ok()
1674 );
1675 }
1676
1677 #[test]
1678 fn test_optimizer_with_cost_model() {
1679 let cost_model = CostModel::new();
1680 let optimizer = Optimizer::new().with_cost_model(cost_model);
1681 assert!(
1682 optimizer
1683 .cost_model()
1684 .estimate(&LogicalOperator::Empty, 0.0)
1685 .total()
1686 < 0.001
1687 );
1688 }
1689
1690 #[test]
1691 fn test_optimizer_with_cardinality_estimator() {
1692 let mut estimator = CardinalityEstimator::new();
1693 estimator.add_table_stats("Test", TableStats::new(500));
1694 let optimizer = Optimizer::new().with_cardinality_estimator(estimator);
1695
1696 let scan = LogicalOperator::NodeScan(NodeScanOp {
1697 variable: "n".to_string(),
1698 label: Some("Test".to_string()),
1699 input: None,
1700 });
1701 let plan = LogicalPlan::new(scan);
1702
1703 let cardinality = optimizer.estimate_cardinality(&plan);
1704 assert!((cardinality - 500.0).abs() < 0.001);
1705 }
1706
1707 #[test]
1708 fn test_optimizer_estimate_cost() {
1709 let optimizer = Optimizer::new();
1710 let plan = LogicalPlan::new(LogicalOperator::NodeScan(NodeScanOp {
1711 variable: "n".to_string(),
1712 label: None,
1713 input: None,
1714 }));
1715
1716 let cost = optimizer.estimate_cost(&plan);
1717 assert!(cost.total() > 0.0);
1718 }
1719
1720 #[test]
1723 fn test_filter_pushdown_through_project() {
1724 let optimizer = Optimizer::new();
1725
1726 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1727 predicate: LogicalExpression::Binary {
1728 left: Box::new(LogicalExpression::Property {
1729 variable: "n".to_string(),
1730 property: "age".to_string(),
1731 }),
1732 op: BinaryOp::Gt,
1733 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1734 },
1735 pushdown_hint: None,
1736 input: Box::new(LogicalOperator::Project(ProjectOp {
1737 projections: vec![Projection {
1738 expression: LogicalExpression::Variable("n".to_string()),
1739 alias: None,
1740 }],
1741 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1742 variable: "n".to_string(),
1743 label: None,
1744 input: None,
1745 })),
1746 pass_through_input: false,
1747 })),
1748 }));
1749
1750 let optimized = optimizer.optimize(plan).unwrap();
1751
1752 if let LogicalOperator::Project(proj) = &optimized.root
1754 && let LogicalOperator::Filter(_) = proj.input.as_ref()
1755 {
1756 return;
1757 }
1758 panic!("Expected Project -> Filter structure");
1759 }
1760
1761 #[test]
1762 fn test_filter_not_pushed_through_project_with_alias() {
1763 let optimizer = Optimizer::new();
1764
1765 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1767 predicate: LogicalExpression::Binary {
1768 left: Box::new(LogicalExpression::Variable("x".to_string())),
1769 op: BinaryOp::Gt,
1770 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1771 },
1772 pushdown_hint: None,
1773 input: Box::new(LogicalOperator::Project(ProjectOp {
1774 projections: vec![Projection {
1775 expression: LogicalExpression::Property {
1776 variable: "n".to_string(),
1777 property: "age".to_string(),
1778 },
1779 alias: Some("x".to_string()),
1780 }],
1781 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1782 variable: "n".to_string(),
1783 label: None,
1784 input: None,
1785 })),
1786 pass_through_input: false,
1787 })),
1788 }));
1789
1790 let optimized = optimizer.optimize(plan).unwrap();
1791
1792 if let LogicalOperator::Filter(filter) = &optimized.root
1794 && let LogicalOperator::Project(_) = filter.input.as_ref()
1795 {
1796 return;
1797 }
1798 panic!("Expected Filter -> Project structure");
1799 }
1800
1801 #[test]
1802 fn test_filter_pushdown_through_limit() {
1803 let optimizer = Optimizer::new();
1804
1805 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1806 predicate: LogicalExpression::Literal(Value::Bool(true)),
1807 pushdown_hint: None,
1808 input: Box::new(LogicalOperator::Limit(LimitOp {
1809 count: 10.into(),
1810 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1811 variable: "n".to_string(),
1812 label: None,
1813 input: None,
1814 })),
1815 })),
1816 }));
1817
1818 let optimized = optimizer.optimize(plan).unwrap();
1819
1820 if let LogicalOperator::Filter(filter) = &optimized.root
1822 && let LogicalOperator::Limit(_) = filter.input.as_ref()
1823 {
1824 return;
1825 }
1826 panic!("Expected Filter -> Limit structure");
1827 }
1828
1829 #[test]
1830 fn test_filter_pushdown_through_sort() {
1831 let optimizer = Optimizer::new();
1832
1833 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1834 predicate: LogicalExpression::Literal(Value::Bool(true)),
1835 pushdown_hint: None,
1836 input: Box::new(LogicalOperator::Sort(SortOp {
1837 keys: vec![SortKey {
1838 expression: LogicalExpression::Variable("n".to_string()),
1839 order: SortOrder::Ascending,
1840 nulls: None,
1841 }],
1842 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1843 variable: "n".to_string(),
1844 label: None,
1845 input: None,
1846 })),
1847 })),
1848 }));
1849
1850 let optimized = optimizer.optimize(plan).unwrap();
1851
1852 if let LogicalOperator::Filter(filter) = &optimized.root
1854 && let LogicalOperator::Sort(_) = filter.input.as_ref()
1855 {
1856 return;
1857 }
1858 panic!("Expected Filter -> Sort structure");
1859 }
1860
1861 #[test]
1862 fn test_filter_pushdown_through_distinct() {
1863 let optimizer = Optimizer::new();
1864
1865 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1866 predicate: LogicalExpression::Literal(Value::Bool(true)),
1867 pushdown_hint: None,
1868 input: Box::new(LogicalOperator::Distinct(DistinctOp {
1869 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1870 variable: "n".to_string(),
1871 label: None,
1872 input: None,
1873 })),
1874 columns: None,
1875 })),
1876 }));
1877
1878 let optimized = optimizer.optimize(plan).unwrap();
1879
1880 if let LogicalOperator::Filter(filter) = &optimized.root
1882 && let LogicalOperator::Distinct(_) = filter.input.as_ref()
1883 {
1884 return;
1885 }
1886 panic!("Expected Filter -> Distinct structure");
1887 }
1888
1889 #[test]
1890 fn test_filter_not_pushed_through_aggregate() {
1891 let optimizer = Optimizer::new();
1892
1893 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1894 predicate: LogicalExpression::Binary {
1895 left: Box::new(LogicalExpression::Variable("cnt".to_string())),
1896 op: BinaryOp::Gt,
1897 right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1898 },
1899 pushdown_hint: None,
1900 input: Box::new(LogicalOperator::Aggregate(AggregateOp {
1901 group_by: vec![],
1902 aggregates: vec![AggregateExpr {
1903 function: AggregateFunction::Count,
1904 expression: None,
1905 expression2: None,
1906 distinct: false,
1907 alias: Some("cnt".to_string()),
1908 percentile: None,
1909 separator: None,
1910 }],
1911 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1912 variable: "n".to_string(),
1913 label: None,
1914 input: None,
1915 })),
1916 having: None,
1917 })),
1918 }));
1919
1920 let optimized = optimizer.optimize(plan).unwrap();
1921
1922 if let LogicalOperator::Filter(filter) = &optimized.root
1924 && let LogicalOperator::Aggregate(_) = filter.input.as_ref()
1925 {
1926 return;
1927 }
1928 panic!("Expected Filter -> Aggregate structure");
1929 }
1930
1931 #[test]
1932 fn test_filter_pushdown_to_left_join_side() {
1933 let optimizer = Optimizer::new();
1934
1935 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1937 predicate: LogicalExpression::Binary {
1938 left: Box::new(LogicalExpression::Property {
1939 variable: "a".to_string(),
1940 property: "age".to_string(),
1941 }),
1942 op: BinaryOp::Gt,
1943 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1944 },
1945 pushdown_hint: None,
1946 input: Box::new(LogicalOperator::Join(JoinOp {
1947 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1948 variable: "a".to_string(),
1949 label: Some("Person".to_string()),
1950 input: None,
1951 })),
1952 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1953 variable: "b".to_string(),
1954 label: Some("Company".to_string()),
1955 input: None,
1956 })),
1957 join_type: JoinType::Inner,
1958 conditions: vec![],
1959 })),
1960 }));
1961
1962 let optimized = optimizer.optimize(plan).unwrap();
1963
1964 if let LogicalOperator::Join(join) = &optimized.root
1966 && let LogicalOperator::Filter(_) = join.left.as_ref()
1967 {
1968 return;
1969 }
1970 panic!("Expected Join with Filter on left side");
1971 }
1972
1973 #[test]
1974 fn test_filter_pushdown_to_right_join_side() {
1975 let optimizer = Optimizer::new();
1976
1977 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1979 predicate: LogicalExpression::Binary {
1980 left: Box::new(LogicalExpression::Property {
1981 variable: "b".to_string(),
1982 property: "name".to_string(),
1983 }),
1984 op: BinaryOp::Eq,
1985 right: Box::new(LogicalExpression::Literal(Value::String("Acme".into()))),
1986 },
1987 pushdown_hint: None,
1988 input: Box::new(LogicalOperator::Join(JoinOp {
1989 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1990 variable: "a".to_string(),
1991 label: Some("Person".to_string()),
1992 input: None,
1993 })),
1994 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1995 variable: "b".to_string(),
1996 label: Some("Company".to_string()),
1997 input: None,
1998 })),
1999 join_type: JoinType::Inner,
2000 conditions: vec![],
2001 })),
2002 }));
2003
2004 let optimized = optimizer.optimize(plan).unwrap();
2005
2006 if let LogicalOperator::Join(join) = &optimized.root
2008 && let LogicalOperator::Filter(_) = join.right.as_ref()
2009 {
2010 return;
2011 }
2012 panic!("Expected Join with Filter on right side");
2013 }
2014
2015 #[test]
2016 fn test_filter_not_pushed_when_uses_both_join_sides() {
2017 let optimizer = Optimizer::new();
2018
2019 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
2021 predicate: LogicalExpression::Binary {
2022 left: Box::new(LogicalExpression::Property {
2023 variable: "a".to_string(),
2024 property: "id".to_string(),
2025 }),
2026 op: BinaryOp::Eq,
2027 right: Box::new(LogicalExpression::Property {
2028 variable: "b".to_string(),
2029 property: "a_id".to_string(),
2030 }),
2031 },
2032 pushdown_hint: None,
2033 input: Box::new(LogicalOperator::Join(JoinOp {
2034 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2035 variable: "a".to_string(),
2036 label: None,
2037 input: None,
2038 })),
2039 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2040 variable: "b".to_string(),
2041 label: None,
2042 input: None,
2043 })),
2044 join_type: JoinType::Inner,
2045 conditions: vec![],
2046 })),
2047 }));
2048
2049 let optimized = optimizer.optimize(plan).unwrap();
2050
2051 if let LogicalOperator::Filter(filter) = &optimized.root
2053 && let LogicalOperator::Join(_) = filter.input.as_ref()
2054 {
2055 return;
2056 }
2057 panic!("Expected Filter -> Join structure");
2058 }
2059
2060 #[test]
2063 fn test_extract_variables_from_variable() {
2064 let optimizer = Optimizer::new();
2065 let expr = LogicalExpression::Variable("x".to_string());
2066 let vars = optimizer.extract_variables(&expr);
2067 assert_eq!(vars.len(), 1);
2068 assert!(vars.contains("x"));
2069 }
2070
2071 #[test]
2072 fn test_extract_variables_from_unary() {
2073 let optimizer = Optimizer::new();
2074 let expr = LogicalExpression::Unary {
2075 op: UnaryOp::Not,
2076 operand: Box::new(LogicalExpression::Variable("x".to_string())),
2077 };
2078 let vars = optimizer.extract_variables(&expr);
2079 assert_eq!(vars.len(), 1);
2080 assert!(vars.contains("x"));
2081 }
2082
2083 #[test]
2084 fn test_extract_variables_from_function_call() {
2085 let optimizer = Optimizer::new();
2086 let expr = LogicalExpression::FunctionCall {
2087 name: "length".to_string(),
2088 args: vec![
2089 LogicalExpression::Variable("a".to_string()),
2090 LogicalExpression::Variable("b".to_string()),
2091 ],
2092 distinct: false,
2093 };
2094 let vars = optimizer.extract_variables(&expr);
2095 assert_eq!(vars.len(), 2);
2096 assert!(vars.contains("a"));
2097 assert!(vars.contains("b"));
2098 }
2099
2100 #[test]
2101 fn test_extract_variables_from_list() {
2102 let optimizer = Optimizer::new();
2103 let expr = LogicalExpression::List(vec![
2104 LogicalExpression::Variable("a".to_string()),
2105 LogicalExpression::Literal(Value::Int64(1)),
2106 LogicalExpression::Variable("b".to_string()),
2107 ]);
2108 let vars = optimizer.extract_variables(&expr);
2109 assert_eq!(vars.len(), 2);
2110 assert!(vars.contains("a"));
2111 assert!(vars.contains("b"));
2112 }
2113
2114 #[test]
2115 fn test_extract_variables_from_map() {
2116 let optimizer = Optimizer::new();
2117 let expr = LogicalExpression::Map(vec![
2118 (
2119 "key1".to_string(),
2120 LogicalExpression::Variable("a".to_string()),
2121 ),
2122 (
2123 "key2".to_string(),
2124 LogicalExpression::Variable("b".to_string()),
2125 ),
2126 ]);
2127 let vars = optimizer.extract_variables(&expr);
2128 assert_eq!(vars.len(), 2);
2129 assert!(vars.contains("a"));
2130 assert!(vars.contains("b"));
2131 }
2132
2133 #[test]
2134 fn test_extract_variables_from_index_access() {
2135 let optimizer = Optimizer::new();
2136 let expr = LogicalExpression::IndexAccess {
2137 base: Box::new(LogicalExpression::Variable("list".to_string())),
2138 index: Box::new(LogicalExpression::Variable("idx".to_string())),
2139 };
2140 let vars = optimizer.extract_variables(&expr);
2141 assert_eq!(vars.len(), 2);
2142 assert!(vars.contains("list"));
2143 assert!(vars.contains("idx"));
2144 }
2145
2146 #[test]
2147 fn test_extract_variables_from_slice_access() {
2148 let optimizer = Optimizer::new();
2149 let expr = LogicalExpression::SliceAccess {
2150 base: Box::new(LogicalExpression::Variable("list".to_string())),
2151 start: Some(Box::new(LogicalExpression::Variable("s".to_string()))),
2152 end: Some(Box::new(LogicalExpression::Variable("e".to_string()))),
2153 };
2154 let vars = optimizer.extract_variables(&expr);
2155 assert_eq!(vars.len(), 3);
2156 assert!(vars.contains("list"));
2157 assert!(vars.contains("s"));
2158 assert!(vars.contains("e"));
2159 }
2160
2161 #[test]
2162 fn test_extract_variables_from_case() {
2163 let optimizer = Optimizer::new();
2164 let expr = LogicalExpression::Case {
2165 operand: Some(Box::new(LogicalExpression::Variable("x".to_string()))),
2166 when_clauses: vec![(
2167 LogicalExpression::Literal(Value::Int64(1)),
2168 LogicalExpression::Variable("a".to_string()),
2169 )],
2170 else_clause: Some(Box::new(LogicalExpression::Variable("b".to_string()))),
2171 };
2172 let vars = optimizer.extract_variables(&expr);
2173 assert_eq!(vars.len(), 3);
2174 assert!(vars.contains("x"));
2175 assert!(vars.contains("a"));
2176 assert!(vars.contains("b"));
2177 }
2178
2179 #[test]
2180 fn test_extract_variables_from_labels() {
2181 let optimizer = Optimizer::new();
2182 let expr = LogicalExpression::Labels("n".to_string());
2183 let vars = optimizer.extract_variables(&expr);
2184 assert_eq!(vars.len(), 1);
2185 assert!(vars.contains("n"));
2186 }
2187
2188 #[test]
2189 fn test_extract_variables_from_type() {
2190 let optimizer = Optimizer::new();
2191 let expr = LogicalExpression::Type("e".to_string());
2192 let vars = optimizer.extract_variables(&expr);
2193 assert_eq!(vars.len(), 1);
2194 assert!(vars.contains("e"));
2195 }
2196
2197 #[test]
2198 fn test_extract_variables_from_id() {
2199 let optimizer = Optimizer::new();
2200 let expr = LogicalExpression::Id("n".to_string());
2201 let vars = optimizer.extract_variables(&expr);
2202 assert_eq!(vars.len(), 1);
2203 assert!(vars.contains("n"));
2204 }
2205
2206 #[test]
2207 fn test_extract_variables_from_list_comprehension() {
2208 let optimizer = Optimizer::new();
2209 let expr = LogicalExpression::ListComprehension {
2210 variable: "x".to_string(),
2211 list_expr: Box::new(LogicalExpression::Variable("items".to_string())),
2212 filter_expr: Some(Box::new(LogicalExpression::Variable("pred".to_string()))),
2213 map_expr: Box::new(LogicalExpression::Variable("result".to_string())),
2214 };
2215 let vars = optimizer.extract_variables(&expr);
2216 assert!(vars.contains("items"));
2217 assert!(vars.contains("pred"));
2218 assert!(vars.contains("result"));
2219 }
2220
2221 #[test]
2222 fn test_extract_variables_from_literal_and_parameter() {
2223 let optimizer = Optimizer::new();
2224
2225 let literal = LogicalExpression::Literal(Value::Int64(42));
2226 assert!(optimizer.extract_variables(&literal).is_empty());
2227
2228 let param = LogicalExpression::Parameter("p".to_string());
2229 assert!(optimizer.extract_variables(¶m).is_empty());
2230 }
2231
2232 #[test]
2235 fn test_recursive_filter_pushdown_through_skip() {
2236 let optimizer = Optimizer::new();
2237
2238 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
2239 items: vec![ReturnItem {
2240 expression: LogicalExpression::Variable("n".to_string()),
2241 alias: None,
2242 }],
2243 distinct: false,
2244 input: Box::new(LogicalOperator::Filter(FilterOp {
2245 predicate: LogicalExpression::Literal(Value::Bool(true)),
2246 pushdown_hint: None,
2247 input: Box::new(LogicalOperator::Skip(SkipOp {
2248 count: 5.into(),
2249 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2250 variable: "n".to_string(),
2251 label: None,
2252 input: None,
2253 })),
2254 })),
2255 })),
2256 }));
2257
2258 let optimized = optimizer.optimize(plan).unwrap();
2259
2260 assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
2262 }
2263
2264 #[test]
2265 fn test_nested_filter_pushdown() {
2266 let optimizer = Optimizer::new();
2267
2268 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
2270 items: vec![ReturnItem {
2271 expression: LogicalExpression::Variable("n".to_string()),
2272 alias: None,
2273 }],
2274 distinct: false,
2275 input: Box::new(LogicalOperator::Filter(FilterOp {
2276 predicate: LogicalExpression::Binary {
2277 left: Box::new(LogicalExpression::Property {
2278 variable: "n".to_string(),
2279 property: "x".to_string(),
2280 }),
2281 op: BinaryOp::Gt,
2282 right: Box::new(LogicalExpression::Literal(Value::Int64(1))),
2283 },
2284 pushdown_hint: None,
2285 input: Box::new(LogicalOperator::Filter(FilterOp {
2286 predicate: LogicalExpression::Binary {
2287 left: Box::new(LogicalExpression::Property {
2288 variable: "n".to_string(),
2289 property: "y".to_string(),
2290 }),
2291 op: BinaryOp::Lt,
2292 right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
2293 },
2294 pushdown_hint: None,
2295 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2296 variable: "n".to_string(),
2297 label: None,
2298 input: None,
2299 })),
2300 })),
2301 })),
2302 }));
2303
2304 let optimized = optimizer.optimize(plan).unwrap();
2305 assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
2306 }
2307
2308 #[test]
2309 fn test_cyclic_join_produces_multi_way_join() {
2310 use crate::query::plan::JoinCondition;
2311
2312 let scan_a = LogicalOperator::NodeScan(NodeScanOp {
2314 variable: "a".to_string(),
2315 label: Some("Person".to_string()),
2316 input: None,
2317 });
2318 let scan_b = LogicalOperator::NodeScan(NodeScanOp {
2319 variable: "b".to_string(),
2320 label: Some("Person".to_string()),
2321 input: None,
2322 });
2323 let scan_c = LogicalOperator::NodeScan(NodeScanOp {
2324 variable: "c".to_string(),
2325 label: Some("Person".to_string()),
2326 input: None,
2327 });
2328
2329 let join_ab = LogicalOperator::Join(JoinOp {
2331 left: Box::new(scan_a),
2332 right: Box::new(scan_b),
2333 join_type: JoinType::Inner,
2334 conditions: vec![JoinCondition {
2335 left: LogicalExpression::Variable("a".to_string()),
2336 right: LogicalExpression::Variable("b".to_string()),
2337 }],
2338 });
2339
2340 let join_abc = LogicalOperator::Join(JoinOp {
2341 left: Box::new(join_ab),
2342 right: Box::new(scan_c),
2343 join_type: JoinType::Inner,
2344 conditions: vec![
2345 JoinCondition {
2346 left: LogicalExpression::Variable("b".to_string()),
2347 right: LogicalExpression::Variable("c".to_string()),
2348 },
2349 JoinCondition {
2350 left: LogicalExpression::Variable("c".to_string()),
2351 right: LogicalExpression::Variable("a".to_string()),
2352 },
2353 ],
2354 });
2355
2356 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
2357 items: vec![ReturnItem {
2358 expression: LogicalExpression::Variable("a".to_string()),
2359 alias: None,
2360 }],
2361 distinct: false,
2362 input: Box::new(join_abc),
2363 }));
2364
2365 let mut optimizer = Optimizer::new();
2366 optimizer
2367 .card_estimator
2368 .add_table_stats("Person", cardinality::TableStats::new(1000));
2369
2370 let optimized = optimizer.optimize(plan).unwrap();
2371
2372 fn has_multi_way_join(op: &LogicalOperator) -> bool {
2374 match op {
2375 LogicalOperator::MultiWayJoin(_) => true,
2376 LogicalOperator::Return(ret) => has_multi_way_join(&ret.input),
2377 LogicalOperator::Filter(f) => has_multi_way_join(&f.input),
2378 LogicalOperator::Project(p) => has_multi_way_join(&p.input),
2379 _ => false,
2380 }
2381 }
2382
2383 assert!(
2384 has_multi_way_join(&optimized.root),
2385 "Expected MultiWayJoin for cyclic triangle pattern"
2386 );
2387 }
2388
2389 #[test]
2390 fn test_acyclic_join_uses_binary_joins() {
2391 use crate::query::plan::JoinCondition;
2392
2393 let scan_a = LogicalOperator::NodeScan(NodeScanOp {
2395 variable: "a".to_string(),
2396 label: Some("Person".to_string()),
2397 input: None,
2398 });
2399 let scan_b = LogicalOperator::NodeScan(NodeScanOp {
2400 variable: "b".to_string(),
2401 label: Some("Person".to_string()),
2402 input: None,
2403 });
2404 let scan_c = LogicalOperator::NodeScan(NodeScanOp {
2405 variable: "c".to_string(),
2406 label: Some("Company".to_string()),
2407 input: None,
2408 });
2409
2410 let join_ab = LogicalOperator::Join(JoinOp {
2411 left: Box::new(scan_a),
2412 right: Box::new(scan_b),
2413 join_type: JoinType::Inner,
2414 conditions: vec![JoinCondition {
2415 left: LogicalExpression::Variable("a".to_string()),
2416 right: LogicalExpression::Variable("b".to_string()),
2417 }],
2418 });
2419
2420 let join_abc = LogicalOperator::Join(JoinOp {
2421 left: Box::new(join_ab),
2422 right: Box::new(scan_c),
2423 join_type: JoinType::Inner,
2424 conditions: vec![JoinCondition {
2425 left: LogicalExpression::Variable("b".to_string()),
2426 right: LogicalExpression::Variable("c".to_string()),
2427 }],
2428 });
2429
2430 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
2431 items: vec![ReturnItem {
2432 expression: LogicalExpression::Variable("a".to_string()),
2433 alias: None,
2434 }],
2435 distinct: false,
2436 input: Box::new(join_abc),
2437 }));
2438
2439 let mut optimizer = Optimizer::new();
2440 optimizer
2441 .card_estimator
2442 .add_table_stats("Person", cardinality::TableStats::new(1000));
2443 optimizer
2444 .card_estimator
2445 .add_table_stats("Company", cardinality::TableStats::new(100));
2446
2447 let optimized = optimizer.optimize(plan).unwrap();
2448
2449 fn has_multi_way_join(op: &LogicalOperator) -> bool {
2451 match op {
2452 LogicalOperator::MultiWayJoin(_) => true,
2453 LogicalOperator::Return(ret) => has_multi_way_join(&ret.input),
2454 LogicalOperator::Filter(f) => has_multi_way_join(&f.input),
2455 LogicalOperator::Project(p) => has_multi_way_join(&p.input),
2456 LogicalOperator::Join(j) => {
2457 has_multi_way_join(&j.left) || has_multi_way_join(&j.right)
2458 }
2459 _ => false,
2460 }
2461 }
2462
2463 assert!(
2464 !has_multi_way_join(&optimized.root),
2465 "Acyclic join should NOT produce MultiWayJoin"
2466 );
2467 }
2468}