Skip to main content

grafeo_engine/query/optimizer/
mod.rs

1//! Makes your queries faster without changing their meaning.
2//!
3//! The optimizer transforms logical plans to run more efficiently:
4//!
5//! | Optimization | What it does |
6//! | ------------ | ------------ |
7//! | Filter Pushdown | Moves `WHERE` clauses closer to scans - filter early, process less |
8//! | Join Reordering | Picks the best order to join tables using the DPccp algorithm |
9//! | Predicate Simplification | Folds constants like `1 + 1` into `2` |
10//!
11//! The optimizer uses [`CostModel`] and [`CardinalityEstimator`] to predict
12//! how expensive different plans are, then picks the cheapest.
13
14pub mod cardinality;
15pub mod cost;
16pub mod join_order;
17
18pub use cardinality::{CardinalityEstimator, ColumnStats, TableStats};
19pub use cost::{Cost, CostModel};
20pub use join_order::{BitSet, DPccp, JoinGraph, JoinGraphBuilder, JoinPlan};
21
22use crate::query::plan::{FilterOp, LogicalExpression, LogicalOperator, LogicalPlan};
23use grafeo_common::utils::error::Result;
24use std::collections::HashSet;
25
26/// Information about a join condition for join reordering.
27#[derive(Debug, Clone)]
28struct JoinInfo {
29    left_var: String,
30    right_var: String,
31    left_expr: LogicalExpression,
32    right_expr: LogicalExpression,
33}
34
35/// A column required by the query, used for projection pushdown.
36#[derive(Debug, Clone, PartialEq, Eq, Hash)]
37enum RequiredColumn {
38    /// A variable (node, edge, or path binding)
39    Variable(String),
40    /// A specific property of a variable
41    Property(String, String),
42}
43
44/// Transforms logical plans for faster execution.
45///
46/// Create with [`new()`](Self::new), then call [`optimize()`](Self::optimize).
47/// Use the builder methods to enable/disable specific optimizations.
48pub struct Optimizer {
49    /// Whether to enable filter pushdown.
50    enable_filter_pushdown: bool,
51    /// Whether to enable join reordering.
52    enable_join_reorder: bool,
53    /// Whether to enable projection pushdown.
54    enable_projection_pushdown: bool,
55    /// Cost model for estimation.
56    cost_model: CostModel,
57    /// Cardinality estimator.
58    card_estimator: CardinalityEstimator,
59}
60
61impl Optimizer {
62    /// Creates a new optimizer with default settings.
63    #[must_use]
64    pub fn new() -> Self {
65        Self {
66            enable_filter_pushdown: true,
67            enable_join_reorder: true,
68            enable_projection_pushdown: true,
69            cost_model: CostModel::new(),
70            card_estimator: CardinalityEstimator::new(),
71        }
72    }
73
74    /// Creates an optimizer with cardinality estimates from the store's statistics.
75    ///
76    /// Pre-populates the cardinality estimator with per-label row counts and
77    /// edge type fanout. Ensures statistics are fresh before reading.
78    #[must_use]
79    pub fn from_store(store: &grafeo_core::graph::lpg::LpgStore) -> Self {
80        store.ensure_statistics_fresh();
81        let stats = store.statistics();
82        let estimator = CardinalityEstimator::from_statistics(&stats);
83        Self {
84            enable_filter_pushdown: true,
85            enable_join_reorder: true,
86            enable_projection_pushdown: true,
87            cost_model: CostModel::new(),
88            card_estimator: estimator,
89        }
90    }
91
92    /// Enables or disables filter pushdown.
93    pub fn with_filter_pushdown(mut self, enabled: bool) -> Self {
94        self.enable_filter_pushdown = enabled;
95        self
96    }
97
98    /// Enables or disables join reordering.
99    pub fn with_join_reorder(mut self, enabled: bool) -> Self {
100        self.enable_join_reorder = enabled;
101        self
102    }
103
104    /// Enables or disables projection pushdown.
105    pub fn with_projection_pushdown(mut self, enabled: bool) -> Self {
106        self.enable_projection_pushdown = enabled;
107        self
108    }
109
110    /// Sets the cost model.
111    pub fn with_cost_model(mut self, cost_model: CostModel) -> Self {
112        self.cost_model = cost_model;
113        self
114    }
115
116    /// Sets the cardinality estimator.
117    pub fn with_cardinality_estimator(mut self, estimator: CardinalityEstimator) -> Self {
118        self.card_estimator = estimator;
119        self
120    }
121
122    /// Returns a reference to the cost model.
123    pub fn cost_model(&self) -> &CostModel {
124        &self.cost_model
125    }
126
127    /// Returns a reference to the cardinality estimator.
128    pub fn cardinality_estimator(&self) -> &CardinalityEstimator {
129        &self.card_estimator
130    }
131
132    /// Estimates the cost of a plan.
133    pub fn estimate_cost(&self, plan: &LogicalPlan) -> Cost {
134        let cardinality = self.card_estimator.estimate(&plan.root);
135        self.cost_model.estimate(&plan.root, cardinality)
136    }
137
138    /// Estimates the cardinality of a plan.
139    pub fn estimate_cardinality(&self, plan: &LogicalPlan) -> f64 {
140        self.card_estimator.estimate(&plan.root)
141    }
142
143    /// Optimizes a logical plan.
144    ///
145    /// # Errors
146    ///
147    /// Returns an error if optimization fails.
148    pub fn optimize(&self, plan: LogicalPlan) -> Result<LogicalPlan> {
149        let mut root = plan.root;
150
151        // Apply optimization rules
152        if self.enable_filter_pushdown {
153            root = self.push_filters_down(root);
154        }
155
156        if self.enable_join_reorder {
157            root = self.reorder_joins(root);
158        }
159
160        if self.enable_projection_pushdown {
161            root = self.push_projections_down(root);
162        }
163
164        Ok(LogicalPlan::new(root))
165    }
166
167    /// Pushes projections down the operator tree to eliminate unused columns early.
168    ///
169    /// This optimization:
170    /// 1. Collects required variables/properties from the root
171    /// 2. Propagates requirements down through the tree
172    /// 3. Inserts projections to eliminate unneeded columns before expensive operations
173    fn push_projections_down(&self, op: LogicalOperator) -> LogicalOperator {
174        // Collect required columns from the top of the plan
175        let required = self.collect_required_columns(&op);
176
177        // Push projections down
178        self.push_projections_recursive(op, &required)
179    }
180
181    /// Collects all variables and properties required by an operator and its ancestors.
182    fn collect_required_columns(&self, op: &LogicalOperator) -> HashSet<RequiredColumn> {
183        let mut required = HashSet::new();
184        Self::collect_required_recursive(op, &mut required);
185        required
186    }
187
188    /// Recursively collects required columns.
189    fn collect_required_recursive(op: &LogicalOperator, required: &mut HashSet<RequiredColumn>) {
190        match op {
191            LogicalOperator::Return(ret) => {
192                for item in &ret.items {
193                    Self::collect_from_expression(&item.expression, required);
194                }
195                Self::collect_required_recursive(&ret.input, required);
196            }
197            LogicalOperator::Project(proj) => {
198                for p in &proj.projections {
199                    Self::collect_from_expression(&p.expression, required);
200                }
201                Self::collect_required_recursive(&proj.input, required);
202            }
203            LogicalOperator::Filter(filter) => {
204                Self::collect_from_expression(&filter.predicate, required);
205                Self::collect_required_recursive(&filter.input, required);
206            }
207            LogicalOperator::Sort(sort) => {
208                for key in &sort.keys {
209                    Self::collect_from_expression(&key.expression, required);
210                }
211                Self::collect_required_recursive(&sort.input, required);
212            }
213            LogicalOperator::Aggregate(agg) => {
214                for expr in &agg.group_by {
215                    Self::collect_from_expression(expr, required);
216                }
217                for agg_expr in &agg.aggregates {
218                    if let Some(ref expr) = agg_expr.expression {
219                        Self::collect_from_expression(expr, required);
220                    }
221                }
222                if let Some(ref having) = agg.having {
223                    Self::collect_from_expression(having, required);
224                }
225                Self::collect_required_recursive(&agg.input, required);
226            }
227            LogicalOperator::Join(join) => {
228                for cond in &join.conditions {
229                    Self::collect_from_expression(&cond.left, required);
230                    Self::collect_from_expression(&cond.right, required);
231                }
232                Self::collect_required_recursive(&join.left, required);
233                Self::collect_required_recursive(&join.right, required);
234            }
235            LogicalOperator::Expand(expand) => {
236                // The source and target variables are needed
237                required.insert(RequiredColumn::Variable(expand.from_variable.clone()));
238                required.insert(RequiredColumn::Variable(expand.to_variable.clone()));
239                if let Some(ref edge_var) = expand.edge_variable {
240                    required.insert(RequiredColumn::Variable(edge_var.clone()));
241                }
242                Self::collect_required_recursive(&expand.input, required);
243            }
244            LogicalOperator::Limit(limit) => {
245                Self::collect_required_recursive(&limit.input, required);
246            }
247            LogicalOperator::Skip(skip) => {
248                Self::collect_required_recursive(&skip.input, required);
249            }
250            LogicalOperator::Distinct(distinct) => {
251                Self::collect_required_recursive(&distinct.input, required);
252            }
253            LogicalOperator::NodeScan(scan) => {
254                required.insert(RequiredColumn::Variable(scan.variable.clone()));
255            }
256            LogicalOperator::EdgeScan(scan) => {
257                required.insert(RequiredColumn::Variable(scan.variable.clone()));
258            }
259            _ => {}
260        }
261    }
262
263    /// Collects required columns from an expression.
264    fn collect_from_expression(expr: &LogicalExpression, required: &mut HashSet<RequiredColumn>) {
265        match expr {
266            LogicalExpression::Variable(var) => {
267                required.insert(RequiredColumn::Variable(var.clone()));
268            }
269            LogicalExpression::Property { variable, property } => {
270                required.insert(RequiredColumn::Property(variable.clone(), property.clone()));
271                required.insert(RequiredColumn::Variable(variable.clone()));
272            }
273            LogicalExpression::Binary { left, right, .. } => {
274                Self::collect_from_expression(left, required);
275                Self::collect_from_expression(right, required);
276            }
277            LogicalExpression::Unary { operand, .. } => {
278                Self::collect_from_expression(operand, required);
279            }
280            LogicalExpression::FunctionCall { args, .. } => {
281                for arg in args {
282                    Self::collect_from_expression(arg, required);
283                }
284            }
285            LogicalExpression::List(items) => {
286                for item in items {
287                    Self::collect_from_expression(item, required);
288                }
289            }
290            LogicalExpression::Map(pairs) => {
291                for (_, value) in pairs {
292                    Self::collect_from_expression(value, required);
293                }
294            }
295            LogicalExpression::IndexAccess { base, index } => {
296                Self::collect_from_expression(base, required);
297                Self::collect_from_expression(index, required);
298            }
299            LogicalExpression::SliceAccess { base, start, end } => {
300                Self::collect_from_expression(base, required);
301                if let Some(s) = start {
302                    Self::collect_from_expression(s, required);
303                }
304                if let Some(e) = end {
305                    Self::collect_from_expression(e, required);
306                }
307            }
308            LogicalExpression::Case {
309                operand,
310                when_clauses,
311                else_clause,
312            } => {
313                if let Some(op) = operand {
314                    Self::collect_from_expression(op, required);
315                }
316                for (cond, result) in when_clauses {
317                    Self::collect_from_expression(cond, required);
318                    Self::collect_from_expression(result, required);
319                }
320                if let Some(else_expr) = else_clause {
321                    Self::collect_from_expression(else_expr, required);
322                }
323            }
324            LogicalExpression::Labels(var)
325            | LogicalExpression::Type(var)
326            | LogicalExpression::Id(var) => {
327                required.insert(RequiredColumn::Variable(var.clone()));
328            }
329            LogicalExpression::ListComprehension {
330                list_expr,
331                filter_expr,
332                map_expr,
333                ..
334            } => {
335                Self::collect_from_expression(list_expr, required);
336                if let Some(filter) = filter_expr {
337                    Self::collect_from_expression(filter, required);
338                }
339                Self::collect_from_expression(map_expr, required);
340            }
341            _ => {}
342        }
343    }
344
345    /// Recursively pushes projections down, adding them before expensive operations.
346    fn push_projections_recursive(
347        &self,
348        op: LogicalOperator,
349        required: &HashSet<RequiredColumn>,
350    ) -> LogicalOperator {
351        match op {
352            LogicalOperator::Return(mut ret) => {
353                ret.input = Box::new(self.push_projections_recursive(*ret.input, required));
354                LogicalOperator::Return(ret)
355            }
356            LogicalOperator::Project(mut proj) => {
357                proj.input = Box::new(self.push_projections_recursive(*proj.input, required));
358                LogicalOperator::Project(proj)
359            }
360            LogicalOperator::Filter(mut filter) => {
361                filter.input = Box::new(self.push_projections_recursive(*filter.input, required));
362                LogicalOperator::Filter(filter)
363            }
364            LogicalOperator::Sort(mut sort) => {
365                // Sort is expensive - consider adding a projection before it
366                // to reduce tuple width
367                sort.input = Box::new(self.push_projections_recursive(*sort.input, required));
368                LogicalOperator::Sort(sort)
369            }
370            LogicalOperator::Aggregate(mut agg) => {
371                agg.input = Box::new(self.push_projections_recursive(*agg.input, required));
372                LogicalOperator::Aggregate(agg)
373            }
374            LogicalOperator::Join(mut join) => {
375                // Joins are expensive - the required columns help determine
376                // what to project on each side
377                let left_vars = self.collect_output_variables(&join.left);
378                let right_vars = self.collect_output_variables(&join.right);
379
380                // Filter required columns to each side
381                let left_required: HashSet<_> = required
382                    .iter()
383                    .filter(|c| match c {
384                        RequiredColumn::Variable(v) => left_vars.contains(v),
385                        RequiredColumn::Property(v, _) => left_vars.contains(v),
386                    })
387                    .cloned()
388                    .collect();
389
390                let right_required: HashSet<_> = required
391                    .iter()
392                    .filter(|c| match c {
393                        RequiredColumn::Variable(v) => right_vars.contains(v),
394                        RequiredColumn::Property(v, _) => right_vars.contains(v),
395                    })
396                    .cloned()
397                    .collect();
398
399                join.left = Box::new(self.push_projections_recursive(*join.left, &left_required));
400                join.right =
401                    Box::new(self.push_projections_recursive(*join.right, &right_required));
402                LogicalOperator::Join(join)
403            }
404            LogicalOperator::Expand(mut expand) => {
405                expand.input = Box::new(self.push_projections_recursive(*expand.input, required));
406                LogicalOperator::Expand(expand)
407            }
408            LogicalOperator::Limit(mut limit) => {
409                limit.input = Box::new(self.push_projections_recursive(*limit.input, required));
410                LogicalOperator::Limit(limit)
411            }
412            LogicalOperator::Skip(mut skip) => {
413                skip.input = Box::new(self.push_projections_recursive(*skip.input, required));
414                LogicalOperator::Skip(skip)
415            }
416            LogicalOperator::Distinct(mut distinct) => {
417                distinct.input =
418                    Box::new(self.push_projections_recursive(*distinct.input, required));
419                LogicalOperator::Distinct(distinct)
420            }
421            other => other,
422        }
423    }
424
425    /// Reorders joins in the operator tree using the DPccp algorithm.
426    ///
427    /// This optimization finds the optimal join order by:
428    /// 1. Extracting all base relations (scans) and join conditions
429    /// 2. Building a join graph
430    /// 3. Using dynamic programming to find the cheapest join order
431    fn reorder_joins(&self, op: LogicalOperator) -> LogicalOperator {
432        // First, recursively optimize children
433        let op = self.reorder_joins_recursive(op);
434
435        // Then, if this is a join tree, try to optimize it
436        if let Some((relations, conditions)) = self.extract_join_tree(&op)
437            && relations.len() >= 2
438            && let Some(optimized) = self.optimize_join_order(&relations, &conditions)
439        {
440            return optimized;
441        }
442
443        op
444    }
445
446    /// Recursively applies join reordering to child operators.
447    fn reorder_joins_recursive(&self, op: LogicalOperator) -> LogicalOperator {
448        match op {
449            LogicalOperator::Return(mut ret) => {
450                ret.input = Box::new(self.reorder_joins(*ret.input));
451                LogicalOperator::Return(ret)
452            }
453            LogicalOperator::Project(mut proj) => {
454                proj.input = Box::new(self.reorder_joins(*proj.input));
455                LogicalOperator::Project(proj)
456            }
457            LogicalOperator::Filter(mut filter) => {
458                filter.input = Box::new(self.reorder_joins(*filter.input));
459                LogicalOperator::Filter(filter)
460            }
461            LogicalOperator::Limit(mut limit) => {
462                limit.input = Box::new(self.reorder_joins(*limit.input));
463                LogicalOperator::Limit(limit)
464            }
465            LogicalOperator::Skip(mut skip) => {
466                skip.input = Box::new(self.reorder_joins(*skip.input));
467                LogicalOperator::Skip(skip)
468            }
469            LogicalOperator::Sort(mut sort) => {
470                sort.input = Box::new(self.reorder_joins(*sort.input));
471                LogicalOperator::Sort(sort)
472            }
473            LogicalOperator::Distinct(mut distinct) => {
474                distinct.input = Box::new(self.reorder_joins(*distinct.input));
475                LogicalOperator::Distinct(distinct)
476            }
477            LogicalOperator::Aggregate(mut agg) => {
478                agg.input = Box::new(self.reorder_joins(*agg.input));
479                LogicalOperator::Aggregate(agg)
480            }
481            LogicalOperator::Expand(mut expand) => {
482                expand.input = Box::new(self.reorder_joins(*expand.input));
483                LogicalOperator::Expand(expand)
484            }
485            // Join operators are handled by the parent reorder_joins call
486            other => other,
487        }
488    }
489
490    /// Extracts base relations and join conditions from a join tree.
491    ///
492    /// Returns None if the operator is not a join tree.
493    fn extract_join_tree(
494        &self,
495        op: &LogicalOperator,
496    ) -> Option<(Vec<(String, LogicalOperator)>, Vec<JoinInfo>)> {
497        let mut relations = Vec::new();
498        let mut join_conditions = Vec::new();
499
500        if !self.collect_join_tree(op, &mut relations, &mut join_conditions) {
501            return None;
502        }
503
504        if relations.len() < 2 {
505            return None;
506        }
507
508        Some((relations, join_conditions))
509    }
510
511    /// Recursively collects base relations and join conditions.
512    ///
513    /// Returns true if this subtree is part of a join tree.
514    fn collect_join_tree(
515        &self,
516        op: &LogicalOperator,
517        relations: &mut Vec<(String, LogicalOperator)>,
518        conditions: &mut Vec<JoinInfo>,
519    ) -> bool {
520        match op {
521            LogicalOperator::Join(join) => {
522                // Collect from both sides
523                let left_ok = self.collect_join_tree(&join.left, relations, conditions);
524                let right_ok = self.collect_join_tree(&join.right, relations, conditions);
525
526                // Add conditions from this join
527                for cond in &join.conditions {
528                    if let (Some(left_var), Some(right_var)) = (
529                        self.extract_variable_from_expr(&cond.left),
530                        self.extract_variable_from_expr(&cond.right),
531                    ) {
532                        conditions.push(JoinInfo {
533                            left_var,
534                            right_var,
535                            left_expr: cond.left.clone(),
536                            right_expr: cond.right.clone(),
537                        });
538                    }
539                }
540
541                left_ok && right_ok
542            }
543            LogicalOperator::NodeScan(scan) => {
544                relations.push((scan.variable.clone(), op.clone()));
545                true
546            }
547            LogicalOperator::EdgeScan(scan) => {
548                relations.push((scan.variable.clone(), op.clone()));
549                true
550            }
551            LogicalOperator::Filter(filter) => {
552                // A filter on a base relation is still part of the join tree
553                self.collect_join_tree(&filter.input, relations, conditions)
554            }
555            LogicalOperator::Expand(expand) => {
556                // Expand is a special case - it's like a join with the adjacency
557                // For now, treat the whole Expand subtree as a single relation
558                relations.push((expand.to_variable.clone(), op.clone()));
559                true
560            }
561            _ => false,
562        }
563    }
564
565    /// Extracts the primary variable from an expression.
566    fn extract_variable_from_expr(&self, expr: &LogicalExpression) -> Option<String> {
567        match expr {
568            LogicalExpression::Variable(v) => Some(v.clone()),
569            LogicalExpression::Property { variable, .. } => Some(variable.clone()),
570            _ => None,
571        }
572    }
573
574    /// Optimizes the join order using DPccp.
575    fn optimize_join_order(
576        &self,
577        relations: &[(String, LogicalOperator)],
578        conditions: &[JoinInfo],
579    ) -> Option<LogicalOperator> {
580        use join_order::{DPccp, JoinGraphBuilder};
581
582        // Build the join graph
583        let mut builder = JoinGraphBuilder::new();
584
585        for (var, relation) in relations {
586            builder.add_relation(var, relation.clone());
587        }
588
589        for cond in conditions {
590            builder.add_join_condition(
591                &cond.left_var,
592                &cond.right_var,
593                cond.left_expr.clone(),
594                cond.right_expr.clone(),
595            );
596        }
597
598        let graph = builder.build();
599
600        // Run DPccp
601        let mut dpccp = DPccp::new(&graph, &self.cost_model, &self.card_estimator);
602        let plan = dpccp.optimize()?;
603
604        Some(plan.operator)
605    }
606
607    /// Pushes filters down the operator tree.
608    ///
609    /// This optimization moves filter predicates as close to the data source
610    /// as possible to reduce the amount of data processed by upper operators.
611    fn push_filters_down(&self, op: LogicalOperator) -> LogicalOperator {
612        match op {
613            // For Filter operators, try to push the predicate into the child
614            LogicalOperator::Filter(filter) => {
615                let optimized_input = self.push_filters_down(*filter.input);
616                self.try_push_filter_into(filter.predicate, optimized_input)
617            }
618            // Recursively optimize children for other operators
619            LogicalOperator::Return(mut ret) => {
620                ret.input = Box::new(self.push_filters_down(*ret.input));
621                LogicalOperator::Return(ret)
622            }
623            LogicalOperator::Project(mut proj) => {
624                proj.input = Box::new(self.push_filters_down(*proj.input));
625                LogicalOperator::Project(proj)
626            }
627            LogicalOperator::Limit(mut limit) => {
628                limit.input = Box::new(self.push_filters_down(*limit.input));
629                LogicalOperator::Limit(limit)
630            }
631            LogicalOperator::Skip(mut skip) => {
632                skip.input = Box::new(self.push_filters_down(*skip.input));
633                LogicalOperator::Skip(skip)
634            }
635            LogicalOperator::Sort(mut sort) => {
636                sort.input = Box::new(self.push_filters_down(*sort.input));
637                LogicalOperator::Sort(sort)
638            }
639            LogicalOperator::Distinct(mut distinct) => {
640                distinct.input = Box::new(self.push_filters_down(*distinct.input));
641                LogicalOperator::Distinct(distinct)
642            }
643            LogicalOperator::Expand(mut expand) => {
644                expand.input = Box::new(self.push_filters_down(*expand.input));
645                LogicalOperator::Expand(expand)
646            }
647            LogicalOperator::Join(mut join) => {
648                join.left = Box::new(self.push_filters_down(*join.left));
649                join.right = Box::new(self.push_filters_down(*join.right));
650                LogicalOperator::Join(join)
651            }
652            LogicalOperator::Aggregate(mut agg) => {
653                agg.input = Box::new(self.push_filters_down(*agg.input));
654                LogicalOperator::Aggregate(agg)
655            }
656            // Leaf operators and unsupported operators are returned as-is
657            other => other,
658        }
659    }
660
661    /// Tries to push a filter predicate into the given operator.
662    ///
663    /// Returns either the predicate pushed into the operator, or a new
664    /// Filter operator on top if the predicate cannot be pushed further.
665    fn try_push_filter_into(
666        &self,
667        predicate: LogicalExpression,
668        op: LogicalOperator,
669    ) -> LogicalOperator {
670        match op {
671            // Can push through Project if predicate doesn't depend on computed columns
672            LogicalOperator::Project(mut proj) => {
673                let predicate_vars = self.extract_variables(&predicate);
674                let computed_vars = self.extract_projection_aliases(&proj.projections);
675
676                // If predicate doesn't use any computed columns, push through
677                if predicate_vars.is_disjoint(&computed_vars) {
678                    proj.input = Box::new(self.try_push_filter_into(predicate, *proj.input));
679                    LogicalOperator::Project(proj)
680                } else {
681                    // Can't push through, keep filter on top
682                    LogicalOperator::Filter(FilterOp {
683                        predicate,
684                        input: Box::new(LogicalOperator::Project(proj)),
685                    })
686                }
687            }
688
689            // Can push through Return (which is like a projection)
690            LogicalOperator::Return(mut ret) => {
691                ret.input = Box::new(self.try_push_filter_into(predicate, *ret.input));
692                LogicalOperator::Return(ret)
693            }
694
695            // Can push through Expand if predicate doesn't use variables introduced by this expand
696            LogicalOperator::Expand(mut expand) => {
697                let predicate_vars = self.extract_variables(&predicate);
698
699                // Variables introduced by this expand are:
700                // - The target variable (to_variable)
701                // - The edge variable (if any)
702                // - The path alias (if any)
703                let mut introduced_vars = vec![&expand.to_variable];
704                if let Some(ref edge_var) = expand.edge_variable {
705                    introduced_vars.push(edge_var);
706                }
707                if let Some(ref path_alias) = expand.path_alias {
708                    introduced_vars.push(path_alias);
709                }
710
711                // Check if predicate uses any variables introduced by this expand
712                let uses_introduced_vars =
713                    predicate_vars.iter().any(|v| introduced_vars.contains(&v));
714
715                if !uses_introduced_vars {
716                    // Predicate doesn't use vars from this expand, so push through
717                    expand.input = Box::new(self.try_push_filter_into(predicate, *expand.input));
718                    LogicalOperator::Expand(expand)
719                } else {
720                    // Keep filter after expand
721                    LogicalOperator::Filter(FilterOp {
722                        predicate,
723                        input: Box::new(LogicalOperator::Expand(expand)),
724                    })
725                }
726            }
727
728            // Can push through Join to left/right side based on variables used
729            LogicalOperator::Join(mut join) => {
730                let predicate_vars = self.extract_variables(&predicate);
731                let left_vars = self.collect_output_variables(&join.left);
732                let right_vars = self.collect_output_variables(&join.right);
733
734                let uses_left = predicate_vars.iter().any(|v| left_vars.contains(v));
735                let uses_right = predicate_vars.iter().any(|v| right_vars.contains(v));
736
737                if uses_left && !uses_right {
738                    // Push to left side
739                    join.left = Box::new(self.try_push_filter_into(predicate, *join.left));
740                    LogicalOperator::Join(join)
741                } else if uses_right && !uses_left {
742                    // Push to right side
743                    join.right = Box::new(self.try_push_filter_into(predicate, *join.right));
744                    LogicalOperator::Join(join)
745                } else {
746                    // Uses both sides - keep above join
747                    LogicalOperator::Filter(FilterOp {
748                        predicate,
749                        input: Box::new(LogicalOperator::Join(join)),
750                    })
751                }
752            }
753
754            // Cannot push through Aggregate (predicate refers to aggregated values)
755            LogicalOperator::Aggregate(agg) => LogicalOperator::Filter(FilterOp {
756                predicate,
757                input: Box::new(LogicalOperator::Aggregate(agg)),
758            }),
759
760            // For NodeScan, we've reached the bottom - keep filter on top
761            LogicalOperator::NodeScan(scan) => LogicalOperator::Filter(FilterOp {
762                predicate,
763                input: Box::new(LogicalOperator::NodeScan(scan)),
764            }),
765
766            // For other operators, keep filter on top
767            other => LogicalOperator::Filter(FilterOp {
768                predicate,
769                input: Box::new(other),
770            }),
771        }
772    }
773
774    /// Collects all output variable names from an operator.
775    fn collect_output_variables(&self, op: &LogicalOperator) -> HashSet<String> {
776        let mut vars = HashSet::new();
777        Self::collect_output_variables_recursive(op, &mut vars);
778        vars
779    }
780
781    /// Recursively collects output variables from an operator.
782    fn collect_output_variables_recursive(op: &LogicalOperator, vars: &mut HashSet<String>) {
783        match op {
784            LogicalOperator::NodeScan(scan) => {
785                vars.insert(scan.variable.clone());
786            }
787            LogicalOperator::EdgeScan(scan) => {
788                vars.insert(scan.variable.clone());
789            }
790            LogicalOperator::Expand(expand) => {
791                vars.insert(expand.to_variable.clone());
792                if let Some(edge_var) = &expand.edge_variable {
793                    vars.insert(edge_var.clone());
794                }
795                Self::collect_output_variables_recursive(&expand.input, vars);
796            }
797            LogicalOperator::Filter(filter) => {
798                Self::collect_output_variables_recursive(&filter.input, vars);
799            }
800            LogicalOperator::Project(proj) => {
801                for p in &proj.projections {
802                    if let Some(alias) = &p.alias {
803                        vars.insert(alias.clone());
804                    }
805                }
806                Self::collect_output_variables_recursive(&proj.input, vars);
807            }
808            LogicalOperator::Join(join) => {
809                Self::collect_output_variables_recursive(&join.left, vars);
810                Self::collect_output_variables_recursive(&join.right, vars);
811            }
812            LogicalOperator::Aggregate(agg) => {
813                for expr in &agg.group_by {
814                    Self::collect_variables(expr, vars);
815                }
816                for agg_expr in &agg.aggregates {
817                    if let Some(alias) = &agg_expr.alias {
818                        vars.insert(alias.clone());
819                    }
820                }
821            }
822            LogicalOperator::Return(ret) => {
823                Self::collect_output_variables_recursive(&ret.input, vars);
824            }
825            LogicalOperator::Limit(limit) => {
826                Self::collect_output_variables_recursive(&limit.input, vars);
827            }
828            LogicalOperator::Skip(skip) => {
829                Self::collect_output_variables_recursive(&skip.input, vars);
830            }
831            LogicalOperator::Sort(sort) => {
832                Self::collect_output_variables_recursive(&sort.input, vars);
833            }
834            LogicalOperator::Distinct(distinct) => {
835                Self::collect_output_variables_recursive(&distinct.input, vars);
836            }
837            _ => {}
838        }
839    }
840
841    /// Extracts all variable names referenced in an expression.
842    fn extract_variables(&self, expr: &LogicalExpression) -> HashSet<String> {
843        let mut vars = HashSet::new();
844        Self::collect_variables(expr, &mut vars);
845        vars
846    }
847
848    /// Recursively collects variable names from an expression.
849    fn collect_variables(expr: &LogicalExpression, vars: &mut HashSet<String>) {
850        match expr {
851            LogicalExpression::Variable(name) => {
852                vars.insert(name.clone());
853            }
854            LogicalExpression::Property { variable, .. } => {
855                vars.insert(variable.clone());
856            }
857            LogicalExpression::Binary { left, right, .. } => {
858                Self::collect_variables(left, vars);
859                Self::collect_variables(right, vars);
860            }
861            LogicalExpression::Unary { operand, .. } => {
862                Self::collect_variables(operand, vars);
863            }
864            LogicalExpression::FunctionCall { args, .. } => {
865                for arg in args {
866                    Self::collect_variables(arg, vars);
867                }
868            }
869            LogicalExpression::List(items) => {
870                for item in items {
871                    Self::collect_variables(item, vars);
872                }
873            }
874            LogicalExpression::Map(pairs) => {
875                for (_, value) in pairs {
876                    Self::collect_variables(value, vars);
877                }
878            }
879            LogicalExpression::IndexAccess { base, index } => {
880                Self::collect_variables(base, vars);
881                Self::collect_variables(index, vars);
882            }
883            LogicalExpression::SliceAccess { base, start, end } => {
884                Self::collect_variables(base, vars);
885                if let Some(s) = start {
886                    Self::collect_variables(s, vars);
887                }
888                if let Some(e) = end {
889                    Self::collect_variables(e, vars);
890                }
891            }
892            LogicalExpression::Case {
893                operand,
894                when_clauses,
895                else_clause,
896            } => {
897                if let Some(op) = operand {
898                    Self::collect_variables(op, vars);
899                }
900                for (cond, result) in when_clauses {
901                    Self::collect_variables(cond, vars);
902                    Self::collect_variables(result, vars);
903                }
904                if let Some(else_expr) = else_clause {
905                    Self::collect_variables(else_expr, vars);
906                }
907            }
908            LogicalExpression::Labels(var)
909            | LogicalExpression::Type(var)
910            | LogicalExpression::Id(var) => {
911                vars.insert(var.clone());
912            }
913            LogicalExpression::Literal(_) | LogicalExpression::Parameter(_) => {}
914            LogicalExpression::ListComprehension {
915                list_expr,
916                filter_expr,
917                map_expr,
918                ..
919            } => {
920                Self::collect_variables(list_expr, vars);
921                if let Some(filter) = filter_expr {
922                    Self::collect_variables(filter, vars);
923                }
924                Self::collect_variables(map_expr, vars);
925            }
926            LogicalExpression::ExistsSubquery(_) | LogicalExpression::CountSubquery(_) => {
927                // Subqueries have their own variable scope
928            }
929        }
930    }
931
932    /// Extracts aliases from projection expressions.
933    fn extract_projection_aliases(
934        &self,
935        projections: &[crate::query::plan::Projection],
936    ) -> HashSet<String> {
937        projections.iter().filter_map(|p| p.alias.clone()).collect()
938    }
939}
940
941impl Default for Optimizer {
942    fn default() -> Self {
943        Self::new()
944    }
945}
946
947#[cfg(test)]
948mod tests {
949    use super::*;
950    use crate::query::plan::{
951        AggregateExpr, AggregateFunction, AggregateOp, BinaryOp, DistinctOp, ExpandDirection,
952        ExpandOp, JoinOp, JoinType, LimitOp, NodeScanOp, ProjectOp, Projection, ReturnItem,
953        ReturnOp, SkipOp, SortKey, SortOp, SortOrder, UnaryOp,
954    };
955    use grafeo_common::types::Value;
956
957    #[test]
958    fn test_optimizer_filter_pushdown_simple() {
959        // Query: MATCH (n:Person) WHERE n.age > 30 RETURN n
960        // Before: Return -> Filter -> NodeScan
961        // After:  Return -> Filter -> NodeScan (filter stays at bottom)
962
963        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
964            items: vec![ReturnItem {
965                expression: LogicalExpression::Variable("n".to_string()),
966                alias: None,
967            }],
968            distinct: false,
969            input: Box::new(LogicalOperator::Filter(FilterOp {
970                predicate: LogicalExpression::Binary {
971                    left: Box::new(LogicalExpression::Property {
972                        variable: "n".to_string(),
973                        property: "age".to_string(),
974                    }),
975                    op: BinaryOp::Gt,
976                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
977                },
978                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
979                    variable: "n".to_string(),
980                    label: Some("Person".to_string()),
981                    input: None,
982                })),
983            })),
984        }));
985
986        let optimizer = Optimizer::new();
987        let optimized = optimizer.optimize(plan).unwrap();
988
989        // The structure should remain similar (filter stays near scan)
990        if let LogicalOperator::Return(ret) = &optimized.root
991            && let LogicalOperator::Filter(filter) = ret.input.as_ref()
992            && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
993        {
994            assert_eq!(scan.variable, "n");
995            return;
996        }
997        panic!("Expected Return -> Filter -> NodeScan structure");
998    }
999
1000    #[test]
1001    fn test_optimizer_filter_pushdown_through_expand() {
1002        // Query: MATCH (a:Person)-[:KNOWS]->(b) WHERE a.age > 30 RETURN b
1003        // The filter on 'a' should be pushed before the expand
1004
1005        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1006            items: vec![ReturnItem {
1007                expression: LogicalExpression::Variable("b".to_string()),
1008                alias: None,
1009            }],
1010            distinct: false,
1011            input: Box::new(LogicalOperator::Filter(FilterOp {
1012                predicate: LogicalExpression::Binary {
1013                    left: Box::new(LogicalExpression::Property {
1014                        variable: "a".to_string(),
1015                        property: "age".to_string(),
1016                    }),
1017                    op: BinaryOp::Gt,
1018                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1019                },
1020                input: Box::new(LogicalOperator::Expand(ExpandOp {
1021                    from_variable: "a".to_string(),
1022                    to_variable: "b".to_string(),
1023                    edge_variable: None,
1024                    direction: ExpandDirection::Outgoing,
1025                    edge_type: Some("KNOWS".to_string()),
1026                    min_hops: 1,
1027                    max_hops: Some(1),
1028                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1029                        variable: "a".to_string(),
1030                        label: Some("Person".to_string()),
1031                        input: None,
1032                    })),
1033                    path_alias: None,
1034                })),
1035            })),
1036        }));
1037
1038        let optimizer = Optimizer::new();
1039        let optimized = optimizer.optimize(plan).unwrap();
1040
1041        // Filter on 'a' should be pushed before the expand
1042        // Expected: Return -> Expand -> Filter -> NodeScan
1043        if let LogicalOperator::Return(ret) = &optimized.root
1044            && let LogicalOperator::Expand(expand) = ret.input.as_ref()
1045            && let LogicalOperator::Filter(filter) = expand.input.as_ref()
1046            && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1047        {
1048            assert_eq!(scan.variable, "a");
1049            assert_eq!(expand.from_variable, "a");
1050            assert_eq!(expand.to_variable, "b");
1051            return;
1052        }
1053        panic!("Expected Return -> Expand -> Filter -> NodeScan structure");
1054    }
1055
1056    #[test]
1057    fn test_optimizer_filter_not_pushed_through_expand_for_target_var() {
1058        // Query: MATCH (a:Person)-[:KNOWS]->(b) WHERE b.age > 30 RETURN a
1059        // The filter on 'b' should NOT be pushed before the expand
1060
1061        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1062            items: vec![ReturnItem {
1063                expression: LogicalExpression::Variable("a".to_string()),
1064                alias: None,
1065            }],
1066            distinct: false,
1067            input: Box::new(LogicalOperator::Filter(FilterOp {
1068                predicate: LogicalExpression::Binary {
1069                    left: Box::new(LogicalExpression::Property {
1070                        variable: "b".to_string(),
1071                        property: "age".to_string(),
1072                    }),
1073                    op: BinaryOp::Gt,
1074                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1075                },
1076                input: Box::new(LogicalOperator::Expand(ExpandOp {
1077                    from_variable: "a".to_string(),
1078                    to_variable: "b".to_string(),
1079                    edge_variable: None,
1080                    direction: ExpandDirection::Outgoing,
1081                    edge_type: Some("KNOWS".to_string()),
1082                    min_hops: 1,
1083                    max_hops: Some(1),
1084                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1085                        variable: "a".to_string(),
1086                        label: Some("Person".to_string()),
1087                        input: None,
1088                    })),
1089                    path_alias: None,
1090                })),
1091            })),
1092        }));
1093
1094        let optimizer = Optimizer::new();
1095        let optimized = optimizer.optimize(plan).unwrap();
1096
1097        // Filter on 'b' should stay after the expand
1098        // Expected: Return -> Filter -> Expand -> NodeScan
1099        if let LogicalOperator::Return(ret) = &optimized.root
1100            && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1101        {
1102            // Check that the filter is on 'b'
1103            if let LogicalExpression::Binary { left, .. } = &filter.predicate
1104                && let LogicalExpression::Property { variable, .. } = left.as_ref()
1105            {
1106                assert_eq!(variable, "b");
1107            }
1108
1109            if let LogicalOperator::Expand(expand) = filter.input.as_ref()
1110                && let LogicalOperator::NodeScan(_) = expand.input.as_ref()
1111            {
1112                return;
1113            }
1114        }
1115        panic!("Expected Return -> Filter -> Expand -> NodeScan structure");
1116    }
1117
1118    #[test]
1119    fn test_optimizer_extract_variables() {
1120        let optimizer = Optimizer::new();
1121
1122        let expr = LogicalExpression::Binary {
1123            left: Box::new(LogicalExpression::Property {
1124                variable: "n".to_string(),
1125                property: "age".to_string(),
1126            }),
1127            op: BinaryOp::Gt,
1128            right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1129        };
1130
1131        let vars = optimizer.extract_variables(&expr);
1132        assert_eq!(vars.len(), 1);
1133        assert!(vars.contains("n"));
1134    }
1135
1136    // Additional tests for optimizer configuration
1137
1138    #[test]
1139    fn test_optimizer_default() {
1140        let optimizer = Optimizer::default();
1141        // Should be able to optimize an empty plan
1142        let plan = LogicalPlan::new(LogicalOperator::Empty);
1143        let result = optimizer.optimize(plan);
1144        assert!(result.is_ok());
1145    }
1146
1147    #[test]
1148    fn test_optimizer_with_filter_pushdown_disabled() {
1149        let optimizer = Optimizer::new().with_filter_pushdown(false);
1150
1151        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1152            items: vec![ReturnItem {
1153                expression: LogicalExpression::Variable("n".to_string()),
1154                alias: None,
1155            }],
1156            distinct: false,
1157            input: Box::new(LogicalOperator::Filter(FilterOp {
1158                predicate: LogicalExpression::Literal(Value::Bool(true)),
1159                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1160                    variable: "n".to_string(),
1161                    label: None,
1162                    input: None,
1163                })),
1164            })),
1165        }));
1166
1167        let optimized = optimizer.optimize(plan).unwrap();
1168        // Structure should be unchanged
1169        if let LogicalOperator::Return(ret) = &optimized.root
1170            && let LogicalOperator::Filter(_) = ret.input.as_ref()
1171        {
1172            return;
1173        }
1174        panic!("Expected unchanged structure");
1175    }
1176
1177    #[test]
1178    fn test_optimizer_with_join_reorder_disabled() {
1179        let optimizer = Optimizer::new().with_join_reorder(false);
1180        assert!(
1181            optimizer
1182                .optimize(LogicalPlan::new(LogicalOperator::Empty))
1183                .is_ok()
1184        );
1185    }
1186
1187    #[test]
1188    fn test_optimizer_with_cost_model() {
1189        let cost_model = CostModel::new();
1190        let optimizer = Optimizer::new().with_cost_model(cost_model);
1191        assert!(
1192            optimizer
1193                .cost_model()
1194                .estimate(&LogicalOperator::Empty, 0.0)
1195                .total()
1196                < 0.001
1197        );
1198    }
1199
1200    #[test]
1201    fn test_optimizer_with_cardinality_estimator() {
1202        let mut estimator = CardinalityEstimator::new();
1203        estimator.add_table_stats("Test", TableStats::new(500));
1204        let optimizer = Optimizer::new().with_cardinality_estimator(estimator);
1205
1206        let scan = LogicalOperator::NodeScan(NodeScanOp {
1207            variable: "n".to_string(),
1208            label: Some("Test".to_string()),
1209            input: None,
1210        });
1211        let plan = LogicalPlan::new(scan);
1212
1213        let cardinality = optimizer.estimate_cardinality(&plan);
1214        assert!((cardinality - 500.0).abs() < 0.001);
1215    }
1216
1217    #[test]
1218    fn test_optimizer_estimate_cost() {
1219        let optimizer = Optimizer::new();
1220        let plan = LogicalPlan::new(LogicalOperator::NodeScan(NodeScanOp {
1221            variable: "n".to_string(),
1222            label: None,
1223            input: None,
1224        }));
1225
1226        let cost = optimizer.estimate_cost(&plan);
1227        assert!(cost.total() > 0.0);
1228    }
1229
1230    // Filter pushdown through various operators
1231
1232    #[test]
1233    fn test_filter_pushdown_through_project() {
1234        let optimizer = Optimizer::new();
1235
1236        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1237            predicate: LogicalExpression::Binary {
1238                left: Box::new(LogicalExpression::Property {
1239                    variable: "n".to_string(),
1240                    property: "age".to_string(),
1241                }),
1242                op: BinaryOp::Gt,
1243                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1244            },
1245            input: Box::new(LogicalOperator::Project(ProjectOp {
1246                projections: vec![Projection {
1247                    expression: LogicalExpression::Variable("n".to_string()),
1248                    alias: None,
1249                }],
1250                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1251                    variable: "n".to_string(),
1252                    label: None,
1253                    input: None,
1254                })),
1255            })),
1256        }));
1257
1258        let optimized = optimizer.optimize(plan).unwrap();
1259
1260        // Filter should be pushed through Project
1261        if let LogicalOperator::Project(proj) = &optimized.root
1262            && let LogicalOperator::Filter(_) = proj.input.as_ref()
1263        {
1264            return;
1265        }
1266        panic!("Expected Project -> Filter structure");
1267    }
1268
1269    #[test]
1270    fn test_filter_not_pushed_through_project_with_alias() {
1271        let optimizer = Optimizer::new();
1272
1273        // Filter on computed column 'x' should not be pushed through project that creates 'x'
1274        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1275            predicate: LogicalExpression::Binary {
1276                left: Box::new(LogicalExpression::Variable("x".to_string())),
1277                op: BinaryOp::Gt,
1278                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1279            },
1280            input: Box::new(LogicalOperator::Project(ProjectOp {
1281                projections: vec![Projection {
1282                    expression: LogicalExpression::Property {
1283                        variable: "n".to_string(),
1284                        property: "age".to_string(),
1285                    },
1286                    alias: Some("x".to_string()),
1287                }],
1288                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1289                    variable: "n".to_string(),
1290                    label: None,
1291                    input: None,
1292                })),
1293            })),
1294        }));
1295
1296        let optimized = optimizer.optimize(plan).unwrap();
1297
1298        // Filter should stay above Project
1299        if let LogicalOperator::Filter(filter) = &optimized.root
1300            && let LogicalOperator::Project(_) = filter.input.as_ref()
1301        {
1302            return;
1303        }
1304        panic!("Expected Filter -> Project structure");
1305    }
1306
1307    #[test]
1308    fn test_filter_pushdown_through_limit() {
1309        let optimizer = Optimizer::new();
1310
1311        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1312            predicate: LogicalExpression::Literal(Value::Bool(true)),
1313            input: Box::new(LogicalOperator::Limit(LimitOp {
1314                count: 10,
1315                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1316                    variable: "n".to_string(),
1317                    label: None,
1318                    input: None,
1319                })),
1320            })),
1321        }));
1322
1323        let optimized = optimizer.optimize(plan).unwrap();
1324
1325        // Filter stays above Limit (cannot be pushed through)
1326        if let LogicalOperator::Filter(filter) = &optimized.root
1327            && let LogicalOperator::Limit(_) = filter.input.as_ref()
1328        {
1329            return;
1330        }
1331        panic!("Expected Filter -> Limit structure");
1332    }
1333
1334    #[test]
1335    fn test_filter_pushdown_through_sort() {
1336        let optimizer = Optimizer::new();
1337
1338        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1339            predicate: LogicalExpression::Literal(Value::Bool(true)),
1340            input: Box::new(LogicalOperator::Sort(SortOp {
1341                keys: vec![SortKey {
1342                    expression: LogicalExpression::Variable("n".to_string()),
1343                    order: SortOrder::Ascending,
1344                }],
1345                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1346                    variable: "n".to_string(),
1347                    label: None,
1348                    input: None,
1349                })),
1350            })),
1351        }));
1352
1353        let optimized = optimizer.optimize(plan).unwrap();
1354
1355        // Filter stays above Sort
1356        if let LogicalOperator::Filter(filter) = &optimized.root
1357            && let LogicalOperator::Sort(_) = filter.input.as_ref()
1358        {
1359            return;
1360        }
1361        panic!("Expected Filter -> Sort structure");
1362    }
1363
1364    #[test]
1365    fn test_filter_pushdown_through_distinct() {
1366        let optimizer = Optimizer::new();
1367
1368        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1369            predicate: LogicalExpression::Literal(Value::Bool(true)),
1370            input: Box::new(LogicalOperator::Distinct(DistinctOp {
1371                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1372                    variable: "n".to_string(),
1373                    label: None,
1374                    input: None,
1375                })),
1376                columns: None,
1377            })),
1378        }));
1379
1380        let optimized = optimizer.optimize(plan).unwrap();
1381
1382        // Filter stays above Distinct
1383        if let LogicalOperator::Filter(filter) = &optimized.root
1384            && let LogicalOperator::Distinct(_) = filter.input.as_ref()
1385        {
1386            return;
1387        }
1388        panic!("Expected Filter -> Distinct structure");
1389    }
1390
1391    #[test]
1392    fn test_filter_not_pushed_through_aggregate() {
1393        let optimizer = Optimizer::new();
1394
1395        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1396            predicate: LogicalExpression::Binary {
1397                left: Box::new(LogicalExpression::Variable("cnt".to_string())),
1398                op: BinaryOp::Gt,
1399                right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1400            },
1401            input: Box::new(LogicalOperator::Aggregate(AggregateOp {
1402                group_by: vec![],
1403                aggregates: vec![AggregateExpr {
1404                    function: AggregateFunction::Count,
1405                    expression: None,
1406                    distinct: false,
1407                    alias: Some("cnt".to_string()),
1408                    percentile: None,
1409                }],
1410                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1411                    variable: "n".to_string(),
1412                    label: None,
1413                    input: None,
1414                })),
1415                having: None,
1416            })),
1417        }));
1418
1419        let optimized = optimizer.optimize(plan).unwrap();
1420
1421        // Filter should stay above Aggregate
1422        if let LogicalOperator::Filter(filter) = &optimized.root
1423            && let LogicalOperator::Aggregate(_) = filter.input.as_ref()
1424        {
1425            return;
1426        }
1427        panic!("Expected Filter -> Aggregate structure");
1428    }
1429
1430    #[test]
1431    fn test_filter_pushdown_to_left_join_side() {
1432        let optimizer = Optimizer::new();
1433
1434        // Filter on left variable should be pushed to left side
1435        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1436            predicate: LogicalExpression::Binary {
1437                left: Box::new(LogicalExpression::Property {
1438                    variable: "a".to_string(),
1439                    property: "age".to_string(),
1440                }),
1441                op: BinaryOp::Gt,
1442                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1443            },
1444            input: Box::new(LogicalOperator::Join(JoinOp {
1445                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1446                    variable: "a".to_string(),
1447                    label: Some("Person".to_string()),
1448                    input: None,
1449                })),
1450                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1451                    variable: "b".to_string(),
1452                    label: Some("Company".to_string()),
1453                    input: None,
1454                })),
1455                join_type: JoinType::Inner,
1456                conditions: vec![],
1457            })),
1458        }));
1459
1460        let optimized = optimizer.optimize(plan).unwrap();
1461
1462        // Filter should be pushed to left side of join
1463        if let LogicalOperator::Join(join) = &optimized.root
1464            && let LogicalOperator::Filter(_) = join.left.as_ref()
1465        {
1466            return;
1467        }
1468        panic!("Expected Join with Filter on left side");
1469    }
1470
1471    #[test]
1472    fn test_filter_pushdown_to_right_join_side() {
1473        let optimizer = Optimizer::new();
1474
1475        // Filter on right variable should be pushed to right side
1476        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1477            predicate: LogicalExpression::Binary {
1478                left: Box::new(LogicalExpression::Property {
1479                    variable: "b".to_string(),
1480                    property: "name".to_string(),
1481                }),
1482                op: BinaryOp::Eq,
1483                right: Box::new(LogicalExpression::Literal(Value::String("Acme".into()))),
1484            },
1485            input: Box::new(LogicalOperator::Join(JoinOp {
1486                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1487                    variable: "a".to_string(),
1488                    label: Some("Person".to_string()),
1489                    input: None,
1490                })),
1491                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1492                    variable: "b".to_string(),
1493                    label: Some("Company".to_string()),
1494                    input: None,
1495                })),
1496                join_type: JoinType::Inner,
1497                conditions: vec![],
1498            })),
1499        }));
1500
1501        let optimized = optimizer.optimize(plan).unwrap();
1502
1503        // Filter should be pushed to right side of join
1504        if let LogicalOperator::Join(join) = &optimized.root
1505            && let LogicalOperator::Filter(_) = join.right.as_ref()
1506        {
1507            return;
1508        }
1509        panic!("Expected Join with Filter on right side");
1510    }
1511
1512    #[test]
1513    fn test_filter_not_pushed_when_uses_both_join_sides() {
1514        let optimizer = Optimizer::new();
1515
1516        // Filter using both variables should stay above join
1517        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1518            predicate: LogicalExpression::Binary {
1519                left: Box::new(LogicalExpression::Property {
1520                    variable: "a".to_string(),
1521                    property: "id".to_string(),
1522                }),
1523                op: BinaryOp::Eq,
1524                right: Box::new(LogicalExpression::Property {
1525                    variable: "b".to_string(),
1526                    property: "a_id".to_string(),
1527                }),
1528            },
1529            input: Box::new(LogicalOperator::Join(JoinOp {
1530                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1531                    variable: "a".to_string(),
1532                    label: None,
1533                    input: None,
1534                })),
1535                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1536                    variable: "b".to_string(),
1537                    label: None,
1538                    input: None,
1539                })),
1540                join_type: JoinType::Inner,
1541                conditions: vec![],
1542            })),
1543        }));
1544
1545        let optimized = optimizer.optimize(plan).unwrap();
1546
1547        // Filter should stay above join
1548        if let LogicalOperator::Filter(filter) = &optimized.root
1549            && let LogicalOperator::Join(_) = filter.input.as_ref()
1550        {
1551            return;
1552        }
1553        panic!("Expected Filter -> Join structure");
1554    }
1555
1556    // Variable extraction tests
1557
1558    #[test]
1559    fn test_extract_variables_from_variable() {
1560        let optimizer = Optimizer::new();
1561        let expr = LogicalExpression::Variable("x".to_string());
1562        let vars = optimizer.extract_variables(&expr);
1563        assert_eq!(vars.len(), 1);
1564        assert!(vars.contains("x"));
1565    }
1566
1567    #[test]
1568    fn test_extract_variables_from_unary() {
1569        let optimizer = Optimizer::new();
1570        let expr = LogicalExpression::Unary {
1571            op: UnaryOp::Not,
1572            operand: Box::new(LogicalExpression::Variable("x".to_string())),
1573        };
1574        let vars = optimizer.extract_variables(&expr);
1575        assert_eq!(vars.len(), 1);
1576        assert!(vars.contains("x"));
1577    }
1578
1579    #[test]
1580    fn test_extract_variables_from_function_call() {
1581        let optimizer = Optimizer::new();
1582        let expr = LogicalExpression::FunctionCall {
1583            name: "length".to_string(),
1584            args: vec![
1585                LogicalExpression::Variable("a".to_string()),
1586                LogicalExpression::Variable("b".to_string()),
1587            ],
1588            distinct: false,
1589        };
1590        let vars = optimizer.extract_variables(&expr);
1591        assert_eq!(vars.len(), 2);
1592        assert!(vars.contains("a"));
1593        assert!(vars.contains("b"));
1594    }
1595
1596    #[test]
1597    fn test_extract_variables_from_list() {
1598        let optimizer = Optimizer::new();
1599        let expr = LogicalExpression::List(vec![
1600            LogicalExpression::Variable("a".to_string()),
1601            LogicalExpression::Literal(Value::Int64(1)),
1602            LogicalExpression::Variable("b".to_string()),
1603        ]);
1604        let vars = optimizer.extract_variables(&expr);
1605        assert_eq!(vars.len(), 2);
1606        assert!(vars.contains("a"));
1607        assert!(vars.contains("b"));
1608    }
1609
1610    #[test]
1611    fn test_extract_variables_from_map() {
1612        let optimizer = Optimizer::new();
1613        let expr = LogicalExpression::Map(vec![
1614            (
1615                "key1".to_string(),
1616                LogicalExpression::Variable("a".to_string()),
1617            ),
1618            (
1619                "key2".to_string(),
1620                LogicalExpression::Variable("b".to_string()),
1621            ),
1622        ]);
1623        let vars = optimizer.extract_variables(&expr);
1624        assert_eq!(vars.len(), 2);
1625        assert!(vars.contains("a"));
1626        assert!(vars.contains("b"));
1627    }
1628
1629    #[test]
1630    fn test_extract_variables_from_index_access() {
1631        let optimizer = Optimizer::new();
1632        let expr = LogicalExpression::IndexAccess {
1633            base: Box::new(LogicalExpression::Variable("list".to_string())),
1634            index: Box::new(LogicalExpression::Variable("idx".to_string())),
1635        };
1636        let vars = optimizer.extract_variables(&expr);
1637        assert_eq!(vars.len(), 2);
1638        assert!(vars.contains("list"));
1639        assert!(vars.contains("idx"));
1640    }
1641
1642    #[test]
1643    fn test_extract_variables_from_slice_access() {
1644        let optimizer = Optimizer::new();
1645        let expr = LogicalExpression::SliceAccess {
1646            base: Box::new(LogicalExpression::Variable("list".to_string())),
1647            start: Some(Box::new(LogicalExpression::Variable("s".to_string()))),
1648            end: Some(Box::new(LogicalExpression::Variable("e".to_string()))),
1649        };
1650        let vars = optimizer.extract_variables(&expr);
1651        assert_eq!(vars.len(), 3);
1652        assert!(vars.contains("list"));
1653        assert!(vars.contains("s"));
1654        assert!(vars.contains("e"));
1655    }
1656
1657    #[test]
1658    fn test_extract_variables_from_case() {
1659        let optimizer = Optimizer::new();
1660        let expr = LogicalExpression::Case {
1661            operand: Some(Box::new(LogicalExpression::Variable("x".to_string()))),
1662            when_clauses: vec![(
1663                LogicalExpression::Literal(Value::Int64(1)),
1664                LogicalExpression::Variable("a".to_string()),
1665            )],
1666            else_clause: Some(Box::new(LogicalExpression::Variable("b".to_string()))),
1667        };
1668        let vars = optimizer.extract_variables(&expr);
1669        assert_eq!(vars.len(), 3);
1670        assert!(vars.contains("x"));
1671        assert!(vars.contains("a"));
1672        assert!(vars.contains("b"));
1673    }
1674
1675    #[test]
1676    fn test_extract_variables_from_labels() {
1677        let optimizer = Optimizer::new();
1678        let expr = LogicalExpression::Labels("n".to_string());
1679        let vars = optimizer.extract_variables(&expr);
1680        assert_eq!(vars.len(), 1);
1681        assert!(vars.contains("n"));
1682    }
1683
1684    #[test]
1685    fn test_extract_variables_from_type() {
1686        let optimizer = Optimizer::new();
1687        let expr = LogicalExpression::Type("e".to_string());
1688        let vars = optimizer.extract_variables(&expr);
1689        assert_eq!(vars.len(), 1);
1690        assert!(vars.contains("e"));
1691    }
1692
1693    #[test]
1694    fn test_extract_variables_from_id() {
1695        let optimizer = Optimizer::new();
1696        let expr = LogicalExpression::Id("n".to_string());
1697        let vars = optimizer.extract_variables(&expr);
1698        assert_eq!(vars.len(), 1);
1699        assert!(vars.contains("n"));
1700    }
1701
1702    #[test]
1703    fn test_extract_variables_from_list_comprehension() {
1704        let optimizer = Optimizer::new();
1705        let expr = LogicalExpression::ListComprehension {
1706            variable: "x".to_string(),
1707            list_expr: Box::new(LogicalExpression::Variable("items".to_string())),
1708            filter_expr: Some(Box::new(LogicalExpression::Variable("pred".to_string()))),
1709            map_expr: Box::new(LogicalExpression::Variable("result".to_string())),
1710        };
1711        let vars = optimizer.extract_variables(&expr);
1712        assert!(vars.contains("items"));
1713        assert!(vars.contains("pred"));
1714        assert!(vars.contains("result"));
1715    }
1716
1717    #[test]
1718    fn test_extract_variables_from_literal_and_parameter() {
1719        let optimizer = Optimizer::new();
1720
1721        let literal = LogicalExpression::Literal(Value::Int64(42));
1722        assert!(optimizer.extract_variables(&literal).is_empty());
1723
1724        let param = LogicalExpression::Parameter("p".to_string());
1725        assert!(optimizer.extract_variables(&param).is_empty());
1726    }
1727
1728    // Recursive filter pushdown tests
1729
1730    #[test]
1731    fn test_recursive_filter_pushdown_through_skip() {
1732        let optimizer = Optimizer::new();
1733
1734        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1735            items: vec![ReturnItem {
1736                expression: LogicalExpression::Variable("n".to_string()),
1737                alias: None,
1738            }],
1739            distinct: false,
1740            input: Box::new(LogicalOperator::Filter(FilterOp {
1741                predicate: LogicalExpression::Literal(Value::Bool(true)),
1742                input: Box::new(LogicalOperator::Skip(SkipOp {
1743                    count: 5,
1744                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1745                        variable: "n".to_string(),
1746                        label: None,
1747                        input: None,
1748                    })),
1749                })),
1750            })),
1751        }));
1752
1753        let optimized = optimizer.optimize(plan).unwrap();
1754
1755        // Verify optimization succeeded
1756        assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1757    }
1758
1759    #[test]
1760    fn test_nested_filter_pushdown() {
1761        let optimizer = Optimizer::new();
1762
1763        // Multiple nested filters
1764        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1765            items: vec![ReturnItem {
1766                expression: LogicalExpression::Variable("n".to_string()),
1767                alias: None,
1768            }],
1769            distinct: false,
1770            input: Box::new(LogicalOperator::Filter(FilterOp {
1771                predicate: LogicalExpression::Binary {
1772                    left: Box::new(LogicalExpression::Property {
1773                        variable: "n".to_string(),
1774                        property: "x".to_string(),
1775                    }),
1776                    op: BinaryOp::Gt,
1777                    right: Box::new(LogicalExpression::Literal(Value::Int64(1))),
1778                },
1779                input: Box::new(LogicalOperator::Filter(FilterOp {
1780                    predicate: LogicalExpression::Binary {
1781                        left: Box::new(LogicalExpression::Property {
1782                            variable: "n".to_string(),
1783                            property: "y".to_string(),
1784                        }),
1785                        op: BinaryOp::Lt,
1786                        right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1787                    },
1788                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1789                        variable: "n".to_string(),
1790                        label: None,
1791                        input: None,
1792                    })),
1793                })),
1794            })),
1795        }));
1796
1797        let optimized = optimizer.optimize(plan).unwrap();
1798        assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1799    }
1800}