1pub mod cardinality;
18pub mod cost;
19pub mod join_order;
20
21pub use cardinality::{CardinalityEstimator, ColumnStats, TableStats};
22pub use cost::{Cost, CostModel};
23pub use join_order::{BitSet, DPccp, JoinGraph, JoinGraphBuilder, JoinPlan};
24
25use crate::query::plan::{FilterOp, LogicalExpression, LogicalOperator, LogicalPlan};
26use graphos_common::utils::error::Result;
27use std::collections::HashSet;
28
29pub struct Optimizer {
31 enable_filter_pushdown: bool,
33 enable_join_reorder: bool,
35 cost_model: CostModel,
37 card_estimator: CardinalityEstimator,
39}
40
41impl Optimizer {
42 #[must_use]
44 pub fn new() -> Self {
45 Self {
46 enable_filter_pushdown: true,
47 enable_join_reorder: true,
48 cost_model: CostModel::new(),
49 card_estimator: CardinalityEstimator::new(),
50 }
51 }
52
53 pub fn with_filter_pushdown(mut self, enabled: bool) -> Self {
55 self.enable_filter_pushdown = enabled;
56 self
57 }
58
59 pub fn with_join_reorder(mut self, enabled: bool) -> Self {
61 self.enable_join_reorder = enabled;
62 self
63 }
64
65 pub fn with_cost_model(mut self, cost_model: CostModel) -> Self {
67 self.cost_model = cost_model;
68 self
69 }
70
71 pub fn with_cardinality_estimator(mut self, estimator: CardinalityEstimator) -> Self {
73 self.card_estimator = estimator;
74 self
75 }
76
77 pub fn cost_model(&self) -> &CostModel {
79 &self.cost_model
80 }
81
82 pub fn cardinality_estimator(&self) -> &CardinalityEstimator {
84 &self.card_estimator
85 }
86
87 pub fn estimate_cost(&self, plan: &LogicalPlan) -> Cost {
89 let cardinality = self.card_estimator.estimate(&plan.root);
90 self.cost_model.estimate(&plan.root, cardinality)
91 }
92
93 pub fn estimate_cardinality(&self, plan: &LogicalPlan) -> f64 {
95 self.card_estimator.estimate(&plan.root)
96 }
97
98 pub fn optimize(&self, plan: LogicalPlan) -> Result<LogicalPlan> {
104 let mut root = plan.root;
105
106 if self.enable_filter_pushdown {
108 root = self.push_filters_down(root);
109 }
110
111 Ok(LogicalPlan::new(root))
112 }
113
114 fn push_filters_down(&self, op: LogicalOperator) -> LogicalOperator {
119 match op {
120 LogicalOperator::Filter(filter) => {
122 let optimized_input = self.push_filters_down(*filter.input);
123 self.try_push_filter_into(filter.predicate, optimized_input)
124 }
125 LogicalOperator::Return(mut ret) => {
127 ret.input = Box::new(self.push_filters_down(*ret.input));
128 LogicalOperator::Return(ret)
129 }
130 LogicalOperator::Project(mut proj) => {
131 proj.input = Box::new(self.push_filters_down(*proj.input));
132 LogicalOperator::Project(proj)
133 }
134 LogicalOperator::Limit(mut limit) => {
135 limit.input = Box::new(self.push_filters_down(*limit.input));
136 LogicalOperator::Limit(limit)
137 }
138 LogicalOperator::Skip(mut skip) => {
139 skip.input = Box::new(self.push_filters_down(*skip.input));
140 LogicalOperator::Skip(skip)
141 }
142 LogicalOperator::Sort(mut sort) => {
143 sort.input = Box::new(self.push_filters_down(*sort.input));
144 LogicalOperator::Sort(sort)
145 }
146 LogicalOperator::Distinct(mut distinct) => {
147 distinct.input = Box::new(self.push_filters_down(*distinct.input));
148 LogicalOperator::Distinct(distinct)
149 }
150 LogicalOperator::Expand(mut expand) => {
151 expand.input = Box::new(self.push_filters_down(*expand.input));
152 LogicalOperator::Expand(expand)
153 }
154 LogicalOperator::Join(mut join) => {
155 join.left = Box::new(self.push_filters_down(*join.left));
156 join.right = Box::new(self.push_filters_down(*join.right));
157 LogicalOperator::Join(join)
158 }
159 LogicalOperator::Aggregate(mut agg) => {
160 agg.input = Box::new(self.push_filters_down(*agg.input));
161 LogicalOperator::Aggregate(agg)
162 }
163 other => other,
165 }
166 }
167
168 fn try_push_filter_into(
173 &self,
174 predicate: LogicalExpression,
175 op: LogicalOperator,
176 ) -> LogicalOperator {
177 match op {
178 LogicalOperator::Project(mut proj) => {
180 let predicate_vars = self.extract_variables(&predicate);
181 let computed_vars = self.extract_projection_aliases(&proj.projections);
182
183 if predicate_vars.is_disjoint(&computed_vars) {
185 proj.input = Box::new(self.try_push_filter_into(predicate, *proj.input));
186 LogicalOperator::Project(proj)
187 } else {
188 LogicalOperator::Filter(FilterOp {
190 predicate,
191 input: Box::new(LogicalOperator::Project(proj)),
192 })
193 }
194 }
195
196 LogicalOperator::Return(mut ret) => {
198 ret.input = Box::new(self.try_push_filter_into(predicate, *ret.input));
199 LogicalOperator::Return(ret)
200 }
201
202 LogicalOperator::Expand(mut expand) => {
204 let predicate_vars = self.extract_variables(&predicate);
205
206 let uses_only_source = predicate_vars.iter().all(|v| v == &expand.from_variable);
208
209 if uses_only_source {
210 expand.input = Box::new(self.try_push_filter_into(predicate, *expand.input));
212 LogicalOperator::Expand(expand)
213 } else {
214 LogicalOperator::Filter(FilterOp {
216 predicate,
217 input: Box::new(LogicalOperator::Expand(expand)),
218 })
219 }
220 }
221
222 LogicalOperator::Join(mut join) => {
224 let predicate_vars = self.extract_variables(&predicate);
225 let left_vars = self.collect_output_variables(&join.left);
226 let right_vars = self.collect_output_variables(&join.right);
227
228 let uses_left = predicate_vars.iter().any(|v| left_vars.contains(v));
229 let uses_right = predicate_vars.iter().any(|v| right_vars.contains(v));
230
231 if uses_left && !uses_right {
232 join.left = Box::new(self.try_push_filter_into(predicate, *join.left));
234 LogicalOperator::Join(join)
235 } else if uses_right && !uses_left {
236 join.right = Box::new(self.try_push_filter_into(predicate, *join.right));
238 LogicalOperator::Join(join)
239 } else {
240 LogicalOperator::Filter(FilterOp {
242 predicate,
243 input: Box::new(LogicalOperator::Join(join)),
244 })
245 }
246 }
247
248 LogicalOperator::Aggregate(agg) => LogicalOperator::Filter(FilterOp {
250 predicate,
251 input: Box::new(LogicalOperator::Aggregate(agg)),
252 }),
253
254 LogicalOperator::NodeScan(scan) => LogicalOperator::Filter(FilterOp {
256 predicate,
257 input: Box::new(LogicalOperator::NodeScan(scan)),
258 }),
259
260 other => LogicalOperator::Filter(FilterOp {
262 predicate,
263 input: Box::new(other),
264 }),
265 }
266 }
267
268 fn collect_output_variables(&self, op: &LogicalOperator) -> HashSet<String> {
270 let mut vars = HashSet::new();
271 self.collect_output_variables_recursive(op, &mut vars);
272 vars
273 }
274
275 fn collect_output_variables_recursive(&self, op: &LogicalOperator, vars: &mut HashSet<String>) {
277 match op {
278 LogicalOperator::NodeScan(scan) => {
279 vars.insert(scan.variable.clone());
280 }
281 LogicalOperator::EdgeScan(scan) => {
282 vars.insert(scan.variable.clone());
283 }
284 LogicalOperator::Expand(expand) => {
285 vars.insert(expand.to_variable.clone());
286 if let Some(edge_var) = &expand.edge_variable {
287 vars.insert(edge_var.clone());
288 }
289 self.collect_output_variables_recursive(&expand.input, vars);
290 }
291 LogicalOperator::Filter(filter) => {
292 self.collect_output_variables_recursive(&filter.input, vars);
293 }
294 LogicalOperator::Project(proj) => {
295 for p in &proj.projections {
296 if let Some(alias) = &p.alias {
297 vars.insert(alias.clone());
298 }
299 }
300 self.collect_output_variables_recursive(&proj.input, vars);
301 }
302 LogicalOperator::Join(join) => {
303 self.collect_output_variables_recursive(&join.left, vars);
304 self.collect_output_variables_recursive(&join.right, vars);
305 }
306 LogicalOperator::Aggregate(agg) => {
307 for expr in &agg.group_by {
308 self.collect_variables(expr, vars);
309 }
310 for agg_expr in &agg.aggregates {
311 if let Some(alias) = &agg_expr.alias {
312 vars.insert(alias.clone());
313 }
314 }
315 }
316 LogicalOperator::Return(ret) => {
317 self.collect_output_variables_recursive(&ret.input, vars);
318 }
319 LogicalOperator::Limit(limit) => {
320 self.collect_output_variables_recursive(&limit.input, vars);
321 }
322 LogicalOperator::Skip(skip) => {
323 self.collect_output_variables_recursive(&skip.input, vars);
324 }
325 LogicalOperator::Sort(sort) => {
326 self.collect_output_variables_recursive(&sort.input, vars);
327 }
328 LogicalOperator::Distinct(distinct) => {
329 self.collect_output_variables_recursive(&distinct.input, vars);
330 }
331 _ => {}
332 }
333 }
334
335 fn extract_variables(&self, expr: &LogicalExpression) -> HashSet<String> {
337 let mut vars = HashSet::new();
338 self.collect_variables(expr, &mut vars);
339 vars
340 }
341
342 fn collect_variables(&self, expr: &LogicalExpression, vars: &mut HashSet<String>) {
344 match expr {
345 LogicalExpression::Variable(name) => {
346 vars.insert(name.clone());
347 }
348 LogicalExpression::Property { variable, .. } => {
349 vars.insert(variable.clone());
350 }
351 LogicalExpression::Binary { left, right, .. } => {
352 self.collect_variables(left, vars);
353 self.collect_variables(right, vars);
354 }
355 LogicalExpression::Unary { operand, .. } => {
356 self.collect_variables(operand, vars);
357 }
358 LogicalExpression::FunctionCall { args, .. } => {
359 for arg in args {
360 self.collect_variables(arg, vars);
361 }
362 }
363 LogicalExpression::List(items) => {
364 for item in items {
365 self.collect_variables(item, vars);
366 }
367 }
368 LogicalExpression::Map(pairs) => {
369 for (_, value) in pairs {
370 self.collect_variables(value, vars);
371 }
372 }
373 LogicalExpression::IndexAccess { base, index } => {
374 self.collect_variables(base, vars);
375 self.collect_variables(index, vars);
376 }
377 LogicalExpression::SliceAccess { base, start, end } => {
378 self.collect_variables(base, vars);
379 if let Some(s) = start {
380 self.collect_variables(s, vars);
381 }
382 if let Some(e) = end {
383 self.collect_variables(e, vars);
384 }
385 }
386 LogicalExpression::Case {
387 operand,
388 when_clauses,
389 else_clause,
390 } => {
391 if let Some(op) = operand {
392 self.collect_variables(op, vars);
393 }
394 for (cond, result) in when_clauses {
395 self.collect_variables(cond, vars);
396 self.collect_variables(result, vars);
397 }
398 if let Some(else_expr) = else_clause {
399 self.collect_variables(else_expr, vars);
400 }
401 }
402 LogicalExpression::Labels(var)
403 | LogicalExpression::Type(var)
404 | LogicalExpression::Id(var) => {
405 vars.insert(var.clone());
406 }
407 LogicalExpression::Literal(_) | LogicalExpression::Parameter(_) => {}
408 LogicalExpression::ListComprehension {
409 list_expr,
410 filter_expr,
411 map_expr,
412 ..
413 } => {
414 self.collect_variables(list_expr, vars);
415 if let Some(filter) = filter_expr {
416 self.collect_variables(filter, vars);
417 }
418 self.collect_variables(map_expr, vars);
419 }
420 LogicalExpression::ExistsSubquery(_) | LogicalExpression::CountSubquery(_) => {
421 }
423 }
424 }
425
426 fn extract_projection_aliases(
428 &self,
429 projections: &[crate::query::plan::Projection],
430 ) -> HashSet<String> {
431 projections.iter().filter_map(|p| p.alias.clone()).collect()
432 }
433}
434
435impl Default for Optimizer {
436 fn default() -> Self {
437 Self::new()
438 }
439}
440
441#[cfg(test)]
442mod tests {
443 use super::*;
444 use crate::query::plan::{
445 AggregateExpr, AggregateFunction, AggregateOp, BinaryOp, DistinctOp, ExpandDirection,
446 ExpandOp, JoinOp, JoinType, LimitOp, NodeScanOp, ProjectOp, Projection, ReturnItem,
447 ReturnOp, SkipOp, SortKey, SortOp, SortOrder, UnaryOp,
448 };
449 use graphos_common::types::Value;
450
451 #[test]
452 fn test_optimizer_filter_pushdown_simple() {
453 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
458 items: vec![ReturnItem {
459 expression: LogicalExpression::Variable("n".to_string()),
460 alias: None,
461 }],
462 distinct: false,
463 input: Box::new(LogicalOperator::Filter(FilterOp {
464 predicate: LogicalExpression::Binary {
465 left: Box::new(LogicalExpression::Property {
466 variable: "n".to_string(),
467 property: "age".to_string(),
468 }),
469 op: BinaryOp::Gt,
470 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
471 },
472 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
473 variable: "n".to_string(),
474 label: Some("Person".to_string()),
475 input: None,
476 })),
477 })),
478 }));
479
480 let optimizer = Optimizer::new();
481 let optimized = optimizer.optimize(plan).unwrap();
482
483 if let LogicalOperator::Return(ret) = &optimized.root {
485 if let LogicalOperator::Filter(filter) = ret.input.as_ref() {
486 if let LogicalOperator::NodeScan(scan) = filter.input.as_ref() {
487 assert_eq!(scan.variable, "n");
488 return;
489 }
490 }
491 }
492 panic!("Expected Return -> Filter -> NodeScan structure");
493 }
494
495 #[test]
496 fn test_optimizer_filter_pushdown_through_expand() {
497 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
501 items: vec![ReturnItem {
502 expression: LogicalExpression::Variable("b".to_string()),
503 alias: None,
504 }],
505 distinct: false,
506 input: Box::new(LogicalOperator::Filter(FilterOp {
507 predicate: LogicalExpression::Binary {
508 left: Box::new(LogicalExpression::Property {
509 variable: "a".to_string(),
510 property: "age".to_string(),
511 }),
512 op: BinaryOp::Gt,
513 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
514 },
515 input: Box::new(LogicalOperator::Expand(ExpandOp {
516 from_variable: "a".to_string(),
517 to_variable: "b".to_string(),
518 edge_variable: None,
519 direction: ExpandDirection::Outgoing,
520 edge_type: Some("KNOWS".to_string()),
521 min_hops: 1,
522 max_hops: Some(1),
523 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
524 variable: "a".to_string(),
525 label: Some("Person".to_string()),
526 input: None,
527 })),
528 })),
529 })),
530 }));
531
532 let optimizer = Optimizer::new();
533 let optimized = optimizer.optimize(plan).unwrap();
534
535 if let LogicalOperator::Return(ret) = &optimized.root {
538 if let LogicalOperator::Expand(expand) = ret.input.as_ref() {
539 if let LogicalOperator::Filter(filter) = expand.input.as_ref() {
540 if let LogicalOperator::NodeScan(scan) = filter.input.as_ref() {
541 assert_eq!(scan.variable, "a");
542 assert_eq!(expand.from_variable, "a");
543 assert_eq!(expand.to_variable, "b");
544 return;
545 }
546 }
547 }
548 }
549 panic!("Expected Return -> Expand -> Filter -> NodeScan structure");
550 }
551
552 #[test]
553 fn test_optimizer_filter_not_pushed_through_expand_for_target_var() {
554 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
558 items: vec![ReturnItem {
559 expression: LogicalExpression::Variable("a".to_string()),
560 alias: None,
561 }],
562 distinct: false,
563 input: Box::new(LogicalOperator::Filter(FilterOp {
564 predicate: LogicalExpression::Binary {
565 left: Box::new(LogicalExpression::Property {
566 variable: "b".to_string(),
567 property: "age".to_string(),
568 }),
569 op: BinaryOp::Gt,
570 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
571 },
572 input: Box::new(LogicalOperator::Expand(ExpandOp {
573 from_variable: "a".to_string(),
574 to_variable: "b".to_string(),
575 edge_variable: None,
576 direction: ExpandDirection::Outgoing,
577 edge_type: Some("KNOWS".to_string()),
578 min_hops: 1,
579 max_hops: Some(1),
580 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
581 variable: "a".to_string(),
582 label: Some("Person".to_string()),
583 input: None,
584 })),
585 })),
586 })),
587 }));
588
589 let optimizer = Optimizer::new();
590 let optimized = optimizer.optimize(plan).unwrap();
591
592 if let LogicalOperator::Return(ret) = &optimized.root {
595 if let LogicalOperator::Filter(filter) = ret.input.as_ref() {
596 if let LogicalExpression::Binary { left, .. } = &filter.predicate {
598 if let LogicalExpression::Property { variable, .. } = left.as_ref() {
599 assert_eq!(variable, "b");
600 }
601 }
602
603 if let LogicalOperator::Expand(expand) = filter.input.as_ref() {
604 if let LogicalOperator::NodeScan(_) = expand.input.as_ref() {
605 return;
606 }
607 }
608 }
609 }
610 panic!("Expected Return -> Filter -> Expand -> NodeScan structure");
611 }
612
613 #[test]
614 fn test_optimizer_extract_variables() {
615 let optimizer = Optimizer::new();
616
617 let expr = LogicalExpression::Binary {
618 left: Box::new(LogicalExpression::Property {
619 variable: "n".to_string(),
620 property: "age".to_string(),
621 }),
622 op: BinaryOp::Gt,
623 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
624 };
625
626 let vars = optimizer.extract_variables(&expr);
627 assert_eq!(vars.len(), 1);
628 assert!(vars.contains("n"));
629 }
630
631 #[test]
634 fn test_optimizer_default() {
635 let optimizer = Optimizer::default();
636 let plan = LogicalPlan::new(LogicalOperator::Empty);
638 let result = optimizer.optimize(plan);
639 assert!(result.is_ok());
640 }
641
642 #[test]
643 fn test_optimizer_with_filter_pushdown_disabled() {
644 let optimizer = Optimizer::new().with_filter_pushdown(false);
645
646 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
647 items: vec![ReturnItem {
648 expression: LogicalExpression::Variable("n".to_string()),
649 alias: None,
650 }],
651 distinct: false,
652 input: Box::new(LogicalOperator::Filter(FilterOp {
653 predicate: LogicalExpression::Literal(Value::Bool(true)),
654 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
655 variable: "n".to_string(),
656 label: None,
657 input: None,
658 })),
659 })),
660 }));
661
662 let optimized = optimizer.optimize(plan).unwrap();
663 if let LogicalOperator::Return(ret) = &optimized.root {
665 if let LogicalOperator::Filter(_) = ret.input.as_ref() {
666 return;
667 }
668 }
669 panic!("Expected unchanged structure");
670 }
671
672 #[test]
673 fn test_optimizer_with_join_reorder_disabled() {
674 let optimizer = Optimizer::new().with_join_reorder(false);
675 assert!(
676 optimizer
677 .optimize(LogicalPlan::new(LogicalOperator::Empty))
678 .is_ok()
679 );
680 }
681
682 #[test]
683 fn test_optimizer_with_cost_model() {
684 let cost_model = CostModel::new();
685 let optimizer = Optimizer::new().with_cost_model(cost_model);
686 assert!(
687 optimizer
688 .cost_model()
689 .estimate(&LogicalOperator::Empty, 0.0)
690 .total()
691 < 0.001
692 );
693 }
694
695 #[test]
696 fn test_optimizer_with_cardinality_estimator() {
697 let mut estimator = CardinalityEstimator::new();
698 estimator.add_table_stats("Test", TableStats::new(500));
699 let optimizer = Optimizer::new().with_cardinality_estimator(estimator);
700
701 let scan = LogicalOperator::NodeScan(NodeScanOp {
702 variable: "n".to_string(),
703 label: Some("Test".to_string()),
704 input: None,
705 });
706 let plan = LogicalPlan::new(scan);
707
708 let cardinality = optimizer.estimate_cardinality(&plan);
709 assert!((cardinality - 500.0).abs() < 0.001);
710 }
711
712 #[test]
713 fn test_optimizer_estimate_cost() {
714 let optimizer = Optimizer::new();
715 let plan = LogicalPlan::new(LogicalOperator::NodeScan(NodeScanOp {
716 variable: "n".to_string(),
717 label: None,
718 input: None,
719 }));
720
721 let cost = optimizer.estimate_cost(&plan);
722 assert!(cost.total() > 0.0);
723 }
724
725 #[test]
728 fn test_filter_pushdown_through_project() {
729 let optimizer = Optimizer::new();
730
731 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
732 predicate: LogicalExpression::Binary {
733 left: Box::new(LogicalExpression::Property {
734 variable: "n".to_string(),
735 property: "age".to_string(),
736 }),
737 op: BinaryOp::Gt,
738 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
739 },
740 input: Box::new(LogicalOperator::Project(ProjectOp {
741 projections: vec![Projection {
742 expression: LogicalExpression::Variable("n".to_string()),
743 alias: None,
744 }],
745 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
746 variable: "n".to_string(),
747 label: None,
748 input: None,
749 })),
750 })),
751 }));
752
753 let optimized = optimizer.optimize(plan).unwrap();
754
755 if let LogicalOperator::Project(proj) = &optimized.root {
757 if let LogicalOperator::Filter(_) = proj.input.as_ref() {
758 return;
759 }
760 }
761 panic!("Expected Project -> Filter structure");
762 }
763
764 #[test]
765 fn test_filter_not_pushed_through_project_with_alias() {
766 let optimizer = Optimizer::new();
767
768 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
770 predicate: LogicalExpression::Binary {
771 left: Box::new(LogicalExpression::Variable("x".to_string())),
772 op: BinaryOp::Gt,
773 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
774 },
775 input: Box::new(LogicalOperator::Project(ProjectOp {
776 projections: vec![Projection {
777 expression: LogicalExpression::Property {
778 variable: "n".to_string(),
779 property: "age".to_string(),
780 },
781 alias: Some("x".to_string()),
782 }],
783 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
784 variable: "n".to_string(),
785 label: None,
786 input: None,
787 })),
788 })),
789 }));
790
791 let optimized = optimizer.optimize(plan).unwrap();
792
793 if let LogicalOperator::Filter(filter) = &optimized.root {
795 if let LogicalOperator::Project(_) = filter.input.as_ref() {
796 return;
797 }
798 }
799 panic!("Expected Filter -> Project structure");
800 }
801
802 #[test]
803 fn test_filter_pushdown_through_limit() {
804 let optimizer = Optimizer::new();
805
806 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
807 predicate: LogicalExpression::Literal(Value::Bool(true)),
808 input: Box::new(LogicalOperator::Limit(LimitOp {
809 count: 10,
810 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
811 variable: "n".to_string(),
812 label: None,
813 input: None,
814 })),
815 })),
816 }));
817
818 let optimized = optimizer.optimize(plan).unwrap();
819
820 if let LogicalOperator::Filter(filter) = &optimized.root {
822 if let LogicalOperator::Limit(_) = filter.input.as_ref() {
823 return;
824 }
825 }
826 panic!("Expected Filter -> Limit structure");
827 }
828
829 #[test]
830 fn test_filter_pushdown_through_sort() {
831 let optimizer = Optimizer::new();
832
833 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
834 predicate: LogicalExpression::Literal(Value::Bool(true)),
835 input: Box::new(LogicalOperator::Sort(SortOp {
836 keys: vec![SortKey {
837 expression: LogicalExpression::Variable("n".to_string()),
838 order: SortOrder::Ascending,
839 }],
840 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
841 variable: "n".to_string(),
842 label: None,
843 input: None,
844 })),
845 })),
846 }));
847
848 let optimized = optimizer.optimize(plan).unwrap();
849
850 if let LogicalOperator::Filter(filter) = &optimized.root {
852 if let LogicalOperator::Sort(_) = filter.input.as_ref() {
853 return;
854 }
855 }
856 panic!("Expected Filter -> Sort structure");
857 }
858
859 #[test]
860 fn test_filter_pushdown_through_distinct() {
861 let optimizer = Optimizer::new();
862
863 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
864 predicate: LogicalExpression::Literal(Value::Bool(true)),
865 input: Box::new(LogicalOperator::Distinct(DistinctOp {
866 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
867 variable: "n".to_string(),
868 label: None,
869 input: None,
870 })),
871 })),
872 }));
873
874 let optimized = optimizer.optimize(plan).unwrap();
875
876 if let LogicalOperator::Filter(filter) = &optimized.root {
878 if let LogicalOperator::Distinct(_) = filter.input.as_ref() {
879 return;
880 }
881 }
882 panic!("Expected Filter -> Distinct structure");
883 }
884
885 #[test]
886 fn test_filter_not_pushed_through_aggregate() {
887 let optimizer = Optimizer::new();
888
889 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
890 predicate: LogicalExpression::Binary {
891 left: Box::new(LogicalExpression::Variable("cnt".to_string())),
892 op: BinaryOp::Gt,
893 right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
894 },
895 input: Box::new(LogicalOperator::Aggregate(AggregateOp {
896 group_by: vec![],
897 aggregates: vec![AggregateExpr {
898 function: AggregateFunction::Count,
899 expression: None,
900 distinct: false,
901 alias: Some("cnt".to_string()),
902 }],
903 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
904 variable: "n".to_string(),
905 label: None,
906 input: None,
907 })),
908 })),
909 }));
910
911 let optimized = optimizer.optimize(plan).unwrap();
912
913 if let LogicalOperator::Filter(filter) = &optimized.root {
915 if let LogicalOperator::Aggregate(_) = filter.input.as_ref() {
916 return;
917 }
918 }
919 panic!("Expected Filter -> Aggregate structure");
920 }
921
922 #[test]
923 fn test_filter_pushdown_to_left_join_side() {
924 let optimizer = Optimizer::new();
925
926 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
928 predicate: LogicalExpression::Binary {
929 left: Box::new(LogicalExpression::Property {
930 variable: "a".to_string(),
931 property: "age".to_string(),
932 }),
933 op: BinaryOp::Gt,
934 right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
935 },
936 input: Box::new(LogicalOperator::Join(JoinOp {
937 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
938 variable: "a".to_string(),
939 label: Some("Person".to_string()),
940 input: None,
941 })),
942 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
943 variable: "b".to_string(),
944 label: Some("Company".to_string()),
945 input: None,
946 })),
947 join_type: JoinType::Inner,
948 conditions: vec![],
949 })),
950 }));
951
952 let optimized = optimizer.optimize(plan).unwrap();
953
954 if let LogicalOperator::Join(join) = &optimized.root {
956 if let LogicalOperator::Filter(_) = join.left.as_ref() {
957 return;
958 }
959 }
960 panic!("Expected Join with Filter on left side");
961 }
962
963 #[test]
964 fn test_filter_pushdown_to_right_join_side() {
965 let optimizer = Optimizer::new();
966
967 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
969 predicate: LogicalExpression::Binary {
970 left: Box::new(LogicalExpression::Property {
971 variable: "b".to_string(),
972 property: "name".to_string(),
973 }),
974 op: BinaryOp::Eq,
975 right: Box::new(LogicalExpression::Literal(Value::String("Acme".into()))),
976 },
977 input: Box::new(LogicalOperator::Join(JoinOp {
978 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
979 variable: "a".to_string(),
980 label: Some("Person".to_string()),
981 input: None,
982 })),
983 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
984 variable: "b".to_string(),
985 label: Some("Company".to_string()),
986 input: None,
987 })),
988 join_type: JoinType::Inner,
989 conditions: vec![],
990 })),
991 }));
992
993 let optimized = optimizer.optimize(plan).unwrap();
994
995 if let LogicalOperator::Join(join) = &optimized.root {
997 if let LogicalOperator::Filter(_) = join.right.as_ref() {
998 return;
999 }
1000 }
1001 panic!("Expected Join with Filter on right side");
1002 }
1003
1004 #[test]
1005 fn test_filter_not_pushed_when_uses_both_join_sides() {
1006 let optimizer = Optimizer::new();
1007
1008 let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1010 predicate: LogicalExpression::Binary {
1011 left: Box::new(LogicalExpression::Property {
1012 variable: "a".to_string(),
1013 property: "id".to_string(),
1014 }),
1015 op: BinaryOp::Eq,
1016 right: Box::new(LogicalExpression::Property {
1017 variable: "b".to_string(),
1018 property: "a_id".to_string(),
1019 }),
1020 },
1021 input: Box::new(LogicalOperator::Join(JoinOp {
1022 left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1023 variable: "a".to_string(),
1024 label: None,
1025 input: None,
1026 })),
1027 right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1028 variable: "b".to_string(),
1029 label: None,
1030 input: None,
1031 })),
1032 join_type: JoinType::Inner,
1033 conditions: vec![],
1034 })),
1035 }));
1036
1037 let optimized = optimizer.optimize(plan).unwrap();
1038
1039 if let LogicalOperator::Filter(filter) = &optimized.root {
1041 if let LogicalOperator::Join(_) = filter.input.as_ref() {
1042 return;
1043 }
1044 }
1045 panic!("Expected Filter -> Join structure");
1046 }
1047
1048 #[test]
1051 fn test_extract_variables_from_variable() {
1052 let optimizer = Optimizer::new();
1053 let expr = LogicalExpression::Variable("x".to_string());
1054 let vars = optimizer.extract_variables(&expr);
1055 assert_eq!(vars.len(), 1);
1056 assert!(vars.contains("x"));
1057 }
1058
1059 #[test]
1060 fn test_extract_variables_from_unary() {
1061 let optimizer = Optimizer::new();
1062 let expr = LogicalExpression::Unary {
1063 op: UnaryOp::Not,
1064 operand: Box::new(LogicalExpression::Variable("x".to_string())),
1065 };
1066 let vars = optimizer.extract_variables(&expr);
1067 assert_eq!(vars.len(), 1);
1068 assert!(vars.contains("x"));
1069 }
1070
1071 #[test]
1072 fn test_extract_variables_from_function_call() {
1073 let optimizer = Optimizer::new();
1074 let expr = LogicalExpression::FunctionCall {
1075 name: "length".to_string(),
1076 args: vec![
1077 LogicalExpression::Variable("a".to_string()),
1078 LogicalExpression::Variable("b".to_string()),
1079 ],
1080 };
1081 let vars = optimizer.extract_variables(&expr);
1082 assert_eq!(vars.len(), 2);
1083 assert!(vars.contains("a"));
1084 assert!(vars.contains("b"));
1085 }
1086
1087 #[test]
1088 fn test_extract_variables_from_list() {
1089 let optimizer = Optimizer::new();
1090 let expr = LogicalExpression::List(vec![
1091 LogicalExpression::Variable("a".to_string()),
1092 LogicalExpression::Literal(Value::Int64(1)),
1093 LogicalExpression::Variable("b".to_string()),
1094 ]);
1095 let vars = optimizer.extract_variables(&expr);
1096 assert_eq!(vars.len(), 2);
1097 assert!(vars.contains("a"));
1098 assert!(vars.contains("b"));
1099 }
1100
1101 #[test]
1102 fn test_extract_variables_from_map() {
1103 let optimizer = Optimizer::new();
1104 let expr = LogicalExpression::Map(vec![
1105 (
1106 "key1".to_string(),
1107 LogicalExpression::Variable("a".to_string()),
1108 ),
1109 (
1110 "key2".to_string(),
1111 LogicalExpression::Variable("b".to_string()),
1112 ),
1113 ]);
1114 let vars = optimizer.extract_variables(&expr);
1115 assert_eq!(vars.len(), 2);
1116 assert!(vars.contains("a"));
1117 assert!(vars.contains("b"));
1118 }
1119
1120 #[test]
1121 fn test_extract_variables_from_index_access() {
1122 let optimizer = Optimizer::new();
1123 let expr = LogicalExpression::IndexAccess {
1124 base: Box::new(LogicalExpression::Variable("list".to_string())),
1125 index: Box::new(LogicalExpression::Variable("idx".to_string())),
1126 };
1127 let vars = optimizer.extract_variables(&expr);
1128 assert_eq!(vars.len(), 2);
1129 assert!(vars.contains("list"));
1130 assert!(vars.contains("idx"));
1131 }
1132
1133 #[test]
1134 fn test_extract_variables_from_slice_access() {
1135 let optimizer = Optimizer::new();
1136 let expr = LogicalExpression::SliceAccess {
1137 base: Box::new(LogicalExpression::Variable("list".to_string())),
1138 start: Some(Box::new(LogicalExpression::Variable("s".to_string()))),
1139 end: Some(Box::new(LogicalExpression::Variable("e".to_string()))),
1140 };
1141 let vars = optimizer.extract_variables(&expr);
1142 assert_eq!(vars.len(), 3);
1143 assert!(vars.contains("list"));
1144 assert!(vars.contains("s"));
1145 assert!(vars.contains("e"));
1146 }
1147
1148 #[test]
1149 fn test_extract_variables_from_case() {
1150 let optimizer = Optimizer::new();
1151 let expr = LogicalExpression::Case {
1152 operand: Some(Box::new(LogicalExpression::Variable("x".to_string()))),
1153 when_clauses: vec![(
1154 LogicalExpression::Literal(Value::Int64(1)),
1155 LogicalExpression::Variable("a".to_string()),
1156 )],
1157 else_clause: Some(Box::new(LogicalExpression::Variable("b".to_string()))),
1158 };
1159 let vars = optimizer.extract_variables(&expr);
1160 assert_eq!(vars.len(), 3);
1161 assert!(vars.contains("x"));
1162 assert!(vars.contains("a"));
1163 assert!(vars.contains("b"));
1164 }
1165
1166 #[test]
1167 fn test_extract_variables_from_labels() {
1168 let optimizer = Optimizer::new();
1169 let expr = LogicalExpression::Labels("n".to_string());
1170 let vars = optimizer.extract_variables(&expr);
1171 assert_eq!(vars.len(), 1);
1172 assert!(vars.contains("n"));
1173 }
1174
1175 #[test]
1176 fn test_extract_variables_from_type() {
1177 let optimizer = Optimizer::new();
1178 let expr = LogicalExpression::Type("e".to_string());
1179 let vars = optimizer.extract_variables(&expr);
1180 assert_eq!(vars.len(), 1);
1181 assert!(vars.contains("e"));
1182 }
1183
1184 #[test]
1185 fn test_extract_variables_from_id() {
1186 let optimizer = Optimizer::new();
1187 let expr = LogicalExpression::Id("n".to_string());
1188 let vars = optimizer.extract_variables(&expr);
1189 assert_eq!(vars.len(), 1);
1190 assert!(vars.contains("n"));
1191 }
1192
1193 #[test]
1194 fn test_extract_variables_from_list_comprehension() {
1195 let optimizer = Optimizer::new();
1196 let expr = LogicalExpression::ListComprehension {
1197 variable: "x".to_string(),
1198 list_expr: Box::new(LogicalExpression::Variable("items".to_string())),
1199 filter_expr: Some(Box::new(LogicalExpression::Variable("pred".to_string()))),
1200 map_expr: Box::new(LogicalExpression::Variable("result".to_string())),
1201 };
1202 let vars = optimizer.extract_variables(&expr);
1203 assert!(vars.contains("items"));
1204 assert!(vars.contains("pred"));
1205 assert!(vars.contains("result"));
1206 }
1207
1208 #[test]
1209 fn test_extract_variables_from_literal_and_parameter() {
1210 let optimizer = Optimizer::new();
1211
1212 let literal = LogicalExpression::Literal(Value::Int64(42));
1213 assert!(optimizer.extract_variables(&literal).is_empty());
1214
1215 let param = LogicalExpression::Parameter("p".to_string());
1216 assert!(optimizer.extract_variables(¶m).is_empty());
1217 }
1218
1219 #[test]
1222 fn test_recursive_filter_pushdown_through_skip() {
1223 let optimizer = Optimizer::new();
1224
1225 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1226 items: vec![ReturnItem {
1227 expression: LogicalExpression::Variable("n".to_string()),
1228 alias: None,
1229 }],
1230 distinct: false,
1231 input: Box::new(LogicalOperator::Filter(FilterOp {
1232 predicate: LogicalExpression::Literal(Value::Bool(true)),
1233 input: Box::new(LogicalOperator::Skip(SkipOp {
1234 count: 5,
1235 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1236 variable: "n".to_string(),
1237 label: None,
1238 input: None,
1239 })),
1240 })),
1241 })),
1242 }));
1243
1244 let optimized = optimizer.optimize(plan).unwrap();
1245
1246 assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1248 }
1249
1250 #[test]
1251 fn test_nested_filter_pushdown() {
1252 let optimizer = Optimizer::new();
1253
1254 let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1256 items: vec![ReturnItem {
1257 expression: LogicalExpression::Variable("n".to_string()),
1258 alias: None,
1259 }],
1260 distinct: false,
1261 input: Box::new(LogicalOperator::Filter(FilterOp {
1262 predicate: LogicalExpression::Binary {
1263 left: Box::new(LogicalExpression::Property {
1264 variable: "n".to_string(),
1265 property: "x".to_string(),
1266 }),
1267 op: BinaryOp::Gt,
1268 right: Box::new(LogicalExpression::Literal(Value::Int64(1))),
1269 },
1270 input: Box::new(LogicalOperator::Filter(FilterOp {
1271 predicate: LogicalExpression::Binary {
1272 left: Box::new(LogicalExpression::Property {
1273 variable: "n".to_string(),
1274 property: "y".to_string(),
1275 }),
1276 op: BinaryOp::Lt,
1277 right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1278 },
1279 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1280 variable: "n".to_string(),
1281 label: None,
1282 input: None,
1283 })),
1284 })),
1285 })),
1286 }));
1287
1288 let optimized = optimizer.optimize(plan).unwrap();
1289 assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1290 }
1291}