Skip to main content

grafeo_engine/query/optimizer/
mod.rs

1//! Makes your queries faster without changing their meaning.
2//!
3//! The optimizer transforms logical plans to run more efficiently:
4//!
5//! | Optimization | What it does |
6//! | ------------ | ------------ |
7//! | Filter Pushdown | Moves `WHERE` clauses closer to scans - filter early, process less |
8//! | Join Reordering | Picks the best order to join tables using the DPccp algorithm |
9//! | Predicate Simplification | Folds constants like `1 + 1` into `2` |
10//!
11//! The optimizer uses [`CostModel`] and [`CardinalityEstimator`] to predict
12//! how expensive different plans are, then picks the cheapest.
13
14pub mod cardinality;
15pub mod cost;
16pub mod join_order;
17
18pub use cardinality::{
19    CardinalityEstimator, ColumnStats, EstimationLog, SelectivityConfig, TableStats,
20};
21pub use cost::{Cost, CostModel};
22pub use join_order::{BitSet, DPccp, JoinGraph, JoinGraphBuilder, JoinPlan};
23
24use crate::query::plan::{FilterOp, LogicalExpression, LogicalOperator, LogicalPlan};
25use grafeo_common::utils::error::Result;
26use std::collections::HashSet;
27
28/// Information about a join condition for join reordering.
29#[derive(Debug, Clone)]
30struct JoinInfo {
31    left_var: String,
32    right_var: String,
33    left_expr: LogicalExpression,
34    right_expr: LogicalExpression,
35}
36
37/// A column required by the query, used for projection pushdown.
38#[derive(Debug, Clone, PartialEq, Eq, Hash)]
39enum RequiredColumn {
40    /// A variable (node, edge, or path binding)
41    Variable(String),
42    /// A specific property of a variable
43    Property(String, String),
44}
45
46/// Transforms logical plans for faster execution.
47///
48/// Create with [`new()`](Self::new), then call [`optimize()`](Self::optimize).
49/// Use the builder methods to enable/disable specific optimizations.
50pub struct Optimizer {
51    /// Whether to enable filter pushdown.
52    enable_filter_pushdown: bool,
53    /// Whether to enable join reordering.
54    enable_join_reorder: bool,
55    /// Whether to enable projection pushdown.
56    enable_projection_pushdown: bool,
57    /// Cost model for estimation.
58    cost_model: CostModel,
59    /// Cardinality estimator.
60    card_estimator: CardinalityEstimator,
61}
62
63impl Optimizer {
64    /// Creates a new optimizer with default settings.
65    #[must_use]
66    pub fn new() -> Self {
67        Self {
68            enable_filter_pushdown: true,
69            enable_join_reorder: true,
70            enable_projection_pushdown: true,
71            cost_model: CostModel::new(),
72            card_estimator: CardinalityEstimator::new(),
73        }
74    }
75
76    /// Creates an optimizer with cardinality estimates from the store's statistics.
77    ///
78    /// Pre-populates the cardinality estimator with per-label row counts and
79    /// edge type fanout. Feeds per-edge-type degree stats into the cost model
80    /// for accurate expand cost estimation.
81    #[must_use]
82    pub fn from_store(store: &grafeo_core::graph::lpg::LpgStore) -> Self {
83        store.ensure_statistics_fresh();
84        let stats = store.statistics();
85        let estimator = CardinalityEstimator::from_statistics(&stats);
86
87        // Derive average fanout from statistics for the cost model
88        let avg_fanout = if stats.total_nodes > 0 {
89            (stats.total_edges as f64 / stats.total_nodes as f64).max(1.0)
90        } else {
91            10.0
92        };
93
94        // Collect per-edge-type degree stats for accurate expand costing
95        let edge_type_degrees: std::collections::HashMap<String, (f64, f64)> = stats
96            .edge_types
97            .iter()
98            .map(|(name, et)| (name.clone(), (et.avg_out_degree, et.avg_in_degree)))
99            .collect();
100
101        Self {
102            enable_filter_pushdown: true,
103            enable_join_reorder: true,
104            enable_projection_pushdown: true,
105            cost_model: CostModel::new()
106                .with_avg_fanout(avg_fanout)
107                .with_edge_type_degrees(edge_type_degrees),
108            card_estimator: estimator,
109        }
110    }
111
112    /// Creates an optimizer from any GraphStore implementation.
113    ///
114    /// Unlike [`from_store`](Self::from_store), this does not call
115    /// `ensure_statistics_fresh()` since external stores manage their own
116    /// statistics. The store's [`statistics()`](grafeo_core::graph::GraphStore::statistics) method
117    /// is called directly.
118    #[must_use]
119    pub fn from_graph_store(store: &dyn grafeo_core::graph::GraphStore) -> Self {
120        let stats = store.statistics();
121        let estimator = CardinalityEstimator::from_statistics(&stats);
122
123        let avg_fanout = if stats.total_nodes > 0 {
124            (stats.total_edges as f64 / stats.total_nodes as f64).max(1.0)
125        } else {
126            10.0
127        };
128
129        let edge_type_degrees: std::collections::HashMap<String, (f64, f64)> = stats
130            .edge_types
131            .iter()
132            .map(|(name, et)| (name.clone(), (et.avg_out_degree, et.avg_in_degree)))
133            .collect();
134
135        Self {
136            enable_filter_pushdown: true,
137            enable_join_reorder: true,
138            enable_projection_pushdown: true,
139            cost_model: CostModel::new()
140                .with_avg_fanout(avg_fanout)
141                .with_edge_type_degrees(edge_type_degrees),
142            card_estimator: estimator,
143        }
144    }
145
146    /// Enables or disables filter pushdown.
147    pub fn with_filter_pushdown(mut self, enabled: bool) -> Self {
148        self.enable_filter_pushdown = enabled;
149        self
150    }
151
152    /// Enables or disables join reordering.
153    pub fn with_join_reorder(mut self, enabled: bool) -> Self {
154        self.enable_join_reorder = enabled;
155        self
156    }
157
158    /// Enables or disables projection pushdown.
159    pub fn with_projection_pushdown(mut self, enabled: bool) -> Self {
160        self.enable_projection_pushdown = enabled;
161        self
162    }
163
164    /// Sets the cost model.
165    pub fn with_cost_model(mut self, cost_model: CostModel) -> Self {
166        self.cost_model = cost_model;
167        self
168    }
169
170    /// Sets the cardinality estimator.
171    pub fn with_cardinality_estimator(mut self, estimator: CardinalityEstimator) -> Self {
172        self.card_estimator = estimator;
173        self
174    }
175
176    /// Sets the selectivity configuration for the cardinality estimator.
177    pub fn with_selectivity_config(mut self, config: SelectivityConfig) -> Self {
178        self.card_estimator = CardinalityEstimator::with_selectivity_config(config);
179        self
180    }
181
182    /// Returns a reference to the cost model.
183    pub fn cost_model(&self) -> &CostModel {
184        &self.cost_model
185    }
186
187    /// Returns a reference to the cardinality estimator.
188    pub fn cardinality_estimator(&self) -> &CardinalityEstimator {
189        &self.card_estimator
190    }
191
192    /// Estimates the cost of a plan.
193    pub fn estimate_cost(&self, plan: &LogicalPlan) -> Cost {
194        let cardinality = self.card_estimator.estimate(&plan.root);
195        self.cost_model.estimate(&plan.root, cardinality)
196    }
197
198    /// Estimates the cardinality of a plan.
199    pub fn estimate_cardinality(&self, plan: &LogicalPlan) -> f64 {
200        self.card_estimator.estimate(&plan.root)
201    }
202
203    /// Optimizes a logical plan.
204    ///
205    /// # Errors
206    ///
207    /// Returns an error if optimization fails.
208    pub fn optimize(&self, plan: LogicalPlan) -> Result<LogicalPlan> {
209        let mut root = plan.root;
210
211        // Apply optimization rules
212        if self.enable_filter_pushdown {
213            root = self.push_filters_down(root);
214        }
215
216        if self.enable_join_reorder {
217            root = self.reorder_joins(root);
218        }
219
220        if self.enable_projection_pushdown {
221            root = self.push_projections_down(root);
222        }
223
224        Ok(LogicalPlan::new(root))
225    }
226
227    /// Pushes projections down the operator tree to eliminate unused columns early.
228    ///
229    /// This optimization:
230    /// 1. Collects required variables/properties from the root
231    /// 2. Propagates requirements down through the tree
232    /// 3. Inserts projections to eliminate unneeded columns before expensive operations
233    fn push_projections_down(&self, op: LogicalOperator) -> LogicalOperator {
234        // Collect required columns from the top of the plan
235        let required = self.collect_required_columns(&op);
236
237        // Push projections down
238        self.push_projections_recursive(op, &required)
239    }
240
241    /// Collects all variables and properties required by an operator and its ancestors.
242    fn collect_required_columns(&self, op: &LogicalOperator) -> HashSet<RequiredColumn> {
243        let mut required = HashSet::new();
244        Self::collect_required_recursive(op, &mut required);
245        required
246    }
247
248    /// Recursively collects required columns.
249    fn collect_required_recursive(op: &LogicalOperator, required: &mut HashSet<RequiredColumn>) {
250        match op {
251            LogicalOperator::Return(ret) => {
252                for item in &ret.items {
253                    Self::collect_from_expression(&item.expression, required);
254                }
255                Self::collect_required_recursive(&ret.input, required);
256            }
257            LogicalOperator::Project(proj) => {
258                for p in &proj.projections {
259                    Self::collect_from_expression(&p.expression, required);
260                }
261                Self::collect_required_recursive(&proj.input, required);
262            }
263            LogicalOperator::Filter(filter) => {
264                Self::collect_from_expression(&filter.predicate, required);
265                Self::collect_required_recursive(&filter.input, required);
266            }
267            LogicalOperator::Sort(sort) => {
268                for key in &sort.keys {
269                    Self::collect_from_expression(&key.expression, required);
270                }
271                Self::collect_required_recursive(&sort.input, required);
272            }
273            LogicalOperator::Aggregate(agg) => {
274                for expr in &agg.group_by {
275                    Self::collect_from_expression(expr, required);
276                }
277                for agg_expr in &agg.aggregates {
278                    if let Some(ref expr) = agg_expr.expression {
279                        Self::collect_from_expression(expr, required);
280                    }
281                }
282                if let Some(ref having) = agg.having {
283                    Self::collect_from_expression(having, required);
284                }
285                Self::collect_required_recursive(&agg.input, required);
286            }
287            LogicalOperator::Join(join) => {
288                for cond in &join.conditions {
289                    Self::collect_from_expression(&cond.left, required);
290                    Self::collect_from_expression(&cond.right, required);
291                }
292                Self::collect_required_recursive(&join.left, required);
293                Self::collect_required_recursive(&join.right, required);
294            }
295            LogicalOperator::Expand(expand) => {
296                // The source and target variables are needed
297                required.insert(RequiredColumn::Variable(expand.from_variable.clone()));
298                required.insert(RequiredColumn::Variable(expand.to_variable.clone()));
299                if let Some(ref edge_var) = expand.edge_variable {
300                    required.insert(RequiredColumn::Variable(edge_var.clone()));
301                }
302                Self::collect_required_recursive(&expand.input, required);
303            }
304            LogicalOperator::Limit(limit) => {
305                Self::collect_required_recursive(&limit.input, required);
306            }
307            LogicalOperator::Skip(skip) => {
308                Self::collect_required_recursive(&skip.input, required);
309            }
310            LogicalOperator::Distinct(distinct) => {
311                Self::collect_required_recursive(&distinct.input, required);
312            }
313            LogicalOperator::NodeScan(scan) => {
314                required.insert(RequiredColumn::Variable(scan.variable.clone()));
315            }
316            LogicalOperator::EdgeScan(scan) => {
317                required.insert(RequiredColumn::Variable(scan.variable.clone()));
318            }
319            _ => {}
320        }
321    }
322
323    /// Collects required columns from an expression.
324    fn collect_from_expression(expr: &LogicalExpression, required: &mut HashSet<RequiredColumn>) {
325        match expr {
326            LogicalExpression::Variable(var) => {
327                required.insert(RequiredColumn::Variable(var.clone()));
328            }
329            LogicalExpression::Property { variable, property } => {
330                required.insert(RequiredColumn::Property(variable.clone(), property.clone()));
331                required.insert(RequiredColumn::Variable(variable.clone()));
332            }
333            LogicalExpression::Binary { left, right, .. } => {
334                Self::collect_from_expression(left, required);
335                Self::collect_from_expression(right, required);
336            }
337            LogicalExpression::Unary { operand, .. } => {
338                Self::collect_from_expression(operand, required);
339            }
340            LogicalExpression::FunctionCall { args, .. } => {
341                for arg in args {
342                    Self::collect_from_expression(arg, required);
343                }
344            }
345            LogicalExpression::List(items) => {
346                for item in items {
347                    Self::collect_from_expression(item, required);
348                }
349            }
350            LogicalExpression::Map(pairs) => {
351                for (_, value) in pairs {
352                    Self::collect_from_expression(value, required);
353                }
354            }
355            LogicalExpression::IndexAccess { base, index } => {
356                Self::collect_from_expression(base, required);
357                Self::collect_from_expression(index, required);
358            }
359            LogicalExpression::SliceAccess { base, start, end } => {
360                Self::collect_from_expression(base, required);
361                if let Some(s) = start {
362                    Self::collect_from_expression(s, required);
363                }
364                if let Some(e) = end {
365                    Self::collect_from_expression(e, required);
366                }
367            }
368            LogicalExpression::Case {
369                operand,
370                when_clauses,
371                else_clause,
372            } => {
373                if let Some(op) = operand {
374                    Self::collect_from_expression(op, required);
375                }
376                for (cond, result) in when_clauses {
377                    Self::collect_from_expression(cond, required);
378                    Self::collect_from_expression(result, required);
379                }
380                if let Some(else_expr) = else_clause {
381                    Self::collect_from_expression(else_expr, required);
382                }
383            }
384            LogicalExpression::Labels(var)
385            | LogicalExpression::Type(var)
386            | LogicalExpression::Id(var) => {
387                required.insert(RequiredColumn::Variable(var.clone()));
388            }
389            LogicalExpression::ListComprehension {
390                list_expr,
391                filter_expr,
392                map_expr,
393                ..
394            } => {
395                Self::collect_from_expression(list_expr, required);
396                if let Some(filter) = filter_expr {
397                    Self::collect_from_expression(filter, required);
398                }
399                Self::collect_from_expression(map_expr, required);
400            }
401            _ => {}
402        }
403    }
404
405    /// Recursively pushes projections down, adding them before expensive operations.
406    fn push_projections_recursive(
407        &self,
408        op: LogicalOperator,
409        required: &HashSet<RequiredColumn>,
410    ) -> LogicalOperator {
411        match op {
412            LogicalOperator::Return(mut ret) => {
413                ret.input = Box::new(self.push_projections_recursive(*ret.input, required));
414                LogicalOperator::Return(ret)
415            }
416            LogicalOperator::Project(mut proj) => {
417                proj.input = Box::new(self.push_projections_recursive(*proj.input, required));
418                LogicalOperator::Project(proj)
419            }
420            LogicalOperator::Filter(mut filter) => {
421                filter.input = Box::new(self.push_projections_recursive(*filter.input, required));
422                LogicalOperator::Filter(filter)
423            }
424            LogicalOperator::Sort(mut sort) => {
425                // Sort is expensive - consider adding a projection before it
426                // to reduce tuple width
427                sort.input = Box::new(self.push_projections_recursive(*sort.input, required));
428                LogicalOperator::Sort(sort)
429            }
430            LogicalOperator::Aggregate(mut agg) => {
431                agg.input = Box::new(self.push_projections_recursive(*agg.input, required));
432                LogicalOperator::Aggregate(agg)
433            }
434            LogicalOperator::Join(mut join) => {
435                // Joins are expensive - the required columns help determine
436                // what to project on each side
437                let left_vars = self.collect_output_variables(&join.left);
438                let right_vars = self.collect_output_variables(&join.right);
439
440                // Filter required columns to each side
441                let left_required: HashSet<_> = required
442                    .iter()
443                    .filter(|c| match c {
444                        RequiredColumn::Variable(v) => left_vars.contains(v),
445                        RequiredColumn::Property(v, _) => left_vars.contains(v),
446                    })
447                    .cloned()
448                    .collect();
449
450                let right_required: HashSet<_> = required
451                    .iter()
452                    .filter(|c| match c {
453                        RequiredColumn::Variable(v) => right_vars.contains(v),
454                        RequiredColumn::Property(v, _) => right_vars.contains(v),
455                    })
456                    .cloned()
457                    .collect();
458
459                join.left = Box::new(self.push_projections_recursive(*join.left, &left_required));
460                join.right =
461                    Box::new(self.push_projections_recursive(*join.right, &right_required));
462                LogicalOperator::Join(join)
463            }
464            LogicalOperator::Expand(mut expand) => {
465                expand.input = Box::new(self.push_projections_recursive(*expand.input, required));
466                LogicalOperator::Expand(expand)
467            }
468            LogicalOperator::Limit(mut limit) => {
469                limit.input = Box::new(self.push_projections_recursive(*limit.input, required));
470                LogicalOperator::Limit(limit)
471            }
472            LogicalOperator::Skip(mut skip) => {
473                skip.input = Box::new(self.push_projections_recursive(*skip.input, required));
474                LogicalOperator::Skip(skip)
475            }
476            LogicalOperator::Distinct(mut distinct) => {
477                distinct.input =
478                    Box::new(self.push_projections_recursive(*distinct.input, required));
479                LogicalOperator::Distinct(distinct)
480            }
481            LogicalOperator::MapCollect(mut mc) => {
482                mc.input = Box::new(self.push_projections_recursive(*mc.input, required));
483                LogicalOperator::MapCollect(mc)
484            }
485            other => other,
486        }
487    }
488
489    /// Reorders joins in the operator tree using the DPccp algorithm.
490    ///
491    /// This optimization finds the optimal join order by:
492    /// 1. Extracting all base relations (scans) and join conditions
493    /// 2. Building a join graph
494    /// 3. Using dynamic programming to find the cheapest join order
495    fn reorder_joins(&self, op: LogicalOperator) -> LogicalOperator {
496        // First, recursively optimize children
497        let op = self.reorder_joins_recursive(op);
498
499        // Then, if this is a join tree, try to optimize it
500        if let Some((relations, conditions)) = self.extract_join_tree(&op)
501            && relations.len() >= 2
502            && let Some(optimized) = self.optimize_join_order(&relations, &conditions)
503        {
504            return optimized;
505        }
506
507        op
508    }
509
510    /// Recursively applies join reordering to child operators.
511    fn reorder_joins_recursive(&self, op: LogicalOperator) -> LogicalOperator {
512        match op {
513            LogicalOperator::Return(mut ret) => {
514                ret.input = Box::new(self.reorder_joins(*ret.input));
515                LogicalOperator::Return(ret)
516            }
517            LogicalOperator::Project(mut proj) => {
518                proj.input = Box::new(self.reorder_joins(*proj.input));
519                LogicalOperator::Project(proj)
520            }
521            LogicalOperator::Filter(mut filter) => {
522                filter.input = Box::new(self.reorder_joins(*filter.input));
523                LogicalOperator::Filter(filter)
524            }
525            LogicalOperator::Limit(mut limit) => {
526                limit.input = Box::new(self.reorder_joins(*limit.input));
527                LogicalOperator::Limit(limit)
528            }
529            LogicalOperator::Skip(mut skip) => {
530                skip.input = Box::new(self.reorder_joins(*skip.input));
531                LogicalOperator::Skip(skip)
532            }
533            LogicalOperator::Sort(mut sort) => {
534                sort.input = Box::new(self.reorder_joins(*sort.input));
535                LogicalOperator::Sort(sort)
536            }
537            LogicalOperator::Distinct(mut distinct) => {
538                distinct.input = Box::new(self.reorder_joins(*distinct.input));
539                LogicalOperator::Distinct(distinct)
540            }
541            LogicalOperator::Aggregate(mut agg) => {
542                agg.input = Box::new(self.reorder_joins(*agg.input));
543                LogicalOperator::Aggregate(agg)
544            }
545            LogicalOperator::Expand(mut expand) => {
546                expand.input = Box::new(self.reorder_joins(*expand.input));
547                LogicalOperator::Expand(expand)
548            }
549            LogicalOperator::MapCollect(mut mc) => {
550                mc.input = Box::new(self.reorder_joins(*mc.input));
551                LogicalOperator::MapCollect(mc)
552            }
553            // Join operators are handled by the parent reorder_joins call
554            other => other,
555        }
556    }
557
558    /// Extracts base relations and join conditions from a join tree.
559    ///
560    /// Returns None if the operator is not a join tree.
561    fn extract_join_tree(
562        &self,
563        op: &LogicalOperator,
564    ) -> Option<(Vec<(String, LogicalOperator)>, Vec<JoinInfo>)> {
565        let mut relations = Vec::new();
566        let mut join_conditions = Vec::new();
567
568        if !self.collect_join_tree(op, &mut relations, &mut join_conditions) {
569            return None;
570        }
571
572        if relations.len() < 2 {
573            return None;
574        }
575
576        Some((relations, join_conditions))
577    }
578
579    /// Recursively collects base relations and join conditions.
580    ///
581    /// Returns true if this subtree is part of a join tree.
582    fn collect_join_tree(
583        &self,
584        op: &LogicalOperator,
585        relations: &mut Vec<(String, LogicalOperator)>,
586        conditions: &mut Vec<JoinInfo>,
587    ) -> bool {
588        match op {
589            LogicalOperator::Join(join) => {
590                // Collect from both sides
591                let left_ok = self.collect_join_tree(&join.left, relations, conditions);
592                let right_ok = self.collect_join_tree(&join.right, relations, conditions);
593
594                // Add conditions from this join
595                for cond in &join.conditions {
596                    if let (Some(left_var), Some(right_var)) = (
597                        self.extract_variable_from_expr(&cond.left),
598                        self.extract_variable_from_expr(&cond.right),
599                    ) {
600                        conditions.push(JoinInfo {
601                            left_var,
602                            right_var,
603                            left_expr: cond.left.clone(),
604                            right_expr: cond.right.clone(),
605                        });
606                    }
607                }
608
609                left_ok && right_ok
610            }
611            LogicalOperator::NodeScan(scan) => {
612                relations.push((scan.variable.clone(), op.clone()));
613                true
614            }
615            LogicalOperator::EdgeScan(scan) => {
616                relations.push((scan.variable.clone(), op.clone()));
617                true
618            }
619            LogicalOperator::Filter(filter) => {
620                // A filter on a base relation is still part of the join tree
621                self.collect_join_tree(&filter.input, relations, conditions)
622            }
623            LogicalOperator::Expand(expand) => {
624                // Expand is a special case - it's like a join with the adjacency
625                // For now, treat the whole Expand subtree as a single relation
626                relations.push((expand.to_variable.clone(), op.clone()));
627                true
628            }
629            _ => false,
630        }
631    }
632
633    /// Extracts the primary variable from an expression.
634    fn extract_variable_from_expr(&self, expr: &LogicalExpression) -> Option<String> {
635        match expr {
636            LogicalExpression::Variable(v) => Some(v.clone()),
637            LogicalExpression::Property { variable, .. } => Some(variable.clone()),
638            _ => None,
639        }
640    }
641
642    /// Optimizes the join order using DPccp.
643    fn optimize_join_order(
644        &self,
645        relations: &[(String, LogicalOperator)],
646        conditions: &[JoinInfo],
647    ) -> Option<LogicalOperator> {
648        use join_order::{DPccp, JoinGraphBuilder};
649
650        // Build the join graph
651        let mut builder = JoinGraphBuilder::new();
652
653        for (var, relation) in relations {
654            builder.add_relation(var, relation.clone());
655        }
656
657        for cond in conditions {
658            builder.add_join_condition(
659                &cond.left_var,
660                &cond.right_var,
661                cond.left_expr.clone(),
662                cond.right_expr.clone(),
663            );
664        }
665
666        let graph = builder.build();
667
668        // Run DPccp
669        let mut dpccp = DPccp::new(&graph, &self.cost_model, &self.card_estimator);
670        let plan = dpccp.optimize()?;
671
672        Some(plan.operator)
673    }
674
675    /// Pushes filters down the operator tree.
676    ///
677    /// This optimization moves filter predicates as close to the data source
678    /// as possible to reduce the amount of data processed by upper operators.
679    fn push_filters_down(&self, op: LogicalOperator) -> LogicalOperator {
680        match op {
681            // For Filter operators, try to push the predicate into the child
682            LogicalOperator::Filter(filter) => {
683                let optimized_input = self.push_filters_down(*filter.input);
684                self.try_push_filter_into(filter.predicate, optimized_input)
685            }
686            // Recursively optimize children for other operators
687            LogicalOperator::Return(mut ret) => {
688                ret.input = Box::new(self.push_filters_down(*ret.input));
689                LogicalOperator::Return(ret)
690            }
691            LogicalOperator::Project(mut proj) => {
692                proj.input = Box::new(self.push_filters_down(*proj.input));
693                LogicalOperator::Project(proj)
694            }
695            LogicalOperator::Limit(mut limit) => {
696                limit.input = Box::new(self.push_filters_down(*limit.input));
697                LogicalOperator::Limit(limit)
698            }
699            LogicalOperator::Skip(mut skip) => {
700                skip.input = Box::new(self.push_filters_down(*skip.input));
701                LogicalOperator::Skip(skip)
702            }
703            LogicalOperator::Sort(mut sort) => {
704                sort.input = Box::new(self.push_filters_down(*sort.input));
705                LogicalOperator::Sort(sort)
706            }
707            LogicalOperator::Distinct(mut distinct) => {
708                distinct.input = Box::new(self.push_filters_down(*distinct.input));
709                LogicalOperator::Distinct(distinct)
710            }
711            LogicalOperator::Expand(mut expand) => {
712                expand.input = Box::new(self.push_filters_down(*expand.input));
713                LogicalOperator::Expand(expand)
714            }
715            LogicalOperator::Join(mut join) => {
716                join.left = Box::new(self.push_filters_down(*join.left));
717                join.right = Box::new(self.push_filters_down(*join.right));
718                LogicalOperator::Join(join)
719            }
720            LogicalOperator::Aggregate(mut agg) => {
721                agg.input = Box::new(self.push_filters_down(*agg.input));
722                LogicalOperator::Aggregate(agg)
723            }
724            LogicalOperator::MapCollect(mut mc) => {
725                mc.input = Box::new(self.push_filters_down(*mc.input));
726                LogicalOperator::MapCollect(mc)
727            }
728            // Leaf operators and unsupported operators are returned as-is
729            other => other,
730        }
731    }
732
733    /// Tries to push a filter predicate into the given operator.
734    ///
735    /// Returns either the predicate pushed into the operator, or a new
736    /// Filter operator on top if the predicate cannot be pushed further.
737    fn try_push_filter_into(
738        &self,
739        predicate: LogicalExpression,
740        op: LogicalOperator,
741    ) -> LogicalOperator {
742        match op {
743            // Can push through Project if predicate doesn't depend on computed columns
744            LogicalOperator::Project(mut proj) => {
745                let predicate_vars = self.extract_variables(&predicate);
746                let computed_vars = self.extract_projection_aliases(&proj.projections);
747
748                // If predicate doesn't use any computed columns, push through
749                if predicate_vars.is_disjoint(&computed_vars) {
750                    proj.input = Box::new(self.try_push_filter_into(predicate, *proj.input));
751                    LogicalOperator::Project(proj)
752                } else {
753                    // Can't push through, keep filter on top
754                    LogicalOperator::Filter(FilterOp {
755                        predicate,
756                        input: Box::new(LogicalOperator::Project(proj)),
757                    })
758                }
759            }
760
761            // Can push through Return (which is like a projection)
762            LogicalOperator::Return(mut ret) => {
763                ret.input = Box::new(self.try_push_filter_into(predicate, *ret.input));
764                LogicalOperator::Return(ret)
765            }
766
767            // Can push through Expand if predicate doesn't use variables introduced by this expand
768            LogicalOperator::Expand(mut expand) => {
769                let predicate_vars = self.extract_variables(&predicate);
770
771                // Variables introduced by this expand are:
772                // - The target variable (to_variable)
773                // - The edge variable (if any)
774                // - The path alias (if any)
775                let mut introduced_vars = vec![&expand.to_variable];
776                if let Some(ref edge_var) = expand.edge_variable {
777                    introduced_vars.push(edge_var);
778                }
779                if let Some(ref path_alias) = expand.path_alias {
780                    introduced_vars.push(path_alias);
781                }
782
783                // Check if predicate uses any variables introduced by this expand
784                let uses_introduced_vars =
785                    predicate_vars.iter().any(|v| introduced_vars.contains(&v));
786
787                if !uses_introduced_vars {
788                    // Predicate doesn't use vars from this expand, so push through
789                    expand.input = Box::new(self.try_push_filter_into(predicate, *expand.input));
790                    LogicalOperator::Expand(expand)
791                } else {
792                    // Keep filter after expand
793                    LogicalOperator::Filter(FilterOp {
794                        predicate,
795                        input: Box::new(LogicalOperator::Expand(expand)),
796                    })
797                }
798            }
799
800            // Can push through Join to left/right side based on variables used
801            LogicalOperator::Join(mut join) => {
802                let predicate_vars = self.extract_variables(&predicate);
803                let left_vars = self.collect_output_variables(&join.left);
804                let right_vars = self.collect_output_variables(&join.right);
805
806                let uses_left = predicate_vars.iter().any(|v| left_vars.contains(v));
807                let uses_right = predicate_vars.iter().any(|v| right_vars.contains(v));
808
809                if uses_left && !uses_right {
810                    // Push to left side
811                    join.left = Box::new(self.try_push_filter_into(predicate, *join.left));
812                    LogicalOperator::Join(join)
813                } else if uses_right && !uses_left {
814                    // Push to right side
815                    join.right = Box::new(self.try_push_filter_into(predicate, *join.right));
816                    LogicalOperator::Join(join)
817                } else {
818                    // Uses both sides - keep above join
819                    LogicalOperator::Filter(FilterOp {
820                        predicate,
821                        input: Box::new(LogicalOperator::Join(join)),
822                    })
823                }
824            }
825
826            // Cannot push through Aggregate (predicate refers to aggregated values)
827            LogicalOperator::Aggregate(agg) => LogicalOperator::Filter(FilterOp {
828                predicate,
829                input: Box::new(LogicalOperator::Aggregate(agg)),
830            }),
831
832            // For NodeScan, we've reached the bottom - keep filter on top
833            LogicalOperator::NodeScan(scan) => LogicalOperator::Filter(FilterOp {
834                predicate,
835                input: Box::new(LogicalOperator::NodeScan(scan)),
836            }),
837
838            // For other operators, keep filter on top
839            other => LogicalOperator::Filter(FilterOp {
840                predicate,
841                input: Box::new(other),
842            }),
843        }
844    }
845
846    /// Collects all output variable names from an operator.
847    fn collect_output_variables(&self, op: &LogicalOperator) -> HashSet<String> {
848        let mut vars = HashSet::new();
849        Self::collect_output_variables_recursive(op, &mut vars);
850        vars
851    }
852
853    /// Recursively collects output variables from an operator.
854    fn collect_output_variables_recursive(op: &LogicalOperator, vars: &mut HashSet<String>) {
855        match op {
856            LogicalOperator::NodeScan(scan) => {
857                vars.insert(scan.variable.clone());
858            }
859            LogicalOperator::EdgeScan(scan) => {
860                vars.insert(scan.variable.clone());
861            }
862            LogicalOperator::Expand(expand) => {
863                vars.insert(expand.to_variable.clone());
864                if let Some(edge_var) = &expand.edge_variable {
865                    vars.insert(edge_var.clone());
866                }
867                Self::collect_output_variables_recursive(&expand.input, vars);
868            }
869            LogicalOperator::Filter(filter) => {
870                Self::collect_output_variables_recursive(&filter.input, vars);
871            }
872            LogicalOperator::Project(proj) => {
873                for p in &proj.projections {
874                    if let Some(alias) = &p.alias {
875                        vars.insert(alias.clone());
876                    }
877                }
878                Self::collect_output_variables_recursive(&proj.input, vars);
879            }
880            LogicalOperator::Join(join) => {
881                Self::collect_output_variables_recursive(&join.left, vars);
882                Self::collect_output_variables_recursive(&join.right, vars);
883            }
884            LogicalOperator::Aggregate(agg) => {
885                for expr in &agg.group_by {
886                    Self::collect_variables(expr, vars);
887                }
888                for agg_expr in &agg.aggregates {
889                    if let Some(alias) = &agg_expr.alias {
890                        vars.insert(alias.clone());
891                    }
892                }
893            }
894            LogicalOperator::Return(ret) => {
895                Self::collect_output_variables_recursive(&ret.input, vars);
896            }
897            LogicalOperator::Limit(limit) => {
898                Self::collect_output_variables_recursive(&limit.input, vars);
899            }
900            LogicalOperator::Skip(skip) => {
901                Self::collect_output_variables_recursive(&skip.input, vars);
902            }
903            LogicalOperator::Sort(sort) => {
904                Self::collect_output_variables_recursive(&sort.input, vars);
905            }
906            LogicalOperator::Distinct(distinct) => {
907                Self::collect_output_variables_recursive(&distinct.input, vars);
908            }
909            _ => {}
910        }
911    }
912
913    /// Extracts all variable names referenced in an expression.
914    fn extract_variables(&self, expr: &LogicalExpression) -> HashSet<String> {
915        let mut vars = HashSet::new();
916        Self::collect_variables(expr, &mut vars);
917        vars
918    }
919
920    /// Recursively collects variable names from an expression.
921    fn collect_variables(expr: &LogicalExpression, vars: &mut HashSet<String>) {
922        match expr {
923            LogicalExpression::Variable(name) => {
924                vars.insert(name.clone());
925            }
926            LogicalExpression::Property { variable, .. } => {
927                vars.insert(variable.clone());
928            }
929            LogicalExpression::Binary { left, right, .. } => {
930                Self::collect_variables(left, vars);
931                Self::collect_variables(right, vars);
932            }
933            LogicalExpression::Unary { operand, .. } => {
934                Self::collect_variables(operand, vars);
935            }
936            LogicalExpression::FunctionCall { args, .. } => {
937                for arg in args {
938                    Self::collect_variables(arg, vars);
939                }
940            }
941            LogicalExpression::List(items) => {
942                for item in items {
943                    Self::collect_variables(item, vars);
944                }
945            }
946            LogicalExpression::Map(pairs) => {
947                for (_, value) in pairs {
948                    Self::collect_variables(value, vars);
949                }
950            }
951            LogicalExpression::IndexAccess { base, index } => {
952                Self::collect_variables(base, vars);
953                Self::collect_variables(index, vars);
954            }
955            LogicalExpression::SliceAccess { base, start, end } => {
956                Self::collect_variables(base, vars);
957                if let Some(s) = start {
958                    Self::collect_variables(s, vars);
959                }
960                if let Some(e) = end {
961                    Self::collect_variables(e, vars);
962                }
963            }
964            LogicalExpression::Case {
965                operand,
966                when_clauses,
967                else_clause,
968            } => {
969                if let Some(op) = operand {
970                    Self::collect_variables(op, vars);
971                }
972                for (cond, result) in when_clauses {
973                    Self::collect_variables(cond, vars);
974                    Self::collect_variables(result, vars);
975                }
976                if let Some(else_expr) = else_clause {
977                    Self::collect_variables(else_expr, vars);
978                }
979            }
980            LogicalExpression::Labels(var)
981            | LogicalExpression::Type(var)
982            | LogicalExpression::Id(var) => {
983                vars.insert(var.clone());
984            }
985            LogicalExpression::Literal(_) | LogicalExpression::Parameter(_) => {}
986            LogicalExpression::ListComprehension {
987                list_expr,
988                filter_expr,
989                map_expr,
990                ..
991            } => {
992                Self::collect_variables(list_expr, vars);
993                if let Some(filter) = filter_expr {
994                    Self::collect_variables(filter, vars);
995                }
996                Self::collect_variables(map_expr, vars);
997            }
998            LogicalExpression::ListPredicate {
999                list_expr,
1000                predicate,
1001                ..
1002            } => {
1003                Self::collect_variables(list_expr, vars);
1004                Self::collect_variables(predicate, vars);
1005            }
1006            LogicalExpression::ExistsSubquery(_) | LogicalExpression::CountSubquery(_) => {
1007                // Subqueries have their own variable scope
1008            }
1009        }
1010    }
1011
1012    /// Extracts aliases from projection expressions.
1013    fn extract_projection_aliases(
1014        &self,
1015        projections: &[crate::query::plan::Projection],
1016    ) -> HashSet<String> {
1017        projections.iter().filter_map(|p| p.alias.clone()).collect()
1018    }
1019}
1020
1021impl Default for Optimizer {
1022    fn default() -> Self {
1023        Self::new()
1024    }
1025}
1026
1027#[cfg(test)]
1028mod tests {
1029    use super::*;
1030    use crate::query::plan::{
1031        AggregateExpr, AggregateFunction, AggregateOp, BinaryOp, DistinctOp, ExpandDirection,
1032        ExpandOp, JoinOp, JoinType, LimitOp, NodeScanOp, ProjectOp, Projection, ReturnItem,
1033        ReturnOp, SkipOp, SortKey, SortOp, SortOrder, UnaryOp,
1034    };
1035    use grafeo_common::types::Value;
1036
1037    #[test]
1038    fn test_optimizer_filter_pushdown_simple() {
1039        // Query: MATCH (n:Person) WHERE n.age > 30 RETURN n
1040        // Before: Return -> Filter -> NodeScan
1041        // After:  Return -> Filter -> NodeScan (filter stays at bottom)
1042
1043        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1044            items: vec![ReturnItem {
1045                expression: LogicalExpression::Variable("n".to_string()),
1046                alias: None,
1047            }],
1048            distinct: false,
1049            input: Box::new(LogicalOperator::Filter(FilterOp {
1050                predicate: LogicalExpression::Binary {
1051                    left: Box::new(LogicalExpression::Property {
1052                        variable: "n".to_string(),
1053                        property: "age".to_string(),
1054                    }),
1055                    op: BinaryOp::Gt,
1056                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1057                },
1058                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1059                    variable: "n".to_string(),
1060                    label: Some("Person".to_string()),
1061                    input: None,
1062                })),
1063            })),
1064        }));
1065
1066        let optimizer = Optimizer::new();
1067        let optimized = optimizer.optimize(plan).unwrap();
1068
1069        // The structure should remain similar (filter stays near scan)
1070        if let LogicalOperator::Return(ret) = &optimized.root
1071            && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1072            && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1073        {
1074            assert_eq!(scan.variable, "n");
1075            return;
1076        }
1077        panic!("Expected Return -> Filter -> NodeScan structure");
1078    }
1079
1080    #[test]
1081    fn test_optimizer_filter_pushdown_through_expand() {
1082        // Query: MATCH (a:Person)-[:KNOWS]->(b) WHERE a.age > 30 RETURN b
1083        // The filter on 'a' should be pushed before the expand
1084
1085        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1086            items: vec![ReturnItem {
1087                expression: LogicalExpression::Variable("b".to_string()),
1088                alias: None,
1089            }],
1090            distinct: false,
1091            input: Box::new(LogicalOperator::Filter(FilterOp {
1092                predicate: LogicalExpression::Binary {
1093                    left: Box::new(LogicalExpression::Property {
1094                        variable: "a".to_string(),
1095                        property: "age".to_string(),
1096                    }),
1097                    op: BinaryOp::Gt,
1098                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1099                },
1100                input: Box::new(LogicalOperator::Expand(ExpandOp {
1101                    from_variable: "a".to_string(),
1102                    to_variable: "b".to_string(),
1103                    edge_variable: None,
1104                    direction: ExpandDirection::Outgoing,
1105                    edge_type: Some("KNOWS".to_string()),
1106                    min_hops: 1,
1107                    max_hops: Some(1),
1108                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1109                        variable: "a".to_string(),
1110                        label: Some("Person".to_string()),
1111                        input: None,
1112                    })),
1113                    path_alias: None,
1114                })),
1115            })),
1116        }));
1117
1118        let optimizer = Optimizer::new();
1119        let optimized = optimizer.optimize(plan).unwrap();
1120
1121        // Filter on 'a' should be pushed before the expand
1122        // Expected: Return -> Expand -> Filter -> NodeScan
1123        if let LogicalOperator::Return(ret) = &optimized.root
1124            && let LogicalOperator::Expand(expand) = ret.input.as_ref()
1125            && let LogicalOperator::Filter(filter) = expand.input.as_ref()
1126            && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1127        {
1128            assert_eq!(scan.variable, "a");
1129            assert_eq!(expand.from_variable, "a");
1130            assert_eq!(expand.to_variable, "b");
1131            return;
1132        }
1133        panic!("Expected Return -> Expand -> Filter -> NodeScan structure");
1134    }
1135
1136    #[test]
1137    fn test_optimizer_filter_not_pushed_through_expand_for_target_var() {
1138        // Query: MATCH (a:Person)-[:KNOWS]->(b) WHERE b.age > 30 RETURN a
1139        // The filter on 'b' should NOT be pushed before the expand
1140
1141        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1142            items: vec![ReturnItem {
1143                expression: LogicalExpression::Variable("a".to_string()),
1144                alias: None,
1145            }],
1146            distinct: false,
1147            input: Box::new(LogicalOperator::Filter(FilterOp {
1148                predicate: LogicalExpression::Binary {
1149                    left: Box::new(LogicalExpression::Property {
1150                        variable: "b".to_string(),
1151                        property: "age".to_string(),
1152                    }),
1153                    op: BinaryOp::Gt,
1154                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1155                },
1156                input: Box::new(LogicalOperator::Expand(ExpandOp {
1157                    from_variable: "a".to_string(),
1158                    to_variable: "b".to_string(),
1159                    edge_variable: None,
1160                    direction: ExpandDirection::Outgoing,
1161                    edge_type: Some("KNOWS".to_string()),
1162                    min_hops: 1,
1163                    max_hops: Some(1),
1164                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1165                        variable: "a".to_string(),
1166                        label: Some("Person".to_string()),
1167                        input: None,
1168                    })),
1169                    path_alias: None,
1170                })),
1171            })),
1172        }));
1173
1174        let optimizer = Optimizer::new();
1175        let optimized = optimizer.optimize(plan).unwrap();
1176
1177        // Filter on 'b' should stay after the expand
1178        // Expected: Return -> Filter -> Expand -> NodeScan
1179        if let LogicalOperator::Return(ret) = &optimized.root
1180            && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1181        {
1182            // Check that the filter is on 'b'
1183            if let LogicalExpression::Binary { left, .. } = &filter.predicate
1184                && let LogicalExpression::Property { variable, .. } = left.as_ref()
1185            {
1186                assert_eq!(variable, "b");
1187            }
1188
1189            if let LogicalOperator::Expand(expand) = filter.input.as_ref()
1190                && let LogicalOperator::NodeScan(_) = expand.input.as_ref()
1191            {
1192                return;
1193            }
1194        }
1195        panic!("Expected Return -> Filter -> Expand -> NodeScan structure");
1196    }
1197
1198    #[test]
1199    fn test_optimizer_extract_variables() {
1200        let optimizer = Optimizer::new();
1201
1202        let expr = LogicalExpression::Binary {
1203            left: Box::new(LogicalExpression::Property {
1204                variable: "n".to_string(),
1205                property: "age".to_string(),
1206            }),
1207            op: BinaryOp::Gt,
1208            right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1209        };
1210
1211        let vars = optimizer.extract_variables(&expr);
1212        assert_eq!(vars.len(), 1);
1213        assert!(vars.contains("n"));
1214    }
1215
1216    // Additional tests for optimizer configuration
1217
1218    #[test]
1219    fn test_optimizer_default() {
1220        let optimizer = Optimizer::default();
1221        // Should be able to optimize an empty plan
1222        let plan = LogicalPlan::new(LogicalOperator::Empty);
1223        let result = optimizer.optimize(plan);
1224        assert!(result.is_ok());
1225    }
1226
1227    #[test]
1228    fn test_optimizer_with_filter_pushdown_disabled() {
1229        let optimizer = Optimizer::new().with_filter_pushdown(false);
1230
1231        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1232            items: vec![ReturnItem {
1233                expression: LogicalExpression::Variable("n".to_string()),
1234                alias: None,
1235            }],
1236            distinct: false,
1237            input: Box::new(LogicalOperator::Filter(FilterOp {
1238                predicate: LogicalExpression::Literal(Value::Bool(true)),
1239                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1240                    variable: "n".to_string(),
1241                    label: None,
1242                    input: None,
1243                })),
1244            })),
1245        }));
1246
1247        let optimized = optimizer.optimize(plan).unwrap();
1248        // Structure should be unchanged
1249        if let LogicalOperator::Return(ret) = &optimized.root
1250            && let LogicalOperator::Filter(_) = ret.input.as_ref()
1251        {
1252            return;
1253        }
1254        panic!("Expected unchanged structure");
1255    }
1256
1257    #[test]
1258    fn test_optimizer_with_join_reorder_disabled() {
1259        let optimizer = Optimizer::new().with_join_reorder(false);
1260        assert!(
1261            optimizer
1262                .optimize(LogicalPlan::new(LogicalOperator::Empty))
1263                .is_ok()
1264        );
1265    }
1266
1267    #[test]
1268    fn test_optimizer_with_cost_model() {
1269        let cost_model = CostModel::new();
1270        let optimizer = Optimizer::new().with_cost_model(cost_model);
1271        assert!(
1272            optimizer
1273                .cost_model()
1274                .estimate(&LogicalOperator::Empty, 0.0)
1275                .total()
1276                < 0.001
1277        );
1278    }
1279
1280    #[test]
1281    fn test_optimizer_with_cardinality_estimator() {
1282        let mut estimator = CardinalityEstimator::new();
1283        estimator.add_table_stats("Test", TableStats::new(500));
1284        let optimizer = Optimizer::new().with_cardinality_estimator(estimator);
1285
1286        let scan = LogicalOperator::NodeScan(NodeScanOp {
1287            variable: "n".to_string(),
1288            label: Some("Test".to_string()),
1289            input: None,
1290        });
1291        let plan = LogicalPlan::new(scan);
1292
1293        let cardinality = optimizer.estimate_cardinality(&plan);
1294        assert!((cardinality - 500.0).abs() < 0.001);
1295    }
1296
1297    #[test]
1298    fn test_optimizer_estimate_cost() {
1299        let optimizer = Optimizer::new();
1300        let plan = LogicalPlan::new(LogicalOperator::NodeScan(NodeScanOp {
1301            variable: "n".to_string(),
1302            label: None,
1303            input: None,
1304        }));
1305
1306        let cost = optimizer.estimate_cost(&plan);
1307        assert!(cost.total() > 0.0);
1308    }
1309
1310    // Filter pushdown through various operators
1311
1312    #[test]
1313    fn test_filter_pushdown_through_project() {
1314        let optimizer = Optimizer::new();
1315
1316        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1317            predicate: LogicalExpression::Binary {
1318                left: Box::new(LogicalExpression::Property {
1319                    variable: "n".to_string(),
1320                    property: "age".to_string(),
1321                }),
1322                op: BinaryOp::Gt,
1323                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1324            },
1325            input: Box::new(LogicalOperator::Project(ProjectOp {
1326                projections: vec![Projection {
1327                    expression: LogicalExpression::Variable("n".to_string()),
1328                    alias: None,
1329                }],
1330                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1331                    variable: "n".to_string(),
1332                    label: None,
1333                    input: None,
1334                })),
1335            })),
1336        }));
1337
1338        let optimized = optimizer.optimize(plan).unwrap();
1339
1340        // Filter should be pushed through Project
1341        if let LogicalOperator::Project(proj) = &optimized.root
1342            && let LogicalOperator::Filter(_) = proj.input.as_ref()
1343        {
1344            return;
1345        }
1346        panic!("Expected Project -> Filter structure");
1347    }
1348
1349    #[test]
1350    fn test_filter_not_pushed_through_project_with_alias() {
1351        let optimizer = Optimizer::new();
1352
1353        // Filter on computed column 'x' should not be pushed through project that creates 'x'
1354        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1355            predicate: LogicalExpression::Binary {
1356                left: Box::new(LogicalExpression::Variable("x".to_string())),
1357                op: BinaryOp::Gt,
1358                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1359            },
1360            input: Box::new(LogicalOperator::Project(ProjectOp {
1361                projections: vec![Projection {
1362                    expression: LogicalExpression::Property {
1363                        variable: "n".to_string(),
1364                        property: "age".to_string(),
1365                    },
1366                    alias: Some("x".to_string()),
1367                }],
1368                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1369                    variable: "n".to_string(),
1370                    label: None,
1371                    input: None,
1372                })),
1373            })),
1374        }));
1375
1376        let optimized = optimizer.optimize(plan).unwrap();
1377
1378        // Filter should stay above Project
1379        if let LogicalOperator::Filter(filter) = &optimized.root
1380            && let LogicalOperator::Project(_) = filter.input.as_ref()
1381        {
1382            return;
1383        }
1384        panic!("Expected Filter -> Project structure");
1385    }
1386
1387    #[test]
1388    fn test_filter_pushdown_through_limit() {
1389        let optimizer = Optimizer::new();
1390
1391        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1392            predicate: LogicalExpression::Literal(Value::Bool(true)),
1393            input: Box::new(LogicalOperator::Limit(LimitOp {
1394                count: 10,
1395                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1396                    variable: "n".to_string(),
1397                    label: None,
1398                    input: None,
1399                })),
1400            })),
1401        }));
1402
1403        let optimized = optimizer.optimize(plan).unwrap();
1404
1405        // Filter stays above Limit (cannot be pushed through)
1406        if let LogicalOperator::Filter(filter) = &optimized.root
1407            && let LogicalOperator::Limit(_) = filter.input.as_ref()
1408        {
1409            return;
1410        }
1411        panic!("Expected Filter -> Limit structure");
1412    }
1413
1414    #[test]
1415    fn test_filter_pushdown_through_sort() {
1416        let optimizer = Optimizer::new();
1417
1418        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1419            predicate: LogicalExpression::Literal(Value::Bool(true)),
1420            input: Box::new(LogicalOperator::Sort(SortOp {
1421                keys: vec![SortKey {
1422                    expression: LogicalExpression::Variable("n".to_string()),
1423                    order: SortOrder::Ascending,
1424                }],
1425                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1426                    variable: "n".to_string(),
1427                    label: None,
1428                    input: None,
1429                })),
1430            })),
1431        }));
1432
1433        let optimized = optimizer.optimize(plan).unwrap();
1434
1435        // Filter stays above Sort
1436        if let LogicalOperator::Filter(filter) = &optimized.root
1437            && let LogicalOperator::Sort(_) = filter.input.as_ref()
1438        {
1439            return;
1440        }
1441        panic!("Expected Filter -> Sort structure");
1442    }
1443
1444    #[test]
1445    fn test_filter_pushdown_through_distinct() {
1446        let optimizer = Optimizer::new();
1447
1448        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1449            predicate: LogicalExpression::Literal(Value::Bool(true)),
1450            input: Box::new(LogicalOperator::Distinct(DistinctOp {
1451                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1452                    variable: "n".to_string(),
1453                    label: None,
1454                    input: None,
1455                })),
1456                columns: None,
1457            })),
1458        }));
1459
1460        let optimized = optimizer.optimize(plan).unwrap();
1461
1462        // Filter stays above Distinct
1463        if let LogicalOperator::Filter(filter) = &optimized.root
1464            && let LogicalOperator::Distinct(_) = filter.input.as_ref()
1465        {
1466            return;
1467        }
1468        panic!("Expected Filter -> Distinct structure");
1469    }
1470
1471    #[test]
1472    fn test_filter_not_pushed_through_aggregate() {
1473        let optimizer = Optimizer::new();
1474
1475        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1476            predicate: LogicalExpression::Binary {
1477                left: Box::new(LogicalExpression::Variable("cnt".to_string())),
1478                op: BinaryOp::Gt,
1479                right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1480            },
1481            input: Box::new(LogicalOperator::Aggregate(AggregateOp {
1482                group_by: vec![],
1483                aggregates: vec![AggregateExpr {
1484                    function: AggregateFunction::Count,
1485                    expression: None,
1486                    distinct: false,
1487                    alias: Some("cnt".to_string()),
1488                    percentile: None,
1489                }],
1490                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1491                    variable: "n".to_string(),
1492                    label: None,
1493                    input: None,
1494                })),
1495                having: None,
1496            })),
1497        }));
1498
1499        let optimized = optimizer.optimize(plan).unwrap();
1500
1501        // Filter should stay above Aggregate
1502        if let LogicalOperator::Filter(filter) = &optimized.root
1503            && let LogicalOperator::Aggregate(_) = filter.input.as_ref()
1504        {
1505            return;
1506        }
1507        panic!("Expected Filter -> Aggregate structure");
1508    }
1509
1510    #[test]
1511    fn test_filter_pushdown_to_left_join_side() {
1512        let optimizer = Optimizer::new();
1513
1514        // Filter on left variable should be pushed to left side
1515        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1516            predicate: LogicalExpression::Binary {
1517                left: Box::new(LogicalExpression::Property {
1518                    variable: "a".to_string(),
1519                    property: "age".to_string(),
1520                }),
1521                op: BinaryOp::Gt,
1522                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1523            },
1524            input: Box::new(LogicalOperator::Join(JoinOp {
1525                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1526                    variable: "a".to_string(),
1527                    label: Some("Person".to_string()),
1528                    input: None,
1529                })),
1530                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1531                    variable: "b".to_string(),
1532                    label: Some("Company".to_string()),
1533                    input: None,
1534                })),
1535                join_type: JoinType::Inner,
1536                conditions: vec![],
1537            })),
1538        }));
1539
1540        let optimized = optimizer.optimize(plan).unwrap();
1541
1542        // Filter should be pushed to left side of join
1543        if let LogicalOperator::Join(join) = &optimized.root
1544            && let LogicalOperator::Filter(_) = join.left.as_ref()
1545        {
1546            return;
1547        }
1548        panic!("Expected Join with Filter on left side");
1549    }
1550
1551    #[test]
1552    fn test_filter_pushdown_to_right_join_side() {
1553        let optimizer = Optimizer::new();
1554
1555        // Filter on right variable should be pushed to right side
1556        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1557            predicate: LogicalExpression::Binary {
1558                left: Box::new(LogicalExpression::Property {
1559                    variable: "b".to_string(),
1560                    property: "name".to_string(),
1561                }),
1562                op: BinaryOp::Eq,
1563                right: Box::new(LogicalExpression::Literal(Value::String("Acme".into()))),
1564            },
1565            input: Box::new(LogicalOperator::Join(JoinOp {
1566                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1567                    variable: "a".to_string(),
1568                    label: Some("Person".to_string()),
1569                    input: None,
1570                })),
1571                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1572                    variable: "b".to_string(),
1573                    label: Some("Company".to_string()),
1574                    input: None,
1575                })),
1576                join_type: JoinType::Inner,
1577                conditions: vec![],
1578            })),
1579        }));
1580
1581        let optimized = optimizer.optimize(plan).unwrap();
1582
1583        // Filter should be pushed to right side of join
1584        if let LogicalOperator::Join(join) = &optimized.root
1585            && let LogicalOperator::Filter(_) = join.right.as_ref()
1586        {
1587            return;
1588        }
1589        panic!("Expected Join with Filter on right side");
1590    }
1591
1592    #[test]
1593    fn test_filter_not_pushed_when_uses_both_join_sides() {
1594        let optimizer = Optimizer::new();
1595
1596        // Filter using both variables should stay above join
1597        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1598            predicate: LogicalExpression::Binary {
1599                left: Box::new(LogicalExpression::Property {
1600                    variable: "a".to_string(),
1601                    property: "id".to_string(),
1602                }),
1603                op: BinaryOp::Eq,
1604                right: Box::new(LogicalExpression::Property {
1605                    variable: "b".to_string(),
1606                    property: "a_id".to_string(),
1607                }),
1608            },
1609            input: Box::new(LogicalOperator::Join(JoinOp {
1610                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1611                    variable: "a".to_string(),
1612                    label: None,
1613                    input: None,
1614                })),
1615                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1616                    variable: "b".to_string(),
1617                    label: None,
1618                    input: None,
1619                })),
1620                join_type: JoinType::Inner,
1621                conditions: vec![],
1622            })),
1623        }));
1624
1625        let optimized = optimizer.optimize(plan).unwrap();
1626
1627        // Filter should stay above join
1628        if let LogicalOperator::Filter(filter) = &optimized.root
1629            && let LogicalOperator::Join(_) = filter.input.as_ref()
1630        {
1631            return;
1632        }
1633        panic!("Expected Filter -> Join structure");
1634    }
1635
1636    // Variable extraction tests
1637
1638    #[test]
1639    fn test_extract_variables_from_variable() {
1640        let optimizer = Optimizer::new();
1641        let expr = LogicalExpression::Variable("x".to_string());
1642        let vars = optimizer.extract_variables(&expr);
1643        assert_eq!(vars.len(), 1);
1644        assert!(vars.contains("x"));
1645    }
1646
1647    #[test]
1648    fn test_extract_variables_from_unary() {
1649        let optimizer = Optimizer::new();
1650        let expr = LogicalExpression::Unary {
1651            op: UnaryOp::Not,
1652            operand: Box::new(LogicalExpression::Variable("x".to_string())),
1653        };
1654        let vars = optimizer.extract_variables(&expr);
1655        assert_eq!(vars.len(), 1);
1656        assert!(vars.contains("x"));
1657    }
1658
1659    #[test]
1660    fn test_extract_variables_from_function_call() {
1661        let optimizer = Optimizer::new();
1662        let expr = LogicalExpression::FunctionCall {
1663            name: "length".to_string(),
1664            args: vec![
1665                LogicalExpression::Variable("a".to_string()),
1666                LogicalExpression::Variable("b".to_string()),
1667            ],
1668            distinct: false,
1669        };
1670        let vars = optimizer.extract_variables(&expr);
1671        assert_eq!(vars.len(), 2);
1672        assert!(vars.contains("a"));
1673        assert!(vars.contains("b"));
1674    }
1675
1676    #[test]
1677    fn test_extract_variables_from_list() {
1678        let optimizer = Optimizer::new();
1679        let expr = LogicalExpression::List(vec![
1680            LogicalExpression::Variable("a".to_string()),
1681            LogicalExpression::Literal(Value::Int64(1)),
1682            LogicalExpression::Variable("b".to_string()),
1683        ]);
1684        let vars = optimizer.extract_variables(&expr);
1685        assert_eq!(vars.len(), 2);
1686        assert!(vars.contains("a"));
1687        assert!(vars.contains("b"));
1688    }
1689
1690    #[test]
1691    fn test_extract_variables_from_map() {
1692        let optimizer = Optimizer::new();
1693        let expr = LogicalExpression::Map(vec![
1694            (
1695                "key1".to_string(),
1696                LogicalExpression::Variable("a".to_string()),
1697            ),
1698            (
1699                "key2".to_string(),
1700                LogicalExpression::Variable("b".to_string()),
1701            ),
1702        ]);
1703        let vars = optimizer.extract_variables(&expr);
1704        assert_eq!(vars.len(), 2);
1705        assert!(vars.contains("a"));
1706        assert!(vars.contains("b"));
1707    }
1708
1709    #[test]
1710    fn test_extract_variables_from_index_access() {
1711        let optimizer = Optimizer::new();
1712        let expr = LogicalExpression::IndexAccess {
1713            base: Box::new(LogicalExpression::Variable("list".to_string())),
1714            index: Box::new(LogicalExpression::Variable("idx".to_string())),
1715        };
1716        let vars = optimizer.extract_variables(&expr);
1717        assert_eq!(vars.len(), 2);
1718        assert!(vars.contains("list"));
1719        assert!(vars.contains("idx"));
1720    }
1721
1722    #[test]
1723    fn test_extract_variables_from_slice_access() {
1724        let optimizer = Optimizer::new();
1725        let expr = LogicalExpression::SliceAccess {
1726            base: Box::new(LogicalExpression::Variable("list".to_string())),
1727            start: Some(Box::new(LogicalExpression::Variable("s".to_string()))),
1728            end: Some(Box::new(LogicalExpression::Variable("e".to_string()))),
1729        };
1730        let vars = optimizer.extract_variables(&expr);
1731        assert_eq!(vars.len(), 3);
1732        assert!(vars.contains("list"));
1733        assert!(vars.contains("s"));
1734        assert!(vars.contains("e"));
1735    }
1736
1737    #[test]
1738    fn test_extract_variables_from_case() {
1739        let optimizer = Optimizer::new();
1740        let expr = LogicalExpression::Case {
1741            operand: Some(Box::new(LogicalExpression::Variable("x".to_string()))),
1742            when_clauses: vec![(
1743                LogicalExpression::Literal(Value::Int64(1)),
1744                LogicalExpression::Variable("a".to_string()),
1745            )],
1746            else_clause: Some(Box::new(LogicalExpression::Variable("b".to_string()))),
1747        };
1748        let vars = optimizer.extract_variables(&expr);
1749        assert_eq!(vars.len(), 3);
1750        assert!(vars.contains("x"));
1751        assert!(vars.contains("a"));
1752        assert!(vars.contains("b"));
1753    }
1754
1755    #[test]
1756    fn test_extract_variables_from_labels() {
1757        let optimizer = Optimizer::new();
1758        let expr = LogicalExpression::Labels("n".to_string());
1759        let vars = optimizer.extract_variables(&expr);
1760        assert_eq!(vars.len(), 1);
1761        assert!(vars.contains("n"));
1762    }
1763
1764    #[test]
1765    fn test_extract_variables_from_type() {
1766        let optimizer = Optimizer::new();
1767        let expr = LogicalExpression::Type("e".to_string());
1768        let vars = optimizer.extract_variables(&expr);
1769        assert_eq!(vars.len(), 1);
1770        assert!(vars.contains("e"));
1771    }
1772
1773    #[test]
1774    fn test_extract_variables_from_id() {
1775        let optimizer = Optimizer::new();
1776        let expr = LogicalExpression::Id("n".to_string());
1777        let vars = optimizer.extract_variables(&expr);
1778        assert_eq!(vars.len(), 1);
1779        assert!(vars.contains("n"));
1780    }
1781
1782    #[test]
1783    fn test_extract_variables_from_list_comprehension() {
1784        let optimizer = Optimizer::new();
1785        let expr = LogicalExpression::ListComprehension {
1786            variable: "x".to_string(),
1787            list_expr: Box::new(LogicalExpression::Variable("items".to_string())),
1788            filter_expr: Some(Box::new(LogicalExpression::Variable("pred".to_string()))),
1789            map_expr: Box::new(LogicalExpression::Variable("result".to_string())),
1790        };
1791        let vars = optimizer.extract_variables(&expr);
1792        assert!(vars.contains("items"));
1793        assert!(vars.contains("pred"));
1794        assert!(vars.contains("result"));
1795    }
1796
1797    #[test]
1798    fn test_extract_variables_from_literal_and_parameter() {
1799        let optimizer = Optimizer::new();
1800
1801        let literal = LogicalExpression::Literal(Value::Int64(42));
1802        assert!(optimizer.extract_variables(&literal).is_empty());
1803
1804        let param = LogicalExpression::Parameter("p".to_string());
1805        assert!(optimizer.extract_variables(&param).is_empty());
1806    }
1807
1808    // Recursive filter pushdown tests
1809
1810    #[test]
1811    fn test_recursive_filter_pushdown_through_skip() {
1812        let optimizer = Optimizer::new();
1813
1814        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1815            items: vec![ReturnItem {
1816                expression: LogicalExpression::Variable("n".to_string()),
1817                alias: None,
1818            }],
1819            distinct: false,
1820            input: Box::new(LogicalOperator::Filter(FilterOp {
1821                predicate: LogicalExpression::Literal(Value::Bool(true)),
1822                input: Box::new(LogicalOperator::Skip(SkipOp {
1823                    count: 5,
1824                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1825                        variable: "n".to_string(),
1826                        label: None,
1827                        input: None,
1828                    })),
1829                })),
1830            })),
1831        }));
1832
1833        let optimized = optimizer.optimize(plan).unwrap();
1834
1835        // Verify optimization succeeded
1836        assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1837    }
1838
1839    #[test]
1840    fn test_nested_filter_pushdown() {
1841        let optimizer = Optimizer::new();
1842
1843        // Multiple nested filters
1844        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1845            items: vec![ReturnItem {
1846                expression: LogicalExpression::Variable("n".to_string()),
1847                alias: None,
1848            }],
1849            distinct: false,
1850            input: Box::new(LogicalOperator::Filter(FilterOp {
1851                predicate: LogicalExpression::Binary {
1852                    left: Box::new(LogicalExpression::Property {
1853                        variable: "n".to_string(),
1854                        property: "x".to_string(),
1855                    }),
1856                    op: BinaryOp::Gt,
1857                    right: Box::new(LogicalExpression::Literal(Value::Int64(1))),
1858                },
1859                input: Box::new(LogicalOperator::Filter(FilterOp {
1860                    predicate: LogicalExpression::Binary {
1861                        left: Box::new(LogicalExpression::Property {
1862                            variable: "n".to_string(),
1863                            property: "y".to_string(),
1864                        }),
1865                        op: BinaryOp::Lt,
1866                        right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1867                    },
1868                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1869                        variable: "n".to_string(),
1870                        label: None,
1871                        input: None,
1872                    })),
1873                })),
1874            })),
1875        }));
1876
1877        let optimized = optimizer.optimize(plan).unwrap();
1878        assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1879    }
1880}