1pub mod cardinality;
15pub mod cost;
16pub mod join_order;
17
18pub use cardinality::{CardinalityEstimator, ColumnStats, TableStats};
19pub use cost::{Cost, CostModel};
20pub use join_order::{BitSet, DPccp, JoinGraph, JoinGraphBuilder, JoinPlan};
21
22use crate::query::plan::{FilterOp, LogicalExpression, LogicalOperator, LogicalPlan};
23use grafeo_common::utils::error::Result;
24use std::collections::HashSet;
25
26#[derive(Debug, Clone)]
28struct JoinInfo {
29 left_var: String,
30 right_var: String,
31 left_expr: LogicalExpression,
32 right_expr: LogicalExpression,
33}
34
35#[derive(Debug, Clone, PartialEq, Eq, Hash)]
37enum RequiredColumn {
38 Variable(String),
40 Property(String, String),
42}
43
44pub struct Optimizer {
49 enable_filter_pushdown: bool,
51 enable_join_reorder: bool,
53 enable_projection_pushdown: bool,
55 cost_model: CostModel,
57 card_estimator: CardinalityEstimator,
59}
60
61impl Optimizer {
62 #[must_use]
64 pub fn new() -> Self {
65 Self {
66 enable_filter_pushdown: true,
67 enable_join_reorder: true,
68 enable_projection_pushdown: true,
69 cost_model: CostModel::new(),
70 card_estimator: CardinalityEstimator::new(),
71 }
72 }
73
74 pub fn with_filter_pushdown(mut self, enabled: bool) -> Self {
76 self.enable_filter_pushdown = enabled;
77 self
78 }
79
80 pub fn with_join_reorder(mut self, enabled: bool) -> Self {
82 self.enable_join_reorder = enabled;
83 self
84 }
85
86 pub fn with_projection_pushdown(mut self, enabled: bool) -> Self {
88 self.enable_projection_pushdown = enabled;
89 self
90 }
91
92 pub fn with_cost_model(mut self, cost_model: CostModel) -> Self {
94 self.cost_model = cost_model;
95 self
96 }
97
98 pub fn with_cardinality_estimator(mut self, estimator: CardinalityEstimator) -> Self {
100 self.card_estimator = estimator;
101 self
102 }
103
104 pub fn cost_model(&self) -> &CostModel {
106 &self.cost_model
107 }
108
109 pub fn cardinality_estimator(&self) -> &CardinalityEstimator {
111 &self.card_estimator
112 }
113
114 pub fn estimate_cost(&self, plan: &LogicalPlan) -> Cost {
116 let cardinality = self.card_estimator.estimate(&plan.root);
117 self.cost_model.estimate(&plan.root, cardinality)
118 }
119
120 pub fn estimate_cardinality(&self, plan: &LogicalPlan) -> f64 {
122 self.card_estimator.estimate(&plan.root)
123 }
124
125 pub fn optimize(&self, plan: LogicalPlan) -> Result<LogicalPlan> {
131 let mut root = plan.root;
132
133 if self.enable_filter_pushdown {
135 root = self.push_filters_down(root);
136 }
137
138 if self.enable_join_reorder {
139 root = self.reorder_joins(root);
140 }
141
142 if self.enable_projection_pushdown {
143 root = self.push_projections_down(root);
144 }
145
146 Ok(LogicalPlan::new(root))
147 }
148
149 fn push_projections_down(&self, op: LogicalOperator) -> LogicalOperator {
156 let required = self.collect_required_columns(&op);
158
159 self.push_projections_recursive(op, &required)
161 }
162
163 fn collect_required_columns(&self, op: &LogicalOperator) -> HashSet<RequiredColumn> {
165 let mut required = HashSet::new();
166 Self::collect_required_recursive(op, &mut required);
167 required
168 }
169
170 fn collect_required_recursive(op: &LogicalOperator, required: &mut HashSet<RequiredColumn>) {
172 match op {
173 LogicalOperator::Return(ret) => {
174 for item in &ret.items {
175 Self::collect_from_expression(&item.expression, required);
176 }
177 Self::collect_required_recursive(&ret.input, required);
178 }
179 LogicalOperator::Project(proj) => {
180 for p in &proj.projections {
181 Self::collect_from_expression(&p.expression, required);
182 }
183 Self::collect_required_recursive(&proj.input, required);
184 }
185 LogicalOperator::Filter(filter) => {
186 Self::collect_from_expression(&filter.predicate, required);
187 Self::collect_required_recursive(&filter.input, required);
188 }
189 LogicalOperator::Sort(sort) => {
190 for key in &sort.keys {
191 Self::collect_from_expression(&key.expression, required);
192 }
193 Self::collect_required_recursive(&sort.input, required);
194 }
195 LogicalOperator::Aggregate(agg) => {
196 for expr in &agg.group_by {
197 Self::collect_from_expression(expr, required);
198 }
199 for agg_expr in &agg.aggregates {
200 if let Some(ref expr) = agg_expr.expression {
201 Self::collect_from_expression(expr, required);
202 }
203 }
204 if let Some(ref having) = agg.having {
205 Self::collect_from_expression(having, required);
206 }
207 Self::collect_required_recursive(&agg.input, required);
208 }
209 LogicalOperator::Join(join) => {
210 for cond in &join.conditions {
211 Self::collect_from_expression(&cond.left, required);
212 Self::collect_from_expression(&cond.right, required);
213 }
214 Self::collect_required_recursive(&join.left, required);
215 Self::collect_required_recursive(&join.right, required);
216 }
217 LogicalOperator::Expand(expand) => {
218 required.insert(RequiredColumn::Variable(expand.from_variable.clone()));
220 required.insert(RequiredColumn::Variable(expand.to_variable.clone()));
221 if let Some(ref edge_var) = expand.edge_variable {
222 required.insert(RequiredColumn::Variable(edge_var.clone()));
223 }
224 Self::collect_required_recursive(&expand.input, required);
225 }
226 LogicalOperator::Limit(limit) => {
227 Self::collect_required_recursive(&limit.input, required);
228 }
229 LogicalOperator::Skip(skip) => {
230 Self::collect_required_recursive(&skip.input, required);
231 }
232 LogicalOperator::Distinct(distinct) => {
233 Self::collect_required_recursive(&distinct.input, required);
234 }
235 LogicalOperator::NodeScan(scan) => {
236 required.insert(RequiredColumn::Variable(scan.variable.clone()));
237 }
238 LogicalOperator::EdgeScan(scan) => {
239 required.insert(RequiredColumn::Variable(scan.variable.clone()));
240 }
241 _ => {}
242 }
243 }
244
245 fn collect_from_expression(expr: &LogicalExpression, required: &mut HashSet<RequiredColumn>) {
247 match expr {
248 LogicalExpression::Variable(var) => {
249 required.insert(RequiredColumn::Variable(var.clone()));
250 }
251 LogicalExpression::Property { variable, property } => {
252 required.insert(RequiredColumn::Property(variable.clone(), property.clone()));
253 required.insert(RequiredColumn::Variable(variable.clone()));
254 }
255 LogicalExpression::Binary { left, right, .. } => {
256 Self::collect_from_expression(left, required);
257 Self::collect_from_expression(right, required);
258 }
259 LogicalExpression::Unary { operand, .. } => {
260 Self::collect_from_expression(operand, required);
261 }
262 LogicalExpression::FunctionCall { args, .. } => {
263 for arg in args {
264 Self::collect_from_expression(arg, required);
265 }
266 }
267 LogicalExpression::List(items) => {
268 for item in items {
269 Self::collect_from_expression(item, required);
270 }
271 }
272 LogicalExpression::Map(pairs) => {
273 for (_, value) in pairs {
274 Self::collect_from_expression(value, required);
275 }
276 }
277 LogicalExpression::IndexAccess { base, index } => {
278 Self::collect_from_expression(base, required);
279 Self::collect_from_expression(index, required);
280 }
281 LogicalExpression::SliceAccess { base, start, end } => {
282 Self::collect_from_expression(base, required);
283 if let Some(s) = start {
284 Self::collect_from_expression(s, required);
285 }
286 if let Some(e) = end {
287 Self::collect_from_expression(e, required);
288 }
289 }
290 LogicalExpression::Case {
291 operand,
292 when_clauses,
293 else_clause,
294 } => {
295 if let Some(op) = operand {
296 Self::collect_from_expression(op, required);
297 }
298 for (cond, result) in when_clauses {
299 Self::collect_from_expression(cond, required);
300 Self::collect_from_expression(result, required);
301 }
302 if let Some(else_expr) = else_clause {
303 Self::collect_from_expression(else_expr, required);
304 }
305 }
306 LogicalExpression::Labels(var)
307 | LogicalExpression::Type(var)
308 | LogicalExpression::Id(var) => {
309 required.insert(RequiredColumn::Variable(var.clone()));
310 }
311 LogicalExpression::ListComprehension {
312 list_expr,
313 filter_expr,
314 map_expr,
315 ..
316 } => {
317 Self::collect_from_expression(list_expr, required);
318 if let Some(filter) = filter_expr {
319 Self::collect_from_expression(filter, required);
320 }
321 Self::collect_from_expression(map_expr, required);
322 }
323 _ => {}
324 }
325 }
326
327 fn push_projections_recursive(
329 &self,
330 op: LogicalOperator,
331 required: &HashSet<RequiredColumn>,
332 ) -> LogicalOperator {
333 match op {
334 LogicalOperator::Return(mut ret) => {
335 ret.input = Box::new(self.push_projections_recursive(*ret.input, required));
336 LogicalOperator::Return(ret)
337 }
338 LogicalOperator::Project(mut proj) => {
339 proj.input = Box::new(self.push_projections_recursive(*proj.input, required));
340 LogicalOperator::Project(proj)
341 }
342 LogicalOperator::Filter(mut filter) => {
343 filter.input = Box::new(self.push_projections_recursive(*filter.input, required));
344 LogicalOperator::Filter(filter)
345 }
346 LogicalOperator::Sort(mut sort) => {
347 sort.input = Box::new(self.push_projections_recursive(*sort.input, required));
350 LogicalOperator::Sort(sort)
351 }
352 LogicalOperator::Aggregate(mut agg) => {
353 agg.input = Box::new(self.push_projections_recursive(*agg.input, required));
354 LogicalOperator::Aggregate(agg)
355 }
356 LogicalOperator::Join(mut join) => {
357 let left_vars = self.collect_output_variables(&join.left);
360 let right_vars = self.collect_output_variables(&join.right);
361
362 let left_required: HashSet<_> = required
364 .iter()
365 .filter(|c| match c {
366 RequiredColumn::Variable(v) => left_vars.contains(v),
367 RequiredColumn::Property(v, _) => left_vars.contains(v),
368 })
369 .cloned()
370 .collect();
371
372 let right_required: HashSet<_> = required
373 .iter()
374 .filter(|c| match c {
375 RequiredColumn::Variable(v) => right_vars.contains(v),
376 RequiredColumn::Property(v, _) => right_vars.contains(v),
377 })
378 .cloned()
379 .collect();
380
381 join.left = Box::new(self.push_projections_recursive(*join.left, &left_required));
382 join.right =
383 Box::new(self.push_projections_recursive(*join.right, &right_required));
384 LogicalOperator::Join(join)
385 }
386 LogicalOperator::Expand(mut expand) => {
387 expand.input = Box::new(self.push_projections_recursive(*expand.input, required));
388 LogicalOperator::Expand(expand)
389 }
390 LogicalOperator::Limit(mut limit) => {
391 limit.input = Box::new(self.push_projections_recursive(*limit.input, required));
392 LogicalOperator::Limit(limit)
393 }
394 LogicalOperator::Skip(mut skip) => {
395 skip.input = Box::new(self.push_projections_recursive(*skip.input, required));
396 LogicalOperator::Skip(skip)
397 }
398 LogicalOperator::Distinct(mut distinct) => {
399 distinct.input =
400 Box::new(self.push_projections_recursive(*distinct.input, required));
401 LogicalOperator::Distinct(distinct)
402 }
403 other => other,
404 }
405 }
406
407 fn reorder_joins(&self, op: LogicalOperator) -> LogicalOperator {
414 let op = self.reorder_joins_recursive(op);
416
417 if let Some((relations, conditions)) = self.extract_join_tree(&op) {
419 if relations.len() >= 2 {
420 if let Some(optimized) = self.optimize_join_order(&relations, &conditions) {
421 return optimized;
422 }
423 }
424 }
425
426 op
427 }
428
429 fn reorder_joins_recursive(&self, op: LogicalOperator) -> LogicalOperator {
431 match op {
432 LogicalOperator::Return(mut ret) => {
433 ret.input = Box::new(self.reorder_joins(*ret.input));
434 LogicalOperator::Return(ret)
435 }
436 LogicalOperator::Project(mut proj) => {
437 proj.input = Box::new(self.reorder_joins(*proj.input));
438 LogicalOperator::Project(proj)
439 }
440 LogicalOperator::Filter(mut filter) => {
441 filter.input = Box::new(self.reorder_joins(*filter.input));
442 LogicalOperator::Filter(filter)
443 }
444 LogicalOperator::Limit(mut limit) => {
445 limit.input = Box::new(self.reorder_joins(*limit.input));
446 LogicalOperator::Limit(limit)
447 }
448 LogicalOperator::Skip(mut skip) => {
449 skip.input = Box::new(self.reorder_joins(*skip.input));
450 LogicalOperator::Skip(skip)
451 }
452 LogicalOperator::Sort(mut sort) => {
453 sort.input = Box::new(self.reorder_joins(*sort.input));
454 LogicalOperator::Sort(sort)
455 }
456 LogicalOperator::Distinct(mut distinct) => {
457 distinct.input = Box::new(self.reorder_joins(*distinct.input));
458 LogicalOperator::Distinct(distinct)
459 }
460 LogicalOperator::Aggregate(mut agg) => {
461 agg.input = Box::new(self.reorder_joins(*agg.input));
462 LogicalOperator::Aggregate(agg)
463 }
464 LogicalOperator::Expand(mut expand) => {
465 expand.input = Box::new(self.reorder_joins(*expand.input));
466 LogicalOperator::Expand(expand)
467 }
468 other => other,
470 }
471 }
472
473 fn extract_join_tree(
477 &self,
478 op: &LogicalOperator,
479 ) -> Option<(Vec<(String, LogicalOperator)>, Vec<JoinInfo>)> {
480 let mut relations = Vec::new();
481 let mut join_conditions = Vec::new();
482
483 if !self.collect_join_tree(op, &mut relations, &mut join_conditions) {
484 return None;
485 }
486
487 if relations.len() < 2 {
488 return None;
489 }
490
491 Some((relations, join_conditions))
492 }
493
494 fn collect_join_tree(
498 &self,
499 op: &LogicalOperator,
500 relations: &mut Vec<(String, LogicalOperator)>,
501 conditions: &mut Vec<JoinInfo>,
502 ) -> bool {
503 match op {
504 LogicalOperator::Join(join) => {
505 let left_ok = self.collect_join_tree(&join.left, relations, conditions);
507 let right_ok = self.collect_join_tree(&join.right, relations, conditions);
508
509 for cond in &join.conditions {
511 if let (Some(left_var), Some(right_var)) = (
512 self.extract_variable_from_expr(&cond.left),
513 self.extract_variable_from_expr(&cond.right),
514 ) {
515 conditions.push(JoinInfo {
516 left_var,
517 right_var,
518 left_expr: cond.left.clone(),
519 right_expr: cond.right.clone(),
520 });
521 }
522 }
523
524 left_ok && right_ok
525 }
526 LogicalOperator::NodeScan(scan) => {
527 relations.push((scan.variable.clone(), op.clone()));
528 true
529 }
530 LogicalOperator::EdgeScan(scan) => {
531 relations.push((scan.variable.clone(), op.clone()));
532 true
533 }
534 LogicalOperator::Filter(filter) => {
535 self.collect_join_tree(&filter.input, relations, conditions)
537 }
538 LogicalOperator::Expand(expand) => {
539 relations.push((expand.to_variable.clone(), op.clone()));
542 true
543 }
544 _ => false,
545 }
546 }
547
548 fn extract_variable_from_expr(&self, expr: &LogicalExpression) -> Option<String> {
550 match expr {
551 LogicalExpression::Variable(v) => Some(v.clone()),
552 LogicalExpression::Property { variable, .. } => Some(variable.clone()),
553 _ => None,
554 }
555 }
556
557 fn optimize_join_order(
559 &self,
560 relations: &[(String, LogicalOperator)],
561 conditions: &[JoinInfo],
562 ) -> Option<LogicalOperator> {
563 use join_order::{DPccp, JoinGraphBuilder};
564
565 let mut builder = JoinGraphBuilder::new();
567
568 for (var, relation) in relations {
569 builder.add_relation(var, relation.clone());
570 }
571
572 for cond in conditions {
573 builder.add_join_condition(
574 &cond.left_var,
575 &cond.right_var,
576 cond.left_expr.clone(),
577 cond.right_expr.clone(),
578 );
579 }
580
581 let graph = builder.build();
582
583 let mut dpccp = DPccp::new(&graph, &self.cost_model, &self.card_estimator);
585 let plan = dpccp.optimize()?;
586
587 Some(plan.operator)
588 }
589
590 fn push_filters_down(&self, op: LogicalOperator) -> LogicalOperator {
595 match op {
596 LogicalOperator::Filter(filter) => {
598 let optimized_input = self.push_filters_down(*filter.input);
599 self.try_push_filter_into(filter.predicate, optimized_input)
600 }
601 LogicalOperator::Return(mut ret) => {
603 ret.input = Box::new(self.push_filters_down(*ret.input));
604 LogicalOperator::Return(ret)
605 }
606 LogicalOperator::Project(mut proj) => {
607 proj.input = Box::new(self.push_filters_down(*proj.input));
608 LogicalOperator::Project(proj)
609 }
610 LogicalOperator::Limit(mut limit) => {
611 limit.input = Box::new(self.push_filters_down(*limit.input));
612 LogicalOperator::Limit(limit)
613 }
614 LogicalOperator::Skip(mut skip) => {
615 skip.input = Box::new(self.push_filters_down(*skip.input));
616 LogicalOperator::Skip(skip)
617 }
618 LogicalOperator::Sort(mut sort) => {
619 sort.input = Box::new(self.push_filters_down(*sort.input));
620 LogicalOperator::Sort(sort)
621 }
622 LogicalOperator::Distinct(mut distinct) => {
623 distinct.input = Box::new(self.push_filters_down(*distinct.input));
624 LogicalOperator::Distinct(distinct)
625 }
626 LogicalOperator::Expand(mut expand) => {
627 expand.input = Box::new(self.push_filters_down(*expand.input));
628 LogicalOperator::Expand(expand)
629 }
630 LogicalOperator::Join(mut join) => {
631 join.left = Box::new(self.push_filters_down(*join.left));
632 join.right = Box::new(self.push_filters_down(*join.right));
633 LogicalOperator::Join(join)
634 }
635 LogicalOperator::Aggregate(mut agg) => {
636 agg.input = Box::new(self.push_filters_down(*agg.input));
637 LogicalOperator::Aggregate(agg)
638 }
639 other => other,
641 }
642 }
643
644 fn try_push_filter_into(
649 &self,
650 predicate: LogicalExpression,
651 op: LogicalOperator,
652 ) -> LogicalOperator {
653 match op {
654 LogicalOperator::Project(mut proj) => {
656 let predicate_vars = self.extract_variables(&predicate);
657 let computed_vars = self.extract_projection_aliases(&proj.projections);
658
659 if predicate_vars.is_disjoint(&computed_vars) {
661 proj.input = Box::new(self.try_push_filter_into(predicate, *proj.input));
662 LogicalOperator::Project(proj)
663 } else {
664 LogicalOperator::Filter(FilterOp {
666 predicate,
667 input: Box::new(LogicalOperator::Project(proj)),
668 })
669 }
670 }
671
672 LogicalOperator::Return(mut ret) => {
674 ret.input = Box::new(self.try_push_filter_into(predicate, *ret.input));
675 LogicalOperator::Return(ret)
676 }
677
678 LogicalOperator::Expand(mut expand) => {
680 let predicate_vars = self.extract_variables(&predicate);
681
682 let mut introduced_vars = vec![&expand.to_variable];
687 if let Some(ref edge_var) = expand.edge_variable {
688 introduced_vars.push(edge_var);
689 }
690 if let Some(ref path_alias) = expand.path_alias {
691 introduced_vars.push(path_alias);
692 }
693
694 let uses_introduced_vars =
696 predicate_vars.iter().any(|v| introduced_vars.contains(&v));
697
698 if !uses_introduced_vars {
699 expand.input = Box::new(self.try_push_filter_into(predicate, *expand.input));
701 LogicalOperator::Expand(expand)
702 } else {
703 LogicalOperator::Filter(FilterOp {
705 predicate,
706 input: Box::new(LogicalOperator::Expand(expand)),
707 })
708 }
709 }
710
711 LogicalOperator::Join(mut join) => {
713 let predicate_vars = self.extract_variables(&predicate);
714 let left_vars = self.collect_output_variables(&join.left);
715 let right_vars = self.collect_output_variables(&join.right);
716
717 let uses_left = predicate_vars.iter().any(|v| left_vars.contains(v));
718 let uses_right = predicate_vars.iter().any(|v| right_vars.contains(v));
719
720 if uses_left && !uses_right {
721 join.left = Box::new(self.try_push_filter_into(predicate, *join.left));
723 LogicalOperator::Join(join)
724 } else if uses_right && !uses_left {
725 join.right = Box::new(self.try_push_filter_into(predicate, *join.right));
727 LogicalOperator::Join(join)
728 } else {
729 LogicalOperator::Filter(FilterOp {
731 predicate,
732 input: Box::new(LogicalOperator::Join(join)),
733 })
734 }
735 }
736
737 LogicalOperator::Aggregate(agg) => LogicalOperator::Filter(FilterOp {
739 predicate,
740 input: Box::new(LogicalOperator::Aggregate(agg)),
741 }),
742
743 LogicalOperator::NodeScan(scan) => LogicalOperator::Filter(FilterOp {
745 predicate,
746 input: Box::new(LogicalOperator::NodeScan(scan)),
747 }),
748
749 other => LogicalOperator::Filter(FilterOp {
751 predicate,
752 input: Box::new(other),
753 }),
754 }
755 }
756
757 fn collect_output_variables(&self, op: &LogicalOperator) -> HashSet<String> {
759 let mut vars = HashSet::new();
760 Self::collect_output_variables_recursive(op, &mut vars);
761 vars
762 }
763
764 fn collect_output_variables_recursive(op: &LogicalOperator, vars: &mut HashSet<String>) {
766 match op {
767 LogicalOperator::NodeScan(scan) => {
768 vars.insert(scan.variable.clone());
769 }
770 LogicalOperator::EdgeScan(scan) => {
771 vars.insert(scan.variable.clone());
772 }
773 LogicalOperator::Expand(expand) => {
774 vars.insert(expand.to_variable.clone());
775 if let Some(edge_var) = &expand.edge_variable {
776 vars.insert(edge_var.clone());
777 }
778 Self::collect_output_variables_recursive(&expand.input, vars);
779 }
780 LogicalOperator::Filter(filter) => {
781 Self::collect_output_variables_recursive(&filter.input, vars);
782 }
783 LogicalOperator::Project(proj) => {
784 for p in &proj.projections {
785 if let Some(alias) = &p.alias {
786 vars.insert(alias.clone());
787 }
788 }
789 Self::collect_output_variables_recursive(&proj.input, vars);
790 }
791 LogicalOperator::Join(join) => {
792 Self::collect_output_variables_recursive(&join.left, vars);
793 Self::collect_output_variables_recursive(&join.right, vars);
794 }
795 LogicalOperator::Aggregate(agg) => {
796 for expr in &agg.group_by {
797 Self::collect_variables(expr, vars);
798 }
799 for agg_expr in &agg.aggregates {
800 if let Some(alias) = &agg_expr.alias {
801 vars.insert(alias.clone());
802 }
803 }
804 }
805 LogicalOperator::Return(ret) => {
806 Self::collect_output_variables_recursive(&ret.input, vars);
807 }
808 LogicalOperator::Limit(limit) => {
809 Self::collect_output_variables_recursive(&limit.input, vars);
810 }
811 LogicalOperator::Skip(skip) => {
812 Self::collect_output_variables_recursive(&skip.input, vars);
813 }
814 LogicalOperator::Sort(sort) => {
815 Self::collect_output_variables_recursive(&sort.input, vars);
816 }
817 LogicalOperator::Distinct(distinct) => {
818 Self::collect_output_variables_recursive(&distinct.input, vars);
819 }
820 _ => {}
821 }
822 }
823
824 fn extract_variables(&self, expr: &LogicalExpression) -> HashSet<String> {
826 let mut vars = HashSet::new();
827 Self::collect_variables(expr, &mut vars);
828 vars
829 }
830
831 fn collect_variables(expr: &LogicalExpression, vars: &mut HashSet<String>) {
833 match expr {
834 LogicalExpression::Variable(name) => {
835 vars.insert(name.clone());
836 }
837 LogicalExpression::Property { variable, .. } => {
838 vars.insert(variable.clone());
839 }
840 LogicalExpression::Binary { left, right, .. } => {
841 Self::collect_variables(left, vars);
842 Self::collect_variables(right, vars);
843 }
844 LogicalExpression::Unary { operand, .. } => {
845 Self::collect_variables(operand, vars);
846 }
847 LogicalExpression::FunctionCall { args, .. } => {
848 for arg in args {
849 Self::collect_variables(arg, vars);
850 }
851 }
852 LogicalExpression::List(items) => {
853 for item in items {
854 Self::collect_variables(item, vars);
855 }
856 }
857 LogicalExpression::Map(pairs) => {
858 for (_, value) in pairs {
859 Self::collect_variables(value, vars);
860 }
861 }
862 LogicalExpression::IndexAccess { base, index } => {
863 Self::collect_variables(base, vars);
864 Self::collect_variables(index, vars);
865 }
866 LogicalExpression::SliceAccess { base, start, end } => {
867 Self::collect_variables(base, vars);
868 if let Some(s) = start {
869 Self::collect_variables(s, vars);
870 }
871 if let Some(e) = end {
872 Self::collect_variables(e, vars);
873 }
874 }
875 LogicalExpression::Case {
876 operand,
877 when_clauses,
878 else_clause,
879 } => {
880 if let Some(op) = operand {
881 Self::collect_variables(op, vars);
882 }
883 for (cond, result) in when_clauses {
884 Self::collect_variables(cond, vars);
885 Self::collect_variables(result, vars);
886 }
887 if let Some(else_expr) = else_clause {
888 Self::collect_variables(else_expr, vars);
889 }
890 }
891 LogicalExpression::Labels(var)
892 | LogicalExpression::Type(var)
893 | LogicalExpression::Id(var) => {
894 vars.insert(var.clone());
895 }
896 LogicalExpression::Literal(_) | LogicalExpression::Parameter(_) => {}
897 LogicalExpression::ListComprehension {
898 list_expr,
899 filter_expr,
900 map_expr,
901 ..
902 } => {
903 Self::collect_variables(list_expr, vars);
904 if let Some(filter) = filter_expr {
905 Self::collect_variables(filter, vars);
906 }
907 Self::collect_variables(map_expr, vars);
908 }
909 LogicalExpression::ExistsSubquery(_) | LogicalExpression::CountSubquery(_) => {
910 }
912 }
913 }
914
915 fn extract_projection_aliases(
917 &self,
918 projections: &[crate::query::plan::Projection],
919 ) -> HashSet<String> {
920 projections.iter().filter_map(|p| p.alias.clone()).collect()
921 }
922}
923
924impl Default for Optimizer {
925 fn default() -> Self {
926 Self::new()
927 }
928}
929
930#[cfg(test)]
931mod tests {
932 use super::*;
933 use crate::query::plan::{
934 AggregateExpr, AggregateFunction, AggregateOp, BinaryOp, DistinctOp, ExpandDirection,
935 ExpandOp, JoinOp, JoinType, LimitOp, NodeScanOp, ProjectOp, Projection, ReturnItem,
936 ReturnOp, SkipOp, SortKey, SortOp, SortOrder, UnaryOp,
937 };
938 use grafeo_common::types::Value;
939
940 #[test]
941 fn test_optimizer_filter_pushdown_simple() {
942 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
947 items: vec![ReturnItem {
948 expression: LogicalExpression::Variable("n".to_string()),
949 alias: None,
950 }],
951 distinct: false,
952 input: Box::new(LogicalOperator::Filter(FilterOp {
953 predicate: LogicalExpression::Binary {
954 left: Box::new(LogicalExpression::Property {
955 variable: "n".to_string(),
956 property: "age".to_string(),
957 }),
958 op: BinaryOp::Gt,
959 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
960 },
961 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
962 variable: "n".to_string(),
963 label: Some("Person".to_string()),
964 input: None,
965 })),
966 })),
967 }));
968
969 let optimizer = Optimizer::new();
970 let optimized = optimizer.optimize(plan).unwrap();
971
972 if let LogicalOperator::Return(ret) = &optimized.root {
974 if let LogicalOperator::Filter(filter) = ret.input.as_ref() {
975 if let LogicalOperator::NodeScan(scan) = filter.input.as_ref() {
976 assert_eq!(scan.variable, "n");
977 return;
978 }
979 }
980 }
981 panic!("Expected Return -> Filter -> NodeScan structure");
982 }
983
984 #[test]
985 fn test_optimizer_filter_pushdown_through_expand() {
986 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
990 items: vec![ReturnItem {
991 expression: LogicalExpression::Variable("b".to_string()),
992 alias: None,
993 }],
994 distinct: false,
995 input: Box::new(LogicalOperator::Filter(FilterOp {
996 predicate: LogicalExpression::Binary {
997 left: Box::new(LogicalExpression::Property {
998 variable: "a".to_string(),
999 property: "age".to_string(),
1000 }),
1001 op: BinaryOp::Gt,
1002 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1003 },
1004 input: Box::new(LogicalOperator::Expand(ExpandOp {
1005 from_variable: "a".to_string(),
1006 to_variable: "b".to_string(),
1007 edge_variable: None,
1008 direction: ExpandDirection::Outgoing,
1009 edge_type: Some("KNOWS".to_string()),
1010 min_hops: 1,
1011 max_hops: Some(1),
1012 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1013 variable: "a".to_string(),
1014 label: Some("Person".to_string()),
1015 input: None,
1016 })),
1017 path_alias: None,
1018 })),
1019 })),
1020 }));
1021
1022 let optimizer = Optimizer::new();
1023 let optimized = optimizer.optimize(plan).unwrap();
1024
1025 if let LogicalOperator::Return(ret) = &optimized.root {
1028 if let LogicalOperator::Expand(expand) = ret.input.as_ref() {
1029 if let LogicalOperator::Filter(filter) = expand.input.as_ref() {
1030 if let LogicalOperator::NodeScan(scan) = filter.input.as_ref() {
1031 assert_eq!(scan.variable, "a");
1032 assert_eq!(expand.from_variable, "a");
1033 assert_eq!(expand.to_variable, "b");
1034 return;
1035 }
1036 }
1037 }
1038 }
1039 panic!("Expected Return -> Expand -> Filter -> NodeScan structure");
1040 }
1041
1042 #[test]
1043 fn test_optimizer_filter_not_pushed_through_expand_for_target_var() {
1044 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1048 items: vec![ReturnItem {
1049 expression: LogicalExpression::Variable("a".to_string()),
1050 alias: None,
1051 }],
1052 distinct: false,
1053 input: Box::new(LogicalOperator::Filter(FilterOp {
1054 predicate: LogicalExpression::Binary {
1055 left: Box::new(LogicalExpression::Property {
1056 variable: "b".to_string(),
1057 property: "age".to_string(),
1058 }),
1059 op: BinaryOp::Gt,
1060 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1061 },
1062 input: Box::new(LogicalOperator::Expand(ExpandOp {
1063 from_variable: "a".to_string(),
1064 to_variable: "b".to_string(),
1065 edge_variable: None,
1066 direction: ExpandDirection::Outgoing,
1067 edge_type: Some("KNOWS".to_string()),
1068 min_hops: 1,
1069 max_hops: Some(1),
1070 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1071 variable: "a".to_string(),
1072 label: Some("Person".to_string()),
1073 input: None,
1074 })),
1075 path_alias: None,
1076 })),
1077 })),
1078 }));
1079
1080 let optimizer = Optimizer::new();
1081 let optimized = optimizer.optimize(plan).unwrap();
1082
1083 if let LogicalOperator::Return(ret) = &optimized.root {
1086 if let LogicalOperator::Filter(filter) = ret.input.as_ref() {
1087 if let LogicalExpression::Binary { left, .. } = &filter.predicate {
1089 if let LogicalExpression::Property { variable, .. } = left.as_ref() {
1090 assert_eq!(variable, "b");
1091 }
1092 }
1093
1094 if let LogicalOperator::Expand(expand) = filter.input.as_ref() {
1095 if let LogicalOperator::NodeScan(_) = expand.input.as_ref() {
1096 return;
1097 }
1098 }
1099 }
1100 }
1101 panic!("Expected Return -> Filter -> Expand -> NodeScan structure");
1102 }
1103
1104 #[test]
1105 fn test_optimizer_extract_variables() {
1106 let optimizer = Optimizer::new();
1107
1108 let expr = LogicalExpression::Binary {
1109 left: Box::new(LogicalExpression::Property {
1110 variable: "n".to_string(),
1111 property: "age".to_string(),
1112 }),
1113 op: BinaryOp::Gt,
1114 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1115 };
1116
1117 let vars = optimizer.extract_variables(&expr);
1118 assert_eq!(vars.len(), 1);
1119 assert!(vars.contains("n"));
1120 }
1121
1122 #[test]
1125 fn test_optimizer_default() {
1126 let optimizer = Optimizer::default();
1127 let plan = LogicalPlan::new(LogicalOperator::Empty);
1129 let result = optimizer.optimize(plan);
1130 assert!(result.is_ok());
1131 }
1132
1133 #[test]
1134 fn test_optimizer_with_filter_pushdown_disabled() {
1135 let optimizer = Optimizer::new().with_filter_pushdown(false);
1136
1137 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1138 items: vec![ReturnItem {
1139 expression: LogicalExpression::Variable("n".to_string()),
1140 alias: None,
1141 }],
1142 distinct: false,
1143 input: Box::new(LogicalOperator::Filter(FilterOp {
1144 predicate: LogicalExpression::Literal(Value::Bool(true)),
1145 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1146 variable: "n".to_string(),
1147 label: None,
1148 input: None,
1149 })),
1150 })),
1151 }));
1152
1153 let optimized = optimizer.optimize(plan).unwrap();
1154 if let LogicalOperator::Return(ret) = &optimized.root {
1156 if let LogicalOperator::Filter(_) = ret.input.as_ref() {
1157 return;
1158 }
1159 }
1160 panic!("Expected unchanged structure");
1161 }
1162
1163 #[test]
1164 fn test_optimizer_with_join_reorder_disabled() {
1165 let optimizer = Optimizer::new().with_join_reorder(false);
1166 assert!(
1167 optimizer
1168 .optimize(LogicalPlan::new(LogicalOperator::Empty))
1169 .is_ok()
1170 );
1171 }
1172
1173 #[test]
1174 fn test_optimizer_with_cost_model() {
1175 let cost_model = CostModel::new();
1176 let optimizer = Optimizer::new().with_cost_model(cost_model);
1177 assert!(
1178 optimizer
1179 .cost_model()
1180 .estimate(&LogicalOperator::Empty, 0.0)
1181 .total()
1182 < 0.001
1183 );
1184 }
1185
1186 #[test]
1187 fn test_optimizer_with_cardinality_estimator() {
1188 let mut estimator = CardinalityEstimator::new();
1189 estimator.add_table_stats("Test", TableStats::new(500));
1190 let optimizer = Optimizer::new().with_cardinality_estimator(estimator);
1191
1192 let scan = LogicalOperator::NodeScan(NodeScanOp {
1193 variable: "n".to_string(),
1194 label: Some("Test".to_string()),
1195 input: None,
1196 });
1197 let plan = LogicalPlan::new(scan);
1198
1199 let cardinality = optimizer.estimate_cardinality(&plan);
1200 assert!((cardinality - 500.0).abs() < 0.001);
1201 }
1202
1203 #[test]
1204 fn test_optimizer_estimate_cost() {
1205 let optimizer = Optimizer::new();
1206 let plan = LogicalPlan::new(LogicalOperator::NodeScan(NodeScanOp {
1207 variable: "n".to_string(),
1208 label: None,
1209 input: None,
1210 }));
1211
1212 let cost = optimizer.estimate_cost(&plan);
1213 assert!(cost.total() > 0.0);
1214 }
1215
1216 #[test]
1219 fn test_filter_pushdown_through_project() {
1220 let optimizer = Optimizer::new();
1221
1222 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1223 predicate: LogicalExpression::Binary {
1224 left: Box::new(LogicalExpression::Property {
1225 variable: "n".to_string(),
1226 property: "age".to_string(),
1227 }),
1228 op: BinaryOp::Gt,
1229 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1230 },
1231 input: Box::new(LogicalOperator::Project(ProjectOp {
1232 projections: vec![Projection {
1233 expression: LogicalExpression::Variable("n".to_string()),
1234 alias: None,
1235 }],
1236 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1237 variable: "n".to_string(),
1238 label: None,
1239 input: None,
1240 })),
1241 })),
1242 }));
1243
1244 let optimized = optimizer.optimize(plan).unwrap();
1245
1246 if let LogicalOperator::Project(proj) = &optimized.root {
1248 if let LogicalOperator::Filter(_) = proj.input.as_ref() {
1249 return;
1250 }
1251 }
1252 panic!("Expected Project -> Filter structure");
1253 }
1254
1255 #[test]
1256 fn test_filter_not_pushed_through_project_with_alias() {
1257 let optimizer = Optimizer::new();
1258
1259 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1261 predicate: LogicalExpression::Binary {
1262 left: Box::new(LogicalExpression::Variable("x".to_string())),
1263 op: BinaryOp::Gt,
1264 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1265 },
1266 input: Box::new(LogicalOperator::Project(ProjectOp {
1267 projections: vec![Projection {
1268 expression: LogicalExpression::Property {
1269 variable: "n".to_string(),
1270 property: "age".to_string(),
1271 },
1272 alias: Some("x".to_string()),
1273 }],
1274 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1275 variable: "n".to_string(),
1276 label: None,
1277 input: None,
1278 })),
1279 })),
1280 }));
1281
1282 let optimized = optimizer.optimize(plan).unwrap();
1283
1284 if let LogicalOperator::Filter(filter) = &optimized.root {
1286 if let LogicalOperator::Project(_) = filter.input.as_ref() {
1287 return;
1288 }
1289 }
1290 panic!("Expected Filter -> Project structure");
1291 }
1292
1293 #[test]
1294 fn test_filter_pushdown_through_limit() {
1295 let optimizer = Optimizer::new();
1296
1297 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1298 predicate: LogicalExpression::Literal(Value::Bool(true)),
1299 input: Box::new(LogicalOperator::Limit(LimitOp {
1300 count: 10,
1301 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1302 variable: "n".to_string(),
1303 label: None,
1304 input: None,
1305 })),
1306 })),
1307 }));
1308
1309 let optimized = optimizer.optimize(plan).unwrap();
1310
1311 if let LogicalOperator::Filter(filter) = &optimized.root {
1313 if let LogicalOperator::Limit(_) = filter.input.as_ref() {
1314 return;
1315 }
1316 }
1317 panic!("Expected Filter -> Limit structure");
1318 }
1319
1320 #[test]
1321 fn test_filter_pushdown_through_sort() {
1322 let optimizer = Optimizer::new();
1323
1324 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1325 predicate: LogicalExpression::Literal(Value::Bool(true)),
1326 input: Box::new(LogicalOperator::Sort(SortOp {
1327 keys: vec![SortKey {
1328 expression: LogicalExpression::Variable("n".to_string()),
1329 order: SortOrder::Ascending,
1330 }],
1331 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1332 variable: "n".to_string(),
1333 label: None,
1334 input: None,
1335 })),
1336 })),
1337 }));
1338
1339 let optimized = optimizer.optimize(plan).unwrap();
1340
1341 if let LogicalOperator::Filter(filter) = &optimized.root {
1343 if let LogicalOperator::Sort(_) = filter.input.as_ref() {
1344 return;
1345 }
1346 }
1347 panic!("Expected Filter -> Sort structure");
1348 }
1349
1350 #[test]
1351 fn test_filter_pushdown_through_distinct() {
1352 let optimizer = Optimizer::new();
1353
1354 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1355 predicate: LogicalExpression::Literal(Value::Bool(true)),
1356 input: Box::new(LogicalOperator::Distinct(DistinctOp {
1357 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1358 variable: "n".to_string(),
1359 label: None,
1360 input: None,
1361 })),
1362 columns: None,
1363 })),
1364 }));
1365
1366 let optimized = optimizer.optimize(plan).unwrap();
1367
1368 if let LogicalOperator::Filter(filter) = &optimized.root {
1370 if let LogicalOperator::Distinct(_) = filter.input.as_ref() {
1371 return;
1372 }
1373 }
1374 panic!("Expected Filter -> Distinct structure");
1375 }
1376
1377 #[test]
1378 fn test_filter_not_pushed_through_aggregate() {
1379 let optimizer = Optimizer::new();
1380
1381 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1382 predicate: LogicalExpression::Binary {
1383 left: Box::new(LogicalExpression::Variable("cnt".to_string())),
1384 op: BinaryOp::Gt,
1385 right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1386 },
1387 input: Box::new(LogicalOperator::Aggregate(AggregateOp {
1388 group_by: vec![],
1389 aggregates: vec![AggregateExpr {
1390 function: AggregateFunction::Count,
1391 expression: None,
1392 distinct: false,
1393 alias: Some("cnt".to_string()),
1394 percentile: None,
1395 }],
1396 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1397 variable: "n".to_string(),
1398 label: None,
1399 input: None,
1400 })),
1401 having: None,
1402 })),
1403 }));
1404
1405 let optimized = optimizer.optimize(plan).unwrap();
1406
1407 if let LogicalOperator::Filter(filter) = &optimized.root {
1409 if let LogicalOperator::Aggregate(_) = filter.input.as_ref() {
1410 return;
1411 }
1412 }
1413 panic!("Expected Filter -> Aggregate structure");
1414 }
1415
1416 #[test]
1417 fn test_filter_pushdown_to_left_join_side() {
1418 let optimizer = Optimizer::new();
1419
1420 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1422 predicate: LogicalExpression::Binary {
1423 left: Box::new(LogicalExpression::Property {
1424 variable: "a".to_string(),
1425 property: "age".to_string(),
1426 }),
1427 op: BinaryOp::Gt,
1428 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1429 },
1430 input: Box::new(LogicalOperator::Join(JoinOp {
1431 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1432 variable: "a".to_string(),
1433 label: Some("Person".to_string()),
1434 input: None,
1435 })),
1436 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1437 variable: "b".to_string(),
1438 label: Some("Company".to_string()),
1439 input: None,
1440 })),
1441 join_type: JoinType::Inner,
1442 conditions: vec![],
1443 })),
1444 }));
1445
1446 let optimized = optimizer.optimize(plan).unwrap();
1447
1448 if let LogicalOperator::Join(join) = &optimized.root {
1450 if let LogicalOperator::Filter(_) = join.left.as_ref() {
1451 return;
1452 }
1453 }
1454 panic!("Expected Join with Filter on left side");
1455 }
1456
1457 #[test]
1458 fn test_filter_pushdown_to_right_join_side() {
1459 let optimizer = Optimizer::new();
1460
1461 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1463 predicate: LogicalExpression::Binary {
1464 left: Box::new(LogicalExpression::Property {
1465 variable: "b".to_string(),
1466 property: "name".to_string(),
1467 }),
1468 op: BinaryOp::Eq,
1469 right: Box::new(LogicalExpression::Literal(Value::String("Acme".into()))),
1470 },
1471 input: Box::new(LogicalOperator::Join(JoinOp {
1472 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1473 variable: "a".to_string(),
1474 label: Some("Person".to_string()),
1475 input: None,
1476 })),
1477 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1478 variable: "b".to_string(),
1479 label: Some("Company".to_string()),
1480 input: None,
1481 })),
1482 join_type: JoinType::Inner,
1483 conditions: vec![],
1484 })),
1485 }));
1486
1487 let optimized = optimizer.optimize(plan).unwrap();
1488
1489 if let LogicalOperator::Join(join) = &optimized.root {
1491 if let LogicalOperator::Filter(_) = join.right.as_ref() {
1492 return;
1493 }
1494 }
1495 panic!("Expected Join with Filter on right side");
1496 }
1497
1498 #[test]
1499 fn test_filter_not_pushed_when_uses_both_join_sides() {
1500 let optimizer = Optimizer::new();
1501
1502 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1504 predicate: LogicalExpression::Binary {
1505 left: Box::new(LogicalExpression::Property {
1506 variable: "a".to_string(),
1507 property: "id".to_string(),
1508 }),
1509 op: BinaryOp::Eq,
1510 right: Box::new(LogicalExpression::Property {
1511 variable: "b".to_string(),
1512 property: "a_id".to_string(),
1513 }),
1514 },
1515 input: Box::new(LogicalOperator::Join(JoinOp {
1516 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1517 variable: "a".to_string(),
1518 label: None,
1519 input: None,
1520 })),
1521 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1522 variable: "b".to_string(),
1523 label: None,
1524 input: None,
1525 })),
1526 join_type: JoinType::Inner,
1527 conditions: vec![],
1528 })),
1529 }));
1530
1531 let optimized = optimizer.optimize(plan).unwrap();
1532
1533 if let LogicalOperator::Filter(filter) = &optimized.root {
1535 if let LogicalOperator::Join(_) = filter.input.as_ref() {
1536 return;
1537 }
1538 }
1539 panic!("Expected Filter -> Join structure");
1540 }
1541
1542 #[test]
1545 fn test_extract_variables_from_variable() {
1546 let optimizer = Optimizer::new();
1547 let expr = LogicalExpression::Variable("x".to_string());
1548 let vars = optimizer.extract_variables(&expr);
1549 assert_eq!(vars.len(), 1);
1550 assert!(vars.contains("x"));
1551 }
1552
1553 #[test]
1554 fn test_extract_variables_from_unary() {
1555 let optimizer = Optimizer::new();
1556 let expr = LogicalExpression::Unary {
1557 op: UnaryOp::Not,
1558 operand: Box::new(LogicalExpression::Variable("x".to_string())),
1559 };
1560 let vars = optimizer.extract_variables(&expr);
1561 assert_eq!(vars.len(), 1);
1562 assert!(vars.contains("x"));
1563 }
1564
1565 #[test]
1566 fn test_extract_variables_from_function_call() {
1567 let optimizer = Optimizer::new();
1568 let expr = LogicalExpression::FunctionCall {
1569 name: "length".to_string(),
1570 args: vec![
1571 LogicalExpression::Variable("a".to_string()),
1572 LogicalExpression::Variable("b".to_string()),
1573 ],
1574 distinct: false,
1575 };
1576 let vars = optimizer.extract_variables(&expr);
1577 assert_eq!(vars.len(), 2);
1578 assert!(vars.contains("a"));
1579 assert!(vars.contains("b"));
1580 }
1581
1582 #[test]
1583 fn test_extract_variables_from_list() {
1584 let optimizer = Optimizer::new();
1585 let expr = LogicalExpression::List(vec![
1586 LogicalExpression::Variable("a".to_string()),
1587 LogicalExpression::Literal(Value::Int64(1)),
1588 LogicalExpression::Variable("b".to_string()),
1589 ]);
1590 let vars = optimizer.extract_variables(&expr);
1591 assert_eq!(vars.len(), 2);
1592 assert!(vars.contains("a"));
1593 assert!(vars.contains("b"));
1594 }
1595
1596 #[test]
1597 fn test_extract_variables_from_map() {
1598 let optimizer = Optimizer::new();
1599 let expr = LogicalExpression::Map(vec![
1600 (
1601 "key1".to_string(),
1602 LogicalExpression::Variable("a".to_string()),
1603 ),
1604 (
1605 "key2".to_string(),
1606 LogicalExpression::Variable("b".to_string()),
1607 ),
1608 ]);
1609 let vars = optimizer.extract_variables(&expr);
1610 assert_eq!(vars.len(), 2);
1611 assert!(vars.contains("a"));
1612 assert!(vars.contains("b"));
1613 }
1614
1615 #[test]
1616 fn test_extract_variables_from_index_access() {
1617 let optimizer = Optimizer::new();
1618 let expr = LogicalExpression::IndexAccess {
1619 base: Box::new(LogicalExpression::Variable("list".to_string())),
1620 index: Box::new(LogicalExpression::Variable("idx".to_string())),
1621 };
1622 let vars = optimizer.extract_variables(&expr);
1623 assert_eq!(vars.len(), 2);
1624 assert!(vars.contains("list"));
1625 assert!(vars.contains("idx"));
1626 }
1627
1628 #[test]
1629 fn test_extract_variables_from_slice_access() {
1630 let optimizer = Optimizer::new();
1631 let expr = LogicalExpression::SliceAccess {
1632 base: Box::new(LogicalExpression::Variable("list".to_string())),
1633 start: Some(Box::new(LogicalExpression::Variable("s".to_string()))),
1634 end: Some(Box::new(LogicalExpression::Variable("e".to_string()))),
1635 };
1636 let vars = optimizer.extract_variables(&expr);
1637 assert_eq!(vars.len(), 3);
1638 assert!(vars.contains("list"));
1639 assert!(vars.contains("s"));
1640 assert!(vars.contains("e"));
1641 }
1642
1643 #[test]
1644 fn test_extract_variables_from_case() {
1645 let optimizer = Optimizer::new();
1646 let expr = LogicalExpression::Case {
1647 operand: Some(Box::new(LogicalExpression::Variable("x".to_string()))),
1648 when_clauses: vec![(
1649 LogicalExpression::Literal(Value::Int64(1)),
1650 LogicalExpression::Variable("a".to_string()),
1651 )],
1652 else_clause: Some(Box::new(LogicalExpression::Variable("b".to_string()))),
1653 };
1654 let vars = optimizer.extract_variables(&expr);
1655 assert_eq!(vars.len(), 3);
1656 assert!(vars.contains("x"));
1657 assert!(vars.contains("a"));
1658 assert!(vars.contains("b"));
1659 }
1660
1661 #[test]
1662 fn test_extract_variables_from_labels() {
1663 let optimizer = Optimizer::new();
1664 let expr = LogicalExpression::Labels("n".to_string());
1665 let vars = optimizer.extract_variables(&expr);
1666 assert_eq!(vars.len(), 1);
1667 assert!(vars.contains("n"));
1668 }
1669
1670 #[test]
1671 fn test_extract_variables_from_type() {
1672 let optimizer = Optimizer::new();
1673 let expr = LogicalExpression::Type("e".to_string());
1674 let vars = optimizer.extract_variables(&expr);
1675 assert_eq!(vars.len(), 1);
1676 assert!(vars.contains("e"));
1677 }
1678
1679 #[test]
1680 fn test_extract_variables_from_id() {
1681 let optimizer = Optimizer::new();
1682 let expr = LogicalExpression::Id("n".to_string());
1683 let vars = optimizer.extract_variables(&expr);
1684 assert_eq!(vars.len(), 1);
1685 assert!(vars.contains("n"));
1686 }
1687
1688 #[test]
1689 fn test_extract_variables_from_list_comprehension() {
1690 let optimizer = Optimizer::new();
1691 let expr = LogicalExpression::ListComprehension {
1692 variable: "x".to_string(),
1693 list_expr: Box::new(LogicalExpression::Variable("items".to_string())),
1694 filter_expr: Some(Box::new(LogicalExpression::Variable("pred".to_string()))),
1695 map_expr: Box::new(LogicalExpression::Variable("result".to_string())),
1696 };
1697 let vars = optimizer.extract_variables(&expr);
1698 assert!(vars.contains("items"));
1699 assert!(vars.contains("pred"));
1700 assert!(vars.contains("result"));
1701 }
1702
1703 #[test]
1704 fn test_extract_variables_from_literal_and_parameter() {
1705 let optimizer = Optimizer::new();
1706
1707 let literal = LogicalExpression::Literal(Value::Int64(42));
1708 assert!(optimizer.extract_variables(&literal).is_empty());
1709
1710 let param = LogicalExpression::Parameter("p".to_string());
1711 assert!(optimizer.extract_variables(¶m).is_empty());
1712 }
1713
1714 #[test]
1717 fn test_recursive_filter_pushdown_through_skip() {
1718 let optimizer = Optimizer::new();
1719
1720 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1721 items: vec![ReturnItem {
1722 expression: LogicalExpression::Variable("n".to_string()),
1723 alias: None,
1724 }],
1725 distinct: false,
1726 input: Box::new(LogicalOperator::Filter(FilterOp {
1727 predicate: LogicalExpression::Literal(Value::Bool(true)),
1728 input: Box::new(LogicalOperator::Skip(SkipOp {
1729 count: 5,
1730 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1731 variable: "n".to_string(),
1732 label: None,
1733 input: None,
1734 })),
1735 })),
1736 })),
1737 }));
1738
1739 let optimized = optimizer.optimize(plan).unwrap();
1740
1741 assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1743 }
1744
1745 #[test]
1746 fn test_nested_filter_pushdown() {
1747 let optimizer = Optimizer::new();
1748
1749 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1751 items: vec![ReturnItem {
1752 expression: LogicalExpression::Variable("n".to_string()),
1753 alias: None,
1754 }],
1755 distinct: false,
1756 input: Box::new(LogicalOperator::Filter(FilterOp {
1757 predicate: LogicalExpression::Binary {
1758 left: Box::new(LogicalExpression::Property {
1759 variable: "n".to_string(),
1760 property: "x".to_string(),
1761 }),
1762 op: BinaryOp::Gt,
1763 right: Box::new(LogicalExpression::Literal(Value::Int64(1))),
1764 },
1765 input: Box::new(LogicalOperator::Filter(FilterOp {
1766 predicate: LogicalExpression::Binary {
1767 left: Box::new(LogicalExpression::Property {
1768 variable: "n".to_string(),
1769 property: "y".to_string(),
1770 }),
1771 op: BinaryOp::Lt,
1772 right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1773 },
1774 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1775 variable: "n".to_string(),
1776 label: None,
1777 input: None,
1778 })),
1779 })),
1780 })),
1781 }));
1782
1783 let optimized = optimizer.optimize(plan).unwrap();
1784 assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1785 }
1786}