1pub mod cardinality;
15pub mod cost;
16pub mod join_order;
17
18pub use cardinality::{
19 CardinalityEstimator, ColumnStats, EstimationLog, SelectivityConfig, TableStats,
20};
21pub use cost::{Cost, CostModel};
22pub use join_order::{BitSet, DPccp, JoinGraph, JoinGraphBuilder, JoinPlan};
23
24use crate::query::plan::{
25 FilterOp, JoinCondition, LogicalExpression, LogicalOperator, LogicalPlan, MultiWayJoinOp,
26};
27use grafeo_common::utils::error::Result;
28use std::collections::HashSet;
29
30#[derive(Debug, Clone)]
32struct JoinInfo {
33 left_var: String,
34 right_var: String,
35 left_expr: LogicalExpression,
36 right_expr: LogicalExpression,
37}
38
39#[derive(Debug, Clone, PartialEq, Eq, Hash)]
41enum RequiredColumn {
42 Variable(String),
44 Property(String, String),
46}
47
48pub struct Optimizer {
53 enable_filter_pushdown: bool,
55 enable_join_reorder: bool,
57 enable_projection_pushdown: bool,
59 cost_model: CostModel,
61 card_estimator: CardinalityEstimator,
63}
64
65impl Optimizer {
66 #[must_use]
68 pub fn new() -> Self {
69 Self {
70 enable_filter_pushdown: true,
71 enable_join_reorder: true,
72 enable_projection_pushdown: true,
73 cost_model: CostModel::new(),
74 card_estimator: CardinalityEstimator::new(),
75 }
76 }
77
78 #[must_use]
84 pub fn from_store(store: &grafeo_core::graph::lpg::LpgStore) -> Self {
85 store.ensure_statistics_fresh();
86 let stats = store.statistics();
87 Self::from_statistics(&stats)
88 }
89
90 #[must_use]
97 pub fn from_graph_store(store: &dyn grafeo_core::graph::GraphStore) -> Self {
98 let stats = store.statistics();
99 Self::from_statistics(&stats)
100 }
101
102 #[must_use]
107 fn from_statistics(stats: &grafeo_core::statistics::Statistics) -> Self {
108 let estimator = CardinalityEstimator::from_statistics(stats);
109
110 let avg_fanout = if stats.total_nodes > 0 {
111 (stats.total_edges as f64 / stats.total_nodes as f64).max(1.0)
112 } else {
113 10.0
114 };
115
116 let edge_type_degrees: std::collections::HashMap<String, (f64, f64)> = stats
117 .edge_types
118 .iter()
119 .map(|(name, et)| (name.clone(), (et.avg_out_degree, et.avg_in_degree)))
120 .collect();
121
122 let label_cardinalities: std::collections::HashMap<String, u64> = stats
123 .labels
124 .iter()
125 .map(|(name, ls)| (name.clone(), ls.node_count))
126 .collect();
127
128 Self {
129 enable_filter_pushdown: true,
130 enable_join_reorder: true,
131 enable_projection_pushdown: true,
132 cost_model: CostModel::new()
133 .with_avg_fanout(avg_fanout)
134 .with_edge_type_degrees(edge_type_degrees)
135 .with_label_cardinalities(label_cardinalities)
136 .with_graph_totals(stats.total_nodes, stats.total_edges),
137 card_estimator: estimator,
138 }
139 }
140
141 pub fn with_filter_pushdown(mut self, enabled: bool) -> Self {
143 self.enable_filter_pushdown = enabled;
144 self
145 }
146
147 pub fn with_join_reorder(mut self, enabled: bool) -> Self {
149 self.enable_join_reorder = enabled;
150 self
151 }
152
153 pub fn with_projection_pushdown(mut self, enabled: bool) -> Self {
155 self.enable_projection_pushdown = enabled;
156 self
157 }
158
159 pub fn with_cost_model(mut self, cost_model: CostModel) -> Self {
161 self.cost_model = cost_model;
162 self
163 }
164
165 pub fn with_cardinality_estimator(mut self, estimator: CardinalityEstimator) -> Self {
167 self.card_estimator = estimator;
168 self
169 }
170
171 pub fn with_selectivity_config(mut self, config: SelectivityConfig) -> Self {
173 self.card_estimator = CardinalityEstimator::with_selectivity_config(config);
174 self
175 }
176
177 pub fn cost_model(&self) -> &CostModel {
179 &self.cost_model
180 }
181
182 pub fn cardinality_estimator(&self) -> &CardinalityEstimator {
184 &self.card_estimator
185 }
186
187 pub fn estimate_cost(&self, plan: &LogicalPlan) -> Cost {
193 self.cost_model
194 .estimate_tree(&plan.root, &self.card_estimator)
195 }
196
197 pub fn estimate_cardinality(&self, plan: &LogicalPlan) -> f64 {
199 self.card_estimator.estimate(&plan.root)
200 }
201
202 pub fn optimize(&self, plan: LogicalPlan) -> Result<LogicalPlan> {
208 let mut root = plan.root;
209
210 if self.enable_filter_pushdown {
212 root = self.push_filters_down(root);
213 }
214
215 if self.enable_join_reorder {
216 root = self.reorder_joins(root);
217 }
218
219 if self.enable_projection_pushdown {
220 root = self.push_projections_down(root);
221 }
222
223 Ok(LogicalPlan {
224 root,
225 explain: plan.explain,
226 profile: plan.profile,
227 })
228 }
229
230 fn push_projections_down(&self, op: LogicalOperator) -> LogicalOperator {
237 let required = self.collect_required_columns(&op);
239
240 self.push_projections_recursive(op, &required)
242 }
243
244 fn collect_required_columns(&self, op: &LogicalOperator) -> HashSet<RequiredColumn> {
246 let mut required = HashSet::new();
247 Self::collect_required_recursive(op, &mut required);
248 required
249 }
250
251 fn collect_required_recursive(op: &LogicalOperator, required: &mut HashSet<RequiredColumn>) {
253 match op {
254 LogicalOperator::Return(ret) => {
255 for item in &ret.items {
256 Self::collect_from_expression(&item.expression, required);
257 }
258 Self::collect_required_recursive(&ret.input, required);
259 }
260 LogicalOperator::Project(proj) => {
261 for p in &proj.projections {
262 Self::collect_from_expression(&p.expression, required);
263 }
264 Self::collect_required_recursive(&proj.input, required);
265 }
266 LogicalOperator::Filter(filter) => {
267 Self::collect_from_expression(&filter.predicate, required);
268 Self::collect_required_recursive(&filter.input, required);
269 }
270 LogicalOperator::Sort(sort) => {
271 for key in &sort.keys {
272 Self::collect_from_expression(&key.expression, required);
273 }
274 Self::collect_required_recursive(&sort.input, required);
275 }
276 LogicalOperator::Aggregate(agg) => {
277 for expr in &agg.group_by {
278 Self::collect_from_expression(expr, required);
279 }
280 for agg_expr in &agg.aggregates {
281 if let Some(ref expr) = agg_expr.expression {
282 Self::collect_from_expression(expr, required);
283 }
284 }
285 if let Some(ref having) = agg.having {
286 Self::collect_from_expression(having, required);
287 }
288 Self::collect_required_recursive(&agg.input, required);
289 }
290 LogicalOperator::Join(join) => {
291 for cond in &join.conditions {
292 Self::collect_from_expression(&cond.left, required);
293 Self::collect_from_expression(&cond.right, required);
294 }
295 Self::collect_required_recursive(&join.left, required);
296 Self::collect_required_recursive(&join.right, required);
297 }
298 LogicalOperator::Expand(expand) => {
299 required.insert(RequiredColumn::Variable(expand.from_variable.clone()));
301 required.insert(RequiredColumn::Variable(expand.to_variable.clone()));
302 if let Some(ref edge_var) = expand.edge_variable {
303 required.insert(RequiredColumn::Variable(edge_var.clone()));
304 }
305 Self::collect_required_recursive(&expand.input, required);
306 }
307 LogicalOperator::Limit(limit) => {
308 Self::collect_required_recursive(&limit.input, required);
309 }
310 LogicalOperator::Skip(skip) => {
311 Self::collect_required_recursive(&skip.input, required);
312 }
313 LogicalOperator::Distinct(distinct) => {
314 Self::collect_required_recursive(&distinct.input, required);
315 }
316 LogicalOperator::NodeScan(scan) => {
317 required.insert(RequiredColumn::Variable(scan.variable.clone()));
318 }
319 LogicalOperator::EdgeScan(scan) => {
320 required.insert(RequiredColumn::Variable(scan.variable.clone()));
321 }
322 LogicalOperator::MultiWayJoin(mwj) => {
323 for cond in &mwj.conditions {
324 Self::collect_from_expression(&cond.left, required);
325 Self::collect_from_expression(&cond.right, required);
326 }
327 for input in &mwj.inputs {
328 Self::collect_required_recursive(input, required);
329 }
330 }
331 _ => {}
332 }
333 }
334
335 fn collect_from_expression(expr: &LogicalExpression, required: &mut HashSet<RequiredColumn>) {
337 match expr {
338 LogicalExpression::Variable(var) => {
339 required.insert(RequiredColumn::Variable(var.clone()));
340 }
341 LogicalExpression::Property { variable, property } => {
342 required.insert(RequiredColumn::Property(variable.clone(), property.clone()));
343 required.insert(RequiredColumn::Variable(variable.clone()));
344 }
345 LogicalExpression::Binary { left, right, .. } => {
346 Self::collect_from_expression(left, required);
347 Self::collect_from_expression(right, required);
348 }
349 LogicalExpression::Unary { operand, .. } => {
350 Self::collect_from_expression(operand, required);
351 }
352 LogicalExpression::FunctionCall { args, .. } => {
353 for arg in args {
354 Self::collect_from_expression(arg, required);
355 }
356 }
357 LogicalExpression::List(items) => {
358 for item in items {
359 Self::collect_from_expression(item, required);
360 }
361 }
362 LogicalExpression::Map(pairs) => {
363 for (_, value) in pairs {
364 Self::collect_from_expression(value, required);
365 }
366 }
367 LogicalExpression::IndexAccess { base, index } => {
368 Self::collect_from_expression(base, required);
369 Self::collect_from_expression(index, required);
370 }
371 LogicalExpression::SliceAccess { base, start, end } => {
372 Self::collect_from_expression(base, required);
373 if let Some(s) = start {
374 Self::collect_from_expression(s, required);
375 }
376 if let Some(e) = end {
377 Self::collect_from_expression(e, required);
378 }
379 }
380 LogicalExpression::Case {
381 operand,
382 when_clauses,
383 else_clause,
384 } => {
385 if let Some(op) = operand {
386 Self::collect_from_expression(op, required);
387 }
388 for (cond, result) in when_clauses {
389 Self::collect_from_expression(cond, required);
390 Self::collect_from_expression(result, required);
391 }
392 if let Some(else_expr) = else_clause {
393 Self::collect_from_expression(else_expr, required);
394 }
395 }
396 LogicalExpression::Labels(var)
397 | LogicalExpression::Type(var)
398 | LogicalExpression::Id(var) => {
399 required.insert(RequiredColumn::Variable(var.clone()));
400 }
401 LogicalExpression::ListComprehension {
402 list_expr,
403 filter_expr,
404 map_expr,
405 ..
406 } => {
407 Self::collect_from_expression(list_expr, required);
408 if let Some(filter) = filter_expr {
409 Self::collect_from_expression(filter, required);
410 }
411 Self::collect_from_expression(map_expr, required);
412 }
413 _ => {}
414 }
415 }
416
417 fn push_projections_recursive(
419 &self,
420 op: LogicalOperator,
421 required: &HashSet<RequiredColumn>,
422 ) -> LogicalOperator {
423 match op {
424 LogicalOperator::Return(mut ret) => {
425 ret.input = Box::new(self.push_projections_recursive(*ret.input, required));
426 LogicalOperator::Return(ret)
427 }
428 LogicalOperator::Project(mut proj) => {
429 proj.input = Box::new(self.push_projections_recursive(*proj.input, required));
430 LogicalOperator::Project(proj)
431 }
432 LogicalOperator::Filter(mut filter) => {
433 filter.input = Box::new(self.push_projections_recursive(*filter.input, required));
434 LogicalOperator::Filter(filter)
435 }
436 LogicalOperator::Sort(mut sort) => {
437 sort.input = Box::new(self.push_projections_recursive(*sort.input, required));
440 LogicalOperator::Sort(sort)
441 }
442 LogicalOperator::Aggregate(mut agg) => {
443 agg.input = Box::new(self.push_projections_recursive(*agg.input, required));
444 LogicalOperator::Aggregate(agg)
445 }
446 LogicalOperator::Join(mut join) => {
447 let left_vars = self.collect_output_variables(&join.left);
450 let right_vars = self.collect_output_variables(&join.right);
451
452 let left_required: HashSet<_> = required
454 .iter()
455 .filter(|c| match c {
456 RequiredColumn::Variable(v) => left_vars.contains(v),
457 RequiredColumn::Property(v, _) => left_vars.contains(v),
458 })
459 .cloned()
460 .collect();
461
462 let right_required: HashSet<_> = required
463 .iter()
464 .filter(|c| match c {
465 RequiredColumn::Variable(v) => right_vars.contains(v),
466 RequiredColumn::Property(v, _) => right_vars.contains(v),
467 })
468 .cloned()
469 .collect();
470
471 join.left = Box::new(self.push_projections_recursive(*join.left, &left_required));
472 join.right =
473 Box::new(self.push_projections_recursive(*join.right, &right_required));
474 LogicalOperator::Join(join)
475 }
476 LogicalOperator::Expand(mut expand) => {
477 expand.input = Box::new(self.push_projections_recursive(*expand.input, required));
478 LogicalOperator::Expand(expand)
479 }
480 LogicalOperator::Limit(mut limit) => {
481 limit.input = Box::new(self.push_projections_recursive(*limit.input, required));
482 LogicalOperator::Limit(limit)
483 }
484 LogicalOperator::Skip(mut skip) => {
485 skip.input = Box::new(self.push_projections_recursive(*skip.input, required));
486 LogicalOperator::Skip(skip)
487 }
488 LogicalOperator::Distinct(mut distinct) => {
489 distinct.input =
490 Box::new(self.push_projections_recursive(*distinct.input, required));
491 LogicalOperator::Distinct(distinct)
492 }
493 LogicalOperator::MapCollect(mut mc) => {
494 mc.input = Box::new(self.push_projections_recursive(*mc.input, required));
495 LogicalOperator::MapCollect(mc)
496 }
497 LogicalOperator::MultiWayJoin(mut mwj) => {
498 mwj.inputs = mwj
499 .inputs
500 .into_iter()
501 .map(|input| self.push_projections_recursive(input, required))
502 .collect();
503 LogicalOperator::MultiWayJoin(mwj)
504 }
505 other => other,
506 }
507 }
508
509 fn reorder_joins(&self, op: LogicalOperator) -> LogicalOperator {
516 let op = self.reorder_joins_recursive(op);
518
519 if let Some((relations, conditions)) = self.extract_join_tree(&op)
521 && relations.len() >= 2
522 && let Some(optimized) = self.optimize_join_order(&relations, &conditions)
523 {
524 return optimized;
525 }
526
527 op
528 }
529
530 fn reorder_joins_recursive(&self, op: LogicalOperator) -> LogicalOperator {
532 match op {
533 LogicalOperator::Return(mut ret) => {
534 ret.input = Box::new(self.reorder_joins(*ret.input));
535 LogicalOperator::Return(ret)
536 }
537 LogicalOperator::Project(mut proj) => {
538 proj.input = Box::new(self.reorder_joins(*proj.input));
539 LogicalOperator::Project(proj)
540 }
541 LogicalOperator::Filter(mut filter) => {
542 filter.input = Box::new(self.reorder_joins(*filter.input));
543 LogicalOperator::Filter(filter)
544 }
545 LogicalOperator::Limit(mut limit) => {
546 limit.input = Box::new(self.reorder_joins(*limit.input));
547 LogicalOperator::Limit(limit)
548 }
549 LogicalOperator::Skip(mut skip) => {
550 skip.input = Box::new(self.reorder_joins(*skip.input));
551 LogicalOperator::Skip(skip)
552 }
553 LogicalOperator::Sort(mut sort) => {
554 sort.input = Box::new(self.reorder_joins(*sort.input));
555 LogicalOperator::Sort(sort)
556 }
557 LogicalOperator::Distinct(mut distinct) => {
558 distinct.input = Box::new(self.reorder_joins(*distinct.input));
559 LogicalOperator::Distinct(distinct)
560 }
561 LogicalOperator::Aggregate(mut agg) => {
562 agg.input = Box::new(self.reorder_joins(*agg.input));
563 LogicalOperator::Aggregate(agg)
564 }
565 LogicalOperator::Expand(mut expand) => {
566 expand.input = Box::new(self.reorder_joins(*expand.input));
567 LogicalOperator::Expand(expand)
568 }
569 LogicalOperator::MapCollect(mut mc) => {
570 mc.input = Box::new(self.reorder_joins(*mc.input));
571 LogicalOperator::MapCollect(mc)
572 }
573 LogicalOperator::MultiWayJoin(mut mwj) => {
574 mwj.inputs = mwj
575 .inputs
576 .into_iter()
577 .map(|input| self.reorder_joins(input))
578 .collect();
579 LogicalOperator::MultiWayJoin(mwj)
580 }
581 other => other,
583 }
584 }
585
586 fn extract_join_tree(
590 &self,
591 op: &LogicalOperator,
592 ) -> Option<(Vec<(String, LogicalOperator)>, Vec<JoinInfo>)> {
593 let mut relations = Vec::new();
594 let mut join_conditions = Vec::new();
595
596 if !self.collect_join_tree(op, &mut relations, &mut join_conditions) {
597 return None;
598 }
599
600 if relations.len() < 2 {
601 return None;
602 }
603
604 Some((relations, join_conditions))
605 }
606
607 fn collect_join_tree(
611 &self,
612 op: &LogicalOperator,
613 relations: &mut Vec<(String, LogicalOperator)>,
614 conditions: &mut Vec<JoinInfo>,
615 ) -> bool {
616 match op {
617 LogicalOperator::Join(join) => {
618 let left_ok = self.collect_join_tree(&join.left, relations, conditions);
620 let right_ok = self.collect_join_tree(&join.right, relations, conditions);
621
622 for cond in &join.conditions {
624 if let (Some(left_var), Some(right_var)) = (
625 self.extract_variable_from_expr(&cond.left),
626 self.extract_variable_from_expr(&cond.right),
627 ) {
628 conditions.push(JoinInfo {
629 left_var,
630 right_var,
631 left_expr: cond.left.clone(),
632 right_expr: cond.right.clone(),
633 });
634 }
635 }
636
637 left_ok && right_ok
638 }
639 LogicalOperator::NodeScan(scan) => {
640 relations.push((scan.variable.clone(), op.clone()));
641 true
642 }
643 LogicalOperator::EdgeScan(scan) => {
644 relations.push((scan.variable.clone(), op.clone()));
645 true
646 }
647 LogicalOperator::Filter(filter) => {
648 self.collect_join_tree(&filter.input, relations, conditions)
650 }
651 LogicalOperator::Expand(expand) => {
652 relations.push((expand.to_variable.clone(), op.clone()));
655 true
656 }
657 _ => false,
658 }
659 }
660
661 fn extract_variable_from_expr(&self, expr: &LogicalExpression) -> Option<String> {
663 match expr {
664 LogicalExpression::Variable(v) => Some(v.clone()),
665 LogicalExpression::Property { variable, .. } => Some(variable.clone()),
666 _ => None,
667 }
668 }
669
670 fn optimize_join_order(
673 &self,
674 relations: &[(String, LogicalOperator)],
675 conditions: &[JoinInfo],
676 ) -> Option<LogicalOperator> {
677 use join_order::{DPccp, JoinGraphBuilder};
678
679 let mut builder = JoinGraphBuilder::new();
681
682 for (var, relation) in relations {
683 builder.add_relation(var, relation.clone());
684 }
685
686 for cond in conditions {
687 builder.add_join_condition(
688 &cond.left_var,
689 &cond.right_var,
690 cond.left_expr.clone(),
691 cond.right_expr.clone(),
692 );
693 }
694
695 let graph = builder.build();
696
697 if graph.is_cyclic() && relations.len() >= 3 {
702 let mut var_counts: std::collections::HashMap<&str, usize> =
704 std::collections::HashMap::new();
705 for cond in conditions {
706 *var_counts.entry(&cond.left_var).or_default() += 1;
707 *var_counts.entry(&cond.right_var).or_default() += 1;
708 }
709 let shared_variables: Vec<String> = var_counts
710 .into_iter()
711 .filter(|(_, count)| *count >= 2)
712 .map(|(var, _)| var.to_string())
713 .collect();
714
715 let join_conditions: Vec<JoinCondition> = conditions
716 .iter()
717 .map(|c| JoinCondition {
718 left: c.left_expr.clone(),
719 right: c.right_expr.clone(),
720 })
721 .collect();
722
723 return Some(LogicalOperator::MultiWayJoin(MultiWayJoinOp {
724 inputs: relations.iter().map(|(_, rel)| rel.clone()).collect(),
725 conditions: join_conditions,
726 shared_variables,
727 }));
728 }
729
730 let mut dpccp = DPccp::new(&graph, &self.cost_model, &self.card_estimator);
732 let plan = dpccp.optimize()?;
733
734 Some(plan.operator)
735 }
736
737 fn push_filters_down(&self, op: LogicalOperator) -> LogicalOperator {
742 match op {
743 LogicalOperator::Filter(filter) => {
745 let optimized_input = self.push_filters_down(*filter.input);
746 self.try_push_filter_into(filter.predicate, optimized_input)
747 }
748 LogicalOperator::Return(mut ret) => {
750 ret.input = Box::new(self.push_filters_down(*ret.input));
751 LogicalOperator::Return(ret)
752 }
753 LogicalOperator::Project(mut proj) => {
754 proj.input = Box::new(self.push_filters_down(*proj.input));
755 LogicalOperator::Project(proj)
756 }
757 LogicalOperator::Limit(mut limit) => {
758 limit.input = Box::new(self.push_filters_down(*limit.input));
759 LogicalOperator::Limit(limit)
760 }
761 LogicalOperator::Skip(mut skip) => {
762 skip.input = Box::new(self.push_filters_down(*skip.input));
763 LogicalOperator::Skip(skip)
764 }
765 LogicalOperator::Sort(mut sort) => {
766 sort.input = Box::new(self.push_filters_down(*sort.input));
767 LogicalOperator::Sort(sort)
768 }
769 LogicalOperator::Distinct(mut distinct) => {
770 distinct.input = Box::new(self.push_filters_down(*distinct.input));
771 LogicalOperator::Distinct(distinct)
772 }
773 LogicalOperator::Expand(mut expand) => {
774 expand.input = Box::new(self.push_filters_down(*expand.input));
775 LogicalOperator::Expand(expand)
776 }
777 LogicalOperator::Join(mut join) => {
778 join.left = Box::new(self.push_filters_down(*join.left));
779 join.right = Box::new(self.push_filters_down(*join.right));
780 LogicalOperator::Join(join)
781 }
782 LogicalOperator::Aggregate(mut agg) => {
783 agg.input = Box::new(self.push_filters_down(*agg.input));
784 LogicalOperator::Aggregate(agg)
785 }
786 LogicalOperator::MapCollect(mut mc) => {
787 mc.input = Box::new(self.push_filters_down(*mc.input));
788 LogicalOperator::MapCollect(mc)
789 }
790 LogicalOperator::MultiWayJoin(mut mwj) => {
791 mwj.inputs = mwj
792 .inputs
793 .into_iter()
794 .map(|input| self.push_filters_down(input))
795 .collect();
796 LogicalOperator::MultiWayJoin(mwj)
797 }
798 other => other,
800 }
801 }
802
803 fn try_push_filter_into(
808 &self,
809 predicate: LogicalExpression,
810 op: LogicalOperator,
811 ) -> LogicalOperator {
812 match op {
813 LogicalOperator::Project(mut proj) => {
815 let predicate_vars = self.extract_variables(&predicate);
816 let computed_vars = self.extract_projection_aliases(&proj.projections);
817
818 if predicate_vars.is_disjoint(&computed_vars) {
820 proj.input = Box::new(self.try_push_filter_into(predicate, *proj.input));
821 LogicalOperator::Project(proj)
822 } else {
823 LogicalOperator::Filter(FilterOp {
825 predicate,
826 pushdown_hint: None,
827 input: Box::new(LogicalOperator::Project(proj)),
828 })
829 }
830 }
831
832 LogicalOperator::Return(mut ret) => {
834 ret.input = Box::new(self.try_push_filter_into(predicate, *ret.input));
835 LogicalOperator::Return(ret)
836 }
837
838 LogicalOperator::Expand(mut expand) => {
840 let predicate_vars = self.extract_variables(&predicate);
841
842 let mut introduced_vars = vec![&expand.to_variable];
847 if let Some(ref edge_var) = expand.edge_variable {
848 introduced_vars.push(edge_var);
849 }
850 if let Some(ref path_alias) = expand.path_alias {
851 introduced_vars.push(path_alias);
852 }
853
854 let uses_introduced_vars =
856 predicate_vars.iter().any(|v| introduced_vars.contains(&v));
857
858 if !uses_introduced_vars {
859 expand.input = Box::new(self.try_push_filter_into(predicate, *expand.input));
861 LogicalOperator::Expand(expand)
862 } else {
863 LogicalOperator::Filter(FilterOp {
865 predicate,
866 pushdown_hint: None,
867 input: Box::new(LogicalOperator::Expand(expand)),
868 })
869 }
870 }
871
872 LogicalOperator::Join(mut join) => {
874 let predicate_vars = self.extract_variables(&predicate);
875 let left_vars = self.collect_output_variables(&join.left);
876 let right_vars = self.collect_output_variables(&join.right);
877
878 let uses_left = predicate_vars.iter().any(|v| left_vars.contains(v));
879 let uses_right = predicate_vars.iter().any(|v| right_vars.contains(v));
880
881 if uses_left && !uses_right {
882 join.left = Box::new(self.try_push_filter_into(predicate, *join.left));
884 LogicalOperator::Join(join)
885 } else if uses_right && !uses_left {
886 join.right = Box::new(self.try_push_filter_into(predicate, *join.right));
888 LogicalOperator::Join(join)
889 } else {
890 LogicalOperator::Filter(FilterOp {
892 predicate,
893 pushdown_hint: None,
894 input: Box::new(LogicalOperator::Join(join)),
895 })
896 }
897 }
898
899 LogicalOperator::Aggregate(agg) => LogicalOperator::Filter(FilterOp {
901 predicate,
902 pushdown_hint: None,
903 input: Box::new(LogicalOperator::Aggregate(agg)),
904 }),
905
906 LogicalOperator::NodeScan(scan) => LogicalOperator::Filter(FilterOp {
908 predicate,
909 pushdown_hint: None,
910 input: Box::new(LogicalOperator::NodeScan(scan)),
911 }),
912
913 other => LogicalOperator::Filter(FilterOp {
915 predicate,
916 pushdown_hint: None,
917 input: Box::new(other),
918 }),
919 }
920 }
921
922 fn collect_output_variables(&self, op: &LogicalOperator) -> HashSet<String> {
924 let mut vars = HashSet::new();
925 Self::collect_output_variables_recursive(op, &mut vars);
926 vars
927 }
928
929 fn collect_output_variables_recursive(op: &LogicalOperator, vars: &mut HashSet<String>) {
931 match op {
932 LogicalOperator::NodeScan(scan) => {
933 vars.insert(scan.variable.clone());
934 }
935 LogicalOperator::EdgeScan(scan) => {
936 vars.insert(scan.variable.clone());
937 }
938 LogicalOperator::Expand(expand) => {
939 vars.insert(expand.to_variable.clone());
940 if let Some(edge_var) = &expand.edge_variable {
941 vars.insert(edge_var.clone());
942 }
943 Self::collect_output_variables_recursive(&expand.input, vars);
944 }
945 LogicalOperator::Filter(filter) => {
946 Self::collect_output_variables_recursive(&filter.input, vars);
947 }
948 LogicalOperator::Project(proj) => {
949 for p in &proj.projections {
950 if let Some(alias) = &p.alias {
951 vars.insert(alias.clone());
952 }
953 }
954 Self::collect_output_variables_recursive(&proj.input, vars);
955 }
956 LogicalOperator::Join(join) => {
957 Self::collect_output_variables_recursive(&join.left, vars);
958 Self::collect_output_variables_recursive(&join.right, vars);
959 }
960 LogicalOperator::Aggregate(agg) => {
961 for expr in &agg.group_by {
962 Self::collect_variables(expr, vars);
963 }
964 for agg_expr in &agg.aggregates {
965 if let Some(alias) = &agg_expr.alias {
966 vars.insert(alias.clone());
967 }
968 }
969 }
970 LogicalOperator::Return(ret) => {
971 Self::collect_output_variables_recursive(&ret.input, vars);
972 }
973 LogicalOperator::Limit(limit) => {
974 Self::collect_output_variables_recursive(&limit.input, vars);
975 }
976 LogicalOperator::Skip(skip) => {
977 Self::collect_output_variables_recursive(&skip.input, vars);
978 }
979 LogicalOperator::Sort(sort) => {
980 Self::collect_output_variables_recursive(&sort.input, vars);
981 }
982 LogicalOperator::Distinct(distinct) => {
983 Self::collect_output_variables_recursive(&distinct.input, vars);
984 }
985 _ => {}
986 }
987 }
988
989 fn extract_variables(&self, expr: &LogicalExpression) -> HashSet<String> {
991 let mut vars = HashSet::new();
992 Self::collect_variables(expr, &mut vars);
993 vars
994 }
995
996 fn collect_variables(expr: &LogicalExpression, vars: &mut HashSet<String>) {
998 match expr {
999 LogicalExpression::Variable(name) => {
1000 vars.insert(name.clone());
1001 }
1002 LogicalExpression::Property { variable, .. } => {
1003 vars.insert(variable.clone());
1004 }
1005 LogicalExpression::Binary { left, right, .. } => {
1006 Self::collect_variables(left, vars);
1007 Self::collect_variables(right, vars);
1008 }
1009 LogicalExpression::Unary { operand, .. } => {
1010 Self::collect_variables(operand, vars);
1011 }
1012 LogicalExpression::FunctionCall { args, .. } => {
1013 for arg in args {
1014 Self::collect_variables(arg, vars);
1015 }
1016 }
1017 LogicalExpression::List(items) => {
1018 for item in items {
1019 Self::collect_variables(item, vars);
1020 }
1021 }
1022 LogicalExpression::Map(pairs) => {
1023 for (_, value) in pairs {
1024 Self::collect_variables(value, vars);
1025 }
1026 }
1027 LogicalExpression::IndexAccess { base, index } => {
1028 Self::collect_variables(base, vars);
1029 Self::collect_variables(index, vars);
1030 }
1031 LogicalExpression::SliceAccess { base, start, end } => {
1032 Self::collect_variables(base, vars);
1033 if let Some(s) = start {
1034 Self::collect_variables(s, vars);
1035 }
1036 if let Some(e) = end {
1037 Self::collect_variables(e, vars);
1038 }
1039 }
1040 LogicalExpression::Case {
1041 operand,
1042 when_clauses,
1043 else_clause,
1044 } => {
1045 if let Some(op) = operand {
1046 Self::collect_variables(op, vars);
1047 }
1048 for (cond, result) in when_clauses {
1049 Self::collect_variables(cond, vars);
1050 Self::collect_variables(result, vars);
1051 }
1052 if let Some(else_expr) = else_clause {
1053 Self::collect_variables(else_expr, vars);
1054 }
1055 }
1056 LogicalExpression::Labels(var)
1057 | LogicalExpression::Type(var)
1058 | LogicalExpression::Id(var) => {
1059 vars.insert(var.clone());
1060 }
1061 LogicalExpression::Literal(_) | LogicalExpression::Parameter(_) => {}
1062 LogicalExpression::ListComprehension {
1063 list_expr,
1064 filter_expr,
1065 map_expr,
1066 ..
1067 } => {
1068 Self::collect_variables(list_expr, vars);
1069 if let Some(filter) = filter_expr {
1070 Self::collect_variables(filter, vars);
1071 }
1072 Self::collect_variables(map_expr, vars);
1073 }
1074 LogicalExpression::ListPredicate {
1075 list_expr,
1076 predicate,
1077 ..
1078 } => {
1079 Self::collect_variables(list_expr, vars);
1080 Self::collect_variables(predicate, vars);
1081 }
1082 LogicalExpression::ExistsSubquery(_) | LogicalExpression::CountSubquery(_) => {
1083 }
1085 LogicalExpression::PatternComprehension { projection, .. } => {
1086 Self::collect_variables(projection, vars);
1087 }
1088 LogicalExpression::MapProjection { base, entries } => {
1089 vars.insert(base.clone());
1090 for entry in entries {
1091 if let crate::query::plan::MapProjectionEntry::LiteralEntry(_, expr) = entry {
1092 Self::collect_variables(expr, vars);
1093 }
1094 }
1095 }
1096 LogicalExpression::Reduce {
1097 initial,
1098 list,
1099 expression,
1100 ..
1101 } => {
1102 Self::collect_variables(initial, vars);
1103 Self::collect_variables(list, vars);
1104 Self::collect_variables(expression, vars);
1105 }
1106 }
1107 }
1108
1109 fn extract_projection_aliases(
1111 &self,
1112 projections: &[crate::query::plan::Projection],
1113 ) -> HashSet<String> {
1114 projections.iter().filter_map(|p| p.alias.clone()).collect()
1115 }
1116}
1117
1118impl Default for Optimizer {
1119 fn default() -> Self {
1120 Self::new()
1121 }
1122}
1123
1124#[cfg(test)]
1125mod tests {
1126 use super::*;
1127 use crate::query::plan::{
1128 AggregateExpr, AggregateFunction, AggregateOp, BinaryOp, DistinctOp, ExpandDirection,
1129 ExpandOp, JoinOp, JoinType, LimitOp, NodeScanOp, PathMode, ProjectOp, Projection,
1130 ReturnItem, ReturnOp, SkipOp, SortKey, SortOp, SortOrder, UnaryOp,
1131 };
1132 use grafeo_common::types::Value;
1133
1134 #[test]
1135 fn test_optimizer_filter_pushdown_simple() {
1136 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1141 items: vec![ReturnItem {
1142 expression: LogicalExpression::Variable("n".to_string()),
1143 alias: None,
1144 }],
1145 distinct: false,
1146 input: Box::new(LogicalOperator::Filter(FilterOp {
1147 predicate: LogicalExpression::Binary {
1148 left: Box::new(LogicalExpression::Property {
1149 variable: "n".to_string(),
1150 property: "age".to_string(),
1151 }),
1152 op: BinaryOp::Gt,
1153 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1154 },
1155 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1156 variable: "n".to_string(),
1157 label: Some("Person".to_string()),
1158 input: None,
1159 })),
1160 pushdown_hint: None,
1161 })),
1162 }));
1163
1164 let optimizer = Optimizer::new();
1165 let optimized = optimizer.optimize(plan).unwrap();
1166
1167 if let LogicalOperator::Return(ret) = &optimized.root
1169 && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1170 && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1171 {
1172 assert_eq!(scan.variable, "n");
1173 return;
1174 }
1175 panic!("Expected Return -> Filter -> NodeScan structure");
1176 }
1177
1178 #[test]
1179 fn test_optimizer_filter_pushdown_through_expand() {
1180 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1184 items: vec![ReturnItem {
1185 expression: LogicalExpression::Variable("b".to_string()),
1186 alias: None,
1187 }],
1188 distinct: false,
1189 input: Box::new(LogicalOperator::Filter(FilterOp {
1190 predicate: LogicalExpression::Binary {
1191 left: Box::new(LogicalExpression::Property {
1192 variable: "a".to_string(),
1193 property: "age".to_string(),
1194 }),
1195 op: BinaryOp::Gt,
1196 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1197 },
1198 pushdown_hint: None,
1199 input: Box::new(LogicalOperator::Expand(ExpandOp {
1200 from_variable: "a".to_string(),
1201 to_variable: "b".to_string(),
1202 edge_variable: None,
1203 direction: ExpandDirection::Outgoing,
1204 edge_types: vec!["KNOWS".to_string()],
1205 min_hops: 1,
1206 max_hops: Some(1),
1207 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1208 variable: "a".to_string(),
1209 label: Some("Person".to_string()),
1210 input: None,
1211 })),
1212 path_alias: None,
1213 path_mode: PathMode::Walk,
1214 })),
1215 })),
1216 }));
1217
1218 let optimizer = Optimizer::new();
1219 let optimized = optimizer.optimize(plan).unwrap();
1220
1221 if let LogicalOperator::Return(ret) = &optimized.root
1224 && let LogicalOperator::Expand(expand) = ret.input.as_ref()
1225 && let LogicalOperator::Filter(filter) = expand.input.as_ref()
1226 && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1227 {
1228 assert_eq!(scan.variable, "a");
1229 assert_eq!(expand.from_variable, "a");
1230 assert_eq!(expand.to_variable, "b");
1231 return;
1232 }
1233 panic!("Expected Return -> Expand -> Filter -> NodeScan structure");
1234 }
1235
1236 #[test]
1237 fn test_optimizer_filter_not_pushed_through_expand_for_target_var() {
1238 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1242 items: vec![ReturnItem {
1243 expression: LogicalExpression::Variable("a".to_string()),
1244 alias: None,
1245 }],
1246 distinct: false,
1247 input: Box::new(LogicalOperator::Filter(FilterOp {
1248 predicate: LogicalExpression::Binary {
1249 left: Box::new(LogicalExpression::Property {
1250 variable: "b".to_string(),
1251 property: "age".to_string(),
1252 }),
1253 op: BinaryOp::Gt,
1254 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1255 },
1256 pushdown_hint: None,
1257 input: Box::new(LogicalOperator::Expand(ExpandOp {
1258 from_variable: "a".to_string(),
1259 to_variable: "b".to_string(),
1260 edge_variable: None,
1261 direction: ExpandDirection::Outgoing,
1262 edge_types: vec!["KNOWS".to_string()],
1263 min_hops: 1,
1264 max_hops: Some(1),
1265 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1266 variable: "a".to_string(),
1267 label: Some("Person".to_string()),
1268 input: None,
1269 })),
1270 path_alias: None,
1271 path_mode: PathMode::Walk,
1272 })),
1273 })),
1274 }));
1275
1276 let optimizer = Optimizer::new();
1277 let optimized = optimizer.optimize(plan).unwrap();
1278
1279 if let LogicalOperator::Return(ret) = &optimized.root
1282 && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1283 {
1284 if let LogicalExpression::Binary { left, .. } = &filter.predicate
1286 && let LogicalExpression::Property { variable, .. } = left.as_ref()
1287 {
1288 assert_eq!(variable, "b");
1289 }
1290
1291 if let LogicalOperator::Expand(expand) = filter.input.as_ref()
1292 && let LogicalOperator::NodeScan(_) = expand.input.as_ref()
1293 {
1294 return;
1295 }
1296 }
1297 panic!("Expected Return -> Filter -> Expand -> NodeScan structure");
1298 }
1299
1300 #[test]
1301 fn test_optimizer_extract_variables() {
1302 let optimizer = Optimizer::new();
1303
1304 let expr = LogicalExpression::Binary {
1305 left: Box::new(LogicalExpression::Property {
1306 variable: "n".to_string(),
1307 property: "age".to_string(),
1308 }),
1309 op: BinaryOp::Gt,
1310 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1311 };
1312
1313 let vars = optimizer.extract_variables(&expr);
1314 assert_eq!(vars.len(), 1);
1315 assert!(vars.contains("n"));
1316 }
1317
1318 #[test]
1321 fn test_optimizer_default() {
1322 let optimizer = Optimizer::default();
1323 let plan = LogicalPlan::new(LogicalOperator::Empty);
1325 let result = optimizer.optimize(plan);
1326 assert!(result.is_ok());
1327 }
1328
1329 #[test]
1330 fn test_optimizer_with_filter_pushdown_disabled() {
1331 let optimizer = Optimizer::new().with_filter_pushdown(false);
1332
1333 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1334 items: vec![ReturnItem {
1335 expression: LogicalExpression::Variable("n".to_string()),
1336 alias: None,
1337 }],
1338 distinct: false,
1339 input: Box::new(LogicalOperator::Filter(FilterOp {
1340 predicate: LogicalExpression::Literal(Value::Bool(true)),
1341 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1342 variable: "n".to_string(),
1343 label: None,
1344 input: None,
1345 })),
1346 pushdown_hint: None,
1347 })),
1348 }));
1349
1350 let optimized = optimizer.optimize(plan).unwrap();
1351 if let LogicalOperator::Return(ret) = &optimized.root
1353 && let LogicalOperator::Filter(_) = ret.input.as_ref()
1354 {
1355 return;
1356 }
1357 panic!("Expected unchanged structure");
1358 }
1359
1360 #[test]
1361 fn test_optimizer_with_join_reorder_disabled() {
1362 let optimizer = Optimizer::new().with_join_reorder(false);
1363 assert!(
1364 optimizer
1365 .optimize(LogicalPlan::new(LogicalOperator::Empty))
1366 .is_ok()
1367 );
1368 }
1369
1370 #[test]
1371 fn test_optimizer_with_cost_model() {
1372 let cost_model = CostModel::new();
1373 let optimizer = Optimizer::new().with_cost_model(cost_model);
1374 assert!(
1375 optimizer
1376 .cost_model()
1377 .estimate(&LogicalOperator::Empty, 0.0)
1378 .total()
1379 < 0.001
1380 );
1381 }
1382
1383 #[test]
1384 fn test_optimizer_with_cardinality_estimator() {
1385 let mut estimator = CardinalityEstimator::new();
1386 estimator.add_table_stats("Test", TableStats::new(500));
1387 let optimizer = Optimizer::new().with_cardinality_estimator(estimator);
1388
1389 let scan = LogicalOperator::NodeScan(NodeScanOp {
1390 variable: "n".to_string(),
1391 label: Some("Test".to_string()),
1392 input: None,
1393 });
1394 let plan = LogicalPlan::new(scan);
1395
1396 let cardinality = optimizer.estimate_cardinality(&plan);
1397 assert!((cardinality - 500.0).abs() < 0.001);
1398 }
1399
1400 #[test]
1401 fn test_optimizer_estimate_cost() {
1402 let optimizer = Optimizer::new();
1403 let plan = LogicalPlan::new(LogicalOperator::NodeScan(NodeScanOp {
1404 variable: "n".to_string(),
1405 label: None,
1406 input: None,
1407 }));
1408
1409 let cost = optimizer.estimate_cost(&plan);
1410 assert!(cost.total() > 0.0);
1411 }
1412
1413 #[test]
1416 fn test_filter_pushdown_through_project() {
1417 let optimizer = Optimizer::new();
1418
1419 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1420 predicate: LogicalExpression::Binary {
1421 left: Box::new(LogicalExpression::Property {
1422 variable: "n".to_string(),
1423 property: "age".to_string(),
1424 }),
1425 op: BinaryOp::Gt,
1426 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1427 },
1428 pushdown_hint: None,
1429 input: Box::new(LogicalOperator::Project(ProjectOp {
1430 projections: vec![Projection {
1431 expression: LogicalExpression::Variable("n".to_string()),
1432 alias: None,
1433 }],
1434 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1435 variable: "n".to_string(),
1436 label: None,
1437 input: None,
1438 })),
1439 })),
1440 }));
1441
1442 let optimized = optimizer.optimize(plan).unwrap();
1443
1444 if let LogicalOperator::Project(proj) = &optimized.root
1446 && let LogicalOperator::Filter(_) = proj.input.as_ref()
1447 {
1448 return;
1449 }
1450 panic!("Expected Project -> Filter structure");
1451 }
1452
1453 #[test]
1454 fn test_filter_not_pushed_through_project_with_alias() {
1455 let optimizer = Optimizer::new();
1456
1457 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1459 predicate: LogicalExpression::Binary {
1460 left: Box::new(LogicalExpression::Variable("x".to_string())),
1461 op: BinaryOp::Gt,
1462 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1463 },
1464 pushdown_hint: None,
1465 input: Box::new(LogicalOperator::Project(ProjectOp {
1466 projections: vec![Projection {
1467 expression: LogicalExpression::Property {
1468 variable: "n".to_string(),
1469 property: "age".to_string(),
1470 },
1471 alias: Some("x".to_string()),
1472 }],
1473 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1474 variable: "n".to_string(),
1475 label: None,
1476 input: None,
1477 })),
1478 })),
1479 }));
1480
1481 let optimized = optimizer.optimize(plan).unwrap();
1482
1483 if let LogicalOperator::Filter(filter) = &optimized.root
1485 && let LogicalOperator::Project(_) = filter.input.as_ref()
1486 {
1487 return;
1488 }
1489 panic!("Expected Filter -> Project structure");
1490 }
1491
1492 #[test]
1493 fn test_filter_pushdown_through_limit() {
1494 let optimizer = Optimizer::new();
1495
1496 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1497 predicate: LogicalExpression::Literal(Value::Bool(true)),
1498 pushdown_hint: None,
1499 input: Box::new(LogicalOperator::Limit(LimitOp {
1500 count: 10.into(),
1501 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1502 variable: "n".to_string(),
1503 label: None,
1504 input: None,
1505 })),
1506 })),
1507 }));
1508
1509 let optimized = optimizer.optimize(plan).unwrap();
1510
1511 if let LogicalOperator::Filter(filter) = &optimized.root
1513 && let LogicalOperator::Limit(_) = filter.input.as_ref()
1514 {
1515 return;
1516 }
1517 panic!("Expected Filter -> Limit structure");
1518 }
1519
1520 #[test]
1521 fn test_filter_pushdown_through_sort() {
1522 let optimizer = Optimizer::new();
1523
1524 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1525 predicate: LogicalExpression::Literal(Value::Bool(true)),
1526 pushdown_hint: None,
1527 input: Box::new(LogicalOperator::Sort(SortOp {
1528 keys: vec![SortKey {
1529 expression: LogicalExpression::Variable("n".to_string()),
1530 order: SortOrder::Ascending,
1531 nulls: None,
1532 }],
1533 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1534 variable: "n".to_string(),
1535 label: None,
1536 input: None,
1537 })),
1538 })),
1539 }));
1540
1541 let optimized = optimizer.optimize(plan).unwrap();
1542
1543 if let LogicalOperator::Filter(filter) = &optimized.root
1545 && let LogicalOperator::Sort(_) = filter.input.as_ref()
1546 {
1547 return;
1548 }
1549 panic!("Expected Filter -> Sort structure");
1550 }
1551
1552 #[test]
1553 fn test_filter_pushdown_through_distinct() {
1554 let optimizer = Optimizer::new();
1555
1556 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1557 predicate: LogicalExpression::Literal(Value::Bool(true)),
1558 pushdown_hint: None,
1559 input: Box::new(LogicalOperator::Distinct(DistinctOp {
1560 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1561 variable: "n".to_string(),
1562 label: None,
1563 input: None,
1564 })),
1565 columns: None,
1566 })),
1567 }));
1568
1569 let optimized = optimizer.optimize(plan).unwrap();
1570
1571 if let LogicalOperator::Filter(filter) = &optimized.root
1573 && let LogicalOperator::Distinct(_) = filter.input.as_ref()
1574 {
1575 return;
1576 }
1577 panic!("Expected Filter -> Distinct structure");
1578 }
1579
1580 #[test]
1581 fn test_filter_not_pushed_through_aggregate() {
1582 let optimizer = Optimizer::new();
1583
1584 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1585 predicate: LogicalExpression::Binary {
1586 left: Box::new(LogicalExpression::Variable("cnt".to_string())),
1587 op: BinaryOp::Gt,
1588 right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1589 },
1590 pushdown_hint: None,
1591 input: Box::new(LogicalOperator::Aggregate(AggregateOp {
1592 group_by: vec![],
1593 aggregates: vec![AggregateExpr {
1594 function: AggregateFunction::Count,
1595 expression: None,
1596 expression2: None,
1597 distinct: false,
1598 alias: Some("cnt".to_string()),
1599 percentile: None,
1600 separator: None,
1601 }],
1602 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1603 variable: "n".to_string(),
1604 label: None,
1605 input: None,
1606 })),
1607 having: None,
1608 })),
1609 }));
1610
1611 let optimized = optimizer.optimize(plan).unwrap();
1612
1613 if let LogicalOperator::Filter(filter) = &optimized.root
1615 && let LogicalOperator::Aggregate(_) = filter.input.as_ref()
1616 {
1617 return;
1618 }
1619 panic!("Expected Filter -> Aggregate structure");
1620 }
1621
1622 #[test]
1623 fn test_filter_pushdown_to_left_join_side() {
1624 let optimizer = Optimizer::new();
1625
1626 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1628 predicate: LogicalExpression::Binary {
1629 left: Box::new(LogicalExpression::Property {
1630 variable: "a".to_string(),
1631 property: "age".to_string(),
1632 }),
1633 op: BinaryOp::Gt,
1634 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1635 },
1636 pushdown_hint: None,
1637 input: Box::new(LogicalOperator::Join(JoinOp {
1638 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1639 variable: "a".to_string(),
1640 label: Some("Person".to_string()),
1641 input: None,
1642 })),
1643 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1644 variable: "b".to_string(),
1645 label: Some("Company".to_string()),
1646 input: None,
1647 })),
1648 join_type: JoinType::Inner,
1649 conditions: vec![],
1650 })),
1651 }));
1652
1653 let optimized = optimizer.optimize(plan).unwrap();
1654
1655 if let LogicalOperator::Join(join) = &optimized.root
1657 && let LogicalOperator::Filter(_) = join.left.as_ref()
1658 {
1659 return;
1660 }
1661 panic!("Expected Join with Filter on left side");
1662 }
1663
1664 #[test]
1665 fn test_filter_pushdown_to_right_join_side() {
1666 let optimizer = Optimizer::new();
1667
1668 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1670 predicate: LogicalExpression::Binary {
1671 left: Box::new(LogicalExpression::Property {
1672 variable: "b".to_string(),
1673 property: "name".to_string(),
1674 }),
1675 op: BinaryOp::Eq,
1676 right: Box::new(LogicalExpression::Literal(Value::String("Acme".into()))),
1677 },
1678 pushdown_hint: None,
1679 input: Box::new(LogicalOperator::Join(JoinOp {
1680 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1681 variable: "a".to_string(),
1682 label: Some("Person".to_string()),
1683 input: None,
1684 })),
1685 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1686 variable: "b".to_string(),
1687 label: Some("Company".to_string()),
1688 input: None,
1689 })),
1690 join_type: JoinType::Inner,
1691 conditions: vec![],
1692 })),
1693 }));
1694
1695 let optimized = optimizer.optimize(plan).unwrap();
1696
1697 if let LogicalOperator::Join(join) = &optimized.root
1699 && let LogicalOperator::Filter(_) = join.right.as_ref()
1700 {
1701 return;
1702 }
1703 panic!("Expected Join with Filter on right side");
1704 }
1705
1706 #[test]
1707 fn test_filter_not_pushed_when_uses_both_join_sides() {
1708 let optimizer = Optimizer::new();
1709
1710 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1712 predicate: LogicalExpression::Binary {
1713 left: Box::new(LogicalExpression::Property {
1714 variable: "a".to_string(),
1715 property: "id".to_string(),
1716 }),
1717 op: BinaryOp::Eq,
1718 right: Box::new(LogicalExpression::Property {
1719 variable: "b".to_string(),
1720 property: "a_id".to_string(),
1721 }),
1722 },
1723 pushdown_hint: None,
1724 input: Box::new(LogicalOperator::Join(JoinOp {
1725 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1726 variable: "a".to_string(),
1727 label: None,
1728 input: None,
1729 })),
1730 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1731 variable: "b".to_string(),
1732 label: None,
1733 input: None,
1734 })),
1735 join_type: JoinType::Inner,
1736 conditions: vec![],
1737 })),
1738 }));
1739
1740 let optimized = optimizer.optimize(plan).unwrap();
1741
1742 if let LogicalOperator::Filter(filter) = &optimized.root
1744 && let LogicalOperator::Join(_) = filter.input.as_ref()
1745 {
1746 return;
1747 }
1748 panic!("Expected Filter -> Join structure");
1749 }
1750
1751 #[test]
1754 fn test_extract_variables_from_variable() {
1755 let optimizer = Optimizer::new();
1756 let expr = LogicalExpression::Variable("x".to_string());
1757 let vars = optimizer.extract_variables(&expr);
1758 assert_eq!(vars.len(), 1);
1759 assert!(vars.contains("x"));
1760 }
1761
1762 #[test]
1763 fn test_extract_variables_from_unary() {
1764 let optimizer = Optimizer::new();
1765 let expr = LogicalExpression::Unary {
1766 op: UnaryOp::Not,
1767 operand: Box::new(LogicalExpression::Variable("x".to_string())),
1768 };
1769 let vars = optimizer.extract_variables(&expr);
1770 assert_eq!(vars.len(), 1);
1771 assert!(vars.contains("x"));
1772 }
1773
1774 #[test]
1775 fn test_extract_variables_from_function_call() {
1776 let optimizer = Optimizer::new();
1777 let expr = LogicalExpression::FunctionCall {
1778 name: "length".to_string(),
1779 args: vec![
1780 LogicalExpression::Variable("a".to_string()),
1781 LogicalExpression::Variable("b".to_string()),
1782 ],
1783 distinct: false,
1784 };
1785 let vars = optimizer.extract_variables(&expr);
1786 assert_eq!(vars.len(), 2);
1787 assert!(vars.contains("a"));
1788 assert!(vars.contains("b"));
1789 }
1790
1791 #[test]
1792 fn test_extract_variables_from_list() {
1793 let optimizer = Optimizer::new();
1794 let expr = LogicalExpression::List(vec![
1795 LogicalExpression::Variable("a".to_string()),
1796 LogicalExpression::Literal(Value::Int64(1)),
1797 LogicalExpression::Variable("b".to_string()),
1798 ]);
1799 let vars = optimizer.extract_variables(&expr);
1800 assert_eq!(vars.len(), 2);
1801 assert!(vars.contains("a"));
1802 assert!(vars.contains("b"));
1803 }
1804
1805 #[test]
1806 fn test_extract_variables_from_map() {
1807 let optimizer = Optimizer::new();
1808 let expr = LogicalExpression::Map(vec![
1809 (
1810 "key1".to_string(),
1811 LogicalExpression::Variable("a".to_string()),
1812 ),
1813 (
1814 "key2".to_string(),
1815 LogicalExpression::Variable("b".to_string()),
1816 ),
1817 ]);
1818 let vars = optimizer.extract_variables(&expr);
1819 assert_eq!(vars.len(), 2);
1820 assert!(vars.contains("a"));
1821 assert!(vars.contains("b"));
1822 }
1823
1824 #[test]
1825 fn test_extract_variables_from_index_access() {
1826 let optimizer = Optimizer::new();
1827 let expr = LogicalExpression::IndexAccess {
1828 base: Box::new(LogicalExpression::Variable("list".to_string())),
1829 index: Box::new(LogicalExpression::Variable("idx".to_string())),
1830 };
1831 let vars = optimizer.extract_variables(&expr);
1832 assert_eq!(vars.len(), 2);
1833 assert!(vars.contains("list"));
1834 assert!(vars.contains("idx"));
1835 }
1836
1837 #[test]
1838 fn test_extract_variables_from_slice_access() {
1839 let optimizer = Optimizer::new();
1840 let expr = LogicalExpression::SliceAccess {
1841 base: Box::new(LogicalExpression::Variable("list".to_string())),
1842 start: Some(Box::new(LogicalExpression::Variable("s".to_string()))),
1843 end: Some(Box::new(LogicalExpression::Variable("e".to_string()))),
1844 };
1845 let vars = optimizer.extract_variables(&expr);
1846 assert_eq!(vars.len(), 3);
1847 assert!(vars.contains("list"));
1848 assert!(vars.contains("s"));
1849 assert!(vars.contains("e"));
1850 }
1851
1852 #[test]
1853 fn test_extract_variables_from_case() {
1854 let optimizer = Optimizer::new();
1855 let expr = LogicalExpression::Case {
1856 operand: Some(Box::new(LogicalExpression::Variable("x".to_string()))),
1857 when_clauses: vec![(
1858 LogicalExpression::Literal(Value::Int64(1)),
1859 LogicalExpression::Variable("a".to_string()),
1860 )],
1861 else_clause: Some(Box::new(LogicalExpression::Variable("b".to_string()))),
1862 };
1863 let vars = optimizer.extract_variables(&expr);
1864 assert_eq!(vars.len(), 3);
1865 assert!(vars.contains("x"));
1866 assert!(vars.contains("a"));
1867 assert!(vars.contains("b"));
1868 }
1869
1870 #[test]
1871 fn test_extract_variables_from_labels() {
1872 let optimizer = Optimizer::new();
1873 let expr = LogicalExpression::Labels("n".to_string());
1874 let vars = optimizer.extract_variables(&expr);
1875 assert_eq!(vars.len(), 1);
1876 assert!(vars.contains("n"));
1877 }
1878
1879 #[test]
1880 fn test_extract_variables_from_type() {
1881 let optimizer = Optimizer::new();
1882 let expr = LogicalExpression::Type("e".to_string());
1883 let vars = optimizer.extract_variables(&expr);
1884 assert_eq!(vars.len(), 1);
1885 assert!(vars.contains("e"));
1886 }
1887
1888 #[test]
1889 fn test_extract_variables_from_id() {
1890 let optimizer = Optimizer::new();
1891 let expr = LogicalExpression::Id("n".to_string());
1892 let vars = optimizer.extract_variables(&expr);
1893 assert_eq!(vars.len(), 1);
1894 assert!(vars.contains("n"));
1895 }
1896
1897 #[test]
1898 fn test_extract_variables_from_list_comprehension() {
1899 let optimizer = Optimizer::new();
1900 let expr = LogicalExpression::ListComprehension {
1901 variable: "x".to_string(),
1902 list_expr: Box::new(LogicalExpression::Variable("items".to_string())),
1903 filter_expr: Some(Box::new(LogicalExpression::Variable("pred".to_string()))),
1904 map_expr: Box::new(LogicalExpression::Variable("result".to_string())),
1905 };
1906 let vars = optimizer.extract_variables(&expr);
1907 assert!(vars.contains("items"));
1908 assert!(vars.contains("pred"));
1909 assert!(vars.contains("result"));
1910 }
1911
1912 #[test]
1913 fn test_extract_variables_from_literal_and_parameter() {
1914 let optimizer = Optimizer::new();
1915
1916 let literal = LogicalExpression::Literal(Value::Int64(42));
1917 assert!(optimizer.extract_variables(&literal).is_empty());
1918
1919 let param = LogicalExpression::Parameter("p".to_string());
1920 assert!(optimizer.extract_variables(¶m).is_empty());
1921 }
1922
1923 #[test]
1926 fn test_recursive_filter_pushdown_through_skip() {
1927 let optimizer = Optimizer::new();
1928
1929 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1930 items: vec![ReturnItem {
1931 expression: LogicalExpression::Variable("n".to_string()),
1932 alias: None,
1933 }],
1934 distinct: false,
1935 input: Box::new(LogicalOperator::Filter(FilterOp {
1936 predicate: LogicalExpression::Literal(Value::Bool(true)),
1937 pushdown_hint: None,
1938 input: Box::new(LogicalOperator::Skip(SkipOp {
1939 count: 5.into(),
1940 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1941 variable: "n".to_string(),
1942 label: None,
1943 input: None,
1944 })),
1945 })),
1946 })),
1947 }));
1948
1949 let optimized = optimizer.optimize(plan).unwrap();
1950
1951 assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1953 }
1954
1955 #[test]
1956 fn test_nested_filter_pushdown() {
1957 let optimizer = Optimizer::new();
1958
1959 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1961 items: vec![ReturnItem {
1962 expression: LogicalExpression::Variable("n".to_string()),
1963 alias: None,
1964 }],
1965 distinct: false,
1966 input: Box::new(LogicalOperator::Filter(FilterOp {
1967 predicate: LogicalExpression::Binary {
1968 left: Box::new(LogicalExpression::Property {
1969 variable: "n".to_string(),
1970 property: "x".to_string(),
1971 }),
1972 op: BinaryOp::Gt,
1973 right: Box::new(LogicalExpression::Literal(Value::Int64(1))),
1974 },
1975 pushdown_hint: None,
1976 input: Box::new(LogicalOperator::Filter(FilterOp {
1977 predicate: LogicalExpression::Binary {
1978 left: Box::new(LogicalExpression::Property {
1979 variable: "n".to_string(),
1980 property: "y".to_string(),
1981 }),
1982 op: BinaryOp::Lt,
1983 right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1984 },
1985 pushdown_hint: None,
1986 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1987 variable: "n".to_string(),
1988 label: None,
1989 input: None,
1990 })),
1991 })),
1992 })),
1993 }));
1994
1995 let optimized = optimizer.optimize(plan).unwrap();
1996 assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1997 }
1998
1999 #[test]
2000 fn test_cyclic_join_produces_multi_way_join() {
2001 use crate::query::plan::JoinCondition;
2002
2003 let scan_a = LogicalOperator::NodeScan(NodeScanOp {
2005 variable: "a".to_string(),
2006 label: Some("Person".to_string()),
2007 input: None,
2008 });
2009 let scan_b = LogicalOperator::NodeScan(NodeScanOp {
2010 variable: "b".to_string(),
2011 label: Some("Person".to_string()),
2012 input: None,
2013 });
2014 let scan_c = LogicalOperator::NodeScan(NodeScanOp {
2015 variable: "c".to_string(),
2016 label: Some("Person".to_string()),
2017 input: None,
2018 });
2019
2020 let join_ab = LogicalOperator::Join(JoinOp {
2022 left: Box::new(scan_a),
2023 right: Box::new(scan_b),
2024 join_type: JoinType::Inner,
2025 conditions: vec![JoinCondition {
2026 left: LogicalExpression::Variable("a".to_string()),
2027 right: LogicalExpression::Variable("b".to_string()),
2028 }],
2029 });
2030
2031 let join_abc = LogicalOperator::Join(JoinOp {
2032 left: Box::new(join_ab),
2033 right: Box::new(scan_c),
2034 join_type: JoinType::Inner,
2035 conditions: vec![
2036 JoinCondition {
2037 left: LogicalExpression::Variable("b".to_string()),
2038 right: LogicalExpression::Variable("c".to_string()),
2039 },
2040 JoinCondition {
2041 left: LogicalExpression::Variable("c".to_string()),
2042 right: LogicalExpression::Variable("a".to_string()),
2043 },
2044 ],
2045 });
2046
2047 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
2048 items: vec![ReturnItem {
2049 expression: LogicalExpression::Variable("a".to_string()),
2050 alias: None,
2051 }],
2052 distinct: false,
2053 input: Box::new(join_abc),
2054 }));
2055
2056 let mut optimizer = Optimizer::new();
2057 optimizer
2058 .card_estimator
2059 .add_table_stats("Person", cardinality::TableStats::new(1000));
2060
2061 let optimized = optimizer.optimize(plan).unwrap();
2062
2063 fn has_multi_way_join(op: &LogicalOperator) -> bool {
2065 match op {
2066 LogicalOperator::MultiWayJoin(_) => true,
2067 LogicalOperator::Return(ret) => has_multi_way_join(&ret.input),
2068 LogicalOperator::Filter(f) => has_multi_way_join(&f.input),
2069 LogicalOperator::Project(p) => has_multi_way_join(&p.input),
2070 _ => false,
2071 }
2072 }
2073
2074 assert!(
2075 has_multi_way_join(&optimized.root),
2076 "Expected MultiWayJoin for cyclic triangle pattern"
2077 );
2078 }
2079
2080 #[test]
2081 fn test_acyclic_join_uses_binary_joins() {
2082 use crate::query::plan::JoinCondition;
2083
2084 let scan_a = LogicalOperator::NodeScan(NodeScanOp {
2086 variable: "a".to_string(),
2087 label: Some("Person".to_string()),
2088 input: None,
2089 });
2090 let scan_b = LogicalOperator::NodeScan(NodeScanOp {
2091 variable: "b".to_string(),
2092 label: Some("Person".to_string()),
2093 input: None,
2094 });
2095 let scan_c = LogicalOperator::NodeScan(NodeScanOp {
2096 variable: "c".to_string(),
2097 label: Some("Company".to_string()),
2098 input: None,
2099 });
2100
2101 let join_ab = LogicalOperator::Join(JoinOp {
2102 left: Box::new(scan_a),
2103 right: Box::new(scan_b),
2104 join_type: JoinType::Inner,
2105 conditions: vec![JoinCondition {
2106 left: LogicalExpression::Variable("a".to_string()),
2107 right: LogicalExpression::Variable("b".to_string()),
2108 }],
2109 });
2110
2111 let join_abc = LogicalOperator::Join(JoinOp {
2112 left: Box::new(join_ab),
2113 right: Box::new(scan_c),
2114 join_type: JoinType::Inner,
2115 conditions: vec![JoinCondition {
2116 left: LogicalExpression::Variable("b".to_string()),
2117 right: LogicalExpression::Variable("c".to_string()),
2118 }],
2119 });
2120
2121 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
2122 items: vec![ReturnItem {
2123 expression: LogicalExpression::Variable("a".to_string()),
2124 alias: None,
2125 }],
2126 distinct: false,
2127 input: Box::new(join_abc),
2128 }));
2129
2130 let mut optimizer = Optimizer::new();
2131 optimizer
2132 .card_estimator
2133 .add_table_stats("Person", cardinality::TableStats::new(1000));
2134 optimizer
2135 .card_estimator
2136 .add_table_stats("Company", cardinality::TableStats::new(100));
2137
2138 let optimized = optimizer.optimize(plan).unwrap();
2139
2140 fn has_multi_way_join(op: &LogicalOperator) -> bool {
2142 match op {
2143 LogicalOperator::MultiWayJoin(_) => true,
2144 LogicalOperator::Return(ret) => has_multi_way_join(&ret.input),
2145 LogicalOperator::Filter(f) => has_multi_way_join(&f.input),
2146 LogicalOperator::Project(p) => has_multi_way_join(&p.input),
2147 LogicalOperator::Join(j) => {
2148 has_multi_way_join(&j.left) || has_multi_way_join(&j.right)
2149 }
2150 _ => false,
2151 }
2152 }
2153
2154 assert!(
2155 !has_multi_way_join(&optimized.root),
2156 "Acyclic join should NOT produce MultiWayJoin"
2157 );
2158 }
2159}