1pub mod cardinality;
15pub mod cost;
16pub mod join_order;
17
18pub use cardinality::{
19 CardinalityEstimator, ColumnStats, EstimationLog, SelectivityConfig, TableStats,
20};
21pub use cost::{Cost, CostModel};
22pub use join_order::{BitSet, DPccp, JoinGraph, JoinGraphBuilder, JoinPlan};
23
24use crate::query::plan::{
25 FilterOp, JoinCondition, LogicalExpression, LogicalOperator, LogicalPlan, MultiWayJoinOp,
26};
27use grafeo_common::grafeo_debug_span;
28use grafeo_common::utils::error::Result;
29use std::collections::HashSet;
30
31#[derive(Debug, Clone)]
33struct JoinInfo {
34 left_var: String,
35 right_var: String,
36 left_expr: LogicalExpression,
37 right_expr: LogicalExpression,
38}
39
40#[derive(Debug, Clone, PartialEq, Eq, Hash)]
42enum RequiredColumn {
43 Variable(String),
45 Property(String, String),
47}
48
49pub struct Optimizer {
54 enable_filter_pushdown: bool,
56 enable_join_reorder: bool,
58 enable_projection_pushdown: bool,
60 cost_model: CostModel,
62 card_estimator: CardinalityEstimator,
64}
65
66impl Optimizer {
67 #[must_use]
69 pub fn new() -> Self {
70 Self {
71 enable_filter_pushdown: true,
72 enable_join_reorder: true,
73 enable_projection_pushdown: true,
74 cost_model: CostModel::new(),
75 card_estimator: CardinalityEstimator::new(),
76 }
77 }
78
79 #[cfg(feature = "lpg")]
85 #[must_use]
86 pub fn from_store(store: &grafeo_core::graph::lpg::LpgStore) -> Self {
87 store.ensure_statistics_fresh();
88 let stats = store.statistics();
89 Self::from_statistics(&stats)
90 }
91
92 #[must_use]
99 pub fn from_graph_store(store: &dyn grafeo_core::graph::GraphStore) -> Self {
100 let stats = store.statistics();
101 Self::from_statistics(&stats)
102 }
103
104 #[cfg(feature = "triple-store")]
109 #[must_use]
110 pub fn from_rdf_statistics(rdf_stats: grafeo_core::statistics::RdfStatistics) -> Self {
111 let total = rdf_stats.total_triples;
112 let estimator = CardinalityEstimator::from_rdf_statistics(rdf_stats);
113 Self {
114 enable_filter_pushdown: true,
115 enable_join_reorder: true,
116 enable_projection_pushdown: true,
117 cost_model: CostModel::new().with_graph_totals(total, total),
118 card_estimator: estimator,
119 }
120 }
121
122 #[must_use]
127 fn from_statistics(stats: &grafeo_core::statistics::Statistics) -> Self {
128 let estimator = CardinalityEstimator::from_statistics(stats);
129
130 let avg_fanout = if stats.total_nodes > 0 {
131 (stats.total_edges as f64 / stats.total_nodes as f64).max(1.0)
132 } else {
133 10.0
134 };
135
136 let edge_type_degrees: std::collections::HashMap<String, (f64, f64)> = stats
137 .edge_types
138 .iter()
139 .map(|(name, et)| (name.clone(), (et.avg_out_degree, et.avg_in_degree)))
140 .collect();
141
142 let label_cardinalities: std::collections::HashMap<String, u64> = stats
143 .labels
144 .iter()
145 .map(|(name, ls)| (name.clone(), ls.node_count))
146 .collect();
147
148 Self {
149 enable_filter_pushdown: true,
150 enable_join_reorder: true,
151 enable_projection_pushdown: true,
152 cost_model: CostModel::new()
153 .with_avg_fanout(avg_fanout)
154 .with_edge_type_degrees(edge_type_degrees)
155 .with_label_cardinalities(label_cardinalities)
156 .with_graph_totals(stats.total_nodes, stats.total_edges),
157 card_estimator: estimator,
158 }
159 }
160
161 pub fn with_filter_pushdown(mut self, enabled: bool) -> Self {
163 self.enable_filter_pushdown = enabled;
164 self
165 }
166
167 pub fn with_join_reorder(mut self, enabled: bool) -> Self {
169 self.enable_join_reorder = enabled;
170 self
171 }
172
173 pub fn with_projection_pushdown(mut self, enabled: bool) -> Self {
175 self.enable_projection_pushdown = enabled;
176 self
177 }
178
179 pub fn with_cost_model(mut self, cost_model: CostModel) -> Self {
181 self.cost_model = cost_model;
182 self
183 }
184
185 pub fn with_cardinality_estimator(mut self, estimator: CardinalityEstimator) -> Self {
187 self.card_estimator = estimator;
188 self
189 }
190
191 pub fn with_selectivity_config(mut self, config: SelectivityConfig) -> Self {
193 self.card_estimator = CardinalityEstimator::with_selectivity_config(config);
194 self
195 }
196
197 pub fn cost_model(&self) -> &CostModel {
199 &self.cost_model
200 }
201
202 pub fn cardinality_estimator(&self) -> &CardinalityEstimator {
204 &self.card_estimator
205 }
206
207 pub fn estimate_cost(&self, plan: &LogicalPlan) -> Cost {
213 self.cost_model
214 .estimate_tree(&plan.root, &self.card_estimator)
215 }
216
217 pub fn estimate_cardinality(&self, plan: &LogicalPlan) -> f64 {
219 self.card_estimator.estimate(&plan.root)
220 }
221
222 pub fn optimize(&self, plan: LogicalPlan) -> Result<LogicalPlan> {
228 let _span = grafeo_debug_span!("grafeo::query::optimize");
229 let mut root = plan.root;
230
231 if self.enable_filter_pushdown {
233 root = self.push_filters_down(root);
234 }
235
236 if self.enable_join_reorder {
237 root = self.reorder_joins(root);
238 }
239
240 if self.enable_projection_pushdown {
241 root = self.push_projections_down(root);
242 }
243
244 Ok(LogicalPlan {
245 root,
246 explain: plan.explain,
247 profile: plan.profile,
248 default_params: plan.default_params,
249 })
250 }
251
252 fn push_projections_down(&self, op: LogicalOperator) -> LogicalOperator {
259 let required = self.collect_required_columns(&op);
261
262 self.push_projections_recursive(op, &required)
264 }
265
266 fn collect_required_columns(&self, op: &LogicalOperator) -> HashSet<RequiredColumn> {
268 let mut required = HashSet::new();
269 Self::collect_required_recursive(op, &mut required);
270 required
271 }
272
273 fn collect_required_recursive(op: &LogicalOperator, required: &mut HashSet<RequiredColumn>) {
275 match op {
276 LogicalOperator::Return(ret) => {
277 for item in &ret.items {
278 Self::collect_from_expression(&item.expression, required);
279 }
280 Self::collect_required_recursive(&ret.input, required);
281 }
282 LogicalOperator::Project(proj) => {
283 for p in &proj.projections {
284 Self::collect_from_expression(&p.expression, required);
285 }
286 Self::collect_required_recursive(&proj.input, required);
287 }
288 LogicalOperator::Filter(filter) => {
289 Self::collect_from_expression(&filter.predicate, required);
290 Self::collect_required_recursive(&filter.input, required);
291 }
292 LogicalOperator::Sort(sort) => {
293 for key in &sort.keys {
294 Self::collect_from_expression(&key.expression, required);
295 }
296 Self::collect_required_recursive(&sort.input, required);
297 }
298 LogicalOperator::Aggregate(agg) => {
299 for expr in &agg.group_by {
300 Self::collect_from_expression(expr, required);
301 }
302 for agg_expr in &agg.aggregates {
303 if let Some(ref expr) = agg_expr.expression {
304 Self::collect_from_expression(expr, required);
305 }
306 }
307 if let Some(ref having) = agg.having {
308 Self::collect_from_expression(having, required);
309 }
310 Self::collect_required_recursive(&agg.input, required);
311 }
312 LogicalOperator::Join(join) => {
313 for cond in &join.conditions {
314 Self::collect_from_expression(&cond.left, required);
315 Self::collect_from_expression(&cond.right, required);
316 }
317 Self::collect_required_recursive(&join.left, required);
318 Self::collect_required_recursive(&join.right, required);
319 }
320 LogicalOperator::Expand(expand) => {
321 required.insert(RequiredColumn::Variable(expand.from_variable.clone()));
323 required.insert(RequiredColumn::Variable(expand.to_variable.clone()));
324 if let Some(ref edge_var) = expand.edge_variable {
325 required.insert(RequiredColumn::Variable(edge_var.clone()));
326 }
327 Self::collect_required_recursive(&expand.input, required);
328 }
329 LogicalOperator::Limit(limit) => {
330 Self::collect_required_recursive(&limit.input, required);
331 }
332 LogicalOperator::Skip(skip) => {
333 Self::collect_required_recursive(&skip.input, required);
334 }
335 LogicalOperator::Distinct(distinct) => {
336 Self::collect_required_recursive(&distinct.input, required);
337 }
338 LogicalOperator::NodeScan(scan) => {
339 required.insert(RequiredColumn::Variable(scan.variable.clone()));
340 }
341 LogicalOperator::EdgeScan(scan) => {
342 required.insert(RequiredColumn::Variable(scan.variable.clone()));
343 }
344 LogicalOperator::MultiWayJoin(mwj) => {
345 for cond in &mwj.conditions {
346 Self::collect_from_expression(&cond.left, required);
347 Self::collect_from_expression(&cond.right, required);
348 }
349 for input in &mwj.inputs {
350 Self::collect_required_recursive(input, required);
351 }
352 }
353 _ => {}
354 }
355 }
356
357 fn collect_from_expression(expr: &LogicalExpression, required: &mut HashSet<RequiredColumn>) {
359 match expr {
360 LogicalExpression::Variable(var) => {
361 required.insert(RequiredColumn::Variable(var.clone()));
362 }
363 LogicalExpression::Property { variable, property } => {
364 required.insert(RequiredColumn::Property(variable.clone(), property.clone()));
365 required.insert(RequiredColumn::Variable(variable.clone()));
366 }
367 LogicalExpression::Binary { left, right, .. } => {
368 Self::collect_from_expression(left, required);
369 Self::collect_from_expression(right, required);
370 }
371 LogicalExpression::Unary { operand, .. } => {
372 Self::collect_from_expression(operand, required);
373 }
374 LogicalExpression::FunctionCall { args, .. } => {
375 for arg in args {
376 Self::collect_from_expression(arg, required);
377 }
378 }
379 LogicalExpression::List(items) => {
380 for item in items {
381 Self::collect_from_expression(item, required);
382 }
383 }
384 LogicalExpression::Map(pairs) => {
385 for (_, value) in pairs {
386 Self::collect_from_expression(value, required);
387 }
388 }
389 LogicalExpression::IndexAccess { base, index } => {
390 Self::collect_from_expression(base, required);
391 Self::collect_from_expression(index, required);
392 }
393 LogicalExpression::SliceAccess { base, start, end } => {
394 Self::collect_from_expression(base, required);
395 if let Some(s) = start {
396 Self::collect_from_expression(s, required);
397 }
398 if let Some(e) = end {
399 Self::collect_from_expression(e, required);
400 }
401 }
402 LogicalExpression::Case {
403 operand,
404 when_clauses,
405 else_clause,
406 } => {
407 if let Some(op) = operand {
408 Self::collect_from_expression(op, required);
409 }
410 for (cond, result) in when_clauses {
411 Self::collect_from_expression(cond, required);
412 Self::collect_from_expression(result, required);
413 }
414 if let Some(else_expr) = else_clause {
415 Self::collect_from_expression(else_expr, required);
416 }
417 }
418 LogicalExpression::Labels(var)
419 | LogicalExpression::Type(var)
420 | LogicalExpression::Id(var) => {
421 required.insert(RequiredColumn::Variable(var.clone()));
422 }
423 LogicalExpression::ListComprehension {
424 list_expr,
425 filter_expr,
426 map_expr,
427 ..
428 } => {
429 Self::collect_from_expression(list_expr, required);
430 if let Some(filter) = filter_expr {
431 Self::collect_from_expression(filter, required);
432 }
433 Self::collect_from_expression(map_expr, required);
434 }
435 _ => {}
436 }
437 }
438
439 fn push_projections_recursive(
441 &self,
442 op: LogicalOperator,
443 required: &HashSet<RequiredColumn>,
444 ) -> LogicalOperator {
445 match op {
446 LogicalOperator::Return(mut ret) => {
447 ret.input = Box::new(self.push_projections_recursive(*ret.input, required));
448 LogicalOperator::Return(ret)
449 }
450 LogicalOperator::Project(mut proj) => {
451 proj.input = Box::new(self.push_projections_recursive(*proj.input, required));
452 LogicalOperator::Project(proj)
453 }
454 LogicalOperator::Filter(mut filter) => {
455 filter.input = Box::new(self.push_projections_recursive(*filter.input, required));
456 LogicalOperator::Filter(filter)
457 }
458 LogicalOperator::Sort(mut sort) => {
459 sort.input = Box::new(self.push_projections_recursive(*sort.input, required));
462 LogicalOperator::Sort(sort)
463 }
464 LogicalOperator::Aggregate(mut agg) => {
465 agg.input = Box::new(self.push_projections_recursive(*agg.input, required));
466 LogicalOperator::Aggregate(agg)
467 }
468 LogicalOperator::Join(mut join) => {
469 let left_vars = self.collect_output_variables(&join.left);
472 let right_vars = self.collect_output_variables(&join.right);
473
474 let left_required: HashSet<_> = required
476 .iter()
477 .filter(|c| match c {
478 RequiredColumn::Variable(v) => left_vars.contains(v),
479 RequiredColumn::Property(v, _) => left_vars.contains(v),
480 })
481 .cloned()
482 .collect();
483
484 let right_required: HashSet<_> = required
485 .iter()
486 .filter(|c| match c {
487 RequiredColumn::Variable(v) => right_vars.contains(v),
488 RequiredColumn::Property(v, _) => right_vars.contains(v),
489 })
490 .cloned()
491 .collect();
492
493 join.left = Box::new(self.push_projections_recursive(*join.left, &left_required));
494 join.right =
495 Box::new(self.push_projections_recursive(*join.right, &right_required));
496 LogicalOperator::Join(join)
497 }
498 LogicalOperator::Expand(mut expand) => {
499 expand.input = Box::new(self.push_projections_recursive(*expand.input, required));
500 LogicalOperator::Expand(expand)
501 }
502 LogicalOperator::Limit(mut limit) => {
503 limit.input = Box::new(self.push_projections_recursive(*limit.input, required));
504 LogicalOperator::Limit(limit)
505 }
506 LogicalOperator::Skip(mut skip) => {
507 skip.input = Box::new(self.push_projections_recursive(*skip.input, required));
508 LogicalOperator::Skip(skip)
509 }
510 LogicalOperator::Distinct(mut distinct) => {
511 distinct.input =
512 Box::new(self.push_projections_recursive(*distinct.input, required));
513 LogicalOperator::Distinct(distinct)
514 }
515 LogicalOperator::MapCollect(mut mc) => {
516 mc.input = Box::new(self.push_projections_recursive(*mc.input, required));
517 LogicalOperator::MapCollect(mc)
518 }
519 LogicalOperator::MultiWayJoin(mut mwj) => {
520 mwj.inputs = mwj
521 .inputs
522 .into_iter()
523 .map(|input| self.push_projections_recursive(input, required))
524 .collect();
525 LogicalOperator::MultiWayJoin(mwj)
526 }
527 other => other,
528 }
529 }
530
531 fn reorder_joins(&self, op: LogicalOperator) -> LogicalOperator {
538 let op = self.reorder_joins_recursive(op);
540
541 if let Some((relations, conditions)) = self.extract_join_tree(&op)
543 && relations.len() >= 2
544 && let Some(optimized) = self.optimize_join_order(&relations, &conditions)
545 {
546 return optimized;
547 }
548
549 op
550 }
551
552 fn reorder_joins_recursive(&self, op: LogicalOperator) -> LogicalOperator {
554 match op {
555 LogicalOperator::Return(mut ret) => {
556 ret.input = Box::new(self.reorder_joins(*ret.input));
557 LogicalOperator::Return(ret)
558 }
559 LogicalOperator::Project(mut proj) => {
560 proj.input = Box::new(self.reorder_joins(*proj.input));
561 LogicalOperator::Project(proj)
562 }
563 LogicalOperator::Filter(mut filter) => {
564 filter.input = Box::new(self.reorder_joins(*filter.input));
565 LogicalOperator::Filter(filter)
566 }
567 LogicalOperator::Limit(mut limit) => {
568 limit.input = Box::new(self.reorder_joins(*limit.input));
569 LogicalOperator::Limit(limit)
570 }
571 LogicalOperator::Skip(mut skip) => {
572 skip.input = Box::new(self.reorder_joins(*skip.input));
573 LogicalOperator::Skip(skip)
574 }
575 LogicalOperator::Sort(mut sort) => {
576 sort.input = Box::new(self.reorder_joins(*sort.input));
577 LogicalOperator::Sort(sort)
578 }
579 LogicalOperator::Distinct(mut distinct) => {
580 distinct.input = Box::new(self.reorder_joins(*distinct.input));
581 LogicalOperator::Distinct(distinct)
582 }
583 LogicalOperator::Aggregate(mut agg) => {
584 agg.input = Box::new(self.reorder_joins(*agg.input));
585 LogicalOperator::Aggregate(agg)
586 }
587 LogicalOperator::Expand(mut expand) => {
588 expand.input = Box::new(self.reorder_joins(*expand.input));
589 LogicalOperator::Expand(expand)
590 }
591 LogicalOperator::MapCollect(mut mc) => {
592 mc.input = Box::new(self.reorder_joins(*mc.input));
593 LogicalOperator::MapCollect(mc)
594 }
595 LogicalOperator::MultiWayJoin(mut mwj) => {
596 mwj.inputs = mwj
597 .inputs
598 .into_iter()
599 .map(|input| self.reorder_joins(input))
600 .collect();
601 LogicalOperator::MultiWayJoin(mwj)
602 }
603 other => other,
605 }
606 }
607
608 fn extract_join_tree(
612 &self,
613 op: &LogicalOperator,
614 ) -> Option<(Vec<(String, LogicalOperator)>, Vec<JoinInfo>)> {
615 let mut relations = Vec::new();
616 let mut join_conditions = Vec::new();
617
618 if !self.collect_join_tree(op, &mut relations, &mut join_conditions) {
619 return None;
620 }
621
622 if relations.len() < 2 {
623 return None;
624 }
625
626 Some((relations, join_conditions))
627 }
628
629 fn collect_join_tree(
633 &self,
634 op: &LogicalOperator,
635 relations: &mut Vec<(String, LogicalOperator)>,
636 conditions: &mut Vec<JoinInfo>,
637 ) -> bool {
638 match op {
639 LogicalOperator::Join(join) => {
640 let left_ok = self.collect_join_tree(&join.left, relations, conditions);
642 let right_ok = self.collect_join_tree(&join.right, relations, conditions);
643
644 for cond in &join.conditions {
646 if let (Some(left_var), Some(right_var)) = (
647 self.extract_variable_from_expr(&cond.left),
648 self.extract_variable_from_expr(&cond.right),
649 ) {
650 conditions.push(JoinInfo {
651 left_var,
652 right_var,
653 left_expr: cond.left.clone(),
654 right_expr: cond.right.clone(),
655 });
656 }
657 }
658
659 left_ok && right_ok
660 }
661 LogicalOperator::NodeScan(scan) => {
662 relations.push((scan.variable.clone(), op.clone()));
663 true
664 }
665 LogicalOperator::EdgeScan(scan) => {
666 relations.push((scan.variable.clone(), op.clone()));
667 true
668 }
669 LogicalOperator::Filter(filter) => {
670 self.collect_join_tree(&filter.input, relations, conditions)
672 }
673 LogicalOperator::Expand(expand) => {
674 relations.push((expand.to_variable.clone(), op.clone()));
677 true
678 }
679 #[cfg(feature = "triple-store")]
680 LogicalOperator::TripleScan(scan) => {
681 let name = scan
685 .subject
686 .as_variable()
687 .or_else(|| scan.predicate.as_variable())
688 .or_else(|| scan.object.as_variable())
689 .map_or_else(|| format!("__tp_{}", relations.len()), String::from);
690 relations.push((name, op.clone()));
691 true
692 }
693 _ => false,
694 }
695 }
696
697 fn extract_variable_from_expr(&self, expr: &LogicalExpression) -> Option<String> {
699 match expr {
700 LogicalExpression::Variable(v) => Some(v.clone()),
701 LogicalExpression::Property { variable, .. } => Some(variable.clone()),
702 _ => None,
703 }
704 }
705
706 fn optimize_join_order(
709 &self,
710 relations: &[(String, LogicalOperator)],
711 conditions: &[JoinInfo],
712 ) -> Option<LogicalOperator> {
713 use join_order::{DPccp, JoinGraphBuilder};
714
715 let mut builder = JoinGraphBuilder::new();
717
718 for (var, relation) in relations {
719 builder.add_relation(var, relation.clone());
720 }
721
722 for cond in conditions {
723 builder.add_join_condition(
724 &cond.left_var,
725 &cond.right_var,
726 cond.left_expr.clone(),
727 cond.right_expr.clone(),
728 );
729 }
730
731 let graph = builder.build();
732
733 if graph.is_cyclic() && relations.len() >= 3 {
738 let mut var_counts: std::collections::HashMap<&str, usize> =
740 std::collections::HashMap::new();
741 for cond in conditions {
742 *var_counts.entry(&cond.left_var).or_default() += 1;
743 *var_counts.entry(&cond.right_var).or_default() += 1;
744 }
745 let shared_variables: Vec<String> = var_counts
746 .into_iter()
747 .filter(|(_, count)| *count >= 2)
748 .map(|(var, _)| var.to_string())
749 .collect();
750
751 let join_conditions: Vec<JoinCondition> = conditions
752 .iter()
753 .map(|c| JoinCondition {
754 left: c.left_expr.clone(),
755 right: c.right_expr.clone(),
756 })
757 .collect();
758
759 return Some(LogicalOperator::MultiWayJoin(MultiWayJoinOp {
760 inputs: relations.iter().map(|(_, rel)| rel.clone()).collect(),
761 conditions: join_conditions,
762 shared_variables,
763 }));
764 }
765
766 let mut dpccp = DPccp::new(&graph, &self.cost_model, &self.card_estimator);
768 let plan = dpccp.optimize()?;
769
770 Some(plan.operator)
771 }
772
773 fn push_filters_down(&self, op: LogicalOperator) -> LogicalOperator {
778 match op {
779 LogicalOperator::Filter(filter) => {
781 let optimized_input = self.push_filters_down(*filter.input);
782 self.try_push_filter_into(filter.predicate, optimized_input)
783 }
784 LogicalOperator::Return(mut ret) => {
786 ret.input = Box::new(self.push_filters_down(*ret.input));
787 LogicalOperator::Return(ret)
788 }
789 LogicalOperator::Project(mut proj) => {
790 proj.input = Box::new(self.push_filters_down(*proj.input));
791 LogicalOperator::Project(proj)
792 }
793 LogicalOperator::Limit(mut limit) => {
794 limit.input = Box::new(self.push_filters_down(*limit.input));
795 LogicalOperator::Limit(limit)
796 }
797 LogicalOperator::Skip(mut skip) => {
798 skip.input = Box::new(self.push_filters_down(*skip.input));
799 LogicalOperator::Skip(skip)
800 }
801 LogicalOperator::Sort(mut sort) => {
802 sort.input = Box::new(self.push_filters_down(*sort.input));
803 LogicalOperator::Sort(sort)
804 }
805 LogicalOperator::Distinct(mut distinct) => {
806 distinct.input = Box::new(self.push_filters_down(*distinct.input));
807 LogicalOperator::Distinct(distinct)
808 }
809 LogicalOperator::Expand(mut expand) => {
810 expand.input = Box::new(self.push_filters_down(*expand.input));
811 LogicalOperator::Expand(expand)
812 }
813 LogicalOperator::Join(mut join) => {
814 join.left = Box::new(self.push_filters_down(*join.left));
815 join.right = Box::new(self.push_filters_down(*join.right));
816 LogicalOperator::Join(join)
817 }
818 LogicalOperator::Aggregate(mut agg) => {
819 agg.input = Box::new(self.push_filters_down(*agg.input));
820 LogicalOperator::Aggregate(agg)
821 }
822 LogicalOperator::MapCollect(mut mc) => {
823 mc.input = Box::new(self.push_filters_down(*mc.input));
824 LogicalOperator::MapCollect(mc)
825 }
826 LogicalOperator::MultiWayJoin(mut mwj) => {
827 mwj.inputs = mwj
828 .inputs
829 .into_iter()
830 .map(|input| self.push_filters_down(input))
831 .collect();
832 LogicalOperator::MultiWayJoin(mwj)
833 }
834 other => other,
836 }
837 }
838
839 fn try_push_filter_into(
844 &self,
845 predicate: LogicalExpression,
846 op: LogicalOperator,
847 ) -> LogicalOperator {
848 match op {
849 LogicalOperator::Project(mut proj) => {
851 let predicate_vars = self.extract_variables(&predicate);
852 let computed_vars = self.extract_projection_aliases(&proj.projections);
853
854 if predicate_vars.is_disjoint(&computed_vars) {
856 proj.input = Box::new(self.try_push_filter_into(predicate, *proj.input));
857 LogicalOperator::Project(proj)
858 } else {
859 LogicalOperator::Filter(FilterOp {
861 predicate,
862 pushdown_hint: None,
863 input: Box::new(LogicalOperator::Project(proj)),
864 })
865 }
866 }
867
868 LogicalOperator::Return(mut ret) => {
870 ret.input = Box::new(self.try_push_filter_into(predicate, *ret.input));
871 LogicalOperator::Return(ret)
872 }
873
874 LogicalOperator::Expand(mut expand) => {
876 let predicate_vars = self.extract_variables(&predicate);
877
878 let mut introduced_vars = vec![&expand.to_variable];
883 if let Some(ref edge_var) = expand.edge_variable {
884 introduced_vars.push(edge_var);
885 }
886 if let Some(ref path_alias) = expand.path_alias {
887 introduced_vars.push(path_alias);
888 }
889
890 let uses_introduced_vars =
892 predicate_vars.iter().any(|v| introduced_vars.contains(&v));
893
894 if !uses_introduced_vars {
895 expand.input = Box::new(self.try_push_filter_into(predicate, *expand.input));
897 LogicalOperator::Expand(expand)
898 } else {
899 LogicalOperator::Filter(FilterOp {
901 predicate,
902 pushdown_hint: None,
903 input: Box::new(LogicalOperator::Expand(expand)),
904 })
905 }
906 }
907
908 LogicalOperator::Join(mut join) => {
910 let predicate_vars = self.extract_variables(&predicate);
911 let left_vars = self.collect_output_variables(&join.left);
912 let right_vars = self.collect_output_variables(&join.right);
913
914 let uses_left = predicate_vars.iter().any(|v| left_vars.contains(v));
915 let uses_right = predicate_vars.iter().any(|v| right_vars.contains(v));
916
917 if uses_left && !uses_right {
918 join.left = Box::new(self.try_push_filter_into(predicate, *join.left));
920 LogicalOperator::Join(join)
921 } else if uses_right && !uses_left {
922 join.right = Box::new(self.try_push_filter_into(predicate, *join.right));
924 LogicalOperator::Join(join)
925 } else {
926 LogicalOperator::Filter(FilterOp {
928 predicate,
929 pushdown_hint: None,
930 input: Box::new(LogicalOperator::Join(join)),
931 })
932 }
933 }
934
935 LogicalOperator::Aggregate(agg) => LogicalOperator::Filter(FilterOp {
937 predicate,
938 pushdown_hint: None,
939 input: Box::new(LogicalOperator::Aggregate(agg)),
940 }),
941
942 LogicalOperator::NodeScan(scan) => LogicalOperator::Filter(FilterOp {
944 predicate,
945 pushdown_hint: None,
946 input: Box::new(LogicalOperator::NodeScan(scan)),
947 }),
948
949 other => LogicalOperator::Filter(FilterOp {
951 predicate,
952 pushdown_hint: None,
953 input: Box::new(other),
954 }),
955 }
956 }
957
958 fn collect_output_variables(&self, op: &LogicalOperator) -> HashSet<String> {
966 let mut vars = HashSet::new();
967 Self::collect_output_variables_recursive(op, &mut vars);
968 vars
969 }
970
971 fn collect_output_variables_recursive(op: &LogicalOperator, vars: &mut HashSet<String>) {
973 match op {
974 LogicalOperator::NodeScan(scan) => {
975 vars.insert(scan.variable.clone());
976 }
977 LogicalOperator::EdgeScan(scan) => {
978 vars.insert(scan.variable.clone());
979 }
980 LogicalOperator::Expand(expand) => {
981 vars.insert(expand.to_variable.clone());
982 if let Some(edge_var) = &expand.edge_variable {
983 vars.insert(edge_var.clone());
984 }
985 Self::collect_output_variables_recursive(&expand.input, vars);
986 }
987 LogicalOperator::Filter(filter) => {
988 Self::collect_output_variables_recursive(&filter.input, vars);
989 }
990 LogicalOperator::Project(proj) => {
991 for p in &proj.projections {
992 if let Some(alias) = &p.alias {
993 vars.insert(alias.clone());
994 }
995 }
996 Self::collect_output_variables_recursive(&proj.input, vars);
997 }
998 LogicalOperator::Join(join) => {
999 Self::collect_output_variables_recursive(&join.left, vars);
1000 Self::collect_output_variables_recursive(&join.right, vars);
1001 }
1002 LogicalOperator::Aggregate(agg) => {
1003 for expr in &agg.group_by {
1004 Self::collect_variables(expr, vars);
1005 }
1006 for agg_expr in &agg.aggregates {
1007 if let Some(alias) = &agg_expr.alias {
1008 vars.insert(alias.clone());
1009 }
1010 }
1011 }
1012 LogicalOperator::Return(ret) => {
1013 Self::collect_output_variables_recursive(&ret.input, vars);
1014 }
1015 LogicalOperator::Limit(limit) => {
1016 Self::collect_output_variables_recursive(&limit.input, vars);
1017 }
1018 LogicalOperator::Skip(skip) => {
1019 Self::collect_output_variables_recursive(&skip.input, vars);
1020 }
1021 LogicalOperator::Sort(sort) => {
1022 Self::collect_output_variables_recursive(&sort.input, vars);
1023 }
1024 LogicalOperator::Distinct(distinct) => {
1025 Self::collect_output_variables_recursive(&distinct.input, vars);
1026 }
1027 #[cfg(feature = "triple-store")]
1028 LogicalOperator::TripleScan(scan) => {
1029 if let Some(v) = scan.subject.as_variable() {
1030 vars.insert(v.to_string());
1031 }
1032 if let Some(v) = scan.predicate.as_variable() {
1033 vars.insert(v.to_string());
1034 }
1035 if let Some(v) = scan.object.as_variable() {
1036 vars.insert(v.to_string());
1037 }
1038 if let Some(ref g) = scan.graph
1039 && let Some(v) = g.as_variable()
1040 {
1041 vars.insert(v.to_string());
1042 }
1043 }
1044 _ => {}
1045 }
1046 }
1047
1048 fn extract_variables(&self, expr: &LogicalExpression) -> HashSet<String> {
1050 let mut vars = HashSet::new();
1051 Self::collect_variables(expr, &mut vars);
1052 vars
1053 }
1054
1055 fn collect_variables(expr: &LogicalExpression, vars: &mut HashSet<String>) {
1057 match expr {
1058 LogicalExpression::Variable(name) => {
1059 vars.insert(name.clone());
1060 }
1061 LogicalExpression::Property { variable, .. } => {
1062 vars.insert(variable.clone());
1063 }
1064 LogicalExpression::Binary { left, right, .. } => {
1065 Self::collect_variables(left, vars);
1066 Self::collect_variables(right, vars);
1067 }
1068 LogicalExpression::Unary { operand, .. } => {
1069 Self::collect_variables(operand, vars);
1070 }
1071 LogicalExpression::FunctionCall { args, .. } => {
1072 for arg in args {
1073 Self::collect_variables(arg, vars);
1074 }
1075 }
1076 LogicalExpression::List(items) => {
1077 for item in items {
1078 Self::collect_variables(item, vars);
1079 }
1080 }
1081 LogicalExpression::Map(pairs) => {
1082 for (_, value) in pairs {
1083 Self::collect_variables(value, vars);
1084 }
1085 }
1086 LogicalExpression::IndexAccess { base, index } => {
1087 Self::collect_variables(base, vars);
1088 Self::collect_variables(index, vars);
1089 }
1090 LogicalExpression::SliceAccess { base, start, end } => {
1091 Self::collect_variables(base, vars);
1092 if let Some(s) = start {
1093 Self::collect_variables(s, vars);
1094 }
1095 if let Some(e) = end {
1096 Self::collect_variables(e, vars);
1097 }
1098 }
1099 LogicalExpression::Case {
1100 operand,
1101 when_clauses,
1102 else_clause,
1103 } => {
1104 if let Some(op) = operand {
1105 Self::collect_variables(op, vars);
1106 }
1107 for (cond, result) in when_clauses {
1108 Self::collect_variables(cond, vars);
1109 Self::collect_variables(result, vars);
1110 }
1111 if let Some(else_expr) = else_clause {
1112 Self::collect_variables(else_expr, vars);
1113 }
1114 }
1115 LogicalExpression::Labels(var)
1116 | LogicalExpression::Type(var)
1117 | LogicalExpression::Id(var) => {
1118 vars.insert(var.clone());
1119 }
1120 LogicalExpression::Literal(_) | LogicalExpression::Parameter(_) => {}
1121 LogicalExpression::ListComprehension {
1122 list_expr,
1123 filter_expr,
1124 map_expr,
1125 ..
1126 } => {
1127 Self::collect_variables(list_expr, vars);
1128 if let Some(filter) = filter_expr {
1129 Self::collect_variables(filter, vars);
1130 }
1131 Self::collect_variables(map_expr, vars);
1132 }
1133 LogicalExpression::ListPredicate {
1134 list_expr,
1135 predicate,
1136 ..
1137 } => {
1138 Self::collect_variables(list_expr, vars);
1139 Self::collect_variables(predicate, vars);
1140 }
1141 LogicalExpression::ExistsSubquery(_)
1142 | LogicalExpression::CountSubquery(_)
1143 | LogicalExpression::ValueSubquery(_) => {
1144 }
1146 LogicalExpression::PatternComprehension { projection, .. } => {
1147 Self::collect_variables(projection, vars);
1148 }
1149 LogicalExpression::MapProjection { base, entries } => {
1150 vars.insert(base.clone());
1151 for entry in entries {
1152 if let crate::query::plan::MapProjectionEntry::LiteralEntry(_, expr) = entry {
1153 Self::collect_variables(expr, vars);
1154 }
1155 }
1156 }
1157 LogicalExpression::Reduce {
1158 initial,
1159 list,
1160 expression,
1161 ..
1162 } => {
1163 Self::collect_variables(initial, vars);
1164 Self::collect_variables(list, vars);
1165 Self::collect_variables(expression, vars);
1166 }
1167 }
1168 }
1169
1170 fn extract_projection_aliases(
1172 &self,
1173 projections: &[crate::query::plan::Projection],
1174 ) -> HashSet<String> {
1175 projections.iter().filter_map(|p| p.alias.clone()).collect()
1176 }
1177}
1178
1179impl Default for Optimizer {
1180 fn default() -> Self {
1181 Self::new()
1182 }
1183}
1184
1185#[cfg(test)]
1186mod tests {
1187 use super::*;
1188 use crate::query::plan::{
1189 AggregateExpr, AggregateFunction, AggregateOp, BinaryOp, DistinctOp, ExpandDirection,
1190 ExpandOp, JoinOp, JoinType, LimitOp, NodeScanOp, PathMode, ProjectOp, Projection,
1191 ReturnItem, ReturnOp, SkipOp, SortKey, SortOp, SortOrder, UnaryOp,
1192 };
1193 use grafeo_common::types::Value;
1194
1195 #[test]
1196 fn test_optimizer_filter_pushdown_simple() {
1197 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1202 items: vec![ReturnItem {
1203 expression: LogicalExpression::Variable("n".to_string()),
1204 alias: None,
1205 }],
1206 distinct: false,
1207 input: Box::new(LogicalOperator::Filter(FilterOp {
1208 predicate: LogicalExpression::Binary {
1209 left: Box::new(LogicalExpression::Property {
1210 variable: "n".to_string(),
1211 property: "age".to_string(),
1212 }),
1213 op: BinaryOp::Gt,
1214 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1215 },
1216 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1217 variable: "n".to_string(),
1218 label: Some("Person".to_string()),
1219 input: None,
1220 })),
1221 pushdown_hint: None,
1222 })),
1223 }));
1224
1225 let optimizer = Optimizer::new();
1226 let optimized = optimizer.optimize(plan).unwrap();
1227
1228 if let LogicalOperator::Return(ret) = &optimized.root
1230 && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1231 && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1232 {
1233 assert_eq!(scan.variable, "n");
1234 return;
1235 }
1236 panic!("Expected Return -> Filter -> NodeScan structure");
1237 }
1238
1239 #[test]
1240 fn test_optimizer_filter_pushdown_through_expand() {
1241 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1245 items: vec![ReturnItem {
1246 expression: LogicalExpression::Variable("b".to_string()),
1247 alias: None,
1248 }],
1249 distinct: false,
1250 input: Box::new(LogicalOperator::Filter(FilterOp {
1251 predicate: LogicalExpression::Binary {
1252 left: Box::new(LogicalExpression::Property {
1253 variable: "a".to_string(),
1254 property: "age".to_string(),
1255 }),
1256 op: BinaryOp::Gt,
1257 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1258 },
1259 pushdown_hint: None,
1260 input: Box::new(LogicalOperator::Expand(ExpandOp {
1261 from_variable: "a".to_string(),
1262 to_variable: "b".to_string(),
1263 edge_variable: None,
1264 direction: ExpandDirection::Outgoing,
1265 edge_types: vec!["KNOWS".to_string()],
1266 min_hops: 1,
1267 max_hops: Some(1),
1268 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1269 variable: "a".to_string(),
1270 label: Some("Person".to_string()),
1271 input: None,
1272 })),
1273 path_alias: None,
1274 path_mode: PathMode::Walk,
1275 })),
1276 })),
1277 }));
1278
1279 let optimizer = Optimizer::new();
1280 let optimized = optimizer.optimize(plan).unwrap();
1281
1282 if let LogicalOperator::Return(ret) = &optimized.root
1285 && let LogicalOperator::Expand(expand) = ret.input.as_ref()
1286 && let LogicalOperator::Filter(filter) = expand.input.as_ref()
1287 && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1288 {
1289 assert_eq!(scan.variable, "a");
1290 assert_eq!(expand.from_variable, "a");
1291 assert_eq!(expand.to_variable, "b");
1292 return;
1293 }
1294 panic!("Expected Return -> Expand -> Filter -> NodeScan structure");
1295 }
1296
1297 #[test]
1298 fn test_optimizer_filter_not_pushed_through_expand_for_target_var() {
1299 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1303 items: vec![ReturnItem {
1304 expression: LogicalExpression::Variable("a".to_string()),
1305 alias: None,
1306 }],
1307 distinct: false,
1308 input: Box::new(LogicalOperator::Filter(FilterOp {
1309 predicate: LogicalExpression::Binary {
1310 left: Box::new(LogicalExpression::Property {
1311 variable: "b".to_string(),
1312 property: "age".to_string(),
1313 }),
1314 op: BinaryOp::Gt,
1315 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1316 },
1317 pushdown_hint: None,
1318 input: Box::new(LogicalOperator::Expand(ExpandOp {
1319 from_variable: "a".to_string(),
1320 to_variable: "b".to_string(),
1321 edge_variable: None,
1322 direction: ExpandDirection::Outgoing,
1323 edge_types: vec!["KNOWS".to_string()],
1324 min_hops: 1,
1325 max_hops: Some(1),
1326 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1327 variable: "a".to_string(),
1328 label: Some("Person".to_string()),
1329 input: None,
1330 })),
1331 path_alias: None,
1332 path_mode: PathMode::Walk,
1333 })),
1334 })),
1335 }));
1336
1337 let optimizer = Optimizer::new();
1338 let optimized = optimizer.optimize(plan).unwrap();
1339
1340 if let LogicalOperator::Return(ret) = &optimized.root
1343 && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1344 {
1345 if let LogicalExpression::Binary { left, .. } = &filter.predicate
1347 && let LogicalExpression::Property { variable, .. } = left.as_ref()
1348 {
1349 assert_eq!(variable, "b");
1350 }
1351
1352 if let LogicalOperator::Expand(expand) = filter.input.as_ref()
1353 && let LogicalOperator::NodeScan(_) = expand.input.as_ref()
1354 {
1355 return;
1356 }
1357 }
1358 panic!("Expected Return -> Filter -> Expand -> NodeScan structure");
1359 }
1360
1361 #[test]
1362 fn test_optimizer_extract_variables() {
1363 let optimizer = Optimizer::new();
1364
1365 let expr = LogicalExpression::Binary {
1366 left: Box::new(LogicalExpression::Property {
1367 variable: "n".to_string(),
1368 property: "age".to_string(),
1369 }),
1370 op: BinaryOp::Gt,
1371 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1372 };
1373
1374 let vars = optimizer.extract_variables(&expr);
1375 assert_eq!(vars.len(), 1);
1376 assert!(vars.contains("n"));
1377 }
1378
1379 #[test]
1382 fn test_optimizer_default() {
1383 let optimizer = Optimizer::default();
1384 let plan = LogicalPlan::new(LogicalOperator::Empty);
1386 let result = optimizer.optimize(plan);
1387 assert!(result.is_ok());
1388 }
1389
1390 #[test]
1391 fn test_optimizer_with_filter_pushdown_disabled() {
1392 let optimizer = Optimizer::new().with_filter_pushdown(false);
1393
1394 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1395 items: vec![ReturnItem {
1396 expression: LogicalExpression::Variable("n".to_string()),
1397 alias: None,
1398 }],
1399 distinct: false,
1400 input: Box::new(LogicalOperator::Filter(FilterOp {
1401 predicate: LogicalExpression::Literal(Value::Bool(true)),
1402 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1403 variable: "n".to_string(),
1404 label: None,
1405 input: None,
1406 })),
1407 pushdown_hint: None,
1408 })),
1409 }));
1410
1411 let optimized = optimizer.optimize(plan).unwrap();
1412 if let LogicalOperator::Return(ret) = &optimized.root
1414 && let LogicalOperator::Filter(_) = ret.input.as_ref()
1415 {
1416 return;
1417 }
1418 panic!("Expected unchanged structure");
1419 }
1420
1421 #[test]
1422 fn test_optimizer_with_join_reorder_disabled() {
1423 let optimizer = Optimizer::new().with_join_reorder(false);
1424 assert!(
1425 optimizer
1426 .optimize(LogicalPlan::new(LogicalOperator::Empty))
1427 .is_ok()
1428 );
1429 }
1430
1431 #[test]
1432 fn test_optimizer_with_cost_model() {
1433 let cost_model = CostModel::new();
1434 let optimizer = Optimizer::new().with_cost_model(cost_model);
1435 assert!(
1436 optimizer
1437 .cost_model()
1438 .estimate(&LogicalOperator::Empty, 0.0)
1439 .total()
1440 < 0.001
1441 );
1442 }
1443
1444 #[test]
1445 fn test_optimizer_with_cardinality_estimator() {
1446 let mut estimator = CardinalityEstimator::new();
1447 estimator.add_table_stats("Test", TableStats::new(500));
1448 let optimizer = Optimizer::new().with_cardinality_estimator(estimator);
1449
1450 let scan = LogicalOperator::NodeScan(NodeScanOp {
1451 variable: "n".to_string(),
1452 label: Some("Test".to_string()),
1453 input: None,
1454 });
1455 let plan = LogicalPlan::new(scan);
1456
1457 let cardinality = optimizer.estimate_cardinality(&plan);
1458 assert!((cardinality - 500.0).abs() < 0.001);
1459 }
1460
1461 #[test]
1462 fn test_optimizer_estimate_cost() {
1463 let optimizer = Optimizer::new();
1464 let plan = LogicalPlan::new(LogicalOperator::NodeScan(NodeScanOp {
1465 variable: "n".to_string(),
1466 label: None,
1467 input: None,
1468 }));
1469
1470 let cost = optimizer.estimate_cost(&plan);
1471 assert!(cost.total() > 0.0);
1472 }
1473
1474 #[test]
1477 fn test_filter_pushdown_through_project() {
1478 let optimizer = Optimizer::new();
1479
1480 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1481 predicate: LogicalExpression::Binary {
1482 left: Box::new(LogicalExpression::Property {
1483 variable: "n".to_string(),
1484 property: "age".to_string(),
1485 }),
1486 op: BinaryOp::Gt,
1487 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1488 },
1489 pushdown_hint: None,
1490 input: Box::new(LogicalOperator::Project(ProjectOp {
1491 projections: vec![Projection {
1492 expression: LogicalExpression::Variable("n".to_string()),
1493 alias: None,
1494 }],
1495 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1496 variable: "n".to_string(),
1497 label: None,
1498 input: None,
1499 })),
1500 pass_through_input: false,
1501 })),
1502 }));
1503
1504 let optimized = optimizer.optimize(plan).unwrap();
1505
1506 if let LogicalOperator::Project(proj) = &optimized.root
1508 && let LogicalOperator::Filter(_) = proj.input.as_ref()
1509 {
1510 return;
1511 }
1512 panic!("Expected Project -> Filter structure");
1513 }
1514
1515 #[test]
1516 fn test_filter_not_pushed_through_project_with_alias() {
1517 let optimizer = Optimizer::new();
1518
1519 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1521 predicate: LogicalExpression::Binary {
1522 left: Box::new(LogicalExpression::Variable("x".to_string())),
1523 op: BinaryOp::Gt,
1524 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1525 },
1526 pushdown_hint: None,
1527 input: Box::new(LogicalOperator::Project(ProjectOp {
1528 projections: vec![Projection {
1529 expression: LogicalExpression::Property {
1530 variable: "n".to_string(),
1531 property: "age".to_string(),
1532 },
1533 alias: Some("x".to_string()),
1534 }],
1535 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1536 variable: "n".to_string(),
1537 label: None,
1538 input: None,
1539 })),
1540 pass_through_input: false,
1541 })),
1542 }));
1543
1544 let optimized = optimizer.optimize(plan).unwrap();
1545
1546 if let LogicalOperator::Filter(filter) = &optimized.root
1548 && let LogicalOperator::Project(_) = filter.input.as_ref()
1549 {
1550 return;
1551 }
1552 panic!("Expected Filter -> Project structure");
1553 }
1554
1555 #[test]
1556 fn test_filter_pushdown_through_limit() {
1557 let optimizer = Optimizer::new();
1558
1559 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1560 predicate: LogicalExpression::Literal(Value::Bool(true)),
1561 pushdown_hint: None,
1562 input: Box::new(LogicalOperator::Limit(LimitOp {
1563 count: 10.into(),
1564 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1565 variable: "n".to_string(),
1566 label: None,
1567 input: None,
1568 })),
1569 })),
1570 }));
1571
1572 let optimized = optimizer.optimize(plan).unwrap();
1573
1574 if let LogicalOperator::Filter(filter) = &optimized.root
1576 && let LogicalOperator::Limit(_) = filter.input.as_ref()
1577 {
1578 return;
1579 }
1580 panic!("Expected Filter -> Limit structure");
1581 }
1582
1583 #[test]
1584 fn test_filter_pushdown_through_sort() {
1585 let optimizer = Optimizer::new();
1586
1587 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1588 predicate: LogicalExpression::Literal(Value::Bool(true)),
1589 pushdown_hint: None,
1590 input: Box::new(LogicalOperator::Sort(SortOp {
1591 keys: vec![SortKey {
1592 expression: LogicalExpression::Variable("n".to_string()),
1593 order: SortOrder::Ascending,
1594 nulls: None,
1595 }],
1596 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1597 variable: "n".to_string(),
1598 label: None,
1599 input: None,
1600 })),
1601 })),
1602 }));
1603
1604 let optimized = optimizer.optimize(plan).unwrap();
1605
1606 if let LogicalOperator::Filter(filter) = &optimized.root
1608 && let LogicalOperator::Sort(_) = filter.input.as_ref()
1609 {
1610 return;
1611 }
1612 panic!("Expected Filter -> Sort structure");
1613 }
1614
1615 #[test]
1616 fn test_filter_pushdown_through_distinct() {
1617 let optimizer = Optimizer::new();
1618
1619 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1620 predicate: LogicalExpression::Literal(Value::Bool(true)),
1621 pushdown_hint: None,
1622 input: Box::new(LogicalOperator::Distinct(DistinctOp {
1623 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1624 variable: "n".to_string(),
1625 label: None,
1626 input: None,
1627 })),
1628 columns: None,
1629 })),
1630 }));
1631
1632 let optimized = optimizer.optimize(plan).unwrap();
1633
1634 if let LogicalOperator::Filter(filter) = &optimized.root
1636 && let LogicalOperator::Distinct(_) = filter.input.as_ref()
1637 {
1638 return;
1639 }
1640 panic!("Expected Filter -> Distinct structure");
1641 }
1642
1643 #[test]
1644 fn test_filter_not_pushed_through_aggregate() {
1645 let optimizer = Optimizer::new();
1646
1647 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1648 predicate: LogicalExpression::Binary {
1649 left: Box::new(LogicalExpression::Variable("cnt".to_string())),
1650 op: BinaryOp::Gt,
1651 right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1652 },
1653 pushdown_hint: None,
1654 input: Box::new(LogicalOperator::Aggregate(AggregateOp {
1655 group_by: vec![],
1656 aggregates: vec![AggregateExpr {
1657 function: AggregateFunction::Count,
1658 expression: None,
1659 expression2: None,
1660 distinct: false,
1661 alias: Some("cnt".to_string()),
1662 percentile: None,
1663 separator: None,
1664 }],
1665 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1666 variable: "n".to_string(),
1667 label: None,
1668 input: None,
1669 })),
1670 having: None,
1671 })),
1672 }));
1673
1674 let optimized = optimizer.optimize(plan).unwrap();
1675
1676 if let LogicalOperator::Filter(filter) = &optimized.root
1678 && let LogicalOperator::Aggregate(_) = filter.input.as_ref()
1679 {
1680 return;
1681 }
1682 panic!("Expected Filter -> Aggregate structure");
1683 }
1684
1685 #[test]
1686 fn test_filter_pushdown_to_left_join_side() {
1687 let optimizer = Optimizer::new();
1688
1689 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1691 predicate: LogicalExpression::Binary {
1692 left: Box::new(LogicalExpression::Property {
1693 variable: "a".to_string(),
1694 property: "age".to_string(),
1695 }),
1696 op: BinaryOp::Gt,
1697 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1698 },
1699 pushdown_hint: None,
1700 input: Box::new(LogicalOperator::Join(JoinOp {
1701 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1702 variable: "a".to_string(),
1703 label: Some("Person".to_string()),
1704 input: None,
1705 })),
1706 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1707 variable: "b".to_string(),
1708 label: Some("Company".to_string()),
1709 input: None,
1710 })),
1711 join_type: JoinType::Inner,
1712 conditions: vec![],
1713 })),
1714 }));
1715
1716 let optimized = optimizer.optimize(plan).unwrap();
1717
1718 if let LogicalOperator::Join(join) = &optimized.root
1720 && let LogicalOperator::Filter(_) = join.left.as_ref()
1721 {
1722 return;
1723 }
1724 panic!("Expected Join with Filter on left side");
1725 }
1726
1727 #[test]
1728 fn test_filter_pushdown_to_right_join_side() {
1729 let optimizer = Optimizer::new();
1730
1731 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1733 predicate: LogicalExpression::Binary {
1734 left: Box::new(LogicalExpression::Property {
1735 variable: "b".to_string(),
1736 property: "name".to_string(),
1737 }),
1738 op: BinaryOp::Eq,
1739 right: Box::new(LogicalExpression::Literal(Value::String("Acme".into()))),
1740 },
1741 pushdown_hint: None,
1742 input: Box::new(LogicalOperator::Join(JoinOp {
1743 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1744 variable: "a".to_string(),
1745 label: Some("Person".to_string()),
1746 input: None,
1747 })),
1748 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1749 variable: "b".to_string(),
1750 label: Some("Company".to_string()),
1751 input: None,
1752 })),
1753 join_type: JoinType::Inner,
1754 conditions: vec![],
1755 })),
1756 }));
1757
1758 let optimized = optimizer.optimize(plan).unwrap();
1759
1760 if let LogicalOperator::Join(join) = &optimized.root
1762 && let LogicalOperator::Filter(_) = join.right.as_ref()
1763 {
1764 return;
1765 }
1766 panic!("Expected Join with Filter on right side");
1767 }
1768
1769 #[test]
1770 fn test_filter_not_pushed_when_uses_both_join_sides() {
1771 let optimizer = Optimizer::new();
1772
1773 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1775 predicate: LogicalExpression::Binary {
1776 left: Box::new(LogicalExpression::Property {
1777 variable: "a".to_string(),
1778 property: "id".to_string(),
1779 }),
1780 op: BinaryOp::Eq,
1781 right: Box::new(LogicalExpression::Property {
1782 variable: "b".to_string(),
1783 property: "a_id".to_string(),
1784 }),
1785 },
1786 pushdown_hint: None,
1787 input: Box::new(LogicalOperator::Join(JoinOp {
1788 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1789 variable: "a".to_string(),
1790 label: None,
1791 input: None,
1792 })),
1793 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1794 variable: "b".to_string(),
1795 label: None,
1796 input: None,
1797 })),
1798 join_type: JoinType::Inner,
1799 conditions: vec![],
1800 })),
1801 }));
1802
1803 let optimized = optimizer.optimize(plan).unwrap();
1804
1805 if let LogicalOperator::Filter(filter) = &optimized.root
1807 && let LogicalOperator::Join(_) = filter.input.as_ref()
1808 {
1809 return;
1810 }
1811 panic!("Expected Filter -> Join structure");
1812 }
1813
1814 #[test]
1817 fn test_extract_variables_from_variable() {
1818 let optimizer = Optimizer::new();
1819 let expr = LogicalExpression::Variable("x".to_string());
1820 let vars = optimizer.extract_variables(&expr);
1821 assert_eq!(vars.len(), 1);
1822 assert!(vars.contains("x"));
1823 }
1824
1825 #[test]
1826 fn test_extract_variables_from_unary() {
1827 let optimizer = Optimizer::new();
1828 let expr = LogicalExpression::Unary {
1829 op: UnaryOp::Not,
1830 operand: Box::new(LogicalExpression::Variable("x".to_string())),
1831 };
1832 let vars = optimizer.extract_variables(&expr);
1833 assert_eq!(vars.len(), 1);
1834 assert!(vars.contains("x"));
1835 }
1836
1837 #[test]
1838 fn test_extract_variables_from_function_call() {
1839 let optimizer = Optimizer::new();
1840 let expr = LogicalExpression::FunctionCall {
1841 name: "length".to_string(),
1842 args: vec![
1843 LogicalExpression::Variable("a".to_string()),
1844 LogicalExpression::Variable("b".to_string()),
1845 ],
1846 distinct: false,
1847 };
1848 let vars = optimizer.extract_variables(&expr);
1849 assert_eq!(vars.len(), 2);
1850 assert!(vars.contains("a"));
1851 assert!(vars.contains("b"));
1852 }
1853
1854 #[test]
1855 fn test_extract_variables_from_list() {
1856 let optimizer = Optimizer::new();
1857 let expr = LogicalExpression::List(vec![
1858 LogicalExpression::Variable("a".to_string()),
1859 LogicalExpression::Literal(Value::Int64(1)),
1860 LogicalExpression::Variable("b".to_string()),
1861 ]);
1862 let vars = optimizer.extract_variables(&expr);
1863 assert_eq!(vars.len(), 2);
1864 assert!(vars.contains("a"));
1865 assert!(vars.contains("b"));
1866 }
1867
1868 #[test]
1869 fn test_extract_variables_from_map() {
1870 let optimizer = Optimizer::new();
1871 let expr = LogicalExpression::Map(vec![
1872 (
1873 "key1".to_string(),
1874 LogicalExpression::Variable("a".to_string()),
1875 ),
1876 (
1877 "key2".to_string(),
1878 LogicalExpression::Variable("b".to_string()),
1879 ),
1880 ]);
1881 let vars = optimizer.extract_variables(&expr);
1882 assert_eq!(vars.len(), 2);
1883 assert!(vars.contains("a"));
1884 assert!(vars.contains("b"));
1885 }
1886
1887 #[test]
1888 fn test_extract_variables_from_index_access() {
1889 let optimizer = Optimizer::new();
1890 let expr = LogicalExpression::IndexAccess {
1891 base: Box::new(LogicalExpression::Variable("list".to_string())),
1892 index: Box::new(LogicalExpression::Variable("idx".to_string())),
1893 };
1894 let vars = optimizer.extract_variables(&expr);
1895 assert_eq!(vars.len(), 2);
1896 assert!(vars.contains("list"));
1897 assert!(vars.contains("idx"));
1898 }
1899
1900 #[test]
1901 fn test_extract_variables_from_slice_access() {
1902 let optimizer = Optimizer::new();
1903 let expr = LogicalExpression::SliceAccess {
1904 base: Box::new(LogicalExpression::Variable("list".to_string())),
1905 start: Some(Box::new(LogicalExpression::Variable("s".to_string()))),
1906 end: Some(Box::new(LogicalExpression::Variable("e".to_string()))),
1907 };
1908 let vars = optimizer.extract_variables(&expr);
1909 assert_eq!(vars.len(), 3);
1910 assert!(vars.contains("list"));
1911 assert!(vars.contains("s"));
1912 assert!(vars.contains("e"));
1913 }
1914
1915 #[test]
1916 fn test_extract_variables_from_case() {
1917 let optimizer = Optimizer::new();
1918 let expr = LogicalExpression::Case {
1919 operand: Some(Box::new(LogicalExpression::Variable("x".to_string()))),
1920 when_clauses: vec![(
1921 LogicalExpression::Literal(Value::Int64(1)),
1922 LogicalExpression::Variable("a".to_string()),
1923 )],
1924 else_clause: Some(Box::new(LogicalExpression::Variable("b".to_string()))),
1925 };
1926 let vars = optimizer.extract_variables(&expr);
1927 assert_eq!(vars.len(), 3);
1928 assert!(vars.contains("x"));
1929 assert!(vars.contains("a"));
1930 assert!(vars.contains("b"));
1931 }
1932
1933 #[test]
1934 fn test_extract_variables_from_labels() {
1935 let optimizer = Optimizer::new();
1936 let expr = LogicalExpression::Labels("n".to_string());
1937 let vars = optimizer.extract_variables(&expr);
1938 assert_eq!(vars.len(), 1);
1939 assert!(vars.contains("n"));
1940 }
1941
1942 #[test]
1943 fn test_extract_variables_from_type() {
1944 let optimizer = Optimizer::new();
1945 let expr = LogicalExpression::Type("e".to_string());
1946 let vars = optimizer.extract_variables(&expr);
1947 assert_eq!(vars.len(), 1);
1948 assert!(vars.contains("e"));
1949 }
1950
1951 #[test]
1952 fn test_extract_variables_from_id() {
1953 let optimizer = Optimizer::new();
1954 let expr = LogicalExpression::Id("n".to_string());
1955 let vars = optimizer.extract_variables(&expr);
1956 assert_eq!(vars.len(), 1);
1957 assert!(vars.contains("n"));
1958 }
1959
1960 #[test]
1961 fn test_extract_variables_from_list_comprehension() {
1962 let optimizer = Optimizer::new();
1963 let expr = LogicalExpression::ListComprehension {
1964 variable: "x".to_string(),
1965 list_expr: Box::new(LogicalExpression::Variable("items".to_string())),
1966 filter_expr: Some(Box::new(LogicalExpression::Variable("pred".to_string()))),
1967 map_expr: Box::new(LogicalExpression::Variable("result".to_string())),
1968 };
1969 let vars = optimizer.extract_variables(&expr);
1970 assert!(vars.contains("items"));
1971 assert!(vars.contains("pred"));
1972 assert!(vars.contains("result"));
1973 }
1974
1975 #[test]
1976 fn test_extract_variables_from_literal_and_parameter() {
1977 let optimizer = Optimizer::new();
1978
1979 let literal = LogicalExpression::Literal(Value::Int64(42));
1980 assert!(optimizer.extract_variables(&literal).is_empty());
1981
1982 let param = LogicalExpression::Parameter("p".to_string());
1983 assert!(optimizer.extract_variables(¶m).is_empty());
1984 }
1985
1986 #[test]
1989 fn test_recursive_filter_pushdown_through_skip() {
1990 let optimizer = Optimizer::new();
1991
1992 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1993 items: vec![ReturnItem {
1994 expression: LogicalExpression::Variable("n".to_string()),
1995 alias: None,
1996 }],
1997 distinct: false,
1998 input: Box::new(LogicalOperator::Filter(FilterOp {
1999 predicate: LogicalExpression::Literal(Value::Bool(true)),
2000 pushdown_hint: None,
2001 input: Box::new(LogicalOperator::Skip(SkipOp {
2002 count: 5.into(),
2003 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2004 variable: "n".to_string(),
2005 label: None,
2006 input: None,
2007 })),
2008 })),
2009 })),
2010 }));
2011
2012 let optimized = optimizer.optimize(plan).unwrap();
2013
2014 assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
2016 }
2017
2018 #[test]
2019 fn test_nested_filter_pushdown() {
2020 let optimizer = Optimizer::new();
2021
2022 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
2024 items: vec![ReturnItem {
2025 expression: LogicalExpression::Variable("n".to_string()),
2026 alias: None,
2027 }],
2028 distinct: false,
2029 input: Box::new(LogicalOperator::Filter(FilterOp {
2030 predicate: LogicalExpression::Binary {
2031 left: Box::new(LogicalExpression::Property {
2032 variable: "n".to_string(),
2033 property: "x".to_string(),
2034 }),
2035 op: BinaryOp::Gt,
2036 right: Box::new(LogicalExpression::Literal(Value::Int64(1))),
2037 },
2038 pushdown_hint: None,
2039 input: Box::new(LogicalOperator::Filter(FilterOp {
2040 predicate: LogicalExpression::Binary {
2041 left: Box::new(LogicalExpression::Property {
2042 variable: "n".to_string(),
2043 property: "y".to_string(),
2044 }),
2045 op: BinaryOp::Lt,
2046 right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
2047 },
2048 pushdown_hint: None,
2049 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2050 variable: "n".to_string(),
2051 label: None,
2052 input: None,
2053 })),
2054 })),
2055 })),
2056 }));
2057
2058 let optimized = optimizer.optimize(plan).unwrap();
2059 assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
2060 }
2061
2062 #[test]
2063 fn test_cyclic_join_produces_multi_way_join() {
2064 use crate::query::plan::JoinCondition;
2065
2066 let scan_a = LogicalOperator::NodeScan(NodeScanOp {
2068 variable: "a".to_string(),
2069 label: Some("Person".to_string()),
2070 input: None,
2071 });
2072 let scan_b = LogicalOperator::NodeScan(NodeScanOp {
2073 variable: "b".to_string(),
2074 label: Some("Person".to_string()),
2075 input: None,
2076 });
2077 let scan_c = LogicalOperator::NodeScan(NodeScanOp {
2078 variable: "c".to_string(),
2079 label: Some("Person".to_string()),
2080 input: None,
2081 });
2082
2083 let join_ab = LogicalOperator::Join(JoinOp {
2085 left: Box::new(scan_a),
2086 right: Box::new(scan_b),
2087 join_type: JoinType::Inner,
2088 conditions: vec![JoinCondition {
2089 left: LogicalExpression::Variable("a".to_string()),
2090 right: LogicalExpression::Variable("b".to_string()),
2091 }],
2092 });
2093
2094 let join_abc = LogicalOperator::Join(JoinOp {
2095 left: Box::new(join_ab),
2096 right: Box::new(scan_c),
2097 join_type: JoinType::Inner,
2098 conditions: vec![
2099 JoinCondition {
2100 left: LogicalExpression::Variable("b".to_string()),
2101 right: LogicalExpression::Variable("c".to_string()),
2102 },
2103 JoinCondition {
2104 left: LogicalExpression::Variable("c".to_string()),
2105 right: LogicalExpression::Variable("a".to_string()),
2106 },
2107 ],
2108 });
2109
2110 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
2111 items: vec![ReturnItem {
2112 expression: LogicalExpression::Variable("a".to_string()),
2113 alias: None,
2114 }],
2115 distinct: false,
2116 input: Box::new(join_abc),
2117 }));
2118
2119 let mut optimizer = Optimizer::new();
2120 optimizer
2121 .card_estimator
2122 .add_table_stats("Person", cardinality::TableStats::new(1000));
2123
2124 let optimized = optimizer.optimize(plan).unwrap();
2125
2126 fn has_multi_way_join(op: &LogicalOperator) -> bool {
2128 match op {
2129 LogicalOperator::MultiWayJoin(_) => true,
2130 LogicalOperator::Return(ret) => has_multi_way_join(&ret.input),
2131 LogicalOperator::Filter(f) => has_multi_way_join(&f.input),
2132 LogicalOperator::Project(p) => has_multi_way_join(&p.input),
2133 _ => false,
2134 }
2135 }
2136
2137 assert!(
2138 has_multi_way_join(&optimized.root),
2139 "Expected MultiWayJoin for cyclic triangle pattern"
2140 );
2141 }
2142
2143 #[test]
2144 fn test_acyclic_join_uses_binary_joins() {
2145 use crate::query::plan::JoinCondition;
2146
2147 let scan_a = LogicalOperator::NodeScan(NodeScanOp {
2149 variable: "a".to_string(),
2150 label: Some("Person".to_string()),
2151 input: None,
2152 });
2153 let scan_b = LogicalOperator::NodeScan(NodeScanOp {
2154 variable: "b".to_string(),
2155 label: Some("Person".to_string()),
2156 input: None,
2157 });
2158 let scan_c = LogicalOperator::NodeScan(NodeScanOp {
2159 variable: "c".to_string(),
2160 label: Some("Company".to_string()),
2161 input: None,
2162 });
2163
2164 let join_ab = LogicalOperator::Join(JoinOp {
2165 left: Box::new(scan_a),
2166 right: Box::new(scan_b),
2167 join_type: JoinType::Inner,
2168 conditions: vec![JoinCondition {
2169 left: LogicalExpression::Variable("a".to_string()),
2170 right: LogicalExpression::Variable("b".to_string()),
2171 }],
2172 });
2173
2174 let join_abc = LogicalOperator::Join(JoinOp {
2175 left: Box::new(join_ab),
2176 right: Box::new(scan_c),
2177 join_type: JoinType::Inner,
2178 conditions: vec![JoinCondition {
2179 left: LogicalExpression::Variable("b".to_string()),
2180 right: LogicalExpression::Variable("c".to_string()),
2181 }],
2182 });
2183
2184 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
2185 items: vec![ReturnItem {
2186 expression: LogicalExpression::Variable("a".to_string()),
2187 alias: None,
2188 }],
2189 distinct: false,
2190 input: Box::new(join_abc),
2191 }));
2192
2193 let mut optimizer = Optimizer::new();
2194 optimizer
2195 .card_estimator
2196 .add_table_stats("Person", cardinality::TableStats::new(1000));
2197 optimizer
2198 .card_estimator
2199 .add_table_stats("Company", cardinality::TableStats::new(100));
2200
2201 let optimized = optimizer.optimize(plan).unwrap();
2202
2203 fn has_multi_way_join(op: &LogicalOperator) -> bool {
2205 match op {
2206 LogicalOperator::MultiWayJoin(_) => true,
2207 LogicalOperator::Return(ret) => has_multi_way_join(&ret.input),
2208 LogicalOperator::Filter(f) => has_multi_way_join(&f.input),
2209 LogicalOperator::Project(p) => has_multi_way_join(&p.input),
2210 LogicalOperator::Join(j) => {
2211 has_multi_way_join(&j.left) || has_multi_way_join(&j.right)
2212 }
2213 _ => false,
2214 }
2215 }
2216
2217 assert!(
2218 !has_multi_way_join(&optimized.root),
2219 "Acyclic join should NOT produce MultiWayJoin"
2220 );
2221 }
2222}