Skip to main content

grafeo_engine/query/optimizer/
mod.rs

1//! Makes your queries faster without changing their meaning.
2//!
3//! The optimizer transforms logical plans to run more efficiently:
4//!
5//! | Optimization | What it does |
6//! | ------------ | ------------ |
7//! | Filter Pushdown | Moves `WHERE` clauses closer to scans - filter early, process less |
8//! | Join Reordering | Picks the best order to join tables using the DPccp algorithm |
9//! | Predicate Simplification | Folds constants like `1 + 1` into `2` |
10//!
11//! The optimizer uses [`CostModel`] and [`CardinalityEstimator`] to predict
12//! how expensive different plans are, then picks the cheapest.
13
14pub mod cardinality;
15pub mod cost;
16pub mod join_order;
17
18pub use cardinality::{CardinalityEstimator, ColumnStats, TableStats};
19pub use cost::{Cost, CostModel};
20pub use join_order::{BitSet, DPccp, JoinGraph, JoinGraphBuilder, JoinPlan};
21
22use crate::query::plan::{FilterOp, LogicalExpression, LogicalOperator, LogicalPlan};
23use grafeo_common::utils::error::Result;
24use std::collections::HashSet;
25
26/// Information about a join condition for join reordering.
27#[derive(Debug, Clone)]
28struct JoinInfo {
29    left_var: String,
30    right_var: String,
31    left_expr: LogicalExpression,
32    right_expr: LogicalExpression,
33}
34
35/// A column required by the query, used for projection pushdown.
36#[derive(Debug, Clone, PartialEq, Eq, Hash)]
37enum RequiredColumn {
38    /// A variable (node, edge, or path binding)
39    Variable(String),
40    /// A specific property of a variable
41    Property(String, String),
42}
43
44/// Transforms logical plans for faster execution.
45///
46/// Create with [`new()`](Self::new), then call [`optimize()`](Self::optimize).
47/// Use the builder methods to enable/disable specific optimizations.
48pub struct Optimizer {
49    /// Whether to enable filter pushdown.
50    enable_filter_pushdown: bool,
51    /// Whether to enable join reordering.
52    enable_join_reorder: bool,
53    /// Whether to enable projection pushdown.
54    enable_projection_pushdown: bool,
55    /// Cost model for estimation.
56    cost_model: CostModel,
57    /// Cardinality estimator.
58    card_estimator: CardinalityEstimator,
59}
60
61impl Optimizer {
62    /// Creates a new optimizer with default settings.
63    #[must_use]
64    pub fn new() -> Self {
65        Self {
66            enable_filter_pushdown: true,
67            enable_join_reorder: true,
68            enable_projection_pushdown: true,
69            cost_model: CostModel::new(),
70            card_estimator: CardinalityEstimator::new(),
71        }
72    }
73
74    /// Enables or disables filter pushdown.
75    pub fn with_filter_pushdown(mut self, enabled: bool) -> Self {
76        self.enable_filter_pushdown = enabled;
77        self
78    }
79
80    /// Enables or disables join reordering.
81    pub fn with_join_reorder(mut self, enabled: bool) -> Self {
82        self.enable_join_reorder = enabled;
83        self
84    }
85
86    /// Enables or disables projection pushdown.
87    pub fn with_projection_pushdown(mut self, enabled: bool) -> Self {
88        self.enable_projection_pushdown = enabled;
89        self
90    }
91
92    /// Sets the cost model.
93    pub fn with_cost_model(mut self, cost_model: CostModel) -> Self {
94        self.cost_model = cost_model;
95        self
96    }
97
98    /// Sets the cardinality estimator.
99    pub fn with_cardinality_estimator(mut self, estimator: CardinalityEstimator) -> Self {
100        self.card_estimator = estimator;
101        self
102    }
103
104    /// Returns a reference to the cost model.
105    pub fn cost_model(&self) -> &CostModel {
106        &self.cost_model
107    }
108
109    /// Returns a reference to the cardinality estimator.
110    pub fn cardinality_estimator(&self) -> &CardinalityEstimator {
111        &self.card_estimator
112    }
113
114    /// Estimates the cost of a plan.
115    pub fn estimate_cost(&self, plan: &LogicalPlan) -> Cost {
116        let cardinality = self.card_estimator.estimate(&plan.root);
117        self.cost_model.estimate(&plan.root, cardinality)
118    }
119
120    /// Estimates the cardinality of a plan.
121    pub fn estimate_cardinality(&self, plan: &LogicalPlan) -> f64 {
122        self.card_estimator.estimate(&plan.root)
123    }
124
125    /// Optimizes a logical plan.
126    ///
127    /// # Errors
128    ///
129    /// Returns an error if optimization fails.
130    pub fn optimize(&self, plan: LogicalPlan) -> Result<LogicalPlan> {
131        let mut root = plan.root;
132
133        // Apply optimization rules
134        if self.enable_filter_pushdown {
135            root = self.push_filters_down(root);
136        }
137
138        if self.enable_join_reorder {
139            root = self.reorder_joins(root);
140        }
141
142        if self.enable_projection_pushdown {
143            root = self.push_projections_down(root);
144        }
145
146        Ok(LogicalPlan::new(root))
147    }
148
149    /// Pushes projections down the operator tree to eliminate unused columns early.
150    ///
151    /// This optimization:
152    /// 1. Collects required variables/properties from the root
153    /// 2. Propagates requirements down through the tree
154    /// 3. Inserts projections to eliminate unneeded columns before expensive operations
155    fn push_projections_down(&self, op: LogicalOperator) -> LogicalOperator {
156        // Collect required columns from the top of the plan
157        let required = self.collect_required_columns(&op);
158
159        // Push projections down
160        self.push_projections_recursive(op, &required)
161    }
162
163    /// Collects all variables and properties required by an operator and its ancestors.
164    fn collect_required_columns(&self, op: &LogicalOperator) -> HashSet<RequiredColumn> {
165        let mut required = HashSet::new();
166        self.collect_required_recursive(op, &mut required);
167        required
168    }
169
170    /// Recursively collects required columns.
171    fn collect_required_recursive(
172        &self,
173        op: &LogicalOperator,
174        required: &mut HashSet<RequiredColumn>,
175    ) {
176        match op {
177            LogicalOperator::Return(ret) => {
178                for item in &ret.items {
179                    self.collect_from_expression(&item.expression, required);
180                }
181                self.collect_required_recursive(&ret.input, required);
182            }
183            LogicalOperator::Project(proj) => {
184                for p in &proj.projections {
185                    self.collect_from_expression(&p.expression, required);
186                }
187                self.collect_required_recursive(&proj.input, required);
188            }
189            LogicalOperator::Filter(filter) => {
190                self.collect_from_expression(&filter.predicate, required);
191                self.collect_required_recursive(&filter.input, required);
192            }
193            LogicalOperator::Sort(sort) => {
194                for key in &sort.keys {
195                    self.collect_from_expression(&key.expression, required);
196                }
197                self.collect_required_recursive(&sort.input, required);
198            }
199            LogicalOperator::Aggregate(agg) => {
200                for expr in &agg.group_by {
201                    self.collect_from_expression(expr, required);
202                }
203                for agg_expr in &agg.aggregates {
204                    if let Some(ref expr) = agg_expr.expression {
205                        self.collect_from_expression(expr, required);
206                    }
207                }
208                if let Some(ref having) = agg.having {
209                    self.collect_from_expression(having, required);
210                }
211                self.collect_required_recursive(&agg.input, required);
212            }
213            LogicalOperator::Join(join) => {
214                for cond in &join.conditions {
215                    self.collect_from_expression(&cond.left, required);
216                    self.collect_from_expression(&cond.right, required);
217                }
218                self.collect_required_recursive(&join.left, required);
219                self.collect_required_recursive(&join.right, required);
220            }
221            LogicalOperator::Expand(expand) => {
222                // The source and target variables are needed
223                required.insert(RequiredColumn::Variable(expand.from_variable.clone()));
224                required.insert(RequiredColumn::Variable(expand.to_variable.clone()));
225                if let Some(ref edge_var) = expand.edge_variable {
226                    required.insert(RequiredColumn::Variable(edge_var.clone()));
227                }
228                self.collect_required_recursive(&expand.input, required);
229            }
230            LogicalOperator::Limit(limit) => {
231                self.collect_required_recursive(&limit.input, required);
232            }
233            LogicalOperator::Skip(skip) => {
234                self.collect_required_recursive(&skip.input, required);
235            }
236            LogicalOperator::Distinct(distinct) => {
237                self.collect_required_recursive(&distinct.input, required);
238            }
239            LogicalOperator::NodeScan(scan) => {
240                required.insert(RequiredColumn::Variable(scan.variable.clone()));
241            }
242            LogicalOperator::EdgeScan(scan) => {
243                required.insert(RequiredColumn::Variable(scan.variable.clone()));
244            }
245            _ => {}
246        }
247    }
248
249    /// Collects required columns from an expression.
250    fn collect_from_expression(
251        &self,
252        expr: &LogicalExpression,
253        required: &mut HashSet<RequiredColumn>,
254    ) {
255        match expr {
256            LogicalExpression::Variable(var) => {
257                required.insert(RequiredColumn::Variable(var.clone()));
258            }
259            LogicalExpression::Property { variable, property } => {
260                required.insert(RequiredColumn::Property(variable.clone(), property.clone()));
261                required.insert(RequiredColumn::Variable(variable.clone()));
262            }
263            LogicalExpression::Binary { left, right, .. } => {
264                self.collect_from_expression(left, required);
265                self.collect_from_expression(right, required);
266            }
267            LogicalExpression::Unary { operand, .. } => {
268                self.collect_from_expression(operand, required);
269            }
270            LogicalExpression::FunctionCall { args, .. } => {
271                for arg in args {
272                    self.collect_from_expression(arg, required);
273                }
274            }
275            LogicalExpression::List(items) => {
276                for item in items {
277                    self.collect_from_expression(item, required);
278                }
279            }
280            LogicalExpression::Map(pairs) => {
281                for (_, value) in pairs {
282                    self.collect_from_expression(value, required);
283                }
284            }
285            LogicalExpression::IndexAccess { base, index } => {
286                self.collect_from_expression(base, required);
287                self.collect_from_expression(index, required);
288            }
289            LogicalExpression::SliceAccess { base, start, end } => {
290                self.collect_from_expression(base, required);
291                if let Some(s) = start {
292                    self.collect_from_expression(s, required);
293                }
294                if let Some(e) = end {
295                    self.collect_from_expression(e, required);
296                }
297            }
298            LogicalExpression::Case {
299                operand,
300                when_clauses,
301                else_clause,
302            } => {
303                if let Some(op) = operand {
304                    self.collect_from_expression(op, required);
305                }
306                for (cond, result) in when_clauses {
307                    self.collect_from_expression(cond, required);
308                    self.collect_from_expression(result, required);
309                }
310                if let Some(else_expr) = else_clause {
311                    self.collect_from_expression(else_expr, required);
312                }
313            }
314            LogicalExpression::Labels(var)
315            | LogicalExpression::Type(var)
316            | LogicalExpression::Id(var) => {
317                required.insert(RequiredColumn::Variable(var.clone()));
318            }
319            LogicalExpression::ListComprehension {
320                list_expr,
321                filter_expr,
322                map_expr,
323                ..
324            } => {
325                self.collect_from_expression(list_expr, required);
326                if let Some(filter) = filter_expr {
327                    self.collect_from_expression(filter, required);
328                }
329                self.collect_from_expression(map_expr, required);
330            }
331            _ => {}
332        }
333    }
334
335    /// Recursively pushes projections down, adding them before expensive operations.
336    fn push_projections_recursive(
337        &self,
338        op: LogicalOperator,
339        required: &HashSet<RequiredColumn>,
340    ) -> LogicalOperator {
341        match op {
342            LogicalOperator::Return(mut ret) => {
343                ret.input = Box::new(self.push_projections_recursive(*ret.input, required));
344                LogicalOperator::Return(ret)
345            }
346            LogicalOperator::Project(mut proj) => {
347                proj.input = Box::new(self.push_projections_recursive(*proj.input, required));
348                LogicalOperator::Project(proj)
349            }
350            LogicalOperator::Filter(mut filter) => {
351                filter.input = Box::new(self.push_projections_recursive(*filter.input, required));
352                LogicalOperator::Filter(filter)
353            }
354            LogicalOperator::Sort(mut sort) => {
355                // Sort is expensive - consider adding a projection before it
356                // to reduce tuple width
357                sort.input = Box::new(self.push_projections_recursive(*sort.input, required));
358                LogicalOperator::Sort(sort)
359            }
360            LogicalOperator::Aggregate(mut agg) => {
361                agg.input = Box::new(self.push_projections_recursive(*agg.input, required));
362                LogicalOperator::Aggregate(agg)
363            }
364            LogicalOperator::Join(mut join) => {
365                // Joins are expensive - the required columns help determine
366                // what to project on each side
367                let left_vars = self.collect_output_variables(&join.left);
368                let right_vars = self.collect_output_variables(&join.right);
369
370                // Filter required columns to each side
371                let left_required: HashSet<_> = required
372                    .iter()
373                    .filter(|c| match c {
374                        RequiredColumn::Variable(v) => left_vars.contains(v),
375                        RequiredColumn::Property(v, _) => left_vars.contains(v),
376                    })
377                    .cloned()
378                    .collect();
379
380                let right_required: HashSet<_> = required
381                    .iter()
382                    .filter(|c| match c {
383                        RequiredColumn::Variable(v) => right_vars.contains(v),
384                        RequiredColumn::Property(v, _) => right_vars.contains(v),
385                    })
386                    .cloned()
387                    .collect();
388
389                join.left = Box::new(self.push_projections_recursive(*join.left, &left_required));
390                join.right =
391                    Box::new(self.push_projections_recursive(*join.right, &right_required));
392                LogicalOperator::Join(join)
393            }
394            LogicalOperator::Expand(mut expand) => {
395                expand.input = Box::new(self.push_projections_recursive(*expand.input, required));
396                LogicalOperator::Expand(expand)
397            }
398            LogicalOperator::Limit(mut limit) => {
399                limit.input = Box::new(self.push_projections_recursive(*limit.input, required));
400                LogicalOperator::Limit(limit)
401            }
402            LogicalOperator::Skip(mut skip) => {
403                skip.input = Box::new(self.push_projections_recursive(*skip.input, required));
404                LogicalOperator::Skip(skip)
405            }
406            LogicalOperator::Distinct(mut distinct) => {
407                distinct.input =
408                    Box::new(self.push_projections_recursive(*distinct.input, required));
409                LogicalOperator::Distinct(distinct)
410            }
411            other => other,
412        }
413    }
414
415    /// Reorders joins in the operator tree using the DPccp algorithm.
416    ///
417    /// This optimization finds the optimal join order by:
418    /// 1. Extracting all base relations (scans) and join conditions
419    /// 2. Building a join graph
420    /// 3. Using dynamic programming to find the cheapest join order
421    fn reorder_joins(&self, op: LogicalOperator) -> LogicalOperator {
422        // First, recursively optimize children
423        let op = self.reorder_joins_recursive(op);
424
425        // Then, if this is a join tree, try to optimize it
426        if let Some((relations, conditions)) = self.extract_join_tree(&op) {
427            if relations.len() >= 2 {
428                if let Some(optimized) = self.optimize_join_order(&relations, &conditions) {
429                    return optimized;
430                }
431            }
432        }
433
434        op
435    }
436
437    /// Recursively applies join reordering to child operators.
438    fn reorder_joins_recursive(&self, op: LogicalOperator) -> LogicalOperator {
439        match op {
440            LogicalOperator::Return(mut ret) => {
441                ret.input = Box::new(self.reorder_joins(*ret.input));
442                LogicalOperator::Return(ret)
443            }
444            LogicalOperator::Project(mut proj) => {
445                proj.input = Box::new(self.reorder_joins(*proj.input));
446                LogicalOperator::Project(proj)
447            }
448            LogicalOperator::Filter(mut filter) => {
449                filter.input = Box::new(self.reorder_joins(*filter.input));
450                LogicalOperator::Filter(filter)
451            }
452            LogicalOperator::Limit(mut limit) => {
453                limit.input = Box::new(self.reorder_joins(*limit.input));
454                LogicalOperator::Limit(limit)
455            }
456            LogicalOperator::Skip(mut skip) => {
457                skip.input = Box::new(self.reorder_joins(*skip.input));
458                LogicalOperator::Skip(skip)
459            }
460            LogicalOperator::Sort(mut sort) => {
461                sort.input = Box::new(self.reorder_joins(*sort.input));
462                LogicalOperator::Sort(sort)
463            }
464            LogicalOperator::Distinct(mut distinct) => {
465                distinct.input = Box::new(self.reorder_joins(*distinct.input));
466                LogicalOperator::Distinct(distinct)
467            }
468            LogicalOperator::Aggregate(mut agg) => {
469                agg.input = Box::new(self.reorder_joins(*agg.input));
470                LogicalOperator::Aggregate(agg)
471            }
472            LogicalOperator::Expand(mut expand) => {
473                expand.input = Box::new(self.reorder_joins(*expand.input));
474                LogicalOperator::Expand(expand)
475            }
476            // Join operators are handled by the parent reorder_joins call
477            other => other,
478        }
479    }
480
481    /// Extracts base relations and join conditions from a join tree.
482    ///
483    /// Returns None if the operator is not a join tree.
484    fn extract_join_tree(
485        &self,
486        op: &LogicalOperator,
487    ) -> Option<(Vec<(String, LogicalOperator)>, Vec<JoinInfo>)> {
488        let mut relations = Vec::new();
489        let mut join_conditions = Vec::new();
490
491        if !self.collect_join_tree(op, &mut relations, &mut join_conditions) {
492            return None;
493        }
494
495        if relations.len() < 2 {
496            return None;
497        }
498
499        Some((relations, join_conditions))
500    }
501
502    /// Recursively collects base relations and join conditions.
503    ///
504    /// Returns true if this subtree is part of a join tree.
505    fn collect_join_tree(
506        &self,
507        op: &LogicalOperator,
508        relations: &mut Vec<(String, LogicalOperator)>,
509        conditions: &mut Vec<JoinInfo>,
510    ) -> bool {
511        match op {
512            LogicalOperator::Join(join) => {
513                // Collect from both sides
514                let left_ok = self.collect_join_tree(&join.left, relations, conditions);
515                let right_ok = self.collect_join_tree(&join.right, relations, conditions);
516
517                // Add conditions from this join
518                for cond in &join.conditions {
519                    if let (Some(left_var), Some(right_var)) = (
520                        self.extract_variable_from_expr(&cond.left),
521                        self.extract_variable_from_expr(&cond.right),
522                    ) {
523                        conditions.push(JoinInfo {
524                            left_var,
525                            right_var,
526                            left_expr: cond.left.clone(),
527                            right_expr: cond.right.clone(),
528                        });
529                    }
530                }
531
532                left_ok && right_ok
533            }
534            LogicalOperator::NodeScan(scan) => {
535                relations.push((scan.variable.clone(), op.clone()));
536                true
537            }
538            LogicalOperator::EdgeScan(scan) => {
539                relations.push((scan.variable.clone(), op.clone()));
540                true
541            }
542            LogicalOperator::Filter(filter) => {
543                // A filter on a base relation is still part of the join tree
544                self.collect_join_tree(&filter.input, relations, conditions)
545            }
546            LogicalOperator::Expand(expand) => {
547                // Expand is a special case - it's like a join with the adjacency
548                // For now, treat the whole Expand subtree as a single relation
549                relations.push((expand.to_variable.clone(), op.clone()));
550                true
551            }
552            _ => false,
553        }
554    }
555
556    /// Extracts the primary variable from an expression.
557    fn extract_variable_from_expr(&self, expr: &LogicalExpression) -> Option<String> {
558        match expr {
559            LogicalExpression::Variable(v) => Some(v.clone()),
560            LogicalExpression::Property { variable, .. } => Some(variable.clone()),
561            _ => None,
562        }
563    }
564
565    /// Optimizes the join order using DPccp.
566    fn optimize_join_order(
567        &self,
568        relations: &[(String, LogicalOperator)],
569        conditions: &[JoinInfo],
570    ) -> Option<LogicalOperator> {
571        use join_order::{DPccp, JoinGraphBuilder};
572
573        // Build the join graph
574        let mut builder = JoinGraphBuilder::new();
575
576        for (var, relation) in relations {
577            builder.add_relation(var, relation.clone());
578        }
579
580        for cond in conditions {
581            builder.add_join_condition(
582                &cond.left_var,
583                &cond.right_var,
584                cond.left_expr.clone(),
585                cond.right_expr.clone(),
586            );
587        }
588
589        let graph = builder.build();
590
591        // Run DPccp
592        let mut dpccp = DPccp::new(&graph, &self.cost_model, &self.card_estimator);
593        let plan = dpccp.optimize()?;
594
595        Some(plan.operator)
596    }
597
598    /// Pushes filters down the operator tree.
599    ///
600    /// This optimization moves filter predicates as close to the data source
601    /// as possible to reduce the amount of data processed by upper operators.
602    fn push_filters_down(&self, op: LogicalOperator) -> LogicalOperator {
603        match op {
604            // For Filter operators, try to push the predicate into the child
605            LogicalOperator::Filter(filter) => {
606                let optimized_input = self.push_filters_down(*filter.input);
607                self.try_push_filter_into(filter.predicate, optimized_input)
608            }
609            // Recursively optimize children for other operators
610            LogicalOperator::Return(mut ret) => {
611                ret.input = Box::new(self.push_filters_down(*ret.input));
612                LogicalOperator::Return(ret)
613            }
614            LogicalOperator::Project(mut proj) => {
615                proj.input = Box::new(self.push_filters_down(*proj.input));
616                LogicalOperator::Project(proj)
617            }
618            LogicalOperator::Limit(mut limit) => {
619                limit.input = Box::new(self.push_filters_down(*limit.input));
620                LogicalOperator::Limit(limit)
621            }
622            LogicalOperator::Skip(mut skip) => {
623                skip.input = Box::new(self.push_filters_down(*skip.input));
624                LogicalOperator::Skip(skip)
625            }
626            LogicalOperator::Sort(mut sort) => {
627                sort.input = Box::new(self.push_filters_down(*sort.input));
628                LogicalOperator::Sort(sort)
629            }
630            LogicalOperator::Distinct(mut distinct) => {
631                distinct.input = Box::new(self.push_filters_down(*distinct.input));
632                LogicalOperator::Distinct(distinct)
633            }
634            LogicalOperator::Expand(mut expand) => {
635                expand.input = Box::new(self.push_filters_down(*expand.input));
636                LogicalOperator::Expand(expand)
637            }
638            LogicalOperator::Join(mut join) => {
639                join.left = Box::new(self.push_filters_down(*join.left));
640                join.right = Box::new(self.push_filters_down(*join.right));
641                LogicalOperator::Join(join)
642            }
643            LogicalOperator::Aggregate(mut agg) => {
644                agg.input = Box::new(self.push_filters_down(*agg.input));
645                LogicalOperator::Aggregate(agg)
646            }
647            // Leaf operators and unsupported operators are returned as-is
648            other => other,
649        }
650    }
651
652    /// Tries to push a filter predicate into the given operator.
653    ///
654    /// Returns either the predicate pushed into the operator, or a new
655    /// Filter operator on top if the predicate cannot be pushed further.
656    fn try_push_filter_into(
657        &self,
658        predicate: LogicalExpression,
659        op: LogicalOperator,
660    ) -> LogicalOperator {
661        match op {
662            // Can push through Project if predicate doesn't depend on computed columns
663            LogicalOperator::Project(mut proj) => {
664                let predicate_vars = self.extract_variables(&predicate);
665                let computed_vars = self.extract_projection_aliases(&proj.projections);
666
667                // If predicate doesn't use any computed columns, push through
668                if predicate_vars.is_disjoint(&computed_vars) {
669                    proj.input = Box::new(self.try_push_filter_into(predicate, *proj.input));
670                    LogicalOperator::Project(proj)
671                } else {
672                    // Can't push through, keep filter on top
673                    LogicalOperator::Filter(FilterOp {
674                        predicate,
675                        input: Box::new(LogicalOperator::Project(proj)),
676                    })
677                }
678            }
679
680            // Can push through Return (which is like a projection)
681            LogicalOperator::Return(mut ret) => {
682                ret.input = Box::new(self.try_push_filter_into(predicate, *ret.input));
683                LogicalOperator::Return(ret)
684            }
685
686            // Can push through Expand if predicate only uses source variable
687            LogicalOperator::Expand(mut expand) => {
688                let predicate_vars = self.extract_variables(&predicate);
689
690                // Check if predicate only uses the source variable
691                let uses_only_source = predicate_vars.iter().all(|v| v == &expand.from_variable);
692
693                if uses_only_source {
694                    // Push the filter before the expand
695                    expand.input = Box::new(self.try_push_filter_into(predicate, *expand.input));
696                    LogicalOperator::Expand(expand)
697                } else {
698                    // Keep filter after expand
699                    LogicalOperator::Filter(FilterOp {
700                        predicate,
701                        input: Box::new(LogicalOperator::Expand(expand)),
702                    })
703                }
704            }
705
706            // Can push through Join to left/right side based on variables used
707            LogicalOperator::Join(mut join) => {
708                let predicate_vars = self.extract_variables(&predicate);
709                let left_vars = self.collect_output_variables(&join.left);
710                let right_vars = self.collect_output_variables(&join.right);
711
712                let uses_left = predicate_vars.iter().any(|v| left_vars.contains(v));
713                let uses_right = predicate_vars.iter().any(|v| right_vars.contains(v));
714
715                if uses_left && !uses_right {
716                    // Push to left side
717                    join.left = Box::new(self.try_push_filter_into(predicate, *join.left));
718                    LogicalOperator::Join(join)
719                } else if uses_right && !uses_left {
720                    // Push to right side
721                    join.right = Box::new(self.try_push_filter_into(predicate, *join.right));
722                    LogicalOperator::Join(join)
723                } else {
724                    // Uses both sides - keep above join
725                    LogicalOperator::Filter(FilterOp {
726                        predicate,
727                        input: Box::new(LogicalOperator::Join(join)),
728                    })
729                }
730            }
731
732            // Cannot push through Aggregate (predicate refers to aggregated values)
733            LogicalOperator::Aggregate(agg) => LogicalOperator::Filter(FilterOp {
734                predicate,
735                input: Box::new(LogicalOperator::Aggregate(agg)),
736            }),
737
738            // For NodeScan, we've reached the bottom - keep filter on top
739            LogicalOperator::NodeScan(scan) => LogicalOperator::Filter(FilterOp {
740                predicate,
741                input: Box::new(LogicalOperator::NodeScan(scan)),
742            }),
743
744            // For other operators, keep filter on top
745            other => LogicalOperator::Filter(FilterOp {
746                predicate,
747                input: Box::new(other),
748            }),
749        }
750    }
751
752    /// Collects all output variable names from an operator.
753    fn collect_output_variables(&self, op: &LogicalOperator) -> HashSet<String> {
754        let mut vars = HashSet::new();
755        Self::collect_output_variables_recursive(op, &mut vars);
756        vars
757    }
758
759    /// Recursively collects output variables from an operator.
760    fn collect_output_variables_recursive(op: &LogicalOperator, vars: &mut HashSet<String>) {
761        match op {
762            LogicalOperator::NodeScan(scan) => {
763                vars.insert(scan.variable.clone());
764            }
765            LogicalOperator::EdgeScan(scan) => {
766                vars.insert(scan.variable.clone());
767            }
768            LogicalOperator::Expand(expand) => {
769                vars.insert(expand.to_variable.clone());
770                if let Some(edge_var) = &expand.edge_variable {
771                    vars.insert(edge_var.clone());
772                }
773                Self::collect_output_variables_recursive(&expand.input, vars);
774            }
775            LogicalOperator::Filter(filter) => {
776                Self::collect_output_variables_recursive(&filter.input, vars);
777            }
778            LogicalOperator::Project(proj) => {
779                for p in &proj.projections {
780                    if let Some(alias) = &p.alias {
781                        vars.insert(alias.clone());
782                    }
783                }
784                Self::collect_output_variables_recursive(&proj.input, vars);
785            }
786            LogicalOperator::Join(join) => {
787                Self::collect_output_variables_recursive(&join.left, vars);
788                Self::collect_output_variables_recursive(&join.right, vars);
789            }
790            LogicalOperator::Aggregate(agg) => {
791                for expr in &agg.group_by {
792                    Self::collect_variables(expr, vars);
793                }
794                for agg_expr in &agg.aggregates {
795                    if let Some(alias) = &agg_expr.alias {
796                        vars.insert(alias.clone());
797                    }
798                }
799            }
800            LogicalOperator::Return(ret) => {
801                Self::collect_output_variables_recursive(&ret.input, vars);
802            }
803            LogicalOperator::Limit(limit) => {
804                Self::collect_output_variables_recursive(&limit.input, vars);
805            }
806            LogicalOperator::Skip(skip) => {
807                Self::collect_output_variables_recursive(&skip.input, vars);
808            }
809            LogicalOperator::Sort(sort) => {
810                Self::collect_output_variables_recursive(&sort.input, vars);
811            }
812            LogicalOperator::Distinct(distinct) => {
813                Self::collect_output_variables_recursive(&distinct.input, vars);
814            }
815            _ => {}
816        }
817    }
818
819    /// Extracts all variable names referenced in an expression.
820    fn extract_variables(&self, expr: &LogicalExpression) -> HashSet<String> {
821        let mut vars = HashSet::new();
822        Self::collect_variables(expr, &mut vars);
823        vars
824    }
825
826    /// Recursively collects variable names from an expression.
827    fn collect_variables(expr: &LogicalExpression, vars: &mut HashSet<String>) {
828        match expr {
829            LogicalExpression::Variable(name) => {
830                vars.insert(name.clone());
831            }
832            LogicalExpression::Property { variable, .. } => {
833                vars.insert(variable.clone());
834            }
835            LogicalExpression::Binary { left, right, .. } => {
836                Self::collect_variables(left, vars);
837                Self::collect_variables(right, vars);
838            }
839            LogicalExpression::Unary { operand, .. } => {
840                Self::collect_variables(operand, vars);
841            }
842            LogicalExpression::FunctionCall { args, .. } => {
843                for arg in args {
844                    Self::collect_variables(arg, vars);
845                }
846            }
847            LogicalExpression::List(items) => {
848                for item in items {
849                    Self::collect_variables(item, vars);
850                }
851            }
852            LogicalExpression::Map(pairs) => {
853                for (_, value) in pairs {
854                    Self::collect_variables(value, vars);
855                }
856            }
857            LogicalExpression::IndexAccess { base, index } => {
858                Self::collect_variables(base, vars);
859                Self::collect_variables(index, vars);
860            }
861            LogicalExpression::SliceAccess { base, start, end } => {
862                Self::collect_variables(base, vars);
863                if let Some(s) = start {
864                    Self::collect_variables(s, vars);
865                }
866                if let Some(e) = end {
867                    Self::collect_variables(e, vars);
868                }
869            }
870            LogicalExpression::Case {
871                operand,
872                when_clauses,
873                else_clause,
874            } => {
875                if let Some(op) = operand {
876                    Self::collect_variables(op, vars);
877                }
878                for (cond, result) in when_clauses {
879                    Self::collect_variables(cond, vars);
880                    Self::collect_variables(result, vars);
881                }
882                if let Some(else_expr) = else_clause {
883                    Self::collect_variables(else_expr, vars);
884                }
885            }
886            LogicalExpression::Labels(var)
887            | LogicalExpression::Type(var)
888            | LogicalExpression::Id(var) => {
889                vars.insert(var.clone());
890            }
891            LogicalExpression::Literal(_) | LogicalExpression::Parameter(_) => {}
892            LogicalExpression::ListComprehension {
893                list_expr,
894                filter_expr,
895                map_expr,
896                ..
897            } => {
898                Self::collect_variables(list_expr, vars);
899                if let Some(filter) = filter_expr {
900                    Self::collect_variables(filter, vars);
901                }
902                Self::collect_variables(map_expr, vars);
903            }
904            LogicalExpression::ExistsSubquery(_) | LogicalExpression::CountSubquery(_) => {
905                // Subqueries have their own variable scope
906            }
907        }
908    }
909
910    /// Extracts aliases from projection expressions.
911    fn extract_projection_aliases(
912        &self,
913        projections: &[crate::query::plan::Projection],
914    ) -> HashSet<String> {
915        projections.iter().filter_map(|p| p.alias.clone()).collect()
916    }
917}
918
919impl Default for Optimizer {
920    fn default() -> Self {
921        Self::new()
922    }
923}
924
925#[cfg(test)]
926mod tests {
927    use super::*;
928    use crate::query::plan::{
929        AggregateExpr, AggregateFunction, AggregateOp, BinaryOp, DistinctOp, ExpandDirection,
930        ExpandOp, JoinOp, JoinType, LimitOp, NodeScanOp, ProjectOp, Projection, ReturnItem,
931        ReturnOp, SkipOp, SortKey, SortOp, SortOrder, UnaryOp,
932    };
933    use grafeo_common::types::Value;
934
935    #[test]
936    fn test_optimizer_filter_pushdown_simple() {
937        // Query: MATCH (n:Person) WHERE n.age > 30 RETURN n
938        // Before: Return -> Filter -> NodeScan
939        // After:  Return -> Filter -> NodeScan (filter stays at bottom)
940
941        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
942            items: vec![ReturnItem {
943                expression: LogicalExpression::Variable("n".to_string()),
944                alias: None,
945            }],
946            distinct: false,
947            input: Box::new(LogicalOperator::Filter(FilterOp {
948                predicate: LogicalExpression::Binary {
949                    left: Box::new(LogicalExpression::Property {
950                        variable: "n".to_string(),
951                        property: "age".to_string(),
952                    }),
953                    op: BinaryOp::Gt,
954                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
955                },
956                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
957                    variable: "n".to_string(),
958                    label: Some("Person".to_string()),
959                    input: None,
960                })),
961            })),
962        }));
963
964        let optimizer = Optimizer::new();
965        let optimized = optimizer.optimize(plan).unwrap();
966
967        // The structure should remain similar (filter stays near scan)
968        if let LogicalOperator::Return(ret) = &optimized.root {
969            if let LogicalOperator::Filter(filter) = ret.input.as_ref() {
970                if let LogicalOperator::NodeScan(scan) = filter.input.as_ref() {
971                    assert_eq!(scan.variable, "n");
972                    return;
973                }
974            }
975        }
976        panic!("Expected Return -> Filter -> NodeScan structure");
977    }
978
979    #[test]
980    fn test_optimizer_filter_pushdown_through_expand() {
981        // Query: MATCH (a:Person)-[:KNOWS]->(b) WHERE a.age > 30 RETURN b
982        // The filter on 'a' should be pushed before the expand
983
984        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
985            items: vec![ReturnItem {
986                expression: LogicalExpression::Variable("b".to_string()),
987                alias: None,
988            }],
989            distinct: false,
990            input: Box::new(LogicalOperator::Filter(FilterOp {
991                predicate: LogicalExpression::Binary {
992                    left: Box::new(LogicalExpression::Property {
993                        variable: "a".to_string(),
994                        property: "age".to_string(),
995                    }),
996                    op: BinaryOp::Gt,
997                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
998                },
999                input: Box::new(LogicalOperator::Expand(ExpandOp {
1000                    from_variable: "a".to_string(),
1001                    to_variable: "b".to_string(),
1002                    edge_variable: None,
1003                    direction: ExpandDirection::Outgoing,
1004                    edge_type: Some("KNOWS".to_string()),
1005                    min_hops: 1,
1006                    max_hops: Some(1),
1007                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1008                        variable: "a".to_string(),
1009                        label: Some("Person".to_string()),
1010                        input: None,
1011                    })),
1012                    path_alias: None,
1013                })),
1014            })),
1015        }));
1016
1017        let optimizer = Optimizer::new();
1018        let optimized = optimizer.optimize(plan).unwrap();
1019
1020        // Filter on 'a' should be pushed before the expand
1021        // Expected: Return -> Expand -> Filter -> NodeScan
1022        if let LogicalOperator::Return(ret) = &optimized.root {
1023            if let LogicalOperator::Expand(expand) = ret.input.as_ref() {
1024                if let LogicalOperator::Filter(filter) = expand.input.as_ref() {
1025                    if let LogicalOperator::NodeScan(scan) = filter.input.as_ref() {
1026                        assert_eq!(scan.variable, "a");
1027                        assert_eq!(expand.from_variable, "a");
1028                        assert_eq!(expand.to_variable, "b");
1029                        return;
1030                    }
1031                }
1032            }
1033        }
1034        panic!("Expected Return -> Expand -> Filter -> NodeScan structure");
1035    }
1036
1037    #[test]
1038    fn test_optimizer_filter_not_pushed_through_expand_for_target_var() {
1039        // Query: MATCH (a:Person)-[:KNOWS]->(b) WHERE b.age > 30 RETURN a
1040        // The filter on 'b' should NOT be pushed before the expand
1041
1042        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1043            items: vec![ReturnItem {
1044                expression: LogicalExpression::Variable("a".to_string()),
1045                alias: None,
1046            }],
1047            distinct: false,
1048            input: Box::new(LogicalOperator::Filter(FilterOp {
1049                predicate: LogicalExpression::Binary {
1050                    left: Box::new(LogicalExpression::Property {
1051                        variable: "b".to_string(),
1052                        property: "age".to_string(),
1053                    }),
1054                    op: BinaryOp::Gt,
1055                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1056                },
1057                input: Box::new(LogicalOperator::Expand(ExpandOp {
1058                    from_variable: "a".to_string(),
1059                    to_variable: "b".to_string(),
1060                    edge_variable: None,
1061                    direction: ExpandDirection::Outgoing,
1062                    edge_type: Some("KNOWS".to_string()),
1063                    min_hops: 1,
1064                    max_hops: Some(1),
1065                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1066                        variable: "a".to_string(),
1067                        label: Some("Person".to_string()),
1068                        input: None,
1069                    })),
1070                    path_alias: None,
1071                })),
1072            })),
1073        }));
1074
1075        let optimizer = Optimizer::new();
1076        let optimized = optimizer.optimize(plan).unwrap();
1077
1078        // Filter on 'b' should stay after the expand
1079        // Expected: Return -> Filter -> Expand -> NodeScan
1080        if let LogicalOperator::Return(ret) = &optimized.root {
1081            if let LogicalOperator::Filter(filter) = ret.input.as_ref() {
1082                // Check that the filter is on 'b'
1083                if let LogicalExpression::Binary { left, .. } = &filter.predicate {
1084                    if let LogicalExpression::Property { variable, .. } = left.as_ref() {
1085                        assert_eq!(variable, "b");
1086                    }
1087                }
1088
1089                if let LogicalOperator::Expand(expand) = filter.input.as_ref() {
1090                    if let LogicalOperator::NodeScan(_) = expand.input.as_ref() {
1091                        return;
1092                    }
1093                }
1094            }
1095        }
1096        panic!("Expected Return -> Filter -> Expand -> NodeScan structure");
1097    }
1098
1099    #[test]
1100    fn test_optimizer_extract_variables() {
1101        let optimizer = Optimizer::new();
1102
1103        let expr = LogicalExpression::Binary {
1104            left: Box::new(LogicalExpression::Property {
1105                variable: "n".to_string(),
1106                property: "age".to_string(),
1107            }),
1108            op: BinaryOp::Gt,
1109            right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1110        };
1111
1112        let vars = optimizer.extract_variables(&expr);
1113        assert_eq!(vars.len(), 1);
1114        assert!(vars.contains("n"));
1115    }
1116
1117    // Additional tests for optimizer configuration
1118
1119    #[test]
1120    fn test_optimizer_default() {
1121        let optimizer = Optimizer::default();
1122        // Should be able to optimize an empty plan
1123        let plan = LogicalPlan::new(LogicalOperator::Empty);
1124        let result = optimizer.optimize(plan);
1125        assert!(result.is_ok());
1126    }
1127
1128    #[test]
1129    fn test_optimizer_with_filter_pushdown_disabled() {
1130        let optimizer = Optimizer::new().with_filter_pushdown(false);
1131
1132        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1133            items: vec![ReturnItem {
1134                expression: LogicalExpression::Variable("n".to_string()),
1135                alias: None,
1136            }],
1137            distinct: false,
1138            input: Box::new(LogicalOperator::Filter(FilterOp {
1139                predicate: LogicalExpression::Literal(Value::Bool(true)),
1140                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1141                    variable: "n".to_string(),
1142                    label: None,
1143                    input: None,
1144                })),
1145            })),
1146        }));
1147
1148        let optimized = optimizer.optimize(plan).unwrap();
1149        // Structure should be unchanged
1150        if let LogicalOperator::Return(ret) = &optimized.root {
1151            if let LogicalOperator::Filter(_) = ret.input.as_ref() {
1152                return;
1153            }
1154        }
1155        panic!("Expected unchanged structure");
1156    }
1157
1158    #[test]
1159    fn test_optimizer_with_join_reorder_disabled() {
1160        let optimizer = Optimizer::new().with_join_reorder(false);
1161        assert!(
1162            optimizer
1163                .optimize(LogicalPlan::new(LogicalOperator::Empty))
1164                .is_ok()
1165        );
1166    }
1167
1168    #[test]
1169    fn test_optimizer_with_cost_model() {
1170        let cost_model = CostModel::new();
1171        let optimizer = Optimizer::new().with_cost_model(cost_model);
1172        assert!(
1173            optimizer
1174                .cost_model()
1175                .estimate(&LogicalOperator::Empty, 0.0)
1176                .total()
1177                < 0.001
1178        );
1179    }
1180
1181    #[test]
1182    fn test_optimizer_with_cardinality_estimator() {
1183        let mut estimator = CardinalityEstimator::new();
1184        estimator.add_table_stats("Test", TableStats::new(500));
1185        let optimizer = Optimizer::new().with_cardinality_estimator(estimator);
1186
1187        let scan = LogicalOperator::NodeScan(NodeScanOp {
1188            variable: "n".to_string(),
1189            label: Some("Test".to_string()),
1190            input: None,
1191        });
1192        let plan = LogicalPlan::new(scan);
1193
1194        let cardinality = optimizer.estimate_cardinality(&plan);
1195        assert!((cardinality - 500.0).abs() < 0.001);
1196    }
1197
1198    #[test]
1199    fn test_optimizer_estimate_cost() {
1200        let optimizer = Optimizer::new();
1201        let plan = LogicalPlan::new(LogicalOperator::NodeScan(NodeScanOp {
1202            variable: "n".to_string(),
1203            label: None,
1204            input: None,
1205        }));
1206
1207        let cost = optimizer.estimate_cost(&plan);
1208        assert!(cost.total() > 0.0);
1209    }
1210
1211    // Filter pushdown through various operators
1212
1213    #[test]
1214    fn test_filter_pushdown_through_project() {
1215        let optimizer = Optimizer::new();
1216
1217        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1218            predicate: LogicalExpression::Binary {
1219                left: Box::new(LogicalExpression::Property {
1220                    variable: "n".to_string(),
1221                    property: "age".to_string(),
1222                }),
1223                op: BinaryOp::Gt,
1224                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1225            },
1226            input: Box::new(LogicalOperator::Project(ProjectOp {
1227                projections: vec![Projection {
1228                    expression: LogicalExpression::Variable("n".to_string()),
1229                    alias: None,
1230                }],
1231                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1232                    variable: "n".to_string(),
1233                    label: None,
1234                    input: None,
1235                })),
1236            })),
1237        }));
1238
1239        let optimized = optimizer.optimize(plan).unwrap();
1240
1241        // Filter should be pushed through Project
1242        if let LogicalOperator::Project(proj) = &optimized.root {
1243            if let LogicalOperator::Filter(_) = proj.input.as_ref() {
1244                return;
1245            }
1246        }
1247        panic!("Expected Project -> Filter structure");
1248    }
1249
1250    #[test]
1251    fn test_filter_not_pushed_through_project_with_alias() {
1252        let optimizer = Optimizer::new();
1253
1254        // Filter on computed column 'x' should not be pushed through project that creates 'x'
1255        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1256            predicate: LogicalExpression::Binary {
1257                left: Box::new(LogicalExpression::Variable("x".to_string())),
1258                op: BinaryOp::Gt,
1259                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1260            },
1261            input: Box::new(LogicalOperator::Project(ProjectOp {
1262                projections: vec![Projection {
1263                    expression: LogicalExpression::Property {
1264                        variable: "n".to_string(),
1265                        property: "age".to_string(),
1266                    },
1267                    alias: Some("x".to_string()),
1268                }],
1269                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1270                    variable: "n".to_string(),
1271                    label: None,
1272                    input: None,
1273                })),
1274            })),
1275        }));
1276
1277        let optimized = optimizer.optimize(plan).unwrap();
1278
1279        // Filter should stay above Project
1280        if let LogicalOperator::Filter(filter) = &optimized.root {
1281            if let LogicalOperator::Project(_) = filter.input.as_ref() {
1282                return;
1283            }
1284        }
1285        panic!("Expected Filter -> Project structure");
1286    }
1287
1288    #[test]
1289    fn test_filter_pushdown_through_limit() {
1290        let optimizer = Optimizer::new();
1291
1292        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1293            predicate: LogicalExpression::Literal(Value::Bool(true)),
1294            input: Box::new(LogicalOperator::Limit(LimitOp {
1295                count: 10,
1296                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1297                    variable: "n".to_string(),
1298                    label: None,
1299                    input: None,
1300                })),
1301            })),
1302        }));
1303
1304        let optimized = optimizer.optimize(plan).unwrap();
1305
1306        // Filter stays above Limit (cannot be pushed through)
1307        if let LogicalOperator::Filter(filter) = &optimized.root {
1308            if let LogicalOperator::Limit(_) = filter.input.as_ref() {
1309                return;
1310            }
1311        }
1312        panic!("Expected Filter -> Limit structure");
1313    }
1314
1315    #[test]
1316    fn test_filter_pushdown_through_sort() {
1317        let optimizer = Optimizer::new();
1318
1319        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1320            predicate: LogicalExpression::Literal(Value::Bool(true)),
1321            input: Box::new(LogicalOperator::Sort(SortOp {
1322                keys: vec![SortKey {
1323                    expression: LogicalExpression::Variable("n".to_string()),
1324                    order: SortOrder::Ascending,
1325                }],
1326                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1327                    variable: "n".to_string(),
1328                    label: None,
1329                    input: None,
1330                })),
1331            })),
1332        }));
1333
1334        let optimized = optimizer.optimize(plan).unwrap();
1335
1336        // Filter stays above Sort
1337        if let LogicalOperator::Filter(filter) = &optimized.root {
1338            if let LogicalOperator::Sort(_) = filter.input.as_ref() {
1339                return;
1340            }
1341        }
1342        panic!("Expected Filter -> Sort structure");
1343    }
1344
1345    #[test]
1346    fn test_filter_pushdown_through_distinct() {
1347        let optimizer = Optimizer::new();
1348
1349        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1350            predicate: LogicalExpression::Literal(Value::Bool(true)),
1351            input: Box::new(LogicalOperator::Distinct(DistinctOp {
1352                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1353                    variable: "n".to_string(),
1354                    label: None,
1355                    input: None,
1356                })),
1357                columns: None,
1358            })),
1359        }));
1360
1361        let optimized = optimizer.optimize(plan).unwrap();
1362
1363        // Filter stays above Distinct
1364        if let LogicalOperator::Filter(filter) = &optimized.root {
1365            if let LogicalOperator::Distinct(_) = filter.input.as_ref() {
1366                return;
1367            }
1368        }
1369        panic!("Expected Filter -> Distinct structure");
1370    }
1371
1372    #[test]
1373    fn test_filter_not_pushed_through_aggregate() {
1374        let optimizer = Optimizer::new();
1375
1376        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1377            predicate: LogicalExpression::Binary {
1378                left: Box::new(LogicalExpression::Variable("cnt".to_string())),
1379                op: BinaryOp::Gt,
1380                right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1381            },
1382            input: Box::new(LogicalOperator::Aggregate(AggregateOp {
1383                group_by: vec![],
1384                aggregates: vec![AggregateExpr {
1385                    function: AggregateFunction::Count,
1386                    expression: None,
1387                    distinct: false,
1388                    alias: Some("cnt".to_string()),
1389                    percentile: None,
1390                }],
1391                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1392                    variable: "n".to_string(),
1393                    label: None,
1394                    input: None,
1395                })),
1396                having: None,
1397            })),
1398        }));
1399
1400        let optimized = optimizer.optimize(plan).unwrap();
1401
1402        // Filter should stay above Aggregate
1403        if let LogicalOperator::Filter(filter) = &optimized.root {
1404            if let LogicalOperator::Aggregate(_) = filter.input.as_ref() {
1405                return;
1406            }
1407        }
1408        panic!("Expected Filter -> Aggregate structure");
1409    }
1410
1411    #[test]
1412    fn test_filter_pushdown_to_left_join_side() {
1413        let optimizer = Optimizer::new();
1414
1415        // Filter on left variable should be pushed to left side
1416        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1417            predicate: LogicalExpression::Binary {
1418                left: Box::new(LogicalExpression::Property {
1419                    variable: "a".to_string(),
1420                    property: "age".to_string(),
1421                }),
1422                op: BinaryOp::Gt,
1423                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1424            },
1425            input: Box::new(LogicalOperator::Join(JoinOp {
1426                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1427                    variable: "a".to_string(),
1428                    label: Some("Person".to_string()),
1429                    input: None,
1430                })),
1431                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1432                    variable: "b".to_string(),
1433                    label: Some("Company".to_string()),
1434                    input: None,
1435                })),
1436                join_type: JoinType::Inner,
1437                conditions: vec![],
1438            })),
1439        }));
1440
1441        let optimized = optimizer.optimize(plan).unwrap();
1442
1443        // Filter should be pushed to left side of join
1444        if let LogicalOperator::Join(join) = &optimized.root {
1445            if let LogicalOperator::Filter(_) = join.left.as_ref() {
1446                return;
1447            }
1448        }
1449        panic!("Expected Join with Filter on left side");
1450    }
1451
1452    #[test]
1453    fn test_filter_pushdown_to_right_join_side() {
1454        let optimizer = Optimizer::new();
1455
1456        // Filter on right variable should be pushed to right side
1457        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1458            predicate: LogicalExpression::Binary {
1459                left: Box::new(LogicalExpression::Property {
1460                    variable: "b".to_string(),
1461                    property: "name".to_string(),
1462                }),
1463                op: BinaryOp::Eq,
1464                right: Box::new(LogicalExpression::Literal(Value::String("Acme".into()))),
1465            },
1466            input: Box::new(LogicalOperator::Join(JoinOp {
1467                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1468                    variable: "a".to_string(),
1469                    label: Some("Person".to_string()),
1470                    input: None,
1471                })),
1472                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1473                    variable: "b".to_string(),
1474                    label: Some("Company".to_string()),
1475                    input: None,
1476                })),
1477                join_type: JoinType::Inner,
1478                conditions: vec![],
1479            })),
1480        }));
1481
1482        let optimized = optimizer.optimize(plan).unwrap();
1483
1484        // Filter should be pushed to right side of join
1485        if let LogicalOperator::Join(join) = &optimized.root {
1486            if let LogicalOperator::Filter(_) = join.right.as_ref() {
1487                return;
1488            }
1489        }
1490        panic!("Expected Join with Filter on right side");
1491    }
1492
1493    #[test]
1494    fn test_filter_not_pushed_when_uses_both_join_sides() {
1495        let optimizer = Optimizer::new();
1496
1497        // Filter using both variables should stay above join
1498        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1499            predicate: LogicalExpression::Binary {
1500                left: Box::new(LogicalExpression::Property {
1501                    variable: "a".to_string(),
1502                    property: "id".to_string(),
1503                }),
1504                op: BinaryOp::Eq,
1505                right: Box::new(LogicalExpression::Property {
1506                    variable: "b".to_string(),
1507                    property: "a_id".to_string(),
1508                }),
1509            },
1510            input: Box::new(LogicalOperator::Join(JoinOp {
1511                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1512                    variable: "a".to_string(),
1513                    label: None,
1514                    input: None,
1515                })),
1516                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1517                    variable: "b".to_string(),
1518                    label: None,
1519                    input: None,
1520                })),
1521                join_type: JoinType::Inner,
1522                conditions: vec![],
1523            })),
1524        }));
1525
1526        let optimized = optimizer.optimize(plan).unwrap();
1527
1528        // Filter should stay above join
1529        if let LogicalOperator::Filter(filter) = &optimized.root {
1530            if let LogicalOperator::Join(_) = filter.input.as_ref() {
1531                return;
1532            }
1533        }
1534        panic!("Expected Filter -> Join structure");
1535    }
1536
1537    // Variable extraction tests
1538
1539    #[test]
1540    fn test_extract_variables_from_variable() {
1541        let optimizer = Optimizer::new();
1542        let expr = LogicalExpression::Variable("x".to_string());
1543        let vars = optimizer.extract_variables(&expr);
1544        assert_eq!(vars.len(), 1);
1545        assert!(vars.contains("x"));
1546    }
1547
1548    #[test]
1549    fn test_extract_variables_from_unary() {
1550        let optimizer = Optimizer::new();
1551        let expr = LogicalExpression::Unary {
1552            op: UnaryOp::Not,
1553            operand: Box::new(LogicalExpression::Variable("x".to_string())),
1554        };
1555        let vars = optimizer.extract_variables(&expr);
1556        assert_eq!(vars.len(), 1);
1557        assert!(vars.contains("x"));
1558    }
1559
1560    #[test]
1561    fn test_extract_variables_from_function_call() {
1562        let optimizer = Optimizer::new();
1563        let expr = LogicalExpression::FunctionCall {
1564            name: "length".to_string(),
1565            args: vec![
1566                LogicalExpression::Variable("a".to_string()),
1567                LogicalExpression::Variable("b".to_string()),
1568            ],
1569            distinct: false,
1570        };
1571        let vars = optimizer.extract_variables(&expr);
1572        assert_eq!(vars.len(), 2);
1573        assert!(vars.contains("a"));
1574        assert!(vars.contains("b"));
1575    }
1576
1577    #[test]
1578    fn test_extract_variables_from_list() {
1579        let optimizer = Optimizer::new();
1580        let expr = LogicalExpression::List(vec![
1581            LogicalExpression::Variable("a".to_string()),
1582            LogicalExpression::Literal(Value::Int64(1)),
1583            LogicalExpression::Variable("b".to_string()),
1584        ]);
1585        let vars = optimizer.extract_variables(&expr);
1586        assert_eq!(vars.len(), 2);
1587        assert!(vars.contains("a"));
1588        assert!(vars.contains("b"));
1589    }
1590
1591    #[test]
1592    fn test_extract_variables_from_map() {
1593        let optimizer = Optimizer::new();
1594        let expr = LogicalExpression::Map(vec![
1595            (
1596                "key1".to_string(),
1597                LogicalExpression::Variable("a".to_string()),
1598            ),
1599            (
1600                "key2".to_string(),
1601                LogicalExpression::Variable("b".to_string()),
1602            ),
1603        ]);
1604        let vars = optimizer.extract_variables(&expr);
1605        assert_eq!(vars.len(), 2);
1606        assert!(vars.contains("a"));
1607        assert!(vars.contains("b"));
1608    }
1609
1610    #[test]
1611    fn test_extract_variables_from_index_access() {
1612        let optimizer = Optimizer::new();
1613        let expr = LogicalExpression::IndexAccess {
1614            base: Box::new(LogicalExpression::Variable("list".to_string())),
1615            index: Box::new(LogicalExpression::Variable("idx".to_string())),
1616        };
1617        let vars = optimizer.extract_variables(&expr);
1618        assert_eq!(vars.len(), 2);
1619        assert!(vars.contains("list"));
1620        assert!(vars.contains("idx"));
1621    }
1622
1623    #[test]
1624    fn test_extract_variables_from_slice_access() {
1625        let optimizer = Optimizer::new();
1626        let expr = LogicalExpression::SliceAccess {
1627            base: Box::new(LogicalExpression::Variable("list".to_string())),
1628            start: Some(Box::new(LogicalExpression::Variable("s".to_string()))),
1629            end: Some(Box::new(LogicalExpression::Variable("e".to_string()))),
1630        };
1631        let vars = optimizer.extract_variables(&expr);
1632        assert_eq!(vars.len(), 3);
1633        assert!(vars.contains("list"));
1634        assert!(vars.contains("s"));
1635        assert!(vars.contains("e"));
1636    }
1637
1638    #[test]
1639    fn test_extract_variables_from_case() {
1640        let optimizer = Optimizer::new();
1641        let expr = LogicalExpression::Case {
1642            operand: Some(Box::new(LogicalExpression::Variable("x".to_string()))),
1643            when_clauses: vec![(
1644                LogicalExpression::Literal(Value::Int64(1)),
1645                LogicalExpression::Variable("a".to_string()),
1646            )],
1647            else_clause: Some(Box::new(LogicalExpression::Variable("b".to_string()))),
1648        };
1649        let vars = optimizer.extract_variables(&expr);
1650        assert_eq!(vars.len(), 3);
1651        assert!(vars.contains("x"));
1652        assert!(vars.contains("a"));
1653        assert!(vars.contains("b"));
1654    }
1655
1656    #[test]
1657    fn test_extract_variables_from_labels() {
1658        let optimizer = Optimizer::new();
1659        let expr = LogicalExpression::Labels("n".to_string());
1660        let vars = optimizer.extract_variables(&expr);
1661        assert_eq!(vars.len(), 1);
1662        assert!(vars.contains("n"));
1663    }
1664
1665    #[test]
1666    fn test_extract_variables_from_type() {
1667        let optimizer = Optimizer::new();
1668        let expr = LogicalExpression::Type("e".to_string());
1669        let vars = optimizer.extract_variables(&expr);
1670        assert_eq!(vars.len(), 1);
1671        assert!(vars.contains("e"));
1672    }
1673
1674    #[test]
1675    fn test_extract_variables_from_id() {
1676        let optimizer = Optimizer::new();
1677        let expr = LogicalExpression::Id("n".to_string());
1678        let vars = optimizer.extract_variables(&expr);
1679        assert_eq!(vars.len(), 1);
1680        assert!(vars.contains("n"));
1681    }
1682
1683    #[test]
1684    fn test_extract_variables_from_list_comprehension() {
1685        let optimizer = Optimizer::new();
1686        let expr = LogicalExpression::ListComprehension {
1687            variable: "x".to_string(),
1688            list_expr: Box::new(LogicalExpression::Variable("items".to_string())),
1689            filter_expr: Some(Box::new(LogicalExpression::Variable("pred".to_string()))),
1690            map_expr: Box::new(LogicalExpression::Variable("result".to_string())),
1691        };
1692        let vars = optimizer.extract_variables(&expr);
1693        assert!(vars.contains("items"));
1694        assert!(vars.contains("pred"));
1695        assert!(vars.contains("result"));
1696    }
1697
1698    #[test]
1699    fn test_extract_variables_from_literal_and_parameter() {
1700        let optimizer = Optimizer::new();
1701
1702        let literal = LogicalExpression::Literal(Value::Int64(42));
1703        assert!(optimizer.extract_variables(&literal).is_empty());
1704
1705        let param = LogicalExpression::Parameter("p".to_string());
1706        assert!(optimizer.extract_variables(&param).is_empty());
1707    }
1708
1709    // Recursive filter pushdown tests
1710
1711    #[test]
1712    fn test_recursive_filter_pushdown_through_skip() {
1713        let optimizer = Optimizer::new();
1714
1715        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1716            items: vec![ReturnItem {
1717                expression: LogicalExpression::Variable("n".to_string()),
1718                alias: None,
1719            }],
1720            distinct: false,
1721            input: Box::new(LogicalOperator::Filter(FilterOp {
1722                predicate: LogicalExpression::Literal(Value::Bool(true)),
1723                input: Box::new(LogicalOperator::Skip(SkipOp {
1724                    count: 5,
1725                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1726                        variable: "n".to_string(),
1727                        label: None,
1728                        input: None,
1729                    })),
1730                })),
1731            })),
1732        }));
1733
1734        let optimized = optimizer.optimize(plan).unwrap();
1735
1736        // Verify optimization succeeded
1737        assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1738    }
1739
1740    #[test]
1741    fn test_nested_filter_pushdown() {
1742        let optimizer = Optimizer::new();
1743
1744        // Multiple nested filters
1745        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1746            items: vec![ReturnItem {
1747                expression: LogicalExpression::Variable("n".to_string()),
1748                alias: None,
1749            }],
1750            distinct: false,
1751            input: Box::new(LogicalOperator::Filter(FilterOp {
1752                predicate: LogicalExpression::Binary {
1753                    left: Box::new(LogicalExpression::Property {
1754                        variable: "n".to_string(),
1755                        property: "x".to_string(),
1756                    }),
1757                    op: BinaryOp::Gt,
1758                    right: Box::new(LogicalExpression::Literal(Value::Int64(1))),
1759                },
1760                input: Box::new(LogicalOperator::Filter(FilterOp {
1761                    predicate: LogicalExpression::Binary {
1762                        left: Box::new(LogicalExpression::Property {
1763                            variable: "n".to_string(),
1764                            property: "y".to_string(),
1765                        }),
1766                        op: BinaryOp::Lt,
1767                        right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1768                    },
1769                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1770                        variable: "n".to_string(),
1771                        label: None,
1772                        input: None,
1773                    })),
1774                })),
1775            })),
1776        }));
1777
1778        let optimized = optimizer.optimize(plan).unwrap();
1779        assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1780    }
1781}