Skip to main content

grafeo_engine/query/optimizer/
mod.rs

1//! Makes your queries faster without changing their meaning.
2//!
3//! The optimizer transforms logical plans to run more efficiently:
4//!
5//! | Optimization | What it does |
6//! | ------------ | ------------ |
7//! | Filter Pushdown | Moves `WHERE` clauses closer to scans - filter early, process less |
8//! | Join Reordering | Picks the best order to join tables using the DPccp algorithm |
9//! | Predicate Simplification | Folds constants like `1 + 1` into `2` |
10//!
11//! The optimizer uses [`CostModel`] and [`CardinalityEstimator`] to predict
12//! how expensive different plans are, then picks the cheapest.
13
14pub mod cardinality;
15pub mod cost;
16pub mod join_order;
17
18pub use cardinality::{CardinalityEstimator, ColumnStats, TableStats};
19pub use cost::{Cost, CostModel};
20pub use join_order::{BitSet, DPccp, JoinGraph, JoinGraphBuilder, JoinPlan};
21
22use crate::query::plan::{FilterOp, LogicalExpression, LogicalOperator, LogicalPlan};
23use grafeo_common::utils::error::Result;
24use std::collections::HashSet;
25
26/// Information about a join condition for join reordering.
27#[derive(Debug, Clone)]
28struct JoinInfo {
29    left_var: String,
30    right_var: String,
31    left_expr: LogicalExpression,
32    right_expr: LogicalExpression,
33}
34
35/// A column required by the query, used for projection pushdown.
36#[derive(Debug, Clone, PartialEq, Eq, Hash)]
37enum RequiredColumn {
38    /// A variable (node, edge, or path binding)
39    Variable(String),
40    /// A specific property of a variable
41    Property(String, String),
42}
43
44/// Transforms logical plans for faster execution.
45///
46/// Create with [`new()`](Self::new), then call [`optimize()`](Self::optimize).
47/// Use the builder methods to enable/disable specific optimizations.
48pub struct Optimizer {
49    /// Whether to enable filter pushdown.
50    enable_filter_pushdown: bool,
51    /// Whether to enable join reordering.
52    enable_join_reorder: bool,
53    /// Whether to enable projection pushdown.
54    enable_projection_pushdown: bool,
55    /// Cost model for estimation.
56    cost_model: CostModel,
57    /// Cardinality estimator.
58    card_estimator: CardinalityEstimator,
59}
60
61impl Optimizer {
62    /// Creates a new optimizer with default settings.
63    #[must_use]
64    pub fn new() -> Self {
65        Self {
66            enable_filter_pushdown: true,
67            enable_join_reorder: true,
68            enable_projection_pushdown: true,
69            cost_model: CostModel::new(),
70            card_estimator: CardinalityEstimator::new(),
71        }
72    }
73
74    /// Creates an optimizer with cardinality estimates from the store's statistics.
75    ///
76    /// Pre-populates the cardinality estimator with per-label row counts and
77    /// edge type fanout. Ensures statistics are fresh before reading.
78    #[must_use]
79    pub fn from_store(store: &grafeo_core::graph::lpg::LpgStore) -> Self {
80        store.ensure_statistics_fresh();
81        let stats = store.statistics();
82        let estimator = CardinalityEstimator::from_statistics(&stats);
83        Self {
84            enable_filter_pushdown: true,
85            enable_join_reorder: true,
86            enable_projection_pushdown: true,
87            cost_model: CostModel::new(),
88            card_estimator: estimator,
89        }
90    }
91
92    /// Enables or disables filter pushdown.
93    pub fn with_filter_pushdown(mut self, enabled: bool) -> Self {
94        self.enable_filter_pushdown = enabled;
95        self
96    }
97
98    /// Enables or disables join reordering.
99    pub fn with_join_reorder(mut self, enabled: bool) -> Self {
100        self.enable_join_reorder = enabled;
101        self
102    }
103
104    /// Enables or disables projection pushdown.
105    pub fn with_projection_pushdown(mut self, enabled: bool) -> Self {
106        self.enable_projection_pushdown = enabled;
107        self
108    }
109
110    /// Sets the cost model.
111    pub fn with_cost_model(mut self, cost_model: CostModel) -> Self {
112        self.cost_model = cost_model;
113        self
114    }
115
116    /// Sets the cardinality estimator.
117    pub fn with_cardinality_estimator(mut self, estimator: CardinalityEstimator) -> Self {
118        self.card_estimator = estimator;
119        self
120    }
121
122    /// Returns a reference to the cost model.
123    pub fn cost_model(&self) -> &CostModel {
124        &self.cost_model
125    }
126
127    /// Returns a reference to the cardinality estimator.
128    pub fn cardinality_estimator(&self) -> &CardinalityEstimator {
129        &self.card_estimator
130    }
131
132    /// Estimates the cost of a plan.
133    pub fn estimate_cost(&self, plan: &LogicalPlan) -> Cost {
134        let cardinality = self.card_estimator.estimate(&plan.root);
135        self.cost_model.estimate(&plan.root, cardinality)
136    }
137
138    /// Estimates the cardinality of a plan.
139    pub fn estimate_cardinality(&self, plan: &LogicalPlan) -> f64 {
140        self.card_estimator.estimate(&plan.root)
141    }
142
143    /// Optimizes a logical plan.
144    ///
145    /// # Errors
146    ///
147    /// Returns an error if optimization fails.
148    pub fn optimize(&self, plan: LogicalPlan) -> Result<LogicalPlan> {
149        let mut root = plan.root;
150
151        // Apply optimization rules
152        if self.enable_filter_pushdown {
153            root = self.push_filters_down(root);
154        }
155
156        if self.enable_join_reorder {
157            root = self.reorder_joins(root);
158        }
159
160        if self.enable_projection_pushdown {
161            root = self.push_projections_down(root);
162        }
163
164        Ok(LogicalPlan::new(root))
165    }
166
167    /// Pushes projections down the operator tree to eliminate unused columns early.
168    ///
169    /// This optimization:
170    /// 1. Collects required variables/properties from the root
171    /// 2. Propagates requirements down through the tree
172    /// 3. Inserts projections to eliminate unneeded columns before expensive operations
173    fn push_projections_down(&self, op: LogicalOperator) -> LogicalOperator {
174        // Collect required columns from the top of the plan
175        let required = self.collect_required_columns(&op);
176
177        // Push projections down
178        self.push_projections_recursive(op, &required)
179    }
180
181    /// Collects all variables and properties required by an operator and its ancestors.
182    fn collect_required_columns(&self, op: &LogicalOperator) -> HashSet<RequiredColumn> {
183        let mut required = HashSet::new();
184        Self::collect_required_recursive(op, &mut required);
185        required
186    }
187
188    /// Recursively collects required columns.
189    fn collect_required_recursive(op: &LogicalOperator, required: &mut HashSet<RequiredColumn>) {
190        match op {
191            LogicalOperator::Return(ret) => {
192                for item in &ret.items {
193                    Self::collect_from_expression(&item.expression, required);
194                }
195                Self::collect_required_recursive(&ret.input, required);
196            }
197            LogicalOperator::Project(proj) => {
198                for p in &proj.projections {
199                    Self::collect_from_expression(&p.expression, required);
200                }
201                Self::collect_required_recursive(&proj.input, required);
202            }
203            LogicalOperator::Filter(filter) => {
204                Self::collect_from_expression(&filter.predicate, required);
205                Self::collect_required_recursive(&filter.input, required);
206            }
207            LogicalOperator::Sort(sort) => {
208                for key in &sort.keys {
209                    Self::collect_from_expression(&key.expression, required);
210                }
211                Self::collect_required_recursive(&sort.input, required);
212            }
213            LogicalOperator::Aggregate(agg) => {
214                for expr in &agg.group_by {
215                    Self::collect_from_expression(expr, required);
216                }
217                for agg_expr in &agg.aggregates {
218                    if let Some(ref expr) = agg_expr.expression {
219                        Self::collect_from_expression(expr, required);
220                    }
221                }
222                if let Some(ref having) = agg.having {
223                    Self::collect_from_expression(having, required);
224                }
225                Self::collect_required_recursive(&agg.input, required);
226            }
227            LogicalOperator::Join(join) => {
228                for cond in &join.conditions {
229                    Self::collect_from_expression(&cond.left, required);
230                    Self::collect_from_expression(&cond.right, required);
231                }
232                Self::collect_required_recursive(&join.left, required);
233                Self::collect_required_recursive(&join.right, required);
234            }
235            LogicalOperator::Expand(expand) => {
236                // The source and target variables are needed
237                required.insert(RequiredColumn::Variable(expand.from_variable.clone()));
238                required.insert(RequiredColumn::Variable(expand.to_variable.clone()));
239                if let Some(ref edge_var) = expand.edge_variable {
240                    required.insert(RequiredColumn::Variable(edge_var.clone()));
241                }
242                Self::collect_required_recursive(&expand.input, required);
243            }
244            LogicalOperator::Limit(limit) => {
245                Self::collect_required_recursive(&limit.input, required);
246            }
247            LogicalOperator::Skip(skip) => {
248                Self::collect_required_recursive(&skip.input, required);
249            }
250            LogicalOperator::Distinct(distinct) => {
251                Self::collect_required_recursive(&distinct.input, required);
252            }
253            LogicalOperator::NodeScan(scan) => {
254                required.insert(RequiredColumn::Variable(scan.variable.clone()));
255            }
256            LogicalOperator::EdgeScan(scan) => {
257                required.insert(RequiredColumn::Variable(scan.variable.clone()));
258            }
259            _ => {}
260        }
261    }
262
263    /// Collects required columns from an expression.
264    fn collect_from_expression(expr: &LogicalExpression, required: &mut HashSet<RequiredColumn>) {
265        match expr {
266            LogicalExpression::Variable(var) => {
267                required.insert(RequiredColumn::Variable(var.clone()));
268            }
269            LogicalExpression::Property { variable, property } => {
270                required.insert(RequiredColumn::Property(variable.clone(), property.clone()));
271                required.insert(RequiredColumn::Variable(variable.clone()));
272            }
273            LogicalExpression::Binary { left, right, .. } => {
274                Self::collect_from_expression(left, required);
275                Self::collect_from_expression(right, required);
276            }
277            LogicalExpression::Unary { operand, .. } => {
278                Self::collect_from_expression(operand, required);
279            }
280            LogicalExpression::FunctionCall { args, .. } => {
281                for arg in args {
282                    Self::collect_from_expression(arg, required);
283                }
284            }
285            LogicalExpression::List(items) => {
286                for item in items {
287                    Self::collect_from_expression(item, required);
288                }
289            }
290            LogicalExpression::Map(pairs) => {
291                for (_, value) in pairs {
292                    Self::collect_from_expression(value, required);
293                }
294            }
295            LogicalExpression::IndexAccess { base, index } => {
296                Self::collect_from_expression(base, required);
297                Self::collect_from_expression(index, required);
298            }
299            LogicalExpression::SliceAccess { base, start, end } => {
300                Self::collect_from_expression(base, required);
301                if let Some(s) = start {
302                    Self::collect_from_expression(s, required);
303                }
304                if let Some(e) = end {
305                    Self::collect_from_expression(e, required);
306                }
307            }
308            LogicalExpression::Case {
309                operand,
310                when_clauses,
311                else_clause,
312            } => {
313                if let Some(op) = operand {
314                    Self::collect_from_expression(op, required);
315                }
316                for (cond, result) in when_clauses {
317                    Self::collect_from_expression(cond, required);
318                    Self::collect_from_expression(result, required);
319                }
320                if let Some(else_expr) = else_clause {
321                    Self::collect_from_expression(else_expr, required);
322                }
323            }
324            LogicalExpression::Labels(var)
325            | LogicalExpression::Type(var)
326            | LogicalExpression::Id(var) => {
327                required.insert(RequiredColumn::Variable(var.clone()));
328            }
329            LogicalExpression::ListComprehension {
330                list_expr,
331                filter_expr,
332                map_expr,
333                ..
334            } => {
335                Self::collect_from_expression(list_expr, required);
336                if let Some(filter) = filter_expr {
337                    Self::collect_from_expression(filter, required);
338                }
339                Self::collect_from_expression(map_expr, required);
340            }
341            _ => {}
342        }
343    }
344
345    /// Recursively pushes projections down, adding them before expensive operations.
346    fn push_projections_recursive(
347        &self,
348        op: LogicalOperator,
349        required: &HashSet<RequiredColumn>,
350    ) -> LogicalOperator {
351        match op {
352            LogicalOperator::Return(mut ret) => {
353                ret.input = Box::new(self.push_projections_recursive(*ret.input, required));
354                LogicalOperator::Return(ret)
355            }
356            LogicalOperator::Project(mut proj) => {
357                proj.input = Box::new(self.push_projections_recursive(*proj.input, required));
358                LogicalOperator::Project(proj)
359            }
360            LogicalOperator::Filter(mut filter) => {
361                filter.input = Box::new(self.push_projections_recursive(*filter.input, required));
362                LogicalOperator::Filter(filter)
363            }
364            LogicalOperator::Sort(mut sort) => {
365                // Sort is expensive - consider adding a projection before it
366                // to reduce tuple width
367                sort.input = Box::new(self.push_projections_recursive(*sort.input, required));
368                LogicalOperator::Sort(sort)
369            }
370            LogicalOperator::Aggregate(mut agg) => {
371                agg.input = Box::new(self.push_projections_recursive(*agg.input, required));
372                LogicalOperator::Aggregate(agg)
373            }
374            LogicalOperator::Join(mut join) => {
375                // Joins are expensive - the required columns help determine
376                // what to project on each side
377                let left_vars = self.collect_output_variables(&join.left);
378                let right_vars = self.collect_output_variables(&join.right);
379
380                // Filter required columns to each side
381                let left_required: HashSet<_> = required
382                    .iter()
383                    .filter(|c| match c {
384                        RequiredColumn::Variable(v) => left_vars.contains(v),
385                        RequiredColumn::Property(v, _) => left_vars.contains(v),
386                    })
387                    .cloned()
388                    .collect();
389
390                let right_required: HashSet<_> = required
391                    .iter()
392                    .filter(|c| match c {
393                        RequiredColumn::Variable(v) => right_vars.contains(v),
394                        RequiredColumn::Property(v, _) => right_vars.contains(v),
395                    })
396                    .cloned()
397                    .collect();
398
399                join.left = Box::new(self.push_projections_recursive(*join.left, &left_required));
400                join.right =
401                    Box::new(self.push_projections_recursive(*join.right, &right_required));
402                LogicalOperator::Join(join)
403            }
404            LogicalOperator::Expand(mut expand) => {
405                expand.input = Box::new(self.push_projections_recursive(*expand.input, required));
406                LogicalOperator::Expand(expand)
407            }
408            LogicalOperator::Limit(mut limit) => {
409                limit.input = Box::new(self.push_projections_recursive(*limit.input, required));
410                LogicalOperator::Limit(limit)
411            }
412            LogicalOperator::Skip(mut skip) => {
413                skip.input = Box::new(self.push_projections_recursive(*skip.input, required));
414                LogicalOperator::Skip(skip)
415            }
416            LogicalOperator::Distinct(mut distinct) => {
417                distinct.input =
418                    Box::new(self.push_projections_recursive(*distinct.input, required));
419                LogicalOperator::Distinct(distinct)
420            }
421            other => other,
422        }
423    }
424
425    /// Reorders joins in the operator tree using the DPccp algorithm.
426    ///
427    /// This optimization finds the optimal join order by:
428    /// 1. Extracting all base relations (scans) and join conditions
429    /// 2. Building a join graph
430    /// 3. Using dynamic programming to find the cheapest join order
431    fn reorder_joins(&self, op: LogicalOperator) -> LogicalOperator {
432        // First, recursively optimize children
433        let op = self.reorder_joins_recursive(op);
434
435        // Then, if this is a join tree, try to optimize it
436        if let Some((relations, conditions)) = self.extract_join_tree(&op) {
437            if relations.len() >= 2 {
438                if let Some(optimized) = self.optimize_join_order(&relations, &conditions) {
439                    return optimized;
440                }
441            }
442        }
443
444        op
445    }
446
447    /// Recursively applies join reordering to child operators.
448    fn reorder_joins_recursive(&self, op: LogicalOperator) -> LogicalOperator {
449        match op {
450            LogicalOperator::Return(mut ret) => {
451                ret.input = Box::new(self.reorder_joins(*ret.input));
452                LogicalOperator::Return(ret)
453            }
454            LogicalOperator::Project(mut proj) => {
455                proj.input = Box::new(self.reorder_joins(*proj.input));
456                LogicalOperator::Project(proj)
457            }
458            LogicalOperator::Filter(mut filter) => {
459                filter.input = Box::new(self.reorder_joins(*filter.input));
460                LogicalOperator::Filter(filter)
461            }
462            LogicalOperator::Limit(mut limit) => {
463                limit.input = Box::new(self.reorder_joins(*limit.input));
464                LogicalOperator::Limit(limit)
465            }
466            LogicalOperator::Skip(mut skip) => {
467                skip.input = Box::new(self.reorder_joins(*skip.input));
468                LogicalOperator::Skip(skip)
469            }
470            LogicalOperator::Sort(mut sort) => {
471                sort.input = Box::new(self.reorder_joins(*sort.input));
472                LogicalOperator::Sort(sort)
473            }
474            LogicalOperator::Distinct(mut distinct) => {
475                distinct.input = Box::new(self.reorder_joins(*distinct.input));
476                LogicalOperator::Distinct(distinct)
477            }
478            LogicalOperator::Aggregate(mut agg) => {
479                agg.input = Box::new(self.reorder_joins(*agg.input));
480                LogicalOperator::Aggregate(agg)
481            }
482            LogicalOperator::Expand(mut expand) => {
483                expand.input = Box::new(self.reorder_joins(*expand.input));
484                LogicalOperator::Expand(expand)
485            }
486            // Join operators are handled by the parent reorder_joins call
487            other => other,
488        }
489    }
490
491    /// Extracts base relations and join conditions from a join tree.
492    ///
493    /// Returns None if the operator is not a join tree.
494    fn extract_join_tree(
495        &self,
496        op: &LogicalOperator,
497    ) -> Option<(Vec<(String, LogicalOperator)>, Vec<JoinInfo>)> {
498        let mut relations = Vec::new();
499        let mut join_conditions = Vec::new();
500
501        if !self.collect_join_tree(op, &mut relations, &mut join_conditions) {
502            return None;
503        }
504
505        if relations.len() < 2 {
506            return None;
507        }
508
509        Some((relations, join_conditions))
510    }
511
512    /// Recursively collects base relations and join conditions.
513    ///
514    /// Returns true if this subtree is part of a join tree.
515    fn collect_join_tree(
516        &self,
517        op: &LogicalOperator,
518        relations: &mut Vec<(String, LogicalOperator)>,
519        conditions: &mut Vec<JoinInfo>,
520    ) -> bool {
521        match op {
522            LogicalOperator::Join(join) => {
523                // Collect from both sides
524                let left_ok = self.collect_join_tree(&join.left, relations, conditions);
525                let right_ok = self.collect_join_tree(&join.right, relations, conditions);
526
527                // Add conditions from this join
528                for cond in &join.conditions {
529                    if let (Some(left_var), Some(right_var)) = (
530                        self.extract_variable_from_expr(&cond.left),
531                        self.extract_variable_from_expr(&cond.right),
532                    ) {
533                        conditions.push(JoinInfo {
534                            left_var,
535                            right_var,
536                            left_expr: cond.left.clone(),
537                            right_expr: cond.right.clone(),
538                        });
539                    }
540                }
541
542                left_ok && right_ok
543            }
544            LogicalOperator::NodeScan(scan) => {
545                relations.push((scan.variable.clone(), op.clone()));
546                true
547            }
548            LogicalOperator::EdgeScan(scan) => {
549                relations.push((scan.variable.clone(), op.clone()));
550                true
551            }
552            LogicalOperator::Filter(filter) => {
553                // A filter on a base relation is still part of the join tree
554                self.collect_join_tree(&filter.input, relations, conditions)
555            }
556            LogicalOperator::Expand(expand) => {
557                // Expand is a special case - it's like a join with the adjacency
558                // For now, treat the whole Expand subtree as a single relation
559                relations.push((expand.to_variable.clone(), op.clone()));
560                true
561            }
562            _ => false,
563        }
564    }
565
566    /// Extracts the primary variable from an expression.
567    fn extract_variable_from_expr(&self, expr: &LogicalExpression) -> Option<String> {
568        match expr {
569            LogicalExpression::Variable(v) => Some(v.clone()),
570            LogicalExpression::Property { variable, .. } => Some(variable.clone()),
571            _ => None,
572        }
573    }
574
575    /// Optimizes the join order using DPccp.
576    fn optimize_join_order(
577        &self,
578        relations: &[(String, LogicalOperator)],
579        conditions: &[JoinInfo],
580    ) -> Option<LogicalOperator> {
581        use join_order::{DPccp, JoinGraphBuilder};
582
583        // Build the join graph
584        let mut builder = JoinGraphBuilder::new();
585
586        for (var, relation) in relations {
587            builder.add_relation(var, relation.clone());
588        }
589
590        for cond in conditions {
591            builder.add_join_condition(
592                &cond.left_var,
593                &cond.right_var,
594                cond.left_expr.clone(),
595                cond.right_expr.clone(),
596            );
597        }
598
599        let graph = builder.build();
600
601        // Run DPccp
602        let mut dpccp = DPccp::new(&graph, &self.cost_model, &self.card_estimator);
603        let plan = dpccp.optimize()?;
604
605        Some(plan.operator)
606    }
607
608    /// Pushes filters down the operator tree.
609    ///
610    /// This optimization moves filter predicates as close to the data source
611    /// as possible to reduce the amount of data processed by upper operators.
612    fn push_filters_down(&self, op: LogicalOperator) -> LogicalOperator {
613        match op {
614            // For Filter operators, try to push the predicate into the child
615            LogicalOperator::Filter(filter) => {
616                let optimized_input = self.push_filters_down(*filter.input);
617                self.try_push_filter_into(filter.predicate, optimized_input)
618            }
619            // Recursively optimize children for other operators
620            LogicalOperator::Return(mut ret) => {
621                ret.input = Box::new(self.push_filters_down(*ret.input));
622                LogicalOperator::Return(ret)
623            }
624            LogicalOperator::Project(mut proj) => {
625                proj.input = Box::new(self.push_filters_down(*proj.input));
626                LogicalOperator::Project(proj)
627            }
628            LogicalOperator::Limit(mut limit) => {
629                limit.input = Box::new(self.push_filters_down(*limit.input));
630                LogicalOperator::Limit(limit)
631            }
632            LogicalOperator::Skip(mut skip) => {
633                skip.input = Box::new(self.push_filters_down(*skip.input));
634                LogicalOperator::Skip(skip)
635            }
636            LogicalOperator::Sort(mut sort) => {
637                sort.input = Box::new(self.push_filters_down(*sort.input));
638                LogicalOperator::Sort(sort)
639            }
640            LogicalOperator::Distinct(mut distinct) => {
641                distinct.input = Box::new(self.push_filters_down(*distinct.input));
642                LogicalOperator::Distinct(distinct)
643            }
644            LogicalOperator::Expand(mut expand) => {
645                expand.input = Box::new(self.push_filters_down(*expand.input));
646                LogicalOperator::Expand(expand)
647            }
648            LogicalOperator::Join(mut join) => {
649                join.left = Box::new(self.push_filters_down(*join.left));
650                join.right = Box::new(self.push_filters_down(*join.right));
651                LogicalOperator::Join(join)
652            }
653            LogicalOperator::Aggregate(mut agg) => {
654                agg.input = Box::new(self.push_filters_down(*agg.input));
655                LogicalOperator::Aggregate(agg)
656            }
657            // Leaf operators and unsupported operators are returned as-is
658            other => other,
659        }
660    }
661
662    /// Tries to push a filter predicate into the given operator.
663    ///
664    /// Returns either the predicate pushed into the operator, or a new
665    /// Filter operator on top if the predicate cannot be pushed further.
666    fn try_push_filter_into(
667        &self,
668        predicate: LogicalExpression,
669        op: LogicalOperator,
670    ) -> LogicalOperator {
671        match op {
672            // Can push through Project if predicate doesn't depend on computed columns
673            LogicalOperator::Project(mut proj) => {
674                let predicate_vars = self.extract_variables(&predicate);
675                let computed_vars = self.extract_projection_aliases(&proj.projections);
676
677                // If predicate doesn't use any computed columns, push through
678                if predicate_vars.is_disjoint(&computed_vars) {
679                    proj.input = Box::new(self.try_push_filter_into(predicate, *proj.input));
680                    LogicalOperator::Project(proj)
681                } else {
682                    // Can't push through, keep filter on top
683                    LogicalOperator::Filter(FilterOp {
684                        predicate,
685                        input: Box::new(LogicalOperator::Project(proj)),
686                    })
687                }
688            }
689
690            // Can push through Return (which is like a projection)
691            LogicalOperator::Return(mut ret) => {
692                ret.input = Box::new(self.try_push_filter_into(predicate, *ret.input));
693                LogicalOperator::Return(ret)
694            }
695
696            // Can push through Expand if predicate doesn't use variables introduced by this expand
697            LogicalOperator::Expand(mut expand) => {
698                let predicate_vars = self.extract_variables(&predicate);
699
700                // Variables introduced by this expand are:
701                // - The target variable (to_variable)
702                // - The edge variable (if any)
703                // - The path alias (if any)
704                let mut introduced_vars = vec![&expand.to_variable];
705                if let Some(ref edge_var) = expand.edge_variable {
706                    introduced_vars.push(edge_var);
707                }
708                if let Some(ref path_alias) = expand.path_alias {
709                    introduced_vars.push(path_alias);
710                }
711
712                // Check if predicate uses any variables introduced by this expand
713                let uses_introduced_vars =
714                    predicate_vars.iter().any(|v| introduced_vars.contains(&v));
715
716                if !uses_introduced_vars {
717                    // Predicate doesn't use vars from this expand, so push through
718                    expand.input = Box::new(self.try_push_filter_into(predicate, *expand.input));
719                    LogicalOperator::Expand(expand)
720                } else {
721                    // Keep filter after expand
722                    LogicalOperator::Filter(FilterOp {
723                        predicate,
724                        input: Box::new(LogicalOperator::Expand(expand)),
725                    })
726                }
727            }
728
729            // Can push through Join to left/right side based on variables used
730            LogicalOperator::Join(mut join) => {
731                let predicate_vars = self.extract_variables(&predicate);
732                let left_vars = self.collect_output_variables(&join.left);
733                let right_vars = self.collect_output_variables(&join.right);
734
735                let uses_left = predicate_vars.iter().any(|v| left_vars.contains(v));
736                let uses_right = predicate_vars.iter().any(|v| right_vars.contains(v));
737
738                if uses_left && !uses_right {
739                    // Push to left side
740                    join.left = Box::new(self.try_push_filter_into(predicate, *join.left));
741                    LogicalOperator::Join(join)
742                } else if uses_right && !uses_left {
743                    // Push to right side
744                    join.right = Box::new(self.try_push_filter_into(predicate, *join.right));
745                    LogicalOperator::Join(join)
746                } else {
747                    // Uses both sides - keep above join
748                    LogicalOperator::Filter(FilterOp {
749                        predicate,
750                        input: Box::new(LogicalOperator::Join(join)),
751                    })
752                }
753            }
754
755            // Cannot push through Aggregate (predicate refers to aggregated values)
756            LogicalOperator::Aggregate(agg) => LogicalOperator::Filter(FilterOp {
757                predicate,
758                input: Box::new(LogicalOperator::Aggregate(agg)),
759            }),
760
761            // For NodeScan, we've reached the bottom - keep filter on top
762            LogicalOperator::NodeScan(scan) => LogicalOperator::Filter(FilterOp {
763                predicate,
764                input: Box::new(LogicalOperator::NodeScan(scan)),
765            }),
766
767            // For other operators, keep filter on top
768            other => LogicalOperator::Filter(FilterOp {
769                predicate,
770                input: Box::new(other),
771            }),
772        }
773    }
774
775    /// Collects all output variable names from an operator.
776    fn collect_output_variables(&self, op: &LogicalOperator) -> HashSet<String> {
777        let mut vars = HashSet::new();
778        Self::collect_output_variables_recursive(op, &mut vars);
779        vars
780    }
781
782    /// Recursively collects output variables from an operator.
783    fn collect_output_variables_recursive(op: &LogicalOperator, vars: &mut HashSet<String>) {
784        match op {
785            LogicalOperator::NodeScan(scan) => {
786                vars.insert(scan.variable.clone());
787            }
788            LogicalOperator::EdgeScan(scan) => {
789                vars.insert(scan.variable.clone());
790            }
791            LogicalOperator::Expand(expand) => {
792                vars.insert(expand.to_variable.clone());
793                if let Some(edge_var) = &expand.edge_variable {
794                    vars.insert(edge_var.clone());
795                }
796                Self::collect_output_variables_recursive(&expand.input, vars);
797            }
798            LogicalOperator::Filter(filter) => {
799                Self::collect_output_variables_recursive(&filter.input, vars);
800            }
801            LogicalOperator::Project(proj) => {
802                for p in &proj.projections {
803                    if let Some(alias) = &p.alias {
804                        vars.insert(alias.clone());
805                    }
806                }
807                Self::collect_output_variables_recursive(&proj.input, vars);
808            }
809            LogicalOperator::Join(join) => {
810                Self::collect_output_variables_recursive(&join.left, vars);
811                Self::collect_output_variables_recursive(&join.right, vars);
812            }
813            LogicalOperator::Aggregate(agg) => {
814                for expr in &agg.group_by {
815                    Self::collect_variables(expr, vars);
816                }
817                for agg_expr in &agg.aggregates {
818                    if let Some(alias) = &agg_expr.alias {
819                        vars.insert(alias.clone());
820                    }
821                }
822            }
823            LogicalOperator::Return(ret) => {
824                Self::collect_output_variables_recursive(&ret.input, vars);
825            }
826            LogicalOperator::Limit(limit) => {
827                Self::collect_output_variables_recursive(&limit.input, vars);
828            }
829            LogicalOperator::Skip(skip) => {
830                Self::collect_output_variables_recursive(&skip.input, vars);
831            }
832            LogicalOperator::Sort(sort) => {
833                Self::collect_output_variables_recursive(&sort.input, vars);
834            }
835            LogicalOperator::Distinct(distinct) => {
836                Self::collect_output_variables_recursive(&distinct.input, vars);
837            }
838            _ => {}
839        }
840    }
841
842    /// Extracts all variable names referenced in an expression.
843    fn extract_variables(&self, expr: &LogicalExpression) -> HashSet<String> {
844        let mut vars = HashSet::new();
845        Self::collect_variables(expr, &mut vars);
846        vars
847    }
848
849    /// Recursively collects variable names from an expression.
850    fn collect_variables(expr: &LogicalExpression, vars: &mut HashSet<String>) {
851        match expr {
852            LogicalExpression::Variable(name) => {
853                vars.insert(name.clone());
854            }
855            LogicalExpression::Property { variable, .. } => {
856                vars.insert(variable.clone());
857            }
858            LogicalExpression::Binary { left, right, .. } => {
859                Self::collect_variables(left, vars);
860                Self::collect_variables(right, vars);
861            }
862            LogicalExpression::Unary { operand, .. } => {
863                Self::collect_variables(operand, vars);
864            }
865            LogicalExpression::FunctionCall { args, .. } => {
866                for arg in args {
867                    Self::collect_variables(arg, vars);
868                }
869            }
870            LogicalExpression::List(items) => {
871                for item in items {
872                    Self::collect_variables(item, vars);
873                }
874            }
875            LogicalExpression::Map(pairs) => {
876                for (_, value) in pairs {
877                    Self::collect_variables(value, vars);
878                }
879            }
880            LogicalExpression::IndexAccess { base, index } => {
881                Self::collect_variables(base, vars);
882                Self::collect_variables(index, vars);
883            }
884            LogicalExpression::SliceAccess { base, start, end } => {
885                Self::collect_variables(base, vars);
886                if let Some(s) = start {
887                    Self::collect_variables(s, vars);
888                }
889                if let Some(e) = end {
890                    Self::collect_variables(e, vars);
891                }
892            }
893            LogicalExpression::Case {
894                operand,
895                when_clauses,
896                else_clause,
897            } => {
898                if let Some(op) = operand {
899                    Self::collect_variables(op, vars);
900                }
901                for (cond, result) in when_clauses {
902                    Self::collect_variables(cond, vars);
903                    Self::collect_variables(result, vars);
904                }
905                if let Some(else_expr) = else_clause {
906                    Self::collect_variables(else_expr, vars);
907                }
908            }
909            LogicalExpression::Labels(var)
910            | LogicalExpression::Type(var)
911            | LogicalExpression::Id(var) => {
912                vars.insert(var.clone());
913            }
914            LogicalExpression::Literal(_) | LogicalExpression::Parameter(_) => {}
915            LogicalExpression::ListComprehension {
916                list_expr,
917                filter_expr,
918                map_expr,
919                ..
920            } => {
921                Self::collect_variables(list_expr, vars);
922                if let Some(filter) = filter_expr {
923                    Self::collect_variables(filter, vars);
924                }
925                Self::collect_variables(map_expr, vars);
926            }
927            LogicalExpression::ExistsSubquery(_) | LogicalExpression::CountSubquery(_) => {
928                // Subqueries have their own variable scope
929            }
930        }
931    }
932
933    /// Extracts aliases from projection expressions.
934    fn extract_projection_aliases(
935        &self,
936        projections: &[crate::query::plan::Projection],
937    ) -> HashSet<String> {
938        projections.iter().filter_map(|p| p.alias.clone()).collect()
939    }
940}
941
942impl Default for Optimizer {
943    fn default() -> Self {
944        Self::new()
945    }
946}
947
948#[cfg(test)]
949mod tests {
950    use super::*;
951    use crate::query::plan::{
952        AggregateExpr, AggregateFunction, AggregateOp, BinaryOp, DistinctOp, ExpandDirection,
953        ExpandOp, JoinOp, JoinType, LimitOp, NodeScanOp, ProjectOp, Projection, ReturnItem,
954        ReturnOp, SkipOp, SortKey, SortOp, SortOrder, UnaryOp,
955    };
956    use grafeo_common::types::Value;
957
958    #[test]
959    fn test_optimizer_filter_pushdown_simple() {
960        // Query: MATCH (n:Person) WHERE n.age > 30 RETURN n
961        // Before: Return -> Filter -> NodeScan
962        // After:  Return -> Filter -> NodeScan (filter stays at bottom)
963
964        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
965            items: vec![ReturnItem {
966                expression: LogicalExpression::Variable("n".to_string()),
967                alias: None,
968            }],
969            distinct: false,
970            input: Box::new(LogicalOperator::Filter(FilterOp {
971                predicate: LogicalExpression::Binary {
972                    left: Box::new(LogicalExpression::Property {
973                        variable: "n".to_string(),
974                        property: "age".to_string(),
975                    }),
976                    op: BinaryOp::Gt,
977                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
978                },
979                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
980                    variable: "n".to_string(),
981                    label: Some("Person".to_string()),
982                    input: None,
983                })),
984            })),
985        }));
986
987        let optimizer = Optimizer::new();
988        let optimized = optimizer.optimize(plan).unwrap();
989
990        // The structure should remain similar (filter stays near scan)
991        if let LogicalOperator::Return(ret) = &optimized.root {
992            if let LogicalOperator::Filter(filter) = ret.input.as_ref() {
993                if let LogicalOperator::NodeScan(scan) = filter.input.as_ref() {
994                    assert_eq!(scan.variable, "n");
995                    return;
996                }
997            }
998        }
999        panic!("Expected Return -> Filter -> NodeScan structure");
1000    }
1001
1002    #[test]
1003    fn test_optimizer_filter_pushdown_through_expand() {
1004        // Query: MATCH (a:Person)-[:KNOWS]->(b) WHERE a.age > 30 RETURN b
1005        // The filter on 'a' should be pushed before the expand
1006
1007        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1008            items: vec![ReturnItem {
1009                expression: LogicalExpression::Variable("b".to_string()),
1010                alias: None,
1011            }],
1012            distinct: false,
1013            input: Box::new(LogicalOperator::Filter(FilterOp {
1014                predicate: LogicalExpression::Binary {
1015                    left: Box::new(LogicalExpression::Property {
1016                        variable: "a".to_string(),
1017                        property: "age".to_string(),
1018                    }),
1019                    op: BinaryOp::Gt,
1020                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1021                },
1022                input: Box::new(LogicalOperator::Expand(ExpandOp {
1023                    from_variable: "a".to_string(),
1024                    to_variable: "b".to_string(),
1025                    edge_variable: None,
1026                    direction: ExpandDirection::Outgoing,
1027                    edge_type: Some("KNOWS".to_string()),
1028                    min_hops: 1,
1029                    max_hops: Some(1),
1030                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1031                        variable: "a".to_string(),
1032                        label: Some("Person".to_string()),
1033                        input: None,
1034                    })),
1035                    path_alias: None,
1036                })),
1037            })),
1038        }));
1039
1040        let optimizer = Optimizer::new();
1041        let optimized = optimizer.optimize(plan).unwrap();
1042
1043        // Filter on 'a' should be pushed before the expand
1044        // Expected: Return -> Expand -> Filter -> NodeScan
1045        if let LogicalOperator::Return(ret) = &optimized.root {
1046            if let LogicalOperator::Expand(expand) = ret.input.as_ref() {
1047                if let LogicalOperator::Filter(filter) = expand.input.as_ref() {
1048                    if let LogicalOperator::NodeScan(scan) = filter.input.as_ref() {
1049                        assert_eq!(scan.variable, "a");
1050                        assert_eq!(expand.from_variable, "a");
1051                        assert_eq!(expand.to_variable, "b");
1052                        return;
1053                    }
1054                }
1055            }
1056        }
1057        panic!("Expected Return -> Expand -> Filter -> NodeScan structure");
1058    }
1059
1060    #[test]
1061    fn test_optimizer_filter_not_pushed_through_expand_for_target_var() {
1062        // Query: MATCH (a:Person)-[:KNOWS]->(b) WHERE b.age > 30 RETURN a
1063        // The filter on 'b' should NOT be pushed before the expand
1064
1065        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1066            items: vec![ReturnItem {
1067                expression: LogicalExpression::Variable("a".to_string()),
1068                alias: None,
1069            }],
1070            distinct: false,
1071            input: Box::new(LogicalOperator::Filter(FilterOp {
1072                predicate: LogicalExpression::Binary {
1073                    left: Box::new(LogicalExpression::Property {
1074                        variable: "b".to_string(),
1075                        property: "age".to_string(),
1076                    }),
1077                    op: BinaryOp::Gt,
1078                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1079                },
1080                input: Box::new(LogicalOperator::Expand(ExpandOp {
1081                    from_variable: "a".to_string(),
1082                    to_variable: "b".to_string(),
1083                    edge_variable: None,
1084                    direction: ExpandDirection::Outgoing,
1085                    edge_type: Some("KNOWS".to_string()),
1086                    min_hops: 1,
1087                    max_hops: Some(1),
1088                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1089                        variable: "a".to_string(),
1090                        label: Some("Person".to_string()),
1091                        input: None,
1092                    })),
1093                    path_alias: None,
1094                })),
1095            })),
1096        }));
1097
1098        let optimizer = Optimizer::new();
1099        let optimized = optimizer.optimize(plan).unwrap();
1100
1101        // Filter on 'b' should stay after the expand
1102        // Expected: Return -> Filter -> Expand -> NodeScan
1103        if let LogicalOperator::Return(ret) = &optimized.root {
1104            if let LogicalOperator::Filter(filter) = ret.input.as_ref() {
1105                // Check that the filter is on 'b'
1106                if let LogicalExpression::Binary { left, .. } = &filter.predicate {
1107                    if let LogicalExpression::Property { variable, .. } = left.as_ref() {
1108                        assert_eq!(variable, "b");
1109                    }
1110                }
1111
1112                if let LogicalOperator::Expand(expand) = filter.input.as_ref() {
1113                    if let LogicalOperator::NodeScan(_) = expand.input.as_ref() {
1114                        return;
1115                    }
1116                }
1117            }
1118        }
1119        panic!("Expected Return -> Filter -> Expand -> NodeScan structure");
1120    }
1121
1122    #[test]
1123    fn test_optimizer_extract_variables() {
1124        let optimizer = Optimizer::new();
1125
1126        let expr = LogicalExpression::Binary {
1127            left: Box::new(LogicalExpression::Property {
1128                variable: "n".to_string(),
1129                property: "age".to_string(),
1130            }),
1131            op: BinaryOp::Gt,
1132            right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1133        };
1134
1135        let vars = optimizer.extract_variables(&expr);
1136        assert_eq!(vars.len(), 1);
1137        assert!(vars.contains("n"));
1138    }
1139
1140    // Additional tests for optimizer configuration
1141
1142    #[test]
1143    fn test_optimizer_default() {
1144        let optimizer = Optimizer::default();
1145        // Should be able to optimize an empty plan
1146        let plan = LogicalPlan::new(LogicalOperator::Empty);
1147        let result = optimizer.optimize(plan);
1148        assert!(result.is_ok());
1149    }
1150
1151    #[test]
1152    fn test_optimizer_with_filter_pushdown_disabled() {
1153        let optimizer = Optimizer::new().with_filter_pushdown(false);
1154
1155        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1156            items: vec![ReturnItem {
1157                expression: LogicalExpression::Variable("n".to_string()),
1158                alias: None,
1159            }],
1160            distinct: false,
1161            input: Box::new(LogicalOperator::Filter(FilterOp {
1162                predicate: LogicalExpression::Literal(Value::Bool(true)),
1163                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1164                    variable: "n".to_string(),
1165                    label: None,
1166                    input: None,
1167                })),
1168            })),
1169        }));
1170
1171        let optimized = optimizer.optimize(plan).unwrap();
1172        // Structure should be unchanged
1173        if let LogicalOperator::Return(ret) = &optimized.root {
1174            if let LogicalOperator::Filter(_) = ret.input.as_ref() {
1175                return;
1176            }
1177        }
1178        panic!("Expected unchanged structure");
1179    }
1180
1181    #[test]
1182    fn test_optimizer_with_join_reorder_disabled() {
1183        let optimizer = Optimizer::new().with_join_reorder(false);
1184        assert!(
1185            optimizer
1186                .optimize(LogicalPlan::new(LogicalOperator::Empty))
1187                .is_ok()
1188        );
1189    }
1190
1191    #[test]
1192    fn test_optimizer_with_cost_model() {
1193        let cost_model = CostModel::new();
1194        let optimizer = Optimizer::new().with_cost_model(cost_model);
1195        assert!(
1196            optimizer
1197                .cost_model()
1198                .estimate(&LogicalOperator::Empty, 0.0)
1199                .total()
1200                < 0.001
1201        );
1202    }
1203
1204    #[test]
1205    fn test_optimizer_with_cardinality_estimator() {
1206        let mut estimator = CardinalityEstimator::new();
1207        estimator.add_table_stats("Test", TableStats::new(500));
1208        let optimizer = Optimizer::new().with_cardinality_estimator(estimator);
1209
1210        let scan = LogicalOperator::NodeScan(NodeScanOp {
1211            variable: "n".to_string(),
1212            label: Some("Test".to_string()),
1213            input: None,
1214        });
1215        let plan = LogicalPlan::new(scan);
1216
1217        let cardinality = optimizer.estimate_cardinality(&plan);
1218        assert!((cardinality - 500.0).abs() < 0.001);
1219    }
1220
1221    #[test]
1222    fn test_optimizer_estimate_cost() {
1223        let optimizer = Optimizer::new();
1224        let plan = LogicalPlan::new(LogicalOperator::NodeScan(NodeScanOp {
1225            variable: "n".to_string(),
1226            label: None,
1227            input: None,
1228        }));
1229
1230        let cost = optimizer.estimate_cost(&plan);
1231        assert!(cost.total() > 0.0);
1232    }
1233
1234    // Filter pushdown through various operators
1235
1236    #[test]
1237    fn test_filter_pushdown_through_project() {
1238        let optimizer = Optimizer::new();
1239
1240        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1241            predicate: LogicalExpression::Binary {
1242                left: Box::new(LogicalExpression::Property {
1243                    variable: "n".to_string(),
1244                    property: "age".to_string(),
1245                }),
1246                op: BinaryOp::Gt,
1247                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1248            },
1249            input: Box::new(LogicalOperator::Project(ProjectOp {
1250                projections: vec![Projection {
1251                    expression: LogicalExpression::Variable("n".to_string()),
1252                    alias: None,
1253                }],
1254                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1255                    variable: "n".to_string(),
1256                    label: None,
1257                    input: None,
1258                })),
1259            })),
1260        }));
1261
1262        let optimized = optimizer.optimize(plan).unwrap();
1263
1264        // Filter should be pushed through Project
1265        if let LogicalOperator::Project(proj) = &optimized.root {
1266            if let LogicalOperator::Filter(_) = proj.input.as_ref() {
1267                return;
1268            }
1269        }
1270        panic!("Expected Project -> Filter structure");
1271    }
1272
1273    #[test]
1274    fn test_filter_not_pushed_through_project_with_alias() {
1275        let optimizer = Optimizer::new();
1276
1277        // Filter on computed column 'x' should not be pushed through project that creates 'x'
1278        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1279            predicate: LogicalExpression::Binary {
1280                left: Box::new(LogicalExpression::Variable("x".to_string())),
1281                op: BinaryOp::Gt,
1282                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1283            },
1284            input: Box::new(LogicalOperator::Project(ProjectOp {
1285                projections: vec![Projection {
1286                    expression: LogicalExpression::Property {
1287                        variable: "n".to_string(),
1288                        property: "age".to_string(),
1289                    },
1290                    alias: Some("x".to_string()),
1291                }],
1292                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1293                    variable: "n".to_string(),
1294                    label: None,
1295                    input: None,
1296                })),
1297            })),
1298        }));
1299
1300        let optimized = optimizer.optimize(plan).unwrap();
1301
1302        // Filter should stay above Project
1303        if let LogicalOperator::Filter(filter) = &optimized.root {
1304            if let LogicalOperator::Project(_) = filter.input.as_ref() {
1305                return;
1306            }
1307        }
1308        panic!("Expected Filter -> Project structure");
1309    }
1310
1311    #[test]
1312    fn test_filter_pushdown_through_limit() {
1313        let optimizer = Optimizer::new();
1314
1315        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1316            predicate: LogicalExpression::Literal(Value::Bool(true)),
1317            input: Box::new(LogicalOperator::Limit(LimitOp {
1318                count: 10,
1319                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1320                    variable: "n".to_string(),
1321                    label: None,
1322                    input: None,
1323                })),
1324            })),
1325        }));
1326
1327        let optimized = optimizer.optimize(plan).unwrap();
1328
1329        // Filter stays above Limit (cannot be pushed through)
1330        if let LogicalOperator::Filter(filter) = &optimized.root {
1331            if let LogicalOperator::Limit(_) = filter.input.as_ref() {
1332                return;
1333            }
1334        }
1335        panic!("Expected Filter -> Limit structure");
1336    }
1337
1338    #[test]
1339    fn test_filter_pushdown_through_sort() {
1340        let optimizer = Optimizer::new();
1341
1342        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1343            predicate: LogicalExpression::Literal(Value::Bool(true)),
1344            input: Box::new(LogicalOperator::Sort(SortOp {
1345                keys: vec![SortKey {
1346                    expression: LogicalExpression::Variable("n".to_string()),
1347                    order: SortOrder::Ascending,
1348                }],
1349                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1350                    variable: "n".to_string(),
1351                    label: None,
1352                    input: None,
1353                })),
1354            })),
1355        }));
1356
1357        let optimized = optimizer.optimize(plan).unwrap();
1358
1359        // Filter stays above Sort
1360        if let LogicalOperator::Filter(filter) = &optimized.root {
1361            if let LogicalOperator::Sort(_) = filter.input.as_ref() {
1362                return;
1363            }
1364        }
1365        panic!("Expected Filter -> Sort structure");
1366    }
1367
1368    #[test]
1369    fn test_filter_pushdown_through_distinct() {
1370        let optimizer = Optimizer::new();
1371
1372        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1373            predicate: LogicalExpression::Literal(Value::Bool(true)),
1374            input: Box::new(LogicalOperator::Distinct(DistinctOp {
1375                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1376                    variable: "n".to_string(),
1377                    label: None,
1378                    input: None,
1379                })),
1380                columns: None,
1381            })),
1382        }));
1383
1384        let optimized = optimizer.optimize(plan).unwrap();
1385
1386        // Filter stays above Distinct
1387        if let LogicalOperator::Filter(filter) = &optimized.root {
1388            if let LogicalOperator::Distinct(_) = filter.input.as_ref() {
1389                return;
1390            }
1391        }
1392        panic!("Expected Filter -> Distinct structure");
1393    }
1394
1395    #[test]
1396    fn test_filter_not_pushed_through_aggregate() {
1397        let optimizer = Optimizer::new();
1398
1399        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1400            predicate: LogicalExpression::Binary {
1401                left: Box::new(LogicalExpression::Variable("cnt".to_string())),
1402                op: BinaryOp::Gt,
1403                right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1404            },
1405            input: Box::new(LogicalOperator::Aggregate(AggregateOp {
1406                group_by: vec![],
1407                aggregates: vec![AggregateExpr {
1408                    function: AggregateFunction::Count,
1409                    expression: None,
1410                    distinct: false,
1411                    alias: Some("cnt".to_string()),
1412                    percentile: None,
1413                }],
1414                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1415                    variable: "n".to_string(),
1416                    label: None,
1417                    input: None,
1418                })),
1419                having: None,
1420            })),
1421        }));
1422
1423        let optimized = optimizer.optimize(plan).unwrap();
1424
1425        // Filter should stay above Aggregate
1426        if let LogicalOperator::Filter(filter) = &optimized.root {
1427            if let LogicalOperator::Aggregate(_) = filter.input.as_ref() {
1428                return;
1429            }
1430        }
1431        panic!("Expected Filter -> Aggregate structure");
1432    }
1433
1434    #[test]
1435    fn test_filter_pushdown_to_left_join_side() {
1436        let optimizer = Optimizer::new();
1437
1438        // Filter on left variable should be pushed to left side
1439        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1440            predicate: LogicalExpression::Binary {
1441                left: Box::new(LogicalExpression::Property {
1442                    variable: "a".to_string(),
1443                    property: "age".to_string(),
1444                }),
1445                op: BinaryOp::Gt,
1446                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1447            },
1448            input: Box::new(LogicalOperator::Join(JoinOp {
1449                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1450                    variable: "a".to_string(),
1451                    label: Some("Person".to_string()),
1452                    input: None,
1453                })),
1454                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1455                    variable: "b".to_string(),
1456                    label: Some("Company".to_string()),
1457                    input: None,
1458                })),
1459                join_type: JoinType::Inner,
1460                conditions: vec![],
1461            })),
1462        }));
1463
1464        let optimized = optimizer.optimize(plan).unwrap();
1465
1466        // Filter should be pushed to left side of join
1467        if let LogicalOperator::Join(join) = &optimized.root {
1468            if let LogicalOperator::Filter(_) = join.left.as_ref() {
1469                return;
1470            }
1471        }
1472        panic!("Expected Join with Filter on left side");
1473    }
1474
1475    #[test]
1476    fn test_filter_pushdown_to_right_join_side() {
1477        let optimizer = Optimizer::new();
1478
1479        // Filter on right variable should be pushed to right side
1480        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1481            predicate: LogicalExpression::Binary {
1482                left: Box::new(LogicalExpression::Property {
1483                    variable: "b".to_string(),
1484                    property: "name".to_string(),
1485                }),
1486                op: BinaryOp::Eq,
1487                right: Box::new(LogicalExpression::Literal(Value::String("Acme".into()))),
1488            },
1489            input: Box::new(LogicalOperator::Join(JoinOp {
1490                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1491                    variable: "a".to_string(),
1492                    label: Some("Person".to_string()),
1493                    input: None,
1494                })),
1495                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1496                    variable: "b".to_string(),
1497                    label: Some("Company".to_string()),
1498                    input: None,
1499                })),
1500                join_type: JoinType::Inner,
1501                conditions: vec![],
1502            })),
1503        }));
1504
1505        let optimized = optimizer.optimize(plan).unwrap();
1506
1507        // Filter should be pushed to right side of join
1508        if let LogicalOperator::Join(join) = &optimized.root {
1509            if let LogicalOperator::Filter(_) = join.right.as_ref() {
1510                return;
1511            }
1512        }
1513        panic!("Expected Join with Filter on right side");
1514    }
1515
1516    #[test]
1517    fn test_filter_not_pushed_when_uses_both_join_sides() {
1518        let optimizer = Optimizer::new();
1519
1520        // Filter using both variables should stay above join
1521        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1522            predicate: LogicalExpression::Binary {
1523                left: Box::new(LogicalExpression::Property {
1524                    variable: "a".to_string(),
1525                    property: "id".to_string(),
1526                }),
1527                op: BinaryOp::Eq,
1528                right: Box::new(LogicalExpression::Property {
1529                    variable: "b".to_string(),
1530                    property: "a_id".to_string(),
1531                }),
1532            },
1533            input: Box::new(LogicalOperator::Join(JoinOp {
1534                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1535                    variable: "a".to_string(),
1536                    label: None,
1537                    input: None,
1538                })),
1539                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1540                    variable: "b".to_string(),
1541                    label: None,
1542                    input: None,
1543                })),
1544                join_type: JoinType::Inner,
1545                conditions: vec![],
1546            })),
1547        }));
1548
1549        let optimized = optimizer.optimize(plan).unwrap();
1550
1551        // Filter should stay above join
1552        if let LogicalOperator::Filter(filter) = &optimized.root {
1553            if let LogicalOperator::Join(_) = filter.input.as_ref() {
1554                return;
1555            }
1556        }
1557        panic!("Expected Filter -> Join structure");
1558    }
1559
1560    // Variable extraction tests
1561
1562    #[test]
1563    fn test_extract_variables_from_variable() {
1564        let optimizer = Optimizer::new();
1565        let expr = LogicalExpression::Variable("x".to_string());
1566        let vars = optimizer.extract_variables(&expr);
1567        assert_eq!(vars.len(), 1);
1568        assert!(vars.contains("x"));
1569    }
1570
1571    #[test]
1572    fn test_extract_variables_from_unary() {
1573        let optimizer = Optimizer::new();
1574        let expr = LogicalExpression::Unary {
1575            op: UnaryOp::Not,
1576            operand: Box::new(LogicalExpression::Variable("x".to_string())),
1577        };
1578        let vars = optimizer.extract_variables(&expr);
1579        assert_eq!(vars.len(), 1);
1580        assert!(vars.contains("x"));
1581    }
1582
1583    #[test]
1584    fn test_extract_variables_from_function_call() {
1585        let optimizer = Optimizer::new();
1586        let expr = LogicalExpression::FunctionCall {
1587            name: "length".to_string(),
1588            args: vec![
1589                LogicalExpression::Variable("a".to_string()),
1590                LogicalExpression::Variable("b".to_string()),
1591            ],
1592            distinct: false,
1593        };
1594        let vars = optimizer.extract_variables(&expr);
1595        assert_eq!(vars.len(), 2);
1596        assert!(vars.contains("a"));
1597        assert!(vars.contains("b"));
1598    }
1599
1600    #[test]
1601    fn test_extract_variables_from_list() {
1602        let optimizer = Optimizer::new();
1603        let expr = LogicalExpression::List(vec![
1604            LogicalExpression::Variable("a".to_string()),
1605            LogicalExpression::Literal(Value::Int64(1)),
1606            LogicalExpression::Variable("b".to_string()),
1607        ]);
1608        let vars = optimizer.extract_variables(&expr);
1609        assert_eq!(vars.len(), 2);
1610        assert!(vars.contains("a"));
1611        assert!(vars.contains("b"));
1612    }
1613
1614    #[test]
1615    fn test_extract_variables_from_map() {
1616        let optimizer = Optimizer::new();
1617        let expr = LogicalExpression::Map(vec![
1618            (
1619                "key1".to_string(),
1620                LogicalExpression::Variable("a".to_string()),
1621            ),
1622            (
1623                "key2".to_string(),
1624                LogicalExpression::Variable("b".to_string()),
1625            ),
1626        ]);
1627        let vars = optimizer.extract_variables(&expr);
1628        assert_eq!(vars.len(), 2);
1629        assert!(vars.contains("a"));
1630        assert!(vars.contains("b"));
1631    }
1632
1633    #[test]
1634    fn test_extract_variables_from_index_access() {
1635        let optimizer = Optimizer::new();
1636        let expr = LogicalExpression::IndexAccess {
1637            base: Box::new(LogicalExpression::Variable("list".to_string())),
1638            index: Box::new(LogicalExpression::Variable("idx".to_string())),
1639        };
1640        let vars = optimizer.extract_variables(&expr);
1641        assert_eq!(vars.len(), 2);
1642        assert!(vars.contains("list"));
1643        assert!(vars.contains("idx"));
1644    }
1645
1646    #[test]
1647    fn test_extract_variables_from_slice_access() {
1648        let optimizer = Optimizer::new();
1649        let expr = LogicalExpression::SliceAccess {
1650            base: Box::new(LogicalExpression::Variable("list".to_string())),
1651            start: Some(Box::new(LogicalExpression::Variable("s".to_string()))),
1652            end: Some(Box::new(LogicalExpression::Variable("e".to_string()))),
1653        };
1654        let vars = optimizer.extract_variables(&expr);
1655        assert_eq!(vars.len(), 3);
1656        assert!(vars.contains("list"));
1657        assert!(vars.contains("s"));
1658        assert!(vars.contains("e"));
1659    }
1660
1661    #[test]
1662    fn test_extract_variables_from_case() {
1663        let optimizer = Optimizer::new();
1664        let expr = LogicalExpression::Case {
1665            operand: Some(Box::new(LogicalExpression::Variable("x".to_string()))),
1666            when_clauses: vec![(
1667                LogicalExpression::Literal(Value::Int64(1)),
1668                LogicalExpression::Variable("a".to_string()),
1669            )],
1670            else_clause: Some(Box::new(LogicalExpression::Variable("b".to_string()))),
1671        };
1672        let vars = optimizer.extract_variables(&expr);
1673        assert_eq!(vars.len(), 3);
1674        assert!(vars.contains("x"));
1675        assert!(vars.contains("a"));
1676        assert!(vars.contains("b"));
1677    }
1678
1679    #[test]
1680    fn test_extract_variables_from_labels() {
1681        let optimizer = Optimizer::new();
1682        let expr = LogicalExpression::Labels("n".to_string());
1683        let vars = optimizer.extract_variables(&expr);
1684        assert_eq!(vars.len(), 1);
1685        assert!(vars.contains("n"));
1686    }
1687
1688    #[test]
1689    fn test_extract_variables_from_type() {
1690        let optimizer = Optimizer::new();
1691        let expr = LogicalExpression::Type("e".to_string());
1692        let vars = optimizer.extract_variables(&expr);
1693        assert_eq!(vars.len(), 1);
1694        assert!(vars.contains("e"));
1695    }
1696
1697    #[test]
1698    fn test_extract_variables_from_id() {
1699        let optimizer = Optimizer::new();
1700        let expr = LogicalExpression::Id("n".to_string());
1701        let vars = optimizer.extract_variables(&expr);
1702        assert_eq!(vars.len(), 1);
1703        assert!(vars.contains("n"));
1704    }
1705
1706    #[test]
1707    fn test_extract_variables_from_list_comprehension() {
1708        let optimizer = Optimizer::new();
1709        let expr = LogicalExpression::ListComprehension {
1710            variable: "x".to_string(),
1711            list_expr: Box::new(LogicalExpression::Variable("items".to_string())),
1712            filter_expr: Some(Box::new(LogicalExpression::Variable("pred".to_string()))),
1713            map_expr: Box::new(LogicalExpression::Variable("result".to_string())),
1714        };
1715        let vars = optimizer.extract_variables(&expr);
1716        assert!(vars.contains("items"));
1717        assert!(vars.contains("pred"));
1718        assert!(vars.contains("result"));
1719    }
1720
1721    #[test]
1722    fn test_extract_variables_from_literal_and_parameter() {
1723        let optimizer = Optimizer::new();
1724
1725        let literal = LogicalExpression::Literal(Value::Int64(42));
1726        assert!(optimizer.extract_variables(&literal).is_empty());
1727
1728        let param = LogicalExpression::Parameter("p".to_string());
1729        assert!(optimizer.extract_variables(&param).is_empty());
1730    }
1731
1732    // Recursive filter pushdown tests
1733
1734    #[test]
1735    fn test_recursive_filter_pushdown_through_skip() {
1736        let optimizer = Optimizer::new();
1737
1738        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1739            items: vec![ReturnItem {
1740                expression: LogicalExpression::Variable("n".to_string()),
1741                alias: None,
1742            }],
1743            distinct: false,
1744            input: Box::new(LogicalOperator::Filter(FilterOp {
1745                predicate: LogicalExpression::Literal(Value::Bool(true)),
1746                input: Box::new(LogicalOperator::Skip(SkipOp {
1747                    count: 5,
1748                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1749                        variable: "n".to_string(),
1750                        label: None,
1751                        input: None,
1752                    })),
1753                })),
1754            })),
1755        }));
1756
1757        let optimized = optimizer.optimize(plan).unwrap();
1758
1759        // Verify optimization succeeded
1760        assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1761    }
1762
1763    #[test]
1764    fn test_nested_filter_pushdown() {
1765        let optimizer = Optimizer::new();
1766
1767        // Multiple nested filters
1768        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1769            items: vec![ReturnItem {
1770                expression: LogicalExpression::Variable("n".to_string()),
1771                alias: None,
1772            }],
1773            distinct: false,
1774            input: Box::new(LogicalOperator::Filter(FilterOp {
1775                predicate: LogicalExpression::Binary {
1776                    left: Box::new(LogicalExpression::Property {
1777                        variable: "n".to_string(),
1778                        property: "x".to_string(),
1779                    }),
1780                    op: BinaryOp::Gt,
1781                    right: Box::new(LogicalExpression::Literal(Value::Int64(1))),
1782                },
1783                input: Box::new(LogicalOperator::Filter(FilterOp {
1784                    predicate: LogicalExpression::Binary {
1785                        left: Box::new(LogicalExpression::Property {
1786                            variable: "n".to_string(),
1787                            property: "y".to_string(),
1788                        }),
1789                        op: BinaryOp::Lt,
1790                        right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1791                    },
1792                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1793                        variable: "n".to_string(),
1794                        label: None,
1795                        input: None,
1796                    })),
1797                })),
1798            })),
1799        }));
1800
1801        let optimized = optimizer.optimize(plan).unwrap();
1802        assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1803    }
1804}