1pub mod cardinality;
15pub mod cost;
16pub mod join_order;
17
18pub use cardinality::{
19 CardinalityEstimator, ColumnStats, EstimationLog, SelectivityConfig, TableStats,
20};
21pub use cost::{Cost, CostModel};
22pub use join_order::{BitSet, DPccp, JoinGraph, JoinGraphBuilder, JoinPlan};
23
24use crate::query::plan::{
25 FilterOp, JoinCondition, LogicalExpression, LogicalOperator, LogicalPlan, MultiWayJoinOp,
26};
27use grafeo_common::grafeo_debug_span;
28use grafeo_common::utils::error::Result;
29use std::collections::HashSet;
30
31#[derive(Debug, Clone)]
33struct JoinInfo {
34 left_var: String,
35 right_var: String,
36 left_expr: LogicalExpression,
37 right_expr: LogicalExpression,
38}
39
40#[derive(Debug, Clone, PartialEq, Eq, Hash)]
42enum RequiredColumn {
43 Variable(String),
45 Property(String, String),
47}
48
49pub struct Optimizer {
54 enable_filter_pushdown: bool,
56 enable_join_reorder: bool,
58 enable_projection_pushdown: bool,
60 cost_model: CostModel,
62 card_estimator: CardinalityEstimator,
64}
65
66impl Optimizer {
67 #[must_use]
69 pub fn new() -> Self {
70 Self {
71 enable_filter_pushdown: true,
72 enable_join_reorder: true,
73 enable_projection_pushdown: true,
74 cost_model: CostModel::new(),
75 card_estimator: CardinalityEstimator::new(),
76 }
77 }
78
79 #[must_use]
85 pub fn from_store(store: &grafeo_core::graph::lpg::LpgStore) -> Self {
86 store.ensure_statistics_fresh();
87 let stats = store.statistics();
88 Self::from_statistics(&stats)
89 }
90
91 #[must_use]
98 pub fn from_graph_store(store: &dyn grafeo_core::graph::GraphStore) -> Self {
99 let stats = store.statistics();
100 Self::from_statistics(&stats)
101 }
102
103 #[cfg(feature = "rdf")]
108 #[must_use]
109 pub fn from_rdf_statistics(rdf_stats: grafeo_core::statistics::RdfStatistics) -> Self {
110 let total = rdf_stats.total_triples;
111 let estimator = CardinalityEstimator::from_rdf_statistics(rdf_stats);
112 Self {
113 enable_filter_pushdown: true,
114 enable_join_reorder: true,
115 enable_projection_pushdown: true,
116 cost_model: CostModel::new().with_graph_totals(total, total),
117 card_estimator: estimator,
118 }
119 }
120
121 #[must_use]
126 fn from_statistics(stats: &grafeo_core::statistics::Statistics) -> Self {
127 let estimator = CardinalityEstimator::from_statistics(stats);
128
129 let avg_fanout = if stats.total_nodes > 0 {
130 (stats.total_edges as f64 / stats.total_nodes as f64).max(1.0)
131 } else {
132 10.0
133 };
134
135 let edge_type_degrees: std::collections::HashMap<String, (f64, f64)> = stats
136 .edge_types
137 .iter()
138 .map(|(name, et)| (name.clone(), (et.avg_out_degree, et.avg_in_degree)))
139 .collect();
140
141 let label_cardinalities: std::collections::HashMap<String, u64> = stats
142 .labels
143 .iter()
144 .map(|(name, ls)| (name.clone(), ls.node_count))
145 .collect();
146
147 Self {
148 enable_filter_pushdown: true,
149 enable_join_reorder: true,
150 enable_projection_pushdown: true,
151 cost_model: CostModel::new()
152 .with_avg_fanout(avg_fanout)
153 .with_edge_type_degrees(edge_type_degrees)
154 .with_label_cardinalities(label_cardinalities)
155 .with_graph_totals(stats.total_nodes, stats.total_edges),
156 card_estimator: estimator,
157 }
158 }
159
160 pub fn with_filter_pushdown(mut self, enabled: bool) -> Self {
162 self.enable_filter_pushdown = enabled;
163 self
164 }
165
166 pub fn with_join_reorder(mut self, enabled: bool) -> Self {
168 self.enable_join_reorder = enabled;
169 self
170 }
171
172 pub fn with_projection_pushdown(mut self, enabled: bool) -> Self {
174 self.enable_projection_pushdown = enabled;
175 self
176 }
177
178 pub fn with_cost_model(mut self, cost_model: CostModel) -> Self {
180 self.cost_model = cost_model;
181 self
182 }
183
184 pub fn with_cardinality_estimator(mut self, estimator: CardinalityEstimator) -> Self {
186 self.card_estimator = estimator;
187 self
188 }
189
190 pub fn with_selectivity_config(mut self, config: SelectivityConfig) -> Self {
192 self.card_estimator = CardinalityEstimator::with_selectivity_config(config);
193 self
194 }
195
196 pub fn cost_model(&self) -> &CostModel {
198 &self.cost_model
199 }
200
201 pub fn cardinality_estimator(&self) -> &CardinalityEstimator {
203 &self.card_estimator
204 }
205
206 pub fn estimate_cost(&self, plan: &LogicalPlan) -> Cost {
212 self.cost_model
213 .estimate_tree(&plan.root, &self.card_estimator)
214 }
215
216 pub fn estimate_cardinality(&self, plan: &LogicalPlan) -> f64 {
218 self.card_estimator.estimate(&plan.root)
219 }
220
221 pub fn optimize(&self, plan: LogicalPlan) -> Result<LogicalPlan> {
227 let _span = grafeo_debug_span!("grafeo::query::optimize");
228 let mut root = plan.root;
229
230 if self.enable_filter_pushdown {
232 root = self.push_filters_down(root);
233 }
234
235 if self.enable_join_reorder {
236 root = self.reorder_joins(root);
237 }
238
239 if self.enable_projection_pushdown {
240 root = self.push_projections_down(root);
241 }
242
243 Ok(LogicalPlan {
244 root,
245 explain: plan.explain,
246 profile: plan.profile,
247 })
248 }
249
250 fn push_projections_down(&self, op: LogicalOperator) -> LogicalOperator {
257 let required = self.collect_required_columns(&op);
259
260 self.push_projections_recursive(op, &required)
262 }
263
264 fn collect_required_columns(&self, op: &LogicalOperator) -> HashSet<RequiredColumn> {
266 let mut required = HashSet::new();
267 Self::collect_required_recursive(op, &mut required);
268 required
269 }
270
271 fn collect_required_recursive(op: &LogicalOperator, required: &mut HashSet<RequiredColumn>) {
273 match op {
274 LogicalOperator::Return(ret) => {
275 for item in &ret.items {
276 Self::collect_from_expression(&item.expression, required);
277 }
278 Self::collect_required_recursive(&ret.input, required);
279 }
280 LogicalOperator::Project(proj) => {
281 for p in &proj.projections {
282 Self::collect_from_expression(&p.expression, required);
283 }
284 Self::collect_required_recursive(&proj.input, required);
285 }
286 LogicalOperator::Filter(filter) => {
287 Self::collect_from_expression(&filter.predicate, required);
288 Self::collect_required_recursive(&filter.input, required);
289 }
290 LogicalOperator::Sort(sort) => {
291 for key in &sort.keys {
292 Self::collect_from_expression(&key.expression, required);
293 }
294 Self::collect_required_recursive(&sort.input, required);
295 }
296 LogicalOperator::Aggregate(agg) => {
297 for expr in &agg.group_by {
298 Self::collect_from_expression(expr, required);
299 }
300 for agg_expr in &agg.aggregates {
301 if let Some(ref expr) = agg_expr.expression {
302 Self::collect_from_expression(expr, required);
303 }
304 }
305 if let Some(ref having) = agg.having {
306 Self::collect_from_expression(having, required);
307 }
308 Self::collect_required_recursive(&agg.input, required);
309 }
310 LogicalOperator::Join(join) => {
311 for cond in &join.conditions {
312 Self::collect_from_expression(&cond.left, required);
313 Self::collect_from_expression(&cond.right, required);
314 }
315 Self::collect_required_recursive(&join.left, required);
316 Self::collect_required_recursive(&join.right, required);
317 }
318 LogicalOperator::Expand(expand) => {
319 required.insert(RequiredColumn::Variable(expand.from_variable.clone()));
321 required.insert(RequiredColumn::Variable(expand.to_variable.clone()));
322 if let Some(ref edge_var) = expand.edge_variable {
323 required.insert(RequiredColumn::Variable(edge_var.clone()));
324 }
325 Self::collect_required_recursive(&expand.input, required);
326 }
327 LogicalOperator::Limit(limit) => {
328 Self::collect_required_recursive(&limit.input, required);
329 }
330 LogicalOperator::Skip(skip) => {
331 Self::collect_required_recursive(&skip.input, required);
332 }
333 LogicalOperator::Distinct(distinct) => {
334 Self::collect_required_recursive(&distinct.input, required);
335 }
336 LogicalOperator::NodeScan(scan) => {
337 required.insert(RequiredColumn::Variable(scan.variable.clone()));
338 }
339 LogicalOperator::EdgeScan(scan) => {
340 required.insert(RequiredColumn::Variable(scan.variable.clone()));
341 }
342 LogicalOperator::MultiWayJoin(mwj) => {
343 for cond in &mwj.conditions {
344 Self::collect_from_expression(&cond.left, required);
345 Self::collect_from_expression(&cond.right, required);
346 }
347 for input in &mwj.inputs {
348 Self::collect_required_recursive(input, required);
349 }
350 }
351 _ => {}
352 }
353 }
354
355 fn collect_from_expression(expr: &LogicalExpression, required: &mut HashSet<RequiredColumn>) {
357 match expr {
358 LogicalExpression::Variable(var) => {
359 required.insert(RequiredColumn::Variable(var.clone()));
360 }
361 LogicalExpression::Property { variable, property } => {
362 required.insert(RequiredColumn::Property(variable.clone(), property.clone()));
363 required.insert(RequiredColumn::Variable(variable.clone()));
364 }
365 LogicalExpression::Binary { left, right, .. } => {
366 Self::collect_from_expression(left, required);
367 Self::collect_from_expression(right, required);
368 }
369 LogicalExpression::Unary { operand, .. } => {
370 Self::collect_from_expression(operand, required);
371 }
372 LogicalExpression::FunctionCall { args, .. } => {
373 for arg in args {
374 Self::collect_from_expression(arg, required);
375 }
376 }
377 LogicalExpression::List(items) => {
378 for item in items {
379 Self::collect_from_expression(item, required);
380 }
381 }
382 LogicalExpression::Map(pairs) => {
383 for (_, value) in pairs {
384 Self::collect_from_expression(value, required);
385 }
386 }
387 LogicalExpression::IndexAccess { base, index } => {
388 Self::collect_from_expression(base, required);
389 Self::collect_from_expression(index, required);
390 }
391 LogicalExpression::SliceAccess { base, start, end } => {
392 Self::collect_from_expression(base, required);
393 if let Some(s) = start {
394 Self::collect_from_expression(s, required);
395 }
396 if let Some(e) = end {
397 Self::collect_from_expression(e, required);
398 }
399 }
400 LogicalExpression::Case {
401 operand,
402 when_clauses,
403 else_clause,
404 } => {
405 if let Some(op) = operand {
406 Self::collect_from_expression(op, required);
407 }
408 for (cond, result) in when_clauses {
409 Self::collect_from_expression(cond, required);
410 Self::collect_from_expression(result, required);
411 }
412 if let Some(else_expr) = else_clause {
413 Self::collect_from_expression(else_expr, required);
414 }
415 }
416 LogicalExpression::Labels(var)
417 | LogicalExpression::Type(var)
418 | LogicalExpression::Id(var) => {
419 required.insert(RequiredColumn::Variable(var.clone()));
420 }
421 LogicalExpression::ListComprehension {
422 list_expr,
423 filter_expr,
424 map_expr,
425 ..
426 } => {
427 Self::collect_from_expression(list_expr, required);
428 if let Some(filter) = filter_expr {
429 Self::collect_from_expression(filter, required);
430 }
431 Self::collect_from_expression(map_expr, required);
432 }
433 _ => {}
434 }
435 }
436
437 fn push_projections_recursive(
439 &self,
440 op: LogicalOperator,
441 required: &HashSet<RequiredColumn>,
442 ) -> LogicalOperator {
443 match op {
444 LogicalOperator::Return(mut ret) => {
445 ret.input = Box::new(self.push_projections_recursive(*ret.input, required));
446 LogicalOperator::Return(ret)
447 }
448 LogicalOperator::Project(mut proj) => {
449 proj.input = Box::new(self.push_projections_recursive(*proj.input, required));
450 LogicalOperator::Project(proj)
451 }
452 LogicalOperator::Filter(mut filter) => {
453 filter.input = Box::new(self.push_projections_recursive(*filter.input, required));
454 LogicalOperator::Filter(filter)
455 }
456 LogicalOperator::Sort(mut sort) => {
457 sort.input = Box::new(self.push_projections_recursive(*sort.input, required));
460 LogicalOperator::Sort(sort)
461 }
462 LogicalOperator::Aggregate(mut agg) => {
463 agg.input = Box::new(self.push_projections_recursive(*agg.input, required));
464 LogicalOperator::Aggregate(agg)
465 }
466 LogicalOperator::Join(mut join) => {
467 let left_vars = self.collect_output_variables(&join.left);
470 let right_vars = self.collect_output_variables(&join.right);
471
472 let left_required: HashSet<_> = required
474 .iter()
475 .filter(|c| match c {
476 RequiredColumn::Variable(v) => left_vars.contains(v),
477 RequiredColumn::Property(v, _) => left_vars.contains(v),
478 })
479 .cloned()
480 .collect();
481
482 let right_required: HashSet<_> = required
483 .iter()
484 .filter(|c| match c {
485 RequiredColumn::Variable(v) => right_vars.contains(v),
486 RequiredColumn::Property(v, _) => right_vars.contains(v),
487 })
488 .cloned()
489 .collect();
490
491 join.left = Box::new(self.push_projections_recursive(*join.left, &left_required));
492 join.right =
493 Box::new(self.push_projections_recursive(*join.right, &right_required));
494 LogicalOperator::Join(join)
495 }
496 LogicalOperator::Expand(mut expand) => {
497 expand.input = Box::new(self.push_projections_recursive(*expand.input, required));
498 LogicalOperator::Expand(expand)
499 }
500 LogicalOperator::Limit(mut limit) => {
501 limit.input = Box::new(self.push_projections_recursive(*limit.input, required));
502 LogicalOperator::Limit(limit)
503 }
504 LogicalOperator::Skip(mut skip) => {
505 skip.input = Box::new(self.push_projections_recursive(*skip.input, required));
506 LogicalOperator::Skip(skip)
507 }
508 LogicalOperator::Distinct(mut distinct) => {
509 distinct.input =
510 Box::new(self.push_projections_recursive(*distinct.input, required));
511 LogicalOperator::Distinct(distinct)
512 }
513 LogicalOperator::MapCollect(mut mc) => {
514 mc.input = Box::new(self.push_projections_recursive(*mc.input, required));
515 LogicalOperator::MapCollect(mc)
516 }
517 LogicalOperator::MultiWayJoin(mut mwj) => {
518 mwj.inputs = mwj
519 .inputs
520 .into_iter()
521 .map(|input| self.push_projections_recursive(input, required))
522 .collect();
523 LogicalOperator::MultiWayJoin(mwj)
524 }
525 other => other,
526 }
527 }
528
529 fn reorder_joins(&self, op: LogicalOperator) -> LogicalOperator {
536 let op = self.reorder_joins_recursive(op);
538
539 if let Some((relations, conditions)) = self.extract_join_tree(&op)
541 && relations.len() >= 2
542 && let Some(optimized) = self.optimize_join_order(&relations, &conditions)
543 {
544 return optimized;
545 }
546
547 op
548 }
549
550 fn reorder_joins_recursive(&self, op: LogicalOperator) -> LogicalOperator {
552 match op {
553 LogicalOperator::Return(mut ret) => {
554 ret.input = Box::new(self.reorder_joins(*ret.input));
555 LogicalOperator::Return(ret)
556 }
557 LogicalOperator::Project(mut proj) => {
558 proj.input = Box::new(self.reorder_joins(*proj.input));
559 LogicalOperator::Project(proj)
560 }
561 LogicalOperator::Filter(mut filter) => {
562 filter.input = Box::new(self.reorder_joins(*filter.input));
563 LogicalOperator::Filter(filter)
564 }
565 LogicalOperator::Limit(mut limit) => {
566 limit.input = Box::new(self.reorder_joins(*limit.input));
567 LogicalOperator::Limit(limit)
568 }
569 LogicalOperator::Skip(mut skip) => {
570 skip.input = Box::new(self.reorder_joins(*skip.input));
571 LogicalOperator::Skip(skip)
572 }
573 LogicalOperator::Sort(mut sort) => {
574 sort.input = Box::new(self.reorder_joins(*sort.input));
575 LogicalOperator::Sort(sort)
576 }
577 LogicalOperator::Distinct(mut distinct) => {
578 distinct.input = Box::new(self.reorder_joins(*distinct.input));
579 LogicalOperator::Distinct(distinct)
580 }
581 LogicalOperator::Aggregate(mut agg) => {
582 agg.input = Box::new(self.reorder_joins(*agg.input));
583 LogicalOperator::Aggregate(agg)
584 }
585 LogicalOperator::Expand(mut expand) => {
586 expand.input = Box::new(self.reorder_joins(*expand.input));
587 LogicalOperator::Expand(expand)
588 }
589 LogicalOperator::MapCollect(mut mc) => {
590 mc.input = Box::new(self.reorder_joins(*mc.input));
591 LogicalOperator::MapCollect(mc)
592 }
593 LogicalOperator::MultiWayJoin(mut mwj) => {
594 mwj.inputs = mwj
595 .inputs
596 .into_iter()
597 .map(|input| self.reorder_joins(input))
598 .collect();
599 LogicalOperator::MultiWayJoin(mwj)
600 }
601 other => other,
603 }
604 }
605
606 fn extract_join_tree(
610 &self,
611 op: &LogicalOperator,
612 ) -> Option<(Vec<(String, LogicalOperator)>, Vec<JoinInfo>)> {
613 let mut relations = Vec::new();
614 let mut join_conditions = Vec::new();
615
616 if !self.collect_join_tree(op, &mut relations, &mut join_conditions) {
617 return None;
618 }
619
620 if relations.len() < 2 {
621 return None;
622 }
623
624 Some((relations, join_conditions))
625 }
626
627 fn collect_join_tree(
631 &self,
632 op: &LogicalOperator,
633 relations: &mut Vec<(String, LogicalOperator)>,
634 conditions: &mut Vec<JoinInfo>,
635 ) -> bool {
636 match op {
637 LogicalOperator::Join(join) => {
638 let left_ok = self.collect_join_tree(&join.left, relations, conditions);
640 let right_ok = self.collect_join_tree(&join.right, relations, conditions);
641
642 for cond in &join.conditions {
644 if let (Some(left_var), Some(right_var)) = (
645 self.extract_variable_from_expr(&cond.left),
646 self.extract_variable_from_expr(&cond.right),
647 ) {
648 conditions.push(JoinInfo {
649 left_var,
650 right_var,
651 left_expr: cond.left.clone(),
652 right_expr: cond.right.clone(),
653 });
654 }
655 }
656
657 left_ok && right_ok
658 }
659 LogicalOperator::NodeScan(scan) => {
660 relations.push((scan.variable.clone(), op.clone()));
661 true
662 }
663 LogicalOperator::EdgeScan(scan) => {
664 relations.push((scan.variable.clone(), op.clone()));
665 true
666 }
667 LogicalOperator::Filter(filter) => {
668 self.collect_join_tree(&filter.input, relations, conditions)
670 }
671 LogicalOperator::Expand(expand) => {
672 relations.push((expand.to_variable.clone(), op.clone()));
675 true
676 }
677 _ => false,
678 }
679 }
680
681 fn extract_variable_from_expr(&self, expr: &LogicalExpression) -> Option<String> {
683 match expr {
684 LogicalExpression::Variable(v) => Some(v.clone()),
685 LogicalExpression::Property { variable, .. } => Some(variable.clone()),
686 _ => None,
687 }
688 }
689
690 fn optimize_join_order(
693 &self,
694 relations: &[(String, LogicalOperator)],
695 conditions: &[JoinInfo],
696 ) -> Option<LogicalOperator> {
697 use join_order::{DPccp, JoinGraphBuilder};
698
699 let mut builder = JoinGraphBuilder::new();
701
702 for (var, relation) in relations {
703 builder.add_relation(var, relation.clone());
704 }
705
706 for cond in conditions {
707 builder.add_join_condition(
708 &cond.left_var,
709 &cond.right_var,
710 cond.left_expr.clone(),
711 cond.right_expr.clone(),
712 );
713 }
714
715 let graph = builder.build();
716
717 if graph.is_cyclic() && relations.len() >= 3 {
722 let mut var_counts: std::collections::HashMap<&str, usize> =
724 std::collections::HashMap::new();
725 for cond in conditions {
726 *var_counts.entry(&cond.left_var).or_default() += 1;
727 *var_counts.entry(&cond.right_var).or_default() += 1;
728 }
729 let shared_variables: Vec<String> = var_counts
730 .into_iter()
731 .filter(|(_, count)| *count >= 2)
732 .map(|(var, _)| var.to_string())
733 .collect();
734
735 let join_conditions: Vec<JoinCondition> = conditions
736 .iter()
737 .map(|c| JoinCondition {
738 left: c.left_expr.clone(),
739 right: c.right_expr.clone(),
740 })
741 .collect();
742
743 return Some(LogicalOperator::MultiWayJoin(MultiWayJoinOp {
744 inputs: relations.iter().map(|(_, rel)| rel.clone()).collect(),
745 conditions: join_conditions,
746 shared_variables,
747 }));
748 }
749
750 let mut dpccp = DPccp::new(&graph, &self.cost_model, &self.card_estimator);
752 let plan = dpccp.optimize()?;
753
754 Some(plan.operator)
755 }
756
757 fn push_filters_down(&self, op: LogicalOperator) -> LogicalOperator {
762 match op {
763 LogicalOperator::Filter(filter) => {
765 let optimized_input = self.push_filters_down(*filter.input);
766 self.try_push_filter_into(filter.predicate, optimized_input)
767 }
768 LogicalOperator::Return(mut ret) => {
770 ret.input = Box::new(self.push_filters_down(*ret.input));
771 LogicalOperator::Return(ret)
772 }
773 LogicalOperator::Project(mut proj) => {
774 proj.input = Box::new(self.push_filters_down(*proj.input));
775 LogicalOperator::Project(proj)
776 }
777 LogicalOperator::Limit(mut limit) => {
778 limit.input = Box::new(self.push_filters_down(*limit.input));
779 LogicalOperator::Limit(limit)
780 }
781 LogicalOperator::Skip(mut skip) => {
782 skip.input = Box::new(self.push_filters_down(*skip.input));
783 LogicalOperator::Skip(skip)
784 }
785 LogicalOperator::Sort(mut sort) => {
786 sort.input = Box::new(self.push_filters_down(*sort.input));
787 LogicalOperator::Sort(sort)
788 }
789 LogicalOperator::Distinct(mut distinct) => {
790 distinct.input = Box::new(self.push_filters_down(*distinct.input));
791 LogicalOperator::Distinct(distinct)
792 }
793 LogicalOperator::Expand(mut expand) => {
794 expand.input = Box::new(self.push_filters_down(*expand.input));
795 LogicalOperator::Expand(expand)
796 }
797 LogicalOperator::Join(mut join) => {
798 join.left = Box::new(self.push_filters_down(*join.left));
799 join.right = Box::new(self.push_filters_down(*join.right));
800 LogicalOperator::Join(join)
801 }
802 LogicalOperator::Aggregate(mut agg) => {
803 agg.input = Box::new(self.push_filters_down(*agg.input));
804 LogicalOperator::Aggregate(agg)
805 }
806 LogicalOperator::MapCollect(mut mc) => {
807 mc.input = Box::new(self.push_filters_down(*mc.input));
808 LogicalOperator::MapCollect(mc)
809 }
810 LogicalOperator::MultiWayJoin(mut mwj) => {
811 mwj.inputs = mwj
812 .inputs
813 .into_iter()
814 .map(|input| self.push_filters_down(input))
815 .collect();
816 LogicalOperator::MultiWayJoin(mwj)
817 }
818 other => other,
820 }
821 }
822
823 fn try_push_filter_into(
828 &self,
829 predicate: LogicalExpression,
830 op: LogicalOperator,
831 ) -> LogicalOperator {
832 match op {
833 LogicalOperator::Project(mut proj) => {
835 let predicate_vars = self.extract_variables(&predicate);
836 let computed_vars = self.extract_projection_aliases(&proj.projections);
837
838 if predicate_vars.is_disjoint(&computed_vars) {
840 proj.input = Box::new(self.try_push_filter_into(predicate, *proj.input));
841 LogicalOperator::Project(proj)
842 } else {
843 LogicalOperator::Filter(FilterOp {
845 predicate,
846 pushdown_hint: None,
847 input: Box::new(LogicalOperator::Project(proj)),
848 })
849 }
850 }
851
852 LogicalOperator::Return(mut ret) => {
854 ret.input = Box::new(self.try_push_filter_into(predicate, *ret.input));
855 LogicalOperator::Return(ret)
856 }
857
858 LogicalOperator::Expand(mut expand) => {
860 let predicate_vars = self.extract_variables(&predicate);
861
862 let mut introduced_vars = vec![&expand.to_variable];
867 if let Some(ref edge_var) = expand.edge_variable {
868 introduced_vars.push(edge_var);
869 }
870 if let Some(ref path_alias) = expand.path_alias {
871 introduced_vars.push(path_alias);
872 }
873
874 let uses_introduced_vars =
876 predicate_vars.iter().any(|v| introduced_vars.contains(&v));
877
878 if !uses_introduced_vars {
879 expand.input = Box::new(self.try_push_filter_into(predicate, *expand.input));
881 LogicalOperator::Expand(expand)
882 } else {
883 LogicalOperator::Filter(FilterOp {
885 predicate,
886 pushdown_hint: None,
887 input: Box::new(LogicalOperator::Expand(expand)),
888 })
889 }
890 }
891
892 LogicalOperator::Join(mut join) => {
894 let predicate_vars = self.extract_variables(&predicate);
895 let left_vars = self.collect_output_variables(&join.left);
896 let right_vars = self.collect_output_variables(&join.right);
897
898 let uses_left = predicate_vars.iter().any(|v| left_vars.contains(v));
899 let uses_right = predicate_vars.iter().any(|v| right_vars.contains(v));
900
901 if uses_left && !uses_right {
902 join.left = Box::new(self.try_push_filter_into(predicate, *join.left));
904 LogicalOperator::Join(join)
905 } else if uses_right && !uses_left {
906 join.right = Box::new(self.try_push_filter_into(predicate, *join.right));
908 LogicalOperator::Join(join)
909 } else {
910 LogicalOperator::Filter(FilterOp {
912 predicate,
913 pushdown_hint: None,
914 input: Box::new(LogicalOperator::Join(join)),
915 })
916 }
917 }
918
919 LogicalOperator::Aggregate(agg) => LogicalOperator::Filter(FilterOp {
921 predicate,
922 pushdown_hint: None,
923 input: Box::new(LogicalOperator::Aggregate(agg)),
924 }),
925
926 LogicalOperator::NodeScan(scan) => LogicalOperator::Filter(FilterOp {
928 predicate,
929 pushdown_hint: None,
930 input: Box::new(LogicalOperator::NodeScan(scan)),
931 }),
932
933 other => LogicalOperator::Filter(FilterOp {
935 predicate,
936 pushdown_hint: None,
937 input: Box::new(other),
938 }),
939 }
940 }
941
942 fn collect_output_variables(&self, op: &LogicalOperator) -> HashSet<String> {
944 let mut vars = HashSet::new();
945 Self::collect_output_variables_recursive(op, &mut vars);
946 vars
947 }
948
949 fn collect_output_variables_recursive(op: &LogicalOperator, vars: &mut HashSet<String>) {
951 match op {
952 LogicalOperator::NodeScan(scan) => {
953 vars.insert(scan.variable.clone());
954 }
955 LogicalOperator::EdgeScan(scan) => {
956 vars.insert(scan.variable.clone());
957 }
958 LogicalOperator::Expand(expand) => {
959 vars.insert(expand.to_variable.clone());
960 if let Some(edge_var) = &expand.edge_variable {
961 vars.insert(edge_var.clone());
962 }
963 Self::collect_output_variables_recursive(&expand.input, vars);
964 }
965 LogicalOperator::Filter(filter) => {
966 Self::collect_output_variables_recursive(&filter.input, vars);
967 }
968 LogicalOperator::Project(proj) => {
969 for p in &proj.projections {
970 if let Some(alias) = &p.alias {
971 vars.insert(alias.clone());
972 }
973 }
974 Self::collect_output_variables_recursive(&proj.input, vars);
975 }
976 LogicalOperator::Join(join) => {
977 Self::collect_output_variables_recursive(&join.left, vars);
978 Self::collect_output_variables_recursive(&join.right, vars);
979 }
980 LogicalOperator::Aggregate(agg) => {
981 for expr in &agg.group_by {
982 Self::collect_variables(expr, vars);
983 }
984 for agg_expr in &agg.aggregates {
985 if let Some(alias) = &agg_expr.alias {
986 vars.insert(alias.clone());
987 }
988 }
989 }
990 LogicalOperator::Return(ret) => {
991 Self::collect_output_variables_recursive(&ret.input, vars);
992 }
993 LogicalOperator::Limit(limit) => {
994 Self::collect_output_variables_recursive(&limit.input, vars);
995 }
996 LogicalOperator::Skip(skip) => {
997 Self::collect_output_variables_recursive(&skip.input, vars);
998 }
999 LogicalOperator::Sort(sort) => {
1000 Self::collect_output_variables_recursive(&sort.input, vars);
1001 }
1002 LogicalOperator::Distinct(distinct) => {
1003 Self::collect_output_variables_recursive(&distinct.input, vars);
1004 }
1005 _ => {}
1006 }
1007 }
1008
1009 fn extract_variables(&self, expr: &LogicalExpression) -> HashSet<String> {
1011 let mut vars = HashSet::new();
1012 Self::collect_variables(expr, &mut vars);
1013 vars
1014 }
1015
1016 fn collect_variables(expr: &LogicalExpression, vars: &mut HashSet<String>) {
1018 match expr {
1019 LogicalExpression::Variable(name) => {
1020 vars.insert(name.clone());
1021 }
1022 LogicalExpression::Property { variable, .. } => {
1023 vars.insert(variable.clone());
1024 }
1025 LogicalExpression::Binary { left, right, .. } => {
1026 Self::collect_variables(left, vars);
1027 Self::collect_variables(right, vars);
1028 }
1029 LogicalExpression::Unary { operand, .. } => {
1030 Self::collect_variables(operand, vars);
1031 }
1032 LogicalExpression::FunctionCall { args, .. } => {
1033 for arg in args {
1034 Self::collect_variables(arg, vars);
1035 }
1036 }
1037 LogicalExpression::List(items) => {
1038 for item in items {
1039 Self::collect_variables(item, vars);
1040 }
1041 }
1042 LogicalExpression::Map(pairs) => {
1043 for (_, value) in pairs {
1044 Self::collect_variables(value, vars);
1045 }
1046 }
1047 LogicalExpression::IndexAccess { base, index } => {
1048 Self::collect_variables(base, vars);
1049 Self::collect_variables(index, vars);
1050 }
1051 LogicalExpression::SliceAccess { base, start, end } => {
1052 Self::collect_variables(base, vars);
1053 if let Some(s) = start {
1054 Self::collect_variables(s, vars);
1055 }
1056 if let Some(e) = end {
1057 Self::collect_variables(e, vars);
1058 }
1059 }
1060 LogicalExpression::Case {
1061 operand,
1062 when_clauses,
1063 else_clause,
1064 } => {
1065 if let Some(op) = operand {
1066 Self::collect_variables(op, vars);
1067 }
1068 for (cond, result) in when_clauses {
1069 Self::collect_variables(cond, vars);
1070 Self::collect_variables(result, vars);
1071 }
1072 if let Some(else_expr) = else_clause {
1073 Self::collect_variables(else_expr, vars);
1074 }
1075 }
1076 LogicalExpression::Labels(var)
1077 | LogicalExpression::Type(var)
1078 | LogicalExpression::Id(var) => {
1079 vars.insert(var.clone());
1080 }
1081 LogicalExpression::Literal(_) | LogicalExpression::Parameter(_) => {}
1082 LogicalExpression::ListComprehension {
1083 list_expr,
1084 filter_expr,
1085 map_expr,
1086 ..
1087 } => {
1088 Self::collect_variables(list_expr, vars);
1089 if let Some(filter) = filter_expr {
1090 Self::collect_variables(filter, vars);
1091 }
1092 Self::collect_variables(map_expr, vars);
1093 }
1094 LogicalExpression::ListPredicate {
1095 list_expr,
1096 predicate,
1097 ..
1098 } => {
1099 Self::collect_variables(list_expr, vars);
1100 Self::collect_variables(predicate, vars);
1101 }
1102 LogicalExpression::ExistsSubquery(_)
1103 | LogicalExpression::CountSubquery(_)
1104 | LogicalExpression::ValueSubquery(_) => {
1105 }
1107 LogicalExpression::PatternComprehension { projection, .. } => {
1108 Self::collect_variables(projection, vars);
1109 }
1110 LogicalExpression::MapProjection { base, entries } => {
1111 vars.insert(base.clone());
1112 for entry in entries {
1113 if let crate::query::plan::MapProjectionEntry::LiteralEntry(_, expr) = entry {
1114 Self::collect_variables(expr, vars);
1115 }
1116 }
1117 }
1118 LogicalExpression::Reduce {
1119 initial,
1120 list,
1121 expression,
1122 ..
1123 } => {
1124 Self::collect_variables(initial, vars);
1125 Self::collect_variables(list, vars);
1126 Self::collect_variables(expression, vars);
1127 }
1128 }
1129 }
1130
1131 fn extract_projection_aliases(
1133 &self,
1134 projections: &[crate::query::plan::Projection],
1135 ) -> HashSet<String> {
1136 projections.iter().filter_map(|p| p.alias.clone()).collect()
1137 }
1138}
1139
1140impl Default for Optimizer {
1141 fn default() -> Self {
1142 Self::new()
1143 }
1144}
1145
1146#[cfg(test)]
1147mod tests {
1148 use super::*;
1149 use crate::query::plan::{
1150 AggregateExpr, AggregateFunction, AggregateOp, BinaryOp, DistinctOp, ExpandDirection,
1151 ExpandOp, JoinOp, JoinType, LimitOp, NodeScanOp, PathMode, ProjectOp, Projection,
1152 ReturnItem, ReturnOp, SkipOp, SortKey, SortOp, SortOrder, UnaryOp,
1153 };
1154 use grafeo_common::types::Value;
1155
1156 #[test]
1157 fn test_optimizer_filter_pushdown_simple() {
1158 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1163 items: vec![ReturnItem {
1164 expression: LogicalExpression::Variable("n".to_string()),
1165 alias: None,
1166 }],
1167 distinct: false,
1168 input: Box::new(LogicalOperator::Filter(FilterOp {
1169 predicate: LogicalExpression::Binary {
1170 left: Box::new(LogicalExpression::Property {
1171 variable: "n".to_string(),
1172 property: "age".to_string(),
1173 }),
1174 op: BinaryOp::Gt,
1175 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1176 },
1177 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1178 variable: "n".to_string(),
1179 label: Some("Person".to_string()),
1180 input: None,
1181 })),
1182 pushdown_hint: None,
1183 })),
1184 }));
1185
1186 let optimizer = Optimizer::new();
1187 let optimized = optimizer.optimize(plan).unwrap();
1188
1189 if let LogicalOperator::Return(ret) = &optimized.root
1191 && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1192 && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1193 {
1194 assert_eq!(scan.variable, "n");
1195 return;
1196 }
1197 panic!("Expected Return -> Filter -> NodeScan structure");
1198 }
1199
1200 #[test]
1201 fn test_optimizer_filter_pushdown_through_expand() {
1202 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1206 items: vec![ReturnItem {
1207 expression: LogicalExpression::Variable("b".to_string()),
1208 alias: None,
1209 }],
1210 distinct: false,
1211 input: Box::new(LogicalOperator::Filter(FilterOp {
1212 predicate: LogicalExpression::Binary {
1213 left: Box::new(LogicalExpression::Property {
1214 variable: "a".to_string(),
1215 property: "age".to_string(),
1216 }),
1217 op: BinaryOp::Gt,
1218 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1219 },
1220 pushdown_hint: None,
1221 input: Box::new(LogicalOperator::Expand(ExpandOp {
1222 from_variable: "a".to_string(),
1223 to_variable: "b".to_string(),
1224 edge_variable: None,
1225 direction: ExpandDirection::Outgoing,
1226 edge_types: vec!["KNOWS".to_string()],
1227 min_hops: 1,
1228 max_hops: Some(1),
1229 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1230 variable: "a".to_string(),
1231 label: Some("Person".to_string()),
1232 input: None,
1233 })),
1234 path_alias: None,
1235 path_mode: PathMode::Walk,
1236 })),
1237 })),
1238 }));
1239
1240 let optimizer = Optimizer::new();
1241 let optimized = optimizer.optimize(plan).unwrap();
1242
1243 if let LogicalOperator::Return(ret) = &optimized.root
1246 && let LogicalOperator::Expand(expand) = ret.input.as_ref()
1247 && let LogicalOperator::Filter(filter) = expand.input.as_ref()
1248 && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1249 {
1250 assert_eq!(scan.variable, "a");
1251 assert_eq!(expand.from_variable, "a");
1252 assert_eq!(expand.to_variable, "b");
1253 return;
1254 }
1255 panic!("Expected Return -> Expand -> Filter -> NodeScan structure");
1256 }
1257
1258 #[test]
1259 fn test_optimizer_filter_not_pushed_through_expand_for_target_var() {
1260 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1264 items: vec![ReturnItem {
1265 expression: LogicalExpression::Variable("a".to_string()),
1266 alias: None,
1267 }],
1268 distinct: false,
1269 input: Box::new(LogicalOperator::Filter(FilterOp {
1270 predicate: LogicalExpression::Binary {
1271 left: Box::new(LogicalExpression::Property {
1272 variable: "b".to_string(),
1273 property: "age".to_string(),
1274 }),
1275 op: BinaryOp::Gt,
1276 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1277 },
1278 pushdown_hint: None,
1279 input: Box::new(LogicalOperator::Expand(ExpandOp {
1280 from_variable: "a".to_string(),
1281 to_variable: "b".to_string(),
1282 edge_variable: None,
1283 direction: ExpandDirection::Outgoing,
1284 edge_types: vec!["KNOWS".to_string()],
1285 min_hops: 1,
1286 max_hops: Some(1),
1287 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1288 variable: "a".to_string(),
1289 label: Some("Person".to_string()),
1290 input: None,
1291 })),
1292 path_alias: None,
1293 path_mode: PathMode::Walk,
1294 })),
1295 })),
1296 }));
1297
1298 let optimizer = Optimizer::new();
1299 let optimized = optimizer.optimize(plan).unwrap();
1300
1301 if let LogicalOperator::Return(ret) = &optimized.root
1304 && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1305 {
1306 if let LogicalExpression::Binary { left, .. } = &filter.predicate
1308 && let LogicalExpression::Property { variable, .. } = left.as_ref()
1309 {
1310 assert_eq!(variable, "b");
1311 }
1312
1313 if let LogicalOperator::Expand(expand) = filter.input.as_ref()
1314 && let LogicalOperator::NodeScan(_) = expand.input.as_ref()
1315 {
1316 return;
1317 }
1318 }
1319 panic!("Expected Return -> Filter -> Expand -> NodeScan structure");
1320 }
1321
1322 #[test]
1323 fn test_optimizer_extract_variables() {
1324 let optimizer = Optimizer::new();
1325
1326 let expr = LogicalExpression::Binary {
1327 left: Box::new(LogicalExpression::Property {
1328 variable: "n".to_string(),
1329 property: "age".to_string(),
1330 }),
1331 op: BinaryOp::Gt,
1332 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1333 };
1334
1335 let vars = optimizer.extract_variables(&expr);
1336 assert_eq!(vars.len(), 1);
1337 assert!(vars.contains("n"));
1338 }
1339
1340 #[test]
1343 fn test_optimizer_default() {
1344 let optimizer = Optimizer::default();
1345 let plan = LogicalPlan::new(LogicalOperator::Empty);
1347 let result = optimizer.optimize(plan);
1348 assert!(result.is_ok());
1349 }
1350
1351 #[test]
1352 fn test_optimizer_with_filter_pushdown_disabled() {
1353 let optimizer = Optimizer::new().with_filter_pushdown(false);
1354
1355 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1356 items: vec![ReturnItem {
1357 expression: LogicalExpression::Variable("n".to_string()),
1358 alias: None,
1359 }],
1360 distinct: false,
1361 input: Box::new(LogicalOperator::Filter(FilterOp {
1362 predicate: LogicalExpression::Literal(Value::Bool(true)),
1363 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1364 variable: "n".to_string(),
1365 label: None,
1366 input: None,
1367 })),
1368 pushdown_hint: None,
1369 })),
1370 }));
1371
1372 let optimized = optimizer.optimize(plan).unwrap();
1373 if let LogicalOperator::Return(ret) = &optimized.root
1375 && let LogicalOperator::Filter(_) = ret.input.as_ref()
1376 {
1377 return;
1378 }
1379 panic!("Expected unchanged structure");
1380 }
1381
1382 #[test]
1383 fn test_optimizer_with_join_reorder_disabled() {
1384 let optimizer = Optimizer::new().with_join_reorder(false);
1385 assert!(
1386 optimizer
1387 .optimize(LogicalPlan::new(LogicalOperator::Empty))
1388 .is_ok()
1389 );
1390 }
1391
1392 #[test]
1393 fn test_optimizer_with_cost_model() {
1394 let cost_model = CostModel::new();
1395 let optimizer = Optimizer::new().with_cost_model(cost_model);
1396 assert!(
1397 optimizer
1398 .cost_model()
1399 .estimate(&LogicalOperator::Empty, 0.0)
1400 .total()
1401 < 0.001
1402 );
1403 }
1404
1405 #[test]
1406 fn test_optimizer_with_cardinality_estimator() {
1407 let mut estimator = CardinalityEstimator::new();
1408 estimator.add_table_stats("Test", TableStats::new(500));
1409 let optimizer = Optimizer::new().with_cardinality_estimator(estimator);
1410
1411 let scan = LogicalOperator::NodeScan(NodeScanOp {
1412 variable: "n".to_string(),
1413 label: Some("Test".to_string()),
1414 input: None,
1415 });
1416 let plan = LogicalPlan::new(scan);
1417
1418 let cardinality = optimizer.estimate_cardinality(&plan);
1419 assert!((cardinality - 500.0).abs() < 0.001);
1420 }
1421
1422 #[test]
1423 fn test_optimizer_estimate_cost() {
1424 let optimizer = Optimizer::new();
1425 let plan = LogicalPlan::new(LogicalOperator::NodeScan(NodeScanOp {
1426 variable: "n".to_string(),
1427 label: None,
1428 input: None,
1429 }));
1430
1431 let cost = optimizer.estimate_cost(&plan);
1432 assert!(cost.total() > 0.0);
1433 }
1434
1435 #[test]
1438 fn test_filter_pushdown_through_project() {
1439 let optimizer = Optimizer::new();
1440
1441 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1442 predicate: LogicalExpression::Binary {
1443 left: Box::new(LogicalExpression::Property {
1444 variable: "n".to_string(),
1445 property: "age".to_string(),
1446 }),
1447 op: BinaryOp::Gt,
1448 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1449 },
1450 pushdown_hint: None,
1451 input: Box::new(LogicalOperator::Project(ProjectOp {
1452 projections: vec![Projection {
1453 expression: LogicalExpression::Variable("n".to_string()),
1454 alias: None,
1455 }],
1456 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1457 variable: "n".to_string(),
1458 label: None,
1459 input: None,
1460 })),
1461 pass_through_input: false,
1462 })),
1463 }));
1464
1465 let optimized = optimizer.optimize(plan).unwrap();
1466
1467 if let LogicalOperator::Project(proj) = &optimized.root
1469 && let LogicalOperator::Filter(_) = proj.input.as_ref()
1470 {
1471 return;
1472 }
1473 panic!("Expected Project -> Filter structure");
1474 }
1475
1476 #[test]
1477 fn test_filter_not_pushed_through_project_with_alias() {
1478 let optimizer = Optimizer::new();
1479
1480 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1482 predicate: LogicalExpression::Binary {
1483 left: Box::new(LogicalExpression::Variable("x".to_string())),
1484 op: BinaryOp::Gt,
1485 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1486 },
1487 pushdown_hint: None,
1488 input: Box::new(LogicalOperator::Project(ProjectOp {
1489 projections: vec![Projection {
1490 expression: LogicalExpression::Property {
1491 variable: "n".to_string(),
1492 property: "age".to_string(),
1493 },
1494 alias: Some("x".to_string()),
1495 }],
1496 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1497 variable: "n".to_string(),
1498 label: None,
1499 input: None,
1500 })),
1501 pass_through_input: false,
1502 })),
1503 }));
1504
1505 let optimized = optimizer.optimize(plan).unwrap();
1506
1507 if let LogicalOperator::Filter(filter) = &optimized.root
1509 && let LogicalOperator::Project(_) = filter.input.as_ref()
1510 {
1511 return;
1512 }
1513 panic!("Expected Filter -> Project structure");
1514 }
1515
1516 #[test]
1517 fn test_filter_pushdown_through_limit() {
1518 let optimizer = Optimizer::new();
1519
1520 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1521 predicate: LogicalExpression::Literal(Value::Bool(true)),
1522 pushdown_hint: None,
1523 input: Box::new(LogicalOperator::Limit(LimitOp {
1524 count: 10.into(),
1525 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1526 variable: "n".to_string(),
1527 label: None,
1528 input: None,
1529 })),
1530 })),
1531 }));
1532
1533 let optimized = optimizer.optimize(plan).unwrap();
1534
1535 if let LogicalOperator::Filter(filter) = &optimized.root
1537 && let LogicalOperator::Limit(_) = filter.input.as_ref()
1538 {
1539 return;
1540 }
1541 panic!("Expected Filter -> Limit structure");
1542 }
1543
1544 #[test]
1545 fn test_filter_pushdown_through_sort() {
1546 let optimizer = Optimizer::new();
1547
1548 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1549 predicate: LogicalExpression::Literal(Value::Bool(true)),
1550 pushdown_hint: None,
1551 input: Box::new(LogicalOperator::Sort(SortOp {
1552 keys: vec![SortKey {
1553 expression: LogicalExpression::Variable("n".to_string()),
1554 order: SortOrder::Ascending,
1555 nulls: None,
1556 }],
1557 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1558 variable: "n".to_string(),
1559 label: None,
1560 input: None,
1561 })),
1562 })),
1563 }));
1564
1565 let optimized = optimizer.optimize(plan).unwrap();
1566
1567 if let LogicalOperator::Filter(filter) = &optimized.root
1569 && let LogicalOperator::Sort(_) = filter.input.as_ref()
1570 {
1571 return;
1572 }
1573 panic!("Expected Filter -> Sort structure");
1574 }
1575
1576 #[test]
1577 fn test_filter_pushdown_through_distinct() {
1578 let optimizer = Optimizer::new();
1579
1580 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1581 predicate: LogicalExpression::Literal(Value::Bool(true)),
1582 pushdown_hint: None,
1583 input: Box::new(LogicalOperator::Distinct(DistinctOp {
1584 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1585 variable: "n".to_string(),
1586 label: None,
1587 input: None,
1588 })),
1589 columns: None,
1590 })),
1591 }));
1592
1593 let optimized = optimizer.optimize(plan).unwrap();
1594
1595 if let LogicalOperator::Filter(filter) = &optimized.root
1597 && let LogicalOperator::Distinct(_) = filter.input.as_ref()
1598 {
1599 return;
1600 }
1601 panic!("Expected Filter -> Distinct structure");
1602 }
1603
1604 #[test]
1605 fn test_filter_not_pushed_through_aggregate() {
1606 let optimizer = Optimizer::new();
1607
1608 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1609 predicate: LogicalExpression::Binary {
1610 left: Box::new(LogicalExpression::Variable("cnt".to_string())),
1611 op: BinaryOp::Gt,
1612 right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1613 },
1614 pushdown_hint: None,
1615 input: Box::new(LogicalOperator::Aggregate(AggregateOp {
1616 group_by: vec![],
1617 aggregates: vec![AggregateExpr {
1618 function: AggregateFunction::Count,
1619 expression: None,
1620 expression2: None,
1621 distinct: false,
1622 alias: Some("cnt".to_string()),
1623 percentile: None,
1624 separator: None,
1625 }],
1626 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1627 variable: "n".to_string(),
1628 label: None,
1629 input: None,
1630 })),
1631 having: None,
1632 })),
1633 }));
1634
1635 let optimized = optimizer.optimize(plan).unwrap();
1636
1637 if let LogicalOperator::Filter(filter) = &optimized.root
1639 && let LogicalOperator::Aggregate(_) = filter.input.as_ref()
1640 {
1641 return;
1642 }
1643 panic!("Expected Filter -> Aggregate structure");
1644 }
1645
1646 #[test]
1647 fn test_filter_pushdown_to_left_join_side() {
1648 let optimizer = Optimizer::new();
1649
1650 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1652 predicate: LogicalExpression::Binary {
1653 left: Box::new(LogicalExpression::Property {
1654 variable: "a".to_string(),
1655 property: "age".to_string(),
1656 }),
1657 op: BinaryOp::Gt,
1658 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1659 },
1660 pushdown_hint: None,
1661 input: Box::new(LogicalOperator::Join(JoinOp {
1662 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1663 variable: "a".to_string(),
1664 label: Some("Person".to_string()),
1665 input: None,
1666 })),
1667 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1668 variable: "b".to_string(),
1669 label: Some("Company".to_string()),
1670 input: None,
1671 })),
1672 join_type: JoinType::Inner,
1673 conditions: vec![],
1674 })),
1675 }));
1676
1677 let optimized = optimizer.optimize(plan).unwrap();
1678
1679 if let LogicalOperator::Join(join) = &optimized.root
1681 && let LogicalOperator::Filter(_) = join.left.as_ref()
1682 {
1683 return;
1684 }
1685 panic!("Expected Join with Filter on left side");
1686 }
1687
1688 #[test]
1689 fn test_filter_pushdown_to_right_join_side() {
1690 let optimizer = Optimizer::new();
1691
1692 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1694 predicate: LogicalExpression::Binary {
1695 left: Box::new(LogicalExpression::Property {
1696 variable: "b".to_string(),
1697 property: "name".to_string(),
1698 }),
1699 op: BinaryOp::Eq,
1700 right: Box::new(LogicalExpression::Literal(Value::String("Acme".into()))),
1701 },
1702 pushdown_hint: None,
1703 input: Box::new(LogicalOperator::Join(JoinOp {
1704 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1705 variable: "a".to_string(),
1706 label: Some("Person".to_string()),
1707 input: None,
1708 })),
1709 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1710 variable: "b".to_string(),
1711 label: Some("Company".to_string()),
1712 input: None,
1713 })),
1714 join_type: JoinType::Inner,
1715 conditions: vec![],
1716 })),
1717 }));
1718
1719 let optimized = optimizer.optimize(plan).unwrap();
1720
1721 if let LogicalOperator::Join(join) = &optimized.root
1723 && let LogicalOperator::Filter(_) = join.right.as_ref()
1724 {
1725 return;
1726 }
1727 panic!("Expected Join with Filter on right side");
1728 }
1729
1730 #[test]
1731 fn test_filter_not_pushed_when_uses_both_join_sides() {
1732 let optimizer = Optimizer::new();
1733
1734 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1736 predicate: LogicalExpression::Binary {
1737 left: Box::new(LogicalExpression::Property {
1738 variable: "a".to_string(),
1739 property: "id".to_string(),
1740 }),
1741 op: BinaryOp::Eq,
1742 right: Box::new(LogicalExpression::Property {
1743 variable: "b".to_string(),
1744 property: "a_id".to_string(),
1745 }),
1746 },
1747 pushdown_hint: None,
1748 input: Box::new(LogicalOperator::Join(JoinOp {
1749 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1750 variable: "a".to_string(),
1751 label: None,
1752 input: None,
1753 })),
1754 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1755 variable: "b".to_string(),
1756 label: None,
1757 input: None,
1758 })),
1759 join_type: JoinType::Inner,
1760 conditions: vec![],
1761 })),
1762 }));
1763
1764 let optimized = optimizer.optimize(plan).unwrap();
1765
1766 if let LogicalOperator::Filter(filter) = &optimized.root
1768 && let LogicalOperator::Join(_) = filter.input.as_ref()
1769 {
1770 return;
1771 }
1772 panic!("Expected Filter -> Join structure");
1773 }
1774
1775 #[test]
1778 fn test_extract_variables_from_variable() {
1779 let optimizer = Optimizer::new();
1780 let expr = LogicalExpression::Variable("x".to_string());
1781 let vars = optimizer.extract_variables(&expr);
1782 assert_eq!(vars.len(), 1);
1783 assert!(vars.contains("x"));
1784 }
1785
1786 #[test]
1787 fn test_extract_variables_from_unary() {
1788 let optimizer = Optimizer::new();
1789 let expr = LogicalExpression::Unary {
1790 op: UnaryOp::Not,
1791 operand: Box::new(LogicalExpression::Variable("x".to_string())),
1792 };
1793 let vars = optimizer.extract_variables(&expr);
1794 assert_eq!(vars.len(), 1);
1795 assert!(vars.contains("x"));
1796 }
1797
1798 #[test]
1799 fn test_extract_variables_from_function_call() {
1800 let optimizer = Optimizer::new();
1801 let expr = LogicalExpression::FunctionCall {
1802 name: "length".to_string(),
1803 args: vec![
1804 LogicalExpression::Variable("a".to_string()),
1805 LogicalExpression::Variable("b".to_string()),
1806 ],
1807 distinct: false,
1808 };
1809 let vars = optimizer.extract_variables(&expr);
1810 assert_eq!(vars.len(), 2);
1811 assert!(vars.contains("a"));
1812 assert!(vars.contains("b"));
1813 }
1814
1815 #[test]
1816 fn test_extract_variables_from_list() {
1817 let optimizer = Optimizer::new();
1818 let expr = LogicalExpression::List(vec![
1819 LogicalExpression::Variable("a".to_string()),
1820 LogicalExpression::Literal(Value::Int64(1)),
1821 LogicalExpression::Variable("b".to_string()),
1822 ]);
1823 let vars = optimizer.extract_variables(&expr);
1824 assert_eq!(vars.len(), 2);
1825 assert!(vars.contains("a"));
1826 assert!(vars.contains("b"));
1827 }
1828
1829 #[test]
1830 fn test_extract_variables_from_map() {
1831 let optimizer = Optimizer::new();
1832 let expr = LogicalExpression::Map(vec![
1833 (
1834 "key1".to_string(),
1835 LogicalExpression::Variable("a".to_string()),
1836 ),
1837 (
1838 "key2".to_string(),
1839 LogicalExpression::Variable("b".to_string()),
1840 ),
1841 ]);
1842 let vars = optimizer.extract_variables(&expr);
1843 assert_eq!(vars.len(), 2);
1844 assert!(vars.contains("a"));
1845 assert!(vars.contains("b"));
1846 }
1847
1848 #[test]
1849 fn test_extract_variables_from_index_access() {
1850 let optimizer = Optimizer::new();
1851 let expr = LogicalExpression::IndexAccess {
1852 base: Box::new(LogicalExpression::Variable("list".to_string())),
1853 index: Box::new(LogicalExpression::Variable("idx".to_string())),
1854 };
1855 let vars = optimizer.extract_variables(&expr);
1856 assert_eq!(vars.len(), 2);
1857 assert!(vars.contains("list"));
1858 assert!(vars.contains("idx"));
1859 }
1860
1861 #[test]
1862 fn test_extract_variables_from_slice_access() {
1863 let optimizer = Optimizer::new();
1864 let expr = LogicalExpression::SliceAccess {
1865 base: Box::new(LogicalExpression::Variable("list".to_string())),
1866 start: Some(Box::new(LogicalExpression::Variable("s".to_string()))),
1867 end: Some(Box::new(LogicalExpression::Variable("e".to_string()))),
1868 };
1869 let vars = optimizer.extract_variables(&expr);
1870 assert_eq!(vars.len(), 3);
1871 assert!(vars.contains("list"));
1872 assert!(vars.contains("s"));
1873 assert!(vars.contains("e"));
1874 }
1875
1876 #[test]
1877 fn test_extract_variables_from_case() {
1878 let optimizer = Optimizer::new();
1879 let expr = LogicalExpression::Case {
1880 operand: Some(Box::new(LogicalExpression::Variable("x".to_string()))),
1881 when_clauses: vec![(
1882 LogicalExpression::Literal(Value::Int64(1)),
1883 LogicalExpression::Variable("a".to_string()),
1884 )],
1885 else_clause: Some(Box::new(LogicalExpression::Variable("b".to_string()))),
1886 };
1887 let vars = optimizer.extract_variables(&expr);
1888 assert_eq!(vars.len(), 3);
1889 assert!(vars.contains("x"));
1890 assert!(vars.contains("a"));
1891 assert!(vars.contains("b"));
1892 }
1893
1894 #[test]
1895 fn test_extract_variables_from_labels() {
1896 let optimizer = Optimizer::new();
1897 let expr = LogicalExpression::Labels("n".to_string());
1898 let vars = optimizer.extract_variables(&expr);
1899 assert_eq!(vars.len(), 1);
1900 assert!(vars.contains("n"));
1901 }
1902
1903 #[test]
1904 fn test_extract_variables_from_type() {
1905 let optimizer = Optimizer::new();
1906 let expr = LogicalExpression::Type("e".to_string());
1907 let vars = optimizer.extract_variables(&expr);
1908 assert_eq!(vars.len(), 1);
1909 assert!(vars.contains("e"));
1910 }
1911
1912 #[test]
1913 fn test_extract_variables_from_id() {
1914 let optimizer = Optimizer::new();
1915 let expr = LogicalExpression::Id("n".to_string());
1916 let vars = optimizer.extract_variables(&expr);
1917 assert_eq!(vars.len(), 1);
1918 assert!(vars.contains("n"));
1919 }
1920
1921 #[test]
1922 fn test_extract_variables_from_list_comprehension() {
1923 let optimizer = Optimizer::new();
1924 let expr = LogicalExpression::ListComprehension {
1925 variable: "x".to_string(),
1926 list_expr: Box::new(LogicalExpression::Variable("items".to_string())),
1927 filter_expr: Some(Box::new(LogicalExpression::Variable("pred".to_string()))),
1928 map_expr: Box::new(LogicalExpression::Variable("result".to_string())),
1929 };
1930 let vars = optimizer.extract_variables(&expr);
1931 assert!(vars.contains("items"));
1932 assert!(vars.contains("pred"));
1933 assert!(vars.contains("result"));
1934 }
1935
1936 #[test]
1937 fn test_extract_variables_from_literal_and_parameter() {
1938 let optimizer = Optimizer::new();
1939
1940 let literal = LogicalExpression::Literal(Value::Int64(42));
1941 assert!(optimizer.extract_variables(&literal).is_empty());
1942
1943 let param = LogicalExpression::Parameter("p".to_string());
1944 assert!(optimizer.extract_variables(¶m).is_empty());
1945 }
1946
1947 #[test]
1950 fn test_recursive_filter_pushdown_through_skip() {
1951 let optimizer = Optimizer::new();
1952
1953 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1954 items: vec![ReturnItem {
1955 expression: LogicalExpression::Variable("n".to_string()),
1956 alias: None,
1957 }],
1958 distinct: false,
1959 input: Box::new(LogicalOperator::Filter(FilterOp {
1960 predicate: LogicalExpression::Literal(Value::Bool(true)),
1961 pushdown_hint: None,
1962 input: Box::new(LogicalOperator::Skip(SkipOp {
1963 count: 5.into(),
1964 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1965 variable: "n".to_string(),
1966 label: None,
1967 input: None,
1968 })),
1969 })),
1970 })),
1971 }));
1972
1973 let optimized = optimizer.optimize(plan).unwrap();
1974
1975 assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1977 }
1978
1979 #[test]
1980 fn test_nested_filter_pushdown() {
1981 let optimizer = Optimizer::new();
1982
1983 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1985 items: vec![ReturnItem {
1986 expression: LogicalExpression::Variable("n".to_string()),
1987 alias: None,
1988 }],
1989 distinct: false,
1990 input: Box::new(LogicalOperator::Filter(FilterOp {
1991 predicate: LogicalExpression::Binary {
1992 left: Box::new(LogicalExpression::Property {
1993 variable: "n".to_string(),
1994 property: "x".to_string(),
1995 }),
1996 op: BinaryOp::Gt,
1997 right: Box::new(LogicalExpression::Literal(Value::Int64(1))),
1998 },
1999 pushdown_hint: None,
2000 input: Box::new(LogicalOperator::Filter(FilterOp {
2001 predicate: LogicalExpression::Binary {
2002 left: Box::new(LogicalExpression::Property {
2003 variable: "n".to_string(),
2004 property: "y".to_string(),
2005 }),
2006 op: BinaryOp::Lt,
2007 right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
2008 },
2009 pushdown_hint: None,
2010 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2011 variable: "n".to_string(),
2012 label: None,
2013 input: None,
2014 })),
2015 })),
2016 })),
2017 }));
2018
2019 let optimized = optimizer.optimize(plan).unwrap();
2020 assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
2021 }
2022
2023 #[test]
2024 fn test_cyclic_join_produces_multi_way_join() {
2025 use crate::query::plan::JoinCondition;
2026
2027 let scan_a = LogicalOperator::NodeScan(NodeScanOp {
2029 variable: "a".to_string(),
2030 label: Some("Person".to_string()),
2031 input: None,
2032 });
2033 let scan_b = LogicalOperator::NodeScan(NodeScanOp {
2034 variable: "b".to_string(),
2035 label: Some("Person".to_string()),
2036 input: None,
2037 });
2038 let scan_c = LogicalOperator::NodeScan(NodeScanOp {
2039 variable: "c".to_string(),
2040 label: Some("Person".to_string()),
2041 input: None,
2042 });
2043
2044 let join_ab = LogicalOperator::Join(JoinOp {
2046 left: Box::new(scan_a),
2047 right: Box::new(scan_b),
2048 join_type: JoinType::Inner,
2049 conditions: vec![JoinCondition {
2050 left: LogicalExpression::Variable("a".to_string()),
2051 right: LogicalExpression::Variable("b".to_string()),
2052 }],
2053 });
2054
2055 let join_abc = LogicalOperator::Join(JoinOp {
2056 left: Box::new(join_ab),
2057 right: Box::new(scan_c),
2058 join_type: JoinType::Inner,
2059 conditions: vec![
2060 JoinCondition {
2061 left: LogicalExpression::Variable("b".to_string()),
2062 right: LogicalExpression::Variable("c".to_string()),
2063 },
2064 JoinCondition {
2065 left: LogicalExpression::Variable("c".to_string()),
2066 right: LogicalExpression::Variable("a".to_string()),
2067 },
2068 ],
2069 });
2070
2071 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
2072 items: vec![ReturnItem {
2073 expression: LogicalExpression::Variable("a".to_string()),
2074 alias: None,
2075 }],
2076 distinct: false,
2077 input: Box::new(join_abc),
2078 }));
2079
2080 let mut optimizer = Optimizer::new();
2081 optimizer
2082 .card_estimator
2083 .add_table_stats("Person", cardinality::TableStats::new(1000));
2084
2085 let optimized = optimizer.optimize(plan).unwrap();
2086
2087 fn has_multi_way_join(op: &LogicalOperator) -> bool {
2089 match op {
2090 LogicalOperator::MultiWayJoin(_) => true,
2091 LogicalOperator::Return(ret) => has_multi_way_join(&ret.input),
2092 LogicalOperator::Filter(f) => has_multi_way_join(&f.input),
2093 LogicalOperator::Project(p) => has_multi_way_join(&p.input),
2094 _ => false,
2095 }
2096 }
2097
2098 assert!(
2099 has_multi_way_join(&optimized.root),
2100 "Expected MultiWayJoin for cyclic triangle pattern"
2101 );
2102 }
2103
2104 #[test]
2105 fn test_acyclic_join_uses_binary_joins() {
2106 use crate::query::plan::JoinCondition;
2107
2108 let scan_a = LogicalOperator::NodeScan(NodeScanOp {
2110 variable: "a".to_string(),
2111 label: Some("Person".to_string()),
2112 input: None,
2113 });
2114 let scan_b = LogicalOperator::NodeScan(NodeScanOp {
2115 variable: "b".to_string(),
2116 label: Some("Person".to_string()),
2117 input: None,
2118 });
2119 let scan_c = LogicalOperator::NodeScan(NodeScanOp {
2120 variable: "c".to_string(),
2121 label: Some("Company".to_string()),
2122 input: None,
2123 });
2124
2125 let join_ab = LogicalOperator::Join(JoinOp {
2126 left: Box::new(scan_a),
2127 right: Box::new(scan_b),
2128 join_type: JoinType::Inner,
2129 conditions: vec![JoinCondition {
2130 left: LogicalExpression::Variable("a".to_string()),
2131 right: LogicalExpression::Variable("b".to_string()),
2132 }],
2133 });
2134
2135 let join_abc = LogicalOperator::Join(JoinOp {
2136 left: Box::new(join_ab),
2137 right: Box::new(scan_c),
2138 join_type: JoinType::Inner,
2139 conditions: vec![JoinCondition {
2140 left: LogicalExpression::Variable("b".to_string()),
2141 right: LogicalExpression::Variable("c".to_string()),
2142 }],
2143 });
2144
2145 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
2146 items: vec![ReturnItem {
2147 expression: LogicalExpression::Variable("a".to_string()),
2148 alias: None,
2149 }],
2150 distinct: false,
2151 input: Box::new(join_abc),
2152 }));
2153
2154 let mut optimizer = Optimizer::new();
2155 optimizer
2156 .card_estimator
2157 .add_table_stats("Person", cardinality::TableStats::new(1000));
2158 optimizer
2159 .card_estimator
2160 .add_table_stats("Company", cardinality::TableStats::new(100));
2161
2162 let optimized = optimizer.optimize(plan).unwrap();
2163
2164 fn has_multi_way_join(op: &LogicalOperator) -> bool {
2166 match op {
2167 LogicalOperator::MultiWayJoin(_) => true,
2168 LogicalOperator::Return(ret) => has_multi_way_join(&ret.input),
2169 LogicalOperator::Filter(f) => has_multi_way_join(&f.input),
2170 LogicalOperator::Project(p) => has_multi_way_join(&p.input),
2171 LogicalOperator::Join(j) => {
2172 has_multi_way_join(&j.left) || has_multi_way_join(&j.right)
2173 }
2174 _ => false,
2175 }
2176 }
2177
2178 assert!(
2179 !has_multi_way_join(&optimized.root),
2180 "Acyclic join should NOT produce MultiWayJoin"
2181 );
2182 }
2183}