1pub mod cardinality;
15pub mod cost;
16pub mod join_order;
17
18pub use cardinality::{
19 CardinalityEstimator, ColumnStats, EstimationLog, SelectivityConfig, TableStats,
20};
21pub use cost::{Cost, CostModel};
22pub use join_order::{BitSet, DPccp, JoinGraph, JoinGraphBuilder, JoinPlan};
23
24use crate::query::plan::{FilterOp, LogicalExpression, LogicalOperator, LogicalPlan};
25use grafeo_common::utils::error::Result;
26use std::collections::HashSet;
27
28#[derive(Debug, Clone)]
30struct JoinInfo {
31 left_var: String,
32 right_var: String,
33 left_expr: LogicalExpression,
34 right_expr: LogicalExpression,
35}
36
37#[derive(Debug, Clone, PartialEq, Eq, Hash)]
39enum RequiredColumn {
40 Variable(String),
42 Property(String, String),
44}
45
46pub struct Optimizer {
51 enable_filter_pushdown: bool,
53 enable_join_reorder: bool,
55 enable_projection_pushdown: bool,
57 cost_model: CostModel,
59 card_estimator: CardinalityEstimator,
61}
62
63impl Optimizer {
64 #[must_use]
66 pub fn new() -> Self {
67 Self {
68 enable_filter_pushdown: true,
69 enable_join_reorder: true,
70 enable_projection_pushdown: true,
71 cost_model: CostModel::new(),
72 card_estimator: CardinalityEstimator::new(),
73 }
74 }
75
76 #[must_use]
82 pub fn from_store(store: &grafeo_core::graph::lpg::LpgStore) -> Self {
83 store.ensure_statistics_fresh();
84 let stats = store.statistics();
85 let estimator = CardinalityEstimator::from_statistics(&stats);
86
87 let avg_fanout = if stats.total_nodes > 0 {
89 (stats.total_edges as f64 / stats.total_nodes as f64).max(1.0)
90 } else {
91 10.0
92 };
93
94 let edge_type_degrees: std::collections::HashMap<String, (f64, f64)> = stats
96 .edge_types
97 .iter()
98 .map(|(name, et)| (name.clone(), (et.avg_out_degree, et.avg_in_degree)))
99 .collect();
100
101 Self {
102 enable_filter_pushdown: true,
103 enable_join_reorder: true,
104 enable_projection_pushdown: true,
105 cost_model: CostModel::new()
106 .with_avg_fanout(avg_fanout)
107 .with_edge_type_degrees(edge_type_degrees),
108 card_estimator: estimator,
109 }
110 }
111
112 #[must_use]
119 pub fn from_graph_store(store: &dyn grafeo_core::graph::GraphStore) -> Self {
120 let stats = store.statistics();
121 let estimator = CardinalityEstimator::from_statistics(&stats);
122
123 let avg_fanout = if stats.total_nodes > 0 {
124 (stats.total_edges as f64 / stats.total_nodes as f64).max(1.0)
125 } else {
126 10.0
127 };
128
129 let edge_type_degrees: std::collections::HashMap<String, (f64, f64)> = stats
130 .edge_types
131 .iter()
132 .map(|(name, et)| (name.clone(), (et.avg_out_degree, et.avg_in_degree)))
133 .collect();
134
135 Self {
136 enable_filter_pushdown: true,
137 enable_join_reorder: true,
138 enable_projection_pushdown: true,
139 cost_model: CostModel::new()
140 .with_avg_fanout(avg_fanout)
141 .with_edge_type_degrees(edge_type_degrees),
142 card_estimator: estimator,
143 }
144 }
145
146 pub fn with_filter_pushdown(mut self, enabled: bool) -> Self {
148 self.enable_filter_pushdown = enabled;
149 self
150 }
151
152 pub fn with_join_reorder(mut self, enabled: bool) -> Self {
154 self.enable_join_reorder = enabled;
155 self
156 }
157
158 pub fn with_projection_pushdown(mut self, enabled: bool) -> Self {
160 self.enable_projection_pushdown = enabled;
161 self
162 }
163
164 pub fn with_cost_model(mut self, cost_model: CostModel) -> Self {
166 self.cost_model = cost_model;
167 self
168 }
169
170 pub fn with_cardinality_estimator(mut self, estimator: CardinalityEstimator) -> Self {
172 self.card_estimator = estimator;
173 self
174 }
175
176 pub fn with_selectivity_config(mut self, config: SelectivityConfig) -> Self {
178 self.card_estimator = CardinalityEstimator::with_selectivity_config(config);
179 self
180 }
181
182 pub fn cost_model(&self) -> &CostModel {
184 &self.cost_model
185 }
186
187 pub fn cardinality_estimator(&self) -> &CardinalityEstimator {
189 &self.card_estimator
190 }
191
192 pub fn estimate_cost(&self, plan: &LogicalPlan) -> Cost {
194 let cardinality = self.card_estimator.estimate(&plan.root);
195 self.cost_model.estimate(&plan.root, cardinality)
196 }
197
198 pub fn estimate_cardinality(&self, plan: &LogicalPlan) -> f64 {
200 self.card_estimator.estimate(&plan.root)
201 }
202
203 pub fn optimize(&self, plan: LogicalPlan) -> Result<LogicalPlan> {
209 let mut root = plan.root;
210
211 if self.enable_filter_pushdown {
213 root = self.push_filters_down(root);
214 }
215
216 if self.enable_join_reorder {
217 root = self.reorder_joins(root);
218 }
219
220 if self.enable_projection_pushdown {
221 root = self.push_projections_down(root);
222 }
223
224 Ok(LogicalPlan::new(root))
225 }
226
227 fn push_projections_down(&self, op: LogicalOperator) -> LogicalOperator {
234 let required = self.collect_required_columns(&op);
236
237 self.push_projections_recursive(op, &required)
239 }
240
241 fn collect_required_columns(&self, op: &LogicalOperator) -> HashSet<RequiredColumn> {
243 let mut required = HashSet::new();
244 Self::collect_required_recursive(op, &mut required);
245 required
246 }
247
248 fn collect_required_recursive(op: &LogicalOperator, required: &mut HashSet<RequiredColumn>) {
250 match op {
251 LogicalOperator::Return(ret) => {
252 for item in &ret.items {
253 Self::collect_from_expression(&item.expression, required);
254 }
255 Self::collect_required_recursive(&ret.input, required);
256 }
257 LogicalOperator::Project(proj) => {
258 for p in &proj.projections {
259 Self::collect_from_expression(&p.expression, required);
260 }
261 Self::collect_required_recursive(&proj.input, required);
262 }
263 LogicalOperator::Filter(filter) => {
264 Self::collect_from_expression(&filter.predicate, required);
265 Self::collect_required_recursive(&filter.input, required);
266 }
267 LogicalOperator::Sort(sort) => {
268 for key in &sort.keys {
269 Self::collect_from_expression(&key.expression, required);
270 }
271 Self::collect_required_recursive(&sort.input, required);
272 }
273 LogicalOperator::Aggregate(agg) => {
274 for expr in &agg.group_by {
275 Self::collect_from_expression(expr, required);
276 }
277 for agg_expr in &agg.aggregates {
278 if let Some(ref expr) = agg_expr.expression {
279 Self::collect_from_expression(expr, required);
280 }
281 }
282 if let Some(ref having) = agg.having {
283 Self::collect_from_expression(having, required);
284 }
285 Self::collect_required_recursive(&agg.input, required);
286 }
287 LogicalOperator::Join(join) => {
288 for cond in &join.conditions {
289 Self::collect_from_expression(&cond.left, required);
290 Self::collect_from_expression(&cond.right, required);
291 }
292 Self::collect_required_recursive(&join.left, required);
293 Self::collect_required_recursive(&join.right, required);
294 }
295 LogicalOperator::Expand(expand) => {
296 required.insert(RequiredColumn::Variable(expand.from_variable.clone()));
298 required.insert(RequiredColumn::Variable(expand.to_variable.clone()));
299 if let Some(ref edge_var) = expand.edge_variable {
300 required.insert(RequiredColumn::Variable(edge_var.clone()));
301 }
302 Self::collect_required_recursive(&expand.input, required);
303 }
304 LogicalOperator::Limit(limit) => {
305 Self::collect_required_recursive(&limit.input, required);
306 }
307 LogicalOperator::Skip(skip) => {
308 Self::collect_required_recursive(&skip.input, required);
309 }
310 LogicalOperator::Distinct(distinct) => {
311 Self::collect_required_recursive(&distinct.input, required);
312 }
313 LogicalOperator::NodeScan(scan) => {
314 required.insert(RequiredColumn::Variable(scan.variable.clone()));
315 }
316 LogicalOperator::EdgeScan(scan) => {
317 required.insert(RequiredColumn::Variable(scan.variable.clone()));
318 }
319 _ => {}
320 }
321 }
322
323 fn collect_from_expression(expr: &LogicalExpression, required: &mut HashSet<RequiredColumn>) {
325 match expr {
326 LogicalExpression::Variable(var) => {
327 required.insert(RequiredColumn::Variable(var.clone()));
328 }
329 LogicalExpression::Property { variable, property } => {
330 required.insert(RequiredColumn::Property(variable.clone(), property.clone()));
331 required.insert(RequiredColumn::Variable(variable.clone()));
332 }
333 LogicalExpression::Binary { left, right, .. } => {
334 Self::collect_from_expression(left, required);
335 Self::collect_from_expression(right, required);
336 }
337 LogicalExpression::Unary { operand, .. } => {
338 Self::collect_from_expression(operand, required);
339 }
340 LogicalExpression::FunctionCall { args, .. } => {
341 for arg in args {
342 Self::collect_from_expression(arg, required);
343 }
344 }
345 LogicalExpression::List(items) => {
346 for item in items {
347 Self::collect_from_expression(item, required);
348 }
349 }
350 LogicalExpression::Map(pairs) => {
351 for (_, value) in pairs {
352 Self::collect_from_expression(value, required);
353 }
354 }
355 LogicalExpression::IndexAccess { base, index } => {
356 Self::collect_from_expression(base, required);
357 Self::collect_from_expression(index, required);
358 }
359 LogicalExpression::SliceAccess { base, start, end } => {
360 Self::collect_from_expression(base, required);
361 if let Some(s) = start {
362 Self::collect_from_expression(s, required);
363 }
364 if let Some(e) = end {
365 Self::collect_from_expression(e, required);
366 }
367 }
368 LogicalExpression::Case {
369 operand,
370 when_clauses,
371 else_clause,
372 } => {
373 if let Some(op) = operand {
374 Self::collect_from_expression(op, required);
375 }
376 for (cond, result) in when_clauses {
377 Self::collect_from_expression(cond, required);
378 Self::collect_from_expression(result, required);
379 }
380 if let Some(else_expr) = else_clause {
381 Self::collect_from_expression(else_expr, required);
382 }
383 }
384 LogicalExpression::Labels(var)
385 | LogicalExpression::Type(var)
386 | LogicalExpression::Id(var) => {
387 required.insert(RequiredColumn::Variable(var.clone()));
388 }
389 LogicalExpression::ListComprehension {
390 list_expr,
391 filter_expr,
392 map_expr,
393 ..
394 } => {
395 Self::collect_from_expression(list_expr, required);
396 if let Some(filter) = filter_expr {
397 Self::collect_from_expression(filter, required);
398 }
399 Self::collect_from_expression(map_expr, required);
400 }
401 _ => {}
402 }
403 }
404
405 fn push_projections_recursive(
407 &self,
408 op: LogicalOperator,
409 required: &HashSet<RequiredColumn>,
410 ) -> LogicalOperator {
411 match op {
412 LogicalOperator::Return(mut ret) => {
413 ret.input = Box::new(self.push_projections_recursive(*ret.input, required));
414 LogicalOperator::Return(ret)
415 }
416 LogicalOperator::Project(mut proj) => {
417 proj.input = Box::new(self.push_projections_recursive(*proj.input, required));
418 LogicalOperator::Project(proj)
419 }
420 LogicalOperator::Filter(mut filter) => {
421 filter.input = Box::new(self.push_projections_recursive(*filter.input, required));
422 LogicalOperator::Filter(filter)
423 }
424 LogicalOperator::Sort(mut sort) => {
425 sort.input = Box::new(self.push_projections_recursive(*sort.input, required));
428 LogicalOperator::Sort(sort)
429 }
430 LogicalOperator::Aggregate(mut agg) => {
431 agg.input = Box::new(self.push_projections_recursive(*agg.input, required));
432 LogicalOperator::Aggregate(agg)
433 }
434 LogicalOperator::Join(mut join) => {
435 let left_vars = self.collect_output_variables(&join.left);
438 let right_vars = self.collect_output_variables(&join.right);
439
440 let left_required: HashSet<_> = required
442 .iter()
443 .filter(|c| match c {
444 RequiredColumn::Variable(v) => left_vars.contains(v),
445 RequiredColumn::Property(v, _) => left_vars.contains(v),
446 })
447 .cloned()
448 .collect();
449
450 let right_required: HashSet<_> = required
451 .iter()
452 .filter(|c| match c {
453 RequiredColumn::Variable(v) => right_vars.contains(v),
454 RequiredColumn::Property(v, _) => right_vars.contains(v),
455 })
456 .cloned()
457 .collect();
458
459 join.left = Box::new(self.push_projections_recursive(*join.left, &left_required));
460 join.right =
461 Box::new(self.push_projections_recursive(*join.right, &right_required));
462 LogicalOperator::Join(join)
463 }
464 LogicalOperator::Expand(mut expand) => {
465 expand.input = Box::new(self.push_projections_recursive(*expand.input, required));
466 LogicalOperator::Expand(expand)
467 }
468 LogicalOperator::Limit(mut limit) => {
469 limit.input = Box::new(self.push_projections_recursive(*limit.input, required));
470 LogicalOperator::Limit(limit)
471 }
472 LogicalOperator::Skip(mut skip) => {
473 skip.input = Box::new(self.push_projections_recursive(*skip.input, required));
474 LogicalOperator::Skip(skip)
475 }
476 LogicalOperator::Distinct(mut distinct) => {
477 distinct.input =
478 Box::new(self.push_projections_recursive(*distinct.input, required));
479 LogicalOperator::Distinct(distinct)
480 }
481 LogicalOperator::MapCollect(mut mc) => {
482 mc.input = Box::new(self.push_projections_recursive(*mc.input, required));
483 LogicalOperator::MapCollect(mc)
484 }
485 other => other,
486 }
487 }
488
489 fn reorder_joins(&self, op: LogicalOperator) -> LogicalOperator {
496 let op = self.reorder_joins_recursive(op);
498
499 if let Some((relations, conditions)) = self.extract_join_tree(&op)
501 && relations.len() >= 2
502 && let Some(optimized) = self.optimize_join_order(&relations, &conditions)
503 {
504 return optimized;
505 }
506
507 op
508 }
509
510 fn reorder_joins_recursive(&self, op: LogicalOperator) -> LogicalOperator {
512 match op {
513 LogicalOperator::Return(mut ret) => {
514 ret.input = Box::new(self.reorder_joins(*ret.input));
515 LogicalOperator::Return(ret)
516 }
517 LogicalOperator::Project(mut proj) => {
518 proj.input = Box::new(self.reorder_joins(*proj.input));
519 LogicalOperator::Project(proj)
520 }
521 LogicalOperator::Filter(mut filter) => {
522 filter.input = Box::new(self.reorder_joins(*filter.input));
523 LogicalOperator::Filter(filter)
524 }
525 LogicalOperator::Limit(mut limit) => {
526 limit.input = Box::new(self.reorder_joins(*limit.input));
527 LogicalOperator::Limit(limit)
528 }
529 LogicalOperator::Skip(mut skip) => {
530 skip.input = Box::new(self.reorder_joins(*skip.input));
531 LogicalOperator::Skip(skip)
532 }
533 LogicalOperator::Sort(mut sort) => {
534 sort.input = Box::new(self.reorder_joins(*sort.input));
535 LogicalOperator::Sort(sort)
536 }
537 LogicalOperator::Distinct(mut distinct) => {
538 distinct.input = Box::new(self.reorder_joins(*distinct.input));
539 LogicalOperator::Distinct(distinct)
540 }
541 LogicalOperator::Aggregate(mut agg) => {
542 agg.input = Box::new(self.reorder_joins(*agg.input));
543 LogicalOperator::Aggregate(agg)
544 }
545 LogicalOperator::Expand(mut expand) => {
546 expand.input = Box::new(self.reorder_joins(*expand.input));
547 LogicalOperator::Expand(expand)
548 }
549 LogicalOperator::MapCollect(mut mc) => {
550 mc.input = Box::new(self.reorder_joins(*mc.input));
551 LogicalOperator::MapCollect(mc)
552 }
553 other => other,
555 }
556 }
557
558 fn extract_join_tree(
562 &self,
563 op: &LogicalOperator,
564 ) -> Option<(Vec<(String, LogicalOperator)>, Vec<JoinInfo>)> {
565 let mut relations = Vec::new();
566 let mut join_conditions = Vec::new();
567
568 if !self.collect_join_tree(op, &mut relations, &mut join_conditions) {
569 return None;
570 }
571
572 if relations.len() < 2 {
573 return None;
574 }
575
576 Some((relations, join_conditions))
577 }
578
579 fn collect_join_tree(
583 &self,
584 op: &LogicalOperator,
585 relations: &mut Vec<(String, LogicalOperator)>,
586 conditions: &mut Vec<JoinInfo>,
587 ) -> bool {
588 match op {
589 LogicalOperator::Join(join) => {
590 let left_ok = self.collect_join_tree(&join.left, relations, conditions);
592 let right_ok = self.collect_join_tree(&join.right, relations, conditions);
593
594 for cond in &join.conditions {
596 if let (Some(left_var), Some(right_var)) = (
597 self.extract_variable_from_expr(&cond.left),
598 self.extract_variable_from_expr(&cond.right),
599 ) {
600 conditions.push(JoinInfo {
601 left_var,
602 right_var,
603 left_expr: cond.left.clone(),
604 right_expr: cond.right.clone(),
605 });
606 }
607 }
608
609 left_ok && right_ok
610 }
611 LogicalOperator::NodeScan(scan) => {
612 relations.push((scan.variable.clone(), op.clone()));
613 true
614 }
615 LogicalOperator::EdgeScan(scan) => {
616 relations.push((scan.variable.clone(), op.clone()));
617 true
618 }
619 LogicalOperator::Filter(filter) => {
620 self.collect_join_tree(&filter.input, relations, conditions)
622 }
623 LogicalOperator::Expand(expand) => {
624 relations.push((expand.to_variable.clone(), op.clone()));
627 true
628 }
629 _ => false,
630 }
631 }
632
633 fn extract_variable_from_expr(&self, expr: &LogicalExpression) -> Option<String> {
635 match expr {
636 LogicalExpression::Variable(v) => Some(v.clone()),
637 LogicalExpression::Property { variable, .. } => Some(variable.clone()),
638 _ => None,
639 }
640 }
641
642 fn optimize_join_order(
644 &self,
645 relations: &[(String, LogicalOperator)],
646 conditions: &[JoinInfo],
647 ) -> Option<LogicalOperator> {
648 use join_order::{DPccp, JoinGraphBuilder};
649
650 let mut builder = JoinGraphBuilder::new();
652
653 for (var, relation) in relations {
654 builder.add_relation(var, relation.clone());
655 }
656
657 for cond in conditions {
658 builder.add_join_condition(
659 &cond.left_var,
660 &cond.right_var,
661 cond.left_expr.clone(),
662 cond.right_expr.clone(),
663 );
664 }
665
666 let graph = builder.build();
667
668 let mut dpccp = DPccp::new(&graph, &self.cost_model, &self.card_estimator);
670 let plan = dpccp.optimize()?;
671
672 Some(plan.operator)
673 }
674
675 fn push_filters_down(&self, op: LogicalOperator) -> LogicalOperator {
680 match op {
681 LogicalOperator::Filter(filter) => {
683 let optimized_input = self.push_filters_down(*filter.input);
684 self.try_push_filter_into(filter.predicate, optimized_input)
685 }
686 LogicalOperator::Return(mut ret) => {
688 ret.input = Box::new(self.push_filters_down(*ret.input));
689 LogicalOperator::Return(ret)
690 }
691 LogicalOperator::Project(mut proj) => {
692 proj.input = Box::new(self.push_filters_down(*proj.input));
693 LogicalOperator::Project(proj)
694 }
695 LogicalOperator::Limit(mut limit) => {
696 limit.input = Box::new(self.push_filters_down(*limit.input));
697 LogicalOperator::Limit(limit)
698 }
699 LogicalOperator::Skip(mut skip) => {
700 skip.input = Box::new(self.push_filters_down(*skip.input));
701 LogicalOperator::Skip(skip)
702 }
703 LogicalOperator::Sort(mut sort) => {
704 sort.input = Box::new(self.push_filters_down(*sort.input));
705 LogicalOperator::Sort(sort)
706 }
707 LogicalOperator::Distinct(mut distinct) => {
708 distinct.input = Box::new(self.push_filters_down(*distinct.input));
709 LogicalOperator::Distinct(distinct)
710 }
711 LogicalOperator::Expand(mut expand) => {
712 expand.input = Box::new(self.push_filters_down(*expand.input));
713 LogicalOperator::Expand(expand)
714 }
715 LogicalOperator::Join(mut join) => {
716 join.left = Box::new(self.push_filters_down(*join.left));
717 join.right = Box::new(self.push_filters_down(*join.right));
718 LogicalOperator::Join(join)
719 }
720 LogicalOperator::Aggregate(mut agg) => {
721 agg.input = Box::new(self.push_filters_down(*agg.input));
722 LogicalOperator::Aggregate(agg)
723 }
724 LogicalOperator::MapCollect(mut mc) => {
725 mc.input = Box::new(self.push_filters_down(*mc.input));
726 LogicalOperator::MapCollect(mc)
727 }
728 other => other,
730 }
731 }
732
733 fn try_push_filter_into(
738 &self,
739 predicate: LogicalExpression,
740 op: LogicalOperator,
741 ) -> LogicalOperator {
742 match op {
743 LogicalOperator::Project(mut proj) => {
745 let predicate_vars = self.extract_variables(&predicate);
746 let computed_vars = self.extract_projection_aliases(&proj.projections);
747
748 if predicate_vars.is_disjoint(&computed_vars) {
750 proj.input = Box::new(self.try_push_filter_into(predicate, *proj.input));
751 LogicalOperator::Project(proj)
752 } else {
753 LogicalOperator::Filter(FilterOp {
755 predicate,
756 input: Box::new(LogicalOperator::Project(proj)),
757 })
758 }
759 }
760
761 LogicalOperator::Return(mut ret) => {
763 ret.input = Box::new(self.try_push_filter_into(predicate, *ret.input));
764 LogicalOperator::Return(ret)
765 }
766
767 LogicalOperator::Expand(mut expand) => {
769 let predicate_vars = self.extract_variables(&predicate);
770
771 let mut introduced_vars = vec![&expand.to_variable];
776 if let Some(ref edge_var) = expand.edge_variable {
777 introduced_vars.push(edge_var);
778 }
779 if let Some(ref path_alias) = expand.path_alias {
780 introduced_vars.push(path_alias);
781 }
782
783 let uses_introduced_vars =
785 predicate_vars.iter().any(|v| introduced_vars.contains(&v));
786
787 if !uses_introduced_vars {
788 expand.input = Box::new(self.try_push_filter_into(predicate, *expand.input));
790 LogicalOperator::Expand(expand)
791 } else {
792 LogicalOperator::Filter(FilterOp {
794 predicate,
795 input: Box::new(LogicalOperator::Expand(expand)),
796 })
797 }
798 }
799
800 LogicalOperator::Join(mut join) => {
802 let predicate_vars = self.extract_variables(&predicate);
803 let left_vars = self.collect_output_variables(&join.left);
804 let right_vars = self.collect_output_variables(&join.right);
805
806 let uses_left = predicate_vars.iter().any(|v| left_vars.contains(v));
807 let uses_right = predicate_vars.iter().any(|v| right_vars.contains(v));
808
809 if uses_left && !uses_right {
810 join.left = Box::new(self.try_push_filter_into(predicate, *join.left));
812 LogicalOperator::Join(join)
813 } else if uses_right && !uses_left {
814 join.right = Box::new(self.try_push_filter_into(predicate, *join.right));
816 LogicalOperator::Join(join)
817 } else {
818 LogicalOperator::Filter(FilterOp {
820 predicate,
821 input: Box::new(LogicalOperator::Join(join)),
822 })
823 }
824 }
825
826 LogicalOperator::Aggregate(agg) => LogicalOperator::Filter(FilterOp {
828 predicate,
829 input: Box::new(LogicalOperator::Aggregate(agg)),
830 }),
831
832 LogicalOperator::NodeScan(scan) => LogicalOperator::Filter(FilterOp {
834 predicate,
835 input: Box::new(LogicalOperator::NodeScan(scan)),
836 }),
837
838 other => LogicalOperator::Filter(FilterOp {
840 predicate,
841 input: Box::new(other),
842 }),
843 }
844 }
845
846 fn collect_output_variables(&self, op: &LogicalOperator) -> HashSet<String> {
848 let mut vars = HashSet::new();
849 Self::collect_output_variables_recursive(op, &mut vars);
850 vars
851 }
852
853 fn collect_output_variables_recursive(op: &LogicalOperator, vars: &mut HashSet<String>) {
855 match op {
856 LogicalOperator::NodeScan(scan) => {
857 vars.insert(scan.variable.clone());
858 }
859 LogicalOperator::EdgeScan(scan) => {
860 vars.insert(scan.variable.clone());
861 }
862 LogicalOperator::Expand(expand) => {
863 vars.insert(expand.to_variable.clone());
864 if let Some(edge_var) = &expand.edge_variable {
865 vars.insert(edge_var.clone());
866 }
867 Self::collect_output_variables_recursive(&expand.input, vars);
868 }
869 LogicalOperator::Filter(filter) => {
870 Self::collect_output_variables_recursive(&filter.input, vars);
871 }
872 LogicalOperator::Project(proj) => {
873 for p in &proj.projections {
874 if let Some(alias) = &p.alias {
875 vars.insert(alias.clone());
876 }
877 }
878 Self::collect_output_variables_recursive(&proj.input, vars);
879 }
880 LogicalOperator::Join(join) => {
881 Self::collect_output_variables_recursive(&join.left, vars);
882 Self::collect_output_variables_recursive(&join.right, vars);
883 }
884 LogicalOperator::Aggregate(agg) => {
885 for expr in &agg.group_by {
886 Self::collect_variables(expr, vars);
887 }
888 for agg_expr in &agg.aggregates {
889 if let Some(alias) = &agg_expr.alias {
890 vars.insert(alias.clone());
891 }
892 }
893 }
894 LogicalOperator::Return(ret) => {
895 Self::collect_output_variables_recursive(&ret.input, vars);
896 }
897 LogicalOperator::Limit(limit) => {
898 Self::collect_output_variables_recursive(&limit.input, vars);
899 }
900 LogicalOperator::Skip(skip) => {
901 Self::collect_output_variables_recursive(&skip.input, vars);
902 }
903 LogicalOperator::Sort(sort) => {
904 Self::collect_output_variables_recursive(&sort.input, vars);
905 }
906 LogicalOperator::Distinct(distinct) => {
907 Self::collect_output_variables_recursive(&distinct.input, vars);
908 }
909 _ => {}
910 }
911 }
912
913 fn extract_variables(&self, expr: &LogicalExpression) -> HashSet<String> {
915 let mut vars = HashSet::new();
916 Self::collect_variables(expr, &mut vars);
917 vars
918 }
919
920 fn collect_variables(expr: &LogicalExpression, vars: &mut HashSet<String>) {
922 match expr {
923 LogicalExpression::Variable(name) => {
924 vars.insert(name.clone());
925 }
926 LogicalExpression::Property { variable, .. } => {
927 vars.insert(variable.clone());
928 }
929 LogicalExpression::Binary { left, right, .. } => {
930 Self::collect_variables(left, vars);
931 Self::collect_variables(right, vars);
932 }
933 LogicalExpression::Unary { operand, .. } => {
934 Self::collect_variables(operand, vars);
935 }
936 LogicalExpression::FunctionCall { args, .. } => {
937 for arg in args {
938 Self::collect_variables(arg, vars);
939 }
940 }
941 LogicalExpression::List(items) => {
942 for item in items {
943 Self::collect_variables(item, vars);
944 }
945 }
946 LogicalExpression::Map(pairs) => {
947 for (_, value) in pairs {
948 Self::collect_variables(value, vars);
949 }
950 }
951 LogicalExpression::IndexAccess { base, index } => {
952 Self::collect_variables(base, vars);
953 Self::collect_variables(index, vars);
954 }
955 LogicalExpression::SliceAccess { base, start, end } => {
956 Self::collect_variables(base, vars);
957 if let Some(s) = start {
958 Self::collect_variables(s, vars);
959 }
960 if let Some(e) = end {
961 Self::collect_variables(e, vars);
962 }
963 }
964 LogicalExpression::Case {
965 operand,
966 when_clauses,
967 else_clause,
968 } => {
969 if let Some(op) = operand {
970 Self::collect_variables(op, vars);
971 }
972 for (cond, result) in when_clauses {
973 Self::collect_variables(cond, vars);
974 Self::collect_variables(result, vars);
975 }
976 if let Some(else_expr) = else_clause {
977 Self::collect_variables(else_expr, vars);
978 }
979 }
980 LogicalExpression::Labels(var)
981 | LogicalExpression::Type(var)
982 | LogicalExpression::Id(var) => {
983 vars.insert(var.clone());
984 }
985 LogicalExpression::Literal(_) | LogicalExpression::Parameter(_) => {}
986 LogicalExpression::ListComprehension {
987 list_expr,
988 filter_expr,
989 map_expr,
990 ..
991 } => {
992 Self::collect_variables(list_expr, vars);
993 if let Some(filter) = filter_expr {
994 Self::collect_variables(filter, vars);
995 }
996 Self::collect_variables(map_expr, vars);
997 }
998 LogicalExpression::ListPredicate {
999 list_expr,
1000 predicate,
1001 ..
1002 } => {
1003 Self::collect_variables(list_expr, vars);
1004 Self::collect_variables(predicate, vars);
1005 }
1006 LogicalExpression::ExistsSubquery(_) | LogicalExpression::CountSubquery(_) => {
1007 }
1009 }
1010 }
1011
1012 fn extract_projection_aliases(
1014 &self,
1015 projections: &[crate::query::plan::Projection],
1016 ) -> HashSet<String> {
1017 projections.iter().filter_map(|p| p.alias.clone()).collect()
1018 }
1019}
1020
1021impl Default for Optimizer {
1022 fn default() -> Self {
1023 Self::new()
1024 }
1025}
1026
1027#[cfg(test)]
1028mod tests {
1029 use super::*;
1030 use crate::query::plan::{
1031 AggregateExpr, AggregateFunction, AggregateOp, BinaryOp, DistinctOp, ExpandDirection,
1032 ExpandOp, JoinOp, JoinType, LimitOp, NodeScanOp, ProjectOp, Projection, ReturnItem,
1033 ReturnOp, SkipOp, SortKey, SortOp, SortOrder, UnaryOp,
1034 };
1035 use grafeo_common::types::Value;
1036
1037 #[test]
1038 fn test_optimizer_filter_pushdown_simple() {
1039 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1044 items: vec![ReturnItem {
1045 expression: LogicalExpression::Variable("n".to_string()),
1046 alias: None,
1047 }],
1048 distinct: false,
1049 input: Box::new(LogicalOperator::Filter(FilterOp {
1050 predicate: LogicalExpression::Binary {
1051 left: Box::new(LogicalExpression::Property {
1052 variable: "n".to_string(),
1053 property: "age".to_string(),
1054 }),
1055 op: BinaryOp::Gt,
1056 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1057 },
1058 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1059 variable: "n".to_string(),
1060 label: Some("Person".to_string()),
1061 input: None,
1062 })),
1063 })),
1064 }));
1065
1066 let optimizer = Optimizer::new();
1067 let optimized = optimizer.optimize(plan).unwrap();
1068
1069 if let LogicalOperator::Return(ret) = &optimized.root
1071 && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1072 && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1073 {
1074 assert_eq!(scan.variable, "n");
1075 return;
1076 }
1077 panic!("Expected Return -> Filter -> NodeScan structure");
1078 }
1079
1080 #[test]
1081 fn test_optimizer_filter_pushdown_through_expand() {
1082 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1086 items: vec![ReturnItem {
1087 expression: LogicalExpression::Variable("b".to_string()),
1088 alias: None,
1089 }],
1090 distinct: false,
1091 input: Box::new(LogicalOperator::Filter(FilterOp {
1092 predicate: LogicalExpression::Binary {
1093 left: Box::new(LogicalExpression::Property {
1094 variable: "a".to_string(),
1095 property: "age".to_string(),
1096 }),
1097 op: BinaryOp::Gt,
1098 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1099 },
1100 input: Box::new(LogicalOperator::Expand(ExpandOp {
1101 from_variable: "a".to_string(),
1102 to_variable: "b".to_string(),
1103 edge_variable: None,
1104 direction: ExpandDirection::Outgoing,
1105 edge_type: Some("KNOWS".to_string()),
1106 min_hops: 1,
1107 max_hops: Some(1),
1108 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1109 variable: "a".to_string(),
1110 label: Some("Person".to_string()),
1111 input: None,
1112 })),
1113 path_alias: None,
1114 })),
1115 })),
1116 }));
1117
1118 let optimizer = Optimizer::new();
1119 let optimized = optimizer.optimize(plan).unwrap();
1120
1121 if let LogicalOperator::Return(ret) = &optimized.root
1124 && let LogicalOperator::Expand(expand) = ret.input.as_ref()
1125 && let LogicalOperator::Filter(filter) = expand.input.as_ref()
1126 && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1127 {
1128 assert_eq!(scan.variable, "a");
1129 assert_eq!(expand.from_variable, "a");
1130 assert_eq!(expand.to_variable, "b");
1131 return;
1132 }
1133 panic!("Expected Return -> Expand -> Filter -> NodeScan structure");
1134 }
1135
1136 #[test]
1137 fn test_optimizer_filter_not_pushed_through_expand_for_target_var() {
1138 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1142 items: vec![ReturnItem {
1143 expression: LogicalExpression::Variable("a".to_string()),
1144 alias: None,
1145 }],
1146 distinct: false,
1147 input: Box::new(LogicalOperator::Filter(FilterOp {
1148 predicate: LogicalExpression::Binary {
1149 left: Box::new(LogicalExpression::Property {
1150 variable: "b".to_string(),
1151 property: "age".to_string(),
1152 }),
1153 op: BinaryOp::Gt,
1154 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1155 },
1156 input: Box::new(LogicalOperator::Expand(ExpandOp {
1157 from_variable: "a".to_string(),
1158 to_variable: "b".to_string(),
1159 edge_variable: None,
1160 direction: ExpandDirection::Outgoing,
1161 edge_type: Some("KNOWS".to_string()),
1162 min_hops: 1,
1163 max_hops: Some(1),
1164 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1165 variable: "a".to_string(),
1166 label: Some("Person".to_string()),
1167 input: None,
1168 })),
1169 path_alias: None,
1170 })),
1171 })),
1172 }));
1173
1174 let optimizer = Optimizer::new();
1175 let optimized = optimizer.optimize(plan).unwrap();
1176
1177 if let LogicalOperator::Return(ret) = &optimized.root
1180 && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1181 {
1182 if let LogicalExpression::Binary { left, .. } = &filter.predicate
1184 && let LogicalExpression::Property { variable, .. } = left.as_ref()
1185 {
1186 assert_eq!(variable, "b");
1187 }
1188
1189 if let LogicalOperator::Expand(expand) = filter.input.as_ref()
1190 && let LogicalOperator::NodeScan(_) = expand.input.as_ref()
1191 {
1192 return;
1193 }
1194 }
1195 panic!("Expected Return -> Filter -> Expand -> NodeScan structure");
1196 }
1197
1198 #[test]
1199 fn test_optimizer_extract_variables() {
1200 let optimizer = Optimizer::new();
1201
1202 let expr = LogicalExpression::Binary {
1203 left: Box::new(LogicalExpression::Property {
1204 variable: "n".to_string(),
1205 property: "age".to_string(),
1206 }),
1207 op: BinaryOp::Gt,
1208 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1209 };
1210
1211 let vars = optimizer.extract_variables(&expr);
1212 assert_eq!(vars.len(), 1);
1213 assert!(vars.contains("n"));
1214 }
1215
1216 #[test]
1219 fn test_optimizer_default() {
1220 let optimizer = Optimizer::default();
1221 let plan = LogicalPlan::new(LogicalOperator::Empty);
1223 let result = optimizer.optimize(plan);
1224 assert!(result.is_ok());
1225 }
1226
1227 #[test]
1228 fn test_optimizer_with_filter_pushdown_disabled() {
1229 let optimizer = Optimizer::new().with_filter_pushdown(false);
1230
1231 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1232 items: vec![ReturnItem {
1233 expression: LogicalExpression::Variable("n".to_string()),
1234 alias: None,
1235 }],
1236 distinct: false,
1237 input: Box::new(LogicalOperator::Filter(FilterOp {
1238 predicate: LogicalExpression::Literal(Value::Bool(true)),
1239 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1240 variable: "n".to_string(),
1241 label: None,
1242 input: None,
1243 })),
1244 })),
1245 }));
1246
1247 let optimized = optimizer.optimize(plan).unwrap();
1248 if let LogicalOperator::Return(ret) = &optimized.root
1250 && let LogicalOperator::Filter(_) = ret.input.as_ref()
1251 {
1252 return;
1253 }
1254 panic!("Expected unchanged structure");
1255 }
1256
1257 #[test]
1258 fn test_optimizer_with_join_reorder_disabled() {
1259 let optimizer = Optimizer::new().with_join_reorder(false);
1260 assert!(
1261 optimizer
1262 .optimize(LogicalPlan::new(LogicalOperator::Empty))
1263 .is_ok()
1264 );
1265 }
1266
1267 #[test]
1268 fn test_optimizer_with_cost_model() {
1269 let cost_model = CostModel::new();
1270 let optimizer = Optimizer::new().with_cost_model(cost_model);
1271 assert!(
1272 optimizer
1273 .cost_model()
1274 .estimate(&LogicalOperator::Empty, 0.0)
1275 .total()
1276 < 0.001
1277 );
1278 }
1279
1280 #[test]
1281 fn test_optimizer_with_cardinality_estimator() {
1282 let mut estimator = CardinalityEstimator::new();
1283 estimator.add_table_stats("Test", TableStats::new(500));
1284 let optimizer = Optimizer::new().with_cardinality_estimator(estimator);
1285
1286 let scan = LogicalOperator::NodeScan(NodeScanOp {
1287 variable: "n".to_string(),
1288 label: Some("Test".to_string()),
1289 input: None,
1290 });
1291 let plan = LogicalPlan::new(scan);
1292
1293 let cardinality = optimizer.estimate_cardinality(&plan);
1294 assert!((cardinality - 500.0).abs() < 0.001);
1295 }
1296
1297 #[test]
1298 fn test_optimizer_estimate_cost() {
1299 let optimizer = Optimizer::new();
1300 let plan = LogicalPlan::new(LogicalOperator::NodeScan(NodeScanOp {
1301 variable: "n".to_string(),
1302 label: None,
1303 input: None,
1304 }));
1305
1306 let cost = optimizer.estimate_cost(&plan);
1307 assert!(cost.total() > 0.0);
1308 }
1309
1310 #[test]
1313 fn test_filter_pushdown_through_project() {
1314 let optimizer = Optimizer::new();
1315
1316 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1317 predicate: LogicalExpression::Binary {
1318 left: Box::new(LogicalExpression::Property {
1319 variable: "n".to_string(),
1320 property: "age".to_string(),
1321 }),
1322 op: BinaryOp::Gt,
1323 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1324 },
1325 input: Box::new(LogicalOperator::Project(ProjectOp {
1326 projections: vec![Projection {
1327 expression: LogicalExpression::Variable("n".to_string()),
1328 alias: None,
1329 }],
1330 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1331 variable: "n".to_string(),
1332 label: None,
1333 input: None,
1334 })),
1335 })),
1336 }));
1337
1338 let optimized = optimizer.optimize(plan).unwrap();
1339
1340 if let LogicalOperator::Project(proj) = &optimized.root
1342 && let LogicalOperator::Filter(_) = proj.input.as_ref()
1343 {
1344 return;
1345 }
1346 panic!("Expected Project -> Filter structure");
1347 }
1348
1349 #[test]
1350 fn test_filter_not_pushed_through_project_with_alias() {
1351 let optimizer = Optimizer::new();
1352
1353 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1355 predicate: LogicalExpression::Binary {
1356 left: Box::new(LogicalExpression::Variable("x".to_string())),
1357 op: BinaryOp::Gt,
1358 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1359 },
1360 input: Box::new(LogicalOperator::Project(ProjectOp {
1361 projections: vec![Projection {
1362 expression: LogicalExpression::Property {
1363 variable: "n".to_string(),
1364 property: "age".to_string(),
1365 },
1366 alias: Some("x".to_string()),
1367 }],
1368 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1369 variable: "n".to_string(),
1370 label: None,
1371 input: None,
1372 })),
1373 })),
1374 }));
1375
1376 let optimized = optimizer.optimize(plan).unwrap();
1377
1378 if let LogicalOperator::Filter(filter) = &optimized.root
1380 && let LogicalOperator::Project(_) = filter.input.as_ref()
1381 {
1382 return;
1383 }
1384 panic!("Expected Filter -> Project structure");
1385 }
1386
1387 #[test]
1388 fn test_filter_pushdown_through_limit() {
1389 let optimizer = Optimizer::new();
1390
1391 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1392 predicate: LogicalExpression::Literal(Value::Bool(true)),
1393 input: Box::new(LogicalOperator::Limit(LimitOp {
1394 count: 10,
1395 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1396 variable: "n".to_string(),
1397 label: None,
1398 input: None,
1399 })),
1400 })),
1401 }));
1402
1403 let optimized = optimizer.optimize(plan).unwrap();
1404
1405 if let LogicalOperator::Filter(filter) = &optimized.root
1407 && let LogicalOperator::Limit(_) = filter.input.as_ref()
1408 {
1409 return;
1410 }
1411 panic!("Expected Filter -> Limit structure");
1412 }
1413
1414 #[test]
1415 fn test_filter_pushdown_through_sort() {
1416 let optimizer = Optimizer::new();
1417
1418 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1419 predicate: LogicalExpression::Literal(Value::Bool(true)),
1420 input: Box::new(LogicalOperator::Sort(SortOp {
1421 keys: vec![SortKey {
1422 expression: LogicalExpression::Variable("n".to_string()),
1423 order: SortOrder::Ascending,
1424 }],
1425 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1426 variable: "n".to_string(),
1427 label: None,
1428 input: None,
1429 })),
1430 })),
1431 }));
1432
1433 let optimized = optimizer.optimize(plan).unwrap();
1434
1435 if let LogicalOperator::Filter(filter) = &optimized.root
1437 && let LogicalOperator::Sort(_) = filter.input.as_ref()
1438 {
1439 return;
1440 }
1441 panic!("Expected Filter -> Sort structure");
1442 }
1443
1444 #[test]
1445 fn test_filter_pushdown_through_distinct() {
1446 let optimizer = Optimizer::new();
1447
1448 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1449 predicate: LogicalExpression::Literal(Value::Bool(true)),
1450 input: Box::new(LogicalOperator::Distinct(DistinctOp {
1451 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1452 variable: "n".to_string(),
1453 label: None,
1454 input: None,
1455 })),
1456 columns: None,
1457 })),
1458 }));
1459
1460 let optimized = optimizer.optimize(plan).unwrap();
1461
1462 if let LogicalOperator::Filter(filter) = &optimized.root
1464 && let LogicalOperator::Distinct(_) = filter.input.as_ref()
1465 {
1466 return;
1467 }
1468 panic!("Expected Filter -> Distinct structure");
1469 }
1470
1471 #[test]
1472 fn test_filter_not_pushed_through_aggregate() {
1473 let optimizer = Optimizer::new();
1474
1475 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1476 predicate: LogicalExpression::Binary {
1477 left: Box::new(LogicalExpression::Variable("cnt".to_string())),
1478 op: BinaryOp::Gt,
1479 right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1480 },
1481 input: Box::new(LogicalOperator::Aggregate(AggregateOp {
1482 group_by: vec![],
1483 aggregates: vec![AggregateExpr {
1484 function: AggregateFunction::Count,
1485 expression: None,
1486 distinct: false,
1487 alias: Some("cnt".to_string()),
1488 percentile: None,
1489 }],
1490 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1491 variable: "n".to_string(),
1492 label: None,
1493 input: None,
1494 })),
1495 having: None,
1496 })),
1497 }));
1498
1499 let optimized = optimizer.optimize(plan).unwrap();
1500
1501 if let LogicalOperator::Filter(filter) = &optimized.root
1503 && let LogicalOperator::Aggregate(_) = filter.input.as_ref()
1504 {
1505 return;
1506 }
1507 panic!("Expected Filter -> Aggregate structure");
1508 }
1509
1510 #[test]
1511 fn test_filter_pushdown_to_left_join_side() {
1512 let optimizer = Optimizer::new();
1513
1514 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1516 predicate: LogicalExpression::Binary {
1517 left: Box::new(LogicalExpression::Property {
1518 variable: "a".to_string(),
1519 property: "age".to_string(),
1520 }),
1521 op: BinaryOp::Gt,
1522 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1523 },
1524 input: Box::new(LogicalOperator::Join(JoinOp {
1525 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1526 variable: "a".to_string(),
1527 label: Some("Person".to_string()),
1528 input: None,
1529 })),
1530 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1531 variable: "b".to_string(),
1532 label: Some("Company".to_string()),
1533 input: None,
1534 })),
1535 join_type: JoinType::Inner,
1536 conditions: vec![],
1537 })),
1538 }));
1539
1540 let optimized = optimizer.optimize(plan).unwrap();
1541
1542 if let LogicalOperator::Join(join) = &optimized.root
1544 && let LogicalOperator::Filter(_) = join.left.as_ref()
1545 {
1546 return;
1547 }
1548 panic!("Expected Join with Filter on left side");
1549 }
1550
1551 #[test]
1552 fn test_filter_pushdown_to_right_join_side() {
1553 let optimizer = Optimizer::new();
1554
1555 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1557 predicate: LogicalExpression::Binary {
1558 left: Box::new(LogicalExpression::Property {
1559 variable: "b".to_string(),
1560 property: "name".to_string(),
1561 }),
1562 op: BinaryOp::Eq,
1563 right: Box::new(LogicalExpression::Literal(Value::String("Acme".into()))),
1564 },
1565 input: Box::new(LogicalOperator::Join(JoinOp {
1566 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1567 variable: "a".to_string(),
1568 label: Some("Person".to_string()),
1569 input: None,
1570 })),
1571 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1572 variable: "b".to_string(),
1573 label: Some("Company".to_string()),
1574 input: None,
1575 })),
1576 join_type: JoinType::Inner,
1577 conditions: vec![],
1578 })),
1579 }));
1580
1581 let optimized = optimizer.optimize(plan).unwrap();
1582
1583 if let LogicalOperator::Join(join) = &optimized.root
1585 && let LogicalOperator::Filter(_) = join.right.as_ref()
1586 {
1587 return;
1588 }
1589 panic!("Expected Join with Filter on right side");
1590 }
1591
1592 #[test]
1593 fn test_filter_not_pushed_when_uses_both_join_sides() {
1594 let optimizer = Optimizer::new();
1595
1596 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1598 predicate: LogicalExpression::Binary {
1599 left: Box::new(LogicalExpression::Property {
1600 variable: "a".to_string(),
1601 property: "id".to_string(),
1602 }),
1603 op: BinaryOp::Eq,
1604 right: Box::new(LogicalExpression::Property {
1605 variable: "b".to_string(),
1606 property: "a_id".to_string(),
1607 }),
1608 },
1609 input: Box::new(LogicalOperator::Join(JoinOp {
1610 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1611 variable: "a".to_string(),
1612 label: None,
1613 input: None,
1614 })),
1615 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1616 variable: "b".to_string(),
1617 label: None,
1618 input: None,
1619 })),
1620 join_type: JoinType::Inner,
1621 conditions: vec![],
1622 })),
1623 }));
1624
1625 let optimized = optimizer.optimize(plan).unwrap();
1626
1627 if let LogicalOperator::Filter(filter) = &optimized.root
1629 && let LogicalOperator::Join(_) = filter.input.as_ref()
1630 {
1631 return;
1632 }
1633 panic!("Expected Filter -> Join structure");
1634 }
1635
1636 #[test]
1639 fn test_extract_variables_from_variable() {
1640 let optimizer = Optimizer::new();
1641 let expr = LogicalExpression::Variable("x".to_string());
1642 let vars = optimizer.extract_variables(&expr);
1643 assert_eq!(vars.len(), 1);
1644 assert!(vars.contains("x"));
1645 }
1646
1647 #[test]
1648 fn test_extract_variables_from_unary() {
1649 let optimizer = Optimizer::new();
1650 let expr = LogicalExpression::Unary {
1651 op: UnaryOp::Not,
1652 operand: Box::new(LogicalExpression::Variable("x".to_string())),
1653 };
1654 let vars = optimizer.extract_variables(&expr);
1655 assert_eq!(vars.len(), 1);
1656 assert!(vars.contains("x"));
1657 }
1658
1659 #[test]
1660 fn test_extract_variables_from_function_call() {
1661 let optimizer = Optimizer::new();
1662 let expr = LogicalExpression::FunctionCall {
1663 name: "length".to_string(),
1664 args: vec![
1665 LogicalExpression::Variable("a".to_string()),
1666 LogicalExpression::Variable("b".to_string()),
1667 ],
1668 distinct: false,
1669 };
1670 let vars = optimizer.extract_variables(&expr);
1671 assert_eq!(vars.len(), 2);
1672 assert!(vars.contains("a"));
1673 assert!(vars.contains("b"));
1674 }
1675
1676 #[test]
1677 fn test_extract_variables_from_list() {
1678 let optimizer = Optimizer::new();
1679 let expr = LogicalExpression::List(vec![
1680 LogicalExpression::Variable("a".to_string()),
1681 LogicalExpression::Literal(Value::Int64(1)),
1682 LogicalExpression::Variable("b".to_string()),
1683 ]);
1684 let vars = optimizer.extract_variables(&expr);
1685 assert_eq!(vars.len(), 2);
1686 assert!(vars.contains("a"));
1687 assert!(vars.contains("b"));
1688 }
1689
1690 #[test]
1691 fn test_extract_variables_from_map() {
1692 let optimizer = Optimizer::new();
1693 let expr = LogicalExpression::Map(vec![
1694 (
1695 "key1".to_string(),
1696 LogicalExpression::Variable("a".to_string()),
1697 ),
1698 (
1699 "key2".to_string(),
1700 LogicalExpression::Variable("b".to_string()),
1701 ),
1702 ]);
1703 let vars = optimizer.extract_variables(&expr);
1704 assert_eq!(vars.len(), 2);
1705 assert!(vars.contains("a"));
1706 assert!(vars.contains("b"));
1707 }
1708
1709 #[test]
1710 fn test_extract_variables_from_index_access() {
1711 let optimizer = Optimizer::new();
1712 let expr = LogicalExpression::IndexAccess {
1713 base: Box::new(LogicalExpression::Variable("list".to_string())),
1714 index: Box::new(LogicalExpression::Variable("idx".to_string())),
1715 };
1716 let vars = optimizer.extract_variables(&expr);
1717 assert_eq!(vars.len(), 2);
1718 assert!(vars.contains("list"));
1719 assert!(vars.contains("idx"));
1720 }
1721
1722 #[test]
1723 fn test_extract_variables_from_slice_access() {
1724 let optimizer = Optimizer::new();
1725 let expr = LogicalExpression::SliceAccess {
1726 base: Box::new(LogicalExpression::Variable("list".to_string())),
1727 start: Some(Box::new(LogicalExpression::Variable("s".to_string()))),
1728 end: Some(Box::new(LogicalExpression::Variable("e".to_string()))),
1729 };
1730 let vars = optimizer.extract_variables(&expr);
1731 assert_eq!(vars.len(), 3);
1732 assert!(vars.contains("list"));
1733 assert!(vars.contains("s"));
1734 assert!(vars.contains("e"));
1735 }
1736
1737 #[test]
1738 fn test_extract_variables_from_case() {
1739 let optimizer = Optimizer::new();
1740 let expr = LogicalExpression::Case {
1741 operand: Some(Box::new(LogicalExpression::Variable("x".to_string()))),
1742 when_clauses: vec![(
1743 LogicalExpression::Literal(Value::Int64(1)),
1744 LogicalExpression::Variable("a".to_string()),
1745 )],
1746 else_clause: Some(Box::new(LogicalExpression::Variable("b".to_string()))),
1747 };
1748 let vars = optimizer.extract_variables(&expr);
1749 assert_eq!(vars.len(), 3);
1750 assert!(vars.contains("x"));
1751 assert!(vars.contains("a"));
1752 assert!(vars.contains("b"));
1753 }
1754
1755 #[test]
1756 fn test_extract_variables_from_labels() {
1757 let optimizer = Optimizer::new();
1758 let expr = LogicalExpression::Labels("n".to_string());
1759 let vars = optimizer.extract_variables(&expr);
1760 assert_eq!(vars.len(), 1);
1761 assert!(vars.contains("n"));
1762 }
1763
1764 #[test]
1765 fn test_extract_variables_from_type() {
1766 let optimizer = Optimizer::new();
1767 let expr = LogicalExpression::Type("e".to_string());
1768 let vars = optimizer.extract_variables(&expr);
1769 assert_eq!(vars.len(), 1);
1770 assert!(vars.contains("e"));
1771 }
1772
1773 #[test]
1774 fn test_extract_variables_from_id() {
1775 let optimizer = Optimizer::new();
1776 let expr = LogicalExpression::Id("n".to_string());
1777 let vars = optimizer.extract_variables(&expr);
1778 assert_eq!(vars.len(), 1);
1779 assert!(vars.contains("n"));
1780 }
1781
1782 #[test]
1783 fn test_extract_variables_from_list_comprehension() {
1784 let optimizer = Optimizer::new();
1785 let expr = LogicalExpression::ListComprehension {
1786 variable: "x".to_string(),
1787 list_expr: Box::new(LogicalExpression::Variable("items".to_string())),
1788 filter_expr: Some(Box::new(LogicalExpression::Variable("pred".to_string()))),
1789 map_expr: Box::new(LogicalExpression::Variable("result".to_string())),
1790 };
1791 let vars = optimizer.extract_variables(&expr);
1792 assert!(vars.contains("items"));
1793 assert!(vars.contains("pred"));
1794 assert!(vars.contains("result"));
1795 }
1796
1797 #[test]
1798 fn test_extract_variables_from_literal_and_parameter() {
1799 let optimizer = Optimizer::new();
1800
1801 let literal = LogicalExpression::Literal(Value::Int64(42));
1802 assert!(optimizer.extract_variables(&literal).is_empty());
1803
1804 let param = LogicalExpression::Parameter("p".to_string());
1805 assert!(optimizer.extract_variables(¶m).is_empty());
1806 }
1807
1808 #[test]
1811 fn test_recursive_filter_pushdown_through_skip() {
1812 let optimizer = Optimizer::new();
1813
1814 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1815 items: vec![ReturnItem {
1816 expression: LogicalExpression::Variable("n".to_string()),
1817 alias: None,
1818 }],
1819 distinct: false,
1820 input: Box::new(LogicalOperator::Filter(FilterOp {
1821 predicate: LogicalExpression::Literal(Value::Bool(true)),
1822 input: Box::new(LogicalOperator::Skip(SkipOp {
1823 count: 5,
1824 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1825 variable: "n".to_string(),
1826 label: None,
1827 input: None,
1828 })),
1829 })),
1830 })),
1831 }));
1832
1833 let optimized = optimizer.optimize(plan).unwrap();
1834
1835 assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1837 }
1838
1839 #[test]
1840 fn test_nested_filter_pushdown() {
1841 let optimizer = Optimizer::new();
1842
1843 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1845 items: vec![ReturnItem {
1846 expression: LogicalExpression::Variable("n".to_string()),
1847 alias: None,
1848 }],
1849 distinct: false,
1850 input: Box::new(LogicalOperator::Filter(FilterOp {
1851 predicate: LogicalExpression::Binary {
1852 left: Box::new(LogicalExpression::Property {
1853 variable: "n".to_string(),
1854 property: "x".to_string(),
1855 }),
1856 op: BinaryOp::Gt,
1857 right: Box::new(LogicalExpression::Literal(Value::Int64(1))),
1858 },
1859 input: Box::new(LogicalOperator::Filter(FilterOp {
1860 predicate: LogicalExpression::Binary {
1861 left: Box::new(LogicalExpression::Property {
1862 variable: "n".to_string(),
1863 property: "y".to_string(),
1864 }),
1865 op: BinaryOp::Lt,
1866 right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1867 },
1868 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1869 variable: "n".to_string(),
1870 label: None,
1871 input: None,
1872 })),
1873 })),
1874 })),
1875 }));
1876
1877 let optimized = optimizer.optimize(plan).unwrap();
1878 assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1879 }
1880}