Skip to main content

grafeo_engine/query/optimizer/
mod.rs

1//! Query optimizer.
2//!
3//! Transforms logical plans for better performance.
4//!
5//! ## Optimization Rules
6//!
7//! - **Filter Pushdown**: Pushes filters closer to scans to reduce data early
8//! - **Predicate Simplification**: Simplifies constant expressions
9//! - **Join Reordering**: Optimizes join order using DPccp algorithm
10//!
11//! ## Submodules
12//!
13//! - [`cost`] - Cost model for estimating operator costs
14//! - [`cardinality`] - Cardinality estimation for query operators
15//! - [`join_order`] - DPccp join ordering algorithm
16
17pub mod cardinality;
18pub mod cost;
19pub mod join_order;
20
21pub use cardinality::{CardinalityEstimator, ColumnStats, TableStats};
22pub use cost::{Cost, CostModel};
23pub use join_order::{BitSet, DPccp, JoinGraph, JoinGraphBuilder, JoinPlan};
24
25use crate::query::plan::{FilterOp, LogicalExpression, LogicalOperator, LogicalPlan};
26use grafeo_common::utils::error::Result;
27use std::collections::HashSet;
28
29/// Query optimizer that transforms logical plans for better performance.
30pub struct Optimizer {
31    /// Whether to enable filter pushdown.
32    enable_filter_pushdown: bool,
33    /// Whether to enable join reordering.
34    enable_join_reorder: bool,
35    /// Cost model for estimation.
36    cost_model: CostModel,
37    /// Cardinality estimator.
38    card_estimator: CardinalityEstimator,
39}
40
41impl Optimizer {
42    /// Creates a new optimizer with default settings.
43    #[must_use]
44    pub fn new() -> Self {
45        Self {
46            enable_filter_pushdown: true,
47            enable_join_reorder: true,
48            cost_model: CostModel::new(),
49            card_estimator: CardinalityEstimator::new(),
50        }
51    }
52
53    /// Enables or disables filter pushdown.
54    pub fn with_filter_pushdown(mut self, enabled: bool) -> Self {
55        self.enable_filter_pushdown = enabled;
56        self
57    }
58
59    /// Enables or disables join reordering.
60    pub fn with_join_reorder(mut self, enabled: bool) -> Self {
61        self.enable_join_reorder = enabled;
62        self
63    }
64
65    /// Sets the cost model.
66    pub fn with_cost_model(mut self, cost_model: CostModel) -> Self {
67        self.cost_model = cost_model;
68        self
69    }
70
71    /// Sets the cardinality estimator.
72    pub fn with_cardinality_estimator(mut self, estimator: CardinalityEstimator) -> Self {
73        self.card_estimator = estimator;
74        self
75    }
76
77    /// Returns a reference to the cost model.
78    pub fn cost_model(&self) -> &CostModel {
79        &self.cost_model
80    }
81
82    /// Returns a reference to the cardinality estimator.
83    pub fn cardinality_estimator(&self) -> &CardinalityEstimator {
84        &self.card_estimator
85    }
86
87    /// Estimates the cost of a plan.
88    pub fn estimate_cost(&self, plan: &LogicalPlan) -> Cost {
89        let cardinality = self.card_estimator.estimate(&plan.root);
90        self.cost_model.estimate(&plan.root, cardinality)
91    }
92
93    /// Estimates the cardinality of a plan.
94    pub fn estimate_cardinality(&self, plan: &LogicalPlan) -> f64 {
95        self.card_estimator.estimate(&plan.root)
96    }
97
98    /// Optimizes a logical plan.
99    ///
100    /// # Errors
101    ///
102    /// Returns an error if optimization fails.
103    pub fn optimize(&self, plan: LogicalPlan) -> Result<LogicalPlan> {
104        let mut root = plan.root;
105
106        // Apply optimization rules
107        if self.enable_filter_pushdown {
108            root = self.push_filters_down(root);
109        }
110
111        Ok(LogicalPlan::new(root))
112    }
113
114    /// Pushes filters down the operator tree.
115    ///
116    /// This optimization moves filter predicates as close to the data source
117    /// as possible to reduce the amount of data processed by upper operators.
118    fn push_filters_down(&self, op: LogicalOperator) -> LogicalOperator {
119        match op {
120            // For Filter operators, try to push the predicate into the child
121            LogicalOperator::Filter(filter) => {
122                let optimized_input = self.push_filters_down(*filter.input);
123                self.try_push_filter_into(filter.predicate, optimized_input)
124            }
125            // Recursively optimize children for other operators
126            LogicalOperator::Return(mut ret) => {
127                ret.input = Box::new(self.push_filters_down(*ret.input));
128                LogicalOperator::Return(ret)
129            }
130            LogicalOperator::Project(mut proj) => {
131                proj.input = Box::new(self.push_filters_down(*proj.input));
132                LogicalOperator::Project(proj)
133            }
134            LogicalOperator::Limit(mut limit) => {
135                limit.input = Box::new(self.push_filters_down(*limit.input));
136                LogicalOperator::Limit(limit)
137            }
138            LogicalOperator::Skip(mut skip) => {
139                skip.input = Box::new(self.push_filters_down(*skip.input));
140                LogicalOperator::Skip(skip)
141            }
142            LogicalOperator::Sort(mut sort) => {
143                sort.input = Box::new(self.push_filters_down(*sort.input));
144                LogicalOperator::Sort(sort)
145            }
146            LogicalOperator::Distinct(mut distinct) => {
147                distinct.input = Box::new(self.push_filters_down(*distinct.input));
148                LogicalOperator::Distinct(distinct)
149            }
150            LogicalOperator::Expand(mut expand) => {
151                expand.input = Box::new(self.push_filters_down(*expand.input));
152                LogicalOperator::Expand(expand)
153            }
154            LogicalOperator::Join(mut join) => {
155                join.left = Box::new(self.push_filters_down(*join.left));
156                join.right = Box::new(self.push_filters_down(*join.right));
157                LogicalOperator::Join(join)
158            }
159            LogicalOperator::Aggregate(mut agg) => {
160                agg.input = Box::new(self.push_filters_down(*agg.input));
161                LogicalOperator::Aggregate(agg)
162            }
163            // Leaf operators and unsupported operators are returned as-is
164            other => other,
165        }
166    }
167
168    /// Tries to push a filter predicate into the given operator.
169    ///
170    /// Returns either the predicate pushed into the operator, or a new
171    /// Filter operator on top if the predicate cannot be pushed further.
172    fn try_push_filter_into(
173        &self,
174        predicate: LogicalExpression,
175        op: LogicalOperator,
176    ) -> LogicalOperator {
177        match op {
178            // Can push through Project if predicate doesn't depend on computed columns
179            LogicalOperator::Project(mut proj) => {
180                let predicate_vars = self.extract_variables(&predicate);
181                let computed_vars = self.extract_projection_aliases(&proj.projections);
182
183                // If predicate doesn't use any computed columns, push through
184                if predicate_vars.is_disjoint(&computed_vars) {
185                    proj.input = Box::new(self.try_push_filter_into(predicate, *proj.input));
186                    LogicalOperator::Project(proj)
187                } else {
188                    // Can't push through, keep filter on top
189                    LogicalOperator::Filter(FilterOp {
190                        predicate,
191                        input: Box::new(LogicalOperator::Project(proj)),
192                    })
193                }
194            }
195
196            // Can push through Return (which is like a projection)
197            LogicalOperator::Return(mut ret) => {
198                ret.input = Box::new(self.try_push_filter_into(predicate, *ret.input));
199                LogicalOperator::Return(ret)
200            }
201
202            // Can push through Expand if predicate only uses source variable
203            LogicalOperator::Expand(mut expand) => {
204                let predicate_vars = self.extract_variables(&predicate);
205
206                // Check if predicate only uses the source variable
207                let uses_only_source = predicate_vars.iter().all(|v| v == &expand.from_variable);
208
209                if uses_only_source {
210                    // Push the filter before the expand
211                    expand.input = Box::new(self.try_push_filter_into(predicate, *expand.input));
212                    LogicalOperator::Expand(expand)
213                } else {
214                    // Keep filter after expand
215                    LogicalOperator::Filter(FilterOp {
216                        predicate,
217                        input: Box::new(LogicalOperator::Expand(expand)),
218                    })
219                }
220            }
221
222            // Can push through Join to left/right side based on variables used
223            LogicalOperator::Join(mut join) => {
224                let predicate_vars = self.extract_variables(&predicate);
225                let left_vars = self.collect_output_variables(&join.left);
226                let right_vars = self.collect_output_variables(&join.right);
227
228                let uses_left = predicate_vars.iter().any(|v| left_vars.contains(v));
229                let uses_right = predicate_vars.iter().any(|v| right_vars.contains(v));
230
231                if uses_left && !uses_right {
232                    // Push to left side
233                    join.left = Box::new(self.try_push_filter_into(predicate, *join.left));
234                    LogicalOperator::Join(join)
235                } else if uses_right && !uses_left {
236                    // Push to right side
237                    join.right = Box::new(self.try_push_filter_into(predicate, *join.right));
238                    LogicalOperator::Join(join)
239                } else {
240                    // Uses both sides - keep above join
241                    LogicalOperator::Filter(FilterOp {
242                        predicate,
243                        input: Box::new(LogicalOperator::Join(join)),
244                    })
245                }
246            }
247
248            // Cannot push through Aggregate (predicate refers to aggregated values)
249            LogicalOperator::Aggregate(agg) => LogicalOperator::Filter(FilterOp {
250                predicate,
251                input: Box::new(LogicalOperator::Aggregate(agg)),
252            }),
253
254            // For NodeScan, we've reached the bottom - keep filter on top
255            LogicalOperator::NodeScan(scan) => LogicalOperator::Filter(FilterOp {
256                predicate,
257                input: Box::new(LogicalOperator::NodeScan(scan)),
258            }),
259
260            // For other operators, keep filter on top
261            other => LogicalOperator::Filter(FilterOp {
262                predicate,
263                input: Box::new(other),
264            }),
265        }
266    }
267
268    /// Collects all output variable names from an operator.
269    fn collect_output_variables(&self, op: &LogicalOperator) -> HashSet<String> {
270        let mut vars = HashSet::new();
271        self.collect_output_variables_recursive(op, &mut vars);
272        vars
273    }
274
275    /// Recursively collects output variables from an operator.
276    fn collect_output_variables_recursive(&self, op: &LogicalOperator, vars: &mut HashSet<String>) {
277        match op {
278            LogicalOperator::NodeScan(scan) => {
279                vars.insert(scan.variable.clone());
280            }
281            LogicalOperator::EdgeScan(scan) => {
282                vars.insert(scan.variable.clone());
283            }
284            LogicalOperator::Expand(expand) => {
285                vars.insert(expand.to_variable.clone());
286                if let Some(edge_var) = &expand.edge_variable {
287                    vars.insert(edge_var.clone());
288                }
289                self.collect_output_variables_recursive(&expand.input, vars);
290            }
291            LogicalOperator::Filter(filter) => {
292                self.collect_output_variables_recursive(&filter.input, vars);
293            }
294            LogicalOperator::Project(proj) => {
295                for p in &proj.projections {
296                    if let Some(alias) = &p.alias {
297                        vars.insert(alias.clone());
298                    }
299                }
300                self.collect_output_variables_recursive(&proj.input, vars);
301            }
302            LogicalOperator::Join(join) => {
303                self.collect_output_variables_recursive(&join.left, vars);
304                self.collect_output_variables_recursive(&join.right, vars);
305            }
306            LogicalOperator::Aggregate(agg) => {
307                for expr in &agg.group_by {
308                    self.collect_variables(expr, vars);
309                }
310                for agg_expr in &agg.aggregates {
311                    if let Some(alias) = &agg_expr.alias {
312                        vars.insert(alias.clone());
313                    }
314                }
315            }
316            LogicalOperator::Return(ret) => {
317                self.collect_output_variables_recursive(&ret.input, vars);
318            }
319            LogicalOperator::Limit(limit) => {
320                self.collect_output_variables_recursive(&limit.input, vars);
321            }
322            LogicalOperator::Skip(skip) => {
323                self.collect_output_variables_recursive(&skip.input, vars);
324            }
325            LogicalOperator::Sort(sort) => {
326                self.collect_output_variables_recursive(&sort.input, vars);
327            }
328            LogicalOperator::Distinct(distinct) => {
329                self.collect_output_variables_recursive(&distinct.input, vars);
330            }
331            _ => {}
332        }
333    }
334
335    /// Extracts all variable names referenced in an expression.
336    fn extract_variables(&self, expr: &LogicalExpression) -> HashSet<String> {
337        let mut vars = HashSet::new();
338        self.collect_variables(expr, &mut vars);
339        vars
340    }
341
342    /// Recursively collects variable names from an expression.
343    fn collect_variables(&self, expr: &LogicalExpression, vars: &mut HashSet<String>) {
344        match expr {
345            LogicalExpression::Variable(name) => {
346                vars.insert(name.clone());
347            }
348            LogicalExpression::Property { variable, .. } => {
349                vars.insert(variable.clone());
350            }
351            LogicalExpression::Binary { left, right, .. } => {
352                self.collect_variables(left, vars);
353                self.collect_variables(right, vars);
354            }
355            LogicalExpression::Unary { operand, .. } => {
356                self.collect_variables(operand, vars);
357            }
358            LogicalExpression::FunctionCall { args, .. } => {
359                for arg in args {
360                    self.collect_variables(arg, vars);
361                }
362            }
363            LogicalExpression::List(items) => {
364                for item in items {
365                    self.collect_variables(item, vars);
366                }
367            }
368            LogicalExpression::Map(pairs) => {
369                for (_, value) in pairs {
370                    self.collect_variables(value, vars);
371                }
372            }
373            LogicalExpression::IndexAccess { base, index } => {
374                self.collect_variables(base, vars);
375                self.collect_variables(index, vars);
376            }
377            LogicalExpression::SliceAccess { base, start, end } => {
378                self.collect_variables(base, vars);
379                if let Some(s) = start {
380                    self.collect_variables(s, vars);
381                }
382                if let Some(e) = end {
383                    self.collect_variables(e, vars);
384                }
385            }
386            LogicalExpression::Case {
387                operand,
388                when_clauses,
389                else_clause,
390            } => {
391                if let Some(op) = operand {
392                    self.collect_variables(op, vars);
393                }
394                for (cond, result) in when_clauses {
395                    self.collect_variables(cond, vars);
396                    self.collect_variables(result, vars);
397                }
398                if let Some(else_expr) = else_clause {
399                    self.collect_variables(else_expr, vars);
400                }
401            }
402            LogicalExpression::Labels(var)
403            | LogicalExpression::Type(var)
404            | LogicalExpression::Id(var) => {
405                vars.insert(var.clone());
406            }
407            LogicalExpression::Literal(_) | LogicalExpression::Parameter(_) => {}
408            LogicalExpression::ListComprehension {
409                list_expr,
410                filter_expr,
411                map_expr,
412                ..
413            } => {
414                self.collect_variables(list_expr, vars);
415                if let Some(filter) = filter_expr {
416                    self.collect_variables(filter, vars);
417                }
418                self.collect_variables(map_expr, vars);
419            }
420            LogicalExpression::ExistsSubquery(_) | LogicalExpression::CountSubquery(_) => {
421                // Subqueries have their own variable scope
422            }
423        }
424    }
425
426    /// Extracts aliases from projection expressions.
427    fn extract_projection_aliases(
428        &self,
429        projections: &[crate::query::plan::Projection],
430    ) -> HashSet<String> {
431        projections.iter().filter_map(|p| p.alias.clone()).collect()
432    }
433}
434
435impl Default for Optimizer {
436    fn default() -> Self {
437        Self::new()
438    }
439}
440
441#[cfg(test)]
442mod tests {
443    use super::*;
444    use crate::query::plan::{
445        AggregateExpr, AggregateFunction, AggregateOp, BinaryOp, DistinctOp, ExpandDirection,
446        ExpandOp, JoinOp, JoinType, LimitOp, NodeScanOp, ProjectOp, Projection, ReturnItem,
447        ReturnOp, SkipOp, SortKey, SortOp, SortOrder, UnaryOp,
448    };
449    use grafeo_common::types::Value;
450
451    #[test]
452    fn test_optimizer_filter_pushdown_simple() {
453        // Query: MATCH (n:Person) WHERE n.age > 30 RETURN n
454        // Before: Return -> Filter -> NodeScan
455        // After:  Return -> Filter -> NodeScan (filter stays at bottom)
456
457        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
458            items: vec![ReturnItem {
459                expression: LogicalExpression::Variable("n".to_string()),
460                alias: None,
461            }],
462            distinct: false,
463            input: Box::new(LogicalOperator::Filter(FilterOp {
464                predicate: LogicalExpression::Binary {
465                    left: Box::new(LogicalExpression::Property {
466                        variable: "n".to_string(),
467                        property: "age".to_string(),
468                    }),
469                    op: BinaryOp::Gt,
470                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
471                },
472                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
473                    variable: "n".to_string(),
474                    label: Some("Person".to_string()),
475                    input: None,
476                })),
477            })),
478        }));
479
480        let optimizer = Optimizer::new();
481        let optimized = optimizer.optimize(plan).unwrap();
482
483        // The structure should remain similar (filter stays near scan)
484        if let LogicalOperator::Return(ret) = &optimized.root {
485            if let LogicalOperator::Filter(filter) = ret.input.as_ref() {
486                if let LogicalOperator::NodeScan(scan) = filter.input.as_ref() {
487                    assert_eq!(scan.variable, "n");
488                    return;
489                }
490            }
491        }
492        panic!("Expected Return -> Filter -> NodeScan structure");
493    }
494
495    #[test]
496    fn test_optimizer_filter_pushdown_through_expand() {
497        // Query: MATCH (a:Person)-[:KNOWS]->(b) WHERE a.age > 30 RETURN b
498        // The filter on 'a' should be pushed before the expand
499
500        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
501            items: vec![ReturnItem {
502                expression: LogicalExpression::Variable("b".to_string()),
503                alias: None,
504            }],
505            distinct: false,
506            input: Box::new(LogicalOperator::Filter(FilterOp {
507                predicate: LogicalExpression::Binary {
508                    left: Box::new(LogicalExpression::Property {
509                        variable: "a".to_string(),
510                        property: "age".to_string(),
511                    }),
512                    op: BinaryOp::Gt,
513                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
514                },
515                input: Box::new(LogicalOperator::Expand(ExpandOp {
516                    from_variable: "a".to_string(),
517                    to_variable: "b".to_string(),
518                    edge_variable: None,
519                    direction: ExpandDirection::Outgoing,
520                    edge_type: Some("KNOWS".to_string()),
521                    min_hops: 1,
522                    max_hops: Some(1),
523                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
524                        variable: "a".to_string(),
525                        label: Some("Person".to_string()),
526                        input: None,
527                    })),
528                })),
529            })),
530        }));
531
532        let optimizer = Optimizer::new();
533        let optimized = optimizer.optimize(plan).unwrap();
534
535        // Filter on 'a' should be pushed before the expand
536        // Expected: Return -> Expand -> Filter -> NodeScan
537        if let LogicalOperator::Return(ret) = &optimized.root {
538            if let LogicalOperator::Expand(expand) = ret.input.as_ref() {
539                if let LogicalOperator::Filter(filter) = expand.input.as_ref() {
540                    if let LogicalOperator::NodeScan(scan) = filter.input.as_ref() {
541                        assert_eq!(scan.variable, "a");
542                        assert_eq!(expand.from_variable, "a");
543                        assert_eq!(expand.to_variable, "b");
544                        return;
545                    }
546                }
547            }
548        }
549        panic!("Expected Return -> Expand -> Filter -> NodeScan structure");
550    }
551
552    #[test]
553    fn test_optimizer_filter_not_pushed_through_expand_for_target_var() {
554        // Query: MATCH (a:Person)-[:KNOWS]->(b) WHERE b.age > 30 RETURN a
555        // The filter on 'b' should NOT be pushed before the expand
556
557        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
558            items: vec![ReturnItem {
559                expression: LogicalExpression::Variable("a".to_string()),
560                alias: None,
561            }],
562            distinct: false,
563            input: Box::new(LogicalOperator::Filter(FilterOp {
564                predicate: LogicalExpression::Binary {
565                    left: Box::new(LogicalExpression::Property {
566                        variable: "b".to_string(),
567                        property: "age".to_string(),
568                    }),
569                    op: BinaryOp::Gt,
570                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
571                },
572                input: Box::new(LogicalOperator::Expand(ExpandOp {
573                    from_variable: "a".to_string(),
574                    to_variable: "b".to_string(),
575                    edge_variable: None,
576                    direction: ExpandDirection::Outgoing,
577                    edge_type: Some("KNOWS".to_string()),
578                    min_hops: 1,
579                    max_hops: Some(1),
580                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
581                        variable: "a".to_string(),
582                        label: Some("Person".to_string()),
583                        input: None,
584                    })),
585                })),
586            })),
587        }));
588
589        let optimizer = Optimizer::new();
590        let optimized = optimizer.optimize(plan).unwrap();
591
592        // Filter on 'b' should stay after the expand
593        // Expected: Return -> Filter -> Expand -> NodeScan
594        if let LogicalOperator::Return(ret) = &optimized.root {
595            if let LogicalOperator::Filter(filter) = ret.input.as_ref() {
596                // Check that the filter is on 'b'
597                if let LogicalExpression::Binary { left, .. } = &filter.predicate {
598                    if let LogicalExpression::Property { variable, .. } = left.as_ref() {
599                        assert_eq!(variable, "b");
600                    }
601                }
602
603                if let LogicalOperator::Expand(expand) = filter.input.as_ref() {
604                    if let LogicalOperator::NodeScan(_) = expand.input.as_ref() {
605                        return;
606                    }
607                }
608            }
609        }
610        panic!("Expected Return -> Filter -> Expand -> NodeScan structure");
611    }
612
613    #[test]
614    fn test_optimizer_extract_variables() {
615        let optimizer = Optimizer::new();
616
617        let expr = LogicalExpression::Binary {
618            left: Box::new(LogicalExpression::Property {
619                variable: "n".to_string(),
620                property: "age".to_string(),
621            }),
622            op: BinaryOp::Gt,
623            right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
624        };
625
626        let vars = optimizer.extract_variables(&expr);
627        assert_eq!(vars.len(), 1);
628        assert!(vars.contains("n"));
629    }
630
631    // Additional tests for optimizer configuration
632
633    #[test]
634    fn test_optimizer_default() {
635        let optimizer = Optimizer::default();
636        // Should be able to optimize an empty plan
637        let plan = LogicalPlan::new(LogicalOperator::Empty);
638        let result = optimizer.optimize(plan);
639        assert!(result.is_ok());
640    }
641
642    #[test]
643    fn test_optimizer_with_filter_pushdown_disabled() {
644        let optimizer = Optimizer::new().with_filter_pushdown(false);
645
646        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
647            items: vec![ReturnItem {
648                expression: LogicalExpression::Variable("n".to_string()),
649                alias: None,
650            }],
651            distinct: false,
652            input: Box::new(LogicalOperator::Filter(FilterOp {
653                predicate: LogicalExpression::Literal(Value::Bool(true)),
654                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
655                    variable: "n".to_string(),
656                    label: None,
657                    input: None,
658                })),
659            })),
660        }));
661
662        let optimized = optimizer.optimize(plan).unwrap();
663        // Structure should be unchanged
664        if let LogicalOperator::Return(ret) = &optimized.root {
665            if let LogicalOperator::Filter(_) = ret.input.as_ref() {
666                return;
667            }
668        }
669        panic!("Expected unchanged structure");
670    }
671
672    #[test]
673    fn test_optimizer_with_join_reorder_disabled() {
674        let optimizer = Optimizer::new().with_join_reorder(false);
675        assert!(
676            optimizer
677                .optimize(LogicalPlan::new(LogicalOperator::Empty))
678                .is_ok()
679        );
680    }
681
682    #[test]
683    fn test_optimizer_with_cost_model() {
684        let cost_model = CostModel::new();
685        let optimizer = Optimizer::new().with_cost_model(cost_model);
686        assert!(
687            optimizer
688                .cost_model()
689                .estimate(&LogicalOperator::Empty, 0.0)
690                .total()
691                < 0.001
692        );
693    }
694
695    #[test]
696    fn test_optimizer_with_cardinality_estimator() {
697        let mut estimator = CardinalityEstimator::new();
698        estimator.add_table_stats("Test", TableStats::new(500));
699        let optimizer = Optimizer::new().with_cardinality_estimator(estimator);
700
701        let scan = LogicalOperator::NodeScan(NodeScanOp {
702            variable: "n".to_string(),
703            label: Some("Test".to_string()),
704            input: None,
705        });
706        let plan = LogicalPlan::new(scan);
707
708        let cardinality = optimizer.estimate_cardinality(&plan);
709        assert!((cardinality - 500.0).abs() < 0.001);
710    }
711
712    #[test]
713    fn test_optimizer_estimate_cost() {
714        let optimizer = Optimizer::new();
715        let plan = LogicalPlan::new(LogicalOperator::NodeScan(NodeScanOp {
716            variable: "n".to_string(),
717            label: None,
718            input: None,
719        }));
720
721        let cost = optimizer.estimate_cost(&plan);
722        assert!(cost.total() > 0.0);
723    }
724
725    // Filter pushdown through various operators
726
727    #[test]
728    fn test_filter_pushdown_through_project() {
729        let optimizer = Optimizer::new();
730
731        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
732            predicate: LogicalExpression::Binary {
733                left: Box::new(LogicalExpression::Property {
734                    variable: "n".to_string(),
735                    property: "age".to_string(),
736                }),
737                op: BinaryOp::Gt,
738                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
739            },
740            input: Box::new(LogicalOperator::Project(ProjectOp {
741                projections: vec![Projection {
742                    expression: LogicalExpression::Variable("n".to_string()),
743                    alias: None,
744                }],
745                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
746                    variable: "n".to_string(),
747                    label: None,
748                    input: None,
749                })),
750            })),
751        }));
752
753        let optimized = optimizer.optimize(plan).unwrap();
754
755        // Filter should be pushed through Project
756        if let LogicalOperator::Project(proj) = &optimized.root {
757            if let LogicalOperator::Filter(_) = proj.input.as_ref() {
758                return;
759            }
760        }
761        panic!("Expected Project -> Filter structure");
762    }
763
764    #[test]
765    fn test_filter_not_pushed_through_project_with_alias() {
766        let optimizer = Optimizer::new();
767
768        // Filter on computed column 'x' should not be pushed through project that creates 'x'
769        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
770            predicate: LogicalExpression::Binary {
771                left: Box::new(LogicalExpression::Variable("x".to_string())),
772                op: BinaryOp::Gt,
773                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
774            },
775            input: Box::new(LogicalOperator::Project(ProjectOp {
776                projections: vec![Projection {
777                    expression: LogicalExpression::Property {
778                        variable: "n".to_string(),
779                        property: "age".to_string(),
780                    },
781                    alias: Some("x".to_string()),
782                }],
783                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
784                    variable: "n".to_string(),
785                    label: None,
786                    input: None,
787                })),
788            })),
789        }));
790
791        let optimized = optimizer.optimize(plan).unwrap();
792
793        // Filter should stay above Project
794        if let LogicalOperator::Filter(filter) = &optimized.root {
795            if let LogicalOperator::Project(_) = filter.input.as_ref() {
796                return;
797            }
798        }
799        panic!("Expected Filter -> Project structure");
800    }
801
802    #[test]
803    fn test_filter_pushdown_through_limit() {
804        let optimizer = Optimizer::new();
805
806        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
807            predicate: LogicalExpression::Literal(Value::Bool(true)),
808            input: Box::new(LogicalOperator::Limit(LimitOp {
809                count: 10,
810                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
811                    variable: "n".to_string(),
812                    label: None,
813                    input: None,
814                })),
815            })),
816        }));
817
818        let optimized = optimizer.optimize(plan).unwrap();
819
820        // Filter stays above Limit (cannot be pushed through)
821        if let LogicalOperator::Filter(filter) = &optimized.root {
822            if let LogicalOperator::Limit(_) = filter.input.as_ref() {
823                return;
824            }
825        }
826        panic!("Expected Filter -> Limit structure");
827    }
828
829    #[test]
830    fn test_filter_pushdown_through_sort() {
831        let optimizer = Optimizer::new();
832
833        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
834            predicate: LogicalExpression::Literal(Value::Bool(true)),
835            input: Box::new(LogicalOperator::Sort(SortOp {
836                keys: vec![SortKey {
837                    expression: LogicalExpression::Variable("n".to_string()),
838                    order: SortOrder::Ascending,
839                }],
840                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
841                    variable: "n".to_string(),
842                    label: None,
843                    input: None,
844                })),
845            })),
846        }));
847
848        let optimized = optimizer.optimize(plan).unwrap();
849
850        // Filter stays above Sort
851        if let LogicalOperator::Filter(filter) = &optimized.root {
852            if let LogicalOperator::Sort(_) = filter.input.as_ref() {
853                return;
854            }
855        }
856        panic!("Expected Filter -> Sort structure");
857    }
858
859    #[test]
860    fn test_filter_pushdown_through_distinct() {
861        let optimizer = Optimizer::new();
862
863        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
864            predicate: LogicalExpression::Literal(Value::Bool(true)),
865            input: Box::new(LogicalOperator::Distinct(DistinctOp {
866                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
867                    variable: "n".to_string(),
868                    label: None,
869                    input: None,
870                })),
871            })),
872        }));
873
874        let optimized = optimizer.optimize(plan).unwrap();
875
876        // Filter stays above Distinct
877        if let LogicalOperator::Filter(filter) = &optimized.root {
878            if let LogicalOperator::Distinct(_) = filter.input.as_ref() {
879                return;
880            }
881        }
882        panic!("Expected Filter -> Distinct structure");
883    }
884
885    #[test]
886    fn test_filter_not_pushed_through_aggregate() {
887        let optimizer = Optimizer::new();
888
889        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
890            predicate: LogicalExpression::Binary {
891                left: Box::new(LogicalExpression::Variable("cnt".to_string())),
892                op: BinaryOp::Gt,
893                right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
894            },
895            input: Box::new(LogicalOperator::Aggregate(AggregateOp {
896                group_by: vec![],
897                aggregates: vec![AggregateExpr {
898                    function: AggregateFunction::Count,
899                    expression: None,
900                    distinct: false,
901                    alias: Some("cnt".to_string()),
902                }],
903                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
904                    variable: "n".to_string(),
905                    label: None,
906                    input: None,
907                })),
908            })),
909        }));
910
911        let optimized = optimizer.optimize(plan).unwrap();
912
913        // Filter should stay above Aggregate
914        if let LogicalOperator::Filter(filter) = &optimized.root {
915            if let LogicalOperator::Aggregate(_) = filter.input.as_ref() {
916                return;
917            }
918        }
919        panic!("Expected Filter -> Aggregate structure");
920    }
921
922    #[test]
923    fn test_filter_pushdown_to_left_join_side() {
924        let optimizer = Optimizer::new();
925
926        // Filter on left variable should be pushed to left side
927        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
928            predicate: LogicalExpression::Binary {
929                left: Box::new(LogicalExpression::Property {
930                    variable: "a".to_string(),
931                    property: "age".to_string(),
932                }),
933                op: BinaryOp::Gt,
934                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
935            },
936            input: Box::new(LogicalOperator::Join(JoinOp {
937                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
938                    variable: "a".to_string(),
939                    label: Some("Person".to_string()),
940                    input: None,
941                })),
942                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
943                    variable: "b".to_string(),
944                    label: Some("Company".to_string()),
945                    input: None,
946                })),
947                join_type: JoinType::Inner,
948                conditions: vec![],
949            })),
950        }));
951
952        let optimized = optimizer.optimize(plan).unwrap();
953
954        // Filter should be pushed to left side of join
955        if let LogicalOperator::Join(join) = &optimized.root {
956            if let LogicalOperator::Filter(_) = join.left.as_ref() {
957                return;
958            }
959        }
960        panic!("Expected Join with Filter on left side");
961    }
962
963    #[test]
964    fn test_filter_pushdown_to_right_join_side() {
965        let optimizer = Optimizer::new();
966
967        // Filter on right variable should be pushed to right side
968        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
969            predicate: LogicalExpression::Binary {
970                left: Box::new(LogicalExpression::Property {
971                    variable: "b".to_string(),
972                    property: "name".to_string(),
973                }),
974                op: BinaryOp::Eq,
975                right: Box::new(LogicalExpression::Literal(Value::String("Acme".into()))),
976            },
977            input: Box::new(LogicalOperator::Join(JoinOp {
978                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
979                    variable: "a".to_string(),
980                    label: Some("Person".to_string()),
981                    input: None,
982                })),
983                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
984                    variable: "b".to_string(),
985                    label: Some("Company".to_string()),
986                    input: None,
987                })),
988                join_type: JoinType::Inner,
989                conditions: vec![],
990            })),
991        }));
992
993        let optimized = optimizer.optimize(plan).unwrap();
994
995        // Filter should be pushed to right side of join
996        if let LogicalOperator::Join(join) = &optimized.root {
997            if let LogicalOperator::Filter(_) = join.right.as_ref() {
998                return;
999            }
1000        }
1001        panic!("Expected Join with Filter on right side");
1002    }
1003
1004    #[test]
1005    fn test_filter_not_pushed_when_uses_both_join_sides() {
1006        let optimizer = Optimizer::new();
1007
1008        // Filter using both variables should stay above join
1009        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1010            predicate: LogicalExpression::Binary {
1011                left: Box::new(LogicalExpression::Property {
1012                    variable: "a".to_string(),
1013                    property: "id".to_string(),
1014                }),
1015                op: BinaryOp::Eq,
1016                right: Box::new(LogicalExpression::Property {
1017                    variable: "b".to_string(),
1018                    property: "a_id".to_string(),
1019                }),
1020            },
1021            input: Box::new(LogicalOperator::Join(JoinOp {
1022                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1023                    variable: "a".to_string(),
1024                    label: None,
1025                    input: None,
1026                })),
1027                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1028                    variable: "b".to_string(),
1029                    label: None,
1030                    input: None,
1031                })),
1032                join_type: JoinType::Inner,
1033                conditions: vec![],
1034            })),
1035        }));
1036
1037        let optimized = optimizer.optimize(plan).unwrap();
1038
1039        // Filter should stay above join
1040        if let LogicalOperator::Filter(filter) = &optimized.root {
1041            if let LogicalOperator::Join(_) = filter.input.as_ref() {
1042                return;
1043            }
1044        }
1045        panic!("Expected Filter -> Join structure");
1046    }
1047
1048    // Variable extraction tests
1049
1050    #[test]
1051    fn test_extract_variables_from_variable() {
1052        let optimizer = Optimizer::new();
1053        let expr = LogicalExpression::Variable("x".to_string());
1054        let vars = optimizer.extract_variables(&expr);
1055        assert_eq!(vars.len(), 1);
1056        assert!(vars.contains("x"));
1057    }
1058
1059    #[test]
1060    fn test_extract_variables_from_unary() {
1061        let optimizer = Optimizer::new();
1062        let expr = LogicalExpression::Unary {
1063            op: UnaryOp::Not,
1064            operand: Box::new(LogicalExpression::Variable("x".to_string())),
1065        };
1066        let vars = optimizer.extract_variables(&expr);
1067        assert_eq!(vars.len(), 1);
1068        assert!(vars.contains("x"));
1069    }
1070
1071    #[test]
1072    fn test_extract_variables_from_function_call() {
1073        let optimizer = Optimizer::new();
1074        let expr = LogicalExpression::FunctionCall {
1075            name: "length".to_string(),
1076            args: vec![
1077                LogicalExpression::Variable("a".to_string()),
1078                LogicalExpression::Variable("b".to_string()),
1079            ],
1080        };
1081        let vars = optimizer.extract_variables(&expr);
1082        assert_eq!(vars.len(), 2);
1083        assert!(vars.contains("a"));
1084        assert!(vars.contains("b"));
1085    }
1086
1087    #[test]
1088    fn test_extract_variables_from_list() {
1089        let optimizer = Optimizer::new();
1090        let expr = LogicalExpression::List(vec![
1091            LogicalExpression::Variable("a".to_string()),
1092            LogicalExpression::Literal(Value::Int64(1)),
1093            LogicalExpression::Variable("b".to_string()),
1094        ]);
1095        let vars = optimizer.extract_variables(&expr);
1096        assert_eq!(vars.len(), 2);
1097        assert!(vars.contains("a"));
1098        assert!(vars.contains("b"));
1099    }
1100
1101    #[test]
1102    fn test_extract_variables_from_map() {
1103        let optimizer = Optimizer::new();
1104        let expr = LogicalExpression::Map(vec![
1105            (
1106                "key1".to_string(),
1107                LogicalExpression::Variable("a".to_string()),
1108            ),
1109            (
1110                "key2".to_string(),
1111                LogicalExpression::Variable("b".to_string()),
1112            ),
1113        ]);
1114        let vars = optimizer.extract_variables(&expr);
1115        assert_eq!(vars.len(), 2);
1116        assert!(vars.contains("a"));
1117        assert!(vars.contains("b"));
1118    }
1119
1120    #[test]
1121    fn test_extract_variables_from_index_access() {
1122        let optimizer = Optimizer::new();
1123        let expr = LogicalExpression::IndexAccess {
1124            base: Box::new(LogicalExpression::Variable("list".to_string())),
1125            index: Box::new(LogicalExpression::Variable("idx".to_string())),
1126        };
1127        let vars = optimizer.extract_variables(&expr);
1128        assert_eq!(vars.len(), 2);
1129        assert!(vars.contains("list"));
1130        assert!(vars.contains("idx"));
1131    }
1132
1133    #[test]
1134    fn test_extract_variables_from_slice_access() {
1135        let optimizer = Optimizer::new();
1136        let expr = LogicalExpression::SliceAccess {
1137            base: Box::new(LogicalExpression::Variable("list".to_string())),
1138            start: Some(Box::new(LogicalExpression::Variable("s".to_string()))),
1139            end: Some(Box::new(LogicalExpression::Variable("e".to_string()))),
1140        };
1141        let vars = optimizer.extract_variables(&expr);
1142        assert_eq!(vars.len(), 3);
1143        assert!(vars.contains("list"));
1144        assert!(vars.contains("s"));
1145        assert!(vars.contains("e"));
1146    }
1147
1148    #[test]
1149    fn test_extract_variables_from_case() {
1150        let optimizer = Optimizer::new();
1151        let expr = LogicalExpression::Case {
1152            operand: Some(Box::new(LogicalExpression::Variable("x".to_string()))),
1153            when_clauses: vec![(
1154                LogicalExpression::Literal(Value::Int64(1)),
1155                LogicalExpression::Variable("a".to_string()),
1156            )],
1157            else_clause: Some(Box::new(LogicalExpression::Variable("b".to_string()))),
1158        };
1159        let vars = optimizer.extract_variables(&expr);
1160        assert_eq!(vars.len(), 3);
1161        assert!(vars.contains("x"));
1162        assert!(vars.contains("a"));
1163        assert!(vars.contains("b"));
1164    }
1165
1166    #[test]
1167    fn test_extract_variables_from_labels() {
1168        let optimizer = Optimizer::new();
1169        let expr = LogicalExpression::Labels("n".to_string());
1170        let vars = optimizer.extract_variables(&expr);
1171        assert_eq!(vars.len(), 1);
1172        assert!(vars.contains("n"));
1173    }
1174
1175    #[test]
1176    fn test_extract_variables_from_type() {
1177        let optimizer = Optimizer::new();
1178        let expr = LogicalExpression::Type("e".to_string());
1179        let vars = optimizer.extract_variables(&expr);
1180        assert_eq!(vars.len(), 1);
1181        assert!(vars.contains("e"));
1182    }
1183
1184    #[test]
1185    fn test_extract_variables_from_id() {
1186        let optimizer = Optimizer::new();
1187        let expr = LogicalExpression::Id("n".to_string());
1188        let vars = optimizer.extract_variables(&expr);
1189        assert_eq!(vars.len(), 1);
1190        assert!(vars.contains("n"));
1191    }
1192
1193    #[test]
1194    fn test_extract_variables_from_list_comprehension() {
1195        let optimizer = Optimizer::new();
1196        let expr = LogicalExpression::ListComprehension {
1197            variable: "x".to_string(),
1198            list_expr: Box::new(LogicalExpression::Variable("items".to_string())),
1199            filter_expr: Some(Box::new(LogicalExpression::Variable("pred".to_string()))),
1200            map_expr: Box::new(LogicalExpression::Variable("result".to_string())),
1201        };
1202        let vars = optimizer.extract_variables(&expr);
1203        assert!(vars.contains("items"));
1204        assert!(vars.contains("pred"));
1205        assert!(vars.contains("result"));
1206    }
1207
1208    #[test]
1209    fn test_extract_variables_from_literal_and_parameter() {
1210        let optimizer = Optimizer::new();
1211
1212        let literal = LogicalExpression::Literal(Value::Int64(42));
1213        assert!(optimizer.extract_variables(&literal).is_empty());
1214
1215        let param = LogicalExpression::Parameter("p".to_string());
1216        assert!(optimizer.extract_variables(&param).is_empty());
1217    }
1218
1219    // Recursive filter pushdown tests
1220
1221    #[test]
1222    fn test_recursive_filter_pushdown_through_skip() {
1223        let optimizer = Optimizer::new();
1224
1225        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1226            items: vec![ReturnItem {
1227                expression: LogicalExpression::Variable("n".to_string()),
1228                alias: None,
1229            }],
1230            distinct: false,
1231            input: Box::new(LogicalOperator::Filter(FilterOp {
1232                predicate: LogicalExpression::Literal(Value::Bool(true)),
1233                input: Box::new(LogicalOperator::Skip(SkipOp {
1234                    count: 5,
1235                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1236                        variable: "n".to_string(),
1237                        label: None,
1238                        input: None,
1239                    })),
1240                })),
1241            })),
1242        }));
1243
1244        let optimized = optimizer.optimize(plan).unwrap();
1245
1246        // Verify optimization succeeded
1247        assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1248    }
1249
1250    #[test]
1251    fn test_nested_filter_pushdown() {
1252        let optimizer = Optimizer::new();
1253
1254        // Multiple nested filters
1255        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1256            items: vec![ReturnItem {
1257                expression: LogicalExpression::Variable("n".to_string()),
1258                alias: None,
1259            }],
1260            distinct: false,
1261            input: Box::new(LogicalOperator::Filter(FilterOp {
1262                predicate: LogicalExpression::Binary {
1263                    left: Box::new(LogicalExpression::Property {
1264                        variable: "n".to_string(),
1265                        property: "x".to_string(),
1266                    }),
1267                    op: BinaryOp::Gt,
1268                    right: Box::new(LogicalExpression::Literal(Value::Int64(1))),
1269                },
1270                input: Box::new(LogicalOperator::Filter(FilterOp {
1271                    predicate: LogicalExpression::Binary {
1272                        left: Box::new(LogicalExpression::Property {
1273                            variable: "n".to_string(),
1274                            property: "y".to_string(),
1275                        }),
1276                        op: BinaryOp::Lt,
1277                        right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1278                    },
1279                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1280                        variable: "n".to_string(),
1281                        label: None,
1282                        input: None,
1283                    })),
1284                })),
1285            })),
1286        }));
1287
1288        let optimized = optimizer.optimize(plan).unwrap();
1289        assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1290    }
1291}