1pub mod cardinality;
15pub mod cost;
16pub mod join_order;
17
18pub use cardinality::{
19 CardinalityEstimator, ColumnStats, EstimationLog, SelectivityConfig, TableStats,
20};
21pub use cost::{Cost, CostModel};
22pub use join_order::{BitSet, DPccp, JoinGraph, JoinGraphBuilder, JoinPlan};
23
24use crate::query::plan::{
25 FilterOp, JoinCondition, LogicalExpression, LogicalOperator, LogicalPlan, MultiWayJoinOp,
26};
27use grafeo_common::utils::error::Result;
28use std::collections::HashSet;
29
30#[derive(Debug, Clone)]
32struct JoinInfo {
33 left_var: String,
34 right_var: String,
35 left_expr: LogicalExpression,
36 right_expr: LogicalExpression,
37}
38
39#[derive(Debug, Clone, PartialEq, Eq, Hash)]
41enum RequiredColumn {
42 Variable(String),
44 Property(String, String),
46}
47
48pub struct Optimizer {
53 enable_filter_pushdown: bool,
55 enable_join_reorder: bool,
57 enable_projection_pushdown: bool,
59 cost_model: CostModel,
61 card_estimator: CardinalityEstimator,
63}
64
65impl Optimizer {
66 #[must_use]
68 pub fn new() -> Self {
69 Self {
70 enable_filter_pushdown: true,
71 enable_join_reorder: true,
72 enable_projection_pushdown: true,
73 cost_model: CostModel::new(),
74 card_estimator: CardinalityEstimator::new(),
75 }
76 }
77
78 #[must_use]
84 pub fn from_store(store: &grafeo_core::graph::lpg::LpgStore) -> Self {
85 store.ensure_statistics_fresh();
86 let stats = store.statistics();
87 Self::from_statistics(&stats)
88 }
89
90 #[must_use]
97 pub fn from_graph_store(store: &dyn grafeo_core::graph::GraphStore) -> Self {
98 let stats = store.statistics();
99 Self::from_statistics(&stats)
100 }
101
102 #[cfg(feature = "rdf")]
107 #[must_use]
108 pub fn from_rdf_statistics(rdf_stats: grafeo_core::statistics::RdfStatistics) -> Self {
109 let total = rdf_stats.total_triples;
110 let estimator = CardinalityEstimator::from_rdf_statistics(rdf_stats);
111 Self {
112 enable_filter_pushdown: true,
113 enable_join_reorder: true,
114 enable_projection_pushdown: true,
115 cost_model: CostModel::new().with_graph_totals(total, total),
116 card_estimator: estimator,
117 }
118 }
119
120 #[must_use]
125 fn from_statistics(stats: &grafeo_core::statistics::Statistics) -> Self {
126 let estimator = CardinalityEstimator::from_statistics(stats);
127
128 let avg_fanout = if stats.total_nodes > 0 {
129 (stats.total_edges as f64 / stats.total_nodes as f64).max(1.0)
130 } else {
131 10.0
132 };
133
134 let edge_type_degrees: std::collections::HashMap<String, (f64, f64)> = stats
135 .edge_types
136 .iter()
137 .map(|(name, et)| (name.clone(), (et.avg_out_degree, et.avg_in_degree)))
138 .collect();
139
140 let label_cardinalities: std::collections::HashMap<String, u64> = stats
141 .labels
142 .iter()
143 .map(|(name, ls)| (name.clone(), ls.node_count))
144 .collect();
145
146 Self {
147 enable_filter_pushdown: true,
148 enable_join_reorder: true,
149 enable_projection_pushdown: true,
150 cost_model: CostModel::new()
151 .with_avg_fanout(avg_fanout)
152 .with_edge_type_degrees(edge_type_degrees)
153 .with_label_cardinalities(label_cardinalities)
154 .with_graph_totals(stats.total_nodes, stats.total_edges),
155 card_estimator: estimator,
156 }
157 }
158
159 pub fn with_filter_pushdown(mut self, enabled: bool) -> Self {
161 self.enable_filter_pushdown = enabled;
162 self
163 }
164
165 pub fn with_join_reorder(mut self, enabled: bool) -> Self {
167 self.enable_join_reorder = enabled;
168 self
169 }
170
171 pub fn with_projection_pushdown(mut self, enabled: bool) -> Self {
173 self.enable_projection_pushdown = enabled;
174 self
175 }
176
177 pub fn with_cost_model(mut self, cost_model: CostModel) -> Self {
179 self.cost_model = cost_model;
180 self
181 }
182
183 pub fn with_cardinality_estimator(mut self, estimator: CardinalityEstimator) -> Self {
185 self.card_estimator = estimator;
186 self
187 }
188
189 pub fn with_selectivity_config(mut self, config: SelectivityConfig) -> Self {
191 self.card_estimator = CardinalityEstimator::with_selectivity_config(config);
192 self
193 }
194
195 pub fn cost_model(&self) -> &CostModel {
197 &self.cost_model
198 }
199
200 pub fn cardinality_estimator(&self) -> &CardinalityEstimator {
202 &self.card_estimator
203 }
204
205 pub fn estimate_cost(&self, plan: &LogicalPlan) -> Cost {
211 self.cost_model
212 .estimate_tree(&plan.root, &self.card_estimator)
213 }
214
215 pub fn estimate_cardinality(&self, plan: &LogicalPlan) -> f64 {
217 self.card_estimator.estimate(&plan.root)
218 }
219
220 pub fn optimize(&self, plan: LogicalPlan) -> Result<LogicalPlan> {
226 let _span = tracing::debug_span!("grafeo::query::optimize").entered();
227 let mut root = plan.root;
228
229 if self.enable_filter_pushdown {
231 root = self.push_filters_down(root);
232 }
233
234 if self.enable_join_reorder {
235 root = self.reorder_joins(root);
236 }
237
238 if self.enable_projection_pushdown {
239 root = self.push_projections_down(root);
240 }
241
242 Ok(LogicalPlan {
243 root,
244 explain: plan.explain,
245 profile: plan.profile,
246 })
247 }
248
249 fn push_projections_down(&self, op: LogicalOperator) -> LogicalOperator {
256 let required = self.collect_required_columns(&op);
258
259 self.push_projections_recursive(op, &required)
261 }
262
263 fn collect_required_columns(&self, op: &LogicalOperator) -> HashSet<RequiredColumn> {
265 let mut required = HashSet::new();
266 Self::collect_required_recursive(op, &mut required);
267 required
268 }
269
270 fn collect_required_recursive(op: &LogicalOperator, required: &mut HashSet<RequiredColumn>) {
272 match op {
273 LogicalOperator::Return(ret) => {
274 for item in &ret.items {
275 Self::collect_from_expression(&item.expression, required);
276 }
277 Self::collect_required_recursive(&ret.input, required);
278 }
279 LogicalOperator::Project(proj) => {
280 for p in &proj.projections {
281 Self::collect_from_expression(&p.expression, required);
282 }
283 Self::collect_required_recursive(&proj.input, required);
284 }
285 LogicalOperator::Filter(filter) => {
286 Self::collect_from_expression(&filter.predicate, required);
287 Self::collect_required_recursive(&filter.input, required);
288 }
289 LogicalOperator::Sort(sort) => {
290 for key in &sort.keys {
291 Self::collect_from_expression(&key.expression, required);
292 }
293 Self::collect_required_recursive(&sort.input, required);
294 }
295 LogicalOperator::Aggregate(agg) => {
296 for expr in &agg.group_by {
297 Self::collect_from_expression(expr, required);
298 }
299 for agg_expr in &agg.aggregates {
300 if let Some(ref expr) = agg_expr.expression {
301 Self::collect_from_expression(expr, required);
302 }
303 }
304 if let Some(ref having) = agg.having {
305 Self::collect_from_expression(having, required);
306 }
307 Self::collect_required_recursive(&agg.input, required);
308 }
309 LogicalOperator::Join(join) => {
310 for cond in &join.conditions {
311 Self::collect_from_expression(&cond.left, required);
312 Self::collect_from_expression(&cond.right, required);
313 }
314 Self::collect_required_recursive(&join.left, required);
315 Self::collect_required_recursive(&join.right, required);
316 }
317 LogicalOperator::Expand(expand) => {
318 required.insert(RequiredColumn::Variable(expand.from_variable.clone()));
320 required.insert(RequiredColumn::Variable(expand.to_variable.clone()));
321 if let Some(ref edge_var) = expand.edge_variable {
322 required.insert(RequiredColumn::Variable(edge_var.clone()));
323 }
324 Self::collect_required_recursive(&expand.input, required);
325 }
326 LogicalOperator::Limit(limit) => {
327 Self::collect_required_recursive(&limit.input, required);
328 }
329 LogicalOperator::Skip(skip) => {
330 Self::collect_required_recursive(&skip.input, required);
331 }
332 LogicalOperator::Distinct(distinct) => {
333 Self::collect_required_recursive(&distinct.input, required);
334 }
335 LogicalOperator::NodeScan(scan) => {
336 required.insert(RequiredColumn::Variable(scan.variable.clone()));
337 }
338 LogicalOperator::EdgeScan(scan) => {
339 required.insert(RequiredColumn::Variable(scan.variable.clone()));
340 }
341 LogicalOperator::MultiWayJoin(mwj) => {
342 for cond in &mwj.conditions {
343 Self::collect_from_expression(&cond.left, required);
344 Self::collect_from_expression(&cond.right, required);
345 }
346 for input in &mwj.inputs {
347 Self::collect_required_recursive(input, required);
348 }
349 }
350 _ => {}
351 }
352 }
353
354 fn collect_from_expression(expr: &LogicalExpression, required: &mut HashSet<RequiredColumn>) {
356 match expr {
357 LogicalExpression::Variable(var) => {
358 required.insert(RequiredColumn::Variable(var.clone()));
359 }
360 LogicalExpression::Property { variable, property } => {
361 required.insert(RequiredColumn::Property(variable.clone(), property.clone()));
362 required.insert(RequiredColumn::Variable(variable.clone()));
363 }
364 LogicalExpression::Binary { left, right, .. } => {
365 Self::collect_from_expression(left, required);
366 Self::collect_from_expression(right, required);
367 }
368 LogicalExpression::Unary { operand, .. } => {
369 Self::collect_from_expression(operand, required);
370 }
371 LogicalExpression::FunctionCall { args, .. } => {
372 for arg in args {
373 Self::collect_from_expression(arg, required);
374 }
375 }
376 LogicalExpression::List(items) => {
377 for item in items {
378 Self::collect_from_expression(item, required);
379 }
380 }
381 LogicalExpression::Map(pairs) => {
382 for (_, value) in pairs {
383 Self::collect_from_expression(value, required);
384 }
385 }
386 LogicalExpression::IndexAccess { base, index } => {
387 Self::collect_from_expression(base, required);
388 Self::collect_from_expression(index, required);
389 }
390 LogicalExpression::SliceAccess { base, start, end } => {
391 Self::collect_from_expression(base, required);
392 if let Some(s) = start {
393 Self::collect_from_expression(s, required);
394 }
395 if let Some(e) = end {
396 Self::collect_from_expression(e, required);
397 }
398 }
399 LogicalExpression::Case {
400 operand,
401 when_clauses,
402 else_clause,
403 } => {
404 if let Some(op) = operand {
405 Self::collect_from_expression(op, required);
406 }
407 for (cond, result) in when_clauses {
408 Self::collect_from_expression(cond, required);
409 Self::collect_from_expression(result, required);
410 }
411 if let Some(else_expr) = else_clause {
412 Self::collect_from_expression(else_expr, required);
413 }
414 }
415 LogicalExpression::Labels(var)
416 | LogicalExpression::Type(var)
417 | LogicalExpression::Id(var) => {
418 required.insert(RequiredColumn::Variable(var.clone()));
419 }
420 LogicalExpression::ListComprehension {
421 list_expr,
422 filter_expr,
423 map_expr,
424 ..
425 } => {
426 Self::collect_from_expression(list_expr, required);
427 if let Some(filter) = filter_expr {
428 Self::collect_from_expression(filter, required);
429 }
430 Self::collect_from_expression(map_expr, required);
431 }
432 _ => {}
433 }
434 }
435
436 fn push_projections_recursive(
438 &self,
439 op: LogicalOperator,
440 required: &HashSet<RequiredColumn>,
441 ) -> LogicalOperator {
442 match op {
443 LogicalOperator::Return(mut ret) => {
444 ret.input = Box::new(self.push_projections_recursive(*ret.input, required));
445 LogicalOperator::Return(ret)
446 }
447 LogicalOperator::Project(mut proj) => {
448 proj.input = Box::new(self.push_projections_recursive(*proj.input, required));
449 LogicalOperator::Project(proj)
450 }
451 LogicalOperator::Filter(mut filter) => {
452 filter.input = Box::new(self.push_projections_recursive(*filter.input, required));
453 LogicalOperator::Filter(filter)
454 }
455 LogicalOperator::Sort(mut sort) => {
456 sort.input = Box::new(self.push_projections_recursive(*sort.input, required));
459 LogicalOperator::Sort(sort)
460 }
461 LogicalOperator::Aggregate(mut agg) => {
462 agg.input = Box::new(self.push_projections_recursive(*agg.input, required));
463 LogicalOperator::Aggregate(agg)
464 }
465 LogicalOperator::Join(mut join) => {
466 let left_vars = self.collect_output_variables(&join.left);
469 let right_vars = self.collect_output_variables(&join.right);
470
471 let left_required: HashSet<_> = required
473 .iter()
474 .filter(|c| match c {
475 RequiredColumn::Variable(v) => left_vars.contains(v),
476 RequiredColumn::Property(v, _) => left_vars.contains(v),
477 })
478 .cloned()
479 .collect();
480
481 let right_required: HashSet<_> = required
482 .iter()
483 .filter(|c| match c {
484 RequiredColumn::Variable(v) => right_vars.contains(v),
485 RequiredColumn::Property(v, _) => right_vars.contains(v),
486 })
487 .cloned()
488 .collect();
489
490 join.left = Box::new(self.push_projections_recursive(*join.left, &left_required));
491 join.right =
492 Box::new(self.push_projections_recursive(*join.right, &right_required));
493 LogicalOperator::Join(join)
494 }
495 LogicalOperator::Expand(mut expand) => {
496 expand.input = Box::new(self.push_projections_recursive(*expand.input, required));
497 LogicalOperator::Expand(expand)
498 }
499 LogicalOperator::Limit(mut limit) => {
500 limit.input = Box::new(self.push_projections_recursive(*limit.input, required));
501 LogicalOperator::Limit(limit)
502 }
503 LogicalOperator::Skip(mut skip) => {
504 skip.input = Box::new(self.push_projections_recursive(*skip.input, required));
505 LogicalOperator::Skip(skip)
506 }
507 LogicalOperator::Distinct(mut distinct) => {
508 distinct.input =
509 Box::new(self.push_projections_recursive(*distinct.input, required));
510 LogicalOperator::Distinct(distinct)
511 }
512 LogicalOperator::MapCollect(mut mc) => {
513 mc.input = Box::new(self.push_projections_recursive(*mc.input, required));
514 LogicalOperator::MapCollect(mc)
515 }
516 LogicalOperator::MultiWayJoin(mut mwj) => {
517 mwj.inputs = mwj
518 .inputs
519 .into_iter()
520 .map(|input| self.push_projections_recursive(input, required))
521 .collect();
522 LogicalOperator::MultiWayJoin(mwj)
523 }
524 other => other,
525 }
526 }
527
528 fn reorder_joins(&self, op: LogicalOperator) -> LogicalOperator {
535 let op = self.reorder_joins_recursive(op);
537
538 if let Some((relations, conditions)) = self.extract_join_tree(&op)
540 && relations.len() >= 2
541 && let Some(optimized) = self.optimize_join_order(&relations, &conditions)
542 {
543 return optimized;
544 }
545
546 op
547 }
548
549 fn reorder_joins_recursive(&self, op: LogicalOperator) -> LogicalOperator {
551 match op {
552 LogicalOperator::Return(mut ret) => {
553 ret.input = Box::new(self.reorder_joins(*ret.input));
554 LogicalOperator::Return(ret)
555 }
556 LogicalOperator::Project(mut proj) => {
557 proj.input = Box::new(self.reorder_joins(*proj.input));
558 LogicalOperator::Project(proj)
559 }
560 LogicalOperator::Filter(mut filter) => {
561 filter.input = Box::new(self.reorder_joins(*filter.input));
562 LogicalOperator::Filter(filter)
563 }
564 LogicalOperator::Limit(mut limit) => {
565 limit.input = Box::new(self.reorder_joins(*limit.input));
566 LogicalOperator::Limit(limit)
567 }
568 LogicalOperator::Skip(mut skip) => {
569 skip.input = Box::new(self.reorder_joins(*skip.input));
570 LogicalOperator::Skip(skip)
571 }
572 LogicalOperator::Sort(mut sort) => {
573 sort.input = Box::new(self.reorder_joins(*sort.input));
574 LogicalOperator::Sort(sort)
575 }
576 LogicalOperator::Distinct(mut distinct) => {
577 distinct.input = Box::new(self.reorder_joins(*distinct.input));
578 LogicalOperator::Distinct(distinct)
579 }
580 LogicalOperator::Aggregate(mut agg) => {
581 agg.input = Box::new(self.reorder_joins(*agg.input));
582 LogicalOperator::Aggregate(agg)
583 }
584 LogicalOperator::Expand(mut expand) => {
585 expand.input = Box::new(self.reorder_joins(*expand.input));
586 LogicalOperator::Expand(expand)
587 }
588 LogicalOperator::MapCollect(mut mc) => {
589 mc.input = Box::new(self.reorder_joins(*mc.input));
590 LogicalOperator::MapCollect(mc)
591 }
592 LogicalOperator::MultiWayJoin(mut mwj) => {
593 mwj.inputs = mwj
594 .inputs
595 .into_iter()
596 .map(|input| self.reorder_joins(input))
597 .collect();
598 LogicalOperator::MultiWayJoin(mwj)
599 }
600 other => other,
602 }
603 }
604
605 fn extract_join_tree(
609 &self,
610 op: &LogicalOperator,
611 ) -> Option<(Vec<(String, LogicalOperator)>, Vec<JoinInfo>)> {
612 let mut relations = Vec::new();
613 let mut join_conditions = Vec::new();
614
615 if !self.collect_join_tree(op, &mut relations, &mut join_conditions) {
616 return None;
617 }
618
619 if relations.len() < 2 {
620 return None;
621 }
622
623 Some((relations, join_conditions))
624 }
625
626 fn collect_join_tree(
630 &self,
631 op: &LogicalOperator,
632 relations: &mut Vec<(String, LogicalOperator)>,
633 conditions: &mut Vec<JoinInfo>,
634 ) -> bool {
635 match op {
636 LogicalOperator::Join(join) => {
637 let left_ok = self.collect_join_tree(&join.left, relations, conditions);
639 let right_ok = self.collect_join_tree(&join.right, relations, conditions);
640
641 for cond in &join.conditions {
643 if let (Some(left_var), Some(right_var)) = (
644 self.extract_variable_from_expr(&cond.left),
645 self.extract_variable_from_expr(&cond.right),
646 ) {
647 conditions.push(JoinInfo {
648 left_var,
649 right_var,
650 left_expr: cond.left.clone(),
651 right_expr: cond.right.clone(),
652 });
653 }
654 }
655
656 left_ok && right_ok
657 }
658 LogicalOperator::NodeScan(scan) => {
659 relations.push((scan.variable.clone(), op.clone()));
660 true
661 }
662 LogicalOperator::EdgeScan(scan) => {
663 relations.push((scan.variable.clone(), op.clone()));
664 true
665 }
666 LogicalOperator::Filter(filter) => {
667 self.collect_join_tree(&filter.input, relations, conditions)
669 }
670 LogicalOperator::Expand(expand) => {
671 relations.push((expand.to_variable.clone(), op.clone()));
674 true
675 }
676 _ => false,
677 }
678 }
679
680 fn extract_variable_from_expr(&self, expr: &LogicalExpression) -> Option<String> {
682 match expr {
683 LogicalExpression::Variable(v) => Some(v.clone()),
684 LogicalExpression::Property { variable, .. } => Some(variable.clone()),
685 _ => None,
686 }
687 }
688
689 fn optimize_join_order(
692 &self,
693 relations: &[(String, LogicalOperator)],
694 conditions: &[JoinInfo],
695 ) -> Option<LogicalOperator> {
696 use join_order::{DPccp, JoinGraphBuilder};
697
698 let mut builder = JoinGraphBuilder::new();
700
701 for (var, relation) in relations {
702 builder.add_relation(var, relation.clone());
703 }
704
705 for cond in conditions {
706 builder.add_join_condition(
707 &cond.left_var,
708 &cond.right_var,
709 cond.left_expr.clone(),
710 cond.right_expr.clone(),
711 );
712 }
713
714 let graph = builder.build();
715
716 if graph.is_cyclic() && relations.len() >= 3 {
721 let mut var_counts: std::collections::HashMap<&str, usize> =
723 std::collections::HashMap::new();
724 for cond in conditions {
725 *var_counts.entry(&cond.left_var).or_default() += 1;
726 *var_counts.entry(&cond.right_var).or_default() += 1;
727 }
728 let shared_variables: Vec<String> = var_counts
729 .into_iter()
730 .filter(|(_, count)| *count >= 2)
731 .map(|(var, _)| var.to_string())
732 .collect();
733
734 let join_conditions: Vec<JoinCondition> = conditions
735 .iter()
736 .map(|c| JoinCondition {
737 left: c.left_expr.clone(),
738 right: c.right_expr.clone(),
739 })
740 .collect();
741
742 return Some(LogicalOperator::MultiWayJoin(MultiWayJoinOp {
743 inputs: relations.iter().map(|(_, rel)| rel.clone()).collect(),
744 conditions: join_conditions,
745 shared_variables,
746 }));
747 }
748
749 let mut dpccp = DPccp::new(&graph, &self.cost_model, &self.card_estimator);
751 let plan = dpccp.optimize()?;
752
753 Some(plan.operator)
754 }
755
756 fn push_filters_down(&self, op: LogicalOperator) -> LogicalOperator {
761 match op {
762 LogicalOperator::Filter(filter) => {
764 let optimized_input = self.push_filters_down(*filter.input);
765 self.try_push_filter_into(filter.predicate, optimized_input)
766 }
767 LogicalOperator::Return(mut ret) => {
769 ret.input = Box::new(self.push_filters_down(*ret.input));
770 LogicalOperator::Return(ret)
771 }
772 LogicalOperator::Project(mut proj) => {
773 proj.input = Box::new(self.push_filters_down(*proj.input));
774 LogicalOperator::Project(proj)
775 }
776 LogicalOperator::Limit(mut limit) => {
777 limit.input = Box::new(self.push_filters_down(*limit.input));
778 LogicalOperator::Limit(limit)
779 }
780 LogicalOperator::Skip(mut skip) => {
781 skip.input = Box::new(self.push_filters_down(*skip.input));
782 LogicalOperator::Skip(skip)
783 }
784 LogicalOperator::Sort(mut sort) => {
785 sort.input = Box::new(self.push_filters_down(*sort.input));
786 LogicalOperator::Sort(sort)
787 }
788 LogicalOperator::Distinct(mut distinct) => {
789 distinct.input = Box::new(self.push_filters_down(*distinct.input));
790 LogicalOperator::Distinct(distinct)
791 }
792 LogicalOperator::Expand(mut expand) => {
793 expand.input = Box::new(self.push_filters_down(*expand.input));
794 LogicalOperator::Expand(expand)
795 }
796 LogicalOperator::Join(mut join) => {
797 join.left = Box::new(self.push_filters_down(*join.left));
798 join.right = Box::new(self.push_filters_down(*join.right));
799 LogicalOperator::Join(join)
800 }
801 LogicalOperator::Aggregate(mut agg) => {
802 agg.input = Box::new(self.push_filters_down(*agg.input));
803 LogicalOperator::Aggregate(agg)
804 }
805 LogicalOperator::MapCollect(mut mc) => {
806 mc.input = Box::new(self.push_filters_down(*mc.input));
807 LogicalOperator::MapCollect(mc)
808 }
809 LogicalOperator::MultiWayJoin(mut mwj) => {
810 mwj.inputs = mwj
811 .inputs
812 .into_iter()
813 .map(|input| self.push_filters_down(input))
814 .collect();
815 LogicalOperator::MultiWayJoin(mwj)
816 }
817 other => other,
819 }
820 }
821
822 fn try_push_filter_into(
827 &self,
828 predicate: LogicalExpression,
829 op: LogicalOperator,
830 ) -> LogicalOperator {
831 match op {
832 LogicalOperator::Project(mut proj) => {
834 let predicate_vars = self.extract_variables(&predicate);
835 let computed_vars = self.extract_projection_aliases(&proj.projections);
836
837 if predicate_vars.is_disjoint(&computed_vars) {
839 proj.input = Box::new(self.try_push_filter_into(predicate, *proj.input));
840 LogicalOperator::Project(proj)
841 } else {
842 LogicalOperator::Filter(FilterOp {
844 predicate,
845 pushdown_hint: None,
846 input: Box::new(LogicalOperator::Project(proj)),
847 })
848 }
849 }
850
851 LogicalOperator::Return(mut ret) => {
853 ret.input = Box::new(self.try_push_filter_into(predicate, *ret.input));
854 LogicalOperator::Return(ret)
855 }
856
857 LogicalOperator::Expand(mut expand) => {
859 let predicate_vars = self.extract_variables(&predicate);
860
861 let mut introduced_vars = vec![&expand.to_variable];
866 if let Some(ref edge_var) = expand.edge_variable {
867 introduced_vars.push(edge_var);
868 }
869 if let Some(ref path_alias) = expand.path_alias {
870 introduced_vars.push(path_alias);
871 }
872
873 let uses_introduced_vars =
875 predicate_vars.iter().any(|v| introduced_vars.contains(&v));
876
877 if !uses_introduced_vars {
878 expand.input = Box::new(self.try_push_filter_into(predicate, *expand.input));
880 LogicalOperator::Expand(expand)
881 } else {
882 LogicalOperator::Filter(FilterOp {
884 predicate,
885 pushdown_hint: None,
886 input: Box::new(LogicalOperator::Expand(expand)),
887 })
888 }
889 }
890
891 LogicalOperator::Join(mut join) => {
893 let predicate_vars = self.extract_variables(&predicate);
894 let left_vars = self.collect_output_variables(&join.left);
895 let right_vars = self.collect_output_variables(&join.right);
896
897 let uses_left = predicate_vars.iter().any(|v| left_vars.contains(v));
898 let uses_right = predicate_vars.iter().any(|v| right_vars.contains(v));
899
900 if uses_left && !uses_right {
901 join.left = Box::new(self.try_push_filter_into(predicate, *join.left));
903 LogicalOperator::Join(join)
904 } else if uses_right && !uses_left {
905 join.right = Box::new(self.try_push_filter_into(predicate, *join.right));
907 LogicalOperator::Join(join)
908 } else {
909 LogicalOperator::Filter(FilterOp {
911 predicate,
912 pushdown_hint: None,
913 input: Box::new(LogicalOperator::Join(join)),
914 })
915 }
916 }
917
918 LogicalOperator::Aggregate(agg) => LogicalOperator::Filter(FilterOp {
920 predicate,
921 pushdown_hint: None,
922 input: Box::new(LogicalOperator::Aggregate(agg)),
923 }),
924
925 LogicalOperator::NodeScan(scan) => LogicalOperator::Filter(FilterOp {
927 predicate,
928 pushdown_hint: None,
929 input: Box::new(LogicalOperator::NodeScan(scan)),
930 }),
931
932 other => LogicalOperator::Filter(FilterOp {
934 predicate,
935 pushdown_hint: None,
936 input: Box::new(other),
937 }),
938 }
939 }
940
941 fn collect_output_variables(&self, op: &LogicalOperator) -> HashSet<String> {
943 let mut vars = HashSet::new();
944 Self::collect_output_variables_recursive(op, &mut vars);
945 vars
946 }
947
948 fn collect_output_variables_recursive(op: &LogicalOperator, vars: &mut HashSet<String>) {
950 match op {
951 LogicalOperator::NodeScan(scan) => {
952 vars.insert(scan.variable.clone());
953 }
954 LogicalOperator::EdgeScan(scan) => {
955 vars.insert(scan.variable.clone());
956 }
957 LogicalOperator::Expand(expand) => {
958 vars.insert(expand.to_variable.clone());
959 if let Some(edge_var) = &expand.edge_variable {
960 vars.insert(edge_var.clone());
961 }
962 Self::collect_output_variables_recursive(&expand.input, vars);
963 }
964 LogicalOperator::Filter(filter) => {
965 Self::collect_output_variables_recursive(&filter.input, vars);
966 }
967 LogicalOperator::Project(proj) => {
968 for p in &proj.projections {
969 if let Some(alias) = &p.alias {
970 vars.insert(alias.clone());
971 }
972 }
973 Self::collect_output_variables_recursive(&proj.input, vars);
974 }
975 LogicalOperator::Join(join) => {
976 Self::collect_output_variables_recursive(&join.left, vars);
977 Self::collect_output_variables_recursive(&join.right, vars);
978 }
979 LogicalOperator::Aggregate(agg) => {
980 for expr in &agg.group_by {
981 Self::collect_variables(expr, vars);
982 }
983 for agg_expr in &agg.aggregates {
984 if let Some(alias) = &agg_expr.alias {
985 vars.insert(alias.clone());
986 }
987 }
988 }
989 LogicalOperator::Return(ret) => {
990 Self::collect_output_variables_recursive(&ret.input, vars);
991 }
992 LogicalOperator::Limit(limit) => {
993 Self::collect_output_variables_recursive(&limit.input, vars);
994 }
995 LogicalOperator::Skip(skip) => {
996 Self::collect_output_variables_recursive(&skip.input, vars);
997 }
998 LogicalOperator::Sort(sort) => {
999 Self::collect_output_variables_recursive(&sort.input, vars);
1000 }
1001 LogicalOperator::Distinct(distinct) => {
1002 Self::collect_output_variables_recursive(&distinct.input, vars);
1003 }
1004 _ => {}
1005 }
1006 }
1007
1008 fn extract_variables(&self, expr: &LogicalExpression) -> HashSet<String> {
1010 let mut vars = HashSet::new();
1011 Self::collect_variables(expr, &mut vars);
1012 vars
1013 }
1014
1015 fn collect_variables(expr: &LogicalExpression, vars: &mut HashSet<String>) {
1017 match expr {
1018 LogicalExpression::Variable(name) => {
1019 vars.insert(name.clone());
1020 }
1021 LogicalExpression::Property { variable, .. } => {
1022 vars.insert(variable.clone());
1023 }
1024 LogicalExpression::Binary { left, right, .. } => {
1025 Self::collect_variables(left, vars);
1026 Self::collect_variables(right, vars);
1027 }
1028 LogicalExpression::Unary { operand, .. } => {
1029 Self::collect_variables(operand, vars);
1030 }
1031 LogicalExpression::FunctionCall { args, .. } => {
1032 for arg in args {
1033 Self::collect_variables(arg, vars);
1034 }
1035 }
1036 LogicalExpression::List(items) => {
1037 for item in items {
1038 Self::collect_variables(item, vars);
1039 }
1040 }
1041 LogicalExpression::Map(pairs) => {
1042 for (_, value) in pairs {
1043 Self::collect_variables(value, vars);
1044 }
1045 }
1046 LogicalExpression::IndexAccess { base, index } => {
1047 Self::collect_variables(base, vars);
1048 Self::collect_variables(index, vars);
1049 }
1050 LogicalExpression::SliceAccess { base, start, end } => {
1051 Self::collect_variables(base, vars);
1052 if let Some(s) = start {
1053 Self::collect_variables(s, vars);
1054 }
1055 if let Some(e) = end {
1056 Self::collect_variables(e, vars);
1057 }
1058 }
1059 LogicalExpression::Case {
1060 operand,
1061 when_clauses,
1062 else_clause,
1063 } => {
1064 if let Some(op) = operand {
1065 Self::collect_variables(op, vars);
1066 }
1067 for (cond, result) in when_clauses {
1068 Self::collect_variables(cond, vars);
1069 Self::collect_variables(result, vars);
1070 }
1071 if let Some(else_expr) = else_clause {
1072 Self::collect_variables(else_expr, vars);
1073 }
1074 }
1075 LogicalExpression::Labels(var)
1076 | LogicalExpression::Type(var)
1077 | LogicalExpression::Id(var) => {
1078 vars.insert(var.clone());
1079 }
1080 LogicalExpression::Literal(_) | LogicalExpression::Parameter(_) => {}
1081 LogicalExpression::ListComprehension {
1082 list_expr,
1083 filter_expr,
1084 map_expr,
1085 ..
1086 } => {
1087 Self::collect_variables(list_expr, vars);
1088 if let Some(filter) = filter_expr {
1089 Self::collect_variables(filter, vars);
1090 }
1091 Self::collect_variables(map_expr, vars);
1092 }
1093 LogicalExpression::ListPredicate {
1094 list_expr,
1095 predicate,
1096 ..
1097 } => {
1098 Self::collect_variables(list_expr, vars);
1099 Self::collect_variables(predicate, vars);
1100 }
1101 LogicalExpression::ExistsSubquery(_)
1102 | LogicalExpression::CountSubquery(_)
1103 | LogicalExpression::ValueSubquery(_) => {
1104 }
1106 LogicalExpression::PatternComprehension { projection, .. } => {
1107 Self::collect_variables(projection, vars);
1108 }
1109 LogicalExpression::MapProjection { base, entries } => {
1110 vars.insert(base.clone());
1111 for entry in entries {
1112 if let crate::query::plan::MapProjectionEntry::LiteralEntry(_, expr) = entry {
1113 Self::collect_variables(expr, vars);
1114 }
1115 }
1116 }
1117 LogicalExpression::Reduce {
1118 initial,
1119 list,
1120 expression,
1121 ..
1122 } => {
1123 Self::collect_variables(initial, vars);
1124 Self::collect_variables(list, vars);
1125 Self::collect_variables(expression, vars);
1126 }
1127 }
1128 }
1129
1130 fn extract_projection_aliases(
1132 &self,
1133 projections: &[crate::query::plan::Projection],
1134 ) -> HashSet<String> {
1135 projections.iter().filter_map(|p| p.alias.clone()).collect()
1136 }
1137}
1138
1139impl Default for Optimizer {
1140 fn default() -> Self {
1141 Self::new()
1142 }
1143}
1144
1145#[cfg(test)]
1146mod tests {
1147 use super::*;
1148 use crate::query::plan::{
1149 AggregateExpr, AggregateFunction, AggregateOp, BinaryOp, DistinctOp, ExpandDirection,
1150 ExpandOp, JoinOp, JoinType, LimitOp, NodeScanOp, PathMode, ProjectOp, Projection,
1151 ReturnItem, ReturnOp, SkipOp, SortKey, SortOp, SortOrder, UnaryOp,
1152 };
1153 use grafeo_common::types::Value;
1154
1155 #[test]
1156 fn test_optimizer_filter_pushdown_simple() {
1157 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1162 items: vec![ReturnItem {
1163 expression: LogicalExpression::Variable("n".to_string()),
1164 alias: None,
1165 }],
1166 distinct: false,
1167 input: Box::new(LogicalOperator::Filter(FilterOp {
1168 predicate: LogicalExpression::Binary {
1169 left: Box::new(LogicalExpression::Property {
1170 variable: "n".to_string(),
1171 property: "age".to_string(),
1172 }),
1173 op: BinaryOp::Gt,
1174 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1175 },
1176 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1177 variable: "n".to_string(),
1178 label: Some("Person".to_string()),
1179 input: None,
1180 })),
1181 pushdown_hint: None,
1182 })),
1183 }));
1184
1185 let optimizer = Optimizer::new();
1186 let optimized = optimizer.optimize(plan).unwrap();
1187
1188 if let LogicalOperator::Return(ret) = &optimized.root
1190 && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1191 && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1192 {
1193 assert_eq!(scan.variable, "n");
1194 return;
1195 }
1196 panic!("Expected Return -> Filter -> NodeScan structure");
1197 }
1198
1199 #[test]
1200 fn test_optimizer_filter_pushdown_through_expand() {
1201 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1205 items: vec![ReturnItem {
1206 expression: LogicalExpression::Variable("b".to_string()),
1207 alias: None,
1208 }],
1209 distinct: false,
1210 input: Box::new(LogicalOperator::Filter(FilterOp {
1211 predicate: LogicalExpression::Binary {
1212 left: Box::new(LogicalExpression::Property {
1213 variable: "a".to_string(),
1214 property: "age".to_string(),
1215 }),
1216 op: BinaryOp::Gt,
1217 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1218 },
1219 pushdown_hint: None,
1220 input: Box::new(LogicalOperator::Expand(ExpandOp {
1221 from_variable: "a".to_string(),
1222 to_variable: "b".to_string(),
1223 edge_variable: None,
1224 direction: ExpandDirection::Outgoing,
1225 edge_types: vec!["KNOWS".to_string()],
1226 min_hops: 1,
1227 max_hops: Some(1),
1228 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1229 variable: "a".to_string(),
1230 label: Some("Person".to_string()),
1231 input: None,
1232 })),
1233 path_alias: None,
1234 path_mode: PathMode::Walk,
1235 })),
1236 })),
1237 }));
1238
1239 let optimizer = Optimizer::new();
1240 let optimized = optimizer.optimize(plan).unwrap();
1241
1242 if let LogicalOperator::Return(ret) = &optimized.root
1245 && let LogicalOperator::Expand(expand) = ret.input.as_ref()
1246 && let LogicalOperator::Filter(filter) = expand.input.as_ref()
1247 && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1248 {
1249 assert_eq!(scan.variable, "a");
1250 assert_eq!(expand.from_variable, "a");
1251 assert_eq!(expand.to_variable, "b");
1252 return;
1253 }
1254 panic!("Expected Return -> Expand -> Filter -> NodeScan structure");
1255 }
1256
1257 #[test]
1258 fn test_optimizer_filter_not_pushed_through_expand_for_target_var() {
1259 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1263 items: vec![ReturnItem {
1264 expression: LogicalExpression::Variable("a".to_string()),
1265 alias: None,
1266 }],
1267 distinct: false,
1268 input: Box::new(LogicalOperator::Filter(FilterOp {
1269 predicate: LogicalExpression::Binary {
1270 left: Box::new(LogicalExpression::Property {
1271 variable: "b".to_string(),
1272 property: "age".to_string(),
1273 }),
1274 op: BinaryOp::Gt,
1275 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1276 },
1277 pushdown_hint: None,
1278 input: Box::new(LogicalOperator::Expand(ExpandOp {
1279 from_variable: "a".to_string(),
1280 to_variable: "b".to_string(),
1281 edge_variable: None,
1282 direction: ExpandDirection::Outgoing,
1283 edge_types: vec!["KNOWS".to_string()],
1284 min_hops: 1,
1285 max_hops: Some(1),
1286 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1287 variable: "a".to_string(),
1288 label: Some("Person".to_string()),
1289 input: None,
1290 })),
1291 path_alias: None,
1292 path_mode: PathMode::Walk,
1293 })),
1294 })),
1295 }));
1296
1297 let optimizer = Optimizer::new();
1298 let optimized = optimizer.optimize(plan).unwrap();
1299
1300 if let LogicalOperator::Return(ret) = &optimized.root
1303 && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1304 {
1305 if let LogicalExpression::Binary { left, .. } = &filter.predicate
1307 && let LogicalExpression::Property { variable, .. } = left.as_ref()
1308 {
1309 assert_eq!(variable, "b");
1310 }
1311
1312 if let LogicalOperator::Expand(expand) = filter.input.as_ref()
1313 && let LogicalOperator::NodeScan(_) = expand.input.as_ref()
1314 {
1315 return;
1316 }
1317 }
1318 panic!("Expected Return -> Filter -> Expand -> NodeScan structure");
1319 }
1320
1321 #[test]
1322 fn test_optimizer_extract_variables() {
1323 let optimizer = Optimizer::new();
1324
1325 let expr = LogicalExpression::Binary {
1326 left: Box::new(LogicalExpression::Property {
1327 variable: "n".to_string(),
1328 property: "age".to_string(),
1329 }),
1330 op: BinaryOp::Gt,
1331 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1332 };
1333
1334 let vars = optimizer.extract_variables(&expr);
1335 assert_eq!(vars.len(), 1);
1336 assert!(vars.contains("n"));
1337 }
1338
1339 #[test]
1342 fn test_optimizer_default() {
1343 let optimizer = Optimizer::default();
1344 let plan = LogicalPlan::new(LogicalOperator::Empty);
1346 let result = optimizer.optimize(plan);
1347 assert!(result.is_ok());
1348 }
1349
1350 #[test]
1351 fn test_optimizer_with_filter_pushdown_disabled() {
1352 let optimizer = Optimizer::new().with_filter_pushdown(false);
1353
1354 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1355 items: vec![ReturnItem {
1356 expression: LogicalExpression::Variable("n".to_string()),
1357 alias: None,
1358 }],
1359 distinct: false,
1360 input: Box::new(LogicalOperator::Filter(FilterOp {
1361 predicate: LogicalExpression::Literal(Value::Bool(true)),
1362 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1363 variable: "n".to_string(),
1364 label: None,
1365 input: None,
1366 })),
1367 pushdown_hint: None,
1368 })),
1369 }));
1370
1371 let optimized = optimizer.optimize(plan).unwrap();
1372 if let LogicalOperator::Return(ret) = &optimized.root
1374 && let LogicalOperator::Filter(_) = ret.input.as_ref()
1375 {
1376 return;
1377 }
1378 panic!("Expected unchanged structure");
1379 }
1380
1381 #[test]
1382 fn test_optimizer_with_join_reorder_disabled() {
1383 let optimizer = Optimizer::new().with_join_reorder(false);
1384 assert!(
1385 optimizer
1386 .optimize(LogicalPlan::new(LogicalOperator::Empty))
1387 .is_ok()
1388 );
1389 }
1390
1391 #[test]
1392 fn test_optimizer_with_cost_model() {
1393 let cost_model = CostModel::new();
1394 let optimizer = Optimizer::new().with_cost_model(cost_model);
1395 assert!(
1396 optimizer
1397 .cost_model()
1398 .estimate(&LogicalOperator::Empty, 0.0)
1399 .total()
1400 < 0.001
1401 );
1402 }
1403
1404 #[test]
1405 fn test_optimizer_with_cardinality_estimator() {
1406 let mut estimator = CardinalityEstimator::new();
1407 estimator.add_table_stats("Test", TableStats::new(500));
1408 let optimizer = Optimizer::new().with_cardinality_estimator(estimator);
1409
1410 let scan = LogicalOperator::NodeScan(NodeScanOp {
1411 variable: "n".to_string(),
1412 label: Some("Test".to_string()),
1413 input: None,
1414 });
1415 let plan = LogicalPlan::new(scan);
1416
1417 let cardinality = optimizer.estimate_cardinality(&plan);
1418 assert!((cardinality - 500.0).abs() < 0.001);
1419 }
1420
1421 #[test]
1422 fn test_optimizer_estimate_cost() {
1423 let optimizer = Optimizer::new();
1424 let plan = LogicalPlan::new(LogicalOperator::NodeScan(NodeScanOp {
1425 variable: "n".to_string(),
1426 label: None,
1427 input: None,
1428 }));
1429
1430 let cost = optimizer.estimate_cost(&plan);
1431 assert!(cost.total() > 0.0);
1432 }
1433
1434 #[test]
1437 fn test_filter_pushdown_through_project() {
1438 let optimizer = Optimizer::new();
1439
1440 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1441 predicate: LogicalExpression::Binary {
1442 left: Box::new(LogicalExpression::Property {
1443 variable: "n".to_string(),
1444 property: "age".to_string(),
1445 }),
1446 op: BinaryOp::Gt,
1447 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1448 },
1449 pushdown_hint: None,
1450 input: Box::new(LogicalOperator::Project(ProjectOp {
1451 projections: vec![Projection {
1452 expression: LogicalExpression::Variable("n".to_string()),
1453 alias: None,
1454 }],
1455 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1456 variable: "n".to_string(),
1457 label: None,
1458 input: None,
1459 })),
1460 pass_through_input: false,
1461 })),
1462 }));
1463
1464 let optimized = optimizer.optimize(plan).unwrap();
1465
1466 if let LogicalOperator::Project(proj) = &optimized.root
1468 && let LogicalOperator::Filter(_) = proj.input.as_ref()
1469 {
1470 return;
1471 }
1472 panic!("Expected Project -> Filter structure");
1473 }
1474
1475 #[test]
1476 fn test_filter_not_pushed_through_project_with_alias() {
1477 let optimizer = Optimizer::new();
1478
1479 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1481 predicate: LogicalExpression::Binary {
1482 left: Box::new(LogicalExpression::Variable("x".to_string())),
1483 op: BinaryOp::Gt,
1484 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1485 },
1486 pushdown_hint: None,
1487 input: Box::new(LogicalOperator::Project(ProjectOp {
1488 projections: vec![Projection {
1489 expression: LogicalExpression::Property {
1490 variable: "n".to_string(),
1491 property: "age".to_string(),
1492 },
1493 alias: Some("x".to_string()),
1494 }],
1495 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1496 variable: "n".to_string(),
1497 label: None,
1498 input: None,
1499 })),
1500 pass_through_input: false,
1501 })),
1502 }));
1503
1504 let optimized = optimizer.optimize(plan).unwrap();
1505
1506 if let LogicalOperator::Filter(filter) = &optimized.root
1508 && let LogicalOperator::Project(_) = filter.input.as_ref()
1509 {
1510 return;
1511 }
1512 panic!("Expected Filter -> Project structure");
1513 }
1514
1515 #[test]
1516 fn test_filter_pushdown_through_limit() {
1517 let optimizer = Optimizer::new();
1518
1519 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1520 predicate: LogicalExpression::Literal(Value::Bool(true)),
1521 pushdown_hint: None,
1522 input: Box::new(LogicalOperator::Limit(LimitOp {
1523 count: 10.into(),
1524 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1525 variable: "n".to_string(),
1526 label: None,
1527 input: None,
1528 })),
1529 })),
1530 }));
1531
1532 let optimized = optimizer.optimize(plan).unwrap();
1533
1534 if let LogicalOperator::Filter(filter) = &optimized.root
1536 && let LogicalOperator::Limit(_) = filter.input.as_ref()
1537 {
1538 return;
1539 }
1540 panic!("Expected Filter -> Limit structure");
1541 }
1542
1543 #[test]
1544 fn test_filter_pushdown_through_sort() {
1545 let optimizer = Optimizer::new();
1546
1547 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1548 predicate: LogicalExpression::Literal(Value::Bool(true)),
1549 pushdown_hint: None,
1550 input: Box::new(LogicalOperator::Sort(SortOp {
1551 keys: vec![SortKey {
1552 expression: LogicalExpression::Variable("n".to_string()),
1553 order: SortOrder::Ascending,
1554 nulls: None,
1555 }],
1556 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1557 variable: "n".to_string(),
1558 label: None,
1559 input: None,
1560 })),
1561 })),
1562 }));
1563
1564 let optimized = optimizer.optimize(plan).unwrap();
1565
1566 if let LogicalOperator::Filter(filter) = &optimized.root
1568 && let LogicalOperator::Sort(_) = filter.input.as_ref()
1569 {
1570 return;
1571 }
1572 panic!("Expected Filter -> Sort structure");
1573 }
1574
1575 #[test]
1576 fn test_filter_pushdown_through_distinct() {
1577 let optimizer = Optimizer::new();
1578
1579 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1580 predicate: LogicalExpression::Literal(Value::Bool(true)),
1581 pushdown_hint: None,
1582 input: Box::new(LogicalOperator::Distinct(DistinctOp {
1583 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1584 variable: "n".to_string(),
1585 label: None,
1586 input: None,
1587 })),
1588 columns: None,
1589 })),
1590 }));
1591
1592 let optimized = optimizer.optimize(plan).unwrap();
1593
1594 if let LogicalOperator::Filter(filter) = &optimized.root
1596 && let LogicalOperator::Distinct(_) = filter.input.as_ref()
1597 {
1598 return;
1599 }
1600 panic!("Expected Filter -> Distinct structure");
1601 }
1602
1603 #[test]
1604 fn test_filter_not_pushed_through_aggregate() {
1605 let optimizer = Optimizer::new();
1606
1607 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1608 predicate: LogicalExpression::Binary {
1609 left: Box::new(LogicalExpression::Variable("cnt".to_string())),
1610 op: BinaryOp::Gt,
1611 right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1612 },
1613 pushdown_hint: None,
1614 input: Box::new(LogicalOperator::Aggregate(AggregateOp {
1615 group_by: vec![],
1616 aggregates: vec![AggregateExpr {
1617 function: AggregateFunction::Count,
1618 expression: None,
1619 expression2: None,
1620 distinct: false,
1621 alias: Some("cnt".to_string()),
1622 percentile: None,
1623 separator: None,
1624 }],
1625 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1626 variable: "n".to_string(),
1627 label: None,
1628 input: None,
1629 })),
1630 having: None,
1631 })),
1632 }));
1633
1634 let optimized = optimizer.optimize(plan).unwrap();
1635
1636 if let LogicalOperator::Filter(filter) = &optimized.root
1638 && let LogicalOperator::Aggregate(_) = filter.input.as_ref()
1639 {
1640 return;
1641 }
1642 panic!("Expected Filter -> Aggregate structure");
1643 }
1644
1645 #[test]
1646 fn test_filter_pushdown_to_left_join_side() {
1647 let optimizer = Optimizer::new();
1648
1649 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1651 predicate: LogicalExpression::Binary {
1652 left: Box::new(LogicalExpression::Property {
1653 variable: "a".to_string(),
1654 property: "age".to_string(),
1655 }),
1656 op: BinaryOp::Gt,
1657 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1658 },
1659 pushdown_hint: None,
1660 input: Box::new(LogicalOperator::Join(JoinOp {
1661 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1662 variable: "a".to_string(),
1663 label: Some("Person".to_string()),
1664 input: None,
1665 })),
1666 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1667 variable: "b".to_string(),
1668 label: Some("Company".to_string()),
1669 input: None,
1670 })),
1671 join_type: JoinType::Inner,
1672 conditions: vec![],
1673 })),
1674 }));
1675
1676 let optimized = optimizer.optimize(plan).unwrap();
1677
1678 if let LogicalOperator::Join(join) = &optimized.root
1680 && let LogicalOperator::Filter(_) = join.left.as_ref()
1681 {
1682 return;
1683 }
1684 panic!("Expected Join with Filter on left side");
1685 }
1686
1687 #[test]
1688 fn test_filter_pushdown_to_right_join_side() {
1689 let optimizer = Optimizer::new();
1690
1691 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1693 predicate: LogicalExpression::Binary {
1694 left: Box::new(LogicalExpression::Property {
1695 variable: "b".to_string(),
1696 property: "name".to_string(),
1697 }),
1698 op: BinaryOp::Eq,
1699 right: Box::new(LogicalExpression::Literal(Value::String("Acme".into()))),
1700 },
1701 pushdown_hint: None,
1702 input: Box::new(LogicalOperator::Join(JoinOp {
1703 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1704 variable: "a".to_string(),
1705 label: Some("Person".to_string()),
1706 input: None,
1707 })),
1708 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1709 variable: "b".to_string(),
1710 label: Some("Company".to_string()),
1711 input: None,
1712 })),
1713 join_type: JoinType::Inner,
1714 conditions: vec![],
1715 })),
1716 }));
1717
1718 let optimized = optimizer.optimize(plan).unwrap();
1719
1720 if let LogicalOperator::Join(join) = &optimized.root
1722 && let LogicalOperator::Filter(_) = join.right.as_ref()
1723 {
1724 return;
1725 }
1726 panic!("Expected Join with Filter on right side");
1727 }
1728
1729 #[test]
1730 fn test_filter_not_pushed_when_uses_both_join_sides() {
1731 let optimizer = Optimizer::new();
1732
1733 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1735 predicate: LogicalExpression::Binary {
1736 left: Box::new(LogicalExpression::Property {
1737 variable: "a".to_string(),
1738 property: "id".to_string(),
1739 }),
1740 op: BinaryOp::Eq,
1741 right: Box::new(LogicalExpression::Property {
1742 variable: "b".to_string(),
1743 property: "a_id".to_string(),
1744 }),
1745 },
1746 pushdown_hint: None,
1747 input: Box::new(LogicalOperator::Join(JoinOp {
1748 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1749 variable: "a".to_string(),
1750 label: None,
1751 input: None,
1752 })),
1753 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1754 variable: "b".to_string(),
1755 label: None,
1756 input: None,
1757 })),
1758 join_type: JoinType::Inner,
1759 conditions: vec![],
1760 })),
1761 }));
1762
1763 let optimized = optimizer.optimize(plan).unwrap();
1764
1765 if let LogicalOperator::Filter(filter) = &optimized.root
1767 && let LogicalOperator::Join(_) = filter.input.as_ref()
1768 {
1769 return;
1770 }
1771 panic!("Expected Filter -> Join structure");
1772 }
1773
1774 #[test]
1777 fn test_extract_variables_from_variable() {
1778 let optimizer = Optimizer::new();
1779 let expr = LogicalExpression::Variable("x".to_string());
1780 let vars = optimizer.extract_variables(&expr);
1781 assert_eq!(vars.len(), 1);
1782 assert!(vars.contains("x"));
1783 }
1784
1785 #[test]
1786 fn test_extract_variables_from_unary() {
1787 let optimizer = Optimizer::new();
1788 let expr = LogicalExpression::Unary {
1789 op: UnaryOp::Not,
1790 operand: Box::new(LogicalExpression::Variable("x".to_string())),
1791 };
1792 let vars = optimizer.extract_variables(&expr);
1793 assert_eq!(vars.len(), 1);
1794 assert!(vars.contains("x"));
1795 }
1796
1797 #[test]
1798 fn test_extract_variables_from_function_call() {
1799 let optimizer = Optimizer::new();
1800 let expr = LogicalExpression::FunctionCall {
1801 name: "length".to_string(),
1802 args: vec![
1803 LogicalExpression::Variable("a".to_string()),
1804 LogicalExpression::Variable("b".to_string()),
1805 ],
1806 distinct: false,
1807 };
1808 let vars = optimizer.extract_variables(&expr);
1809 assert_eq!(vars.len(), 2);
1810 assert!(vars.contains("a"));
1811 assert!(vars.contains("b"));
1812 }
1813
1814 #[test]
1815 fn test_extract_variables_from_list() {
1816 let optimizer = Optimizer::new();
1817 let expr = LogicalExpression::List(vec![
1818 LogicalExpression::Variable("a".to_string()),
1819 LogicalExpression::Literal(Value::Int64(1)),
1820 LogicalExpression::Variable("b".to_string()),
1821 ]);
1822 let vars = optimizer.extract_variables(&expr);
1823 assert_eq!(vars.len(), 2);
1824 assert!(vars.contains("a"));
1825 assert!(vars.contains("b"));
1826 }
1827
1828 #[test]
1829 fn test_extract_variables_from_map() {
1830 let optimizer = Optimizer::new();
1831 let expr = LogicalExpression::Map(vec![
1832 (
1833 "key1".to_string(),
1834 LogicalExpression::Variable("a".to_string()),
1835 ),
1836 (
1837 "key2".to_string(),
1838 LogicalExpression::Variable("b".to_string()),
1839 ),
1840 ]);
1841 let vars = optimizer.extract_variables(&expr);
1842 assert_eq!(vars.len(), 2);
1843 assert!(vars.contains("a"));
1844 assert!(vars.contains("b"));
1845 }
1846
1847 #[test]
1848 fn test_extract_variables_from_index_access() {
1849 let optimizer = Optimizer::new();
1850 let expr = LogicalExpression::IndexAccess {
1851 base: Box::new(LogicalExpression::Variable("list".to_string())),
1852 index: Box::new(LogicalExpression::Variable("idx".to_string())),
1853 };
1854 let vars = optimizer.extract_variables(&expr);
1855 assert_eq!(vars.len(), 2);
1856 assert!(vars.contains("list"));
1857 assert!(vars.contains("idx"));
1858 }
1859
1860 #[test]
1861 fn test_extract_variables_from_slice_access() {
1862 let optimizer = Optimizer::new();
1863 let expr = LogicalExpression::SliceAccess {
1864 base: Box::new(LogicalExpression::Variable("list".to_string())),
1865 start: Some(Box::new(LogicalExpression::Variable("s".to_string()))),
1866 end: Some(Box::new(LogicalExpression::Variable("e".to_string()))),
1867 };
1868 let vars = optimizer.extract_variables(&expr);
1869 assert_eq!(vars.len(), 3);
1870 assert!(vars.contains("list"));
1871 assert!(vars.contains("s"));
1872 assert!(vars.contains("e"));
1873 }
1874
1875 #[test]
1876 fn test_extract_variables_from_case() {
1877 let optimizer = Optimizer::new();
1878 let expr = LogicalExpression::Case {
1879 operand: Some(Box::new(LogicalExpression::Variable("x".to_string()))),
1880 when_clauses: vec![(
1881 LogicalExpression::Literal(Value::Int64(1)),
1882 LogicalExpression::Variable("a".to_string()),
1883 )],
1884 else_clause: Some(Box::new(LogicalExpression::Variable("b".to_string()))),
1885 };
1886 let vars = optimizer.extract_variables(&expr);
1887 assert_eq!(vars.len(), 3);
1888 assert!(vars.contains("x"));
1889 assert!(vars.contains("a"));
1890 assert!(vars.contains("b"));
1891 }
1892
1893 #[test]
1894 fn test_extract_variables_from_labels() {
1895 let optimizer = Optimizer::new();
1896 let expr = LogicalExpression::Labels("n".to_string());
1897 let vars = optimizer.extract_variables(&expr);
1898 assert_eq!(vars.len(), 1);
1899 assert!(vars.contains("n"));
1900 }
1901
1902 #[test]
1903 fn test_extract_variables_from_type() {
1904 let optimizer = Optimizer::new();
1905 let expr = LogicalExpression::Type("e".to_string());
1906 let vars = optimizer.extract_variables(&expr);
1907 assert_eq!(vars.len(), 1);
1908 assert!(vars.contains("e"));
1909 }
1910
1911 #[test]
1912 fn test_extract_variables_from_id() {
1913 let optimizer = Optimizer::new();
1914 let expr = LogicalExpression::Id("n".to_string());
1915 let vars = optimizer.extract_variables(&expr);
1916 assert_eq!(vars.len(), 1);
1917 assert!(vars.contains("n"));
1918 }
1919
1920 #[test]
1921 fn test_extract_variables_from_list_comprehension() {
1922 let optimizer = Optimizer::new();
1923 let expr = LogicalExpression::ListComprehension {
1924 variable: "x".to_string(),
1925 list_expr: Box::new(LogicalExpression::Variable("items".to_string())),
1926 filter_expr: Some(Box::new(LogicalExpression::Variable("pred".to_string()))),
1927 map_expr: Box::new(LogicalExpression::Variable("result".to_string())),
1928 };
1929 let vars = optimizer.extract_variables(&expr);
1930 assert!(vars.contains("items"));
1931 assert!(vars.contains("pred"));
1932 assert!(vars.contains("result"));
1933 }
1934
1935 #[test]
1936 fn test_extract_variables_from_literal_and_parameter() {
1937 let optimizer = Optimizer::new();
1938
1939 let literal = LogicalExpression::Literal(Value::Int64(42));
1940 assert!(optimizer.extract_variables(&literal).is_empty());
1941
1942 let param = LogicalExpression::Parameter("p".to_string());
1943 assert!(optimizer.extract_variables(¶m).is_empty());
1944 }
1945
1946 #[test]
1949 fn test_recursive_filter_pushdown_through_skip() {
1950 let optimizer = Optimizer::new();
1951
1952 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1953 items: vec![ReturnItem {
1954 expression: LogicalExpression::Variable("n".to_string()),
1955 alias: None,
1956 }],
1957 distinct: false,
1958 input: Box::new(LogicalOperator::Filter(FilterOp {
1959 predicate: LogicalExpression::Literal(Value::Bool(true)),
1960 pushdown_hint: None,
1961 input: Box::new(LogicalOperator::Skip(SkipOp {
1962 count: 5.into(),
1963 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1964 variable: "n".to_string(),
1965 label: None,
1966 input: None,
1967 })),
1968 })),
1969 })),
1970 }));
1971
1972 let optimized = optimizer.optimize(plan).unwrap();
1973
1974 assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1976 }
1977
1978 #[test]
1979 fn test_nested_filter_pushdown() {
1980 let optimizer = Optimizer::new();
1981
1982 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1984 items: vec![ReturnItem {
1985 expression: LogicalExpression::Variable("n".to_string()),
1986 alias: None,
1987 }],
1988 distinct: false,
1989 input: Box::new(LogicalOperator::Filter(FilterOp {
1990 predicate: LogicalExpression::Binary {
1991 left: Box::new(LogicalExpression::Property {
1992 variable: "n".to_string(),
1993 property: "x".to_string(),
1994 }),
1995 op: BinaryOp::Gt,
1996 right: Box::new(LogicalExpression::Literal(Value::Int64(1))),
1997 },
1998 pushdown_hint: None,
1999 input: Box::new(LogicalOperator::Filter(FilterOp {
2000 predicate: LogicalExpression::Binary {
2001 left: Box::new(LogicalExpression::Property {
2002 variable: "n".to_string(),
2003 property: "y".to_string(),
2004 }),
2005 op: BinaryOp::Lt,
2006 right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
2007 },
2008 pushdown_hint: None,
2009 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2010 variable: "n".to_string(),
2011 label: None,
2012 input: None,
2013 })),
2014 })),
2015 })),
2016 }));
2017
2018 let optimized = optimizer.optimize(plan).unwrap();
2019 assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
2020 }
2021
2022 #[test]
2023 fn test_cyclic_join_produces_multi_way_join() {
2024 use crate::query::plan::JoinCondition;
2025
2026 let scan_a = LogicalOperator::NodeScan(NodeScanOp {
2028 variable: "a".to_string(),
2029 label: Some("Person".to_string()),
2030 input: None,
2031 });
2032 let scan_b = LogicalOperator::NodeScan(NodeScanOp {
2033 variable: "b".to_string(),
2034 label: Some("Person".to_string()),
2035 input: None,
2036 });
2037 let scan_c = LogicalOperator::NodeScan(NodeScanOp {
2038 variable: "c".to_string(),
2039 label: Some("Person".to_string()),
2040 input: None,
2041 });
2042
2043 let join_ab = LogicalOperator::Join(JoinOp {
2045 left: Box::new(scan_a),
2046 right: Box::new(scan_b),
2047 join_type: JoinType::Inner,
2048 conditions: vec![JoinCondition {
2049 left: LogicalExpression::Variable("a".to_string()),
2050 right: LogicalExpression::Variable("b".to_string()),
2051 }],
2052 });
2053
2054 let join_abc = LogicalOperator::Join(JoinOp {
2055 left: Box::new(join_ab),
2056 right: Box::new(scan_c),
2057 join_type: JoinType::Inner,
2058 conditions: vec![
2059 JoinCondition {
2060 left: LogicalExpression::Variable("b".to_string()),
2061 right: LogicalExpression::Variable("c".to_string()),
2062 },
2063 JoinCondition {
2064 left: LogicalExpression::Variable("c".to_string()),
2065 right: LogicalExpression::Variable("a".to_string()),
2066 },
2067 ],
2068 });
2069
2070 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
2071 items: vec![ReturnItem {
2072 expression: LogicalExpression::Variable("a".to_string()),
2073 alias: None,
2074 }],
2075 distinct: false,
2076 input: Box::new(join_abc),
2077 }));
2078
2079 let mut optimizer = Optimizer::new();
2080 optimizer
2081 .card_estimator
2082 .add_table_stats("Person", cardinality::TableStats::new(1000));
2083
2084 let optimized = optimizer.optimize(plan).unwrap();
2085
2086 fn has_multi_way_join(op: &LogicalOperator) -> bool {
2088 match op {
2089 LogicalOperator::MultiWayJoin(_) => true,
2090 LogicalOperator::Return(ret) => has_multi_way_join(&ret.input),
2091 LogicalOperator::Filter(f) => has_multi_way_join(&f.input),
2092 LogicalOperator::Project(p) => has_multi_way_join(&p.input),
2093 _ => false,
2094 }
2095 }
2096
2097 assert!(
2098 has_multi_way_join(&optimized.root),
2099 "Expected MultiWayJoin for cyclic triangle pattern"
2100 );
2101 }
2102
2103 #[test]
2104 fn test_acyclic_join_uses_binary_joins() {
2105 use crate::query::plan::JoinCondition;
2106
2107 let scan_a = LogicalOperator::NodeScan(NodeScanOp {
2109 variable: "a".to_string(),
2110 label: Some("Person".to_string()),
2111 input: None,
2112 });
2113 let scan_b = LogicalOperator::NodeScan(NodeScanOp {
2114 variable: "b".to_string(),
2115 label: Some("Person".to_string()),
2116 input: None,
2117 });
2118 let scan_c = LogicalOperator::NodeScan(NodeScanOp {
2119 variable: "c".to_string(),
2120 label: Some("Company".to_string()),
2121 input: None,
2122 });
2123
2124 let join_ab = LogicalOperator::Join(JoinOp {
2125 left: Box::new(scan_a),
2126 right: Box::new(scan_b),
2127 join_type: JoinType::Inner,
2128 conditions: vec![JoinCondition {
2129 left: LogicalExpression::Variable("a".to_string()),
2130 right: LogicalExpression::Variable("b".to_string()),
2131 }],
2132 });
2133
2134 let join_abc = LogicalOperator::Join(JoinOp {
2135 left: Box::new(join_ab),
2136 right: Box::new(scan_c),
2137 join_type: JoinType::Inner,
2138 conditions: vec![JoinCondition {
2139 left: LogicalExpression::Variable("b".to_string()),
2140 right: LogicalExpression::Variable("c".to_string()),
2141 }],
2142 });
2143
2144 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
2145 items: vec![ReturnItem {
2146 expression: LogicalExpression::Variable("a".to_string()),
2147 alias: None,
2148 }],
2149 distinct: false,
2150 input: Box::new(join_abc),
2151 }));
2152
2153 let mut optimizer = Optimizer::new();
2154 optimizer
2155 .card_estimator
2156 .add_table_stats("Person", cardinality::TableStats::new(1000));
2157 optimizer
2158 .card_estimator
2159 .add_table_stats("Company", cardinality::TableStats::new(100));
2160
2161 let optimized = optimizer.optimize(plan).unwrap();
2162
2163 fn has_multi_way_join(op: &LogicalOperator) -> bool {
2165 match op {
2166 LogicalOperator::MultiWayJoin(_) => true,
2167 LogicalOperator::Return(ret) => has_multi_way_join(&ret.input),
2168 LogicalOperator::Filter(f) => has_multi_way_join(&f.input),
2169 LogicalOperator::Project(p) => has_multi_way_join(&p.input),
2170 LogicalOperator::Join(j) => {
2171 has_multi_way_join(&j.left) || has_multi_way_join(&j.right)
2172 }
2173 _ => false,
2174 }
2175 }
2176
2177 assert!(
2178 !has_multi_way_join(&optimized.root),
2179 "Acyclic join should NOT produce MultiWayJoin"
2180 );
2181 }
2182}