1pub mod cardinality;
15pub mod cost;
16pub mod join_order;
17
18pub use cardinality::{CardinalityEstimator, ColumnStats, TableStats};
19pub use cost::{Cost, CostModel};
20pub use join_order::{BitSet, DPccp, JoinGraph, JoinGraphBuilder, JoinPlan};
21
22use crate::query::plan::{FilterOp, LogicalExpression, LogicalOperator, LogicalPlan};
23use grafeo_common::utils::error::Result;
24use std::collections::HashSet;
25
26#[derive(Debug, Clone)]
28struct JoinInfo {
29 left_var: String,
30 right_var: String,
31 left_expr: LogicalExpression,
32 right_expr: LogicalExpression,
33}
34
35#[derive(Debug, Clone, PartialEq, Eq, Hash)]
37enum RequiredColumn {
38 Variable(String),
40 Property(String, String),
42}
43
44pub struct Optimizer {
49 enable_filter_pushdown: bool,
51 enable_join_reorder: bool,
53 enable_projection_pushdown: bool,
55 cost_model: CostModel,
57 card_estimator: CardinalityEstimator,
59}
60
61impl Optimizer {
62 #[must_use]
64 pub fn new() -> Self {
65 Self {
66 enable_filter_pushdown: true,
67 enable_join_reorder: true,
68 enable_projection_pushdown: true,
69 cost_model: CostModel::new(),
70 card_estimator: CardinalityEstimator::new(),
71 }
72 }
73
74 #[must_use]
79 pub fn from_store(store: &grafeo_core::graph::lpg::LpgStore) -> Self {
80 store.ensure_statistics_fresh();
81 let stats = store.statistics();
82 let estimator = CardinalityEstimator::from_statistics(&stats);
83 Self {
84 enable_filter_pushdown: true,
85 enable_join_reorder: true,
86 enable_projection_pushdown: true,
87 cost_model: CostModel::new(),
88 card_estimator: estimator,
89 }
90 }
91
92 pub fn with_filter_pushdown(mut self, enabled: bool) -> Self {
94 self.enable_filter_pushdown = enabled;
95 self
96 }
97
98 pub fn with_join_reorder(mut self, enabled: bool) -> Self {
100 self.enable_join_reorder = enabled;
101 self
102 }
103
104 pub fn with_projection_pushdown(mut self, enabled: bool) -> Self {
106 self.enable_projection_pushdown = enabled;
107 self
108 }
109
110 pub fn with_cost_model(mut self, cost_model: CostModel) -> Self {
112 self.cost_model = cost_model;
113 self
114 }
115
116 pub fn with_cardinality_estimator(mut self, estimator: CardinalityEstimator) -> Self {
118 self.card_estimator = estimator;
119 self
120 }
121
122 pub fn cost_model(&self) -> &CostModel {
124 &self.cost_model
125 }
126
127 pub fn cardinality_estimator(&self) -> &CardinalityEstimator {
129 &self.card_estimator
130 }
131
132 pub fn estimate_cost(&self, plan: &LogicalPlan) -> Cost {
134 let cardinality = self.card_estimator.estimate(&plan.root);
135 self.cost_model.estimate(&plan.root, cardinality)
136 }
137
138 pub fn estimate_cardinality(&self, plan: &LogicalPlan) -> f64 {
140 self.card_estimator.estimate(&plan.root)
141 }
142
143 pub fn optimize(&self, plan: LogicalPlan) -> Result<LogicalPlan> {
149 let mut root = plan.root;
150
151 if self.enable_filter_pushdown {
153 root = self.push_filters_down(root);
154 }
155
156 if self.enable_join_reorder {
157 root = self.reorder_joins(root);
158 }
159
160 if self.enable_projection_pushdown {
161 root = self.push_projections_down(root);
162 }
163
164 Ok(LogicalPlan::new(root))
165 }
166
167 fn push_projections_down(&self, op: LogicalOperator) -> LogicalOperator {
174 let required = self.collect_required_columns(&op);
176
177 self.push_projections_recursive(op, &required)
179 }
180
181 fn collect_required_columns(&self, op: &LogicalOperator) -> HashSet<RequiredColumn> {
183 let mut required = HashSet::new();
184 Self::collect_required_recursive(op, &mut required);
185 required
186 }
187
188 fn collect_required_recursive(op: &LogicalOperator, required: &mut HashSet<RequiredColumn>) {
190 match op {
191 LogicalOperator::Return(ret) => {
192 for item in &ret.items {
193 Self::collect_from_expression(&item.expression, required);
194 }
195 Self::collect_required_recursive(&ret.input, required);
196 }
197 LogicalOperator::Project(proj) => {
198 for p in &proj.projections {
199 Self::collect_from_expression(&p.expression, required);
200 }
201 Self::collect_required_recursive(&proj.input, required);
202 }
203 LogicalOperator::Filter(filter) => {
204 Self::collect_from_expression(&filter.predicate, required);
205 Self::collect_required_recursive(&filter.input, required);
206 }
207 LogicalOperator::Sort(sort) => {
208 for key in &sort.keys {
209 Self::collect_from_expression(&key.expression, required);
210 }
211 Self::collect_required_recursive(&sort.input, required);
212 }
213 LogicalOperator::Aggregate(agg) => {
214 for expr in &agg.group_by {
215 Self::collect_from_expression(expr, required);
216 }
217 for agg_expr in &agg.aggregates {
218 if let Some(ref expr) = agg_expr.expression {
219 Self::collect_from_expression(expr, required);
220 }
221 }
222 if let Some(ref having) = agg.having {
223 Self::collect_from_expression(having, required);
224 }
225 Self::collect_required_recursive(&agg.input, required);
226 }
227 LogicalOperator::Join(join) => {
228 for cond in &join.conditions {
229 Self::collect_from_expression(&cond.left, required);
230 Self::collect_from_expression(&cond.right, required);
231 }
232 Self::collect_required_recursive(&join.left, required);
233 Self::collect_required_recursive(&join.right, required);
234 }
235 LogicalOperator::Expand(expand) => {
236 required.insert(RequiredColumn::Variable(expand.from_variable.clone()));
238 required.insert(RequiredColumn::Variable(expand.to_variable.clone()));
239 if let Some(ref edge_var) = expand.edge_variable {
240 required.insert(RequiredColumn::Variable(edge_var.clone()));
241 }
242 Self::collect_required_recursive(&expand.input, required);
243 }
244 LogicalOperator::Limit(limit) => {
245 Self::collect_required_recursive(&limit.input, required);
246 }
247 LogicalOperator::Skip(skip) => {
248 Self::collect_required_recursive(&skip.input, required);
249 }
250 LogicalOperator::Distinct(distinct) => {
251 Self::collect_required_recursive(&distinct.input, required);
252 }
253 LogicalOperator::NodeScan(scan) => {
254 required.insert(RequiredColumn::Variable(scan.variable.clone()));
255 }
256 LogicalOperator::EdgeScan(scan) => {
257 required.insert(RequiredColumn::Variable(scan.variable.clone()));
258 }
259 _ => {}
260 }
261 }
262
263 fn collect_from_expression(expr: &LogicalExpression, required: &mut HashSet<RequiredColumn>) {
265 match expr {
266 LogicalExpression::Variable(var) => {
267 required.insert(RequiredColumn::Variable(var.clone()));
268 }
269 LogicalExpression::Property { variable, property } => {
270 required.insert(RequiredColumn::Property(variable.clone(), property.clone()));
271 required.insert(RequiredColumn::Variable(variable.clone()));
272 }
273 LogicalExpression::Binary { left, right, .. } => {
274 Self::collect_from_expression(left, required);
275 Self::collect_from_expression(right, required);
276 }
277 LogicalExpression::Unary { operand, .. } => {
278 Self::collect_from_expression(operand, required);
279 }
280 LogicalExpression::FunctionCall { args, .. } => {
281 for arg in args {
282 Self::collect_from_expression(arg, required);
283 }
284 }
285 LogicalExpression::List(items) => {
286 for item in items {
287 Self::collect_from_expression(item, required);
288 }
289 }
290 LogicalExpression::Map(pairs) => {
291 for (_, value) in pairs {
292 Self::collect_from_expression(value, required);
293 }
294 }
295 LogicalExpression::IndexAccess { base, index } => {
296 Self::collect_from_expression(base, required);
297 Self::collect_from_expression(index, required);
298 }
299 LogicalExpression::SliceAccess { base, start, end } => {
300 Self::collect_from_expression(base, required);
301 if let Some(s) = start {
302 Self::collect_from_expression(s, required);
303 }
304 if let Some(e) = end {
305 Self::collect_from_expression(e, required);
306 }
307 }
308 LogicalExpression::Case {
309 operand,
310 when_clauses,
311 else_clause,
312 } => {
313 if let Some(op) = operand {
314 Self::collect_from_expression(op, required);
315 }
316 for (cond, result) in when_clauses {
317 Self::collect_from_expression(cond, required);
318 Self::collect_from_expression(result, required);
319 }
320 if let Some(else_expr) = else_clause {
321 Self::collect_from_expression(else_expr, required);
322 }
323 }
324 LogicalExpression::Labels(var)
325 | LogicalExpression::Type(var)
326 | LogicalExpression::Id(var) => {
327 required.insert(RequiredColumn::Variable(var.clone()));
328 }
329 LogicalExpression::ListComprehension {
330 list_expr,
331 filter_expr,
332 map_expr,
333 ..
334 } => {
335 Self::collect_from_expression(list_expr, required);
336 if let Some(filter) = filter_expr {
337 Self::collect_from_expression(filter, required);
338 }
339 Self::collect_from_expression(map_expr, required);
340 }
341 _ => {}
342 }
343 }
344
345 fn push_projections_recursive(
347 &self,
348 op: LogicalOperator,
349 required: &HashSet<RequiredColumn>,
350 ) -> LogicalOperator {
351 match op {
352 LogicalOperator::Return(mut ret) => {
353 ret.input = Box::new(self.push_projections_recursive(*ret.input, required));
354 LogicalOperator::Return(ret)
355 }
356 LogicalOperator::Project(mut proj) => {
357 proj.input = Box::new(self.push_projections_recursive(*proj.input, required));
358 LogicalOperator::Project(proj)
359 }
360 LogicalOperator::Filter(mut filter) => {
361 filter.input = Box::new(self.push_projections_recursive(*filter.input, required));
362 LogicalOperator::Filter(filter)
363 }
364 LogicalOperator::Sort(mut sort) => {
365 sort.input = Box::new(self.push_projections_recursive(*sort.input, required));
368 LogicalOperator::Sort(sort)
369 }
370 LogicalOperator::Aggregate(mut agg) => {
371 agg.input = Box::new(self.push_projections_recursive(*agg.input, required));
372 LogicalOperator::Aggregate(agg)
373 }
374 LogicalOperator::Join(mut join) => {
375 let left_vars = self.collect_output_variables(&join.left);
378 let right_vars = self.collect_output_variables(&join.right);
379
380 let left_required: HashSet<_> = required
382 .iter()
383 .filter(|c| match c {
384 RequiredColumn::Variable(v) => left_vars.contains(v),
385 RequiredColumn::Property(v, _) => left_vars.contains(v),
386 })
387 .cloned()
388 .collect();
389
390 let right_required: HashSet<_> = required
391 .iter()
392 .filter(|c| match c {
393 RequiredColumn::Variable(v) => right_vars.contains(v),
394 RequiredColumn::Property(v, _) => right_vars.contains(v),
395 })
396 .cloned()
397 .collect();
398
399 join.left = Box::new(self.push_projections_recursive(*join.left, &left_required));
400 join.right =
401 Box::new(self.push_projections_recursive(*join.right, &right_required));
402 LogicalOperator::Join(join)
403 }
404 LogicalOperator::Expand(mut expand) => {
405 expand.input = Box::new(self.push_projections_recursive(*expand.input, required));
406 LogicalOperator::Expand(expand)
407 }
408 LogicalOperator::Limit(mut limit) => {
409 limit.input = Box::new(self.push_projections_recursive(*limit.input, required));
410 LogicalOperator::Limit(limit)
411 }
412 LogicalOperator::Skip(mut skip) => {
413 skip.input = Box::new(self.push_projections_recursive(*skip.input, required));
414 LogicalOperator::Skip(skip)
415 }
416 LogicalOperator::Distinct(mut distinct) => {
417 distinct.input =
418 Box::new(self.push_projections_recursive(*distinct.input, required));
419 LogicalOperator::Distinct(distinct)
420 }
421 other => other,
422 }
423 }
424
425 fn reorder_joins(&self, op: LogicalOperator) -> LogicalOperator {
432 let op = self.reorder_joins_recursive(op);
434
435 if let Some((relations, conditions)) = self.extract_join_tree(&op)
437 && relations.len() >= 2
438 && let Some(optimized) = self.optimize_join_order(&relations, &conditions)
439 {
440 return optimized;
441 }
442
443 op
444 }
445
446 fn reorder_joins_recursive(&self, op: LogicalOperator) -> LogicalOperator {
448 match op {
449 LogicalOperator::Return(mut ret) => {
450 ret.input = Box::new(self.reorder_joins(*ret.input));
451 LogicalOperator::Return(ret)
452 }
453 LogicalOperator::Project(mut proj) => {
454 proj.input = Box::new(self.reorder_joins(*proj.input));
455 LogicalOperator::Project(proj)
456 }
457 LogicalOperator::Filter(mut filter) => {
458 filter.input = Box::new(self.reorder_joins(*filter.input));
459 LogicalOperator::Filter(filter)
460 }
461 LogicalOperator::Limit(mut limit) => {
462 limit.input = Box::new(self.reorder_joins(*limit.input));
463 LogicalOperator::Limit(limit)
464 }
465 LogicalOperator::Skip(mut skip) => {
466 skip.input = Box::new(self.reorder_joins(*skip.input));
467 LogicalOperator::Skip(skip)
468 }
469 LogicalOperator::Sort(mut sort) => {
470 sort.input = Box::new(self.reorder_joins(*sort.input));
471 LogicalOperator::Sort(sort)
472 }
473 LogicalOperator::Distinct(mut distinct) => {
474 distinct.input = Box::new(self.reorder_joins(*distinct.input));
475 LogicalOperator::Distinct(distinct)
476 }
477 LogicalOperator::Aggregate(mut agg) => {
478 agg.input = Box::new(self.reorder_joins(*agg.input));
479 LogicalOperator::Aggregate(agg)
480 }
481 LogicalOperator::Expand(mut expand) => {
482 expand.input = Box::new(self.reorder_joins(*expand.input));
483 LogicalOperator::Expand(expand)
484 }
485 other => other,
487 }
488 }
489
490 fn extract_join_tree(
494 &self,
495 op: &LogicalOperator,
496 ) -> Option<(Vec<(String, LogicalOperator)>, Vec<JoinInfo>)> {
497 let mut relations = Vec::new();
498 let mut join_conditions = Vec::new();
499
500 if !self.collect_join_tree(op, &mut relations, &mut join_conditions) {
501 return None;
502 }
503
504 if relations.len() < 2 {
505 return None;
506 }
507
508 Some((relations, join_conditions))
509 }
510
511 fn collect_join_tree(
515 &self,
516 op: &LogicalOperator,
517 relations: &mut Vec<(String, LogicalOperator)>,
518 conditions: &mut Vec<JoinInfo>,
519 ) -> bool {
520 match op {
521 LogicalOperator::Join(join) => {
522 let left_ok = self.collect_join_tree(&join.left, relations, conditions);
524 let right_ok = self.collect_join_tree(&join.right, relations, conditions);
525
526 for cond in &join.conditions {
528 if let (Some(left_var), Some(right_var)) = (
529 self.extract_variable_from_expr(&cond.left),
530 self.extract_variable_from_expr(&cond.right),
531 ) {
532 conditions.push(JoinInfo {
533 left_var,
534 right_var,
535 left_expr: cond.left.clone(),
536 right_expr: cond.right.clone(),
537 });
538 }
539 }
540
541 left_ok && right_ok
542 }
543 LogicalOperator::NodeScan(scan) => {
544 relations.push((scan.variable.clone(), op.clone()));
545 true
546 }
547 LogicalOperator::EdgeScan(scan) => {
548 relations.push((scan.variable.clone(), op.clone()));
549 true
550 }
551 LogicalOperator::Filter(filter) => {
552 self.collect_join_tree(&filter.input, relations, conditions)
554 }
555 LogicalOperator::Expand(expand) => {
556 relations.push((expand.to_variable.clone(), op.clone()));
559 true
560 }
561 _ => false,
562 }
563 }
564
565 fn extract_variable_from_expr(&self, expr: &LogicalExpression) -> Option<String> {
567 match expr {
568 LogicalExpression::Variable(v) => Some(v.clone()),
569 LogicalExpression::Property { variable, .. } => Some(variable.clone()),
570 _ => None,
571 }
572 }
573
574 fn optimize_join_order(
576 &self,
577 relations: &[(String, LogicalOperator)],
578 conditions: &[JoinInfo],
579 ) -> Option<LogicalOperator> {
580 use join_order::{DPccp, JoinGraphBuilder};
581
582 let mut builder = JoinGraphBuilder::new();
584
585 for (var, relation) in relations {
586 builder.add_relation(var, relation.clone());
587 }
588
589 for cond in conditions {
590 builder.add_join_condition(
591 &cond.left_var,
592 &cond.right_var,
593 cond.left_expr.clone(),
594 cond.right_expr.clone(),
595 );
596 }
597
598 let graph = builder.build();
599
600 let mut dpccp = DPccp::new(&graph, &self.cost_model, &self.card_estimator);
602 let plan = dpccp.optimize()?;
603
604 Some(plan.operator)
605 }
606
607 fn push_filters_down(&self, op: LogicalOperator) -> LogicalOperator {
612 match op {
613 LogicalOperator::Filter(filter) => {
615 let optimized_input = self.push_filters_down(*filter.input);
616 self.try_push_filter_into(filter.predicate, optimized_input)
617 }
618 LogicalOperator::Return(mut ret) => {
620 ret.input = Box::new(self.push_filters_down(*ret.input));
621 LogicalOperator::Return(ret)
622 }
623 LogicalOperator::Project(mut proj) => {
624 proj.input = Box::new(self.push_filters_down(*proj.input));
625 LogicalOperator::Project(proj)
626 }
627 LogicalOperator::Limit(mut limit) => {
628 limit.input = Box::new(self.push_filters_down(*limit.input));
629 LogicalOperator::Limit(limit)
630 }
631 LogicalOperator::Skip(mut skip) => {
632 skip.input = Box::new(self.push_filters_down(*skip.input));
633 LogicalOperator::Skip(skip)
634 }
635 LogicalOperator::Sort(mut sort) => {
636 sort.input = Box::new(self.push_filters_down(*sort.input));
637 LogicalOperator::Sort(sort)
638 }
639 LogicalOperator::Distinct(mut distinct) => {
640 distinct.input = Box::new(self.push_filters_down(*distinct.input));
641 LogicalOperator::Distinct(distinct)
642 }
643 LogicalOperator::Expand(mut expand) => {
644 expand.input = Box::new(self.push_filters_down(*expand.input));
645 LogicalOperator::Expand(expand)
646 }
647 LogicalOperator::Join(mut join) => {
648 join.left = Box::new(self.push_filters_down(*join.left));
649 join.right = Box::new(self.push_filters_down(*join.right));
650 LogicalOperator::Join(join)
651 }
652 LogicalOperator::Aggregate(mut agg) => {
653 agg.input = Box::new(self.push_filters_down(*agg.input));
654 LogicalOperator::Aggregate(agg)
655 }
656 other => other,
658 }
659 }
660
661 fn try_push_filter_into(
666 &self,
667 predicate: LogicalExpression,
668 op: LogicalOperator,
669 ) -> LogicalOperator {
670 match op {
671 LogicalOperator::Project(mut proj) => {
673 let predicate_vars = self.extract_variables(&predicate);
674 let computed_vars = self.extract_projection_aliases(&proj.projections);
675
676 if predicate_vars.is_disjoint(&computed_vars) {
678 proj.input = Box::new(self.try_push_filter_into(predicate, *proj.input));
679 LogicalOperator::Project(proj)
680 } else {
681 LogicalOperator::Filter(FilterOp {
683 predicate,
684 input: Box::new(LogicalOperator::Project(proj)),
685 })
686 }
687 }
688
689 LogicalOperator::Return(mut ret) => {
691 ret.input = Box::new(self.try_push_filter_into(predicate, *ret.input));
692 LogicalOperator::Return(ret)
693 }
694
695 LogicalOperator::Expand(mut expand) => {
697 let predicate_vars = self.extract_variables(&predicate);
698
699 let mut introduced_vars = vec![&expand.to_variable];
704 if let Some(ref edge_var) = expand.edge_variable {
705 introduced_vars.push(edge_var);
706 }
707 if let Some(ref path_alias) = expand.path_alias {
708 introduced_vars.push(path_alias);
709 }
710
711 let uses_introduced_vars =
713 predicate_vars.iter().any(|v| introduced_vars.contains(&v));
714
715 if !uses_introduced_vars {
716 expand.input = Box::new(self.try_push_filter_into(predicate, *expand.input));
718 LogicalOperator::Expand(expand)
719 } else {
720 LogicalOperator::Filter(FilterOp {
722 predicate,
723 input: Box::new(LogicalOperator::Expand(expand)),
724 })
725 }
726 }
727
728 LogicalOperator::Join(mut join) => {
730 let predicate_vars = self.extract_variables(&predicate);
731 let left_vars = self.collect_output_variables(&join.left);
732 let right_vars = self.collect_output_variables(&join.right);
733
734 let uses_left = predicate_vars.iter().any(|v| left_vars.contains(v));
735 let uses_right = predicate_vars.iter().any(|v| right_vars.contains(v));
736
737 if uses_left && !uses_right {
738 join.left = Box::new(self.try_push_filter_into(predicate, *join.left));
740 LogicalOperator::Join(join)
741 } else if uses_right && !uses_left {
742 join.right = Box::new(self.try_push_filter_into(predicate, *join.right));
744 LogicalOperator::Join(join)
745 } else {
746 LogicalOperator::Filter(FilterOp {
748 predicate,
749 input: Box::new(LogicalOperator::Join(join)),
750 })
751 }
752 }
753
754 LogicalOperator::Aggregate(agg) => LogicalOperator::Filter(FilterOp {
756 predicate,
757 input: Box::new(LogicalOperator::Aggregate(agg)),
758 }),
759
760 LogicalOperator::NodeScan(scan) => LogicalOperator::Filter(FilterOp {
762 predicate,
763 input: Box::new(LogicalOperator::NodeScan(scan)),
764 }),
765
766 other => LogicalOperator::Filter(FilterOp {
768 predicate,
769 input: Box::new(other),
770 }),
771 }
772 }
773
774 fn collect_output_variables(&self, op: &LogicalOperator) -> HashSet<String> {
776 let mut vars = HashSet::new();
777 Self::collect_output_variables_recursive(op, &mut vars);
778 vars
779 }
780
781 fn collect_output_variables_recursive(op: &LogicalOperator, vars: &mut HashSet<String>) {
783 match op {
784 LogicalOperator::NodeScan(scan) => {
785 vars.insert(scan.variable.clone());
786 }
787 LogicalOperator::EdgeScan(scan) => {
788 vars.insert(scan.variable.clone());
789 }
790 LogicalOperator::Expand(expand) => {
791 vars.insert(expand.to_variable.clone());
792 if let Some(edge_var) = &expand.edge_variable {
793 vars.insert(edge_var.clone());
794 }
795 Self::collect_output_variables_recursive(&expand.input, vars);
796 }
797 LogicalOperator::Filter(filter) => {
798 Self::collect_output_variables_recursive(&filter.input, vars);
799 }
800 LogicalOperator::Project(proj) => {
801 for p in &proj.projections {
802 if let Some(alias) = &p.alias {
803 vars.insert(alias.clone());
804 }
805 }
806 Self::collect_output_variables_recursive(&proj.input, vars);
807 }
808 LogicalOperator::Join(join) => {
809 Self::collect_output_variables_recursive(&join.left, vars);
810 Self::collect_output_variables_recursive(&join.right, vars);
811 }
812 LogicalOperator::Aggregate(agg) => {
813 for expr in &agg.group_by {
814 Self::collect_variables(expr, vars);
815 }
816 for agg_expr in &agg.aggregates {
817 if let Some(alias) = &agg_expr.alias {
818 vars.insert(alias.clone());
819 }
820 }
821 }
822 LogicalOperator::Return(ret) => {
823 Self::collect_output_variables_recursive(&ret.input, vars);
824 }
825 LogicalOperator::Limit(limit) => {
826 Self::collect_output_variables_recursive(&limit.input, vars);
827 }
828 LogicalOperator::Skip(skip) => {
829 Self::collect_output_variables_recursive(&skip.input, vars);
830 }
831 LogicalOperator::Sort(sort) => {
832 Self::collect_output_variables_recursive(&sort.input, vars);
833 }
834 LogicalOperator::Distinct(distinct) => {
835 Self::collect_output_variables_recursive(&distinct.input, vars);
836 }
837 _ => {}
838 }
839 }
840
841 fn extract_variables(&self, expr: &LogicalExpression) -> HashSet<String> {
843 let mut vars = HashSet::new();
844 Self::collect_variables(expr, &mut vars);
845 vars
846 }
847
848 fn collect_variables(expr: &LogicalExpression, vars: &mut HashSet<String>) {
850 match expr {
851 LogicalExpression::Variable(name) => {
852 vars.insert(name.clone());
853 }
854 LogicalExpression::Property { variable, .. } => {
855 vars.insert(variable.clone());
856 }
857 LogicalExpression::Binary { left, right, .. } => {
858 Self::collect_variables(left, vars);
859 Self::collect_variables(right, vars);
860 }
861 LogicalExpression::Unary { operand, .. } => {
862 Self::collect_variables(operand, vars);
863 }
864 LogicalExpression::FunctionCall { args, .. } => {
865 for arg in args {
866 Self::collect_variables(arg, vars);
867 }
868 }
869 LogicalExpression::List(items) => {
870 for item in items {
871 Self::collect_variables(item, vars);
872 }
873 }
874 LogicalExpression::Map(pairs) => {
875 for (_, value) in pairs {
876 Self::collect_variables(value, vars);
877 }
878 }
879 LogicalExpression::IndexAccess { base, index } => {
880 Self::collect_variables(base, vars);
881 Self::collect_variables(index, vars);
882 }
883 LogicalExpression::SliceAccess { base, start, end } => {
884 Self::collect_variables(base, vars);
885 if let Some(s) = start {
886 Self::collect_variables(s, vars);
887 }
888 if let Some(e) = end {
889 Self::collect_variables(e, vars);
890 }
891 }
892 LogicalExpression::Case {
893 operand,
894 when_clauses,
895 else_clause,
896 } => {
897 if let Some(op) = operand {
898 Self::collect_variables(op, vars);
899 }
900 for (cond, result) in when_clauses {
901 Self::collect_variables(cond, vars);
902 Self::collect_variables(result, vars);
903 }
904 if let Some(else_expr) = else_clause {
905 Self::collect_variables(else_expr, vars);
906 }
907 }
908 LogicalExpression::Labels(var)
909 | LogicalExpression::Type(var)
910 | LogicalExpression::Id(var) => {
911 vars.insert(var.clone());
912 }
913 LogicalExpression::Literal(_) | LogicalExpression::Parameter(_) => {}
914 LogicalExpression::ListComprehension {
915 list_expr,
916 filter_expr,
917 map_expr,
918 ..
919 } => {
920 Self::collect_variables(list_expr, vars);
921 if let Some(filter) = filter_expr {
922 Self::collect_variables(filter, vars);
923 }
924 Self::collect_variables(map_expr, vars);
925 }
926 LogicalExpression::ExistsSubquery(_) | LogicalExpression::CountSubquery(_) => {
927 }
929 }
930 }
931
932 fn extract_projection_aliases(
934 &self,
935 projections: &[crate::query::plan::Projection],
936 ) -> HashSet<String> {
937 projections.iter().filter_map(|p| p.alias.clone()).collect()
938 }
939}
940
941impl Default for Optimizer {
942 fn default() -> Self {
943 Self::new()
944 }
945}
946
947#[cfg(test)]
948mod tests {
949 use super::*;
950 use crate::query::plan::{
951 AggregateExpr, AggregateFunction, AggregateOp, BinaryOp, DistinctOp, ExpandDirection,
952 ExpandOp, JoinOp, JoinType, LimitOp, NodeScanOp, ProjectOp, Projection, ReturnItem,
953 ReturnOp, SkipOp, SortKey, SortOp, SortOrder, UnaryOp,
954 };
955 use grafeo_common::types::Value;
956
957 #[test]
958 fn test_optimizer_filter_pushdown_simple() {
959 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
964 items: vec![ReturnItem {
965 expression: LogicalExpression::Variable("n".to_string()),
966 alias: None,
967 }],
968 distinct: false,
969 input: Box::new(LogicalOperator::Filter(FilterOp {
970 predicate: LogicalExpression::Binary {
971 left: Box::new(LogicalExpression::Property {
972 variable: "n".to_string(),
973 property: "age".to_string(),
974 }),
975 op: BinaryOp::Gt,
976 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
977 },
978 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
979 variable: "n".to_string(),
980 label: Some("Person".to_string()),
981 input: None,
982 })),
983 })),
984 }));
985
986 let optimizer = Optimizer::new();
987 let optimized = optimizer.optimize(plan).unwrap();
988
989 if let LogicalOperator::Return(ret) = &optimized.root
991 && let LogicalOperator::Filter(filter) = ret.input.as_ref()
992 && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
993 {
994 assert_eq!(scan.variable, "n");
995 return;
996 }
997 panic!("Expected Return -> Filter -> NodeScan structure");
998 }
999
1000 #[test]
1001 fn test_optimizer_filter_pushdown_through_expand() {
1002 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1006 items: vec![ReturnItem {
1007 expression: LogicalExpression::Variable("b".to_string()),
1008 alias: None,
1009 }],
1010 distinct: false,
1011 input: Box::new(LogicalOperator::Filter(FilterOp {
1012 predicate: LogicalExpression::Binary {
1013 left: Box::new(LogicalExpression::Property {
1014 variable: "a".to_string(),
1015 property: "age".to_string(),
1016 }),
1017 op: BinaryOp::Gt,
1018 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1019 },
1020 input: Box::new(LogicalOperator::Expand(ExpandOp {
1021 from_variable: "a".to_string(),
1022 to_variable: "b".to_string(),
1023 edge_variable: None,
1024 direction: ExpandDirection::Outgoing,
1025 edge_type: Some("KNOWS".to_string()),
1026 min_hops: 1,
1027 max_hops: Some(1),
1028 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1029 variable: "a".to_string(),
1030 label: Some("Person".to_string()),
1031 input: None,
1032 })),
1033 path_alias: None,
1034 })),
1035 })),
1036 }));
1037
1038 let optimizer = Optimizer::new();
1039 let optimized = optimizer.optimize(plan).unwrap();
1040
1041 if let LogicalOperator::Return(ret) = &optimized.root
1044 && let LogicalOperator::Expand(expand) = ret.input.as_ref()
1045 && let LogicalOperator::Filter(filter) = expand.input.as_ref()
1046 && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1047 {
1048 assert_eq!(scan.variable, "a");
1049 assert_eq!(expand.from_variable, "a");
1050 assert_eq!(expand.to_variable, "b");
1051 return;
1052 }
1053 panic!("Expected Return -> Expand -> Filter -> NodeScan structure");
1054 }
1055
1056 #[test]
1057 fn test_optimizer_filter_not_pushed_through_expand_for_target_var() {
1058 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1062 items: vec![ReturnItem {
1063 expression: LogicalExpression::Variable("a".to_string()),
1064 alias: None,
1065 }],
1066 distinct: false,
1067 input: Box::new(LogicalOperator::Filter(FilterOp {
1068 predicate: LogicalExpression::Binary {
1069 left: Box::new(LogicalExpression::Property {
1070 variable: "b".to_string(),
1071 property: "age".to_string(),
1072 }),
1073 op: BinaryOp::Gt,
1074 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1075 },
1076 input: Box::new(LogicalOperator::Expand(ExpandOp {
1077 from_variable: "a".to_string(),
1078 to_variable: "b".to_string(),
1079 edge_variable: None,
1080 direction: ExpandDirection::Outgoing,
1081 edge_type: Some("KNOWS".to_string()),
1082 min_hops: 1,
1083 max_hops: Some(1),
1084 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1085 variable: "a".to_string(),
1086 label: Some("Person".to_string()),
1087 input: None,
1088 })),
1089 path_alias: None,
1090 })),
1091 })),
1092 }));
1093
1094 let optimizer = Optimizer::new();
1095 let optimized = optimizer.optimize(plan).unwrap();
1096
1097 if let LogicalOperator::Return(ret) = &optimized.root
1100 && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1101 {
1102 if let LogicalExpression::Binary { left, .. } = &filter.predicate
1104 && let LogicalExpression::Property { variable, .. } = left.as_ref()
1105 {
1106 assert_eq!(variable, "b");
1107 }
1108
1109 if let LogicalOperator::Expand(expand) = filter.input.as_ref()
1110 && let LogicalOperator::NodeScan(_) = expand.input.as_ref()
1111 {
1112 return;
1113 }
1114 }
1115 panic!("Expected Return -> Filter -> Expand -> NodeScan structure");
1116 }
1117
1118 #[test]
1119 fn test_optimizer_extract_variables() {
1120 let optimizer = Optimizer::new();
1121
1122 let expr = LogicalExpression::Binary {
1123 left: Box::new(LogicalExpression::Property {
1124 variable: "n".to_string(),
1125 property: "age".to_string(),
1126 }),
1127 op: BinaryOp::Gt,
1128 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1129 };
1130
1131 let vars = optimizer.extract_variables(&expr);
1132 assert_eq!(vars.len(), 1);
1133 assert!(vars.contains("n"));
1134 }
1135
1136 #[test]
1139 fn test_optimizer_default() {
1140 let optimizer = Optimizer::default();
1141 let plan = LogicalPlan::new(LogicalOperator::Empty);
1143 let result = optimizer.optimize(plan);
1144 assert!(result.is_ok());
1145 }
1146
1147 #[test]
1148 fn test_optimizer_with_filter_pushdown_disabled() {
1149 let optimizer = Optimizer::new().with_filter_pushdown(false);
1150
1151 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1152 items: vec![ReturnItem {
1153 expression: LogicalExpression::Variable("n".to_string()),
1154 alias: None,
1155 }],
1156 distinct: false,
1157 input: Box::new(LogicalOperator::Filter(FilterOp {
1158 predicate: LogicalExpression::Literal(Value::Bool(true)),
1159 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1160 variable: "n".to_string(),
1161 label: None,
1162 input: None,
1163 })),
1164 })),
1165 }));
1166
1167 let optimized = optimizer.optimize(plan).unwrap();
1168 if let LogicalOperator::Return(ret) = &optimized.root
1170 && let LogicalOperator::Filter(_) = ret.input.as_ref()
1171 {
1172 return;
1173 }
1174 panic!("Expected unchanged structure");
1175 }
1176
1177 #[test]
1178 fn test_optimizer_with_join_reorder_disabled() {
1179 let optimizer = Optimizer::new().with_join_reorder(false);
1180 assert!(
1181 optimizer
1182 .optimize(LogicalPlan::new(LogicalOperator::Empty))
1183 .is_ok()
1184 );
1185 }
1186
1187 #[test]
1188 fn test_optimizer_with_cost_model() {
1189 let cost_model = CostModel::new();
1190 let optimizer = Optimizer::new().with_cost_model(cost_model);
1191 assert!(
1192 optimizer
1193 .cost_model()
1194 .estimate(&LogicalOperator::Empty, 0.0)
1195 .total()
1196 < 0.001
1197 );
1198 }
1199
1200 #[test]
1201 fn test_optimizer_with_cardinality_estimator() {
1202 let mut estimator = CardinalityEstimator::new();
1203 estimator.add_table_stats("Test", TableStats::new(500));
1204 let optimizer = Optimizer::new().with_cardinality_estimator(estimator);
1205
1206 let scan = LogicalOperator::NodeScan(NodeScanOp {
1207 variable: "n".to_string(),
1208 label: Some("Test".to_string()),
1209 input: None,
1210 });
1211 let plan = LogicalPlan::new(scan);
1212
1213 let cardinality = optimizer.estimate_cardinality(&plan);
1214 assert!((cardinality - 500.0).abs() < 0.001);
1215 }
1216
1217 #[test]
1218 fn test_optimizer_estimate_cost() {
1219 let optimizer = Optimizer::new();
1220 let plan = LogicalPlan::new(LogicalOperator::NodeScan(NodeScanOp {
1221 variable: "n".to_string(),
1222 label: None,
1223 input: None,
1224 }));
1225
1226 let cost = optimizer.estimate_cost(&plan);
1227 assert!(cost.total() > 0.0);
1228 }
1229
1230 #[test]
1233 fn test_filter_pushdown_through_project() {
1234 let optimizer = Optimizer::new();
1235
1236 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1237 predicate: LogicalExpression::Binary {
1238 left: Box::new(LogicalExpression::Property {
1239 variable: "n".to_string(),
1240 property: "age".to_string(),
1241 }),
1242 op: BinaryOp::Gt,
1243 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1244 },
1245 input: Box::new(LogicalOperator::Project(ProjectOp {
1246 projections: vec![Projection {
1247 expression: LogicalExpression::Variable("n".to_string()),
1248 alias: None,
1249 }],
1250 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1251 variable: "n".to_string(),
1252 label: None,
1253 input: None,
1254 })),
1255 })),
1256 }));
1257
1258 let optimized = optimizer.optimize(plan).unwrap();
1259
1260 if let LogicalOperator::Project(proj) = &optimized.root
1262 && let LogicalOperator::Filter(_) = proj.input.as_ref()
1263 {
1264 return;
1265 }
1266 panic!("Expected Project -> Filter structure");
1267 }
1268
1269 #[test]
1270 fn test_filter_not_pushed_through_project_with_alias() {
1271 let optimizer = Optimizer::new();
1272
1273 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1275 predicate: LogicalExpression::Binary {
1276 left: Box::new(LogicalExpression::Variable("x".to_string())),
1277 op: BinaryOp::Gt,
1278 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1279 },
1280 input: Box::new(LogicalOperator::Project(ProjectOp {
1281 projections: vec![Projection {
1282 expression: LogicalExpression::Property {
1283 variable: "n".to_string(),
1284 property: "age".to_string(),
1285 },
1286 alias: Some("x".to_string()),
1287 }],
1288 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1289 variable: "n".to_string(),
1290 label: None,
1291 input: None,
1292 })),
1293 })),
1294 }));
1295
1296 let optimized = optimizer.optimize(plan).unwrap();
1297
1298 if let LogicalOperator::Filter(filter) = &optimized.root
1300 && let LogicalOperator::Project(_) = filter.input.as_ref()
1301 {
1302 return;
1303 }
1304 panic!("Expected Filter -> Project structure");
1305 }
1306
1307 #[test]
1308 fn test_filter_pushdown_through_limit() {
1309 let optimizer = Optimizer::new();
1310
1311 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1312 predicate: LogicalExpression::Literal(Value::Bool(true)),
1313 input: Box::new(LogicalOperator::Limit(LimitOp {
1314 count: 10,
1315 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1316 variable: "n".to_string(),
1317 label: None,
1318 input: None,
1319 })),
1320 })),
1321 }));
1322
1323 let optimized = optimizer.optimize(plan).unwrap();
1324
1325 if let LogicalOperator::Filter(filter) = &optimized.root
1327 && let LogicalOperator::Limit(_) = filter.input.as_ref()
1328 {
1329 return;
1330 }
1331 panic!("Expected Filter -> Limit structure");
1332 }
1333
1334 #[test]
1335 fn test_filter_pushdown_through_sort() {
1336 let optimizer = Optimizer::new();
1337
1338 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1339 predicate: LogicalExpression::Literal(Value::Bool(true)),
1340 input: Box::new(LogicalOperator::Sort(SortOp {
1341 keys: vec![SortKey {
1342 expression: LogicalExpression::Variable("n".to_string()),
1343 order: SortOrder::Ascending,
1344 }],
1345 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1346 variable: "n".to_string(),
1347 label: None,
1348 input: None,
1349 })),
1350 })),
1351 }));
1352
1353 let optimized = optimizer.optimize(plan).unwrap();
1354
1355 if let LogicalOperator::Filter(filter) = &optimized.root
1357 && let LogicalOperator::Sort(_) = filter.input.as_ref()
1358 {
1359 return;
1360 }
1361 panic!("Expected Filter -> Sort structure");
1362 }
1363
1364 #[test]
1365 fn test_filter_pushdown_through_distinct() {
1366 let optimizer = Optimizer::new();
1367
1368 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1369 predicate: LogicalExpression::Literal(Value::Bool(true)),
1370 input: Box::new(LogicalOperator::Distinct(DistinctOp {
1371 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1372 variable: "n".to_string(),
1373 label: None,
1374 input: None,
1375 })),
1376 columns: None,
1377 })),
1378 }));
1379
1380 let optimized = optimizer.optimize(plan).unwrap();
1381
1382 if let LogicalOperator::Filter(filter) = &optimized.root
1384 && let LogicalOperator::Distinct(_) = filter.input.as_ref()
1385 {
1386 return;
1387 }
1388 panic!("Expected Filter -> Distinct structure");
1389 }
1390
1391 #[test]
1392 fn test_filter_not_pushed_through_aggregate() {
1393 let optimizer = Optimizer::new();
1394
1395 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1396 predicate: LogicalExpression::Binary {
1397 left: Box::new(LogicalExpression::Variable("cnt".to_string())),
1398 op: BinaryOp::Gt,
1399 right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1400 },
1401 input: Box::new(LogicalOperator::Aggregate(AggregateOp {
1402 group_by: vec![],
1403 aggregates: vec![AggregateExpr {
1404 function: AggregateFunction::Count,
1405 expression: None,
1406 distinct: false,
1407 alias: Some("cnt".to_string()),
1408 percentile: None,
1409 }],
1410 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1411 variable: "n".to_string(),
1412 label: None,
1413 input: None,
1414 })),
1415 having: None,
1416 })),
1417 }));
1418
1419 let optimized = optimizer.optimize(plan).unwrap();
1420
1421 if let LogicalOperator::Filter(filter) = &optimized.root
1423 && let LogicalOperator::Aggregate(_) = filter.input.as_ref()
1424 {
1425 return;
1426 }
1427 panic!("Expected Filter -> Aggregate structure");
1428 }
1429
1430 #[test]
1431 fn test_filter_pushdown_to_left_join_side() {
1432 let optimizer = Optimizer::new();
1433
1434 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1436 predicate: LogicalExpression::Binary {
1437 left: Box::new(LogicalExpression::Property {
1438 variable: "a".to_string(),
1439 property: "age".to_string(),
1440 }),
1441 op: BinaryOp::Gt,
1442 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1443 },
1444 input: Box::new(LogicalOperator::Join(JoinOp {
1445 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1446 variable: "a".to_string(),
1447 label: Some("Person".to_string()),
1448 input: None,
1449 })),
1450 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1451 variable: "b".to_string(),
1452 label: Some("Company".to_string()),
1453 input: None,
1454 })),
1455 join_type: JoinType::Inner,
1456 conditions: vec![],
1457 })),
1458 }));
1459
1460 let optimized = optimizer.optimize(plan).unwrap();
1461
1462 if let LogicalOperator::Join(join) = &optimized.root
1464 && let LogicalOperator::Filter(_) = join.left.as_ref()
1465 {
1466 return;
1467 }
1468 panic!("Expected Join with Filter on left side");
1469 }
1470
1471 #[test]
1472 fn test_filter_pushdown_to_right_join_side() {
1473 let optimizer = Optimizer::new();
1474
1475 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1477 predicate: LogicalExpression::Binary {
1478 left: Box::new(LogicalExpression::Property {
1479 variable: "b".to_string(),
1480 property: "name".to_string(),
1481 }),
1482 op: BinaryOp::Eq,
1483 right: Box::new(LogicalExpression::Literal(Value::String("Acme".into()))),
1484 },
1485 input: Box::new(LogicalOperator::Join(JoinOp {
1486 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1487 variable: "a".to_string(),
1488 label: Some("Person".to_string()),
1489 input: None,
1490 })),
1491 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1492 variable: "b".to_string(),
1493 label: Some("Company".to_string()),
1494 input: None,
1495 })),
1496 join_type: JoinType::Inner,
1497 conditions: vec![],
1498 })),
1499 }));
1500
1501 let optimized = optimizer.optimize(plan).unwrap();
1502
1503 if let LogicalOperator::Join(join) = &optimized.root
1505 && let LogicalOperator::Filter(_) = join.right.as_ref()
1506 {
1507 return;
1508 }
1509 panic!("Expected Join with Filter on right side");
1510 }
1511
1512 #[test]
1513 fn test_filter_not_pushed_when_uses_both_join_sides() {
1514 let optimizer = Optimizer::new();
1515
1516 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1518 predicate: LogicalExpression::Binary {
1519 left: Box::new(LogicalExpression::Property {
1520 variable: "a".to_string(),
1521 property: "id".to_string(),
1522 }),
1523 op: BinaryOp::Eq,
1524 right: Box::new(LogicalExpression::Property {
1525 variable: "b".to_string(),
1526 property: "a_id".to_string(),
1527 }),
1528 },
1529 input: Box::new(LogicalOperator::Join(JoinOp {
1530 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1531 variable: "a".to_string(),
1532 label: None,
1533 input: None,
1534 })),
1535 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1536 variable: "b".to_string(),
1537 label: None,
1538 input: None,
1539 })),
1540 join_type: JoinType::Inner,
1541 conditions: vec![],
1542 })),
1543 }));
1544
1545 let optimized = optimizer.optimize(plan).unwrap();
1546
1547 if let LogicalOperator::Filter(filter) = &optimized.root
1549 && let LogicalOperator::Join(_) = filter.input.as_ref()
1550 {
1551 return;
1552 }
1553 panic!("Expected Filter -> Join structure");
1554 }
1555
1556 #[test]
1559 fn test_extract_variables_from_variable() {
1560 let optimizer = Optimizer::new();
1561 let expr = LogicalExpression::Variable("x".to_string());
1562 let vars = optimizer.extract_variables(&expr);
1563 assert_eq!(vars.len(), 1);
1564 assert!(vars.contains("x"));
1565 }
1566
1567 #[test]
1568 fn test_extract_variables_from_unary() {
1569 let optimizer = Optimizer::new();
1570 let expr = LogicalExpression::Unary {
1571 op: UnaryOp::Not,
1572 operand: Box::new(LogicalExpression::Variable("x".to_string())),
1573 };
1574 let vars = optimizer.extract_variables(&expr);
1575 assert_eq!(vars.len(), 1);
1576 assert!(vars.contains("x"));
1577 }
1578
1579 #[test]
1580 fn test_extract_variables_from_function_call() {
1581 let optimizer = Optimizer::new();
1582 let expr = LogicalExpression::FunctionCall {
1583 name: "length".to_string(),
1584 args: vec![
1585 LogicalExpression::Variable("a".to_string()),
1586 LogicalExpression::Variable("b".to_string()),
1587 ],
1588 distinct: false,
1589 };
1590 let vars = optimizer.extract_variables(&expr);
1591 assert_eq!(vars.len(), 2);
1592 assert!(vars.contains("a"));
1593 assert!(vars.contains("b"));
1594 }
1595
1596 #[test]
1597 fn test_extract_variables_from_list() {
1598 let optimizer = Optimizer::new();
1599 let expr = LogicalExpression::List(vec![
1600 LogicalExpression::Variable("a".to_string()),
1601 LogicalExpression::Literal(Value::Int64(1)),
1602 LogicalExpression::Variable("b".to_string()),
1603 ]);
1604 let vars = optimizer.extract_variables(&expr);
1605 assert_eq!(vars.len(), 2);
1606 assert!(vars.contains("a"));
1607 assert!(vars.contains("b"));
1608 }
1609
1610 #[test]
1611 fn test_extract_variables_from_map() {
1612 let optimizer = Optimizer::new();
1613 let expr = LogicalExpression::Map(vec![
1614 (
1615 "key1".to_string(),
1616 LogicalExpression::Variable("a".to_string()),
1617 ),
1618 (
1619 "key2".to_string(),
1620 LogicalExpression::Variable("b".to_string()),
1621 ),
1622 ]);
1623 let vars = optimizer.extract_variables(&expr);
1624 assert_eq!(vars.len(), 2);
1625 assert!(vars.contains("a"));
1626 assert!(vars.contains("b"));
1627 }
1628
1629 #[test]
1630 fn test_extract_variables_from_index_access() {
1631 let optimizer = Optimizer::new();
1632 let expr = LogicalExpression::IndexAccess {
1633 base: Box::new(LogicalExpression::Variable("list".to_string())),
1634 index: Box::new(LogicalExpression::Variable("idx".to_string())),
1635 };
1636 let vars = optimizer.extract_variables(&expr);
1637 assert_eq!(vars.len(), 2);
1638 assert!(vars.contains("list"));
1639 assert!(vars.contains("idx"));
1640 }
1641
1642 #[test]
1643 fn test_extract_variables_from_slice_access() {
1644 let optimizer = Optimizer::new();
1645 let expr = LogicalExpression::SliceAccess {
1646 base: Box::new(LogicalExpression::Variable("list".to_string())),
1647 start: Some(Box::new(LogicalExpression::Variable("s".to_string()))),
1648 end: Some(Box::new(LogicalExpression::Variable("e".to_string()))),
1649 };
1650 let vars = optimizer.extract_variables(&expr);
1651 assert_eq!(vars.len(), 3);
1652 assert!(vars.contains("list"));
1653 assert!(vars.contains("s"));
1654 assert!(vars.contains("e"));
1655 }
1656
1657 #[test]
1658 fn test_extract_variables_from_case() {
1659 let optimizer = Optimizer::new();
1660 let expr = LogicalExpression::Case {
1661 operand: Some(Box::new(LogicalExpression::Variable("x".to_string()))),
1662 when_clauses: vec![(
1663 LogicalExpression::Literal(Value::Int64(1)),
1664 LogicalExpression::Variable("a".to_string()),
1665 )],
1666 else_clause: Some(Box::new(LogicalExpression::Variable("b".to_string()))),
1667 };
1668 let vars = optimizer.extract_variables(&expr);
1669 assert_eq!(vars.len(), 3);
1670 assert!(vars.contains("x"));
1671 assert!(vars.contains("a"));
1672 assert!(vars.contains("b"));
1673 }
1674
1675 #[test]
1676 fn test_extract_variables_from_labels() {
1677 let optimizer = Optimizer::new();
1678 let expr = LogicalExpression::Labels("n".to_string());
1679 let vars = optimizer.extract_variables(&expr);
1680 assert_eq!(vars.len(), 1);
1681 assert!(vars.contains("n"));
1682 }
1683
1684 #[test]
1685 fn test_extract_variables_from_type() {
1686 let optimizer = Optimizer::new();
1687 let expr = LogicalExpression::Type("e".to_string());
1688 let vars = optimizer.extract_variables(&expr);
1689 assert_eq!(vars.len(), 1);
1690 assert!(vars.contains("e"));
1691 }
1692
1693 #[test]
1694 fn test_extract_variables_from_id() {
1695 let optimizer = Optimizer::new();
1696 let expr = LogicalExpression::Id("n".to_string());
1697 let vars = optimizer.extract_variables(&expr);
1698 assert_eq!(vars.len(), 1);
1699 assert!(vars.contains("n"));
1700 }
1701
1702 #[test]
1703 fn test_extract_variables_from_list_comprehension() {
1704 let optimizer = Optimizer::new();
1705 let expr = LogicalExpression::ListComprehension {
1706 variable: "x".to_string(),
1707 list_expr: Box::new(LogicalExpression::Variable("items".to_string())),
1708 filter_expr: Some(Box::new(LogicalExpression::Variable("pred".to_string()))),
1709 map_expr: Box::new(LogicalExpression::Variable("result".to_string())),
1710 };
1711 let vars = optimizer.extract_variables(&expr);
1712 assert!(vars.contains("items"));
1713 assert!(vars.contains("pred"));
1714 assert!(vars.contains("result"));
1715 }
1716
1717 #[test]
1718 fn test_extract_variables_from_literal_and_parameter() {
1719 let optimizer = Optimizer::new();
1720
1721 let literal = LogicalExpression::Literal(Value::Int64(42));
1722 assert!(optimizer.extract_variables(&literal).is_empty());
1723
1724 let param = LogicalExpression::Parameter("p".to_string());
1725 assert!(optimizer.extract_variables(¶m).is_empty());
1726 }
1727
1728 #[test]
1731 fn test_recursive_filter_pushdown_through_skip() {
1732 let optimizer = Optimizer::new();
1733
1734 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1735 items: vec![ReturnItem {
1736 expression: LogicalExpression::Variable("n".to_string()),
1737 alias: None,
1738 }],
1739 distinct: false,
1740 input: Box::new(LogicalOperator::Filter(FilterOp {
1741 predicate: LogicalExpression::Literal(Value::Bool(true)),
1742 input: Box::new(LogicalOperator::Skip(SkipOp {
1743 count: 5,
1744 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1745 variable: "n".to_string(),
1746 label: None,
1747 input: None,
1748 })),
1749 })),
1750 })),
1751 }));
1752
1753 let optimized = optimizer.optimize(plan).unwrap();
1754
1755 assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1757 }
1758
1759 #[test]
1760 fn test_nested_filter_pushdown() {
1761 let optimizer = Optimizer::new();
1762
1763 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1765 items: vec![ReturnItem {
1766 expression: LogicalExpression::Variable("n".to_string()),
1767 alias: None,
1768 }],
1769 distinct: false,
1770 input: Box::new(LogicalOperator::Filter(FilterOp {
1771 predicate: LogicalExpression::Binary {
1772 left: Box::new(LogicalExpression::Property {
1773 variable: "n".to_string(),
1774 property: "x".to_string(),
1775 }),
1776 op: BinaryOp::Gt,
1777 right: Box::new(LogicalExpression::Literal(Value::Int64(1))),
1778 },
1779 input: Box::new(LogicalOperator::Filter(FilterOp {
1780 predicate: LogicalExpression::Binary {
1781 left: Box::new(LogicalExpression::Property {
1782 variable: "n".to_string(),
1783 property: "y".to_string(),
1784 }),
1785 op: BinaryOp::Lt,
1786 right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1787 },
1788 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1789 variable: "n".to_string(),
1790 label: None,
1791 input: None,
1792 })),
1793 })),
1794 })),
1795 }));
1796
1797 let optimized = optimizer.optimize(plan).unwrap();
1798 assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1799 }
1800}