1pub mod cardinality;
15pub mod cost;
16pub mod join_order;
17
18pub use cardinality::{
19 CardinalityEstimator, ColumnStats, EstimationLog, SelectivityConfig, TableStats,
20};
21pub use cost::{Cost, CostModel};
22pub use join_order::{BitSet, DPccp, JoinGraph, JoinGraphBuilder, JoinPlan};
23
24use crate::query::plan::{
25 FilterOp, JoinCondition, LogicalExpression, LogicalOperator, LogicalPlan, MultiWayJoinOp,
26};
27use grafeo_common::utils::error::Result;
28use std::collections::HashSet;
29
30#[derive(Debug, Clone)]
32struct JoinInfo {
33 left_var: String,
34 right_var: String,
35 left_expr: LogicalExpression,
36 right_expr: LogicalExpression,
37}
38
39#[derive(Debug, Clone, PartialEq, Eq, Hash)]
41enum RequiredColumn {
42 Variable(String),
44 Property(String, String),
46}
47
48pub struct Optimizer {
53 enable_filter_pushdown: bool,
55 enable_join_reorder: bool,
57 enable_projection_pushdown: bool,
59 cost_model: CostModel,
61 card_estimator: CardinalityEstimator,
63}
64
65impl Optimizer {
66 #[must_use]
68 pub fn new() -> Self {
69 Self {
70 enable_filter_pushdown: true,
71 enable_join_reorder: true,
72 enable_projection_pushdown: true,
73 cost_model: CostModel::new(),
74 card_estimator: CardinalityEstimator::new(),
75 }
76 }
77
78 #[must_use]
84 pub fn from_store(store: &grafeo_core::graph::lpg::LpgStore) -> Self {
85 store.ensure_statistics_fresh();
86 let stats = store.statistics();
87 let estimator = CardinalityEstimator::from_statistics(&stats);
88
89 let avg_fanout = if stats.total_nodes > 0 {
91 (stats.total_edges as f64 / stats.total_nodes as f64).max(1.0)
92 } else {
93 10.0
94 };
95
96 let edge_type_degrees: std::collections::HashMap<String, (f64, f64)> = stats
98 .edge_types
99 .iter()
100 .map(|(name, et)| (name.clone(), (et.avg_out_degree, et.avg_in_degree)))
101 .collect();
102
103 Self {
104 enable_filter_pushdown: true,
105 enable_join_reorder: true,
106 enable_projection_pushdown: true,
107 cost_model: CostModel::new()
108 .with_avg_fanout(avg_fanout)
109 .with_edge_type_degrees(edge_type_degrees),
110 card_estimator: estimator,
111 }
112 }
113
114 #[must_use]
121 pub fn from_graph_store(store: &dyn grafeo_core::graph::GraphStore) -> Self {
122 let stats = store.statistics();
123 let estimator = CardinalityEstimator::from_statistics(&stats);
124
125 let avg_fanout = if stats.total_nodes > 0 {
126 (stats.total_edges as f64 / stats.total_nodes as f64).max(1.0)
127 } else {
128 10.0
129 };
130
131 let edge_type_degrees: std::collections::HashMap<String, (f64, f64)> = stats
132 .edge_types
133 .iter()
134 .map(|(name, et)| (name.clone(), (et.avg_out_degree, et.avg_in_degree)))
135 .collect();
136
137 Self {
138 enable_filter_pushdown: true,
139 enable_join_reorder: true,
140 enable_projection_pushdown: true,
141 cost_model: CostModel::new()
142 .with_avg_fanout(avg_fanout)
143 .with_edge_type_degrees(edge_type_degrees),
144 card_estimator: estimator,
145 }
146 }
147
148 pub fn with_filter_pushdown(mut self, enabled: bool) -> Self {
150 self.enable_filter_pushdown = enabled;
151 self
152 }
153
154 pub fn with_join_reorder(mut self, enabled: bool) -> Self {
156 self.enable_join_reorder = enabled;
157 self
158 }
159
160 pub fn with_projection_pushdown(mut self, enabled: bool) -> Self {
162 self.enable_projection_pushdown = enabled;
163 self
164 }
165
166 pub fn with_cost_model(mut self, cost_model: CostModel) -> Self {
168 self.cost_model = cost_model;
169 self
170 }
171
172 pub fn with_cardinality_estimator(mut self, estimator: CardinalityEstimator) -> Self {
174 self.card_estimator = estimator;
175 self
176 }
177
178 pub fn with_selectivity_config(mut self, config: SelectivityConfig) -> Self {
180 self.card_estimator = CardinalityEstimator::with_selectivity_config(config);
181 self
182 }
183
184 pub fn cost_model(&self) -> &CostModel {
186 &self.cost_model
187 }
188
189 pub fn cardinality_estimator(&self) -> &CardinalityEstimator {
191 &self.card_estimator
192 }
193
194 pub fn estimate_cost(&self, plan: &LogicalPlan) -> Cost {
196 let cardinality = self.card_estimator.estimate(&plan.root);
197 self.cost_model.estimate(&plan.root, cardinality)
198 }
199
200 pub fn estimate_cardinality(&self, plan: &LogicalPlan) -> f64 {
202 self.card_estimator.estimate(&plan.root)
203 }
204
205 pub fn optimize(&self, plan: LogicalPlan) -> Result<LogicalPlan> {
211 let mut root = plan.root;
212
213 if self.enable_filter_pushdown {
215 root = self.push_filters_down(root);
216 }
217
218 if self.enable_join_reorder {
219 root = self.reorder_joins(root);
220 }
221
222 if self.enable_projection_pushdown {
223 root = self.push_projections_down(root);
224 }
225
226 Ok(LogicalPlan::new(root))
227 }
228
229 fn push_projections_down(&self, op: LogicalOperator) -> LogicalOperator {
236 let required = self.collect_required_columns(&op);
238
239 self.push_projections_recursive(op, &required)
241 }
242
243 fn collect_required_columns(&self, op: &LogicalOperator) -> HashSet<RequiredColumn> {
245 let mut required = HashSet::new();
246 Self::collect_required_recursive(op, &mut required);
247 required
248 }
249
250 fn collect_required_recursive(op: &LogicalOperator, required: &mut HashSet<RequiredColumn>) {
252 match op {
253 LogicalOperator::Return(ret) => {
254 for item in &ret.items {
255 Self::collect_from_expression(&item.expression, required);
256 }
257 Self::collect_required_recursive(&ret.input, required);
258 }
259 LogicalOperator::Project(proj) => {
260 for p in &proj.projections {
261 Self::collect_from_expression(&p.expression, required);
262 }
263 Self::collect_required_recursive(&proj.input, required);
264 }
265 LogicalOperator::Filter(filter) => {
266 Self::collect_from_expression(&filter.predicate, required);
267 Self::collect_required_recursive(&filter.input, required);
268 }
269 LogicalOperator::Sort(sort) => {
270 for key in &sort.keys {
271 Self::collect_from_expression(&key.expression, required);
272 }
273 Self::collect_required_recursive(&sort.input, required);
274 }
275 LogicalOperator::Aggregate(agg) => {
276 for expr in &agg.group_by {
277 Self::collect_from_expression(expr, required);
278 }
279 for agg_expr in &agg.aggregates {
280 if let Some(ref expr) = agg_expr.expression {
281 Self::collect_from_expression(expr, required);
282 }
283 }
284 if let Some(ref having) = agg.having {
285 Self::collect_from_expression(having, required);
286 }
287 Self::collect_required_recursive(&agg.input, required);
288 }
289 LogicalOperator::Join(join) => {
290 for cond in &join.conditions {
291 Self::collect_from_expression(&cond.left, required);
292 Self::collect_from_expression(&cond.right, required);
293 }
294 Self::collect_required_recursive(&join.left, required);
295 Self::collect_required_recursive(&join.right, required);
296 }
297 LogicalOperator::Expand(expand) => {
298 required.insert(RequiredColumn::Variable(expand.from_variable.clone()));
300 required.insert(RequiredColumn::Variable(expand.to_variable.clone()));
301 if let Some(ref edge_var) = expand.edge_variable {
302 required.insert(RequiredColumn::Variable(edge_var.clone()));
303 }
304 Self::collect_required_recursive(&expand.input, required);
305 }
306 LogicalOperator::Limit(limit) => {
307 Self::collect_required_recursive(&limit.input, required);
308 }
309 LogicalOperator::Skip(skip) => {
310 Self::collect_required_recursive(&skip.input, required);
311 }
312 LogicalOperator::Distinct(distinct) => {
313 Self::collect_required_recursive(&distinct.input, required);
314 }
315 LogicalOperator::NodeScan(scan) => {
316 required.insert(RequiredColumn::Variable(scan.variable.clone()));
317 }
318 LogicalOperator::EdgeScan(scan) => {
319 required.insert(RequiredColumn::Variable(scan.variable.clone()));
320 }
321 LogicalOperator::MultiWayJoin(mwj) => {
322 for cond in &mwj.conditions {
323 Self::collect_from_expression(&cond.left, required);
324 Self::collect_from_expression(&cond.right, required);
325 }
326 for input in &mwj.inputs {
327 Self::collect_required_recursive(input, required);
328 }
329 }
330 _ => {}
331 }
332 }
333
334 fn collect_from_expression(expr: &LogicalExpression, required: &mut HashSet<RequiredColumn>) {
336 match expr {
337 LogicalExpression::Variable(var) => {
338 required.insert(RequiredColumn::Variable(var.clone()));
339 }
340 LogicalExpression::Property { variable, property } => {
341 required.insert(RequiredColumn::Property(variable.clone(), property.clone()));
342 required.insert(RequiredColumn::Variable(variable.clone()));
343 }
344 LogicalExpression::Binary { left, right, .. } => {
345 Self::collect_from_expression(left, required);
346 Self::collect_from_expression(right, required);
347 }
348 LogicalExpression::Unary { operand, .. } => {
349 Self::collect_from_expression(operand, required);
350 }
351 LogicalExpression::FunctionCall { args, .. } => {
352 for arg in args {
353 Self::collect_from_expression(arg, required);
354 }
355 }
356 LogicalExpression::List(items) => {
357 for item in items {
358 Self::collect_from_expression(item, required);
359 }
360 }
361 LogicalExpression::Map(pairs) => {
362 for (_, value) in pairs {
363 Self::collect_from_expression(value, required);
364 }
365 }
366 LogicalExpression::IndexAccess { base, index } => {
367 Self::collect_from_expression(base, required);
368 Self::collect_from_expression(index, required);
369 }
370 LogicalExpression::SliceAccess { base, start, end } => {
371 Self::collect_from_expression(base, required);
372 if let Some(s) = start {
373 Self::collect_from_expression(s, required);
374 }
375 if let Some(e) = end {
376 Self::collect_from_expression(e, required);
377 }
378 }
379 LogicalExpression::Case {
380 operand,
381 when_clauses,
382 else_clause,
383 } => {
384 if let Some(op) = operand {
385 Self::collect_from_expression(op, required);
386 }
387 for (cond, result) in when_clauses {
388 Self::collect_from_expression(cond, required);
389 Self::collect_from_expression(result, required);
390 }
391 if let Some(else_expr) = else_clause {
392 Self::collect_from_expression(else_expr, required);
393 }
394 }
395 LogicalExpression::Labels(var)
396 | LogicalExpression::Type(var)
397 | LogicalExpression::Id(var) => {
398 required.insert(RequiredColumn::Variable(var.clone()));
399 }
400 LogicalExpression::ListComprehension {
401 list_expr,
402 filter_expr,
403 map_expr,
404 ..
405 } => {
406 Self::collect_from_expression(list_expr, required);
407 if let Some(filter) = filter_expr {
408 Self::collect_from_expression(filter, required);
409 }
410 Self::collect_from_expression(map_expr, required);
411 }
412 _ => {}
413 }
414 }
415
416 fn push_projections_recursive(
418 &self,
419 op: LogicalOperator,
420 required: &HashSet<RequiredColumn>,
421 ) -> LogicalOperator {
422 match op {
423 LogicalOperator::Return(mut ret) => {
424 ret.input = Box::new(self.push_projections_recursive(*ret.input, required));
425 LogicalOperator::Return(ret)
426 }
427 LogicalOperator::Project(mut proj) => {
428 proj.input = Box::new(self.push_projections_recursive(*proj.input, required));
429 LogicalOperator::Project(proj)
430 }
431 LogicalOperator::Filter(mut filter) => {
432 filter.input = Box::new(self.push_projections_recursive(*filter.input, required));
433 LogicalOperator::Filter(filter)
434 }
435 LogicalOperator::Sort(mut sort) => {
436 sort.input = Box::new(self.push_projections_recursive(*sort.input, required));
439 LogicalOperator::Sort(sort)
440 }
441 LogicalOperator::Aggregate(mut agg) => {
442 agg.input = Box::new(self.push_projections_recursive(*agg.input, required));
443 LogicalOperator::Aggregate(agg)
444 }
445 LogicalOperator::Join(mut join) => {
446 let left_vars = self.collect_output_variables(&join.left);
449 let right_vars = self.collect_output_variables(&join.right);
450
451 let left_required: HashSet<_> = required
453 .iter()
454 .filter(|c| match c {
455 RequiredColumn::Variable(v) => left_vars.contains(v),
456 RequiredColumn::Property(v, _) => left_vars.contains(v),
457 })
458 .cloned()
459 .collect();
460
461 let right_required: HashSet<_> = required
462 .iter()
463 .filter(|c| match c {
464 RequiredColumn::Variable(v) => right_vars.contains(v),
465 RequiredColumn::Property(v, _) => right_vars.contains(v),
466 })
467 .cloned()
468 .collect();
469
470 join.left = Box::new(self.push_projections_recursive(*join.left, &left_required));
471 join.right =
472 Box::new(self.push_projections_recursive(*join.right, &right_required));
473 LogicalOperator::Join(join)
474 }
475 LogicalOperator::Expand(mut expand) => {
476 expand.input = Box::new(self.push_projections_recursive(*expand.input, required));
477 LogicalOperator::Expand(expand)
478 }
479 LogicalOperator::Limit(mut limit) => {
480 limit.input = Box::new(self.push_projections_recursive(*limit.input, required));
481 LogicalOperator::Limit(limit)
482 }
483 LogicalOperator::Skip(mut skip) => {
484 skip.input = Box::new(self.push_projections_recursive(*skip.input, required));
485 LogicalOperator::Skip(skip)
486 }
487 LogicalOperator::Distinct(mut distinct) => {
488 distinct.input =
489 Box::new(self.push_projections_recursive(*distinct.input, required));
490 LogicalOperator::Distinct(distinct)
491 }
492 LogicalOperator::MapCollect(mut mc) => {
493 mc.input = Box::new(self.push_projections_recursive(*mc.input, required));
494 LogicalOperator::MapCollect(mc)
495 }
496 LogicalOperator::MultiWayJoin(mut mwj) => {
497 mwj.inputs = mwj
498 .inputs
499 .into_iter()
500 .map(|input| self.push_projections_recursive(input, required))
501 .collect();
502 LogicalOperator::MultiWayJoin(mwj)
503 }
504 other => other,
505 }
506 }
507
508 fn reorder_joins(&self, op: LogicalOperator) -> LogicalOperator {
515 let op = self.reorder_joins_recursive(op);
517
518 if let Some((relations, conditions)) = self.extract_join_tree(&op)
520 && relations.len() >= 2
521 && let Some(optimized) = self.optimize_join_order(&relations, &conditions)
522 {
523 return optimized;
524 }
525
526 op
527 }
528
529 fn reorder_joins_recursive(&self, op: LogicalOperator) -> LogicalOperator {
531 match op {
532 LogicalOperator::Return(mut ret) => {
533 ret.input = Box::new(self.reorder_joins(*ret.input));
534 LogicalOperator::Return(ret)
535 }
536 LogicalOperator::Project(mut proj) => {
537 proj.input = Box::new(self.reorder_joins(*proj.input));
538 LogicalOperator::Project(proj)
539 }
540 LogicalOperator::Filter(mut filter) => {
541 filter.input = Box::new(self.reorder_joins(*filter.input));
542 LogicalOperator::Filter(filter)
543 }
544 LogicalOperator::Limit(mut limit) => {
545 limit.input = Box::new(self.reorder_joins(*limit.input));
546 LogicalOperator::Limit(limit)
547 }
548 LogicalOperator::Skip(mut skip) => {
549 skip.input = Box::new(self.reorder_joins(*skip.input));
550 LogicalOperator::Skip(skip)
551 }
552 LogicalOperator::Sort(mut sort) => {
553 sort.input = Box::new(self.reorder_joins(*sort.input));
554 LogicalOperator::Sort(sort)
555 }
556 LogicalOperator::Distinct(mut distinct) => {
557 distinct.input = Box::new(self.reorder_joins(*distinct.input));
558 LogicalOperator::Distinct(distinct)
559 }
560 LogicalOperator::Aggregate(mut agg) => {
561 agg.input = Box::new(self.reorder_joins(*agg.input));
562 LogicalOperator::Aggregate(agg)
563 }
564 LogicalOperator::Expand(mut expand) => {
565 expand.input = Box::new(self.reorder_joins(*expand.input));
566 LogicalOperator::Expand(expand)
567 }
568 LogicalOperator::MapCollect(mut mc) => {
569 mc.input = Box::new(self.reorder_joins(*mc.input));
570 LogicalOperator::MapCollect(mc)
571 }
572 LogicalOperator::MultiWayJoin(mut mwj) => {
573 mwj.inputs = mwj
574 .inputs
575 .into_iter()
576 .map(|input| self.reorder_joins(input))
577 .collect();
578 LogicalOperator::MultiWayJoin(mwj)
579 }
580 other => other,
582 }
583 }
584
585 fn extract_join_tree(
589 &self,
590 op: &LogicalOperator,
591 ) -> Option<(Vec<(String, LogicalOperator)>, Vec<JoinInfo>)> {
592 let mut relations = Vec::new();
593 let mut join_conditions = Vec::new();
594
595 if !self.collect_join_tree(op, &mut relations, &mut join_conditions) {
596 return None;
597 }
598
599 if relations.len() < 2 {
600 return None;
601 }
602
603 Some((relations, join_conditions))
604 }
605
606 fn collect_join_tree(
610 &self,
611 op: &LogicalOperator,
612 relations: &mut Vec<(String, LogicalOperator)>,
613 conditions: &mut Vec<JoinInfo>,
614 ) -> bool {
615 match op {
616 LogicalOperator::Join(join) => {
617 let left_ok = self.collect_join_tree(&join.left, relations, conditions);
619 let right_ok = self.collect_join_tree(&join.right, relations, conditions);
620
621 for cond in &join.conditions {
623 if let (Some(left_var), Some(right_var)) = (
624 self.extract_variable_from_expr(&cond.left),
625 self.extract_variable_from_expr(&cond.right),
626 ) {
627 conditions.push(JoinInfo {
628 left_var,
629 right_var,
630 left_expr: cond.left.clone(),
631 right_expr: cond.right.clone(),
632 });
633 }
634 }
635
636 left_ok && right_ok
637 }
638 LogicalOperator::NodeScan(scan) => {
639 relations.push((scan.variable.clone(), op.clone()));
640 true
641 }
642 LogicalOperator::EdgeScan(scan) => {
643 relations.push((scan.variable.clone(), op.clone()));
644 true
645 }
646 LogicalOperator::Filter(filter) => {
647 self.collect_join_tree(&filter.input, relations, conditions)
649 }
650 LogicalOperator::Expand(expand) => {
651 relations.push((expand.to_variable.clone(), op.clone()));
654 true
655 }
656 _ => false,
657 }
658 }
659
660 fn extract_variable_from_expr(&self, expr: &LogicalExpression) -> Option<String> {
662 match expr {
663 LogicalExpression::Variable(v) => Some(v.clone()),
664 LogicalExpression::Property { variable, .. } => Some(variable.clone()),
665 _ => None,
666 }
667 }
668
669 fn optimize_join_order(
672 &self,
673 relations: &[(String, LogicalOperator)],
674 conditions: &[JoinInfo],
675 ) -> Option<LogicalOperator> {
676 use join_order::{DPccp, JoinGraphBuilder};
677
678 let mut builder = JoinGraphBuilder::new();
680
681 for (var, relation) in relations {
682 builder.add_relation(var, relation.clone());
683 }
684
685 for cond in conditions {
686 builder.add_join_condition(
687 &cond.left_var,
688 &cond.right_var,
689 cond.left_expr.clone(),
690 cond.right_expr.clone(),
691 );
692 }
693
694 let graph = builder.build();
695
696 if graph.is_cyclic() && relations.len() >= 3 {
701 let mut var_counts: std::collections::HashMap<&str, usize> =
703 std::collections::HashMap::new();
704 for cond in conditions {
705 *var_counts.entry(&cond.left_var).or_default() += 1;
706 *var_counts.entry(&cond.right_var).or_default() += 1;
707 }
708 let shared_variables: Vec<String> = var_counts
709 .into_iter()
710 .filter(|(_, count)| *count >= 2)
711 .map(|(var, _)| var.to_string())
712 .collect();
713
714 let join_conditions: Vec<JoinCondition> = conditions
715 .iter()
716 .map(|c| JoinCondition {
717 left: c.left_expr.clone(),
718 right: c.right_expr.clone(),
719 })
720 .collect();
721
722 return Some(LogicalOperator::MultiWayJoin(MultiWayJoinOp {
723 inputs: relations.iter().map(|(_, rel)| rel.clone()).collect(),
724 conditions: join_conditions,
725 shared_variables,
726 }));
727 }
728
729 let mut dpccp = DPccp::new(&graph, &self.cost_model, &self.card_estimator);
731 let plan = dpccp.optimize()?;
732
733 Some(plan.operator)
734 }
735
736 fn push_filters_down(&self, op: LogicalOperator) -> LogicalOperator {
741 match op {
742 LogicalOperator::Filter(filter) => {
744 let optimized_input = self.push_filters_down(*filter.input);
745 self.try_push_filter_into(filter.predicate, optimized_input)
746 }
747 LogicalOperator::Return(mut ret) => {
749 ret.input = Box::new(self.push_filters_down(*ret.input));
750 LogicalOperator::Return(ret)
751 }
752 LogicalOperator::Project(mut proj) => {
753 proj.input = Box::new(self.push_filters_down(*proj.input));
754 LogicalOperator::Project(proj)
755 }
756 LogicalOperator::Limit(mut limit) => {
757 limit.input = Box::new(self.push_filters_down(*limit.input));
758 LogicalOperator::Limit(limit)
759 }
760 LogicalOperator::Skip(mut skip) => {
761 skip.input = Box::new(self.push_filters_down(*skip.input));
762 LogicalOperator::Skip(skip)
763 }
764 LogicalOperator::Sort(mut sort) => {
765 sort.input = Box::new(self.push_filters_down(*sort.input));
766 LogicalOperator::Sort(sort)
767 }
768 LogicalOperator::Distinct(mut distinct) => {
769 distinct.input = Box::new(self.push_filters_down(*distinct.input));
770 LogicalOperator::Distinct(distinct)
771 }
772 LogicalOperator::Expand(mut expand) => {
773 expand.input = Box::new(self.push_filters_down(*expand.input));
774 LogicalOperator::Expand(expand)
775 }
776 LogicalOperator::Join(mut join) => {
777 join.left = Box::new(self.push_filters_down(*join.left));
778 join.right = Box::new(self.push_filters_down(*join.right));
779 LogicalOperator::Join(join)
780 }
781 LogicalOperator::Aggregate(mut agg) => {
782 agg.input = Box::new(self.push_filters_down(*agg.input));
783 LogicalOperator::Aggregate(agg)
784 }
785 LogicalOperator::MapCollect(mut mc) => {
786 mc.input = Box::new(self.push_filters_down(*mc.input));
787 LogicalOperator::MapCollect(mc)
788 }
789 LogicalOperator::MultiWayJoin(mut mwj) => {
790 mwj.inputs = mwj
791 .inputs
792 .into_iter()
793 .map(|input| self.push_filters_down(input))
794 .collect();
795 LogicalOperator::MultiWayJoin(mwj)
796 }
797 other => other,
799 }
800 }
801
802 fn try_push_filter_into(
807 &self,
808 predicate: LogicalExpression,
809 op: LogicalOperator,
810 ) -> LogicalOperator {
811 match op {
812 LogicalOperator::Project(mut proj) => {
814 let predicate_vars = self.extract_variables(&predicate);
815 let computed_vars = self.extract_projection_aliases(&proj.projections);
816
817 if predicate_vars.is_disjoint(&computed_vars) {
819 proj.input = Box::new(self.try_push_filter_into(predicate, *proj.input));
820 LogicalOperator::Project(proj)
821 } else {
822 LogicalOperator::Filter(FilterOp {
824 predicate,
825 pushdown_hint: None,
826 input: Box::new(LogicalOperator::Project(proj)),
827 })
828 }
829 }
830
831 LogicalOperator::Return(mut ret) => {
833 ret.input = Box::new(self.try_push_filter_into(predicate, *ret.input));
834 LogicalOperator::Return(ret)
835 }
836
837 LogicalOperator::Expand(mut expand) => {
839 let predicate_vars = self.extract_variables(&predicate);
840
841 let mut introduced_vars = vec![&expand.to_variable];
846 if let Some(ref edge_var) = expand.edge_variable {
847 introduced_vars.push(edge_var);
848 }
849 if let Some(ref path_alias) = expand.path_alias {
850 introduced_vars.push(path_alias);
851 }
852
853 let uses_introduced_vars =
855 predicate_vars.iter().any(|v| introduced_vars.contains(&v));
856
857 if !uses_introduced_vars {
858 expand.input = Box::new(self.try_push_filter_into(predicate, *expand.input));
860 LogicalOperator::Expand(expand)
861 } else {
862 LogicalOperator::Filter(FilterOp {
864 predicate,
865 pushdown_hint: None,
866 input: Box::new(LogicalOperator::Expand(expand)),
867 })
868 }
869 }
870
871 LogicalOperator::Join(mut join) => {
873 let predicate_vars = self.extract_variables(&predicate);
874 let left_vars = self.collect_output_variables(&join.left);
875 let right_vars = self.collect_output_variables(&join.right);
876
877 let uses_left = predicate_vars.iter().any(|v| left_vars.contains(v));
878 let uses_right = predicate_vars.iter().any(|v| right_vars.contains(v));
879
880 if uses_left && !uses_right {
881 join.left = Box::new(self.try_push_filter_into(predicate, *join.left));
883 LogicalOperator::Join(join)
884 } else if uses_right && !uses_left {
885 join.right = Box::new(self.try_push_filter_into(predicate, *join.right));
887 LogicalOperator::Join(join)
888 } else {
889 LogicalOperator::Filter(FilterOp {
891 predicate,
892 pushdown_hint: None,
893 input: Box::new(LogicalOperator::Join(join)),
894 })
895 }
896 }
897
898 LogicalOperator::Aggregate(agg) => LogicalOperator::Filter(FilterOp {
900 predicate,
901 pushdown_hint: None,
902 input: Box::new(LogicalOperator::Aggregate(agg)),
903 }),
904
905 LogicalOperator::NodeScan(scan) => LogicalOperator::Filter(FilterOp {
907 predicate,
908 pushdown_hint: None,
909 input: Box::new(LogicalOperator::NodeScan(scan)),
910 }),
911
912 other => LogicalOperator::Filter(FilterOp {
914 predicate,
915 pushdown_hint: None,
916 input: Box::new(other),
917 }),
918 }
919 }
920
921 fn collect_output_variables(&self, op: &LogicalOperator) -> HashSet<String> {
923 let mut vars = HashSet::new();
924 Self::collect_output_variables_recursive(op, &mut vars);
925 vars
926 }
927
928 fn collect_output_variables_recursive(op: &LogicalOperator, vars: &mut HashSet<String>) {
930 match op {
931 LogicalOperator::NodeScan(scan) => {
932 vars.insert(scan.variable.clone());
933 }
934 LogicalOperator::EdgeScan(scan) => {
935 vars.insert(scan.variable.clone());
936 }
937 LogicalOperator::Expand(expand) => {
938 vars.insert(expand.to_variable.clone());
939 if let Some(edge_var) = &expand.edge_variable {
940 vars.insert(edge_var.clone());
941 }
942 Self::collect_output_variables_recursive(&expand.input, vars);
943 }
944 LogicalOperator::Filter(filter) => {
945 Self::collect_output_variables_recursive(&filter.input, vars);
946 }
947 LogicalOperator::Project(proj) => {
948 for p in &proj.projections {
949 if let Some(alias) = &p.alias {
950 vars.insert(alias.clone());
951 }
952 }
953 Self::collect_output_variables_recursive(&proj.input, vars);
954 }
955 LogicalOperator::Join(join) => {
956 Self::collect_output_variables_recursive(&join.left, vars);
957 Self::collect_output_variables_recursive(&join.right, vars);
958 }
959 LogicalOperator::Aggregate(agg) => {
960 for expr in &agg.group_by {
961 Self::collect_variables(expr, vars);
962 }
963 for agg_expr in &agg.aggregates {
964 if let Some(alias) = &agg_expr.alias {
965 vars.insert(alias.clone());
966 }
967 }
968 }
969 LogicalOperator::Return(ret) => {
970 Self::collect_output_variables_recursive(&ret.input, vars);
971 }
972 LogicalOperator::Limit(limit) => {
973 Self::collect_output_variables_recursive(&limit.input, vars);
974 }
975 LogicalOperator::Skip(skip) => {
976 Self::collect_output_variables_recursive(&skip.input, vars);
977 }
978 LogicalOperator::Sort(sort) => {
979 Self::collect_output_variables_recursive(&sort.input, vars);
980 }
981 LogicalOperator::Distinct(distinct) => {
982 Self::collect_output_variables_recursive(&distinct.input, vars);
983 }
984 _ => {}
985 }
986 }
987
988 fn extract_variables(&self, expr: &LogicalExpression) -> HashSet<String> {
990 let mut vars = HashSet::new();
991 Self::collect_variables(expr, &mut vars);
992 vars
993 }
994
995 fn collect_variables(expr: &LogicalExpression, vars: &mut HashSet<String>) {
997 match expr {
998 LogicalExpression::Variable(name) => {
999 vars.insert(name.clone());
1000 }
1001 LogicalExpression::Property { variable, .. } => {
1002 vars.insert(variable.clone());
1003 }
1004 LogicalExpression::Binary { left, right, .. } => {
1005 Self::collect_variables(left, vars);
1006 Self::collect_variables(right, vars);
1007 }
1008 LogicalExpression::Unary { operand, .. } => {
1009 Self::collect_variables(operand, vars);
1010 }
1011 LogicalExpression::FunctionCall { args, .. } => {
1012 for arg in args {
1013 Self::collect_variables(arg, vars);
1014 }
1015 }
1016 LogicalExpression::List(items) => {
1017 for item in items {
1018 Self::collect_variables(item, vars);
1019 }
1020 }
1021 LogicalExpression::Map(pairs) => {
1022 for (_, value) in pairs {
1023 Self::collect_variables(value, vars);
1024 }
1025 }
1026 LogicalExpression::IndexAccess { base, index } => {
1027 Self::collect_variables(base, vars);
1028 Self::collect_variables(index, vars);
1029 }
1030 LogicalExpression::SliceAccess { base, start, end } => {
1031 Self::collect_variables(base, vars);
1032 if let Some(s) = start {
1033 Self::collect_variables(s, vars);
1034 }
1035 if let Some(e) = end {
1036 Self::collect_variables(e, vars);
1037 }
1038 }
1039 LogicalExpression::Case {
1040 operand,
1041 when_clauses,
1042 else_clause,
1043 } => {
1044 if let Some(op) = operand {
1045 Self::collect_variables(op, vars);
1046 }
1047 for (cond, result) in when_clauses {
1048 Self::collect_variables(cond, vars);
1049 Self::collect_variables(result, vars);
1050 }
1051 if let Some(else_expr) = else_clause {
1052 Self::collect_variables(else_expr, vars);
1053 }
1054 }
1055 LogicalExpression::Labels(var)
1056 | LogicalExpression::Type(var)
1057 | LogicalExpression::Id(var) => {
1058 vars.insert(var.clone());
1059 }
1060 LogicalExpression::Literal(_) | LogicalExpression::Parameter(_) => {}
1061 LogicalExpression::ListComprehension {
1062 list_expr,
1063 filter_expr,
1064 map_expr,
1065 ..
1066 } => {
1067 Self::collect_variables(list_expr, vars);
1068 if let Some(filter) = filter_expr {
1069 Self::collect_variables(filter, vars);
1070 }
1071 Self::collect_variables(map_expr, vars);
1072 }
1073 LogicalExpression::ListPredicate {
1074 list_expr,
1075 predicate,
1076 ..
1077 } => {
1078 Self::collect_variables(list_expr, vars);
1079 Self::collect_variables(predicate, vars);
1080 }
1081 LogicalExpression::ExistsSubquery(_) | LogicalExpression::CountSubquery(_) => {
1082 }
1084 LogicalExpression::PatternComprehension { projection, .. } => {
1085 Self::collect_variables(projection, vars);
1086 }
1087 LogicalExpression::MapProjection { base, entries } => {
1088 vars.insert(base.clone());
1089 for entry in entries {
1090 if let crate::query::plan::MapProjectionEntry::LiteralEntry(_, expr) = entry {
1091 Self::collect_variables(expr, vars);
1092 }
1093 }
1094 }
1095 LogicalExpression::Reduce {
1096 initial,
1097 list,
1098 expression,
1099 ..
1100 } => {
1101 Self::collect_variables(initial, vars);
1102 Self::collect_variables(list, vars);
1103 Self::collect_variables(expression, vars);
1104 }
1105 }
1106 }
1107
1108 fn extract_projection_aliases(
1110 &self,
1111 projections: &[crate::query::plan::Projection],
1112 ) -> HashSet<String> {
1113 projections.iter().filter_map(|p| p.alias.clone()).collect()
1114 }
1115}
1116
1117impl Default for Optimizer {
1118 fn default() -> Self {
1119 Self::new()
1120 }
1121}
1122
1123#[cfg(test)]
1124mod tests {
1125 use super::*;
1126 use crate::query::plan::{
1127 AggregateExpr, AggregateFunction, AggregateOp, BinaryOp, DistinctOp, ExpandDirection,
1128 ExpandOp, JoinOp, JoinType, LimitOp, NodeScanOp, PathMode, ProjectOp, Projection,
1129 ReturnItem, ReturnOp, SkipOp, SortKey, SortOp, SortOrder, UnaryOp,
1130 };
1131 use grafeo_common::types::Value;
1132
1133 #[test]
1134 fn test_optimizer_filter_pushdown_simple() {
1135 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1140 items: vec![ReturnItem {
1141 expression: LogicalExpression::Variable("n".to_string()),
1142 alias: None,
1143 }],
1144 distinct: false,
1145 input: Box::new(LogicalOperator::Filter(FilterOp {
1146 predicate: LogicalExpression::Binary {
1147 left: Box::new(LogicalExpression::Property {
1148 variable: "n".to_string(),
1149 property: "age".to_string(),
1150 }),
1151 op: BinaryOp::Gt,
1152 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1153 },
1154 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1155 variable: "n".to_string(),
1156 label: Some("Person".to_string()),
1157 input: None,
1158 })),
1159 pushdown_hint: None,
1160 })),
1161 }));
1162
1163 let optimizer = Optimizer::new();
1164 let optimized = optimizer.optimize(plan).unwrap();
1165
1166 if let LogicalOperator::Return(ret) = &optimized.root
1168 && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1169 && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1170 {
1171 assert_eq!(scan.variable, "n");
1172 return;
1173 }
1174 panic!("Expected Return -> Filter -> NodeScan structure");
1175 }
1176
1177 #[test]
1178 fn test_optimizer_filter_pushdown_through_expand() {
1179 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1183 items: vec![ReturnItem {
1184 expression: LogicalExpression::Variable("b".to_string()),
1185 alias: None,
1186 }],
1187 distinct: false,
1188 input: Box::new(LogicalOperator::Filter(FilterOp {
1189 predicate: LogicalExpression::Binary {
1190 left: Box::new(LogicalExpression::Property {
1191 variable: "a".to_string(),
1192 property: "age".to_string(),
1193 }),
1194 op: BinaryOp::Gt,
1195 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1196 },
1197 pushdown_hint: None,
1198 input: Box::new(LogicalOperator::Expand(ExpandOp {
1199 from_variable: "a".to_string(),
1200 to_variable: "b".to_string(),
1201 edge_variable: None,
1202 direction: ExpandDirection::Outgoing,
1203 edge_types: vec!["KNOWS".to_string()],
1204 min_hops: 1,
1205 max_hops: Some(1),
1206 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1207 variable: "a".to_string(),
1208 label: Some("Person".to_string()),
1209 input: None,
1210 })),
1211 path_alias: None,
1212 path_mode: PathMode::Walk,
1213 })),
1214 })),
1215 }));
1216
1217 let optimizer = Optimizer::new();
1218 let optimized = optimizer.optimize(plan).unwrap();
1219
1220 if let LogicalOperator::Return(ret) = &optimized.root
1223 && let LogicalOperator::Expand(expand) = ret.input.as_ref()
1224 && let LogicalOperator::Filter(filter) = expand.input.as_ref()
1225 && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1226 {
1227 assert_eq!(scan.variable, "a");
1228 assert_eq!(expand.from_variable, "a");
1229 assert_eq!(expand.to_variable, "b");
1230 return;
1231 }
1232 panic!("Expected Return -> Expand -> Filter -> NodeScan structure");
1233 }
1234
1235 #[test]
1236 fn test_optimizer_filter_not_pushed_through_expand_for_target_var() {
1237 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1241 items: vec![ReturnItem {
1242 expression: LogicalExpression::Variable("a".to_string()),
1243 alias: None,
1244 }],
1245 distinct: false,
1246 input: Box::new(LogicalOperator::Filter(FilterOp {
1247 predicate: LogicalExpression::Binary {
1248 left: Box::new(LogicalExpression::Property {
1249 variable: "b".to_string(),
1250 property: "age".to_string(),
1251 }),
1252 op: BinaryOp::Gt,
1253 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1254 },
1255 pushdown_hint: None,
1256 input: Box::new(LogicalOperator::Expand(ExpandOp {
1257 from_variable: "a".to_string(),
1258 to_variable: "b".to_string(),
1259 edge_variable: None,
1260 direction: ExpandDirection::Outgoing,
1261 edge_types: vec!["KNOWS".to_string()],
1262 min_hops: 1,
1263 max_hops: Some(1),
1264 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1265 variable: "a".to_string(),
1266 label: Some("Person".to_string()),
1267 input: None,
1268 })),
1269 path_alias: None,
1270 path_mode: PathMode::Walk,
1271 })),
1272 })),
1273 }));
1274
1275 let optimizer = Optimizer::new();
1276 let optimized = optimizer.optimize(plan).unwrap();
1277
1278 if let LogicalOperator::Return(ret) = &optimized.root
1281 && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1282 {
1283 if let LogicalExpression::Binary { left, .. } = &filter.predicate
1285 && let LogicalExpression::Property { variable, .. } = left.as_ref()
1286 {
1287 assert_eq!(variable, "b");
1288 }
1289
1290 if let LogicalOperator::Expand(expand) = filter.input.as_ref()
1291 && let LogicalOperator::NodeScan(_) = expand.input.as_ref()
1292 {
1293 return;
1294 }
1295 }
1296 panic!("Expected Return -> Filter -> Expand -> NodeScan structure");
1297 }
1298
1299 #[test]
1300 fn test_optimizer_extract_variables() {
1301 let optimizer = Optimizer::new();
1302
1303 let expr = LogicalExpression::Binary {
1304 left: Box::new(LogicalExpression::Property {
1305 variable: "n".to_string(),
1306 property: "age".to_string(),
1307 }),
1308 op: BinaryOp::Gt,
1309 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1310 };
1311
1312 let vars = optimizer.extract_variables(&expr);
1313 assert_eq!(vars.len(), 1);
1314 assert!(vars.contains("n"));
1315 }
1316
1317 #[test]
1320 fn test_optimizer_default() {
1321 let optimizer = Optimizer::default();
1322 let plan = LogicalPlan::new(LogicalOperator::Empty);
1324 let result = optimizer.optimize(plan);
1325 assert!(result.is_ok());
1326 }
1327
1328 #[test]
1329 fn test_optimizer_with_filter_pushdown_disabled() {
1330 let optimizer = Optimizer::new().with_filter_pushdown(false);
1331
1332 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1333 items: vec![ReturnItem {
1334 expression: LogicalExpression::Variable("n".to_string()),
1335 alias: None,
1336 }],
1337 distinct: false,
1338 input: Box::new(LogicalOperator::Filter(FilterOp {
1339 predicate: LogicalExpression::Literal(Value::Bool(true)),
1340 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1341 variable: "n".to_string(),
1342 label: None,
1343 input: None,
1344 })),
1345 pushdown_hint: None,
1346 })),
1347 }));
1348
1349 let optimized = optimizer.optimize(plan).unwrap();
1350 if let LogicalOperator::Return(ret) = &optimized.root
1352 && let LogicalOperator::Filter(_) = ret.input.as_ref()
1353 {
1354 return;
1355 }
1356 panic!("Expected unchanged structure");
1357 }
1358
1359 #[test]
1360 fn test_optimizer_with_join_reorder_disabled() {
1361 let optimizer = Optimizer::new().with_join_reorder(false);
1362 assert!(
1363 optimizer
1364 .optimize(LogicalPlan::new(LogicalOperator::Empty))
1365 .is_ok()
1366 );
1367 }
1368
1369 #[test]
1370 fn test_optimizer_with_cost_model() {
1371 let cost_model = CostModel::new();
1372 let optimizer = Optimizer::new().with_cost_model(cost_model);
1373 assert!(
1374 optimizer
1375 .cost_model()
1376 .estimate(&LogicalOperator::Empty, 0.0)
1377 .total()
1378 < 0.001
1379 );
1380 }
1381
1382 #[test]
1383 fn test_optimizer_with_cardinality_estimator() {
1384 let mut estimator = CardinalityEstimator::new();
1385 estimator.add_table_stats("Test", TableStats::new(500));
1386 let optimizer = Optimizer::new().with_cardinality_estimator(estimator);
1387
1388 let scan = LogicalOperator::NodeScan(NodeScanOp {
1389 variable: "n".to_string(),
1390 label: Some("Test".to_string()),
1391 input: None,
1392 });
1393 let plan = LogicalPlan::new(scan);
1394
1395 let cardinality = optimizer.estimate_cardinality(&plan);
1396 assert!((cardinality - 500.0).abs() < 0.001);
1397 }
1398
1399 #[test]
1400 fn test_optimizer_estimate_cost() {
1401 let optimizer = Optimizer::new();
1402 let plan = LogicalPlan::new(LogicalOperator::NodeScan(NodeScanOp {
1403 variable: "n".to_string(),
1404 label: None,
1405 input: None,
1406 }));
1407
1408 let cost = optimizer.estimate_cost(&plan);
1409 assert!(cost.total() > 0.0);
1410 }
1411
1412 #[test]
1415 fn test_filter_pushdown_through_project() {
1416 let optimizer = Optimizer::new();
1417
1418 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1419 predicate: LogicalExpression::Binary {
1420 left: Box::new(LogicalExpression::Property {
1421 variable: "n".to_string(),
1422 property: "age".to_string(),
1423 }),
1424 op: BinaryOp::Gt,
1425 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1426 },
1427 pushdown_hint: None,
1428 input: Box::new(LogicalOperator::Project(ProjectOp {
1429 projections: vec![Projection {
1430 expression: LogicalExpression::Variable("n".to_string()),
1431 alias: None,
1432 }],
1433 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1434 variable: "n".to_string(),
1435 label: None,
1436 input: None,
1437 })),
1438 })),
1439 }));
1440
1441 let optimized = optimizer.optimize(plan).unwrap();
1442
1443 if let LogicalOperator::Project(proj) = &optimized.root
1445 && let LogicalOperator::Filter(_) = proj.input.as_ref()
1446 {
1447 return;
1448 }
1449 panic!("Expected Project -> Filter structure");
1450 }
1451
1452 #[test]
1453 fn test_filter_not_pushed_through_project_with_alias() {
1454 let optimizer = Optimizer::new();
1455
1456 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1458 predicate: LogicalExpression::Binary {
1459 left: Box::new(LogicalExpression::Variable("x".to_string())),
1460 op: BinaryOp::Gt,
1461 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1462 },
1463 pushdown_hint: None,
1464 input: Box::new(LogicalOperator::Project(ProjectOp {
1465 projections: vec![Projection {
1466 expression: LogicalExpression::Property {
1467 variable: "n".to_string(),
1468 property: "age".to_string(),
1469 },
1470 alias: Some("x".to_string()),
1471 }],
1472 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1473 variable: "n".to_string(),
1474 label: None,
1475 input: None,
1476 })),
1477 })),
1478 }));
1479
1480 let optimized = optimizer.optimize(plan).unwrap();
1481
1482 if let LogicalOperator::Filter(filter) = &optimized.root
1484 && let LogicalOperator::Project(_) = filter.input.as_ref()
1485 {
1486 return;
1487 }
1488 panic!("Expected Filter -> Project structure");
1489 }
1490
1491 #[test]
1492 fn test_filter_pushdown_through_limit() {
1493 let optimizer = Optimizer::new();
1494
1495 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1496 predicate: LogicalExpression::Literal(Value::Bool(true)),
1497 pushdown_hint: None,
1498 input: Box::new(LogicalOperator::Limit(LimitOp {
1499 count: 10,
1500 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1501 variable: "n".to_string(),
1502 label: None,
1503 input: None,
1504 })),
1505 })),
1506 }));
1507
1508 let optimized = optimizer.optimize(plan).unwrap();
1509
1510 if let LogicalOperator::Filter(filter) = &optimized.root
1512 && let LogicalOperator::Limit(_) = filter.input.as_ref()
1513 {
1514 return;
1515 }
1516 panic!("Expected Filter -> Limit structure");
1517 }
1518
1519 #[test]
1520 fn test_filter_pushdown_through_sort() {
1521 let optimizer = Optimizer::new();
1522
1523 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1524 predicate: LogicalExpression::Literal(Value::Bool(true)),
1525 pushdown_hint: None,
1526 input: Box::new(LogicalOperator::Sort(SortOp {
1527 keys: vec![SortKey {
1528 expression: LogicalExpression::Variable("n".to_string()),
1529 order: SortOrder::Ascending,
1530 nulls: None,
1531 }],
1532 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1533 variable: "n".to_string(),
1534 label: None,
1535 input: None,
1536 })),
1537 })),
1538 }));
1539
1540 let optimized = optimizer.optimize(plan).unwrap();
1541
1542 if let LogicalOperator::Filter(filter) = &optimized.root
1544 && let LogicalOperator::Sort(_) = filter.input.as_ref()
1545 {
1546 return;
1547 }
1548 panic!("Expected Filter -> Sort structure");
1549 }
1550
1551 #[test]
1552 fn test_filter_pushdown_through_distinct() {
1553 let optimizer = Optimizer::new();
1554
1555 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1556 predicate: LogicalExpression::Literal(Value::Bool(true)),
1557 pushdown_hint: None,
1558 input: Box::new(LogicalOperator::Distinct(DistinctOp {
1559 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1560 variable: "n".to_string(),
1561 label: None,
1562 input: None,
1563 })),
1564 columns: None,
1565 })),
1566 }));
1567
1568 let optimized = optimizer.optimize(plan).unwrap();
1569
1570 if let LogicalOperator::Filter(filter) = &optimized.root
1572 && let LogicalOperator::Distinct(_) = filter.input.as_ref()
1573 {
1574 return;
1575 }
1576 panic!("Expected Filter -> Distinct structure");
1577 }
1578
1579 #[test]
1580 fn test_filter_not_pushed_through_aggregate() {
1581 let optimizer = Optimizer::new();
1582
1583 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1584 predicate: LogicalExpression::Binary {
1585 left: Box::new(LogicalExpression::Variable("cnt".to_string())),
1586 op: BinaryOp::Gt,
1587 right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1588 },
1589 pushdown_hint: None,
1590 input: Box::new(LogicalOperator::Aggregate(AggregateOp {
1591 group_by: vec![],
1592 aggregates: vec![AggregateExpr {
1593 function: AggregateFunction::Count,
1594 expression: None,
1595 expression2: None,
1596 distinct: false,
1597 alias: Some("cnt".to_string()),
1598 percentile: None,
1599 }],
1600 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1601 variable: "n".to_string(),
1602 label: None,
1603 input: None,
1604 })),
1605 having: None,
1606 })),
1607 }));
1608
1609 let optimized = optimizer.optimize(plan).unwrap();
1610
1611 if let LogicalOperator::Filter(filter) = &optimized.root
1613 && let LogicalOperator::Aggregate(_) = filter.input.as_ref()
1614 {
1615 return;
1616 }
1617 panic!("Expected Filter -> Aggregate structure");
1618 }
1619
1620 #[test]
1621 fn test_filter_pushdown_to_left_join_side() {
1622 let optimizer = Optimizer::new();
1623
1624 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1626 predicate: LogicalExpression::Binary {
1627 left: Box::new(LogicalExpression::Property {
1628 variable: "a".to_string(),
1629 property: "age".to_string(),
1630 }),
1631 op: BinaryOp::Gt,
1632 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1633 },
1634 pushdown_hint: None,
1635 input: Box::new(LogicalOperator::Join(JoinOp {
1636 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1637 variable: "a".to_string(),
1638 label: Some("Person".to_string()),
1639 input: None,
1640 })),
1641 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1642 variable: "b".to_string(),
1643 label: Some("Company".to_string()),
1644 input: None,
1645 })),
1646 join_type: JoinType::Inner,
1647 conditions: vec![],
1648 })),
1649 }));
1650
1651 let optimized = optimizer.optimize(plan).unwrap();
1652
1653 if let LogicalOperator::Join(join) = &optimized.root
1655 && let LogicalOperator::Filter(_) = join.left.as_ref()
1656 {
1657 return;
1658 }
1659 panic!("Expected Join with Filter on left side");
1660 }
1661
1662 #[test]
1663 fn test_filter_pushdown_to_right_join_side() {
1664 let optimizer = Optimizer::new();
1665
1666 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1668 predicate: LogicalExpression::Binary {
1669 left: Box::new(LogicalExpression::Property {
1670 variable: "b".to_string(),
1671 property: "name".to_string(),
1672 }),
1673 op: BinaryOp::Eq,
1674 right: Box::new(LogicalExpression::Literal(Value::String("Acme".into()))),
1675 },
1676 pushdown_hint: None,
1677 input: Box::new(LogicalOperator::Join(JoinOp {
1678 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1679 variable: "a".to_string(),
1680 label: Some("Person".to_string()),
1681 input: None,
1682 })),
1683 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1684 variable: "b".to_string(),
1685 label: Some("Company".to_string()),
1686 input: None,
1687 })),
1688 join_type: JoinType::Inner,
1689 conditions: vec![],
1690 })),
1691 }));
1692
1693 let optimized = optimizer.optimize(plan).unwrap();
1694
1695 if let LogicalOperator::Join(join) = &optimized.root
1697 && let LogicalOperator::Filter(_) = join.right.as_ref()
1698 {
1699 return;
1700 }
1701 panic!("Expected Join with Filter on right side");
1702 }
1703
1704 #[test]
1705 fn test_filter_not_pushed_when_uses_both_join_sides() {
1706 let optimizer = Optimizer::new();
1707
1708 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1710 predicate: LogicalExpression::Binary {
1711 left: Box::new(LogicalExpression::Property {
1712 variable: "a".to_string(),
1713 property: "id".to_string(),
1714 }),
1715 op: BinaryOp::Eq,
1716 right: Box::new(LogicalExpression::Property {
1717 variable: "b".to_string(),
1718 property: "a_id".to_string(),
1719 }),
1720 },
1721 pushdown_hint: None,
1722 input: Box::new(LogicalOperator::Join(JoinOp {
1723 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1724 variable: "a".to_string(),
1725 label: None,
1726 input: None,
1727 })),
1728 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1729 variable: "b".to_string(),
1730 label: None,
1731 input: None,
1732 })),
1733 join_type: JoinType::Inner,
1734 conditions: vec![],
1735 })),
1736 }));
1737
1738 let optimized = optimizer.optimize(plan).unwrap();
1739
1740 if let LogicalOperator::Filter(filter) = &optimized.root
1742 && let LogicalOperator::Join(_) = filter.input.as_ref()
1743 {
1744 return;
1745 }
1746 panic!("Expected Filter -> Join structure");
1747 }
1748
1749 #[test]
1752 fn test_extract_variables_from_variable() {
1753 let optimizer = Optimizer::new();
1754 let expr = LogicalExpression::Variable("x".to_string());
1755 let vars = optimizer.extract_variables(&expr);
1756 assert_eq!(vars.len(), 1);
1757 assert!(vars.contains("x"));
1758 }
1759
1760 #[test]
1761 fn test_extract_variables_from_unary() {
1762 let optimizer = Optimizer::new();
1763 let expr = LogicalExpression::Unary {
1764 op: UnaryOp::Not,
1765 operand: Box::new(LogicalExpression::Variable("x".to_string())),
1766 };
1767 let vars = optimizer.extract_variables(&expr);
1768 assert_eq!(vars.len(), 1);
1769 assert!(vars.contains("x"));
1770 }
1771
1772 #[test]
1773 fn test_extract_variables_from_function_call() {
1774 let optimizer = Optimizer::new();
1775 let expr = LogicalExpression::FunctionCall {
1776 name: "length".to_string(),
1777 args: vec![
1778 LogicalExpression::Variable("a".to_string()),
1779 LogicalExpression::Variable("b".to_string()),
1780 ],
1781 distinct: false,
1782 };
1783 let vars = optimizer.extract_variables(&expr);
1784 assert_eq!(vars.len(), 2);
1785 assert!(vars.contains("a"));
1786 assert!(vars.contains("b"));
1787 }
1788
1789 #[test]
1790 fn test_extract_variables_from_list() {
1791 let optimizer = Optimizer::new();
1792 let expr = LogicalExpression::List(vec![
1793 LogicalExpression::Variable("a".to_string()),
1794 LogicalExpression::Literal(Value::Int64(1)),
1795 LogicalExpression::Variable("b".to_string()),
1796 ]);
1797 let vars = optimizer.extract_variables(&expr);
1798 assert_eq!(vars.len(), 2);
1799 assert!(vars.contains("a"));
1800 assert!(vars.contains("b"));
1801 }
1802
1803 #[test]
1804 fn test_extract_variables_from_map() {
1805 let optimizer = Optimizer::new();
1806 let expr = LogicalExpression::Map(vec![
1807 (
1808 "key1".to_string(),
1809 LogicalExpression::Variable("a".to_string()),
1810 ),
1811 (
1812 "key2".to_string(),
1813 LogicalExpression::Variable("b".to_string()),
1814 ),
1815 ]);
1816 let vars = optimizer.extract_variables(&expr);
1817 assert_eq!(vars.len(), 2);
1818 assert!(vars.contains("a"));
1819 assert!(vars.contains("b"));
1820 }
1821
1822 #[test]
1823 fn test_extract_variables_from_index_access() {
1824 let optimizer = Optimizer::new();
1825 let expr = LogicalExpression::IndexAccess {
1826 base: Box::new(LogicalExpression::Variable("list".to_string())),
1827 index: Box::new(LogicalExpression::Variable("idx".to_string())),
1828 };
1829 let vars = optimizer.extract_variables(&expr);
1830 assert_eq!(vars.len(), 2);
1831 assert!(vars.contains("list"));
1832 assert!(vars.contains("idx"));
1833 }
1834
1835 #[test]
1836 fn test_extract_variables_from_slice_access() {
1837 let optimizer = Optimizer::new();
1838 let expr = LogicalExpression::SliceAccess {
1839 base: Box::new(LogicalExpression::Variable("list".to_string())),
1840 start: Some(Box::new(LogicalExpression::Variable("s".to_string()))),
1841 end: Some(Box::new(LogicalExpression::Variable("e".to_string()))),
1842 };
1843 let vars = optimizer.extract_variables(&expr);
1844 assert_eq!(vars.len(), 3);
1845 assert!(vars.contains("list"));
1846 assert!(vars.contains("s"));
1847 assert!(vars.contains("e"));
1848 }
1849
1850 #[test]
1851 fn test_extract_variables_from_case() {
1852 let optimizer = Optimizer::new();
1853 let expr = LogicalExpression::Case {
1854 operand: Some(Box::new(LogicalExpression::Variable("x".to_string()))),
1855 when_clauses: vec![(
1856 LogicalExpression::Literal(Value::Int64(1)),
1857 LogicalExpression::Variable("a".to_string()),
1858 )],
1859 else_clause: Some(Box::new(LogicalExpression::Variable("b".to_string()))),
1860 };
1861 let vars = optimizer.extract_variables(&expr);
1862 assert_eq!(vars.len(), 3);
1863 assert!(vars.contains("x"));
1864 assert!(vars.contains("a"));
1865 assert!(vars.contains("b"));
1866 }
1867
1868 #[test]
1869 fn test_extract_variables_from_labels() {
1870 let optimizer = Optimizer::new();
1871 let expr = LogicalExpression::Labels("n".to_string());
1872 let vars = optimizer.extract_variables(&expr);
1873 assert_eq!(vars.len(), 1);
1874 assert!(vars.contains("n"));
1875 }
1876
1877 #[test]
1878 fn test_extract_variables_from_type() {
1879 let optimizer = Optimizer::new();
1880 let expr = LogicalExpression::Type("e".to_string());
1881 let vars = optimizer.extract_variables(&expr);
1882 assert_eq!(vars.len(), 1);
1883 assert!(vars.contains("e"));
1884 }
1885
1886 #[test]
1887 fn test_extract_variables_from_id() {
1888 let optimizer = Optimizer::new();
1889 let expr = LogicalExpression::Id("n".to_string());
1890 let vars = optimizer.extract_variables(&expr);
1891 assert_eq!(vars.len(), 1);
1892 assert!(vars.contains("n"));
1893 }
1894
1895 #[test]
1896 fn test_extract_variables_from_list_comprehension() {
1897 let optimizer = Optimizer::new();
1898 let expr = LogicalExpression::ListComprehension {
1899 variable: "x".to_string(),
1900 list_expr: Box::new(LogicalExpression::Variable("items".to_string())),
1901 filter_expr: Some(Box::new(LogicalExpression::Variable("pred".to_string()))),
1902 map_expr: Box::new(LogicalExpression::Variable("result".to_string())),
1903 };
1904 let vars = optimizer.extract_variables(&expr);
1905 assert!(vars.contains("items"));
1906 assert!(vars.contains("pred"));
1907 assert!(vars.contains("result"));
1908 }
1909
1910 #[test]
1911 fn test_extract_variables_from_literal_and_parameter() {
1912 let optimizer = Optimizer::new();
1913
1914 let literal = LogicalExpression::Literal(Value::Int64(42));
1915 assert!(optimizer.extract_variables(&literal).is_empty());
1916
1917 let param = LogicalExpression::Parameter("p".to_string());
1918 assert!(optimizer.extract_variables(¶m).is_empty());
1919 }
1920
1921 #[test]
1924 fn test_recursive_filter_pushdown_through_skip() {
1925 let optimizer = Optimizer::new();
1926
1927 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1928 items: vec![ReturnItem {
1929 expression: LogicalExpression::Variable("n".to_string()),
1930 alias: None,
1931 }],
1932 distinct: false,
1933 input: Box::new(LogicalOperator::Filter(FilterOp {
1934 predicate: LogicalExpression::Literal(Value::Bool(true)),
1935 pushdown_hint: None,
1936 input: Box::new(LogicalOperator::Skip(SkipOp {
1937 count: 5,
1938 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1939 variable: "n".to_string(),
1940 label: None,
1941 input: None,
1942 })),
1943 })),
1944 })),
1945 }));
1946
1947 let optimized = optimizer.optimize(plan).unwrap();
1948
1949 assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1951 }
1952
1953 #[test]
1954 fn test_nested_filter_pushdown() {
1955 let optimizer = Optimizer::new();
1956
1957 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1959 items: vec![ReturnItem {
1960 expression: LogicalExpression::Variable("n".to_string()),
1961 alias: None,
1962 }],
1963 distinct: false,
1964 input: Box::new(LogicalOperator::Filter(FilterOp {
1965 predicate: LogicalExpression::Binary {
1966 left: Box::new(LogicalExpression::Property {
1967 variable: "n".to_string(),
1968 property: "x".to_string(),
1969 }),
1970 op: BinaryOp::Gt,
1971 right: Box::new(LogicalExpression::Literal(Value::Int64(1))),
1972 },
1973 pushdown_hint: None,
1974 input: Box::new(LogicalOperator::Filter(FilterOp {
1975 predicate: LogicalExpression::Binary {
1976 left: Box::new(LogicalExpression::Property {
1977 variable: "n".to_string(),
1978 property: "y".to_string(),
1979 }),
1980 op: BinaryOp::Lt,
1981 right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1982 },
1983 pushdown_hint: None,
1984 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1985 variable: "n".to_string(),
1986 label: None,
1987 input: None,
1988 })),
1989 })),
1990 })),
1991 }));
1992
1993 let optimized = optimizer.optimize(plan).unwrap();
1994 assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1995 }
1996
1997 #[test]
1998 fn test_cyclic_join_produces_multi_way_join() {
1999 use crate::query::plan::JoinCondition;
2000
2001 let scan_a = LogicalOperator::NodeScan(NodeScanOp {
2003 variable: "a".to_string(),
2004 label: Some("Person".to_string()),
2005 input: None,
2006 });
2007 let scan_b = LogicalOperator::NodeScan(NodeScanOp {
2008 variable: "b".to_string(),
2009 label: Some("Person".to_string()),
2010 input: None,
2011 });
2012 let scan_c = LogicalOperator::NodeScan(NodeScanOp {
2013 variable: "c".to_string(),
2014 label: Some("Person".to_string()),
2015 input: None,
2016 });
2017
2018 let join_ab = LogicalOperator::Join(JoinOp {
2020 left: Box::new(scan_a),
2021 right: Box::new(scan_b),
2022 join_type: JoinType::Inner,
2023 conditions: vec![JoinCondition {
2024 left: LogicalExpression::Variable("a".to_string()),
2025 right: LogicalExpression::Variable("b".to_string()),
2026 }],
2027 });
2028
2029 let join_abc = LogicalOperator::Join(JoinOp {
2030 left: Box::new(join_ab),
2031 right: Box::new(scan_c),
2032 join_type: JoinType::Inner,
2033 conditions: vec![
2034 JoinCondition {
2035 left: LogicalExpression::Variable("b".to_string()),
2036 right: LogicalExpression::Variable("c".to_string()),
2037 },
2038 JoinCondition {
2039 left: LogicalExpression::Variable("c".to_string()),
2040 right: LogicalExpression::Variable("a".to_string()),
2041 },
2042 ],
2043 });
2044
2045 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
2046 items: vec![ReturnItem {
2047 expression: LogicalExpression::Variable("a".to_string()),
2048 alias: None,
2049 }],
2050 distinct: false,
2051 input: Box::new(join_abc),
2052 }));
2053
2054 let mut optimizer = Optimizer::new();
2055 optimizer
2056 .card_estimator
2057 .add_table_stats("Person", cardinality::TableStats::new(1000));
2058
2059 let optimized = optimizer.optimize(plan).unwrap();
2060
2061 fn has_multi_way_join(op: &LogicalOperator) -> bool {
2063 match op {
2064 LogicalOperator::MultiWayJoin(_) => true,
2065 LogicalOperator::Return(ret) => has_multi_way_join(&ret.input),
2066 LogicalOperator::Filter(f) => has_multi_way_join(&f.input),
2067 LogicalOperator::Project(p) => has_multi_way_join(&p.input),
2068 _ => false,
2069 }
2070 }
2071
2072 assert!(
2073 has_multi_way_join(&optimized.root),
2074 "Expected MultiWayJoin for cyclic triangle pattern"
2075 );
2076 }
2077
2078 #[test]
2079 fn test_acyclic_join_uses_binary_joins() {
2080 use crate::query::plan::JoinCondition;
2081
2082 let scan_a = LogicalOperator::NodeScan(NodeScanOp {
2084 variable: "a".to_string(),
2085 label: Some("Person".to_string()),
2086 input: None,
2087 });
2088 let scan_b = LogicalOperator::NodeScan(NodeScanOp {
2089 variable: "b".to_string(),
2090 label: Some("Person".to_string()),
2091 input: None,
2092 });
2093 let scan_c = LogicalOperator::NodeScan(NodeScanOp {
2094 variable: "c".to_string(),
2095 label: Some("Company".to_string()),
2096 input: None,
2097 });
2098
2099 let join_ab = LogicalOperator::Join(JoinOp {
2100 left: Box::new(scan_a),
2101 right: Box::new(scan_b),
2102 join_type: JoinType::Inner,
2103 conditions: vec![JoinCondition {
2104 left: LogicalExpression::Variable("a".to_string()),
2105 right: LogicalExpression::Variable("b".to_string()),
2106 }],
2107 });
2108
2109 let join_abc = LogicalOperator::Join(JoinOp {
2110 left: Box::new(join_ab),
2111 right: Box::new(scan_c),
2112 join_type: JoinType::Inner,
2113 conditions: vec![JoinCondition {
2114 left: LogicalExpression::Variable("b".to_string()),
2115 right: LogicalExpression::Variable("c".to_string()),
2116 }],
2117 });
2118
2119 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
2120 items: vec![ReturnItem {
2121 expression: LogicalExpression::Variable("a".to_string()),
2122 alias: None,
2123 }],
2124 distinct: false,
2125 input: Box::new(join_abc),
2126 }));
2127
2128 let mut optimizer = Optimizer::new();
2129 optimizer
2130 .card_estimator
2131 .add_table_stats("Person", cardinality::TableStats::new(1000));
2132 optimizer
2133 .card_estimator
2134 .add_table_stats("Company", cardinality::TableStats::new(100));
2135
2136 let optimized = optimizer.optimize(plan).unwrap();
2137
2138 fn has_multi_way_join(op: &LogicalOperator) -> bool {
2140 match op {
2141 LogicalOperator::MultiWayJoin(_) => true,
2142 LogicalOperator::Return(ret) => has_multi_way_join(&ret.input),
2143 LogicalOperator::Filter(f) => has_multi_way_join(&f.input),
2144 LogicalOperator::Project(p) => has_multi_way_join(&p.input),
2145 LogicalOperator::Join(j) => {
2146 has_multi_way_join(&j.left) || has_multi_way_join(&j.right)
2147 }
2148 _ => false,
2149 }
2150 }
2151
2152 assert!(
2153 !has_multi_way_join(&optimized.root),
2154 "Acyclic join should NOT produce MultiWayJoin"
2155 );
2156 }
2157}