1pub mod cardinality;
15pub mod cost;
16pub mod join_order;
17
18pub use cardinality::{
19 CardinalityEstimator, ColumnStats, EstimationLog, SelectivityConfig, TableStats,
20};
21pub use cost::{Cost, CostModel};
22pub use join_order::{BitSet, DPccp, JoinGraph, JoinGraphBuilder, JoinPlan};
23
24use crate::query::plan::{
25 FilterOp, JoinCondition, LogicalExpression, LogicalOperator, LogicalPlan, MultiWayJoinOp,
26};
27use grafeo_common::grafeo_debug_span;
28use grafeo_common::utils::error::Result;
29use std::collections::HashSet;
30
31#[derive(Debug, Clone)]
33struct JoinInfo {
34 left_var: String,
35 right_var: String,
36 left_expr: LogicalExpression,
37 right_expr: LogicalExpression,
38}
39
40#[derive(Debug, Clone, PartialEq, Eq, Hash)]
42enum RequiredColumn {
43 Variable(String),
45 Property(String, String),
47}
48
49pub struct Optimizer {
54 enable_filter_pushdown: bool,
56 enable_join_reorder: bool,
58 enable_projection_pushdown: bool,
60 cost_model: CostModel,
62 card_estimator: CardinalityEstimator,
64}
65
66impl Optimizer {
67 #[must_use]
69 pub fn new() -> Self {
70 Self {
71 enable_filter_pushdown: true,
72 enable_join_reorder: true,
73 enable_projection_pushdown: true,
74 cost_model: CostModel::new(),
75 card_estimator: CardinalityEstimator::new(),
76 }
77 }
78
79 #[must_use]
85 pub fn from_store(store: &grafeo_core::graph::lpg::LpgStore) -> Self {
86 store.ensure_statistics_fresh();
87 let stats = store.statistics();
88 Self::from_statistics(&stats)
89 }
90
91 #[must_use]
98 pub fn from_graph_store(store: &dyn grafeo_core::graph::GraphStore) -> Self {
99 let stats = store.statistics();
100 Self::from_statistics(&stats)
101 }
102
103 #[cfg(feature = "rdf")]
108 #[must_use]
109 pub fn from_rdf_statistics(rdf_stats: grafeo_core::statistics::RdfStatistics) -> Self {
110 let total = rdf_stats.total_triples;
111 let estimator = CardinalityEstimator::from_rdf_statistics(rdf_stats);
112 Self {
113 enable_filter_pushdown: true,
114 enable_join_reorder: true,
115 enable_projection_pushdown: true,
116 cost_model: CostModel::new().with_graph_totals(total, total),
117 card_estimator: estimator,
118 }
119 }
120
121 #[must_use]
126 fn from_statistics(stats: &grafeo_core::statistics::Statistics) -> Self {
127 let estimator = CardinalityEstimator::from_statistics(stats);
128
129 let avg_fanout = if stats.total_nodes > 0 {
130 (stats.total_edges as f64 / stats.total_nodes as f64).max(1.0)
131 } else {
132 10.0
133 };
134
135 let edge_type_degrees: std::collections::HashMap<String, (f64, f64)> = stats
136 .edge_types
137 .iter()
138 .map(|(name, et)| (name.clone(), (et.avg_out_degree, et.avg_in_degree)))
139 .collect();
140
141 let label_cardinalities: std::collections::HashMap<String, u64> = stats
142 .labels
143 .iter()
144 .map(|(name, ls)| (name.clone(), ls.node_count))
145 .collect();
146
147 Self {
148 enable_filter_pushdown: true,
149 enable_join_reorder: true,
150 enable_projection_pushdown: true,
151 cost_model: CostModel::new()
152 .with_avg_fanout(avg_fanout)
153 .with_edge_type_degrees(edge_type_degrees)
154 .with_label_cardinalities(label_cardinalities)
155 .with_graph_totals(stats.total_nodes, stats.total_edges),
156 card_estimator: estimator,
157 }
158 }
159
160 pub fn with_filter_pushdown(mut self, enabled: bool) -> Self {
162 self.enable_filter_pushdown = enabled;
163 self
164 }
165
166 pub fn with_join_reorder(mut self, enabled: bool) -> Self {
168 self.enable_join_reorder = enabled;
169 self
170 }
171
172 pub fn with_projection_pushdown(mut self, enabled: bool) -> Self {
174 self.enable_projection_pushdown = enabled;
175 self
176 }
177
178 pub fn with_cost_model(mut self, cost_model: CostModel) -> Self {
180 self.cost_model = cost_model;
181 self
182 }
183
184 pub fn with_cardinality_estimator(mut self, estimator: CardinalityEstimator) -> Self {
186 self.card_estimator = estimator;
187 self
188 }
189
190 pub fn with_selectivity_config(mut self, config: SelectivityConfig) -> Self {
192 self.card_estimator = CardinalityEstimator::with_selectivity_config(config);
193 self
194 }
195
196 pub fn cost_model(&self) -> &CostModel {
198 &self.cost_model
199 }
200
201 pub fn cardinality_estimator(&self) -> &CardinalityEstimator {
203 &self.card_estimator
204 }
205
206 pub fn estimate_cost(&self, plan: &LogicalPlan) -> Cost {
212 self.cost_model
213 .estimate_tree(&plan.root, &self.card_estimator)
214 }
215
216 pub fn estimate_cardinality(&self, plan: &LogicalPlan) -> f64 {
218 self.card_estimator.estimate(&plan.root)
219 }
220
221 pub fn optimize(&self, plan: LogicalPlan) -> Result<LogicalPlan> {
227 let _span = grafeo_debug_span!("grafeo::query::optimize");
228 let mut root = plan.root;
229
230 if self.enable_filter_pushdown {
232 root = self.push_filters_down(root);
233 }
234
235 if self.enable_join_reorder {
236 root = self.reorder_joins(root);
237 }
238
239 if self.enable_projection_pushdown {
240 root = self.push_projections_down(root);
241 }
242
243 Ok(LogicalPlan {
244 root,
245 explain: plan.explain,
246 profile: plan.profile,
247 default_params: plan.default_params,
248 })
249 }
250
251 fn push_projections_down(&self, op: LogicalOperator) -> LogicalOperator {
258 let required = self.collect_required_columns(&op);
260
261 self.push_projections_recursive(op, &required)
263 }
264
265 fn collect_required_columns(&self, op: &LogicalOperator) -> HashSet<RequiredColumn> {
267 let mut required = HashSet::new();
268 Self::collect_required_recursive(op, &mut required);
269 required
270 }
271
272 fn collect_required_recursive(op: &LogicalOperator, required: &mut HashSet<RequiredColumn>) {
274 match op {
275 LogicalOperator::Return(ret) => {
276 for item in &ret.items {
277 Self::collect_from_expression(&item.expression, required);
278 }
279 Self::collect_required_recursive(&ret.input, required);
280 }
281 LogicalOperator::Project(proj) => {
282 for p in &proj.projections {
283 Self::collect_from_expression(&p.expression, required);
284 }
285 Self::collect_required_recursive(&proj.input, required);
286 }
287 LogicalOperator::Filter(filter) => {
288 Self::collect_from_expression(&filter.predicate, required);
289 Self::collect_required_recursive(&filter.input, required);
290 }
291 LogicalOperator::Sort(sort) => {
292 for key in &sort.keys {
293 Self::collect_from_expression(&key.expression, required);
294 }
295 Self::collect_required_recursive(&sort.input, required);
296 }
297 LogicalOperator::Aggregate(agg) => {
298 for expr in &agg.group_by {
299 Self::collect_from_expression(expr, required);
300 }
301 for agg_expr in &agg.aggregates {
302 if let Some(ref expr) = agg_expr.expression {
303 Self::collect_from_expression(expr, required);
304 }
305 }
306 if let Some(ref having) = agg.having {
307 Self::collect_from_expression(having, required);
308 }
309 Self::collect_required_recursive(&agg.input, required);
310 }
311 LogicalOperator::Join(join) => {
312 for cond in &join.conditions {
313 Self::collect_from_expression(&cond.left, required);
314 Self::collect_from_expression(&cond.right, required);
315 }
316 Self::collect_required_recursive(&join.left, required);
317 Self::collect_required_recursive(&join.right, required);
318 }
319 LogicalOperator::Expand(expand) => {
320 required.insert(RequiredColumn::Variable(expand.from_variable.clone()));
322 required.insert(RequiredColumn::Variable(expand.to_variable.clone()));
323 if let Some(ref edge_var) = expand.edge_variable {
324 required.insert(RequiredColumn::Variable(edge_var.clone()));
325 }
326 Self::collect_required_recursive(&expand.input, required);
327 }
328 LogicalOperator::Limit(limit) => {
329 Self::collect_required_recursive(&limit.input, required);
330 }
331 LogicalOperator::Skip(skip) => {
332 Self::collect_required_recursive(&skip.input, required);
333 }
334 LogicalOperator::Distinct(distinct) => {
335 Self::collect_required_recursive(&distinct.input, required);
336 }
337 LogicalOperator::NodeScan(scan) => {
338 required.insert(RequiredColumn::Variable(scan.variable.clone()));
339 }
340 LogicalOperator::EdgeScan(scan) => {
341 required.insert(RequiredColumn::Variable(scan.variable.clone()));
342 }
343 LogicalOperator::MultiWayJoin(mwj) => {
344 for cond in &mwj.conditions {
345 Self::collect_from_expression(&cond.left, required);
346 Self::collect_from_expression(&cond.right, required);
347 }
348 for input in &mwj.inputs {
349 Self::collect_required_recursive(input, required);
350 }
351 }
352 _ => {}
353 }
354 }
355
356 fn collect_from_expression(expr: &LogicalExpression, required: &mut HashSet<RequiredColumn>) {
358 match expr {
359 LogicalExpression::Variable(var) => {
360 required.insert(RequiredColumn::Variable(var.clone()));
361 }
362 LogicalExpression::Property { variable, property } => {
363 required.insert(RequiredColumn::Property(variable.clone(), property.clone()));
364 required.insert(RequiredColumn::Variable(variable.clone()));
365 }
366 LogicalExpression::Binary { left, right, .. } => {
367 Self::collect_from_expression(left, required);
368 Self::collect_from_expression(right, required);
369 }
370 LogicalExpression::Unary { operand, .. } => {
371 Self::collect_from_expression(operand, required);
372 }
373 LogicalExpression::FunctionCall { args, .. } => {
374 for arg in args {
375 Self::collect_from_expression(arg, required);
376 }
377 }
378 LogicalExpression::List(items) => {
379 for item in items {
380 Self::collect_from_expression(item, required);
381 }
382 }
383 LogicalExpression::Map(pairs) => {
384 for (_, value) in pairs {
385 Self::collect_from_expression(value, required);
386 }
387 }
388 LogicalExpression::IndexAccess { base, index } => {
389 Self::collect_from_expression(base, required);
390 Self::collect_from_expression(index, required);
391 }
392 LogicalExpression::SliceAccess { base, start, end } => {
393 Self::collect_from_expression(base, required);
394 if let Some(s) = start {
395 Self::collect_from_expression(s, required);
396 }
397 if let Some(e) = end {
398 Self::collect_from_expression(e, required);
399 }
400 }
401 LogicalExpression::Case {
402 operand,
403 when_clauses,
404 else_clause,
405 } => {
406 if let Some(op) = operand {
407 Self::collect_from_expression(op, required);
408 }
409 for (cond, result) in when_clauses {
410 Self::collect_from_expression(cond, required);
411 Self::collect_from_expression(result, required);
412 }
413 if let Some(else_expr) = else_clause {
414 Self::collect_from_expression(else_expr, required);
415 }
416 }
417 LogicalExpression::Labels(var)
418 | LogicalExpression::Type(var)
419 | LogicalExpression::Id(var) => {
420 required.insert(RequiredColumn::Variable(var.clone()));
421 }
422 LogicalExpression::ListComprehension {
423 list_expr,
424 filter_expr,
425 map_expr,
426 ..
427 } => {
428 Self::collect_from_expression(list_expr, required);
429 if let Some(filter) = filter_expr {
430 Self::collect_from_expression(filter, required);
431 }
432 Self::collect_from_expression(map_expr, required);
433 }
434 _ => {}
435 }
436 }
437
438 fn push_projections_recursive(
440 &self,
441 op: LogicalOperator,
442 required: &HashSet<RequiredColumn>,
443 ) -> LogicalOperator {
444 match op {
445 LogicalOperator::Return(mut ret) => {
446 ret.input = Box::new(self.push_projections_recursive(*ret.input, required));
447 LogicalOperator::Return(ret)
448 }
449 LogicalOperator::Project(mut proj) => {
450 proj.input = Box::new(self.push_projections_recursive(*proj.input, required));
451 LogicalOperator::Project(proj)
452 }
453 LogicalOperator::Filter(mut filter) => {
454 filter.input = Box::new(self.push_projections_recursive(*filter.input, required));
455 LogicalOperator::Filter(filter)
456 }
457 LogicalOperator::Sort(mut sort) => {
458 sort.input = Box::new(self.push_projections_recursive(*sort.input, required));
461 LogicalOperator::Sort(sort)
462 }
463 LogicalOperator::Aggregate(mut agg) => {
464 agg.input = Box::new(self.push_projections_recursive(*agg.input, required));
465 LogicalOperator::Aggregate(agg)
466 }
467 LogicalOperator::Join(mut join) => {
468 let left_vars = self.collect_output_variables(&join.left);
471 let right_vars = self.collect_output_variables(&join.right);
472
473 let left_required: HashSet<_> = required
475 .iter()
476 .filter(|c| match c {
477 RequiredColumn::Variable(v) => left_vars.contains(v),
478 RequiredColumn::Property(v, _) => left_vars.contains(v),
479 })
480 .cloned()
481 .collect();
482
483 let right_required: HashSet<_> = required
484 .iter()
485 .filter(|c| match c {
486 RequiredColumn::Variable(v) => right_vars.contains(v),
487 RequiredColumn::Property(v, _) => right_vars.contains(v),
488 })
489 .cloned()
490 .collect();
491
492 join.left = Box::new(self.push_projections_recursive(*join.left, &left_required));
493 join.right =
494 Box::new(self.push_projections_recursive(*join.right, &right_required));
495 LogicalOperator::Join(join)
496 }
497 LogicalOperator::Expand(mut expand) => {
498 expand.input = Box::new(self.push_projections_recursive(*expand.input, required));
499 LogicalOperator::Expand(expand)
500 }
501 LogicalOperator::Limit(mut limit) => {
502 limit.input = Box::new(self.push_projections_recursive(*limit.input, required));
503 LogicalOperator::Limit(limit)
504 }
505 LogicalOperator::Skip(mut skip) => {
506 skip.input = Box::new(self.push_projections_recursive(*skip.input, required));
507 LogicalOperator::Skip(skip)
508 }
509 LogicalOperator::Distinct(mut distinct) => {
510 distinct.input =
511 Box::new(self.push_projections_recursive(*distinct.input, required));
512 LogicalOperator::Distinct(distinct)
513 }
514 LogicalOperator::MapCollect(mut mc) => {
515 mc.input = Box::new(self.push_projections_recursive(*mc.input, required));
516 LogicalOperator::MapCollect(mc)
517 }
518 LogicalOperator::MultiWayJoin(mut mwj) => {
519 mwj.inputs = mwj
520 .inputs
521 .into_iter()
522 .map(|input| self.push_projections_recursive(input, required))
523 .collect();
524 LogicalOperator::MultiWayJoin(mwj)
525 }
526 other => other,
527 }
528 }
529
530 fn reorder_joins(&self, op: LogicalOperator) -> LogicalOperator {
537 let op = self.reorder_joins_recursive(op);
539
540 if let Some((relations, conditions)) = self.extract_join_tree(&op)
542 && relations.len() >= 2
543 && let Some(optimized) = self.optimize_join_order(&relations, &conditions)
544 {
545 return optimized;
546 }
547
548 op
549 }
550
551 fn reorder_joins_recursive(&self, op: LogicalOperator) -> LogicalOperator {
553 match op {
554 LogicalOperator::Return(mut ret) => {
555 ret.input = Box::new(self.reorder_joins(*ret.input));
556 LogicalOperator::Return(ret)
557 }
558 LogicalOperator::Project(mut proj) => {
559 proj.input = Box::new(self.reorder_joins(*proj.input));
560 LogicalOperator::Project(proj)
561 }
562 LogicalOperator::Filter(mut filter) => {
563 filter.input = Box::new(self.reorder_joins(*filter.input));
564 LogicalOperator::Filter(filter)
565 }
566 LogicalOperator::Limit(mut limit) => {
567 limit.input = Box::new(self.reorder_joins(*limit.input));
568 LogicalOperator::Limit(limit)
569 }
570 LogicalOperator::Skip(mut skip) => {
571 skip.input = Box::new(self.reorder_joins(*skip.input));
572 LogicalOperator::Skip(skip)
573 }
574 LogicalOperator::Sort(mut sort) => {
575 sort.input = Box::new(self.reorder_joins(*sort.input));
576 LogicalOperator::Sort(sort)
577 }
578 LogicalOperator::Distinct(mut distinct) => {
579 distinct.input = Box::new(self.reorder_joins(*distinct.input));
580 LogicalOperator::Distinct(distinct)
581 }
582 LogicalOperator::Aggregate(mut agg) => {
583 agg.input = Box::new(self.reorder_joins(*agg.input));
584 LogicalOperator::Aggregate(agg)
585 }
586 LogicalOperator::Expand(mut expand) => {
587 expand.input = Box::new(self.reorder_joins(*expand.input));
588 LogicalOperator::Expand(expand)
589 }
590 LogicalOperator::MapCollect(mut mc) => {
591 mc.input = Box::new(self.reorder_joins(*mc.input));
592 LogicalOperator::MapCollect(mc)
593 }
594 LogicalOperator::MultiWayJoin(mut mwj) => {
595 mwj.inputs = mwj
596 .inputs
597 .into_iter()
598 .map(|input| self.reorder_joins(input))
599 .collect();
600 LogicalOperator::MultiWayJoin(mwj)
601 }
602 other => other,
604 }
605 }
606
607 fn extract_join_tree(
611 &self,
612 op: &LogicalOperator,
613 ) -> Option<(Vec<(String, LogicalOperator)>, Vec<JoinInfo>)> {
614 let mut relations = Vec::new();
615 let mut join_conditions = Vec::new();
616
617 if !self.collect_join_tree(op, &mut relations, &mut join_conditions) {
618 return None;
619 }
620
621 if relations.len() < 2 {
622 return None;
623 }
624
625 Some((relations, join_conditions))
626 }
627
628 fn collect_join_tree(
632 &self,
633 op: &LogicalOperator,
634 relations: &mut Vec<(String, LogicalOperator)>,
635 conditions: &mut Vec<JoinInfo>,
636 ) -> bool {
637 match op {
638 LogicalOperator::Join(join) => {
639 let left_ok = self.collect_join_tree(&join.left, relations, conditions);
641 let right_ok = self.collect_join_tree(&join.right, relations, conditions);
642
643 for cond in &join.conditions {
645 if let (Some(left_var), Some(right_var)) = (
646 self.extract_variable_from_expr(&cond.left),
647 self.extract_variable_from_expr(&cond.right),
648 ) {
649 conditions.push(JoinInfo {
650 left_var,
651 right_var,
652 left_expr: cond.left.clone(),
653 right_expr: cond.right.clone(),
654 });
655 }
656 }
657
658 left_ok && right_ok
659 }
660 LogicalOperator::NodeScan(scan) => {
661 relations.push((scan.variable.clone(), op.clone()));
662 true
663 }
664 LogicalOperator::EdgeScan(scan) => {
665 relations.push((scan.variable.clone(), op.clone()));
666 true
667 }
668 LogicalOperator::Filter(filter) => {
669 self.collect_join_tree(&filter.input, relations, conditions)
671 }
672 LogicalOperator::Expand(expand) => {
673 relations.push((expand.to_variable.clone(), op.clone()));
676 true
677 }
678 _ => false,
679 }
680 }
681
682 fn extract_variable_from_expr(&self, expr: &LogicalExpression) -> Option<String> {
684 match expr {
685 LogicalExpression::Variable(v) => Some(v.clone()),
686 LogicalExpression::Property { variable, .. } => Some(variable.clone()),
687 _ => None,
688 }
689 }
690
691 fn optimize_join_order(
694 &self,
695 relations: &[(String, LogicalOperator)],
696 conditions: &[JoinInfo],
697 ) -> Option<LogicalOperator> {
698 use join_order::{DPccp, JoinGraphBuilder};
699
700 let mut builder = JoinGraphBuilder::new();
702
703 for (var, relation) in relations {
704 builder.add_relation(var, relation.clone());
705 }
706
707 for cond in conditions {
708 builder.add_join_condition(
709 &cond.left_var,
710 &cond.right_var,
711 cond.left_expr.clone(),
712 cond.right_expr.clone(),
713 );
714 }
715
716 let graph = builder.build();
717
718 if graph.is_cyclic() && relations.len() >= 3 {
723 let mut var_counts: std::collections::HashMap<&str, usize> =
725 std::collections::HashMap::new();
726 for cond in conditions {
727 *var_counts.entry(&cond.left_var).or_default() += 1;
728 *var_counts.entry(&cond.right_var).or_default() += 1;
729 }
730 let shared_variables: Vec<String> = var_counts
731 .into_iter()
732 .filter(|(_, count)| *count >= 2)
733 .map(|(var, _)| var.to_string())
734 .collect();
735
736 let join_conditions: Vec<JoinCondition> = conditions
737 .iter()
738 .map(|c| JoinCondition {
739 left: c.left_expr.clone(),
740 right: c.right_expr.clone(),
741 })
742 .collect();
743
744 return Some(LogicalOperator::MultiWayJoin(MultiWayJoinOp {
745 inputs: relations.iter().map(|(_, rel)| rel.clone()).collect(),
746 conditions: join_conditions,
747 shared_variables,
748 }));
749 }
750
751 let mut dpccp = DPccp::new(&graph, &self.cost_model, &self.card_estimator);
753 let plan = dpccp.optimize()?;
754
755 Some(plan.operator)
756 }
757
758 fn push_filters_down(&self, op: LogicalOperator) -> LogicalOperator {
763 match op {
764 LogicalOperator::Filter(filter) => {
766 let optimized_input = self.push_filters_down(*filter.input);
767 self.try_push_filter_into(filter.predicate, optimized_input)
768 }
769 LogicalOperator::Return(mut ret) => {
771 ret.input = Box::new(self.push_filters_down(*ret.input));
772 LogicalOperator::Return(ret)
773 }
774 LogicalOperator::Project(mut proj) => {
775 proj.input = Box::new(self.push_filters_down(*proj.input));
776 LogicalOperator::Project(proj)
777 }
778 LogicalOperator::Limit(mut limit) => {
779 limit.input = Box::new(self.push_filters_down(*limit.input));
780 LogicalOperator::Limit(limit)
781 }
782 LogicalOperator::Skip(mut skip) => {
783 skip.input = Box::new(self.push_filters_down(*skip.input));
784 LogicalOperator::Skip(skip)
785 }
786 LogicalOperator::Sort(mut sort) => {
787 sort.input = Box::new(self.push_filters_down(*sort.input));
788 LogicalOperator::Sort(sort)
789 }
790 LogicalOperator::Distinct(mut distinct) => {
791 distinct.input = Box::new(self.push_filters_down(*distinct.input));
792 LogicalOperator::Distinct(distinct)
793 }
794 LogicalOperator::Expand(mut expand) => {
795 expand.input = Box::new(self.push_filters_down(*expand.input));
796 LogicalOperator::Expand(expand)
797 }
798 LogicalOperator::Join(mut join) => {
799 join.left = Box::new(self.push_filters_down(*join.left));
800 join.right = Box::new(self.push_filters_down(*join.right));
801 LogicalOperator::Join(join)
802 }
803 LogicalOperator::Aggregate(mut agg) => {
804 agg.input = Box::new(self.push_filters_down(*agg.input));
805 LogicalOperator::Aggregate(agg)
806 }
807 LogicalOperator::MapCollect(mut mc) => {
808 mc.input = Box::new(self.push_filters_down(*mc.input));
809 LogicalOperator::MapCollect(mc)
810 }
811 LogicalOperator::MultiWayJoin(mut mwj) => {
812 mwj.inputs = mwj
813 .inputs
814 .into_iter()
815 .map(|input| self.push_filters_down(input))
816 .collect();
817 LogicalOperator::MultiWayJoin(mwj)
818 }
819 other => other,
821 }
822 }
823
824 fn try_push_filter_into(
829 &self,
830 predicate: LogicalExpression,
831 op: LogicalOperator,
832 ) -> LogicalOperator {
833 match op {
834 LogicalOperator::Project(mut proj) => {
836 let predicate_vars = self.extract_variables(&predicate);
837 let computed_vars = self.extract_projection_aliases(&proj.projections);
838
839 if predicate_vars.is_disjoint(&computed_vars) {
841 proj.input = Box::new(self.try_push_filter_into(predicate, *proj.input));
842 LogicalOperator::Project(proj)
843 } else {
844 LogicalOperator::Filter(FilterOp {
846 predicate,
847 pushdown_hint: None,
848 input: Box::new(LogicalOperator::Project(proj)),
849 })
850 }
851 }
852
853 LogicalOperator::Return(mut ret) => {
855 ret.input = Box::new(self.try_push_filter_into(predicate, *ret.input));
856 LogicalOperator::Return(ret)
857 }
858
859 LogicalOperator::Expand(mut expand) => {
861 let predicate_vars = self.extract_variables(&predicate);
862
863 let mut introduced_vars = vec![&expand.to_variable];
868 if let Some(ref edge_var) = expand.edge_variable {
869 introduced_vars.push(edge_var);
870 }
871 if let Some(ref path_alias) = expand.path_alias {
872 introduced_vars.push(path_alias);
873 }
874
875 let uses_introduced_vars =
877 predicate_vars.iter().any(|v| introduced_vars.contains(&v));
878
879 if !uses_introduced_vars {
880 expand.input = Box::new(self.try_push_filter_into(predicate, *expand.input));
882 LogicalOperator::Expand(expand)
883 } else {
884 LogicalOperator::Filter(FilterOp {
886 predicate,
887 pushdown_hint: None,
888 input: Box::new(LogicalOperator::Expand(expand)),
889 })
890 }
891 }
892
893 LogicalOperator::Join(mut join) => {
895 let predicate_vars = self.extract_variables(&predicate);
896 let left_vars = self.collect_output_variables(&join.left);
897 let right_vars = self.collect_output_variables(&join.right);
898
899 let uses_left = predicate_vars.iter().any(|v| left_vars.contains(v));
900 let uses_right = predicate_vars.iter().any(|v| right_vars.contains(v));
901
902 if uses_left && !uses_right {
903 join.left = Box::new(self.try_push_filter_into(predicate, *join.left));
905 LogicalOperator::Join(join)
906 } else if uses_right && !uses_left {
907 join.right = Box::new(self.try_push_filter_into(predicate, *join.right));
909 LogicalOperator::Join(join)
910 } else {
911 LogicalOperator::Filter(FilterOp {
913 predicate,
914 pushdown_hint: None,
915 input: Box::new(LogicalOperator::Join(join)),
916 })
917 }
918 }
919
920 LogicalOperator::Aggregate(agg) => LogicalOperator::Filter(FilterOp {
922 predicate,
923 pushdown_hint: None,
924 input: Box::new(LogicalOperator::Aggregate(agg)),
925 }),
926
927 LogicalOperator::NodeScan(scan) => LogicalOperator::Filter(FilterOp {
929 predicate,
930 pushdown_hint: None,
931 input: Box::new(LogicalOperator::NodeScan(scan)),
932 }),
933
934 other => LogicalOperator::Filter(FilterOp {
936 predicate,
937 pushdown_hint: None,
938 input: Box::new(other),
939 }),
940 }
941 }
942
943 fn collect_output_variables(&self, op: &LogicalOperator) -> HashSet<String> {
945 let mut vars = HashSet::new();
946 Self::collect_output_variables_recursive(op, &mut vars);
947 vars
948 }
949
950 fn collect_output_variables_recursive(op: &LogicalOperator, vars: &mut HashSet<String>) {
952 match op {
953 LogicalOperator::NodeScan(scan) => {
954 vars.insert(scan.variable.clone());
955 }
956 LogicalOperator::EdgeScan(scan) => {
957 vars.insert(scan.variable.clone());
958 }
959 LogicalOperator::Expand(expand) => {
960 vars.insert(expand.to_variable.clone());
961 if let Some(edge_var) = &expand.edge_variable {
962 vars.insert(edge_var.clone());
963 }
964 Self::collect_output_variables_recursive(&expand.input, vars);
965 }
966 LogicalOperator::Filter(filter) => {
967 Self::collect_output_variables_recursive(&filter.input, vars);
968 }
969 LogicalOperator::Project(proj) => {
970 for p in &proj.projections {
971 if let Some(alias) = &p.alias {
972 vars.insert(alias.clone());
973 }
974 }
975 Self::collect_output_variables_recursive(&proj.input, vars);
976 }
977 LogicalOperator::Join(join) => {
978 Self::collect_output_variables_recursive(&join.left, vars);
979 Self::collect_output_variables_recursive(&join.right, vars);
980 }
981 LogicalOperator::Aggregate(agg) => {
982 for expr in &agg.group_by {
983 Self::collect_variables(expr, vars);
984 }
985 for agg_expr in &agg.aggregates {
986 if let Some(alias) = &agg_expr.alias {
987 vars.insert(alias.clone());
988 }
989 }
990 }
991 LogicalOperator::Return(ret) => {
992 Self::collect_output_variables_recursive(&ret.input, vars);
993 }
994 LogicalOperator::Limit(limit) => {
995 Self::collect_output_variables_recursive(&limit.input, vars);
996 }
997 LogicalOperator::Skip(skip) => {
998 Self::collect_output_variables_recursive(&skip.input, vars);
999 }
1000 LogicalOperator::Sort(sort) => {
1001 Self::collect_output_variables_recursive(&sort.input, vars);
1002 }
1003 LogicalOperator::Distinct(distinct) => {
1004 Self::collect_output_variables_recursive(&distinct.input, vars);
1005 }
1006 _ => {}
1007 }
1008 }
1009
1010 fn extract_variables(&self, expr: &LogicalExpression) -> HashSet<String> {
1012 let mut vars = HashSet::new();
1013 Self::collect_variables(expr, &mut vars);
1014 vars
1015 }
1016
1017 fn collect_variables(expr: &LogicalExpression, vars: &mut HashSet<String>) {
1019 match expr {
1020 LogicalExpression::Variable(name) => {
1021 vars.insert(name.clone());
1022 }
1023 LogicalExpression::Property { variable, .. } => {
1024 vars.insert(variable.clone());
1025 }
1026 LogicalExpression::Binary { left, right, .. } => {
1027 Self::collect_variables(left, vars);
1028 Self::collect_variables(right, vars);
1029 }
1030 LogicalExpression::Unary { operand, .. } => {
1031 Self::collect_variables(operand, vars);
1032 }
1033 LogicalExpression::FunctionCall { args, .. } => {
1034 for arg in args {
1035 Self::collect_variables(arg, vars);
1036 }
1037 }
1038 LogicalExpression::List(items) => {
1039 for item in items {
1040 Self::collect_variables(item, vars);
1041 }
1042 }
1043 LogicalExpression::Map(pairs) => {
1044 for (_, value) in pairs {
1045 Self::collect_variables(value, vars);
1046 }
1047 }
1048 LogicalExpression::IndexAccess { base, index } => {
1049 Self::collect_variables(base, vars);
1050 Self::collect_variables(index, vars);
1051 }
1052 LogicalExpression::SliceAccess { base, start, end } => {
1053 Self::collect_variables(base, vars);
1054 if let Some(s) = start {
1055 Self::collect_variables(s, vars);
1056 }
1057 if let Some(e) = end {
1058 Self::collect_variables(e, vars);
1059 }
1060 }
1061 LogicalExpression::Case {
1062 operand,
1063 when_clauses,
1064 else_clause,
1065 } => {
1066 if let Some(op) = operand {
1067 Self::collect_variables(op, vars);
1068 }
1069 for (cond, result) in when_clauses {
1070 Self::collect_variables(cond, vars);
1071 Self::collect_variables(result, vars);
1072 }
1073 if let Some(else_expr) = else_clause {
1074 Self::collect_variables(else_expr, vars);
1075 }
1076 }
1077 LogicalExpression::Labels(var)
1078 | LogicalExpression::Type(var)
1079 | LogicalExpression::Id(var) => {
1080 vars.insert(var.clone());
1081 }
1082 LogicalExpression::Literal(_) | LogicalExpression::Parameter(_) => {}
1083 LogicalExpression::ListComprehension {
1084 list_expr,
1085 filter_expr,
1086 map_expr,
1087 ..
1088 } => {
1089 Self::collect_variables(list_expr, vars);
1090 if let Some(filter) = filter_expr {
1091 Self::collect_variables(filter, vars);
1092 }
1093 Self::collect_variables(map_expr, vars);
1094 }
1095 LogicalExpression::ListPredicate {
1096 list_expr,
1097 predicate,
1098 ..
1099 } => {
1100 Self::collect_variables(list_expr, vars);
1101 Self::collect_variables(predicate, vars);
1102 }
1103 LogicalExpression::ExistsSubquery(_)
1104 | LogicalExpression::CountSubquery(_)
1105 | LogicalExpression::ValueSubquery(_) => {
1106 }
1108 LogicalExpression::PatternComprehension { projection, .. } => {
1109 Self::collect_variables(projection, vars);
1110 }
1111 LogicalExpression::MapProjection { base, entries } => {
1112 vars.insert(base.clone());
1113 for entry in entries {
1114 if let crate::query::plan::MapProjectionEntry::LiteralEntry(_, expr) = entry {
1115 Self::collect_variables(expr, vars);
1116 }
1117 }
1118 }
1119 LogicalExpression::Reduce {
1120 initial,
1121 list,
1122 expression,
1123 ..
1124 } => {
1125 Self::collect_variables(initial, vars);
1126 Self::collect_variables(list, vars);
1127 Self::collect_variables(expression, vars);
1128 }
1129 }
1130 }
1131
1132 fn extract_projection_aliases(
1134 &self,
1135 projections: &[crate::query::plan::Projection],
1136 ) -> HashSet<String> {
1137 projections.iter().filter_map(|p| p.alias.clone()).collect()
1138 }
1139}
1140
1141impl Default for Optimizer {
1142 fn default() -> Self {
1143 Self::new()
1144 }
1145}
1146
1147#[cfg(test)]
1148mod tests {
1149 use super::*;
1150 use crate::query::plan::{
1151 AggregateExpr, AggregateFunction, AggregateOp, BinaryOp, DistinctOp, ExpandDirection,
1152 ExpandOp, JoinOp, JoinType, LimitOp, NodeScanOp, PathMode, ProjectOp, Projection,
1153 ReturnItem, ReturnOp, SkipOp, SortKey, SortOp, SortOrder, UnaryOp,
1154 };
1155 use grafeo_common::types::Value;
1156
1157 #[test]
1158 fn test_optimizer_filter_pushdown_simple() {
1159 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1164 items: vec![ReturnItem {
1165 expression: LogicalExpression::Variable("n".to_string()),
1166 alias: None,
1167 }],
1168 distinct: false,
1169 input: Box::new(LogicalOperator::Filter(FilterOp {
1170 predicate: LogicalExpression::Binary {
1171 left: Box::new(LogicalExpression::Property {
1172 variable: "n".to_string(),
1173 property: "age".to_string(),
1174 }),
1175 op: BinaryOp::Gt,
1176 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1177 },
1178 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1179 variable: "n".to_string(),
1180 label: Some("Person".to_string()),
1181 input: None,
1182 })),
1183 pushdown_hint: None,
1184 })),
1185 }));
1186
1187 let optimizer = Optimizer::new();
1188 let optimized = optimizer.optimize(plan).unwrap();
1189
1190 if let LogicalOperator::Return(ret) = &optimized.root
1192 && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1193 && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1194 {
1195 assert_eq!(scan.variable, "n");
1196 return;
1197 }
1198 panic!("Expected Return -> Filter -> NodeScan structure");
1199 }
1200
1201 #[test]
1202 fn test_optimizer_filter_pushdown_through_expand() {
1203 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1207 items: vec![ReturnItem {
1208 expression: LogicalExpression::Variable("b".to_string()),
1209 alias: None,
1210 }],
1211 distinct: false,
1212 input: Box::new(LogicalOperator::Filter(FilterOp {
1213 predicate: LogicalExpression::Binary {
1214 left: Box::new(LogicalExpression::Property {
1215 variable: "a".to_string(),
1216 property: "age".to_string(),
1217 }),
1218 op: BinaryOp::Gt,
1219 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1220 },
1221 pushdown_hint: None,
1222 input: Box::new(LogicalOperator::Expand(ExpandOp {
1223 from_variable: "a".to_string(),
1224 to_variable: "b".to_string(),
1225 edge_variable: None,
1226 direction: ExpandDirection::Outgoing,
1227 edge_types: vec!["KNOWS".to_string()],
1228 min_hops: 1,
1229 max_hops: Some(1),
1230 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1231 variable: "a".to_string(),
1232 label: Some("Person".to_string()),
1233 input: None,
1234 })),
1235 path_alias: None,
1236 path_mode: PathMode::Walk,
1237 })),
1238 })),
1239 }));
1240
1241 let optimizer = Optimizer::new();
1242 let optimized = optimizer.optimize(plan).unwrap();
1243
1244 if let LogicalOperator::Return(ret) = &optimized.root
1247 && let LogicalOperator::Expand(expand) = ret.input.as_ref()
1248 && let LogicalOperator::Filter(filter) = expand.input.as_ref()
1249 && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1250 {
1251 assert_eq!(scan.variable, "a");
1252 assert_eq!(expand.from_variable, "a");
1253 assert_eq!(expand.to_variable, "b");
1254 return;
1255 }
1256 panic!("Expected Return -> Expand -> Filter -> NodeScan structure");
1257 }
1258
1259 #[test]
1260 fn test_optimizer_filter_not_pushed_through_expand_for_target_var() {
1261 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1265 items: vec![ReturnItem {
1266 expression: LogicalExpression::Variable("a".to_string()),
1267 alias: None,
1268 }],
1269 distinct: false,
1270 input: Box::new(LogicalOperator::Filter(FilterOp {
1271 predicate: LogicalExpression::Binary {
1272 left: Box::new(LogicalExpression::Property {
1273 variable: "b".to_string(),
1274 property: "age".to_string(),
1275 }),
1276 op: BinaryOp::Gt,
1277 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1278 },
1279 pushdown_hint: None,
1280 input: Box::new(LogicalOperator::Expand(ExpandOp {
1281 from_variable: "a".to_string(),
1282 to_variable: "b".to_string(),
1283 edge_variable: None,
1284 direction: ExpandDirection::Outgoing,
1285 edge_types: vec!["KNOWS".to_string()],
1286 min_hops: 1,
1287 max_hops: Some(1),
1288 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1289 variable: "a".to_string(),
1290 label: Some("Person".to_string()),
1291 input: None,
1292 })),
1293 path_alias: None,
1294 path_mode: PathMode::Walk,
1295 })),
1296 })),
1297 }));
1298
1299 let optimizer = Optimizer::new();
1300 let optimized = optimizer.optimize(plan).unwrap();
1301
1302 if let LogicalOperator::Return(ret) = &optimized.root
1305 && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1306 {
1307 if let LogicalExpression::Binary { left, .. } = &filter.predicate
1309 && let LogicalExpression::Property { variable, .. } = left.as_ref()
1310 {
1311 assert_eq!(variable, "b");
1312 }
1313
1314 if let LogicalOperator::Expand(expand) = filter.input.as_ref()
1315 && let LogicalOperator::NodeScan(_) = expand.input.as_ref()
1316 {
1317 return;
1318 }
1319 }
1320 panic!("Expected Return -> Filter -> Expand -> NodeScan structure");
1321 }
1322
1323 #[test]
1324 fn test_optimizer_extract_variables() {
1325 let optimizer = Optimizer::new();
1326
1327 let expr = LogicalExpression::Binary {
1328 left: Box::new(LogicalExpression::Property {
1329 variable: "n".to_string(),
1330 property: "age".to_string(),
1331 }),
1332 op: BinaryOp::Gt,
1333 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1334 };
1335
1336 let vars = optimizer.extract_variables(&expr);
1337 assert_eq!(vars.len(), 1);
1338 assert!(vars.contains("n"));
1339 }
1340
1341 #[test]
1344 fn test_optimizer_default() {
1345 let optimizer = Optimizer::default();
1346 let plan = LogicalPlan::new(LogicalOperator::Empty);
1348 let result = optimizer.optimize(plan);
1349 assert!(result.is_ok());
1350 }
1351
1352 #[test]
1353 fn test_optimizer_with_filter_pushdown_disabled() {
1354 let optimizer = Optimizer::new().with_filter_pushdown(false);
1355
1356 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1357 items: vec![ReturnItem {
1358 expression: LogicalExpression::Variable("n".to_string()),
1359 alias: None,
1360 }],
1361 distinct: false,
1362 input: Box::new(LogicalOperator::Filter(FilterOp {
1363 predicate: LogicalExpression::Literal(Value::Bool(true)),
1364 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1365 variable: "n".to_string(),
1366 label: None,
1367 input: None,
1368 })),
1369 pushdown_hint: None,
1370 })),
1371 }));
1372
1373 let optimized = optimizer.optimize(plan).unwrap();
1374 if let LogicalOperator::Return(ret) = &optimized.root
1376 && let LogicalOperator::Filter(_) = ret.input.as_ref()
1377 {
1378 return;
1379 }
1380 panic!("Expected unchanged structure");
1381 }
1382
1383 #[test]
1384 fn test_optimizer_with_join_reorder_disabled() {
1385 let optimizer = Optimizer::new().with_join_reorder(false);
1386 assert!(
1387 optimizer
1388 .optimize(LogicalPlan::new(LogicalOperator::Empty))
1389 .is_ok()
1390 );
1391 }
1392
1393 #[test]
1394 fn test_optimizer_with_cost_model() {
1395 let cost_model = CostModel::new();
1396 let optimizer = Optimizer::new().with_cost_model(cost_model);
1397 assert!(
1398 optimizer
1399 .cost_model()
1400 .estimate(&LogicalOperator::Empty, 0.0)
1401 .total()
1402 < 0.001
1403 );
1404 }
1405
1406 #[test]
1407 fn test_optimizer_with_cardinality_estimator() {
1408 let mut estimator = CardinalityEstimator::new();
1409 estimator.add_table_stats("Test", TableStats::new(500));
1410 let optimizer = Optimizer::new().with_cardinality_estimator(estimator);
1411
1412 let scan = LogicalOperator::NodeScan(NodeScanOp {
1413 variable: "n".to_string(),
1414 label: Some("Test".to_string()),
1415 input: None,
1416 });
1417 let plan = LogicalPlan::new(scan);
1418
1419 let cardinality = optimizer.estimate_cardinality(&plan);
1420 assert!((cardinality - 500.0).abs() < 0.001);
1421 }
1422
1423 #[test]
1424 fn test_optimizer_estimate_cost() {
1425 let optimizer = Optimizer::new();
1426 let plan = LogicalPlan::new(LogicalOperator::NodeScan(NodeScanOp {
1427 variable: "n".to_string(),
1428 label: None,
1429 input: None,
1430 }));
1431
1432 let cost = optimizer.estimate_cost(&plan);
1433 assert!(cost.total() > 0.0);
1434 }
1435
1436 #[test]
1439 fn test_filter_pushdown_through_project() {
1440 let optimizer = Optimizer::new();
1441
1442 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1443 predicate: LogicalExpression::Binary {
1444 left: Box::new(LogicalExpression::Property {
1445 variable: "n".to_string(),
1446 property: "age".to_string(),
1447 }),
1448 op: BinaryOp::Gt,
1449 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1450 },
1451 pushdown_hint: None,
1452 input: Box::new(LogicalOperator::Project(ProjectOp {
1453 projections: vec![Projection {
1454 expression: LogicalExpression::Variable("n".to_string()),
1455 alias: None,
1456 }],
1457 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1458 variable: "n".to_string(),
1459 label: None,
1460 input: None,
1461 })),
1462 pass_through_input: false,
1463 })),
1464 }));
1465
1466 let optimized = optimizer.optimize(plan).unwrap();
1467
1468 if let LogicalOperator::Project(proj) = &optimized.root
1470 && let LogicalOperator::Filter(_) = proj.input.as_ref()
1471 {
1472 return;
1473 }
1474 panic!("Expected Project -> Filter structure");
1475 }
1476
1477 #[test]
1478 fn test_filter_not_pushed_through_project_with_alias() {
1479 let optimizer = Optimizer::new();
1480
1481 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1483 predicate: LogicalExpression::Binary {
1484 left: Box::new(LogicalExpression::Variable("x".to_string())),
1485 op: BinaryOp::Gt,
1486 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1487 },
1488 pushdown_hint: None,
1489 input: Box::new(LogicalOperator::Project(ProjectOp {
1490 projections: vec![Projection {
1491 expression: LogicalExpression::Property {
1492 variable: "n".to_string(),
1493 property: "age".to_string(),
1494 },
1495 alias: Some("x".to_string()),
1496 }],
1497 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1498 variable: "n".to_string(),
1499 label: None,
1500 input: None,
1501 })),
1502 pass_through_input: false,
1503 })),
1504 }));
1505
1506 let optimized = optimizer.optimize(plan).unwrap();
1507
1508 if let LogicalOperator::Filter(filter) = &optimized.root
1510 && let LogicalOperator::Project(_) = filter.input.as_ref()
1511 {
1512 return;
1513 }
1514 panic!("Expected Filter -> Project structure");
1515 }
1516
1517 #[test]
1518 fn test_filter_pushdown_through_limit() {
1519 let optimizer = Optimizer::new();
1520
1521 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1522 predicate: LogicalExpression::Literal(Value::Bool(true)),
1523 pushdown_hint: None,
1524 input: Box::new(LogicalOperator::Limit(LimitOp {
1525 count: 10.into(),
1526 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1527 variable: "n".to_string(),
1528 label: None,
1529 input: None,
1530 })),
1531 })),
1532 }));
1533
1534 let optimized = optimizer.optimize(plan).unwrap();
1535
1536 if let LogicalOperator::Filter(filter) = &optimized.root
1538 && let LogicalOperator::Limit(_) = filter.input.as_ref()
1539 {
1540 return;
1541 }
1542 panic!("Expected Filter -> Limit structure");
1543 }
1544
1545 #[test]
1546 fn test_filter_pushdown_through_sort() {
1547 let optimizer = Optimizer::new();
1548
1549 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1550 predicate: LogicalExpression::Literal(Value::Bool(true)),
1551 pushdown_hint: None,
1552 input: Box::new(LogicalOperator::Sort(SortOp {
1553 keys: vec![SortKey {
1554 expression: LogicalExpression::Variable("n".to_string()),
1555 order: SortOrder::Ascending,
1556 nulls: None,
1557 }],
1558 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1559 variable: "n".to_string(),
1560 label: None,
1561 input: None,
1562 })),
1563 })),
1564 }));
1565
1566 let optimized = optimizer.optimize(plan).unwrap();
1567
1568 if let LogicalOperator::Filter(filter) = &optimized.root
1570 && let LogicalOperator::Sort(_) = filter.input.as_ref()
1571 {
1572 return;
1573 }
1574 panic!("Expected Filter -> Sort structure");
1575 }
1576
1577 #[test]
1578 fn test_filter_pushdown_through_distinct() {
1579 let optimizer = Optimizer::new();
1580
1581 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1582 predicate: LogicalExpression::Literal(Value::Bool(true)),
1583 pushdown_hint: None,
1584 input: Box::new(LogicalOperator::Distinct(DistinctOp {
1585 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1586 variable: "n".to_string(),
1587 label: None,
1588 input: None,
1589 })),
1590 columns: None,
1591 })),
1592 }));
1593
1594 let optimized = optimizer.optimize(plan).unwrap();
1595
1596 if let LogicalOperator::Filter(filter) = &optimized.root
1598 && let LogicalOperator::Distinct(_) = filter.input.as_ref()
1599 {
1600 return;
1601 }
1602 panic!("Expected Filter -> Distinct structure");
1603 }
1604
1605 #[test]
1606 fn test_filter_not_pushed_through_aggregate() {
1607 let optimizer = Optimizer::new();
1608
1609 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1610 predicate: LogicalExpression::Binary {
1611 left: Box::new(LogicalExpression::Variable("cnt".to_string())),
1612 op: BinaryOp::Gt,
1613 right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1614 },
1615 pushdown_hint: None,
1616 input: Box::new(LogicalOperator::Aggregate(AggregateOp {
1617 group_by: vec![],
1618 aggregates: vec![AggregateExpr {
1619 function: AggregateFunction::Count,
1620 expression: None,
1621 expression2: None,
1622 distinct: false,
1623 alias: Some("cnt".to_string()),
1624 percentile: None,
1625 separator: None,
1626 }],
1627 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1628 variable: "n".to_string(),
1629 label: None,
1630 input: None,
1631 })),
1632 having: None,
1633 })),
1634 }));
1635
1636 let optimized = optimizer.optimize(plan).unwrap();
1637
1638 if let LogicalOperator::Filter(filter) = &optimized.root
1640 && let LogicalOperator::Aggregate(_) = filter.input.as_ref()
1641 {
1642 return;
1643 }
1644 panic!("Expected Filter -> Aggregate structure");
1645 }
1646
1647 #[test]
1648 fn test_filter_pushdown_to_left_join_side() {
1649 let optimizer = Optimizer::new();
1650
1651 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1653 predicate: LogicalExpression::Binary {
1654 left: Box::new(LogicalExpression::Property {
1655 variable: "a".to_string(),
1656 property: "age".to_string(),
1657 }),
1658 op: BinaryOp::Gt,
1659 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1660 },
1661 pushdown_hint: None,
1662 input: Box::new(LogicalOperator::Join(JoinOp {
1663 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1664 variable: "a".to_string(),
1665 label: Some("Person".to_string()),
1666 input: None,
1667 })),
1668 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1669 variable: "b".to_string(),
1670 label: Some("Company".to_string()),
1671 input: None,
1672 })),
1673 join_type: JoinType::Inner,
1674 conditions: vec![],
1675 })),
1676 }));
1677
1678 let optimized = optimizer.optimize(plan).unwrap();
1679
1680 if let LogicalOperator::Join(join) = &optimized.root
1682 && let LogicalOperator::Filter(_) = join.left.as_ref()
1683 {
1684 return;
1685 }
1686 panic!("Expected Join with Filter on left side");
1687 }
1688
1689 #[test]
1690 fn test_filter_pushdown_to_right_join_side() {
1691 let optimizer = Optimizer::new();
1692
1693 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1695 predicate: LogicalExpression::Binary {
1696 left: Box::new(LogicalExpression::Property {
1697 variable: "b".to_string(),
1698 property: "name".to_string(),
1699 }),
1700 op: BinaryOp::Eq,
1701 right: Box::new(LogicalExpression::Literal(Value::String("Acme".into()))),
1702 },
1703 pushdown_hint: None,
1704 input: Box::new(LogicalOperator::Join(JoinOp {
1705 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1706 variable: "a".to_string(),
1707 label: Some("Person".to_string()),
1708 input: None,
1709 })),
1710 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1711 variable: "b".to_string(),
1712 label: Some("Company".to_string()),
1713 input: None,
1714 })),
1715 join_type: JoinType::Inner,
1716 conditions: vec![],
1717 })),
1718 }));
1719
1720 let optimized = optimizer.optimize(plan).unwrap();
1721
1722 if let LogicalOperator::Join(join) = &optimized.root
1724 && let LogicalOperator::Filter(_) = join.right.as_ref()
1725 {
1726 return;
1727 }
1728 panic!("Expected Join with Filter on right side");
1729 }
1730
1731 #[test]
1732 fn test_filter_not_pushed_when_uses_both_join_sides() {
1733 let optimizer = Optimizer::new();
1734
1735 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1737 predicate: LogicalExpression::Binary {
1738 left: Box::new(LogicalExpression::Property {
1739 variable: "a".to_string(),
1740 property: "id".to_string(),
1741 }),
1742 op: BinaryOp::Eq,
1743 right: Box::new(LogicalExpression::Property {
1744 variable: "b".to_string(),
1745 property: "a_id".to_string(),
1746 }),
1747 },
1748 pushdown_hint: None,
1749 input: Box::new(LogicalOperator::Join(JoinOp {
1750 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1751 variable: "a".to_string(),
1752 label: None,
1753 input: None,
1754 })),
1755 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1756 variable: "b".to_string(),
1757 label: None,
1758 input: None,
1759 })),
1760 join_type: JoinType::Inner,
1761 conditions: vec![],
1762 })),
1763 }));
1764
1765 let optimized = optimizer.optimize(plan).unwrap();
1766
1767 if let LogicalOperator::Filter(filter) = &optimized.root
1769 && let LogicalOperator::Join(_) = filter.input.as_ref()
1770 {
1771 return;
1772 }
1773 panic!("Expected Filter -> Join structure");
1774 }
1775
1776 #[test]
1779 fn test_extract_variables_from_variable() {
1780 let optimizer = Optimizer::new();
1781 let expr = LogicalExpression::Variable("x".to_string());
1782 let vars = optimizer.extract_variables(&expr);
1783 assert_eq!(vars.len(), 1);
1784 assert!(vars.contains("x"));
1785 }
1786
1787 #[test]
1788 fn test_extract_variables_from_unary() {
1789 let optimizer = Optimizer::new();
1790 let expr = LogicalExpression::Unary {
1791 op: UnaryOp::Not,
1792 operand: Box::new(LogicalExpression::Variable("x".to_string())),
1793 };
1794 let vars = optimizer.extract_variables(&expr);
1795 assert_eq!(vars.len(), 1);
1796 assert!(vars.contains("x"));
1797 }
1798
1799 #[test]
1800 fn test_extract_variables_from_function_call() {
1801 let optimizer = Optimizer::new();
1802 let expr = LogicalExpression::FunctionCall {
1803 name: "length".to_string(),
1804 args: vec![
1805 LogicalExpression::Variable("a".to_string()),
1806 LogicalExpression::Variable("b".to_string()),
1807 ],
1808 distinct: false,
1809 };
1810 let vars = optimizer.extract_variables(&expr);
1811 assert_eq!(vars.len(), 2);
1812 assert!(vars.contains("a"));
1813 assert!(vars.contains("b"));
1814 }
1815
1816 #[test]
1817 fn test_extract_variables_from_list() {
1818 let optimizer = Optimizer::new();
1819 let expr = LogicalExpression::List(vec![
1820 LogicalExpression::Variable("a".to_string()),
1821 LogicalExpression::Literal(Value::Int64(1)),
1822 LogicalExpression::Variable("b".to_string()),
1823 ]);
1824 let vars = optimizer.extract_variables(&expr);
1825 assert_eq!(vars.len(), 2);
1826 assert!(vars.contains("a"));
1827 assert!(vars.contains("b"));
1828 }
1829
1830 #[test]
1831 fn test_extract_variables_from_map() {
1832 let optimizer = Optimizer::new();
1833 let expr = LogicalExpression::Map(vec![
1834 (
1835 "key1".to_string(),
1836 LogicalExpression::Variable("a".to_string()),
1837 ),
1838 (
1839 "key2".to_string(),
1840 LogicalExpression::Variable("b".to_string()),
1841 ),
1842 ]);
1843 let vars = optimizer.extract_variables(&expr);
1844 assert_eq!(vars.len(), 2);
1845 assert!(vars.contains("a"));
1846 assert!(vars.contains("b"));
1847 }
1848
1849 #[test]
1850 fn test_extract_variables_from_index_access() {
1851 let optimizer = Optimizer::new();
1852 let expr = LogicalExpression::IndexAccess {
1853 base: Box::new(LogicalExpression::Variable("list".to_string())),
1854 index: Box::new(LogicalExpression::Variable("idx".to_string())),
1855 };
1856 let vars = optimizer.extract_variables(&expr);
1857 assert_eq!(vars.len(), 2);
1858 assert!(vars.contains("list"));
1859 assert!(vars.contains("idx"));
1860 }
1861
1862 #[test]
1863 fn test_extract_variables_from_slice_access() {
1864 let optimizer = Optimizer::new();
1865 let expr = LogicalExpression::SliceAccess {
1866 base: Box::new(LogicalExpression::Variable("list".to_string())),
1867 start: Some(Box::new(LogicalExpression::Variable("s".to_string()))),
1868 end: Some(Box::new(LogicalExpression::Variable("e".to_string()))),
1869 };
1870 let vars = optimizer.extract_variables(&expr);
1871 assert_eq!(vars.len(), 3);
1872 assert!(vars.contains("list"));
1873 assert!(vars.contains("s"));
1874 assert!(vars.contains("e"));
1875 }
1876
1877 #[test]
1878 fn test_extract_variables_from_case() {
1879 let optimizer = Optimizer::new();
1880 let expr = LogicalExpression::Case {
1881 operand: Some(Box::new(LogicalExpression::Variable("x".to_string()))),
1882 when_clauses: vec![(
1883 LogicalExpression::Literal(Value::Int64(1)),
1884 LogicalExpression::Variable("a".to_string()),
1885 )],
1886 else_clause: Some(Box::new(LogicalExpression::Variable("b".to_string()))),
1887 };
1888 let vars = optimizer.extract_variables(&expr);
1889 assert_eq!(vars.len(), 3);
1890 assert!(vars.contains("x"));
1891 assert!(vars.contains("a"));
1892 assert!(vars.contains("b"));
1893 }
1894
1895 #[test]
1896 fn test_extract_variables_from_labels() {
1897 let optimizer = Optimizer::new();
1898 let expr = LogicalExpression::Labels("n".to_string());
1899 let vars = optimizer.extract_variables(&expr);
1900 assert_eq!(vars.len(), 1);
1901 assert!(vars.contains("n"));
1902 }
1903
1904 #[test]
1905 fn test_extract_variables_from_type() {
1906 let optimizer = Optimizer::new();
1907 let expr = LogicalExpression::Type("e".to_string());
1908 let vars = optimizer.extract_variables(&expr);
1909 assert_eq!(vars.len(), 1);
1910 assert!(vars.contains("e"));
1911 }
1912
1913 #[test]
1914 fn test_extract_variables_from_id() {
1915 let optimizer = Optimizer::new();
1916 let expr = LogicalExpression::Id("n".to_string());
1917 let vars = optimizer.extract_variables(&expr);
1918 assert_eq!(vars.len(), 1);
1919 assert!(vars.contains("n"));
1920 }
1921
1922 #[test]
1923 fn test_extract_variables_from_list_comprehension() {
1924 let optimizer = Optimizer::new();
1925 let expr = LogicalExpression::ListComprehension {
1926 variable: "x".to_string(),
1927 list_expr: Box::new(LogicalExpression::Variable("items".to_string())),
1928 filter_expr: Some(Box::new(LogicalExpression::Variable("pred".to_string()))),
1929 map_expr: Box::new(LogicalExpression::Variable("result".to_string())),
1930 };
1931 let vars = optimizer.extract_variables(&expr);
1932 assert!(vars.contains("items"));
1933 assert!(vars.contains("pred"));
1934 assert!(vars.contains("result"));
1935 }
1936
1937 #[test]
1938 fn test_extract_variables_from_literal_and_parameter() {
1939 let optimizer = Optimizer::new();
1940
1941 let literal = LogicalExpression::Literal(Value::Int64(42));
1942 assert!(optimizer.extract_variables(&literal).is_empty());
1943
1944 let param = LogicalExpression::Parameter("p".to_string());
1945 assert!(optimizer.extract_variables(¶m).is_empty());
1946 }
1947
1948 #[test]
1951 fn test_recursive_filter_pushdown_through_skip() {
1952 let optimizer = Optimizer::new();
1953
1954 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1955 items: vec![ReturnItem {
1956 expression: LogicalExpression::Variable("n".to_string()),
1957 alias: None,
1958 }],
1959 distinct: false,
1960 input: Box::new(LogicalOperator::Filter(FilterOp {
1961 predicate: LogicalExpression::Literal(Value::Bool(true)),
1962 pushdown_hint: None,
1963 input: Box::new(LogicalOperator::Skip(SkipOp {
1964 count: 5.into(),
1965 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1966 variable: "n".to_string(),
1967 label: None,
1968 input: None,
1969 })),
1970 })),
1971 })),
1972 }));
1973
1974 let optimized = optimizer.optimize(plan).unwrap();
1975
1976 assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1978 }
1979
1980 #[test]
1981 fn test_nested_filter_pushdown() {
1982 let optimizer = Optimizer::new();
1983
1984 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1986 items: vec![ReturnItem {
1987 expression: LogicalExpression::Variable("n".to_string()),
1988 alias: None,
1989 }],
1990 distinct: false,
1991 input: Box::new(LogicalOperator::Filter(FilterOp {
1992 predicate: LogicalExpression::Binary {
1993 left: Box::new(LogicalExpression::Property {
1994 variable: "n".to_string(),
1995 property: "x".to_string(),
1996 }),
1997 op: BinaryOp::Gt,
1998 right: Box::new(LogicalExpression::Literal(Value::Int64(1))),
1999 },
2000 pushdown_hint: None,
2001 input: Box::new(LogicalOperator::Filter(FilterOp {
2002 predicate: LogicalExpression::Binary {
2003 left: Box::new(LogicalExpression::Property {
2004 variable: "n".to_string(),
2005 property: "y".to_string(),
2006 }),
2007 op: BinaryOp::Lt,
2008 right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
2009 },
2010 pushdown_hint: None,
2011 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2012 variable: "n".to_string(),
2013 label: None,
2014 input: None,
2015 })),
2016 })),
2017 })),
2018 }));
2019
2020 let optimized = optimizer.optimize(plan).unwrap();
2021 assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
2022 }
2023
2024 #[test]
2025 fn test_cyclic_join_produces_multi_way_join() {
2026 use crate::query::plan::JoinCondition;
2027
2028 let scan_a = LogicalOperator::NodeScan(NodeScanOp {
2030 variable: "a".to_string(),
2031 label: Some("Person".to_string()),
2032 input: None,
2033 });
2034 let scan_b = LogicalOperator::NodeScan(NodeScanOp {
2035 variable: "b".to_string(),
2036 label: Some("Person".to_string()),
2037 input: None,
2038 });
2039 let scan_c = LogicalOperator::NodeScan(NodeScanOp {
2040 variable: "c".to_string(),
2041 label: Some("Person".to_string()),
2042 input: None,
2043 });
2044
2045 let join_ab = LogicalOperator::Join(JoinOp {
2047 left: Box::new(scan_a),
2048 right: Box::new(scan_b),
2049 join_type: JoinType::Inner,
2050 conditions: vec![JoinCondition {
2051 left: LogicalExpression::Variable("a".to_string()),
2052 right: LogicalExpression::Variable("b".to_string()),
2053 }],
2054 });
2055
2056 let join_abc = LogicalOperator::Join(JoinOp {
2057 left: Box::new(join_ab),
2058 right: Box::new(scan_c),
2059 join_type: JoinType::Inner,
2060 conditions: vec![
2061 JoinCondition {
2062 left: LogicalExpression::Variable("b".to_string()),
2063 right: LogicalExpression::Variable("c".to_string()),
2064 },
2065 JoinCondition {
2066 left: LogicalExpression::Variable("c".to_string()),
2067 right: LogicalExpression::Variable("a".to_string()),
2068 },
2069 ],
2070 });
2071
2072 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
2073 items: vec![ReturnItem {
2074 expression: LogicalExpression::Variable("a".to_string()),
2075 alias: None,
2076 }],
2077 distinct: false,
2078 input: Box::new(join_abc),
2079 }));
2080
2081 let mut optimizer = Optimizer::new();
2082 optimizer
2083 .card_estimator
2084 .add_table_stats("Person", cardinality::TableStats::new(1000));
2085
2086 let optimized = optimizer.optimize(plan).unwrap();
2087
2088 fn has_multi_way_join(op: &LogicalOperator) -> bool {
2090 match op {
2091 LogicalOperator::MultiWayJoin(_) => true,
2092 LogicalOperator::Return(ret) => has_multi_way_join(&ret.input),
2093 LogicalOperator::Filter(f) => has_multi_way_join(&f.input),
2094 LogicalOperator::Project(p) => has_multi_way_join(&p.input),
2095 _ => false,
2096 }
2097 }
2098
2099 assert!(
2100 has_multi_way_join(&optimized.root),
2101 "Expected MultiWayJoin for cyclic triangle pattern"
2102 );
2103 }
2104
2105 #[test]
2106 fn test_acyclic_join_uses_binary_joins() {
2107 use crate::query::plan::JoinCondition;
2108
2109 let scan_a = LogicalOperator::NodeScan(NodeScanOp {
2111 variable: "a".to_string(),
2112 label: Some("Person".to_string()),
2113 input: None,
2114 });
2115 let scan_b = LogicalOperator::NodeScan(NodeScanOp {
2116 variable: "b".to_string(),
2117 label: Some("Person".to_string()),
2118 input: None,
2119 });
2120 let scan_c = LogicalOperator::NodeScan(NodeScanOp {
2121 variable: "c".to_string(),
2122 label: Some("Company".to_string()),
2123 input: None,
2124 });
2125
2126 let join_ab = LogicalOperator::Join(JoinOp {
2127 left: Box::new(scan_a),
2128 right: Box::new(scan_b),
2129 join_type: JoinType::Inner,
2130 conditions: vec![JoinCondition {
2131 left: LogicalExpression::Variable("a".to_string()),
2132 right: LogicalExpression::Variable("b".to_string()),
2133 }],
2134 });
2135
2136 let join_abc = LogicalOperator::Join(JoinOp {
2137 left: Box::new(join_ab),
2138 right: Box::new(scan_c),
2139 join_type: JoinType::Inner,
2140 conditions: vec![JoinCondition {
2141 left: LogicalExpression::Variable("b".to_string()),
2142 right: LogicalExpression::Variable("c".to_string()),
2143 }],
2144 });
2145
2146 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
2147 items: vec![ReturnItem {
2148 expression: LogicalExpression::Variable("a".to_string()),
2149 alias: None,
2150 }],
2151 distinct: false,
2152 input: Box::new(join_abc),
2153 }));
2154
2155 let mut optimizer = Optimizer::new();
2156 optimizer
2157 .card_estimator
2158 .add_table_stats("Person", cardinality::TableStats::new(1000));
2159 optimizer
2160 .card_estimator
2161 .add_table_stats("Company", cardinality::TableStats::new(100));
2162
2163 let optimized = optimizer.optimize(plan).unwrap();
2164
2165 fn has_multi_way_join(op: &LogicalOperator) -> bool {
2167 match op {
2168 LogicalOperator::MultiWayJoin(_) => true,
2169 LogicalOperator::Return(ret) => has_multi_way_join(&ret.input),
2170 LogicalOperator::Filter(f) => has_multi_way_join(&f.input),
2171 LogicalOperator::Project(p) => has_multi_way_join(&p.input),
2172 LogicalOperator::Join(j) => {
2173 has_multi_way_join(&j.left) || has_multi_way_join(&j.right)
2174 }
2175 _ => false,
2176 }
2177 }
2178
2179 assert!(
2180 !has_multi_way_join(&optimized.root),
2181 "Acyclic join should NOT produce MultiWayJoin"
2182 );
2183 }
2184}