1pub mod cardinality;
15pub mod cost;
16pub mod join_order;
17
18pub use cardinality::{CardinalityEstimator, ColumnStats, TableStats};
19pub use cost::{Cost, CostModel};
20pub use join_order::{BitSet, DPccp, JoinGraph, JoinGraphBuilder, JoinPlan};
21
22use crate::query::plan::{FilterOp, LogicalExpression, LogicalOperator, LogicalPlan};
23use grafeo_common::utils::error::Result;
24use std::collections::HashSet;
25
26#[derive(Debug, Clone)]
28struct JoinInfo {
29 left_var: String,
30 right_var: String,
31 left_expr: LogicalExpression,
32 right_expr: LogicalExpression,
33}
34
35#[derive(Debug, Clone, PartialEq, Eq, Hash)]
37enum RequiredColumn {
38 Variable(String),
40 Property(String, String),
42}
43
44pub struct Optimizer {
49 enable_filter_pushdown: bool,
51 enable_join_reorder: bool,
53 enable_projection_pushdown: bool,
55 cost_model: CostModel,
57 card_estimator: CardinalityEstimator,
59}
60
61impl Optimizer {
62 #[must_use]
64 pub fn new() -> Self {
65 Self {
66 enable_filter_pushdown: true,
67 enable_join_reorder: true,
68 enable_projection_pushdown: true,
69 cost_model: CostModel::new(),
70 card_estimator: CardinalityEstimator::new(),
71 }
72 }
73
74 #[must_use]
79 pub fn from_store(store: &grafeo_core::graph::lpg::LpgStore) -> Self {
80 store.ensure_statistics_fresh();
81 let stats = store.statistics();
82 let estimator = CardinalityEstimator::from_statistics(&stats);
83 Self {
84 enable_filter_pushdown: true,
85 enable_join_reorder: true,
86 enable_projection_pushdown: true,
87 cost_model: CostModel::new(),
88 card_estimator: estimator,
89 }
90 }
91
92 pub fn with_filter_pushdown(mut self, enabled: bool) -> Self {
94 self.enable_filter_pushdown = enabled;
95 self
96 }
97
98 pub fn with_join_reorder(mut self, enabled: bool) -> Self {
100 self.enable_join_reorder = enabled;
101 self
102 }
103
104 pub fn with_projection_pushdown(mut self, enabled: bool) -> Self {
106 self.enable_projection_pushdown = enabled;
107 self
108 }
109
110 pub fn with_cost_model(mut self, cost_model: CostModel) -> Self {
112 self.cost_model = cost_model;
113 self
114 }
115
116 pub fn with_cardinality_estimator(mut self, estimator: CardinalityEstimator) -> Self {
118 self.card_estimator = estimator;
119 self
120 }
121
122 pub fn cost_model(&self) -> &CostModel {
124 &self.cost_model
125 }
126
127 pub fn cardinality_estimator(&self) -> &CardinalityEstimator {
129 &self.card_estimator
130 }
131
132 pub fn estimate_cost(&self, plan: &LogicalPlan) -> Cost {
134 let cardinality = self.card_estimator.estimate(&plan.root);
135 self.cost_model.estimate(&plan.root, cardinality)
136 }
137
138 pub fn estimate_cardinality(&self, plan: &LogicalPlan) -> f64 {
140 self.card_estimator.estimate(&plan.root)
141 }
142
143 pub fn optimize(&self, plan: LogicalPlan) -> Result<LogicalPlan> {
149 let mut root = plan.root;
150
151 if self.enable_filter_pushdown {
153 root = self.push_filters_down(root);
154 }
155
156 if self.enable_join_reorder {
157 root = self.reorder_joins(root);
158 }
159
160 if self.enable_projection_pushdown {
161 root = self.push_projections_down(root);
162 }
163
164 Ok(LogicalPlan::new(root))
165 }
166
167 fn push_projections_down(&self, op: LogicalOperator) -> LogicalOperator {
174 let required = self.collect_required_columns(&op);
176
177 self.push_projections_recursive(op, &required)
179 }
180
181 fn collect_required_columns(&self, op: &LogicalOperator) -> HashSet<RequiredColumn> {
183 let mut required = HashSet::new();
184 Self::collect_required_recursive(op, &mut required);
185 required
186 }
187
188 fn collect_required_recursive(op: &LogicalOperator, required: &mut HashSet<RequiredColumn>) {
190 match op {
191 LogicalOperator::Return(ret) => {
192 for item in &ret.items {
193 Self::collect_from_expression(&item.expression, required);
194 }
195 Self::collect_required_recursive(&ret.input, required);
196 }
197 LogicalOperator::Project(proj) => {
198 for p in &proj.projections {
199 Self::collect_from_expression(&p.expression, required);
200 }
201 Self::collect_required_recursive(&proj.input, required);
202 }
203 LogicalOperator::Filter(filter) => {
204 Self::collect_from_expression(&filter.predicate, required);
205 Self::collect_required_recursive(&filter.input, required);
206 }
207 LogicalOperator::Sort(sort) => {
208 for key in &sort.keys {
209 Self::collect_from_expression(&key.expression, required);
210 }
211 Self::collect_required_recursive(&sort.input, required);
212 }
213 LogicalOperator::Aggregate(agg) => {
214 for expr in &agg.group_by {
215 Self::collect_from_expression(expr, required);
216 }
217 for agg_expr in &agg.aggregates {
218 if let Some(ref expr) = agg_expr.expression {
219 Self::collect_from_expression(expr, required);
220 }
221 }
222 if let Some(ref having) = agg.having {
223 Self::collect_from_expression(having, required);
224 }
225 Self::collect_required_recursive(&agg.input, required);
226 }
227 LogicalOperator::Join(join) => {
228 for cond in &join.conditions {
229 Self::collect_from_expression(&cond.left, required);
230 Self::collect_from_expression(&cond.right, required);
231 }
232 Self::collect_required_recursive(&join.left, required);
233 Self::collect_required_recursive(&join.right, required);
234 }
235 LogicalOperator::Expand(expand) => {
236 required.insert(RequiredColumn::Variable(expand.from_variable.clone()));
238 required.insert(RequiredColumn::Variable(expand.to_variable.clone()));
239 if let Some(ref edge_var) = expand.edge_variable {
240 required.insert(RequiredColumn::Variable(edge_var.clone()));
241 }
242 Self::collect_required_recursive(&expand.input, required);
243 }
244 LogicalOperator::Limit(limit) => {
245 Self::collect_required_recursive(&limit.input, required);
246 }
247 LogicalOperator::Skip(skip) => {
248 Self::collect_required_recursive(&skip.input, required);
249 }
250 LogicalOperator::Distinct(distinct) => {
251 Self::collect_required_recursive(&distinct.input, required);
252 }
253 LogicalOperator::NodeScan(scan) => {
254 required.insert(RequiredColumn::Variable(scan.variable.clone()));
255 }
256 LogicalOperator::EdgeScan(scan) => {
257 required.insert(RequiredColumn::Variable(scan.variable.clone()));
258 }
259 _ => {}
260 }
261 }
262
263 fn collect_from_expression(expr: &LogicalExpression, required: &mut HashSet<RequiredColumn>) {
265 match expr {
266 LogicalExpression::Variable(var) => {
267 required.insert(RequiredColumn::Variable(var.clone()));
268 }
269 LogicalExpression::Property { variable, property } => {
270 required.insert(RequiredColumn::Property(variable.clone(), property.clone()));
271 required.insert(RequiredColumn::Variable(variable.clone()));
272 }
273 LogicalExpression::Binary { left, right, .. } => {
274 Self::collect_from_expression(left, required);
275 Self::collect_from_expression(right, required);
276 }
277 LogicalExpression::Unary { operand, .. } => {
278 Self::collect_from_expression(operand, required);
279 }
280 LogicalExpression::FunctionCall { args, .. } => {
281 for arg in args {
282 Self::collect_from_expression(arg, required);
283 }
284 }
285 LogicalExpression::List(items) => {
286 for item in items {
287 Self::collect_from_expression(item, required);
288 }
289 }
290 LogicalExpression::Map(pairs) => {
291 for (_, value) in pairs {
292 Self::collect_from_expression(value, required);
293 }
294 }
295 LogicalExpression::IndexAccess { base, index } => {
296 Self::collect_from_expression(base, required);
297 Self::collect_from_expression(index, required);
298 }
299 LogicalExpression::SliceAccess { base, start, end } => {
300 Self::collect_from_expression(base, required);
301 if let Some(s) = start {
302 Self::collect_from_expression(s, required);
303 }
304 if let Some(e) = end {
305 Self::collect_from_expression(e, required);
306 }
307 }
308 LogicalExpression::Case {
309 operand,
310 when_clauses,
311 else_clause,
312 } => {
313 if let Some(op) = operand {
314 Self::collect_from_expression(op, required);
315 }
316 for (cond, result) in when_clauses {
317 Self::collect_from_expression(cond, required);
318 Self::collect_from_expression(result, required);
319 }
320 if let Some(else_expr) = else_clause {
321 Self::collect_from_expression(else_expr, required);
322 }
323 }
324 LogicalExpression::Labels(var)
325 | LogicalExpression::Type(var)
326 | LogicalExpression::Id(var) => {
327 required.insert(RequiredColumn::Variable(var.clone()));
328 }
329 LogicalExpression::ListComprehension {
330 list_expr,
331 filter_expr,
332 map_expr,
333 ..
334 } => {
335 Self::collect_from_expression(list_expr, required);
336 if let Some(filter) = filter_expr {
337 Self::collect_from_expression(filter, required);
338 }
339 Self::collect_from_expression(map_expr, required);
340 }
341 _ => {}
342 }
343 }
344
345 fn push_projections_recursive(
347 &self,
348 op: LogicalOperator,
349 required: &HashSet<RequiredColumn>,
350 ) -> LogicalOperator {
351 match op {
352 LogicalOperator::Return(mut ret) => {
353 ret.input = Box::new(self.push_projections_recursive(*ret.input, required));
354 LogicalOperator::Return(ret)
355 }
356 LogicalOperator::Project(mut proj) => {
357 proj.input = Box::new(self.push_projections_recursive(*proj.input, required));
358 LogicalOperator::Project(proj)
359 }
360 LogicalOperator::Filter(mut filter) => {
361 filter.input = Box::new(self.push_projections_recursive(*filter.input, required));
362 LogicalOperator::Filter(filter)
363 }
364 LogicalOperator::Sort(mut sort) => {
365 sort.input = Box::new(self.push_projections_recursive(*sort.input, required));
368 LogicalOperator::Sort(sort)
369 }
370 LogicalOperator::Aggregate(mut agg) => {
371 agg.input = Box::new(self.push_projections_recursive(*agg.input, required));
372 LogicalOperator::Aggregate(agg)
373 }
374 LogicalOperator::Join(mut join) => {
375 let left_vars = self.collect_output_variables(&join.left);
378 let right_vars = self.collect_output_variables(&join.right);
379
380 let left_required: HashSet<_> = required
382 .iter()
383 .filter(|c| match c {
384 RequiredColumn::Variable(v) => left_vars.contains(v),
385 RequiredColumn::Property(v, _) => left_vars.contains(v),
386 })
387 .cloned()
388 .collect();
389
390 let right_required: HashSet<_> = required
391 .iter()
392 .filter(|c| match c {
393 RequiredColumn::Variable(v) => right_vars.contains(v),
394 RequiredColumn::Property(v, _) => right_vars.contains(v),
395 })
396 .cloned()
397 .collect();
398
399 join.left = Box::new(self.push_projections_recursive(*join.left, &left_required));
400 join.right =
401 Box::new(self.push_projections_recursive(*join.right, &right_required));
402 LogicalOperator::Join(join)
403 }
404 LogicalOperator::Expand(mut expand) => {
405 expand.input = Box::new(self.push_projections_recursive(*expand.input, required));
406 LogicalOperator::Expand(expand)
407 }
408 LogicalOperator::Limit(mut limit) => {
409 limit.input = Box::new(self.push_projections_recursive(*limit.input, required));
410 LogicalOperator::Limit(limit)
411 }
412 LogicalOperator::Skip(mut skip) => {
413 skip.input = Box::new(self.push_projections_recursive(*skip.input, required));
414 LogicalOperator::Skip(skip)
415 }
416 LogicalOperator::Distinct(mut distinct) => {
417 distinct.input =
418 Box::new(self.push_projections_recursive(*distinct.input, required));
419 LogicalOperator::Distinct(distinct)
420 }
421 other => other,
422 }
423 }
424
425 fn reorder_joins(&self, op: LogicalOperator) -> LogicalOperator {
432 let op = self.reorder_joins_recursive(op);
434
435 if let Some((relations, conditions)) = self.extract_join_tree(&op) {
437 if relations.len() >= 2 {
438 if let Some(optimized) = self.optimize_join_order(&relations, &conditions) {
439 return optimized;
440 }
441 }
442 }
443
444 op
445 }
446
447 fn reorder_joins_recursive(&self, op: LogicalOperator) -> LogicalOperator {
449 match op {
450 LogicalOperator::Return(mut ret) => {
451 ret.input = Box::new(self.reorder_joins(*ret.input));
452 LogicalOperator::Return(ret)
453 }
454 LogicalOperator::Project(mut proj) => {
455 proj.input = Box::new(self.reorder_joins(*proj.input));
456 LogicalOperator::Project(proj)
457 }
458 LogicalOperator::Filter(mut filter) => {
459 filter.input = Box::new(self.reorder_joins(*filter.input));
460 LogicalOperator::Filter(filter)
461 }
462 LogicalOperator::Limit(mut limit) => {
463 limit.input = Box::new(self.reorder_joins(*limit.input));
464 LogicalOperator::Limit(limit)
465 }
466 LogicalOperator::Skip(mut skip) => {
467 skip.input = Box::new(self.reorder_joins(*skip.input));
468 LogicalOperator::Skip(skip)
469 }
470 LogicalOperator::Sort(mut sort) => {
471 sort.input = Box::new(self.reorder_joins(*sort.input));
472 LogicalOperator::Sort(sort)
473 }
474 LogicalOperator::Distinct(mut distinct) => {
475 distinct.input = Box::new(self.reorder_joins(*distinct.input));
476 LogicalOperator::Distinct(distinct)
477 }
478 LogicalOperator::Aggregate(mut agg) => {
479 agg.input = Box::new(self.reorder_joins(*agg.input));
480 LogicalOperator::Aggregate(agg)
481 }
482 LogicalOperator::Expand(mut expand) => {
483 expand.input = Box::new(self.reorder_joins(*expand.input));
484 LogicalOperator::Expand(expand)
485 }
486 other => other,
488 }
489 }
490
491 fn extract_join_tree(
495 &self,
496 op: &LogicalOperator,
497 ) -> Option<(Vec<(String, LogicalOperator)>, Vec<JoinInfo>)> {
498 let mut relations = Vec::new();
499 let mut join_conditions = Vec::new();
500
501 if !self.collect_join_tree(op, &mut relations, &mut join_conditions) {
502 return None;
503 }
504
505 if relations.len() < 2 {
506 return None;
507 }
508
509 Some((relations, join_conditions))
510 }
511
512 fn collect_join_tree(
516 &self,
517 op: &LogicalOperator,
518 relations: &mut Vec<(String, LogicalOperator)>,
519 conditions: &mut Vec<JoinInfo>,
520 ) -> bool {
521 match op {
522 LogicalOperator::Join(join) => {
523 let left_ok = self.collect_join_tree(&join.left, relations, conditions);
525 let right_ok = self.collect_join_tree(&join.right, relations, conditions);
526
527 for cond in &join.conditions {
529 if let (Some(left_var), Some(right_var)) = (
530 self.extract_variable_from_expr(&cond.left),
531 self.extract_variable_from_expr(&cond.right),
532 ) {
533 conditions.push(JoinInfo {
534 left_var,
535 right_var,
536 left_expr: cond.left.clone(),
537 right_expr: cond.right.clone(),
538 });
539 }
540 }
541
542 left_ok && right_ok
543 }
544 LogicalOperator::NodeScan(scan) => {
545 relations.push((scan.variable.clone(), op.clone()));
546 true
547 }
548 LogicalOperator::EdgeScan(scan) => {
549 relations.push((scan.variable.clone(), op.clone()));
550 true
551 }
552 LogicalOperator::Filter(filter) => {
553 self.collect_join_tree(&filter.input, relations, conditions)
555 }
556 LogicalOperator::Expand(expand) => {
557 relations.push((expand.to_variable.clone(), op.clone()));
560 true
561 }
562 _ => false,
563 }
564 }
565
566 fn extract_variable_from_expr(&self, expr: &LogicalExpression) -> Option<String> {
568 match expr {
569 LogicalExpression::Variable(v) => Some(v.clone()),
570 LogicalExpression::Property { variable, .. } => Some(variable.clone()),
571 _ => None,
572 }
573 }
574
575 fn optimize_join_order(
577 &self,
578 relations: &[(String, LogicalOperator)],
579 conditions: &[JoinInfo],
580 ) -> Option<LogicalOperator> {
581 use join_order::{DPccp, JoinGraphBuilder};
582
583 let mut builder = JoinGraphBuilder::new();
585
586 for (var, relation) in relations {
587 builder.add_relation(var, relation.clone());
588 }
589
590 for cond in conditions {
591 builder.add_join_condition(
592 &cond.left_var,
593 &cond.right_var,
594 cond.left_expr.clone(),
595 cond.right_expr.clone(),
596 );
597 }
598
599 let graph = builder.build();
600
601 let mut dpccp = DPccp::new(&graph, &self.cost_model, &self.card_estimator);
603 let plan = dpccp.optimize()?;
604
605 Some(plan.operator)
606 }
607
608 fn push_filters_down(&self, op: LogicalOperator) -> LogicalOperator {
613 match op {
614 LogicalOperator::Filter(filter) => {
616 let optimized_input = self.push_filters_down(*filter.input);
617 self.try_push_filter_into(filter.predicate, optimized_input)
618 }
619 LogicalOperator::Return(mut ret) => {
621 ret.input = Box::new(self.push_filters_down(*ret.input));
622 LogicalOperator::Return(ret)
623 }
624 LogicalOperator::Project(mut proj) => {
625 proj.input = Box::new(self.push_filters_down(*proj.input));
626 LogicalOperator::Project(proj)
627 }
628 LogicalOperator::Limit(mut limit) => {
629 limit.input = Box::new(self.push_filters_down(*limit.input));
630 LogicalOperator::Limit(limit)
631 }
632 LogicalOperator::Skip(mut skip) => {
633 skip.input = Box::new(self.push_filters_down(*skip.input));
634 LogicalOperator::Skip(skip)
635 }
636 LogicalOperator::Sort(mut sort) => {
637 sort.input = Box::new(self.push_filters_down(*sort.input));
638 LogicalOperator::Sort(sort)
639 }
640 LogicalOperator::Distinct(mut distinct) => {
641 distinct.input = Box::new(self.push_filters_down(*distinct.input));
642 LogicalOperator::Distinct(distinct)
643 }
644 LogicalOperator::Expand(mut expand) => {
645 expand.input = Box::new(self.push_filters_down(*expand.input));
646 LogicalOperator::Expand(expand)
647 }
648 LogicalOperator::Join(mut join) => {
649 join.left = Box::new(self.push_filters_down(*join.left));
650 join.right = Box::new(self.push_filters_down(*join.right));
651 LogicalOperator::Join(join)
652 }
653 LogicalOperator::Aggregate(mut agg) => {
654 agg.input = Box::new(self.push_filters_down(*agg.input));
655 LogicalOperator::Aggregate(agg)
656 }
657 other => other,
659 }
660 }
661
662 fn try_push_filter_into(
667 &self,
668 predicate: LogicalExpression,
669 op: LogicalOperator,
670 ) -> LogicalOperator {
671 match op {
672 LogicalOperator::Project(mut proj) => {
674 let predicate_vars = self.extract_variables(&predicate);
675 let computed_vars = self.extract_projection_aliases(&proj.projections);
676
677 if predicate_vars.is_disjoint(&computed_vars) {
679 proj.input = Box::new(self.try_push_filter_into(predicate, *proj.input));
680 LogicalOperator::Project(proj)
681 } else {
682 LogicalOperator::Filter(FilterOp {
684 predicate,
685 input: Box::new(LogicalOperator::Project(proj)),
686 })
687 }
688 }
689
690 LogicalOperator::Return(mut ret) => {
692 ret.input = Box::new(self.try_push_filter_into(predicate, *ret.input));
693 LogicalOperator::Return(ret)
694 }
695
696 LogicalOperator::Expand(mut expand) => {
698 let predicate_vars = self.extract_variables(&predicate);
699
700 let mut introduced_vars = vec![&expand.to_variable];
705 if let Some(ref edge_var) = expand.edge_variable {
706 introduced_vars.push(edge_var);
707 }
708 if let Some(ref path_alias) = expand.path_alias {
709 introduced_vars.push(path_alias);
710 }
711
712 let uses_introduced_vars =
714 predicate_vars.iter().any(|v| introduced_vars.contains(&v));
715
716 if !uses_introduced_vars {
717 expand.input = Box::new(self.try_push_filter_into(predicate, *expand.input));
719 LogicalOperator::Expand(expand)
720 } else {
721 LogicalOperator::Filter(FilterOp {
723 predicate,
724 input: Box::new(LogicalOperator::Expand(expand)),
725 })
726 }
727 }
728
729 LogicalOperator::Join(mut join) => {
731 let predicate_vars = self.extract_variables(&predicate);
732 let left_vars = self.collect_output_variables(&join.left);
733 let right_vars = self.collect_output_variables(&join.right);
734
735 let uses_left = predicate_vars.iter().any(|v| left_vars.contains(v));
736 let uses_right = predicate_vars.iter().any(|v| right_vars.contains(v));
737
738 if uses_left && !uses_right {
739 join.left = Box::new(self.try_push_filter_into(predicate, *join.left));
741 LogicalOperator::Join(join)
742 } else if uses_right && !uses_left {
743 join.right = Box::new(self.try_push_filter_into(predicate, *join.right));
745 LogicalOperator::Join(join)
746 } else {
747 LogicalOperator::Filter(FilterOp {
749 predicate,
750 input: Box::new(LogicalOperator::Join(join)),
751 })
752 }
753 }
754
755 LogicalOperator::Aggregate(agg) => LogicalOperator::Filter(FilterOp {
757 predicate,
758 input: Box::new(LogicalOperator::Aggregate(agg)),
759 }),
760
761 LogicalOperator::NodeScan(scan) => LogicalOperator::Filter(FilterOp {
763 predicate,
764 input: Box::new(LogicalOperator::NodeScan(scan)),
765 }),
766
767 other => LogicalOperator::Filter(FilterOp {
769 predicate,
770 input: Box::new(other),
771 }),
772 }
773 }
774
775 fn collect_output_variables(&self, op: &LogicalOperator) -> HashSet<String> {
777 let mut vars = HashSet::new();
778 Self::collect_output_variables_recursive(op, &mut vars);
779 vars
780 }
781
782 fn collect_output_variables_recursive(op: &LogicalOperator, vars: &mut HashSet<String>) {
784 match op {
785 LogicalOperator::NodeScan(scan) => {
786 vars.insert(scan.variable.clone());
787 }
788 LogicalOperator::EdgeScan(scan) => {
789 vars.insert(scan.variable.clone());
790 }
791 LogicalOperator::Expand(expand) => {
792 vars.insert(expand.to_variable.clone());
793 if let Some(edge_var) = &expand.edge_variable {
794 vars.insert(edge_var.clone());
795 }
796 Self::collect_output_variables_recursive(&expand.input, vars);
797 }
798 LogicalOperator::Filter(filter) => {
799 Self::collect_output_variables_recursive(&filter.input, vars);
800 }
801 LogicalOperator::Project(proj) => {
802 for p in &proj.projections {
803 if let Some(alias) = &p.alias {
804 vars.insert(alias.clone());
805 }
806 }
807 Self::collect_output_variables_recursive(&proj.input, vars);
808 }
809 LogicalOperator::Join(join) => {
810 Self::collect_output_variables_recursive(&join.left, vars);
811 Self::collect_output_variables_recursive(&join.right, vars);
812 }
813 LogicalOperator::Aggregate(agg) => {
814 for expr in &agg.group_by {
815 Self::collect_variables(expr, vars);
816 }
817 for agg_expr in &agg.aggregates {
818 if let Some(alias) = &agg_expr.alias {
819 vars.insert(alias.clone());
820 }
821 }
822 }
823 LogicalOperator::Return(ret) => {
824 Self::collect_output_variables_recursive(&ret.input, vars);
825 }
826 LogicalOperator::Limit(limit) => {
827 Self::collect_output_variables_recursive(&limit.input, vars);
828 }
829 LogicalOperator::Skip(skip) => {
830 Self::collect_output_variables_recursive(&skip.input, vars);
831 }
832 LogicalOperator::Sort(sort) => {
833 Self::collect_output_variables_recursive(&sort.input, vars);
834 }
835 LogicalOperator::Distinct(distinct) => {
836 Self::collect_output_variables_recursive(&distinct.input, vars);
837 }
838 _ => {}
839 }
840 }
841
842 fn extract_variables(&self, expr: &LogicalExpression) -> HashSet<String> {
844 let mut vars = HashSet::new();
845 Self::collect_variables(expr, &mut vars);
846 vars
847 }
848
849 fn collect_variables(expr: &LogicalExpression, vars: &mut HashSet<String>) {
851 match expr {
852 LogicalExpression::Variable(name) => {
853 vars.insert(name.clone());
854 }
855 LogicalExpression::Property { variable, .. } => {
856 vars.insert(variable.clone());
857 }
858 LogicalExpression::Binary { left, right, .. } => {
859 Self::collect_variables(left, vars);
860 Self::collect_variables(right, vars);
861 }
862 LogicalExpression::Unary { operand, .. } => {
863 Self::collect_variables(operand, vars);
864 }
865 LogicalExpression::FunctionCall { args, .. } => {
866 for arg in args {
867 Self::collect_variables(arg, vars);
868 }
869 }
870 LogicalExpression::List(items) => {
871 for item in items {
872 Self::collect_variables(item, vars);
873 }
874 }
875 LogicalExpression::Map(pairs) => {
876 for (_, value) in pairs {
877 Self::collect_variables(value, vars);
878 }
879 }
880 LogicalExpression::IndexAccess { base, index } => {
881 Self::collect_variables(base, vars);
882 Self::collect_variables(index, vars);
883 }
884 LogicalExpression::SliceAccess { base, start, end } => {
885 Self::collect_variables(base, vars);
886 if let Some(s) = start {
887 Self::collect_variables(s, vars);
888 }
889 if let Some(e) = end {
890 Self::collect_variables(e, vars);
891 }
892 }
893 LogicalExpression::Case {
894 operand,
895 when_clauses,
896 else_clause,
897 } => {
898 if let Some(op) = operand {
899 Self::collect_variables(op, vars);
900 }
901 for (cond, result) in when_clauses {
902 Self::collect_variables(cond, vars);
903 Self::collect_variables(result, vars);
904 }
905 if let Some(else_expr) = else_clause {
906 Self::collect_variables(else_expr, vars);
907 }
908 }
909 LogicalExpression::Labels(var)
910 | LogicalExpression::Type(var)
911 | LogicalExpression::Id(var) => {
912 vars.insert(var.clone());
913 }
914 LogicalExpression::Literal(_) | LogicalExpression::Parameter(_) => {}
915 LogicalExpression::ListComprehension {
916 list_expr,
917 filter_expr,
918 map_expr,
919 ..
920 } => {
921 Self::collect_variables(list_expr, vars);
922 if let Some(filter) = filter_expr {
923 Self::collect_variables(filter, vars);
924 }
925 Self::collect_variables(map_expr, vars);
926 }
927 LogicalExpression::ExistsSubquery(_) | LogicalExpression::CountSubquery(_) => {
928 }
930 }
931 }
932
933 fn extract_projection_aliases(
935 &self,
936 projections: &[crate::query::plan::Projection],
937 ) -> HashSet<String> {
938 projections.iter().filter_map(|p| p.alias.clone()).collect()
939 }
940}
941
942impl Default for Optimizer {
943 fn default() -> Self {
944 Self::new()
945 }
946}
947
948#[cfg(test)]
949mod tests {
950 use super::*;
951 use crate::query::plan::{
952 AggregateExpr, AggregateFunction, AggregateOp, BinaryOp, DistinctOp, ExpandDirection,
953 ExpandOp, JoinOp, JoinType, LimitOp, NodeScanOp, ProjectOp, Projection, ReturnItem,
954 ReturnOp, SkipOp, SortKey, SortOp, SortOrder, UnaryOp,
955 };
956 use grafeo_common::types::Value;
957
958 #[test]
959 fn test_optimizer_filter_pushdown_simple() {
960 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
965 items: vec![ReturnItem {
966 expression: LogicalExpression::Variable("n".to_string()),
967 alias: None,
968 }],
969 distinct: false,
970 input: Box::new(LogicalOperator::Filter(FilterOp {
971 predicate: LogicalExpression::Binary {
972 left: Box::new(LogicalExpression::Property {
973 variable: "n".to_string(),
974 property: "age".to_string(),
975 }),
976 op: BinaryOp::Gt,
977 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
978 },
979 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
980 variable: "n".to_string(),
981 label: Some("Person".to_string()),
982 input: None,
983 })),
984 })),
985 }));
986
987 let optimizer = Optimizer::new();
988 let optimized = optimizer.optimize(plan).unwrap();
989
990 if let LogicalOperator::Return(ret) = &optimized.root {
992 if let LogicalOperator::Filter(filter) = ret.input.as_ref() {
993 if let LogicalOperator::NodeScan(scan) = filter.input.as_ref() {
994 assert_eq!(scan.variable, "n");
995 return;
996 }
997 }
998 }
999 panic!("Expected Return -> Filter -> NodeScan structure");
1000 }
1001
1002 #[test]
1003 fn test_optimizer_filter_pushdown_through_expand() {
1004 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1008 items: vec![ReturnItem {
1009 expression: LogicalExpression::Variable("b".to_string()),
1010 alias: None,
1011 }],
1012 distinct: false,
1013 input: Box::new(LogicalOperator::Filter(FilterOp {
1014 predicate: LogicalExpression::Binary {
1015 left: Box::new(LogicalExpression::Property {
1016 variable: "a".to_string(),
1017 property: "age".to_string(),
1018 }),
1019 op: BinaryOp::Gt,
1020 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1021 },
1022 input: Box::new(LogicalOperator::Expand(ExpandOp {
1023 from_variable: "a".to_string(),
1024 to_variable: "b".to_string(),
1025 edge_variable: None,
1026 direction: ExpandDirection::Outgoing,
1027 edge_type: Some("KNOWS".to_string()),
1028 min_hops: 1,
1029 max_hops: Some(1),
1030 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1031 variable: "a".to_string(),
1032 label: Some("Person".to_string()),
1033 input: None,
1034 })),
1035 path_alias: None,
1036 })),
1037 })),
1038 }));
1039
1040 let optimizer = Optimizer::new();
1041 let optimized = optimizer.optimize(plan).unwrap();
1042
1043 if let LogicalOperator::Return(ret) = &optimized.root {
1046 if let LogicalOperator::Expand(expand) = ret.input.as_ref() {
1047 if let LogicalOperator::Filter(filter) = expand.input.as_ref() {
1048 if let LogicalOperator::NodeScan(scan) = filter.input.as_ref() {
1049 assert_eq!(scan.variable, "a");
1050 assert_eq!(expand.from_variable, "a");
1051 assert_eq!(expand.to_variable, "b");
1052 return;
1053 }
1054 }
1055 }
1056 }
1057 panic!("Expected Return -> Expand -> Filter -> NodeScan structure");
1058 }
1059
1060 #[test]
1061 fn test_optimizer_filter_not_pushed_through_expand_for_target_var() {
1062 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1066 items: vec![ReturnItem {
1067 expression: LogicalExpression::Variable("a".to_string()),
1068 alias: None,
1069 }],
1070 distinct: false,
1071 input: Box::new(LogicalOperator::Filter(FilterOp {
1072 predicate: LogicalExpression::Binary {
1073 left: Box::new(LogicalExpression::Property {
1074 variable: "b".to_string(),
1075 property: "age".to_string(),
1076 }),
1077 op: BinaryOp::Gt,
1078 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1079 },
1080 input: Box::new(LogicalOperator::Expand(ExpandOp {
1081 from_variable: "a".to_string(),
1082 to_variable: "b".to_string(),
1083 edge_variable: None,
1084 direction: ExpandDirection::Outgoing,
1085 edge_type: Some("KNOWS".to_string()),
1086 min_hops: 1,
1087 max_hops: Some(1),
1088 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1089 variable: "a".to_string(),
1090 label: Some("Person".to_string()),
1091 input: None,
1092 })),
1093 path_alias: None,
1094 })),
1095 })),
1096 }));
1097
1098 let optimizer = Optimizer::new();
1099 let optimized = optimizer.optimize(plan).unwrap();
1100
1101 if let LogicalOperator::Return(ret) = &optimized.root {
1104 if let LogicalOperator::Filter(filter) = ret.input.as_ref() {
1105 if let LogicalExpression::Binary { left, .. } = &filter.predicate {
1107 if let LogicalExpression::Property { variable, .. } = left.as_ref() {
1108 assert_eq!(variable, "b");
1109 }
1110 }
1111
1112 if let LogicalOperator::Expand(expand) = filter.input.as_ref() {
1113 if let LogicalOperator::NodeScan(_) = expand.input.as_ref() {
1114 return;
1115 }
1116 }
1117 }
1118 }
1119 panic!("Expected Return -> Filter -> Expand -> NodeScan structure");
1120 }
1121
1122 #[test]
1123 fn test_optimizer_extract_variables() {
1124 let optimizer = Optimizer::new();
1125
1126 let expr = LogicalExpression::Binary {
1127 left: Box::new(LogicalExpression::Property {
1128 variable: "n".to_string(),
1129 property: "age".to_string(),
1130 }),
1131 op: BinaryOp::Gt,
1132 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1133 };
1134
1135 let vars = optimizer.extract_variables(&expr);
1136 assert_eq!(vars.len(), 1);
1137 assert!(vars.contains("n"));
1138 }
1139
1140 #[test]
1143 fn test_optimizer_default() {
1144 let optimizer = Optimizer::default();
1145 let plan = LogicalPlan::new(LogicalOperator::Empty);
1147 let result = optimizer.optimize(plan);
1148 assert!(result.is_ok());
1149 }
1150
1151 #[test]
1152 fn test_optimizer_with_filter_pushdown_disabled() {
1153 let optimizer = Optimizer::new().with_filter_pushdown(false);
1154
1155 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1156 items: vec![ReturnItem {
1157 expression: LogicalExpression::Variable("n".to_string()),
1158 alias: None,
1159 }],
1160 distinct: false,
1161 input: Box::new(LogicalOperator::Filter(FilterOp {
1162 predicate: LogicalExpression::Literal(Value::Bool(true)),
1163 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1164 variable: "n".to_string(),
1165 label: None,
1166 input: None,
1167 })),
1168 })),
1169 }));
1170
1171 let optimized = optimizer.optimize(plan).unwrap();
1172 if let LogicalOperator::Return(ret) = &optimized.root {
1174 if let LogicalOperator::Filter(_) = ret.input.as_ref() {
1175 return;
1176 }
1177 }
1178 panic!("Expected unchanged structure");
1179 }
1180
1181 #[test]
1182 fn test_optimizer_with_join_reorder_disabled() {
1183 let optimizer = Optimizer::new().with_join_reorder(false);
1184 assert!(
1185 optimizer
1186 .optimize(LogicalPlan::new(LogicalOperator::Empty))
1187 .is_ok()
1188 );
1189 }
1190
1191 #[test]
1192 fn test_optimizer_with_cost_model() {
1193 let cost_model = CostModel::new();
1194 let optimizer = Optimizer::new().with_cost_model(cost_model);
1195 assert!(
1196 optimizer
1197 .cost_model()
1198 .estimate(&LogicalOperator::Empty, 0.0)
1199 .total()
1200 < 0.001
1201 );
1202 }
1203
1204 #[test]
1205 fn test_optimizer_with_cardinality_estimator() {
1206 let mut estimator = CardinalityEstimator::new();
1207 estimator.add_table_stats("Test", TableStats::new(500));
1208 let optimizer = Optimizer::new().with_cardinality_estimator(estimator);
1209
1210 let scan = LogicalOperator::NodeScan(NodeScanOp {
1211 variable: "n".to_string(),
1212 label: Some("Test".to_string()),
1213 input: None,
1214 });
1215 let plan = LogicalPlan::new(scan);
1216
1217 let cardinality = optimizer.estimate_cardinality(&plan);
1218 assert!((cardinality - 500.0).abs() < 0.001);
1219 }
1220
1221 #[test]
1222 fn test_optimizer_estimate_cost() {
1223 let optimizer = Optimizer::new();
1224 let plan = LogicalPlan::new(LogicalOperator::NodeScan(NodeScanOp {
1225 variable: "n".to_string(),
1226 label: None,
1227 input: None,
1228 }));
1229
1230 let cost = optimizer.estimate_cost(&plan);
1231 assert!(cost.total() > 0.0);
1232 }
1233
1234 #[test]
1237 fn test_filter_pushdown_through_project() {
1238 let optimizer = Optimizer::new();
1239
1240 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1241 predicate: LogicalExpression::Binary {
1242 left: Box::new(LogicalExpression::Property {
1243 variable: "n".to_string(),
1244 property: "age".to_string(),
1245 }),
1246 op: BinaryOp::Gt,
1247 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1248 },
1249 input: Box::new(LogicalOperator::Project(ProjectOp {
1250 projections: vec![Projection {
1251 expression: LogicalExpression::Variable("n".to_string()),
1252 alias: None,
1253 }],
1254 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1255 variable: "n".to_string(),
1256 label: None,
1257 input: None,
1258 })),
1259 })),
1260 }));
1261
1262 let optimized = optimizer.optimize(plan).unwrap();
1263
1264 if let LogicalOperator::Project(proj) = &optimized.root {
1266 if let LogicalOperator::Filter(_) = proj.input.as_ref() {
1267 return;
1268 }
1269 }
1270 panic!("Expected Project -> Filter structure");
1271 }
1272
1273 #[test]
1274 fn test_filter_not_pushed_through_project_with_alias() {
1275 let optimizer = Optimizer::new();
1276
1277 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1279 predicate: LogicalExpression::Binary {
1280 left: Box::new(LogicalExpression::Variable("x".to_string())),
1281 op: BinaryOp::Gt,
1282 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1283 },
1284 input: Box::new(LogicalOperator::Project(ProjectOp {
1285 projections: vec![Projection {
1286 expression: LogicalExpression::Property {
1287 variable: "n".to_string(),
1288 property: "age".to_string(),
1289 },
1290 alias: Some("x".to_string()),
1291 }],
1292 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1293 variable: "n".to_string(),
1294 label: None,
1295 input: None,
1296 })),
1297 })),
1298 }));
1299
1300 let optimized = optimizer.optimize(plan).unwrap();
1301
1302 if let LogicalOperator::Filter(filter) = &optimized.root {
1304 if let LogicalOperator::Project(_) = filter.input.as_ref() {
1305 return;
1306 }
1307 }
1308 panic!("Expected Filter -> Project structure");
1309 }
1310
1311 #[test]
1312 fn test_filter_pushdown_through_limit() {
1313 let optimizer = Optimizer::new();
1314
1315 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1316 predicate: LogicalExpression::Literal(Value::Bool(true)),
1317 input: Box::new(LogicalOperator::Limit(LimitOp {
1318 count: 10,
1319 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1320 variable: "n".to_string(),
1321 label: None,
1322 input: None,
1323 })),
1324 })),
1325 }));
1326
1327 let optimized = optimizer.optimize(plan).unwrap();
1328
1329 if let LogicalOperator::Filter(filter) = &optimized.root {
1331 if let LogicalOperator::Limit(_) = filter.input.as_ref() {
1332 return;
1333 }
1334 }
1335 panic!("Expected Filter -> Limit structure");
1336 }
1337
1338 #[test]
1339 fn test_filter_pushdown_through_sort() {
1340 let optimizer = Optimizer::new();
1341
1342 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1343 predicate: LogicalExpression::Literal(Value::Bool(true)),
1344 input: Box::new(LogicalOperator::Sort(SortOp {
1345 keys: vec![SortKey {
1346 expression: LogicalExpression::Variable("n".to_string()),
1347 order: SortOrder::Ascending,
1348 }],
1349 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1350 variable: "n".to_string(),
1351 label: None,
1352 input: None,
1353 })),
1354 })),
1355 }));
1356
1357 let optimized = optimizer.optimize(plan).unwrap();
1358
1359 if let LogicalOperator::Filter(filter) = &optimized.root {
1361 if let LogicalOperator::Sort(_) = filter.input.as_ref() {
1362 return;
1363 }
1364 }
1365 panic!("Expected Filter -> Sort structure");
1366 }
1367
1368 #[test]
1369 fn test_filter_pushdown_through_distinct() {
1370 let optimizer = Optimizer::new();
1371
1372 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1373 predicate: LogicalExpression::Literal(Value::Bool(true)),
1374 input: Box::new(LogicalOperator::Distinct(DistinctOp {
1375 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1376 variable: "n".to_string(),
1377 label: None,
1378 input: None,
1379 })),
1380 columns: None,
1381 })),
1382 }));
1383
1384 let optimized = optimizer.optimize(plan).unwrap();
1385
1386 if let LogicalOperator::Filter(filter) = &optimized.root {
1388 if let LogicalOperator::Distinct(_) = filter.input.as_ref() {
1389 return;
1390 }
1391 }
1392 panic!("Expected Filter -> Distinct structure");
1393 }
1394
1395 #[test]
1396 fn test_filter_not_pushed_through_aggregate() {
1397 let optimizer = Optimizer::new();
1398
1399 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1400 predicate: LogicalExpression::Binary {
1401 left: Box::new(LogicalExpression::Variable("cnt".to_string())),
1402 op: BinaryOp::Gt,
1403 right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1404 },
1405 input: Box::new(LogicalOperator::Aggregate(AggregateOp {
1406 group_by: vec![],
1407 aggregates: vec![AggregateExpr {
1408 function: AggregateFunction::Count,
1409 expression: None,
1410 distinct: false,
1411 alias: Some("cnt".to_string()),
1412 percentile: None,
1413 }],
1414 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1415 variable: "n".to_string(),
1416 label: None,
1417 input: None,
1418 })),
1419 having: None,
1420 })),
1421 }));
1422
1423 let optimized = optimizer.optimize(plan).unwrap();
1424
1425 if let LogicalOperator::Filter(filter) = &optimized.root {
1427 if let LogicalOperator::Aggregate(_) = filter.input.as_ref() {
1428 return;
1429 }
1430 }
1431 panic!("Expected Filter -> Aggregate structure");
1432 }
1433
1434 #[test]
1435 fn test_filter_pushdown_to_left_join_side() {
1436 let optimizer = Optimizer::new();
1437
1438 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1440 predicate: LogicalExpression::Binary {
1441 left: Box::new(LogicalExpression::Property {
1442 variable: "a".to_string(),
1443 property: "age".to_string(),
1444 }),
1445 op: BinaryOp::Gt,
1446 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1447 },
1448 input: Box::new(LogicalOperator::Join(JoinOp {
1449 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1450 variable: "a".to_string(),
1451 label: Some("Person".to_string()),
1452 input: None,
1453 })),
1454 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1455 variable: "b".to_string(),
1456 label: Some("Company".to_string()),
1457 input: None,
1458 })),
1459 join_type: JoinType::Inner,
1460 conditions: vec![],
1461 })),
1462 }));
1463
1464 let optimized = optimizer.optimize(plan).unwrap();
1465
1466 if let LogicalOperator::Join(join) = &optimized.root {
1468 if let LogicalOperator::Filter(_) = join.left.as_ref() {
1469 return;
1470 }
1471 }
1472 panic!("Expected Join with Filter on left side");
1473 }
1474
1475 #[test]
1476 fn test_filter_pushdown_to_right_join_side() {
1477 let optimizer = Optimizer::new();
1478
1479 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1481 predicate: LogicalExpression::Binary {
1482 left: Box::new(LogicalExpression::Property {
1483 variable: "b".to_string(),
1484 property: "name".to_string(),
1485 }),
1486 op: BinaryOp::Eq,
1487 right: Box::new(LogicalExpression::Literal(Value::String("Acme".into()))),
1488 },
1489 input: Box::new(LogicalOperator::Join(JoinOp {
1490 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1491 variable: "a".to_string(),
1492 label: Some("Person".to_string()),
1493 input: None,
1494 })),
1495 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1496 variable: "b".to_string(),
1497 label: Some("Company".to_string()),
1498 input: None,
1499 })),
1500 join_type: JoinType::Inner,
1501 conditions: vec![],
1502 })),
1503 }));
1504
1505 let optimized = optimizer.optimize(plan).unwrap();
1506
1507 if let LogicalOperator::Join(join) = &optimized.root {
1509 if let LogicalOperator::Filter(_) = join.right.as_ref() {
1510 return;
1511 }
1512 }
1513 panic!("Expected Join with Filter on right side");
1514 }
1515
1516 #[test]
1517 fn test_filter_not_pushed_when_uses_both_join_sides() {
1518 let optimizer = Optimizer::new();
1519
1520 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1522 predicate: LogicalExpression::Binary {
1523 left: Box::new(LogicalExpression::Property {
1524 variable: "a".to_string(),
1525 property: "id".to_string(),
1526 }),
1527 op: BinaryOp::Eq,
1528 right: Box::new(LogicalExpression::Property {
1529 variable: "b".to_string(),
1530 property: "a_id".to_string(),
1531 }),
1532 },
1533 input: Box::new(LogicalOperator::Join(JoinOp {
1534 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1535 variable: "a".to_string(),
1536 label: None,
1537 input: None,
1538 })),
1539 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1540 variable: "b".to_string(),
1541 label: None,
1542 input: None,
1543 })),
1544 join_type: JoinType::Inner,
1545 conditions: vec![],
1546 })),
1547 }));
1548
1549 let optimized = optimizer.optimize(plan).unwrap();
1550
1551 if let LogicalOperator::Filter(filter) = &optimized.root {
1553 if let LogicalOperator::Join(_) = filter.input.as_ref() {
1554 return;
1555 }
1556 }
1557 panic!("Expected Filter -> Join structure");
1558 }
1559
1560 #[test]
1563 fn test_extract_variables_from_variable() {
1564 let optimizer = Optimizer::new();
1565 let expr = LogicalExpression::Variable("x".to_string());
1566 let vars = optimizer.extract_variables(&expr);
1567 assert_eq!(vars.len(), 1);
1568 assert!(vars.contains("x"));
1569 }
1570
1571 #[test]
1572 fn test_extract_variables_from_unary() {
1573 let optimizer = Optimizer::new();
1574 let expr = LogicalExpression::Unary {
1575 op: UnaryOp::Not,
1576 operand: Box::new(LogicalExpression::Variable("x".to_string())),
1577 };
1578 let vars = optimizer.extract_variables(&expr);
1579 assert_eq!(vars.len(), 1);
1580 assert!(vars.contains("x"));
1581 }
1582
1583 #[test]
1584 fn test_extract_variables_from_function_call() {
1585 let optimizer = Optimizer::new();
1586 let expr = LogicalExpression::FunctionCall {
1587 name: "length".to_string(),
1588 args: vec![
1589 LogicalExpression::Variable("a".to_string()),
1590 LogicalExpression::Variable("b".to_string()),
1591 ],
1592 distinct: false,
1593 };
1594 let vars = optimizer.extract_variables(&expr);
1595 assert_eq!(vars.len(), 2);
1596 assert!(vars.contains("a"));
1597 assert!(vars.contains("b"));
1598 }
1599
1600 #[test]
1601 fn test_extract_variables_from_list() {
1602 let optimizer = Optimizer::new();
1603 let expr = LogicalExpression::List(vec![
1604 LogicalExpression::Variable("a".to_string()),
1605 LogicalExpression::Literal(Value::Int64(1)),
1606 LogicalExpression::Variable("b".to_string()),
1607 ]);
1608 let vars = optimizer.extract_variables(&expr);
1609 assert_eq!(vars.len(), 2);
1610 assert!(vars.contains("a"));
1611 assert!(vars.contains("b"));
1612 }
1613
1614 #[test]
1615 fn test_extract_variables_from_map() {
1616 let optimizer = Optimizer::new();
1617 let expr = LogicalExpression::Map(vec![
1618 (
1619 "key1".to_string(),
1620 LogicalExpression::Variable("a".to_string()),
1621 ),
1622 (
1623 "key2".to_string(),
1624 LogicalExpression::Variable("b".to_string()),
1625 ),
1626 ]);
1627 let vars = optimizer.extract_variables(&expr);
1628 assert_eq!(vars.len(), 2);
1629 assert!(vars.contains("a"));
1630 assert!(vars.contains("b"));
1631 }
1632
1633 #[test]
1634 fn test_extract_variables_from_index_access() {
1635 let optimizer = Optimizer::new();
1636 let expr = LogicalExpression::IndexAccess {
1637 base: Box::new(LogicalExpression::Variable("list".to_string())),
1638 index: Box::new(LogicalExpression::Variable("idx".to_string())),
1639 };
1640 let vars = optimizer.extract_variables(&expr);
1641 assert_eq!(vars.len(), 2);
1642 assert!(vars.contains("list"));
1643 assert!(vars.contains("idx"));
1644 }
1645
1646 #[test]
1647 fn test_extract_variables_from_slice_access() {
1648 let optimizer = Optimizer::new();
1649 let expr = LogicalExpression::SliceAccess {
1650 base: Box::new(LogicalExpression::Variable("list".to_string())),
1651 start: Some(Box::new(LogicalExpression::Variable("s".to_string()))),
1652 end: Some(Box::new(LogicalExpression::Variable("e".to_string()))),
1653 };
1654 let vars = optimizer.extract_variables(&expr);
1655 assert_eq!(vars.len(), 3);
1656 assert!(vars.contains("list"));
1657 assert!(vars.contains("s"));
1658 assert!(vars.contains("e"));
1659 }
1660
1661 #[test]
1662 fn test_extract_variables_from_case() {
1663 let optimizer = Optimizer::new();
1664 let expr = LogicalExpression::Case {
1665 operand: Some(Box::new(LogicalExpression::Variable("x".to_string()))),
1666 when_clauses: vec![(
1667 LogicalExpression::Literal(Value::Int64(1)),
1668 LogicalExpression::Variable("a".to_string()),
1669 )],
1670 else_clause: Some(Box::new(LogicalExpression::Variable("b".to_string()))),
1671 };
1672 let vars = optimizer.extract_variables(&expr);
1673 assert_eq!(vars.len(), 3);
1674 assert!(vars.contains("x"));
1675 assert!(vars.contains("a"));
1676 assert!(vars.contains("b"));
1677 }
1678
1679 #[test]
1680 fn test_extract_variables_from_labels() {
1681 let optimizer = Optimizer::new();
1682 let expr = LogicalExpression::Labels("n".to_string());
1683 let vars = optimizer.extract_variables(&expr);
1684 assert_eq!(vars.len(), 1);
1685 assert!(vars.contains("n"));
1686 }
1687
1688 #[test]
1689 fn test_extract_variables_from_type() {
1690 let optimizer = Optimizer::new();
1691 let expr = LogicalExpression::Type("e".to_string());
1692 let vars = optimizer.extract_variables(&expr);
1693 assert_eq!(vars.len(), 1);
1694 assert!(vars.contains("e"));
1695 }
1696
1697 #[test]
1698 fn test_extract_variables_from_id() {
1699 let optimizer = Optimizer::new();
1700 let expr = LogicalExpression::Id("n".to_string());
1701 let vars = optimizer.extract_variables(&expr);
1702 assert_eq!(vars.len(), 1);
1703 assert!(vars.contains("n"));
1704 }
1705
1706 #[test]
1707 fn test_extract_variables_from_list_comprehension() {
1708 let optimizer = Optimizer::new();
1709 let expr = LogicalExpression::ListComprehension {
1710 variable: "x".to_string(),
1711 list_expr: Box::new(LogicalExpression::Variable("items".to_string())),
1712 filter_expr: Some(Box::new(LogicalExpression::Variable("pred".to_string()))),
1713 map_expr: Box::new(LogicalExpression::Variable("result".to_string())),
1714 };
1715 let vars = optimizer.extract_variables(&expr);
1716 assert!(vars.contains("items"));
1717 assert!(vars.contains("pred"));
1718 assert!(vars.contains("result"));
1719 }
1720
1721 #[test]
1722 fn test_extract_variables_from_literal_and_parameter() {
1723 let optimizer = Optimizer::new();
1724
1725 let literal = LogicalExpression::Literal(Value::Int64(42));
1726 assert!(optimizer.extract_variables(&literal).is_empty());
1727
1728 let param = LogicalExpression::Parameter("p".to_string());
1729 assert!(optimizer.extract_variables(¶m).is_empty());
1730 }
1731
1732 #[test]
1735 fn test_recursive_filter_pushdown_through_skip() {
1736 let optimizer = Optimizer::new();
1737
1738 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1739 items: vec![ReturnItem {
1740 expression: LogicalExpression::Variable("n".to_string()),
1741 alias: None,
1742 }],
1743 distinct: false,
1744 input: Box::new(LogicalOperator::Filter(FilterOp {
1745 predicate: LogicalExpression::Literal(Value::Bool(true)),
1746 input: Box::new(LogicalOperator::Skip(SkipOp {
1747 count: 5,
1748 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1749 variable: "n".to_string(),
1750 label: None,
1751 input: None,
1752 })),
1753 })),
1754 })),
1755 }));
1756
1757 let optimized = optimizer.optimize(plan).unwrap();
1758
1759 assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1761 }
1762
1763 #[test]
1764 fn test_nested_filter_pushdown() {
1765 let optimizer = Optimizer::new();
1766
1767 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1769 items: vec![ReturnItem {
1770 expression: LogicalExpression::Variable("n".to_string()),
1771 alias: None,
1772 }],
1773 distinct: false,
1774 input: Box::new(LogicalOperator::Filter(FilterOp {
1775 predicate: LogicalExpression::Binary {
1776 left: Box::new(LogicalExpression::Property {
1777 variable: "n".to_string(),
1778 property: "x".to_string(),
1779 }),
1780 op: BinaryOp::Gt,
1781 right: Box::new(LogicalExpression::Literal(Value::Int64(1))),
1782 },
1783 input: Box::new(LogicalOperator::Filter(FilterOp {
1784 predicate: LogicalExpression::Binary {
1785 left: Box::new(LogicalExpression::Property {
1786 variable: "n".to_string(),
1787 property: "y".to_string(),
1788 }),
1789 op: BinaryOp::Lt,
1790 right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1791 },
1792 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1793 variable: "n".to_string(),
1794 label: None,
1795 input: None,
1796 })),
1797 })),
1798 })),
1799 }));
1800
1801 let optimized = optimizer.optimize(plan).unwrap();
1802 assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1803 }
1804}