1pub mod cardinality;
15pub mod cost;
16pub mod join_order;
17
18pub use cardinality::{
19 CardinalityEstimator, ColumnStats, EstimationLog, SelectivityConfig, TableStats,
20};
21pub use cost::{Cost, CostModel};
22pub use join_order::{BitSet, DPccp, JoinGraph, JoinGraphBuilder, JoinPlan};
23
24use crate::query::plan::{
25 FilterOp, JoinCondition, LogicalExpression, LogicalOperator, LogicalPlan, MultiWayJoinOp,
26};
27use grafeo_common::utils::error::Result;
28use std::collections::HashSet;
29
30#[derive(Debug, Clone)]
32struct JoinInfo {
33 left_var: String,
34 right_var: String,
35 left_expr: LogicalExpression,
36 right_expr: LogicalExpression,
37}
38
39#[derive(Debug, Clone, PartialEq, Eq, Hash)]
41enum RequiredColumn {
42 Variable(String),
44 Property(String, String),
46}
47
48pub struct Optimizer {
53 enable_filter_pushdown: bool,
55 enable_join_reorder: bool,
57 enable_projection_pushdown: bool,
59 cost_model: CostModel,
61 card_estimator: CardinalityEstimator,
63}
64
65impl Optimizer {
66 #[must_use]
68 pub fn new() -> Self {
69 Self {
70 enable_filter_pushdown: true,
71 enable_join_reorder: true,
72 enable_projection_pushdown: true,
73 cost_model: CostModel::new(),
74 card_estimator: CardinalityEstimator::new(),
75 }
76 }
77
78 #[must_use]
84 pub fn from_store(store: &grafeo_core::graph::lpg::LpgStore) -> Self {
85 store.ensure_statistics_fresh();
86 let stats = store.statistics();
87 Self::from_statistics(&stats)
88 }
89
90 #[must_use]
97 pub fn from_graph_store(store: &dyn grafeo_core::graph::GraphStore) -> Self {
98 let stats = store.statistics();
99 Self::from_statistics(&stats)
100 }
101
102 #[must_use]
107 fn from_statistics(stats: &grafeo_core::statistics::Statistics) -> Self {
108 let estimator = CardinalityEstimator::from_statistics(stats);
109
110 let avg_fanout = if stats.total_nodes > 0 {
111 (stats.total_edges as f64 / stats.total_nodes as f64).max(1.0)
112 } else {
113 10.0
114 };
115
116 let edge_type_degrees: std::collections::HashMap<String, (f64, f64)> = stats
117 .edge_types
118 .iter()
119 .map(|(name, et)| (name.clone(), (et.avg_out_degree, et.avg_in_degree)))
120 .collect();
121
122 let label_cardinalities: std::collections::HashMap<String, u64> = stats
123 .labels
124 .iter()
125 .map(|(name, ls)| (name.clone(), ls.node_count))
126 .collect();
127
128 Self {
129 enable_filter_pushdown: true,
130 enable_join_reorder: true,
131 enable_projection_pushdown: true,
132 cost_model: CostModel::new()
133 .with_avg_fanout(avg_fanout)
134 .with_edge_type_degrees(edge_type_degrees)
135 .with_label_cardinalities(label_cardinalities)
136 .with_graph_totals(stats.total_nodes, stats.total_edges),
137 card_estimator: estimator,
138 }
139 }
140
141 pub fn with_filter_pushdown(mut self, enabled: bool) -> Self {
143 self.enable_filter_pushdown = enabled;
144 self
145 }
146
147 pub fn with_join_reorder(mut self, enabled: bool) -> Self {
149 self.enable_join_reorder = enabled;
150 self
151 }
152
153 pub fn with_projection_pushdown(mut self, enabled: bool) -> Self {
155 self.enable_projection_pushdown = enabled;
156 self
157 }
158
159 pub fn with_cost_model(mut self, cost_model: CostModel) -> Self {
161 self.cost_model = cost_model;
162 self
163 }
164
165 pub fn with_cardinality_estimator(mut self, estimator: CardinalityEstimator) -> Self {
167 self.card_estimator = estimator;
168 self
169 }
170
171 pub fn with_selectivity_config(mut self, config: SelectivityConfig) -> Self {
173 self.card_estimator = CardinalityEstimator::with_selectivity_config(config);
174 self
175 }
176
177 pub fn cost_model(&self) -> &CostModel {
179 &self.cost_model
180 }
181
182 pub fn cardinality_estimator(&self) -> &CardinalityEstimator {
184 &self.card_estimator
185 }
186
187 pub fn estimate_cost(&self, plan: &LogicalPlan) -> Cost {
193 self.cost_model
194 .estimate_tree(&plan.root, &self.card_estimator)
195 }
196
197 pub fn estimate_cardinality(&self, plan: &LogicalPlan) -> f64 {
199 self.card_estimator.estimate(&plan.root)
200 }
201
202 pub fn optimize(&self, plan: LogicalPlan) -> Result<LogicalPlan> {
208 let mut root = plan.root;
209
210 if self.enable_filter_pushdown {
212 root = self.push_filters_down(root);
213 }
214
215 if self.enable_join_reorder {
216 root = self.reorder_joins(root);
217 }
218
219 if self.enable_projection_pushdown {
220 root = self.push_projections_down(root);
221 }
222
223 Ok(LogicalPlan {
224 root,
225 explain: plan.explain,
226 profile: plan.profile,
227 })
228 }
229
230 fn push_projections_down(&self, op: LogicalOperator) -> LogicalOperator {
237 let required = self.collect_required_columns(&op);
239
240 self.push_projections_recursive(op, &required)
242 }
243
244 fn collect_required_columns(&self, op: &LogicalOperator) -> HashSet<RequiredColumn> {
246 let mut required = HashSet::new();
247 Self::collect_required_recursive(op, &mut required);
248 required
249 }
250
251 fn collect_required_recursive(op: &LogicalOperator, required: &mut HashSet<RequiredColumn>) {
253 match op {
254 LogicalOperator::Return(ret) => {
255 for item in &ret.items {
256 Self::collect_from_expression(&item.expression, required);
257 }
258 Self::collect_required_recursive(&ret.input, required);
259 }
260 LogicalOperator::Project(proj) => {
261 for p in &proj.projections {
262 Self::collect_from_expression(&p.expression, required);
263 }
264 Self::collect_required_recursive(&proj.input, required);
265 }
266 LogicalOperator::Filter(filter) => {
267 Self::collect_from_expression(&filter.predicate, required);
268 Self::collect_required_recursive(&filter.input, required);
269 }
270 LogicalOperator::Sort(sort) => {
271 for key in &sort.keys {
272 Self::collect_from_expression(&key.expression, required);
273 }
274 Self::collect_required_recursive(&sort.input, required);
275 }
276 LogicalOperator::Aggregate(agg) => {
277 for expr in &agg.group_by {
278 Self::collect_from_expression(expr, required);
279 }
280 for agg_expr in &agg.aggregates {
281 if let Some(ref expr) = agg_expr.expression {
282 Self::collect_from_expression(expr, required);
283 }
284 }
285 if let Some(ref having) = agg.having {
286 Self::collect_from_expression(having, required);
287 }
288 Self::collect_required_recursive(&agg.input, required);
289 }
290 LogicalOperator::Join(join) => {
291 for cond in &join.conditions {
292 Self::collect_from_expression(&cond.left, required);
293 Self::collect_from_expression(&cond.right, required);
294 }
295 Self::collect_required_recursive(&join.left, required);
296 Self::collect_required_recursive(&join.right, required);
297 }
298 LogicalOperator::Expand(expand) => {
299 required.insert(RequiredColumn::Variable(expand.from_variable.clone()));
301 required.insert(RequiredColumn::Variable(expand.to_variable.clone()));
302 if let Some(ref edge_var) = expand.edge_variable {
303 required.insert(RequiredColumn::Variable(edge_var.clone()));
304 }
305 Self::collect_required_recursive(&expand.input, required);
306 }
307 LogicalOperator::Limit(limit) => {
308 Self::collect_required_recursive(&limit.input, required);
309 }
310 LogicalOperator::Skip(skip) => {
311 Self::collect_required_recursive(&skip.input, required);
312 }
313 LogicalOperator::Distinct(distinct) => {
314 Self::collect_required_recursive(&distinct.input, required);
315 }
316 LogicalOperator::NodeScan(scan) => {
317 required.insert(RequiredColumn::Variable(scan.variable.clone()));
318 }
319 LogicalOperator::EdgeScan(scan) => {
320 required.insert(RequiredColumn::Variable(scan.variable.clone()));
321 }
322 LogicalOperator::MultiWayJoin(mwj) => {
323 for cond in &mwj.conditions {
324 Self::collect_from_expression(&cond.left, required);
325 Self::collect_from_expression(&cond.right, required);
326 }
327 for input in &mwj.inputs {
328 Self::collect_required_recursive(input, required);
329 }
330 }
331 _ => {}
332 }
333 }
334
335 fn collect_from_expression(expr: &LogicalExpression, required: &mut HashSet<RequiredColumn>) {
337 match expr {
338 LogicalExpression::Variable(var) => {
339 required.insert(RequiredColumn::Variable(var.clone()));
340 }
341 LogicalExpression::Property { variable, property } => {
342 required.insert(RequiredColumn::Property(variable.clone(), property.clone()));
343 required.insert(RequiredColumn::Variable(variable.clone()));
344 }
345 LogicalExpression::Binary { left, right, .. } => {
346 Self::collect_from_expression(left, required);
347 Self::collect_from_expression(right, required);
348 }
349 LogicalExpression::Unary { operand, .. } => {
350 Self::collect_from_expression(operand, required);
351 }
352 LogicalExpression::FunctionCall { args, .. } => {
353 for arg in args {
354 Self::collect_from_expression(arg, required);
355 }
356 }
357 LogicalExpression::List(items) => {
358 for item in items {
359 Self::collect_from_expression(item, required);
360 }
361 }
362 LogicalExpression::Map(pairs) => {
363 for (_, value) in pairs {
364 Self::collect_from_expression(value, required);
365 }
366 }
367 LogicalExpression::IndexAccess { base, index } => {
368 Self::collect_from_expression(base, required);
369 Self::collect_from_expression(index, required);
370 }
371 LogicalExpression::SliceAccess { base, start, end } => {
372 Self::collect_from_expression(base, required);
373 if let Some(s) = start {
374 Self::collect_from_expression(s, required);
375 }
376 if let Some(e) = end {
377 Self::collect_from_expression(e, required);
378 }
379 }
380 LogicalExpression::Case {
381 operand,
382 when_clauses,
383 else_clause,
384 } => {
385 if let Some(op) = operand {
386 Self::collect_from_expression(op, required);
387 }
388 for (cond, result) in when_clauses {
389 Self::collect_from_expression(cond, required);
390 Self::collect_from_expression(result, required);
391 }
392 if let Some(else_expr) = else_clause {
393 Self::collect_from_expression(else_expr, required);
394 }
395 }
396 LogicalExpression::Labels(var)
397 | LogicalExpression::Type(var)
398 | LogicalExpression::Id(var) => {
399 required.insert(RequiredColumn::Variable(var.clone()));
400 }
401 LogicalExpression::ListComprehension {
402 list_expr,
403 filter_expr,
404 map_expr,
405 ..
406 } => {
407 Self::collect_from_expression(list_expr, required);
408 if let Some(filter) = filter_expr {
409 Self::collect_from_expression(filter, required);
410 }
411 Self::collect_from_expression(map_expr, required);
412 }
413 _ => {}
414 }
415 }
416
417 fn push_projections_recursive(
419 &self,
420 op: LogicalOperator,
421 required: &HashSet<RequiredColumn>,
422 ) -> LogicalOperator {
423 match op {
424 LogicalOperator::Return(mut ret) => {
425 ret.input = Box::new(self.push_projections_recursive(*ret.input, required));
426 LogicalOperator::Return(ret)
427 }
428 LogicalOperator::Project(mut proj) => {
429 proj.input = Box::new(self.push_projections_recursive(*proj.input, required));
430 LogicalOperator::Project(proj)
431 }
432 LogicalOperator::Filter(mut filter) => {
433 filter.input = Box::new(self.push_projections_recursive(*filter.input, required));
434 LogicalOperator::Filter(filter)
435 }
436 LogicalOperator::Sort(mut sort) => {
437 sort.input = Box::new(self.push_projections_recursive(*sort.input, required));
440 LogicalOperator::Sort(sort)
441 }
442 LogicalOperator::Aggregate(mut agg) => {
443 agg.input = Box::new(self.push_projections_recursive(*agg.input, required));
444 LogicalOperator::Aggregate(agg)
445 }
446 LogicalOperator::Join(mut join) => {
447 let left_vars = self.collect_output_variables(&join.left);
450 let right_vars = self.collect_output_variables(&join.right);
451
452 let left_required: HashSet<_> = required
454 .iter()
455 .filter(|c| match c {
456 RequiredColumn::Variable(v) => left_vars.contains(v),
457 RequiredColumn::Property(v, _) => left_vars.contains(v),
458 })
459 .cloned()
460 .collect();
461
462 let right_required: HashSet<_> = required
463 .iter()
464 .filter(|c| match c {
465 RequiredColumn::Variable(v) => right_vars.contains(v),
466 RequiredColumn::Property(v, _) => right_vars.contains(v),
467 })
468 .cloned()
469 .collect();
470
471 join.left = Box::new(self.push_projections_recursive(*join.left, &left_required));
472 join.right =
473 Box::new(self.push_projections_recursive(*join.right, &right_required));
474 LogicalOperator::Join(join)
475 }
476 LogicalOperator::Expand(mut expand) => {
477 expand.input = Box::new(self.push_projections_recursive(*expand.input, required));
478 LogicalOperator::Expand(expand)
479 }
480 LogicalOperator::Limit(mut limit) => {
481 limit.input = Box::new(self.push_projections_recursive(*limit.input, required));
482 LogicalOperator::Limit(limit)
483 }
484 LogicalOperator::Skip(mut skip) => {
485 skip.input = Box::new(self.push_projections_recursive(*skip.input, required));
486 LogicalOperator::Skip(skip)
487 }
488 LogicalOperator::Distinct(mut distinct) => {
489 distinct.input =
490 Box::new(self.push_projections_recursive(*distinct.input, required));
491 LogicalOperator::Distinct(distinct)
492 }
493 LogicalOperator::MapCollect(mut mc) => {
494 mc.input = Box::new(self.push_projections_recursive(*mc.input, required));
495 LogicalOperator::MapCollect(mc)
496 }
497 LogicalOperator::MultiWayJoin(mut mwj) => {
498 mwj.inputs = mwj
499 .inputs
500 .into_iter()
501 .map(|input| self.push_projections_recursive(input, required))
502 .collect();
503 LogicalOperator::MultiWayJoin(mwj)
504 }
505 other => other,
506 }
507 }
508
509 fn reorder_joins(&self, op: LogicalOperator) -> LogicalOperator {
516 let op = self.reorder_joins_recursive(op);
518
519 if let Some((relations, conditions)) = self.extract_join_tree(&op)
521 && relations.len() >= 2
522 && let Some(optimized) = self.optimize_join_order(&relations, &conditions)
523 {
524 return optimized;
525 }
526
527 op
528 }
529
530 fn reorder_joins_recursive(&self, op: LogicalOperator) -> LogicalOperator {
532 match op {
533 LogicalOperator::Return(mut ret) => {
534 ret.input = Box::new(self.reorder_joins(*ret.input));
535 LogicalOperator::Return(ret)
536 }
537 LogicalOperator::Project(mut proj) => {
538 proj.input = Box::new(self.reorder_joins(*proj.input));
539 LogicalOperator::Project(proj)
540 }
541 LogicalOperator::Filter(mut filter) => {
542 filter.input = Box::new(self.reorder_joins(*filter.input));
543 LogicalOperator::Filter(filter)
544 }
545 LogicalOperator::Limit(mut limit) => {
546 limit.input = Box::new(self.reorder_joins(*limit.input));
547 LogicalOperator::Limit(limit)
548 }
549 LogicalOperator::Skip(mut skip) => {
550 skip.input = Box::new(self.reorder_joins(*skip.input));
551 LogicalOperator::Skip(skip)
552 }
553 LogicalOperator::Sort(mut sort) => {
554 sort.input = Box::new(self.reorder_joins(*sort.input));
555 LogicalOperator::Sort(sort)
556 }
557 LogicalOperator::Distinct(mut distinct) => {
558 distinct.input = Box::new(self.reorder_joins(*distinct.input));
559 LogicalOperator::Distinct(distinct)
560 }
561 LogicalOperator::Aggregate(mut agg) => {
562 agg.input = Box::new(self.reorder_joins(*agg.input));
563 LogicalOperator::Aggregate(agg)
564 }
565 LogicalOperator::Expand(mut expand) => {
566 expand.input = Box::new(self.reorder_joins(*expand.input));
567 LogicalOperator::Expand(expand)
568 }
569 LogicalOperator::MapCollect(mut mc) => {
570 mc.input = Box::new(self.reorder_joins(*mc.input));
571 LogicalOperator::MapCollect(mc)
572 }
573 LogicalOperator::MultiWayJoin(mut mwj) => {
574 mwj.inputs = mwj
575 .inputs
576 .into_iter()
577 .map(|input| self.reorder_joins(input))
578 .collect();
579 LogicalOperator::MultiWayJoin(mwj)
580 }
581 other => other,
583 }
584 }
585
586 fn extract_join_tree(
590 &self,
591 op: &LogicalOperator,
592 ) -> Option<(Vec<(String, LogicalOperator)>, Vec<JoinInfo>)> {
593 let mut relations = Vec::new();
594 let mut join_conditions = Vec::new();
595
596 if !self.collect_join_tree(op, &mut relations, &mut join_conditions) {
597 return None;
598 }
599
600 if relations.len() < 2 {
601 return None;
602 }
603
604 Some((relations, join_conditions))
605 }
606
607 fn collect_join_tree(
611 &self,
612 op: &LogicalOperator,
613 relations: &mut Vec<(String, LogicalOperator)>,
614 conditions: &mut Vec<JoinInfo>,
615 ) -> bool {
616 match op {
617 LogicalOperator::Join(join) => {
618 let left_ok = self.collect_join_tree(&join.left, relations, conditions);
620 let right_ok = self.collect_join_tree(&join.right, relations, conditions);
621
622 for cond in &join.conditions {
624 if let (Some(left_var), Some(right_var)) = (
625 self.extract_variable_from_expr(&cond.left),
626 self.extract_variable_from_expr(&cond.right),
627 ) {
628 conditions.push(JoinInfo {
629 left_var,
630 right_var,
631 left_expr: cond.left.clone(),
632 right_expr: cond.right.clone(),
633 });
634 }
635 }
636
637 left_ok && right_ok
638 }
639 LogicalOperator::NodeScan(scan) => {
640 relations.push((scan.variable.clone(), op.clone()));
641 true
642 }
643 LogicalOperator::EdgeScan(scan) => {
644 relations.push((scan.variable.clone(), op.clone()));
645 true
646 }
647 LogicalOperator::Filter(filter) => {
648 self.collect_join_tree(&filter.input, relations, conditions)
650 }
651 LogicalOperator::Expand(expand) => {
652 relations.push((expand.to_variable.clone(), op.clone()));
655 true
656 }
657 _ => false,
658 }
659 }
660
661 fn extract_variable_from_expr(&self, expr: &LogicalExpression) -> Option<String> {
663 match expr {
664 LogicalExpression::Variable(v) => Some(v.clone()),
665 LogicalExpression::Property { variable, .. } => Some(variable.clone()),
666 _ => None,
667 }
668 }
669
670 fn optimize_join_order(
673 &self,
674 relations: &[(String, LogicalOperator)],
675 conditions: &[JoinInfo],
676 ) -> Option<LogicalOperator> {
677 use join_order::{DPccp, JoinGraphBuilder};
678
679 let mut builder = JoinGraphBuilder::new();
681
682 for (var, relation) in relations {
683 builder.add_relation(var, relation.clone());
684 }
685
686 for cond in conditions {
687 builder.add_join_condition(
688 &cond.left_var,
689 &cond.right_var,
690 cond.left_expr.clone(),
691 cond.right_expr.clone(),
692 );
693 }
694
695 let graph = builder.build();
696
697 if graph.is_cyclic() && relations.len() >= 3 {
702 let mut var_counts: std::collections::HashMap<&str, usize> =
704 std::collections::HashMap::new();
705 for cond in conditions {
706 *var_counts.entry(&cond.left_var).or_default() += 1;
707 *var_counts.entry(&cond.right_var).or_default() += 1;
708 }
709 let shared_variables: Vec<String> = var_counts
710 .into_iter()
711 .filter(|(_, count)| *count >= 2)
712 .map(|(var, _)| var.to_string())
713 .collect();
714
715 let join_conditions: Vec<JoinCondition> = conditions
716 .iter()
717 .map(|c| JoinCondition {
718 left: c.left_expr.clone(),
719 right: c.right_expr.clone(),
720 })
721 .collect();
722
723 return Some(LogicalOperator::MultiWayJoin(MultiWayJoinOp {
724 inputs: relations.iter().map(|(_, rel)| rel.clone()).collect(),
725 conditions: join_conditions,
726 shared_variables,
727 }));
728 }
729
730 let mut dpccp = DPccp::new(&graph, &self.cost_model, &self.card_estimator);
732 let plan = dpccp.optimize()?;
733
734 Some(plan.operator)
735 }
736
737 fn push_filters_down(&self, op: LogicalOperator) -> LogicalOperator {
742 match op {
743 LogicalOperator::Filter(filter) => {
745 let optimized_input = self.push_filters_down(*filter.input);
746 self.try_push_filter_into(filter.predicate, optimized_input)
747 }
748 LogicalOperator::Return(mut ret) => {
750 ret.input = Box::new(self.push_filters_down(*ret.input));
751 LogicalOperator::Return(ret)
752 }
753 LogicalOperator::Project(mut proj) => {
754 proj.input = Box::new(self.push_filters_down(*proj.input));
755 LogicalOperator::Project(proj)
756 }
757 LogicalOperator::Limit(mut limit) => {
758 limit.input = Box::new(self.push_filters_down(*limit.input));
759 LogicalOperator::Limit(limit)
760 }
761 LogicalOperator::Skip(mut skip) => {
762 skip.input = Box::new(self.push_filters_down(*skip.input));
763 LogicalOperator::Skip(skip)
764 }
765 LogicalOperator::Sort(mut sort) => {
766 sort.input = Box::new(self.push_filters_down(*sort.input));
767 LogicalOperator::Sort(sort)
768 }
769 LogicalOperator::Distinct(mut distinct) => {
770 distinct.input = Box::new(self.push_filters_down(*distinct.input));
771 LogicalOperator::Distinct(distinct)
772 }
773 LogicalOperator::Expand(mut expand) => {
774 expand.input = Box::new(self.push_filters_down(*expand.input));
775 LogicalOperator::Expand(expand)
776 }
777 LogicalOperator::Join(mut join) => {
778 join.left = Box::new(self.push_filters_down(*join.left));
779 join.right = Box::new(self.push_filters_down(*join.right));
780 LogicalOperator::Join(join)
781 }
782 LogicalOperator::Aggregate(mut agg) => {
783 agg.input = Box::new(self.push_filters_down(*agg.input));
784 LogicalOperator::Aggregate(agg)
785 }
786 LogicalOperator::MapCollect(mut mc) => {
787 mc.input = Box::new(self.push_filters_down(*mc.input));
788 LogicalOperator::MapCollect(mc)
789 }
790 LogicalOperator::MultiWayJoin(mut mwj) => {
791 mwj.inputs = mwj
792 .inputs
793 .into_iter()
794 .map(|input| self.push_filters_down(input))
795 .collect();
796 LogicalOperator::MultiWayJoin(mwj)
797 }
798 other => other,
800 }
801 }
802
803 fn try_push_filter_into(
808 &self,
809 predicate: LogicalExpression,
810 op: LogicalOperator,
811 ) -> LogicalOperator {
812 match op {
813 LogicalOperator::Project(mut proj) => {
815 let predicate_vars = self.extract_variables(&predicate);
816 let computed_vars = self.extract_projection_aliases(&proj.projections);
817
818 if predicate_vars.is_disjoint(&computed_vars) {
820 proj.input = Box::new(self.try_push_filter_into(predicate, *proj.input));
821 LogicalOperator::Project(proj)
822 } else {
823 LogicalOperator::Filter(FilterOp {
825 predicate,
826 pushdown_hint: None,
827 input: Box::new(LogicalOperator::Project(proj)),
828 })
829 }
830 }
831
832 LogicalOperator::Return(mut ret) => {
834 ret.input = Box::new(self.try_push_filter_into(predicate, *ret.input));
835 LogicalOperator::Return(ret)
836 }
837
838 LogicalOperator::Expand(mut expand) => {
840 let predicate_vars = self.extract_variables(&predicate);
841
842 let mut introduced_vars = vec![&expand.to_variable];
847 if let Some(ref edge_var) = expand.edge_variable {
848 introduced_vars.push(edge_var);
849 }
850 if let Some(ref path_alias) = expand.path_alias {
851 introduced_vars.push(path_alias);
852 }
853
854 let uses_introduced_vars =
856 predicate_vars.iter().any(|v| introduced_vars.contains(&v));
857
858 if !uses_introduced_vars {
859 expand.input = Box::new(self.try_push_filter_into(predicate, *expand.input));
861 LogicalOperator::Expand(expand)
862 } else {
863 LogicalOperator::Filter(FilterOp {
865 predicate,
866 pushdown_hint: None,
867 input: Box::new(LogicalOperator::Expand(expand)),
868 })
869 }
870 }
871
872 LogicalOperator::Join(mut join) => {
874 let predicate_vars = self.extract_variables(&predicate);
875 let left_vars = self.collect_output_variables(&join.left);
876 let right_vars = self.collect_output_variables(&join.right);
877
878 let uses_left = predicate_vars.iter().any(|v| left_vars.contains(v));
879 let uses_right = predicate_vars.iter().any(|v| right_vars.contains(v));
880
881 if uses_left && !uses_right {
882 join.left = Box::new(self.try_push_filter_into(predicate, *join.left));
884 LogicalOperator::Join(join)
885 } else if uses_right && !uses_left {
886 join.right = Box::new(self.try_push_filter_into(predicate, *join.right));
888 LogicalOperator::Join(join)
889 } else {
890 LogicalOperator::Filter(FilterOp {
892 predicate,
893 pushdown_hint: None,
894 input: Box::new(LogicalOperator::Join(join)),
895 })
896 }
897 }
898
899 LogicalOperator::Aggregate(agg) => LogicalOperator::Filter(FilterOp {
901 predicate,
902 pushdown_hint: None,
903 input: Box::new(LogicalOperator::Aggregate(agg)),
904 }),
905
906 LogicalOperator::NodeScan(scan) => LogicalOperator::Filter(FilterOp {
908 predicate,
909 pushdown_hint: None,
910 input: Box::new(LogicalOperator::NodeScan(scan)),
911 }),
912
913 other => LogicalOperator::Filter(FilterOp {
915 predicate,
916 pushdown_hint: None,
917 input: Box::new(other),
918 }),
919 }
920 }
921
922 fn collect_output_variables(&self, op: &LogicalOperator) -> HashSet<String> {
924 let mut vars = HashSet::new();
925 Self::collect_output_variables_recursive(op, &mut vars);
926 vars
927 }
928
929 fn collect_output_variables_recursive(op: &LogicalOperator, vars: &mut HashSet<String>) {
931 match op {
932 LogicalOperator::NodeScan(scan) => {
933 vars.insert(scan.variable.clone());
934 }
935 LogicalOperator::EdgeScan(scan) => {
936 vars.insert(scan.variable.clone());
937 }
938 LogicalOperator::Expand(expand) => {
939 vars.insert(expand.to_variable.clone());
940 if let Some(edge_var) = &expand.edge_variable {
941 vars.insert(edge_var.clone());
942 }
943 Self::collect_output_variables_recursive(&expand.input, vars);
944 }
945 LogicalOperator::Filter(filter) => {
946 Self::collect_output_variables_recursive(&filter.input, vars);
947 }
948 LogicalOperator::Project(proj) => {
949 for p in &proj.projections {
950 if let Some(alias) = &p.alias {
951 vars.insert(alias.clone());
952 }
953 }
954 Self::collect_output_variables_recursive(&proj.input, vars);
955 }
956 LogicalOperator::Join(join) => {
957 Self::collect_output_variables_recursive(&join.left, vars);
958 Self::collect_output_variables_recursive(&join.right, vars);
959 }
960 LogicalOperator::Aggregate(agg) => {
961 for expr in &agg.group_by {
962 Self::collect_variables(expr, vars);
963 }
964 for agg_expr in &agg.aggregates {
965 if let Some(alias) = &agg_expr.alias {
966 vars.insert(alias.clone());
967 }
968 }
969 }
970 LogicalOperator::Return(ret) => {
971 Self::collect_output_variables_recursive(&ret.input, vars);
972 }
973 LogicalOperator::Limit(limit) => {
974 Self::collect_output_variables_recursive(&limit.input, vars);
975 }
976 LogicalOperator::Skip(skip) => {
977 Self::collect_output_variables_recursive(&skip.input, vars);
978 }
979 LogicalOperator::Sort(sort) => {
980 Self::collect_output_variables_recursive(&sort.input, vars);
981 }
982 LogicalOperator::Distinct(distinct) => {
983 Self::collect_output_variables_recursive(&distinct.input, vars);
984 }
985 _ => {}
986 }
987 }
988
989 fn extract_variables(&self, expr: &LogicalExpression) -> HashSet<String> {
991 let mut vars = HashSet::new();
992 Self::collect_variables(expr, &mut vars);
993 vars
994 }
995
996 fn collect_variables(expr: &LogicalExpression, vars: &mut HashSet<String>) {
998 match expr {
999 LogicalExpression::Variable(name) => {
1000 vars.insert(name.clone());
1001 }
1002 LogicalExpression::Property { variable, .. } => {
1003 vars.insert(variable.clone());
1004 }
1005 LogicalExpression::Binary { left, right, .. } => {
1006 Self::collect_variables(left, vars);
1007 Self::collect_variables(right, vars);
1008 }
1009 LogicalExpression::Unary { operand, .. } => {
1010 Self::collect_variables(operand, vars);
1011 }
1012 LogicalExpression::FunctionCall { args, .. } => {
1013 for arg in args {
1014 Self::collect_variables(arg, vars);
1015 }
1016 }
1017 LogicalExpression::List(items) => {
1018 for item in items {
1019 Self::collect_variables(item, vars);
1020 }
1021 }
1022 LogicalExpression::Map(pairs) => {
1023 for (_, value) in pairs {
1024 Self::collect_variables(value, vars);
1025 }
1026 }
1027 LogicalExpression::IndexAccess { base, index } => {
1028 Self::collect_variables(base, vars);
1029 Self::collect_variables(index, vars);
1030 }
1031 LogicalExpression::SliceAccess { base, start, end } => {
1032 Self::collect_variables(base, vars);
1033 if let Some(s) = start {
1034 Self::collect_variables(s, vars);
1035 }
1036 if let Some(e) = end {
1037 Self::collect_variables(e, vars);
1038 }
1039 }
1040 LogicalExpression::Case {
1041 operand,
1042 when_clauses,
1043 else_clause,
1044 } => {
1045 if let Some(op) = operand {
1046 Self::collect_variables(op, vars);
1047 }
1048 for (cond, result) in when_clauses {
1049 Self::collect_variables(cond, vars);
1050 Self::collect_variables(result, vars);
1051 }
1052 if let Some(else_expr) = else_clause {
1053 Self::collect_variables(else_expr, vars);
1054 }
1055 }
1056 LogicalExpression::Labels(var)
1057 | LogicalExpression::Type(var)
1058 | LogicalExpression::Id(var) => {
1059 vars.insert(var.clone());
1060 }
1061 LogicalExpression::Literal(_) | LogicalExpression::Parameter(_) => {}
1062 LogicalExpression::ListComprehension {
1063 list_expr,
1064 filter_expr,
1065 map_expr,
1066 ..
1067 } => {
1068 Self::collect_variables(list_expr, vars);
1069 if let Some(filter) = filter_expr {
1070 Self::collect_variables(filter, vars);
1071 }
1072 Self::collect_variables(map_expr, vars);
1073 }
1074 LogicalExpression::ListPredicate {
1075 list_expr,
1076 predicate,
1077 ..
1078 } => {
1079 Self::collect_variables(list_expr, vars);
1080 Self::collect_variables(predicate, vars);
1081 }
1082 LogicalExpression::ExistsSubquery(_)
1083 | LogicalExpression::CountSubquery(_)
1084 | LogicalExpression::ValueSubquery(_) => {
1085 }
1087 LogicalExpression::PatternComprehension { projection, .. } => {
1088 Self::collect_variables(projection, vars);
1089 }
1090 LogicalExpression::MapProjection { base, entries } => {
1091 vars.insert(base.clone());
1092 for entry in entries {
1093 if let crate::query::plan::MapProjectionEntry::LiteralEntry(_, expr) = entry {
1094 Self::collect_variables(expr, vars);
1095 }
1096 }
1097 }
1098 LogicalExpression::Reduce {
1099 initial,
1100 list,
1101 expression,
1102 ..
1103 } => {
1104 Self::collect_variables(initial, vars);
1105 Self::collect_variables(list, vars);
1106 Self::collect_variables(expression, vars);
1107 }
1108 }
1109 }
1110
1111 fn extract_projection_aliases(
1113 &self,
1114 projections: &[crate::query::plan::Projection],
1115 ) -> HashSet<String> {
1116 projections.iter().filter_map(|p| p.alias.clone()).collect()
1117 }
1118}
1119
1120impl Default for Optimizer {
1121 fn default() -> Self {
1122 Self::new()
1123 }
1124}
1125
1126#[cfg(test)]
1127mod tests {
1128 use super::*;
1129 use crate::query::plan::{
1130 AggregateExpr, AggregateFunction, AggregateOp, BinaryOp, DistinctOp, ExpandDirection,
1131 ExpandOp, JoinOp, JoinType, LimitOp, NodeScanOp, PathMode, ProjectOp, Projection,
1132 ReturnItem, ReturnOp, SkipOp, SortKey, SortOp, SortOrder, UnaryOp,
1133 };
1134 use grafeo_common::types::Value;
1135
1136 #[test]
1137 fn test_optimizer_filter_pushdown_simple() {
1138 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1143 items: vec![ReturnItem {
1144 expression: LogicalExpression::Variable("n".to_string()),
1145 alias: None,
1146 }],
1147 distinct: false,
1148 input: Box::new(LogicalOperator::Filter(FilterOp {
1149 predicate: LogicalExpression::Binary {
1150 left: Box::new(LogicalExpression::Property {
1151 variable: "n".to_string(),
1152 property: "age".to_string(),
1153 }),
1154 op: BinaryOp::Gt,
1155 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1156 },
1157 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1158 variable: "n".to_string(),
1159 label: Some("Person".to_string()),
1160 input: None,
1161 })),
1162 pushdown_hint: None,
1163 })),
1164 }));
1165
1166 let optimizer = Optimizer::new();
1167 let optimized = optimizer.optimize(plan).unwrap();
1168
1169 if let LogicalOperator::Return(ret) = &optimized.root
1171 && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1172 && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1173 {
1174 assert_eq!(scan.variable, "n");
1175 return;
1176 }
1177 panic!("Expected Return -> Filter -> NodeScan structure");
1178 }
1179
1180 #[test]
1181 fn test_optimizer_filter_pushdown_through_expand() {
1182 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1186 items: vec![ReturnItem {
1187 expression: LogicalExpression::Variable("b".to_string()),
1188 alias: None,
1189 }],
1190 distinct: false,
1191 input: Box::new(LogicalOperator::Filter(FilterOp {
1192 predicate: LogicalExpression::Binary {
1193 left: Box::new(LogicalExpression::Property {
1194 variable: "a".to_string(),
1195 property: "age".to_string(),
1196 }),
1197 op: BinaryOp::Gt,
1198 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1199 },
1200 pushdown_hint: None,
1201 input: Box::new(LogicalOperator::Expand(ExpandOp {
1202 from_variable: "a".to_string(),
1203 to_variable: "b".to_string(),
1204 edge_variable: None,
1205 direction: ExpandDirection::Outgoing,
1206 edge_types: vec!["KNOWS".to_string()],
1207 min_hops: 1,
1208 max_hops: Some(1),
1209 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1210 variable: "a".to_string(),
1211 label: Some("Person".to_string()),
1212 input: None,
1213 })),
1214 path_alias: None,
1215 path_mode: PathMode::Walk,
1216 })),
1217 })),
1218 }));
1219
1220 let optimizer = Optimizer::new();
1221 let optimized = optimizer.optimize(plan).unwrap();
1222
1223 if let LogicalOperator::Return(ret) = &optimized.root
1226 && let LogicalOperator::Expand(expand) = ret.input.as_ref()
1227 && let LogicalOperator::Filter(filter) = expand.input.as_ref()
1228 && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1229 {
1230 assert_eq!(scan.variable, "a");
1231 assert_eq!(expand.from_variable, "a");
1232 assert_eq!(expand.to_variable, "b");
1233 return;
1234 }
1235 panic!("Expected Return -> Expand -> Filter -> NodeScan structure");
1236 }
1237
1238 #[test]
1239 fn test_optimizer_filter_not_pushed_through_expand_for_target_var() {
1240 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1244 items: vec![ReturnItem {
1245 expression: LogicalExpression::Variable("a".to_string()),
1246 alias: None,
1247 }],
1248 distinct: false,
1249 input: Box::new(LogicalOperator::Filter(FilterOp {
1250 predicate: LogicalExpression::Binary {
1251 left: Box::new(LogicalExpression::Property {
1252 variable: "b".to_string(),
1253 property: "age".to_string(),
1254 }),
1255 op: BinaryOp::Gt,
1256 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1257 },
1258 pushdown_hint: None,
1259 input: Box::new(LogicalOperator::Expand(ExpandOp {
1260 from_variable: "a".to_string(),
1261 to_variable: "b".to_string(),
1262 edge_variable: None,
1263 direction: ExpandDirection::Outgoing,
1264 edge_types: vec!["KNOWS".to_string()],
1265 min_hops: 1,
1266 max_hops: Some(1),
1267 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1268 variable: "a".to_string(),
1269 label: Some("Person".to_string()),
1270 input: None,
1271 })),
1272 path_alias: None,
1273 path_mode: PathMode::Walk,
1274 })),
1275 })),
1276 }));
1277
1278 let optimizer = Optimizer::new();
1279 let optimized = optimizer.optimize(plan).unwrap();
1280
1281 if let LogicalOperator::Return(ret) = &optimized.root
1284 && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1285 {
1286 if let LogicalExpression::Binary { left, .. } = &filter.predicate
1288 && let LogicalExpression::Property { variable, .. } = left.as_ref()
1289 {
1290 assert_eq!(variable, "b");
1291 }
1292
1293 if let LogicalOperator::Expand(expand) = filter.input.as_ref()
1294 && let LogicalOperator::NodeScan(_) = expand.input.as_ref()
1295 {
1296 return;
1297 }
1298 }
1299 panic!("Expected Return -> Filter -> Expand -> NodeScan structure");
1300 }
1301
1302 #[test]
1303 fn test_optimizer_extract_variables() {
1304 let optimizer = Optimizer::new();
1305
1306 let expr = LogicalExpression::Binary {
1307 left: Box::new(LogicalExpression::Property {
1308 variable: "n".to_string(),
1309 property: "age".to_string(),
1310 }),
1311 op: BinaryOp::Gt,
1312 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1313 };
1314
1315 let vars = optimizer.extract_variables(&expr);
1316 assert_eq!(vars.len(), 1);
1317 assert!(vars.contains("n"));
1318 }
1319
1320 #[test]
1323 fn test_optimizer_default() {
1324 let optimizer = Optimizer::default();
1325 let plan = LogicalPlan::new(LogicalOperator::Empty);
1327 let result = optimizer.optimize(plan);
1328 assert!(result.is_ok());
1329 }
1330
1331 #[test]
1332 fn test_optimizer_with_filter_pushdown_disabled() {
1333 let optimizer = Optimizer::new().with_filter_pushdown(false);
1334
1335 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1336 items: vec![ReturnItem {
1337 expression: LogicalExpression::Variable("n".to_string()),
1338 alias: None,
1339 }],
1340 distinct: false,
1341 input: Box::new(LogicalOperator::Filter(FilterOp {
1342 predicate: LogicalExpression::Literal(Value::Bool(true)),
1343 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1344 variable: "n".to_string(),
1345 label: None,
1346 input: None,
1347 })),
1348 pushdown_hint: None,
1349 })),
1350 }));
1351
1352 let optimized = optimizer.optimize(plan).unwrap();
1353 if let LogicalOperator::Return(ret) = &optimized.root
1355 && let LogicalOperator::Filter(_) = ret.input.as_ref()
1356 {
1357 return;
1358 }
1359 panic!("Expected unchanged structure");
1360 }
1361
1362 #[test]
1363 fn test_optimizer_with_join_reorder_disabled() {
1364 let optimizer = Optimizer::new().with_join_reorder(false);
1365 assert!(
1366 optimizer
1367 .optimize(LogicalPlan::new(LogicalOperator::Empty))
1368 .is_ok()
1369 );
1370 }
1371
1372 #[test]
1373 fn test_optimizer_with_cost_model() {
1374 let cost_model = CostModel::new();
1375 let optimizer = Optimizer::new().with_cost_model(cost_model);
1376 assert!(
1377 optimizer
1378 .cost_model()
1379 .estimate(&LogicalOperator::Empty, 0.0)
1380 .total()
1381 < 0.001
1382 );
1383 }
1384
1385 #[test]
1386 fn test_optimizer_with_cardinality_estimator() {
1387 let mut estimator = CardinalityEstimator::new();
1388 estimator.add_table_stats("Test", TableStats::new(500));
1389 let optimizer = Optimizer::new().with_cardinality_estimator(estimator);
1390
1391 let scan = LogicalOperator::NodeScan(NodeScanOp {
1392 variable: "n".to_string(),
1393 label: Some("Test".to_string()),
1394 input: None,
1395 });
1396 let plan = LogicalPlan::new(scan);
1397
1398 let cardinality = optimizer.estimate_cardinality(&plan);
1399 assert!((cardinality - 500.0).abs() < 0.001);
1400 }
1401
1402 #[test]
1403 fn test_optimizer_estimate_cost() {
1404 let optimizer = Optimizer::new();
1405 let plan = LogicalPlan::new(LogicalOperator::NodeScan(NodeScanOp {
1406 variable: "n".to_string(),
1407 label: None,
1408 input: None,
1409 }));
1410
1411 let cost = optimizer.estimate_cost(&plan);
1412 assert!(cost.total() > 0.0);
1413 }
1414
1415 #[test]
1418 fn test_filter_pushdown_through_project() {
1419 let optimizer = Optimizer::new();
1420
1421 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1422 predicate: LogicalExpression::Binary {
1423 left: Box::new(LogicalExpression::Property {
1424 variable: "n".to_string(),
1425 property: "age".to_string(),
1426 }),
1427 op: BinaryOp::Gt,
1428 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1429 },
1430 pushdown_hint: None,
1431 input: Box::new(LogicalOperator::Project(ProjectOp {
1432 projections: vec![Projection {
1433 expression: LogicalExpression::Variable("n".to_string()),
1434 alias: None,
1435 }],
1436 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1437 variable: "n".to_string(),
1438 label: None,
1439 input: None,
1440 })),
1441 pass_through_input: false,
1442 })),
1443 }));
1444
1445 let optimized = optimizer.optimize(plan).unwrap();
1446
1447 if let LogicalOperator::Project(proj) = &optimized.root
1449 && let LogicalOperator::Filter(_) = proj.input.as_ref()
1450 {
1451 return;
1452 }
1453 panic!("Expected Project -> Filter structure");
1454 }
1455
1456 #[test]
1457 fn test_filter_not_pushed_through_project_with_alias() {
1458 let optimizer = Optimizer::new();
1459
1460 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1462 predicate: LogicalExpression::Binary {
1463 left: Box::new(LogicalExpression::Variable("x".to_string())),
1464 op: BinaryOp::Gt,
1465 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1466 },
1467 pushdown_hint: None,
1468 input: Box::new(LogicalOperator::Project(ProjectOp {
1469 projections: vec![Projection {
1470 expression: LogicalExpression::Property {
1471 variable: "n".to_string(),
1472 property: "age".to_string(),
1473 },
1474 alias: Some("x".to_string()),
1475 }],
1476 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1477 variable: "n".to_string(),
1478 label: None,
1479 input: None,
1480 })),
1481 pass_through_input: false,
1482 })),
1483 }));
1484
1485 let optimized = optimizer.optimize(plan).unwrap();
1486
1487 if let LogicalOperator::Filter(filter) = &optimized.root
1489 && let LogicalOperator::Project(_) = filter.input.as_ref()
1490 {
1491 return;
1492 }
1493 panic!("Expected Filter -> Project structure");
1494 }
1495
1496 #[test]
1497 fn test_filter_pushdown_through_limit() {
1498 let optimizer = Optimizer::new();
1499
1500 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1501 predicate: LogicalExpression::Literal(Value::Bool(true)),
1502 pushdown_hint: None,
1503 input: Box::new(LogicalOperator::Limit(LimitOp {
1504 count: 10.into(),
1505 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1506 variable: "n".to_string(),
1507 label: None,
1508 input: None,
1509 })),
1510 })),
1511 }));
1512
1513 let optimized = optimizer.optimize(plan).unwrap();
1514
1515 if let LogicalOperator::Filter(filter) = &optimized.root
1517 && let LogicalOperator::Limit(_) = filter.input.as_ref()
1518 {
1519 return;
1520 }
1521 panic!("Expected Filter -> Limit structure");
1522 }
1523
1524 #[test]
1525 fn test_filter_pushdown_through_sort() {
1526 let optimizer = Optimizer::new();
1527
1528 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1529 predicate: LogicalExpression::Literal(Value::Bool(true)),
1530 pushdown_hint: None,
1531 input: Box::new(LogicalOperator::Sort(SortOp {
1532 keys: vec![SortKey {
1533 expression: LogicalExpression::Variable("n".to_string()),
1534 order: SortOrder::Ascending,
1535 nulls: None,
1536 }],
1537 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1538 variable: "n".to_string(),
1539 label: None,
1540 input: None,
1541 })),
1542 })),
1543 }));
1544
1545 let optimized = optimizer.optimize(plan).unwrap();
1546
1547 if let LogicalOperator::Filter(filter) = &optimized.root
1549 && let LogicalOperator::Sort(_) = filter.input.as_ref()
1550 {
1551 return;
1552 }
1553 panic!("Expected Filter -> Sort structure");
1554 }
1555
1556 #[test]
1557 fn test_filter_pushdown_through_distinct() {
1558 let optimizer = Optimizer::new();
1559
1560 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1561 predicate: LogicalExpression::Literal(Value::Bool(true)),
1562 pushdown_hint: None,
1563 input: Box::new(LogicalOperator::Distinct(DistinctOp {
1564 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1565 variable: "n".to_string(),
1566 label: None,
1567 input: None,
1568 })),
1569 columns: None,
1570 })),
1571 }));
1572
1573 let optimized = optimizer.optimize(plan).unwrap();
1574
1575 if let LogicalOperator::Filter(filter) = &optimized.root
1577 && let LogicalOperator::Distinct(_) = filter.input.as_ref()
1578 {
1579 return;
1580 }
1581 panic!("Expected Filter -> Distinct structure");
1582 }
1583
1584 #[test]
1585 fn test_filter_not_pushed_through_aggregate() {
1586 let optimizer = Optimizer::new();
1587
1588 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1589 predicate: LogicalExpression::Binary {
1590 left: Box::new(LogicalExpression::Variable("cnt".to_string())),
1591 op: BinaryOp::Gt,
1592 right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1593 },
1594 pushdown_hint: None,
1595 input: Box::new(LogicalOperator::Aggregate(AggregateOp {
1596 group_by: vec![],
1597 aggregates: vec![AggregateExpr {
1598 function: AggregateFunction::Count,
1599 expression: None,
1600 expression2: None,
1601 distinct: false,
1602 alias: Some("cnt".to_string()),
1603 percentile: None,
1604 separator: None,
1605 }],
1606 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1607 variable: "n".to_string(),
1608 label: None,
1609 input: None,
1610 })),
1611 having: None,
1612 })),
1613 }));
1614
1615 let optimized = optimizer.optimize(plan).unwrap();
1616
1617 if let LogicalOperator::Filter(filter) = &optimized.root
1619 && let LogicalOperator::Aggregate(_) = filter.input.as_ref()
1620 {
1621 return;
1622 }
1623 panic!("Expected Filter -> Aggregate structure");
1624 }
1625
1626 #[test]
1627 fn test_filter_pushdown_to_left_join_side() {
1628 let optimizer = Optimizer::new();
1629
1630 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1632 predicate: LogicalExpression::Binary {
1633 left: Box::new(LogicalExpression::Property {
1634 variable: "a".to_string(),
1635 property: "age".to_string(),
1636 }),
1637 op: BinaryOp::Gt,
1638 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1639 },
1640 pushdown_hint: None,
1641 input: Box::new(LogicalOperator::Join(JoinOp {
1642 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1643 variable: "a".to_string(),
1644 label: Some("Person".to_string()),
1645 input: None,
1646 })),
1647 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1648 variable: "b".to_string(),
1649 label: Some("Company".to_string()),
1650 input: None,
1651 })),
1652 join_type: JoinType::Inner,
1653 conditions: vec![],
1654 })),
1655 }));
1656
1657 let optimized = optimizer.optimize(plan).unwrap();
1658
1659 if let LogicalOperator::Join(join) = &optimized.root
1661 && let LogicalOperator::Filter(_) = join.left.as_ref()
1662 {
1663 return;
1664 }
1665 panic!("Expected Join with Filter on left side");
1666 }
1667
1668 #[test]
1669 fn test_filter_pushdown_to_right_join_side() {
1670 let optimizer = Optimizer::new();
1671
1672 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1674 predicate: LogicalExpression::Binary {
1675 left: Box::new(LogicalExpression::Property {
1676 variable: "b".to_string(),
1677 property: "name".to_string(),
1678 }),
1679 op: BinaryOp::Eq,
1680 right: Box::new(LogicalExpression::Literal(Value::String("Acme".into()))),
1681 },
1682 pushdown_hint: None,
1683 input: Box::new(LogicalOperator::Join(JoinOp {
1684 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1685 variable: "a".to_string(),
1686 label: Some("Person".to_string()),
1687 input: None,
1688 })),
1689 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1690 variable: "b".to_string(),
1691 label: Some("Company".to_string()),
1692 input: None,
1693 })),
1694 join_type: JoinType::Inner,
1695 conditions: vec![],
1696 })),
1697 }));
1698
1699 let optimized = optimizer.optimize(plan).unwrap();
1700
1701 if let LogicalOperator::Join(join) = &optimized.root
1703 && let LogicalOperator::Filter(_) = join.right.as_ref()
1704 {
1705 return;
1706 }
1707 panic!("Expected Join with Filter on right side");
1708 }
1709
1710 #[test]
1711 fn test_filter_not_pushed_when_uses_both_join_sides() {
1712 let optimizer = Optimizer::new();
1713
1714 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1716 predicate: LogicalExpression::Binary {
1717 left: Box::new(LogicalExpression::Property {
1718 variable: "a".to_string(),
1719 property: "id".to_string(),
1720 }),
1721 op: BinaryOp::Eq,
1722 right: Box::new(LogicalExpression::Property {
1723 variable: "b".to_string(),
1724 property: "a_id".to_string(),
1725 }),
1726 },
1727 pushdown_hint: None,
1728 input: Box::new(LogicalOperator::Join(JoinOp {
1729 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1730 variable: "a".to_string(),
1731 label: None,
1732 input: None,
1733 })),
1734 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1735 variable: "b".to_string(),
1736 label: None,
1737 input: None,
1738 })),
1739 join_type: JoinType::Inner,
1740 conditions: vec![],
1741 })),
1742 }));
1743
1744 let optimized = optimizer.optimize(plan).unwrap();
1745
1746 if let LogicalOperator::Filter(filter) = &optimized.root
1748 && let LogicalOperator::Join(_) = filter.input.as_ref()
1749 {
1750 return;
1751 }
1752 panic!("Expected Filter -> Join structure");
1753 }
1754
1755 #[test]
1758 fn test_extract_variables_from_variable() {
1759 let optimizer = Optimizer::new();
1760 let expr = LogicalExpression::Variable("x".to_string());
1761 let vars = optimizer.extract_variables(&expr);
1762 assert_eq!(vars.len(), 1);
1763 assert!(vars.contains("x"));
1764 }
1765
1766 #[test]
1767 fn test_extract_variables_from_unary() {
1768 let optimizer = Optimizer::new();
1769 let expr = LogicalExpression::Unary {
1770 op: UnaryOp::Not,
1771 operand: Box::new(LogicalExpression::Variable("x".to_string())),
1772 };
1773 let vars = optimizer.extract_variables(&expr);
1774 assert_eq!(vars.len(), 1);
1775 assert!(vars.contains("x"));
1776 }
1777
1778 #[test]
1779 fn test_extract_variables_from_function_call() {
1780 let optimizer = Optimizer::new();
1781 let expr = LogicalExpression::FunctionCall {
1782 name: "length".to_string(),
1783 args: vec![
1784 LogicalExpression::Variable("a".to_string()),
1785 LogicalExpression::Variable("b".to_string()),
1786 ],
1787 distinct: false,
1788 };
1789 let vars = optimizer.extract_variables(&expr);
1790 assert_eq!(vars.len(), 2);
1791 assert!(vars.contains("a"));
1792 assert!(vars.contains("b"));
1793 }
1794
1795 #[test]
1796 fn test_extract_variables_from_list() {
1797 let optimizer = Optimizer::new();
1798 let expr = LogicalExpression::List(vec![
1799 LogicalExpression::Variable("a".to_string()),
1800 LogicalExpression::Literal(Value::Int64(1)),
1801 LogicalExpression::Variable("b".to_string()),
1802 ]);
1803 let vars = optimizer.extract_variables(&expr);
1804 assert_eq!(vars.len(), 2);
1805 assert!(vars.contains("a"));
1806 assert!(vars.contains("b"));
1807 }
1808
1809 #[test]
1810 fn test_extract_variables_from_map() {
1811 let optimizer = Optimizer::new();
1812 let expr = LogicalExpression::Map(vec![
1813 (
1814 "key1".to_string(),
1815 LogicalExpression::Variable("a".to_string()),
1816 ),
1817 (
1818 "key2".to_string(),
1819 LogicalExpression::Variable("b".to_string()),
1820 ),
1821 ]);
1822 let vars = optimizer.extract_variables(&expr);
1823 assert_eq!(vars.len(), 2);
1824 assert!(vars.contains("a"));
1825 assert!(vars.contains("b"));
1826 }
1827
1828 #[test]
1829 fn test_extract_variables_from_index_access() {
1830 let optimizer = Optimizer::new();
1831 let expr = LogicalExpression::IndexAccess {
1832 base: Box::new(LogicalExpression::Variable("list".to_string())),
1833 index: Box::new(LogicalExpression::Variable("idx".to_string())),
1834 };
1835 let vars = optimizer.extract_variables(&expr);
1836 assert_eq!(vars.len(), 2);
1837 assert!(vars.contains("list"));
1838 assert!(vars.contains("idx"));
1839 }
1840
1841 #[test]
1842 fn test_extract_variables_from_slice_access() {
1843 let optimizer = Optimizer::new();
1844 let expr = LogicalExpression::SliceAccess {
1845 base: Box::new(LogicalExpression::Variable("list".to_string())),
1846 start: Some(Box::new(LogicalExpression::Variable("s".to_string()))),
1847 end: Some(Box::new(LogicalExpression::Variable("e".to_string()))),
1848 };
1849 let vars = optimizer.extract_variables(&expr);
1850 assert_eq!(vars.len(), 3);
1851 assert!(vars.contains("list"));
1852 assert!(vars.contains("s"));
1853 assert!(vars.contains("e"));
1854 }
1855
1856 #[test]
1857 fn test_extract_variables_from_case() {
1858 let optimizer = Optimizer::new();
1859 let expr = LogicalExpression::Case {
1860 operand: Some(Box::new(LogicalExpression::Variable("x".to_string()))),
1861 when_clauses: vec![(
1862 LogicalExpression::Literal(Value::Int64(1)),
1863 LogicalExpression::Variable("a".to_string()),
1864 )],
1865 else_clause: Some(Box::new(LogicalExpression::Variable("b".to_string()))),
1866 };
1867 let vars = optimizer.extract_variables(&expr);
1868 assert_eq!(vars.len(), 3);
1869 assert!(vars.contains("x"));
1870 assert!(vars.contains("a"));
1871 assert!(vars.contains("b"));
1872 }
1873
1874 #[test]
1875 fn test_extract_variables_from_labels() {
1876 let optimizer = Optimizer::new();
1877 let expr = LogicalExpression::Labels("n".to_string());
1878 let vars = optimizer.extract_variables(&expr);
1879 assert_eq!(vars.len(), 1);
1880 assert!(vars.contains("n"));
1881 }
1882
1883 #[test]
1884 fn test_extract_variables_from_type() {
1885 let optimizer = Optimizer::new();
1886 let expr = LogicalExpression::Type("e".to_string());
1887 let vars = optimizer.extract_variables(&expr);
1888 assert_eq!(vars.len(), 1);
1889 assert!(vars.contains("e"));
1890 }
1891
1892 #[test]
1893 fn test_extract_variables_from_id() {
1894 let optimizer = Optimizer::new();
1895 let expr = LogicalExpression::Id("n".to_string());
1896 let vars = optimizer.extract_variables(&expr);
1897 assert_eq!(vars.len(), 1);
1898 assert!(vars.contains("n"));
1899 }
1900
1901 #[test]
1902 fn test_extract_variables_from_list_comprehension() {
1903 let optimizer = Optimizer::new();
1904 let expr = LogicalExpression::ListComprehension {
1905 variable: "x".to_string(),
1906 list_expr: Box::new(LogicalExpression::Variable("items".to_string())),
1907 filter_expr: Some(Box::new(LogicalExpression::Variable("pred".to_string()))),
1908 map_expr: Box::new(LogicalExpression::Variable("result".to_string())),
1909 };
1910 let vars = optimizer.extract_variables(&expr);
1911 assert!(vars.contains("items"));
1912 assert!(vars.contains("pred"));
1913 assert!(vars.contains("result"));
1914 }
1915
1916 #[test]
1917 fn test_extract_variables_from_literal_and_parameter() {
1918 let optimizer = Optimizer::new();
1919
1920 let literal = LogicalExpression::Literal(Value::Int64(42));
1921 assert!(optimizer.extract_variables(&literal).is_empty());
1922
1923 let param = LogicalExpression::Parameter("p".to_string());
1924 assert!(optimizer.extract_variables(¶m).is_empty());
1925 }
1926
1927 #[test]
1930 fn test_recursive_filter_pushdown_through_skip() {
1931 let optimizer = Optimizer::new();
1932
1933 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1934 items: vec![ReturnItem {
1935 expression: LogicalExpression::Variable("n".to_string()),
1936 alias: None,
1937 }],
1938 distinct: false,
1939 input: Box::new(LogicalOperator::Filter(FilterOp {
1940 predicate: LogicalExpression::Literal(Value::Bool(true)),
1941 pushdown_hint: None,
1942 input: Box::new(LogicalOperator::Skip(SkipOp {
1943 count: 5.into(),
1944 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1945 variable: "n".to_string(),
1946 label: None,
1947 input: None,
1948 })),
1949 })),
1950 })),
1951 }));
1952
1953 let optimized = optimizer.optimize(plan).unwrap();
1954
1955 assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1957 }
1958
1959 #[test]
1960 fn test_nested_filter_pushdown() {
1961 let optimizer = Optimizer::new();
1962
1963 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1965 items: vec![ReturnItem {
1966 expression: LogicalExpression::Variable("n".to_string()),
1967 alias: None,
1968 }],
1969 distinct: false,
1970 input: Box::new(LogicalOperator::Filter(FilterOp {
1971 predicate: LogicalExpression::Binary {
1972 left: Box::new(LogicalExpression::Property {
1973 variable: "n".to_string(),
1974 property: "x".to_string(),
1975 }),
1976 op: BinaryOp::Gt,
1977 right: Box::new(LogicalExpression::Literal(Value::Int64(1))),
1978 },
1979 pushdown_hint: None,
1980 input: Box::new(LogicalOperator::Filter(FilterOp {
1981 predicate: LogicalExpression::Binary {
1982 left: Box::new(LogicalExpression::Property {
1983 variable: "n".to_string(),
1984 property: "y".to_string(),
1985 }),
1986 op: BinaryOp::Lt,
1987 right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1988 },
1989 pushdown_hint: None,
1990 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1991 variable: "n".to_string(),
1992 label: None,
1993 input: None,
1994 })),
1995 })),
1996 })),
1997 }));
1998
1999 let optimized = optimizer.optimize(plan).unwrap();
2000 assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
2001 }
2002
2003 #[test]
2004 fn test_cyclic_join_produces_multi_way_join() {
2005 use crate::query::plan::JoinCondition;
2006
2007 let scan_a = LogicalOperator::NodeScan(NodeScanOp {
2009 variable: "a".to_string(),
2010 label: Some("Person".to_string()),
2011 input: None,
2012 });
2013 let scan_b = LogicalOperator::NodeScan(NodeScanOp {
2014 variable: "b".to_string(),
2015 label: Some("Person".to_string()),
2016 input: None,
2017 });
2018 let scan_c = LogicalOperator::NodeScan(NodeScanOp {
2019 variable: "c".to_string(),
2020 label: Some("Person".to_string()),
2021 input: None,
2022 });
2023
2024 let join_ab = LogicalOperator::Join(JoinOp {
2026 left: Box::new(scan_a),
2027 right: Box::new(scan_b),
2028 join_type: JoinType::Inner,
2029 conditions: vec![JoinCondition {
2030 left: LogicalExpression::Variable("a".to_string()),
2031 right: LogicalExpression::Variable("b".to_string()),
2032 }],
2033 });
2034
2035 let join_abc = LogicalOperator::Join(JoinOp {
2036 left: Box::new(join_ab),
2037 right: Box::new(scan_c),
2038 join_type: JoinType::Inner,
2039 conditions: vec![
2040 JoinCondition {
2041 left: LogicalExpression::Variable("b".to_string()),
2042 right: LogicalExpression::Variable("c".to_string()),
2043 },
2044 JoinCondition {
2045 left: LogicalExpression::Variable("c".to_string()),
2046 right: LogicalExpression::Variable("a".to_string()),
2047 },
2048 ],
2049 });
2050
2051 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
2052 items: vec![ReturnItem {
2053 expression: LogicalExpression::Variable("a".to_string()),
2054 alias: None,
2055 }],
2056 distinct: false,
2057 input: Box::new(join_abc),
2058 }));
2059
2060 let mut optimizer = Optimizer::new();
2061 optimizer
2062 .card_estimator
2063 .add_table_stats("Person", cardinality::TableStats::new(1000));
2064
2065 let optimized = optimizer.optimize(plan).unwrap();
2066
2067 fn has_multi_way_join(op: &LogicalOperator) -> bool {
2069 match op {
2070 LogicalOperator::MultiWayJoin(_) => true,
2071 LogicalOperator::Return(ret) => has_multi_way_join(&ret.input),
2072 LogicalOperator::Filter(f) => has_multi_way_join(&f.input),
2073 LogicalOperator::Project(p) => has_multi_way_join(&p.input),
2074 _ => false,
2075 }
2076 }
2077
2078 assert!(
2079 has_multi_way_join(&optimized.root),
2080 "Expected MultiWayJoin for cyclic triangle pattern"
2081 );
2082 }
2083
2084 #[test]
2085 fn test_acyclic_join_uses_binary_joins() {
2086 use crate::query::plan::JoinCondition;
2087
2088 let scan_a = LogicalOperator::NodeScan(NodeScanOp {
2090 variable: "a".to_string(),
2091 label: Some("Person".to_string()),
2092 input: None,
2093 });
2094 let scan_b = LogicalOperator::NodeScan(NodeScanOp {
2095 variable: "b".to_string(),
2096 label: Some("Person".to_string()),
2097 input: None,
2098 });
2099 let scan_c = LogicalOperator::NodeScan(NodeScanOp {
2100 variable: "c".to_string(),
2101 label: Some("Company".to_string()),
2102 input: None,
2103 });
2104
2105 let join_ab = LogicalOperator::Join(JoinOp {
2106 left: Box::new(scan_a),
2107 right: Box::new(scan_b),
2108 join_type: JoinType::Inner,
2109 conditions: vec![JoinCondition {
2110 left: LogicalExpression::Variable("a".to_string()),
2111 right: LogicalExpression::Variable("b".to_string()),
2112 }],
2113 });
2114
2115 let join_abc = LogicalOperator::Join(JoinOp {
2116 left: Box::new(join_ab),
2117 right: Box::new(scan_c),
2118 join_type: JoinType::Inner,
2119 conditions: vec![JoinCondition {
2120 left: LogicalExpression::Variable("b".to_string()),
2121 right: LogicalExpression::Variable("c".to_string()),
2122 }],
2123 });
2124
2125 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
2126 items: vec![ReturnItem {
2127 expression: LogicalExpression::Variable("a".to_string()),
2128 alias: None,
2129 }],
2130 distinct: false,
2131 input: Box::new(join_abc),
2132 }));
2133
2134 let mut optimizer = Optimizer::new();
2135 optimizer
2136 .card_estimator
2137 .add_table_stats("Person", cardinality::TableStats::new(1000));
2138 optimizer
2139 .card_estimator
2140 .add_table_stats("Company", cardinality::TableStats::new(100));
2141
2142 let optimized = optimizer.optimize(plan).unwrap();
2143
2144 fn has_multi_way_join(op: &LogicalOperator) -> bool {
2146 match op {
2147 LogicalOperator::MultiWayJoin(_) => true,
2148 LogicalOperator::Return(ret) => has_multi_way_join(&ret.input),
2149 LogicalOperator::Filter(f) => has_multi_way_join(&f.input),
2150 LogicalOperator::Project(p) => has_multi_way_join(&p.input),
2151 LogicalOperator::Join(j) => {
2152 has_multi_way_join(&j.left) || has_multi_way_join(&j.right)
2153 }
2154 _ => false,
2155 }
2156 }
2157
2158 assert!(
2159 !has_multi_way_join(&optimized.root),
2160 "Acyclic join should NOT produce MultiWayJoin"
2161 );
2162 }
2163}