1pub mod cardinality;
15pub mod cost;
16pub mod join_order;
17
18pub use cardinality::{
19 CardinalityEstimator, ColumnStats, EstimationLog, SelectivityConfig, TableStats,
20};
21pub use cost::{Cost, CostModel};
22pub use join_order::{BitSet, DPccp, JoinGraph, JoinGraphBuilder, JoinPlan};
23
24use crate::query::plan::{
25 FilterOp, JoinCondition, LogicalExpression, LogicalOperator, LogicalPlan, MultiWayJoinOp,
26};
27use grafeo_common::utils::error::Result;
28use std::collections::HashSet;
29
30#[derive(Debug, Clone)]
32struct JoinInfo {
33 left_var: String,
34 right_var: String,
35 left_expr: LogicalExpression,
36 right_expr: LogicalExpression,
37}
38
39#[derive(Debug, Clone, PartialEq, Eq, Hash)]
41enum RequiredColumn {
42 Variable(String),
44 Property(String, String),
46}
47
48pub struct Optimizer {
53 enable_filter_pushdown: bool,
55 enable_join_reorder: bool,
57 enable_projection_pushdown: bool,
59 cost_model: CostModel,
61 card_estimator: CardinalityEstimator,
63}
64
65impl Optimizer {
66 #[must_use]
68 pub fn new() -> Self {
69 Self {
70 enable_filter_pushdown: true,
71 enable_join_reorder: true,
72 enable_projection_pushdown: true,
73 cost_model: CostModel::new(),
74 card_estimator: CardinalityEstimator::new(),
75 }
76 }
77
78 #[must_use]
84 pub fn from_store(store: &grafeo_core::graph::lpg::LpgStore) -> Self {
85 store.ensure_statistics_fresh();
86 let stats = store.statistics();
87 Self::from_statistics(&stats)
88 }
89
90 #[must_use]
97 pub fn from_graph_store(store: &dyn grafeo_core::graph::GraphStore) -> Self {
98 let stats = store.statistics();
99 Self::from_statistics(&stats)
100 }
101
102 #[cfg(feature = "rdf")]
107 #[must_use]
108 pub fn from_rdf_statistics(rdf_stats: grafeo_core::statistics::RdfStatistics) -> Self {
109 let total = rdf_stats.total_triples;
110 let estimator = CardinalityEstimator::from_rdf_statistics(rdf_stats);
111 Self {
112 enable_filter_pushdown: true,
113 enable_join_reorder: true,
114 enable_projection_pushdown: true,
115 cost_model: CostModel::new().with_graph_totals(total, total),
116 card_estimator: estimator,
117 }
118 }
119
120 #[must_use]
125 fn from_statistics(stats: &grafeo_core::statistics::Statistics) -> Self {
126 let estimator = CardinalityEstimator::from_statistics(stats);
127
128 let avg_fanout = if stats.total_nodes > 0 {
129 (stats.total_edges as f64 / stats.total_nodes as f64).max(1.0)
130 } else {
131 10.0
132 };
133
134 let edge_type_degrees: std::collections::HashMap<String, (f64, f64)> = stats
135 .edge_types
136 .iter()
137 .map(|(name, et)| (name.clone(), (et.avg_out_degree, et.avg_in_degree)))
138 .collect();
139
140 let label_cardinalities: std::collections::HashMap<String, u64> = stats
141 .labels
142 .iter()
143 .map(|(name, ls)| (name.clone(), ls.node_count))
144 .collect();
145
146 Self {
147 enable_filter_pushdown: true,
148 enable_join_reorder: true,
149 enable_projection_pushdown: true,
150 cost_model: CostModel::new()
151 .with_avg_fanout(avg_fanout)
152 .with_edge_type_degrees(edge_type_degrees)
153 .with_label_cardinalities(label_cardinalities)
154 .with_graph_totals(stats.total_nodes, stats.total_edges),
155 card_estimator: estimator,
156 }
157 }
158
159 pub fn with_filter_pushdown(mut self, enabled: bool) -> Self {
161 self.enable_filter_pushdown = enabled;
162 self
163 }
164
165 pub fn with_join_reorder(mut self, enabled: bool) -> Self {
167 self.enable_join_reorder = enabled;
168 self
169 }
170
171 pub fn with_projection_pushdown(mut self, enabled: bool) -> Self {
173 self.enable_projection_pushdown = enabled;
174 self
175 }
176
177 pub fn with_cost_model(mut self, cost_model: CostModel) -> Self {
179 self.cost_model = cost_model;
180 self
181 }
182
183 pub fn with_cardinality_estimator(mut self, estimator: CardinalityEstimator) -> Self {
185 self.card_estimator = estimator;
186 self
187 }
188
189 pub fn with_selectivity_config(mut self, config: SelectivityConfig) -> Self {
191 self.card_estimator = CardinalityEstimator::with_selectivity_config(config);
192 self
193 }
194
195 pub fn cost_model(&self) -> &CostModel {
197 &self.cost_model
198 }
199
200 pub fn cardinality_estimator(&self) -> &CardinalityEstimator {
202 &self.card_estimator
203 }
204
205 pub fn estimate_cost(&self, plan: &LogicalPlan) -> Cost {
211 self.cost_model
212 .estimate_tree(&plan.root, &self.card_estimator)
213 }
214
215 pub fn estimate_cardinality(&self, plan: &LogicalPlan) -> f64 {
217 self.card_estimator.estimate(&plan.root)
218 }
219
220 pub fn optimize(&self, plan: LogicalPlan) -> Result<LogicalPlan> {
226 let mut root = plan.root;
227
228 if self.enable_filter_pushdown {
230 root = self.push_filters_down(root);
231 }
232
233 if self.enable_join_reorder {
234 root = self.reorder_joins(root);
235 }
236
237 if self.enable_projection_pushdown {
238 root = self.push_projections_down(root);
239 }
240
241 Ok(LogicalPlan {
242 root,
243 explain: plan.explain,
244 profile: plan.profile,
245 })
246 }
247
248 fn push_projections_down(&self, op: LogicalOperator) -> LogicalOperator {
255 let required = self.collect_required_columns(&op);
257
258 self.push_projections_recursive(op, &required)
260 }
261
262 fn collect_required_columns(&self, op: &LogicalOperator) -> HashSet<RequiredColumn> {
264 let mut required = HashSet::new();
265 Self::collect_required_recursive(op, &mut required);
266 required
267 }
268
269 fn collect_required_recursive(op: &LogicalOperator, required: &mut HashSet<RequiredColumn>) {
271 match op {
272 LogicalOperator::Return(ret) => {
273 for item in &ret.items {
274 Self::collect_from_expression(&item.expression, required);
275 }
276 Self::collect_required_recursive(&ret.input, required);
277 }
278 LogicalOperator::Project(proj) => {
279 for p in &proj.projections {
280 Self::collect_from_expression(&p.expression, required);
281 }
282 Self::collect_required_recursive(&proj.input, required);
283 }
284 LogicalOperator::Filter(filter) => {
285 Self::collect_from_expression(&filter.predicate, required);
286 Self::collect_required_recursive(&filter.input, required);
287 }
288 LogicalOperator::Sort(sort) => {
289 for key in &sort.keys {
290 Self::collect_from_expression(&key.expression, required);
291 }
292 Self::collect_required_recursive(&sort.input, required);
293 }
294 LogicalOperator::Aggregate(agg) => {
295 for expr in &agg.group_by {
296 Self::collect_from_expression(expr, required);
297 }
298 for agg_expr in &agg.aggregates {
299 if let Some(ref expr) = agg_expr.expression {
300 Self::collect_from_expression(expr, required);
301 }
302 }
303 if let Some(ref having) = agg.having {
304 Self::collect_from_expression(having, required);
305 }
306 Self::collect_required_recursive(&agg.input, required);
307 }
308 LogicalOperator::Join(join) => {
309 for cond in &join.conditions {
310 Self::collect_from_expression(&cond.left, required);
311 Self::collect_from_expression(&cond.right, required);
312 }
313 Self::collect_required_recursive(&join.left, required);
314 Self::collect_required_recursive(&join.right, required);
315 }
316 LogicalOperator::Expand(expand) => {
317 required.insert(RequiredColumn::Variable(expand.from_variable.clone()));
319 required.insert(RequiredColumn::Variable(expand.to_variable.clone()));
320 if let Some(ref edge_var) = expand.edge_variable {
321 required.insert(RequiredColumn::Variable(edge_var.clone()));
322 }
323 Self::collect_required_recursive(&expand.input, required);
324 }
325 LogicalOperator::Limit(limit) => {
326 Self::collect_required_recursive(&limit.input, required);
327 }
328 LogicalOperator::Skip(skip) => {
329 Self::collect_required_recursive(&skip.input, required);
330 }
331 LogicalOperator::Distinct(distinct) => {
332 Self::collect_required_recursive(&distinct.input, required);
333 }
334 LogicalOperator::NodeScan(scan) => {
335 required.insert(RequiredColumn::Variable(scan.variable.clone()));
336 }
337 LogicalOperator::EdgeScan(scan) => {
338 required.insert(RequiredColumn::Variable(scan.variable.clone()));
339 }
340 LogicalOperator::MultiWayJoin(mwj) => {
341 for cond in &mwj.conditions {
342 Self::collect_from_expression(&cond.left, required);
343 Self::collect_from_expression(&cond.right, required);
344 }
345 for input in &mwj.inputs {
346 Self::collect_required_recursive(input, required);
347 }
348 }
349 _ => {}
350 }
351 }
352
353 fn collect_from_expression(expr: &LogicalExpression, required: &mut HashSet<RequiredColumn>) {
355 match expr {
356 LogicalExpression::Variable(var) => {
357 required.insert(RequiredColumn::Variable(var.clone()));
358 }
359 LogicalExpression::Property { variable, property } => {
360 required.insert(RequiredColumn::Property(variable.clone(), property.clone()));
361 required.insert(RequiredColumn::Variable(variable.clone()));
362 }
363 LogicalExpression::Binary { left, right, .. } => {
364 Self::collect_from_expression(left, required);
365 Self::collect_from_expression(right, required);
366 }
367 LogicalExpression::Unary { operand, .. } => {
368 Self::collect_from_expression(operand, required);
369 }
370 LogicalExpression::FunctionCall { args, .. } => {
371 for arg in args {
372 Self::collect_from_expression(arg, required);
373 }
374 }
375 LogicalExpression::List(items) => {
376 for item in items {
377 Self::collect_from_expression(item, required);
378 }
379 }
380 LogicalExpression::Map(pairs) => {
381 for (_, value) in pairs {
382 Self::collect_from_expression(value, required);
383 }
384 }
385 LogicalExpression::IndexAccess { base, index } => {
386 Self::collect_from_expression(base, required);
387 Self::collect_from_expression(index, required);
388 }
389 LogicalExpression::SliceAccess { base, start, end } => {
390 Self::collect_from_expression(base, required);
391 if let Some(s) = start {
392 Self::collect_from_expression(s, required);
393 }
394 if let Some(e) = end {
395 Self::collect_from_expression(e, required);
396 }
397 }
398 LogicalExpression::Case {
399 operand,
400 when_clauses,
401 else_clause,
402 } => {
403 if let Some(op) = operand {
404 Self::collect_from_expression(op, required);
405 }
406 for (cond, result) in when_clauses {
407 Self::collect_from_expression(cond, required);
408 Self::collect_from_expression(result, required);
409 }
410 if let Some(else_expr) = else_clause {
411 Self::collect_from_expression(else_expr, required);
412 }
413 }
414 LogicalExpression::Labels(var)
415 | LogicalExpression::Type(var)
416 | LogicalExpression::Id(var) => {
417 required.insert(RequiredColumn::Variable(var.clone()));
418 }
419 LogicalExpression::ListComprehension {
420 list_expr,
421 filter_expr,
422 map_expr,
423 ..
424 } => {
425 Self::collect_from_expression(list_expr, required);
426 if let Some(filter) = filter_expr {
427 Self::collect_from_expression(filter, required);
428 }
429 Self::collect_from_expression(map_expr, required);
430 }
431 _ => {}
432 }
433 }
434
435 fn push_projections_recursive(
437 &self,
438 op: LogicalOperator,
439 required: &HashSet<RequiredColumn>,
440 ) -> LogicalOperator {
441 match op {
442 LogicalOperator::Return(mut ret) => {
443 ret.input = Box::new(self.push_projections_recursive(*ret.input, required));
444 LogicalOperator::Return(ret)
445 }
446 LogicalOperator::Project(mut proj) => {
447 proj.input = Box::new(self.push_projections_recursive(*proj.input, required));
448 LogicalOperator::Project(proj)
449 }
450 LogicalOperator::Filter(mut filter) => {
451 filter.input = Box::new(self.push_projections_recursive(*filter.input, required));
452 LogicalOperator::Filter(filter)
453 }
454 LogicalOperator::Sort(mut sort) => {
455 sort.input = Box::new(self.push_projections_recursive(*sort.input, required));
458 LogicalOperator::Sort(sort)
459 }
460 LogicalOperator::Aggregate(mut agg) => {
461 agg.input = Box::new(self.push_projections_recursive(*agg.input, required));
462 LogicalOperator::Aggregate(agg)
463 }
464 LogicalOperator::Join(mut join) => {
465 let left_vars = self.collect_output_variables(&join.left);
468 let right_vars = self.collect_output_variables(&join.right);
469
470 let left_required: HashSet<_> = required
472 .iter()
473 .filter(|c| match c {
474 RequiredColumn::Variable(v) => left_vars.contains(v),
475 RequiredColumn::Property(v, _) => left_vars.contains(v),
476 })
477 .cloned()
478 .collect();
479
480 let right_required: HashSet<_> = required
481 .iter()
482 .filter(|c| match c {
483 RequiredColumn::Variable(v) => right_vars.contains(v),
484 RequiredColumn::Property(v, _) => right_vars.contains(v),
485 })
486 .cloned()
487 .collect();
488
489 join.left = Box::new(self.push_projections_recursive(*join.left, &left_required));
490 join.right =
491 Box::new(self.push_projections_recursive(*join.right, &right_required));
492 LogicalOperator::Join(join)
493 }
494 LogicalOperator::Expand(mut expand) => {
495 expand.input = Box::new(self.push_projections_recursive(*expand.input, required));
496 LogicalOperator::Expand(expand)
497 }
498 LogicalOperator::Limit(mut limit) => {
499 limit.input = Box::new(self.push_projections_recursive(*limit.input, required));
500 LogicalOperator::Limit(limit)
501 }
502 LogicalOperator::Skip(mut skip) => {
503 skip.input = Box::new(self.push_projections_recursive(*skip.input, required));
504 LogicalOperator::Skip(skip)
505 }
506 LogicalOperator::Distinct(mut distinct) => {
507 distinct.input =
508 Box::new(self.push_projections_recursive(*distinct.input, required));
509 LogicalOperator::Distinct(distinct)
510 }
511 LogicalOperator::MapCollect(mut mc) => {
512 mc.input = Box::new(self.push_projections_recursive(*mc.input, required));
513 LogicalOperator::MapCollect(mc)
514 }
515 LogicalOperator::MultiWayJoin(mut mwj) => {
516 mwj.inputs = mwj
517 .inputs
518 .into_iter()
519 .map(|input| self.push_projections_recursive(input, required))
520 .collect();
521 LogicalOperator::MultiWayJoin(mwj)
522 }
523 other => other,
524 }
525 }
526
527 fn reorder_joins(&self, op: LogicalOperator) -> LogicalOperator {
534 let op = self.reorder_joins_recursive(op);
536
537 if let Some((relations, conditions)) = self.extract_join_tree(&op)
539 && relations.len() >= 2
540 && let Some(optimized) = self.optimize_join_order(&relations, &conditions)
541 {
542 return optimized;
543 }
544
545 op
546 }
547
548 fn reorder_joins_recursive(&self, op: LogicalOperator) -> LogicalOperator {
550 match op {
551 LogicalOperator::Return(mut ret) => {
552 ret.input = Box::new(self.reorder_joins(*ret.input));
553 LogicalOperator::Return(ret)
554 }
555 LogicalOperator::Project(mut proj) => {
556 proj.input = Box::new(self.reorder_joins(*proj.input));
557 LogicalOperator::Project(proj)
558 }
559 LogicalOperator::Filter(mut filter) => {
560 filter.input = Box::new(self.reorder_joins(*filter.input));
561 LogicalOperator::Filter(filter)
562 }
563 LogicalOperator::Limit(mut limit) => {
564 limit.input = Box::new(self.reorder_joins(*limit.input));
565 LogicalOperator::Limit(limit)
566 }
567 LogicalOperator::Skip(mut skip) => {
568 skip.input = Box::new(self.reorder_joins(*skip.input));
569 LogicalOperator::Skip(skip)
570 }
571 LogicalOperator::Sort(mut sort) => {
572 sort.input = Box::new(self.reorder_joins(*sort.input));
573 LogicalOperator::Sort(sort)
574 }
575 LogicalOperator::Distinct(mut distinct) => {
576 distinct.input = Box::new(self.reorder_joins(*distinct.input));
577 LogicalOperator::Distinct(distinct)
578 }
579 LogicalOperator::Aggregate(mut agg) => {
580 agg.input = Box::new(self.reorder_joins(*agg.input));
581 LogicalOperator::Aggregate(agg)
582 }
583 LogicalOperator::Expand(mut expand) => {
584 expand.input = Box::new(self.reorder_joins(*expand.input));
585 LogicalOperator::Expand(expand)
586 }
587 LogicalOperator::MapCollect(mut mc) => {
588 mc.input = Box::new(self.reorder_joins(*mc.input));
589 LogicalOperator::MapCollect(mc)
590 }
591 LogicalOperator::MultiWayJoin(mut mwj) => {
592 mwj.inputs = mwj
593 .inputs
594 .into_iter()
595 .map(|input| self.reorder_joins(input))
596 .collect();
597 LogicalOperator::MultiWayJoin(mwj)
598 }
599 other => other,
601 }
602 }
603
604 fn extract_join_tree(
608 &self,
609 op: &LogicalOperator,
610 ) -> Option<(Vec<(String, LogicalOperator)>, Vec<JoinInfo>)> {
611 let mut relations = Vec::new();
612 let mut join_conditions = Vec::new();
613
614 if !self.collect_join_tree(op, &mut relations, &mut join_conditions) {
615 return None;
616 }
617
618 if relations.len() < 2 {
619 return None;
620 }
621
622 Some((relations, join_conditions))
623 }
624
625 fn collect_join_tree(
629 &self,
630 op: &LogicalOperator,
631 relations: &mut Vec<(String, LogicalOperator)>,
632 conditions: &mut Vec<JoinInfo>,
633 ) -> bool {
634 match op {
635 LogicalOperator::Join(join) => {
636 let left_ok = self.collect_join_tree(&join.left, relations, conditions);
638 let right_ok = self.collect_join_tree(&join.right, relations, conditions);
639
640 for cond in &join.conditions {
642 if let (Some(left_var), Some(right_var)) = (
643 self.extract_variable_from_expr(&cond.left),
644 self.extract_variable_from_expr(&cond.right),
645 ) {
646 conditions.push(JoinInfo {
647 left_var,
648 right_var,
649 left_expr: cond.left.clone(),
650 right_expr: cond.right.clone(),
651 });
652 }
653 }
654
655 left_ok && right_ok
656 }
657 LogicalOperator::NodeScan(scan) => {
658 relations.push((scan.variable.clone(), op.clone()));
659 true
660 }
661 LogicalOperator::EdgeScan(scan) => {
662 relations.push((scan.variable.clone(), op.clone()));
663 true
664 }
665 LogicalOperator::Filter(filter) => {
666 self.collect_join_tree(&filter.input, relations, conditions)
668 }
669 LogicalOperator::Expand(expand) => {
670 relations.push((expand.to_variable.clone(), op.clone()));
673 true
674 }
675 _ => false,
676 }
677 }
678
679 fn extract_variable_from_expr(&self, expr: &LogicalExpression) -> Option<String> {
681 match expr {
682 LogicalExpression::Variable(v) => Some(v.clone()),
683 LogicalExpression::Property { variable, .. } => Some(variable.clone()),
684 _ => None,
685 }
686 }
687
688 fn optimize_join_order(
691 &self,
692 relations: &[(String, LogicalOperator)],
693 conditions: &[JoinInfo],
694 ) -> Option<LogicalOperator> {
695 use join_order::{DPccp, JoinGraphBuilder};
696
697 let mut builder = JoinGraphBuilder::new();
699
700 for (var, relation) in relations {
701 builder.add_relation(var, relation.clone());
702 }
703
704 for cond in conditions {
705 builder.add_join_condition(
706 &cond.left_var,
707 &cond.right_var,
708 cond.left_expr.clone(),
709 cond.right_expr.clone(),
710 );
711 }
712
713 let graph = builder.build();
714
715 if graph.is_cyclic() && relations.len() >= 3 {
720 let mut var_counts: std::collections::HashMap<&str, usize> =
722 std::collections::HashMap::new();
723 for cond in conditions {
724 *var_counts.entry(&cond.left_var).or_default() += 1;
725 *var_counts.entry(&cond.right_var).or_default() += 1;
726 }
727 let shared_variables: Vec<String> = var_counts
728 .into_iter()
729 .filter(|(_, count)| *count >= 2)
730 .map(|(var, _)| var.to_string())
731 .collect();
732
733 let join_conditions: Vec<JoinCondition> = conditions
734 .iter()
735 .map(|c| JoinCondition {
736 left: c.left_expr.clone(),
737 right: c.right_expr.clone(),
738 })
739 .collect();
740
741 return Some(LogicalOperator::MultiWayJoin(MultiWayJoinOp {
742 inputs: relations.iter().map(|(_, rel)| rel.clone()).collect(),
743 conditions: join_conditions,
744 shared_variables,
745 }));
746 }
747
748 let mut dpccp = DPccp::new(&graph, &self.cost_model, &self.card_estimator);
750 let plan = dpccp.optimize()?;
751
752 Some(plan.operator)
753 }
754
755 fn push_filters_down(&self, op: LogicalOperator) -> LogicalOperator {
760 match op {
761 LogicalOperator::Filter(filter) => {
763 let optimized_input = self.push_filters_down(*filter.input);
764 self.try_push_filter_into(filter.predicate, optimized_input)
765 }
766 LogicalOperator::Return(mut ret) => {
768 ret.input = Box::new(self.push_filters_down(*ret.input));
769 LogicalOperator::Return(ret)
770 }
771 LogicalOperator::Project(mut proj) => {
772 proj.input = Box::new(self.push_filters_down(*proj.input));
773 LogicalOperator::Project(proj)
774 }
775 LogicalOperator::Limit(mut limit) => {
776 limit.input = Box::new(self.push_filters_down(*limit.input));
777 LogicalOperator::Limit(limit)
778 }
779 LogicalOperator::Skip(mut skip) => {
780 skip.input = Box::new(self.push_filters_down(*skip.input));
781 LogicalOperator::Skip(skip)
782 }
783 LogicalOperator::Sort(mut sort) => {
784 sort.input = Box::new(self.push_filters_down(*sort.input));
785 LogicalOperator::Sort(sort)
786 }
787 LogicalOperator::Distinct(mut distinct) => {
788 distinct.input = Box::new(self.push_filters_down(*distinct.input));
789 LogicalOperator::Distinct(distinct)
790 }
791 LogicalOperator::Expand(mut expand) => {
792 expand.input = Box::new(self.push_filters_down(*expand.input));
793 LogicalOperator::Expand(expand)
794 }
795 LogicalOperator::Join(mut join) => {
796 join.left = Box::new(self.push_filters_down(*join.left));
797 join.right = Box::new(self.push_filters_down(*join.right));
798 LogicalOperator::Join(join)
799 }
800 LogicalOperator::Aggregate(mut agg) => {
801 agg.input = Box::new(self.push_filters_down(*agg.input));
802 LogicalOperator::Aggregate(agg)
803 }
804 LogicalOperator::MapCollect(mut mc) => {
805 mc.input = Box::new(self.push_filters_down(*mc.input));
806 LogicalOperator::MapCollect(mc)
807 }
808 LogicalOperator::MultiWayJoin(mut mwj) => {
809 mwj.inputs = mwj
810 .inputs
811 .into_iter()
812 .map(|input| self.push_filters_down(input))
813 .collect();
814 LogicalOperator::MultiWayJoin(mwj)
815 }
816 other => other,
818 }
819 }
820
821 fn try_push_filter_into(
826 &self,
827 predicate: LogicalExpression,
828 op: LogicalOperator,
829 ) -> LogicalOperator {
830 match op {
831 LogicalOperator::Project(mut proj) => {
833 let predicate_vars = self.extract_variables(&predicate);
834 let computed_vars = self.extract_projection_aliases(&proj.projections);
835
836 if predicate_vars.is_disjoint(&computed_vars) {
838 proj.input = Box::new(self.try_push_filter_into(predicate, *proj.input));
839 LogicalOperator::Project(proj)
840 } else {
841 LogicalOperator::Filter(FilterOp {
843 predicate,
844 pushdown_hint: None,
845 input: Box::new(LogicalOperator::Project(proj)),
846 })
847 }
848 }
849
850 LogicalOperator::Return(mut ret) => {
852 ret.input = Box::new(self.try_push_filter_into(predicate, *ret.input));
853 LogicalOperator::Return(ret)
854 }
855
856 LogicalOperator::Expand(mut expand) => {
858 let predicate_vars = self.extract_variables(&predicate);
859
860 let mut introduced_vars = vec![&expand.to_variable];
865 if let Some(ref edge_var) = expand.edge_variable {
866 introduced_vars.push(edge_var);
867 }
868 if let Some(ref path_alias) = expand.path_alias {
869 introduced_vars.push(path_alias);
870 }
871
872 let uses_introduced_vars =
874 predicate_vars.iter().any(|v| introduced_vars.contains(&v));
875
876 if !uses_introduced_vars {
877 expand.input = Box::new(self.try_push_filter_into(predicate, *expand.input));
879 LogicalOperator::Expand(expand)
880 } else {
881 LogicalOperator::Filter(FilterOp {
883 predicate,
884 pushdown_hint: None,
885 input: Box::new(LogicalOperator::Expand(expand)),
886 })
887 }
888 }
889
890 LogicalOperator::Join(mut join) => {
892 let predicate_vars = self.extract_variables(&predicate);
893 let left_vars = self.collect_output_variables(&join.left);
894 let right_vars = self.collect_output_variables(&join.right);
895
896 let uses_left = predicate_vars.iter().any(|v| left_vars.contains(v));
897 let uses_right = predicate_vars.iter().any(|v| right_vars.contains(v));
898
899 if uses_left && !uses_right {
900 join.left = Box::new(self.try_push_filter_into(predicate, *join.left));
902 LogicalOperator::Join(join)
903 } else if uses_right && !uses_left {
904 join.right = Box::new(self.try_push_filter_into(predicate, *join.right));
906 LogicalOperator::Join(join)
907 } else {
908 LogicalOperator::Filter(FilterOp {
910 predicate,
911 pushdown_hint: None,
912 input: Box::new(LogicalOperator::Join(join)),
913 })
914 }
915 }
916
917 LogicalOperator::Aggregate(agg) => LogicalOperator::Filter(FilterOp {
919 predicate,
920 pushdown_hint: None,
921 input: Box::new(LogicalOperator::Aggregate(agg)),
922 }),
923
924 LogicalOperator::NodeScan(scan) => LogicalOperator::Filter(FilterOp {
926 predicate,
927 pushdown_hint: None,
928 input: Box::new(LogicalOperator::NodeScan(scan)),
929 }),
930
931 other => LogicalOperator::Filter(FilterOp {
933 predicate,
934 pushdown_hint: None,
935 input: Box::new(other),
936 }),
937 }
938 }
939
940 fn collect_output_variables(&self, op: &LogicalOperator) -> HashSet<String> {
942 let mut vars = HashSet::new();
943 Self::collect_output_variables_recursive(op, &mut vars);
944 vars
945 }
946
947 fn collect_output_variables_recursive(op: &LogicalOperator, vars: &mut HashSet<String>) {
949 match op {
950 LogicalOperator::NodeScan(scan) => {
951 vars.insert(scan.variable.clone());
952 }
953 LogicalOperator::EdgeScan(scan) => {
954 vars.insert(scan.variable.clone());
955 }
956 LogicalOperator::Expand(expand) => {
957 vars.insert(expand.to_variable.clone());
958 if let Some(edge_var) = &expand.edge_variable {
959 vars.insert(edge_var.clone());
960 }
961 Self::collect_output_variables_recursive(&expand.input, vars);
962 }
963 LogicalOperator::Filter(filter) => {
964 Self::collect_output_variables_recursive(&filter.input, vars);
965 }
966 LogicalOperator::Project(proj) => {
967 for p in &proj.projections {
968 if let Some(alias) = &p.alias {
969 vars.insert(alias.clone());
970 }
971 }
972 Self::collect_output_variables_recursive(&proj.input, vars);
973 }
974 LogicalOperator::Join(join) => {
975 Self::collect_output_variables_recursive(&join.left, vars);
976 Self::collect_output_variables_recursive(&join.right, vars);
977 }
978 LogicalOperator::Aggregate(agg) => {
979 for expr in &agg.group_by {
980 Self::collect_variables(expr, vars);
981 }
982 for agg_expr in &agg.aggregates {
983 if let Some(alias) = &agg_expr.alias {
984 vars.insert(alias.clone());
985 }
986 }
987 }
988 LogicalOperator::Return(ret) => {
989 Self::collect_output_variables_recursive(&ret.input, vars);
990 }
991 LogicalOperator::Limit(limit) => {
992 Self::collect_output_variables_recursive(&limit.input, vars);
993 }
994 LogicalOperator::Skip(skip) => {
995 Self::collect_output_variables_recursive(&skip.input, vars);
996 }
997 LogicalOperator::Sort(sort) => {
998 Self::collect_output_variables_recursive(&sort.input, vars);
999 }
1000 LogicalOperator::Distinct(distinct) => {
1001 Self::collect_output_variables_recursive(&distinct.input, vars);
1002 }
1003 _ => {}
1004 }
1005 }
1006
1007 fn extract_variables(&self, expr: &LogicalExpression) -> HashSet<String> {
1009 let mut vars = HashSet::new();
1010 Self::collect_variables(expr, &mut vars);
1011 vars
1012 }
1013
1014 fn collect_variables(expr: &LogicalExpression, vars: &mut HashSet<String>) {
1016 match expr {
1017 LogicalExpression::Variable(name) => {
1018 vars.insert(name.clone());
1019 }
1020 LogicalExpression::Property { variable, .. } => {
1021 vars.insert(variable.clone());
1022 }
1023 LogicalExpression::Binary { left, right, .. } => {
1024 Self::collect_variables(left, vars);
1025 Self::collect_variables(right, vars);
1026 }
1027 LogicalExpression::Unary { operand, .. } => {
1028 Self::collect_variables(operand, vars);
1029 }
1030 LogicalExpression::FunctionCall { args, .. } => {
1031 for arg in args {
1032 Self::collect_variables(arg, vars);
1033 }
1034 }
1035 LogicalExpression::List(items) => {
1036 for item in items {
1037 Self::collect_variables(item, vars);
1038 }
1039 }
1040 LogicalExpression::Map(pairs) => {
1041 for (_, value) in pairs {
1042 Self::collect_variables(value, vars);
1043 }
1044 }
1045 LogicalExpression::IndexAccess { base, index } => {
1046 Self::collect_variables(base, vars);
1047 Self::collect_variables(index, vars);
1048 }
1049 LogicalExpression::SliceAccess { base, start, end } => {
1050 Self::collect_variables(base, vars);
1051 if let Some(s) = start {
1052 Self::collect_variables(s, vars);
1053 }
1054 if let Some(e) = end {
1055 Self::collect_variables(e, vars);
1056 }
1057 }
1058 LogicalExpression::Case {
1059 operand,
1060 when_clauses,
1061 else_clause,
1062 } => {
1063 if let Some(op) = operand {
1064 Self::collect_variables(op, vars);
1065 }
1066 for (cond, result) in when_clauses {
1067 Self::collect_variables(cond, vars);
1068 Self::collect_variables(result, vars);
1069 }
1070 if let Some(else_expr) = else_clause {
1071 Self::collect_variables(else_expr, vars);
1072 }
1073 }
1074 LogicalExpression::Labels(var)
1075 | LogicalExpression::Type(var)
1076 | LogicalExpression::Id(var) => {
1077 vars.insert(var.clone());
1078 }
1079 LogicalExpression::Literal(_) | LogicalExpression::Parameter(_) => {}
1080 LogicalExpression::ListComprehension {
1081 list_expr,
1082 filter_expr,
1083 map_expr,
1084 ..
1085 } => {
1086 Self::collect_variables(list_expr, vars);
1087 if let Some(filter) = filter_expr {
1088 Self::collect_variables(filter, vars);
1089 }
1090 Self::collect_variables(map_expr, vars);
1091 }
1092 LogicalExpression::ListPredicate {
1093 list_expr,
1094 predicate,
1095 ..
1096 } => {
1097 Self::collect_variables(list_expr, vars);
1098 Self::collect_variables(predicate, vars);
1099 }
1100 LogicalExpression::ExistsSubquery(_)
1101 | LogicalExpression::CountSubquery(_)
1102 | LogicalExpression::ValueSubquery(_) => {
1103 }
1105 LogicalExpression::PatternComprehension { projection, .. } => {
1106 Self::collect_variables(projection, vars);
1107 }
1108 LogicalExpression::MapProjection { base, entries } => {
1109 vars.insert(base.clone());
1110 for entry in entries {
1111 if let crate::query::plan::MapProjectionEntry::LiteralEntry(_, expr) = entry {
1112 Self::collect_variables(expr, vars);
1113 }
1114 }
1115 }
1116 LogicalExpression::Reduce {
1117 initial,
1118 list,
1119 expression,
1120 ..
1121 } => {
1122 Self::collect_variables(initial, vars);
1123 Self::collect_variables(list, vars);
1124 Self::collect_variables(expression, vars);
1125 }
1126 }
1127 }
1128
1129 fn extract_projection_aliases(
1131 &self,
1132 projections: &[crate::query::plan::Projection],
1133 ) -> HashSet<String> {
1134 projections.iter().filter_map(|p| p.alias.clone()).collect()
1135 }
1136}
1137
1138impl Default for Optimizer {
1139 fn default() -> Self {
1140 Self::new()
1141 }
1142}
1143
1144#[cfg(test)]
1145mod tests {
1146 use super::*;
1147 use crate::query::plan::{
1148 AggregateExpr, AggregateFunction, AggregateOp, BinaryOp, DistinctOp, ExpandDirection,
1149 ExpandOp, JoinOp, JoinType, LimitOp, NodeScanOp, PathMode, ProjectOp, Projection,
1150 ReturnItem, ReturnOp, SkipOp, SortKey, SortOp, SortOrder, UnaryOp,
1151 };
1152 use grafeo_common::types::Value;
1153
1154 #[test]
1155 fn test_optimizer_filter_pushdown_simple() {
1156 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1161 items: vec![ReturnItem {
1162 expression: LogicalExpression::Variable("n".to_string()),
1163 alias: None,
1164 }],
1165 distinct: false,
1166 input: Box::new(LogicalOperator::Filter(FilterOp {
1167 predicate: LogicalExpression::Binary {
1168 left: Box::new(LogicalExpression::Property {
1169 variable: "n".to_string(),
1170 property: "age".to_string(),
1171 }),
1172 op: BinaryOp::Gt,
1173 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1174 },
1175 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1176 variable: "n".to_string(),
1177 label: Some("Person".to_string()),
1178 input: None,
1179 })),
1180 pushdown_hint: None,
1181 })),
1182 }));
1183
1184 let optimizer = Optimizer::new();
1185 let optimized = optimizer.optimize(plan).unwrap();
1186
1187 if let LogicalOperator::Return(ret) = &optimized.root
1189 && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1190 && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1191 {
1192 assert_eq!(scan.variable, "n");
1193 return;
1194 }
1195 panic!("Expected Return -> Filter -> NodeScan structure");
1196 }
1197
1198 #[test]
1199 fn test_optimizer_filter_pushdown_through_expand() {
1200 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1204 items: vec![ReturnItem {
1205 expression: LogicalExpression::Variable("b".to_string()),
1206 alias: None,
1207 }],
1208 distinct: false,
1209 input: Box::new(LogicalOperator::Filter(FilterOp {
1210 predicate: LogicalExpression::Binary {
1211 left: Box::new(LogicalExpression::Property {
1212 variable: "a".to_string(),
1213 property: "age".to_string(),
1214 }),
1215 op: BinaryOp::Gt,
1216 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1217 },
1218 pushdown_hint: None,
1219 input: Box::new(LogicalOperator::Expand(ExpandOp {
1220 from_variable: "a".to_string(),
1221 to_variable: "b".to_string(),
1222 edge_variable: None,
1223 direction: ExpandDirection::Outgoing,
1224 edge_types: vec!["KNOWS".to_string()],
1225 min_hops: 1,
1226 max_hops: Some(1),
1227 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1228 variable: "a".to_string(),
1229 label: Some("Person".to_string()),
1230 input: None,
1231 })),
1232 path_alias: None,
1233 path_mode: PathMode::Walk,
1234 })),
1235 })),
1236 }));
1237
1238 let optimizer = Optimizer::new();
1239 let optimized = optimizer.optimize(plan).unwrap();
1240
1241 if let LogicalOperator::Return(ret) = &optimized.root
1244 && let LogicalOperator::Expand(expand) = ret.input.as_ref()
1245 && let LogicalOperator::Filter(filter) = expand.input.as_ref()
1246 && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1247 {
1248 assert_eq!(scan.variable, "a");
1249 assert_eq!(expand.from_variable, "a");
1250 assert_eq!(expand.to_variable, "b");
1251 return;
1252 }
1253 panic!("Expected Return -> Expand -> Filter -> NodeScan structure");
1254 }
1255
1256 #[test]
1257 fn test_optimizer_filter_not_pushed_through_expand_for_target_var() {
1258 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1262 items: vec![ReturnItem {
1263 expression: LogicalExpression::Variable("a".to_string()),
1264 alias: None,
1265 }],
1266 distinct: false,
1267 input: Box::new(LogicalOperator::Filter(FilterOp {
1268 predicate: LogicalExpression::Binary {
1269 left: Box::new(LogicalExpression::Property {
1270 variable: "b".to_string(),
1271 property: "age".to_string(),
1272 }),
1273 op: BinaryOp::Gt,
1274 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1275 },
1276 pushdown_hint: None,
1277 input: Box::new(LogicalOperator::Expand(ExpandOp {
1278 from_variable: "a".to_string(),
1279 to_variable: "b".to_string(),
1280 edge_variable: None,
1281 direction: ExpandDirection::Outgoing,
1282 edge_types: vec!["KNOWS".to_string()],
1283 min_hops: 1,
1284 max_hops: Some(1),
1285 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1286 variable: "a".to_string(),
1287 label: Some("Person".to_string()),
1288 input: None,
1289 })),
1290 path_alias: None,
1291 path_mode: PathMode::Walk,
1292 })),
1293 })),
1294 }));
1295
1296 let optimizer = Optimizer::new();
1297 let optimized = optimizer.optimize(plan).unwrap();
1298
1299 if let LogicalOperator::Return(ret) = &optimized.root
1302 && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1303 {
1304 if let LogicalExpression::Binary { left, .. } = &filter.predicate
1306 && let LogicalExpression::Property { variable, .. } = left.as_ref()
1307 {
1308 assert_eq!(variable, "b");
1309 }
1310
1311 if let LogicalOperator::Expand(expand) = filter.input.as_ref()
1312 && let LogicalOperator::NodeScan(_) = expand.input.as_ref()
1313 {
1314 return;
1315 }
1316 }
1317 panic!("Expected Return -> Filter -> Expand -> NodeScan structure");
1318 }
1319
1320 #[test]
1321 fn test_optimizer_extract_variables() {
1322 let optimizer = Optimizer::new();
1323
1324 let expr = LogicalExpression::Binary {
1325 left: Box::new(LogicalExpression::Property {
1326 variable: "n".to_string(),
1327 property: "age".to_string(),
1328 }),
1329 op: BinaryOp::Gt,
1330 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1331 };
1332
1333 let vars = optimizer.extract_variables(&expr);
1334 assert_eq!(vars.len(), 1);
1335 assert!(vars.contains("n"));
1336 }
1337
1338 #[test]
1341 fn test_optimizer_default() {
1342 let optimizer = Optimizer::default();
1343 let plan = LogicalPlan::new(LogicalOperator::Empty);
1345 let result = optimizer.optimize(plan);
1346 assert!(result.is_ok());
1347 }
1348
1349 #[test]
1350 fn test_optimizer_with_filter_pushdown_disabled() {
1351 let optimizer = Optimizer::new().with_filter_pushdown(false);
1352
1353 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1354 items: vec![ReturnItem {
1355 expression: LogicalExpression::Variable("n".to_string()),
1356 alias: None,
1357 }],
1358 distinct: false,
1359 input: Box::new(LogicalOperator::Filter(FilterOp {
1360 predicate: LogicalExpression::Literal(Value::Bool(true)),
1361 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1362 variable: "n".to_string(),
1363 label: None,
1364 input: None,
1365 })),
1366 pushdown_hint: None,
1367 })),
1368 }));
1369
1370 let optimized = optimizer.optimize(plan).unwrap();
1371 if let LogicalOperator::Return(ret) = &optimized.root
1373 && let LogicalOperator::Filter(_) = ret.input.as_ref()
1374 {
1375 return;
1376 }
1377 panic!("Expected unchanged structure");
1378 }
1379
1380 #[test]
1381 fn test_optimizer_with_join_reorder_disabled() {
1382 let optimizer = Optimizer::new().with_join_reorder(false);
1383 assert!(
1384 optimizer
1385 .optimize(LogicalPlan::new(LogicalOperator::Empty))
1386 .is_ok()
1387 );
1388 }
1389
1390 #[test]
1391 fn test_optimizer_with_cost_model() {
1392 let cost_model = CostModel::new();
1393 let optimizer = Optimizer::new().with_cost_model(cost_model);
1394 assert!(
1395 optimizer
1396 .cost_model()
1397 .estimate(&LogicalOperator::Empty, 0.0)
1398 .total()
1399 < 0.001
1400 );
1401 }
1402
1403 #[test]
1404 fn test_optimizer_with_cardinality_estimator() {
1405 let mut estimator = CardinalityEstimator::new();
1406 estimator.add_table_stats("Test", TableStats::new(500));
1407 let optimizer = Optimizer::new().with_cardinality_estimator(estimator);
1408
1409 let scan = LogicalOperator::NodeScan(NodeScanOp {
1410 variable: "n".to_string(),
1411 label: Some("Test".to_string()),
1412 input: None,
1413 });
1414 let plan = LogicalPlan::new(scan);
1415
1416 let cardinality = optimizer.estimate_cardinality(&plan);
1417 assert!((cardinality - 500.0).abs() < 0.001);
1418 }
1419
1420 #[test]
1421 fn test_optimizer_estimate_cost() {
1422 let optimizer = Optimizer::new();
1423 let plan = LogicalPlan::new(LogicalOperator::NodeScan(NodeScanOp {
1424 variable: "n".to_string(),
1425 label: None,
1426 input: None,
1427 }));
1428
1429 let cost = optimizer.estimate_cost(&plan);
1430 assert!(cost.total() > 0.0);
1431 }
1432
1433 #[test]
1436 fn test_filter_pushdown_through_project() {
1437 let optimizer = Optimizer::new();
1438
1439 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1440 predicate: LogicalExpression::Binary {
1441 left: Box::new(LogicalExpression::Property {
1442 variable: "n".to_string(),
1443 property: "age".to_string(),
1444 }),
1445 op: BinaryOp::Gt,
1446 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1447 },
1448 pushdown_hint: None,
1449 input: Box::new(LogicalOperator::Project(ProjectOp {
1450 projections: vec![Projection {
1451 expression: LogicalExpression::Variable("n".to_string()),
1452 alias: None,
1453 }],
1454 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1455 variable: "n".to_string(),
1456 label: None,
1457 input: None,
1458 })),
1459 pass_through_input: false,
1460 })),
1461 }));
1462
1463 let optimized = optimizer.optimize(plan).unwrap();
1464
1465 if let LogicalOperator::Project(proj) = &optimized.root
1467 && let LogicalOperator::Filter(_) = proj.input.as_ref()
1468 {
1469 return;
1470 }
1471 panic!("Expected Project -> Filter structure");
1472 }
1473
1474 #[test]
1475 fn test_filter_not_pushed_through_project_with_alias() {
1476 let optimizer = Optimizer::new();
1477
1478 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1480 predicate: LogicalExpression::Binary {
1481 left: Box::new(LogicalExpression::Variable("x".to_string())),
1482 op: BinaryOp::Gt,
1483 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1484 },
1485 pushdown_hint: None,
1486 input: Box::new(LogicalOperator::Project(ProjectOp {
1487 projections: vec![Projection {
1488 expression: LogicalExpression::Property {
1489 variable: "n".to_string(),
1490 property: "age".to_string(),
1491 },
1492 alias: Some("x".to_string()),
1493 }],
1494 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1495 variable: "n".to_string(),
1496 label: None,
1497 input: None,
1498 })),
1499 pass_through_input: false,
1500 })),
1501 }));
1502
1503 let optimized = optimizer.optimize(plan).unwrap();
1504
1505 if let LogicalOperator::Filter(filter) = &optimized.root
1507 && let LogicalOperator::Project(_) = filter.input.as_ref()
1508 {
1509 return;
1510 }
1511 panic!("Expected Filter -> Project structure");
1512 }
1513
1514 #[test]
1515 fn test_filter_pushdown_through_limit() {
1516 let optimizer = Optimizer::new();
1517
1518 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1519 predicate: LogicalExpression::Literal(Value::Bool(true)),
1520 pushdown_hint: None,
1521 input: Box::new(LogicalOperator::Limit(LimitOp {
1522 count: 10.into(),
1523 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1524 variable: "n".to_string(),
1525 label: None,
1526 input: None,
1527 })),
1528 })),
1529 }));
1530
1531 let optimized = optimizer.optimize(plan).unwrap();
1532
1533 if let LogicalOperator::Filter(filter) = &optimized.root
1535 && let LogicalOperator::Limit(_) = filter.input.as_ref()
1536 {
1537 return;
1538 }
1539 panic!("Expected Filter -> Limit structure");
1540 }
1541
1542 #[test]
1543 fn test_filter_pushdown_through_sort() {
1544 let optimizer = Optimizer::new();
1545
1546 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1547 predicate: LogicalExpression::Literal(Value::Bool(true)),
1548 pushdown_hint: None,
1549 input: Box::new(LogicalOperator::Sort(SortOp {
1550 keys: vec![SortKey {
1551 expression: LogicalExpression::Variable("n".to_string()),
1552 order: SortOrder::Ascending,
1553 nulls: None,
1554 }],
1555 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1556 variable: "n".to_string(),
1557 label: None,
1558 input: None,
1559 })),
1560 })),
1561 }));
1562
1563 let optimized = optimizer.optimize(plan).unwrap();
1564
1565 if let LogicalOperator::Filter(filter) = &optimized.root
1567 && let LogicalOperator::Sort(_) = filter.input.as_ref()
1568 {
1569 return;
1570 }
1571 panic!("Expected Filter -> Sort structure");
1572 }
1573
1574 #[test]
1575 fn test_filter_pushdown_through_distinct() {
1576 let optimizer = Optimizer::new();
1577
1578 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1579 predicate: LogicalExpression::Literal(Value::Bool(true)),
1580 pushdown_hint: None,
1581 input: Box::new(LogicalOperator::Distinct(DistinctOp {
1582 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1583 variable: "n".to_string(),
1584 label: None,
1585 input: None,
1586 })),
1587 columns: None,
1588 })),
1589 }));
1590
1591 let optimized = optimizer.optimize(plan).unwrap();
1592
1593 if let LogicalOperator::Filter(filter) = &optimized.root
1595 && let LogicalOperator::Distinct(_) = filter.input.as_ref()
1596 {
1597 return;
1598 }
1599 panic!("Expected Filter -> Distinct structure");
1600 }
1601
1602 #[test]
1603 fn test_filter_not_pushed_through_aggregate() {
1604 let optimizer = Optimizer::new();
1605
1606 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1607 predicate: LogicalExpression::Binary {
1608 left: Box::new(LogicalExpression::Variable("cnt".to_string())),
1609 op: BinaryOp::Gt,
1610 right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1611 },
1612 pushdown_hint: None,
1613 input: Box::new(LogicalOperator::Aggregate(AggregateOp {
1614 group_by: vec![],
1615 aggregates: vec![AggregateExpr {
1616 function: AggregateFunction::Count,
1617 expression: None,
1618 expression2: None,
1619 distinct: false,
1620 alias: Some("cnt".to_string()),
1621 percentile: None,
1622 separator: None,
1623 }],
1624 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1625 variable: "n".to_string(),
1626 label: None,
1627 input: None,
1628 })),
1629 having: None,
1630 })),
1631 }));
1632
1633 let optimized = optimizer.optimize(plan).unwrap();
1634
1635 if let LogicalOperator::Filter(filter) = &optimized.root
1637 && let LogicalOperator::Aggregate(_) = filter.input.as_ref()
1638 {
1639 return;
1640 }
1641 panic!("Expected Filter -> Aggregate structure");
1642 }
1643
1644 #[test]
1645 fn test_filter_pushdown_to_left_join_side() {
1646 let optimizer = Optimizer::new();
1647
1648 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1650 predicate: LogicalExpression::Binary {
1651 left: Box::new(LogicalExpression::Property {
1652 variable: "a".to_string(),
1653 property: "age".to_string(),
1654 }),
1655 op: BinaryOp::Gt,
1656 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1657 },
1658 pushdown_hint: None,
1659 input: Box::new(LogicalOperator::Join(JoinOp {
1660 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1661 variable: "a".to_string(),
1662 label: Some("Person".to_string()),
1663 input: None,
1664 })),
1665 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1666 variable: "b".to_string(),
1667 label: Some("Company".to_string()),
1668 input: None,
1669 })),
1670 join_type: JoinType::Inner,
1671 conditions: vec![],
1672 })),
1673 }));
1674
1675 let optimized = optimizer.optimize(plan).unwrap();
1676
1677 if let LogicalOperator::Join(join) = &optimized.root
1679 && let LogicalOperator::Filter(_) = join.left.as_ref()
1680 {
1681 return;
1682 }
1683 panic!("Expected Join with Filter on left side");
1684 }
1685
1686 #[test]
1687 fn test_filter_pushdown_to_right_join_side() {
1688 let optimizer = Optimizer::new();
1689
1690 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1692 predicate: LogicalExpression::Binary {
1693 left: Box::new(LogicalExpression::Property {
1694 variable: "b".to_string(),
1695 property: "name".to_string(),
1696 }),
1697 op: BinaryOp::Eq,
1698 right: Box::new(LogicalExpression::Literal(Value::String("Acme".into()))),
1699 },
1700 pushdown_hint: None,
1701 input: Box::new(LogicalOperator::Join(JoinOp {
1702 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1703 variable: "a".to_string(),
1704 label: Some("Person".to_string()),
1705 input: None,
1706 })),
1707 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1708 variable: "b".to_string(),
1709 label: Some("Company".to_string()),
1710 input: None,
1711 })),
1712 join_type: JoinType::Inner,
1713 conditions: vec![],
1714 })),
1715 }));
1716
1717 let optimized = optimizer.optimize(plan).unwrap();
1718
1719 if let LogicalOperator::Join(join) = &optimized.root
1721 && let LogicalOperator::Filter(_) = join.right.as_ref()
1722 {
1723 return;
1724 }
1725 panic!("Expected Join with Filter on right side");
1726 }
1727
1728 #[test]
1729 fn test_filter_not_pushed_when_uses_both_join_sides() {
1730 let optimizer = Optimizer::new();
1731
1732 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1734 predicate: LogicalExpression::Binary {
1735 left: Box::new(LogicalExpression::Property {
1736 variable: "a".to_string(),
1737 property: "id".to_string(),
1738 }),
1739 op: BinaryOp::Eq,
1740 right: Box::new(LogicalExpression::Property {
1741 variable: "b".to_string(),
1742 property: "a_id".to_string(),
1743 }),
1744 },
1745 pushdown_hint: None,
1746 input: Box::new(LogicalOperator::Join(JoinOp {
1747 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1748 variable: "a".to_string(),
1749 label: None,
1750 input: None,
1751 })),
1752 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1753 variable: "b".to_string(),
1754 label: None,
1755 input: None,
1756 })),
1757 join_type: JoinType::Inner,
1758 conditions: vec![],
1759 })),
1760 }));
1761
1762 let optimized = optimizer.optimize(plan).unwrap();
1763
1764 if let LogicalOperator::Filter(filter) = &optimized.root
1766 && let LogicalOperator::Join(_) = filter.input.as_ref()
1767 {
1768 return;
1769 }
1770 panic!("Expected Filter -> Join structure");
1771 }
1772
1773 #[test]
1776 fn test_extract_variables_from_variable() {
1777 let optimizer = Optimizer::new();
1778 let expr = LogicalExpression::Variable("x".to_string());
1779 let vars = optimizer.extract_variables(&expr);
1780 assert_eq!(vars.len(), 1);
1781 assert!(vars.contains("x"));
1782 }
1783
1784 #[test]
1785 fn test_extract_variables_from_unary() {
1786 let optimizer = Optimizer::new();
1787 let expr = LogicalExpression::Unary {
1788 op: UnaryOp::Not,
1789 operand: Box::new(LogicalExpression::Variable("x".to_string())),
1790 };
1791 let vars = optimizer.extract_variables(&expr);
1792 assert_eq!(vars.len(), 1);
1793 assert!(vars.contains("x"));
1794 }
1795
1796 #[test]
1797 fn test_extract_variables_from_function_call() {
1798 let optimizer = Optimizer::new();
1799 let expr = LogicalExpression::FunctionCall {
1800 name: "length".to_string(),
1801 args: vec![
1802 LogicalExpression::Variable("a".to_string()),
1803 LogicalExpression::Variable("b".to_string()),
1804 ],
1805 distinct: false,
1806 };
1807 let vars = optimizer.extract_variables(&expr);
1808 assert_eq!(vars.len(), 2);
1809 assert!(vars.contains("a"));
1810 assert!(vars.contains("b"));
1811 }
1812
1813 #[test]
1814 fn test_extract_variables_from_list() {
1815 let optimizer = Optimizer::new();
1816 let expr = LogicalExpression::List(vec![
1817 LogicalExpression::Variable("a".to_string()),
1818 LogicalExpression::Literal(Value::Int64(1)),
1819 LogicalExpression::Variable("b".to_string()),
1820 ]);
1821 let vars = optimizer.extract_variables(&expr);
1822 assert_eq!(vars.len(), 2);
1823 assert!(vars.contains("a"));
1824 assert!(vars.contains("b"));
1825 }
1826
1827 #[test]
1828 fn test_extract_variables_from_map() {
1829 let optimizer = Optimizer::new();
1830 let expr = LogicalExpression::Map(vec![
1831 (
1832 "key1".to_string(),
1833 LogicalExpression::Variable("a".to_string()),
1834 ),
1835 (
1836 "key2".to_string(),
1837 LogicalExpression::Variable("b".to_string()),
1838 ),
1839 ]);
1840 let vars = optimizer.extract_variables(&expr);
1841 assert_eq!(vars.len(), 2);
1842 assert!(vars.contains("a"));
1843 assert!(vars.contains("b"));
1844 }
1845
1846 #[test]
1847 fn test_extract_variables_from_index_access() {
1848 let optimizer = Optimizer::new();
1849 let expr = LogicalExpression::IndexAccess {
1850 base: Box::new(LogicalExpression::Variable("list".to_string())),
1851 index: Box::new(LogicalExpression::Variable("idx".to_string())),
1852 };
1853 let vars = optimizer.extract_variables(&expr);
1854 assert_eq!(vars.len(), 2);
1855 assert!(vars.contains("list"));
1856 assert!(vars.contains("idx"));
1857 }
1858
1859 #[test]
1860 fn test_extract_variables_from_slice_access() {
1861 let optimizer = Optimizer::new();
1862 let expr = LogicalExpression::SliceAccess {
1863 base: Box::new(LogicalExpression::Variable("list".to_string())),
1864 start: Some(Box::new(LogicalExpression::Variable("s".to_string()))),
1865 end: Some(Box::new(LogicalExpression::Variable("e".to_string()))),
1866 };
1867 let vars = optimizer.extract_variables(&expr);
1868 assert_eq!(vars.len(), 3);
1869 assert!(vars.contains("list"));
1870 assert!(vars.contains("s"));
1871 assert!(vars.contains("e"));
1872 }
1873
1874 #[test]
1875 fn test_extract_variables_from_case() {
1876 let optimizer = Optimizer::new();
1877 let expr = LogicalExpression::Case {
1878 operand: Some(Box::new(LogicalExpression::Variable("x".to_string()))),
1879 when_clauses: vec![(
1880 LogicalExpression::Literal(Value::Int64(1)),
1881 LogicalExpression::Variable("a".to_string()),
1882 )],
1883 else_clause: Some(Box::new(LogicalExpression::Variable("b".to_string()))),
1884 };
1885 let vars = optimizer.extract_variables(&expr);
1886 assert_eq!(vars.len(), 3);
1887 assert!(vars.contains("x"));
1888 assert!(vars.contains("a"));
1889 assert!(vars.contains("b"));
1890 }
1891
1892 #[test]
1893 fn test_extract_variables_from_labels() {
1894 let optimizer = Optimizer::new();
1895 let expr = LogicalExpression::Labels("n".to_string());
1896 let vars = optimizer.extract_variables(&expr);
1897 assert_eq!(vars.len(), 1);
1898 assert!(vars.contains("n"));
1899 }
1900
1901 #[test]
1902 fn test_extract_variables_from_type() {
1903 let optimizer = Optimizer::new();
1904 let expr = LogicalExpression::Type("e".to_string());
1905 let vars = optimizer.extract_variables(&expr);
1906 assert_eq!(vars.len(), 1);
1907 assert!(vars.contains("e"));
1908 }
1909
1910 #[test]
1911 fn test_extract_variables_from_id() {
1912 let optimizer = Optimizer::new();
1913 let expr = LogicalExpression::Id("n".to_string());
1914 let vars = optimizer.extract_variables(&expr);
1915 assert_eq!(vars.len(), 1);
1916 assert!(vars.contains("n"));
1917 }
1918
1919 #[test]
1920 fn test_extract_variables_from_list_comprehension() {
1921 let optimizer = Optimizer::new();
1922 let expr = LogicalExpression::ListComprehension {
1923 variable: "x".to_string(),
1924 list_expr: Box::new(LogicalExpression::Variable("items".to_string())),
1925 filter_expr: Some(Box::new(LogicalExpression::Variable("pred".to_string()))),
1926 map_expr: Box::new(LogicalExpression::Variable("result".to_string())),
1927 };
1928 let vars = optimizer.extract_variables(&expr);
1929 assert!(vars.contains("items"));
1930 assert!(vars.contains("pred"));
1931 assert!(vars.contains("result"));
1932 }
1933
1934 #[test]
1935 fn test_extract_variables_from_literal_and_parameter() {
1936 let optimizer = Optimizer::new();
1937
1938 let literal = LogicalExpression::Literal(Value::Int64(42));
1939 assert!(optimizer.extract_variables(&literal).is_empty());
1940
1941 let param = LogicalExpression::Parameter("p".to_string());
1942 assert!(optimizer.extract_variables(¶m).is_empty());
1943 }
1944
1945 #[test]
1948 fn test_recursive_filter_pushdown_through_skip() {
1949 let optimizer = Optimizer::new();
1950
1951 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1952 items: vec![ReturnItem {
1953 expression: LogicalExpression::Variable("n".to_string()),
1954 alias: None,
1955 }],
1956 distinct: false,
1957 input: Box::new(LogicalOperator::Filter(FilterOp {
1958 predicate: LogicalExpression::Literal(Value::Bool(true)),
1959 pushdown_hint: None,
1960 input: Box::new(LogicalOperator::Skip(SkipOp {
1961 count: 5.into(),
1962 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1963 variable: "n".to_string(),
1964 label: None,
1965 input: None,
1966 })),
1967 })),
1968 })),
1969 }));
1970
1971 let optimized = optimizer.optimize(plan).unwrap();
1972
1973 assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1975 }
1976
1977 #[test]
1978 fn test_nested_filter_pushdown() {
1979 let optimizer = Optimizer::new();
1980
1981 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1983 items: vec![ReturnItem {
1984 expression: LogicalExpression::Variable("n".to_string()),
1985 alias: None,
1986 }],
1987 distinct: false,
1988 input: Box::new(LogicalOperator::Filter(FilterOp {
1989 predicate: LogicalExpression::Binary {
1990 left: Box::new(LogicalExpression::Property {
1991 variable: "n".to_string(),
1992 property: "x".to_string(),
1993 }),
1994 op: BinaryOp::Gt,
1995 right: Box::new(LogicalExpression::Literal(Value::Int64(1))),
1996 },
1997 pushdown_hint: None,
1998 input: Box::new(LogicalOperator::Filter(FilterOp {
1999 predicate: LogicalExpression::Binary {
2000 left: Box::new(LogicalExpression::Property {
2001 variable: "n".to_string(),
2002 property: "y".to_string(),
2003 }),
2004 op: BinaryOp::Lt,
2005 right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
2006 },
2007 pushdown_hint: None,
2008 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2009 variable: "n".to_string(),
2010 label: None,
2011 input: None,
2012 })),
2013 })),
2014 })),
2015 }));
2016
2017 let optimized = optimizer.optimize(plan).unwrap();
2018 assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
2019 }
2020
2021 #[test]
2022 fn test_cyclic_join_produces_multi_way_join() {
2023 use crate::query::plan::JoinCondition;
2024
2025 let scan_a = LogicalOperator::NodeScan(NodeScanOp {
2027 variable: "a".to_string(),
2028 label: Some("Person".to_string()),
2029 input: None,
2030 });
2031 let scan_b = LogicalOperator::NodeScan(NodeScanOp {
2032 variable: "b".to_string(),
2033 label: Some("Person".to_string()),
2034 input: None,
2035 });
2036 let scan_c = LogicalOperator::NodeScan(NodeScanOp {
2037 variable: "c".to_string(),
2038 label: Some("Person".to_string()),
2039 input: None,
2040 });
2041
2042 let join_ab = LogicalOperator::Join(JoinOp {
2044 left: Box::new(scan_a),
2045 right: Box::new(scan_b),
2046 join_type: JoinType::Inner,
2047 conditions: vec![JoinCondition {
2048 left: LogicalExpression::Variable("a".to_string()),
2049 right: LogicalExpression::Variable("b".to_string()),
2050 }],
2051 });
2052
2053 let join_abc = LogicalOperator::Join(JoinOp {
2054 left: Box::new(join_ab),
2055 right: Box::new(scan_c),
2056 join_type: JoinType::Inner,
2057 conditions: vec![
2058 JoinCondition {
2059 left: LogicalExpression::Variable("b".to_string()),
2060 right: LogicalExpression::Variable("c".to_string()),
2061 },
2062 JoinCondition {
2063 left: LogicalExpression::Variable("c".to_string()),
2064 right: LogicalExpression::Variable("a".to_string()),
2065 },
2066 ],
2067 });
2068
2069 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
2070 items: vec![ReturnItem {
2071 expression: LogicalExpression::Variable("a".to_string()),
2072 alias: None,
2073 }],
2074 distinct: false,
2075 input: Box::new(join_abc),
2076 }));
2077
2078 let mut optimizer = Optimizer::new();
2079 optimizer
2080 .card_estimator
2081 .add_table_stats("Person", cardinality::TableStats::new(1000));
2082
2083 let optimized = optimizer.optimize(plan).unwrap();
2084
2085 fn has_multi_way_join(op: &LogicalOperator) -> bool {
2087 match op {
2088 LogicalOperator::MultiWayJoin(_) => true,
2089 LogicalOperator::Return(ret) => has_multi_way_join(&ret.input),
2090 LogicalOperator::Filter(f) => has_multi_way_join(&f.input),
2091 LogicalOperator::Project(p) => has_multi_way_join(&p.input),
2092 _ => false,
2093 }
2094 }
2095
2096 assert!(
2097 has_multi_way_join(&optimized.root),
2098 "Expected MultiWayJoin for cyclic triangle pattern"
2099 );
2100 }
2101
2102 #[test]
2103 fn test_acyclic_join_uses_binary_joins() {
2104 use crate::query::plan::JoinCondition;
2105
2106 let scan_a = LogicalOperator::NodeScan(NodeScanOp {
2108 variable: "a".to_string(),
2109 label: Some("Person".to_string()),
2110 input: None,
2111 });
2112 let scan_b = LogicalOperator::NodeScan(NodeScanOp {
2113 variable: "b".to_string(),
2114 label: Some("Person".to_string()),
2115 input: None,
2116 });
2117 let scan_c = LogicalOperator::NodeScan(NodeScanOp {
2118 variable: "c".to_string(),
2119 label: Some("Company".to_string()),
2120 input: None,
2121 });
2122
2123 let join_ab = LogicalOperator::Join(JoinOp {
2124 left: Box::new(scan_a),
2125 right: Box::new(scan_b),
2126 join_type: JoinType::Inner,
2127 conditions: vec![JoinCondition {
2128 left: LogicalExpression::Variable("a".to_string()),
2129 right: LogicalExpression::Variable("b".to_string()),
2130 }],
2131 });
2132
2133 let join_abc = LogicalOperator::Join(JoinOp {
2134 left: Box::new(join_ab),
2135 right: Box::new(scan_c),
2136 join_type: JoinType::Inner,
2137 conditions: vec![JoinCondition {
2138 left: LogicalExpression::Variable("b".to_string()),
2139 right: LogicalExpression::Variable("c".to_string()),
2140 }],
2141 });
2142
2143 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
2144 items: vec![ReturnItem {
2145 expression: LogicalExpression::Variable("a".to_string()),
2146 alias: None,
2147 }],
2148 distinct: false,
2149 input: Box::new(join_abc),
2150 }));
2151
2152 let mut optimizer = Optimizer::new();
2153 optimizer
2154 .card_estimator
2155 .add_table_stats("Person", cardinality::TableStats::new(1000));
2156 optimizer
2157 .card_estimator
2158 .add_table_stats("Company", cardinality::TableStats::new(100));
2159
2160 let optimized = optimizer.optimize(plan).unwrap();
2161
2162 fn has_multi_way_join(op: &LogicalOperator) -> bool {
2164 match op {
2165 LogicalOperator::MultiWayJoin(_) => true,
2166 LogicalOperator::Return(ret) => has_multi_way_join(&ret.input),
2167 LogicalOperator::Filter(f) => has_multi_way_join(&f.input),
2168 LogicalOperator::Project(p) => has_multi_way_join(&p.input),
2169 LogicalOperator::Join(j) => {
2170 has_multi_way_join(&j.left) || has_multi_way_join(&j.right)
2171 }
2172 _ => false,
2173 }
2174 }
2175
2176 assert!(
2177 !has_multi_way_join(&optimized.root),
2178 "Acyclic join should NOT produce MultiWayJoin"
2179 );
2180 }
2181}