Skip to main content

grafeo_engine/query/optimizer/
mod.rs

1//! Makes your queries faster without changing their meaning.
2//!
3//! The optimizer transforms logical plans to run more efficiently:
4//!
5//! | Optimization | What it does |
6//! | ------------ | ------------ |
7//! | Filter Pushdown | Moves `WHERE` clauses closer to scans - filter early, process less |
8//! | Join Reordering | Picks the best order to join tables using the DPccp algorithm |
9//! | Predicate Simplification | Folds constants like `1 + 1` into `2` |
10//!
11//! The optimizer uses [`CostModel`] and [`CardinalityEstimator`] to predict
12//! how expensive different plans are, then picks the cheapest.
13
14pub mod cardinality;
15pub mod cost;
16pub mod join_order;
17
18pub use cardinality::{
19    CardinalityEstimator, ColumnStats, EstimationLog, SelectivityConfig, TableStats,
20};
21pub use cost::{Cost, CostModel};
22pub use join_order::{BitSet, DPccp, JoinGraph, JoinGraphBuilder, JoinPlan};
23
24use crate::query::plan::{FilterOp, LogicalExpression, LogicalOperator, LogicalPlan};
25use grafeo_common::utils::error::Result;
26use std::collections::HashSet;
27
28/// Information about a join condition for join reordering.
29#[derive(Debug, Clone)]
30struct JoinInfo {
31    left_var: String,
32    right_var: String,
33    left_expr: LogicalExpression,
34    right_expr: LogicalExpression,
35}
36
37/// A column required by the query, used for projection pushdown.
38#[derive(Debug, Clone, PartialEq, Eq, Hash)]
39enum RequiredColumn {
40    /// A variable (node, edge, or path binding)
41    Variable(String),
42    /// A specific property of a variable
43    Property(String, String),
44}
45
46/// Transforms logical plans for faster execution.
47///
48/// Create with [`new()`](Self::new), then call [`optimize()`](Self::optimize).
49/// Use the builder methods to enable/disable specific optimizations.
50pub struct Optimizer {
51    /// Whether to enable filter pushdown.
52    enable_filter_pushdown: bool,
53    /// Whether to enable join reordering.
54    enable_join_reorder: bool,
55    /// Whether to enable projection pushdown.
56    enable_projection_pushdown: bool,
57    /// Cost model for estimation.
58    cost_model: CostModel,
59    /// Cardinality estimator.
60    card_estimator: CardinalityEstimator,
61}
62
63impl Optimizer {
64    /// Creates a new optimizer with default settings.
65    #[must_use]
66    pub fn new() -> Self {
67        Self {
68            enable_filter_pushdown: true,
69            enable_join_reorder: true,
70            enable_projection_pushdown: true,
71            cost_model: CostModel::new(),
72            card_estimator: CardinalityEstimator::new(),
73        }
74    }
75
76    /// Creates an optimizer with cardinality estimates from the store's statistics.
77    ///
78    /// Pre-populates the cardinality estimator with per-label row counts and
79    /// edge type fanout. Ensures statistics are fresh before reading.
80    #[must_use]
81    pub fn from_store(store: &grafeo_core::graph::lpg::LpgStore) -> Self {
82        store.ensure_statistics_fresh();
83        let stats = store.statistics();
84        let estimator = CardinalityEstimator::from_statistics(&stats);
85
86        // Derive average fanout from statistics for the cost model
87        let avg_fanout = if stats.total_nodes > 0 {
88            (stats.total_edges as f64 / stats.total_nodes as f64).max(1.0)
89        } else {
90            10.0
91        };
92
93        Self {
94            enable_filter_pushdown: true,
95            enable_join_reorder: true,
96            enable_projection_pushdown: true,
97            cost_model: CostModel::new().with_avg_fanout(avg_fanout),
98            card_estimator: estimator,
99        }
100    }
101
102    /// Enables or disables filter pushdown.
103    pub fn with_filter_pushdown(mut self, enabled: bool) -> Self {
104        self.enable_filter_pushdown = enabled;
105        self
106    }
107
108    /// Enables or disables join reordering.
109    pub fn with_join_reorder(mut self, enabled: bool) -> Self {
110        self.enable_join_reorder = enabled;
111        self
112    }
113
114    /// Enables or disables projection pushdown.
115    pub fn with_projection_pushdown(mut self, enabled: bool) -> Self {
116        self.enable_projection_pushdown = enabled;
117        self
118    }
119
120    /// Sets the cost model.
121    pub fn with_cost_model(mut self, cost_model: CostModel) -> Self {
122        self.cost_model = cost_model;
123        self
124    }
125
126    /// Sets the cardinality estimator.
127    pub fn with_cardinality_estimator(mut self, estimator: CardinalityEstimator) -> Self {
128        self.card_estimator = estimator;
129        self
130    }
131
132    /// Sets the selectivity configuration for the cardinality estimator.
133    pub fn with_selectivity_config(mut self, config: SelectivityConfig) -> Self {
134        self.card_estimator = CardinalityEstimator::with_selectivity_config(config);
135        self
136    }
137
138    /// Returns a reference to the cost model.
139    pub fn cost_model(&self) -> &CostModel {
140        &self.cost_model
141    }
142
143    /// Returns a reference to the cardinality estimator.
144    pub fn cardinality_estimator(&self) -> &CardinalityEstimator {
145        &self.card_estimator
146    }
147
148    /// Estimates the cost of a plan.
149    pub fn estimate_cost(&self, plan: &LogicalPlan) -> Cost {
150        let cardinality = self.card_estimator.estimate(&plan.root);
151        self.cost_model.estimate(&plan.root, cardinality)
152    }
153
154    /// Estimates the cardinality of a plan.
155    pub fn estimate_cardinality(&self, plan: &LogicalPlan) -> f64 {
156        self.card_estimator.estimate(&plan.root)
157    }
158
159    /// Optimizes a logical plan.
160    ///
161    /// # Errors
162    ///
163    /// Returns an error if optimization fails.
164    pub fn optimize(&self, plan: LogicalPlan) -> Result<LogicalPlan> {
165        let mut root = plan.root;
166
167        // Apply optimization rules
168        if self.enable_filter_pushdown {
169            root = self.push_filters_down(root);
170        }
171
172        if self.enable_join_reorder {
173            root = self.reorder_joins(root);
174        }
175
176        if self.enable_projection_pushdown {
177            root = self.push_projections_down(root);
178        }
179
180        Ok(LogicalPlan::new(root))
181    }
182
183    /// Pushes projections down the operator tree to eliminate unused columns early.
184    ///
185    /// This optimization:
186    /// 1. Collects required variables/properties from the root
187    /// 2. Propagates requirements down through the tree
188    /// 3. Inserts projections to eliminate unneeded columns before expensive operations
189    fn push_projections_down(&self, op: LogicalOperator) -> LogicalOperator {
190        // Collect required columns from the top of the plan
191        let required = self.collect_required_columns(&op);
192
193        // Push projections down
194        self.push_projections_recursive(op, &required)
195    }
196
197    /// Collects all variables and properties required by an operator and its ancestors.
198    fn collect_required_columns(&self, op: &LogicalOperator) -> HashSet<RequiredColumn> {
199        let mut required = HashSet::new();
200        Self::collect_required_recursive(op, &mut required);
201        required
202    }
203
204    /// Recursively collects required columns.
205    fn collect_required_recursive(op: &LogicalOperator, required: &mut HashSet<RequiredColumn>) {
206        match op {
207            LogicalOperator::Return(ret) => {
208                for item in &ret.items {
209                    Self::collect_from_expression(&item.expression, required);
210                }
211                Self::collect_required_recursive(&ret.input, required);
212            }
213            LogicalOperator::Project(proj) => {
214                for p in &proj.projections {
215                    Self::collect_from_expression(&p.expression, required);
216                }
217                Self::collect_required_recursive(&proj.input, required);
218            }
219            LogicalOperator::Filter(filter) => {
220                Self::collect_from_expression(&filter.predicate, required);
221                Self::collect_required_recursive(&filter.input, required);
222            }
223            LogicalOperator::Sort(sort) => {
224                for key in &sort.keys {
225                    Self::collect_from_expression(&key.expression, required);
226                }
227                Self::collect_required_recursive(&sort.input, required);
228            }
229            LogicalOperator::Aggregate(agg) => {
230                for expr in &agg.group_by {
231                    Self::collect_from_expression(expr, required);
232                }
233                for agg_expr in &agg.aggregates {
234                    if let Some(ref expr) = agg_expr.expression {
235                        Self::collect_from_expression(expr, required);
236                    }
237                }
238                if let Some(ref having) = agg.having {
239                    Self::collect_from_expression(having, required);
240                }
241                Self::collect_required_recursive(&agg.input, required);
242            }
243            LogicalOperator::Join(join) => {
244                for cond in &join.conditions {
245                    Self::collect_from_expression(&cond.left, required);
246                    Self::collect_from_expression(&cond.right, required);
247                }
248                Self::collect_required_recursive(&join.left, required);
249                Self::collect_required_recursive(&join.right, required);
250            }
251            LogicalOperator::Expand(expand) => {
252                // The source and target variables are needed
253                required.insert(RequiredColumn::Variable(expand.from_variable.clone()));
254                required.insert(RequiredColumn::Variable(expand.to_variable.clone()));
255                if let Some(ref edge_var) = expand.edge_variable {
256                    required.insert(RequiredColumn::Variable(edge_var.clone()));
257                }
258                Self::collect_required_recursive(&expand.input, required);
259            }
260            LogicalOperator::Limit(limit) => {
261                Self::collect_required_recursive(&limit.input, required);
262            }
263            LogicalOperator::Skip(skip) => {
264                Self::collect_required_recursive(&skip.input, required);
265            }
266            LogicalOperator::Distinct(distinct) => {
267                Self::collect_required_recursive(&distinct.input, required);
268            }
269            LogicalOperator::NodeScan(scan) => {
270                required.insert(RequiredColumn::Variable(scan.variable.clone()));
271            }
272            LogicalOperator::EdgeScan(scan) => {
273                required.insert(RequiredColumn::Variable(scan.variable.clone()));
274            }
275            _ => {}
276        }
277    }
278
279    /// Collects required columns from an expression.
280    fn collect_from_expression(expr: &LogicalExpression, required: &mut HashSet<RequiredColumn>) {
281        match expr {
282            LogicalExpression::Variable(var) => {
283                required.insert(RequiredColumn::Variable(var.clone()));
284            }
285            LogicalExpression::Property { variable, property } => {
286                required.insert(RequiredColumn::Property(variable.clone(), property.clone()));
287                required.insert(RequiredColumn::Variable(variable.clone()));
288            }
289            LogicalExpression::Binary { left, right, .. } => {
290                Self::collect_from_expression(left, required);
291                Self::collect_from_expression(right, required);
292            }
293            LogicalExpression::Unary { operand, .. } => {
294                Self::collect_from_expression(operand, required);
295            }
296            LogicalExpression::FunctionCall { args, .. } => {
297                for arg in args {
298                    Self::collect_from_expression(arg, required);
299                }
300            }
301            LogicalExpression::List(items) => {
302                for item in items {
303                    Self::collect_from_expression(item, required);
304                }
305            }
306            LogicalExpression::Map(pairs) => {
307                for (_, value) in pairs {
308                    Self::collect_from_expression(value, required);
309                }
310            }
311            LogicalExpression::IndexAccess { base, index } => {
312                Self::collect_from_expression(base, required);
313                Self::collect_from_expression(index, required);
314            }
315            LogicalExpression::SliceAccess { base, start, end } => {
316                Self::collect_from_expression(base, required);
317                if let Some(s) = start {
318                    Self::collect_from_expression(s, required);
319                }
320                if let Some(e) = end {
321                    Self::collect_from_expression(e, required);
322                }
323            }
324            LogicalExpression::Case {
325                operand,
326                when_clauses,
327                else_clause,
328            } => {
329                if let Some(op) = operand {
330                    Self::collect_from_expression(op, required);
331                }
332                for (cond, result) in when_clauses {
333                    Self::collect_from_expression(cond, required);
334                    Self::collect_from_expression(result, required);
335                }
336                if let Some(else_expr) = else_clause {
337                    Self::collect_from_expression(else_expr, required);
338                }
339            }
340            LogicalExpression::Labels(var)
341            | LogicalExpression::Type(var)
342            | LogicalExpression::Id(var) => {
343                required.insert(RequiredColumn::Variable(var.clone()));
344            }
345            LogicalExpression::ListComprehension {
346                list_expr,
347                filter_expr,
348                map_expr,
349                ..
350            } => {
351                Self::collect_from_expression(list_expr, required);
352                if let Some(filter) = filter_expr {
353                    Self::collect_from_expression(filter, required);
354                }
355                Self::collect_from_expression(map_expr, required);
356            }
357            _ => {}
358        }
359    }
360
361    /// Recursively pushes projections down, adding them before expensive operations.
362    fn push_projections_recursive(
363        &self,
364        op: LogicalOperator,
365        required: &HashSet<RequiredColumn>,
366    ) -> LogicalOperator {
367        match op {
368            LogicalOperator::Return(mut ret) => {
369                ret.input = Box::new(self.push_projections_recursive(*ret.input, required));
370                LogicalOperator::Return(ret)
371            }
372            LogicalOperator::Project(mut proj) => {
373                proj.input = Box::new(self.push_projections_recursive(*proj.input, required));
374                LogicalOperator::Project(proj)
375            }
376            LogicalOperator::Filter(mut filter) => {
377                filter.input = Box::new(self.push_projections_recursive(*filter.input, required));
378                LogicalOperator::Filter(filter)
379            }
380            LogicalOperator::Sort(mut sort) => {
381                // Sort is expensive - consider adding a projection before it
382                // to reduce tuple width
383                sort.input = Box::new(self.push_projections_recursive(*sort.input, required));
384                LogicalOperator::Sort(sort)
385            }
386            LogicalOperator::Aggregate(mut agg) => {
387                agg.input = Box::new(self.push_projections_recursive(*agg.input, required));
388                LogicalOperator::Aggregate(agg)
389            }
390            LogicalOperator::Join(mut join) => {
391                // Joins are expensive - the required columns help determine
392                // what to project on each side
393                let left_vars = self.collect_output_variables(&join.left);
394                let right_vars = self.collect_output_variables(&join.right);
395
396                // Filter required columns to each side
397                let left_required: HashSet<_> = required
398                    .iter()
399                    .filter(|c| match c {
400                        RequiredColumn::Variable(v) => left_vars.contains(v),
401                        RequiredColumn::Property(v, _) => left_vars.contains(v),
402                    })
403                    .cloned()
404                    .collect();
405
406                let right_required: HashSet<_> = required
407                    .iter()
408                    .filter(|c| match c {
409                        RequiredColumn::Variable(v) => right_vars.contains(v),
410                        RequiredColumn::Property(v, _) => right_vars.contains(v),
411                    })
412                    .cloned()
413                    .collect();
414
415                join.left = Box::new(self.push_projections_recursive(*join.left, &left_required));
416                join.right =
417                    Box::new(self.push_projections_recursive(*join.right, &right_required));
418                LogicalOperator::Join(join)
419            }
420            LogicalOperator::Expand(mut expand) => {
421                expand.input = Box::new(self.push_projections_recursive(*expand.input, required));
422                LogicalOperator::Expand(expand)
423            }
424            LogicalOperator::Limit(mut limit) => {
425                limit.input = Box::new(self.push_projections_recursive(*limit.input, required));
426                LogicalOperator::Limit(limit)
427            }
428            LogicalOperator::Skip(mut skip) => {
429                skip.input = Box::new(self.push_projections_recursive(*skip.input, required));
430                LogicalOperator::Skip(skip)
431            }
432            LogicalOperator::Distinct(mut distinct) => {
433                distinct.input =
434                    Box::new(self.push_projections_recursive(*distinct.input, required));
435                LogicalOperator::Distinct(distinct)
436            }
437            other => other,
438        }
439    }
440
441    /// Reorders joins in the operator tree using the DPccp algorithm.
442    ///
443    /// This optimization finds the optimal join order by:
444    /// 1. Extracting all base relations (scans) and join conditions
445    /// 2. Building a join graph
446    /// 3. Using dynamic programming to find the cheapest join order
447    fn reorder_joins(&self, op: LogicalOperator) -> LogicalOperator {
448        // First, recursively optimize children
449        let op = self.reorder_joins_recursive(op);
450
451        // Then, if this is a join tree, try to optimize it
452        if let Some((relations, conditions)) = self.extract_join_tree(&op)
453            && relations.len() >= 2
454            && let Some(optimized) = self.optimize_join_order(&relations, &conditions)
455        {
456            return optimized;
457        }
458
459        op
460    }
461
462    /// Recursively applies join reordering to child operators.
463    fn reorder_joins_recursive(&self, op: LogicalOperator) -> LogicalOperator {
464        match op {
465            LogicalOperator::Return(mut ret) => {
466                ret.input = Box::new(self.reorder_joins(*ret.input));
467                LogicalOperator::Return(ret)
468            }
469            LogicalOperator::Project(mut proj) => {
470                proj.input = Box::new(self.reorder_joins(*proj.input));
471                LogicalOperator::Project(proj)
472            }
473            LogicalOperator::Filter(mut filter) => {
474                filter.input = Box::new(self.reorder_joins(*filter.input));
475                LogicalOperator::Filter(filter)
476            }
477            LogicalOperator::Limit(mut limit) => {
478                limit.input = Box::new(self.reorder_joins(*limit.input));
479                LogicalOperator::Limit(limit)
480            }
481            LogicalOperator::Skip(mut skip) => {
482                skip.input = Box::new(self.reorder_joins(*skip.input));
483                LogicalOperator::Skip(skip)
484            }
485            LogicalOperator::Sort(mut sort) => {
486                sort.input = Box::new(self.reorder_joins(*sort.input));
487                LogicalOperator::Sort(sort)
488            }
489            LogicalOperator::Distinct(mut distinct) => {
490                distinct.input = Box::new(self.reorder_joins(*distinct.input));
491                LogicalOperator::Distinct(distinct)
492            }
493            LogicalOperator::Aggregate(mut agg) => {
494                agg.input = Box::new(self.reorder_joins(*agg.input));
495                LogicalOperator::Aggregate(agg)
496            }
497            LogicalOperator::Expand(mut expand) => {
498                expand.input = Box::new(self.reorder_joins(*expand.input));
499                LogicalOperator::Expand(expand)
500            }
501            // Join operators are handled by the parent reorder_joins call
502            other => other,
503        }
504    }
505
506    /// Extracts base relations and join conditions from a join tree.
507    ///
508    /// Returns None if the operator is not a join tree.
509    fn extract_join_tree(
510        &self,
511        op: &LogicalOperator,
512    ) -> Option<(Vec<(String, LogicalOperator)>, Vec<JoinInfo>)> {
513        let mut relations = Vec::new();
514        let mut join_conditions = Vec::new();
515
516        if !self.collect_join_tree(op, &mut relations, &mut join_conditions) {
517            return None;
518        }
519
520        if relations.len() < 2 {
521            return None;
522        }
523
524        Some((relations, join_conditions))
525    }
526
527    /// Recursively collects base relations and join conditions.
528    ///
529    /// Returns true if this subtree is part of a join tree.
530    fn collect_join_tree(
531        &self,
532        op: &LogicalOperator,
533        relations: &mut Vec<(String, LogicalOperator)>,
534        conditions: &mut Vec<JoinInfo>,
535    ) -> bool {
536        match op {
537            LogicalOperator::Join(join) => {
538                // Collect from both sides
539                let left_ok = self.collect_join_tree(&join.left, relations, conditions);
540                let right_ok = self.collect_join_tree(&join.right, relations, conditions);
541
542                // Add conditions from this join
543                for cond in &join.conditions {
544                    if let (Some(left_var), Some(right_var)) = (
545                        self.extract_variable_from_expr(&cond.left),
546                        self.extract_variable_from_expr(&cond.right),
547                    ) {
548                        conditions.push(JoinInfo {
549                            left_var,
550                            right_var,
551                            left_expr: cond.left.clone(),
552                            right_expr: cond.right.clone(),
553                        });
554                    }
555                }
556
557                left_ok && right_ok
558            }
559            LogicalOperator::NodeScan(scan) => {
560                relations.push((scan.variable.clone(), op.clone()));
561                true
562            }
563            LogicalOperator::EdgeScan(scan) => {
564                relations.push((scan.variable.clone(), op.clone()));
565                true
566            }
567            LogicalOperator::Filter(filter) => {
568                // A filter on a base relation is still part of the join tree
569                self.collect_join_tree(&filter.input, relations, conditions)
570            }
571            LogicalOperator::Expand(expand) => {
572                // Expand is a special case - it's like a join with the adjacency
573                // For now, treat the whole Expand subtree as a single relation
574                relations.push((expand.to_variable.clone(), op.clone()));
575                true
576            }
577            _ => false,
578        }
579    }
580
581    /// Extracts the primary variable from an expression.
582    fn extract_variable_from_expr(&self, expr: &LogicalExpression) -> Option<String> {
583        match expr {
584            LogicalExpression::Variable(v) => Some(v.clone()),
585            LogicalExpression::Property { variable, .. } => Some(variable.clone()),
586            _ => None,
587        }
588    }
589
590    /// Optimizes the join order using DPccp.
591    fn optimize_join_order(
592        &self,
593        relations: &[(String, LogicalOperator)],
594        conditions: &[JoinInfo],
595    ) -> Option<LogicalOperator> {
596        use join_order::{DPccp, JoinGraphBuilder};
597
598        // Build the join graph
599        let mut builder = JoinGraphBuilder::new();
600
601        for (var, relation) in relations {
602            builder.add_relation(var, relation.clone());
603        }
604
605        for cond in conditions {
606            builder.add_join_condition(
607                &cond.left_var,
608                &cond.right_var,
609                cond.left_expr.clone(),
610                cond.right_expr.clone(),
611            );
612        }
613
614        let graph = builder.build();
615
616        // Run DPccp
617        let mut dpccp = DPccp::new(&graph, &self.cost_model, &self.card_estimator);
618        let plan = dpccp.optimize()?;
619
620        Some(plan.operator)
621    }
622
623    /// Pushes filters down the operator tree.
624    ///
625    /// This optimization moves filter predicates as close to the data source
626    /// as possible to reduce the amount of data processed by upper operators.
627    fn push_filters_down(&self, op: LogicalOperator) -> LogicalOperator {
628        match op {
629            // For Filter operators, try to push the predicate into the child
630            LogicalOperator::Filter(filter) => {
631                let optimized_input = self.push_filters_down(*filter.input);
632                self.try_push_filter_into(filter.predicate, optimized_input)
633            }
634            // Recursively optimize children for other operators
635            LogicalOperator::Return(mut ret) => {
636                ret.input = Box::new(self.push_filters_down(*ret.input));
637                LogicalOperator::Return(ret)
638            }
639            LogicalOperator::Project(mut proj) => {
640                proj.input = Box::new(self.push_filters_down(*proj.input));
641                LogicalOperator::Project(proj)
642            }
643            LogicalOperator::Limit(mut limit) => {
644                limit.input = Box::new(self.push_filters_down(*limit.input));
645                LogicalOperator::Limit(limit)
646            }
647            LogicalOperator::Skip(mut skip) => {
648                skip.input = Box::new(self.push_filters_down(*skip.input));
649                LogicalOperator::Skip(skip)
650            }
651            LogicalOperator::Sort(mut sort) => {
652                sort.input = Box::new(self.push_filters_down(*sort.input));
653                LogicalOperator::Sort(sort)
654            }
655            LogicalOperator::Distinct(mut distinct) => {
656                distinct.input = Box::new(self.push_filters_down(*distinct.input));
657                LogicalOperator::Distinct(distinct)
658            }
659            LogicalOperator::Expand(mut expand) => {
660                expand.input = Box::new(self.push_filters_down(*expand.input));
661                LogicalOperator::Expand(expand)
662            }
663            LogicalOperator::Join(mut join) => {
664                join.left = Box::new(self.push_filters_down(*join.left));
665                join.right = Box::new(self.push_filters_down(*join.right));
666                LogicalOperator::Join(join)
667            }
668            LogicalOperator::Aggregate(mut agg) => {
669                agg.input = Box::new(self.push_filters_down(*agg.input));
670                LogicalOperator::Aggregate(agg)
671            }
672            // Leaf operators and unsupported operators are returned as-is
673            other => other,
674        }
675    }
676
677    /// Tries to push a filter predicate into the given operator.
678    ///
679    /// Returns either the predicate pushed into the operator, or a new
680    /// Filter operator on top if the predicate cannot be pushed further.
681    fn try_push_filter_into(
682        &self,
683        predicate: LogicalExpression,
684        op: LogicalOperator,
685    ) -> LogicalOperator {
686        match op {
687            // Can push through Project if predicate doesn't depend on computed columns
688            LogicalOperator::Project(mut proj) => {
689                let predicate_vars = self.extract_variables(&predicate);
690                let computed_vars = self.extract_projection_aliases(&proj.projections);
691
692                // If predicate doesn't use any computed columns, push through
693                if predicate_vars.is_disjoint(&computed_vars) {
694                    proj.input = Box::new(self.try_push_filter_into(predicate, *proj.input));
695                    LogicalOperator::Project(proj)
696                } else {
697                    // Can't push through, keep filter on top
698                    LogicalOperator::Filter(FilterOp {
699                        predicate,
700                        input: Box::new(LogicalOperator::Project(proj)),
701                    })
702                }
703            }
704
705            // Can push through Return (which is like a projection)
706            LogicalOperator::Return(mut ret) => {
707                ret.input = Box::new(self.try_push_filter_into(predicate, *ret.input));
708                LogicalOperator::Return(ret)
709            }
710
711            // Can push through Expand if predicate doesn't use variables introduced by this expand
712            LogicalOperator::Expand(mut expand) => {
713                let predicate_vars = self.extract_variables(&predicate);
714
715                // Variables introduced by this expand are:
716                // - The target variable (to_variable)
717                // - The edge variable (if any)
718                // - The path alias (if any)
719                let mut introduced_vars = vec![&expand.to_variable];
720                if let Some(ref edge_var) = expand.edge_variable {
721                    introduced_vars.push(edge_var);
722                }
723                if let Some(ref path_alias) = expand.path_alias {
724                    introduced_vars.push(path_alias);
725                }
726
727                // Check if predicate uses any variables introduced by this expand
728                let uses_introduced_vars =
729                    predicate_vars.iter().any(|v| introduced_vars.contains(&v));
730
731                if !uses_introduced_vars {
732                    // Predicate doesn't use vars from this expand, so push through
733                    expand.input = Box::new(self.try_push_filter_into(predicate, *expand.input));
734                    LogicalOperator::Expand(expand)
735                } else {
736                    // Keep filter after expand
737                    LogicalOperator::Filter(FilterOp {
738                        predicate,
739                        input: Box::new(LogicalOperator::Expand(expand)),
740                    })
741                }
742            }
743
744            // Can push through Join to left/right side based on variables used
745            LogicalOperator::Join(mut join) => {
746                let predicate_vars = self.extract_variables(&predicate);
747                let left_vars = self.collect_output_variables(&join.left);
748                let right_vars = self.collect_output_variables(&join.right);
749
750                let uses_left = predicate_vars.iter().any(|v| left_vars.contains(v));
751                let uses_right = predicate_vars.iter().any(|v| right_vars.contains(v));
752
753                if uses_left && !uses_right {
754                    // Push to left side
755                    join.left = Box::new(self.try_push_filter_into(predicate, *join.left));
756                    LogicalOperator::Join(join)
757                } else if uses_right && !uses_left {
758                    // Push to right side
759                    join.right = Box::new(self.try_push_filter_into(predicate, *join.right));
760                    LogicalOperator::Join(join)
761                } else {
762                    // Uses both sides - keep above join
763                    LogicalOperator::Filter(FilterOp {
764                        predicate,
765                        input: Box::new(LogicalOperator::Join(join)),
766                    })
767                }
768            }
769
770            // Cannot push through Aggregate (predicate refers to aggregated values)
771            LogicalOperator::Aggregate(agg) => LogicalOperator::Filter(FilterOp {
772                predicate,
773                input: Box::new(LogicalOperator::Aggregate(agg)),
774            }),
775
776            // For NodeScan, we've reached the bottom - keep filter on top
777            LogicalOperator::NodeScan(scan) => LogicalOperator::Filter(FilterOp {
778                predicate,
779                input: Box::new(LogicalOperator::NodeScan(scan)),
780            }),
781
782            // For other operators, keep filter on top
783            other => LogicalOperator::Filter(FilterOp {
784                predicate,
785                input: Box::new(other),
786            }),
787        }
788    }
789
790    /// Collects all output variable names from an operator.
791    fn collect_output_variables(&self, op: &LogicalOperator) -> HashSet<String> {
792        let mut vars = HashSet::new();
793        Self::collect_output_variables_recursive(op, &mut vars);
794        vars
795    }
796
797    /// Recursively collects output variables from an operator.
798    fn collect_output_variables_recursive(op: &LogicalOperator, vars: &mut HashSet<String>) {
799        match op {
800            LogicalOperator::NodeScan(scan) => {
801                vars.insert(scan.variable.clone());
802            }
803            LogicalOperator::EdgeScan(scan) => {
804                vars.insert(scan.variable.clone());
805            }
806            LogicalOperator::Expand(expand) => {
807                vars.insert(expand.to_variable.clone());
808                if let Some(edge_var) = &expand.edge_variable {
809                    vars.insert(edge_var.clone());
810                }
811                Self::collect_output_variables_recursive(&expand.input, vars);
812            }
813            LogicalOperator::Filter(filter) => {
814                Self::collect_output_variables_recursive(&filter.input, vars);
815            }
816            LogicalOperator::Project(proj) => {
817                for p in &proj.projections {
818                    if let Some(alias) = &p.alias {
819                        vars.insert(alias.clone());
820                    }
821                }
822                Self::collect_output_variables_recursive(&proj.input, vars);
823            }
824            LogicalOperator::Join(join) => {
825                Self::collect_output_variables_recursive(&join.left, vars);
826                Self::collect_output_variables_recursive(&join.right, vars);
827            }
828            LogicalOperator::Aggregate(agg) => {
829                for expr in &agg.group_by {
830                    Self::collect_variables(expr, vars);
831                }
832                for agg_expr in &agg.aggregates {
833                    if let Some(alias) = &agg_expr.alias {
834                        vars.insert(alias.clone());
835                    }
836                }
837            }
838            LogicalOperator::Return(ret) => {
839                Self::collect_output_variables_recursive(&ret.input, vars);
840            }
841            LogicalOperator::Limit(limit) => {
842                Self::collect_output_variables_recursive(&limit.input, vars);
843            }
844            LogicalOperator::Skip(skip) => {
845                Self::collect_output_variables_recursive(&skip.input, vars);
846            }
847            LogicalOperator::Sort(sort) => {
848                Self::collect_output_variables_recursive(&sort.input, vars);
849            }
850            LogicalOperator::Distinct(distinct) => {
851                Self::collect_output_variables_recursive(&distinct.input, vars);
852            }
853            _ => {}
854        }
855    }
856
857    /// Extracts all variable names referenced in an expression.
858    fn extract_variables(&self, expr: &LogicalExpression) -> HashSet<String> {
859        let mut vars = HashSet::new();
860        Self::collect_variables(expr, &mut vars);
861        vars
862    }
863
864    /// Recursively collects variable names from an expression.
865    fn collect_variables(expr: &LogicalExpression, vars: &mut HashSet<String>) {
866        match expr {
867            LogicalExpression::Variable(name) => {
868                vars.insert(name.clone());
869            }
870            LogicalExpression::Property { variable, .. } => {
871                vars.insert(variable.clone());
872            }
873            LogicalExpression::Binary { left, right, .. } => {
874                Self::collect_variables(left, vars);
875                Self::collect_variables(right, vars);
876            }
877            LogicalExpression::Unary { operand, .. } => {
878                Self::collect_variables(operand, vars);
879            }
880            LogicalExpression::FunctionCall { args, .. } => {
881                for arg in args {
882                    Self::collect_variables(arg, vars);
883                }
884            }
885            LogicalExpression::List(items) => {
886                for item in items {
887                    Self::collect_variables(item, vars);
888                }
889            }
890            LogicalExpression::Map(pairs) => {
891                for (_, value) in pairs {
892                    Self::collect_variables(value, vars);
893                }
894            }
895            LogicalExpression::IndexAccess { base, index } => {
896                Self::collect_variables(base, vars);
897                Self::collect_variables(index, vars);
898            }
899            LogicalExpression::SliceAccess { base, start, end } => {
900                Self::collect_variables(base, vars);
901                if let Some(s) = start {
902                    Self::collect_variables(s, vars);
903                }
904                if let Some(e) = end {
905                    Self::collect_variables(e, vars);
906                }
907            }
908            LogicalExpression::Case {
909                operand,
910                when_clauses,
911                else_clause,
912            } => {
913                if let Some(op) = operand {
914                    Self::collect_variables(op, vars);
915                }
916                for (cond, result) in when_clauses {
917                    Self::collect_variables(cond, vars);
918                    Self::collect_variables(result, vars);
919                }
920                if let Some(else_expr) = else_clause {
921                    Self::collect_variables(else_expr, vars);
922                }
923            }
924            LogicalExpression::Labels(var)
925            | LogicalExpression::Type(var)
926            | LogicalExpression::Id(var) => {
927                vars.insert(var.clone());
928            }
929            LogicalExpression::Literal(_) | LogicalExpression::Parameter(_) => {}
930            LogicalExpression::ListComprehension {
931                list_expr,
932                filter_expr,
933                map_expr,
934                ..
935            } => {
936                Self::collect_variables(list_expr, vars);
937                if let Some(filter) = filter_expr {
938                    Self::collect_variables(filter, vars);
939                }
940                Self::collect_variables(map_expr, vars);
941            }
942            LogicalExpression::ExistsSubquery(_) | LogicalExpression::CountSubquery(_) => {
943                // Subqueries have their own variable scope
944            }
945        }
946    }
947
948    /// Extracts aliases from projection expressions.
949    fn extract_projection_aliases(
950        &self,
951        projections: &[crate::query::plan::Projection],
952    ) -> HashSet<String> {
953        projections.iter().filter_map(|p| p.alias.clone()).collect()
954    }
955}
956
957impl Default for Optimizer {
958    fn default() -> Self {
959        Self::new()
960    }
961}
962
963#[cfg(test)]
964mod tests {
965    use super::*;
966    use crate::query::plan::{
967        AggregateExpr, AggregateFunction, AggregateOp, BinaryOp, DistinctOp, ExpandDirection,
968        ExpandOp, JoinOp, JoinType, LimitOp, NodeScanOp, ProjectOp, Projection, ReturnItem,
969        ReturnOp, SkipOp, SortKey, SortOp, SortOrder, UnaryOp,
970    };
971    use grafeo_common::types::Value;
972
973    #[test]
974    fn test_optimizer_filter_pushdown_simple() {
975        // Query: MATCH (n:Person) WHERE n.age > 30 RETURN n
976        // Before: Return -> Filter -> NodeScan
977        // After:  Return -> Filter -> NodeScan (filter stays at bottom)
978
979        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
980            items: vec![ReturnItem {
981                expression: LogicalExpression::Variable("n".to_string()),
982                alias: None,
983            }],
984            distinct: false,
985            input: Box::new(LogicalOperator::Filter(FilterOp {
986                predicate: LogicalExpression::Binary {
987                    left: Box::new(LogicalExpression::Property {
988                        variable: "n".to_string(),
989                        property: "age".to_string(),
990                    }),
991                    op: BinaryOp::Gt,
992                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
993                },
994                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
995                    variable: "n".to_string(),
996                    label: Some("Person".to_string()),
997                    input: None,
998                })),
999            })),
1000        }));
1001
1002        let optimizer = Optimizer::new();
1003        let optimized = optimizer.optimize(plan).unwrap();
1004
1005        // The structure should remain similar (filter stays near scan)
1006        if let LogicalOperator::Return(ret) = &optimized.root
1007            && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1008            && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1009        {
1010            assert_eq!(scan.variable, "n");
1011            return;
1012        }
1013        panic!("Expected Return -> Filter -> NodeScan structure");
1014    }
1015
1016    #[test]
1017    fn test_optimizer_filter_pushdown_through_expand() {
1018        // Query: MATCH (a:Person)-[:KNOWS]->(b) WHERE a.age > 30 RETURN b
1019        // The filter on 'a' should be pushed before the expand
1020
1021        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1022            items: vec![ReturnItem {
1023                expression: LogicalExpression::Variable("b".to_string()),
1024                alias: None,
1025            }],
1026            distinct: false,
1027            input: Box::new(LogicalOperator::Filter(FilterOp {
1028                predicate: LogicalExpression::Binary {
1029                    left: Box::new(LogicalExpression::Property {
1030                        variable: "a".to_string(),
1031                        property: "age".to_string(),
1032                    }),
1033                    op: BinaryOp::Gt,
1034                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1035                },
1036                input: Box::new(LogicalOperator::Expand(ExpandOp {
1037                    from_variable: "a".to_string(),
1038                    to_variable: "b".to_string(),
1039                    edge_variable: None,
1040                    direction: ExpandDirection::Outgoing,
1041                    edge_type: Some("KNOWS".to_string()),
1042                    min_hops: 1,
1043                    max_hops: Some(1),
1044                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1045                        variable: "a".to_string(),
1046                        label: Some("Person".to_string()),
1047                        input: None,
1048                    })),
1049                    path_alias: None,
1050                })),
1051            })),
1052        }));
1053
1054        let optimizer = Optimizer::new();
1055        let optimized = optimizer.optimize(plan).unwrap();
1056
1057        // Filter on 'a' should be pushed before the expand
1058        // Expected: Return -> Expand -> Filter -> NodeScan
1059        if let LogicalOperator::Return(ret) = &optimized.root
1060            && let LogicalOperator::Expand(expand) = ret.input.as_ref()
1061            && let LogicalOperator::Filter(filter) = expand.input.as_ref()
1062            && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1063        {
1064            assert_eq!(scan.variable, "a");
1065            assert_eq!(expand.from_variable, "a");
1066            assert_eq!(expand.to_variable, "b");
1067            return;
1068        }
1069        panic!("Expected Return -> Expand -> Filter -> NodeScan structure");
1070    }
1071
1072    #[test]
1073    fn test_optimizer_filter_not_pushed_through_expand_for_target_var() {
1074        // Query: MATCH (a:Person)-[:KNOWS]->(b) WHERE b.age > 30 RETURN a
1075        // The filter on 'b' should NOT be pushed before the expand
1076
1077        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1078            items: vec![ReturnItem {
1079                expression: LogicalExpression::Variable("a".to_string()),
1080                alias: None,
1081            }],
1082            distinct: false,
1083            input: Box::new(LogicalOperator::Filter(FilterOp {
1084                predicate: LogicalExpression::Binary {
1085                    left: Box::new(LogicalExpression::Property {
1086                        variable: "b".to_string(),
1087                        property: "age".to_string(),
1088                    }),
1089                    op: BinaryOp::Gt,
1090                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1091                },
1092                input: Box::new(LogicalOperator::Expand(ExpandOp {
1093                    from_variable: "a".to_string(),
1094                    to_variable: "b".to_string(),
1095                    edge_variable: None,
1096                    direction: ExpandDirection::Outgoing,
1097                    edge_type: Some("KNOWS".to_string()),
1098                    min_hops: 1,
1099                    max_hops: Some(1),
1100                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1101                        variable: "a".to_string(),
1102                        label: Some("Person".to_string()),
1103                        input: None,
1104                    })),
1105                    path_alias: None,
1106                })),
1107            })),
1108        }));
1109
1110        let optimizer = Optimizer::new();
1111        let optimized = optimizer.optimize(plan).unwrap();
1112
1113        // Filter on 'b' should stay after the expand
1114        // Expected: Return -> Filter -> Expand -> NodeScan
1115        if let LogicalOperator::Return(ret) = &optimized.root
1116            && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1117        {
1118            // Check that the filter is on 'b'
1119            if let LogicalExpression::Binary { left, .. } = &filter.predicate
1120                && let LogicalExpression::Property { variable, .. } = left.as_ref()
1121            {
1122                assert_eq!(variable, "b");
1123            }
1124
1125            if let LogicalOperator::Expand(expand) = filter.input.as_ref()
1126                && let LogicalOperator::NodeScan(_) = expand.input.as_ref()
1127            {
1128                return;
1129            }
1130        }
1131        panic!("Expected Return -> Filter -> Expand -> NodeScan structure");
1132    }
1133
1134    #[test]
1135    fn test_optimizer_extract_variables() {
1136        let optimizer = Optimizer::new();
1137
1138        let expr = LogicalExpression::Binary {
1139            left: Box::new(LogicalExpression::Property {
1140                variable: "n".to_string(),
1141                property: "age".to_string(),
1142            }),
1143            op: BinaryOp::Gt,
1144            right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1145        };
1146
1147        let vars = optimizer.extract_variables(&expr);
1148        assert_eq!(vars.len(), 1);
1149        assert!(vars.contains("n"));
1150    }
1151
1152    // Additional tests for optimizer configuration
1153
1154    #[test]
1155    fn test_optimizer_default() {
1156        let optimizer = Optimizer::default();
1157        // Should be able to optimize an empty plan
1158        let plan = LogicalPlan::new(LogicalOperator::Empty);
1159        let result = optimizer.optimize(plan);
1160        assert!(result.is_ok());
1161    }
1162
1163    #[test]
1164    fn test_optimizer_with_filter_pushdown_disabled() {
1165        let optimizer = Optimizer::new().with_filter_pushdown(false);
1166
1167        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1168            items: vec![ReturnItem {
1169                expression: LogicalExpression::Variable("n".to_string()),
1170                alias: None,
1171            }],
1172            distinct: false,
1173            input: Box::new(LogicalOperator::Filter(FilterOp {
1174                predicate: LogicalExpression::Literal(Value::Bool(true)),
1175                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1176                    variable: "n".to_string(),
1177                    label: None,
1178                    input: None,
1179                })),
1180            })),
1181        }));
1182
1183        let optimized = optimizer.optimize(plan).unwrap();
1184        // Structure should be unchanged
1185        if let LogicalOperator::Return(ret) = &optimized.root
1186            && let LogicalOperator::Filter(_) = ret.input.as_ref()
1187        {
1188            return;
1189        }
1190        panic!("Expected unchanged structure");
1191    }
1192
1193    #[test]
1194    fn test_optimizer_with_join_reorder_disabled() {
1195        let optimizer = Optimizer::new().with_join_reorder(false);
1196        assert!(
1197            optimizer
1198                .optimize(LogicalPlan::new(LogicalOperator::Empty))
1199                .is_ok()
1200        );
1201    }
1202
1203    #[test]
1204    fn test_optimizer_with_cost_model() {
1205        let cost_model = CostModel::new();
1206        let optimizer = Optimizer::new().with_cost_model(cost_model);
1207        assert!(
1208            optimizer
1209                .cost_model()
1210                .estimate(&LogicalOperator::Empty, 0.0)
1211                .total()
1212                < 0.001
1213        );
1214    }
1215
1216    #[test]
1217    fn test_optimizer_with_cardinality_estimator() {
1218        let mut estimator = CardinalityEstimator::new();
1219        estimator.add_table_stats("Test", TableStats::new(500));
1220        let optimizer = Optimizer::new().with_cardinality_estimator(estimator);
1221
1222        let scan = LogicalOperator::NodeScan(NodeScanOp {
1223            variable: "n".to_string(),
1224            label: Some("Test".to_string()),
1225            input: None,
1226        });
1227        let plan = LogicalPlan::new(scan);
1228
1229        let cardinality = optimizer.estimate_cardinality(&plan);
1230        assert!((cardinality - 500.0).abs() < 0.001);
1231    }
1232
1233    #[test]
1234    fn test_optimizer_estimate_cost() {
1235        let optimizer = Optimizer::new();
1236        let plan = LogicalPlan::new(LogicalOperator::NodeScan(NodeScanOp {
1237            variable: "n".to_string(),
1238            label: None,
1239            input: None,
1240        }));
1241
1242        let cost = optimizer.estimate_cost(&plan);
1243        assert!(cost.total() > 0.0);
1244    }
1245
1246    // Filter pushdown through various operators
1247
1248    #[test]
1249    fn test_filter_pushdown_through_project() {
1250        let optimizer = Optimizer::new();
1251
1252        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1253            predicate: LogicalExpression::Binary {
1254                left: Box::new(LogicalExpression::Property {
1255                    variable: "n".to_string(),
1256                    property: "age".to_string(),
1257                }),
1258                op: BinaryOp::Gt,
1259                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1260            },
1261            input: Box::new(LogicalOperator::Project(ProjectOp {
1262                projections: vec![Projection {
1263                    expression: LogicalExpression::Variable("n".to_string()),
1264                    alias: None,
1265                }],
1266                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1267                    variable: "n".to_string(),
1268                    label: None,
1269                    input: None,
1270                })),
1271            })),
1272        }));
1273
1274        let optimized = optimizer.optimize(plan).unwrap();
1275
1276        // Filter should be pushed through Project
1277        if let LogicalOperator::Project(proj) = &optimized.root
1278            && let LogicalOperator::Filter(_) = proj.input.as_ref()
1279        {
1280            return;
1281        }
1282        panic!("Expected Project -> Filter structure");
1283    }
1284
1285    #[test]
1286    fn test_filter_not_pushed_through_project_with_alias() {
1287        let optimizer = Optimizer::new();
1288
1289        // Filter on computed column 'x' should not be pushed through project that creates 'x'
1290        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1291            predicate: LogicalExpression::Binary {
1292                left: Box::new(LogicalExpression::Variable("x".to_string())),
1293                op: BinaryOp::Gt,
1294                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1295            },
1296            input: Box::new(LogicalOperator::Project(ProjectOp {
1297                projections: vec![Projection {
1298                    expression: LogicalExpression::Property {
1299                        variable: "n".to_string(),
1300                        property: "age".to_string(),
1301                    },
1302                    alias: Some("x".to_string()),
1303                }],
1304                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1305                    variable: "n".to_string(),
1306                    label: None,
1307                    input: None,
1308                })),
1309            })),
1310        }));
1311
1312        let optimized = optimizer.optimize(plan).unwrap();
1313
1314        // Filter should stay above Project
1315        if let LogicalOperator::Filter(filter) = &optimized.root
1316            && let LogicalOperator::Project(_) = filter.input.as_ref()
1317        {
1318            return;
1319        }
1320        panic!("Expected Filter -> Project structure");
1321    }
1322
1323    #[test]
1324    fn test_filter_pushdown_through_limit() {
1325        let optimizer = Optimizer::new();
1326
1327        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1328            predicate: LogicalExpression::Literal(Value::Bool(true)),
1329            input: Box::new(LogicalOperator::Limit(LimitOp {
1330                count: 10,
1331                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1332                    variable: "n".to_string(),
1333                    label: None,
1334                    input: None,
1335                })),
1336            })),
1337        }));
1338
1339        let optimized = optimizer.optimize(plan).unwrap();
1340
1341        // Filter stays above Limit (cannot be pushed through)
1342        if let LogicalOperator::Filter(filter) = &optimized.root
1343            && let LogicalOperator::Limit(_) = filter.input.as_ref()
1344        {
1345            return;
1346        }
1347        panic!("Expected Filter -> Limit structure");
1348    }
1349
1350    #[test]
1351    fn test_filter_pushdown_through_sort() {
1352        let optimizer = Optimizer::new();
1353
1354        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1355            predicate: LogicalExpression::Literal(Value::Bool(true)),
1356            input: Box::new(LogicalOperator::Sort(SortOp {
1357                keys: vec![SortKey {
1358                    expression: LogicalExpression::Variable("n".to_string()),
1359                    order: SortOrder::Ascending,
1360                }],
1361                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1362                    variable: "n".to_string(),
1363                    label: None,
1364                    input: None,
1365                })),
1366            })),
1367        }));
1368
1369        let optimized = optimizer.optimize(plan).unwrap();
1370
1371        // Filter stays above Sort
1372        if let LogicalOperator::Filter(filter) = &optimized.root
1373            && let LogicalOperator::Sort(_) = filter.input.as_ref()
1374        {
1375            return;
1376        }
1377        panic!("Expected Filter -> Sort structure");
1378    }
1379
1380    #[test]
1381    fn test_filter_pushdown_through_distinct() {
1382        let optimizer = Optimizer::new();
1383
1384        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1385            predicate: LogicalExpression::Literal(Value::Bool(true)),
1386            input: Box::new(LogicalOperator::Distinct(DistinctOp {
1387                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1388                    variable: "n".to_string(),
1389                    label: None,
1390                    input: None,
1391                })),
1392                columns: None,
1393            })),
1394        }));
1395
1396        let optimized = optimizer.optimize(plan).unwrap();
1397
1398        // Filter stays above Distinct
1399        if let LogicalOperator::Filter(filter) = &optimized.root
1400            && let LogicalOperator::Distinct(_) = filter.input.as_ref()
1401        {
1402            return;
1403        }
1404        panic!("Expected Filter -> Distinct structure");
1405    }
1406
1407    #[test]
1408    fn test_filter_not_pushed_through_aggregate() {
1409        let optimizer = Optimizer::new();
1410
1411        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1412            predicate: LogicalExpression::Binary {
1413                left: Box::new(LogicalExpression::Variable("cnt".to_string())),
1414                op: BinaryOp::Gt,
1415                right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1416            },
1417            input: Box::new(LogicalOperator::Aggregate(AggregateOp {
1418                group_by: vec![],
1419                aggregates: vec![AggregateExpr {
1420                    function: AggregateFunction::Count,
1421                    expression: None,
1422                    distinct: false,
1423                    alias: Some("cnt".to_string()),
1424                    percentile: None,
1425                }],
1426                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1427                    variable: "n".to_string(),
1428                    label: None,
1429                    input: None,
1430                })),
1431                having: None,
1432            })),
1433        }));
1434
1435        let optimized = optimizer.optimize(plan).unwrap();
1436
1437        // Filter should stay above Aggregate
1438        if let LogicalOperator::Filter(filter) = &optimized.root
1439            && let LogicalOperator::Aggregate(_) = filter.input.as_ref()
1440        {
1441            return;
1442        }
1443        panic!("Expected Filter -> Aggregate structure");
1444    }
1445
1446    #[test]
1447    fn test_filter_pushdown_to_left_join_side() {
1448        let optimizer = Optimizer::new();
1449
1450        // Filter on left variable should be pushed to left side
1451        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1452            predicate: LogicalExpression::Binary {
1453                left: Box::new(LogicalExpression::Property {
1454                    variable: "a".to_string(),
1455                    property: "age".to_string(),
1456                }),
1457                op: BinaryOp::Gt,
1458                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1459            },
1460            input: Box::new(LogicalOperator::Join(JoinOp {
1461                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1462                    variable: "a".to_string(),
1463                    label: Some("Person".to_string()),
1464                    input: None,
1465                })),
1466                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1467                    variable: "b".to_string(),
1468                    label: Some("Company".to_string()),
1469                    input: None,
1470                })),
1471                join_type: JoinType::Inner,
1472                conditions: vec![],
1473            })),
1474        }));
1475
1476        let optimized = optimizer.optimize(plan).unwrap();
1477
1478        // Filter should be pushed to left side of join
1479        if let LogicalOperator::Join(join) = &optimized.root
1480            && let LogicalOperator::Filter(_) = join.left.as_ref()
1481        {
1482            return;
1483        }
1484        panic!("Expected Join with Filter on left side");
1485    }
1486
1487    #[test]
1488    fn test_filter_pushdown_to_right_join_side() {
1489        let optimizer = Optimizer::new();
1490
1491        // Filter on right variable should be pushed to right side
1492        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1493            predicate: LogicalExpression::Binary {
1494                left: Box::new(LogicalExpression::Property {
1495                    variable: "b".to_string(),
1496                    property: "name".to_string(),
1497                }),
1498                op: BinaryOp::Eq,
1499                right: Box::new(LogicalExpression::Literal(Value::String("Acme".into()))),
1500            },
1501            input: Box::new(LogicalOperator::Join(JoinOp {
1502                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1503                    variable: "a".to_string(),
1504                    label: Some("Person".to_string()),
1505                    input: None,
1506                })),
1507                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1508                    variable: "b".to_string(),
1509                    label: Some("Company".to_string()),
1510                    input: None,
1511                })),
1512                join_type: JoinType::Inner,
1513                conditions: vec![],
1514            })),
1515        }));
1516
1517        let optimized = optimizer.optimize(plan).unwrap();
1518
1519        // Filter should be pushed to right side of join
1520        if let LogicalOperator::Join(join) = &optimized.root
1521            && let LogicalOperator::Filter(_) = join.right.as_ref()
1522        {
1523            return;
1524        }
1525        panic!("Expected Join with Filter on right side");
1526    }
1527
1528    #[test]
1529    fn test_filter_not_pushed_when_uses_both_join_sides() {
1530        let optimizer = Optimizer::new();
1531
1532        // Filter using both variables should stay above join
1533        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1534            predicate: LogicalExpression::Binary {
1535                left: Box::new(LogicalExpression::Property {
1536                    variable: "a".to_string(),
1537                    property: "id".to_string(),
1538                }),
1539                op: BinaryOp::Eq,
1540                right: Box::new(LogicalExpression::Property {
1541                    variable: "b".to_string(),
1542                    property: "a_id".to_string(),
1543                }),
1544            },
1545            input: Box::new(LogicalOperator::Join(JoinOp {
1546                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1547                    variable: "a".to_string(),
1548                    label: None,
1549                    input: None,
1550                })),
1551                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1552                    variable: "b".to_string(),
1553                    label: None,
1554                    input: None,
1555                })),
1556                join_type: JoinType::Inner,
1557                conditions: vec![],
1558            })),
1559        }));
1560
1561        let optimized = optimizer.optimize(plan).unwrap();
1562
1563        // Filter should stay above join
1564        if let LogicalOperator::Filter(filter) = &optimized.root
1565            && let LogicalOperator::Join(_) = filter.input.as_ref()
1566        {
1567            return;
1568        }
1569        panic!("Expected Filter -> Join structure");
1570    }
1571
1572    // Variable extraction tests
1573
1574    #[test]
1575    fn test_extract_variables_from_variable() {
1576        let optimizer = Optimizer::new();
1577        let expr = LogicalExpression::Variable("x".to_string());
1578        let vars = optimizer.extract_variables(&expr);
1579        assert_eq!(vars.len(), 1);
1580        assert!(vars.contains("x"));
1581    }
1582
1583    #[test]
1584    fn test_extract_variables_from_unary() {
1585        let optimizer = Optimizer::new();
1586        let expr = LogicalExpression::Unary {
1587            op: UnaryOp::Not,
1588            operand: Box::new(LogicalExpression::Variable("x".to_string())),
1589        };
1590        let vars = optimizer.extract_variables(&expr);
1591        assert_eq!(vars.len(), 1);
1592        assert!(vars.contains("x"));
1593    }
1594
1595    #[test]
1596    fn test_extract_variables_from_function_call() {
1597        let optimizer = Optimizer::new();
1598        let expr = LogicalExpression::FunctionCall {
1599            name: "length".to_string(),
1600            args: vec![
1601                LogicalExpression::Variable("a".to_string()),
1602                LogicalExpression::Variable("b".to_string()),
1603            ],
1604            distinct: false,
1605        };
1606        let vars = optimizer.extract_variables(&expr);
1607        assert_eq!(vars.len(), 2);
1608        assert!(vars.contains("a"));
1609        assert!(vars.contains("b"));
1610    }
1611
1612    #[test]
1613    fn test_extract_variables_from_list() {
1614        let optimizer = Optimizer::new();
1615        let expr = LogicalExpression::List(vec![
1616            LogicalExpression::Variable("a".to_string()),
1617            LogicalExpression::Literal(Value::Int64(1)),
1618            LogicalExpression::Variable("b".to_string()),
1619        ]);
1620        let vars = optimizer.extract_variables(&expr);
1621        assert_eq!(vars.len(), 2);
1622        assert!(vars.contains("a"));
1623        assert!(vars.contains("b"));
1624    }
1625
1626    #[test]
1627    fn test_extract_variables_from_map() {
1628        let optimizer = Optimizer::new();
1629        let expr = LogicalExpression::Map(vec![
1630            (
1631                "key1".to_string(),
1632                LogicalExpression::Variable("a".to_string()),
1633            ),
1634            (
1635                "key2".to_string(),
1636                LogicalExpression::Variable("b".to_string()),
1637            ),
1638        ]);
1639        let vars = optimizer.extract_variables(&expr);
1640        assert_eq!(vars.len(), 2);
1641        assert!(vars.contains("a"));
1642        assert!(vars.contains("b"));
1643    }
1644
1645    #[test]
1646    fn test_extract_variables_from_index_access() {
1647        let optimizer = Optimizer::new();
1648        let expr = LogicalExpression::IndexAccess {
1649            base: Box::new(LogicalExpression::Variable("list".to_string())),
1650            index: Box::new(LogicalExpression::Variable("idx".to_string())),
1651        };
1652        let vars = optimizer.extract_variables(&expr);
1653        assert_eq!(vars.len(), 2);
1654        assert!(vars.contains("list"));
1655        assert!(vars.contains("idx"));
1656    }
1657
1658    #[test]
1659    fn test_extract_variables_from_slice_access() {
1660        let optimizer = Optimizer::new();
1661        let expr = LogicalExpression::SliceAccess {
1662            base: Box::new(LogicalExpression::Variable("list".to_string())),
1663            start: Some(Box::new(LogicalExpression::Variable("s".to_string()))),
1664            end: Some(Box::new(LogicalExpression::Variable("e".to_string()))),
1665        };
1666        let vars = optimizer.extract_variables(&expr);
1667        assert_eq!(vars.len(), 3);
1668        assert!(vars.contains("list"));
1669        assert!(vars.contains("s"));
1670        assert!(vars.contains("e"));
1671    }
1672
1673    #[test]
1674    fn test_extract_variables_from_case() {
1675        let optimizer = Optimizer::new();
1676        let expr = LogicalExpression::Case {
1677            operand: Some(Box::new(LogicalExpression::Variable("x".to_string()))),
1678            when_clauses: vec![(
1679                LogicalExpression::Literal(Value::Int64(1)),
1680                LogicalExpression::Variable("a".to_string()),
1681            )],
1682            else_clause: Some(Box::new(LogicalExpression::Variable("b".to_string()))),
1683        };
1684        let vars = optimizer.extract_variables(&expr);
1685        assert_eq!(vars.len(), 3);
1686        assert!(vars.contains("x"));
1687        assert!(vars.contains("a"));
1688        assert!(vars.contains("b"));
1689    }
1690
1691    #[test]
1692    fn test_extract_variables_from_labels() {
1693        let optimizer = Optimizer::new();
1694        let expr = LogicalExpression::Labels("n".to_string());
1695        let vars = optimizer.extract_variables(&expr);
1696        assert_eq!(vars.len(), 1);
1697        assert!(vars.contains("n"));
1698    }
1699
1700    #[test]
1701    fn test_extract_variables_from_type() {
1702        let optimizer = Optimizer::new();
1703        let expr = LogicalExpression::Type("e".to_string());
1704        let vars = optimizer.extract_variables(&expr);
1705        assert_eq!(vars.len(), 1);
1706        assert!(vars.contains("e"));
1707    }
1708
1709    #[test]
1710    fn test_extract_variables_from_id() {
1711        let optimizer = Optimizer::new();
1712        let expr = LogicalExpression::Id("n".to_string());
1713        let vars = optimizer.extract_variables(&expr);
1714        assert_eq!(vars.len(), 1);
1715        assert!(vars.contains("n"));
1716    }
1717
1718    #[test]
1719    fn test_extract_variables_from_list_comprehension() {
1720        let optimizer = Optimizer::new();
1721        let expr = LogicalExpression::ListComprehension {
1722            variable: "x".to_string(),
1723            list_expr: Box::new(LogicalExpression::Variable("items".to_string())),
1724            filter_expr: Some(Box::new(LogicalExpression::Variable("pred".to_string()))),
1725            map_expr: Box::new(LogicalExpression::Variable("result".to_string())),
1726        };
1727        let vars = optimizer.extract_variables(&expr);
1728        assert!(vars.contains("items"));
1729        assert!(vars.contains("pred"));
1730        assert!(vars.contains("result"));
1731    }
1732
1733    #[test]
1734    fn test_extract_variables_from_literal_and_parameter() {
1735        let optimizer = Optimizer::new();
1736
1737        let literal = LogicalExpression::Literal(Value::Int64(42));
1738        assert!(optimizer.extract_variables(&literal).is_empty());
1739
1740        let param = LogicalExpression::Parameter("p".to_string());
1741        assert!(optimizer.extract_variables(&param).is_empty());
1742    }
1743
1744    // Recursive filter pushdown tests
1745
1746    #[test]
1747    fn test_recursive_filter_pushdown_through_skip() {
1748        let optimizer = Optimizer::new();
1749
1750        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1751            items: vec![ReturnItem {
1752                expression: LogicalExpression::Variable("n".to_string()),
1753                alias: None,
1754            }],
1755            distinct: false,
1756            input: Box::new(LogicalOperator::Filter(FilterOp {
1757                predicate: LogicalExpression::Literal(Value::Bool(true)),
1758                input: Box::new(LogicalOperator::Skip(SkipOp {
1759                    count: 5,
1760                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1761                        variable: "n".to_string(),
1762                        label: None,
1763                        input: None,
1764                    })),
1765                })),
1766            })),
1767        }));
1768
1769        let optimized = optimizer.optimize(plan).unwrap();
1770
1771        // Verify optimization succeeded
1772        assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1773    }
1774
1775    #[test]
1776    fn test_nested_filter_pushdown() {
1777        let optimizer = Optimizer::new();
1778
1779        // Multiple nested filters
1780        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1781            items: vec![ReturnItem {
1782                expression: LogicalExpression::Variable("n".to_string()),
1783                alias: None,
1784            }],
1785            distinct: false,
1786            input: Box::new(LogicalOperator::Filter(FilterOp {
1787                predicate: LogicalExpression::Binary {
1788                    left: Box::new(LogicalExpression::Property {
1789                        variable: "n".to_string(),
1790                        property: "x".to_string(),
1791                    }),
1792                    op: BinaryOp::Gt,
1793                    right: Box::new(LogicalExpression::Literal(Value::Int64(1))),
1794                },
1795                input: Box::new(LogicalOperator::Filter(FilterOp {
1796                    predicate: LogicalExpression::Binary {
1797                        left: Box::new(LogicalExpression::Property {
1798                            variable: "n".to_string(),
1799                            property: "y".to_string(),
1800                        }),
1801                        op: BinaryOp::Lt,
1802                        right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1803                    },
1804                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1805                        variable: "n".to_string(),
1806                        label: None,
1807                        input: None,
1808                    })),
1809                })),
1810            })),
1811        }));
1812
1813        let optimized = optimizer.optimize(plan).unwrap();
1814        assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1815    }
1816}