1pub mod cardinality;
15pub mod cost;
16pub mod join_order;
17
18pub use cardinality::{
19 CardinalityEstimator, ColumnStats, EstimationLog, SelectivityConfig, TableStats,
20};
21pub use cost::{Cost, CostModel};
22pub use join_order::{BitSet, DPccp, JoinGraph, JoinGraphBuilder, JoinPlan};
23
24use crate::query::plan::{FilterOp, LogicalExpression, LogicalOperator, LogicalPlan};
25use grafeo_common::utils::error::Result;
26use std::collections::HashSet;
27
28#[derive(Debug, Clone)]
30struct JoinInfo {
31 left_var: String,
32 right_var: String,
33 left_expr: LogicalExpression,
34 right_expr: LogicalExpression,
35}
36
37#[derive(Debug, Clone, PartialEq, Eq, Hash)]
39enum RequiredColumn {
40 Variable(String),
42 Property(String, String),
44}
45
46pub struct Optimizer {
51 enable_filter_pushdown: bool,
53 enable_join_reorder: bool,
55 enable_projection_pushdown: bool,
57 cost_model: CostModel,
59 card_estimator: CardinalityEstimator,
61}
62
63impl Optimizer {
64 #[must_use]
66 pub fn new() -> Self {
67 Self {
68 enable_filter_pushdown: true,
69 enable_join_reorder: true,
70 enable_projection_pushdown: true,
71 cost_model: CostModel::new(),
72 card_estimator: CardinalityEstimator::new(),
73 }
74 }
75
76 #[must_use]
82 pub fn from_store(store: &grafeo_core::graph::lpg::LpgStore) -> Self {
83 store.ensure_statistics_fresh();
84 let stats = store.statistics();
85 let estimator = CardinalityEstimator::from_statistics(&stats);
86
87 let avg_fanout = if stats.total_nodes > 0 {
89 (stats.total_edges as f64 / stats.total_nodes as f64).max(1.0)
90 } else {
91 10.0
92 };
93
94 let edge_type_degrees: std::collections::HashMap<String, (f64, f64)> = stats
96 .edge_types
97 .iter()
98 .map(|(name, et)| (name.clone(), (et.avg_out_degree, et.avg_in_degree)))
99 .collect();
100
101 Self {
102 enable_filter_pushdown: true,
103 enable_join_reorder: true,
104 enable_projection_pushdown: true,
105 cost_model: CostModel::new()
106 .with_avg_fanout(avg_fanout)
107 .with_edge_type_degrees(edge_type_degrees),
108 card_estimator: estimator,
109 }
110 }
111
112 #[must_use]
119 pub fn from_graph_store(store: &dyn grafeo_core::graph::GraphStore) -> Self {
120 let stats = store.statistics();
121 let estimator = CardinalityEstimator::from_statistics(&stats);
122
123 let avg_fanout = if stats.total_nodes > 0 {
124 (stats.total_edges as f64 / stats.total_nodes as f64).max(1.0)
125 } else {
126 10.0
127 };
128
129 let edge_type_degrees: std::collections::HashMap<String, (f64, f64)> = stats
130 .edge_types
131 .iter()
132 .map(|(name, et)| (name.clone(), (et.avg_out_degree, et.avg_in_degree)))
133 .collect();
134
135 Self {
136 enable_filter_pushdown: true,
137 enable_join_reorder: true,
138 enable_projection_pushdown: true,
139 cost_model: CostModel::new()
140 .with_avg_fanout(avg_fanout)
141 .with_edge_type_degrees(edge_type_degrees),
142 card_estimator: estimator,
143 }
144 }
145
146 pub fn with_filter_pushdown(mut self, enabled: bool) -> Self {
148 self.enable_filter_pushdown = enabled;
149 self
150 }
151
152 pub fn with_join_reorder(mut self, enabled: bool) -> Self {
154 self.enable_join_reorder = enabled;
155 self
156 }
157
158 pub fn with_projection_pushdown(mut self, enabled: bool) -> Self {
160 self.enable_projection_pushdown = enabled;
161 self
162 }
163
164 pub fn with_cost_model(mut self, cost_model: CostModel) -> Self {
166 self.cost_model = cost_model;
167 self
168 }
169
170 pub fn with_cardinality_estimator(mut self, estimator: CardinalityEstimator) -> Self {
172 self.card_estimator = estimator;
173 self
174 }
175
176 pub fn with_selectivity_config(mut self, config: SelectivityConfig) -> Self {
178 self.card_estimator = CardinalityEstimator::with_selectivity_config(config);
179 self
180 }
181
182 pub fn cost_model(&self) -> &CostModel {
184 &self.cost_model
185 }
186
187 pub fn cardinality_estimator(&self) -> &CardinalityEstimator {
189 &self.card_estimator
190 }
191
192 pub fn estimate_cost(&self, plan: &LogicalPlan) -> Cost {
194 let cardinality = self.card_estimator.estimate(&plan.root);
195 self.cost_model.estimate(&plan.root, cardinality)
196 }
197
198 pub fn estimate_cardinality(&self, plan: &LogicalPlan) -> f64 {
200 self.card_estimator.estimate(&plan.root)
201 }
202
203 pub fn optimize(&self, plan: LogicalPlan) -> Result<LogicalPlan> {
209 let mut root = plan.root;
210
211 if self.enable_filter_pushdown {
213 root = self.push_filters_down(root);
214 }
215
216 if self.enable_join_reorder {
217 root = self.reorder_joins(root);
218 }
219
220 if self.enable_projection_pushdown {
221 root = self.push_projections_down(root);
222 }
223
224 Ok(LogicalPlan::new(root))
225 }
226
227 fn push_projections_down(&self, op: LogicalOperator) -> LogicalOperator {
234 let required = self.collect_required_columns(&op);
236
237 self.push_projections_recursive(op, &required)
239 }
240
241 fn collect_required_columns(&self, op: &LogicalOperator) -> HashSet<RequiredColumn> {
243 let mut required = HashSet::new();
244 Self::collect_required_recursive(op, &mut required);
245 required
246 }
247
248 fn collect_required_recursive(op: &LogicalOperator, required: &mut HashSet<RequiredColumn>) {
250 match op {
251 LogicalOperator::Return(ret) => {
252 for item in &ret.items {
253 Self::collect_from_expression(&item.expression, required);
254 }
255 Self::collect_required_recursive(&ret.input, required);
256 }
257 LogicalOperator::Project(proj) => {
258 for p in &proj.projections {
259 Self::collect_from_expression(&p.expression, required);
260 }
261 Self::collect_required_recursive(&proj.input, required);
262 }
263 LogicalOperator::Filter(filter) => {
264 Self::collect_from_expression(&filter.predicate, required);
265 Self::collect_required_recursive(&filter.input, required);
266 }
267 LogicalOperator::Sort(sort) => {
268 for key in &sort.keys {
269 Self::collect_from_expression(&key.expression, required);
270 }
271 Self::collect_required_recursive(&sort.input, required);
272 }
273 LogicalOperator::Aggregate(agg) => {
274 for expr in &agg.group_by {
275 Self::collect_from_expression(expr, required);
276 }
277 for agg_expr in &agg.aggregates {
278 if let Some(ref expr) = agg_expr.expression {
279 Self::collect_from_expression(expr, required);
280 }
281 }
282 if let Some(ref having) = agg.having {
283 Self::collect_from_expression(having, required);
284 }
285 Self::collect_required_recursive(&agg.input, required);
286 }
287 LogicalOperator::Join(join) => {
288 for cond in &join.conditions {
289 Self::collect_from_expression(&cond.left, required);
290 Self::collect_from_expression(&cond.right, required);
291 }
292 Self::collect_required_recursive(&join.left, required);
293 Self::collect_required_recursive(&join.right, required);
294 }
295 LogicalOperator::Expand(expand) => {
296 required.insert(RequiredColumn::Variable(expand.from_variable.clone()));
298 required.insert(RequiredColumn::Variable(expand.to_variable.clone()));
299 if let Some(ref edge_var) = expand.edge_variable {
300 required.insert(RequiredColumn::Variable(edge_var.clone()));
301 }
302 Self::collect_required_recursive(&expand.input, required);
303 }
304 LogicalOperator::Limit(limit) => {
305 Self::collect_required_recursive(&limit.input, required);
306 }
307 LogicalOperator::Skip(skip) => {
308 Self::collect_required_recursive(&skip.input, required);
309 }
310 LogicalOperator::Distinct(distinct) => {
311 Self::collect_required_recursive(&distinct.input, required);
312 }
313 LogicalOperator::NodeScan(scan) => {
314 required.insert(RequiredColumn::Variable(scan.variable.clone()));
315 }
316 LogicalOperator::EdgeScan(scan) => {
317 required.insert(RequiredColumn::Variable(scan.variable.clone()));
318 }
319 _ => {}
320 }
321 }
322
323 fn collect_from_expression(expr: &LogicalExpression, required: &mut HashSet<RequiredColumn>) {
325 match expr {
326 LogicalExpression::Variable(var) => {
327 required.insert(RequiredColumn::Variable(var.clone()));
328 }
329 LogicalExpression::Property { variable, property } => {
330 required.insert(RequiredColumn::Property(variable.clone(), property.clone()));
331 required.insert(RequiredColumn::Variable(variable.clone()));
332 }
333 LogicalExpression::Binary { left, right, .. } => {
334 Self::collect_from_expression(left, required);
335 Self::collect_from_expression(right, required);
336 }
337 LogicalExpression::Unary { operand, .. } => {
338 Self::collect_from_expression(operand, required);
339 }
340 LogicalExpression::FunctionCall { args, .. } => {
341 for arg in args {
342 Self::collect_from_expression(arg, required);
343 }
344 }
345 LogicalExpression::List(items) => {
346 for item in items {
347 Self::collect_from_expression(item, required);
348 }
349 }
350 LogicalExpression::Map(pairs) => {
351 for (_, value) in pairs {
352 Self::collect_from_expression(value, required);
353 }
354 }
355 LogicalExpression::IndexAccess { base, index } => {
356 Self::collect_from_expression(base, required);
357 Self::collect_from_expression(index, required);
358 }
359 LogicalExpression::SliceAccess { base, start, end } => {
360 Self::collect_from_expression(base, required);
361 if let Some(s) = start {
362 Self::collect_from_expression(s, required);
363 }
364 if let Some(e) = end {
365 Self::collect_from_expression(e, required);
366 }
367 }
368 LogicalExpression::Case {
369 operand,
370 when_clauses,
371 else_clause,
372 } => {
373 if let Some(op) = operand {
374 Self::collect_from_expression(op, required);
375 }
376 for (cond, result) in when_clauses {
377 Self::collect_from_expression(cond, required);
378 Self::collect_from_expression(result, required);
379 }
380 if let Some(else_expr) = else_clause {
381 Self::collect_from_expression(else_expr, required);
382 }
383 }
384 LogicalExpression::Labels(var)
385 | LogicalExpression::Type(var)
386 | LogicalExpression::Id(var) => {
387 required.insert(RequiredColumn::Variable(var.clone()));
388 }
389 LogicalExpression::ListComprehension {
390 list_expr,
391 filter_expr,
392 map_expr,
393 ..
394 } => {
395 Self::collect_from_expression(list_expr, required);
396 if let Some(filter) = filter_expr {
397 Self::collect_from_expression(filter, required);
398 }
399 Self::collect_from_expression(map_expr, required);
400 }
401 _ => {}
402 }
403 }
404
405 fn push_projections_recursive(
407 &self,
408 op: LogicalOperator,
409 required: &HashSet<RequiredColumn>,
410 ) -> LogicalOperator {
411 match op {
412 LogicalOperator::Return(mut ret) => {
413 ret.input = Box::new(self.push_projections_recursive(*ret.input, required));
414 LogicalOperator::Return(ret)
415 }
416 LogicalOperator::Project(mut proj) => {
417 proj.input = Box::new(self.push_projections_recursive(*proj.input, required));
418 LogicalOperator::Project(proj)
419 }
420 LogicalOperator::Filter(mut filter) => {
421 filter.input = Box::new(self.push_projections_recursive(*filter.input, required));
422 LogicalOperator::Filter(filter)
423 }
424 LogicalOperator::Sort(mut sort) => {
425 sort.input = Box::new(self.push_projections_recursive(*sort.input, required));
428 LogicalOperator::Sort(sort)
429 }
430 LogicalOperator::Aggregate(mut agg) => {
431 agg.input = Box::new(self.push_projections_recursive(*agg.input, required));
432 LogicalOperator::Aggregate(agg)
433 }
434 LogicalOperator::Join(mut join) => {
435 let left_vars = self.collect_output_variables(&join.left);
438 let right_vars = self.collect_output_variables(&join.right);
439
440 let left_required: HashSet<_> = required
442 .iter()
443 .filter(|c| match c {
444 RequiredColumn::Variable(v) => left_vars.contains(v),
445 RequiredColumn::Property(v, _) => left_vars.contains(v),
446 })
447 .cloned()
448 .collect();
449
450 let right_required: HashSet<_> = required
451 .iter()
452 .filter(|c| match c {
453 RequiredColumn::Variable(v) => right_vars.contains(v),
454 RequiredColumn::Property(v, _) => right_vars.contains(v),
455 })
456 .cloned()
457 .collect();
458
459 join.left = Box::new(self.push_projections_recursive(*join.left, &left_required));
460 join.right =
461 Box::new(self.push_projections_recursive(*join.right, &right_required));
462 LogicalOperator::Join(join)
463 }
464 LogicalOperator::Expand(mut expand) => {
465 expand.input = Box::new(self.push_projections_recursive(*expand.input, required));
466 LogicalOperator::Expand(expand)
467 }
468 LogicalOperator::Limit(mut limit) => {
469 limit.input = Box::new(self.push_projections_recursive(*limit.input, required));
470 LogicalOperator::Limit(limit)
471 }
472 LogicalOperator::Skip(mut skip) => {
473 skip.input = Box::new(self.push_projections_recursive(*skip.input, required));
474 LogicalOperator::Skip(skip)
475 }
476 LogicalOperator::Distinct(mut distinct) => {
477 distinct.input =
478 Box::new(self.push_projections_recursive(*distinct.input, required));
479 LogicalOperator::Distinct(distinct)
480 }
481 LogicalOperator::MapCollect(mut mc) => {
482 mc.input = Box::new(self.push_projections_recursive(*mc.input, required));
483 LogicalOperator::MapCollect(mc)
484 }
485 other => other,
486 }
487 }
488
489 fn reorder_joins(&self, op: LogicalOperator) -> LogicalOperator {
496 let op = self.reorder_joins_recursive(op);
498
499 if let Some((relations, conditions)) = self.extract_join_tree(&op)
501 && relations.len() >= 2
502 && let Some(optimized) = self.optimize_join_order(&relations, &conditions)
503 {
504 return optimized;
505 }
506
507 op
508 }
509
510 fn reorder_joins_recursive(&self, op: LogicalOperator) -> LogicalOperator {
512 match op {
513 LogicalOperator::Return(mut ret) => {
514 ret.input = Box::new(self.reorder_joins(*ret.input));
515 LogicalOperator::Return(ret)
516 }
517 LogicalOperator::Project(mut proj) => {
518 proj.input = Box::new(self.reorder_joins(*proj.input));
519 LogicalOperator::Project(proj)
520 }
521 LogicalOperator::Filter(mut filter) => {
522 filter.input = Box::new(self.reorder_joins(*filter.input));
523 LogicalOperator::Filter(filter)
524 }
525 LogicalOperator::Limit(mut limit) => {
526 limit.input = Box::new(self.reorder_joins(*limit.input));
527 LogicalOperator::Limit(limit)
528 }
529 LogicalOperator::Skip(mut skip) => {
530 skip.input = Box::new(self.reorder_joins(*skip.input));
531 LogicalOperator::Skip(skip)
532 }
533 LogicalOperator::Sort(mut sort) => {
534 sort.input = Box::new(self.reorder_joins(*sort.input));
535 LogicalOperator::Sort(sort)
536 }
537 LogicalOperator::Distinct(mut distinct) => {
538 distinct.input = Box::new(self.reorder_joins(*distinct.input));
539 LogicalOperator::Distinct(distinct)
540 }
541 LogicalOperator::Aggregate(mut agg) => {
542 agg.input = Box::new(self.reorder_joins(*agg.input));
543 LogicalOperator::Aggregate(agg)
544 }
545 LogicalOperator::Expand(mut expand) => {
546 expand.input = Box::new(self.reorder_joins(*expand.input));
547 LogicalOperator::Expand(expand)
548 }
549 LogicalOperator::MapCollect(mut mc) => {
550 mc.input = Box::new(self.reorder_joins(*mc.input));
551 LogicalOperator::MapCollect(mc)
552 }
553 other => other,
555 }
556 }
557
558 fn extract_join_tree(
562 &self,
563 op: &LogicalOperator,
564 ) -> Option<(Vec<(String, LogicalOperator)>, Vec<JoinInfo>)> {
565 let mut relations = Vec::new();
566 let mut join_conditions = Vec::new();
567
568 if !self.collect_join_tree(op, &mut relations, &mut join_conditions) {
569 return None;
570 }
571
572 if relations.len() < 2 {
573 return None;
574 }
575
576 Some((relations, join_conditions))
577 }
578
579 fn collect_join_tree(
583 &self,
584 op: &LogicalOperator,
585 relations: &mut Vec<(String, LogicalOperator)>,
586 conditions: &mut Vec<JoinInfo>,
587 ) -> bool {
588 match op {
589 LogicalOperator::Join(join) => {
590 let left_ok = self.collect_join_tree(&join.left, relations, conditions);
592 let right_ok = self.collect_join_tree(&join.right, relations, conditions);
593
594 for cond in &join.conditions {
596 if let (Some(left_var), Some(right_var)) = (
597 self.extract_variable_from_expr(&cond.left),
598 self.extract_variable_from_expr(&cond.right),
599 ) {
600 conditions.push(JoinInfo {
601 left_var,
602 right_var,
603 left_expr: cond.left.clone(),
604 right_expr: cond.right.clone(),
605 });
606 }
607 }
608
609 left_ok && right_ok
610 }
611 LogicalOperator::NodeScan(scan) => {
612 relations.push((scan.variable.clone(), op.clone()));
613 true
614 }
615 LogicalOperator::EdgeScan(scan) => {
616 relations.push((scan.variable.clone(), op.clone()));
617 true
618 }
619 LogicalOperator::Filter(filter) => {
620 self.collect_join_tree(&filter.input, relations, conditions)
622 }
623 LogicalOperator::Expand(expand) => {
624 relations.push((expand.to_variable.clone(), op.clone()));
627 true
628 }
629 _ => false,
630 }
631 }
632
633 fn extract_variable_from_expr(&self, expr: &LogicalExpression) -> Option<String> {
635 match expr {
636 LogicalExpression::Variable(v) => Some(v.clone()),
637 LogicalExpression::Property { variable, .. } => Some(variable.clone()),
638 _ => None,
639 }
640 }
641
642 fn optimize_join_order(
644 &self,
645 relations: &[(String, LogicalOperator)],
646 conditions: &[JoinInfo],
647 ) -> Option<LogicalOperator> {
648 use join_order::{DPccp, JoinGraphBuilder};
649
650 let mut builder = JoinGraphBuilder::new();
652
653 for (var, relation) in relations {
654 builder.add_relation(var, relation.clone());
655 }
656
657 for cond in conditions {
658 builder.add_join_condition(
659 &cond.left_var,
660 &cond.right_var,
661 cond.left_expr.clone(),
662 cond.right_expr.clone(),
663 );
664 }
665
666 let graph = builder.build();
667
668 let mut dpccp = DPccp::new(&graph, &self.cost_model, &self.card_estimator);
670 let plan = dpccp.optimize()?;
671
672 Some(plan.operator)
673 }
674
675 fn push_filters_down(&self, op: LogicalOperator) -> LogicalOperator {
680 match op {
681 LogicalOperator::Filter(filter) => {
683 let optimized_input = self.push_filters_down(*filter.input);
684 self.try_push_filter_into(filter.predicate, optimized_input)
685 }
686 LogicalOperator::Return(mut ret) => {
688 ret.input = Box::new(self.push_filters_down(*ret.input));
689 LogicalOperator::Return(ret)
690 }
691 LogicalOperator::Project(mut proj) => {
692 proj.input = Box::new(self.push_filters_down(*proj.input));
693 LogicalOperator::Project(proj)
694 }
695 LogicalOperator::Limit(mut limit) => {
696 limit.input = Box::new(self.push_filters_down(*limit.input));
697 LogicalOperator::Limit(limit)
698 }
699 LogicalOperator::Skip(mut skip) => {
700 skip.input = Box::new(self.push_filters_down(*skip.input));
701 LogicalOperator::Skip(skip)
702 }
703 LogicalOperator::Sort(mut sort) => {
704 sort.input = Box::new(self.push_filters_down(*sort.input));
705 LogicalOperator::Sort(sort)
706 }
707 LogicalOperator::Distinct(mut distinct) => {
708 distinct.input = Box::new(self.push_filters_down(*distinct.input));
709 LogicalOperator::Distinct(distinct)
710 }
711 LogicalOperator::Expand(mut expand) => {
712 expand.input = Box::new(self.push_filters_down(*expand.input));
713 LogicalOperator::Expand(expand)
714 }
715 LogicalOperator::Join(mut join) => {
716 join.left = Box::new(self.push_filters_down(*join.left));
717 join.right = Box::new(self.push_filters_down(*join.right));
718 LogicalOperator::Join(join)
719 }
720 LogicalOperator::Aggregate(mut agg) => {
721 agg.input = Box::new(self.push_filters_down(*agg.input));
722 LogicalOperator::Aggregate(agg)
723 }
724 LogicalOperator::MapCollect(mut mc) => {
725 mc.input = Box::new(self.push_filters_down(*mc.input));
726 LogicalOperator::MapCollect(mc)
727 }
728 other => other,
730 }
731 }
732
733 fn try_push_filter_into(
738 &self,
739 predicate: LogicalExpression,
740 op: LogicalOperator,
741 ) -> LogicalOperator {
742 match op {
743 LogicalOperator::Project(mut proj) => {
745 let predicate_vars = self.extract_variables(&predicate);
746 let computed_vars = self.extract_projection_aliases(&proj.projections);
747
748 if predicate_vars.is_disjoint(&computed_vars) {
750 proj.input = Box::new(self.try_push_filter_into(predicate, *proj.input));
751 LogicalOperator::Project(proj)
752 } else {
753 LogicalOperator::Filter(FilterOp {
755 predicate,
756 input: Box::new(LogicalOperator::Project(proj)),
757 })
758 }
759 }
760
761 LogicalOperator::Return(mut ret) => {
763 ret.input = Box::new(self.try_push_filter_into(predicate, *ret.input));
764 LogicalOperator::Return(ret)
765 }
766
767 LogicalOperator::Expand(mut expand) => {
769 let predicate_vars = self.extract_variables(&predicate);
770
771 let mut introduced_vars = vec![&expand.to_variable];
776 if let Some(ref edge_var) = expand.edge_variable {
777 introduced_vars.push(edge_var);
778 }
779 if let Some(ref path_alias) = expand.path_alias {
780 introduced_vars.push(path_alias);
781 }
782
783 let uses_introduced_vars =
785 predicate_vars.iter().any(|v| introduced_vars.contains(&v));
786
787 if !uses_introduced_vars {
788 expand.input = Box::new(self.try_push_filter_into(predicate, *expand.input));
790 LogicalOperator::Expand(expand)
791 } else {
792 LogicalOperator::Filter(FilterOp {
794 predicate,
795 input: Box::new(LogicalOperator::Expand(expand)),
796 })
797 }
798 }
799
800 LogicalOperator::Join(mut join) => {
802 let predicate_vars = self.extract_variables(&predicate);
803 let left_vars = self.collect_output_variables(&join.left);
804 let right_vars = self.collect_output_variables(&join.right);
805
806 let uses_left = predicate_vars.iter().any(|v| left_vars.contains(v));
807 let uses_right = predicate_vars.iter().any(|v| right_vars.contains(v));
808
809 if uses_left && !uses_right {
810 join.left = Box::new(self.try_push_filter_into(predicate, *join.left));
812 LogicalOperator::Join(join)
813 } else if uses_right && !uses_left {
814 join.right = Box::new(self.try_push_filter_into(predicate, *join.right));
816 LogicalOperator::Join(join)
817 } else {
818 LogicalOperator::Filter(FilterOp {
820 predicate,
821 input: Box::new(LogicalOperator::Join(join)),
822 })
823 }
824 }
825
826 LogicalOperator::Aggregate(agg) => LogicalOperator::Filter(FilterOp {
828 predicate,
829 input: Box::new(LogicalOperator::Aggregate(agg)),
830 }),
831
832 LogicalOperator::NodeScan(scan) => LogicalOperator::Filter(FilterOp {
834 predicate,
835 input: Box::new(LogicalOperator::NodeScan(scan)),
836 }),
837
838 other => LogicalOperator::Filter(FilterOp {
840 predicate,
841 input: Box::new(other),
842 }),
843 }
844 }
845
846 fn collect_output_variables(&self, op: &LogicalOperator) -> HashSet<String> {
848 let mut vars = HashSet::new();
849 Self::collect_output_variables_recursive(op, &mut vars);
850 vars
851 }
852
853 fn collect_output_variables_recursive(op: &LogicalOperator, vars: &mut HashSet<String>) {
855 match op {
856 LogicalOperator::NodeScan(scan) => {
857 vars.insert(scan.variable.clone());
858 }
859 LogicalOperator::EdgeScan(scan) => {
860 vars.insert(scan.variable.clone());
861 }
862 LogicalOperator::Expand(expand) => {
863 vars.insert(expand.to_variable.clone());
864 if let Some(edge_var) = &expand.edge_variable {
865 vars.insert(edge_var.clone());
866 }
867 Self::collect_output_variables_recursive(&expand.input, vars);
868 }
869 LogicalOperator::Filter(filter) => {
870 Self::collect_output_variables_recursive(&filter.input, vars);
871 }
872 LogicalOperator::Project(proj) => {
873 for p in &proj.projections {
874 if let Some(alias) = &p.alias {
875 vars.insert(alias.clone());
876 }
877 }
878 Self::collect_output_variables_recursive(&proj.input, vars);
879 }
880 LogicalOperator::Join(join) => {
881 Self::collect_output_variables_recursive(&join.left, vars);
882 Self::collect_output_variables_recursive(&join.right, vars);
883 }
884 LogicalOperator::Aggregate(agg) => {
885 for expr in &agg.group_by {
886 Self::collect_variables(expr, vars);
887 }
888 for agg_expr in &agg.aggregates {
889 if let Some(alias) = &agg_expr.alias {
890 vars.insert(alias.clone());
891 }
892 }
893 }
894 LogicalOperator::Return(ret) => {
895 Self::collect_output_variables_recursive(&ret.input, vars);
896 }
897 LogicalOperator::Limit(limit) => {
898 Self::collect_output_variables_recursive(&limit.input, vars);
899 }
900 LogicalOperator::Skip(skip) => {
901 Self::collect_output_variables_recursive(&skip.input, vars);
902 }
903 LogicalOperator::Sort(sort) => {
904 Self::collect_output_variables_recursive(&sort.input, vars);
905 }
906 LogicalOperator::Distinct(distinct) => {
907 Self::collect_output_variables_recursive(&distinct.input, vars);
908 }
909 _ => {}
910 }
911 }
912
913 fn extract_variables(&self, expr: &LogicalExpression) -> HashSet<String> {
915 let mut vars = HashSet::new();
916 Self::collect_variables(expr, &mut vars);
917 vars
918 }
919
920 fn collect_variables(expr: &LogicalExpression, vars: &mut HashSet<String>) {
922 match expr {
923 LogicalExpression::Variable(name) => {
924 vars.insert(name.clone());
925 }
926 LogicalExpression::Property { variable, .. } => {
927 vars.insert(variable.clone());
928 }
929 LogicalExpression::Binary { left, right, .. } => {
930 Self::collect_variables(left, vars);
931 Self::collect_variables(right, vars);
932 }
933 LogicalExpression::Unary { operand, .. } => {
934 Self::collect_variables(operand, vars);
935 }
936 LogicalExpression::FunctionCall { args, .. } => {
937 for arg in args {
938 Self::collect_variables(arg, vars);
939 }
940 }
941 LogicalExpression::List(items) => {
942 for item in items {
943 Self::collect_variables(item, vars);
944 }
945 }
946 LogicalExpression::Map(pairs) => {
947 for (_, value) in pairs {
948 Self::collect_variables(value, vars);
949 }
950 }
951 LogicalExpression::IndexAccess { base, index } => {
952 Self::collect_variables(base, vars);
953 Self::collect_variables(index, vars);
954 }
955 LogicalExpression::SliceAccess { base, start, end } => {
956 Self::collect_variables(base, vars);
957 if let Some(s) = start {
958 Self::collect_variables(s, vars);
959 }
960 if let Some(e) = end {
961 Self::collect_variables(e, vars);
962 }
963 }
964 LogicalExpression::Case {
965 operand,
966 when_clauses,
967 else_clause,
968 } => {
969 if let Some(op) = operand {
970 Self::collect_variables(op, vars);
971 }
972 for (cond, result) in when_clauses {
973 Self::collect_variables(cond, vars);
974 Self::collect_variables(result, vars);
975 }
976 if let Some(else_expr) = else_clause {
977 Self::collect_variables(else_expr, vars);
978 }
979 }
980 LogicalExpression::Labels(var)
981 | LogicalExpression::Type(var)
982 | LogicalExpression::Id(var) => {
983 vars.insert(var.clone());
984 }
985 LogicalExpression::Literal(_) | LogicalExpression::Parameter(_) => {}
986 LogicalExpression::ListComprehension {
987 list_expr,
988 filter_expr,
989 map_expr,
990 ..
991 } => {
992 Self::collect_variables(list_expr, vars);
993 if let Some(filter) = filter_expr {
994 Self::collect_variables(filter, vars);
995 }
996 Self::collect_variables(map_expr, vars);
997 }
998 LogicalExpression::ListPredicate {
999 list_expr,
1000 predicate,
1001 ..
1002 } => {
1003 Self::collect_variables(list_expr, vars);
1004 Self::collect_variables(predicate, vars);
1005 }
1006 LogicalExpression::ExistsSubquery(_) | LogicalExpression::CountSubquery(_) => {
1007 }
1009 LogicalExpression::PatternComprehension { projection, .. } => {
1010 Self::collect_variables(projection, vars);
1011 }
1012 LogicalExpression::MapProjection { base, entries } => {
1013 vars.insert(base.clone());
1014 for entry in entries {
1015 if let crate::query::plan::MapProjectionEntry::LiteralEntry(_, expr) = entry {
1016 Self::collect_variables(expr, vars);
1017 }
1018 }
1019 }
1020 LogicalExpression::Reduce {
1021 initial,
1022 list,
1023 expression,
1024 ..
1025 } => {
1026 Self::collect_variables(initial, vars);
1027 Self::collect_variables(list, vars);
1028 Self::collect_variables(expression, vars);
1029 }
1030 }
1031 }
1032
1033 fn extract_projection_aliases(
1035 &self,
1036 projections: &[crate::query::plan::Projection],
1037 ) -> HashSet<String> {
1038 projections.iter().filter_map(|p| p.alias.clone()).collect()
1039 }
1040}
1041
1042impl Default for Optimizer {
1043 fn default() -> Self {
1044 Self::new()
1045 }
1046}
1047
1048#[cfg(test)]
1049mod tests {
1050 use super::*;
1051 use crate::query::plan::{
1052 AggregateExpr, AggregateFunction, AggregateOp, BinaryOp, DistinctOp, ExpandDirection,
1053 ExpandOp, JoinOp, JoinType, LimitOp, NodeScanOp, PathMode, ProjectOp, Projection,
1054 ReturnItem, ReturnOp, SkipOp, SortKey, SortOp, SortOrder, UnaryOp,
1055 };
1056 use grafeo_common::types::Value;
1057
1058 #[test]
1059 fn test_optimizer_filter_pushdown_simple() {
1060 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1065 items: vec![ReturnItem {
1066 expression: LogicalExpression::Variable("n".to_string()),
1067 alias: None,
1068 }],
1069 distinct: false,
1070 input: Box::new(LogicalOperator::Filter(FilterOp {
1071 predicate: LogicalExpression::Binary {
1072 left: Box::new(LogicalExpression::Property {
1073 variable: "n".to_string(),
1074 property: "age".to_string(),
1075 }),
1076 op: BinaryOp::Gt,
1077 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1078 },
1079 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1080 variable: "n".to_string(),
1081 label: Some("Person".to_string()),
1082 input: None,
1083 })),
1084 })),
1085 }));
1086
1087 let optimizer = Optimizer::new();
1088 let optimized = optimizer.optimize(plan).unwrap();
1089
1090 if let LogicalOperator::Return(ret) = &optimized.root
1092 && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1093 && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1094 {
1095 assert_eq!(scan.variable, "n");
1096 return;
1097 }
1098 panic!("Expected Return -> Filter -> NodeScan structure");
1099 }
1100
1101 #[test]
1102 fn test_optimizer_filter_pushdown_through_expand() {
1103 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1107 items: vec![ReturnItem {
1108 expression: LogicalExpression::Variable("b".to_string()),
1109 alias: None,
1110 }],
1111 distinct: false,
1112 input: Box::new(LogicalOperator::Filter(FilterOp {
1113 predicate: LogicalExpression::Binary {
1114 left: Box::new(LogicalExpression::Property {
1115 variable: "a".to_string(),
1116 property: "age".to_string(),
1117 }),
1118 op: BinaryOp::Gt,
1119 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1120 },
1121 input: Box::new(LogicalOperator::Expand(ExpandOp {
1122 from_variable: "a".to_string(),
1123 to_variable: "b".to_string(),
1124 edge_variable: None,
1125 direction: ExpandDirection::Outgoing,
1126 edge_types: vec!["KNOWS".to_string()],
1127 min_hops: 1,
1128 max_hops: Some(1),
1129 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1130 variable: "a".to_string(),
1131 label: Some("Person".to_string()),
1132 input: None,
1133 })),
1134 path_alias: None,
1135 path_mode: PathMode::Walk,
1136 })),
1137 })),
1138 }));
1139
1140 let optimizer = Optimizer::new();
1141 let optimized = optimizer.optimize(plan).unwrap();
1142
1143 if let LogicalOperator::Return(ret) = &optimized.root
1146 && let LogicalOperator::Expand(expand) = ret.input.as_ref()
1147 && let LogicalOperator::Filter(filter) = expand.input.as_ref()
1148 && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1149 {
1150 assert_eq!(scan.variable, "a");
1151 assert_eq!(expand.from_variable, "a");
1152 assert_eq!(expand.to_variable, "b");
1153 return;
1154 }
1155 panic!("Expected Return -> Expand -> Filter -> NodeScan structure");
1156 }
1157
1158 #[test]
1159 fn test_optimizer_filter_not_pushed_through_expand_for_target_var() {
1160 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1164 items: vec![ReturnItem {
1165 expression: LogicalExpression::Variable("a".to_string()),
1166 alias: None,
1167 }],
1168 distinct: false,
1169 input: Box::new(LogicalOperator::Filter(FilterOp {
1170 predicate: LogicalExpression::Binary {
1171 left: Box::new(LogicalExpression::Property {
1172 variable: "b".to_string(),
1173 property: "age".to_string(),
1174 }),
1175 op: BinaryOp::Gt,
1176 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1177 },
1178 input: Box::new(LogicalOperator::Expand(ExpandOp {
1179 from_variable: "a".to_string(),
1180 to_variable: "b".to_string(),
1181 edge_variable: None,
1182 direction: ExpandDirection::Outgoing,
1183 edge_types: vec!["KNOWS".to_string()],
1184 min_hops: 1,
1185 max_hops: Some(1),
1186 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1187 variable: "a".to_string(),
1188 label: Some("Person".to_string()),
1189 input: None,
1190 })),
1191 path_alias: None,
1192 path_mode: PathMode::Walk,
1193 })),
1194 })),
1195 }));
1196
1197 let optimizer = Optimizer::new();
1198 let optimized = optimizer.optimize(plan).unwrap();
1199
1200 if let LogicalOperator::Return(ret) = &optimized.root
1203 && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1204 {
1205 if let LogicalExpression::Binary { left, .. } = &filter.predicate
1207 && let LogicalExpression::Property { variable, .. } = left.as_ref()
1208 {
1209 assert_eq!(variable, "b");
1210 }
1211
1212 if let LogicalOperator::Expand(expand) = filter.input.as_ref()
1213 && let LogicalOperator::NodeScan(_) = expand.input.as_ref()
1214 {
1215 return;
1216 }
1217 }
1218 panic!("Expected Return -> Filter -> Expand -> NodeScan structure");
1219 }
1220
1221 #[test]
1222 fn test_optimizer_extract_variables() {
1223 let optimizer = Optimizer::new();
1224
1225 let expr = LogicalExpression::Binary {
1226 left: Box::new(LogicalExpression::Property {
1227 variable: "n".to_string(),
1228 property: "age".to_string(),
1229 }),
1230 op: BinaryOp::Gt,
1231 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1232 };
1233
1234 let vars = optimizer.extract_variables(&expr);
1235 assert_eq!(vars.len(), 1);
1236 assert!(vars.contains("n"));
1237 }
1238
1239 #[test]
1242 fn test_optimizer_default() {
1243 let optimizer = Optimizer::default();
1244 let plan = LogicalPlan::new(LogicalOperator::Empty);
1246 let result = optimizer.optimize(plan);
1247 assert!(result.is_ok());
1248 }
1249
1250 #[test]
1251 fn test_optimizer_with_filter_pushdown_disabled() {
1252 let optimizer = Optimizer::new().with_filter_pushdown(false);
1253
1254 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1255 items: vec![ReturnItem {
1256 expression: LogicalExpression::Variable("n".to_string()),
1257 alias: None,
1258 }],
1259 distinct: false,
1260 input: Box::new(LogicalOperator::Filter(FilterOp {
1261 predicate: LogicalExpression::Literal(Value::Bool(true)),
1262 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1263 variable: "n".to_string(),
1264 label: None,
1265 input: None,
1266 })),
1267 })),
1268 }));
1269
1270 let optimized = optimizer.optimize(plan).unwrap();
1271 if let LogicalOperator::Return(ret) = &optimized.root
1273 && let LogicalOperator::Filter(_) = ret.input.as_ref()
1274 {
1275 return;
1276 }
1277 panic!("Expected unchanged structure");
1278 }
1279
1280 #[test]
1281 fn test_optimizer_with_join_reorder_disabled() {
1282 let optimizer = Optimizer::new().with_join_reorder(false);
1283 assert!(
1284 optimizer
1285 .optimize(LogicalPlan::new(LogicalOperator::Empty))
1286 .is_ok()
1287 );
1288 }
1289
1290 #[test]
1291 fn test_optimizer_with_cost_model() {
1292 let cost_model = CostModel::new();
1293 let optimizer = Optimizer::new().with_cost_model(cost_model);
1294 assert!(
1295 optimizer
1296 .cost_model()
1297 .estimate(&LogicalOperator::Empty, 0.0)
1298 .total()
1299 < 0.001
1300 );
1301 }
1302
1303 #[test]
1304 fn test_optimizer_with_cardinality_estimator() {
1305 let mut estimator = CardinalityEstimator::new();
1306 estimator.add_table_stats("Test", TableStats::new(500));
1307 let optimizer = Optimizer::new().with_cardinality_estimator(estimator);
1308
1309 let scan = LogicalOperator::NodeScan(NodeScanOp {
1310 variable: "n".to_string(),
1311 label: Some("Test".to_string()),
1312 input: None,
1313 });
1314 let plan = LogicalPlan::new(scan);
1315
1316 let cardinality = optimizer.estimate_cardinality(&plan);
1317 assert!((cardinality - 500.0).abs() < 0.001);
1318 }
1319
1320 #[test]
1321 fn test_optimizer_estimate_cost() {
1322 let optimizer = Optimizer::new();
1323 let plan = LogicalPlan::new(LogicalOperator::NodeScan(NodeScanOp {
1324 variable: "n".to_string(),
1325 label: None,
1326 input: None,
1327 }));
1328
1329 let cost = optimizer.estimate_cost(&plan);
1330 assert!(cost.total() > 0.0);
1331 }
1332
1333 #[test]
1336 fn test_filter_pushdown_through_project() {
1337 let optimizer = Optimizer::new();
1338
1339 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1340 predicate: LogicalExpression::Binary {
1341 left: Box::new(LogicalExpression::Property {
1342 variable: "n".to_string(),
1343 property: "age".to_string(),
1344 }),
1345 op: BinaryOp::Gt,
1346 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1347 },
1348 input: Box::new(LogicalOperator::Project(ProjectOp {
1349 projections: vec![Projection {
1350 expression: LogicalExpression::Variable("n".to_string()),
1351 alias: None,
1352 }],
1353 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1354 variable: "n".to_string(),
1355 label: None,
1356 input: None,
1357 })),
1358 })),
1359 }));
1360
1361 let optimized = optimizer.optimize(plan).unwrap();
1362
1363 if let LogicalOperator::Project(proj) = &optimized.root
1365 && let LogicalOperator::Filter(_) = proj.input.as_ref()
1366 {
1367 return;
1368 }
1369 panic!("Expected Project -> Filter structure");
1370 }
1371
1372 #[test]
1373 fn test_filter_not_pushed_through_project_with_alias() {
1374 let optimizer = Optimizer::new();
1375
1376 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1378 predicate: LogicalExpression::Binary {
1379 left: Box::new(LogicalExpression::Variable("x".to_string())),
1380 op: BinaryOp::Gt,
1381 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1382 },
1383 input: Box::new(LogicalOperator::Project(ProjectOp {
1384 projections: vec![Projection {
1385 expression: LogicalExpression::Property {
1386 variable: "n".to_string(),
1387 property: "age".to_string(),
1388 },
1389 alias: Some("x".to_string()),
1390 }],
1391 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1392 variable: "n".to_string(),
1393 label: None,
1394 input: None,
1395 })),
1396 })),
1397 }));
1398
1399 let optimized = optimizer.optimize(plan).unwrap();
1400
1401 if let LogicalOperator::Filter(filter) = &optimized.root
1403 && let LogicalOperator::Project(_) = filter.input.as_ref()
1404 {
1405 return;
1406 }
1407 panic!("Expected Filter -> Project structure");
1408 }
1409
1410 #[test]
1411 fn test_filter_pushdown_through_limit() {
1412 let optimizer = Optimizer::new();
1413
1414 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1415 predicate: LogicalExpression::Literal(Value::Bool(true)),
1416 input: Box::new(LogicalOperator::Limit(LimitOp {
1417 count: 10,
1418 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1419 variable: "n".to_string(),
1420 label: None,
1421 input: None,
1422 })),
1423 })),
1424 }));
1425
1426 let optimized = optimizer.optimize(plan).unwrap();
1427
1428 if let LogicalOperator::Filter(filter) = &optimized.root
1430 && let LogicalOperator::Limit(_) = filter.input.as_ref()
1431 {
1432 return;
1433 }
1434 panic!("Expected Filter -> Limit structure");
1435 }
1436
1437 #[test]
1438 fn test_filter_pushdown_through_sort() {
1439 let optimizer = Optimizer::new();
1440
1441 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1442 predicate: LogicalExpression::Literal(Value::Bool(true)),
1443 input: Box::new(LogicalOperator::Sort(SortOp {
1444 keys: vec![SortKey {
1445 expression: LogicalExpression::Variable("n".to_string()),
1446 order: SortOrder::Ascending,
1447 }],
1448 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1449 variable: "n".to_string(),
1450 label: None,
1451 input: None,
1452 })),
1453 })),
1454 }));
1455
1456 let optimized = optimizer.optimize(plan).unwrap();
1457
1458 if let LogicalOperator::Filter(filter) = &optimized.root
1460 && let LogicalOperator::Sort(_) = filter.input.as_ref()
1461 {
1462 return;
1463 }
1464 panic!("Expected Filter -> Sort structure");
1465 }
1466
1467 #[test]
1468 fn test_filter_pushdown_through_distinct() {
1469 let optimizer = Optimizer::new();
1470
1471 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1472 predicate: LogicalExpression::Literal(Value::Bool(true)),
1473 input: Box::new(LogicalOperator::Distinct(DistinctOp {
1474 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1475 variable: "n".to_string(),
1476 label: None,
1477 input: None,
1478 })),
1479 columns: None,
1480 })),
1481 }));
1482
1483 let optimized = optimizer.optimize(plan).unwrap();
1484
1485 if let LogicalOperator::Filter(filter) = &optimized.root
1487 && let LogicalOperator::Distinct(_) = filter.input.as_ref()
1488 {
1489 return;
1490 }
1491 panic!("Expected Filter -> Distinct structure");
1492 }
1493
1494 #[test]
1495 fn test_filter_not_pushed_through_aggregate() {
1496 let optimizer = Optimizer::new();
1497
1498 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1499 predicate: LogicalExpression::Binary {
1500 left: Box::new(LogicalExpression::Variable("cnt".to_string())),
1501 op: BinaryOp::Gt,
1502 right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1503 },
1504 input: Box::new(LogicalOperator::Aggregate(AggregateOp {
1505 group_by: vec![],
1506 aggregates: vec![AggregateExpr {
1507 function: AggregateFunction::Count,
1508 expression: None,
1509 distinct: false,
1510 alias: Some("cnt".to_string()),
1511 percentile: None,
1512 }],
1513 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1514 variable: "n".to_string(),
1515 label: None,
1516 input: None,
1517 })),
1518 having: None,
1519 })),
1520 }));
1521
1522 let optimized = optimizer.optimize(plan).unwrap();
1523
1524 if let LogicalOperator::Filter(filter) = &optimized.root
1526 && let LogicalOperator::Aggregate(_) = filter.input.as_ref()
1527 {
1528 return;
1529 }
1530 panic!("Expected Filter -> Aggregate structure");
1531 }
1532
1533 #[test]
1534 fn test_filter_pushdown_to_left_join_side() {
1535 let optimizer = Optimizer::new();
1536
1537 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1539 predicate: LogicalExpression::Binary {
1540 left: Box::new(LogicalExpression::Property {
1541 variable: "a".to_string(),
1542 property: "age".to_string(),
1543 }),
1544 op: BinaryOp::Gt,
1545 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1546 },
1547 input: Box::new(LogicalOperator::Join(JoinOp {
1548 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1549 variable: "a".to_string(),
1550 label: Some("Person".to_string()),
1551 input: None,
1552 })),
1553 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1554 variable: "b".to_string(),
1555 label: Some("Company".to_string()),
1556 input: None,
1557 })),
1558 join_type: JoinType::Inner,
1559 conditions: vec![],
1560 })),
1561 }));
1562
1563 let optimized = optimizer.optimize(plan).unwrap();
1564
1565 if let LogicalOperator::Join(join) = &optimized.root
1567 && let LogicalOperator::Filter(_) = join.left.as_ref()
1568 {
1569 return;
1570 }
1571 panic!("Expected Join with Filter on left side");
1572 }
1573
1574 #[test]
1575 fn test_filter_pushdown_to_right_join_side() {
1576 let optimizer = Optimizer::new();
1577
1578 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1580 predicate: LogicalExpression::Binary {
1581 left: Box::new(LogicalExpression::Property {
1582 variable: "b".to_string(),
1583 property: "name".to_string(),
1584 }),
1585 op: BinaryOp::Eq,
1586 right: Box::new(LogicalExpression::Literal(Value::String("Acme".into()))),
1587 },
1588 input: Box::new(LogicalOperator::Join(JoinOp {
1589 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1590 variable: "a".to_string(),
1591 label: Some("Person".to_string()),
1592 input: None,
1593 })),
1594 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1595 variable: "b".to_string(),
1596 label: Some("Company".to_string()),
1597 input: None,
1598 })),
1599 join_type: JoinType::Inner,
1600 conditions: vec![],
1601 })),
1602 }));
1603
1604 let optimized = optimizer.optimize(plan).unwrap();
1605
1606 if let LogicalOperator::Join(join) = &optimized.root
1608 && let LogicalOperator::Filter(_) = join.right.as_ref()
1609 {
1610 return;
1611 }
1612 panic!("Expected Join with Filter on right side");
1613 }
1614
1615 #[test]
1616 fn test_filter_not_pushed_when_uses_both_join_sides() {
1617 let optimizer = Optimizer::new();
1618
1619 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1621 predicate: LogicalExpression::Binary {
1622 left: Box::new(LogicalExpression::Property {
1623 variable: "a".to_string(),
1624 property: "id".to_string(),
1625 }),
1626 op: BinaryOp::Eq,
1627 right: Box::new(LogicalExpression::Property {
1628 variable: "b".to_string(),
1629 property: "a_id".to_string(),
1630 }),
1631 },
1632 input: Box::new(LogicalOperator::Join(JoinOp {
1633 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1634 variable: "a".to_string(),
1635 label: None,
1636 input: None,
1637 })),
1638 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1639 variable: "b".to_string(),
1640 label: None,
1641 input: None,
1642 })),
1643 join_type: JoinType::Inner,
1644 conditions: vec![],
1645 })),
1646 }));
1647
1648 let optimized = optimizer.optimize(plan).unwrap();
1649
1650 if let LogicalOperator::Filter(filter) = &optimized.root
1652 && let LogicalOperator::Join(_) = filter.input.as_ref()
1653 {
1654 return;
1655 }
1656 panic!("Expected Filter -> Join structure");
1657 }
1658
1659 #[test]
1662 fn test_extract_variables_from_variable() {
1663 let optimizer = Optimizer::new();
1664 let expr = LogicalExpression::Variable("x".to_string());
1665 let vars = optimizer.extract_variables(&expr);
1666 assert_eq!(vars.len(), 1);
1667 assert!(vars.contains("x"));
1668 }
1669
1670 #[test]
1671 fn test_extract_variables_from_unary() {
1672 let optimizer = Optimizer::new();
1673 let expr = LogicalExpression::Unary {
1674 op: UnaryOp::Not,
1675 operand: Box::new(LogicalExpression::Variable("x".to_string())),
1676 };
1677 let vars = optimizer.extract_variables(&expr);
1678 assert_eq!(vars.len(), 1);
1679 assert!(vars.contains("x"));
1680 }
1681
1682 #[test]
1683 fn test_extract_variables_from_function_call() {
1684 let optimizer = Optimizer::new();
1685 let expr = LogicalExpression::FunctionCall {
1686 name: "length".to_string(),
1687 args: vec![
1688 LogicalExpression::Variable("a".to_string()),
1689 LogicalExpression::Variable("b".to_string()),
1690 ],
1691 distinct: false,
1692 };
1693 let vars = optimizer.extract_variables(&expr);
1694 assert_eq!(vars.len(), 2);
1695 assert!(vars.contains("a"));
1696 assert!(vars.contains("b"));
1697 }
1698
1699 #[test]
1700 fn test_extract_variables_from_list() {
1701 let optimizer = Optimizer::new();
1702 let expr = LogicalExpression::List(vec![
1703 LogicalExpression::Variable("a".to_string()),
1704 LogicalExpression::Literal(Value::Int64(1)),
1705 LogicalExpression::Variable("b".to_string()),
1706 ]);
1707 let vars = optimizer.extract_variables(&expr);
1708 assert_eq!(vars.len(), 2);
1709 assert!(vars.contains("a"));
1710 assert!(vars.contains("b"));
1711 }
1712
1713 #[test]
1714 fn test_extract_variables_from_map() {
1715 let optimizer = Optimizer::new();
1716 let expr = LogicalExpression::Map(vec![
1717 (
1718 "key1".to_string(),
1719 LogicalExpression::Variable("a".to_string()),
1720 ),
1721 (
1722 "key2".to_string(),
1723 LogicalExpression::Variable("b".to_string()),
1724 ),
1725 ]);
1726 let vars = optimizer.extract_variables(&expr);
1727 assert_eq!(vars.len(), 2);
1728 assert!(vars.contains("a"));
1729 assert!(vars.contains("b"));
1730 }
1731
1732 #[test]
1733 fn test_extract_variables_from_index_access() {
1734 let optimizer = Optimizer::new();
1735 let expr = LogicalExpression::IndexAccess {
1736 base: Box::new(LogicalExpression::Variable("list".to_string())),
1737 index: Box::new(LogicalExpression::Variable("idx".to_string())),
1738 };
1739 let vars = optimizer.extract_variables(&expr);
1740 assert_eq!(vars.len(), 2);
1741 assert!(vars.contains("list"));
1742 assert!(vars.contains("idx"));
1743 }
1744
1745 #[test]
1746 fn test_extract_variables_from_slice_access() {
1747 let optimizer = Optimizer::new();
1748 let expr = LogicalExpression::SliceAccess {
1749 base: Box::new(LogicalExpression::Variable("list".to_string())),
1750 start: Some(Box::new(LogicalExpression::Variable("s".to_string()))),
1751 end: Some(Box::new(LogicalExpression::Variable("e".to_string()))),
1752 };
1753 let vars = optimizer.extract_variables(&expr);
1754 assert_eq!(vars.len(), 3);
1755 assert!(vars.contains("list"));
1756 assert!(vars.contains("s"));
1757 assert!(vars.contains("e"));
1758 }
1759
1760 #[test]
1761 fn test_extract_variables_from_case() {
1762 let optimizer = Optimizer::new();
1763 let expr = LogicalExpression::Case {
1764 operand: Some(Box::new(LogicalExpression::Variable("x".to_string()))),
1765 when_clauses: vec![(
1766 LogicalExpression::Literal(Value::Int64(1)),
1767 LogicalExpression::Variable("a".to_string()),
1768 )],
1769 else_clause: Some(Box::new(LogicalExpression::Variable("b".to_string()))),
1770 };
1771 let vars = optimizer.extract_variables(&expr);
1772 assert_eq!(vars.len(), 3);
1773 assert!(vars.contains("x"));
1774 assert!(vars.contains("a"));
1775 assert!(vars.contains("b"));
1776 }
1777
1778 #[test]
1779 fn test_extract_variables_from_labels() {
1780 let optimizer = Optimizer::new();
1781 let expr = LogicalExpression::Labels("n".to_string());
1782 let vars = optimizer.extract_variables(&expr);
1783 assert_eq!(vars.len(), 1);
1784 assert!(vars.contains("n"));
1785 }
1786
1787 #[test]
1788 fn test_extract_variables_from_type() {
1789 let optimizer = Optimizer::new();
1790 let expr = LogicalExpression::Type("e".to_string());
1791 let vars = optimizer.extract_variables(&expr);
1792 assert_eq!(vars.len(), 1);
1793 assert!(vars.contains("e"));
1794 }
1795
1796 #[test]
1797 fn test_extract_variables_from_id() {
1798 let optimizer = Optimizer::new();
1799 let expr = LogicalExpression::Id("n".to_string());
1800 let vars = optimizer.extract_variables(&expr);
1801 assert_eq!(vars.len(), 1);
1802 assert!(vars.contains("n"));
1803 }
1804
1805 #[test]
1806 fn test_extract_variables_from_list_comprehension() {
1807 let optimizer = Optimizer::new();
1808 let expr = LogicalExpression::ListComprehension {
1809 variable: "x".to_string(),
1810 list_expr: Box::new(LogicalExpression::Variable("items".to_string())),
1811 filter_expr: Some(Box::new(LogicalExpression::Variable("pred".to_string()))),
1812 map_expr: Box::new(LogicalExpression::Variable("result".to_string())),
1813 };
1814 let vars = optimizer.extract_variables(&expr);
1815 assert!(vars.contains("items"));
1816 assert!(vars.contains("pred"));
1817 assert!(vars.contains("result"));
1818 }
1819
1820 #[test]
1821 fn test_extract_variables_from_literal_and_parameter() {
1822 let optimizer = Optimizer::new();
1823
1824 let literal = LogicalExpression::Literal(Value::Int64(42));
1825 assert!(optimizer.extract_variables(&literal).is_empty());
1826
1827 let param = LogicalExpression::Parameter("p".to_string());
1828 assert!(optimizer.extract_variables(¶m).is_empty());
1829 }
1830
1831 #[test]
1834 fn test_recursive_filter_pushdown_through_skip() {
1835 let optimizer = Optimizer::new();
1836
1837 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1838 items: vec![ReturnItem {
1839 expression: LogicalExpression::Variable("n".to_string()),
1840 alias: None,
1841 }],
1842 distinct: false,
1843 input: Box::new(LogicalOperator::Filter(FilterOp {
1844 predicate: LogicalExpression::Literal(Value::Bool(true)),
1845 input: Box::new(LogicalOperator::Skip(SkipOp {
1846 count: 5,
1847 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1848 variable: "n".to_string(),
1849 label: None,
1850 input: None,
1851 })),
1852 })),
1853 })),
1854 }));
1855
1856 let optimized = optimizer.optimize(plan).unwrap();
1857
1858 assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1860 }
1861
1862 #[test]
1863 fn test_nested_filter_pushdown() {
1864 let optimizer = Optimizer::new();
1865
1866 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1868 items: vec![ReturnItem {
1869 expression: LogicalExpression::Variable("n".to_string()),
1870 alias: None,
1871 }],
1872 distinct: false,
1873 input: Box::new(LogicalOperator::Filter(FilterOp {
1874 predicate: LogicalExpression::Binary {
1875 left: Box::new(LogicalExpression::Property {
1876 variable: "n".to_string(),
1877 property: "x".to_string(),
1878 }),
1879 op: BinaryOp::Gt,
1880 right: Box::new(LogicalExpression::Literal(Value::Int64(1))),
1881 },
1882 input: Box::new(LogicalOperator::Filter(FilterOp {
1883 predicate: LogicalExpression::Binary {
1884 left: Box::new(LogicalExpression::Property {
1885 variable: "n".to_string(),
1886 property: "y".to_string(),
1887 }),
1888 op: BinaryOp::Lt,
1889 right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1890 },
1891 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1892 variable: "n".to_string(),
1893 label: None,
1894 input: None,
1895 })),
1896 })),
1897 })),
1898 }));
1899
1900 let optimized = optimizer.optimize(plan).unwrap();
1901 assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1902 }
1903}