1pub mod cardinality;
15pub mod cost;
16pub mod join_order;
17
18pub use cardinality::{
19 CardinalityEstimator, ColumnStats, EstimationLog, SelectivityConfig, TableStats,
20};
21pub use cost::{Cost, CostModel};
22pub use join_order::{BitSet, DPccp, JoinGraph, JoinGraphBuilder, JoinPlan};
23
24use crate::query::plan::{FilterOp, LogicalExpression, LogicalOperator, LogicalPlan};
25use grafeo_common::utils::error::Result;
26use std::collections::HashSet;
27
28#[derive(Debug, Clone)]
30struct JoinInfo {
31 left_var: String,
32 right_var: String,
33 left_expr: LogicalExpression,
34 right_expr: LogicalExpression,
35}
36
37#[derive(Debug, Clone, PartialEq, Eq, Hash)]
39enum RequiredColumn {
40 Variable(String),
42 Property(String, String),
44}
45
46pub struct Optimizer {
51 enable_filter_pushdown: bool,
53 enable_join_reorder: bool,
55 enable_projection_pushdown: bool,
57 cost_model: CostModel,
59 card_estimator: CardinalityEstimator,
61}
62
63impl Optimizer {
64 #[must_use]
66 pub fn new() -> Self {
67 Self {
68 enable_filter_pushdown: true,
69 enable_join_reorder: true,
70 enable_projection_pushdown: true,
71 cost_model: CostModel::new(),
72 card_estimator: CardinalityEstimator::new(),
73 }
74 }
75
76 #[must_use]
81 pub fn from_store(store: &grafeo_core::graph::lpg::LpgStore) -> Self {
82 store.ensure_statistics_fresh();
83 let stats = store.statistics();
84 let estimator = CardinalityEstimator::from_statistics(&stats);
85 Self {
86 enable_filter_pushdown: true,
87 enable_join_reorder: true,
88 enable_projection_pushdown: true,
89 cost_model: CostModel::new(),
90 card_estimator: estimator,
91 }
92 }
93
94 pub fn with_filter_pushdown(mut self, enabled: bool) -> Self {
96 self.enable_filter_pushdown = enabled;
97 self
98 }
99
100 pub fn with_join_reorder(mut self, enabled: bool) -> Self {
102 self.enable_join_reorder = enabled;
103 self
104 }
105
106 pub fn with_projection_pushdown(mut self, enabled: bool) -> Self {
108 self.enable_projection_pushdown = enabled;
109 self
110 }
111
112 pub fn with_cost_model(mut self, cost_model: CostModel) -> Self {
114 self.cost_model = cost_model;
115 self
116 }
117
118 pub fn with_cardinality_estimator(mut self, estimator: CardinalityEstimator) -> Self {
120 self.card_estimator = estimator;
121 self
122 }
123
124 pub fn with_selectivity_config(mut self, config: SelectivityConfig) -> Self {
126 self.card_estimator = CardinalityEstimator::with_selectivity_config(config);
127 self
128 }
129
130 pub fn cost_model(&self) -> &CostModel {
132 &self.cost_model
133 }
134
135 pub fn cardinality_estimator(&self) -> &CardinalityEstimator {
137 &self.card_estimator
138 }
139
140 pub fn estimate_cost(&self, plan: &LogicalPlan) -> Cost {
142 let cardinality = self.card_estimator.estimate(&plan.root);
143 self.cost_model.estimate(&plan.root, cardinality)
144 }
145
146 pub fn estimate_cardinality(&self, plan: &LogicalPlan) -> f64 {
148 self.card_estimator.estimate(&plan.root)
149 }
150
151 pub fn optimize(&self, plan: LogicalPlan) -> Result<LogicalPlan> {
157 let mut root = plan.root;
158
159 if self.enable_filter_pushdown {
161 root = self.push_filters_down(root);
162 }
163
164 if self.enable_join_reorder {
165 root = self.reorder_joins(root);
166 }
167
168 if self.enable_projection_pushdown {
169 root = self.push_projections_down(root);
170 }
171
172 Ok(LogicalPlan::new(root))
173 }
174
175 fn push_projections_down(&self, op: LogicalOperator) -> LogicalOperator {
182 let required = self.collect_required_columns(&op);
184
185 self.push_projections_recursive(op, &required)
187 }
188
189 fn collect_required_columns(&self, op: &LogicalOperator) -> HashSet<RequiredColumn> {
191 let mut required = HashSet::new();
192 Self::collect_required_recursive(op, &mut required);
193 required
194 }
195
196 fn collect_required_recursive(op: &LogicalOperator, required: &mut HashSet<RequiredColumn>) {
198 match op {
199 LogicalOperator::Return(ret) => {
200 for item in &ret.items {
201 Self::collect_from_expression(&item.expression, required);
202 }
203 Self::collect_required_recursive(&ret.input, required);
204 }
205 LogicalOperator::Project(proj) => {
206 for p in &proj.projections {
207 Self::collect_from_expression(&p.expression, required);
208 }
209 Self::collect_required_recursive(&proj.input, required);
210 }
211 LogicalOperator::Filter(filter) => {
212 Self::collect_from_expression(&filter.predicate, required);
213 Self::collect_required_recursive(&filter.input, required);
214 }
215 LogicalOperator::Sort(sort) => {
216 for key in &sort.keys {
217 Self::collect_from_expression(&key.expression, required);
218 }
219 Self::collect_required_recursive(&sort.input, required);
220 }
221 LogicalOperator::Aggregate(agg) => {
222 for expr in &agg.group_by {
223 Self::collect_from_expression(expr, required);
224 }
225 for agg_expr in &agg.aggregates {
226 if let Some(ref expr) = agg_expr.expression {
227 Self::collect_from_expression(expr, required);
228 }
229 }
230 if let Some(ref having) = agg.having {
231 Self::collect_from_expression(having, required);
232 }
233 Self::collect_required_recursive(&agg.input, required);
234 }
235 LogicalOperator::Join(join) => {
236 for cond in &join.conditions {
237 Self::collect_from_expression(&cond.left, required);
238 Self::collect_from_expression(&cond.right, required);
239 }
240 Self::collect_required_recursive(&join.left, required);
241 Self::collect_required_recursive(&join.right, required);
242 }
243 LogicalOperator::Expand(expand) => {
244 required.insert(RequiredColumn::Variable(expand.from_variable.clone()));
246 required.insert(RequiredColumn::Variable(expand.to_variable.clone()));
247 if let Some(ref edge_var) = expand.edge_variable {
248 required.insert(RequiredColumn::Variable(edge_var.clone()));
249 }
250 Self::collect_required_recursive(&expand.input, required);
251 }
252 LogicalOperator::Limit(limit) => {
253 Self::collect_required_recursive(&limit.input, required);
254 }
255 LogicalOperator::Skip(skip) => {
256 Self::collect_required_recursive(&skip.input, required);
257 }
258 LogicalOperator::Distinct(distinct) => {
259 Self::collect_required_recursive(&distinct.input, required);
260 }
261 LogicalOperator::NodeScan(scan) => {
262 required.insert(RequiredColumn::Variable(scan.variable.clone()));
263 }
264 LogicalOperator::EdgeScan(scan) => {
265 required.insert(RequiredColumn::Variable(scan.variable.clone()));
266 }
267 _ => {}
268 }
269 }
270
271 fn collect_from_expression(expr: &LogicalExpression, required: &mut HashSet<RequiredColumn>) {
273 match expr {
274 LogicalExpression::Variable(var) => {
275 required.insert(RequiredColumn::Variable(var.clone()));
276 }
277 LogicalExpression::Property { variable, property } => {
278 required.insert(RequiredColumn::Property(variable.clone(), property.clone()));
279 required.insert(RequiredColumn::Variable(variable.clone()));
280 }
281 LogicalExpression::Binary { left, right, .. } => {
282 Self::collect_from_expression(left, required);
283 Self::collect_from_expression(right, required);
284 }
285 LogicalExpression::Unary { operand, .. } => {
286 Self::collect_from_expression(operand, required);
287 }
288 LogicalExpression::FunctionCall { args, .. } => {
289 for arg in args {
290 Self::collect_from_expression(arg, required);
291 }
292 }
293 LogicalExpression::List(items) => {
294 for item in items {
295 Self::collect_from_expression(item, required);
296 }
297 }
298 LogicalExpression::Map(pairs) => {
299 for (_, value) in pairs {
300 Self::collect_from_expression(value, required);
301 }
302 }
303 LogicalExpression::IndexAccess { base, index } => {
304 Self::collect_from_expression(base, required);
305 Self::collect_from_expression(index, required);
306 }
307 LogicalExpression::SliceAccess { base, start, end } => {
308 Self::collect_from_expression(base, required);
309 if let Some(s) = start {
310 Self::collect_from_expression(s, required);
311 }
312 if let Some(e) = end {
313 Self::collect_from_expression(e, required);
314 }
315 }
316 LogicalExpression::Case {
317 operand,
318 when_clauses,
319 else_clause,
320 } => {
321 if let Some(op) = operand {
322 Self::collect_from_expression(op, required);
323 }
324 for (cond, result) in when_clauses {
325 Self::collect_from_expression(cond, required);
326 Self::collect_from_expression(result, required);
327 }
328 if let Some(else_expr) = else_clause {
329 Self::collect_from_expression(else_expr, required);
330 }
331 }
332 LogicalExpression::Labels(var)
333 | LogicalExpression::Type(var)
334 | LogicalExpression::Id(var) => {
335 required.insert(RequiredColumn::Variable(var.clone()));
336 }
337 LogicalExpression::ListComprehension {
338 list_expr,
339 filter_expr,
340 map_expr,
341 ..
342 } => {
343 Self::collect_from_expression(list_expr, required);
344 if let Some(filter) = filter_expr {
345 Self::collect_from_expression(filter, required);
346 }
347 Self::collect_from_expression(map_expr, required);
348 }
349 _ => {}
350 }
351 }
352
353 fn push_projections_recursive(
355 &self,
356 op: LogicalOperator,
357 required: &HashSet<RequiredColumn>,
358 ) -> LogicalOperator {
359 match op {
360 LogicalOperator::Return(mut ret) => {
361 ret.input = Box::new(self.push_projections_recursive(*ret.input, required));
362 LogicalOperator::Return(ret)
363 }
364 LogicalOperator::Project(mut proj) => {
365 proj.input = Box::new(self.push_projections_recursive(*proj.input, required));
366 LogicalOperator::Project(proj)
367 }
368 LogicalOperator::Filter(mut filter) => {
369 filter.input = Box::new(self.push_projections_recursive(*filter.input, required));
370 LogicalOperator::Filter(filter)
371 }
372 LogicalOperator::Sort(mut sort) => {
373 sort.input = Box::new(self.push_projections_recursive(*sort.input, required));
376 LogicalOperator::Sort(sort)
377 }
378 LogicalOperator::Aggregate(mut agg) => {
379 agg.input = Box::new(self.push_projections_recursive(*agg.input, required));
380 LogicalOperator::Aggregate(agg)
381 }
382 LogicalOperator::Join(mut join) => {
383 let left_vars = self.collect_output_variables(&join.left);
386 let right_vars = self.collect_output_variables(&join.right);
387
388 let left_required: HashSet<_> = required
390 .iter()
391 .filter(|c| match c {
392 RequiredColumn::Variable(v) => left_vars.contains(v),
393 RequiredColumn::Property(v, _) => left_vars.contains(v),
394 })
395 .cloned()
396 .collect();
397
398 let right_required: HashSet<_> = required
399 .iter()
400 .filter(|c| match c {
401 RequiredColumn::Variable(v) => right_vars.contains(v),
402 RequiredColumn::Property(v, _) => right_vars.contains(v),
403 })
404 .cloned()
405 .collect();
406
407 join.left = Box::new(self.push_projections_recursive(*join.left, &left_required));
408 join.right =
409 Box::new(self.push_projections_recursive(*join.right, &right_required));
410 LogicalOperator::Join(join)
411 }
412 LogicalOperator::Expand(mut expand) => {
413 expand.input = Box::new(self.push_projections_recursive(*expand.input, required));
414 LogicalOperator::Expand(expand)
415 }
416 LogicalOperator::Limit(mut limit) => {
417 limit.input = Box::new(self.push_projections_recursive(*limit.input, required));
418 LogicalOperator::Limit(limit)
419 }
420 LogicalOperator::Skip(mut skip) => {
421 skip.input = Box::new(self.push_projections_recursive(*skip.input, required));
422 LogicalOperator::Skip(skip)
423 }
424 LogicalOperator::Distinct(mut distinct) => {
425 distinct.input =
426 Box::new(self.push_projections_recursive(*distinct.input, required));
427 LogicalOperator::Distinct(distinct)
428 }
429 other => other,
430 }
431 }
432
433 fn reorder_joins(&self, op: LogicalOperator) -> LogicalOperator {
440 let op = self.reorder_joins_recursive(op);
442
443 if let Some((relations, conditions)) = self.extract_join_tree(&op)
445 && relations.len() >= 2
446 && let Some(optimized) = self.optimize_join_order(&relations, &conditions)
447 {
448 return optimized;
449 }
450
451 op
452 }
453
454 fn reorder_joins_recursive(&self, op: LogicalOperator) -> LogicalOperator {
456 match op {
457 LogicalOperator::Return(mut ret) => {
458 ret.input = Box::new(self.reorder_joins(*ret.input));
459 LogicalOperator::Return(ret)
460 }
461 LogicalOperator::Project(mut proj) => {
462 proj.input = Box::new(self.reorder_joins(*proj.input));
463 LogicalOperator::Project(proj)
464 }
465 LogicalOperator::Filter(mut filter) => {
466 filter.input = Box::new(self.reorder_joins(*filter.input));
467 LogicalOperator::Filter(filter)
468 }
469 LogicalOperator::Limit(mut limit) => {
470 limit.input = Box::new(self.reorder_joins(*limit.input));
471 LogicalOperator::Limit(limit)
472 }
473 LogicalOperator::Skip(mut skip) => {
474 skip.input = Box::new(self.reorder_joins(*skip.input));
475 LogicalOperator::Skip(skip)
476 }
477 LogicalOperator::Sort(mut sort) => {
478 sort.input = Box::new(self.reorder_joins(*sort.input));
479 LogicalOperator::Sort(sort)
480 }
481 LogicalOperator::Distinct(mut distinct) => {
482 distinct.input = Box::new(self.reorder_joins(*distinct.input));
483 LogicalOperator::Distinct(distinct)
484 }
485 LogicalOperator::Aggregate(mut agg) => {
486 agg.input = Box::new(self.reorder_joins(*agg.input));
487 LogicalOperator::Aggregate(agg)
488 }
489 LogicalOperator::Expand(mut expand) => {
490 expand.input = Box::new(self.reorder_joins(*expand.input));
491 LogicalOperator::Expand(expand)
492 }
493 other => other,
495 }
496 }
497
498 fn extract_join_tree(
502 &self,
503 op: &LogicalOperator,
504 ) -> Option<(Vec<(String, LogicalOperator)>, Vec<JoinInfo>)> {
505 let mut relations = Vec::new();
506 let mut join_conditions = Vec::new();
507
508 if !self.collect_join_tree(op, &mut relations, &mut join_conditions) {
509 return None;
510 }
511
512 if relations.len() < 2 {
513 return None;
514 }
515
516 Some((relations, join_conditions))
517 }
518
519 fn collect_join_tree(
523 &self,
524 op: &LogicalOperator,
525 relations: &mut Vec<(String, LogicalOperator)>,
526 conditions: &mut Vec<JoinInfo>,
527 ) -> bool {
528 match op {
529 LogicalOperator::Join(join) => {
530 let left_ok = self.collect_join_tree(&join.left, relations, conditions);
532 let right_ok = self.collect_join_tree(&join.right, relations, conditions);
533
534 for cond in &join.conditions {
536 if let (Some(left_var), Some(right_var)) = (
537 self.extract_variable_from_expr(&cond.left),
538 self.extract_variable_from_expr(&cond.right),
539 ) {
540 conditions.push(JoinInfo {
541 left_var,
542 right_var,
543 left_expr: cond.left.clone(),
544 right_expr: cond.right.clone(),
545 });
546 }
547 }
548
549 left_ok && right_ok
550 }
551 LogicalOperator::NodeScan(scan) => {
552 relations.push((scan.variable.clone(), op.clone()));
553 true
554 }
555 LogicalOperator::EdgeScan(scan) => {
556 relations.push((scan.variable.clone(), op.clone()));
557 true
558 }
559 LogicalOperator::Filter(filter) => {
560 self.collect_join_tree(&filter.input, relations, conditions)
562 }
563 LogicalOperator::Expand(expand) => {
564 relations.push((expand.to_variable.clone(), op.clone()));
567 true
568 }
569 _ => false,
570 }
571 }
572
573 fn extract_variable_from_expr(&self, expr: &LogicalExpression) -> Option<String> {
575 match expr {
576 LogicalExpression::Variable(v) => Some(v.clone()),
577 LogicalExpression::Property { variable, .. } => Some(variable.clone()),
578 _ => None,
579 }
580 }
581
582 fn optimize_join_order(
584 &self,
585 relations: &[(String, LogicalOperator)],
586 conditions: &[JoinInfo],
587 ) -> Option<LogicalOperator> {
588 use join_order::{DPccp, JoinGraphBuilder};
589
590 let mut builder = JoinGraphBuilder::new();
592
593 for (var, relation) in relations {
594 builder.add_relation(var, relation.clone());
595 }
596
597 for cond in conditions {
598 builder.add_join_condition(
599 &cond.left_var,
600 &cond.right_var,
601 cond.left_expr.clone(),
602 cond.right_expr.clone(),
603 );
604 }
605
606 let graph = builder.build();
607
608 let mut dpccp = DPccp::new(&graph, &self.cost_model, &self.card_estimator);
610 let plan = dpccp.optimize()?;
611
612 Some(plan.operator)
613 }
614
615 fn push_filters_down(&self, op: LogicalOperator) -> LogicalOperator {
620 match op {
621 LogicalOperator::Filter(filter) => {
623 let optimized_input = self.push_filters_down(*filter.input);
624 self.try_push_filter_into(filter.predicate, optimized_input)
625 }
626 LogicalOperator::Return(mut ret) => {
628 ret.input = Box::new(self.push_filters_down(*ret.input));
629 LogicalOperator::Return(ret)
630 }
631 LogicalOperator::Project(mut proj) => {
632 proj.input = Box::new(self.push_filters_down(*proj.input));
633 LogicalOperator::Project(proj)
634 }
635 LogicalOperator::Limit(mut limit) => {
636 limit.input = Box::new(self.push_filters_down(*limit.input));
637 LogicalOperator::Limit(limit)
638 }
639 LogicalOperator::Skip(mut skip) => {
640 skip.input = Box::new(self.push_filters_down(*skip.input));
641 LogicalOperator::Skip(skip)
642 }
643 LogicalOperator::Sort(mut sort) => {
644 sort.input = Box::new(self.push_filters_down(*sort.input));
645 LogicalOperator::Sort(sort)
646 }
647 LogicalOperator::Distinct(mut distinct) => {
648 distinct.input = Box::new(self.push_filters_down(*distinct.input));
649 LogicalOperator::Distinct(distinct)
650 }
651 LogicalOperator::Expand(mut expand) => {
652 expand.input = Box::new(self.push_filters_down(*expand.input));
653 LogicalOperator::Expand(expand)
654 }
655 LogicalOperator::Join(mut join) => {
656 join.left = Box::new(self.push_filters_down(*join.left));
657 join.right = Box::new(self.push_filters_down(*join.right));
658 LogicalOperator::Join(join)
659 }
660 LogicalOperator::Aggregate(mut agg) => {
661 agg.input = Box::new(self.push_filters_down(*agg.input));
662 LogicalOperator::Aggregate(agg)
663 }
664 other => other,
666 }
667 }
668
669 fn try_push_filter_into(
674 &self,
675 predicate: LogicalExpression,
676 op: LogicalOperator,
677 ) -> LogicalOperator {
678 match op {
679 LogicalOperator::Project(mut proj) => {
681 let predicate_vars = self.extract_variables(&predicate);
682 let computed_vars = self.extract_projection_aliases(&proj.projections);
683
684 if predicate_vars.is_disjoint(&computed_vars) {
686 proj.input = Box::new(self.try_push_filter_into(predicate, *proj.input));
687 LogicalOperator::Project(proj)
688 } else {
689 LogicalOperator::Filter(FilterOp {
691 predicate,
692 input: Box::new(LogicalOperator::Project(proj)),
693 })
694 }
695 }
696
697 LogicalOperator::Return(mut ret) => {
699 ret.input = Box::new(self.try_push_filter_into(predicate, *ret.input));
700 LogicalOperator::Return(ret)
701 }
702
703 LogicalOperator::Expand(mut expand) => {
705 let predicate_vars = self.extract_variables(&predicate);
706
707 let mut introduced_vars = vec![&expand.to_variable];
712 if let Some(ref edge_var) = expand.edge_variable {
713 introduced_vars.push(edge_var);
714 }
715 if let Some(ref path_alias) = expand.path_alias {
716 introduced_vars.push(path_alias);
717 }
718
719 let uses_introduced_vars =
721 predicate_vars.iter().any(|v| introduced_vars.contains(&v));
722
723 if !uses_introduced_vars {
724 expand.input = Box::new(self.try_push_filter_into(predicate, *expand.input));
726 LogicalOperator::Expand(expand)
727 } else {
728 LogicalOperator::Filter(FilterOp {
730 predicate,
731 input: Box::new(LogicalOperator::Expand(expand)),
732 })
733 }
734 }
735
736 LogicalOperator::Join(mut join) => {
738 let predicate_vars = self.extract_variables(&predicate);
739 let left_vars = self.collect_output_variables(&join.left);
740 let right_vars = self.collect_output_variables(&join.right);
741
742 let uses_left = predicate_vars.iter().any(|v| left_vars.contains(v));
743 let uses_right = predicate_vars.iter().any(|v| right_vars.contains(v));
744
745 if uses_left && !uses_right {
746 join.left = Box::new(self.try_push_filter_into(predicate, *join.left));
748 LogicalOperator::Join(join)
749 } else if uses_right && !uses_left {
750 join.right = Box::new(self.try_push_filter_into(predicate, *join.right));
752 LogicalOperator::Join(join)
753 } else {
754 LogicalOperator::Filter(FilterOp {
756 predicate,
757 input: Box::new(LogicalOperator::Join(join)),
758 })
759 }
760 }
761
762 LogicalOperator::Aggregate(agg) => LogicalOperator::Filter(FilterOp {
764 predicate,
765 input: Box::new(LogicalOperator::Aggregate(agg)),
766 }),
767
768 LogicalOperator::NodeScan(scan) => LogicalOperator::Filter(FilterOp {
770 predicate,
771 input: Box::new(LogicalOperator::NodeScan(scan)),
772 }),
773
774 other => LogicalOperator::Filter(FilterOp {
776 predicate,
777 input: Box::new(other),
778 }),
779 }
780 }
781
782 fn collect_output_variables(&self, op: &LogicalOperator) -> HashSet<String> {
784 let mut vars = HashSet::new();
785 Self::collect_output_variables_recursive(op, &mut vars);
786 vars
787 }
788
789 fn collect_output_variables_recursive(op: &LogicalOperator, vars: &mut HashSet<String>) {
791 match op {
792 LogicalOperator::NodeScan(scan) => {
793 vars.insert(scan.variable.clone());
794 }
795 LogicalOperator::EdgeScan(scan) => {
796 vars.insert(scan.variable.clone());
797 }
798 LogicalOperator::Expand(expand) => {
799 vars.insert(expand.to_variable.clone());
800 if let Some(edge_var) = &expand.edge_variable {
801 vars.insert(edge_var.clone());
802 }
803 Self::collect_output_variables_recursive(&expand.input, vars);
804 }
805 LogicalOperator::Filter(filter) => {
806 Self::collect_output_variables_recursive(&filter.input, vars);
807 }
808 LogicalOperator::Project(proj) => {
809 for p in &proj.projections {
810 if let Some(alias) = &p.alias {
811 vars.insert(alias.clone());
812 }
813 }
814 Self::collect_output_variables_recursive(&proj.input, vars);
815 }
816 LogicalOperator::Join(join) => {
817 Self::collect_output_variables_recursive(&join.left, vars);
818 Self::collect_output_variables_recursive(&join.right, vars);
819 }
820 LogicalOperator::Aggregate(agg) => {
821 for expr in &agg.group_by {
822 Self::collect_variables(expr, vars);
823 }
824 for agg_expr in &agg.aggregates {
825 if let Some(alias) = &agg_expr.alias {
826 vars.insert(alias.clone());
827 }
828 }
829 }
830 LogicalOperator::Return(ret) => {
831 Self::collect_output_variables_recursive(&ret.input, vars);
832 }
833 LogicalOperator::Limit(limit) => {
834 Self::collect_output_variables_recursive(&limit.input, vars);
835 }
836 LogicalOperator::Skip(skip) => {
837 Self::collect_output_variables_recursive(&skip.input, vars);
838 }
839 LogicalOperator::Sort(sort) => {
840 Self::collect_output_variables_recursive(&sort.input, vars);
841 }
842 LogicalOperator::Distinct(distinct) => {
843 Self::collect_output_variables_recursive(&distinct.input, vars);
844 }
845 _ => {}
846 }
847 }
848
849 fn extract_variables(&self, expr: &LogicalExpression) -> HashSet<String> {
851 let mut vars = HashSet::new();
852 Self::collect_variables(expr, &mut vars);
853 vars
854 }
855
856 fn collect_variables(expr: &LogicalExpression, vars: &mut HashSet<String>) {
858 match expr {
859 LogicalExpression::Variable(name) => {
860 vars.insert(name.clone());
861 }
862 LogicalExpression::Property { variable, .. } => {
863 vars.insert(variable.clone());
864 }
865 LogicalExpression::Binary { left, right, .. } => {
866 Self::collect_variables(left, vars);
867 Self::collect_variables(right, vars);
868 }
869 LogicalExpression::Unary { operand, .. } => {
870 Self::collect_variables(operand, vars);
871 }
872 LogicalExpression::FunctionCall { args, .. } => {
873 for arg in args {
874 Self::collect_variables(arg, vars);
875 }
876 }
877 LogicalExpression::List(items) => {
878 for item in items {
879 Self::collect_variables(item, vars);
880 }
881 }
882 LogicalExpression::Map(pairs) => {
883 for (_, value) in pairs {
884 Self::collect_variables(value, vars);
885 }
886 }
887 LogicalExpression::IndexAccess { base, index } => {
888 Self::collect_variables(base, vars);
889 Self::collect_variables(index, vars);
890 }
891 LogicalExpression::SliceAccess { base, start, end } => {
892 Self::collect_variables(base, vars);
893 if let Some(s) = start {
894 Self::collect_variables(s, vars);
895 }
896 if let Some(e) = end {
897 Self::collect_variables(e, vars);
898 }
899 }
900 LogicalExpression::Case {
901 operand,
902 when_clauses,
903 else_clause,
904 } => {
905 if let Some(op) = operand {
906 Self::collect_variables(op, vars);
907 }
908 for (cond, result) in when_clauses {
909 Self::collect_variables(cond, vars);
910 Self::collect_variables(result, vars);
911 }
912 if let Some(else_expr) = else_clause {
913 Self::collect_variables(else_expr, vars);
914 }
915 }
916 LogicalExpression::Labels(var)
917 | LogicalExpression::Type(var)
918 | LogicalExpression::Id(var) => {
919 vars.insert(var.clone());
920 }
921 LogicalExpression::Literal(_) | LogicalExpression::Parameter(_) => {}
922 LogicalExpression::ListComprehension {
923 list_expr,
924 filter_expr,
925 map_expr,
926 ..
927 } => {
928 Self::collect_variables(list_expr, vars);
929 if let Some(filter) = filter_expr {
930 Self::collect_variables(filter, vars);
931 }
932 Self::collect_variables(map_expr, vars);
933 }
934 LogicalExpression::ExistsSubquery(_) | LogicalExpression::CountSubquery(_) => {
935 }
937 }
938 }
939
940 fn extract_projection_aliases(
942 &self,
943 projections: &[crate::query::plan::Projection],
944 ) -> HashSet<String> {
945 projections.iter().filter_map(|p| p.alias.clone()).collect()
946 }
947}
948
949impl Default for Optimizer {
950 fn default() -> Self {
951 Self::new()
952 }
953}
954
955#[cfg(test)]
956mod tests {
957 use super::*;
958 use crate::query::plan::{
959 AggregateExpr, AggregateFunction, AggregateOp, BinaryOp, DistinctOp, ExpandDirection,
960 ExpandOp, JoinOp, JoinType, LimitOp, NodeScanOp, ProjectOp, Projection, ReturnItem,
961 ReturnOp, SkipOp, SortKey, SortOp, SortOrder, UnaryOp,
962 };
963 use grafeo_common::types::Value;
964
965 #[test]
966 fn test_optimizer_filter_pushdown_simple() {
967 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
972 items: vec![ReturnItem {
973 expression: LogicalExpression::Variable("n".to_string()),
974 alias: None,
975 }],
976 distinct: false,
977 input: Box::new(LogicalOperator::Filter(FilterOp {
978 predicate: LogicalExpression::Binary {
979 left: Box::new(LogicalExpression::Property {
980 variable: "n".to_string(),
981 property: "age".to_string(),
982 }),
983 op: BinaryOp::Gt,
984 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
985 },
986 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
987 variable: "n".to_string(),
988 label: Some("Person".to_string()),
989 input: None,
990 })),
991 })),
992 }));
993
994 let optimizer = Optimizer::new();
995 let optimized = optimizer.optimize(plan).unwrap();
996
997 if let LogicalOperator::Return(ret) = &optimized.root
999 && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1000 && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1001 {
1002 assert_eq!(scan.variable, "n");
1003 return;
1004 }
1005 panic!("Expected Return -> Filter -> NodeScan structure");
1006 }
1007
1008 #[test]
1009 fn test_optimizer_filter_pushdown_through_expand() {
1010 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1014 items: vec![ReturnItem {
1015 expression: LogicalExpression::Variable("b".to_string()),
1016 alias: None,
1017 }],
1018 distinct: false,
1019 input: Box::new(LogicalOperator::Filter(FilterOp {
1020 predicate: LogicalExpression::Binary {
1021 left: Box::new(LogicalExpression::Property {
1022 variable: "a".to_string(),
1023 property: "age".to_string(),
1024 }),
1025 op: BinaryOp::Gt,
1026 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1027 },
1028 input: Box::new(LogicalOperator::Expand(ExpandOp {
1029 from_variable: "a".to_string(),
1030 to_variable: "b".to_string(),
1031 edge_variable: None,
1032 direction: ExpandDirection::Outgoing,
1033 edge_type: Some("KNOWS".to_string()),
1034 min_hops: 1,
1035 max_hops: Some(1),
1036 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1037 variable: "a".to_string(),
1038 label: Some("Person".to_string()),
1039 input: None,
1040 })),
1041 path_alias: None,
1042 })),
1043 })),
1044 }));
1045
1046 let optimizer = Optimizer::new();
1047 let optimized = optimizer.optimize(plan).unwrap();
1048
1049 if let LogicalOperator::Return(ret) = &optimized.root
1052 && let LogicalOperator::Expand(expand) = ret.input.as_ref()
1053 && let LogicalOperator::Filter(filter) = expand.input.as_ref()
1054 && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1055 {
1056 assert_eq!(scan.variable, "a");
1057 assert_eq!(expand.from_variable, "a");
1058 assert_eq!(expand.to_variable, "b");
1059 return;
1060 }
1061 panic!("Expected Return -> Expand -> Filter -> NodeScan structure");
1062 }
1063
1064 #[test]
1065 fn test_optimizer_filter_not_pushed_through_expand_for_target_var() {
1066 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1070 items: vec![ReturnItem {
1071 expression: LogicalExpression::Variable("a".to_string()),
1072 alias: None,
1073 }],
1074 distinct: false,
1075 input: Box::new(LogicalOperator::Filter(FilterOp {
1076 predicate: LogicalExpression::Binary {
1077 left: Box::new(LogicalExpression::Property {
1078 variable: "b".to_string(),
1079 property: "age".to_string(),
1080 }),
1081 op: BinaryOp::Gt,
1082 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1083 },
1084 input: Box::new(LogicalOperator::Expand(ExpandOp {
1085 from_variable: "a".to_string(),
1086 to_variable: "b".to_string(),
1087 edge_variable: None,
1088 direction: ExpandDirection::Outgoing,
1089 edge_type: Some("KNOWS".to_string()),
1090 min_hops: 1,
1091 max_hops: Some(1),
1092 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1093 variable: "a".to_string(),
1094 label: Some("Person".to_string()),
1095 input: None,
1096 })),
1097 path_alias: None,
1098 })),
1099 })),
1100 }));
1101
1102 let optimizer = Optimizer::new();
1103 let optimized = optimizer.optimize(plan).unwrap();
1104
1105 if let LogicalOperator::Return(ret) = &optimized.root
1108 && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1109 {
1110 if let LogicalExpression::Binary { left, .. } = &filter.predicate
1112 && let LogicalExpression::Property { variable, .. } = left.as_ref()
1113 {
1114 assert_eq!(variable, "b");
1115 }
1116
1117 if let LogicalOperator::Expand(expand) = filter.input.as_ref()
1118 && let LogicalOperator::NodeScan(_) = expand.input.as_ref()
1119 {
1120 return;
1121 }
1122 }
1123 panic!("Expected Return -> Filter -> Expand -> NodeScan structure");
1124 }
1125
1126 #[test]
1127 fn test_optimizer_extract_variables() {
1128 let optimizer = Optimizer::new();
1129
1130 let expr = LogicalExpression::Binary {
1131 left: Box::new(LogicalExpression::Property {
1132 variable: "n".to_string(),
1133 property: "age".to_string(),
1134 }),
1135 op: BinaryOp::Gt,
1136 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1137 };
1138
1139 let vars = optimizer.extract_variables(&expr);
1140 assert_eq!(vars.len(), 1);
1141 assert!(vars.contains("n"));
1142 }
1143
1144 #[test]
1147 fn test_optimizer_default() {
1148 let optimizer = Optimizer::default();
1149 let plan = LogicalPlan::new(LogicalOperator::Empty);
1151 let result = optimizer.optimize(plan);
1152 assert!(result.is_ok());
1153 }
1154
1155 #[test]
1156 fn test_optimizer_with_filter_pushdown_disabled() {
1157 let optimizer = Optimizer::new().with_filter_pushdown(false);
1158
1159 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1160 items: vec![ReturnItem {
1161 expression: LogicalExpression::Variable("n".to_string()),
1162 alias: None,
1163 }],
1164 distinct: false,
1165 input: Box::new(LogicalOperator::Filter(FilterOp {
1166 predicate: LogicalExpression::Literal(Value::Bool(true)),
1167 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1168 variable: "n".to_string(),
1169 label: None,
1170 input: None,
1171 })),
1172 })),
1173 }));
1174
1175 let optimized = optimizer.optimize(plan).unwrap();
1176 if let LogicalOperator::Return(ret) = &optimized.root
1178 && let LogicalOperator::Filter(_) = ret.input.as_ref()
1179 {
1180 return;
1181 }
1182 panic!("Expected unchanged structure");
1183 }
1184
1185 #[test]
1186 fn test_optimizer_with_join_reorder_disabled() {
1187 let optimizer = Optimizer::new().with_join_reorder(false);
1188 assert!(
1189 optimizer
1190 .optimize(LogicalPlan::new(LogicalOperator::Empty))
1191 .is_ok()
1192 );
1193 }
1194
1195 #[test]
1196 fn test_optimizer_with_cost_model() {
1197 let cost_model = CostModel::new();
1198 let optimizer = Optimizer::new().with_cost_model(cost_model);
1199 assert!(
1200 optimizer
1201 .cost_model()
1202 .estimate(&LogicalOperator::Empty, 0.0)
1203 .total()
1204 < 0.001
1205 );
1206 }
1207
1208 #[test]
1209 fn test_optimizer_with_cardinality_estimator() {
1210 let mut estimator = CardinalityEstimator::new();
1211 estimator.add_table_stats("Test", TableStats::new(500));
1212 let optimizer = Optimizer::new().with_cardinality_estimator(estimator);
1213
1214 let scan = LogicalOperator::NodeScan(NodeScanOp {
1215 variable: "n".to_string(),
1216 label: Some("Test".to_string()),
1217 input: None,
1218 });
1219 let plan = LogicalPlan::new(scan);
1220
1221 let cardinality = optimizer.estimate_cardinality(&plan);
1222 assert!((cardinality - 500.0).abs() < 0.001);
1223 }
1224
1225 #[test]
1226 fn test_optimizer_estimate_cost() {
1227 let optimizer = Optimizer::new();
1228 let plan = LogicalPlan::new(LogicalOperator::NodeScan(NodeScanOp {
1229 variable: "n".to_string(),
1230 label: None,
1231 input: None,
1232 }));
1233
1234 let cost = optimizer.estimate_cost(&plan);
1235 assert!(cost.total() > 0.0);
1236 }
1237
1238 #[test]
1241 fn test_filter_pushdown_through_project() {
1242 let optimizer = Optimizer::new();
1243
1244 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1245 predicate: LogicalExpression::Binary {
1246 left: Box::new(LogicalExpression::Property {
1247 variable: "n".to_string(),
1248 property: "age".to_string(),
1249 }),
1250 op: BinaryOp::Gt,
1251 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1252 },
1253 input: Box::new(LogicalOperator::Project(ProjectOp {
1254 projections: vec![Projection {
1255 expression: LogicalExpression::Variable("n".to_string()),
1256 alias: None,
1257 }],
1258 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1259 variable: "n".to_string(),
1260 label: None,
1261 input: None,
1262 })),
1263 })),
1264 }));
1265
1266 let optimized = optimizer.optimize(plan).unwrap();
1267
1268 if let LogicalOperator::Project(proj) = &optimized.root
1270 && let LogicalOperator::Filter(_) = proj.input.as_ref()
1271 {
1272 return;
1273 }
1274 panic!("Expected Project -> Filter structure");
1275 }
1276
1277 #[test]
1278 fn test_filter_not_pushed_through_project_with_alias() {
1279 let optimizer = Optimizer::new();
1280
1281 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1283 predicate: LogicalExpression::Binary {
1284 left: Box::new(LogicalExpression::Variable("x".to_string())),
1285 op: BinaryOp::Gt,
1286 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1287 },
1288 input: Box::new(LogicalOperator::Project(ProjectOp {
1289 projections: vec![Projection {
1290 expression: LogicalExpression::Property {
1291 variable: "n".to_string(),
1292 property: "age".to_string(),
1293 },
1294 alias: Some("x".to_string()),
1295 }],
1296 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1297 variable: "n".to_string(),
1298 label: None,
1299 input: None,
1300 })),
1301 })),
1302 }));
1303
1304 let optimized = optimizer.optimize(plan).unwrap();
1305
1306 if let LogicalOperator::Filter(filter) = &optimized.root
1308 && let LogicalOperator::Project(_) = filter.input.as_ref()
1309 {
1310 return;
1311 }
1312 panic!("Expected Filter -> Project structure");
1313 }
1314
1315 #[test]
1316 fn test_filter_pushdown_through_limit() {
1317 let optimizer = Optimizer::new();
1318
1319 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1320 predicate: LogicalExpression::Literal(Value::Bool(true)),
1321 input: Box::new(LogicalOperator::Limit(LimitOp {
1322 count: 10,
1323 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1324 variable: "n".to_string(),
1325 label: None,
1326 input: None,
1327 })),
1328 })),
1329 }));
1330
1331 let optimized = optimizer.optimize(plan).unwrap();
1332
1333 if let LogicalOperator::Filter(filter) = &optimized.root
1335 && let LogicalOperator::Limit(_) = filter.input.as_ref()
1336 {
1337 return;
1338 }
1339 panic!("Expected Filter -> Limit structure");
1340 }
1341
1342 #[test]
1343 fn test_filter_pushdown_through_sort() {
1344 let optimizer = Optimizer::new();
1345
1346 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1347 predicate: LogicalExpression::Literal(Value::Bool(true)),
1348 input: Box::new(LogicalOperator::Sort(SortOp {
1349 keys: vec![SortKey {
1350 expression: LogicalExpression::Variable("n".to_string()),
1351 order: SortOrder::Ascending,
1352 }],
1353 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1354 variable: "n".to_string(),
1355 label: None,
1356 input: None,
1357 })),
1358 })),
1359 }));
1360
1361 let optimized = optimizer.optimize(plan).unwrap();
1362
1363 if let LogicalOperator::Filter(filter) = &optimized.root
1365 && let LogicalOperator::Sort(_) = filter.input.as_ref()
1366 {
1367 return;
1368 }
1369 panic!("Expected Filter -> Sort structure");
1370 }
1371
1372 #[test]
1373 fn test_filter_pushdown_through_distinct() {
1374 let optimizer = Optimizer::new();
1375
1376 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1377 predicate: LogicalExpression::Literal(Value::Bool(true)),
1378 input: Box::new(LogicalOperator::Distinct(DistinctOp {
1379 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1380 variable: "n".to_string(),
1381 label: None,
1382 input: None,
1383 })),
1384 columns: None,
1385 })),
1386 }));
1387
1388 let optimized = optimizer.optimize(plan).unwrap();
1389
1390 if let LogicalOperator::Filter(filter) = &optimized.root
1392 && let LogicalOperator::Distinct(_) = filter.input.as_ref()
1393 {
1394 return;
1395 }
1396 panic!("Expected Filter -> Distinct structure");
1397 }
1398
1399 #[test]
1400 fn test_filter_not_pushed_through_aggregate() {
1401 let optimizer = Optimizer::new();
1402
1403 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1404 predicate: LogicalExpression::Binary {
1405 left: Box::new(LogicalExpression::Variable("cnt".to_string())),
1406 op: BinaryOp::Gt,
1407 right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1408 },
1409 input: Box::new(LogicalOperator::Aggregate(AggregateOp {
1410 group_by: vec![],
1411 aggregates: vec![AggregateExpr {
1412 function: AggregateFunction::Count,
1413 expression: None,
1414 distinct: false,
1415 alias: Some("cnt".to_string()),
1416 percentile: None,
1417 }],
1418 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1419 variable: "n".to_string(),
1420 label: None,
1421 input: None,
1422 })),
1423 having: None,
1424 })),
1425 }));
1426
1427 let optimized = optimizer.optimize(plan).unwrap();
1428
1429 if let LogicalOperator::Filter(filter) = &optimized.root
1431 && let LogicalOperator::Aggregate(_) = filter.input.as_ref()
1432 {
1433 return;
1434 }
1435 panic!("Expected Filter -> Aggregate structure");
1436 }
1437
1438 #[test]
1439 fn test_filter_pushdown_to_left_join_side() {
1440 let optimizer = Optimizer::new();
1441
1442 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1444 predicate: LogicalExpression::Binary {
1445 left: Box::new(LogicalExpression::Property {
1446 variable: "a".to_string(),
1447 property: "age".to_string(),
1448 }),
1449 op: BinaryOp::Gt,
1450 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1451 },
1452 input: Box::new(LogicalOperator::Join(JoinOp {
1453 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1454 variable: "a".to_string(),
1455 label: Some("Person".to_string()),
1456 input: None,
1457 })),
1458 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1459 variable: "b".to_string(),
1460 label: Some("Company".to_string()),
1461 input: None,
1462 })),
1463 join_type: JoinType::Inner,
1464 conditions: vec![],
1465 })),
1466 }));
1467
1468 let optimized = optimizer.optimize(plan).unwrap();
1469
1470 if let LogicalOperator::Join(join) = &optimized.root
1472 && let LogicalOperator::Filter(_) = join.left.as_ref()
1473 {
1474 return;
1475 }
1476 panic!("Expected Join with Filter on left side");
1477 }
1478
1479 #[test]
1480 fn test_filter_pushdown_to_right_join_side() {
1481 let optimizer = Optimizer::new();
1482
1483 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1485 predicate: LogicalExpression::Binary {
1486 left: Box::new(LogicalExpression::Property {
1487 variable: "b".to_string(),
1488 property: "name".to_string(),
1489 }),
1490 op: BinaryOp::Eq,
1491 right: Box::new(LogicalExpression::Literal(Value::String("Acme".into()))),
1492 },
1493 input: Box::new(LogicalOperator::Join(JoinOp {
1494 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1495 variable: "a".to_string(),
1496 label: Some("Person".to_string()),
1497 input: None,
1498 })),
1499 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1500 variable: "b".to_string(),
1501 label: Some("Company".to_string()),
1502 input: None,
1503 })),
1504 join_type: JoinType::Inner,
1505 conditions: vec![],
1506 })),
1507 }));
1508
1509 let optimized = optimizer.optimize(plan).unwrap();
1510
1511 if let LogicalOperator::Join(join) = &optimized.root
1513 && let LogicalOperator::Filter(_) = join.right.as_ref()
1514 {
1515 return;
1516 }
1517 panic!("Expected Join with Filter on right side");
1518 }
1519
1520 #[test]
1521 fn test_filter_not_pushed_when_uses_both_join_sides() {
1522 let optimizer = Optimizer::new();
1523
1524 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1526 predicate: LogicalExpression::Binary {
1527 left: Box::new(LogicalExpression::Property {
1528 variable: "a".to_string(),
1529 property: "id".to_string(),
1530 }),
1531 op: BinaryOp::Eq,
1532 right: Box::new(LogicalExpression::Property {
1533 variable: "b".to_string(),
1534 property: "a_id".to_string(),
1535 }),
1536 },
1537 input: Box::new(LogicalOperator::Join(JoinOp {
1538 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1539 variable: "a".to_string(),
1540 label: None,
1541 input: None,
1542 })),
1543 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1544 variable: "b".to_string(),
1545 label: None,
1546 input: None,
1547 })),
1548 join_type: JoinType::Inner,
1549 conditions: vec![],
1550 })),
1551 }));
1552
1553 let optimized = optimizer.optimize(plan).unwrap();
1554
1555 if let LogicalOperator::Filter(filter) = &optimized.root
1557 && let LogicalOperator::Join(_) = filter.input.as_ref()
1558 {
1559 return;
1560 }
1561 panic!("Expected Filter -> Join structure");
1562 }
1563
1564 #[test]
1567 fn test_extract_variables_from_variable() {
1568 let optimizer = Optimizer::new();
1569 let expr = LogicalExpression::Variable("x".to_string());
1570 let vars = optimizer.extract_variables(&expr);
1571 assert_eq!(vars.len(), 1);
1572 assert!(vars.contains("x"));
1573 }
1574
1575 #[test]
1576 fn test_extract_variables_from_unary() {
1577 let optimizer = Optimizer::new();
1578 let expr = LogicalExpression::Unary {
1579 op: UnaryOp::Not,
1580 operand: Box::new(LogicalExpression::Variable("x".to_string())),
1581 };
1582 let vars = optimizer.extract_variables(&expr);
1583 assert_eq!(vars.len(), 1);
1584 assert!(vars.contains("x"));
1585 }
1586
1587 #[test]
1588 fn test_extract_variables_from_function_call() {
1589 let optimizer = Optimizer::new();
1590 let expr = LogicalExpression::FunctionCall {
1591 name: "length".to_string(),
1592 args: vec![
1593 LogicalExpression::Variable("a".to_string()),
1594 LogicalExpression::Variable("b".to_string()),
1595 ],
1596 distinct: false,
1597 };
1598 let vars = optimizer.extract_variables(&expr);
1599 assert_eq!(vars.len(), 2);
1600 assert!(vars.contains("a"));
1601 assert!(vars.contains("b"));
1602 }
1603
1604 #[test]
1605 fn test_extract_variables_from_list() {
1606 let optimizer = Optimizer::new();
1607 let expr = LogicalExpression::List(vec![
1608 LogicalExpression::Variable("a".to_string()),
1609 LogicalExpression::Literal(Value::Int64(1)),
1610 LogicalExpression::Variable("b".to_string()),
1611 ]);
1612 let vars = optimizer.extract_variables(&expr);
1613 assert_eq!(vars.len(), 2);
1614 assert!(vars.contains("a"));
1615 assert!(vars.contains("b"));
1616 }
1617
1618 #[test]
1619 fn test_extract_variables_from_map() {
1620 let optimizer = Optimizer::new();
1621 let expr = LogicalExpression::Map(vec![
1622 (
1623 "key1".to_string(),
1624 LogicalExpression::Variable("a".to_string()),
1625 ),
1626 (
1627 "key2".to_string(),
1628 LogicalExpression::Variable("b".to_string()),
1629 ),
1630 ]);
1631 let vars = optimizer.extract_variables(&expr);
1632 assert_eq!(vars.len(), 2);
1633 assert!(vars.contains("a"));
1634 assert!(vars.contains("b"));
1635 }
1636
1637 #[test]
1638 fn test_extract_variables_from_index_access() {
1639 let optimizer = Optimizer::new();
1640 let expr = LogicalExpression::IndexAccess {
1641 base: Box::new(LogicalExpression::Variable("list".to_string())),
1642 index: Box::new(LogicalExpression::Variable("idx".to_string())),
1643 };
1644 let vars = optimizer.extract_variables(&expr);
1645 assert_eq!(vars.len(), 2);
1646 assert!(vars.contains("list"));
1647 assert!(vars.contains("idx"));
1648 }
1649
1650 #[test]
1651 fn test_extract_variables_from_slice_access() {
1652 let optimizer = Optimizer::new();
1653 let expr = LogicalExpression::SliceAccess {
1654 base: Box::new(LogicalExpression::Variable("list".to_string())),
1655 start: Some(Box::new(LogicalExpression::Variable("s".to_string()))),
1656 end: Some(Box::new(LogicalExpression::Variable("e".to_string()))),
1657 };
1658 let vars = optimizer.extract_variables(&expr);
1659 assert_eq!(vars.len(), 3);
1660 assert!(vars.contains("list"));
1661 assert!(vars.contains("s"));
1662 assert!(vars.contains("e"));
1663 }
1664
1665 #[test]
1666 fn test_extract_variables_from_case() {
1667 let optimizer = Optimizer::new();
1668 let expr = LogicalExpression::Case {
1669 operand: Some(Box::new(LogicalExpression::Variable("x".to_string()))),
1670 when_clauses: vec![(
1671 LogicalExpression::Literal(Value::Int64(1)),
1672 LogicalExpression::Variable("a".to_string()),
1673 )],
1674 else_clause: Some(Box::new(LogicalExpression::Variable("b".to_string()))),
1675 };
1676 let vars = optimizer.extract_variables(&expr);
1677 assert_eq!(vars.len(), 3);
1678 assert!(vars.contains("x"));
1679 assert!(vars.contains("a"));
1680 assert!(vars.contains("b"));
1681 }
1682
1683 #[test]
1684 fn test_extract_variables_from_labels() {
1685 let optimizer = Optimizer::new();
1686 let expr = LogicalExpression::Labels("n".to_string());
1687 let vars = optimizer.extract_variables(&expr);
1688 assert_eq!(vars.len(), 1);
1689 assert!(vars.contains("n"));
1690 }
1691
1692 #[test]
1693 fn test_extract_variables_from_type() {
1694 let optimizer = Optimizer::new();
1695 let expr = LogicalExpression::Type("e".to_string());
1696 let vars = optimizer.extract_variables(&expr);
1697 assert_eq!(vars.len(), 1);
1698 assert!(vars.contains("e"));
1699 }
1700
1701 #[test]
1702 fn test_extract_variables_from_id() {
1703 let optimizer = Optimizer::new();
1704 let expr = LogicalExpression::Id("n".to_string());
1705 let vars = optimizer.extract_variables(&expr);
1706 assert_eq!(vars.len(), 1);
1707 assert!(vars.contains("n"));
1708 }
1709
1710 #[test]
1711 fn test_extract_variables_from_list_comprehension() {
1712 let optimizer = Optimizer::new();
1713 let expr = LogicalExpression::ListComprehension {
1714 variable: "x".to_string(),
1715 list_expr: Box::new(LogicalExpression::Variable("items".to_string())),
1716 filter_expr: Some(Box::new(LogicalExpression::Variable("pred".to_string()))),
1717 map_expr: Box::new(LogicalExpression::Variable("result".to_string())),
1718 };
1719 let vars = optimizer.extract_variables(&expr);
1720 assert!(vars.contains("items"));
1721 assert!(vars.contains("pred"));
1722 assert!(vars.contains("result"));
1723 }
1724
1725 #[test]
1726 fn test_extract_variables_from_literal_and_parameter() {
1727 let optimizer = Optimizer::new();
1728
1729 let literal = LogicalExpression::Literal(Value::Int64(42));
1730 assert!(optimizer.extract_variables(&literal).is_empty());
1731
1732 let param = LogicalExpression::Parameter("p".to_string());
1733 assert!(optimizer.extract_variables(¶m).is_empty());
1734 }
1735
1736 #[test]
1739 fn test_recursive_filter_pushdown_through_skip() {
1740 let optimizer = Optimizer::new();
1741
1742 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1743 items: vec![ReturnItem {
1744 expression: LogicalExpression::Variable("n".to_string()),
1745 alias: None,
1746 }],
1747 distinct: false,
1748 input: Box::new(LogicalOperator::Filter(FilterOp {
1749 predicate: LogicalExpression::Literal(Value::Bool(true)),
1750 input: Box::new(LogicalOperator::Skip(SkipOp {
1751 count: 5,
1752 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1753 variable: "n".to_string(),
1754 label: None,
1755 input: None,
1756 })),
1757 })),
1758 })),
1759 }));
1760
1761 let optimized = optimizer.optimize(plan).unwrap();
1762
1763 assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1765 }
1766
1767 #[test]
1768 fn test_nested_filter_pushdown() {
1769 let optimizer = Optimizer::new();
1770
1771 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1773 items: vec![ReturnItem {
1774 expression: LogicalExpression::Variable("n".to_string()),
1775 alias: None,
1776 }],
1777 distinct: false,
1778 input: Box::new(LogicalOperator::Filter(FilterOp {
1779 predicate: LogicalExpression::Binary {
1780 left: Box::new(LogicalExpression::Property {
1781 variable: "n".to_string(),
1782 property: "x".to_string(),
1783 }),
1784 op: BinaryOp::Gt,
1785 right: Box::new(LogicalExpression::Literal(Value::Int64(1))),
1786 },
1787 input: Box::new(LogicalOperator::Filter(FilterOp {
1788 predicate: LogicalExpression::Binary {
1789 left: Box::new(LogicalExpression::Property {
1790 variable: "n".to_string(),
1791 property: "y".to_string(),
1792 }),
1793 op: BinaryOp::Lt,
1794 right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1795 },
1796 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1797 variable: "n".to_string(),
1798 label: None,
1799 input: None,
1800 })),
1801 })),
1802 })),
1803 }));
1804
1805 let optimized = optimizer.optimize(plan).unwrap();
1806 assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1807 }
1808}