1pub mod cardinality;
15pub mod cost;
16pub mod join_order;
17
18pub use cardinality::{CardinalityEstimator, ColumnStats, TableStats};
19pub use cost::{Cost, CostModel};
20pub use join_order::{BitSet, DPccp, JoinGraph, JoinGraphBuilder, JoinPlan};
21
22use crate::query::plan::{FilterOp, LogicalExpression, LogicalOperator, LogicalPlan};
23use grafeo_common::utils::error::Result;
24use std::collections::HashSet;
25
26#[derive(Debug, Clone)]
28struct JoinInfo {
29 left_var: String,
30 right_var: String,
31 left_expr: LogicalExpression,
32 right_expr: LogicalExpression,
33}
34
35#[derive(Debug, Clone, PartialEq, Eq, Hash)]
37enum RequiredColumn {
38 Variable(String),
40 Property(String, String),
42}
43
44pub struct Optimizer {
49 enable_filter_pushdown: bool,
51 enable_join_reorder: bool,
53 enable_projection_pushdown: bool,
55 cost_model: CostModel,
57 card_estimator: CardinalityEstimator,
59}
60
61impl Optimizer {
62 #[must_use]
64 pub fn new() -> Self {
65 Self {
66 enable_filter_pushdown: true,
67 enable_join_reorder: true,
68 enable_projection_pushdown: true,
69 cost_model: CostModel::new(),
70 card_estimator: CardinalityEstimator::new(),
71 }
72 }
73
74 pub fn with_filter_pushdown(mut self, enabled: bool) -> Self {
76 self.enable_filter_pushdown = enabled;
77 self
78 }
79
80 pub fn with_join_reorder(mut self, enabled: bool) -> Self {
82 self.enable_join_reorder = enabled;
83 self
84 }
85
86 pub fn with_projection_pushdown(mut self, enabled: bool) -> Self {
88 self.enable_projection_pushdown = enabled;
89 self
90 }
91
92 pub fn with_cost_model(mut self, cost_model: CostModel) -> Self {
94 self.cost_model = cost_model;
95 self
96 }
97
98 pub fn with_cardinality_estimator(mut self, estimator: CardinalityEstimator) -> Self {
100 self.card_estimator = estimator;
101 self
102 }
103
104 pub fn cost_model(&self) -> &CostModel {
106 &self.cost_model
107 }
108
109 pub fn cardinality_estimator(&self) -> &CardinalityEstimator {
111 &self.card_estimator
112 }
113
114 pub fn estimate_cost(&self, plan: &LogicalPlan) -> Cost {
116 let cardinality = self.card_estimator.estimate(&plan.root);
117 self.cost_model.estimate(&plan.root, cardinality)
118 }
119
120 pub fn estimate_cardinality(&self, plan: &LogicalPlan) -> f64 {
122 self.card_estimator.estimate(&plan.root)
123 }
124
125 pub fn optimize(&self, plan: LogicalPlan) -> Result<LogicalPlan> {
131 let mut root = plan.root;
132
133 if self.enable_filter_pushdown {
135 root = self.push_filters_down(root);
136 }
137
138 if self.enable_join_reorder {
139 root = self.reorder_joins(root);
140 }
141
142 if self.enable_projection_pushdown {
143 root = self.push_projections_down(root);
144 }
145
146 Ok(LogicalPlan::new(root))
147 }
148
149 fn push_projections_down(&self, op: LogicalOperator) -> LogicalOperator {
156 let required = self.collect_required_columns(&op);
158
159 self.push_projections_recursive(op, &required)
161 }
162
163 fn collect_required_columns(&self, op: &LogicalOperator) -> HashSet<RequiredColumn> {
165 let mut required = HashSet::new();
166 Self::collect_required_recursive(op, &mut required);
167 required
168 }
169
170 fn collect_required_recursive(op: &LogicalOperator, required: &mut HashSet<RequiredColumn>) {
172 match op {
173 LogicalOperator::Return(ret) => {
174 for item in &ret.items {
175 Self::collect_from_expression(&item.expression, required);
176 }
177 Self::collect_required_recursive(&ret.input, required);
178 }
179 LogicalOperator::Project(proj) => {
180 for p in &proj.projections {
181 Self::collect_from_expression(&p.expression, required);
182 }
183 Self::collect_required_recursive(&proj.input, required);
184 }
185 LogicalOperator::Filter(filter) => {
186 Self::collect_from_expression(&filter.predicate, required);
187 Self::collect_required_recursive(&filter.input, required);
188 }
189 LogicalOperator::Sort(sort) => {
190 for key in &sort.keys {
191 Self::collect_from_expression(&key.expression, required);
192 }
193 Self::collect_required_recursive(&sort.input, required);
194 }
195 LogicalOperator::Aggregate(agg) => {
196 for expr in &agg.group_by {
197 Self::collect_from_expression(expr, required);
198 }
199 for agg_expr in &agg.aggregates {
200 if let Some(ref expr) = agg_expr.expression {
201 Self::collect_from_expression(expr, required);
202 }
203 }
204 if let Some(ref having) = agg.having {
205 Self::collect_from_expression(having, required);
206 }
207 Self::collect_required_recursive(&agg.input, required);
208 }
209 LogicalOperator::Join(join) => {
210 for cond in &join.conditions {
211 Self::collect_from_expression(&cond.left, required);
212 Self::collect_from_expression(&cond.right, required);
213 }
214 Self::collect_required_recursive(&join.left, required);
215 Self::collect_required_recursive(&join.right, required);
216 }
217 LogicalOperator::Expand(expand) => {
218 required.insert(RequiredColumn::Variable(expand.from_variable.clone()));
220 required.insert(RequiredColumn::Variable(expand.to_variable.clone()));
221 if let Some(ref edge_var) = expand.edge_variable {
222 required.insert(RequiredColumn::Variable(edge_var.clone()));
223 }
224 Self::collect_required_recursive(&expand.input, required);
225 }
226 LogicalOperator::Limit(limit) => {
227 Self::collect_required_recursive(&limit.input, required);
228 }
229 LogicalOperator::Skip(skip) => {
230 Self::collect_required_recursive(&skip.input, required);
231 }
232 LogicalOperator::Distinct(distinct) => {
233 Self::collect_required_recursive(&distinct.input, required);
234 }
235 LogicalOperator::NodeScan(scan) => {
236 required.insert(RequiredColumn::Variable(scan.variable.clone()));
237 }
238 LogicalOperator::EdgeScan(scan) => {
239 required.insert(RequiredColumn::Variable(scan.variable.clone()));
240 }
241 _ => {}
242 }
243 }
244
245 fn collect_from_expression(expr: &LogicalExpression, required: &mut HashSet<RequiredColumn>) {
247 match expr {
248 LogicalExpression::Variable(var) => {
249 required.insert(RequiredColumn::Variable(var.clone()));
250 }
251 LogicalExpression::Property { variable, property } => {
252 required.insert(RequiredColumn::Property(variable.clone(), property.clone()));
253 required.insert(RequiredColumn::Variable(variable.clone()));
254 }
255 LogicalExpression::Binary { left, right, .. } => {
256 Self::collect_from_expression(left, required);
257 Self::collect_from_expression(right, required);
258 }
259 LogicalExpression::Unary { operand, .. } => {
260 Self::collect_from_expression(operand, required);
261 }
262 LogicalExpression::FunctionCall { args, .. } => {
263 for arg in args {
264 Self::collect_from_expression(arg, required);
265 }
266 }
267 LogicalExpression::List(items) => {
268 for item in items {
269 Self::collect_from_expression(item, required);
270 }
271 }
272 LogicalExpression::Map(pairs) => {
273 for (_, value) in pairs {
274 Self::collect_from_expression(value, required);
275 }
276 }
277 LogicalExpression::IndexAccess { base, index } => {
278 Self::collect_from_expression(base, required);
279 Self::collect_from_expression(index, required);
280 }
281 LogicalExpression::SliceAccess { base, start, end } => {
282 Self::collect_from_expression(base, required);
283 if let Some(s) = start {
284 Self::collect_from_expression(s, required);
285 }
286 if let Some(e) = end {
287 Self::collect_from_expression(e, required);
288 }
289 }
290 LogicalExpression::Case {
291 operand,
292 when_clauses,
293 else_clause,
294 } => {
295 if let Some(op) = operand {
296 Self::collect_from_expression(op, required);
297 }
298 for (cond, result) in when_clauses {
299 Self::collect_from_expression(cond, required);
300 Self::collect_from_expression(result, required);
301 }
302 if let Some(else_expr) = else_clause {
303 Self::collect_from_expression(else_expr, required);
304 }
305 }
306 LogicalExpression::Labels(var)
307 | LogicalExpression::Type(var)
308 | LogicalExpression::Id(var) => {
309 required.insert(RequiredColumn::Variable(var.clone()));
310 }
311 LogicalExpression::ListComprehension {
312 list_expr,
313 filter_expr,
314 map_expr,
315 ..
316 } => {
317 Self::collect_from_expression(list_expr, required);
318 if let Some(filter) = filter_expr {
319 Self::collect_from_expression(filter, required);
320 }
321 Self::collect_from_expression(map_expr, required);
322 }
323 _ => {}
324 }
325 }
326
327 fn push_projections_recursive(
329 &self,
330 op: LogicalOperator,
331 required: &HashSet<RequiredColumn>,
332 ) -> LogicalOperator {
333 match op {
334 LogicalOperator::Return(mut ret) => {
335 ret.input = Box::new(self.push_projections_recursive(*ret.input, required));
336 LogicalOperator::Return(ret)
337 }
338 LogicalOperator::Project(mut proj) => {
339 proj.input = Box::new(self.push_projections_recursive(*proj.input, required));
340 LogicalOperator::Project(proj)
341 }
342 LogicalOperator::Filter(mut filter) => {
343 filter.input = Box::new(self.push_projections_recursive(*filter.input, required));
344 LogicalOperator::Filter(filter)
345 }
346 LogicalOperator::Sort(mut sort) => {
347 sort.input = Box::new(self.push_projections_recursive(*sort.input, required));
350 LogicalOperator::Sort(sort)
351 }
352 LogicalOperator::Aggregate(mut agg) => {
353 agg.input = Box::new(self.push_projections_recursive(*agg.input, required));
354 LogicalOperator::Aggregate(agg)
355 }
356 LogicalOperator::Join(mut join) => {
357 let left_vars = self.collect_output_variables(&join.left);
360 let right_vars = self.collect_output_variables(&join.right);
361
362 let left_required: HashSet<_> = required
364 .iter()
365 .filter(|c| match c {
366 RequiredColumn::Variable(v) => left_vars.contains(v),
367 RequiredColumn::Property(v, _) => left_vars.contains(v),
368 })
369 .cloned()
370 .collect();
371
372 let right_required: HashSet<_> = required
373 .iter()
374 .filter(|c| match c {
375 RequiredColumn::Variable(v) => right_vars.contains(v),
376 RequiredColumn::Property(v, _) => right_vars.contains(v),
377 })
378 .cloned()
379 .collect();
380
381 join.left = Box::new(self.push_projections_recursive(*join.left, &left_required));
382 join.right =
383 Box::new(self.push_projections_recursive(*join.right, &right_required));
384 LogicalOperator::Join(join)
385 }
386 LogicalOperator::Expand(mut expand) => {
387 expand.input = Box::new(self.push_projections_recursive(*expand.input, required));
388 LogicalOperator::Expand(expand)
389 }
390 LogicalOperator::Limit(mut limit) => {
391 limit.input = Box::new(self.push_projections_recursive(*limit.input, required));
392 LogicalOperator::Limit(limit)
393 }
394 LogicalOperator::Skip(mut skip) => {
395 skip.input = Box::new(self.push_projections_recursive(*skip.input, required));
396 LogicalOperator::Skip(skip)
397 }
398 LogicalOperator::Distinct(mut distinct) => {
399 distinct.input =
400 Box::new(self.push_projections_recursive(*distinct.input, required));
401 LogicalOperator::Distinct(distinct)
402 }
403 other => other,
404 }
405 }
406
407 fn reorder_joins(&self, op: LogicalOperator) -> LogicalOperator {
414 let op = self.reorder_joins_recursive(op);
416
417 if let Some((relations, conditions)) = self.extract_join_tree(&op) {
419 if relations.len() >= 2 {
420 if let Some(optimized) = self.optimize_join_order(&relations, &conditions) {
421 return optimized;
422 }
423 }
424 }
425
426 op
427 }
428
429 fn reorder_joins_recursive(&self, op: LogicalOperator) -> LogicalOperator {
431 match op {
432 LogicalOperator::Return(mut ret) => {
433 ret.input = Box::new(self.reorder_joins(*ret.input));
434 LogicalOperator::Return(ret)
435 }
436 LogicalOperator::Project(mut proj) => {
437 proj.input = Box::new(self.reorder_joins(*proj.input));
438 LogicalOperator::Project(proj)
439 }
440 LogicalOperator::Filter(mut filter) => {
441 filter.input = Box::new(self.reorder_joins(*filter.input));
442 LogicalOperator::Filter(filter)
443 }
444 LogicalOperator::Limit(mut limit) => {
445 limit.input = Box::new(self.reorder_joins(*limit.input));
446 LogicalOperator::Limit(limit)
447 }
448 LogicalOperator::Skip(mut skip) => {
449 skip.input = Box::new(self.reorder_joins(*skip.input));
450 LogicalOperator::Skip(skip)
451 }
452 LogicalOperator::Sort(mut sort) => {
453 sort.input = Box::new(self.reorder_joins(*sort.input));
454 LogicalOperator::Sort(sort)
455 }
456 LogicalOperator::Distinct(mut distinct) => {
457 distinct.input = Box::new(self.reorder_joins(*distinct.input));
458 LogicalOperator::Distinct(distinct)
459 }
460 LogicalOperator::Aggregate(mut agg) => {
461 agg.input = Box::new(self.reorder_joins(*agg.input));
462 LogicalOperator::Aggregate(agg)
463 }
464 LogicalOperator::Expand(mut expand) => {
465 expand.input = Box::new(self.reorder_joins(*expand.input));
466 LogicalOperator::Expand(expand)
467 }
468 other => other,
470 }
471 }
472
473 fn extract_join_tree(
477 &self,
478 op: &LogicalOperator,
479 ) -> Option<(Vec<(String, LogicalOperator)>, Vec<JoinInfo>)> {
480 let mut relations = Vec::new();
481 let mut join_conditions = Vec::new();
482
483 if !self.collect_join_tree(op, &mut relations, &mut join_conditions) {
484 return None;
485 }
486
487 if relations.len() < 2 {
488 return None;
489 }
490
491 Some((relations, join_conditions))
492 }
493
494 fn collect_join_tree(
498 &self,
499 op: &LogicalOperator,
500 relations: &mut Vec<(String, LogicalOperator)>,
501 conditions: &mut Vec<JoinInfo>,
502 ) -> bool {
503 match op {
504 LogicalOperator::Join(join) => {
505 let left_ok = self.collect_join_tree(&join.left, relations, conditions);
507 let right_ok = self.collect_join_tree(&join.right, relations, conditions);
508
509 for cond in &join.conditions {
511 if let (Some(left_var), Some(right_var)) = (
512 self.extract_variable_from_expr(&cond.left),
513 self.extract_variable_from_expr(&cond.right),
514 ) {
515 conditions.push(JoinInfo {
516 left_var,
517 right_var,
518 left_expr: cond.left.clone(),
519 right_expr: cond.right.clone(),
520 });
521 }
522 }
523
524 left_ok && right_ok
525 }
526 LogicalOperator::NodeScan(scan) => {
527 relations.push((scan.variable.clone(), op.clone()));
528 true
529 }
530 LogicalOperator::EdgeScan(scan) => {
531 relations.push((scan.variable.clone(), op.clone()));
532 true
533 }
534 LogicalOperator::Filter(filter) => {
535 self.collect_join_tree(&filter.input, relations, conditions)
537 }
538 LogicalOperator::Expand(expand) => {
539 relations.push((expand.to_variable.clone(), op.clone()));
542 true
543 }
544 _ => false,
545 }
546 }
547
548 fn extract_variable_from_expr(&self, expr: &LogicalExpression) -> Option<String> {
550 match expr {
551 LogicalExpression::Variable(v) => Some(v.clone()),
552 LogicalExpression::Property { variable, .. } => Some(variable.clone()),
553 _ => None,
554 }
555 }
556
557 fn optimize_join_order(
559 &self,
560 relations: &[(String, LogicalOperator)],
561 conditions: &[JoinInfo],
562 ) -> Option<LogicalOperator> {
563 use join_order::{DPccp, JoinGraphBuilder};
564
565 let mut builder = JoinGraphBuilder::new();
567
568 for (var, relation) in relations {
569 builder.add_relation(var, relation.clone());
570 }
571
572 for cond in conditions {
573 builder.add_join_condition(
574 &cond.left_var,
575 &cond.right_var,
576 cond.left_expr.clone(),
577 cond.right_expr.clone(),
578 );
579 }
580
581 let graph = builder.build();
582
583 let mut dpccp = DPccp::new(&graph, &self.cost_model, &self.card_estimator);
585 let plan = dpccp.optimize()?;
586
587 Some(plan.operator)
588 }
589
590 fn push_filters_down(&self, op: LogicalOperator) -> LogicalOperator {
595 match op {
596 LogicalOperator::Filter(filter) => {
598 let optimized_input = self.push_filters_down(*filter.input);
599 self.try_push_filter_into(filter.predicate, optimized_input)
600 }
601 LogicalOperator::Return(mut ret) => {
603 ret.input = Box::new(self.push_filters_down(*ret.input));
604 LogicalOperator::Return(ret)
605 }
606 LogicalOperator::Project(mut proj) => {
607 proj.input = Box::new(self.push_filters_down(*proj.input));
608 LogicalOperator::Project(proj)
609 }
610 LogicalOperator::Limit(mut limit) => {
611 limit.input = Box::new(self.push_filters_down(*limit.input));
612 LogicalOperator::Limit(limit)
613 }
614 LogicalOperator::Skip(mut skip) => {
615 skip.input = Box::new(self.push_filters_down(*skip.input));
616 LogicalOperator::Skip(skip)
617 }
618 LogicalOperator::Sort(mut sort) => {
619 sort.input = Box::new(self.push_filters_down(*sort.input));
620 LogicalOperator::Sort(sort)
621 }
622 LogicalOperator::Distinct(mut distinct) => {
623 distinct.input = Box::new(self.push_filters_down(*distinct.input));
624 LogicalOperator::Distinct(distinct)
625 }
626 LogicalOperator::Expand(mut expand) => {
627 expand.input = Box::new(self.push_filters_down(*expand.input));
628 LogicalOperator::Expand(expand)
629 }
630 LogicalOperator::Join(mut join) => {
631 join.left = Box::new(self.push_filters_down(*join.left));
632 join.right = Box::new(self.push_filters_down(*join.right));
633 LogicalOperator::Join(join)
634 }
635 LogicalOperator::Aggregate(mut agg) => {
636 agg.input = Box::new(self.push_filters_down(*agg.input));
637 LogicalOperator::Aggregate(agg)
638 }
639 other => other,
641 }
642 }
643
644 fn try_push_filter_into(
649 &self,
650 predicate: LogicalExpression,
651 op: LogicalOperator,
652 ) -> LogicalOperator {
653 match op {
654 LogicalOperator::Project(mut proj) => {
656 let predicate_vars = self.extract_variables(&predicate);
657 let computed_vars = self.extract_projection_aliases(&proj.projections);
658
659 if predicate_vars.is_disjoint(&computed_vars) {
661 proj.input = Box::new(self.try_push_filter_into(predicate, *proj.input));
662 LogicalOperator::Project(proj)
663 } else {
664 LogicalOperator::Filter(FilterOp {
666 predicate,
667 input: Box::new(LogicalOperator::Project(proj)),
668 })
669 }
670 }
671
672 LogicalOperator::Return(mut ret) => {
674 ret.input = Box::new(self.try_push_filter_into(predicate, *ret.input));
675 LogicalOperator::Return(ret)
676 }
677
678 LogicalOperator::Expand(mut expand) => {
680 let predicate_vars = self.extract_variables(&predicate);
681
682 let uses_only_source = predicate_vars.iter().all(|v| v == &expand.from_variable);
684
685 if uses_only_source {
686 expand.input = Box::new(self.try_push_filter_into(predicate, *expand.input));
688 LogicalOperator::Expand(expand)
689 } else {
690 LogicalOperator::Filter(FilterOp {
692 predicate,
693 input: Box::new(LogicalOperator::Expand(expand)),
694 })
695 }
696 }
697
698 LogicalOperator::Join(mut join) => {
700 let predicate_vars = self.extract_variables(&predicate);
701 let left_vars = self.collect_output_variables(&join.left);
702 let right_vars = self.collect_output_variables(&join.right);
703
704 let uses_left = predicate_vars.iter().any(|v| left_vars.contains(v));
705 let uses_right = predicate_vars.iter().any(|v| right_vars.contains(v));
706
707 if uses_left && !uses_right {
708 join.left = Box::new(self.try_push_filter_into(predicate, *join.left));
710 LogicalOperator::Join(join)
711 } else if uses_right && !uses_left {
712 join.right = Box::new(self.try_push_filter_into(predicate, *join.right));
714 LogicalOperator::Join(join)
715 } else {
716 LogicalOperator::Filter(FilterOp {
718 predicate,
719 input: Box::new(LogicalOperator::Join(join)),
720 })
721 }
722 }
723
724 LogicalOperator::Aggregate(agg) => LogicalOperator::Filter(FilterOp {
726 predicate,
727 input: Box::new(LogicalOperator::Aggregate(agg)),
728 }),
729
730 LogicalOperator::NodeScan(scan) => LogicalOperator::Filter(FilterOp {
732 predicate,
733 input: Box::new(LogicalOperator::NodeScan(scan)),
734 }),
735
736 other => LogicalOperator::Filter(FilterOp {
738 predicate,
739 input: Box::new(other),
740 }),
741 }
742 }
743
744 fn collect_output_variables(&self, op: &LogicalOperator) -> HashSet<String> {
746 let mut vars = HashSet::new();
747 Self::collect_output_variables_recursive(op, &mut vars);
748 vars
749 }
750
751 fn collect_output_variables_recursive(op: &LogicalOperator, vars: &mut HashSet<String>) {
753 match op {
754 LogicalOperator::NodeScan(scan) => {
755 vars.insert(scan.variable.clone());
756 }
757 LogicalOperator::EdgeScan(scan) => {
758 vars.insert(scan.variable.clone());
759 }
760 LogicalOperator::Expand(expand) => {
761 vars.insert(expand.to_variable.clone());
762 if let Some(edge_var) = &expand.edge_variable {
763 vars.insert(edge_var.clone());
764 }
765 Self::collect_output_variables_recursive(&expand.input, vars);
766 }
767 LogicalOperator::Filter(filter) => {
768 Self::collect_output_variables_recursive(&filter.input, vars);
769 }
770 LogicalOperator::Project(proj) => {
771 for p in &proj.projections {
772 if let Some(alias) = &p.alias {
773 vars.insert(alias.clone());
774 }
775 }
776 Self::collect_output_variables_recursive(&proj.input, vars);
777 }
778 LogicalOperator::Join(join) => {
779 Self::collect_output_variables_recursive(&join.left, vars);
780 Self::collect_output_variables_recursive(&join.right, vars);
781 }
782 LogicalOperator::Aggregate(agg) => {
783 for expr in &agg.group_by {
784 Self::collect_variables(expr, vars);
785 }
786 for agg_expr in &agg.aggregates {
787 if let Some(alias) = &agg_expr.alias {
788 vars.insert(alias.clone());
789 }
790 }
791 }
792 LogicalOperator::Return(ret) => {
793 Self::collect_output_variables_recursive(&ret.input, vars);
794 }
795 LogicalOperator::Limit(limit) => {
796 Self::collect_output_variables_recursive(&limit.input, vars);
797 }
798 LogicalOperator::Skip(skip) => {
799 Self::collect_output_variables_recursive(&skip.input, vars);
800 }
801 LogicalOperator::Sort(sort) => {
802 Self::collect_output_variables_recursive(&sort.input, vars);
803 }
804 LogicalOperator::Distinct(distinct) => {
805 Self::collect_output_variables_recursive(&distinct.input, vars);
806 }
807 _ => {}
808 }
809 }
810
811 fn extract_variables(&self, expr: &LogicalExpression) -> HashSet<String> {
813 let mut vars = HashSet::new();
814 Self::collect_variables(expr, &mut vars);
815 vars
816 }
817
818 fn collect_variables(expr: &LogicalExpression, vars: &mut HashSet<String>) {
820 match expr {
821 LogicalExpression::Variable(name) => {
822 vars.insert(name.clone());
823 }
824 LogicalExpression::Property { variable, .. } => {
825 vars.insert(variable.clone());
826 }
827 LogicalExpression::Binary { left, right, .. } => {
828 Self::collect_variables(left, vars);
829 Self::collect_variables(right, vars);
830 }
831 LogicalExpression::Unary { operand, .. } => {
832 Self::collect_variables(operand, vars);
833 }
834 LogicalExpression::FunctionCall { args, .. } => {
835 for arg in args {
836 Self::collect_variables(arg, vars);
837 }
838 }
839 LogicalExpression::List(items) => {
840 for item in items {
841 Self::collect_variables(item, vars);
842 }
843 }
844 LogicalExpression::Map(pairs) => {
845 for (_, value) in pairs {
846 Self::collect_variables(value, vars);
847 }
848 }
849 LogicalExpression::IndexAccess { base, index } => {
850 Self::collect_variables(base, vars);
851 Self::collect_variables(index, vars);
852 }
853 LogicalExpression::SliceAccess { base, start, end } => {
854 Self::collect_variables(base, vars);
855 if let Some(s) = start {
856 Self::collect_variables(s, vars);
857 }
858 if let Some(e) = end {
859 Self::collect_variables(e, vars);
860 }
861 }
862 LogicalExpression::Case {
863 operand,
864 when_clauses,
865 else_clause,
866 } => {
867 if let Some(op) = operand {
868 Self::collect_variables(op, vars);
869 }
870 for (cond, result) in when_clauses {
871 Self::collect_variables(cond, vars);
872 Self::collect_variables(result, vars);
873 }
874 if let Some(else_expr) = else_clause {
875 Self::collect_variables(else_expr, vars);
876 }
877 }
878 LogicalExpression::Labels(var)
879 | LogicalExpression::Type(var)
880 | LogicalExpression::Id(var) => {
881 vars.insert(var.clone());
882 }
883 LogicalExpression::Literal(_) | LogicalExpression::Parameter(_) => {}
884 LogicalExpression::ListComprehension {
885 list_expr,
886 filter_expr,
887 map_expr,
888 ..
889 } => {
890 Self::collect_variables(list_expr, vars);
891 if let Some(filter) = filter_expr {
892 Self::collect_variables(filter, vars);
893 }
894 Self::collect_variables(map_expr, vars);
895 }
896 LogicalExpression::ExistsSubquery(_) | LogicalExpression::CountSubquery(_) => {
897 }
899 }
900 }
901
902 fn extract_projection_aliases(
904 &self,
905 projections: &[crate::query::plan::Projection],
906 ) -> HashSet<String> {
907 projections.iter().filter_map(|p| p.alias.clone()).collect()
908 }
909}
910
911impl Default for Optimizer {
912 fn default() -> Self {
913 Self::new()
914 }
915}
916
917#[cfg(test)]
918mod tests {
919 use super::*;
920 use crate::query::plan::{
921 AggregateExpr, AggregateFunction, AggregateOp, BinaryOp, DistinctOp, ExpandDirection,
922 ExpandOp, JoinOp, JoinType, LimitOp, NodeScanOp, ProjectOp, Projection, ReturnItem,
923 ReturnOp, SkipOp, SortKey, SortOp, SortOrder, UnaryOp,
924 };
925 use grafeo_common::types::Value;
926
927 #[test]
928 fn test_optimizer_filter_pushdown_simple() {
929 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
934 items: vec![ReturnItem {
935 expression: LogicalExpression::Variable("n".to_string()),
936 alias: None,
937 }],
938 distinct: false,
939 input: Box::new(LogicalOperator::Filter(FilterOp {
940 predicate: LogicalExpression::Binary {
941 left: Box::new(LogicalExpression::Property {
942 variable: "n".to_string(),
943 property: "age".to_string(),
944 }),
945 op: BinaryOp::Gt,
946 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
947 },
948 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
949 variable: "n".to_string(),
950 label: Some("Person".to_string()),
951 input: None,
952 })),
953 })),
954 }));
955
956 let optimizer = Optimizer::new();
957 let optimized = optimizer.optimize(plan).unwrap();
958
959 if let LogicalOperator::Return(ret) = &optimized.root {
961 if let LogicalOperator::Filter(filter) = ret.input.as_ref() {
962 if let LogicalOperator::NodeScan(scan) = filter.input.as_ref() {
963 assert_eq!(scan.variable, "n");
964 return;
965 }
966 }
967 }
968 panic!("Expected Return -> Filter -> NodeScan structure");
969 }
970
971 #[test]
972 fn test_optimizer_filter_pushdown_through_expand() {
973 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
977 items: vec![ReturnItem {
978 expression: LogicalExpression::Variable("b".to_string()),
979 alias: None,
980 }],
981 distinct: false,
982 input: Box::new(LogicalOperator::Filter(FilterOp {
983 predicate: LogicalExpression::Binary {
984 left: Box::new(LogicalExpression::Property {
985 variable: "a".to_string(),
986 property: "age".to_string(),
987 }),
988 op: BinaryOp::Gt,
989 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
990 },
991 input: Box::new(LogicalOperator::Expand(ExpandOp {
992 from_variable: "a".to_string(),
993 to_variable: "b".to_string(),
994 edge_variable: None,
995 direction: ExpandDirection::Outgoing,
996 edge_type: Some("KNOWS".to_string()),
997 min_hops: 1,
998 max_hops: Some(1),
999 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1000 variable: "a".to_string(),
1001 label: Some("Person".to_string()),
1002 input: None,
1003 })),
1004 path_alias: None,
1005 })),
1006 })),
1007 }));
1008
1009 let optimizer = Optimizer::new();
1010 let optimized = optimizer.optimize(plan).unwrap();
1011
1012 if let LogicalOperator::Return(ret) = &optimized.root {
1015 if let LogicalOperator::Expand(expand) = ret.input.as_ref() {
1016 if let LogicalOperator::Filter(filter) = expand.input.as_ref() {
1017 if let LogicalOperator::NodeScan(scan) = filter.input.as_ref() {
1018 assert_eq!(scan.variable, "a");
1019 assert_eq!(expand.from_variable, "a");
1020 assert_eq!(expand.to_variable, "b");
1021 return;
1022 }
1023 }
1024 }
1025 }
1026 panic!("Expected Return -> Expand -> Filter -> NodeScan structure");
1027 }
1028
1029 #[test]
1030 fn test_optimizer_filter_not_pushed_through_expand_for_target_var() {
1031 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1035 items: vec![ReturnItem {
1036 expression: LogicalExpression::Variable("a".to_string()),
1037 alias: None,
1038 }],
1039 distinct: false,
1040 input: Box::new(LogicalOperator::Filter(FilterOp {
1041 predicate: LogicalExpression::Binary {
1042 left: Box::new(LogicalExpression::Property {
1043 variable: "b".to_string(),
1044 property: "age".to_string(),
1045 }),
1046 op: BinaryOp::Gt,
1047 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1048 },
1049 input: Box::new(LogicalOperator::Expand(ExpandOp {
1050 from_variable: "a".to_string(),
1051 to_variable: "b".to_string(),
1052 edge_variable: None,
1053 direction: ExpandDirection::Outgoing,
1054 edge_type: Some("KNOWS".to_string()),
1055 min_hops: 1,
1056 max_hops: Some(1),
1057 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1058 variable: "a".to_string(),
1059 label: Some("Person".to_string()),
1060 input: None,
1061 })),
1062 path_alias: None,
1063 })),
1064 })),
1065 }));
1066
1067 let optimizer = Optimizer::new();
1068 let optimized = optimizer.optimize(plan).unwrap();
1069
1070 if let LogicalOperator::Return(ret) = &optimized.root {
1073 if let LogicalOperator::Filter(filter) = ret.input.as_ref() {
1074 if let LogicalExpression::Binary { left, .. } = &filter.predicate {
1076 if let LogicalExpression::Property { variable, .. } = left.as_ref() {
1077 assert_eq!(variable, "b");
1078 }
1079 }
1080
1081 if let LogicalOperator::Expand(expand) = filter.input.as_ref() {
1082 if let LogicalOperator::NodeScan(_) = expand.input.as_ref() {
1083 return;
1084 }
1085 }
1086 }
1087 }
1088 panic!("Expected Return -> Filter -> Expand -> NodeScan structure");
1089 }
1090
1091 #[test]
1092 fn test_optimizer_extract_variables() {
1093 let optimizer = Optimizer::new();
1094
1095 let expr = LogicalExpression::Binary {
1096 left: Box::new(LogicalExpression::Property {
1097 variable: "n".to_string(),
1098 property: "age".to_string(),
1099 }),
1100 op: BinaryOp::Gt,
1101 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1102 };
1103
1104 let vars = optimizer.extract_variables(&expr);
1105 assert_eq!(vars.len(), 1);
1106 assert!(vars.contains("n"));
1107 }
1108
1109 #[test]
1112 fn test_optimizer_default() {
1113 let optimizer = Optimizer::default();
1114 let plan = LogicalPlan::new(LogicalOperator::Empty);
1116 let result = optimizer.optimize(plan);
1117 assert!(result.is_ok());
1118 }
1119
1120 #[test]
1121 fn test_optimizer_with_filter_pushdown_disabled() {
1122 let optimizer = Optimizer::new().with_filter_pushdown(false);
1123
1124 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1125 items: vec![ReturnItem {
1126 expression: LogicalExpression::Variable("n".to_string()),
1127 alias: None,
1128 }],
1129 distinct: false,
1130 input: Box::new(LogicalOperator::Filter(FilterOp {
1131 predicate: LogicalExpression::Literal(Value::Bool(true)),
1132 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1133 variable: "n".to_string(),
1134 label: None,
1135 input: None,
1136 })),
1137 })),
1138 }));
1139
1140 let optimized = optimizer.optimize(plan).unwrap();
1141 if let LogicalOperator::Return(ret) = &optimized.root {
1143 if let LogicalOperator::Filter(_) = ret.input.as_ref() {
1144 return;
1145 }
1146 }
1147 panic!("Expected unchanged structure");
1148 }
1149
1150 #[test]
1151 fn test_optimizer_with_join_reorder_disabled() {
1152 let optimizer = Optimizer::new().with_join_reorder(false);
1153 assert!(
1154 optimizer
1155 .optimize(LogicalPlan::new(LogicalOperator::Empty))
1156 .is_ok()
1157 );
1158 }
1159
1160 #[test]
1161 fn test_optimizer_with_cost_model() {
1162 let cost_model = CostModel::new();
1163 let optimizer = Optimizer::new().with_cost_model(cost_model);
1164 assert!(
1165 optimizer
1166 .cost_model()
1167 .estimate(&LogicalOperator::Empty, 0.0)
1168 .total()
1169 < 0.001
1170 );
1171 }
1172
1173 #[test]
1174 fn test_optimizer_with_cardinality_estimator() {
1175 let mut estimator = CardinalityEstimator::new();
1176 estimator.add_table_stats("Test", TableStats::new(500));
1177 let optimizer = Optimizer::new().with_cardinality_estimator(estimator);
1178
1179 let scan = LogicalOperator::NodeScan(NodeScanOp {
1180 variable: "n".to_string(),
1181 label: Some("Test".to_string()),
1182 input: None,
1183 });
1184 let plan = LogicalPlan::new(scan);
1185
1186 let cardinality = optimizer.estimate_cardinality(&plan);
1187 assert!((cardinality - 500.0).abs() < 0.001);
1188 }
1189
1190 #[test]
1191 fn test_optimizer_estimate_cost() {
1192 let optimizer = Optimizer::new();
1193 let plan = LogicalPlan::new(LogicalOperator::NodeScan(NodeScanOp {
1194 variable: "n".to_string(),
1195 label: None,
1196 input: None,
1197 }));
1198
1199 let cost = optimizer.estimate_cost(&plan);
1200 assert!(cost.total() > 0.0);
1201 }
1202
1203 #[test]
1206 fn test_filter_pushdown_through_project() {
1207 let optimizer = Optimizer::new();
1208
1209 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1210 predicate: LogicalExpression::Binary {
1211 left: Box::new(LogicalExpression::Property {
1212 variable: "n".to_string(),
1213 property: "age".to_string(),
1214 }),
1215 op: BinaryOp::Gt,
1216 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1217 },
1218 input: Box::new(LogicalOperator::Project(ProjectOp {
1219 projections: vec![Projection {
1220 expression: LogicalExpression::Variable("n".to_string()),
1221 alias: None,
1222 }],
1223 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1224 variable: "n".to_string(),
1225 label: None,
1226 input: None,
1227 })),
1228 })),
1229 }));
1230
1231 let optimized = optimizer.optimize(plan).unwrap();
1232
1233 if let LogicalOperator::Project(proj) = &optimized.root {
1235 if let LogicalOperator::Filter(_) = proj.input.as_ref() {
1236 return;
1237 }
1238 }
1239 panic!("Expected Project -> Filter structure");
1240 }
1241
1242 #[test]
1243 fn test_filter_not_pushed_through_project_with_alias() {
1244 let optimizer = Optimizer::new();
1245
1246 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1248 predicate: LogicalExpression::Binary {
1249 left: Box::new(LogicalExpression::Variable("x".to_string())),
1250 op: BinaryOp::Gt,
1251 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1252 },
1253 input: Box::new(LogicalOperator::Project(ProjectOp {
1254 projections: vec![Projection {
1255 expression: LogicalExpression::Property {
1256 variable: "n".to_string(),
1257 property: "age".to_string(),
1258 },
1259 alias: Some("x".to_string()),
1260 }],
1261 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1262 variable: "n".to_string(),
1263 label: None,
1264 input: None,
1265 })),
1266 })),
1267 }));
1268
1269 let optimized = optimizer.optimize(plan).unwrap();
1270
1271 if let LogicalOperator::Filter(filter) = &optimized.root {
1273 if let LogicalOperator::Project(_) = filter.input.as_ref() {
1274 return;
1275 }
1276 }
1277 panic!("Expected Filter -> Project structure");
1278 }
1279
1280 #[test]
1281 fn test_filter_pushdown_through_limit() {
1282 let optimizer = Optimizer::new();
1283
1284 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1285 predicate: LogicalExpression::Literal(Value::Bool(true)),
1286 input: Box::new(LogicalOperator::Limit(LimitOp {
1287 count: 10,
1288 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1289 variable: "n".to_string(),
1290 label: None,
1291 input: None,
1292 })),
1293 })),
1294 }));
1295
1296 let optimized = optimizer.optimize(plan).unwrap();
1297
1298 if let LogicalOperator::Filter(filter) = &optimized.root {
1300 if let LogicalOperator::Limit(_) = filter.input.as_ref() {
1301 return;
1302 }
1303 }
1304 panic!("Expected Filter -> Limit structure");
1305 }
1306
1307 #[test]
1308 fn test_filter_pushdown_through_sort() {
1309 let optimizer = Optimizer::new();
1310
1311 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1312 predicate: LogicalExpression::Literal(Value::Bool(true)),
1313 input: Box::new(LogicalOperator::Sort(SortOp {
1314 keys: vec![SortKey {
1315 expression: LogicalExpression::Variable("n".to_string()),
1316 order: SortOrder::Ascending,
1317 }],
1318 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1319 variable: "n".to_string(),
1320 label: None,
1321 input: None,
1322 })),
1323 })),
1324 }));
1325
1326 let optimized = optimizer.optimize(plan).unwrap();
1327
1328 if let LogicalOperator::Filter(filter) = &optimized.root {
1330 if let LogicalOperator::Sort(_) = filter.input.as_ref() {
1331 return;
1332 }
1333 }
1334 panic!("Expected Filter -> Sort structure");
1335 }
1336
1337 #[test]
1338 fn test_filter_pushdown_through_distinct() {
1339 let optimizer = Optimizer::new();
1340
1341 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1342 predicate: LogicalExpression::Literal(Value::Bool(true)),
1343 input: Box::new(LogicalOperator::Distinct(DistinctOp {
1344 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1345 variable: "n".to_string(),
1346 label: None,
1347 input: None,
1348 })),
1349 columns: None,
1350 })),
1351 }));
1352
1353 let optimized = optimizer.optimize(plan).unwrap();
1354
1355 if let LogicalOperator::Filter(filter) = &optimized.root {
1357 if let LogicalOperator::Distinct(_) = filter.input.as_ref() {
1358 return;
1359 }
1360 }
1361 panic!("Expected Filter -> Distinct structure");
1362 }
1363
1364 #[test]
1365 fn test_filter_not_pushed_through_aggregate() {
1366 let optimizer = Optimizer::new();
1367
1368 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1369 predicate: LogicalExpression::Binary {
1370 left: Box::new(LogicalExpression::Variable("cnt".to_string())),
1371 op: BinaryOp::Gt,
1372 right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1373 },
1374 input: Box::new(LogicalOperator::Aggregate(AggregateOp {
1375 group_by: vec![],
1376 aggregates: vec![AggregateExpr {
1377 function: AggregateFunction::Count,
1378 expression: None,
1379 distinct: false,
1380 alias: Some("cnt".to_string()),
1381 percentile: None,
1382 }],
1383 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1384 variable: "n".to_string(),
1385 label: None,
1386 input: None,
1387 })),
1388 having: None,
1389 })),
1390 }));
1391
1392 let optimized = optimizer.optimize(plan).unwrap();
1393
1394 if let LogicalOperator::Filter(filter) = &optimized.root {
1396 if let LogicalOperator::Aggregate(_) = filter.input.as_ref() {
1397 return;
1398 }
1399 }
1400 panic!("Expected Filter -> Aggregate structure");
1401 }
1402
1403 #[test]
1404 fn test_filter_pushdown_to_left_join_side() {
1405 let optimizer = Optimizer::new();
1406
1407 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1409 predicate: LogicalExpression::Binary {
1410 left: Box::new(LogicalExpression::Property {
1411 variable: "a".to_string(),
1412 property: "age".to_string(),
1413 }),
1414 op: BinaryOp::Gt,
1415 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1416 },
1417 input: Box::new(LogicalOperator::Join(JoinOp {
1418 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1419 variable: "a".to_string(),
1420 label: Some("Person".to_string()),
1421 input: None,
1422 })),
1423 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1424 variable: "b".to_string(),
1425 label: Some("Company".to_string()),
1426 input: None,
1427 })),
1428 join_type: JoinType::Inner,
1429 conditions: vec![],
1430 })),
1431 }));
1432
1433 let optimized = optimizer.optimize(plan).unwrap();
1434
1435 if let LogicalOperator::Join(join) = &optimized.root {
1437 if let LogicalOperator::Filter(_) = join.left.as_ref() {
1438 return;
1439 }
1440 }
1441 panic!("Expected Join with Filter on left side");
1442 }
1443
1444 #[test]
1445 fn test_filter_pushdown_to_right_join_side() {
1446 let optimizer = Optimizer::new();
1447
1448 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1450 predicate: LogicalExpression::Binary {
1451 left: Box::new(LogicalExpression::Property {
1452 variable: "b".to_string(),
1453 property: "name".to_string(),
1454 }),
1455 op: BinaryOp::Eq,
1456 right: Box::new(LogicalExpression::Literal(Value::String("Acme".into()))),
1457 },
1458 input: Box::new(LogicalOperator::Join(JoinOp {
1459 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1460 variable: "a".to_string(),
1461 label: Some("Person".to_string()),
1462 input: None,
1463 })),
1464 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1465 variable: "b".to_string(),
1466 label: Some("Company".to_string()),
1467 input: None,
1468 })),
1469 join_type: JoinType::Inner,
1470 conditions: vec![],
1471 })),
1472 }));
1473
1474 let optimized = optimizer.optimize(plan).unwrap();
1475
1476 if let LogicalOperator::Join(join) = &optimized.root {
1478 if let LogicalOperator::Filter(_) = join.right.as_ref() {
1479 return;
1480 }
1481 }
1482 panic!("Expected Join with Filter on right side");
1483 }
1484
1485 #[test]
1486 fn test_filter_not_pushed_when_uses_both_join_sides() {
1487 let optimizer = Optimizer::new();
1488
1489 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1491 predicate: LogicalExpression::Binary {
1492 left: Box::new(LogicalExpression::Property {
1493 variable: "a".to_string(),
1494 property: "id".to_string(),
1495 }),
1496 op: BinaryOp::Eq,
1497 right: Box::new(LogicalExpression::Property {
1498 variable: "b".to_string(),
1499 property: "a_id".to_string(),
1500 }),
1501 },
1502 input: Box::new(LogicalOperator::Join(JoinOp {
1503 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1504 variable: "a".to_string(),
1505 label: None,
1506 input: None,
1507 })),
1508 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1509 variable: "b".to_string(),
1510 label: None,
1511 input: None,
1512 })),
1513 join_type: JoinType::Inner,
1514 conditions: vec![],
1515 })),
1516 }));
1517
1518 let optimized = optimizer.optimize(plan).unwrap();
1519
1520 if let LogicalOperator::Filter(filter) = &optimized.root {
1522 if let LogicalOperator::Join(_) = filter.input.as_ref() {
1523 return;
1524 }
1525 }
1526 panic!("Expected Filter -> Join structure");
1527 }
1528
1529 #[test]
1532 fn test_extract_variables_from_variable() {
1533 let optimizer = Optimizer::new();
1534 let expr = LogicalExpression::Variable("x".to_string());
1535 let vars = optimizer.extract_variables(&expr);
1536 assert_eq!(vars.len(), 1);
1537 assert!(vars.contains("x"));
1538 }
1539
1540 #[test]
1541 fn test_extract_variables_from_unary() {
1542 let optimizer = Optimizer::new();
1543 let expr = LogicalExpression::Unary {
1544 op: UnaryOp::Not,
1545 operand: Box::new(LogicalExpression::Variable("x".to_string())),
1546 };
1547 let vars = optimizer.extract_variables(&expr);
1548 assert_eq!(vars.len(), 1);
1549 assert!(vars.contains("x"));
1550 }
1551
1552 #[test]
1553 fn test_extract_variables_from_function_call() {
1554 let optimizer = Optimizer::new();
1555 let expr = LogicalExpression::FunctionCall {
1556 name: "length".to_string(),
1557 args: vec![
1558 LogicalExpression::Variable("a".to_string()),
1559 LogicalExpression::Variable("b".to_string()),
1560 ],
1561 distinct: false,
1562 };
1563 let vars = optimizer.extract_variables(&expr);
1564 assert_eq!(vars.len(), 2);
1565 assert!(vars.contains("a"));
1566 assert!(vars.contains("b"));
1567 }
1568
1569 #[test]
1570 fn test_extract_variables_from_list() {
1571 let optimizer = Optimizer::new();
1572 let expr = LogicalExpression::List(vec![
1573 LogicalExpression::Variable("a".to_string()),
1574 LogicalExpression::Literal(Value::Int64(1)),
1575 LogicalExpression::Variable("b".to_string()),
1576 ]);
1577 let vars = optimizer.extract_variables(&expr);
1578 assert_eq!(vars.len(), 2);
1579 assert!(vars.contains("a"));
1580 assert!(vars.contains("b"));
1581 }
1582
1583 #[test]
1584 fn test_extract_variables_from_map() {
1585 let optimizer = Optimizer::new();
1586 let expr = LogicalExpression::Map(vec![
1587 (
1588 "key1".to_string(),
1589 LogicalExpression::Variable("a".to_string()),
1590 ),
1591 (
1592 "key2".to_string(),
1593 LogicalExpression::Variable("b".to_string()),
1594 ),
1595 ]);
1596 let vars = optimizer.extract_variables(&expr);
1597 assert_eq!(vars.len(), 2);
1598 assert!(vars.contains("a"));
1599 assert!(vars.contains("b"));
1600 }
1601
1602 #[test]
1603 fn test_extract_variables_from_index_access() {
1604 let optimizer = Optimizer::new();
1605 let expr = LogicalExpression::IndexAccess {
1606 base: Box::new(LogicalExpression::Variable("list".to_string())),
1607 index: Box::new(LogicalExpression::Variable("idx".to_string())),
1608 };
1609 let vars = optimizer.extract_variables(&expr);
1610 assert_eq!(vars.len(), 2);
1611 assert!(vars.contains("list"));
1612 assert!(vars.contains("idx"));
1613 }
1614
1615 #[test]
1616 fn test_extract_variables_from_slice_access() {
1617 let optimizer = Optimizer::new();
1618 let expr = LogicalExpression::SliceAccess {
1619 base: Box::new(LogicalExpression::Variable("list".to_string())),
1620 start: Some(Box::new(LogicalExpression::Variable("s".to_string()))),
1621 end: Some(Box::new(LogicalExpression::Variable("e".to_string()))),
1622 };
1623 let vars = optimizer.extract_variables(&expr);
1624 assert_eq!(vars.len(), 3);
1625 assert!(vars.contains("list"));
1626 assert!(vars.contains("s"));
1627 assert!(vars.contains("e"));
1628 }
1629
1630 #[test]
1631 fn test_extract_variables_from_case() {
1632 let optimizer = Optimizer::new();
1633 let expr = LogicalExpression::Case {
1634 operand: Some(Box::new(LogicalExpression::Variable("x".to_string()))),
1635 when_clauses: vec![(
1636 LogicalExpression::Literal(Value::Int64(1)),
1637 LogicalExpression::Variable("a".to_string()),
1638 )],
1639 else_clause: Some(Box::new(LogicalExpression::Variable("b".to_string()))),
1640 };
1641 let vars = optimizer.extract_variables(&expr);
1642 assert_eq!(vars.len(), 3);
1643 assert!(vars.contains("x"));
1644 assert!(vars.contains("a"));
1645 assert!(vars.contains("b"));
1646 }
1647
1648 #[test]
1649 fn test_extract_variables_from_labels() {
1650 let optimizer = Optimizer::new();
1651 let expr = LogicalExpression::Labels("n".to_string());
1652 let vars = optimizer.extract_variables(&expr);
1653 assert_eq!(vars.len(), 1);
1654 assert!(vars.contains("n"));
1655 }
1656
1657 #[test]
1658 fn test_extract_variables_from_type() {
1659 let optimizer = Optimizer::new();
1660 let expr = LogicalExpression::Type("e".to_string());
1661 let vars = optimizer.extract_variables(&expr);
1662 assert_eq!(vars.len(), 1);
1663 assert!(vars.contains("e"));
1664 }
1665
1666 #[test]
1667 fn test_extract_variables_from_id() {
1668 let optimizer = Optimizer::new();
1669 let expr = LogicalExpression::Id("n".to_string());
1670 let vars = optimizer.extract_variables(&expr);
1671 assert_eq!(vars.len(), 1);
1672 assert!(vars.contains("n"));
1673 }
1674
1675 #[test]
1676 fn test_extract_variables_from_list_comprehension() {
1677 let optimizer = Optimizer::new();
1678 let expr = LogicalExpression::ListComprehension {
1679 variable: "x".to_string(),
1680 list_expr: Box::new(LogicalExpression::Variable("items".to_string())),
1681 filter_expr: Some(Box::new(LogicalExpression::Variable("pred".to_string()))),
1682 map_expr: Box::new(LogicalExpression::Variable("result".to_string())),
1683 };
1684 let vars = optimizer.extract_variables(&expr);
1685 assert!(vars.contains("items"));
1686 assert!(vars.contains("pred"));
1687 assert!(vars.contains("result"));
1688 }
1689
1690 #[test]
1691 fn test_extract_variables_from_literal_and_parameter() {
1692 let optimizer = Optimizer::new();
1693
1694 let literal = LogicalExpression::Literal(Value::Int64(42));
1695 assert!(optimizer.extract_variables(&literal).is_empty());
1696
1697 let param = LogicalExpression::Parameter("p".to_string());
1698 assert!(optimizer.extract_variables(¶m).is_empty());
1699 }
1700
1701 #[test]
1704 fn test_recursive_filter_pushdown_through_skip() {
1705 let optimizer = Optimizer::new();
1706
1707 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1708 items: vec![ReturnItem {
1709 expression: LogicalExpression::Variable("n".to_string()),
1710 alias: None,
1711 }],
1712 distinct: false,
1713 input: Box::new(LogicalOperator::Filter(FilterOp {
1714 predicate: LogicalExpression::Literal(Value::Bool(true)),
1715 input: Box::new(LogicalOperator::Skip(SkipOp {
1716 count: 5,
1717 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1718 variable: "n".to_string(),
1719 label: None,
1720 input: None,
1721 })),
1722 })),
1723 })),
1724 }));
1725
1726 let optimized = optimizer.optimize(plan).unwrap();
1727
1728 assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1730 }
1731
1732 #[test]
1733 fn test_nested_filter_pushdown() {
1734 let optimizer = Optimizer::new();
1735
1736 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1738 items: vec![ReturnItem {
1739 expression: LogicalExpression::Variable("n".to_string()),
1740 alias: None,
1741 }],
1742 distinct: false,
1743 input: Box::new(LogicalOperator::Filter(FilterOp {
1744 predicate: LogicalExpression::Binary {
1745 left: Box::new(LogicalExpression::Property {
1746 variable: "n".to_string(),
1747 property: "x".to_string(),
1748 }),
1749 op: BinaryOp::Gt,
1750 right: Box::new(LogicalExpression::Literal(Value::Int64(1))),
1751 },
1752 input: Box::new(LogicalOperator::Filter(FilterOp {
1753 predicate: LogicalExpression::Binary {
1754 left: Box::new(LogicalExpression::Property {
1755 variable: "n".to_string(),
1756 property: "y".to_string(),
1757 }),
1758 op: BinaryOp::Lt,
1759 right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1760 },
1761 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1762 variable: "n".to_string(),
1763 label: None,
1764 input: None,
1765 })),
1766 })),
1767 })),
1768 }));
1769
1770 let optimized = optimizer.optimize(plan).unwrap();
1771 assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1772 }
1773}