Skip to main content

grafeo_engine/query/optimizer/
mod.rs

1//! Makes your queries faster without changing their meaning.
2//!
3//! The optimizer transforms logical plans to run more efficiently:
4//!
5//! | Optimization | What it does |
6//! | ------------ | ------------ |
7//! | Filter Pushdown | Moves `WHERE` clauses closer to scans - filter early, process less |
8//! | Join Reordering | Picks the best order to join tables using the DPccp algorithm |
9//! | Predicate Simplification | Folds constants like `1 + 1` into `2` |
10//!
11//! The optimizer uses [`CostModel`] and [`CardinalityEstimator`] to predict
12//! how expensive different plans are, then picks the cheapest.
13
14pub mod cardinality;
15pub mod cost;
16pub mod join_order;
17
18pub use cardinality::{CardinalityEstimator, ColumnStats, TableStats};
19pub use cost::{Cost, CostModel};
20pub use join_order::{BitSet, DPccp, JoinGraph, JoinGraphBuilder, JoinPlan};
21
22use crate::query::plan::{FilterOp, LogicalExpression, LogicalOperator, LogicalPlan};
23use grafeo_common::utils::error::Result;
24use std::collections::HashSet;
25
26/// Information about a join condition for join reordering.
27#[derive(Debug, Clone)]
28struct JoinInfo {
29    left_var: String,
30    right_var: String,
31    left_expr: LogicalExpression,
32    right_expr: LogicalExpression,
33}
34
35/// A column required by the query, used for projection pushdown.
36#[derive(Debug, Clone, PartialEq, Eq, Hash)]
37enum RequiredColumn {
38    /// A variable (node, edge, or path binding)
39    Variable(String),
40    /// A specific property of a variable
41    Property(String, String),
42}
43
44/// Transforms logical plans for faster execution.
45///
46/// Create with [`new()`](Self::new), then call [`optimize()`](Self::optimize).
47/// Use the builder methods to enable/disable specific optimizations.
48pub struct Optimizer {
49    /// Whether to enable filter pushdown.
50    enable_filter_pushdown: bool,
51    /// Whether to enable join reordering.
52    enable_join_reorder: bool,
53    /// Whether to enable projection pushdown.
54    enable_projection_pushdown: bool,
55    /// Cost model for estimation.
56    cost_model: CostModel,
57    /// Cardinality estimator.
58    card_estimator: CardinalityEstimator,
59}
60
61impl Optimizer {
62    /// Creates a new optimizer with default settings.
63    #[must_use]
64    pub fn new() -> Self {
65        Self {
66            enable_filter_pushdown: true,
67            enable_join_reorder: true,
68            enable_projection_pushdown: true,
69            cost_model: CostModel::new(),
70            card_estimator: CardinalityEstimator::new(),
71        }
72    }
73
74    /// Enables or disables filter pushdown.
75    pub fn with_filter_pushdown(mut self, enabled: bool) -> Self {
76        self.enable_filter_pushdown = enabled;
77        self
78    }
79
80    /// Enables or disables join reordering.
81    pub fn with_join_reorder(mut self, enabled: bool) -> Self {
82        self.enable_join_reorder = enabled;
83        self
84    }
85
86    /// Enables or disables projection pushdown.
87    pub fn with_projection_pushdown(mut self, enabled: bool) -> Self {
88        self.enable_projection_pushdown = enabled;
89        self
90    }
91
92    /// Sets the cost model.
93    pub fn with_cost_model(mut self, cost_model: CostModel) -> Self {
94        self.cost_model = cost_model;
95        self
96    }
97
98    /// Sets the cardinality estimator.
99    pub fn with_cardinality_estimator(mut self, estimator: CardinalityEstimator) -> Self {
100        self.card_estimator = estimator;
101        self
102    }
103
104    /// Returns a reference to the cost model.
105    pub fn cost_model(&self) -> &CostModel {
106        &self.cost_model
107    }
108
109    /// Returns a reference to the cardinality estimator.
110    pub fn cardinality_estimator(&self) -> &CardinalityEstimator {
111        &self.card_estimator
112    }
113
114    /// Estimates the cost of a plan.
115    pub fn estimate_cost(&self, plan: &LogicalPlan) -> Cost {
116        let cardinality = self.card_estimator.estimate(&plan.root);
117        self.cost_model.estimate(&plan.root, cardinality)
118    }
119
120    /// Estimates the cardinality of a plan.
121    pub fn estimate_cardinality(&self, plan: &LogicalPlan) -> f64 {
122        self.card_estimator.estimate(&plan.root)
123    }
124
125    /// Optimizes a logical plan.
126    ///
127    /// # Errors
128    ///
129    /// Returns an error if optimization fails.
130    pub fn optimize(&self, plan: LogicalPlan) -> Result<LogicalPlan> {
131        let mut root = plan.root;
132
133        // Apply optimization rules
134        if self.enable_filter_pushdown {
135            root = self.push_filters_down(root);
136        }
137
138        if self.enable_join_reorder {
139            root = self.reorder_joins(root);
140        }
141
142        if self.enable_projection_pushdown {
143            root = self.push_projections_down(root);
144        }
145
146        Ok(LogicalPlan::new(root))
147    }
148
149    /// Pushes projections down the operator tree to eliminate unused columns early.
150    ///
151    /// This optimization:
152    /// 1. Collects required variables/properties from the root
153    /// 2. Propagates requirements down through the tree
154    /// 3. Inserts projections to eliminate unneeded columns before expensive operations
155    fn push_projections_down(&self, op: LogicalOperator) -> LogicalOperator {
156        // Collect required columns from the top of the plan
157        let required = self.collect_required_columns(&op);
158
159        // Push projections down
160        self.push_projections_recursive(op, &required)
161    }
162
163    /// Collects all variables and properties required by an operator and its ancestors.
164    fn collect_required_columns(&self, op: &LogicalOperator) -> HashSet<RequiredColumn> {
165        let mut required = HashSet::new();
166        Self::collect_required_recursive(op, &mut required);
167        required
168    }
169
170    /// Recursively collects required columns.
171    fn collect_required_recursive(op: &LogicalOperator, required: &mut HashSet<RequiredColumn>) {
172        match op {
173            LogicalOperator::Return(ret) => {
174                for item in &ret.items {
175                    Self::collect_from_expression(&item.expression, required);
176                }
177                Self::collect_required_recursive(&ret.input, required);
178            }
179            LogicalOperator::Project(proj) => {
180                for p in &proj.projections {
181                    Self::collect_from_expression(&p.expression, required);
182                }
183                Self::collect_required_recursive(&proj.input, required);
184            }
185            LogicalOperator::Filter(filter) => {
186                Self::collect_from_expression(&filter.predicate, required);
187                Self::collect_required_recursive(&filter.input, required);
188            }
189            LogicalOperator::Sort(sort) => {
190                for key in &sort.keys {
191                    Self::collect_from_expression(&key.expression, required);
192                }
193                Self::collect_required_recursive(&sort.input, required);
194            }
195            LogicalOperator::Aggregate(agg) => {
196                for expr in &agg.group_by {
197                    Self::collect_from_expression(expr, required);
198                }
199                for agg_expr in &agg.aggregates {
200                    if let Some(ref expr) = agg_expr.expression {
201                        Self::collect_from_expression(expr, required);
202                    }
203                }
204                if let Some(ref having) = agg.having {
205                    Self::collect_from_expression(having, required);
206                }
207                Self::collect_required_recursive(&agg.input, required);
208            }
209            LogicalOperator::Join(join) => {
210                for cond in &join.conditions {
211                    Self::collect_from_expression(&cond.left, required);
212                    Self::collect_from_expression(&cond.right, required);
213                }
214                Self::collect_required_recursive(&join.left, required);
215                Self::collect_required_recursive(&join.right, required);
216            }
217            LogicalOperator::Expand(expand) => {
218                // The source and target variables are needed
219                required.insert(RequiredColumn::Variable(expand.from_variable.clone()));
220                required.insert(RequiredColumn::Variable(expand.to_variable.clone()));
221                if let Some(ref edge_var) = expand.edge_variable {
222                    required.insert(RequiredColumn::Variable(edge_var.clone()));
223                }
224                Self::collect_required_recursive(&expand.input, required);
225            }
226            LogicalOperator::Limit(limit) => {
227                Self::collect_required_recursive(&limit.input, required);
228            }
229            LogicalOperator::Skip(skip) => {
230                Self::collect_required_recursive(&skip.input, required);
231            }
232            LogicalOperator::Distinct(distinct) => {
233                Self::collect_required_recursive(&distinct.input, required);
234            }
235            LogicalOperator::NodeScan(scan) => {
236                required.insert(RequiredColumn::Variable(scan.variable.clone()));
237            }
238            LogicalOperator::EdgeScan(scan) => {
239                required.insert(RequiredColumn::Variable(scan.variable.clone()));
240            }
241            _ => {}
242        }
243    }
244
245    /// Collects required columns from an expression.
246    fn collect_from_expression(expr: &LogicalExpression, required: &mut HashSet<RequiredColumn>) {
247        match expr {
248            LogicalExpression::Variable(var) => {
249                required.insert(RequiredColumn::Variable(var.clone()));
250            }
251            LogicalExpression::Property { variable, property } => {
252                required.insert(RequiredColumn::Property(variable.clone(), property.clone()));
253                required.insert(RequiredColumn::Variable(variable.clone()));
254            }
255            LogicalExpression::Binary { left, right, .. } => {
256                Self::collect_from_expression(left, required);
257                Self::collect_from_expression(right, required);
258            }
259            LogicalExpression::Unary { operand, .. } => {
260                Self::collect_from_expression(operand, required);
261            }
262            LogicalExpression::FunctionCall { args, .. } => {
263                for arg in args {
264                    Self::collect_from_expression(arg, required);
265                }
266            }
267            LogicalExpression::List(items) => {
268                for item in items {
269                    Self::collect_from_expression(item, required);
270                }
271            }
272            LogicalExpression::Map(pairs) => {
273                for (_, value) in pairs {
274                    Self::collect_from_expression(value, required);
275                }
276            }
277            LogicalExpression::IndexAccess { base, index } => {
278                Self::collect_from_expression(base, required);
279                Self::collect_from_expression(index, required);
280            }
281            LogicalExpression::SliceAccess { base, start, end } => {
282                Self::collect_from_expression(base, required);
283                if let Some(s) = start {
284                    Self::collect_from_expression(s, required);
285                }
286                if let Some(e) = end {
287                    Self::collect_from_expression(e, required);
288                }
289            }
290            LogicalExpression::Case {
291                operand,
292                when_clauses,
293                else_clause,
294            } => {
295                if let Some(op) = operand {
296                    Self::collect_from_expression(op, required);
297                }
298                for (cond, result) in when_clauses {
299                    Self::collect_from_expression(cond, required);
300                    Self::collect_from_expression(result, required);
301                }
302                if let Some(else_expr) = else_clause {
303                    Self::collect_from_expression(else_expr, required);
304                }
305            }
306            LogicalExpression::Labels(var)
307            | LogicalExpression::Type(var)
308            | LogicalExpression::Id(var) => {
309                required.insert(RequiredColumn::Variable(var.clone()));
310            }
311            LogicalExpression::ListComprehension {
312                list_expr,
313                filter_expr,
314                map_expr,
315                ..
316            } => {
317                Self::collect_from_expression(list_expr, required);
318                if let Some(filter) = filter_expr {
319                    Self::collect_from_expression(filter, required);
320                }
321                Self::collect_from_expression(map_expr, required);
322            }
323            _ => {}
324        }
325    }
326
327    /// Recursively pushes projections down, adding them before expensive operations.
328    fn push_projections_recursive(
329        &self,
330        op: LogicalOperator,
331        required: &HashSet<RequiredColumn>,
332    ) -> LogicalOperator {
333        match op {
334            LogicalOperator::Return(mut ret) => {
335                ret.input = Box::new(self.push_projections_recursive(*ret.input, required));
336                LogicalOperator::Return(ret)
337            }
338            LogicalOperator::Project(mut proj) => {
339                proj.input = Box::new(self.push_projections_recursive(*proj.input, required));
340                LogicalOperator::Project(proj)
341            }
342            LogicalOperator::Filter(mut filter) => {
343                filter.input = Box::new(self.push_projections_recursive(*filter.input, required));
344                LogicalOperator::Filter(filter)
345            }
346            LogicalOperator::Sort(mut sort) => {
347                // Sort is expensive - consider adding a projection before it
348                // to reduce tuple width
349                sort.input = Box::new(self.push_projections_recursive(*sort.input, required));
350                LogicalOperator::Sort(sort)
351            }
352            LogicalOperator::Aggregate(mut agg) => {
353                agg.input = Box::new(self.push_projections_recursive(*agg.input, required));
354                LogicalOperator::Aggregate(agg)
355            }
356            LogicalOperator::Join(mut join) => {
357                // Joins are expensive - the required columns help determine
358                // what to project on each side
359                let left_vars = self.collect_output_variables(&join.left);
360                let right_vars = self.collect_output_variables(&join.right);
361
362                // Filter required columns to each side
363                let left_required: HashSet<_> = required
364                    .iter()
365                    .filter(|c| match c {
366                        RequiredColumn::Variable(v) => left_vars.contains(v),
367                        RequiredColumn::Property(v, _) => left_vars.contains(v),
368                    })
369                    .cloned()
370                    .collect();
371
372                let right_required: HashSet<_> = required
373                    .iter()
374                    .filter(|c| match c {
375                        RequiredColumn::Variable(v) => right_vars.contains(v),
376                        RequiredColumn::Property(v, _) => right_vars.contains(v),
377                    })
378                    .cloned()
379                    .collect();
380
381                join.left = Box::new(self.push_projections_recursive(*join.left, &left_required));
382                join.right =
383                    Box::new(self.push_projections_recursive(*join.right, &right_required));
384                LogicalOperator::Join(join)
385            }
386            LogicalOperator::Expand(mut expand) => {
387                expand.input = Box::new(self.push_projections_recursive(*expand.input, required));
388                LogicalOperator::Expand(expand)
389            }
390            LogicalOperator::Limit(mut limit) => {
391                limit.input = Box::new(self.push_projections_recursive(*limit.input, required));
392                LogicalOperator::Limit(limit)
393            }
394            LogicalOperator::Skip(mut skip) => {
395                skip.input = Box::new(self.push_projections_recursive(*skip.input, required));
396                LogicalOperator::Skip(skip)
397            }
398            LogicalOperator::Distinct(mut distinct) => {
399                distinct.input =
400                    Box::new(self.push_projections_recursive(*distinct.input, required));
401                LogicalOperator::Distinct(distinct)
402            }
403            other => other,
404        }
405    }
406
407    /// Reorders joins in the operator tree using the DPccp algorithm.
408    ///
409    /// This optimization finds the optimal join order by:
410    /// 1. Extracting all base relations (scans) and join conditions
411    /// 2. Building a join graph
412    /// 3. Using dynamic programming to find the cheapest join order
413    fn reorder_joins(&self, op: LogicalOperator) -> LogicalOperator {
414        // First, recursively optimize children
415        let op = self.reorder_joins_recursive(op);
416
417        // Then, if this is a join tree, try to optimize it
418        if let Some((relations, conditions)) = self.extract_join_tree(&op) {
419            if relations.len() >= 2 {
420                if let Some(optimized) = self.optimize_join_order(&relations, &conditions) {
421                    return optimized;
422                }
423            }
424        }
425
426        op
427    }
428
429    /// Recursively applies join reordering to child operators.
430    fn reorder_joins_recursive(&self, op: LogicalOperator) -> LogicalOperator {
431        match op {
432            LogicalOperator::Return(mut ret) => {
433                ret.input = Box::new(self.reorder_joins(*ret.input));
434                LogicalOperator::Return(ret)
435            }
436            LogicalOperator::Project(mut proj) => {
437                proj.input = Box::new(self.reorder_joins(*proj.input));
438                LogicalOperator::Project(proj)
439            }
440            LogicalOperator::Filter(mut filter) => {
441                filter.input = Box::new(self.reorder_joins(*filter.input));
442                LogicalOperator::Filter(filter)
443            }
444            LogicalOperator::Limit(mut limit) => {
445                limit.input = Box::new(self.reorder_joins(*limit.input));
446                LogicalOperator::Limit(limit)
447            }
448            LogicalOperator::Skip(mut skip) => {
449                skip.input = Box::new(self.reorder_joins(*skip.input));
450                LogicalOperator::Skip(skip)
451            }
452            LogicalOperator::Sort(mut sort) => {
453                sort.input = Box::new(self.reorder_joins(*sort.input));
454                LogicalOperator::Sort(sort)
455            }
456            LogicalOperator::Distinct(mut distinct) => {
457                distinct.input = Box::new(self.reorder_joins(*distinct.input));
458                LogicalOperator::Distinct(distinct)
459            }
460            LogicalOperator::Aggregate(mut agg) => {
461                agg.input = Box::new(self.reorder_joins(*agg.input));
462                LogicalOperator::Aggregate(agg)
463            }
464            LogicalOperator::Expand(mut expand) => {
465                expand.input = Box::new(self.reorder_joins(*expand.input));
466                LogicalOperator::Expand(expand)
467            }
468            // Join operators are handled by the parent reorder_joins call
469            other => other,
470        }
471    }
472
473    /// Extracts base relations and join conditions from a join tree.
474    ///
475    /// Returns None if the operator is not a join tree.
476    fn extract_join_tree(
477        &self,
478        op: &LogicalOperator,
479    ) -> Option<(Vec<(String, LogicalOperator)>, Vec<JoinInfo>)> {
480        let mut relations = Vec::new();
481        let mut join_conditions = Vec::new();
482
483        if !self.collect_join_tree(op, &mut relations, &mut join_conditions) {
484            return None;
485        }
486
487        if relations.len() < 2 {
488            return None;
489        }
490
491        Some((relations, join_conditions))
492    }
493
494    /// Recursively collects base relations and join conditions.
495    ///
496    /// Returns true if this subtree is part of a join tree.
497    fn collect_join_tree(
498        &self,
499        op: &LogicalOperator,
500        relations: &mut Vec<(String, LogicalOperator)>,
501        conditions: &mut Vec<JoinInfo>,
502    ) -> bool {
503        match op {
504            LogicalOperator::Join(join) => {
505                // Collect from both sides
506                let left_ok = self.collect_join_tree(&join.left, relations, conditions);
507                let right_ok = self.collect_join_tree(&join.right, relations, conditions);
508
509                // Add conditions from this join
510                for cond in &join.conditions {
511                    if let (Some(left_var), Some(right_var)) = (
512                        self.extract_variable_from_expr(&cond.left),
513                        self.extract_variable_from_expr(&cond.right),
514                    ) {
515                        conditions.push(JoinInfo {
516                            left_var,
517                            right_var,
518                            left_expr: cond.left.clone(),
519                            right_expr: cond.right.clone(),
520                        });
521                    }
522                }
523
524                left_ok && right_ok
525            }
526            LogicalOperator::NodeScan(scan) => {
527                relations.push((scan.variable.clone(), op.clone()));
528                true
529            }
530            LogicalOperator::EdgeScan(scan) => {
531                relations.push((scan.variable.clone(), op.clone()));
532                true
533            }
534            LogicalOperator::Filter(filter) => {
535                // A filter on a base relation is still part of the join tree
536                self.collect_join_tree(&filter.input, relations, conditions)
537            }
538            LogicalOperator::Expand(expand) => {
539                // Expand is a special case - it's like a join with the adjacency
540                // For now, treat the whole Expand subtree as a single relation
541                relations.push((expand.to_variable.clone(), op.clone()));
542                true
543            }
544            _ => false,
545        }
546    }
547
548    /// Extracts the primary variable from an expression.
549    fn extract_variable_from_expr(&self, expr: &LogicalExpression) -> Option<String> {
550        match expr {
551            LogicalExpression::Variable(v) => Some(v.clone()),
552            LogicalExpression::Property { variable, .. } => Some(variable.clone()),
553            _ => None,
554        }
555    }
556
557    /// Optimizes the join order using DPccp.
558    fn optimize_join_order(
559        &self,
560        relations: &[(String, LogicalOperator)],
561        conditions: &[JoinInfo],
562    ) -> Option<LogicalOperator> {
563        use join_order::{DPccp, JoinGraphBuilder};
564
565        // Build the join graph
566        let mut builder = JoinGraphBuilder::new();
567
568        for (var, relation) in relations {
569            builder.add_relation(var, relation.clone());
570        }
571
572        for cond in conditions {
573            builder.add_join_condition(
574                &cond.left_var,
575                &cond.right_var,
576                cond.left_expr.clone(),
577                cond.right_expr.clone(),
578            );
579        }
580
581        let graph = builder.build();
582
583        // Run DPccp
584        let mut dpccp = DPccp::new(&graph, &self.cost_model, &self.card_estimator);
585        let plan = dpccp.optimize()?;
586
587        Some(plan.operator)
588    }
589
590    /// Pushes filters down the operator tree.
591    ///
592    /// This optimization moves filter predicates as close to the data source
593    /// as possible to reduce the amount of data processed by upper operators.
594    fn push_filters_down(&self, op: LogicalOperator) -> LogicalOperator {
595        match op {
596            // For Filter operators, try to push the predicate into the child
597            LogicalOperator::Filter(filter) => {
598                let optimized_input = self.push_filters_down(*filter.input);
599                self.try_push_filter_into(filter.predicate, optimized_input)
600            }
601            // Recursively optimize children for other operators
602            LogicalOperator::Return(mut ret) => {
603                ret.input = Box::new(self.push_filters_down(*ret.input));
604                LogicalOperator::Return(ret)
605            }
606            LogicalOperator::Project(mut proj) => {
607                proj.input = Box::new(self.push_filters_down(*proj.input));
608                LogicalOperator::Project(proj)
609            }
610            LogicalOperator::Limit(mut limit) => {
611                limit.input = Box::new(self.push_filters_down(*limit.input));
612                LogicalOperator::Limit(limit)
613            }
614            LogicalOperator::Skip(mut skip) => {
615                skip.input = Box::new(self.push_filters_down(*skip.input));
616                LogicalOperator::Skip(skip)
617            }
618            LogicalOperator::Sort(mut sort) => {
619                sort.input = Box::new(self.push_filters_down(*sort.input));
620                LogicalOperator::Sort(sort)
621            }
622            LogicalOperator::Distinct(mut distinct) => {
623                distinct.input = Box::new(self.push_filters_down(*distinct.input));
624                LogicalOperator::Distinct(distinct)
625            }
626            LogicalOperator::Expand(mut expand) => {
627                expand.input = Box::new(self.push_filters_down(*expand.input));
628                LogicalOperator::Expand(expand)
629            }
630            LogicalOperator::Join(mut join) => {
631                join.left = Box::new(self.push_filters_down(*join.left));
632                join.right = Box::new(self.push_filters_down(*join.right));
633                LogicalOperator::Join(join)
634            }
635            LogicalOperator::Aggregate(mut agg) => {
636                agg.input = Box::new(self.push_filters_down(*agg.input));
637                LogicalOperator::Aggregate(agg)
638            }
639            // Leaf operators and unsupported operators are returned as-is
640            other => other,
641        }
642    }
643
644    /// Tries to push a filter predicate into the given operator.
645    ///
646    /// Returns either the predicate pushed into the operator, or a new
647    /// Filter operator on top if the predicate cannot be pushed further.
648    fn try_push_filter_into(
649        &self,
650        predicate: LogicalExpression,
651        op: LogicalOperator,
652    ) -> LogicalOperator {
653        match op {
654            // Can push through Project if predicate doesn't depend on computed columns
655            LogicalOperator::Project(mut proj) => {
656                let predicate_vars = self.extract_variables(&predicate);
657                let computed_vars = self.extract_projection_aliases(&proj.projections);
658
659                // If predicate doesn't use any computed columns, push through
660                if predicate_vars.is_disjoint(&computed_vars) {
661                    proj.input = Box::new(self.try_push_filter_into(predicate, *proj.input));
662                    LogicalOperator::Project(proj)
663                } else {
664                    // Can't push through, keep filter on top
665                    LogicalOperator::Filter(FilterOp {
666                        predicate,
667                        input: Box::new(LogicalOperator::Project(proj)),
668                    })
669                }
670            }
671
672            // Can push through Return (which is like a projection)
673            LogicalOperator::Return(mut ret) => {
674                ret.input = Box::new(self.try_push_filter_into(predicate, *ret.input));
675                LogicalOperator::Return(ret)
676            }
677
678            // Can push through Expand if predicate doesn't use variables introduced by this expand
679            LogicalOperator::Expand(mut expand) => {
680                let predicate_vars = self.extract_variables(&predicate);
681
682                // Variables introduced by this expand are:
683                // - The target variable (to_variable)
684                // - The edge variable (if any)
685                // - The path alias (if any)
686                let mut introduced_vars = vec![&expand.to_variable];
687                if let Some(ref edge_var) = expand.edge_variable {
688                    introduced_vars.push(edge_var);
689                }
690                if let Some(ref path_alias) = expand.path_alias {
691                    introduced_vars.push(path_alias);
692                }
693
694                // Check if predicate uses any variables introduced by this expand
695                let uses_introduced_vars =
696                    predicate_vars.iter().any(|v| introduced_vars.contains(&v));
697
698                if !uses_introduced_vars {
699                    // Predicate doesn't use vars from this expand, so push through
700                    expand.input = Box::new(self.try_push_filter_into(predicate, *expand.input));
701                    LogicalOperator::Expand(expand)
702                } else {
703                    // Keep filter after expand
704                    LogicalOperator::Filter(FilterOp {
705                        predicate,
706                        input: Box::new(LogicalOperator::Expand(expand)),
707                    })
708                }
709            }
710
711            // Can push through Join to left/right side based on variables used
712            LogicalOperator::Join(mut join) => {
713                let predicate_vars = self.extract_variables(&predicate);
714                let left_vars = self.collect_output_variables(&join.left);
715                let right_vars = self.collect_output_variables(&join.right);
716
717                let uses_left = predicate_vars.iter().any(|v| left_vars.contains(v));
718                let uses_right = predicate_vars.iter().any(|v| right_vars.contains(v));
719
720                if uses_left && !uses_right {
721                    // Push to left side
722                    join.left = Box::new(self.try_push_filter_into(predicate, *join.left));
723                    LogicalOperator::Join(join)
724                } else if uses_right && !uses_left {
725                    // Push to right side
726                    join.right = Box::new(self.try_push_filter_into(predicate, *join.right));
727                    LogicalOperator::Join(join)
728                } else {
729                    // Uses both sides - keep above join
730                    LogicalOperator::Filter(FilterOp {
731                        predicate,
732                        input: Box::new(LogicalOperator::Join(join)),
733                    })
734                }
735            }
736
737            // Cannot push through Aggregate (predicate refers to aggregated values)
738            LogicalOperator::Aggregate(agg) => LogicalOperator::Filter(FilterOp {
739                predicate,
740                input: Box::new(LogicalOperator::Aggregate(agg)),
741            }),
742
743            // For NodeScan, we've reached the bottom - keep filter on top
744            LogicalOperator::NodeScan(scan) => LogicalOperator::Filter(FilterOp {
745                predicate,
746                input: Box::new(LogicalOperator::NodeScan(scan)),
747            }),
748
749            // For other operators, keep filter on top
750            other => LogicalOperator::Filter(FilterOp {
751                predicate,
752                input: Box::new(other),
753            }),
754        }
755    }
756
757    /// Collects all output variable names from an operator.
758    fn collect_output_variables(&self, op: &LogicalOperator) -> HashSet<String> {
759        let mut vars = HashSet::new();
760        Self::collect_output_variables_recursive(op, &mut vars);
761        vars
762    }
763
764    /// Recursively collects output variables from an operator.
765    fn collect_output_variables_recursive(op: &LogicalOperator, vars: &mut HashSet<String>) {
766        match op {
767            LogicalOperator::NodeScan(scan) => {
768                vars.insert(scan.variable.clone());
769            }
770            LogicalOperator::EdgeScan(scan) => {
771                vars.insert(scan.variable.clone());
772            }
773            LogicalOperator::Expand(expand) => {
774                vars.insert(expand.to_variable.clone());
775                if let Some(edge_var) = &expand.edge_variable {
776                    vars.insert(edge_var.clone());
777                }
778                Self::collect_output_variables_recursive(&expand.input, vars);
779            }
780            LogicalOperator::Filter(filter) => {
781                Self::collect_output_variables_recursive(&filter.input, vars);
782            }
783            LogicalOperator::Project(proj) => {
784                for p in &proj.projections {
785                    if let Some(alias) = &p.alias {
786                        vars.insert(alias.clone());
787                    }
788                }
789                Self::collect_output_variables_recursive(&proj.input, vars);
790            }
791            LogicalOperator::Join(join) => {
792                Self::collect_output_variables_recursive(&join.left, vars);
793                Self::collect_output_variables_recursive(&join.right, vars);
794            }
795            LogicalOperator::Aggregate(agg) => {
796                for expr in &agg.group_by {
797                    Self::collect_variables(expr, vars);
798                }
799                for agg_expr in &agg.aggregates {
800                    if let Some(alias) = &agg_expr.alias {
801                        vars.insert(alias.clone());
802                    }
803                }
804            }
805            LogicalOperator::Return(ret) => {
806                Self::collect_output_variables_recursive(&ret.input, vars);
807            }
808            LogicalOperator::Limit(limit) => {
809                Self::collect_output_variables_recursive(&limit.input, vars);
810            }
811            LogicalOperator::Skip(skip) => {
812                Self::collect_output_variables_recursive(&skip.input, vars);
813            }
814            LogicalOperator::Sort(sort) => {
815                Self::collect_output_variables_recursive(&sort.input, vars);
816            }
817            LogicalOperator::Distinct(distinct) => {
818                Self::collect_output_variables_recursive(&distinct.input, vars);
819            }
820            _ => {}
821        }
822    }
823
824    /// Extracts all variable names referenced in an expression.
825    fn extract_variables(&self, expr: &LogicalExpression) -> HashSet<String> {
826        let mut vars = HashSet::new();
827        Self::collect_variables(expr, &mut vars);
828        vars
829    }
830
831    /// Recursively collects variable names from an expression.
832    fn collect_variables(expr: &LogicalExpression, vars: &mut HashSet<String>) {
833        match expr {
834            LogicalExpression::Variable(name) => {
835                vars.insert(name.clone());
836            }
837            LogicalExpression::Property { variable, .. } => {
838                vars.insert(variable.clone());
839            }
840            LogicalExpression::Binary { left, right, .. } => {
841                Self::collect_variables(left, vars);
842                Self::collect_variables(right, vars);
843            }
844            LogicalExpression::Unary { operand, .. } => {
845                Self::collect_variables(operand, vars);
846            }
847            LogicalExpression::FunctionCall { args, .. } => {
848                for arg in args {
849                    Self::collect_variables(arg, vars);
850                }
851            }
852            LogicalExpression::List(items) => {
853                for item in items {
854                    Self::collect_variables(item, vars);
855                }
856            }
857            LogicalExpression::Map(pairs) => {
858                for (_, value) in pairs {
859                    Self::collect_variables(value, vars);
860                }
861            }
862            LogicalExpression::IndexAccess { base, index } => {
863                Self::collect_variables(base, vars);
864                Self::collect_variables(index, vars);
865            }
866            LogicalExpression::SliceAccess { base, start, end } => {
867                Self::collect_variables(base, vars);
868                if let Some(s) = start {
869                    Self::collect_variables(s, vars);
870                }
871                if let Some(e) = end {
872                    Self::collect_variables(e, vars);
873                }
874            }
875            LogicalExpression::Case {
876                operand,
877                when_clauses,
878                else_clause,
879            } => {
880                if let Some(op) = operand {
881                    Self::collect_variables(op, vars);
882                }
883                for (cond, result) in when_clauses {
884                    Self::collect_variables(cond, vars);
885                    Self::collect_variables(result, vars);
886                }
887                if let Some(else_expr) = else_clause {
888                    Self::collect_variables(else_expr, vars);
889                }
890            }
891            LogicalExpression::Labels(var)
892            | LogicalExpression::Type(var)
893            | LogicalExpression::Id(var) => {
894                vars.insert(var.clone());
895            }
896            LogicalExpression::Literal(_) | LogicalExpression::Parameter(_) => {}
897            LogicalExpression::ListComprehension {
898                list_expr,
899                filter_expr,
900                map_expr,
901                ..
902            } => {
903                Self::collect_variables(list_expr, vars);
904                if let Some(filter) = filter_expr {
905                    Self::collect_variables(filter, vars);
906                }
907                Self::collect_variables(map_expr, vars);
908            }
909            LogicalExpression::ExistsSubquery(_) | LogicalExpression::CountSubquery(_) => {
910                // Subqueries have their own variable scope
911            }
912        }
913    }
914
915    /// Extracts aliases from projection expressions.
916    fn extract_projection_aliases(
917        &self,
918        projections: &[crate::query::plan::Projection],
919    ) -> HashSet<String> {
920        projections.iter().filter_map(|p| p.alias.clone()).collect()
921    }
922}
923
924impl Default for Optimizer {
925    fn default() -> Self {
926        Self::new()
927    }
928}
929
930#[cfg(test)]
931mod tests {
932    use super::*;
933    use crate::query::plan::{
934        AggregateExpr, AggregateFunction, AggregateOp, BinaryOp, DistinctOp, ExpandDirection,
935        ExpandOp, JoinOp, JoinType, LimitOp, NodeScanOp, ProjectOp, Projection, ReturnItem,
936        ReturnOp, SkipOp, SortKey, SortOp, SortOrder, UnaryOp,
937    };
938    use grafeo_common::types::Value;
939
940    #[test]
941    fn test_optimizer_filter_pushdown_simple() {
942        // Query: MATCH (n:Person) WHERE n.age > 30 RETURN n
943        // Before: Return -> Filter -> NodeScan
944        // After:  Return -> Filter -> NodeScan (filter stays at bottom)
945
946        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
947            items: vec![ReturnItem {
948                expression: LogicalExpression::Variable("n".to_string()),
949                alias: None,
950            }],
951            distinct: false,
952            input: Box::new(LogicalOperator::Filter(FilterOp {
953                predicate: LogicalExpression::Binary {
954                    left: Box::new(LogicalExpression::Property {
955                        variable: "n".to_string(),
956                        property: "age".to_string(),
957                    }),
958                    op: BinaryOp::Gt,
959                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
960                },
961                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
962                    variable: "n".to_string(),
963                    label: Some("Person".to_string()),
964                    input: None,
965                })),
966            })),
967        }));
968
969        let optimizer = Optimizer::new();
970        let optimized = optimizer.optimize(plan).unwrap();
971
972        // The structure should remain similar (filter stays near scan)
973        if let LogicalOperator::Return(ret) = &optimized.root {
974            if let LogicalOperator::Filter(filter) = ret.input.as_ref() {
975                if let LogicalOperator::NodeScan(scan) = filter.input.as_ref() {
976                    assert_eq!(scan.variable, "n");
977                    return;
978                }
979            }
980        }
981        panic!("Expected Return -> Filter -> NodeScan structure");
982    }
983
984    #[test]
985    fn test_optimizer_filter_pushdown_through_expand() {
986        // Query: MATCH (a:Person)-[:KNOWS]->(b) WHERE a.age > 30 RETURN b
987        // The filter on 'a' should be pushed before the expand
988
989        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
990            items: vec![ReturnItem {
991                expression: LogicalExpression::Variable("b".to_string()),
992                alias: None,
993            }],
994            distinct: false,
995            input: Box::new(LogicalOperator::Filter(FilterOp {
996                predicate: LogicalExpression::Binary {
997                    left: Box::new(LogicalExpression::Property {
998                        variable: "a".to_string(),
999                        property: "age".to_string(),
1000                    }),
1001                    op: BinaryOp::Gt,
1002                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1003                },
1004                input: Box::new(LogicalOperator::Expand(ExpandOp {
1005                    from_variable: "a".to_string(),
1006                    to_variable: "b".to_string(),
1007                    edge_variable: None,
1008                    direction: ExpandDirection::Outgoing,
1009                    edge_type: Some("KNOWS".to_string()),
1010                    min_hops: 1,
1011                    max_hops: Some(1),
1012                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1013                        variable: "a".to_string(),
1014                        label: Some("Person".to_string()),
1015                        input: None,
1016                    })),
1017                    path_alias: None,
1018                })),
1019            })),
1020        }));
1021
1022        let optimizer = Optimizer::new();
1023        let optimized = optimizer.optimize(plan).unwrap();
1024
1025        // Filter on 'a' should be pushed before the expand
1026        // Expected: Return -> Expand -> Filter -> NodeScan
1027        if let LogicalOperator::Return(ret) = &optimized.root {
1028            if let LogicalOperator::Expand(expand) = ret.input.as_ref() {
1029                if let LogicalOperator::Filter(filter) = expand.input.as_ref() {
1030                    if let LogicalOperator::NodeScan(scan) = filter.input.as_ref() {
1031                        assert_eq!(scan.variable, "a");
1032                        assert_eq!(expand.from_variable, "a");
1033                        assert_eq!(expand.to_variable, "b");
1034                        return;
1035                    }
1036                }
1037            }
1038        }
1039        panic!("Expected Return -> Expand -> Filter -> NodeScan structure");
1040    }
1041
1042    #[test]
1043    fn test_optimizer_filter_not_pushed_through_expand_for_target_var() {
1044        // Query: MATCH (a:Person)-[:KNOWS]->(b) WHERE b.age > 30 RETURN a
1045        // The filter on 'b' should NOT be pushed before the expand
1046
1047        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1048            items: vec![ReturnItem {
1049                expression: LogicalExpression::Variable("a".to_string()),
1050                alias: None,
1051            }],
1052            distinct: false,
1053            input: Box::new(LogicalOperator::Filter(FilterOp {
1054                predicate: LogicalExpression::Binary {
1055                    left: Box::new(LogicalExpression::Property {
1056                        variable: "b".to_string(),
1057                        property: "age".to_string(),
1058                    }),
1059                    op: BinaryOp::Gt,
1060                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1061                },
1062                input: Box::new(LogicalOperator::Expand(ExpandOp {
1063                    from_variable: "a".to_string(),
1064                    to_variable: "b".to_string(),
1065                    edge_variable: None,
1066                    direction: ExpandDirection::Outgoing,
1067                    edge_type: Some("KNOWS".to_string()),
1068                    min_hops: 1,
1069                    max_hops: Some(1),
1070                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1071                        variable: "a".to_string(),
1072                        label: Some("Person".to_string()),
1073                        input: None,
1074                    })),
1075                    path_alias: None,
1076                })),
1077            })),
1078        }));
1079
1080        let optimizer = Optimizer::new();
1081        let optimized = optimizer.optimize(plan).unwrap();
1082
1083        // Filter on 'b' should stay after the expand
1084        // Expected: Return -> Filter -> Expand -> NodeScan
1085        if let LogicalOperator::Return(ret) = &optimized.root {
1086            if let LogicalOperator::Filter(filter) = ret.input.as_ref() {
1087                // Check that the filter is on 'b'
1088                if let LogicalExpression::Binary { left, .. } = &filter.predicate {
1089                    if let LogicalExpression::Property { variable, .. } = left.as_ref() {
1090                        assert_eq!(variable, "b");
1091                    }
1092                }
1093
1094                if let LogicalOperator::Expand(expand) = filter.input.as_ref() {
1095                    if let LogicalOperator::NodeScan(_) = expand.input.as_ref() {
1096                        return;
1097                    }
1098                }
1099            }
1100        }
1101        panic!("Expected Return -> Filter -> Expand -> NodeScan structure");
1102    }
1103
1104    #[test]
1105    fn test_optimizer_extract_variables() {
1106        let optimizer = Optimizer::new();
1107
1108        let expr = LogicalExpression::Binary {
1109            left: Box::new(LogicalExpression::Property {
1110                variable: "n".to_string(),
1111                property: "age".to_string(),
1112            }),
1113            op: BinaryOp::Gt,
1114            right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1115        };
1116
1117        let vars = optimizer.extract_variables(&expr);
1118        assert_eq!(vars.len(), 1);
1119        assert!(vars.contains("n"));
1120    }
1121
1122    // Additional tests for optimizer configuration
1123
1124    #[test]
1125    fn test_optimizer_default() {
1126        let optimizer = Optimizer::default();
1127        // Should be able to optimize an empty plan
1128        let plan = LogicalPlan::new(LogicalOperator::Empty);
1129        let result = optimizer.optimize(plan);
1130        assert!(result.is_ok());
1131    }
1132
1133    #[test]
1134    fn test_optimizer_with_filter_pushdown_disabled() {
1135        let optimizer = Optimizer::new().with_filter_pushdown(false);
1136
1137        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1138            items: vec![ReturnItem {
1139                expression: LogicalExpression::Variable("n".to_string()),
1140                alias: None,
1141            }],
1142            distinct: false,
1143            input: Box::new(LogicalOperator::Filter(FilterOp {
1144                predicate: LogicalExpression::Literal(Value::Bool(true)),
1145                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1146                    variable: "n".to_string(),
1147                    label: None,
1148                    input: None,
1149                })),
1150            })),
1151        }));
1152
1153        let optimized = optimizer.optimize(plan).unwrap();
1154        // Structure should be unchanged
1155        if let LogicalOperator::Return(ret) = &optimized.root {
1156            if let LogicalOperator::Filter(_) = ret.input.as_ref() {
1157                return;
1158            }
1159        }
1160        panic!("Expected unchanged structure");
1161    }
1162
1163    #[test]
1164    fn test_optimizer_with_join_reorder_disabled() {
1165        let optimizer = Optimizer::new().with_join_reorder(false);
1166        assert!(
1167            optimizer
1168                .optimize(LogicalPlan::new(LogicalOperator::Empty))
1169                .is_ok()
1170        );
1171    }
1172
1173    #[test]
1174    fn test_optimizer_with_cost_model() {
1175        let cost_model = CostModel::new();
1176        let optimizer = Optimizer::new().with_cost_model(cost_model);
1177        assert!(
1178            optimizer
1179                .cost_model()
1180                .estimate(&LogicalOperator::Empty, 0.0)
1181                .total()
1182                < 0.001
1183        );
1184    }
1185
1186    #[test]
1187    fn test_optimizer_with_cardinality_estimator() {
1188        let mut estimator = CardinalityEstimator::new();
1189        estimator.add_table_stats("Test", TableStats::new(500));
1190        let optimizer = Optimizer::new().with_cardinality_estimator(estimator);
1191
1192        let scan = LogicalOperator::NodeScan(NodeScanOp {
1193            variable: "n".to_string(),
1194            label: Some("Test".to_string()),
1195            input: None,
1196        });
1197        let plan = LogicalPlan::new(scan);
1198
1199        let cardinality = optimizer.estimate_cardinality(&plan);
1200        assert!((cardinality - 500.0).abs() < 0.001);
1201    }
1202
1203    #[test]
1204    fn test_optimizer_estimate_cost() {
1205        let optimizer = Optimizer::new();
1206        let plan = LogicalPlan::new(LogicalOperator::NodeScan(NodeScanOp {
1207            variable: "n".to_string(),
1208            label: None,
1209            input: None,
1210        }));
1211
1212        let cost = optimizer.estimate_cost(&plan);
1213        assert!(cost.total() > 0.0);
1214    }
1215
1216    // Filter pushdown through various operators
1217
1218    #[test]
1219    fn test_filter_pushdown_through_project() {
1220        let optimizer = Optimizer::new();
1221
1222        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1223            predicate: LogicalExpression::Binary {
1224                left: Box::new(LogicalExpression::Property {
1225                    variable: "n".to_string(),
1226                    property: "age".to_string(),
1227                }),
1228                op: BinaryOp::Gt,
1229                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1230            },
1231            input: Box::new(LogicalOperator::Project(ProjectOp {
1232                projections: vec![Projection {
1233                    expression: LogicalExpression::Variable("n".to_string()),
1234                    alias: None,
1235                }],
1236                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1237                    variable: "n".to_string(),
1238                    label: None,
1239                    input: None,
1240                })),
1241            })),
1242        }));
1243
1244        let optimized = optimizer.optimize(plan).unwrap();
1245
1246        // Filter should be pushed through Project
1247        if let LogicalOperator::Project(proj) = &optimized.root {
1248            if let LogicalOperator::Filter(_) = proj.input.as_ref() {
1249                return;
1250            }
1251        }
1252        panic!("Expected Project -> Filter structure");
1253    }
1254
1255    #[test]
1256    fn test_filter_not_pushed_through_project_with_alias() {
1257        let optimizer = Optimizer::new();
1258
1259        // Filter on computed column 'x' should not be pushed through project that creates 'x'
1260        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1261            predicate: LogicalExpression::Binary {
1262                left: Box::new(LogicalExpression::Variable("x".to_string())),
1263                op: BinaryOp::Gt,
1264                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1265            },
1266            input: Box::new(LogicalOperator::Project(ProjectOp {
1267                projections: vec![Projection {
1268                    expression: LogicalExpression::Property {
1269                        variable: "n".to_string(),
1270                        property: "age".to_string(),
1271                    },
1272                    alias: Some("x".to_string()),
1273                }],
1274                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1275                    variable: "n".to_string(),
1276                    label: None,
1277                    input: None,
1278                })),
1279            })),
1280        }));
1281
1282        let optimized = optimizer.optimize(plan).unwrap();
1283
1284        // Filter should stay above Project
1285        if let LogicalOperator::Filter(filter) = &optimized.root {
1286            if let LogicalOperator::Project(_) = filter.input.as_ref() {
1287                return;
1288            }
1289        }
1290        panic!("Expected Filter -> Project structure");
1291    }
1292
1293    #[test]
1294    fn test_filter_pushdown_through_limit() {
1295        let optimizer = Optimizer::new();
1296
1297        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1298            predicate: LogicalExpression::Literal(Value::Bool(true)),
1299            input: Box::new(LogicalOperator::Limit(LimitOp {
1300                count: 10,
1301                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1302                    variable: "n".to_string(),
1303                    label: None,
1304                    input: None,
1305                })),
1306            })),
1307        }));
1308
1309        let optimized = optimizer.optimize(plan).unwrap();
1310
1311        // Filter stays above Limit (cannot be pushed through)
1312        if let LogicalOperator::Filter(filter) = &optimized.root {
1313            if let LogicalOperator::Limit(_) = filter.input.as_ref() {
1314                return;
1315            }
1316        }
1317        panic!("Expected Filter -> Limit structure");
1318    }
1319
1320    #[test]
1321    fn test_filter_pushdown_through_sort() {
1322        let optimizer = Optimizer::new();
1323
1324        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1325            predicate: LogicalExpression::Literal(Value::Bool(true)),
1326            input: Box::new(LogicalOperator::Sort(SortOp {
1327                keys: vec![SortKey {
1328                    expression: LogicalExpression::Variable("n".to_string()),
1329                    order: SortOrder::Ascending,
1330                }],
1331                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1332                    variable: "n".to_string(),
1333                    label: None,
1334                    input: None,
1335                })),
1336            })),
1337        }));
1338
1339        let optimized = optimizer.optimize(plan).unwrap();
1340
1341        // Filter stays above Sort
1342        if let LogicalOperator::Filter(filter) = &optimized.root {
1343            if let LogicalOperator::Sort(_) = filter.input.as_ref() {
1344                return;
1345            }
1346        }
1347        panic!("Expected Filter -> Sort structure");
1348    }
1349
1350    #[test]
1351    fn test_filter_pushdown_through_distinct() {
1352        let optimizer = Optimizer::new();
1353
1354        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1355            predicate: LogicalExpression::Literal(Value::Bool(true)),
1356            input: Box::new(LogicalOperator::Distinct(DistinctOp {
1357                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1358                    variable: "n".to_string(),
1359                    label: None,
1360                    input: None,
1361                })),
1362                columns: None,
1363            })),
1364        }));
1365
1366        let optimized = optimizer.optimize(plan).unwrap();
1367
1368        // Filter stays above Distinct
1369        if let LogicalOperator::Filter(filter) = &optimized.root {
1370            if let LogicalOperator::Distinct(_) = filter.input.as_ref() {
1371                return;
1372            }
1373        }
1374        panic!("Expected Filter -> Distinct structure");
1375    }
1376
1377    #[test]
1378    fn test_filter_not_pushed_through_aggregate() {
1379        let optimizer = Optimizer::new();
1380
1381        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1382            predicate: LogicalExpression::Binary {
1383                left: Box::new(LogicalExpression::Variable("cnt".to_string())),
1384                op: BinaryOp::Gt,
1385                right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1386            },
1387            input: Box::new(LogicalOperator::Aggregate(AggregateOp {
1388                group_by: vec![],
1389                aggregates: vec![AggregateExpr {
1390                    function: AggregateFunction::Count,
1391                    expression: None,
1392                    distinct: false,
1393                    alias: Some("cnt".to_string()),
1394                    percentile: None,
1395                }],
1396                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1397                    variable: "n".to_string(),
1398                    label: None,
1399                    input: None,
1400                })),
1401                having: None,
1402            })),
1403        }));
1404
1405        let optimized = optimizer.optimize(plan).unwrap();
1406
1407        // Filter should stay above Aggregate
1408        if let LogicalOperator::Filter(filter) = &optimized.root {
1409            if let LogicalOperator::Aggregate(_) = filter.input.as_ref() {
1410                return;
1411            }
1412        }
1413        panic!("Expected Filter -> Aggregate structure");
1414    }
1415
1416    #[test]
1417    fn test_filter_pushdown_to_left_join_side() {
1418        let optimizer = Optimizer::new();
1419
1420        // Filter on left variable should be pushed to left side
1421        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1422            predicate: LogicalExpression::Binary {
1423                left: Box::new(LogicalExpression::Property {
1424                    variable: "a".to_string(),
1425                    property: "age".to_string(),
1426                }),
1427                op: BinaryOp::Gt,
1428                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1429            },
1430            input: Box::new(LogicalOperator::Join(JoinOp {
1431                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1432                    variable: "a".to_string(),
1433                    label: Some("Person".to_string()),
1434                    input: None,
1435                })),
1436                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1437                    variable: "b".to_string(),
1438                    label: Some("Company".to_string()),
1439                    input: None,
1440                })),
1441                join_type: JoinType::Inner,
1442                conditions: vec![],
1443            })),
1444        }));
1445
1446        let optimized = optimizer.optimize(plan).unwrap();
1447
1448        // Filter should be pushed to left side of join
1449        if let LogicalOperator::Join(join) = &optimized.root {
1450            if let LogicalOperator::Filter(_) = join.left.as_ref() {
1451                return;
1452            }
1453        }
1454        panic!("Expected Join with Filter on left side");
1455    }
1456
1457    #[test]
1458    fn test_filter_pushdown_to_right_join_side() {
1459        let optimizer = Optimizer::new();
1460
1461        // Filter on right variable should be pushed to right side
1462        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1463            predicate: LogicalExpression::Binary {
1464                left: Box::new(LogicalExpression::Property {
1465                    variable: "b".to_string(),
1466                    property: "name".to_string(),
1467                }),
1468                op: BinaryOp::Eq,
1469                right: Box::new(LogicalExpression::Literal(Value::String("Acme".into()))),
1470            },
1471            input: Box::new(LogicalOperator::Join(JoinOp {
1472                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1473                    variable: "a".to_string(),
1474                    label: Some("Person".to_string()),
1475                    input: None,
1476                })),
1477                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1478                    variable: "b".to_string(),
1479                    label: Some("Company".to_string()),
1480                    input: None,
1481                })),
1482                join_type: JoinType::Inner,
1483                conditions: vec![],
1484            })),
1485        }));
1486
1487        let optimized = optimizer.optimize(plan).unwrap();
1488
1489        // Filter should be pushed to right side of join
1490        if let LogicalOperator::Join(join) = &optimized.root {
1491            if let LogicalOperator::Filter(_) = join.right.as_ref() {
1492                return;
1493            }
1494        }
1495        panic!("Expected Join with Filter on right side");
1496    }
1497
1498    #[test]
1499    fn test_filter_not_pushed_when_uses_both_join_sides() {
1500        let optimizer = Optimizer::new();
1501
1502        // Filter using both variables should stay above join
1503        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1504            predicate: LogicalExpression::Binary {
1505                left: Box::new(LogicalExpression::Property {
1506                    variable: "a".to_string(),
1507                    property: "id".to_string(),
1508                }),
1509                op: BinaryOp::Eq,
1510                right: Box::new(LogicalExpression::Property {
1511                    variable: "b".to_string(),
1512                    property: "a_id".to_string(),
1513                }),
1514            },
1515            input: Box::new(LogicalOperator::Join(JoinOp {
1516                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1517                    variable: "a".to_string(),
1518                    label: None,
1519                    input: None,
1520                })),
1521                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1522                    variable: "b".to_string(),
1523                    label: None,
1524                    input: None,
1525                })),
1526                join_type: JoinType::Inner,
1527                conditions: vec![],
1528            })),
1529        }));
1530
1531        let optimized = optimizer.optimize(plan).unwrap();
1532
1533        // Filter should stay above join
1534        if let LogicalOperator::Filter(filter) = &optimized.root {
1535            if let LogicalOperator::Join(_) = filter.input.as_ref() {
1536                return;
1537            }
1538        }
1539        panic!("Expected Filter -> Join structure");
1540    }
1541
1542    // Variable extraction tests
1543
1544    #[test]
1545    fn test_extract_variables_from_variable() {
1546        let optimizer = Optimizer::new();
1547        let expr = LogicalExpression::Variable("x".to_string());
1548        let vars = optimizer.extract_variables(&expr);
1549        assert_eq!(vars.len(), 1);
1550        assert!(vars.contains("x"));
1551    }
1552
1553    #[test]
1554    fn test_extract_variables_from_unary() {
1555        let optimizer = Optimizer::new();
1556        let expr = LogicalExpression::Unary {
1557            op: UnaryOp::Not,
1558            operand: Box::new(LogicalExpression::Variable("x".to_string())),
1559        };
1560        let vars = optimizer.extract_variables(&expr);
1561        assert_eq!(vars.len(), 1);
1562        assert!(vars.contains("x"));
1563    }
1564
1565    #[test]
1566    fn test_extract_variables_from_function_call() {
1567        let optimizer = Optimizer::new();
1568        let expr = LogicalExpression::FunctionCall {
1569            name: "length".to_string(),
1570            args: vec![
1571                LogicalExpression::Variable("a".to_string()),
1572                LogicalExpression::Variable("b".to_string()),
1573            ],
1574            distinct: false,
1575        };
1576        let vars = optimizer.extract_variables(&expr);
1577        assert_eq!(vars.len(), 2);
1578        assert!(vars.contains("a"));
1579        assert!(vars.contains("b"));
1580    }
1581
1582    #[test]
1583    fn test_extract_variables_from_list() {
1584        let optimizer = Optimizer::new();
1585        let expr = LogicalExpression::List(vec![
1586            LogicalExpression::Variable("a".to_string()),
1587            LogicalExpression::Literal(Value::Int64(1)),
1588            LogicalExpression::Variable("b".to_string()),
1589        ]);
1590        let vars = optimizer.extract_variables(&expr);
1591        assert_eq!(vars.len(), 2);
1592        assert!(vars.contains("a"));
1593        assert!(vars.contains("b"));
1594    }
1595
1596    #[test]
1597    fn test_extract_variables_from_map() {
1598        let optimizer = Optimizer::new();
1599        let expr = LogicalExpression::Map(vec![
1600            (
1601                "key1".to_string(),
1602                LogicalExpression::Variable("a".to_string()),
1603            ),
1604            (
1605                "key2".to_string(),
1606                LogicalExpression::Variable("b".to_string()),
1607            ),
1608        ]);
1609        let vars = optimizer.extract_variables(&expr);
1610        assert_eq!(vars.len(), 2);
1611        assert!(vars.contains("a"));
1612        assert!(vars.contains("b"));
1613    }
1614
1615    #[test]
1616    fn test_extract_variables_from_index_access() {
1617        let optimizer = Optimizer::new();
1618        let expr = LogicalExpression::IndexAccess {
1619            base: Box::new(LogicalExpression::Variable("list".to_string())),
1620            index: Box::new(LogicalExpression::Variable("idx".to_string())),
1621        };
1622        let vars = optimizer.extract_variables(&expr);
1623        assert_eq!(vars.len(), 2);
1624        assert!(vars.contains("list"));
1625        assert!(vars.contains("idx"));
1626    }
1627
1628    #[test]
1629    fn test_extract_variables_from_slice_access() {
1630        let optimizer = Optimizer::new();
1631        let expr = LogicalExpression::SliceAccess {
1632            base: Box::new(LogicalExpression::Variable("list".to_string())),
1633            start: Some(Box::new(LogicalExpression::Variable("s".to_string()))),
1634            end: Some(Box::new(LogicalExpression::Variable("e".to_string()))),
1635        };
1636        let vars = optimizer.extract_variables(&expr);
1637        assert_eq!(vars.len(), 3);
1638        assert!(vars.contains("list"));
1639        assert!(vars.contains("s"));
1640        assert!(vars.contains("e"));
1641    }
1642
1643    #[test]
1644    fn test_extract_variables_from_case() {
1645        let optimizer = Optimizer::new();
1646        let expr = LogicalExpression::Case {
1647            operand: Some(Box::new(LogicalExpression::Variable("x".to_string()))),
1648            when_clauses: vec![(
1649                LogicalExpression::Literal(Value::Int64(1)),
1650                LogicalExpression::Variable("a".to_string()),
1651            )],
1652            else_clause: Some(Box::new(LogicalExpression::Variable("b".to_string()))),
1653        };
1654        let vars = optimizer.extract_variables(&expr);
1655        assert_eq!(vars.len(), 3);
1656        assert!(vars.contains("x"));
1657        assert!(vars.contains("a"));
1658        assert!(vars.contains("b"));
1659    }
1660
1661    #[test]
1662    fn test_extract_variables_from_labels() {
1663        let optimizer = Optimizer::new();
1664        let expr = LogicalExpression::Labels("n".to_string());
1665        let vars = optimizer.extract_variables(&expr);
1666        assert_eq!(vars.len(), 1);
1667        assert!(vars.contains("n"));
1668    }
1669
1670    #[test]
1671    fn test_extract_variables_from_type() {
1672        let optimizer = Optimizer::new();
1673        let expr = LogicalExpression::Type("e".to_string());
1674        let vars = optimizer.extract_variables(&expr);
1675        assert_eq!(vars.len(), 1);
1676        assert!(vars.contains("e"));
1677    }
1678
1679    #[test]
1680    fn test_extract_variables_from_id() {
1681        let optimizer = Optimizer::new();
1682        let expr = LogicalExpression::Id("n".to_string());
1683        let vars = optimizer.extract_variables(&expr);
1684        assert_eq!(vars.len(), 1);
1685        assert!(vars.contains("n"));
1686    }
1687
1688    #[test]
1689    fn test_extract_variables_from_list_comprehension() {
1690        let optimizer = Optimizer::new();
1691        let expr = LogicalExpression::ListComprehension {
1692            variable: "x".to_string(),
1693            list_expr: Box::new(LogicalExpression::Variable("items".to_string())),
1694            filter_expr: Some(Box::new(LogicalExpression::Variable("pred".to_string()))),
1695            map_expr: Box::new(LogicalExpression::Variable("result".to_string())),
1696        };
1697        let vars = optimizer.extract_variables(&expr);
1698        assert!(vars.contains("items"));
1699        assert!(vars.contains("pred"));
1700        assert!(vars.contains("result"));
1701    }
1702
1703    #[test]
1704    fn test_extract_variables_from_literal_and_parameter() {
1705        let optimizer = Optimizer::new();
1706
1707        let literal = LogicalExpression::Literal(Value::Int64(42));
1708        assert!(optimizer.extract_variables(&literal).is_empty());
1709
1710        let param = LogicalExpression::Parameter("p".to_string());
1711        assert!(optimizer.extract_variables(&param).is_empty());
1712    }
1713
1714    // Recursive filter pushdown tests
1715
1716    #[test]
1717    fn test_recursive_filter_pushdown_through_skip() {
1718        let optimizer = Optimizer::new();
1719
1720        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1721            items: vec![ReturnItem {
1722                expression: LogicalExpression::Variable("n".to_string()),
1723                alias: None,
1724            }],
1725            distinct: false,
1726            input: Box::new(LogicalOperator::Filter(FilterOp {
1727                predicate: LogicalExpression::Literal(Value::Bool(true)),
1728                input: Box::new(LogicalOperator::Skip(SkipOp {
1729                    count: 5,
1730                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1731                        variable: "n".to_string(),
1732                        label: None,
1733                        input: None,
1734                    })),
1735                })),
1736            })),
1737        }));
1738
1739        let optimized = optimizer.optimize(plan).unwrap();
1740
1741        // Verify optimization succeeded
1742        assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1743    }
1744
1745    #[test]
1746    fn test_nested_filter_pushdown() {
1747        let optimizer = Optimizer::new();
1748
1749        // Multiple nested filters
1750        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1751            items: vec![ReturnItem {
1752                expression: LogicalExpression::Variable("n".to_string()),
1753                alias: None,
1754            }],
1755            distinct: false,
1756            input: Box::new(LogicalOperator::Filter(FilterOp {
1757                predicate: LogicalExpression::Binary {
1758                    left: Box::new(LogicalExpression::Property {
1759                        variable: "n".to_string(),
1760                        property: "x".to_string(),
1761                    }),
1762                    op: BinaryOp::Gt,
1763                    right: Box::new(LogicalExpression::Literal(Value::Int64(1))),
1764                },
1765                input: Box::new(LogicalOperator::Filter(FilterOp {
1766                    predicate: LogicalExpression::Binary {
1767                        left: Box::new(LogicalExpression::Property {
1768                            variable: "n".to_string(),
1769                            property: "y".to_string(),
1770                        }),
1771                        op: BinaryOp::Lt,
1772                        right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1773                    },
1774                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1775                        variable: "n".to_string(),
1776                        label: None,
1777                        input: None,
1778                    })),
1779                })),
1780            })),
1781        }));
1782
1783        let optimized = optimizer.optimize(plan).unwrap();
1784        assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1785    }
1786}