1pub mod cardinality;
15pub mod cost;
16pub mod join_order;
17
18pub use cardinality::{
19 CardinalityEstimator, ColumnStats, EstimationLog, SelectivityConfig, TableStats,
20};
21pub use cost::{Cost, CostModel};
22pub use join_order::{BitSet, DPccp, JoinGraph, JoinGraphBuilder, JoinPlan};
23
24use crate::query::plan::{FilterOp, LogicalExpression, LogicalOperator, LogicalPlan};
25use grafeo_common::utils::error::Result;
26use std::collections::HashSet;
27
28#[derive(Debug, Clone)]
30struct JoinInfo {
31 left_var: String,
32 right_var: String,
33 left_expr: LogicalExpression,
34 right_expr: LogicalExpression,
35}
36
37#[derive(Debug, Clone, PartialEq, Eq, Hash)]
39enum RequiredColumn {
40 Variable(String),
42 Property(String, String),
44}
45
46pub struct Optimizer {
51 enable_filter_pushdown: bool,
53 enable_join_reorder: bool,
55 enable_projection_pushdown: bool,
57 cost_model: CostModel,
59 card_estimator: CardinalityEstimator,
61}
62
63impl Optimizer {
64 #[must_use]
66 pub fn new() -> Self {
67 Self {
68 enable_filter_pushdown: true,
69 enable_join_reorder: true,
70 enable_projection_pushdown: true,
71 cost_model: CostModel::new(),
72 card_estimator: CardinalityEstimator::new(),
73 }
74 }
75
76 #[must_use]
81 pub fn from_store(store: &grafeo_core::graph::lpg::LpgStore) -> Self {
82 store.ensure_statistics_fresh();
83 let stats = store.statistics();
84 let estimator = CardinalityEstimator::from_statistics(&stats);
85
86 let avg_fanout = if stats.total_nodes > 0 {
88 (stats.total_edges as f64 / stats.total_nodes as f64).max(1.0)
89 } else {
90 10.0
91 };
92
93 Self {
94 enable_filter_pushdown: true,
95 enable_join_reorder: true,
96 enable_projection_pushdown: true,
97 cost_model: CostModel::new().with_avg_fanout(avg_fanout),
98 card_estimator: estimator,
99 }
100 }
101
102 pub fn with_filter_pushdown(mut self, enabled: bool) -> Self {
104 self.enable_filter_pushdown = enabled;
105 self
106 }
107
108 pub fn with_join_reorder(mut self, enabled: bool) -> Self {
110 self.enable_join_reorder = enabled;
111 self
112 }
113
114 pub fn with_projection_pushdown(mut self, enabled: bool) -> Self {
116 self.enable_projection_pushdown = enabled;
117 self
118 }
119
120 pub fn with_cost_model(mut self, cost_model: CostModel) -> Self {
122 self.cost_model = cost_model;
123 self
124 }
125
126 pub fn with_cardinality_estimator(mut self, estimator: CardinalityEstimator) -> Self {
128 self.card_estimator = estimator;
129 self
130 }
131
132 pub fn with_selectivity_config(mut self, config: SelectivityConfig) -> Self {
134 self.card_estimator = CardinalityEstimator::with_selectivity_config(config);
135 self
136 }
137
138 pub fn cost_model(&self) -> &CostModel {
140 &self.cost_model
141 }
142
143 pub fn cardinality_estimator(&self) -> &CardinalityEstimator {
145 &self.card_estimator
146 }
147
148 pub fn estimate_cost(&self, plan: &LogicalPlan) -> Cost {
150 let cardinality = self.card_estimator.estimate(&plan.root);
151 self.cost_model.estimate(&plan.root, cardinality)
152 }
153
154 pub fn estimate_cardinality(&self, plan: &LogicalPlan) -> f64 {
156 self.card_estimator.estimate(&plan.root)
157 }
158
159 pub fn optimize(&self, plan: LogicalPlan) -> Result<LogicalPlan> {
165 let mut root = plan.root;
166
167 if self.enable_filter_pushdown {
169 root = self.push_filters_down(root);
170 }
171
172 if self.enable_join_reorder {
173 root = self.reorder_joins(root);
174 }
175
176 if self.enable_projection_pushdown {
177 root = self.push_projections_down(root);
178 }
179
180 Ok(LogicalPlan::new(root))
181 }
182
183 fn push_projections_down(&self, op: LogicalOperator) -> LogicalOperator {
190 let required = self.collect_required_columns(&op);
192
193 self.push_projections_recursive(op, &required)
195 }
196
197 fn collect_required_columns(&self, op: &LogicalOperator) -> HashSet<RequiredColumn> {
199 let mut required = HashSet::new();
200 Self::collect_required_recursive(op, &mut required);
201 required
202 }
203
204 fn collect_required_recursive(op: &LogicalOperator, required: &mut HashSet<RequiredColumn>) {
206 match op {
207 LogicalOperator::Return(ret) => {
208 for item in &ret.items {
209 Self::collect_from_expression(&item.expression, required);
210 }
211 Self::collect_required_recursive(&ret.input, required);
212 }
213 LogicalOperator::Project(proj) => {
214 for p in &proj.projections {
215 Self::collect_from_expression(&p.expression, required);
216 }
217 Self::collect_required_recursive(&proj.input, required);
218 }
219 LogicalOperator::Filter(filter) => {
220 Self::collect_from_expression(&filter.predicate, required);
221 Self::collect_required_recursive(&filter.input, required);
222 }
223 LogicalOperator::Sort(sort) => {
224 for key in &sort.keys {
225 Self::collect_from_expression(&key.expression, required);
226 }
227 Self::collect_required_recursive(&sort.input, required);
228 }
229 LogicalOperator::Aggregate(agg) => {
230 for expr in &agg.group_by {
231 Self::collect_from_expression(expr, required);
232 }
233 for agg_expr in &agg.aggregates {
234 if let Some(ref expr) = agg_expr.expression {
235 Self::collect_from_expression(expr, required);
236 }
237 }
238 if let Some(ref having) = agg.having {
239 Self::collect_from_expression(having, required);
240 }
241 Self::collect_required_recursive(&agg.input, required);
242 }
243 LogicalOperator::Join(join) => {
244 for cond in &join.conditions {
245 Self::collect_from_expression(&cond.left, required);
246 Self::collect_from_expression(&cond.right, required);
247 }
248 Self::collect_required_recursive(&join.left, required);
249 Self::collect_required_recursive(&join.right, required);
250 }
251 LogicalOperator::Expand(expand) => {
252 required.insert(RequiredColumn::Variable(expand.from_variable.clone()));
254 required.insert(RequiredColumn::Variable(expand.to_variable.clone()));
255 if let Some(ref edge_var) = expand.edge_variable {
256 required.insert(RequiredColumn::Variable(edge_var.clone()));
257 }
258 Self::collect_required_recursive(&expand.input, required);
259 }
260 LogicalOperator::Limit(limit) => {
261 Self::collect_required_recursive(&limit.input, required);
262 }
263 LogicalOperator::Skip(skip) => {
264 Self::collect_required_recursive(&skip.input, required);
265 }
266 LogicalOperator::Distinct(distinct) => {
267 Self::collect_required_recursive(&distinct.input, required);
268 }
269 LogicalOperator::NodeScan(scan) => {
270 required.insert(RequiredColumn::Variable(scan.variable.clone()));
271 }
272 LogicalOperator::EdgeScan(scan) => {
273 required.insert(RequiredColumn::Variable(scan.variable.clone()));
274 }
275 _ => {}
276 }
277 }
278
279 fn collect_from_expression(expr: &LogicalExpression, required: &mut HashSet<RequiredColumn>) {
281 match expr {
282 LogicalExpression::Variable(var) => {
283 required.insert(RequiredColumn::Variable(var.clone()));
284 }
285 LogicalExpression::Property { variable, property } => {
286 required.insert(RequiredColumn::Property(variable.clone(), property.clone()));
287 required.insert(RequiredColumn::Variable(variable.clone()));
288 }
289 LogicalExpression::Binary { left, right, .. } => {
290 Self::collect_from_expression(left, required);
291 Self::collect_from_expression(right, required);
292 }
293 LogicalExpression::Unary { operand, .. } => {
294 Self::collect_from_expression(operand, required);
295 }
296 LogicalExpression::FunctionCall { args, .. } => {
297 for arg in args {
298 Self::collect_from_expression(arg, required);
299 }
300 }
301 LogicalExpression::List(items) => {
302 for item in items {
303 Self::collect_from_expression(item, required);
304 }
305 }
306 LogicalExpression::Map(pairs) => {
307 for (_, value) in pairs {
308 Self::collect_from_expression(value, required);
309 }
310 }
311 LogicalExpression::IndexAccess { base, index } => {
312 Self::collect_from_expression(base, required);
313 Self::collect_from_expression(index, required);
314 }
315 LogicalExpression::SliceAccess { base, start, end } => {
316 Self::collect_from_expression(base, required);
317 if let Some(s) = start {
318 Self::collect_from_expression(s, required);
319 }
320 if let Some(e) = end {
321 Self::collect_from_expression(e, required);
322 }
323 }
324 LogicalExpression::Case {
325 operand,
326 when_clauses,
327 else_clause,
328 } => {
329 if let Some(op) = operand {
330 Self::collect_from_expression(op, required);
331 }
332 for (cond, result) in when_clauses {
333 Self::collect_from_expression(cond, required);
334 Self::collect_from_expression(result, required);
335 }
336 if let Some(else_expr) = else_clause {
337 Self::collect_from_expression(else_expr, required);
338 }
339 }
340 LogicalExpression::Labels(var)
341 | LogicalExpression::Type(var)
342 | LogicalExpression::Id(var) => {
343 required.insert(RequiredColumn::Variable(var.clone()));
344 }
345 LogicalExpression::ListComprehension {
346 list_expr,
347 filter_expr,
348 map_expr,
349 ..
350 } => {
351 Self::collect_from_expression(list_expr, required);
352 if let Some(filter) = filter_expr {
353 Self::collect_from_expression(filter, required);
354 }
355 Self::collect_from_expression(map_expr, required);
356 }
357 _ => {}
358 }
359 }
360
361 fn push_projections_recursive(
363 &self,
364 op: LogicalOperator,
365 required: &HashSet<RequiredColumn>,
366 ) -> LogicalOperator {
367 match op {
368 LogicalOperator::Return(mut ret) => {
369 ret.input = Box::new(self.push_projections_recursive(*ret.input, required));
370 LogicalOperator::Return(ret)
371 }
372 LogicalOperator::Project(mut proj) => {
373 proj.input = Box::new(self.push_projections_recursive(*proj.input, required));
374 LogicalOperator::Project(proj)
375 }
376 LogicalOperator::Filter(mut filter) => {
377 filter.input = Box::new(self.push_projections_recursive(*filter.input, required));
378 LogicalOperator::Filter(filter)
379 }
380 LogicalOperator::Sort(mut sort) => {
381 sort.input = Box::new(self.push_projections_recursive(*sort.input, required));
384 LogicalOperator::Sort(sort)
385 }
386 LogicalOperator::Aggregate(mut agg) => {
387 agg.input = Box::new(self.push_projections_recursive(*agg.input, required));
388 LogicalOperator::Aggregate(agg)
389 }
390 LogicalOperator::Join(mut join) => {
391 let left_vars = self.collect_output_variables(&join.left);
394 let right_vars = self.collect_output_variables(&join.right);
395
396 let left_required: HashSet<_> = required
398 .iter()
399 .filter(|c| match c {
400 RequiredColumn::Variable(v) => left_vars.contains(v),
401 RequiredColumn::Property(v, _) => left_vars.contains(v),
402 })
403 .cloned()
404 .collect();
405
406 let right_required: HashSet<_> = required
407 .iter()
408 .filter(|c| match c {
409 RequiredColumn::Variable(v) => right_vars.contains(v),
410 RequiredColumn::Property(v, _) => right_vars.contains(v),
411 })
412 .cloned()
413 .collect();
414
415 join.left = Box::new(self.push_projections_recursive(*join.left, &left_required));
416 join.right =
417 Box::new(self.push_projections_recursive(*join.right, &right_required));
418 LogicalOperator::Join(join)
419 }
420 LogicalOperator::Expand(mut expand) => {
421 expand.input = Box::new(self.push_projections_recursive(*expand.input, required));
422 LogicalOperator::Expand(expand)
423 }
424 LogicalOperator::Limit(mut limit) => {
425 limit.input = Box::new(self.push_projections_recursive(*limit.input, required));
426 LogicalOperator::Limit(limit)
427 }
428 LogicalOperator::Skip(mut skip) => {
429 skip.input = Box::new(self.push_projections_recursive(*skip.input, required));
430 LogicalOperator::Skip(skip)
431 }
432 LogicalOperator::Distinct(mut distinct) => {
433 distinct.input =
434 Box::new(self.push_projections_recursive(*distinct.input, required));
435 LogicalOperator::Distinct(distinct)
436 }
437 other => other,
438 }
439 }
440
441 fn reorder_joins(&self, op: LogicalOperator) -> LogicalOperator {
448 let op = self.reorder_joins_recursive(op);
450
451 if let Some((relations, conditions)) = self.extract_join_tree(&op)
453 && relations.len() >= 2
454 && let Some(optimized) = self.optimize_join_order(&relations, &conditions)
455 {
456 return optimized;
457 }
458
459 op
460 }
461
462 fn reorder_joins_recursive(&self, op: LogicalOperator) -> LogicalOperator {
464 match op {
465 LogicalOperator::Return(mut ret) => {
466 ret.input = Box::new(self.reorder_joins(*ret.input));
467 LogicalOperator::Return(ret)
468 }
469 LogicalOperator::Project(mut proj) => {
470 proj.input = Box::new(self.reorder_joins(*proj.input));
471 LogicalOperator::Project(proj)
472 }
473 LogicalOperator::Filter(mut filter) => {
474 filter.input = Box::new(self.reorder_joins(*filter.input));
475 LogicalOperator::Filter(filter)
476 }
477 LogicalOperator::Limit(mut limit) => {
478 limit.input = Box::new(self.reorder_joins(*limit.input));
479 LogicalOperator::Limit(limit)
480 }
481 LogicalOperator::Skip(mut skip) => {
482 skip.input = Box::new(self.reorder_joins(*skip.input));
483 LogicalOperator::Skip(skip)
484 }
485 LogicalOperator::Sort(mut sort) => {
486 sort.input = Box::new(self.reorder_joins(*sort.input));
487 LogicalOperator::Sort(sort)
488 }
489 LogicalOperator::Distinct(mut distinct) => {
490 distinct.input = Box::new(self.reorder_joins(*distinct.input));
491 LogicalOperator::Distinct(distinct)
492 }
493 LogicalOperator::Aggregate(mut agg) => {
494 agg.input = Box::new(self.reorder_joins(*agg.input));
495 LogicalOperator::Aggregate(agg)
496 }
497 LogicalOperator::Expand(mut expand) => {
498 expand.input = Box::new(self.reorder_joins(*expand.input));
499 LogicalOperator::Expand(expand)
500 }
501 other => other,
503 }
504 }
505
506 fn extract_join_tree(
510 &self,
511 op: &LogicalOperator,
512 ) -> Option<(Vec<(String, LogicalOperator)>, Vec<JoinInfo>)> {
513 let mut relations = Vec::new();
514 let mut join_conditions = Vec::new();
515
516 if !self.collect_join_tree(op, &mut relations, &mut join_conditions) {
517 return None;
518 }
519
520 if relations.len() < 2 {
521 return None;
522 }
523
524 Some((relations, join_conditions))
525 }
526
527 fn collect_join_tree(
531 &self,
532 op: &LogicalOperator,
533 relations: &mut Vec<(String, LogicalOperator)>,
534 conditions: &mut Vec<JoinInfo>,
535 ) -> bool {
536 match op {
537 LogicalOperator::Join(join) => {
538 let left_ok = self.collect_join_tree(&join.left, relations, conditions);
540 let right_ok = self.collect_join_tree(&join.right, relations, conditions);
541
542 for cond in &join.conditions {
544 if let (Some(left_var), Some(right_var)) = (
545 self.extract_variable_from_expr(&cond.left),
546 self.extract_variable_from_expr(&cond.right),
547 ) {
548 conditions.push(JoinInfo {
549 left_var,
550 right_var,
551 left_expr: cond.left.clone(),
552 right_expr: cond.right.clone(),
553 });
554 }
555 }
556
557 left_ok && right_ok
558 }
559 LogicalOperator::NodeScan(scan) => {
560 relations.push((scan.variable.clone(), op.clone()));
561 true
562 }
563 LogicalOperator::EdgeScan(scan) => {
564 relations.push((scan.variable.clone(), op.clone()));
565 true
566 }
567 LogicalOperator::Filter(filter) => {
568 self.collect_join_tree(&filter.input, relations, conditions)
570 }
571 LogicalOperator::Expand(expand) => {
572 relations.push((expand.to_variable.clone(), op.clone()));
575 true
576 }
577 _ => false,
578 }
579 }
580
581 fn extract_variable_from_expr(&self, expr: &LogicalExpression) -> Option<String> {
583 match expr {
584 LogicalExpression::Variable(v) => Some(v.clone()),
585 LogicalExpression::Property { variable, .. } => Some(variable.clone()),
586 _ => None,
587 }
588 }
589
590 fn optimize_join_order(
592 &self,
593 relations: &[(String, LogicalOperator)],
594 conditions: &[JoinInfo],
595 ) -> Option<LogicalOperator> {
596 use join_order::{DPccp, JoinGraphBuilder};
597
598 let mut builder = JoinGraphBuilder::new();
600
601 for (var, relation) in relations {
602 builder.add_relation(var, relation.clone());
603 }
604
605 for cond in conditions {
606 builder.add_join_condition(
607 &cond.left_var,
608 &cond.right_var,
609 cond.left_expr.clone(),
610 cond.right_expr.clone(),
611 );
612 }
613
614 let graph = builder.build();
615
616 let mut dpccp = DPccp::new(&graph, &self.cost_model, &self.card_estimator);
618 let plan = dpccp.optimize()?;
619
620 Some(plan.operator)
621 }
622
623 fn push_filters_down(&self, op: LogicalOperator) -> LogicalOperator {
628 match op {
629 LogicalOperator::Filter(filter) => {
631 let optimized_input = self.push_filters_down(*filter.input);
632 self.try_push_filter_into(filter.predicate, optimized_input)
633 }
634 LogicalOperator::Return(mut ret) => {
636 ret.input = Box::new(self.push_filters_down(*ret.input));
637 LogicalOperator::Return(ret)
638 }
639 LogicalOperator::Project(mut proj) => {
640 proj.input = Box::new(self.push_filters_down(*proj.input));
641 LogicalOperator::Project(proj)
642 }
643 LogicalOperator::Limit(mut limit) => {
644 limit.input = Box::new(self.push_filters_down(*limit.input));
645 LogicalOperator::Limit(limit)
646 }
647 LogicalOperator::Skip(mut skip) => {
648 skip.input = Box::new(self.push_filters_down(*skip.input));
649 LogicalOperator::Skip(skip)
650 }
651 LogicalOperator::Sort(mut sort) => {
652 sort.input = Box::new(self.push_filters_down(*sort.input));
653 LogicalOperator::Sort(sort)
654 }
655 LogicalOperator::Distinct(mut distinct) => {
656 distinct.input = Box::new(self.push_filters_down(*distinct.input));
657 LogicalOperator::Distinct(distinct)
658 }
659 LogicalOperator::Expand(mut expand) => {
660 expand.input = Box::new(self.push_filters_down(*expand.input));
661 LogicalOperator::Expand(expand)
662 }
663 LogicalOperator::Join(mut join) => {
664 join.left = Box::new(self.push_filters_down(*join.left));
665 join.right = Box::new(self.push_filters_down(*join.right));
666 LogicalOperator::Join(join)
667 }
668 LogicalOperator::Aggregate(mut agg) => {
669 agg.input = Box::new(self.push_filters_down(*agg.input));
670 LogicalOperator::Aggregate(agg)
671 }
672 other => other,
674 }
675 }
676
677 fn try_push_filter_into(
682 &self,
683 predicate: LogicalExpression,
684 op: LogicalOperator,
685 ) -> LogicalOperator {
686 match op {
687 LogicalOperator::Project(mut proj) => {
689 let predicate_vars = self.extract_variables(&predicate);
690 let computed_vars = self.extract_projection_aliases(&proj.projections);
691
692 if predicate_vars.is_disjoint(&computed_vars) {
694 proj.input = Box::new(self.try_push_filter_into(predicate, *proj.input));
695 LogicalOperator::Project(proj)
696 } else {
697 LogicalOperator::Filter(FilterOp {
699 predicate,
700 input: Box::new(LogicalOperator::Project(proj)),
701 })
702 }
703 }
704
705 LogicalOperator::Return(mut ret) => {
707 ret.input = Box::new(self.try_push_filter_into(predicate, *ret.input));
708 LogicalOperator::Return(ret)
709 }
710
711 LogicalOperator::Expand(mut expand) => {
713 let predicate_vars = self.extract_variables(&predicate);
714
715 let mut introduced_vars = vec![&expand.to_variable];
720 if let Some(ref edge_var) = expand.edge_variable {
721 introduced_vars.push(edge_var);
722 }
723 if let Some(ref path_alias) = expand.path_alias {
724 introduced_vars.push(path_alias);
725 }
726
727 let uses_introduced_vars =
729 predicate_vars.iter().any(|v| introduced_vars.contains(&v));
730
731 if !uses_introduced_vars {
732 expand.input = Box::new(self.try_push_filter_into(predicate, *expand.input));
734 LogicalOperator::Expand(expand)
735 } else {
736 LogicalOperator::Filter(FilterOp {
738 predicate,
739 input: Box::new(LogicalOperator::Expand(expand)),
740 })
741 }
742 }
743
744 LogicalOperator::Join(mut join) => {
746 let predicate_vars = self.extract_variables(&predicate);
747 let left_vars = self.collect_output_variables(&join.left);
748 let right_vars = self.collect_output_variables(&join.right);
749
750 let uses_left = predicate_vars.iter().any(|v| left_vars.contains(v));
751 let uses_right = predicate_vars.iter().any(|v| right_vars.contains(v));
752
753 if uses_left && !uses_right {
754 join.left = Box::new(self.try_push_filter_into(predicate, *join.left));
756 LogicalOperator::Join(join)
757 } else if uses_right && !uses_left {
758 join.right = Box::new(self.try_push_filter_into(predicate, *join.right));
760 LogicalOperator::Join(join)
761 } else {
762 LogicalOperator::Filter(FilterOp {
764 predicate,
765 input: Box::new(LogicalOperator::Join(join)),
766 })
767 }
768 }
769
770 LogicalOperator::Aggregate(agg) => LogicalOperator::Filter(FilterOp {
772 predicate,
773 input: Box::new(LogicalOperator::Aggregate(agg)),
774 }),
775
776 LogicalOperator::NodeScan(scan) => LogicalOperator::Filter(FilterOp {
778 predicate,
779 input: Box::new(LogicalOperator::NodeScan(scan)),
780 }),
781
782 other => LogicalOperator::Filter(FilterOp {
784 predicate,
785 input: Box::new(other),
786 }),
787 }
788 }
789
790 fn collect_output_variables(&self, op: &LogicalOperator) -> HashSet<String> {
792 let mut vars = HashSet::new();
793 Self::collect_output_variables_recursive(op, &mut vars);
794 vars
795 }
796
797 fn collect_output_variables_recursive(op: &LogicalOperator, vars: &mut HashSet<String>) {
799 match op {
800 LogicalOperator::NodeScan(scan) => {
801 vars.insert(scan.variable.clone());
802 }
803 LogicalOperator::EdgeScan(scan) => {
804 vars.insert(scan.variable.clone());
805 }
806 LogicalOperator::Expand(expand) => {
807 vars.insert(expand.to_variable.clone());
808 if let Some(edge_var) = &expand.edge_variable {
809 vars.insert(edge_var.clone());
810 }
811 Self::collect_output_variables_recursive(&expand.input, vars);
812 }
813 LogicalOperator::Filter(filter) => {
814 Self::collect_output_variables_recursive(&filter.input, vars);
815 }
816 LogicalOperator::Project(proj) => {
817 for p in &proj.projections {
818 if let Some(alias) = &p.alias {
819 vars.insert(alias.clone());
820 }
821 }
822 Self::collect_output_variables_recursive(&proj.input, vars);
823 }
824 LogicalOperator::Join(join) => {
825 Self::collect_output_variables_recursive(&join.left, vars);
826 Self::collect_output_variables_recursive(&join.right, vars);
827 }
828 LogicalOperator::Aggregate(agg) => {
829 for expr in &agg.group_by {
830 Self::collect_variables(expr, vars);
831 }
832 for agg_expr in &agg.aggregates {
833 if let Some(alias) = &agg_expr.alias {
834 vars.insert(alias.clone());
835 }
836 }
837 }
838 LogicalOperator::Return(ret) => {
839 Self::collect_output_variables_recursive(&ret.input, vars);
840 }
841 LogicalOperator::Limit(limit) => {
842 Self::collect_output_variables_recursive(&limit.input, vars);
843 }
844 LogicalOperator::Skip(skip) => {
845 Self::collect_output_variables_recursive(&skip.input, vars);
846 }
847 LogicalOperator::Sort(sort) => {
848 Self::collect_output_variables_recursive(&sort.input, vars);
849 }
850 LogicalOperator::Distinct(distinct) => {
851 Self::collect_output_variables_recursive(&distinct.input, vars);
852 }
853 _ => {}
854 }
855 }
856
857 fn extract_variables(&self, expr: &LogicalExpression) -> HashSet<String> {
859 let mut vars = HashSet::new();
860 Self::collect_variables(expr, &mut vars);
861 vars
862 }
863
864 fn collect_variables(expr: &LogicalExpression, vars: &mut HashSet<String>) {
866 match expr {
867 LogicalExpression::Variable(name) => {
868 vars.insert(name.clone());
869 }
870 LogicalExpression::Property { variable, .. } => {
871 vars.insert(variable.clone());
872 }
873 LogicalExpression::Binary { left, right, .. } => {
874 Self::collect_variables(left, vars);
875 Self::collect_variables(right, vars);
876 }
877 LogicalExpression::Unary { operand, .. } => {
878 Self::collect_variables(operand, vars);
879 }
880 LogicalExpression::FunctionCall { args, .. } => {
881 for arg in args {
882 Self::collect_variables(arg, vars);
883 }
884 }
885 LogicalExpression::List(items) => {
886 for item in items {
887 Self::collect_variables(item, vars);
888 }
889 }
890 LogicalExpression::Map(pairs) => {
891 for (_, value) in pairs {
892 Self::collect_variables(value, vars);
893 }
894 }
895 LogicalExpression::IndexAccess { base, index } => {
896 Self::collect_variables(base, vars);
897 Self::collect_variables(index, vars);
898 }
899 LogicalExpression::SliceAccess { base, start, end } => {
900 Self::collect_variables(base, vars);
901 if let Some(s) = start {
902 Self::collect_variables(s, vars);
903 }
904 if let Some(e) = end {
905 Self::collect_variables(e, vars);
906 }
907 }
908 LogicalExpression::Case {
909 operand,
910 when_clauses,
911 else_clause,
912 } => {
913 if let Some(op) = operand {
914 Self::collect_variables(op, vars);
915 }
916 for (cond, result) in when_clauses {
917 Self::collect_variables(cond, vars);
918 Self::collect_variables(result, vars);
919 }
920 if let Some(else_expr) = else_clause {
921 Self::collect_variables(else_expr, vars);
922 }
923 }
924 LogicalExpression::Labels(var)
925 | LogicalExpression::Type(var)
926 | LogicalExpression::Id(var) => {
927 vars.insert(var.clone());
928 }
929 LogicalExpression::Literal(_) | LogicalExpression::Parameter(_) => {}
930 LogicalExpression::ListComprehension {
931 list_expr,
932 filter_expr,
933 map_expr,
934 ..
935 } => {
936 Self::collect_variables(list_expr, vars);
937 if let Some(filter) = filter_expr {
938 Self::collect_variables(filter, vars);
939 }
940 Self::collect_variables(map_expr, vars);
941 }
942 LogicalExpression::ExistsSubquery(_) | LogicalExpression::CountSubquery(_) => {
943 }
945 }
946 }
947
948 fn extract_projection_aliases(
950 &self,
951 projections: &[crate::query::plan::Projection],
952 ) -> HashSet<String> {
953 projections.iter().filter_map(|p| p.alias.clone()).collect()
954 }
955}
956
957impl Default for Optimizer {
958 fn default() -> Self {
959 Self::new()
960 }
961}
962
963#[cfg(test)]
964mod tests {
965 use super::*;
966 use crate::query::plan::{
967 AggregateExpr, AggregateFunction, AggregateOp, BinaryOp, DistinctOp, ExpandDirection,
968 ExpandOp, JoinOp, JoinType, LimitOp, NodeScanOp, ProjectOp, Projection, ReturnItem,
969 ReturnOp, SkipOp, SortKey, SortOp, SortOrder, UnaryOp,
970 };
971 use grafeo_common::types::Value;
972
973 #[test]
974 fn test_optimizer_filter_pushdown_simple() {
975 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
980 items: vec![ReturnItem {
981 expression: LogicalExpression::Variable("n".to_string()),
982 alias: None,
983 }],
984 distinct: false,
985 input: Box::new(LogicalOperator::Filter(FilterOp {
986 predicate: LogicalExpression::Binary {
987 left: Box::new(LogicalExpression::Property {
988 variable: "n".to_string(),
989 property: "age".to_string(),
990 }),
991 op: BinaryOp::Gt,
992 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
993 },
994 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
995 variable: "n".to_string(),
996 label: Some("Person".to_string()),
997 input: None,
998 })),
999 })),
1000 }));
1001
1002 let optimizer = Optimizer::new();
1003 let optimized = optimizer.optimize(plan).unwrap();
1004
1005 if let LogicalOperator::Return(ret) = &optimized.root
1007 && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1008 && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1009 {
1010 assert_eq!(scan.variable, "n");
1011 return;
1012 }
1013 panic!("Expected Return -> Filter -> NodeScan structure");
1014 }
1015
1016 #[test]
1017 fn test_optimizer_filter_pushdown_through_expand() {
1018 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1022 items: vec![ReturnItem {
1023 expression: LogicalExpression::Variable("b".to_string()),
1024 alias: None,
1025 }],
1026 distinct: false,
1027 input: Box::new(LogicalOperator::Filter(FilterOp {
1028 predicate: LogicalExpression::Binary {
1029 left: Box::new(LogicalExpression::Property {
1030 variable: "a".to_string(),
1031 property: "age".to_string(),
1032 }),
1033 op: BinaryOp::Gt,
1034 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1035 },
1036 input: Box::new(LogicalOperator::Expand(ExpandOp {
1037 from_variable: "a".to_string(),
1038 to_variable: "b".to_string(),
1039 edge_variable: None,
1040 direction: ExpandDirection::Outgoing,
1041 edge_type: Some("KNOWS".to_string()),
1042 min_hops: 1,
1043 max_hops: Some(1),
1044 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1045 variable: "a".to_string(),
1046 label: Some("Person".to_string()),
1047 input: None,
1048 })),
1049 path_alias: None,
1050 })),
1051 })),
1052 }));
1053
1054 let optimizer = Optimizer::new();
1055 let optimized = optimizer.optimize(plan).unwrap();
1056
1057 if let LogicalOperator::Return(ret) = &optimized.root
1060 && let LogicalOperator::Expand(expand) = ret.input.as_ref()
1061 && let LogicalOperator::Filter(filter) = expand.input.as_ref()
1062 && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1063 {
1064 assert_eq!(scan.variable, "a");
1065 assert_eq!(expand.from_variable, "a");
1066 assert_eq!(expand.to_variable, "b");
1067 return;
1068 }
1069 panic!("Expected Return -> Expand -> Filter -> NodeScan structure");
1070 }
1071
1072 #[test]
1073 fn test_optimizer_filter_not_pushed_through_expand_for_target_var() {
1074 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1078 items: vec![ReturnItem {
1079 expression: LogicalExpression::Variable("a".to_string()),
1080 alias: None,
1081 }],
1082 distinct: false,
1083 input: Box::new(LogicalOperator::Filter(FilterOp {
1084 predicate: LogicalExpression::Binary {
1085 left: Box::new(LogicalExpression::Property {
1086 variable: "b".to_string(),
1087 property: "age".to_string(),
1088 }),
1089 op: BinaryOp::Gt,
1090 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1091 },
1092 input: Box::new(LogicalOperator::Expand(ExpandOp {
1093 from_variable: "a".to_string(),
1094 to_variable: "b".to_string(),
1095 edge_variable: None,
1096 direction: ExpandDirection::Outgoing,
1097 edge_type: Some("KNOWS".to_string()),
1098 min_hops: 1,
1099 max_hops: Some(1),
1100 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1101 variable: "a".to_string(),
1102 label: Some("Person".to_string()),
1103 input: None,
1104 })),
1105 path_alias: None,
1106 })),
1107 })),
1108 }));
1109
1110 let optimizer = Optimizer::new();
1111 let optimized = optimizer.optimize(plan).unwrap();
1112
1113 if let LogicalOperator::Return(ret) = &optimized.root
1116 && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1117 {
1118 if let LogicalExpression::Binary { left, .. } = &filter.predicate
1120 && let LogicalExpression::Property { variable, .. } = left.as_ref()
1121 {
1122 assert_eq!(variable, "b");
1123 }
1124
1125 if let LogicalOperator::Expand(expand) = filter.input.as_ref()
1126 && let LogicalOperator::NodeScan(_) = expand.input.as_ref()
1127 {
1128 return;
1129 }
1130 }
1131 panic!("Expected Return -> Filter -> Expand -> NodeScan structure");
1132 }
1133
1134 #[test]
1135 fn test_optimizer_extract_variables() {
1136 let optimizer = Optimizer::new();
1137
1138 let expr = LogicalExpression::Binary {
1139 left: Box::new(LogicalExpression::Property {
1140 variable: "n".to_string(),
1141 property: "age".to_string(),
1142 }),
1143 op: BinaryOp::Gt,
1144 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1145 };
1146
1147 let vars = optimizer.extract_variables(&expr);
1148 assert_eq!(vars.len(), 1);
1149 assert!(vars.contains("n"));
1150 }
1151
1152 #[test]
1155 fn test_optimizer_default() {
1156 let optimizer = Optimizer::default();
1157 let plan = LogicalPlan::new(LogicalOperator::Empty);
1159 let result = optimizer.optimize(plan);
1160 assert!(result.is_ok());
1161 }
1162
1163 #[test]
1164 fn test_optimizer_with_filter_pushdown_disabled() {
1165 let optimizer = Optimizer::new().with_filter_pushdown(false);
1166
1167 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1168 items: vec![ReturnItem {
1169 expression: LogicalExpression::Variable("n".to_string()),
1170 alias: None,
1171 }],
1172 distinct: false,
1173 input: Box::new(LogicalOperator::Filter(FilterOp {
1174 predicate: LogicalExpression::Literal(Value::Bool(true)),
1175 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1176 variable: "n".to_string(),
1177 label: None,
1178 input: None,
1179 })),
1180 })),
1181 }));
1182
1183 let optimized = optimizer.optimize(plan).unwrap();
1184 if let LogicalOperator::Return(ret) = &optimized.root
1186 && let LogicalOperator::Filter(_) = ret.input.as_ref()
1187 {
1188 return;
1189 }
1190 panic!("Expected unchanged structure");
1191 }
1192
1193 #[test]
1194 fn test_optimizer_with_join_reorder_disabled() {
1195 let optimizer = Optimizer::new().with_join_reorder(false);
1196 assert!(
1197 optimizer
1198 .optimize(LogicalPlan::new(LogicalOperator::Empty))
1199 .is_ok()
1200 );
1201 }
1202
1203 #[test]
1204 fn test_optimizer_with_cost_model() {
1205 let cost_model = CostModel::new();
1206 let optimizer = Optimizer::new().with_cost_model(cost_model);
1207 assert!(
1208 optimizer
1209 .cost_model()
1210 .estimate(&LogicalOperator::Empty, 0.0)
1211 .total()
1212 < 0.001
1213 );
1214 }
1215
1216 #[test]
1217 fn test_optimizer_with_cardinality_estimator() {
1218 let mut estimator = CardinalityEstimator::new();
1219 estimator.add_table_stats("Test", TableStats::new(500));
1220 let optimizer = Optimizer::new().with_cardinality_estimator(estimator);
1221
1222 let scan = LogicalOperator::NodeScan(NodeScanOp {
1223 variable: "n".to_string(),
1224 label: Some("Test".to_string()),
1225 input: None,
1226 });
1227 let plan = LogicalPlan::new(scan);
1228
1229 let cardinality = optimizer.estimate_cardinality(&plan);
1230 assert!((cardinality - 500.0).abs() < 0.001);
1231 }
1232
1233 #[test]
1234 fn test_optimizer_estimate_cost() {
1235 let optimizer = Optimizer::new();
1236 let plan = LogicalPlan::new(LogicalOperator::NodeScan(NodeScanOp {
1237 variable: "n".to_string(),
1238 label: None,
1239 input: None,
1240 }));
1241
1242 let cost = optimizer.estimate_cost(&plan);
1243 assert!(cost.total() > 0.0);
1244 }
1245
1246 #[test]
1249 fn test_filter_pushdown_through_project() {
1250 let optimizer = Optimizer::new();
1251
1252 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1253 predicate: LogicalExpression::Binary {
1254 left: Box::new(LogicalExpression::Property {
1255 variable: "n".to_string(),
1256 property: "age".to_string(),
1257 }),
1258 op: BinaryOp::Gt,
1259 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1260 },
1261 input: Box::new(LogicalOperator::Project(ProjectOp {
1262 projections: vec![Projection {
1263 expression: LogicalExpression::Variable("n".to_string()),
1264 alias: None,
1265 }],
1266 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1267 variable: "n".to_string(),
1268 label: None,
1269 input: None,
1270 })),
1271 })),
1272 }));
1273
1274 let optimized = optimizer.optimize(plan).unwrap();
1275
1276 if let LogicalOperator::Project(proj) = &optimized.root
1278 && let LogicalOperator::Filter(_) = proj.input.as_ref()
1279 {
1280 return;
1281 }
1282 panic!("Expected Project -> Filter structure");
1283 }
1284
1285 #[test]
1286 fn test_filter_not_pushed_through_project_with_alias() {
1287 let optimizer = Optimizer::new();
1288
1289 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1291 predicate: LogicalExpression::Binary {
1292 left: Box::new(LogicalExpression::Variable("x".to_string())),
1293 op: BinaryOp::Gt,
1294 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1295 },
1296 input: Box::new(LogicalOperator::Project(ProjectOp {
1297 projections: vec![Projection {
1298 expression: LogicalExpression::Property {
1299 variable: "n".to_string(),
1300 property: "age".to_string(),
1301 },
1302 alias: Some("x".to_string()),
1303 }],
1304 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1305 variable: "n".to_string(),
1306 label: None,
1307 input: None,
1308 })),
1309 })),
1310 }));
1311
1312 let optimized = optimizer.optimize(plan).unwrap();
1313
1314 if let LogicalOperator::Filter(filter) = &optimized.root
1316 && let LogicalOperator::Project(_) = filter.input.as_ref()
1317 {
1318 return;
1319 }
1320 panic!("Expected Filter -> Project structure");
1321 }
1322
1323 #[test]
1324 fn test_filter_pushdown_through_limit() {
1325 let optimizer = Optimizer::new();
1326
1327 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1328 predicate: LogicalExpression::Literal(Value::Bool(true)),
1329 input: Box::new(LogicalOperator::Limit(LimitOp {
1330 count: 10,
1331 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1332 variable: "n".to_string(),
1333 label: None,
1334 input: None,
1335 })),
1336 })),
1337 }));
1338
1339 let optimized = optimizer.optimize(plan).unwrap();
1340
1341 if let LogicalOperator::Filter(filter) = &optimized.root
1343 && let LogicalOperator::Limit(_) = filter.input.as_ref()
1344 {
1345 return;
1346 }
1347 panic!("Expected Filter -> Limit structure");
1348 }
1349
1350 #[test]
1351 fn test_filter_pushdown_through_sort() {
1352 let optimizer = Optimizer::new();
1353
1354 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1355 predicate: LogicalExpression::Literal(Value::Bool(true)),
1356 input: Box::new(LogicalOperator::Sort(SortOp {
1357 keys: vec![SortKey {
1358 expression: LogicalExpression::Variable("n".to_string()),
1359 order: SortOrder::Ascending,
1360 }],
1361 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1362 variable: "n".to_string(),
1363 label: None,
1364 input: None,
1365 })),
1366 })),
1367 }));
1368
1369 let optimized = optimizer.optimize(plan).unwrap();
1370
1371 if let LogicalOperator::Filter(filter) = &optimized.root
1373 && let LogicalOperator::Sort(_) = filter.input.as_ref()
1374 {
1375 return;
1376 }
1377 panic!("Expected Filter -> Sort structure");
1378 }
1379
1380 #[test]
1381 fn test_filter_pushdown_through_distinct() {
1382 let optimizer = Optimizer::new();
1383
1384 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1385 predicate: LogicalExpression::Literal(Value::Bool(true)),
1386 input: Box::new(LogicalOperator::Distinct(DistinctOp {
1387 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1388 variable: "n".to_string(),
1389 label: None,
1390 input: None,
1391 })),
1392 columns: None,
1393 })),
1394 }));
1395
1396 let optimized = optimizer.optimize(plan).unwrap();
1397
1398 if let LogicalOperator::Filter(filter) = &optimized.root
1400 && let LogicalOperator::Distinct(_) = filter.input.as_ref()
1401 {
1402 return;
1403 }
1404 panic!("Expected Filter -> Distinct structure");
1405 }
1406
1407 #[test]
1408 fn test_filter_not_pushed_through_aggregate() {
1409 let optimizer = Optimizer::new();
1410
1411 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1412 predicate: LogicalExpression::Binary {
1413 left: Box::new(LogicalExpression::Variable("cnt".to_string())),
1414 op: BinaryOp::Gt,
1415 right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1416 },
1417 input: Box::new(LogicalOperator::Aggregate(AggregateOp {
1418 group_by: vec![],
1419 aggregates: vec![AggregateExpr {
1420 function: AggregateFunction::Count,
1421 expression: None,
1422 distinct: false,
1423 alias: Some("cnt".to_string()),
1424 percentile: None,
1425 }],
1426 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1427 variable: "n".to_string(),
1428 label: None,
1429 input: None,
1430 })),
1431 having: None,
1432 })),
1433 }));
1434
1435 let optimized = optimizer.optimize(plan).unwrap();
1436
1437 if let LogicalOperator::Filter(filter) = &optimized.root
1439 && let LogicalOperator::Aggregate(_) = filter.input.as_ref()
1440 {
1441 return;
1442 }
1443 panic!("Expected Filter -> Aggregate structure");
1444 }
1445
1446 #[test]
1447 fn test_filter_pushdown_to_left_join_side() {
1448 let optimizer = Optimizer::new();
1449
1450 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1452 predicate: LogicalExpression::Binary {
1453 left: Box::new(LogicalExpression::Property {
1454 variable: "a".to_string(),
1455 property: "age".to_string(),
1456 }),
1457 op: BinaryOp::Gt,
1458 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1459 },
1460 input: Box::new(LogicalOperator::Join(JoinOp {
1461 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1462 variable: "a".to_string(),
1463 label: Some("Person".to_string()),
1464 input: None,
1465 })),
1466 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1467 variable: "b".to_string(),
1468 label: Some("Company".to_string()),
1469 input: None,
1470 })),
1471 join_type: JoinType::Inner,
1472 conditions: vec![],
1473 })),
1474 }));
1475
1476 let optimized = optimizer.optimize(plan).unwrap();
1477
1478 if let LogicalOperator::Join(join) = &optimized.root
1480 && let LogicalOperator::Filter(_) = join.left.as_ref()
1481 {
1482 return;
1483 }
1484 panic!("Expected Join with Filter on left side");
1485 }
1486
1487 #[test]
1488 fn test_filter_pushdown_to_right_join_side() {
1489 let optimizer = Optimizer::new();
1490
1491 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1493 predicate: LogicalExpression::Binary {
1494 left: Box::new(LogicalExpression::Property {
1495 variable: "b".to_string(),
1496 property: "name".to_string(),
1497 }),
1498 op: BinaryOp::Eq,
1499 right: Box::new(LogicalExpression::Literal(Value::String("Acme".into()))),
1500 },
1501 input: Box::new(LogicalOperator::Join(JoinOp {
1502 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1503 variable: "a".to_string(),
1504 label: Some("Person".to_string()),
1505 input: None,
1506 })),
1507 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1508 variable: "b".to_string(),
1509 label: Some("Company".to_string()),
1510 input: None,
1511 })),
1512 join_type: JoinType::Inner,
1513 conditions: vec![],
1514 })),
1515 }));
1516
1517 let optimized = optimizer.optimize(plan).unwrap();
1518
1519 if let LogicalOperator::Join(join) = &optimized.root
1521 && let LogicalOperator::Filter(_) = join.right.as_ref()
1522 {
1523 return;
1524 }
1525 panic!("Expected Join with Filter on right side");
1526 }
1527
1528 #[test]
1529 fn test_filter_not_pushed_when_uses_both_join_sides() {
1530 let optimizer = Optimizer::new();
1531
1532 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1534 predicate: LogicalExpression::Binary {
1535 left: Box::new(LogicalExpression::Property {
1536 variable: "a".to_string(),
1537 property: "id".to_string(),
1538 }),
1539 op: BinaryOp::Eq,
1540 right: Box::new(LogicalExpression::Property {
1541 variable: "b".to_string(),
1542 property: "a_id".to_string(),
1543 }),
1544 },
1545 input: Box::new(LogicalOperator::Join(JoinOp {
1546 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1547 variable: "a".to_string(),
1548 label: None,
1549 input: None,
1550 })),
1551 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1552 variable: "b".to_string(),
1553 label: None,
1554 input: None,
1555 })),
1556 join_type: JoinType::Inner,
1557 conditions: vec![],
1558 })),
1559 }));
1560
1561 let optimized = optimizer.optimize(plan).unwrap();
1562
1563 if let LogicalOperator::Filter(filter) = &optimized.root
1565 && let LogicalOperator::Join(_) = filter.input.as_ref()
1566 {
1567 return;
1568 }
1569 panic!("Expected Filter -> Join structure");
1570 }
1571
1572 #[test]
1575 fn test_extract_variables_from_variable() {
1576 let optimizer = Optimizer::new();
1577 let expr = LogicalExpression::Variable("x".to_string());
1578 let vars = optimizer.extract_variables(&expr);
1579 assert_eq!(vars.len(), 1);
1580 assert!(vars.contains("x"));
1581 }
1582
1583 #[test]
1584 fn test_extract_variables_from_unary() {
1585 let optimizer = Optimizer::new();
1586 let expr = LogicalExpression::Unary {
1587 op: UnaryOp::Not,
1588 operand: Box::new(LogicalExpression::Variable("x".to_string())),
1589 };
1590 let vars = optimizer.extract_variables(&expr);
1591 assert_eq!(vars.len(), 1);
1592 assert!(vars.contains("x"));
1593 }
1594
1595 #[test]
1596 fn test_extract_variables_from_function_call() {
1597 let optimizer = Optimizer::new();
1598 let expr = LogicalExpression::FunctionCall {
1599 name: "length".to_string(),
1600 args: vec![
1601 LogicalExpression::Variable("a".to_string()),
1602 LogicalExpression::Variable("b".to_string()),
1603 ],
1604 distinct: false,
1605 };
1606 let vars = optimizer.extract_variables(&expr);
1607 assert_eq!(vars.len(), 2);
1608 assert!(vars.contains("a"));
1609 assert!(vars.contains("b"));
1610 }
1611
1612 #[test]
1613 fn test_extract_variables_from_list() {
1614 let optimizer = Optimizer::new();
1615 let expr = LogicalExpression::List(vec![
1616 LogicalExpression::Variable("a".to_string()),
1617 LogicalExpression::Literal(Value::Int64(1)),
1618 LogicalExpression::Variable("b".to_string()),
1619 ]);
1620 let vars = optimizer.extract_variables(&expr);
1621 assert_eq!(vars.len(), 2);
1622 assert!(vars.contains("a"));
1623 assert!(vars.contains("b"));
1624 }
1625
1626 #[test]
1627 fn test_extract_variables_from_map() {
1628 let optimizer = Optimizer::new();
1629 let expr = LogicalExpression::Map(vec![
1630 (
1631 "key1".to_string(),
1632 LogicalExpression::Variable("a".to_string()),
1633 ),
1634 (
1635 "key2".to_string(),
1636 LogicalExpression::Variable("b".to_string()),
1637 ),
1638 ]);
1639 let vars = optimizer.extract_variables(&expr);
1640 assert_eq!(vars.len(), 2);
1641 assert!(vars.contains("a"));
1642 assert!(vars.contains("b"));
1643 }
1644
1645 #[test]
1646 fn test_extract_variables_from_index_access() {
1647 let optimizer = Optimizer::new();
1648 let expr = LogicalExpression::IndexAccess {
1649 base: Box::new(LogicalExpression::Variable("list".to_string())),
1650 index: Box::new(LogicalExpression::Variable("idx".to_string())),
1651 };
1652 let vars = optimizer.extract_variables(&expr);
1653 assert_eq!(vars.len(), 2);
1654 assert!(vars.contains("list"));
1655 assert!(vars.contains("idx"));
1656 }
1657
1658 #[test]
1659 fn test_extract_variables_from_slice_access() {
1660 let optimizer = Optimizer::new();
1661 let expr = LogicalExpression::SliceAccess {
1662 base: Box::new(LogicalExpression::Variable("list".to_string())),
1663 start: Some(Box::new(LogicalExpression::Variable("s".to_string()))),
1664 end: Some(Box::new(LogicalExpression::Variable("e".to_string()))),
1665 };
1666 let vars = optimizer.extract_variables(&expr);
1667 assert_eq!(vars.len(), 3);
1668 assert!(vars.contains("list"));
1669 assert!(vars.contains("s"));
1670 assert!(vars.contains("e"));
1671 }
1672
1673 #[test]
1674 fn test_extract_variables_from_case() {
1675 let optimizer = Optimizer::new();
1676 let expr = LogicalExpression::Case {
1677 operand: Some(Box::new(LogicalExpression::Variable("x".to_string()))),
1678 when_clauses: vec![(
1679 LogicalExpression::Literal(Value::Int64(1)),
1680 LogicalExpression::Variable("a".to_string()),
1681 )],
1682 else_clause: Some(Box::new(LogicalExpression::Variable("b".to_string()))),
1683 };
1684 let vars = optimizer.extract_variables(&expr);
1685 assert_eq!(vars.len(), 3);
1686 assert!(vars.contains("x"));
1687 assert!(vars.contains("a"));
1688 assert!(vars.contains("b"));
1689 }
1690
1691 #[test]
1692 fn test_extract_variables_from_labels() {
1693 let optimizer = Optimizer::new();
1694 let expr = LogicalExpression::Labels("n".to_string());
1695 let vars = optimizer.extract_variables(&expr);
1696 assert_eq!(vars.len(), 1);
1697 assert!(vars.contains("n"));
1698 }
1699
1700 #[test]
1701 fn test_extract_variables_from_type() {
1702 let optimizer = Optimizer::new();
1703 let expr = LogicalExpression::Type("e".to_string());
1704 let vars = optimizer.extract_variables(&expr);
1705 assert_eq!(vars.len(), 1);
1706 assert!(vars.contains("e"));
1707 }
1708
1709 #[test]
1710 fn test_extract_variables_from_id() {
1711 let optimizer = Optimizer::new();
1712 let expr = LogicalExpression::Id("n".to_string());
1713 let vars = optimizer.extract_variables(&expr);
1714 assert_eq!(vars.len(), 1);
1715 assert!(vars.contains("n"));
1716 }
1717
1718 #[test]
1719 fn test_extract_variables_from_list_comprehension() {
1720 let optimizer = Optimizer::new();
1721 let expr = LogicalExpression::ListComprehension {
1722 variable: "x".to_string(),
1723 list_expr: Box::new(LogicalExpression::Variable("items".to_string())),
1724 filter_expr: Some(Box::new(LogicalExpression::Variable("pred".to_string()))),
1725 map_expr: Box::new(LogicalExpression::Variable("result".to_string())),
1726 };
1727 let vars = optimizer.extract_variables(&expr);
1728 assert!(vars.contains("items"));
1729 assert!(vars.contains("pred"));
1730 assert!(vars.contains("result"));
1731 }
1732
1733 #[test]
1734 fn test_extract_variables_from_literal_and_parameter() {
1735 let optimizer = Optimizer::new();
1736
1737 let literal = LogicalExpression::Literal(Value::Int64(42));
1738 assert!(optimizer.extract_variables(&literal).is_empty());
1739
1740 let param = LogicalExpression::Parameter("p".to_string());
1741 assert!(optimizer.extract_variables(¶m).is_empty());
1742 }
1743
1744 #[test]
1747 fn test_recursive_filter_pushdown_through_skip() {
1748 let optimizer = Optimizer::new();
1749
1750 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1751 items: vec![ReturnItem {
1752 expression: LogicalExpression::Variable("n".to_string()),
1753 alias: None,
1754 }],
1755 distinct: false,
1756 input: Box::new(LogicalOperator::Filter(FilterOp {
1757 predicate: LogicalExpression::Literal(Value::Bool(true)),
1758 input: Box::new(LogicalOperator::Skip(SkipOp {
1759 count: 5,
1760 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1761 variable: "n".to_string(),
1762 label: None,
1763 input: None,
1764 })),
1765 })),
1766 })),
1767 }));
1768
1769 let optimized = optimizer.optimize(plan).unwrap();
1770
1771 assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1773 }
1774
1775 #[test]
1776 fn test_nested_filter_pushdown() {
1777 let optimizer = Optimizer::new();
1778
1779 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1781 items: vec![ReturnItem {
1782 expression: LogicalExpression::Variable("n".to_string()),
1783 alias: None,
1784 }],
1785 distinct: false,
1786 input: Box::new(LogicalOperator::Filter(FilterOp {
1787 predicate: LogicalExpression::Binary {
1788 left: Box::new(LogicalExpression::Property {
1789 variable: "n".to_string(),
1790 property: "x".to_string(),
1791 }),
1792 op: BinaryOp::Gt,
1793 right: Box::new(LogicalExpression::Literal(Value::Int64(1))),
1794 },
1795 input: Box::new(LogicalOperator::Filter(FilterOp {
1796 predicate: LogicalExpression::Binary {
1797 left: Box::new(LogicalExpression::Property {
1798 variable: "n".to_string(),
1799 property: "y".to_string(),
1800 }),
1801 op: BinaryOp::Lt,
1802 right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1803 },
1804 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1805 variable: "n".to_string(),
1806 label: None,
1807 input: None,
1808 })),
1809 })),
1810 })),
1811 }));
1812
1813 let optimized = optimizer.optimize(plan).unwrap();
1814 assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1815 }
1816}