Skip to main content

grafeo_engine/query/optimizer/
mod.rs

1//! Makes your queries faster without changing their meaning.
2//!
3//! The optimizer transforms logical plans to run more efficiently:
4//!
5//! | Optimization | What it does |
6//! | ------------ | ------------ |
7//! | Filter Pushdown | Moves `WHERE` clauses closer to scans - filter early, process less |
8//! | Join Reordering | Picks the best order to join tables using the DPccp algorithm |
9//! | Predicate Simplification | Folds constants like `1 + 1` into `2` |
10//!
11//! The optimizer uses [`CostModel`] and [`CardinalityEstimator`] to predict
12//! how expensive different plans are, then picks the cheapest.
13
14pub mod cardinality;
15pub mod cost;
16pub mod join_order;
17
18pub use cardinality::{
19    CardinalityEstimator, ColumnStats, EstimationLog, SelectivityConfig, TableStats,
20};
21pub use cost::{Cost, CostModel};
22pub use join_order::{BitSet, DPccp, JoinGraph, JoinGraphBuilder, JoinPlan};
23
24use crate::query::plan::{FilterOp, LogicalExpression, LogicalOperator, LogicalPlan};
25use grafeo_common::utils::error::Result;
26use std::collections::HashSet;
27
28/// Information about a join condition for join reordering.
29#[derive(Debug, Clone)]
30struct JoinInfo {
31    left_var: String,
32    right_var: String,
33    left_expr: LogicalExpression,
34    right_expr: LogicalExpression,
35}
36
37/// A column required by the query, used for projection pushdown.
38#[derive(Debug, Clone, PartialEq, Eq, Hash)]
39enum RequiredColumn {
40    /// A variable (node, edge, or path binding)
41    Variable(String),
42    /// A specific property of a variable
43    Property(String, String),
44}
45
46/// Transforms logical plans for faster execution.
47///
48/// Create with [`new()`](Self::new), then call [`optimize()`](Self::optimize).
49/// Use the builder methods to enable/disable specific optimizations.
50pub struct Optimizer {
51    /// Whether to enable filter pushdown.
52    enable_filter_pushdown: bool,
53    /// Whether to enable join reordering.
54    enable_join_reorder: bool,
55    /// Whether to enable projection pushdown.
56    enable_projection_pushdown: bool,
57    /// Cost model for estimation.
58    cost_model: CostModel,
59    /// Cardinality estimator.
60    card_estimator: CardinalityEstimator,
61}
62
63impl Optimizer {
64    /// Creates a new optimizer with default settings.
65    #[must_use]
66    pub fn new() -> Self {
67        Self {
68            enable_filter_pushdown: true,
69            enable_join_reorder: true,
70            enable_projection_pushdown: true,
71            cost_model: CostModel::new(),
72            card_estimator: CardinalityEstimator::new(),
73        }
74    }
75
76    /// Creates an optimizer with cardinality estimates from the store's statistics.
77    ///
78    /// Pre-populates the cardinality estimator with per-label row counts and
79    /// edge type fanout. Feeds per-edge-type degree stats into the cost model
80    /// for accurate expand cost estimation.
81    #[must_use]
82    pub fn from_store(store: &grafeo_core::graph::lpg::LpgStore) -> Self {
83        store.ensure_statistics_fresh();
84        let stats = store.statistics();
85        let estimator = CardinalityEstimator::from_statistics(&stats);
86
87        // Derive average fanout from statistics for the cost model
88        let avg_fanout = if stats.total_nodes > 0 {
89            (stats.total_edges as f64 / stats.total_nodes as f64).max(1.0)
90        } else {
91            10.0
92        };
93
94        // Collect per-edge-type degree stats for accurate expand costing
95        let edge_type_degrees: std::collections::HashMap<String, (f64, f64)> = stats
96            .edge_types
97            .iter()
98            .map(|(name, et)| (name.clone(), (et.avg_out_degree, et.avg_in_degree)))
99            .collect();
100
101        Self {
102            enable_filter_pushdown: true,
103            enable_join_reorder: true,
104            enable_projection_pushdown: true,
105            cost_model: CostModel::new()
106                .with_avg_fanout(avg_fanout)
107                .with_edge_type_degrees(edge_type_degrees),
108            card_estimator: estimator,
109        }
110    }
111
112    /// Creates an optimizer from any GraphStore implementation.
113    ///
114    /// Unlike [`from_store`](Self::from_store), this does not call
115    /// `ensure_statistics_fresh()` since external stores manage their own
116    /// statistics. The store's [`statistics()`](grafeo_core::graph::GraphStore::statistics) method
117    /// is called directly.
118    #[must_use]
119    pub fn from_graph_store(store: &dyn grafeo_core::graph::GraphStore) -> Self {
120        let stats = store.statistics();
121        let estimator = CardinalityEstimator::from_statistics(&stats);
122
123        let avg_fanout = if stats.total_nodes > 0 {
124            (stats.total_edges as f64 / stats.total_nodes as f64).max(1.0)
125        } else {
126            10.0
127        };
128
129        let edge_type_degrees: std::collections::HashMap<String, (f64, f64)> = stats
130            .edge_types
131            .iter()
132            .map(|(name, et)| (name.clone(), (et.avg_out_degree, et.avg_in_degree)))
133            .collect();
134
135        Self {
136            enable_filter_pushdown: true,
137            enable_join_reorder: true,
138            enable_projection_pushdown: true,
139            cost_model: CostModel::new()
140                .with_avg_fanout(avg_fanout)
141                .with_edge_type_degrees(edge_type_degrees),
142            card_estimator: estimator,
143        }
144    }
145
146    /// Enables or disables filter pushdown.
147    pub fn with_filter_pushdown(mut self, enabled: bool) -> Self {
148        self.enable_filter_pushdown = enabled;
149        self
150    }
151
152    /// Enables or disables join reordering.
153    pub fn with_join_reorder(mut self, enabled: bool) -> Self {
154        self.enable_join_reorder = enabled;
155        self
156    }
157
158    /// Enables or disables projection pushdown.
159    pub fn with_projection_pushdown(mut self, enabled: bool) -> Self {
160        self.enable_projection_pushdown = enabled;
161        self
162    }
163
164    /// Sets the cost model.
165    pub fn with_cost_model(mut self, cost_model: CostModel) -> Self {
166        self.cost_model = cost_model;
167        self
168    }
169
170    /// Sets the cardinality estimator.
171    pub fn with_cardinality_estimator(mut self, estimator: CardinalityEstimator) -> Self {
172        self.card_estimator = estimator;
173        self
174    }
175
176    /// Sets the selectivity configuration for the cardinality estimator.
177    pub fn with_selectivity_config(mut self, config: SelectivityConfig) -> Self {
178        self.card_estimator = CardinalityEstimator::with_selectivity_config(config);
179        self
180    }
181
182    /// Returns a reference to the cost model.
183    pub fn cost_model(&self) -> &CostModel {
184        &self.cost_model
185    }
186
187    /// Returns a reference to the cardinality estimator.
188    pub fn cardinality_estimator(&self) -> &CardinalityEstimator {
189        &self.card_estimator
190    }
191
192    /// Estimates the cost of a plan.
193    pub fn estimate_cost(&self, plan: &LogicalPlan) -> Cost {
194        let cardinality = self.card_estimator.estimate(&plan.root);
195        self.cost_model.estimate(&plan.root, cardinality)
196    }
197
198    /// Estimates the cardinality of a plan.
199    pub fn estimate_cardinality(&self, plan: &LogicalPlan) -> f64 {
200        self.card_estimator.estimate(&plan.root)
201    }
202
203    /// Optimizes a logical plan.
204    ///
205    /// # Errors
206    ///
207    /// Returns an error if optimization fails.
208    pub fn optimize(&self, plan: LogicalPlan) -> Result<LogicalPlan> {
209        let mut root = plan.root;
210
211        // Apply optimization rules
212        if self.enable_filter_pushdown {
213            root = self.push_filters_down(root);
214        }
215
216        if self.enable_join_reorder {
217            root = self.reorder_joins(root);
218        }
219
220        if self.enable_projection_pushdown {
221            root = self.push_projections_down(root);
222        }
223
224        Ok(LogicalPlan::new(root))
225    }
226
227    /// Pushes projections down the operator tree to eliminate unused columns early.
228    ///
229    /// This optimization:
230    /// 1. Collects required variables/properties from the root
231    /// 2. Propagates requirements down through the tree
232    /// 3. Inserts projections to eliminate unneeded columns before expensive operations
233    fn push_projections_down(&self, op: LogicalOperator) -> LogicalOperator {
234        // Collect required columns from the top of the plan
235        let required = self.collect_required_columns(&op);
236
237        // Push projections down
238        self.push_projections_recursive(op, &required)
239    }
240
241    /// Collects all variables and properties required by an operator and its ancestors.
242    fn collect_required_columns(&self, op: &LogicalOperator) -> HashSet<RequiredColumn> {
243        let mut required = HashSet::new();
244        Self::collect_required_recursive(op, &mut required);
245        required
246    }
247
248    /// Recursively collects required columns.
249    fn collect_required_recursive(op: &LogicalOperator, required: &mut HashSet<RequiredColumn>) {
250        match op {
251            LogicalOperator::Return(ret) => {
252                for item in &ret.items {
253                    Self::collect_from_expression(&item.expression, required);
254                }
255                Self::collect_required_recursive(&ret.input, required);
256            }
257            LogicalOperator::Project(proj) => {
258                for p in &proj.projections {
259                    Self::collect_from_expression(&p.expression, required);
260                }
261                Self::collect_required_recursive(&proj.input, required);
262            }
263            LogicalOperator::Filter(filter) => {
264                Self::collect_from_expression(&filter.predicate, required);
265                Self::collect_required_recursive(&filter.input, required);
266            }
267            LogicalOperator::Sort(sort) => {
268                for key in &sort.keys {
269                    Self::collect_from_expression(&key.expression, required);
270                }
271                Self::collect_required_recursive(&sort.input, required);
272            }
273            LogicalOperator::Aggregate(agg) => {
274                for expr in &agg.group_by {
275                    Self::collect_from_expression(expr, required);
276                }
277                for agg_expr in &agg.aggregates {
278                    if let Some(ref expr) = agg_expr.expression {
279                        Self::collect_from_expression(expr, required);
280                    }
281                }
282                if let Some(ref having) = agg.having {
283                    Self::collect_from_expression(having, required);
284                }
285                Self::collect_required_recursive(&agg.input, required);
286            }
287            LogicalOperator::Join(join) => {
288                for cond in &join.conditions {
289                    Self::collect_from_expression(&cond.left, required);
290                    Self::collect_from_expression(&cond.right, required);
291                }
292                Self::collect_required_recursive(&join.left, required);
293                Self::collect_required_recursive(&join.right, required);
294            }
295            LogicalOperator::Expand(expand) => {
296                // The source and target variables are needed
297                required.insert(RequiredColumn::Variable(expand.from_variable.clone()));
298                required.insert(RequiredColumn::Variable(expand.to_variable.clone()));
299                if let Some(ref edge_var) = expand.edge_variable {
300                    required.insert(RequiredColumn::Variable(edge_var.clone()));
301                }
302                Self::collect_required_recursive(&expand.input, required);
303            }
304            LogicalOperator::Limit(limit) => {
305                Self::collect_required_recursive(&limit.input, required);
306            }
307            LogicalOperator::Skip(skip) => {
308                Self::collect_required_recursive(&skip.input, required);
309            }
310            LogicalOperator::Distinct(distinct) => {
311                Self::collect_required_recursive(&distinct.input, required);
312            }
313            LogicalOperator::NodeScan(scan) => {
314                required.insert(RequiredColumn::Variable(scan.variable.clone()));
315            }
316            LogicalOperator::EdgeScan(scan) => {
317                required.insert(RequiredColumn::Variable(scan.variable.clone()));
318            }
319            _ => {}
320        }
321    }
322
323    /// Collects required columns from an expression.
324    fn collect_from_expression(expr: &LogicalExpression, required: &mut HashSet<RequiredColumn>) {
325        match expr {
326            LogicalExpression::Variable(var) => {
327                required.insert(RequiredColumn::Variable(var.clone()));
328            }
329            LogicalExpression::Property { variable, property } => {
330                required.insert(RequiredColumn::Property(variable.clone(), property.clone()));
331                required.insert(RequiredColumn::Variable(variable.clone()));
332            }
333            LogicalExpression::Binary { left, right, .. } => {
334                Self::collect_from_expression(left, required);
335                Self::collect_from_expression(right, required);
336            }
337            LogicalExpression::Unary { operand, .. } => {
338                Self::collect_from_expression(operand, required);
339            }
340            LogicalExpression::FunctionCall { args, .. } => {
341                for arg in args {
342                    Self::collect_from_expression(arg, required);
343                }
344            }
345            LogicalExpression::List(items) => {
346                for item in items {
347                    Self::collect_from_expression(item, required);
348                }
349            }
350            LogicalExpression::Map(pairs) => {
351                for (_, value) in pairs {
352                    Self::collect_from_expression(value, required);
353                }
354            }
355            LogicalExpression::IndexAccess { base, index } => {
356                Self::collect_from_expression(base, required);
357                Self::collect_from_expression(index, required);
358            }
359            LogicalExpression::SliceAccess { base, start, end } => {
360                Self::collect_from_expression(base, required);
361                if let Some(s) = start {
362                    Self::collect_from_expression(s, required);
363                }
364                if let Some(e) = end {
365                    Self::collect_from_expression(e, required);
366                }
367            }
368            LogicalExpression::Case {
369                operand,
370                when_clauses,
371                else_clause,
372            } => {
373                if let Some(op) = operand {
374                    Self::collect_from_expression(op, required);
375                }
376                for (cond, result) in when_clauses {
377                    Self::collect_from_expression(cond, required);
378                    Self::collect_from_expression(result, required);
379                }
380                if let Some(else_expr) = else_clause {
381                    Self::collect_from_expression(else_expr, required);
382                }
383            }
384            LogicalExpression::Labels(var)
385            | LogicalExpression::Type(var)
386            | LogicalExpression::Id(var) => {
387                required.insert(RequiredColumn::Variable(var.clone()));
388            }
389            LogicalExpression::ListComprehension {
390                list_expr,
391                filter_expr,
392                map_expr,
393                ..
394            } => {
395                Self::collect_from_expression(list_expr, required);
396                if let Some(filter) = filter_expr {
397                    Self::collect_from_expression(filter, required);
398                }
399                Self::collect_from_expression(map_expr, required);
400            }
401            _ => {}
402        }
403    }
404
405    /// Recursively pushes projections down, adding them before expensive operations.
406    fn push_projections_recursive(
407        &self,
408        op: LogicalOperator,
409        required: &HashSet<RequiredColumn>,
410    ) -> LogicalOperator {
411        match op {
412            LogicalOperator::Return(mut ret) => {
413                ret.input = Box::new(self.push_projections_recursive(*ret.input, required));
414                LogicalOperator::Return(ret)
415            }
416            LogicalOperator::Project(mut proj) => {
417                proj.input = Box::new(self.push_projections_recursive(*proj.input, required));
418                LogicalOperator::Project(proj)
419            }
420            LogicalOperator::Filter(mut filter) => {
421                filter.input = Box::new(self.push_projections_recursive(*filter.input, required));
422                LogicalOperator::Filter(filter)
423            }
424            LogicalOperator::Sort(mut sort) => {
425                // Sort is expensive - consider adding a projection before it
426                // to reduce tuple width
427                sort.input = Box::new(self.push_projections_recursive(*sort.input, required));
428                LogicalOperator::Sort(sort)
429            }
430            LogicalOperator::Aggregate(mut agg) => {
431                agg.input = Box::new(self.push_projections_recursive(*agg.input, required));
432                LogicalOperator::Aggregate(agg)
433            }
434            LogicalOperator::Join(mut join) => {
435                // Joins are expensive - the required columns help determine
436                // what to project on each side
437                let left_vars = self.collect_output_variables(&join.left);
438                let right_vars = self.collect_output_variables(&join.right);
439
440                // Filter required columns to each side
441                let left_required: HashSet<_> = required
442                    .iter()
443                    .filter(|c| match c {
444                        RequiredColumn::Variable(v) => left_vars.contains(v),
445                        RequiredColumn::Property(v, _) => left_vars.contains(v),
446                    })
447                    .cloned()
448                    .collect();
449
450                let right_required: HashSet<_> = required
451                    .iter()
452                    .filter(|c| match c {
453                        RequiredColumn::Variable(v) => right_vars.contains(v),
454                        RequiredColumn::Property(v, _) => right_vars.contains(v),
455                    })
456                    .cloned()
457                    .collect();
458
459                join.left = Box::new(self.push_projections_recursive(*join.left, &left_required));
460                join.right =
461                    Box::new(self.push_projections_recursive(*join.right, &right_required));
462                LogicalOperator::Join(join)
463            }
464            LogicalOperator::Expand(mut expand) => {
465                expand.input = Box::new(self.push_projections_recursive(*expand.input, required));
466                LogicalOperator::Expand(expand)
467            }
468            LogicalOperator::Limit(mut limit) => {
469                limit.input = Box::new(self.push_projections_recursive(*limit.input, required));
470                LogicalOperator::Limit(limit)
471            }
472            LogicalOperator::Skip(mut skip) => {
473                skip.input = Box::new(self.push_projections_recursive(*skip.input, required));
474                LogicalOperator::Skip(skip)
475            }
476            LogicalOperator::Distinct(mut distinct) => {
477                distinct.input =
478                    Box::new(self.push_projections_recursive(*distinct.input, required));
479                LogicalOperator::Distinct(distinct)
480            }
481            LogicalOperator::MapCollect(mut mc) => {
482                mc.input = Box::new(self.push_projections_recursive(*mc.input, required));
483                LogicalOperator::MapCollect(mc)
484            }
485            other => other,
486        }
487    }
488
489    /// Reorders joins in the operator tree using the DPccp algorithm.
490    ///
491    /// This optimization finds the optimal join order by:
492    /// 1. Extracting all base relations (scans) and join conditions
493    /// 2. Building a join graph
494    /// 3. Using dynamic programming to find the cheapest join order
495    fn reorder_joins(&self, op: LogicalOperator) -> LogicalOperator {
496        // First, recursively optimize children
497        let op = self.reorder_joins_recursive(op);
498
499        // Then, if this is a join tree, try to optimize it
500        if let Some((relations, conditions)) = self.extract_join_tree(&op)
501            && relations.len() >= 2
502            && let Some(optimized) = self.optimize_join_order(&relations, &conditions)
503        {
504            return optimized;
505        }
506
507        op
508    }
509
510    /// Recursively applies join reordering to child operators.
511    fn reorder_joins_recursive(&self, op: LogicalOperator) -> LogicalOperator {
512        match op {
513            LogicalOperator::Return(mut ret) => {
514                ret.input = Box::new(self.reorder_joins(*ret.input));
515                LogicalOperator::Return(ret)
516            }
517            LogicalOperator::Project(mut proj) => {
518                proj.input = Box::new(self.reorder_joins(*proj.input));
519                LogicalOperator::Project(proj)
520            }
521            LogicalOperator::Filter(mut filter) => {
522                filter.input = Box::new(self.reorder_joins(*filter.input));
523                LogicalOperator::Filter(filter)
524            }
525            LogicalOperator::Limit(mut limit) => {
526                limit.input = Box::new(self.reorder_joins(*limit.input));
527                LogicalOperator::Limit(limit)
528            }
529            LogicalOperator::Skip(mut skip) => {
530                skip.input = Box::new(self.reorder_joins(*skip.input));
531                LogicalOperator::Skip(skip)
532            }
533            LogicalOperator::Sort(mut sort) => {
534                sort.input = Box::new(self.reorder_joins(*sort.input));
535                LogicalOperator::Sort(sort)
536            }
537            LogicalOperator::Distinct(mut distinct) => {
538                distinct.input = Box::new(self.reorder_joins(*distinct.input));
539                LogicalOperator::Distinct(distinct)
540            }
541            LogicalOperator::Aggregate(mut agg) => {
542                agg.input = Box::new(self.reorder_joins(*agg.input));
543                LogicalOperator::Aggregate(agg)
544            }
545            LogicalOperator::Expand(mut expand) => {
546                expand.input = Box::new(self.reorder_joins(*expand.input));
547                LogicalOperator::Expand(expand)
548            }
549            LogicalOperator::MapCollect(mut mc) => {
550                mc.input = Box::new(self.reorder_joins(*mc.input));
551                LogicalOperator::MapCollect(mc)
552            }
553            // Join operators are handled by the parent reorder_joins call
554            other => other,
555        }
556    }
557
558    /// Extracts base relations and join conditions from a join tree.
559    ///
560    /// Returns None if the operator is not a join tree.
561    fn extract_join_tree(
562        &self,
563        op: &LogicalOperator,
564    ) -> Option<(Vec<(String, LogicalOperator)>, Vec<JoinInfo>)> {
565        let mut relations = Vec::new();
566        let mut join_conditions = Vec::new();
567
568        if !self.collect_join_tree(op, &mut relations, &mut join_conditions) {
569            return None;
570        }
571
572        if relations.len() < 2 {
573            return None;
574        }
575
576        Some((relations, join_conditions))
577    }
578
579    /// Recursively collects base relations and join conditions.
580    ///
581    /// Returns true if this subtree is part of a join tree.
582    fn collect_join_tree(
583        &self,
584        op: &LogicalOperator,
585        relations: &mut Vec<(String, LogicalOperator)>,
586        conditions: &mut Vec<JoinInfo>,
587    ) -> bool {
588        match op {
589            LogicalOperator::Join(join) => {
590                // Collect from both sides
591                let left_ok = self.collect_join_tree(&join.left, relations, conditions);
592                let right_ok = self.collect_join_tree(&join.right, relations, conditions);
593
594                // Add conditions from this join
595                for cond in &join.conditions {
596                    if let (Some(left_var), Some(right_var)) = (
597                        self.extract_variable_from_expr(&cond.left),
598                        self.extract_variable_from_expr(&cond.right),
599                    ) {
600                        conditions.push(JoinInfo {
601                            left_var,
602                            right_var,
603                            left_expr: cond.left.clone(),
604                            right_expr: cond.right.clone(),
605                        });
606                    }
607                }
608
609                left_ok && right_ok
610            }
611            LogicalOperator::NodeScan(scan) => {
612                relations.push((scan.variable.clone(), op.clone()));
613                true
614            }
615            LogicalOperator::EdgeScan(scan) => {
616                relations.push((scan.variable.clone(), op.clone()));
617                true
618            }
619            LogicalOperator::Filter(filter) => {
620                // A filter on a base relation is still part of the join tree
621                self.collect_join_tree(&filter.input, relations, conditions)
622            }
623            LogicalOperator::Expand(expand) => {
624                // Expand is a special case - it's like a join with the adjacency
625                // For now, treat the whole Expand subtree as a single relation
626                relations.push((expand.to_variable.clone(), op.clone()));
627                true
628            }
629            _ => false,
630        }
631    }
632
633    /// Extracts the primary variable from an expression.
634    fn extract_variable_from_expr(&self, expr: &LogicalExpression) -> Option<String> {
635        match expr {
636            LogicalExpression::Variable(v) => Some(v.clone()),
637            LogicalExpression::Property { variable, .. } => Some(variable.clone()),
638            _ => None,
639        }
640    }
641
642    /// Optimizes the join order using DPccp.
643    fn optimize_join_order(
644        &self,
645        relations: &[(String, LogicalOperator)],
646        conditions: &[JoinInfo],
647    ) -> Option<LogicalOperator> {
648        use join_order::{DPccp, JoinGraphBuilder};
649
650        // Build the join graph
651        let mut builder = JoinGraphBuilder::new();
652
653        for (var, relation) in relations {
654            builder.add_relation(var, relation.clone());
655        }
656
657        for cond in conditions {
658            builder.add_join_condition(
659                &cond.left_var,
660                &cond.right_var,
661                cond.left_expr.clone(),
662                cond.right_expr.clone(),
663            );
664        }
665
666        let graph = builder.build();
667
668        // Run DPccp
669        let mut dpccp = DPccp::new(&graph, &self.cost_model, &self.card_estimator);
670        let plan = dpccp.optimize()?;
671
672        Some(plan.operator)
673    }
674
675    /// Pushes filters down the operator tree.
676    ///
677    /// This optimization moves filter predicates as close to the data source
678    /// as possible to reduce the amount of data processed by upper operators.
679    fn push_filters_down(&self, op: LogicalOperator) -> LogicalOperator {
680        match op {
681            // For Filter operators, try to push the predicate into the child
682            LogicalOperator::Filter(filter) => {
683                let optimized_input = self.push_filters_down(*filter.input);
684                self.try_push_filter_into(filter.predicate, optimized_input)
685            }
686            // Recursively optimize children for other operators
687            LogicalOperator::Return(mut ret) => {
688                ret.input = Box::new(self.push_filters_down(*ret.input));
689                LogicalOperator::Return(ret)
690            }
691            LogicalOperator::Project(mut proj) => {
692                proj.input = Box::new(self.push_filters_down(*proj.input));
693                LogicalOperator::Project(proj)
694            }
695            LogicalOperator::Limit(mut limit) => {
696                limit.input = Box::new(self.push_filters_down(*limit.input));
697                LogicalOperator::Limit(limit)
698            }
699            LogicalOperator::Skip(mut skip) => {
700                skip.input = Box::new(self.push_filters_down(*skip.input));
701                LogicalOperator::Skip(skip)
702            }
703            LogicalOperator::Sort(mut sort) => {
704                sort.input = Box::new(self.push_filters_down(*sort.input));
705                LogicalOperator::Sort(sort)
706            }
707            LogicalOperator::Distinct(mut distinct) => {
708                distinct.input = Box::new(self.push_filters_down(*distinct.input));
709                LogicalOperator::Distinct(distinct)
710            }
711            LogicalOperator::Expand(mut expand) => {
712                expand.input = Box::new(self.push_filters_down(*expand.input));
713                LogicalOperator::Expand(expand)
714            }
715            LogicalOperator::Join(mut join) => {
716                join.left = Box::new(self.push_filters_down(*join.left));
717                join.right = Box::new(self.push_filters_down(*join.right));
718                LogicalOperator::Join(join)
719            }
720            LogicalOperator::Aggregate(mut agg) => {
721                agg.input = Box::new(self.push_filters_down(*agg.input));
722                LogicalOperator::Aggregate(agg)
723            }
724            LogicalOperator::MapCollect(mut mc) => {
725                mc.input = Box::new(self.push_filters_down(*mc.input));
726                LogicalOperator::MapCollect(mc)
727            }
728            // Leaf operators and unsupported operators are returned as-is
729            other => other,
730        }
731    }
732
733    /// Tries to push a filter predicate into the given operator.
734    ///
735    /// Returns either the predicate pushed into the operator, or a new
736    /// Filter operator on top if the predicate cannot be pushed further.
737    fn try_push_filter_into(
738        &self,
739        predicate: LogicalExpression,
740        op: LogicalOperator,
741    ) -> LogicalOperator {
742        match op {
743            // Can push through Project if predicate doesn't depend on computed columns
744            LogicalOperator::Project(mut proj) => {
745                let predicate_vars = self.extract_variables(&predicate);
746                let computed_vars = self.extract_projection_aliases(&proj.projections);
747
748                // If predicate doesn't use any computed columns, push through
749                if predicate_vars.is_disjoint(&computed_vars) {
750                    proj.input = Box::new(self.try_push_filter_into(predicate, *proj.input));
751                    LogicalOperator::Project(proj)
752                } else {
753                    // Can't push through, keep filter on top
754                    LogicalOperator::Filter(FilterOp {
755                        predicate,
756                        input: Box::new(LogicalOperator::Project(proj)),
757                    })
758                }
759            }
760
761            // Can push through Return (which is like a projection)
762            LogicalOperator::Return(mut ret) => {
763                ret.input = Box::new(self.try_push_filter_into(predicate, *ret.input));
764                LogicalOperator::Return(ret)
765            }
766
767            // Can push through Expand if predicate doesn't use variables introduced by this expand
768            LogicalOperator::Expand(mut expand) => {
769                let predicate_vars = self.extract_variables(&predicate);
770
771                // Variables introduced by this expand are:
772                // - The target variable (to_variable)
773                // - The edge variable (if any)
774                // - The path alias (if any)
775                let mut introduced_vars = vec![&expand.to_variable];
776                if let Some(ref edge_var) = expand.edge_variable {
777                    introduced_vars.push(edge_var);
778                }
779                if let Some(ref path_alias) = expand.path_alias {
780                    introduced_vars.push(path_alias);
781                }
782
783                // Check if predicate uses any variables introduced by this expand
784                let uses_introduced_vars =
785                    predicate_vars.iter().any(|v| introduced_vars.contains(&v));
786
787                if !uses_introduced_vars {
788                    // Predicate doesn't use vars from this expand, so push through
789                    expand.input = Box::new(self.try_push_filter_into(predicate, *expand.input));
790                    LogicalOperator::Expand(expand)
791                } else {
792                    // Keep filter after expand
793                    LogicalOperator::Filter(FilterOp {
794                        predicate,
795                        input: Box::new(LogicalOperator::Expand(expand)),
796                    })
797                }
798            }
799
800            // Can push through Join to left/right side based on variables used
801            LogicalOperator::Join(mut join) => {
802                let predicate_vars = self.extract_variables(&predicate);
803                let left_vars = self.collect_output_variables(&join.left);
804                let right_vars = self.collect_output_variables(&join.right);
805
806                let uses_left = predicate_vars.iter().any(|v| left_vars.contains(v));
807                let uses_right = predicate_vars.iter().any(|v| right_vars.contains(v));
808
809                if uses_left && !uses_right {
810                    // Push to left side
811                    join.left = Box::new(self.try_push_filter_into(predicate, *join.left));
812                    LogicalOperator::Join(join)
813                } else if uses_right && !uses_left {
814                    // Push to right side
815                    join.right = Box::new(self.try_push_filter_into(predicate, *join.right));
816                    LogicalOperator::Join(join)
817                } else {
818                    // Uses both sides - keep above join
819                    LogicalOperator::Filter(FilterOp {
820                        predicate,
821                        input: Box::new(LogicalOperator::Join(join)),
822                    })
823                }
824            }
825
826            // Cannot push through Aggregate (predicate refers to aggregated values)
827            LogicalOperator::Aggregate(agg) => LogicalOperator::Filter(FilterOp {
828                predicate,
829                input: Box::new(LogicalOperator::Aggregate(agg)),
830            }),
831
832            // For NodeScan, we've reached the bottom - keep filter on top
833            LogicalOperator::NodeScan(scan) => LogicalOperator::Filter(FilterOp {
834                predicate,
835                input: Box::new(LogicalOperator::NodeScan(scan)),
836            }),
837
838            // For other operators, keep filter on top
839            other => LogicalOperator::Filter(FilterOp {
840                predicate,
841                input: Box::new(other),
842            }),
843        }
844    }
845
846    /// Collects all output variable names from an operator.
847    fn collect_output_variables(&self, op: &LogicalOperator) -> HashSet<String> {
848        let mut vars = HashSet::new();
849        Self::collect_output_variables_recursive(op, &mut vars);
850        vars
851    }
852
853    /// Recursively collects output variables from an operator.
854    fn collect_output_variables_recursive(op: &LogicalOperator, vars: &mut HashSet<String>) {
855        match op {
856            LogicalOperator::NodeScan(scan) => {
857                vars.insert(scan.variable.clone());
858            }
859            LogicalOperator::EdgeScan(scan) => {
860                vars.insert(scan.variable.clone());
861            }
862            LogicalOperator::Expand(expand) => {
863                vars.insert(expand.to_variable.clone());
864                if let Some(edge_var) = &expand.edge_variable {
865                    vars.insert(edge_var.clone());
866                }
867                Self::collect_output_variables_recursive(&expand.input, vars);
868            }
869            LogicalOperator::Filter(filter) => {
870                Self::collect_output_variables_recursive(&filter.input, vars);
871            }
872            LogicalOperator::Project(proj) => {
873                for p in &proj.projections {
874                    if let Some(alias) = &p.alias {
875                        vars.insert(alias.clone());
876                    }
877                }
878                Self::collect_output_variables_recursive(&proj.input, vars);
879            }
880            LogicalOperator::Join(join) => {
881                Self::collect_output_variables_recursive(&join.left, vars);
882                Self::collect_output_variables_recursive(&join.right, vars);
883            }
884            LogicalOperator::Aggregate(agg) => {
885                for expr in &agg.group_by {
886                    Self::collect_variables(expr, vars);
887                }
888                for agg_expr in &agg.aggregates {
889                    if let Some(alias) = &agg_expr.alias {
890                        vars.insert(alias.clone());
891                    }
892                }
893            }
894            LogicalOperator::Return(ret) => {
895                Self::collect_output_variables_recursive(&ret.input, vars);
896            }
897            LogicalOperator::Limit(limit) => {
898                Self::collect_output_variables_recursive(&limit.input, vars);
899            }
900            LogicalOperator::Skip(skip) => {
901                Self::collect_output_variables_recursive(&skip.input, vars);
902            }
903            LogicalOperator::Sort(sort) => {
904                Self::collect_output_variables_recursive(&sort.input, vars);
905            }
906            LogicalOperator::Distinct(distinct) => {
907                Self::collect_output_variables_recursive(&distinct.input, vars);
908            }
909            _ => {}
910        }
911    }
912
913    /// Extracts all variable names referenced in an expression.
914    fn extract_variables(&self, expr: &LogicalExpression) -> HashSet<String> {
915        let mut vars = HashSet::new();
916        Self::collect_variables(expr, &mut vars);
917        vars
918    }
919
920    /// Recursively collects variable names from an expression.
921    fn collect_variables(expr: &LogicalExpression, vars: &mut HashSet<String>) {
922        match expr {
923            LogicalExpression::Variable(name) => {
924                vars.insert(name.clone());
925            }
926            LogicalExpression::Property { variable, .. } => {
927                vars.insert(variable.clone());
928            }
929            LogicalExpression::Binary { left, right, .. } => {
930                Self::collect_variables(left, vars);
931                Self::collect_variables(right, vars);
932            }
933            LogicalExpression::Unary { operand, .. } => {
934                Self::collect_variables(operand, vars);
935            }
936            LogicalExpression::FunctionCall { args, .. } => {
937                for arg in args {
938                    Self::collect_variables(arg, vars);
939                }
940            }
941            LogicalExpression::List(items) => {
942                for item in items {
943                    Self::collect_variables(item, vars);
944                }
945            }
946            LogicalExpression::Map(pairs) => {
947                for (_, value) in pairs {
948                    Self::collect_variables(value, vars);
949                }
950            }
951            LogicalExpression::IndexAccess { base, index } => {
952                Self::collect_variables(base, vars);
953                Self::collect_variables(index, vars);
954            }
955            LogicalExpression::SliceAccess { base, start, end } => {
956                Self::collect_variables(base, vars);
957                if let Some(s) = start {
958                    Self::collect_variables(s, vars);
959                }
960                if let Some(e) = end {
961                    Self::collect_variables(e, vars);
962                }
963            }
964            LogicalExpression::Case {
965                operand,
966                when_clauses,
967                else_clause,
968            } => {
969                if let Some(op) = operand {
970                    Self::collect_variables(op, vars);
971                }
972                for (cond, result) in when_clauses {
973                    Self::collect_variables(cond, vars);
974                    Self::collect_variables(result, vars);
975                }
976                if let Some(else_expr) = else_clause {
977                    Self::collect_variables(else_expr, vars);
978                }
979            }
980            LogicalExpression::Labels(var)
981            | LogicalExpression::Type(var)
982            | LogicalExpression::Id(var) => {
983                vars.insert(var.clone());
984            }
985            LogicalExpression::Literal(_) | LogicalExpression::Parameter(_) => {}
986            LogicalExpression::ListComprehension {
987                list_expr,
988                filter_expr,
989                map_expr,
990                ..
991            } => {
992                Self::collect_variables(list_expr, vars);
993                if let Some(filter) = filter_expr {
994                    Self::collect_variables(filter, vars);
995                }
996                Self::collect_variables(map_expr, vars);
997            }
998            LogicalExpression::ListPredicate {
999                list_expr,
1000                predicate,
1001                ..
1002            } => {
1003                Self::collect_variables(list_expr, vars);
1004                Self::collect_variables(predicate, vars);
1005            }
1006            LogicalExpression::ExistsSubquery(_) | LogicalExpression::CountSubquery(_) => {
1007                // Subqueries have their own variable scope
1008            }
1009            LogicalExpression::PatternComprehension { projection, .. } => {
1010                Self::collect_variables(projection, vars);
1011            }
1012            LogicalExpression::MapProjection { base, entries } => {
1013                vars.insert(base.clone());
1014                for entry in entries {
1015                    if let crate::query::plan::MapProjectionEntry::LiteralEntry(_, expr) = entry {
1016                        Self::collect_variables(expr, vars);
1017                    }
1018                }
1019            }
1020            LogicalExpression::Reduce {
1021                initial,
1022                list,
1023                expression,
1024                ..
1025            } => {
1026                Self::collect_variables(initial, vars);
1027                Self::collect_variables(list, vars);
1028                Self::collect_variables(expression, vars);
1029            }
1030        }
1031    }
1032
1033    /// Extracts aliases from projection expressions.
1034    fn extract_projection_aliases(
1035        &self,
1036        projections: &[crate::query::plan::Projection],
1037    ) -> HashSet<String> {
1038        projections.iter().filter_map(|p| p.alias.clone()).collect()
1039    }
1040}
1041
1042impl Default for Optimizer {
1043    fn default() -> Self {
1044        Self::new()
1045    }
1046}
1047
1048#[cfg(test)]
1049mod tests {
1050    use super::*;
1051    use crate::query::plan::{
1052        AggregateExpr, AggregateFunction, AggregateOp, BinaryOp, DistinctOp, ExpandDirection,
1053        ExpandOp, JoinOp, JoinType, LimitOp, NodeScanOp, PathMode, ProjectOp, Projection,
1054        ReturnItem, ReturnOp, SkipOp, SortKey, SortOp, SortOrder, UnaryOp,
1055    };
1056    use grafeo_common::types::Value;
1057
1058    #[test]
1059    fn test_optimizer_filter_pushdown_simple() {
1060        // Query: MATCH (n:Person) WHERE n.age > 30 RETURN n
1061        // Before: Return -> Filter -> NodeScan
1062        // After:  Return -> Filter -> NodeScan (filter stays at bottom)
1063
1064        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1065            items: vec![ReturnItem {
1066                expression: LogicalExpression::Variable("n".to_string()),
1067                alias: None,
1068            }],
1069            distinct: false,
1070            input: Box::new(LogicalOperator::Filter(FilterOp {
1071                predicate: LogicalExpression::Binary {
1072                    left: Box::new(LogicalExpression::Property {
1073                        variable: "n".to_string(),
1074                        property: "age".to_string(),
1075                    }),
1076                    op: BinaryOp::Gt,
1077                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1078                },
1079                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1080                    variable: "n".to_string(),
1081                    label: Some("Person".to_string()),
1082                    input: None,
1083                })),
1084            })),
1085        }));
1086
1087        let optimizer = Optimizer::new();
1088        let optimized = optimizer.optimize(plan).unwrap();
1089
1090        // The structure should remain similar (filter stays near scan)
1091        if let LogicalOperator::Return(ret) = &optimized.root
1092            && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1093            && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1094        {
1095            assert_eq!(scan.variable, "n");
1096            return;
1097        }
1098        panic!("Expected Return -> Filter -> NodeScan structure");
1099    }
1100
1101    #[test]
1102    fn test_optimizer_filter_pushdown_through_expand() {
1103        // Query: MATCH (a:Person)-[:KNOWS]->(b) WHERE a.age > 30 RETURN b
1104        // The filter on 'a' should be pushed before the expand
1105
1106        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1107            items: vec![ReturnItem {
1108                expression: LogicalExpression::Variable("b".to_string()),
1109                alias: None,
1110            }],
1111            distinct: false,
1112            input: Box::new(LogicalOperator::Filter(FilterOp {
1113                predicate: LogicalExpression::Binary {
1114                    left: Box::new(LogicalExpression::Property {
1115                        variable: "a".to_string(),
1116                        property: "age".to_string(),
1117                    }),
1118                    op: BinaryOp::Gt,
1119                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1120                },
1121                input: Box::new(LogicalOperator::Expand(ExpandOp {
1122                    from_variable: "a".to_string(),
1123                    to_variable: "b".to_string(),
1124                    edge_variable: None,
1125                    direction: ExpandDirection::Outgoing,
1126                    edge_types: vec!["KNOWS".to_string()],
1127                    min_hops: 1,
1128                    max_hops: Some(1),
1129                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1130                        variable: "a".to_string(),
1131                        label: Some("Person".to_string()),
1132                        input: None,
1133                    })),
1134                    path_alias: None,
1135                    path_mode: PathMode::Walk,
1136                })),
1137            })),
1138        }));
1139
1140        let optimizer = Optimizer::new();
1141        let optimized = optimizer.optimize(plan).unwrap();
1142
1143        // Filter on 'a' should be pushed before the expand
1144        // Expected: Return -> Expand -> Filter -> NodeScan
1145        if let LogicalOperator::Return(ret) = &optimized.root
1146            && let LogicalOperator::Expand(expand) = ret.input.as_ref()
1147            && let LogicalOperator::Filter(filter) = expand.input.as_ref()
1148            && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1149        {
1150            assert_eq!(scan.variable, "a");
1151            assert_eq!(expand.from_variable, "a");
1152            assert_eq!(expand.to_variable, "b");
1153            return;
1154        }
1155        panic!("Expected Return -> Expand -> Filter -> NodeScan structure");
1156    }
1157
1158    #[test]
1159    fn test_optimizer_filter_not_pushed_through_expand_for_target_var() {
1160        // Query: MATCH (a:Person)-[:KNOWS]->(b) WHERE b.age > 30 RETURN a
1161        // The filter on 'b' should NOT be pushed before the expand
1162
1163        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1164            items: vec![ReturnItem {
1165                expression: LogicalExpression::Variable("a".to_string()),
1166                alias: None,
1167            }],
1168            distinct: false,
1169            input: Box::new(LogicalOperator::Filter(FilterOp {
1170                predicate: LogicalExpression::Binary {
1171                    left: Box::new(LogicalExpression::Property {
1172                        variable: "b".to_string(),
1173                        property: "age".to_string(),
1174                    }),
1175                    op: BinaryOp::Gt,
1176                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1177                },
1178                input: Box::new(LogicalOperator::Expand(ExpandOp {
1179                    from_variable: "a".to_string(),
1180                    to_variable: "b".to_string(),
1181                    edge_variable: None,
1182                    direction: ExpandDirection::Outgoing,
1183                    edge_types: vec!["KNOWS".to_string()],
1184                    min_hops: 1,
1185                    max_hops: Some(1),
1186                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1187                        variable: "a".to_string(),
1188                        label: Some("Person".to_string()),
1189                        input: None,
1190                    })),
1191                    path_alias: None,
1192                    path_mode: PathMode::Walk,
1193                })),
1194            })),
1195        }));
1196
1197        let optimizer = Optimizer::new();
1198        let optimized = optimizer.optimize(plan).unwrap();
1199
1200        // Filter on 'b' should stay after the expand
1201        // Expected: Return -> Filter -> Expand -> NodeScan
1202        if let LogicalOperator::Return(ret) = &optimized.root
1203            && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1204        {
1205            // Check that the filter is on 'b'
1206            if let LogicalExpression::Binary { left, .. } = &filter.predicate
1207                && let LogicalExpression::Property { variable, .. } = left.as_ref()
1208            {
1209                assert_eq!(variable, "b");
1210            }
1211
1212            if let LogicalOperator::Expand(expand) = filter.input.as_ref()
1213                && let LogicalOperator::NodeScan(_) = expand.input.as_ref()
1214            {
1215                return;
1216            }
1217        }
1218        panic!("Expected Return -> Filter -> Expand -> NodeScan structure");
1219    }
1220
1221    #[test]
1222    fn test_optimizer_extract_variables() {
1223        let optimizer = Optimizer::new();
1224
1225        let expr = LogicalExpression::Binary {
1226            left: Box::new(LogicalExpression::Property {
1227                variable: "n".to_string(),
1228                property: "age".to_string(),
1229            }),
1230            op: BinaryOp::Gt,
1231            right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1232        };
1233
1234        let vars = optimizer.extract_variables(&expr);
1235        assert_eq!(vars.len(), 1);
1236        assert!(vars.contains("n"));
1237    }
1238
1239    // Additional tests for optimizer configuration
1240
1241    #[test]
1242    fn test_optimizer_default() {
1243        let optimizer = Optimizer::default();
1244        // Should be able to optimize an empty plan
1245        let plan = LogicalPlan::new(LogicalOperator::Empty);
1246        let result = optimizer.optimize(plan);
1247        assert!(result.is_ok());
1248    }
1249
1250    #[test]
1251    fn test_optimizer_with_filter_pushdown_disabled() {
1252        let optimizer = Optimizer::new().with_filter_pushdown(false);
1253
1254        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1255            items: vec![ReturnItem {
1256                expression: LogicalExpression::Variable("n".to_string()),
1257                alias: None,
1258            }],
1259            distinct: false,
1260            input: Box::new(LogicalOperator::Filter(FilterOp {
1261                predicate: LogicalExpression::Literal(Value::Bool(true)),
1262                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1263                    variable: "n".to_string(),
1264                    label: None,
1265                    input: None,
1266                })),
1267            })),
1268        }));
1269
1270        let optimized = optimizer.optimize(plan).unwrap();
1271        // Structure should be unchanged
1272        if let LogicalOperator::Return(ret) = &optimized.root
1273            && let LogicalOperator::Filter(_) = ret.input.as_ref()
1274        {
1275            return;
1276        }
1277        panic!("Expected unchanged structure");
1278    }
1279
1280    #[test]
1281    fn test_optimizer_with_join_reorder_disabled() {
1282        let optimizer = Optimizer::new().with_join_reorder(false);
1283        assert!(
1284            optimizer
1285                .optimize(LogicalPlan::new(LogicalOperator::Empty))
1286                .is_ok()
1287        );
1288    }
1289
1290    #[test]
1291    fn test_optimizer_with_cost_model() {
1292        let cost_model = CostModel::new();
1293        let optimizer = Optimizer::new().with_cost_model(cost_model);
1294        assert!(
1295            optimizer
1296                .cost_model()
1297                .estimate(&LogicalOperator::Empty, 0.0)
1298                .total()
1299                < 0.001
1300        );
1301    }
1302
1303    #[test]
1304    fn test_optimizer_with_cardinality_estimator() {
1305        let mut estimator = CardinalityEstimator::new();
1306        estimator.add_table_stats("Test", TableStats::new(500));
1307        let optimizer = Optimizer::new().with_cardinality_estimator(estimator);
1308
1309        let scan = LogicalOperator::NodeScan(NodeScanOp {
1310            variable: "n".to_string(),
1311            label: Some("Test".to_string()),
1312            input: None,
1313        });
1314        let plan = LogicalPlan::new(scan);
1315
1316        let cardinality = optimizer.estimate_cardinality(&plan);
1317        assert!((cardinality - 500.0).abs() < 0.001);
1318    }
1319
1320    #[test]
1321    fn test_optimizer_estimate_cost() {
1322        let optimizer = Optimizer::new();
1323        let plan = LogicalPlan::new(LogicalOperator::NodeScan(NodeScanOp {
1324            variable: "n".to_string(),
1325            label: None,
1326            input: None,
1327        }));
1328
1329        let cost = optimizer.estimate_cost(&plan);
1330        assert!(cost.total() > 0.0);
1331    }
1332
1333    // Filter pushdown through various operators
1334
1335    #[test]
1336    fn test_filter_pushdown_through_project() {
1337        let optimizer = Optimizer::new();
1338
1339        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1340            predicate: LogicalExpression::Binary {
1341                left: Box::new(LogicalExpression::Property {
1342                    variable: "n".to_string(),
1343                    property: "age".to_string(),
1344                }),
1345                op: BinaryOp::Gt,
1346                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1347            },
1348            input: Box::new(LogicalOperator::Project(ProjectOp {
1349                projections: vec![Projection {
1350                    expression: LogicalExpression::Variable("n".to_string()),
1351                    alias: None,
1352                }],
1353                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1354                    variable: "n".to_string(),
1355                    label: None,
1356                    input: None,
1357                })),
1358            })),
1359        }));
1360
1361        let optimized = optimizer.optimize(plan).unwrap();
1362
1363        // Filter should be pushed through Project
1364        if let LogicalOperator::Project(proj) = &optimized.root
1365            && let LogicalOperator::Filter(_) = proj.input.as_ref()
1366        {
1367            return;
1368        }
1369        panic!("Expected Project -> Filter structure");
1370    }
1371
1372    #[test]
1373    fn test_filter_not_pushed_through_project_with_alias() {
1374        let optimizer = Optimizer::new();
1375
1376        // Filter on computed column 'x' should not be pushed through project that creates 'x'
1377        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1378            predicate: LogicalExpression::Binary {
1379                left: Box::new(LogicalExpression::Variable("x".to_string())),
1380                op: BinaryOp::Gt,
1381                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1382            },
1383            input: Box::new(LogicalOperator::Project(ProjectOp {
1384                projections: vec![Projection {
1385                    expression: LogicalExpression::Property {
1386                        variable: "n".to_string(),
1387                        property: "age".to_string(),
1388                    },
1389                    alias: Some("x".to_string()),
1390                }],
1391                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1392                    variable: "n".to_string(),
1393                    label: None,
1394                    input: None,
1395                })),
1396            })),
1397        }));
1398
1399        let optimized = optimizer.optimize(plan).unwrap();
1400
1401        // Filter should stay above Project
1402        if let LogicalOperator::Filter(filter) = &optimized.root
1403            && let LogicalOperator::Project(_) = filter.input.as_ref()
1404        {
1405            return;
1406        }
1407        panic!("Expected Filter -> Project structure");
1408    }
1409
1410    #[test]
1411    fn test_filter_pushdown_through_limit() {
1412        let optimizer = Optimizer::new();
1413
1414        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1415            predicate: LogicalExpression::Literal(Value::Bool(true)),
1416            input: Box::new(LogicalOperator::Limit(LimitOp {
1417                count: 10,
1418                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1419                    variable: "n".to_string(),
1420                    label: None,
1421                    input: None,
1422                })),
1423            })),
1424        }));
1425
1426        let optimized = optimizer.optimize(plan).unwrap();
1427
1428        // Filter stays above Limit (cannot be pushed through)
1429        if let LogicalOperator::Filter(filter) = &optimized.root
1430            && let LogicalOperator::Limit(_) = filter.input.as_ref()
1431        {
1432            return;
1433        }
1434        panic!("Expected Filter -> Limit structure");
1435    }
1436
1437    #[test]
1438    fn test_filter_pushdown_through_sort() {
1439        let optimizer = Optimizer::new();
1440
1441        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1442            predicate: LogicalExpression::Literal(Value::Bool(true)),
1443            input: Box::new(LogicalOperator::Sort(SortOp {
1444                keys: vec![SortKey {
1445                    expression: LogicalExpression::Variable("n".to_string()),
1446                    order: SortOrder::Ascending,
1447                }],
1448                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1449                    variable: "n".to_string(),
1450                    label: None,
1451                    input: None,
1452                })),
1453            })),
1454        }));
1455
1456        let optimized = optimizer.optimize(plan).unwrap();
1457
1458        // Filter stays above Sort
1459        if let LogicalOperator::Filter(filter) = &optimized.root
1460            && let LogicalOperator::Sort(_) = filter.input.as_ref()
1461        {
1462            return;
1463        }
1464        panic!("Expected Filter -> Sort structure");
1465    }
1466
1467    #[test]
1468    fn test_filter_pushdown_through_distinct() {
1469        let optimizer = Optimizer::new();
1470
1471        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1472            predicate: LogicalExpression::Literal(Value::Bool(true)),
1473            input: Box::new(LogicalOperator::Distinct(DistinctOp {
1474                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1475                    variable: "n".to_string(),
1476                    label: None,
1477                    input: None,
1478                })),
1479                columns: None,
1480            })),
1481        }));
1482
1483        let optimized = optimizer.optimize(plan).unwrap();
1484
1485        // Filter stays above Distinct
1486        if let LogicalOperator::Filter(filter) = &optimized.root
1487            && let LogicalOperator::Distinct(_) = filter.input.as_ref()
1488        {
1489            return;
1490        }
1491        panic!("Expected Filter -> Distinct structure");
1492    }
1493
1494    #[test]
1495    fn test_filter_not_pushed_through_aggregate() {
1496        let optimizer = Optimizer::new();
1497
1498        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1499            predicate: LogicalExpression::Binary {
1500                left: Box::new(LogicalExpression::Variable("cnt".to_string())),
1501                op: BinaryOp::Gt,
1502                right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1503            },
1504            input: Box::new(LogicalOperator::Aggregate(AggregateOp {
1505                group_by: vec![],
1506                aggregates: vec![AggregateExpr {
1507                    function: AggregateFunction::Count,
1508                    expression: None,
1509                    distinct: false,
1510                    alias: Some("cnt".to_string()),
1511                    percentile: None,
1512                }],
1513                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1514                    variable: "n".to_string(),
1515                    label: None,
1516                    input: None,
1517                })),
1518                having: None,
1519            })),
1520        }));
1521
1522        let optimized = optimizer.optimize(plan).unwrap();
1523
1524        // Filter should stay above Aggregate
1525        if let LogicalOperator::Filter(filter) = &optimized.root
1526            && let LogicalOperator::Aggregate(_) = filter.input.as_ref()
1527        {
1528            return;
1529        }
1530        panic!("Expected Filter -> Aggregate structure");
1531    }
1532
1533    #[test]
1534    fn test_filter_pushdown_to_left_join_side() {
1535        let optimizer = Optimizer::new();
1536
1537        // Filter on left variable should be pushed to left side
1538        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1539            predicate: LogicalExpression::Binary {
1540                left: Box::new(LogicalExpression::Property {
1541                    variable: "a".to_string(),
1542                    property: "age".to_string(),
1543                }),
1544                op: BinaryOp::Gt,
1545                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1546            },
1547            input: Box::new(LogicalOperator::Join(JoinOp {
1548                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1549                    variable: "a".to_string(),
1550                    label: Some("Person".to_string()),
1551                    input: None,
1552                })),
1553                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1554                    variable: "b".to_string(),
1555                    label: Some("Company".to_string()),
1556                    input: None,
1557                })),
1558                join_type: JoinType::Inner,
1559                conditions: vec![],
1560            })),
1561        }));
1562
1563        let optimized = optimizer.optimize(plan).unwrap();
1564
1565        // Filter should be pushed to left side of join
1566        if let LogicalOperator::Join(join) = &optimized.root
1567            && let LogicalOperator::Filter(_) = join.left.as_ref()
1568        {
1569            return;
1570        }
1571        panic!("Expected Join with Filter on left side");
1572    }
1573
1574    #[test]
1575    fn test_filter_pushdown_to_right_join_side() {
1576        let optimizer = Optimizer::new();
1577
1578        // Filter on right variable should be pushed to right side
1579        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1580            predicate: LogicalExpression::Binary {
1581                left: Box::new(LogicalExpression::Property {
1582                    variable: "b".to_string(),
1583                    property: "name".to_string(),
1584                }),
1585                op: BinaryOp::Eq,
1586                right: Box::new(LogicalExpression::Literal(Value::String("Acme".into()))),
1587            },
1588            input: Box::new(LogicalOperator::Join(JoinOp {
1589                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1590                    variable: "a".to_string(),
1591                    label: Some("Person".to_string()),
1592                    input: None,
1593                })),
1594                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1595                    variable: "b".to_string(),
1596                    label: Some("Company".to_string()),
1597                    input: None,
1598                })),
1599                join_type: JoinType::Inner,
1600                conditions: vec![],
1601            })),
1602        }));
1603
1604        let optimized = optimizer.optimize(plan).unwrap();
1605
1606        // Filter should be pushed to right side of join
1607        if let LogicalOperator::Join(join) = &optimized.root
1608            && let LogicalOperator::Filter(_) = join.right.as_ref()
1609        {
1610            return;
1611        }
1612        panic!("Expected Join with Filter on right side");
1613    }
1614
1615    #[test]
1616    fn test_filter_not_pushed_when_uses_both_join_sides() {
1617        let optimizer = Optimizer::new();
1618
1619        // Filter using both variables should stay above join
1620        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1621            predicate: LogicalExpression::Binary {
1622                left: Box::new(LogicalExpression::Property {
1623                    variable: "a".to_string(),
1624                    property: "id".to_string(),
1625                }),
1626                op: BinaryOp::Eq,
1627                right: Box::new(LogicalExpression::Property {
1628                    variable: "b".to_string(),
1629                    property: "a_id".to_string(),
1630                }),
1631            },
1632            input: Box::new(LogicalOperator::Join(JoinOp {
1633                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1634                    variable: "a".to_string(),
1635                    label: None,
1636                    input: None,
1637                })),
1638                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1639                    variable: "b".to_string(),
1640                    label: None,
1641                    input: None,
1642                })),
1643                join_type: JoinType::Inner,
1644                conditions: vec![],
1645            })),
1646        }));
1647
1648        let optimized = optimizer.optimize(plan).unwrap();
1649
1650        // Filter should stay above join
1651        if let LogicalOperator::Filter(filter) = &optimized.root
1652            && let LogicalOperator::Join(_) = filter.input.as_ref()
1653        {
1654            return;
1655        }
1656        panic!("Expected Filter -> Join structure");
1657    }
1658
1659    // Variable extraction tests
1660
1661    #[test]
1662    fn test_extract_variables_from_variable() {
1663        let optimizer = Optimizer::new();
1664        let expr = LogicalExpression::Variable("x".to_string());
1665        let vars = optimizer.extract_variables(&expr);
1666        assert_eq!(vars.len(), 1);
1667        assert!(vars.contains("x"));
1668    }
1669
1670    #[test]
1671    fn test_extract_variables_from_unary() {
1672        let optimizer = Optimizer::new();
1673        let expr = LogicalExpression::Unary {
1674            op: UnaryOp::Not,
1675            operand: Box::new(LogicalExpression::Variable("x".to_string())),
1676        };
1677        let vars = optimizer.extract_variables(&expr);
1678        assert_eq!(vars.len(), 1);
1679        assert!(vars.contains("x"));
1680    }
1681
1682    #[test]
1683    fn test_extract_variables_from_function_call() {
1684        let optimizer = Optimizer::new();
1685        let expr = LogicalExpression::FunctionCall {
1686            name: "length".to_string(),
1687            args: vec![
1688                LogicalExpression::Variable("a".to_string()),
1689                LogicalExpression::Variable("b".to_string()),
1690            ],
1691            distinct: false,
1692        };
1693        let vars = optimizer.extract_variables(&expr);
1694        assert_eq!(vars.len(), 2);
1695        assert!(vars.contains("a"));
1696        assert!(vars.contains("b"));
1697    }
1698
1699    #[test]
1700    fn test_extract_variables_from_list() {
1701        let optimizer = Optimizer::new();
1702        let expr = LogicalExpression::List(vec![
1703            LogicalExpression::Variable("a".to_string()),
1704            LogicalExpression::Literal(Value::Int64(1)),
1705            LogicalExpression::Variable("b".to_string()),
1706        ]);
1707        let vars = optimizer.extract_variables(&expr);
1708        assert_eq!(vars.len(), 2);
1709        assert!(vars.contains("a"));
1710        assert!(vars.contains("b"));
1711    }
1712
1713    #[test]
1714    fn test_extract_variables_from_map() {
1715        let optimizer = Optimizer::new();
1716        let expr = LogicalExpression::Map(vec![
1717            (
1718                "key1".to_string(),
1719                LogicalExpression::Variable("a".to_string()),
1720            ),
1721            (
1722                "key2".to_string(),
1723                LogicalExpression::Variable("b".to_string()),
1724            ),
1725        ]);
1726        let vars = optimizer.extract_variables(&expr);
1727        assert_eq!(vars.len(), 2);
1728        assert!(vars.contains("a"));
1729        assert!(vars.contains("b"));
1730    }
1731
1732    #[test]
1733    fn test_extract_variables_from_index_access() {
1734        let optimizer = Optimizer::new();
1735        let expr = LogicalExpression::IndexAccess {
1736            base: Box::new(LogicalExpression::Variable("list".to_string())),
1737            index: Box::new(LogicalExpression::Variable("idx".to_string())),
1738        };
1739        let vars = optimizer.extract_variables(&expr);
1740        assert_eq!(vars.len(), 2);
1741        assert!(vars.contains("list"));
1742        assert!(vars.contains("idx"));
1743    }
1744
1745    #[test]
1746    fn test_extract_variables_from_slice_access() {
1747        let optimizer = Optimizer::new();
1748        let expr = LogicalExpression::SliceAccess {
1749            base: Box::new(LogicalExpression::Variable("list".to_string())),
1750            start: Some(Box::new(LogicalExpression::Variable("s".to_string()))),
1751            end: Some(Box::new(LogicalExpression::Variable("e".to_string()))),
1752        };
1753        let vars = optimizer.extract_variables(&expr);
1754        assert_eq!(vars.len(), 3);
1755        assert!(vars.contains("list"));
1756        assert!(vars.contains("s"));
1757        assert!(vars.contains("e"));
1758    }
1759
1760    #[test]
1761    fn test_extract_variables_from_case() {
1762        let optimizer = Optimizer::new();
1763        let expr = LogicalExpression::Case {
1764            operand: Some(Box::new(LogicalExpression::Variable("x".to_string()))),
1765            when_clauses: vec![(
1766                LogicalExpression::Literal(Value::Int64(1)),
1767                LogicalExpression::Variable("a".to_string()),
1768            )],
1769            else_clause: Some(Box::new(LogicalExpression::Variable("b".to_string()))),
1770        };
1771        let vars = optimizer.extract_variables(&expr);
1772        assert_eq!(vars.len(), 3);
1773        assert!(vars.contains("x"));
1774        assert!(vars.contains("a"));
1775        assert!(vars.contains("b"));
1776    }
1777
1778    #[test]
1779    fn test_extract_variables_from_labels() {
1780        let optimizer = Optimizer::new();
1781        let expr = LogicalExpression::Labels("n".to_string());
1782        let vars = optimizer.extract_variables(&expr);
1783        assert_eq!(vars.len(), 1);
1784        assert!(vars.contains("n"));
1785    }
1786
1787    #[test]
1788    fn test_extract_variables_from_type() {
1789        let optimizer = Optimizer::new();
1790        let expr = LogicalExpression::Type("e".to_string());
1791        let vars = optimizer.extract_variables(&expr);
1792        assert_eq!(vars.len(), 1);
1793        assert!(vars.contains("e"));
1794    }
1795
1796    #[test]
1797    fn test_extract_variables_from_id() {
1798        let optimizer = Optimizer::new();
1799        let expr = LogicalExpression::Id("n".to_string());
1800        let vars = optimizer.extract_variables(&expr);
1801        assert_eq!(vars.len(), 1);
1802        assert!(vars.contains("n"));
1803    }
1804
1805    #[test]
1806    fn test_extract_variables_from_list_comprehension() {
1807        let optimizer = Optimizer::new();
1808        let expr = LogicalExpression::ListComprehension {
1809            variable: "x".to_string(),
1810            list_expr: Box::new(LogicalExpression::Variable("items".to_string())),
1811            filter_expr: Some(Box::new(LogicalExpression::Variable("pred".to_string()))),
1812            map_expr: Box::new(LogicalExpression::Variable("result".to_string())),
1813        };
1814        let vars = optimizer.extract_variables(&expr);
1815        assert!(vars.contains("items"));
1816        assert!(vars.contains("pred"));
1817        assert!(vars.contains("result"));
1818    }
1819
1820    #[test]
1821    fn test_extract_variables_from_literal_and_parameter() {
1822        let optimizer = Optimizer::new();
1823
1824        let literal = LogicalExpression::Literal(Value::Int64(42));
1825        assert!(optimizer.extract_variables(&literal).is_empty());
1826
1827        let param = LogicalExpression::Parameter("p".to_string());
1828        assert!(optimizer.extract_variables(&param).is_empty());
1829    }
1830
1831    // Recursive filter pushdown tests
1832
1833    #[test]
1834    fn test_recursive_filter_pushdown_through_skip() {
1835        let optimizer = Optimizer::new();
1836
1837        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1838            items: vec![ReturnItem {
1839                expression: LogicalExpression::Variable("n".to_string()),
1840                alias: None,
1841            }],
1842            distinct: false,
1843            input: Box::new(LogicalOperator::Filter(FilterOp {
1844                predicate: LogicalExpression::Literal(Value::Bool(true)),
1845                input: Box::new(LogicalOperator::Skip(SkipOp {
1846                    count: 5,
1847                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1848                        variable: "n".to_string(),
1849                        label: None,
1850                        input: None,
1851                    })),
1852                })),
1853            })),
1854        }));
1855
1856        let optimized = optimizer.optimize(plan).unwrap();
1857
1858        // Verify optimization succeeded
1859        assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1860    }
1861
1862    #[test]
1863    fn test_nested_filter_pushdown() {
1864        let optimizer = Optimizer::new();
1865
1866        // Multiple nested filters
1867        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1868            items: vec![ReturnItem {
1869                expression: LogicalExpression::Variable("n".to_string()),
1870                alias: None,
1871            }],
1872            distinct: false,
1873            input: Box::new(LogicalOperator::Filter(FilterOp {
1874                predicate: LogicalExpression::Binary {
1875                    left: Box::new(LogicalExpression::Property {
1876                        variable: "n".to_string(),
1877                        property: "x".to_string(),
1878                    }),
1879                    op: BinaryOp::Gt,
1880                    right: Box::new(LogicalExpression::Literal(Value::Int64(1))),
1881                },
1882                input: Box::new(LogicalOperator::Filter(FilterOp {
1883                    predicate: LogicalExpression::Binary {
1884                        left: Box::new(LogicalExpression::Property {
1885                            variable: "n".to_string(),
1886                            property: "y".to_string(),
1887                        }),
1888                        op: BinaryOp::Lt,
1889                        right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1890                    },
1891                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1892                        variable: "n".to_string(),
1893                        label: None,
1894                        input: None,
1895                    })),
1896                })),
1897            })),
1898        }));
1899
1900        let optimized = optimizer.optimize(plan).unwrap();
1901        assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1902    }
1903}