1pub mod cardinality;
15pub mod cost;
16pub mod join_order;
17
18pub use cardinality::{
19 CardinalityEstimator, ColumnStats, EstimationLog, SelectivityConfig, TableStats,
20};
21pub use cost::{Cost, CostModel};
22pub use join_order::{BitSet, DPccp, JoinGraph, JoinGraphBuilder, JoinPlan};
23
24use crate::query::plan::{
25 FilterOp, JoinCondition, LogicalExpression, LogicalOperator, LogicalPlan, MultiWayJoinOp,
26};
27use grafeo_common::grafeo_debug_span;
28use grafeo_common::utils::error::Result;
29use std::collections::HashSet;
30
31#[derive(Debug, Clone)]
33struct JoinInfo {
34 left_var: String,
35 right_var: String,
36 left_expr: LogicalExpression,
37 right_expr: LogicalExpression,
38}
39
40#[derive(Debug, Clone, PartialEq, Eq, Hash)]
42enum RequiredColumn {
43 Variable(String),
45 Property(String, String),
47}
48
49pub struct Optimizer {
54 enable_filter_pushdown: bool,
56 enable_join_reorder: bool,
58 enable_projection_pushdown: bool,
60 cost_model: CostModel,
62 card_estimator: CardinalityEstimator,
64}
65
66impl Optimizer {
67 #[must_use]
69 pub fn new() -> Self {
70 Self {
71 enable_filter_pushdown: true,
72 enable_join_reorder: true,
73 enable_projection_pushdown: true,
74 cost_model: CostModel::new(),
75 card_estimator: CardinalityEstimator::new(),
76 }
77 }
78
79 #[cfg(feature = "lpg")]
85 #[must_use]
86 pub fn from_store(store: &grafeo_core::graph::lpg::LpgStore) -> Self {
87 store.ensure_statistics_fresh();
88 let stats = store.statistics();
89 Self::from_statistics(&stats)
90 }
91
92 #[must_use]
99 pub fn from_graph_store(store: &dyn grafeo_core::graph::GraphStore) -> Self {
100 let stats = store.statistics();
101 Self::from_statistics(&stats)
102 }
103
104 #[cfg(feature = "triple-store")]
109 #[must_use]
110 pub fn from_rdf_statistics(rdf_stats: grafeo_core::statistics::RdfStatistics) -> Self {
111 let total = rdf_stats.total_triples;
112 let estimator = CardinalityEstimator::from_rdf_statistics(rdf_stats);
113 Self {
114 enable_filter_pushdown: true,
115 enable_join_reorder: true,
116 enable_projection_pushdown: true,
117 cost_model: CostModel::new().with_graph_totals(total, total),
118 card_estimator: estimator,
119 }
120 }
121
122 #[must_use]
127 fn from_statistics(stats: &grafeo_core::statistics::Statistics) -> Self {
128 let estimator = CardinalityEstimator::from_statistics(stats);
129
130 let avg_fanout = if stats.total_nodes > 0 {
131 (stats.total_edges as f64 / stats.total_nodes as f64).max(1.0)
132 } else {
133 10.0
134 };
135
136 let edge_type_degrees: std::collections::HashMap<String, (f64, f64)> = stats
137 .edge_types
138 .iter()
139 .map(|(name, et)| (name.clone(), (et.avg_out_degree, et.avg_in_degree)))
140 .collect();
141
142 let label_cardinalities: std::collections::HashMap<String, u64> = stats
143 .labels
144 .iter()
145 .map(|(name, ls)| (name.clone(), ls.node_count))
146 .collect();
147
148 Self {
149 enable_filter_pushdown: true,
150 enable_join_reorder: true,
151 enable_projection_pushdown: true,
152 cost_model: CostModel::new()
153 .with_avg_fanout(avg_fanout)
154 .with_edge_type_degrees(edge_type_degrees)
155 .with_label_cardinalities(label_cardinalities)
156 .with_graph_totals(stats.total_nodes, stats.total_edges),
157 card_estimator: estimator,
158 }
159 }
160
161 pub fn with_filter_pushdown(mut self, enabled: bool) -> Self {
163 self.enable_filter_pushdown = enabled;
164 self
165 }
166
167 pub fn with_join_reorder(mut self, enabled: bool) -> Self {
169 self.enable_join_reorder = enabled;
170 self
171 }
172
173 pub fn with_projection_pushdown(mut self, enabled: bool) -> Self {
175 self.enable_projection_pushdown = enabled;
176 self
177 }
178
179 pub fn with_cost_model(mut self, cost_model: CostModel) -> Self {
181 self.cost_model = cost_model;
182 self
183 }
184
185 pub fn with_cardinality_estimator(mut self, estimator: CardinalityEstimator) -> Self {
187 self.card_estimator = estimator;
188 self
189 }
190
191 pub fn with_selectivity_config(mut self, config: SelectivityConfig) -> Self {
193 self.card_estimator = CardinalityEstimator::with_selectivity_config(config);
194 self
195 }
196
197 pub fn cost_model(&self) -> &CostModel {
199 &self.cost_model
200 }
201
202 pub fn cardinality_estimator(&self) -> &CardinalityEstimator {
204 &self.card_estimator
205 }
206
207 pub fn estimate_cost(&self, plan: &LogicalPlan) -> Cost {
213 self.cost_model
214 .estimate_tree(&plan.root, &self.card_estimator)
215 }
216
217 pub fn estimate_cardinality(&self, plan: &LogicalPlan) -> f64 {
219 self.card_estimator.estimate(&plan.root)
220 }
221
222 pub fn optimize(&self, plan: LogicalPlan) -> Result<LogicalPlan> {
228 let _span = grafeo_debug_span!("grafeo::query::optimize");
229 let mut root = plan.root;
230
231 if self.enable_filter_pushdown {
233 root = self.push_filters_down(root);
234 }
235
236 if self.enable_join_reorder {
237 root = self.reorder_joins(root);
238 }
239
240 if self.enable_projection_pushdown {
241 root = self.push_projections_down(root);
242 }
243
244 Ok(LogicalPlan {
245 root,
246 explain: plan.explain,
247 profile: plan.profile,
248 default_params: plan.default_params,
249 })
250 }
251
252 fn push_projections_down(&self, op: LogicalOperator) -> LogicalOperator {
259 let required = self.collect_required_columns(&op);
261
262 self.push_projections_recursive(op, &required)
264 }
265
266 fn collect_required_columns(&self, op: &LogicalOperator) -> HashSet<RequiredColumn> {
268 let mut required = HashSet::new();
269 Self::collect_required_recursive(op, &mut required);
270 required
271 }
272
273 fn collect_required_recursive(op: &LogicalOperator, required: &mut HashSet<RequiredColumn>) {
275 match op {
276 LogicalOperator::Return(ret) => {
277 for item in &ret.items {
278 Self::collect_from_expression(&item.expression, required);
279 }
280 Self::collect_required_recursive(&ret.input, required);
281 }
282 LogicalOperator::Project(proj) => {
283 for p in &proj.projections {
284 Self::collect_from_expression(&p.expression, required);
285 }
286 Self::collect_required_recursive(&proj.input, required);
287 }
288 LogicalOperator::Filter(filter) => {
289 Self::collect_from_expression(&filter.predicate, required);
290 Self::collect_required_recursive(&filter.input, required);
291 }
292 LogicalOperator::Sort(sort) => {
293 for key in &sort.keys {
294 Self::collect_from_expression(&key.expression, required);
295 }
296 Self::collect_required_recursive(&sort.input, required);
297 }
298 LogicalOperator::Aggregate(agg) => {
299 for expr in &agg.group_by {
300 Self::collect_from_expression(expr, required);
301 }
302 for agg_expr in &agg.aggregates {
303 if let Some(ref expr) = agg_expr.expression {
304 Self::collect_from_expression(expr, required);
305 }
306 }
307 if let Some(ref having) = agg.having {
308 Self::collect_from_expression(having, required);
309 }
310 Self::collect_required_recursive(&agg.input, required);
311 }
312 LogicalOperator::Join(join) => {
313 for cond in &join.conditions {
314 Self::collect_from_expression(&cond.left, required);
315 Self::collect_from_expression(&cond.right, required);
316 }
317 Self::collect_required_recursive(&join.left, required);
318 Self::collect_required_recursive(&join.right, required);
319 }
320 LogicalOperator::Expand(expand) => {
321 required.insert(RequiredColumn::Variable(expand.from_variable.clone()));
323 required.insert(RequiredColumn::Variable(expand.to_variable.clone()));
324 if let Some(ref edge_var) = expand.edge_variable {
325 required.insert(RequiredColumn::Variable(edge_var.clone()));
326 }
327 Self::collect_required_recursive(&expand.input, required);
328 }
329 LogicalOperator::Limit(limit) => {
330 Self::collect_required_recursive(&limit.input, required);
331 }
332 LogicalOperator::Skip(skip) => {
333 Self::collect_required_recursive(&skip.input, required);
334 }
335 LogicalOperator::Distinct(distinct) => {
336 Self::collect_required_recursive(&distinct.input, required);
337 }
338 LogicalOperator::NodeScan(scan) => {
339 required.insert(RequiredColumn::Variable(scan.variable.clone()));
340 }
341 LogicalOperator::EdgeScan(scan) => {
342 required.insert(RequiredColumn::Variable(scan.variable.clone()));
343 }
344 LogicalOperator::MultiWayJoin(mwj) => {
345 for cond in &mwj.conditions {
346 Self::collect_from_expression(&cond.left, required);
347 Self::collect_from_expression(&cond.right, required);
348 }
349 for input in &mwj.inputs {
350 Self::collect_required_recursive(input, required);
351 }
352 }
353 _ => {}
354 }
355 }
356
357 fn collect_from_expression(expr: &LogicalExpression, required: &mut HashSet<RequiredColumn>) {
359 match expr {
360 LogicalExpression::Variable(var) => {
361 required.insert(RequiredColumn::Variable(var.clone()));
362 }
363 LogicalExpression::Property { variable, property } => {
364 required.insert(RequiredColumn::Property(variable.clone(), property.clone()));
365 required.insert(RequiredColumn::Variable(variable.clone()));
366 }
367 LogicalExpression::Binary { left, right, .. } => {
368 Self::collect_from_expression(left, required);
369 Self::collect_from_expression(right, required);
370 }
371 LogicalExpression::Unary { operand, .. } => {
372 Self::collect_from_expression(operand, required);
373 }
374 LogicalExpression::FunctionCall { args, .. } => {
375 for arg in args {
376 Self::collect_from_expression(arg, required);
377 }
378 }
379 LogicalExpression::List(items) => {
380 for item in items {
381 Self::collect_from_expression(item, required);
382 }
383 }
384 LogicalExpression::Map(pairs) => {
385 for (_, value) in pairs {
386 Self::collect_from_expression(value, required);
387 }
388 }
389 LogicalExpression::IndexAccess { base, index } => {
390 Self::collect_from_expression(base, required);
391 Self::collect_from_expression(index, required);
392 }
393 LogicalExpression::SliceAccess { base, start, end } => {
394 Self::collect_from_expression(base, required);
395 if let Some(s) = start {
396 Self::collect_from_expression(s, required);
397 }
398 if let Some(e) = end {
399 Self::collect_from_expression(e, required);
400 }
401 }
402 LogicalExpression::Case {
403 operand,
404 when_clauses,
405 else_clause,
406 } => {
407 if let Some(op) = operand {
408 Self::collect_from_expression(op, required);
409 }
410 for (cond, result) in when_clauses {
411 Self::collect_from_expression(cond, required);
412 Self::collect_from_expression(result, required);
413 }
414 if let Some(else_expr) = else_clause {
415 Self::collect_from_expression(else_expr, required);
416 }
417 }
418 LogicalExpression::Labels(var)
419 | LogicalExpression::Type(var)
420 | LogicalExpression::Id(var) => {
421 required.insert(RequiredColumn::Variable(var.clone()));
422 }
423 LogicalExpression::ListComprehension {
424 list_expr,
425 filter_expr,
426 map_expr,
427 ..
428 } => {
429 Self::collect_from_expression(list_expr, required);
430 if let Some(filter) = filter_expr {
431 Self::collect_from_expression(filter, required);
432 }
433 Self::collect_from_expression(map_expr, required);
434 }
435 _ => {}
436 }
437 }
438
439 fn push_projections_recursive(
441 &self,
442 op: LogicalOperator,
443 required: &HashSet<RequiredColumn>,
444 ) -> LogicalOperator {
445 match op {
446 LogicalOperator::Return(mut ret) => {
447 ret.input = Box::new(self.push_projections_recursive(*ret.input, required));
448 LogicalOperator::Return(ret)
449 }
450 LogicalOperator::Project(mut proj) => {
451 proj.input = Box::new(self.push_projections_recursive(*proj.input, required));
452 LogicalOperator::Project(proj)
453 }
454 LogicalOperator::Filter(mut filter) => {
455 filter.input = Box::new(self.push_projections_recursive(*filter.input, required));
456 LogicalOperator::Filter(filter)
457 }
458 LogicalOperator::Sort(mut sort) => {
459 sort.input = Box::new(self.push_projections_recursive(*sort.input, required));
462 LogicalOperator::Sort(sort)
463 }
464 LogicalOperator::Aggregate(mut agg) => {
465 agg.input = Box::new(self.push_projections_recursive(*agg.input, required));
466 LogicalOperator::Aggregate(agg)
467 }
468 LogicalOperator::Join(mut join) => {
469 let left_vars = self.collect_output_variables(&join.left);
472 let right_vars = self.collect_output_variables(&join.right);
473
474 let left_required: HashSet<_> = required
476 .iter()
477 .filter(|c| match c {
478 RequiredColumn::Variable(v) => left_vars.contains(v),
479 RequiredColumn::Property(v, _) => left_vars.contains(v),
480 })
481 .cloned()
482 .collect();
483
484 let right_required: HashSet<_> = required
485 .iter()
486 .filter(|c| match c {
487 RequiredColumn::Variable(v) => right_vars.contains(v),
488 RequiredColumn::Property(v, _) => right_vars.contains(v),
489 })
490 .cloned()
491 .collect();
492
493 join.left = Box::new(self.push_projections_recursive(*join.left, &left_required));
494 join.right =
495 Box::new(self.push_projections_recursive(*join.right, &right_required));
496 LogicalOperator::Join(join)
497 }
498 LogicalOperator::Expand(mut expand) => {
499 expand.input = Box::new(self.push_projections_recursive(*expand.input, required));
500 LogicalOperator::Expand(expand)
501 }
502 LogicalOperator::Limit(mut limit) => {
503 limit.input = Box::new(self.push_projections_recursive(*limit.input, required));
504 LogicalOperator::Limit(limit)
505 }
506 LogicalOperator::Skip(mut skip) => {
507 skip.input = Box::new(self.push_projections_recursive(*skip.input, required));
508 LogicalOperator::Skip(skip)
509 }
510 LogicalOperator::Distinct(mut distinct) => {
511 distinct.input =
512 Box::new(self.push_projections_recursive(*distinct.input, required));
513 LogicalOperator::Distinct(distinct)
514 }
515 LogicalOperator::MapCollect(mut mc) => {
516 mc.input = Box::new(self.push_projections_recursive(*mc.input, required));
517 LogicalOperator::MapCollect(mc)
518 }
519 LogicalOperator::MultiWayJoin(mut mwj) => {
520 mwj.inputs = mwj
521 .inputs
522 .into_iter()
523 .map(|input| self.push_projections_recursive(input, required))
524 .collect();
525 LogicalOperator::MultiWayJoin(mwj)
526 }
527 other => other,
528 }
529 }
530
531 fn reorder_joins(&self, op: LogicalOperator) -> LogicalOperator {
538 let op = self.reorder_joins_recursive(op);
540
541 if let Some((relations, conditions)) = self.extract_join_tree(&op)
543 && relations.len() >= 2
544 && let Some(optimized) = self.optimize_join_order(&relations, &conditions)
545 {
546 return optimized;
547 }
548
549 op
550 }
551
552 fn reorder_joins_recursive(&self, op: LogicalOperator) -> LogicalOperator {
554 match op {
555 LogicalOperator::Return(mut ret) => {
556 ret.input = Box::new(self.reorder_joins(*ret.input));
557 LogicalOperator::Return(ret)
558 }
559 LogicalOperator::Project(mut proj) => {
560 proj.input = Box::new(self.reorder_joins(*proj.input));
561 LogicalOperator::Project(proj)
562 }
563 LogicalOperator::Filter(mut filter) => {
564 filter.input = Box::new(self.reorder_joins(*filter.input));
565 LogicalOperator::Filter(filter)
566 }
567 LogicalOperator::Limit(mut limit) => {
568 limit.input = Box::new(self.reorder_joins(*limit.input));
569 LogicalOperator::Limit(limit)
570 }
571 LogicalOperator::Skip(mut skip) => {
572 skip.input = Box::new(self.reorder_joins(*skip.input));
573 LogicalOperator::Skip(skip)
574 }
575 LogicalOperator::Sort(mut sort) => {
576 sort.input = Box::new(self.reorder_joins(*sort.input));
577 LogicalOperator::Sort(sort)
578 }
579 LogicalOperator::Distinct(mut distinct) => {
580 distinct.input = Box::new(self.reorder_joins(*distinct.input));
581 LogicalOperator::Distinct(distinct)
582 }
583 LogicalOperator::Aggregate(mut agg) => {
584 agg.input = Box::new(self.reorder_joins(*agg.input));
585 LogicalOperator::Aggregate(agg)
586 }
587 LogicalOperator::Expand(mut expand) => {
588 expand.input = Box::new(self.reorder_joins(*expand.input));
589 LogicalOperator::Expand(expand)
590 }
591 LogicalOperator::MapCollect(mut mc) => {
592 mc.input = Box::new(self.reorder_joins(*mc.input));
593 LogicalOperator::MapCollect(mc)
594 }
595 LogicalOperator::MultiWayJoin(mut mwj) => {
596 mwj.inputs = mwj
597 .inputs
598 .into_iter()
599 .map(|input| self.reorder_joins(input))
600 .collect();
601 LogicalOperator::MultiWayJoin(mwj)
602 }
603 other => other,
605 }
606 }
607
608 fn extract_join_tree(
612 &self,
613 op: &LogicalOperator,
614 ) -> Option<(Vec<(String, LogicalOperator)>, Vec<JoinInfo>)> {
615 let mut relations = Vec::new();
616 let mut join_conditions = Vec::new();
617
618 if !self.collect_join_tree(op, &mut relations, &mut join_conditions) {
619 return None;
620 }
621
622 if relations.len() < 2 {
623 return None;
624 }
625
626 Some((relations, join_conditions))
627 }
628
629 fn collect_join_tree(
633 &self,
634 op: &LogicalOperator,
635 relations: &mut Vec<(String, LogicalOperator)>,
636 conditions: &mut Vec<JoinInfo>,
637 ) -> bool {
638 match op {
639 LogicalOperator::Join(join) => {
640 let left_ok = self.collect_join_tree(&join.left, relations, conditions);
642 let right_ok = self.collect_join_tree(&join.right, relations, conditions);
643
644 for cond in &join.conditions {
646 if let (Some(left_var), Some(right_var)) = (
647 self.extract_variable_from_expr(&cond.left),
648 self.extract_variable_from_expr(&cond.right),
649 ) {
650 conditions.push(JoinInfo {
651 left_var,
652 right_var,
653 left_expr: cond.left.clone(),
654 right_expr: cond.right.clone(),
655 });
656 }
657 }
658
659 left_ok && right_ok
660 }
661 LogicalOperator::NodeScan(scan) => {
662 relations.push((scan.variable.clone(), op.clone()));
663 true
664 }
665 LogicalOperator::EdgeScan(scan) => {
666 relations.push((scan.variable.clone(), op.clone()));
667 true
668 }
669 LogicalOperator::Filter(filter) => {
670 self.collect_join_tree(&filter.input, relations, conditions)
672 }
673 LogicalOperator::Expand(expand) => {
674 relations.push((expand.to_variable.clone(), op.clone()));
677 true
678 }
679 _ => false,
680 }
681 }
682
683 fn extract_variable_from_expr(&self, expr: &LogicalExpression) -> Option<String> {
685 match expr {
686 LogicalExpression::Variable(v) => Some(v.clone()),
687 LogicalExpression::Property { variable, .. } => Some(variable.clone()),
688 _ => None,
689 }
690 }
691
692 fn optimize_join_order(
695 &self,
696 relations: &[(String, LogicalOperator)],
697 conditions: &[JoinInfo],
698 ) -> Option<LogicalOperator> {
699 use join_order::{DPccp, JoinGraphBuilder};
700
701 let mut builder = JoinGraphBuilder::new();
703
704 for (var, relation) in relations {
705 builder.add_relation(var, relation.clone());
706 }
707
708 for cond in conditions {
709 builder.add_join_condition(
710 &cond.left_var,
711 &cond.right_var,
712 cond.left_expr.clone(),
713 cond.right_expr.clone(),
714 );
715 }
716
717 let graph = builder.build();
718
719 if graph.is_cyclic() && relations.len() >= 3 {
724 let mut var_counts: std::collections::HashMap<&str, usize> =
726 std::collections::HashMap::new();
727 for cond in conditions {
728 *var_counts.entry(&cond.left_var).or_default() += 1;
729 *var_counts.entry(&cond.right_var).or_default() += 1;
730 }
731 let shared_variables: Vec<String> = var_counts
732 .into_iter()
733 .filter(|(_, count)| *count >= 2)
734 .map(|(var, _)| var.to_string())
735 .collect();
736
737 let join_conditions: Vec<JoinCondition> = conditions
738 .iter()
739 .map(|c| JoinCondition {
740 left: c.left_expr.clone(),
741 right: c.right_expr.clone(),
742 })
743 .collect();
744
745 return Some(LogicalOperator::MultiWayJoin(MultiWayJoinOp {
746 inputs: relations.iter().map(|(_, rel)| rel.clone()).collect(),
747 conditions: join_conditions,
748 shared_variables,
749 }));
750 }
751
752 let mut dpccp = DPccp::new(&graph, &self.cost_model, &self.card_estimator);
754 let plan = dpccp.optimize()?;
755
756 Some(plan.operator)
757 }
758
759 fn push_filters_down(&self, op: LogicalOperator) -> LogicalOperator {
764 match op {
765 LogicalOperator::Filter(filter) => {
767 let optimized_input = self.push_filters_down(*filter.input);
768 self.try_push_filter_into(filter.predicate, optimized_input)
769 }
770 LogicalOperator::Return(mut ret) => {
772 ret.input = Box::new(self.push_filters_down(*ret.input));
773 LogicalOperator::Return(ret)
774 }
775 LogicalOperator::Project(mut proj) => {
776 proj.input = Box::new(self.push_filters_down(*proj.input));
777 LogicalOperator::Project(proj)
778 }
779 LogicalOperator::Limit(mut limit) => {
780 limit.input = Box::new(self.push_filters_down(*limit.input));
781 LogicalOperator::Limit(limit)
782 }
783 LogicalOperator::Skip(mut skip) => {
784 skip.input = Box::new(self.push_filters_down(*skip.input));
785 LogicalOperator::Skip(skip)
786 }
787 LogicalOperator::Sort(mut sort) => {
788 sort.input = Box::new(self.push_filters_down(*sort.input));
789 LogicalOperator::Sort(sort)
790 }
791 LogicalOperator::Distinct(mut distinct) => {
792 distinct.input = Box::new(self.push_filters_down(*distinct.input));
793 LogicalOperator::Distinct(distinct)
794 }
795 LogicalOperator::Expand(mut expand) => {
796 expand.input = Box::new(self.push_filters_down(*expand.input));
797 LogicalOperator::Expand(expand)
798 }
799 LogicalOperator::Join(mut join) => {
800 join.left = Box::new(self.push_filters_down(*join.left));
801 join.right = Box::new(self.push_filters_down(*join.right));
802 LogicalOperator::Join(join)
803 }
804 LogicalOperator::Aggregate(mut agg) => {
805 agg.input = Box::new(self.push_filters_down(*agg.input));
806 LogicalOperator::Aggregate(agg)
807 }
808 LogicalOperator::MapCollect(mut mc) => {
809 mc.input = Box::new(self.push_filters_down(*mc.input));
810 LogicalOperator::MapCollect(mc)
811 }
812 LogicalOperator::MultiWayJoin(mut mwj) => {
813 mwj.inputs = mwj
814 .inputs
815 .into_iter()
816 .map(|input| self.push_filters_down(input))
817 .collect();
818 LogicalOperator::MultiWayJoin(mwj)
819 }
820 other => other,
822 }
823 }
824
825 fn try_push_filter_into(
830 &self,
831 predicate: LogicalExpression,
832 op: LogicalOperator,
833 ) -> LogicalOperator {
834 match op {
835 LogicalOperator::Project(mut proj) => {
837 let predicate_vars = self.extract_variables(&predicate);
838 let computed_vars = self.extract_projection_aliases(&proj.projections);
839
840 if predicate_vars.is_disjoint(&computed_vars) {
842 proj.input = Box::new(self.try_push_filter_into(predicate, *proj.input));
843 LogicalOperator::Project(proj)
844 } else {
845 LogicalOperator::Filter(FilterOp {
847 predicate,
848 pushdown_hint: None,
849 input: Box::new(LogicalOperator::Project(proj)),
850 })
851 }
852 }
853
854 LogicalOperator::Return(mut ret) => {
856 ret.input = Box::new(self.try_push_filter_into(predicate, *ret.input));
857 LogicalOperator::Return(ret)
858 }
859
860 LogicalOperator::Expand(mut expand) => {
862 let predicate_vars = self.extract_variables(&predicate);
863
864 let mut introduced_vars = vec![&expand.to_variable];
869 if let Some(ref edge_var) = expand.edge_variable {
870 introduced_vars.push(edge_var);
871 }
872 if let Some(ref path_alias) = expand.path_alias {
873 introduced_vars.push(path_alias);
874 }
875
876 let uses_introduced_vars =
878 predicate_vars.iter().any(|v| introduced_vars.contains(&v));
879
880 if !uses_introduced_vars {
881 expand.input = Box::new(self.try_push_filter_into(predicate, *expand.input));
883 LogicalOperator::Expand(expand)
884 } else {
885 LogicalOperator::Filter(FilterOp {
887 predicate,
888 pushdown_hint: None,
889 input: Box::new(LogicalOperator::Expand(expand)),
890 })
891 }
892 }
893
894 LogicalOperator::Join(mut join) => {
896 let predicate_vars = self.extract_variables(&predicate);
897 let left_vars = self.collect_output_variables(&join.left);
898 let right_vars = self.collect_output_variables(&join.right);
899
900 let uses_left = predicate_vars.iter().any(|v| left_vars.contains(v));
901 let uses_right = predicate_vars.iter().any(|v| right_vars.contains(v));
902
903 if uses_left && !uses_right {
904 join.left = Box::new(self.try_push_filter_into(predicate, *join.left));
906 LogicalOperator::Join(join)
907 } else if uses_right && !uses_left {
908 join.right = Box::new(self.try_push_filter_into(predicate, *join.right));
910 LogicalOperator::Join(join)
911 } else {
912 LogicalOperator::Filter(FilterOp {
914 predicate,
915 pushdown_hint: None,
916 input: Box::new(LogicalOperator::Join(join)),
917 })
918 }
919 }
920
921 LogicalOperator::Aggregate(agg) => LogicalOperator::Filter(FilterOp {
923 predicate,
924 pushdown_hint: None,
925 input: Box::new(LogicalOperator::Aggregate(agg)),
926 }),
927
928 LogicalOperator::NodeScan(scan) => LogicalOperator::Filter(FilterOp {
930 predicate,
931 pushdown_hint: None,
932 input: Box::new(LogicalOperator::NodeScan(scan)),
933 }),
934
935 other => LogicalOperator::Filter(FilterOp {
937 predicate,
938 pushdown_hint: None,
939 input: Box::new(other),
940 }),
941 }
942 }
943
944 fn collect_output_variables(&self, op: &LogicalOperator) -> HashSet<String> {
946 let mut vars = HashSet::new();
947 Self::collect_output_variables_recursive(op, &mut vars);
948 vars
949 }
950
951 fn collect_output_variables_recursive(op: &LogicalOperator, vars: &mut HashSet<String>) {
953 match op {
954 LogicalOperator::NodeScan(scan) => {
955 vars.insert(scan.variable.clone());
956 }
957 LogicalOperator::EdgeScan(scan) => {
958 vars.insert(scan.variable.clone());
959 }
960 LogicalOperator::Expand(expand) => {
961 vars.insert(expand.to_variable.clone());
962 if let Some(edge_var) = &expand.edge_variable {
963 vars.insert(edge_var.clone());
964 }
965 Self::collect_output_variables_recursive(&expand.input, vars);
966 }
967 LogicalOperator::Filter(filter) => {
968 Self::collect_output_variables_recursive(&filter.input, vars);
969 }
970 LogicalOperator::Project(proj) => {
971 for p in &proj.projections {
972 if let Some(alias) = &p.alias {
973 vars.insert(alias.clone());
974 }
975 }
976 Self::collect_output_variables_recursive(&proj.input, vars);
977 }
978 LogicalOperator::Join(join) => {
979 Self::collect_output_variables_recursive(&join.left, vars);
980 Self::collect_output_variables_recursive(&join.right, vars);
981 }
982 LogicalOperator::Aggregate(agg) => {
983 for expr in &agg.group_by {
984 Self::collect_variables(expr, vars);
985 }
986 for agg_expr in &agg.aggregates {
987 if let Some(alias) = &agg_expr.alias {
988 vars.insert(alias.clone());
989 }
990 }
991 }
992 LogicalOperator::Return(ret) => {
993 Self::collect_output_variables_recursive(&ret.input, vars);
994 }
995 LogicalOperator::Limit(limit) => {
996 Self::collect_output_variables_recursive(&limit.input, vars);
997 }
998 LogicalOperator::Skip(skip) => {
999 Self::collect_output_variables_recursive(&skip.input, vars);
1000 }
1001 LogicalOperator::Sort(sort) => {
1002 Self::collect_output_variables_recursive(&sort.input, vars);
1003 }
1004 LogicalOperator::Distinct(distinct) => {
1005 Self::collect_output_variables_recursive(&distinct.input, vars);
1006 }
1007 _ => {}
1008 }
1009 }
1010
1011 fn extract_variables(&self, expr: &LogicalExpression) -> HashSet<String> {
1013 let mut vars = HashSet::new();
1014 Self::collect_variables(expr, &mut vars);
1015 vars
1016 }
1017
1018 fn collect_variables(expr: &LogicalExpression, vars: &mut HashSet<String>) {
1020 match expr {
1021 LogicalExpression::Variable(name) => {
1022 vars.insert(name.clone());
1023 }
1024 LogicalExpression::Property { variable, .. } => {
1025 vars.insert(variable.clone());
1026 }
1027 LogicalExpression::Binary { left, right, .. } => {
1028 Self::collect_variables(left, vars);
1029 Self::collect_variables(right, vars);
1030 }
1031 LogicalExpression::Unary { operand, .. } => {
1032 Self::collect_variables(operand, vars);
1033 }
1034 LogicalExpression::FunctionCall { args, .. } => {
1035 for arg in args {
1036 Self::collect_variables(arg, vars);
1037 }
1038 }
1039 LogicalExpression::List(items) => {
1040 for item in items {
1041 Self::collect_variables(item, vars);
1042 }
1043 }
1044 LogicalExpression::Map(pairs) => {
1045 for (_, value) in pairs {
1046 Self::collect_variables(value, vars);
1047 }
1048 }
1049 LogicalExpression::IndexAccess { base, index } => {
1050 Self::collect_variables(base, vars);
1051 Self::collect_variables(index, vars);
1052 }
1053 LogicalExpression::SliceAccess { base, start, end } => {
1054 Self::collect_variables(base, vars);
1055 if let Some(s) = start {
1056 Self::collect_variables(s, vars);
1057 }
1058 if let Some(e) = end {
1059 Self::collect_variables(e, vars);
1060 }
1061 }
1062 LogicalExpression::Case {
1063 operand,
1064 when_clauses,
1065 else_clause,
1066 } => {
1067 if let Some(op) = operand {
1068 Self::collect_variables(op, vars);
1069 }
1070 for (cond, result) in when_clauses {
1071 Self::collect_variables(cond, vars);
1072 Self::collect_variables(result, vars);
1073 }
1074 if let Some(else_expr) = else_clause {
1075 Self::collect_variables(else_expr, vars);
1076 }
1077 }
1078 LogicalExpression::Labels(var)
1079 | LogicalExpression::Type(var)
1080 | LogicalExpression::Id(var) => {
1081 vars.insert(var.clone());
1082 }
1083 LogicalExpression::Literal(_) | LogicalExpression::Parameter(_) => {}
1084 LogicalExpression::ListComprehension {
1085 list_expr,
1086 filter_expr,
1087 map_expr,
1088 ..
1089 } => {
1090 Self::collect_variables(list_expr, vars);
1091 if let Some(filter) = filter_expr {
1092 Self::collect_variables(filter, vars);
1093 }
1094 Self::collect_variables(map_expr, vars);
1095 }
1096 LogicalExpression::ListPredicate {
1097 list_expr,
1098 predicate,
1099 ..
1100 } => {
1101 Self::collect_variables(list_expr, vars);
1102 Self::collect_variables(predicate, vars);
1103 }
1104 LogicalExpression::ExistsSubquery(_)
1105 | LogicalExpression::CountSubquery(_)
1106 | LogicalExpression::ValueSubquery(_) => {
1107 }
1109 LogicalExpression::PatternComprehension { projection, .. } => {
1110 Self::collect_variables(projection, vars);
1111 }
1112 LogicalExpression::MapProjection { base, entries } => {
1113 vars.insert(base.clone());
1114 for entry in entries {
1115 if let crate::query::plan::MapProjectionEntry::LiteralEntry(_, expr) = entry {
1116 Self::collect_variables(expr, vars);
1117 }
1118 }
1119 }
1120 LogicalExpression::Reduce {
1121 initial,
1122 list,
1123 expression,
1124 ..
1125 } => {
1126 Self::collect_variables(initial, vars);
1127 Self::collect_variables(list, vars);
1128 Self::collect_variables(expression, vars);
1129 }
1130 }
1131 }
1132
1133 fn extract_projection_aliases(
1135 &self,
1136 projections: &[crate::query::plan::Projection],
1137 ) -> HashSet<String> {
1138 projections.iter().filter_map(|p| p.alias.clone()).collect()
1139 }
1140}
1141
1142impl Default for Optimizer {
1143 fn default() -> Self {
1144 Self::new()
1145 }
1146}
1147
1148#[cfg(test)]
1149mod tests {
1150 use super::*;
1151 use crate::query::plan::{
1152 AggregateExpr, AggregateFunction, AggregateOp, BinaryOp, DistinctOp, ExpandDirection,
1153 ExpandOp, JoinOp, JoinType, LimitOp, NodeScanOp, PathMode, ProjectOp, Projection,
1154 ReturnItem, ReturnOp, SkipOp, SortKey, SortOp, SortOrder, UnaryOp,
1155 };
1156 use grafeo_common::types::Value;
1157
1158 #[test]
1159 fn test_optimizer_filter_pushdown_simple() {
1160 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1165 items: vec![ReturnItem {
1166 expression: LogicalExpression::Variable("n".to_string()),
1167 alias: None,
1168 }],
1169 distinct: false,
1170 input: Box::new(LogicalOperator::Filter(FilterOp {
1171 predicate: LogicalExpression::Binary {
1172 left: Box::new(LogicalExpression::Property {
1173 variable: "n".to_string(),
1174 property: "age".to_string(),
1175 }),
1176 op: BinaryOp::Gt,
1177 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1178 },
1179 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1180 variable: "n".to_string(),
1181 label: Some("Person".to_string()),
1182 input: None,
1183 })),
1184 pushdown_hint: None,
1185 })),
1186 }));
1187
1188 let optimizer = Optimizer::new();
1189 let optimized = optimizer.optimize(plan).unwrap();
1190
1191 if let LogicalOperator::Return(ret) = &optimized.root
1193 && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1194 && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1195 {
1196 assert_eq!(scan.variable, "n");
1197 return;
1198 }
1199 panic!("Expected Return -> Filter -> NodeScan structure");
1200 }
1201
1202 #[test]
1203 fn test_optimizer_filter_pushdown_through_expand() {
1204 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1208 items: vec![ReturnItem {
1209 expression: LogicalExpression::Variable("b".to_string()),
1210 alias: None,
1211 }],
1212 distinct: false,
1213 input: Box::new(LogicalOperator::Filter(FilterOp {
1214 predicate: LogicalExpression::Binary {
1215 left: Box::new(LogicalExpression::Property {
1216 variable: "a".to_string(),
1217 property: "age".to_string(),
1218 }),
1219 op: BinaryOp::Gt,
1220 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1221 },
1222 pushdown_hint: None,
1223 input: Box::new(LogicalOperator::Expand(ExpandOp {
1224 from_variable: "a".to_string(),
1225 to_variable: "b".to_string(),
1226 edge_variable: None,
1227 direction: ExpandDirection::Outgoing,
1228 edge_types: vec!["KNOWS".to_string()],
1229 min_hops: 1,
1230 max_hops: Some(1),
1231 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1232 variable: "a".to_string(),
1233 label: Some("Person".to_string()),
1234 input: None,
1235 })),
1236 path_alias: None,
1237 path_mode: PathMode::Walk,
1238 })),
1239 })),
1240 }));
1241
1242 let optimizer = Optimizer::new();
1243 let optimized = optimizer.optimize(plan).unwrap();
1244
1245 if let LogicalOperator::Return(ret) = &optimized.root
1248 && let LogicalOperator::Expand(expand) = ret.input.as_ref()
1249 && let LogicalOperator::Filter(filter) = expand.input.as_ref()
1250 && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1251 {
1252 assert_eq!(scan.variable, "a");
1253 assert_eq!(expand.from_variable, "a");
1254 assert_eq!(expand.to_variable, "b");
1255 return;
1256 }
1257 panic!("Expected Return -> Expand -> Filter -> NodeScan structure");
1258 }
1259
1260 #[test]
1261 fn test_optimizer_filter_not_pushed_through_expand_for_target_var() {
1262 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1266 items: vec![ReturnItem {
1267 expression: LogicalExpression::Variable("a".to_string()),
1268 alias: None,
1269 }],
1270 distinct: false,
1271 input: Box::new(LogicalOperator::Filter(FilterOp {
1272 predicate: LogicalExpression::Binary {
1273 left: Box::new(LogicalExpression::Property {
1274 variable: "b".to_string(),
1275 property: "age".to_string(),
1276 }),
1277 op: BinaryOp::Gt,
1278 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1279 },
1280 pushdown_hint: None,
1281 input: Box::new(LogicalOperator::Expand(ExpandOp {
1282 from_variable: "a".to_string(),
1283 to_variable: "b".to_string(),
1284 edge_variable: None,
1285 direction: ExpandDirection::Outgoing,
1286 edge_types: vec!["KNOWS".to_string()],
1287 min_hops: 1,
1288 max_hops: Some(1),
1289 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1290 variable: "a".to_string(),
1291 label: Some("Person".to_string()),
1292 input: None,
1293 })),
1294 path_alias: None,
1295 path_mode: PathMode::Walk,
1296 })),
1297 })),
1298 }));
1299
1300 let optimizer = Optimizer::new();
1301 let optimized = optimizer.optimize(plan).unwrap();
1302
1303 if let LogicalOperator::Return(ret) = &optimized.root
1306 && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1307 {
1308 if let LogicalExpression::Binary { left, .. } = &filter.predicate
1310 && let LogicalExpression::Property { variable, .. } = left.as_ref()
1311 {
1312 assert_eq!(variable, "b");
1313 }
1314
1315 if let LogicalOperator::Expand(expand) = filter.input.as_ref()
1316 && let LogicalOperator::NodeScan(_) = expand.input.as_ref()
1317 {
1318 return;
1319 }
1320 }
1321 panic!("Expected Return -> Filter -> Expand -> NodeScan structure");
1322 }
1323
1324 #[test]
1325 fn test_optimizer_extract_variables() {
1326 let optimizer = Optimizer::new();
1327
1328 let expr = LogicalExpression::Binary {
1329 left: Box::new(LogicalExpression::Property {
1330 variable: "n".to_string(),
1331 property: "age".to_string(),
1332 }),
1333 op: BinaryOp::Gt,
1334 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1335 };
1336
1337 let vars = optimizer.extract_variables(&expr);
1338 assert_eq!(vars.len(), 1);
1339 assert!(vars.contains("n"));
1340 }
1341
1342 #[test]
1345 fn test_optimizer_default() {
1346 let optimizer = Optimizer::default();
1347 let plan = LogicalPlan::new(LogicalOperator::Empty);
1349 let result = optimizer.optimize(plan);
1350 assert!(result.is_ok());
1351 }
1352
1353 #[test]
1354 fn test_optimizer_with_filter_pushdown_disabled() {
1355 let optimizer = Optimizer::new().with_filter_pushdown(false);
1356
1357 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1358 items: vec![ReturnItem {
1359 expression: LogicalExpression::Variable("n".to_string()),
1360 alias: None,
1361 }],
1362 distinct: false,
1363 input: Box::new(LogicalOperator::Filter(FilterOp {
1364 predicate: LogicalExpression::Literal(Value::Bool(true)),
1365 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1366 variable: "n".to_string(),
1367 label: None,
1368 input: None,
1369 })),
1370 pushdown_hint: None,
1371 })),
1372 }));
1373
1374 let optimized = optimizer.optimize(plan).unwrap();
1375 if let LogicalOperator::Return(ret) = &optimized.root
1377 && let LogicalOperator::Filter(_) = ret.input.as_ref()
1378 {
1379 return;
1380 }
1381 panic!("Expected unchanged structure");
1382 }
1383
1384 #[test]
1385 fn test_optimizer_with_join_reorder_disabled() {
1386 let optimizer = Optimizer::new().with_join_reorder(false);
1387 assert!(
1388 optimizer
1389 .optimize(LogicalPlan::new(LogicalOperator::Empty))
1390 .is_ok()
1391 );
1392 }
1393
1394 #[test]
1395 fn test_optimizer_with_cost_model() {
1396 let cost_model = CostModel::new();
1397 let optimizer = Optimizer::new().with_cost_model(cost_model);
1398 assert!(
1399 optimizer
1400 .cost_model()
1401 .estimate(&LogicalOperator::Empty, 0.0)
1402 .total()
1403 < 0.001
1404 );
1405 }
1406
1407 #[test]
1408 fn test_optimizer_with_cardinality_estimator() {
1409 let mut estimator = CardinalityEstimator::new();
1410 estimator.add_table_stats("Test", TableStats::new(500));
1411 let optimizer = Optimizer::new().with_cardinality_estimator(estimator);
1412
1413 let scan = LogicalOperator::NodeScan(NodeScanOp {
1414 variable: "n".to_string(),
1415 label: Some("Test".to_string()),
1416 input: None,
1417 });
1418 let plan = LogicalPlan::new(scan);
1419
1420 let cardinality = optimizer.estimate_cardinality(&plan);
1421 assert!((cardinality - 500.0).abs() < 0.001);
1422 }
1423
1424 #[test]
1425 fn test_optimizer_estimate_cost() {
1426 let optimizer = Optimizer::new();
1427 let plan = LogicalPlan::new(LogicalOperator::NodeScan(NodeScanOp {
1428 variable: "n".to_string(),
1429 label: None,
1430 input: None,
1431 }));
1432
1433 let cost = optimizer.estimate_cost(&plan);
1434 assert!(cost.total() > 0.0);
1435 }
1436
1437 #[test]
1440 fn test_filter_pushdown_through_project() {
1441 let optimizer = Optimizer::new();
1442
1443 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1444 predicate: LogicalExpression::Binary {
1445 left: Box::new(LogicalExpression::Property {
1446 variable: "n".to_string(),
1447 property: "age".to_string(),
1448 }),
1449 op: BinaryOp::Gt,
1450 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1451 },
1452 pushdown_hint: None,
1453 input: Box::new(LogicalOperator::Project(ProjectOp {
1454 projections: vec![Projection {
1455 expression: LogicalExpression::Variable("n".to_string()),
1456 alias: None,
1457 }],
1458 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1459 variable: "n".to_string(),
1460 label: None,
1461 input: None,
1462 })),
1463 pass_through_input: false,
1464 })),
1465 }));
1466
1467 let optimized = optimizer.optimize(plan).unwrap();
1468
1469 if let LogicalOperator::Project(proj) = &optimized.root
1471 && let LogicalOperator::Filter(_) = proj.input.as_ref()
1472 {
1473 return;
1474 }
1475 panic!("Expected Project -> Filter structure");
1476 }
1477
1478 #[test]
1479 fn test_filter_not_pushed_through_project_with_alias() {
1480 let optimizer = Optimizer::new();
1481
1482 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1484 predicate: LogicalExpression::Binary {
1485 left: Box::new(LogicalExpression::Variable("x".to_string())),
1486 op: BinaryOp::Gt,
1487 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1488 },
1489 pushdown_hint: None,
1490 input: Box::new(LogicalOperator::Project(ProjectOp {
1491 projections: vec![Projection {
1492 expression: LogicalExpression::Property {
1493 variable: "n".to_string(),
1494 property: "age".to_string(),
1495 },
1496 alias: Some("x".to_string()),
1497 }],
1498 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1499 variable: "n".to_string(),
1500 label: None,
1501 input: None,
1502 })),
1503 pass_through_input: false,
1504 })),
1505 }));
1506
1507 let optimized = optimizer.optimize(plan).unwrap();
1508
1509 if let LogicalOperator::Filter(filter) = &optimized.root
1511 && let LogicalOperator::Project(_) = filter.input.as_ref()
1512 {
1513 return;
1514 }
1515 panic!("Expected Filter -> Project structure");
1516 }
1517
1518 #[test]
1519 fn test_filter_pushdown_through_limit() {
1520 let optimizer = Optimizer::new();
1521
1522 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1523 predicate: LogicalExpression::Literal(Value::Bool(true)),
1524 pushdown_hint: None,
1525 input: Box::new(LogicalOperator::Limit(LimitOp {
1526 count: 10.into(),
1527 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1528 variable: "n".to_string(),
1529 label: None,
1530 input: None,
1531 })),
1532 })),
1533 }));
1534
1535 let optimized = optimizer.optimize(plan).unwrap();
1536
1537 if let LogicalOperator::Filter(filter) = &optimized.root
1539 && let LogicalOperator::Limit(_) = filter.input.as_ref()
1540 {
1541 return;
1542 }
1543 panic!("Expected Filter -> Limit structure");
1544 }
1545
1546 #[test]
1547 fn test_filter_pushdown_through_sort() {
1548 let optimizer = Optimizer::new();
1549
1550 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1551 predicate: LogicalExpression::Literal(Value::Bool(true)),
1552 pushdown_hint: None,
1553 input: Box::new(LogicalOperator::Sort(SortOp {
1554 keys: vec![SortKey {
1555 expression: LogicalExpression::Variable("n".to_string()),
1556 order: SortOrder::Ascending,
1557 nulls: None,
1558 }],
1559 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1560 variable: "n".to_string(),
1561 label: None,
1562 input: None,
1563 })),
1564 })),
1565 }));
1566
1567 let optimized = optimizer.optimize(plan).unwrap();
1568
1569 if let LogicalOperator::Filter(filter) = &optimized.root
1571 && let LogicalOperator::Sort(_) = filter.input.as_ref()
1572 {
1573 return;
1574 }
1575 panic!("Expected Filter -> Sort structure");
1576 }
1577
1578 #[test]
1579 fn test_filter_pushdown_through_distinct() {
1580 let optimizer = Optimizer::new();
1581
1582 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1583 predicate: LogicalExpression::Literal(Value::Bool(true)),
1584 pushdown_hint: None,
1585 input: Box::new(LogicalOperator::Distinct(DistinctOp {
1586 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1587 variable: "n".to_string(),
1588 label: None,
1589 input: None,
1590 })),
1591 columns: None,
1592 })),
1593 }));
1594
1595 let optimized = optimizer.optimize(plan).unwrap();
1596
1597 if let LogicalOperator::Filter(filter) = &optimized.root
1599 && let LogicalOperator::Distinct(_) = filter.input.as_ref()
1600 {
1601 return;
1602 }
1603 panic!("Expected Filter -> Distinct structure");
1604 }
1605
1606 #[test]
1607 fn test_filter_not_pushed_through_aggregate() {
1608 let optimizer = Optimizer::new();
1609
1610 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1611 predicate: LogicalExpression::Binary {
1612 left: Box::new(LogicalExpression::Variable("cnt".to_string())),
1613 op: BinaryOp::Gt,
1614 right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1615 },
1616 pushdown_hint: None,
1617 input: Box::new(LogicalOperator::Aggregate(AggregateOp {
1618 group_by: vec![],
1619 aggregates: vec![AggregateExpr {
1620 function: AggregateFunction::Count,
1621 expression: None,
1622 expression2: None,
1623 distinct: false,
1624 alias: Some("cnt".to_string()),
1625 percentile: None,
1626 separator: None,
1627 }],
1628 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1629 variable: "n".to_string(),
1630 label: None,
1631 input: None,
1632 })),
1633 having: None,
1634 })),
1635 }));
1636
1637 let optimized = optimizer.optimize(plan).unwrap();
1638
1639 if let LogicalOperator::Filter(filter) = &optimized.root
1641 && let LogicalOperator::Aggregate(_) = filter.input.as_ref()
1642 {
1643 return;
1644 }
1645 panic!("Expected Filter -> Aggregate structure");
1646 }
1647
1648 #[test]
1649 fn test_filter_pushdown_to_left_join_side() {
1650 let optimizer = Optimizer::new();
1651
1652 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1654 predicate: LogicalExpression::Binary {
1655 left: Box::new(LogicalExpression::Property {
1656 variable: "a".to_string(),
1657 property: "age".to_string(),
1658 }),
1659 op: BinaryOp::Gt,
1660 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1661 },
1662 pushdown_hint: None,
1663 input: Box::new(LogicalOperator::Join(JoinOp {
1664 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1665 variable: "a".to_string(),
1666 label: Some("Person".to_string()),
1667 input: None,
1668 })),
1669 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1670 variable: "b".to_string(),
1671 label: Some("Company".to_string()),
1672 input: None,
1673 })),
1674 join_type: JoinType::Inner,
1675 conditions: vec![],
1676 })),
1677 }));
1678
1679 let optimized = optimizer.optimize(plan).unwrap();
1680
1681 if let LogicalOperator::Join(join) = &optimized.root
1683 && let LogicalOperator::Filter(_) = join.left.as_ref()
1684 {
1685 return;
1686 }
1687 panic!("Expected Join with Filter on left side");
1688 }
1689
1690 #[test]
1691 fn test_filter_pushdown_to_right_join_side() {
1692 let optimizer = Optimizer::new();
1693
1694 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1696 predicate: LogicalExpression::Binary {
1697 left: Box::new(LogicalExpression::Property {
1698 variable: "b".to_string(),
1699 property: "name".to_string(),
1700 }),
1701 op: BinaryOp::Eq,
1702 right: Box::new(LogicalExpression::Literal(Value::String("Acme".into()))),
1703 },
1704 pushdown_hint: None,
1705 input: Box::new(LogicalOperator::Join(JoinOp {
1706 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1707 variable: "a".to_string(),
1708 label: Some("Person".to_string()),
1709 input: None,
1710 })),
1711 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1712 variable: "b".to_string(),
1713 label: Some("Company".to_string()),
1714 input: None,
1715 })),
1716 join_type: JoinType::Inner,
1717 conditions: vec![],
1718 })),
1719 }));
1720
1721 let optimized = optimizer.optimize(plan).unwrap();
1722
1723 if let LogicalOperator::Join(join) = &optimized.root
1725 && let LogicalOperator::Filter(_) = join.right.as_ref()
1726 {
1727 return;
1728 }
1729 panic!("Expected Join with Filter on right side");
1730 }
1731
1732 #[test]
1733 fn test_filter_not_pushed_when_uses_both_join_sides() {
1734 let optimizer = Optimizer::new();
1735
1736 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1738 predicate: LogicalExpression::Binary {
1739 left: Box::new(LogicalExpression::Property {
1740 variable: "a".to_string(),
1741 property: "id".to_string(),
1742 }),
1743 op: BinaryOp::Eq,
1744 right: Box::new(LogicalExpression::Property {
1745 variable: "b".to_string(),
1746 property: "a_id".to_string(),
1747 }),
1748 },
1749 pushdown_hint: None,
1750 input: Box::new(LogicalOperator::Join(JoinOp {
1751 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1752 variable: "a".to_string(),
1753 label: None,
1754 input: None,
1755 })),
1756 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1757 variable: "b".to_string(),
1758 label: None,
1759 input: None,
1760 })),
1761 join_type: JoinType::Inner,
1762 conditions: vec![],
1763 })),
1764 }));
1765
1766 let optimized = optimizer.optimize(plan).unwrap();
1767
1768 if let LogicalOperator::Filter(filter) = &optimized.root
1770 && let LogicalOperator::Join(_) = filter.input.as_ref()
1771 {
1772 return;
1773 }
1774 panic!("Expected Filter -> Join structure");
1775 }
1776
1777 #[test]
1780 fn test_extract_variables_from_variable() {
1781 let optimizer = Optimizer::new();
1782 let expr = LogicalExpression::Variable("x".to_string());
1783 let vars = optimizer.extract_variables(&expr);
1784 assert_eq!(vars.len(), 1);
1785 assert!(vars.contains("x"));
1786 }
1787
1788 #[test]
1789 fn test_extract_variables_from_unary() {
1790 let optimizer = Optimizer::new();
1791 let expr = LogicalExpression::Unary {
1792 op: UnaryOp::Not,
1793 operand: Box::new(LogicalExpression::Variable("x".to_string())),
1794 };
1795 let vars = optimizer.extract_variables(&expr);
1796 assert_eq!(vars.len(), 1);
1797 assert!(vars.contains("x"));
1798 }
1799
1800 #[test]
1801 fn test_extract_variables_from_function_call() {
1802 let optimizer = Optimizer::new();
1803 let expr = LogicalExpression::FunctionCall {
1804 name: "length".to_string(),
1805 args: vec![
1806 LogicalExpression::Variable("a".to_string()),
1807 LogicalExpression::Variable("b".to_string()),
1808 ],
1809 distinct: false,
1810 };
1811 let vars = optimizer.extract_variables(&expr);
1812 assert_eq!(vars.len(), 2);
1813 assert!(vars.contains("a"));
1814 assert!(vars.contains("b"));
1815 }
1816
1817 #[test]
1818 fn test_extract_variables_from_list() {
1819 let optimizer = Optimizer::new();
1820 let expr = LogicalExpression::List(vec![
1821 LogicalExpression::Variable("a".to_string()),
1822 LogicalExpression::Literal(Value::Int64(1)),
1823 LogicalExpression::Variable("b".to_string()),
1824 ]);
1825 let vars = optimizer.extract_variables(&expr);
1826 assert_eq!(vars.len(), 2);
1827 assert!(vars.contains("a"));
1828 assert!(vars.contains("b"));
1829 }
1830
1831 #[test]
1832 fn test_extract_variables_from_map() {
1833 let optimizer = Optimizer::new();
1834 let expr = LogicalExpression::Map(vec![
1835 (
1836 "key1".to_string(),
1837 LogicalExpression::Variable("a".to_string()),
1838 ),
1839 (
1840 "key2".to_string(),
1841 LogicalExpression::Variable("b".to_string()),
1842 ),
1843 ]);
1844 let vars = optimizer.extract_variables(&expr);
1845 assert_eq!(vars.len(), 2);
1846 assert!(vars.contains("a"));
1847 assert!(vars.contains("b"));
1848 }
1849
1850 #[test]
1851 fn test_extract_variables_from_index_access() {
1852 let optimizer = Optimizer::new();
1853 let expr = LogicalExpression::IndexAccess {
1854 base: Box::new(LogicalExpression::Variable("list".to_string())),
1855 index: Box::new(LogicalExpression::Variable("idx".to_string())),
1856 };
1857 let vars = optimizer.extract_variables(&expr);
1858 assert_eq!(vars.len(), 2);
1859 assert!(vars.contains("list"));
1860 assert!(vars.contains("idx"));
1861 }
1862
1863 #[test]
1864 fn test_extract_variables_from_slice_access() {
1865 let optimizer = Optimizer::new();
1866 let expr = LogicalExpression::SliceAccess {
1867 base: Box::new(LogicalExpression::Variable("list".to_string())),
1868 start: Some(Box::new(LogicalExpression::Variable("s".to_string()))),
1869 end: Some(Box::new(LogicalExpression::Variable("e".to_string()))),
1870 };
1871 let vars = optimizer.extract_variables(&expr);
1872 assert_eq!(vars.len(), 3);
1873 assert!(vars.contains("list"));
1874 assert!(vars.contains("s"));
1875 assert!(vars.contains("e"));
1876 }
1877
1878 #[test]
1879 fn test_extract_variables_from_case() {
1880 let optimizer = Optimizer::new();
1881 let expr = LogicalExpression::Case {
1882 operand: Some(Box::new(LogicalExpression::Variable("x".to_string()))),
1883 when_clauses: vec![(
1884 LogicalExpression::Literal(Value::Int64(1)),
1885 LogicalExpression::Variable("a".to_string()),
1886 )],
1887 else_clause: Some(Box::new(LogicalExpression::Variable("b".to_string()))),
1888 };
1889 let vars = optimizer.extract_variables(&expr);
1890 assert_eq!(vars.len(), 3);
1891 assert!(vars.contains("x"));
1892 assert!(vars.contains("a"));
1893 assert!(vars.contains("b"));
1894 }
1895
1896 #[test]
1897 fn test_extract_variables_from_labels() {
1898 let optimizer = Optimizer::new();
1899 let expr = LogicalExpression::Labels("n".to_string());
1900 let vars = optimizer.extract_variables(&expr);
1901 assert_eq!(vars.len(), 1);
1902 assert!(vars.contains("n"));
1903 }
1904
1905 #[test]
1906 fn test_extract_variables_from_type() {
1907 let optimizer = Optimizer::new();
1908 let expr = LogicalExpression::Type("e".to_string());
1909 let vars = optimizer.extract_variables(&expr);
1910 assert_eq!(vars.len(), 1);
1911 assert!(vars.contains("e"));
1912 }
1913
1914 #[test]
1915 fn test_extract_variables_from_id() {
1916 let optimizer = Optimizer::new();
1917 let expr = LogicalExpression::Id("n".to_string());
1918 let vars = optimizer.extract_variables(&expr);
1919 assert_eq!(vars.len(), 1);
1920 assert!(vars.contains("n"));
1921 }
1922
1923 #[test]
1924 fn test_extract_variables_from_list_comprehension() {
1925 let optimizer = Optimizer::new();
1926 let expr = LogicalExpression::ListComprehension {
1927 variable: "x".to_string(),
1928 list_expr: Box::new(LogicalExpression::Variable("items".to_string())),
1929 filter_expr: Some(Box::new(LogicalExpression::Variable("pred".to_string()))),
1930 map_expr: Box::new(LogicalExpression::Variable("result".to_string())),
1931 };
1932 let vars = optimizer.extract_variables(&expr);
1933 assert!(vars.contains("items"));
1934 assert!(vars.contains("pred"));
1935 assert!(vars.contains("result"));
1936 }
1937
1938 #[test]
1939 fn test_extract_variables_from_literal_and_parameter() {
1940 let optimizer = Optimizer::new();
1941
1942 let literal = LogicalExpression::Literal(Value::Int64(42));
1943 assert!(optimizer.extract_variables(&literal).is_empty());
1944
1945 let param = LogicalExpression::Parameter("p".to_string());
1946 assert!(optimizer.extract_variables(¶m).is_empty());
1947 }
1948
1949 #[test]
1952 fn test_recursive_filter_pushdown_through_skip() {
1953 let optimizer = Optimizer::new();
1954
1955 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1956 items: vec![ReturnItem {
1957 expression: LogicalExpression::Variable("n".to_string()),
1958 alias: None,
1959 }],
1960 distinct: false,
1961 input: Box::new(LogicalOperator::Filter(FilterOp {
1962 predicate: LogicalExpression::Literal(Value::Bool(true)),
1963 pushdown_hint: None,
1964 input: Box::new(LogicalOperator::Skip(SkipOp {
1965 count: 5.into(),
1966 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1967 variable: "n".to_string(),
1968 label: None,
1969 input: None,
1970 })),
1971 })),
1972 })),
1973 }));
1974
1975 let optimized = optimizer.optimize(plan).unwrap();
1976
1977 assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1979 }
1980
1981 #[test]
1982 fn test_nested_filter_pushdown() {
1983 let optimizer = Optimizer::new();
1984
1985 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1987 items: vec![ReturnItem {
1988 expression: LogicalExpression::Variable("n".to_string()),
1989 alias: None,
1990 }],
1991 distinct: false,
1992 input: Box::new(LogicalOperator::Filter(FilterOp {
1993 predicate: LogicalExpression::Binary {
1994 left: Box::new(LogicalExpression::Property {
1995 variable: "n".to_string(),
1996 property: "x".to_string(),
1997 }),
1998 op: BinaryOp::Gt,
1999 right: Box::new(LogicalExpression::Literal(Value::Int64(1))),
2000 },
2001 pushdown_hint: None,
2002 input: Box::new(LogicalOperator::Filter(FilterOp {
2003 predicate: LogicalExpression::Binary {
2004 left: Box::new(LogicalExpression::Property {
2005 variable: "n".to_string(),
2006 property: "y".to_string(),
2007 }),
2008 op: BinaryOp::Lt,
2009 right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
2010 },
2011 pushdown_hint: None,
2012 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2013 variable: "n".to_string(),
2014 label: None,
2015 input: None,
2016 })),
2017 })),
2018 })),
2019 }));
2020
2021 let optimized = optimizer.optimize(plan).unwrap();
2022 assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
2023 }
2024
2025 #[test]
2026 fn test_cyclic_join_produces_multi_way_join() {
2027 use crate::query::plan::JoinCondition;
2028
2029 let scan_a = LogicalOperator::NodeScan(NodeScanOp {
2031 variable: "a".to_string(),
2032 label: Some("Person".to_string()),
2033 input: None,
2034 });
2035 let scan_b = LogicalOperator::NodeScan(NodeScanOp {
2036 variable: "b".to_string(),
2037 label: Some("Person".to_string()),
2038 input: None,
2039 });
2040 let scan_c = LogicalOperator::NodeScan(NodeScanOp {
2041 variable: "c".to_string(),
2042 label: Some("Person".to_string()),
2043 input: None,
2044 });
2045
2046 let join_ab = LogicalOperator::Join(JoinOp {
2048 left: Box::new(scan_a),
2049 right: Box::new(scan_b),
2050 join_type: JoinType::Inner,
2051 conditions: vec![JoinCondition {
2052 left: LogicalExpression::Variable("a".to_string()),
2053 right: LogicalExpression::Variable("b".to_string()),
2054 }],
2055 });
2056
2057 let join_abc = LogicalOperator::Join(JoinOp {
2058 left: Box::new(join_ab),
2059 right: Box::new(scan_c),
2060 join_type: JoinType::Inner,
2061 conditions: vec![
2062 JoinCondition {
2063 left: LogicalExpression::Variable("b".to_string()),
2064 right: LogicalExpression::Variable("c".to_string()),
2065 },
2066 JoinCondition {
2067 left: LogicalExpression::Variable("c".to_string()),
2068 right: LogicalExpression::Variable("a".to_string()),
2069 },
2070 ],
2071 });
2072
2073 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
2074 items: vec![ReturnItem {
2075 expression: LogicalExpression::Variable("a".to_string()),
2076 alias: None,
2077 }],
2078 distinct: false,
2079 input: Box::new(join_abc),
2080 }));
2081
2082 let mut optimizer = Optimizer::new();
2083 optimizer
2084 .card_estimator
2085 .add_table_stats("Person", cardinality::TableStats::new(1000));
2086
2087 let optimized = optimizer.optimize(plan).unwrap();
2088
2089 fn has_multi_way_join(op: &LogicalOperator) -> bool {
2091 match op {
2092 LogicalOperator::MultiWayJoin(_) => true,
2093 LogicalOperator::Return(ret) => has_multi_way_join(&ret.input),
2094 LogicalOperator::Filter(f) => has_multi_way_join(&f.input),
2095 LogicalOperator::Project(p) => has_multi_way_join(&p.input),
2096 _ => false,
2097 }
2098 }
2099
2100 assert!(
2101 has_multi_way_join(&optimized.root),
2102 "Expected MultiWayJoin for cyclic triangle pattern"
2103 );
2104 }
2105
2106 #[test]
2107 fn test_acyclic_join_uses_binary_joins() {
2108 use crate::query::plan::JoinCondition;
2109
2110 let scan_a = LogicalOperator::NodeScan(NodeScanOp {
2112 variable: "a".to_string(),
2113 label: Some("Person".to_string()),
2114 input: None,
2115 });
2116 let scan_b = LogicalOperator::NodeScan(NodeScanOp {
2117 variable: "b".to_string(),
2118 label: Some("Person".to_string()),
2119 input: None,
2120 });
2121 let scan_c = LogicalOperator::NodeScan(NodeScanOp {
2122 variable: "c".to_string(),
2123 label: Some("Company".to_string()),
2124 input: None,
2125 });
2126
2127 let join_ab = LogicalOperator::Join(JoinOp {
2128 left: Box::new(scan_a),
2129 right: Box::new(scan_b),
2130 join_type: JoinType::Inner,
2131 conditions: vec![JoinCondition {
2132 left: LogicalExpression::Variable("a".to_string()),
2133 right: LogicalExpression::Variable("b".to_string()),
2134 }],
2135 });
2136
2137 let join_abc = LogicalOperator::Join(JoinOp {
2138 left: Box::new(join_ab),
2139 right: Box::new(scan_c),
2140 join_type: JoinType::Inner,
2141 conditions: vec![JoinCondition {
2142 left: LogicalExpression::Variable("b".to_string()),
2143 right: LogicalExpression::Variable("c".to_string()),
2144 }],
2145 });
2146
2147 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
2148 items: vec![ReturnItem {
2149 expression: LogicalExpression::Variable("a".to_string()),
2150 alias: None,
2151 }],
2152 distinct: false,
2153 input: Box::new(join_abc),
2154 }));
2155
2156 let mut optimizer = Optimizer::new();
2157 optimizer
2158 .card_estimator
2159 .add_table_stats("Person", cardinality::TableStats::new(1000));
2160 optimizer
2161 .card_estimator
2162 .add_table_stats("Company", cardinality::TableStats::new(100));
2163
2164 let optimized = optimizer.optimize(plan).unwrap();
2165
2166 fn has_multi_way_join(op: &LogicalOperator) -> bool {
2168 match op {
2169 LogicalOperator::MultiWayJoin(_) => true,
2170 LogicalOperator::Return(ret) => has_multi_way_join(&ret.input),
2171 LogicalOperator::Filter(f) => has_multi_way_join(&f.input),
2172 LogicalOperator::Project(p) => has_multi_way_join(&p.input),
2173 LogicalOperator::Join(j) => {
2174 has_multi_way_join(&j.left) || has_multi_way_join(&j.right)
2175 }
2176 _ => false,
2177 }
2178 }
2179
2180 assert!(
2181 !has_multi_way_join(&optimized.root),
2182 "Acyclic join should NOT produce MultiWayJoin"
2183 );
2184 }
2185}