1pub mod cardinality;
15pub mod cost;
16pub mod join_order;
17
18pub use cardinality::{
19 CardinalityEstimator, ColumnStats, EstimationLog, SelectivityConfig, TableStats,
20};
21pub use cost::{Cost, CostModel};
22pub use join_order::{BitSet, DPccp, JoinGraph, JoinGraphBuilder, JoinPlan};
23
24use crate::query::plan::{FilterOp, LogicalExpression, LogicalOperator, LogicalPlan};
25use grafeo_common::utils::error::Result;
26use std::collections::HashSet;
27
28#[derive(Debug, Clone)]
30struct JoinInfo {
31 left_var: String,
32 right_var: String,
33 left_expr: LogicalExpression,
34 right_expr: LogicalExpression,
35}
36
37#[derive(Debug, Clone, PartialEq, Eq, Hash)]
39enum RequiredColumn {
40 Variable(String),
42 Property(String, String),
44}
45
46pub struct Optimizer {
51 enable_filter_pushdown: bool,
53 enable_join_reorder: bool,
55 enable_projection_pushdown: bool,
57 cost_model: CostModel,
59 card_estimator: CardinalityEstimator,
61}
62
63impl Optimizer {
64 #[must_use]
66 pub fn new() -> Self {
67 Self {
68 enable_filter_pushdown: true,
69 enable_join_reorder: true,
70 enable_projection_pushdown: true,
71 cost_model: CostModel::new(),
72 card_estimator: CardinalityEstimator::new(),
73 }
74 }
75
76 #[must_use]
82 pub fn from_store(store: &grafeo_core::graph::lpg::LpgStore) -> Self {
83 store.ensure_statistics_fresh();
84 let stats = store.statistics();
85 let estimator = CardinalityEstimator::from_statistics(&stats);
86
87 let avg_fanout = if stats.total_nodes > 0 {
89 (stats.total_edges as f64 / stats.total_nodes as f64).max(1.0)
90 } else {
91 10.0
92 };
93
94 let edge_type_degrees: std::collections::HashMap<String, (f64, f64)> = stats
96 .edge_types
97 .iter()
98 .map(|(name, et)| (name.clone(), (et.avg_out_degree, et.avg_in_degree)))
99 .collect();
100
101 Self {
102 enable_filter_pushdown: true,
103 enable_join_reorder: true,
104 enable_projection_pushdown: true,
105 cost_model: CostModel::new()
106 .with_avg_fanout(avg_fanout)
107 .with_edge_type_degrees(edge_type_degrees),
108 card_estimator: estimator,
109 }
110 }
111
112 pub fn with_filter_pushdown(mut self, enabled: bool) -> Self {
114 self.enable_filter_pushdown = enabled;
115 self
116 }
117
118 pub fn with_join_reorder(mut self, enabled: bool) -> Self {
120 self.enable_join_reorder = enabled;
121 self
122 }
123
124 pub fn with_projection_pushdown(mut self, enabled: bool) -> Self {
126 self.enable_projection_pushdown = enabled;
127 self
128 }
129
130 pub fn with_cost_model(mut self, cost_model: CostModel) -> Self {
132 self.cost_model = cost_model;
133 self
134 }
135
136 pub fn with_cardinality_estimator(mut self, estimator: CardinalityEstimator) -> Self {
138 self.card_estimator = estimator;
139 self
140 }
141
142 pub fn with_selectivity_config(mut self, config: SelectivityConfig) -> Self {
144 self.card_estimator = CardinalityEstimator::with_selectivity_config(config);
145 self
146 }
147
148 pub fn cost_model(&self) -> &CostModel {
150 &self.cost_model
151 }
152
153 pub fn cardinality_estimator(&self) -> &CardinalityEstimator {
155 &self.card_estimator
156 }
157
158 pub fn estimate_cost(&self, plan: &LogicalPlan) -> Cost {
160 let cardinality = self.card_estimator.estimate(&plan.root);
161 self.cost_model.estimate(&plan.root, cardinality)
162 }
163
164 pub fn estimate_cardinality(&self, plan: &LogicalPlan) -> f64 {
166 self.card_estimator.estimate(&plan.root)
167 }
168
169 pub fn optimize(&self, plan: LogicalPlan) -> Result<LogicalPlan> {
175 let mut root = plan.root;
176
177 if self.enable_filter_pushdown {
179 root = self.push_filters_down(root);
180 }
181
182 if self.enable_join_reorder {
183 root = self.reorder_joins(root);
184 }
185
186 if self.enable_projection_pushdown {
187 root = self.push_projections_down(root);
188 }
189
190 Ok(LogicalPlan::new(root))
191 }
192
193 fn push_projections_down(&self, op: LogicalOperator) -> LogicalOperator {
200 let required = self.collect_required_columns(&op);
202
203 self.push_projections_recursive(op, &required)
205 }
206
207 fn collect_required_columns(&self, op: &LogicalOperator) -> HashSet<RequiredColumn> {
209 let mut required = HashSet::new();
210 Self::collect_required_recursive(op, &mut required);
211 required
212 }
213
214 fn collect_required_recursive(op: &LogicalOperator, required: &mut HashSet<RequiredColumn>) {
216 match op {
217 LogicalOperator::Return(ret) => {
218 for item in &ret.items {
219 Self::collect_from_expression(&item.expression, required);
220 }
221 Self::collect_required_recursive(&ret.input, required);
222 }
223 LogicalOperator::Project(proj) => {
224 for p in &proj.projections {
225 Self::collect_from_expression(&p.expression, required);
226 }
227 Self::collect_required_recursive(&proj.input, required);
228 }
229 LogicalOperator::Filter(filter) => {
230 Self::collect_from_expression(&filter.predicate, required);
231 Self::collect_required_recursive(&filter.input, required);
232 }
233 LogicalOperator::Sort(sort) => {
234 for key in &sort.keys {
235 Self::collect_from_expression(&key.expression, required);
236 }
237 Self::collect_required_recursive(&sort.input, required);
238 }
239 LogicalOperator::Aggregate(agg) => {
240 for expr in &agg.group_by {
241 Self::collect_from_expression(expr, required);
242 }
243 for agg_expr in &agg.aggregates {
244 if let Some(ref expr) = agg_expr.expression {
245 Self::collect_from_expression(expr, required);
246 }
247 }
248 if let Some(ref having) = agg.having {
249 Self::collect_from_expression(having, required);
250 }
251 Self::collect_required_recursive(&agg.input, required);
252 }
253 LogicalOperator::Join(join) => {
254 for cond in &join.conditions {
255 Self::collect_from_expression(&cond.left, required);
256 Self::collect_from_expression(&cond.right, required);
257 }
258 Self::collect_required_recursive(&join.left, required);
259 Self::collect_required_recursive(&join.right, required);
260 }
261 LogicalOperator::Expand(expand) => {
262 required.insert(RequiredColumn::Variable(expand.from_variable.clone()));
264 required.insert(RequiredColumn::Variable(expand.to_variable.clone()));
265 if let Some(ref edge_var) = expand.edge_variable {
266 required.insert(RequiredColumn::Variable(edge_var.clone()));
267 }
268 Self::collect_required_recursive(&expand.input, required);
269 }
270 LogicalOperator::Limit(limit) => {
271 Self::collect_required_recursive(&limit.input, required);
272 }
273 LogicalOperator::Skip(skip) => {
274 Self::collect_required_recursive(&skip.input, required);
275 }
276 LogicalOperator::Distinct(distinct) => {
277 Self::collect_required_recursive(&distinct.input, required);
278 }
279 LogicalOperator::NodeScan(scan) => {
280 required.insert(RequiredColumn::Variable(scan.variable.clone()));
281 }
282 LogicalOperator::EdgeScan(scan) => {
283 required.insert(RequiredColumn::Variable(scan.variable.clone()));
284 }
285 _ => {}
286 }
287 }
288
289 fn collect_from_expression(expr: &LogicalExpression, required: &mut HashSet<RequiredColumn>) {
291 match expr {
292 LogicalExpression::Variable(var) => {
293 required.insert(RequiredColumn::Variable(var.clone()));
294 }
295 LogicalExpression::Property { variable, property } => {
296 required.insert(RequiredColumn::Property(variable.clone(), property.clone()));
297 required.insert(RequiredColumn::Variable(variable.clone()));
298 }
299 LogicalExpression::Binary { left, right, .. } => {
300 Self::collect_from_expression(left, required);
301 Self::collect_from_expression(right, required);
302 }
303 LogicalExpression::Unary { operand, .. } => {
304 Self::collect_from_expression(operand, required);
305 }
306 LogicalExpression::FunctionCall { args, .. } => {
307 for arg in args {
308 Self::collect_from_expression(arg, required);
309 }
310 }
311 LogicalExpression::List(items) => {
312 for item in items {
313 Self::collect_from_expression(item, required);
314 }
315 }
316 LogicalExpression::Map(pairs) => {
317 for (_, value) in pairs {
318 Self::collect_from_expression(value, required);
319 }
320 }
321 LogicalExpression::IndexAccess { base, index } => {
322 Self::collect_from_expression(base, required);
323 Self::collect_from_expression(index, required);
324 }
325 LogicalExpression::SliceAccess { base, start, end } => {
326 Self::collect_from_expression(base, required);
327 if let Some(s) = start {
328 Self::collect_from_expression(s, required);
329 }
330 if let Some(e) = end {
331 Self::collect_from_expression(e, required);
332 }
333 }
334 LogicalExpression::Case {
335 operand,
336 when_clauses,
337 else_clause,
338 } => {
339 if let Some(op) = operand {
340 Self::collect_from_expression(op, required);
341 }
342 for (cond, result) in when_clauses {
343 Self::collect_from_expression(cond, required);
344 Self::collect_from_expression(result, required);
345 }
346 if let Some(else_expr) = else_clause {
347 Self::collect_from_expression(else_expr, required);
348 }
349 }
350 LogicalExpression::Labels(var)
351 | LogicalExpression::Type(var)
352 | LogicalExpression::Id(var) => {
353 required.insert(RequiredColumn::Variable(var.clone()));
354 }
355 LogicalExpression::ListComprehension {
356 list_expr,
357 filter_expr,
358 map_expr,
359 ..
360 } => {
361 Self::collect_from_expression(list_expr, required);
362 if let Some(filter) = filter_expr {
363 Self::collect_from_expression(filter, required);
364 }
365 Self::collect_from_expression(map_expr, required);
366 }
367 _ => {}
368 }
369 }
370
371 fn push_projections_recursive(
373 &self,
374 op: LogicalOperator,
375 required: &HashSet<RequiredColumn>,
376 ) -> LogicalOperator {
377 match op {
378 LogicalOperator::Return(mut ret) => {
379 ret.input = Box::new(self.push_projections_recursive(*ret.input, required));
380 LogicalOperator::Return(ret)
381 }
382 LogicalOperator::Project(mut proj) => {
383 proj.input = Box::new(self.push_projections_recursive(*proj.input, required));
384 LogicalOperator::Project(proj)
385 }
386 LogicalOperator::Filter(mut filter) => {
387 filter.input = Box::new(self.push_projections_recursive(*filter.input, required));
388 LogicalOperator::Filter(filter)
389 }
390 LogicalOperator::Sort(mut sort) => {
391 sort.input = Box::new(self.push_projections_recursive(*sort.input, required));
394 LogicalOperator::Sort(sort)
395 }
396 LogicalOperator::Aggregate(mut agg) => {
397 agg.input = Box::new(self.push_projections_recursive(*agg.input, required));
398 LogicalOperator::Aggregate(agg)
399 }
400 LogicalOperator::Join(mut join) => {
401 let left_vars = self.collect_output_variables(&join.left);
404 let right_vars = self.collect_output_variables(&join.right);
405
406 let left_required: HashSet<_> = required
408 .iter()
409 .filter(|c| match c {
410 RequiredColumn::Variable(v) => left_vars.contains(v),
411 RequiredColumn::Property(v, _) => left_vars.contains(v),
412 })
413 .cloned()
414 .collect();
415
416 let right_required: HashSet<_> = required
417 .iter()
418 .filter(|c| match c {
419 RequiredColumn::Variable(v) => right_vars.contains(v),
420 RequiredColumn::Property(v, _) => right_vars.contains(v),
421 })
422 .cloned()
423 .collect();
424
425 join.left = Box::new(self.push_projections_recursive(*join.left, &left_required));
426 join.right =
427 Box::new(self.push_projections_recursive(*join.right, &right_required));
428 LogicalOperator::Join(join)
429 }
430 LogicalOperator::Expand(mut expand) => {
431 expand.input = Box::new(self.push_projections_recursive(*expand.input, required));
432 LogicalOperator::Expand(expand)
433 }
434 LogicalOperator::Limit(mut limit) => {
435 limit.input = Box::new(self.push_projections_recursive(*limit.input, required));
436 LogicalOperator::Limit(limit)
437 }
438 LogicalOperator::Skip(mut skip) => {
439 skip.input = Box::new(self.push_projections_recursive(*skip.input, required));
440 LogicalOperator::Skip(skip)
441 }
442 LogicalOperator::Distinct(mut distinct) => {
443 distinct.input =
444 Box::new(self.push_projections_recursive(*distinct.input, required));
445 LogicalOperator::Distinct(distinct)
446 }
447 other => other,
448 }
449 }
450
451 fn reorder_joins(&self, op: LogicalOperator) -> LogicalOperator {
458 let op = self.reorder_joins_recursive(op);
460
461 if let Some((relations, conditions)) = self.extract_join_tree(&op)
463 && relations.len() >= 2
464 && let Some(optimized) = self.optimize_join_order(&relations, &conditions)
465 {
466 return optimized;
467 }
468
469 op
470 }
471
472 fn reorder_joins_recursive(&self, op: LogicalOperator) -> LogicalOperator {
474 match op {
475 LogicalOperator::Return(mut ret) => {
476 ret.input = Box::new(self.reorder_joins(*ret.input));
477 LogicalOperator::Return(ret)
478 }
479 LogicalOperator::Project(mut proj) => {
480 proj.input = Box::new(self.reorder_joins(*proj.input));
481 LogicalOperator::Project(proj)
482 }
483 LogicalOperator::Filter(mut filter) => {
484 filter.input = Box::new(self.reorder_joins(*filter.input));
485 LogicalOperator::Filter(filter)
486 }
487 LogicalOperator::Limit(mut limit) => {
488 limit.input = Box::new(self.reorder_joins(*limit.input));
489 LogicalOperator::Limit(limit)
490 }
491 LogicalOperator::Skip(mut skip) => {
492 skip.input = Box::new(self.reorder_joins(*skip.input));
493 LogicalOperator::Skip(skip)
494 }
495 LogicalOperator::Sort(mut sort) => {
496 sort.input = Box::new(self.reorder_joins(*sort.input));
497 LogicalOperator::Sort(sort)
498 }
499 LogicalOperator::Distinct(mut distinct) => {
500 distinct.input = Box::new(self.reorder_joins(*distinct.input));
501 LogicalOperator::Distinct(distinct)
502 }
503 LogicalOperator::Aggregate(mut agg) => {
504 agg.input = Box::new(self.reorder_joins(*agg.input));
505 LogicalOperator::Aggregate(agg)
506 }
507 LogicalOperator::Expand(mut expand) => {
508 expand.input = Box::new(self.reorder_joins(*expand.input));
509 LogicalOperator::Expand(expand)
510 }
511 other => other,
513 }
514 }
515
516 fn extract_join_tree(
520 &self,
521 op: &LogicalOperator,
522 ) -> Option<(Vec<(String, LogicalOperator)>, Vec<JoinInfo>)> {
523 let mut relations = Vec::new();
524 let mut join_conditions = Vec::new();
525
526 if !self.collect_join_tree(op, &mut relations, &mut join_conditions) {
527 return None;
528 }
529
530 if relations.len() < 2 {
531 return None;
532 }
533
534 Some((relations, join_conditions))
535 }
536
537 fn collect_join_tree(
541 &self,
542 op: &LogicalOperator,
543 relations: &mut Vec<(String, LogicalOperator)>,
544 conditions: &mut Vec<JoinInfo>,
545 ) -> bool {
546 match op {
547 LogicalOperator::Join(join) => {
548 let left_ok = self.collect_join_tree(&join.left, relations, conditions);
550 let right_ok = self.collect_join_tree(&join.right, relations, conditions);
551
552 for cond in &join.conditions {
554 if let (Some(left_var), Some(right_var)) = (
555 self.extract_variable_from_expr(&cond.left),
556 self.extract_variable_from_expr(&cond.right),
557 ) {
558 conditions.push(JoinInfo {
559 left_var,
560 right_var,
561 left_expr: cond.left.clone(),
562 right_expr: cond.right.clone(),
563 });
564 }
565 }
566
567 left_ok && right_ok
568 }
569 LogicalOperator::NodeScan(scan) => {
570 relations.push((scan.variable.clone(), op.clone()));
571 true
572 }
573 LogicalOperator::EdgeScan(scan) => {
574 relations.push((scan.variable.clone(), op.clone()));
575 true
576 }
577 LogicalOperator::Filter(filter) => {
578 self.collect_join_tree(&filter.input, relations, conditions)
580 }
581 LogicalOperator::Expand(expand) => {
582 relations.push((expand.to_variable.clone(), op.clone()));
585 true
586 }
587 _ => false,
588 }
589 }
590
591 fn extract_variable_from_expr(&self, expr: &LogicalExpression) -> Option<String> {
593 match expr {
594 LogicalExpression::Variable(v) => Some(v.clone()),
595 LogicalExpression::Property { variable, .. } => Some(variable.clone()),
596 _ => None,
597 }
598 }
599
600 fn optimize_join_order(
602 &self,
603 relations: &[(String, LogicalOperator)],
604 conditions: &[JoinInfo],
605 ) -> Option<LogicalOperator> {
606 use join_order::{DPccp, JoinGraphBuilder};
607
608 let mut builder = JoinGraphBuilder::new();
610
611 for (var, relation) in relations {
612 builder.add_relation(var, relation.clone());
613 }
614
615 for cond in conditions {
616 builder.add_join_condition(
617 &cond.left_var,
618 &cond.right_var,
619 cond.left_expr.clone(),
620 cond.right_expr.clone(),
621 );
622 }
623
624 let graph = builder.build();
625
626 let mut dpccp = DPccp::new(&graph, &self.cost_model, &self.card_estimator);
628 let plan = dpccp.optimize()?;
629
630 Some(plan.operator)
631 }
632
633 fn push_filters_down(&self, op: LogicalOperator) -> LogicalOperator {
638 match op {
639 LogicalOperator::Filter(filter) => {
641 let optimized_input = self.push_filters_down(*filter.input);
642 self.try_push_filter_into(filter.predicate, optimized_input)
643 }
644 LogicalOperator::Return(mut ret) => {
646 ret.input = Box::new(self.push_filters_down(*ret.input));
647 LogicalOperator::Return(ret)
648 }
649 LogicalOperator::Project(mut proj) => {
650 proj.input = Box::new(self.push_filters_down(*proj.input));
651 LogicalOperator::Project(proj)
652 }
653 LogicalOperator::Limit(mut limit) => {
654 limit.input = Box::new(self.push_filters_down(*limit.input));
655 LogicalOperator::Limit(limit)
656 }
657 LogicalOperator::Skip(mut skip) => {
658 skip.input = Box::new(self.push_filters_down(*skip.input));
659 LogicalOperator::Skip(skip)
660 }
661 LogicalOperator::Sort(mut sort) => {
662 sort.input = Box::new(self.push_filters_down(*sort.input));
663 LogicalOperator::Sort(sort)
664 }
665 LogicalOperator::Distinct(mut distinct) => {
666 distinct.input = Box::new(self.push_filters_down(*distinct.input));
667 LogicalOperator::Distinct(distinct)
668 }
669 LogicalOperator::Expand(mut expand) => {
670 expand.input = Box::new(self.push_filters_down(*expand.input));
671 LogicalOperator::Expand(expand)
672 }
673 LogicalOperator::Join(mut join) => {
674 join.left = Box::new(self.push_filters_down(*join.left));
675 join.right = Box::new(self.push_filters_down(*join.right));
676 LogicalOperator::Join(join)
677 }
678 LogicalOperator::Aggregate(mut agg) => {
679 agg.input = Box::new(self.push_filters_down(*agg.input));
680 LogicalOperator::Aggregate(agg)
681 }
682 other => other,
684 }
685 }
686
687 fn try_push_filter_into(
692 &self,
693 predicate: LogicalExpression,
694 op: LogicalOperator,
695 ) -> LogicalOperator {
696 match op {
697 LogicalOperator::Project(mut proj) => {
699 let predicate_vars = self.extract_variables(&predicate);
700 let computed_vars = self.extract_projection_aliases(&proj.projections);
701
702 if predicate_vars.is_disjoint(&computed_vars) {
704 proj.input = Box::new(self.try_push_filter_into(predicate, *proj.input));
705 LogicalOperator::Project(proj)
706 } else {
707 LogicalOperator::Filter(FilterOp {
709 predicate,
710 input: Box::new(LogicalOperator::Project(proj)),
711 })
712 }
713 }
714
715 LogicalOperator::Return(mut ret) => {
717 ret.input = Box::new(self.try_push_filter_into(predicate, *ret.input));
718 LogicalOperator::Return(ret)
719 }
720
721 LogicalOperator::Expand(mut expand) => {
723 let predicate_vars = self.extract_variables(&predicate);
724
725 let mut introduced_vars = vec![&expand.to_variable];
730 if let Some(ref edge_var) = expand.edge_variable {
731 introduced_vars.push(edge_var);
732 }
733 if let Some(ref path_alias) = expand.path_alias {
734 introduced_vars.push(path_alias);
735 }
736
737 let uses_introduced_vars =
739 predicate_vars.iter().any(|v| introduced_vars.contains(&v));
740
741 if !uses_introduced_vars {
742 expand.input = Box::new(self.try_push_filter_into(predicate, *expand.input));
744 LogicalOperator::Expand(expand)
745 } else {
746 LogicalOperator::Filter(FilterOp {
748 predicate,
749 input: Box::new(LogicalOperator::Expand(expand)),
750 })
751 }
752 }
753
754 LogicalOperator::Join(mut join) => {
756 let predicate_vars = self.extract_variables(&predicate);
757 let left_vars = self.collect_output_variables(&join.left);
758 let right_vars = self.collect_output_variables(&join.right);
759
760 let uses_left = predicate_vars.iter().any(|v| left_vars.contains(v));
761 let uses_right = predicate_vars.iter().any(|v| right_vars.contains(v));
762
763 if uses_left && !uses_right {
764 join.left = Box::new(self.try_push_filter_into(predicate, *join.left));
766 LogicalOperator::Join(join)
767 } else if uses_right && !uses_left {
768 join.right = Box::new(self.try_push_filter_into(predicate, *join.right));
770 LogicalOperator::Join(join)
771 } else {
772 LogicalOperator::Filter(FilterOp {
774 predicate,
775 input: Box::new(LogicalOperator::Join(join)),
776 })
777 }
778 }
779
780 LogicalOperator::Aggregate(agg) => LogicalOperator::Filter(FilterOp {
782 predicate,
783 input: Box::new(LogicalOperator::Aggregate(agg)),
784 }),
785
786 LogicalOperator::NodeScan(scan) => LogicalOperator::Filter(FilterOp {
788 predicate,
789 input: Box::new(LogicalOperator::NodeScan(scan)),
790 }),
791
792 other => LogicalOperator::Filter(FilterOp {
794 predicate,
795 input: Box::new(other),
796 }),
797 }
798 }
799
800 fn collect_output_variables(&self, op: &LogicalOperator) -> HashSet<String> {
802 let mut vars = HashSet::new();
803 Self::collect_output_variables_recursive(op, &mut vars);
804 vars
805 }
806
807 fn collect_output_variables_recursive(op: &LogicalOperator, vars: &mut HashSet<String>) {
809 match op {
810 LogicalOperator::NodeScan(scan) => {
811 vars.insert(scan.variable.clone());
812 }
813 LogicalOperator::EdgeScan(scan) => {
814 vars.insert(scan.variable.clone());
815 }
816 LogicalOperator::Expand(expand) => {
817 vars.insert(expand.to_variable.clone());
818 if let Some(edge_var) = &expand.edge_variable {
819 vars.insert(edge_var.clone());
820 }
821 Self::collect_output_variables_recursive(&expand.input, vars);
822 }
823 LogicalOperator::Filter(filter) => {
824 Self::collect_output_variables_recursive(&filter.input, vars);
825 }
826 LogicalOperator::Project(proj) => {
827 for p in &proj.projections {
828 if let Some(alias) = &p.alias {
829 vars.insert(alias.clone());
830 }
831 }
832 Self::collect_output_variables_recursive(&proj.input, vars);
833 }
834 LogicalOperator::Join(join) => {
835 Self::collect_output_variables_recursive(&join.left, vars);
836 Self::collect_output_variables_recursive(&join.right, vars);
837 }
838 LogicalOperator::Aggregate(agg) => {
839 for expr in &agg.group_by {
840 Self::collect_variables(expr, vars);
841 }
842 for agg_expr in &agg.aggregates {
843 if let Some(alias) = &agg_expr.alias {
844 vars.insert(alias.clone());
845 }
846 }
847 }
848 LogicalOperator::Return(ret) => {
849 Self::collect_output_variables_recursive(&ret.input, vars);
850 }
851 LogicalOperator::Limit(limit) => {
852 Self::collect_output_variables_recursive(&limit.input, vars);
853 }
854 LogicalOperator::Skip(skip) => {
855 Self::collect_output_variables_recursive(&skip.input, vars);
856 }
857 LogicalOperator::Sort(sort) => {
858 Self::collect_output_variables_recursive(&sort.input, vars);
859 }
860 LogicalOperator::Distinct(distinct) => {
861 Self::collect_output_variables_recursive(&distinct.input, vars);
862 }
863 _ => {}
864 }
865 }
866
867 fn extract_variables(&self, expr: &LogicalExpression) -> HashSet<String> {
869 let mut vars = HashSet::new();
870 Self::collect_variables(expr, &mut vars);
871 vars
872 }
873
874 fn collect_variables(expr: &LogicalExpression, vars: &mut HashSet<String>) {
876 match expr {
877 LogicalExpression::Variable(name) => {
878 vars.insert(name.clone());
879 }
880 LogicalExpression::Property { variable, .. } => {
881 vars.insert(variable.clone());
882 }
883 LogicalExpression::Binary { left, right, .. } => {
884 Self::collect_variables(left, vars);
885 Self::collect_variables(right, vars);
886 }
887 LogicalExpression::Unary { operand, .. } => {
888 Self::collect_variables(operand, vars);
889 }
890 LogicalExpression::FunctionCall { args, .. } => {
891 for arg in args {
892 Self::collect_variables(arg, vars);
893 }
894 }
895 LogicalExpression::List(items) => {
896 for item in items {
897 Self::collect_variables(item, vars);
898 }
899 }
900 LogicalExpression::Map(pairs) => {
901 for (_, value) in pairs {
902 Self::collect_variables(value, vars);
903 }
904 }
905 LogicalExpression::IndexAccess { base, index } => {
906 Self::collect_variables(base, vars);
907 Self::collect_variables(index, vars);
908 }
909 LogicalExpression::SliceAccess { base, start, end } => {
910 Self::collect_variables(base, vars);
911 if let Some(s) = start {
912 Self::collect_variables(s, vars);
913 }
914 if let Some(e) = end {
915 Self::collect_variables(e, vars);
916 }
917 }
918 LogicalExpression::Case {
919 operand,
920 when_clauses,
921 else_clause,
922 } => {
923 if let Some(op) = operand {
924 Self::collect_variables(op, vars);
925 }
926 for (cond, result) in when_clauses {
927 Self::collect_variables(cond, vars);
928 Self::collect_variables(result, vars);
929 }
930 if let Some(else_expr) = else_clause {
931 Self::collect_variables(else_expr, vars);
932 }
933 }
934 LogicalExpression::Labels(var)
935 | LogicalExpression::Type(var)
936 | LogicalExpression::Id(var) => {
937 vars.insert(var.clone());
938 }
939 LogicalExpression::Literal(_) | LogicalExpression::Parameter(_) => {}
940 LogicalExpression::ListComprehension {
941 list_expr,
942 filter_expr,
943 map_expr,
944 ..
945 } => {
946 Self::collect_variables(list_expr, vars);
947 if let Some(filter) = filter_expr {
948 Self::collect_variables(filter, vars);
949 }
950 Self::collect_variables(map_expr, vars);
951 }
952 LogicalExpression::ExistsSubquery(_) | LogicalExpression::CountSubquery(_) => {
953 }
955 }
956 }
957
958 fn extract_projection_aliases(
960 &self,
961 projections: &[crate::query::plan::Projection],
962 ) -> HashSet<String> {
963 projections.iter().filter_map(|p| p.alias.clone()).collect()
964 }
965}
966
967impl Default for Optimizer {
968 fn default() -> Self {
969 Self::new()
970 }
971}
972
973#[cfg(test)]
974mod tests {
975 use super::*;
976 use crate::query::plan::{
977 AggregateExpr, AggregateFunction, AggregateOp, BinaryOp, DistinctOp, ExpandDirection,
978 ExpandOp, JoinOp, JoinType, LimitOp, NodeScanOp, ProjectOp, Projection, ReturnItem,
979 ReturnOp, SkipOp, SortKey, SortOp, SortOrder, UnaryOp,
980 };
981 use grafeo_common::types::Value;
982
983 #[test]
984 fn test_optimizer_filter_pushdown_simple() {
985 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
990 items: vec![ReturnItem {
991 expression: LogicalExpression::Variable("n".to_string()),
992 alias: None,
993 }],
994 distinct: false,
995 input: Box::new(LogicalOperator::Filter(FilterOp {
996 predicate: LogicalExpression::Binary {
997 left: Box::new(LogicalExpression::Property {
998 variable: "n".to_string(),
999 property: "age".to_string(),
1000 }),
1001 op: BinaryOp::Gt,
1002 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1003 },
1004 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1005 variable: "n".to_string(),
1006 label: Some("Person".to_string()),
1007 input: None,
1008 })),
1009 })),
1010 }));
1011
1012 let optimizer = Optimizer::new();
1013 let optimized = optimizer.optimize(plan).unwrap();
1014
1015 if let LogicalOperator::Return(ret) = &optimized.root
1017 && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1018 && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1019 {
1020 assert_eq!(scan.variable, "n");
1021 return;
1022 }
1023 panic!("Expected Return -> Filter -> NodeScan structure");
1024 }
1025
1026 #[test]
1027 fn test_optimizer_filter_pushdown_through_expand() {
1028 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1032 items: vec![ReturnItem {
1033 expression: LogicalExpression::Variable("b".to_string()),
1034 alias: None,
1035 }],
1036 distinct: false,
1037 input: Box::new(LogicalOperator::Filter(FilterOp {
1038 predicate: LogicalExpression::Binary {
1039 left: Box::new(LogicalExpression::Property {
1040 variable: "a".to_string(),
1041 property: "age".to_string(),
1042 }),
1043 op: BinaryOp::Gt,
1044 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1045 },
1046 input: Box::new(LogicalOperator::Expand(ExpandOp {
1047 from_variable: "a".to_string(),
1048 to_variable: "b".to_string(),
1049 edge_variable: None,
1050 direction: ExpandDirection::Outgoing,
1051 edge_type: Some("KNOWS".to_string()),
1052 min_hops: 1,
1053 max_hops: Some(1),
1054 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1055 variable: "a".to_string(),
1056 label: Some("Person".to_string()),
1057 input: None,
1058 })),
1059 path_alias: None,
1060 })),
1061 })),
1062 }));
1063
1064 let optimizer = Optimizer::new();
1065 let optimized = optimizer.optimize(plan).unwrap();
1066
1067 if let LogicalOperator::Return(ret) = &optimized.root
1070 && let LogicalOperator::Expand(expand) = ret.input.as_ref()
1071 && let LogicalOperator::Filter(filter) = expand.input.as_ref()
1072 && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1073 {
1074 assert_eq!(scan.variable, "a");
1075 assert_eq!(expand.from_variable, "a");
1076 assert_eq!(expand.to_variable, "b");
1077 return;
1078 }
1079 panic!("Expected Return -> Expand -> Filter -> NodeScan structure");
1080 }
1081
1082 #[test]
1083 fn test_optimizer_filter_not_pushed_through_expand_for_target_var() {
1084 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1088 items: vec![ReturnItem {
1089 expression: LogicalExpression::Variable("a".to_string()),
1090 alias: None,
1091 }],
1092 distinct: false,
1093 input: Box::new(LogicalOperator::Filter(FilterOp {
1094 predicate: LogicalExpression::Binary {
1095 left: Box::new(LogicalExpression::Property {
1096 variable: "b".to_string(),
1097 property: "age".to_string(),
1098 }),
1099 op: BinaryOp::Gt,
1100 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1101 },
1102 input: Box::new(LogicalOperator::Expand(ExpandOp {
1103 from_variable: "a".to_string(),
1104 to_variable: "b".to_string(),
1105 edge_variable: None,
1106 direction: ExpandDirection::Outgoing,
1107 edge_type: Some("KNOWS".to_string()),
1108 min_hops: 1,
1109 max_hops: Some(1),
1110 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1111 variable: "a".to_string(),
1112 label: Some("Person".to_string()),
1113 input: None,
1114 })),
1115 path_alias: None,
1116 })),
1117 })),
1118 }));
1119
1120 let optimizer = Optimizer::new();
1121 let optimized = optimizer.optimize(plan).unwrap();
1122
1123 if let LogicalOperator::Return(ret) = &optimized.root
1126 && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1127 {
1128 if let LogicalExpression::Binary { left, .. } = &filter.predicate
1130 && let LogicalExpression::Property { variable, .. } = left.as_ref()
1131 {
1132 assert_eq!(variable, "b");
1133 }
1134
1135 if let LogicalOperator::Expand(expand) = filter.input.as_ref()
1136 && let LogicalOperator::NodeScan(_) = expand.input.as_ref()
1137 {
1138 return;
1139 }
1140 }
1141 panic!("Expected Return -> Filter -> Expand -> NodeScan structure");
1142 }
1143
1144 #[test]
1145 fn test_optimizer_extract_variables() {
1146 let optimizer = Optimizer::new();
1147
1148 let expr = LogicalExpression::Binary {
1149 left: Box::new(LogicalExpression::Property {
1150 variable: "n".to_string(),
1151 property: "age".to_string(),
1152 }),
1153 op: BinaryOp::Gt,
1154 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1155 };
1156
1157 let vars = optimizer.extract_variables(&expr);
1158 assert_eq!(vars.len(), 1);
1159 assert!(vars.contains("n"));
1160 }
1161
1162 #[test]
1165 fn test_optimizer_default() {
1166 let optimizer = Optimizer::default();
1167 let plan = LogicalPlan::new(LogicalOperator::Empty);
1169 let result = optimizer.optimize(plan);
1170 assert!(result.is_ok());
1171 }
1172
1173 #[test]
1174 fn test_optimizer_with_filter_pushdown_disabled() {
1175 let optimizer = Optimizer::new().with_filter_pushdown(false);
1176
1177 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1178 items: vec![ReturnItem {
1179 expression: LogicalExpression::Variable("n".to_string()),
1180 alias: None,
1181 }],
1182 distinct: false,
1183 input: Box::new(LogicalOperator::Filter(FilterOp {
1184 predicate: LogicalExpression::Literal(Value::Bool(true)),
1185 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1186 variable: "n".to_string(),
1187 label: None,
1188 input: None,
1189 })),
1190 })),
1191 }));
1192
1193 let optimized = optimizer.optimize(plan).unwrap();
1194 if let LogicalOperator::Return(ret) = &optimized.root
1196 && let LogicalOperator::Filter(_) = ret.input.as_ref()
1197 {
1198 return;
1199 }
1200 panic!("Expected unchanged structure");
1201 }
1202
1203 #[test]
1204 fn test_optimizer_with_join_reorder_disabled() {
1205 let optimizer = Optimizer::new().with_join_reorder(false);
1206 assert!(
1207 optimizer
1208 .optimize(LogicalPlan::new(LogicalOperator::Empty))
1209 .is_ok()
1210 );
1211 }
1212
1213 #[test]
1214 fn test_optimizer_with_cost_model() {
1215 let cost_model = CostModel::new();
1216 let optimizer = Optimizer::new().with_cost_model(cost_model);
1217 assert!(
1218 optimizer
1219 .cost_model()
1220 .estimate(&LogicalOperator::Empty, 0.0)
1221 .total()
1222 < 0.001
1223 );
1224 }
1225
1226 #[test]
1227 fn test_optimizer_with_cardinality_estimator() {
1228 let mut estimator = CardinalityEstimator::new();
1229 estimator.add_table_stats("Test", TableStats::new(500));
1230 let optimizer = Optimizer::new().with_cardinality_estimator(estimator);
1231
1232 let scan = LogicalOperator::NodeScan(NodeScanOp {
1233 variable: "n".to_string(),
1234 label: Some("Test".to_string()),
1235 input: None,
1236 });
1237 let plan = LogicalPlan::new(scan);
1238
1239 let cardinality = optimizer.estimate_cardinality(&plan);
1240 assert!((cardinality - 500.0).abs() < 0.001);
1241 }
1242
1243 #[test]
1244 fn test_optimizer_estimate_cost() {
1245 let optimizer = Optimizer::new();
1246 let plan = LogicalPlan::new(LogicalOperator::NodeScan(NodeScanOp {
1247 variable: "n".to_string(),
1248 label: None,
1249 input: None,
1250 }));
1251
1252 let cost = optimizer.estimate_cost(&plan);
1253 assert!(cost.total() > 0.0);
1254 }
1255
1256 #[test]
1259 fn test_filter_pushdown_through_project() {
1260 let optimizer = Optimizer::new();
1261
1262 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1263 predicate: LogicalExpression::Binary {
1264 left: Box::new(LogicalExpression::Property {
1265 variable: "n".to_string(),
1266 property: "age".to_string(),
1267 }),
1268 op: BinaryOp::Gt,
1269 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1270 },
1271 input: Box::new(LogicalOperator::Project(ProjectOp {
1272 projections: vec![Projection {
1273 expression: LogicalExpression::Variable("n".to_string()),
1274 alias: None,
1275 }],
1276 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1277 variable: "n".to_string(),
1278 label: None,
1279 input: None,
1280 })),
1281 })),
1282 }));
1283
1284 let optimized = optimizer.optimize(plan).unwrap();
1285
1286 if let LogicalOperator::Project(proj) = &optimized.root
1288 && let LogicalOperator::Filter(_) = proj.input.as_ref()
1289 {
1290 return;
1291 }
1292 panic!("Expected Project -> Filter structure");
1293 }
1294
1295 #[test]
1296 fn test_filter_not_pushed_through_project_with_alias() {
1297 let optimizer = Optimizer::new();
1298
1299 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1301 predicate: LogicalExpression::Binary {
1302 left: Box::new(LogicalExpression::Variable("x".to_string())),
1303 op: BinaryOp::Gt,
1304 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1305 },
1306 input: Box::new(LogicalOperator::Project(ProjectOp {
1307 projections: vec![Projection {
1308 expression: LogicalExpression::Property {
1309 variable: "n".to_string(),
1310 property: "age".to_string(),
1311 },
1312 alias: Some("x".to_string()),
1313 }],
1314 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1315 variable: "n".to_string(),
1316 label: None,
1317 input: None,
1318 })),
1319 })),
1320 }));
1321
1322 let optimized = optimizer.optimize(plan).unwrap();
1323
1324 if let LogicalOperator::Filter(filter) = &optimized.root
1326 && let LogicalOperator::Project(_) = filter.input.as_ref()
1327 {
1328 return;
1329 }
1330 panic!("Expected Filter -> Project structure");
1331 }
1332
1333 #[test]
1334 fn test_filter_pushdown_through_limit() {
1335 let optimizer = Optimizer::new();
1336
1337 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1338 predicate: LogicalExpression::Literal(Value::Bool(true)),
1339 input: Box::new(LogicalOperator::Limit(LimitOp {
1340 count: 10,
1341 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1342 variable: "n".to_string(),
1343 label: None,
1344 input: None,
1345 })),
1346 })),
1347 }));
1348
1349 let optimized = optimizer.optimize(plan).unwrap();
1350
1351 if let LogicalOperator::Filter(filter) = &optimized.root
1353 && let LogicalOperator::Limit(_) = filter.input.as_ref()
1354 {
1355 return;
1356 }
1357 panic!("Expected Filter -> Limit structure");
1358 }
1359
1360 #[test]
1361 fn test_filter_pushdown_through_sort() {
1362 let optimizer = Optimizer::new();
1363
1364 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1365 predicate: LogicalExpression::Literal(Value::Bool(true)),
1366 input: Box::new(LogicalOperator::Sort(SortOp {
1367 keys: vec![SortKey {
1368 expression: LogicalExpression::Variable("n".to_string()),
1369 order: SortOrder::Ascending,
1370 }],
1371 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1372 variable: "n".to_string(),
1373 label: None,
1374 input: None,
1375 })),
1376 })),
1377 }));
1378
1379 let optimized = optimizer.optimize(plan).unwrap();
1380
1381 if let LogicalOperator::Filter(filter) = &optimized.root
1383 && let LogicalOperator::Sort(_) = filter.input.as_ref()
1384 {
1385 return;
1386 }
1387 panic!("Expected Filter -> Sort structure");
1388 }
1389
1390 #[test]
1391 fn test_filter_pushdown_through_distinct() {
1392 let optimizer = Optimizer::new();
1393
1394 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1395 predicate: LogicalExpression::Literal(Value::Bool(true)),
1396 input: Box::new(LogicalOperator::Distinct(DistinctOp {
1397 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1398 variable: "n".to_string(),
1399 label: None,
1400 input: None,
1401 })),
1402 columns: None,
1403 })),
1404 }));
1405
1406 let optimized = optimizer.optimize(plan).unwrap();
1407
1408 if let LogicalOperator::Filter(filter) = &optimized.root
1410 && let LogicalOperator::Distinct(_) = filter.input.as_ref()
1411 {
1412 return;
1413 }
1414 panic!("Expected Filter -> Distinct structure");
1415 }
1416
1417 #[test]
1418 fn test_filter_not_pushed_through_aggregate() {
1419 let optimizer = Optimizer::new();
1420
1421 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1422 predicate: LogicalExpression::Binary {
1423 left: Box::new(LogicalExpression::Variable("cnt".to_string())),
1424 op: BinaryOp::Gt,
1425 right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1426 },
1427 input: Box::new(LogicalOperator::Aggregate(AggregateOp {
1428 group_by: vec![],
1429 aggregates: vec![AggregateExpr {
1430 function: AggregateFunction::Count,
1431 expression: None,
1432 distinct: false,
1433 alias: Some("cnt".to_string()),
1434 percentile: None,
1435 }],
1436 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1437 variable: "n".to_string(),
1438 label: None,
1439 input: None,
1440 })),
1441 having: None,
1442 })),
1443 }));
1444
1445 let optimized = optimizer.optimize(plan).unwrap();
1446
1447 if let LogicalOperator::Filter(filter) = &optimized.root
1449 && let LogicalOperator::Aggregate(_) = filter.input.as_ref()
1450 {
1451 return;
1452 }
1453 panic!("Expected Filter -> Aggregate structure");
1454 }
1455
1456 #[test]
1457 fn test_filter_pushdown_to_left_join_side() {
1458 let optimizer = Optimizer::new();
1459
1460 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1462 predicate: LogicalExpression::Binary {
1463 left: Box::new(LogicalExpression::Property {
1464 variable: "a".to_string(),
1465 property: "age".to_string(),
1466 }),
1467 op: BinaryOp::Gt,
1468 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1469 },
1470 input: Box::new(LogicalOperator::Join(JoinOp {
1471 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1472 variable: "a".to_string(),
1473 label: Some("Person".to_string()),
1474 input: None,
1475 })),
1476 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1477 variable: "b".to_string(),
1478 label: Some("Company".to_string()),
1479 input: None,
1480 })),
1481 join_type: JoinType::Inner,
1482 conditions: vec![],
1483 })),
1484 }));
1485
1486 let optimized = optimizer.optimize(plan).unwrap();
1487
1488 if let LogicalOperator::Join(join) = &optimized.root
1490 && let LogicalOperator::Filter(_) = join.left.as_ref()
1491 {
1492 return;
1493 }
1494 panic!("Expected Join with Filter on left side");
1495 }
1496
1497 #[test]
1498 fn test_filter_pushdown_to_right_join_side() {
1499 let optimizer = Optimizer::new();
1500
1501 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1503 predicate: LogicalExpression::Binary {
1504 left: Box::new(LogicalExpression::Property {
1505 variable: "b".to_string(),
1506 property: "name".to_string(),
1507 }),
1508 op: BinaryOp::Eq,
1509 right: Box::new(LogicalExpression::Literal(Value::String("Acme".into()))),
1510 },
1511 input: Box::new(LogicalOperator::Join(JoinOp {
1512 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1513 variable: "a".to_string(),
1514 label: Some("Person".to_string()),
1515 input: None,
1516 })),
1517 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1518 variable: "b".to_string(),
1519 label: Some("Company".to_string()),
1520 input: None,
1521 })),
1522 join_type: JoinType::Inner,
1523 conditions: vec![],
1524 })),
1525 }));
1526
1527 let optimized = optimizer.optimize(plan).unwrap();
1528
1529 if let LogicalOperator::Join(join) = &optimized.root
1531 && let LogicalOperator::Filter(_) = join.right.as_ref()
1532 {
1533 return;
1534 }
1535 panic!("Expected Join with Filter on right side");
1536 }
1537
1538 #[test]
1539 fn test_filter_not_pushed_when_uses_both_join_sides() {
1540 let optimizer = Optimizer::new();
1541
1542 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1544 predicate: LogicalExpression::Binary {
1545 left: Box::new(LogicalExpression::Property {
1546 variable: "a".to_string(),
1547 property: "id".to_string(),
1548 }),
1549 op: BinaryOp::Eq,
1550 right: Box::new(LogicalExpression::Property {
1551 variable: "b".to_string(),
1552 property: "a_id".to_string(),
1553 }),
1554 },
1555 input: Box::new(LogicalOperator::Join(JoinOp {
1556 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1557 variable: "a".to_string(),
1558 label: None,
1559 input: None,
1560 })),
1561 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1562 variable: "b".to_string(),
1563 label: None,
1564 input: None,
1565 })),
1566 join_type: JoinType::Inner,
1567 conditions: vec![],
1568 })),
1569 }));
1570
1571 let optimized = optimizer.optimize(plan).unwrap();
1572
1573 if let LogicalOperator::Filter(filter) = &optimized.root
1575 && let LogicalOperator::Join(_) = filter.input.as_ref()
1576 {
1577 return;
1578 }
1579 panic!("Expected Filter -> Join structure");
1580 }
1581
1582 #[test]
1585 fn test_extract_variables_from_variable() {
1586 let optimizer = Optimizer::new();
1587 let expr = LogicalExpression::Variable("x".to_string());
1588 let vars = optimizer.extract_variables(&expr);
1589 assert_eq!(vars.len(), 1);
1590 assert!(vars.contains("x"));
1591 }
1592
1593 #[test]
1594 fn test_extract_variables_from_unary() {
1595 let optimizer = Optimizer::new();
1596 let expr = LogicalExpression::Unary {
1597 op: UnaryOp::Not,
1598 operand: Box::new(LogicalExpression::Variable("x".to_string())),
1599 };
1600 let vars = optimizer.extract_variables(&expr);
1601 assert_eq!(vars.len(), 1);
1602 assert!(vars.contains("x"));
1603 }
1604
1605 #[test]
1606 fn test_extract_variables_from_function_call() {
1607 let optimizer = Optimizer::new();
1608 let expr = LogicalExpression::FunctionCall {
1609 name: "length".to_string(),
1610 args: vec![
1611 LogicalExpression::Variable("a".to_string()),
1612 LogicalExpression::Variable("b".to_string()),
1613 ],
1614 distinct: false,
1615 };
1616 let vars = optimizer.extract_variables(&expr);
1617 assert_eq!(vars.len(), 2);
1618 assert!(vars.contains("a"));
1619 assert!(vars.contains("b"));
1620 }
1621
1622 #[test]
1623 fn test_extract_variables_from_list() {
1624 let optimizer = Optimizer::new();
1625 let expr = LogicalExpression::List(vec![
1626 LogicalExpression::Variable("a".to_string()),
1627 LogicalExpression::Literal(Value::Int64(1)),
1628 LogicalExpression::Variable("b".to_string()),
1629 ]);
1630 let vars = optimizer.extract_variables(&expr);
1631 assert_eq!(vars.len(), 2);
1632 assert!(vars.contains("a"));
1633 assert!(vars.contains("b"));
1634 }
1635
1636 #[test]
1637 fn test_extract_variables_from_map() {
1638 let optimizer = Optimizer::new();
1639 let expr = LogicalExpression::Map(vec![
1640 (
1641 "key1".to_string(),
1642 LogicalExpression::Variable("a".to_string()),
1643 ),
1644 (
1645 "key2".to_string(),
1646 LogicalExpression::Variable("b".to_string()),
1647 ),
1648 ]);
1649 let vars = optimizer.extract_variables(&expr);
1650 assert_eq!(vars.len(), 2);
1651 assert!(vars.contains("a"));
1652 assert!(vars.contains("b"));
1653 }
1654
1655 #[test]
1656 fn test_extract_variables_from_index_access() {
1657 let optimizer = Optimizer::new();
1658 let expr = LogicalExpression::IndexAccess {
1659 base: Box::new(LogicalExpression::Variable("list".to_string())),
1660 index: Box::new(LogicalExpression::Variable("idx".to_string())),
1661 };
1662 let vars = optimizer.extract_variables(&expr);
1663 assert_eq!(vars.len(), 2);
1664 assert!(vars.contains("list"));
1665 assert!(vars.contains("idx"));
1666 }
1667
1668 #[test]
1669 fn test_extract_variables_from_slice_access() {
1670 let optimizer = Optimizer::new();
1671 let expr = LogicalExpression::SliceAccess {
1672 base: Box::new(LogicalExpression::Variable("list".to_string())),
1673 start: Some(Box::new(LogicalExpression::Variable("s".to_string()))),
1674 end: Some(Box::new(LogicalExpression::Variable("e".to_string()))),
1675 };
1676 let vars = optimizer.extract_variables(&expr);
1677 assert_eq!(vars.len(), 3);
1678 assert!(vars.contains("list"));
1679 assert!(vars.contains("s"));
1680 assert!(vars.contains("e"));
1681 }
1682
1683 #[test]
1684 fn test_extract_variables_from_case() {
1685 let optimizer = Optimizer::new();
1686 let expr = LogicalExpression::Case {
1687 operand: Some(Box::new(LogicalExpression::Variable("x".to_string()))),
1688 when_clauses: vec![(
1689 LogicalExpression::Literal(Value::Int64(1)),
1690 LogicalExpression::Variable("a".to_string()),
1691 )],
1692 else_clause: Some(Box::new(LogicalExpression::Variable("b".to_string()))),
1693 };
1694 let vars = optimizer.extract_variables(&expr);
1695 assert_eq!(vars.len(), 3);
1696 assert!(vars.contains("x"));
1697 assert!(vars.contains("a"));
1698 assert!(vars.contains("b"));
1699 }
1700
1701 #[test]
1702 fn test_extract_variables_from_labels() {
1703 let optimizer = Optimizer::new();
1704 let expr = LogicalExpression::Labels("n".to_string());
1705 let vars = optimizer.extract_variables(&expr);
1706 assert_eq!(vars.len(), 1);
1707 assert!(vars.contains("n"));
1708 }
1709
1710 #[test]
1711 fn test_extract_variables_from_type() {
1712 let optimizer = Optimizer::new();
1713 let expr = LogicalExpression::Type("e".to_string());
1714 let vars = optimizer.extract_variables(&expr);
1715 assert_eq!(vars.len(), 1);
1716 assert!(vars.contains("e"));
1717 }
1718
1719 #[test]
1720 fn test_extract_variables_from_id() {
1721 let optimizer = Optimizer::new();
1722 let expr = LogicalExpression::Id("n".to_string());
1723 let vars = optimizer.extract_variables(&expr);
1724 assert_eq!(vars.len(), 1);
1725 assert!(vars.contains("n"));
1726 }
1727
1728 #[test]
1729 fn test_extract_variables_from_list_comprehension() {
1730 let optimizer = Optimizer::new();
1731 let expr = LogicalExpression::ListComprehension {
1732 variable: "x".to_string(),
1733 list_expr: Box::new(LogicalExpression::Variable("items".to_string())),
1734 filter_expr: Some(Box::new(LogicalExpression::Variable("pred".to_string()))),
1735 map_expr: Box::new(LogicalExpression::Variable("result".to_string())),
1736 };
1737 let vars = optimizer.extract_variables(&expr);
1738 assert!(vars.contains("items"));
1739 assert!(vars.contains("pred"));
1740 assert!(vars.contains("result"));
1741 }
1742
1743 #[test]
1744 fn test_extract_variables_from_literal_and_parameter() {
1745 let optimizer = Optimizer::new();
1746
1747 let literal = LogicalExpression::Literal(Value::Int64(42));
1748 assert!(optimizer.extract_variables(&literal).is_empty());
1749
1750 let param = LogicalExpression::Parameter("p".to_string());
1751 assert!(optimizer.extract_variables(¶m).is_empty());
1752 }
1753
1754 #[test]
1757 fn test_recursive_filter_pushdown_through_skip() {
1758 let optimizer = Optimizer::new();
1759
1760 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1761 items: vec![ReturnItem {
1762 expression: LogicalExpression::Variable("n".to_string()),
1763 alias: None,
1764 }],
1765 distinct: false,
1766 input: Box::new(LogicalOperator::Filter(FilterOp {
1767 predicate: LogicalExpression::Literal(Value::Bool(true)),
1768 input: Box::new(LogicalOperator::Skip(SkipOp {
1769 count: 5,
1770 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1771 variable: "n".to_string(),
1772 label: None,
1773 input: None,
1774 })),
1775 })),
1776 })),
1777 }));
1778
1779 let optimized = optimizer.optimize(plan).unwrap();
1780
1781 assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1783 }
1784
1785 #[test]
1786 fn test_nested_filter_pushdown() {
1787 let optimizer = Optimizer::new();
1788
1789 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1791 items: vec![ReturnItem {
1792 expression: LogicalExpression::Variable("n".to_string()),
1793 alias: None,
1794 }],
1795 distinct: false,
1796 input: Box::new(LogicalOperator::Filter(FilterOp {
1797 predicate: LogicalExpression::Binary {
1798 left: Box::new(LogicalExpression::Property {
1799 variable: "n".to_string(),
1800 property: "x".to_string(),
1801 }),
1802 op: BinaryOp::Gt,
1803 right: Box::new(LogicalExpression::Literal(Value::Int64(1))),
1804 },
1805 input: Box::new(LogicalOperator::Filter(FilterOp {
1806 predicate: LogicalExpression::Binary {
1807 left: Box::new(LogicalExpression::Property {
1808 variable: "n".to_string(),
1809 property: "y".to_string(),
1810 }),
1811 op: BinaryOp::Lt,
1812 right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1813 },
1814 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1815 variable: "n".to_string(),
1816 label: None,
1817 input: None,
1818 })),
1819 })),
1820 })),
1821 }));
1822
1823 let optimized = optimizer.optimize(plan).unwrap();
1824 assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1825 }
1826}