1pub mod cardinality;
15pub mod cost;
16pub mod join_order;
17
18pub use cardinality::{CardinalityEstimator, ColumnStats, TableStats};
19pub use cost::{Cost, CostModel};
20pub use join_order::{BitSet, DPccp, JoinGraph, JoinGraphBuilder, JoinPlan};
21
22use crate::query::plan::{FilterOp, LogicalExpression, LogicalOperator, LogicalPlan};
23use grafeo_common::utils::error::Result;
24use std::collections::HashSet;
25
26#[derive(Debug, Clone)]
28struct JoinInfo {
29 left_var: String,
30 right_var: String,
31 left_expr: LogicalExpression,
32 right_expr: LogicalExpression,
33}
34
35#[derive(Debug, Clone, PartialEq, Eq, Hash)]
37enum RequiredColumn {
38 Variable(String),
40 Property(String, String),
42}
43
44pub struct Optimizer {
49 enable_filter_pushdown: bool,
51 enable_join_reorder: bool,
53 enable_projection_pushdown: bool,
55 cost_model: CostModel,
57 card_estimator: CardinalityEstimator,
59}
60
61impl Optimizer {
62 #[must_use]
64 pub fn new() -> Self {
65 Self {
66 enable_filter_pushdown: true,
67 enable_join_reorder: true,
68 enable_projection_pushdown: true,
69 cost_model: CostModel::new(),
70 card_estimator: CardinalityEstimator::new(),
71 }
72 }
73
74 pub fn with_filter_pushdown(mut self, enabled: bool) -> Self {
76 self.enable_filter_pushdown = enabled;
77 self
78 }
79
80 pub fn with_join_reorder(mut self, enabled: bool) -> Self {
82 self.enable_join_reorder = enabled;
83 self
84 }
85
86 pub fn with_projection_pushdown(mut self, enabled: bool) -> Self {
88 self.enable_projection_pushdown = enabled;
89 self
90 }
91
92 pub fn with_cost_model(mut self, cost_model: CostModel) -> Self {
94 self.cost_model = cost_model;
95 self
96 }
97
98 pub fn with_cardinality_estimator(mut self, estimator: CardinalityEstimator) -> Self {
100 self.card_estimator = estimator;
101 self
102 }
103
104 pub fn cost_model(&self) -> &CostModel {
106 &self.cost_model
107 }
108
109 pub fn cardinality_estimator(&self) -> &CardinalityEstimator {
111 &self.card_estimator
112 }
113
114 pub fn estimate_cost(&self, plan: &LogicalPlan) -> Cost {
116 let cardinality = self.card_estimator.estimate(&plan.root);
117 self.cost_model.estimate(&plan.root, cardinality)
118 }
119
120 pub fn estimate_cardinality(&self, plan: &LogicalPlan) -> f64 {
122 self.card_estimator.estimate(&plan.root)
123 }
124
125 pub fn optimize(&self, plan: LogicalPlan) -> Result<LogicalPlan> {
131 let mut root = plan.root;
132
133 if self.enable_filter_pushdown {
135 root = self.push_filters_down(root);
136 }
137
138 if self.enable_join_reorder {
139 root = self.reorder_joins(root);
140 }
141
142 if self.enable_projection_pushdown {
143 root = self.push_projections_down(root);
144 }
145
146 Ok(LogicalPlan::new(root))
147 }
148
149 fn push_projections_down(&self, op: LogicalOperator) -> LogicalOperator {
156 let required = self.collect_required_columns(&op);
158
159 self.push_projections_recursive(op, &required)
161 }
162
163 fn collect_required_columns(&self, op: &LogicalOperator) -> HashSet<RequiredColumn> {
165 let mut required = HashSet::new();
166 self.collect_required_recursive(op, &mut required);
167 required
168 }
169
170 fn collect_required_recursive(
172 &self,
173 op: &LogicalOperator,
174 required: &mut HashSet<RequiredColumn>,
175 ) {
176 match op {
177 LogicalOperator::Return(ret) => {
178 for item in &ret.items {
179 self.collect_from_expression(&item.expression, required);
180 }
181 self.collect_required_recursive(&ret.input, required);
182 }
183 LogicalOperator::Project(proj) => {
184 for p in &proj.projections {
185 self.collect_from_expression(&p.expression, required);
186 }
187 self.collect_required_recursive(&proj.input, required);
188 }
189 LogicalOperator::Filter(filter) => {
190 self.collect_from_expression(&filter.predicate, required);
191 self.collect_required_recursive(&filter.input, required);
192 }
193 LogicalOperator::Sort(sort) => {
194 for key in &sort.keys {
195 self.collect_from_expression(&key.expression, required);
196 }
197 self.collect_required_recursive(&sort.input, required);
198 }
199 LogicalOperator::Aggregate(agg) => {
200 for expr in &agg.group_by {
201 self.collect_from_expression(expr, required);
202 }
203 for agg_expr in &agg.aggregates {
204 if let Some(ref expr) = agg_expr.expression {
205 self.collect_from_expression(expr, required);
206 }
207 }
208 if let Some(ref having) = agg.having {
209 self.collect_from_expression(having, required);
210 }
211 self.collect_required_recursive(&agg.input, required);
212 }
213 LogicalOperator::Join(join) => {
214 for cond in &join.conditions {
215 self.collect_from_expression(&cond.left, required);
216 self.collect_from_expression(&cond.right, required);
217 }
218 self.collect_required_recursive(&join.left, required);
219 self.collect_required_recursive(&join.right, required);
220 }
221 LogicalOperator::Expand(expand) => {
222 required.insert(RequiredColumn::Variable(expand.from_variable.clone()));
224 required.insert(RequiredColumn::Variable(expand.to_variable.clone()));
225 if let Some(ref edge_var) = expand.edge_variable {
226 required.insert(RequiredColumn::Variable(edge_var.clone()));
227 }
228 self.collect_required_recursive(&expand.input, required);
229 }
230 LogicalOperator::Limit(limit) => {
231 self.collect_required_recursive(&limit.input, required);
232 }
233 LogicalOperator::Skip(skip) => {
234 self.collect_required_recursive(&skip.input, required);
235 }
236 LogicalOperator::Distinct(distinct) => {
237 self.collect_required_recursive(&distinct.input, required);
238 }
239 LogicalOperator::NodeScan(scan) => {
240 required.insert(RequiredColumn::Variable(scan.variable.clone()));
241 }
242 LogicalOperator::EdgeScan(scan) => {
243 required.insert(RequiredColumn::Variable(scan.variable.clone()));
244 }
245 _ => {}
246 }
247 }
248
249 fn collect_from_expression(
251 &self,
252 expr: &LogicalExpression,
253 required: &mut HashSet<RequiredColumn>,
254 ) {
255 match expr {
256 LogicalExpression::Variable(var) => {
257 required.insert(RequiredColumn::Variable(var.clone()));
258 }
259 LogicalExpression::Property { variable, property } => {
260 required.insert(RequiredColumn::Property(variable.clone(), property.clone()));
261 required.insert(RequiredColumn::Variable(variable.clone()));
262 }
263 LogicalExpression::Binary { left, right, .. } => {
264 self.collect_from_expression(left, required);
265 self.collect_from_expression(right, required);
266 }
267 LogicalExpression::Unary { operand, .. } => {
268 self.collect_from_expression(operand, required);
269 }
270 LogicalExpression::FunctionCall { args, .. } => {
271 for arg in args {
272 self.collect_from_expression(arg, required);
273 }
274 }
275 LogicalExpression::List(items) => {
276 for item in items {
277 self.collect_from_expression(item, required);
278 }
279 }
280 LogicalExpression::Map(pairs) => {
281 for (_, value) in pairs {
282 self.collect_from_expression(value, required);
283 }
284 }
285 LogicalExpression::IndexAccess { base, index } => {
286 self.collect_from_expression(base, required);
287 self.collect_from_expression(index, required);
288 }
289 LogicalExpression::SliceAccess { base, start, end } => {
290 self.collect_from_expression(base, required);
291 if let Some(s) = start {
292 self.collect_from_expression(s, required);
293 }
294 if let Some(e) = end {
295 self.collect_from_expression(e, required);
296 }
297 }
298 LogicalExpression::Case {
299 operand,
300 when_clauses,
301 else_clause,
302 } => {
303 if let Some(op) = operand {
304 self.collect_from_expression(op, required);
305 }
306 for (cond, result) in when_clauses {
307 self.collect_from_expression(cond, required);
308 self.collect_from_expression(result, required);
309 }
310 if let Some(else_expr) = else_clause {
311 self.collect_from_expression(else_expr, required);
312 }
313 }
314 LogicalExpression::Labels(var)
315 | LogicalExpression::Type(var)
316 | LogicalExpression::Id(var) => {
317 required.insert(RequiredColumn::Variable(var.clone()));
318 }
319 LogicalExpression::ListComprehension {
320 list_expr,
321 filter_expr,
322 map_expr,
323 ..
324 } => {
325 self.collect_from_expression(list_expr, required);
326 if let Some(filter) = filter_expr {
327 self.collect_from_expression(filter, required);
328 }
329 self.collect_from_expression(map_expr, required);
330 }
331 _ => {}
332 }
333 }
334
335 fn push_projections_recursive(
337 &self,
338 op: LogicalOperator,
339 required: &HashSet<RequiredColumn>,
340 ) -> LogicalOperator {
341 match op {
342 LogicalOperator::Return(mut ret) => {
343 ret.input = Box::new(self.push_projections_recursive(*ret.input, required));
344 LogicalOperator::Return(ret)
345 }
346 LogicalOperator::Project(mut proj) => {
347 proj.input = Box::new(self.push_projections_recursive(*proj.input, required));
348 LogicalOperator::Project(proj)
349 }
350 LogicalOperator::Filter(mut filter) => {
351 filter.input = Box::new(self.push_projections_recursive(*filter.input, required));
352 LogicalOperator::Filter(filter)
353 }
354 LogicalOperator::Sort(mut sort) => {
355 sort.input = Box::new(self.push_projections_recursive(*sort.input, required));
358 LogicalOperator::Sort(sort)
359 }
360 LogicalOperator::Aggregate(mut agg) => {
361 agg.input = Box::new(self.push_projections_recursive(*agg.input, required));
362 LogicalOperator::Aggregate(agg)
363 }
364 LogicalOperator::Join(mut join) => {
365 let left_vars = self.collect_output_variables(&join.left);
368 let right_vars = self.collect_output_variables(&join.right);
369
370 let left_required: HashSet<_> = required
372 .iter()
373 .filter(|c| match c {
374 RequiredColumn::Variable(v) => left_vars.contains(v),
375 RequiredColumn::Property(v, _) => left_vars.contains(v),
376 })
377 .cloned()
378 .collect();
379
380 let right_required: HashSet<_> = required
381 .iter()
382 .filter(|c| match c {
383 RequiredColumn::Variable(v) => right_vars.contains(v),
384 RequiredColumn::Property(v, _) => right_vars.contains(v),
385 })
386 .cloned()
387 .collect();
388
389 join.left = Box::new(self.push_projections_recursive(*join.left, &left_required));
390 join.right =
391 Box::new(self.push_projections_recursive(*join.right, &right_required));
392 LogicalOperator::Join(join)
393 }
394 LogicalOperator::Expand(mut expand) => {
395 expand.input = Box::new(self.push_projections_recursive(*expand.input, required));
396 LogicalOperator::Expand(expand)
397 }
398 LogicalOperator::Limit(mut limit) => {
399 limit.input = Box::new(self.push_projections_recursive(*limit.input, required));
400 LogicalOperator::Limit(limit)
401 }
402 LogicalOperator::Skip(mut skip) => {
403 skip.input = Box::new(self.push_projections_recursive(*skip.input, required));
404 LogicalOperator::Skip(skip)
405 }
406 LogicalOperator::Distinct(mut distinct) => {
407 distinct.input =
408 Box::new(self.push_projections_recursive(*distinct.input, required));
409 LogicalOperator::Distinct(distinct)
410 }
411 other => other,
412 }
413 }
414
415 fn reorder_joins(&self, op: LogicalOperator) -> LogicalOperator {
422 let op = self.reorder_joins_recursive(op);
424
425 if let Some((relations, conditions)) = self.extract_join_tree(&op) {
427 if relations.len() >= 2 {
428 if let Some(optimized) = self.optimize_join_order(&relations, &conditions) {
429 return optimized;
430 }
431 }
432 }
433
434 op
435 }
436
437 fn reorder_joins_recursive(&self, op: LogicalOperator) -> LogicalOperator {
439 match op {
440 LogicalOperator::Return(mut ret) => {
441 ret.input = Box::new(self.reorder_joins(*ret.input));
442 LogicalOperator::Return(ret)
443 }
444 LogicalOperator::Project(mut proj) => {
445 proj.input = Box::new(self.reorder_joins(*proj.input));
446 LogicalOperator::Project(proj)
447 }
448 LogicalOperator::Filter(mut filter) => {
449 filter.input = Box::new(self.reorder_joins(*filter.input));
450 LogicalOperator::Filter(filter)
451 }
452 LogicalOperator::Limit(mut limit) => {
453 limit.input = Box::new(self.reorder_joins(*limit.input));
454 LogicalOperator::Limit(limit)
455 }
456 LogicalOperator::Skip(mut skip) => {
457 skip.input = Box::new(self.reorder_joins(*skip.input));
458 LogicalOperator::Skip(skip)
459 }
460 LogicalOperator::Sort(mut sort) => {
461 sort.input = Box::new(self.reorder_joins(*sort.input));
462 LogicalOperator::Sort(sort)
463 }
464 LogicalOperator::Distinct(mut distinct) => {
465 distinct.input = Box::new(self.reorder_joins(*distinct.input));
466 LogicalOperator::Distinct(distinct)
467 }
468 LogicalOperator::Aggregate(mut agg) => {
469 agg.input = Box::new(self.reorder_joins(*agg.input));
470 LogicalOperator::Aggregate(agg)
471 }
472 LogicalOperator::Expand(mut expand) => {
473 expand.input = Box::new(self.reorder_joins(*expand.input));
474 LogicalOperator::Expand(expand)
475 }
476 other => other,
478 }
479 }
480
481 fn extract_join_tree(
485 &self,
486 op: &LogicalOperator,
487 ) -> Option<(Vec<(String, LogicalOperator)>, Vec<JoinInfo>)> {
488 let mut relations = Vec::new();
489 let mut join_conditions = Vec::new();
490
491 if !self.collect_join_tree(op, &mut relations, &mut join_conditions) {
492 return None;
493 }
494
495 if relations.len() < 2 {
496 return None;
497 }
498
499 Some((relations, join_conditions))
500 }
501
502 fn collect_join_tree(
506 &self,
507 op: &LogicalOperator,
508 relations: &mut Vec<(String, LogicalOperator)>,
509 conditions: &mut Vec<JoinInfo>,
510 ) -> bool {
511 match op {
512 LogicalOperator::Join(join) => {
513 let left_ok = self.collect_join_tree(&join.left, relations, conditions);
515 let right_ok = self.collect_join_tree(&join.right, relations, conditions);
516
517 for cond in &join.conditions {
519 if let (Some(left_var), Some(right_var)) = (
520 self.extract_variable_from_expr(&cond.left),
521 self.extract_variable_from_expr(&cond.right),
522 ) {
523 conditions.push(JoinInfo {
524 left_var,
525 right_var,
526 left_expr: cond.left.clone(),
527 right_expr: cond.right.clone(),
528 });
529 }
530 }
531
532 left_ok && right_ok
533 }
534 LogicalOperator::NodeScan(scan) => {
535 relations.push((scan.variable.clone(), op.clone()));
536 true
537 }
538 LogicalOperator::EdgeScan(scan) => {
539 relations.push((scan.variable.clone(), op.clone()));
540 true
541 }
542 LogicalOperator::Filter(filter) => {
543 self.collect_join_tree(&filter.input, relations, conditions)
545 }
546 LogicalOperator::Expand(expand) => {
547 relations.push((expand.to_variable.clone(), op.clone()));
550 true
551 }
552 _ => false,
553 }
554 }
555
556 fn extract_variable_from_expr(&self, expr: &LogicalExpression) -> Option<String> {
558 match expr {
559 LogicalExpression::Variable(v) => Some(v.clone()),
560 LogicalExpression::Property { variable, .. } => Some(variable.clone()),
561 _ => None,
562 }
563 }
564
565 fn optimize_join_order(
567 &self,
568 relations: &[(String, LogicalOperator)],
569 conditions: &[JoinInfo],
570 ) -> Option<LogicalOperator> {
571 use join_order::{DPccp, JoinGraphBuilder};
572
573 let mut builder = JoinGraphBuilder::new();
575
576 for (var, relation) in relations {
577 builder.add_relation(var, relation.clone());
578 }
579
580 for cond in conditions {
581 builder.add_join_condition(
582 &cond.left_var,
583 &cond.right_var,
584 cond.left_expr.clone(),
585 cond.right_expr.clone(),
586 );
587 }
588
589 let graph = builder.build();
590
591 let mut dpccp = DPccp::new(&graph, &self.cost_model, &self.card_estimator);
593 let plan = dpccp.optimize()?;
594
595 Some(plan.operator)
596 }
597
598 fn push_filters_down(&self, op: LogicalOperator) -> LogicalOperator {
603 match op {
604 LogicalOperator::Filter(filter) => {
606 let optimized_input = self.push_filters_down(*filter.input);
607 self.try_push_filter_into(filter.predicate, optimized_input)
608 }
609 LogicalOperator::Return(mut ret) => {
611 ret.input = Box::new(self.push_filters_down(*ret.input));
612 LogicalOperator::Return(ret)
613 }
614 LogicalOperator::Project(mut proj) => {
615 proj.input = Box::new(self.push_filters_down(*proj.input));
616 LogicalOperator::Project(proj)
617 }
618 LogicalOperator::Limit(mut limit) => {
619 limit.input = Box::new(self.push_filters_down(*limit.input));
620 LogicalOperator::Limit(limit)
621 }
622 LogicalOperator::Skip(mut skip) => {
623 skip.input = Box::new(self.push_filters_down(*skip.input));
624 LogicalOperator::Skip(skip)
625 }
626 LogicalOperator::Sort(mut sort) => {
627 sort.input = Box::new(self.push_filters_down(*sort.input));
628 LogicalOperator::Sort(sort)
629 }
630 LogicalOperator::Distinct(mut distinct) => {
631 distinct.input = Box::new(self.push_filters_down(*distinct.input));
632 LogicalOperator::Distinct(distinct)
633 }
634 LogicalOperator::Expand(mut expand) => {
635 expand.input = Box::new(self.push_filters_down(*expand.input));
636 LogicalOperator::Expand(expand)
637 }
638 LogicalOperator::Join(mut join) => {
639 join.left = Box::new(self.push_filters_down(*join.left));
640 join.right = Box::new(self.push_filters_down(*join.right));
641 LogicalOperator::Join(join)
642 }
643 LogicalOperator::Aggregate(mut agg) => {
644 agg.input = Box::new(self.push_filters_down(*agg.input));
645 LogicalOperator::Aggregate(agg)
646 }
647 other => other,
649 }
650 }
651
652 fn try_push_filter_into(
657 &self,
658 predicate: LogicalExpression,
659 op: LogicalOperator,
660 ) -> LogicalOperator {
661 match op {
662 LogicalOperator::Project(mut proj) => {
664 let predicate_vars = self.extract_variables(&predicate);
665 let computed_vars = self.extract_projection_aliases(&proj.projections);
666
667 if predicate_vars.is_disjoint(&computed_vars) {
669 proj.input = Box::new(self.try_push_filter_into(predicate, *proj.input));
670 LogicalOperator::Project(proj)
671 } else {
672 LogicalOperator::Filter(FilterOp {
674 predicate,
675 input: Box::new(LogicalOperator::Project(proj)),
676 })
677 }
678 }
679
680 LogicalOperator::Return(mut ret) => {
682 ret.input = Box::new(self.try_push_filter_into(predicate, *ret.input));
683 LogicalOperator::Return(ret)
684 }
685
686 LogicalOperator::Expand(mut expand) => {
688 let predicate_vars = self.extract_variables(&predicate);
689
690 let uses_only_source = predicate_vars.iter().all(|v| v == &expand.from_variable);
692
693 if uses_only_source {
694 expand.input = Box::new(self.try_push_filter_into(predicate, *expand.input));
696 LogicalOperator::Expand(expand)
697 } else {
698 LogicalOperator::Filter(FilterOp {
700 predicate,
701 input: Box::new(LogicalOperator::Expand(expand)),
702 })
703 }
704 }
705
706 LogicalOperator::Join(mut join) => {
708 let predicate_vars = self.extract_variables(&predicate);
709 let left_vars = self.collect_output_variables(&join.left);
710 let right_vars = self.collect_output_variables(&join.right);
711
712 let uses_left = predicate_vars.iter().any(|v| left_vars.contains(v));
713 let uses_right = predicate_vars.iter().any(|v| right_vars.contains(v));
714
715 if uses_left && !uses_right {
716 join.left = Box::new(self.try_push_filter_into(predicate, *join.left));
718 LogicalOperator::Join(join)
719 } else if uses_right && !uses_left {
720 join.right = Box::new(self.try_push_filter_into(predicate, *join.right));
722 LogicalOperator::Join(join)
723 } else {
724 LogicalOperator::Filter(FilterOp {
726 predicate,
727 input: Box::new(LogicalOperator::Join(join)),
728 })
729 }
730 }
731
732 LogicalOperator::Aggregate(agg) => LogicalOperator::Filter(FilterOp {
734 predicate,
735 input: Box::new(LogicalOperator::Aggregate(agg)),
736 }),
737
738 LogicalOperator::NodeScan(scan) => LogicalOperator::Filter(FilterOp {
740 predicate,
741 input: Box::new(LogicalOperator::NodeScan(scan)),
742 }),
743
744 other => LogicalOperator::Filter(FilterOp {
746 predicate,
747 input: Box::new(other),
748 }),
749 }
750 }
751
752 fn collect_output_variables(&self, op: &LogicalOperator) -> HashSet<String> {
754 let mut vars = HashSet::new();
755 Self::collect_output_variables_recursive(op, &mut vars);
756 vars
757 }
758
759 fn collect_output_variables_recursive(op: &LogicalOperator, vars: &mut HashSet<String>) {
761 match op {
762 LogicalOperator::NodeScan(scan) => {
763 vars.insert(scan.variable.clone());
764 }
765 LogicalOperator::EdgeScan(scan) => {
766 vars.insert(scan.variable.clone());
767 }
768 LogicalOperator::Expand(expand) => {
769 vars.insert(expand.to_variable.clone());
770 if let Some(edge_var) = &expand.edge_variable {
771 vars.insert(edge_var.clone());
772 }
773 Self::collect_output_variables_recursive(&expand.input, vars);
774 }
775 LogicalOperator::Filter(filter) => {
776 Self::collect_output_variables_recursive(&filter.input, vars);
777 }
778 LogicalOperator::Project(proj) => {
779 for p in &proj.projections {
780 if let Some(alias) = &p.alias {
781 vars.insert(alias.clone());
782 }
783 }
784 Self::collect_output_variables_recursive(&proj.input, vars);
785 }
786 LogicalOperator::Join(join) => {
787 Self::collect_output_variables_recursive(&join.left, vars);
788 Self::collect_output_variables_recursive(&join.right, vars);
789 }
790 LogicalOperator::Aggregate(agg) => {
791 for expr in &agg.group_by {
792 Self::collect_variables(expr, vars);
793 }
794 for agg_expr in &agg.aggregates {
795 if let Some(alias) = &agg_expr.alias {
796 vars.insert(alias.clone());
797 }
798 }
799 }
800 LogicalOperator::Return(ret) => {
801 Self::collect_output_variables_recursive(&ret.input, vars);
802 }
803 LogicalOperator::Limit(limit) => {
804 Self::collect_output_variables_recursive(&limit.input, vars);
805 }
806 LogicalOperator::Skip(skip) => {
807 Self::collect_output_variables_recursive(&skip.input, vars);
808 }
809 LogicalOperator::Sort(sort) => {
810 Self::collect_output_variables_recursive(&sort.input, vars);
811 }
812 LogicalOperator::Distinct(distinct) => {
813 Self::collect_output_variables_recursive(&distinct.input, vars);
814 }
815 _ => {}
816 }
817 }
818
819 fn extract_variables(&self, expr: &LogicalExpression) -> HashSet<String> {
821 let mut vars = HashSet::new();
822 Self::collect_variables(expr, &mut vars);
823 vars
824 }
825
826 fn collect_variables(expr: &LogicalExpression, vars: &mut HashSet<String>) {
828 match expr {
829 LogicalExpression::Variable(name) => {
830 vars.insert(name.clone());
831 }
832 LogicalExpression::Property { variable, .. } => {
833 vars.insert(variable.clone());
834 }
835 LogicalExpression::Binary { left, right, .. } => {
836 Self::collect_variables(left, vars);
837 Self::collect_variables(right, vars);
838 }
839 LogicalExpression::Unary { operand, .. } => {
840 Self::collect_variables(operand, vars);
841 }
842 LogicalExpression::FunctionCall { args, .. } => {
843 for arg in args {
844 Self::collect_variables(arg, vars);
845 }
846 }
847 LogicalExpression::List(items) => {
848 for item in items {
849 Self::collect_variables(item, vars);
850 }
851 }
852 LogicalExpression::Map(pairs) => {
853 for (_, value) in pairs {
854 Self::collect_variables(value, vars);
855 }
856 }
857 LogicalExpression::IndexAccess { base, index } => {
858 Self::collect_variables(base, vars);
859 Self::collect_variables(index, vars);
860 }
861 LogicalExpression::SliceAccess { base, start, end } => {
862 Self::collect_variables(base, vars);
863 if let Some(s) = start {
864 Self::collect_variables(s, vars);
865 }
866 if let Some(e) = end {
867 Self::collect_variables(e, vars);
868 }
869 }
870 LogicalExpression::Case {
871 operand,
872 when_clauses,
873 else_clause,
874 } => {
875 if let Some(op) = operand {
876 Self::collect_variables(op, vars);
877 }
878 for (cond, result) in when_clauses {
879 Self::collect_variables(cond, vars);
880 Self::collect_variables(result, vars);
881 }
882 if let Some(else_expr) = else_clause {
883 Self::collect_variables(else_expr, vars);
884 }
885 }
886 LogicalExpression::Labels(var)
887 | LogicalExpression::Type(var)
888 | LogicalExpression::Id(var) => {
889 vars.insert(var.clone());
890 }
891 LogicalExpression::Literal(_) | LogicalExpression::Parameter(_) => {}
892 LogicalExpression::ListComprehension {
893 list_expr,
894 filter_expr,
895 map_expr,
896 ..
897 } => {
898 Self::collect_variables(list_expr, vars);
899 if let Some(filter) = filter_expr {
900 Self::collect_variables(filter, vars);
901 }
902 Self::collect_variables(map_expr, vars);
903 }
904 LogicalExpression::ExistsSubquery(_) | LogicalExpression::CountSubquery(_) => {
905 }
907 }
908 }
909
910 fn extract_projection_aliases(
912 &self,
913 projections: &[crate::query::plan::Projection],
914 ) -> HashSet<String> {
915 projections.iter().filter_map(|p| p.alias.clone()).collect()
916 }
917}
918
919impl Default for Optimizer {
920 fn default() -> Self {
921 Self::new()
922 }
923}
924
925#[cfg(test)]
926mod tests {
927 use super::*;
928 use crate::query::plan::{
929 AggregateExpr, AggregateFunction, AggregateOp, BinaryOp, DistinctOp, ExpandDirection,
930 ExpandOp, JoinOp, JoinType, LimitOp, NodeScanOp, ProjectOp, Projection, ReturnItem,
931 ReturnOp, SkipOp, SortKey, SortOp, SortOrder, UnaryOp,
932 };
933 use grafeo_common::types::Value;
934
935 #[test]
936 fn test_optimizer_filter_pushdown_simple() {
937 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
942 items: vec![ReturnItem {
943 expression: LogicalExpression::Variable("n".to_string()),
944 alias: None,
945 }],
946 distinct: false,
947 input: Box::new(LogicalOperator::Filter(FilterOp {
948 predicate: LogicalExpression::Binary {
949 left: Box::new(LogicalExpression::Property {
950 variable: "n".to_string(),
951 property: "age".to_string(),
952 }),
953 op: BinaryOp::Gt,
954 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
955 },
956 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
957 variable: "n".to_string(),
958 label: Some("Person".to_string()),
959 input: None,
960 })),
961 })),
962 }));
963
964 let optimizer = Optimizer::new();
965 let optimized = optimizer.optimize(plan).unwrap();
966
967 if let LogicalOperator::Return(ret) = &optimized.root {
969 if let LogicalOperator::Filter(filter) = ret.input.as_ref() {
970 if let LogicalOperator::NodeScan(scan) = filter.input.as_ref() {
971 assert_eq!(scan.variable, "n");
972 return;
973 }
974 }
975 }
976 panic!("Expected Return -> Filter -> NodeScan structure");
977 }
978
979 #[test]
980 fn test_optimizer_filter_pushdown_through_expand() {
981 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
985 items: vec![ReturnItem {
986 expression: LogicalExpression::Variable("b".to_string()),
987 alias: None,
988 }],
989 distinct: false,
990 input: Box::new(LogicalOperator::Filter(FilterOp {
991 predicate: LogicalExpression::Binary {
992 left: Box::new(LogicalExpression::Property {
993 variable: "a".to_string(),
994 property: "age".to_string(),
995 }),
996 op: BinaryOp::Gt,
997 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
998 },
999 input: Box::new(LogicalOperator::Expand(ExpandOp {
1000 from_variable: "a".to_string(),
1001 to_variable: "b".to_string(),
1002 edge_variable: None,
1003 direction: ExpandDirection::Outgoing,
1004 edge_type: Some("KNOWS".to_string()),
1005 min_hops: 1,
1006 max_hops: Some(1),
1007 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1008 variable: "a".to_string(),
1009 label: Some("Person".to_string()),
1010 input: None,
1011 })),
1012 path_alias: None,
1013 })),
1014 })),
1015 }));
1016
1017 let optimizer = Optimizer::new();
1018 let optimized = optimizer.optimize(plan).unwrap();
1019
1020 if let LogicalOperator::Return(ret) = &optimized.root {
1023 if let LogicalOperator::Expand(expand) = ret.input.as_ref() {
1024 if let LogicalOperator::Filter(filter) = expand.input.as_ref() {
1025 if let LogicalOperator::NodeScan(scan) = filter.input.as_ref() {
1026 assert_eq!(scan.variable, "a");
1027 assert_eq!(expand.from_variable, "a");
1028 assert_eq!(expand.to_variable, "b");
1029 return;
1030 }
1031 }
1032 }
1033 }
1034 panic!("Expected Return -> Expand -> Filter -> NodeScan structure");
1035 }
1036
1037 #[test]
1038 fn test_optimizer_filter_not_pushed_through_expand_for_target_var() {
1039 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1043 items: vec![ReturnItem {
1044 expression: LogicalExpression::Variable("a".to_string()),
1045 alias: None,
1046 }],
1047 distinct: false,
1048 input: Box::new(LogicalOperator::Filter(FilterOp {
1049 predicate: LogicalExpression::Binary {
1050 left: Box::new(LogicalExpression::Property {
1051 variable: "b".to_string(),
1052 property: "age".to_string(),
1053 }),
1054 op: BinaryOp::Gt,
1055 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1056 },
1057 input: Box::new(LogicalOperator::Expand(ExpandOp {
1058 from_variable: "a".to_string(),
1059 to_variable: "b".to_string(),
1060 edge_variable: None,
1061 direction: ExpandDirection::Outgoing,
1062 edge_type: Some("KNOWS".to_string()),
1063 min_hops: 1,
1064 max_hops: Some(1),
1065 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1066 variable: "a".to_string(),
1067 label: Some("Person".to_string()),
1068 input: None,
1069 })),
1070 path_alias: None,
1071 })),
1072 })),
1073 }));
1074
1075 let optimizer = Optimizer::new();
1076 let optimized = optimizer.optimize(plan).unwrap();
1077
1078 if let LogicalOperator::Return(ret) = &optimized.root {
1081 if let LogicalOperator::Filter(filter) = ret.input.as_ref() {
1082 if let LogicalExpression::Binary { left, .. } = &filter.predicate {
1084 if let LogicalExpression::Property { variable, .. } = left.as_ref() {
1085 assert_eq!(variable, "b");
1086 }
1087 }
1088
1089 if let LogicalOperator::Expand(expand) = filter.input.as_ref() {
1090 if let LogicalOperator::NodeScan(_) = expand.input.as_ref() {
1091 return;
1092 }
1093 }
1094 }
1095 }
1096 panic!("Expected Return -> Filter -> Expand -> NodeScan structure");
1097 }
1098
1099 #[test]
1100 fn test_optimizer_extract_variables() {
1101 let optimizer = Optimizer::new();
1102
1103 let expr = LogicalExpression::Binary {
1104 left: Box::new(LogicalExpression::Property {
1105 variable: "n".to_string(),
1106 property: "age".to_string(),
1107 }),
1108 op: BinaryOp::Gt,
1109 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1110 };
1111
1112 let vars = optimizer.extract_variables(&expr);
1113 assert_eq!(vars.len(), 1);
1114 assert!(vars.contains("n"));
1115 }
1116
1117 #[test]
1120 fn test_optimizer_default() {
1121 let optimizer = Optimizer::default();
1122 let plan = LogicalPlan::new(LogicalOperator::Empty);
1124 let result = optimizer.optimize(plan);
1125 assert!(result.is_ok());
1126 }
1127
1128 #[test]
1129 fn test_optimizer_with_filter_pushdown_disabled() {
1130 let optimizer = Optimizer::new().with_filter_pushdown(false);
1131
1132 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1133 items: vec![ReturnItem {
1134 expression: LogicalExpression::Variable("n".to_string()),
1135 alias: None,
1136 }],
1137 distinct: false,
1138 input: Box::new(LogicalOperator::Filter(FilterOp {
1139 predicate: LogicalExpression::Literal(Value::Bool(true)),
1140 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1141 variable: "n".to_string(),
1142 label: None,
1143 input: None,
1144 })),
1145 })),
1146 }));
1147
1148 let optimized = optimizer.optimize(plan).unwrap();
1149 if let LogicalOperator::Return(ret) = &optimized.root {
1151 if let LogicalOperator::Filter(_) = ret.input.as_ref() {
1152 return;
1153 }
1154 }
1155 panic!("Expected unchanged structure");
1156 }
1157
1158 #[test]
1159 fn test_optimizer_with_join_reorder_disabled() {
1160 let optimizer = Optimizer::new().with_join_reorder(false);
1161 assert!(
1162 optimizer
1163 .optimize(LogicalPlan::new(LogicalOperator::Empty))
1164 .is_ok()
1165 );
1166 }
1167
1168 #[test]
1169 fn test_optimizer_with_cost_model() {
1170 let cost_model = CostModel::new();
1171 let optimizer = Optimizer::new().with_cost_model(cost_model);
1172 assert!(
1173 optimizer
1174 .cost_model()
1175 .estimate(&LogicalOperator::Empty, 0.0)
1176 .total()
1177 < 0.001
1178 );
1179 }
1180
1181 #[test]
1182 fn test_optimizer_with_cardinality_estimator() {
1183 let mut estimator = CardinalityEstimator::new();
1184 estimator.add_table_stats("Test", TableStats::new(500));
1185 let optimizer = Optimizer::new().with_cardinality_estimator(estimator);
1186
1187 let scan = LogicalOperator::NodeScan(NodeScanOp {
1188 variable: "n".to_string(),
1189 label: Some("Test".to_string()),
1190 input: None,
1191 });
1192 let plan = LogicalPlan::new(scan);
1193
1194 let cardinality = optimizer.estimate_cardinality(&plan);
1195 assert!((cardinality - 500.0).abs() < 0.001);
1196 }
1197
1198 #[test]
1199 fn test_optimizer_estimate_cost() {
1200 let optimizer = Optimizer::new();
1201 let plan = LogicalPlan::new(LogicalOperator::NodeScan(NodeScanOp {
1202 variable: "n".to_string(),
1203 label: None,
1204 input: None,
1205 }));
1206
1207 let cost = optimizer.estimate_cost(&plan);
1208 assert!(cost.total() > 0.0);
1209 }
1210
1211 #[test]
1214 fn test_filter_pushdown_through_project() {
1215 let optimizer = Optimizer::new();
1216
1217 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1218 predicate: LogicalExpression::Binary {
1219 left: Box::new(LogicalExpression::Property {
1220 variable: "n".to_string(),
1221 property: "age".to_string(),
1222 }),
1223 op: BinaryOp::Gt,
1224 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1225 },
1226 input: Box::new(LogicalOperator::Project(ProjectOp {
1227 projections: vec![Projection {
1228 expression: LogicalExpression::Variable("n".to_string()),
1229 alias: None,
1230 }],
1231 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1232 variable: "n".to_string(),
1233 label: None,
1234 input: None,
1235 })),
1236 })),
1237 }));
1238
1239 let optimized = optimizer.optimize(plan).unwrap();
1240
1241 if let LogicalOperator::Project(proj) = &optimized.root {
1243 if let LogicalOperator::Filter(_) = proj.input.as_ref() {
1244 return;
1245 }
1246 }
1247 panic!("Expected Project -> Filter structure");
1248 }
1249
1250 #[test]
1251 fn test_filter_not_pushed_through_project_with_alias() {
1252 let optimizer = Optimizer::new();
1253
1254 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1256 predicate: LogicalExpression::Binary {
1257 left: Box::new(LogicalExpression::Variable("x".to_string())),
1258 op: BinaryOp::Gt,
1259 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1260 },
1261 input: Box::new(LogicalOperator::Project(ProjectOp {
1262 projections: vec![Projection {
1263 expression: LogicalExpression::Property {
1264 variable: "n".to_string(),
1265 property: "age".to_string(),
1266 },
1267 alias: Some("x".to_string()),
1268 }],
1269 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1270 variable: "n".to_string(),
1271 label: None,
1272 input: None,
1273 })),
1274 })),
1275 }));
1276
1277 let optimized = optimizer.optimize(plan).unwrap();
1278
1279 if let LogicalOperator::Filter(filter) = &optimized.root {
1281 if let LogicalOperator::Project(_) = filter.input.as_ref() {
1282 return;
1283 }
1284 }
1285 panic!("Expected Filter -> Project structure");
1286 }
1287
1288 #[test]
1289 fn test_filter_pushdown_through_limit() {
1290 let optimizer = Optimizer::new();
1291
1292 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1293 predicate: LogicalExpression::Literal(Value::Bool(true)),
1294 input: Box::new(LogicalOperator::Limit(LimitOp {
1295 count: 10,
1296 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1297 variable: "n".to_string(),
1298 label: None,
1299 input: None,
1300 })),
1301 })),
1302 }));
1303
1304 let optimized = optimizer.optimize(plan).unwrap();
1305
1306 if let LogicalOperator::Filter(filter) = &optimized.root {
1308 if let LogicalOperator::Limit(_) = filter.input.as_ref() {
1309 return;
1310 }
1311 }
1312 panic!("Expected Filter -> Limit structure");
1313 }
1314
1315 #[test]
1316 fn test_filter_pushdown_through_sort() {
1317 let optimizer = Optimizer::new();
1318
1319 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1320 predicate: LogicalExpression::Literal(Value::Bool(true)),
1321 input: Box::new(LogicalOperator::Sort(SortOp {
1322 keys: vec![SortKey {
1323 expression: LogicalExpression::Variable("n".to_string()),
1324 order: SortOrder::Ascending,
1325 }],
1326 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1327 variable: "n".to_string(),
1328 label: None,
1329 input: None,
1330 })),
1331 })),
1332 }));
1333
1334 let optimized = optimizer.optimize(plan).unwrap();
1335
1336 if let LogicalOperator::Filter(filter) = &optimized.root {
1338 if let LogicalOperator::Sort(_) = filter.input.as_ref() {
1339 return;
1340 }
1341 }
1342 panic!("Expected Filter -> Sort structure");
1343 }
1344
1345 #[test]
1346 fn test_filter_pushdown_through_distinct() {
1347 let optimizer = Optimizer::new();
1348
1349 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1350 predicate: LogicalExpression::Literal(Value::Bool(true)),
1351 input: Box::new(LogicalOperator::Distinct(DistinctOp {
1352 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1353 variable: "n".to_string(),
1354 label: None,
1355 input: None,
1356 })),
1357 columns: None,
1358 })),
1359 }));
1360
1361 let optimized = optimizer.optimize(plan).unwrap();
1362
1363 if let LogicalOperator::Filter(filter) = &optimized.root {
1365 if let LogicalOperator::Distinct(_) = filter.input.as_ref() {
1366 return;
1367 }
1368 }
1369 panic!("Expected Filter -> Distinct structure");
1370 }
1371
1372 #[test]
1373 fn test_filter_not_pushed_through_aggregate() {
1374 let optimizer = Optimizer::new();
1375
1376 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1377 predicate: LogicalExpression::Binary {
1378 left: Box::new(LogicalExpression::Variable("cnt".to_string())),
1379 op: BinaryOp::Gt,
1380 right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1381 },
1382 input: Box::new(LogicalOperator::Aggregate(AggregateOp {
1383 group_by: vec![],
1384 aggregates: vec![AggregateExpr {
1385 function: AggregateFunction::Count,
1386 expression: None,
1387 distinct: false,
1388 alias: Some("cnt".to_string()),
1389 percentile: None,
1390 }],
1391 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1392 variable: "n".to_string(),
1393 label: None,
1394 input: None,
1395 })),
1396 having: None,
1397 })),
1398 }));
1399
1400 let optimized = optimizer.optimize(plan).unwrap();
1401
1402 if let LogicalOperator::Filter(filter) = &optimized.root {
1404 if let LogicalOperator::Aggregate(_) = filter.input.as_ref() {
1405 return;
1406 }
1407 }
1408 panic!("Expected Filter -> Aggregate structure");
1409 }
1410
1411 #[test]
1412 fn test_filter_pushdown_to_left_join_side() {
1413 let optimizer = Optimizer::new();
1414
1415 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1417 predicate: LogicalExpression::Binary {
1418 left: Box::new(LogicalExpression::Property {
1419 variable: "a".to_string(),
1420 property: "age".to_string(),
1421 }),
1422 op: BinaryOp::Gt,
1423 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1424 },
1425 input: Box::new(LogicalOperator::Join(JoinOp {
1426 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1427 variable: "a".to_string(),
1428 label: Some("Person".to_string()),
1429 input: None,
1430 })),
1431 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1432 variable: "b".to_string(),
1433 label: Some("Company".to_string()),
1434 input: None,
1435 })),
1436 join_type: JoinType::Inner,
1437 conditions: vec![],
1438 })),
1439 }));
1440
1441 let optimized = optimizer.optimize(plan).unwrap();
1442
1443 if let LogicalOperator::Join(join) = &optimized.root {
1445 if let LogicalOperator::Filter(_) = join.left.as_ref() {
1446 return;
1447 }
1448 }
1449 panic!("Expected Join with Filter on left side");
1450 }
1451
1452 #[test]
1453 fn test_filter_pushdown_to_right_join_side() {
1454 let optimizer = Optimizer::new();
1455
1456 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1458 predicate: LogicalExpression::Binary {
1459 left: Box::new(LogicalExpression::Property {
1460 variable: "b".to_string(),
1461 property: "name".to_string(),
1462 }),
1463 op: BinaryOp::Eq,
1464 right: Box::new(LogicalExpression::Literal(Value::String("Acme".into()))),
1465 },
1466 input: Box::new(LogicalOperator::Join(JoinOp {
1467 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1468 variable: "a".to_string(),
1469 label: Some("Person".to_string()),
1470 input: None,
1471 })),
1472 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1473 variable: "b".to_string(),
1474 label: Some("Company".to_string()),
1475 input: None,
1476 })),
1477 join_type: JoinType::Inner,
1478 conditions: vec![],
1479 })),
1480 }));
1481
1482 let optimized = optimizer.optimize(plan).unwrap();
1483
1484 if let LogicalOperator::Join(join) = &optimized.root {
1486 if let LogicalOperator::Filter(_) = join.right.as_ref() {
1487 return;
1488 }
1489 }
1490 panic!("Expected Join with Filter on right side");
1491 }
1492
1493 #[test]
1494 fn test_filter_not_pushed_when_uses_both_join_sides() {
1495 let optimizer = Optimizer::new();
1496
1497 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1499 predicate: LogicalExpression::Binary {
1500 left: Box::new(LogicalExpression::Property {
1501 variable: "a".to_string(),
1502 property: "id".to_string(),
1503 }),
1504 op: BinaryOp::Eq,
1505 right: Box::new(LogicalExpression::Property {
1506 variable: "b".to_string(),
1507 property: "a_id".to_string(),
1508 }),
1509 },
1510 input: Box::new(LogicalOperator::Join(JoinOp {
1511 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1512 variable: "a".to_string(),
1513 label: None,
1514 input: None,
1515 })),
1516 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1517 variable: "b".to_string(),
1518 label: None,
1519 input: None,
1520 })),
1521 join_type: JoinType::Inner,
1522 conditions: vec![],
1523 })),
1524 }));
1525
1526 let optimized = optimizer.optimize(plan).unwrap();
1527
1528 if let LogicalOperator::Filter(filter) = &optimized.root {
1530 if let LogicalOperator::Join(_) = filter.input.as_ref() {
1531 return;
1532 }
1533 }
1534 panic!("Expected Filter -> Join structure");
1535 }
1536
1537 #[test]
1540 fn test_extract_variables_from_variable() {
1541 let optimizer = Optimizer::new();
1542 let expr = LogicalExpression::Variable("x".to_string());
1543 let vars = optimizer.extract_variables(&expr);
1544 assert_eq!(vars.len(), 1);
1545 assert!(vars.contains("x"));
1546 }
1547
1548 #[test]
1549 fn test_extract_variables_from_unary() {
1550 let optimizer = Optimizer::new();
1551 let expr = LogicalExpression::Unary {
1552 op: UnaryOp::Not,
1553 operand: Box::new(LogicalExpression::Variable("x".to_string())),
1554 };
1555 let vars = optimizer.extract_variables(&expr);
1556 assert_eq!(vars.len(), 1);
1557 assert!(vars.contains("x"));
1558 }
1559
1560 #[test]
1561 fn test_extract_variables_from_function_call() {
1562 let optimizer = Optimizer::new();
1563 let expr = LogicalExpression::FunctionCall {
1564 name: "length".to_string(),
1565 args: vec![
1566 LogicalExpression::Variable("a".to_string()),
1567 LogicalExpression::Variable("b".to_string()),
1568 ],
1569 distinct: false,
1570 };
1571 let vars = optimizer.extract_variables(&expr);
1572 assert_eq!(vars.len(), 2);
1573 assert!(vars.contains("a"));
1574 assert!(vars.contains("b"));
1575 }
1576
1577 #[test]
1578 fn test_extract_variables_from_list() {
1579 let optimizer = Optimizer::new();
1580 let expr = LogicalExpression::List(vec![
1581 LogicalExpression::Variable("a".to_string()),
1582 LogicalExpression::Literal(Value::Int64(1)),
1583 LogicalExpression::Variable("b".to_string()),
1584 ]);
1585 let vars = optimizer.extract_variables(&expr);
1586 assert_eq!(vars.len(), 2);
1587 assert!(vars.contains("a"));
1588 assert!(vars.contains("b"));
1589 }
1590
1591 #[test]
1592 fn test_extract_variables_from_map() {
1593 let optimizer = Optimizer::new();
1594 let expr = LogicalExpression::Map(vec![
1595 (
1596 "key1".to_string(),
1597 LogicalExpression::Variable("a".to_string()),
1598 ),
1599 (
1600 "key2".to_string(),
1601 LogicalExpression::Variable("b".to_string()),
1602 ),
1603 ]);
1604 let vars = optimizer.extract_variables(&expr);
1605 assert_eq!(vars.len(), 2);
1606 assert!(vars.contains("a"));
1607 assert!(vars.contains("b"));
1608 }
1609
1610 #[test]
1611 fn test_extract_variables_from_index_access() {
1612 let optimizer = Optimizer::new();
1613 let expr = LogicalExpression::IndexAccess {
1614 base: Box::new(LogicalExpression::Variable("list".to_string())),
1615 index: Box::new(LogicalExpression::Variable("idx".to_string())),
1616 };
1617 let vars = optimizer.extract_variables(&expr);
1618 assert_eq!(vars.len(), 2);
1619 assert!(vars.contains("list"));
1620 assert!(vars.contains("idx"));
1621 }
1622
1623 #[test]
1624 fn test_extract_variables_from_slice_access() {
1625 let optimizer = Optimizer::new();
1626 let expr = LogicalExpression::SliceAccess {
1627 base: Box::new(LogicalExpression::Variable("list".to_string())),
1628 start: Some(Box::new(LogicalExpression::Variable("s".to_string()))),
1629 end: Some(Box::new(LogicalExpression::Variable("e".to_string()))),
1630 };
1631 let vars = optimizer.extract_variables(&expr);
1632 assert_eq!(vars.len(), 3);
1633 assert!(vars.contains("list"));
1634 assert!(vars.contains("s"));
1635 assert!(vars.contains("e"));
1636 }
1637
1638 #[test]
1639 fn test_extract_variables_from_case() {
1640 let optimizer = Optimizer::new();
1641 let expr = LogicalExpression::Case {
1642 operand: Some(Box::new(LogicalExpression::Variable("x".to_string()))),
1643 when_clauses: vec![(
1644 LogicalExpression::Literal(Value::Int64(1)),
1645 LogicalExpression::Variable("a".to_string()),
1646 )],
1647 else_clause: Some(Box::new(LogicalExpression::Variable("b".to_string()))),
1648 };
1649 let vars = optimizer.extract_variables(&expr);
1650 assert_eq!(vars.len(), 3);
1651 assert!(vars.contains("x"));
1652 assert!(vars.contains("a"));
1653 assert!(vars.contains("b"));
1654 }
1655
1656 #[test]
1657 fn test_extract_variables_from_labels() {
1658 let optimizer = Optimizer::new();
1659 let expr = LogicalExpression::Labels("n".to_string());
1660 let vars = optimizer.extract_variables(&expr);
1661 assert_eq!(vars.len(), 1);
1662 assert!(vars.contains("n"));
1663 }
1664
1665 #[test]
1666 fn test_extract_variables_from_type() {
1667 let optimizer = Optimizer::new();
1668 let expr = LogicalExpression::Type("e".to_string());
1669 let vars = optimizer.extract_variables(&expr);
1670 assert_eq!(vars.len(), 1);
1671 assert!(vars.contains("e"));
1672 }
1673
1674 #[test]
1675 fn test_extract_variables_from_id() {
1676 let optimizer = Optimizer::new();
1677 let expr = LogicalExpression::Id("n".to_string());
1678 let vars = optimizer.extract_variables(&expr);
1679 assert_eq!(vars.len(), 1);
1680 assert!(vars.contains("n"));
1681 }
1682
1683 #[test]
1684 fn test_extract_variables_from_list_comprehension() {
1685 let optimizer = Optimizer::new();
1686 let expr = LogicalExpression::ListComprehension {
1687 variable: "x".to_string(),
1688 list_expr: Box::new(LogicalExpression::Variable("items".to_string())),
1689 filter_expr: Some(Box::new(LogicalExpression::Variable("pred".to_string()))),
1690 map_expr: Box::new(LogicalExpression::Variable("result".to_string())),
1691 };
1692 let vars = optimizer.extract_variables(&expr);
1693 assert!(vars.contains("items"));
1694 assert!(vars.contains("pred"));
1695 assert!(vars.contains("result"));
1696 }
1697
1698 #[test]
1699 fn test_extract_variables_from_literal_and_parameter() {
1700 let optimizer = Optimizer::new();
1701
1702 let literal = LogicalExpression::Literal(Value::Int64(42));
1703 assert!(optimizer.extract_variables(&literal).is_empty());
1704
1705 let param = LogicalExpression::Parameter("p".to_string());
1706 assert!(optimizer.extract_variables(¶m).is_empty());
1707 }
1708
1709 #[test]
1712 fn test_recursive_filter_pushdown_through_skip() {
1713 let optimizer = Optimizer::new();
1714
1715 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1716 items: vec![ReturnItem {
1717 expression: LogicalExpression::Variable("n".to_string()),
1718 alias: None,
1719 }],
1720 distinct: false,
1721 input: Box::new(LogicalOperator::Filter(FilterOp {
1722 predicate: LogicalExpression::Literal(Value::Bool(true)),
1723 input: Box::new(LogicalOperator::Skip(SkipOp {
1724 count: 5,
1725 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1726 variable: "n".to_string(),
1727 label: None,
1728 input: None,
1729 })),
1730 })),
1731 })),
1732 }));
1733
1734 let optimized = optimizer.optimize(plan).unwrap();
1735
1736 assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1738 }
1739
1740 #[test]
1741 fn test_nested_filter_pushdown() {
1742 let optimizer = Optimizer::new();
1743
1744 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1746 items: vec![ReturnItem {
1747 expression: LogicalExpression::Variable("n".to_string()),
1748 alias: None,
1749 }],
1750 distinct: false,
1751 input: Box::new(LogicalOperator::Filter(FilterOp {
1752 predicate: LogicalExpression::Binary {
1753 left: Box::new(LogicalExpression::Property {
1754 variable: "n".to_string(),
1755 property: "x".to_string(),
1756 }),
1757 op: BinaryOp::Gt,
1758 right: Box::new(LogicalExpression::Literal(Value::Int64(1))),
1759 },
1760 input: Box::new(LogicalOperator::Filter(FilterOp {
1761 predicate: LogicalExpression::Binary {
1762 left: Box::new(LogicalExpression::Property {
1763 variable: "n".to_string(),
1764 property: "y".to_string(),
1765 }),
1766 op: BinaryOp::Lt,
1767 right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1768 },
1769 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1770 variable: "n".to_string(),
1771 label: None,
1772 input: None,
1773 })),
1774 })),
1775 })),
1776 }));
1777
1778 let optimized = optimizer.optimize(plan).unwrap();
1779 assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1780 }
1781}