Skip to main content

grafeo_engine/query/optimizer/
mod.rs

1//! Makes your queries faster without changing their meaning.
2//!
3//! The optimizer transforms logical plans to run more efficiently:
4//!
5//! | Optimization | What it does |
6//! | ------------ | ------------ |
7//! | Filter Pushdown | Moves `WHERE` clauses closer to scans - filter early, process less |
8//! | Join Reordering | Picks the best order to join tables using the DPccp algorithm |
9//! | Predicate Simplification | Folds constants like `1 + 1` into `2` |
10//!
11//! The optimizer uses [`CostModel`] and [`CardinalityEstimator`] to predict
12//! how expensive different plans are, then picks the cheapest.
13
14pub mod cardinality;
15pub mod cost;
16pub mod join_order;
17
18pub use cardinality::{
19    CardinalityEstimator, ColumnStats, EstimationLog, SelectivityConfig, TableStats,
20};
21pub use cost::{Cost, CostModel};
22pub use join_order::{BitSet, DPccp, JoinGraph, JoinGraphBuilder, JoinPlan};
23
24use crate::query::plan::{FilterOp, LogicalExpression, LogicalOperator, LogicalPlan};
25use grafeo_common::utils::error::Result;
26use std::collections::HashSet;
27
28/// Information about a join condition for join reordering.
29#[derive(Debug, Clone)]
30struct JoinInfo {
31    left_var: String,
32    right_var: String,
33    left_expr: LogicalExpression,
34    right_expr: LogicalExpression,
35}
36
37/// A column required by the query, used for projection pushdown.
38#[derive(Debug, Clone, PartialEq, Eq, Hash)]
39enum RequiredColumn {
40    /// A variable (node, edge, or path binding)
41    Variable(String),
42    /// A specific property of a variable
43    Property(String, String),
44}
45
46/// Transforms logical plans for faster execution.
47///
48/// Create with [`new()`](Self::new), then call [`optimize()`](Self::optimize).
49/// Use the builder methods to enable/disable specific optimizations.
50pub struct Optimizer {
51    /// Whether to enable filter pushdown.
52    enable_filter_pushdown: bool,
53    /// Whether to enable join reordering.
54    enable_join_reorder: bool,
55    /// Whether to enable projection pushdown.
56    enable_projection_pushdown: bool,
57    /// Cost model for estimation.
58    cost_model: CostModel,
59    /// Cardinality estimator.
60    card_estimator: CardinalityEstimator,
61}
62
63impl Optimizer {
64    /// Creates a new optimizer with default settings.
65    #[must_use]
66    pub fn new() -> Self {
67        Self {
68            enable_filter_pushdown: true,
69            enable_join_reorder: true,
70            enable_projection_pushdown: true,
71            cost_model: CostModel::new(),
72            card_estimator: CardinalityEstimator::new(),
73        }
74    }
75
76    /// Creates an optimizer with cardinality estimates from the store's statistics.
77    ///
78    /// Pre-populates the cardinality estimator with per-label row counts and
79    /// edge type fanout. Feeds per-edge-type degree stats into the cost model
80    /// for accurate expand cost estimation.
81    #[must_use]
82    pub fn from_store(store: &grafeo_core::graph::lpg::LpgStore) -> Self {
83        store.ensure_statistics_fresh();
84        let stats = store.statistics();
85        let estimator = CardinalityEstimator::from_statistics(&stats);
86
87        // Derive average fanout from statistics for the cost model
88        let avg_fanout = if stats.total_nodes > 0 {
89            (stats.total_edges as f64 / stats.total_nodes as f64).max(1.0)
90        } else {
91            10.0
92        };
93
94        // Collect per-edge-type degree stats for accurate expand costing
95        let edge_type_degrees: std::collections::HashMap<String, (f64, f64)> = stats
96            .edge_types
97            .iter()
98            .map(|(name, et)| (name.clone(), (et.avg_out_degree, et.avg_in_degree)))
99            .collect();
100
101        Self {
102            enable_filter_pushdown: true,
103            enable_join_reorder: true,
104            enable_projection_pushdown: true,
105            cost_model: CostModel::new()
106                .with_avg_fanout(avg_fanout)
107                .with_edge_type_degrees(edge_type_degrees),
108            card_estimator: estimator,
109        }
110    }
111
112    /// Enables or disables filter pushdown.
113    pub fn with_filter_pushdown(mut self, enabled: bool) -> Self {
114        self.enable_filter_pushdown = enabled;
115        self
116    }
117
118    /// Enables or disables join reordering.
119    pub fn with_join_reorder(mut self, enabled: bool) -> Self {
120        self.enable_join_reorder = enabled;
121        self
122    }
123
124    /// Enables or disables projection pushdown.
125    pub fn with_projection_pushdown(mut self, enabled: bool) -> Self {
126        self.enable_projection_pushdown = enabled;
127        self
128    }
129
130    /// Sets the cost model.
131    pub fn with_cost_model(mut self, cost_model: CostModel) -> Self {
132        self.cost_model = cost_model;
133        self
134    }
135
136    /// Sets the cardinality estimator.
137    pub fn with_cardinality_estimator(mut self, estimator: CardinalityEstimator) -> Self {
138        self.card_estimator = estimator;
139        self
140    }
141
142    /// Sets the selectivity configuration for the cardinality estimator.
143    pub fn with_selectivity_config(mut self, config: SelectivityConfig) -> Self {
144        self.card_estimator = CardinalityEstimator::with_selectivity_config(config);
145        self
146    }
147
148    /// Returns a reference to the cost model.
149    pub fn cost_model(&self) -> &CostModel {
150        &self.cost_model
151    }
152
153    /// Returns a reference to the cardinality estimator.
154    pub fn cardinality_estimator(&self) -> &CardinalityEstimator {
155        &self.card_estimator
156    }
157
158    /// Estimates the cost of a plan.
159    pub fn estimate_cost(&self, plan: &LogicalPlan) -> Cost {
160        let cardinality = self.card_estimator.estimate(&plan.root);
161        self.cost_model.estimate(&plan.root, cardinality)
162    }
163
164    /// Estimates the cardinality of a plan.
165    pub fn estimate_cardinality(&self, plan: &LogicalPlan) -> f64 {
166        self.card_estimator.estimate(&plan.root)
167    }
168
169    /// Optimizes a logical plan.
170    ///
171    /// # Errors
172    ///
173    /// Returns an error if optimization fails.
174    pub fn optimize(&self, plan: LogicalPlan) -> Result<LogicalPlan> {
175        let mut root = plan.root;
176
177        // Apply optimization rules
178        if self.enable_filter_pushdown {
179            root = self.push_filters_down(root);
180        }
181
182        if self.enable_join_reorder {
183            root = self.reorder_joins(root);
184        }
185
186        if self.enable_projection_pushdown {
187            root = self.push_projections_down(root);
188        }
189
190        Ok(LogicalPlan::new(root))
191    }
192
193    /// Pushes projections down the operator tree to eliminate unused columns early.
194    ///
195    /// This optimization:
196    /// 1. Collects required variables/properties from the root
197    /// 2. Propagates requirements down through the tree
198    /// 3. Inserts projections to eliminate unneeded columns before expensive operations
199    fn push_projections_down(&self, op: LogicalOperator) -> LogicalOperator {
200        // Collect required columns from the top of the plan
201        let required = self.collect_required_columns(&op);
202
203        // Push projections down
204        self.push_projections_recursive(op, &required)
205    }
206
207    /// Collects all variables and properties required by an operator and its ancestors.
208    fn collect_required_columns(&self, op: &LogicalOperator) -> HashSet<RequiredColumn> {
209        let mut required = HashSet::new();
210        Self::collect_required_recursive(op, &mut required);
211        required
212    }
213
214    /// Recursively collects required columns.
215    fn collect_required_recursive(op: &LogicalOperator, required: &mut HashSet<RequiredColumn>) {
216        match op {
217            LogicalOperator::Return(ret) => {
218                for item in &ret.items {
219                    Self::collect_from_expression(&item.expression, required);
220                }
221                Self::collect_required_recursive(&ret.input, required);
222            }
223            LogicalOperator::Project(proj) => {
224                for p in &proj.projections {
225                    Self::collect_from_expression(&p.expression, required);
226                }
227                Self::collect_required_recursive(&proj.input, required);
228            }
229            LogicalOperator::Filter(filter) => {
230                Self::collect_from_expression(&filter.predicate, required);
231                Self::collect_required_recursive(&filter.input, required);
232            }
233            LogicalOperator::Sort(sort) => {
234                for key in &sort.keys {
235                    Self::collect_from_expression(&key.expression, required);
236                }
237                Self::collect_required_recursive(&sort.input, required);
238            }
239            LogicalOperator::Aggregate(agg) => {
240                for expr in &agg.group_by {
241                    Self::collect_from_expression(expr, required);
242                }
243                for agg_expr in &agg.aggregates {
244                    if let Some(ref expr) = agg_expr.expression {
245                        Self::collect_from_expression(expr, required);
246                    }
247                }
248                if let Some(ref having) = agg.having {
249                    Self::collect_from_expression(having, required);
250                }
251                Self::collect_required_recursive(&agg.input, required);
252            }
253            LogicalOperator::Join(join) => {
254                for cond in &join.conditions {
255                    Self::collect_from_expression(&cond.left, required);
256                    Self::collect_from_expression(&cond.right, required);
257                }
258                Self::collect_required_recursive(&join.left, required);
259                Self::collect_required_recursive(&join.right, required);
260            }
261            LogicalOperator::Expand(expand) => {
262                // The source and target variables are needed
263                required.insert(RequiredColumn::Variable(expand.from_variable.clone()));
264                required.insert(RequiredColumn::Variable(expand.to_variable.clone()));
265                if let Some(ref edge_var) = expand.edge_variable {
266                    required.insert(RequiredColumn::Variable(edge_var.clone()));
267                }
268                Self::collect_required_recursive(&expand.input, required);
269            }
270            LogicalOperator::Limit(limit) => {
271                Self::collect_required_recursive(&limit.input, required);
272            }
273            LogicalOperator::Skip(skip) => {
274                Self::collect_required_recursive(&skip.input, required);
275            }
276            LogicalOperator::Distinct(distinct) => {
277                Self::collect_required_recursive(&distinct.input, required);
278            }
279            LogicalOperator::NodeScan(scan) => {
280                required.insert(RequiredColumn::Variable(scan.variable.clone()));
281            }
282            LogicalOperator::EdgeScan(scan) => {
283                required.insert(RequiredColumn::Variable(scan.variable.clone()));
284            }
285            _ => {}
286        }
287    }
288
289    /// Collects required columns from an expression.
290    fn collect_from_expression(expr: &LogicalExpression, required: &mut HashSet<RequiredColumn>) {
291        match expr {
292            LogicalExpression::Variable(var) => {
293                required.insert(RequiredColumn::Variable(var.clone()));
294            }
295            LogicalExpression::Property { variable, property } => {
296                required.insert(RequiredColumn::Property(variable.clone(), property.clone()));
297                required.insert(RequiredColumn::Variable(variable.clone()));
298            }
299            LogicalExpression::Binary { left, right, .. } => {
300                Self::collect_from_expression(left, required);
301                Self::collect_from_expression(right, required);
302            }
303            LogicalExpression::Unary { operand, .. } => {
304                Self::collect_from_expression(operand, required);
305            }
306            LogicalExpression::FunctionCall { args, .. } => {
307                for arg in args {
308                    Self::collect_from_expression(arg, required);
309                }
310            }
311            LogicalExpression::List(items) => {
312                for item in items {
313                    Self::collect_from_expression(item, required);
314                }
315            }
316            LogicalExpression::Map(pairs) => {
317                for (_, value) in pairs {
318                    Self::collect_from_expression(value, required);
319                }
320            }
321            LogicalExpression::IndexAccess { base, index } => {
322                Self::collect_from_expression(base, required);
323                Self::collect_from_expression(index, required);
324            }
325            LogicalExpression::SliceAccess { base, start, end } => {
326                Self::collect_from_expression(base, required);
327                if let Some(s) = start {
328                    Self::collect_from_expression(s, required);
329                }
330                if let Some(e) = end {
331                    Self::collect_from_expression(e, required);
332                }
333            }
334            LogicalExpression::Case {
335                operand,
336                when_clauses,
337                else_clause,
338            } => {
339                if let Some(op) = operand {
340                    Self::collect_from_expression(op, required);
341                }
342                for (cond, result) in when_clauses {
343                    Self::collect_from_expression(cond, required);
344                    Self::collect_from_expression(result, required);
345                }
346                if let Some(else_expr) = else_clause {
347                    Self::collect_from_expression(else_expr, required);
348                }
349            }
350            LogicalExpression::Labels(var)
351            | LogicalExpression::Type(var)
352            | LogicalExpression::Id(var) => {
353                required.insert(RequiredColumn::Variable(var.clone()));
354            }
355            LogicalExpression::ListComprehension {
356                list_expr,
357                filter_expr,
358                map_expr,
359                ..
360            } => {
361                Self::collect_from_expression(list_expr, required);
362                if let Some(filter) = filter_expr {
363                    Self::collect_from_expression(filter, required);
364                }
365                Self::collect_from_expression(map_expr, required);
366            }
367            _ => {}
368        }
369    }
370
371    /// Recursively pushes projections down, adding them before expensive operations.
372    fn push_projections_recursive(
373        &self,
374        op: LogicalOperator,
375        required: &HashSet<RequiredColumn>,
376    ) -> LogicalOperator {
377        match op {
378            LogicalOperator::Return(mut ret) => {
379                ret.input = Box::new(self.push_projections_recursive(*ret.input, required));
380                LogicalOperator::Return(ret)
381            }
382            LogicalOperator::Project(mut proj) => {
383                proj.input = Box::new(self.push_projections_recursive(*proj.input, required));
384                LogicalOperator::Project(proj)
385            }
386            LogicalOperator::Filter(mut filter) => {
387                filter.input = Box::new(self.push_projections_recursive(*filter.input, required));
388                LogicalOperator::Filter(filter)
389            }
390            LogicalOperator::Sort(mut sort) => {
391                // Sort is expensive - consider adding a projection before it
392                // to reduce tuple width
393                sort.input = Box::new(self.push_projections_recursive(*sort.input, required));
394                LogicalOperator::Sort(sort)
395            }
396            LogicalOperator::Aggregate(mut agg) => {
397                agg.input = Box::new(self.push_projections_recursive(*agg.input, required));
398                LogicalOperator::Aggregate(agg)
399            }
400            LogicalOperator::Join(mut join) => {
401                // Joins are expensive - the required columns help determine
402                // what to project on each side
403                let left_vars = self.collect_output_variables(&join.left);
404                let right_vars = self.collect_output_variables(&join.right);
405
406                // Filter required columns to each side
407                let left_required: HashSet<_> = required
408                    .iter()
409                    .filter(|c| match c {
410                        RequiredColumn::Variable(v) => left_vars.contains(v),
411                        RequiredColumn::Property(v, _) => left_vars.contains(v),
412                    })
413                    .cloned()
414                    .collect();
415
416                let right_required: HashSet<_> = required
417                    .iter()
418                    .filter(|c| match c {
419                        RequiredColumn::Variable(v) => right_vars.contains(v),
420                        RequiredColumn::Property(v, _) => right_vars.contains(v),
421                    })
422                    .cloned()
423                    .collect();
424
425                join.left = Box::new(self.push_projections_recursive(*join.left, &left_required));
426                join.right =
427                    Box::new(self.push_projections_recursive(*join.right, &right_required));
428                LogicalOperator::Join(join)
429            }
430            LogicalOperator::Expand(mut expand) => {
431                expand.input = Box::new(self.push_projections_recursive(*expand.input, required));
432                LogicalOperator::Expand(expand)
433            }
434            LogicalOperator::Limit(mut limit) => {
435                limit.input = Box::new(self.push_projections_recursive(*limit.input, required));
436                LogicalOperator::Limit(limit)
437            }
438            LogicalOperator::Skip(mut skip) => {
439                skip.input = Box::new(self.push_projections_recursive(*skip.input, required));
440                LogicalOperator::Skip(skip)
441            }
442            LogicalOperator::Distinct(mut distinct) => {
443                distinct.input =
444                    Box::new(self.push_projections_recursive(*distinct.input, required));
445                LogicalOperator::Distinct(distinct)
446            }
447            other => other,
448        }
449    }
450
451    /// Reorders joins in the operator tree using the DPccp algorithm.
452    ///
453    /// This optimization finds the optimal join order by:
454    /// 1. Extracting all base relations (scans) and join conditions
455    /// 2. Building a join graph
456    /// 3. Using dynamic programming to find the cheapest join order
457    fn reorder_joins(&self, op: LogicalOperator) -> LogicalOperator {
458        // First, recursively optimize children
459        let op = self.reorder_joins_recursive(op);
460
461        // Then, if this is a join tree, try to optimize it
462        if let Some((relations, conditions)) = self.extract_join_tree(&op)
463            && relations.len() >= 2
464            && let Some(optimized) = self.optimize_join_order(&relations, &conditions)
465        {
466            return optimized;
467        }
468
469        op
470    }
471
472    /// Recursively applies join reordering to child operators.
473    fn reorder_joins_recursive(&self, op: LogicalOperator) -> LogicalOperator {
474        match op {
475            LogicalOperator::Return(mut ret) => {
476                ret.input = Box::new(self.reorder_joins(*ret.input));
477                LogicalOperator::Return(ret)
478            }
479            LogicalOperator::Project(mut proj) => {
480                proj.input = Box::new(self.reorder_joins(*proj.input));
481                LogicalOperator::Project(proj)
482            }
483            LogicalOperator::Filter(mut filter) => {
484                filter.input = Box::new(self.reorder_joins(*filter.input));
485                LogicalOperator::Filter(filter)
486            }
487            LogicalOperator::Limit(mut limit) => {
488                limit.input = Box::new(self.reorder_joins(*limit.input));
489                LogicalOperator::Limit(limit)
490            }
491            LogicalOperator::Skip(mut skip) => {
492                skip.input = Box::new(self.reorder_joins(*skip.input));
493                LogicalOperator::Skip(skip)
494            }
495            LogicalOperator::Sort(mut sort) => {
496                sort.input = Box::new(self.reorder_joins(*sort.input));
497                LogicalOperator::Sort(sort)
498            }
499            LogicalOperator::Distinct(mut distinct) => {
500                distinct.input = Box::new(self.reorder_joins(*distinct.input));
501                LogicalOperator::Distinct(distinct)
502            }
503            LogicalOperator::Aggregate(mut agg) => {
504                agg.input = Box::new(self.reorder_joins(*agg.input));
505                LogicalOperator::Aggregate(agg)
506            }
507            LogicalOperator::Expand(mut expand) => {
508                expand.input = Box::new(self.reorder_joins(*expand.input));
509                LogicalOperator::Expand(expand)
510            }
511            // Join operators are handled by the parent reorder_joins call
512            other => other,
513        }
514    }
515
516    /// Extracts base relations and join conditions from a join tree.
517    ///
518    /// Returns None if the operator is not a join tree.
519    fn extract_join_tree(
520        &self,
521        op: &LogicalOperator,
522    ) -> Option<(Vec<(String, LogicalOperator)>, Vec<JoinInfo>)> {
523        let mut relations = Vec::new();
524        let mut join_conditions = Vec::new();
525
526        if !self.collect_join_tree(op, &mut relations, &mut join_conditions) {
527            return None;
528        }
529
530        if relations.len() < 2 {
531            return None;
532        }
533
534        Some((relations, join_conditions))
535    }
536
537    /// Recursively collects base relations and join conditions.
538    ///
539    /// Returns true if this subtree is part of a join tree.
540    fn collect_join_tree(
541        &self,
542        op: &LogicalOperator,
543        relations: &mut Vec<(String, LogicalOperator)>,
544        conditions: &mut Vec<JoinInfo>,
545    ) -> bool {
546        match op {
547            LogicalOperator::Join(join) => {
548                // Collect from both sides
549                let left_ok = self.collect_join_tree(&join.left, relations, conditions);
550                let right_ok = self.collect_join_tree(&join.right, relations, conditions);
551
552                // Add conditions from this join
553                for cond in &join.conditions {
554                    if let (Some(left_var), Some(right_var)) = (
555                        self.extract_variable_from_expr(&cond.left),
556                        self.extract_variable_from_expr(&cond.right),
557                    ) {
558                        conditions.push(JoinInfo {
559                            left_var,
560                            right_var,
561                            left_expr: cond.left.clone(),
562                            right_expr: cond.right.clone(),
563                        });
564                    }
565                }
566
567                left_ok && right_ok
568            }
569            LogicalOperator::NodeScan(scan) => {
570                relations.push((scan.variable.clone(), op.clone()));
571                true
572            }
573            LogicalOperator::EdgeScan(scan) => {
574                relations.push((scan.variable.clone(), op.clone()));
575                true
576            }
577            LogicalOperator::Filter(filter) => {
578                // A filter on a base relation is still part of the join tree
579                self.collect_join_tree(&filter.input, relations, conditions)
580            }
581            LogicalOperator::Expand(expand) => {
582                // Expand is a special case - it's like a join with the adjacency
583                // For now, treat the whole Expand subtree as a single relation
584                relations.push((expand.to_variable.clone(), op.clone()));
585                true
586            }
587            _ => false,
588        }
589    }
590
591    /// Extracts the primary variable from an expression.
592    fn extract_variable_from_expr(&self, expr: &LogicalExpression) -> Option<String> {
593        match expr {
594            LogicalExpression::Variable(v) => Some(v.clone()),
595            LogicalExpression::Property { variable, .. } => Some(variable.clone()),
596            _ => None,
597        }
598    }
599
600    /// Optimizes the join order using DPccp.
601    fn optimize_join_order(
602        &self,
603        relations: &[(String, LogicalOperator)],
604        conditions: &[JoinInfo],
605    ) -> Option<LogicalOperator> {
606        use join_order::{DPccp, JoinGraphBuilder};
607
608        // Build the join graph
609        let mut builder = JoinGraphBuilder::new();
610
611        for (var, relation) in relations {
612            builder.add_relation(var, relation.clone());
613        }
614
615        for cond in conditions {
616            builder.add_join_condition(
617                &cond.left_var,
618                &cond.right_var,
619                cond.left_expr.clone(),
620                cond.right_expr.clone(),
621            );
622        }
623
624        let graph = builder.build();
625
626        // Run DPccp
627        let mut dpccp = DPccp::new(&graph, &self.cost_model, &self.card_estimator);
628        let plan = dpccp.optimize()?;
629
630        Some(plan.operator)
631    }
632
633    /// Pushes filters down the operator tree.
634    ///
635    /// This optimization moves filter predicates as close to the data source
636    /// as possible to reduce the amount of data processed by upper operators.
637    fn push_filters_down(&self, op: LogicalOperator) -> LogicalOperator {
638        match op {
639            // For Filter operators, try to push the predicate into the child
640            LogicalOperator::Filter(filter) => {
641                let optimized_input = self.push_filters_down(*filter.input);
642                self.try_push_filter_into(filter.predicate, optimized_input)
643            }
644            // Recursively optimize children for other operators
645            LogicalOperator::Return(mut ret) => {
646                ret.input = Box::new(self.push_filters_down(*ret.input));
647                LogicalOperator::Return(ret)
648            }
649            LogicalOperator::Project(mut proj) => {
650                proj.input = Box::new(self.push_filters_down(*proj.input));
651                LogicalOperator::Project(proj)
652            }
653            LogicalOperator::Limit(mut limit) => {
654                limit.input = Box::new(self.push_filters_down(*limit.input));
655                LogicalOperator::Limit(limit)
656            }
657            LogicalOperator::Skip(mut skip) => {
658                skip.input = Box::new(self.push_filters_down(*skip.input));
659                LogicalOperator::Skip(skip)
660            }
661            LogicalOperator::Sort(mut sort) => {
662                sort.input = Box::new(self.push_filters_down(*sort.input));
663                LogicalOperator::Sort(sort)
664            }
665            LogicalOperator::Distinct(mut distinct) => {
666                distinct.input = Box::new(self.push_filters_down(*distinct.input));
667                LogicalOperator::Distinct(distinct)
668            }
669            LogicalOperator::Expand(mut expand) => {
670                expand.input = Box::new(self.push_filters_down(*expand.input));
671                LogicalOperator::Expand(expand)
672            }
673            LogicalOperator::Join(mut join) => {
674                join.left = Box::new(self.push_filters_down(*join.left));
675                join.right = Box::new(self.push_filters_down(*join.right));
676                LogicalOperator::Join(join)
677            }
678            LogicalOperator::Aggregate(mut agg) => {
679                agg.input = Box::new(self.push_filters_down(*agg.input));
680                LogicalOperator::Aggregate(agg)
681            }
682            // Leaf operators and unsupported operators are returned as-is
683            other => other,
684        }
685    }
686
687    /// Tries to push a filter predicate into the given operator.
688    ///
689    /// Returns either the predicate pushed into the operator, or a new
690    /// Filter operator on top if the predicate cannot be pushed further.
691    fn try_push_filter_into(
692        &self,
693        predicate: LogicalExpression,
694        op: LogicalOperator,
695    ) -> LogicalOperator {
696        match op {
697            // Can push through Project if predicate doesn't depend on computed columns
698            LogicalOperator::Project(mut proj) => {
699                let predicate_vars = self.extract_variables(&predicate);
700                let computed_vars = self.extract_projection_aliases(&proj.projections);
701
702                // If predicate doesn't use any computed columns, push through
703                if predicate_vars.is_disjoint(&computed_vars) {
704                    proj.input = Box::new(self.try_push_filter_into(predicate, *proj.input));
705                    LogicalOperator::Project(proj)
706                } else {
707                    // Can't push through, keep filter on top
708                    LogicalOperator::Filter(FilterOp {
709                        predicate,
710                        input: Box::new(LogicalOperator::Project(proj)),
711                    })
712                }
713            }
714
715            // Can push through Return (which is like a projection)
716            LogicalOperator::Return(mut ret) => {
717                ret.input = Box::new(self.try_push_filter_into(predicate, *ret.input));
718                LogicalOperator::Return(ret)
719            }
720
721            // Can push through Expand if predicate doesn't use variables introduced by this expand
722            LogicalOperator::Expand(mut expand) => {
723                let predicate_vars = self.extract_variables(&predicate);
724
725                // Variables introduced by this expand are:
726                // - The target variable (to_variable)
727                // - The edge variable (if any)
728                // - The path alias (if any)
729                let mut introduced_vars = vec![&expand.to_variable];
730                if let Some(ref edge_var) = expand.edge_variable {
731                    introduced_vars.push(edge_var);
732                }
733                if let Some(ref path_alias) = expand.path_alias {
734                    introduced_vars.push(path_alias);
735                }
736
737                // Check if predicate uses any variables introduced by this expand
738                let uses_introduced_vars =
739                    predicate_vars.iter().any(|v| introduced_vars.contains(&v));
740
741                if !uses_introduced_vars {
742                    // Predicate doesn't use vars from this expand, so push through
743                    expand.input = Box::new(self.try_push_filter_into(predicate, *expand.input));
744                    LogicalOperator::Expand(expand)
745                } else {
746                    // Keep filter after expand
747                    LogicalOperator::Filter(FilterOp {
748                        predicate,
749                        input: Box::new(LogicalOperator::Expand(expand)),
750                    })
751                }
752            }
753
754            // Can push through Join to left/right side based on variables used
755            LogicalOperator::Join(mut join) => {
756                let predicate_vars = self.extract_variables(&predicate);
757                let left_vars = self.collect_output_variables(&join.left);
758                let right_vars = self.collect_output_variables(&join.right);
759
760                let uses_left = predicate_vars.iter().any(|v| left_vars.contains(v));
761                let uses_right = predicate_vars.iter().any(|v| right_vars.contains(v));
762
763                if uses_left && !uses_right {
764                    // Push to left side
765                    join.left = Box::new(self.try_push_filter_into(predicate, *join.left));
766                    LogicalOperator::Join(join)
767                } else if uses_right && !uses_left {
768                    // Push to right side
769                    join.right = Box::new(self.try_push_filter_into(predicate, *join.right));
770                    LogicalOperator::Join(join)
771                } else {
772                    // Uses both sides - keep above join
773                    LogicalOperator::Filter(FilterOp {
774                        predicate,
775                        input: Box::new(LogicalOperator::Join(join)),
776                    })
777                }
778            }
779
780            // Cannot push through Aggregate (predicate refers to aggregated values)
781            LogicalOperator::Aggregate(agg) => LogicalOperator::Filter(FilterOp {
782                predicate,
783                input: Box::new(LogicalOperator::Aggregate(agg)),
784            }),
785
786            // For NodeScan, we've reached the bottom - keep filter on top
787            LogicalOperator::NodeScan(scan) => LogicalOperator::Filter(FilterOp {
788                predicate,
789                input: Box::new(LogicalOperator::NodeScan(scan)),
790            }),
791
792            // For other operators, keep filter on top
793            other => LogicalOperator::Filter(FilterOp {
794                predicate,
795                input: Box::new(other),
796            }),
797        }
798    }
799
800    /// Collects all output variable names from an operator.
801    fn collect_output_variables(&self, op: &LogicalOperator) -> HashSet<String> {
802        let mut vars = HashSet::new();
803        Self::collect_output_variables_recursive(op, &mut vars);
804        vars
805    }
806
807    /// Recursively collects output variables from an operator.
808    fn collect_output_variables_recursive(op: &LogicalOperator, vars: &mut HashSet<String>) {
809        match op {
810            LogicalOperator::NodeScan(scan) => {
811                vars.insert(scan.variable.clone());
812            }
813            LogicalOperator::EdgeScan(scan) => {
814                vars.insert(scan.variable.clone());
815            }
816            LogicalOperator::Expand(expand) => {
817                vars.insert(expand.to_variable.clone());
818                if let Some(edge_var) = &expand.edge_variable {
819                    vars.insert(edge_var.clone());
820                }
821                Self::collect_output_variables_recursive(&expand.input, vars);
822            }
823            LogicalOperator::Filter(filter) => {
824                Self::collect_output_variables_recursive(&filter.input, vars);
825            }
826            LogicalOperator::Project(proj) => {
827                for p in &proj.projections {
828                    if let Some(alias) = &p.alias {
829                        vars.insert(alias.clone());
830                    }
831                }
832                Self::collect_output_variables_recursive(&proj.input, vars);
833            }
834            LogicalOperator::Join(join) => {
835                Self::collect_output_variables_recursive(&join.left, vars);
836                Self::collect_output_variables_recursive(&join.right, vars);
837            }
838            LogicalOperator::Aggregate(agg) => {
839                for expr in &agg.group_by {
840                    Self::collect_variables(expr, vars);
841                }
842                for agg_expr in &agg.aggregates {
843                    if let Some(alias) = &agg_expr.alias {
844                        vars.insert(alias.clone());
845                    }
846                }
847            }
848            LogicalOperator::Return(ret) => {
849                Self::collect_output_variables_recursive(&ret.input, vars);
850            }
851            LogicalOperator::Limit(limit) => {
852                Self::collect_output_variables_recursive(&limit.input, vars);
853            }
854            LogicalOperator::Skip(skip) => {
855                Self::collect_output_variables_recursive(&skip.input, vars);
856            }
857            LogicalOperator::Sort(sort) => {
858                Self::collect_output_variables_recursive(&sort.input, vars);
859            }
860            LogicalOperator::Distinct(distinct) => {
861                Self::collect_output_variables_recursive(&distinct.input, vars);
862            }
863            _ => {}
864        }
865    }
866
867    /// Extracts all variable names referenced in an expression.
868    fn extract_variables(&self, expr: &LogicalExpression) -> HashSet<String> {
869        let mut vars = HashSet::new();
870        Self::collect_variables(expr, &mut vars);
871        vars
872    }
873
874    /// Recursively collects variable names from an expression.
875    fn collect_variables(expr: &LogicalExpression, vars: &mut HashSet<String>) {
876        match expr {
877            LogicalExpression::Variable(name) => {
878                vars.insert(name.clone());
879            }
880            LogicalExpression::Property { variable, .. } => {
881                vars.insert(variable.clone());
882            }
883            LogicalExpression::Binary { left, right, .. } => {
884                Self::collect_variables(left, vars);
885                Self::collect_variables(right, vars);
886            }
887            LogicalExpression::Unary { operand, .. } => {
888                Self::collect_variables(operand, vars);
889            }
890            LogicalExpression::FunctionCall { args, .. } => {
891                for arg in args {
892                    Self::collect_variables(arg, vars);
893                }
894            }
895            LogicalExpression::List(items) => {
896                for item in items {
897                    Self::collect_variables(item, vars);
898                }
899            }
900            LogicalExpression::Map(pairs) => {
901                for (_, value) in pairs {
902                    Self::collect_variables(value, vars);
903                }
904            }
905            LogicalExpression::IndexAccess { base, index } => {
906                Self::collect_variables(base, vars);
907                Self::collect_variables(index, vars);
908            }
909            LogicalExpression::SliceAccess { base, start, end } => {
910                Self::collect_variables(base, vars);
911                if let Some(s) = start {
912                    Self::collect_variables(s, vars);
913                }
914                if let Some(e) = end {
915                    Self::collect_variables(e, vars);
916                }
917            }
918            LogicalExpression::Case {
919                operand,
920                when_clauses,
921                else_clause,
922            } => {
923                if let Some(op) = operand {
924                    Self::collect_variables(op, vars);
925                }
926                for (cond, result) in when_clauses {
927                    Self::collect_variables(cond, vars);
928                    Self::collect_variables(result, vars);
929                }
930                if let Some(else_expr) = else_clause {
931                    Self::collect_variables(else_expr, vars);
932                }
933            }
934            LogicalExpression::Labels(var)
935            | LogicalExpression::Type(var)
936            | LogicalExpression::Id(var) => {
937                vars.insert(var.clone());
938            }
939            LogicalExpression::Literal(_) | LogicalExpression::Parameter(_) => {}
940            LogicalExpression::ListComprehension {
941                list_expr,
942                filter_expr,
943                map_expr,
944                ..
945            } => {
946                Self::collect_variables(list_expr, vars);
947                if let Some(filter) = filter_expr {
948                    Self::collect_variables(filter, vars);
949                }
950                Self::collect_variables(map_expr, vars);
951            }
952            LogicalExpression::ExistsSubquery(_) | LogicalExpression::CountSubquery(_) => {
953                // Subqueries have their own variable scope
954            }
955        }
956    }
957
958    /// Extracts aliases from projection expressions.
959    fn extract_projection_aliases(
960        &self,
961        projections: &[crate::query::plan::Projection],
962    ) -> HashSet<String> {
963        projections.iter().filter_map(|p| p.alias.clone()).collect()
964    }
965}
966
967impl Default for Optimizer {
968    fn default() -> Self {
969        Self::new()
970    }
971}
972
973#[cfg(test)]
974mod tests {
975    use super::*;
976    use crate::query::plan::{
977        AggregateExpr, AggregateFunction, AggregateOp, BinaryOp, DistinctOp, ExpandDirection,
978        ExpandOp, JoinOp, JoinType, LimitOp, NodeScanOp, ProjectOp, Projection, ReturnItem,
979        ReturnOp, SkipOp, SortKey, SortOp, SortOrder, UnaryOp,
980    };
981    use grafeo_common::types::Value;
982
983    #[test]
984    fn test_optimizer_filter_pushdown_simple() {
985        // Query: MATCH (n:Person) WHERE n.age > 30 RETURN n
986        // Before: Return -> Filter -> NodeScan
987        // After:  Return -> Filter -> NodeScan (filter stays at bottom)
988
989        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
990            items: vec![ReturnItem {
991                expression: LogicalExpression::Variable("n".to_string()),
992                alias: None,
993            }],
994            distinct: false,
995            input: Box::new(LogicalOperator::Filter(FilterOp {
996                predicate: LogicalExpression::Binary {
997                    left: Box::new(LogicalExpression::Property {
998                        variable: "n".to_string(),
999                        property: "age".to_string(),
1000                    }),
1001                    op: BinaryOp::Gt,
1002                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1003                },
1004                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1005                    variable: "n".to_string(),
1006                    label: Some("Person".to_string()),
1007                    input: None,
1008                })),
1009            })),
1010        }));
1011
1012        let optimizer = Optimizer::new();
1013        let optimized = optimizer.optimize(plan).unwrap();
1014
1015        // The structure should remain similar (filter stays near scan)
1016        if let LogicalOperator::Return(ret) = &optimized.root
1017            && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1018            && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1019        {
1020            assert_eq!(scan.variable, "n");
1021            return;
1022        }
1023        panic!("Expected Return -> Filter -> NodeScan structure");
1024    }
1025
1026    #[test]
1027    fn test_optimizer_filter_pushdown_through_expand() {
1028        // Query: MATCH (a:Person)-[:KNOWS]->(b) WHERE a.age > 30 RETURN b
1029        // The filter on 'a' should be pushed before the expand
1030
1031        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1032            items: vec![ReturnItem {
1033                expression: LogicalExpression::Variable("b".to_string()),
1034                alias: None,
1035            }],
1036            distinct: false,
1037            input: Box::new(LogicalOperator::Filter(FilterOp {
1038                predicate: LogicalExpression::Binary {
1039                    left: Box::new(LogicalExpression::Property {
1040                        variable: "a".to_string(),
1041                        property: "age".to_string(),
1042                    }),
1043                    op: BinaryOp::Gt,
1044                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1045                },
1046                input: Box::new(LogicalOperator::Expand(ExpandOp {
1047                    from_variable: "a".to_string(),
1048                    to_variable: "b".to_string(),
1049                    edge_variable: None,
1050                    direction: ExpandDirection::Outgoing,
1051                    edge_type: Some("KNOWS".to_string()),
1052                    min_hops: 1,
1053                    max_hops: Some(1),
1054                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1055                        variable: "a".to_string(),
1056                        label: Some("Person".to_string()),
1057                        input: None,
1058                    })),
1059                    path_alias: None,
1060                })),
1061            })),
1062        }));
1063
1064        let optimizer = Optimizer::new();
1065        let optimized = optimizer.optimize(plan).unwrap();
1066
1067        // Filter on 'a' should be pushed before the expand
1068        // Expected: Return -> Expand -> Filter -> NodeScan
1069        if let LogicalOperator::Return(ret) = &optimized.root
1070            && let LogicalOperator::Expand(expand) = ret.input.as_ref()
1071            && let LogicalOperator::Filter(filter) = expand.input.as_ref()
1072            && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1073        {
1074            assert_eq!(scan.variable, "a");
1075            assert_eq!(expand.from_variable, "a");
1076            assert_eq!(expand.to_variable, "b");
1077            return;
1078        }
1079        panic!("Expected Return -> Expand -> Filter -> NodeScan structure");
1080    }
1081
1082    #[test]
1083    fn test_optimizer_filter_not_pushed_through_expand_for_target_var() {
1084        // Query: MATCH (a:Person)-[:KNOWS]->(b) WHERE b.age > 30 RETURN a
1085        // The filter on 'b' should NOT be pushed before the expand
1086
1087        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1088            items: vec![ReturnItem {
1089                expression: LogicalExpression::Variable("a".to_string()),
1090                alias: None,
1091            }],
1092            distinct: false,
1093            input: Box::new(LogicalOperator::Filter(FilterOp {
1094                predicate: LogicalExpression::Binary {
1095                    left: Box::new(LogicalExpression::Property {
1096                        variable: "b".to_string(),
1097                        property: "age".to_string(),
1098                    }),
1099                    op: BinaryOp::Gt,
1100                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1101                },
1102                input: Box::new(LogicalOperator::Expand(ExpandOp {
1103                    from_variable: "a".to_string(),
1104                    to_variable: "b".to_string(),
1105                    edge_variable: None,
1106                    direction: ExpandDirection::Outgoing,
1107                    edge_type: Some("KNOWS".to_string()),
1108                    min_hops: 1,
1109                    max_hops: Some(1),
1110                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1111                        variable: "a".to_string(),
1112                        label: Some("Person".to_string()),
1113                        input: None,
1114                    })),
1115                    path_alias: None,
1116                })),
1117            })),
1118        }));
1119
1120        let optimizer = Optimizer::new();
1121        let optimized = optimizer.optimize(plan).unwrap();
1122
1123        // Filter on 'b' should stay after the expand
1124        // Expected: Return -> Filter -> Expand -> NodeScan
1125        if let LogicalOperator::Return(ret) = &optimized.root
1126            && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1127        {
1128            // Check that the filter is on 'b'
1129            if let LogicalExpression::Binary { left, .. } = &filter.predicate
1130                && let LogicalExpression::Property { variable, .. } = left.as_ref()
1131            {
1132                assert_eq!(variable, "b");
1133            }
1134
1135            if let LogicalOperator::Expand(expand) = filter.input.as_ref()
1136                && let LogicalOperator::NodeScan(_) = expand.input.as_ref()
1137            {
1138                return;
1139            }
1140        }
1141        panic!("Expected Return -> Filter -> Expand -> NodeScan structure");
1142    }
1143
1144    #[test]
1145    fn test_optimizer_extract_variables() {
1146        let optimizer = Optimizer::new();
1147
1148        let expr = LogicalExpression::Binary {
1149            left: Box::new(LogicalExpression::Property {
1150                variable: "n".to_string(),
1151                property: "age".to_string(),
1152            }),
1153            op: BinaryOp::Gt,
1154            right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1155        };
1156
1157        let vars = optimizer.extract_variables(&expr);
1158        assert_eq!(vars.len(), 1);
1159        assert!(vars.contains("n"));
1160    }
1161
1162    // Additional tests for optimizer configuration
1163
1164    #[test]
1165    fn test_optimizer_default() {
1166        let optimizer = Optimizer::default();
1167        // Should be able to optimize an empty plan
1168        let plan = LogicalPlan::new(LogicalOperator::Empty);
1169        let result = optimizer.optimize(plan);
1170        assert!(result.is_ok());
1171    }
1172
1173    #[test]
1174    fn test_optimizer_with_filter_pushdown_disabled() {
1175        let optimizer = Optimizer::new().with_filter_pushdown(false);
1176
1177        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1178            items: vec![ReturnItem {
1179                expression: LogicalExpression::Variable("n".to_string()),
1180                alias: None,
1181            }],
1182            distinct: false,
1183            input: Box::new(LogicalOperator::Filter(FilterOp {
1184                predicate: LogicalExpression::Literal(Value::Bool(true)),
1185                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1186                    variable: "n".to_string(),
1187                    label: None,
1188                    input: None,
1189                })),
1190            })),
1191        }));
1192
1193        let optimized = optimizer.optimize(plan).unwrap();
1194        // Structure should be unchanged
1195        if let LogicalOperator::Return(ret) = &optimized.root
1196            && let LogicalOperator::Filter(_) = ret.input.as_ref()
1197        {
1198            return;
1199        }
1200        panic!("Expected unchanged structure");
1201    }
1202
1203    #[test]
1204    fn test_optimizer_with_join_reorder_disabled() {
1205        let optimizer = Optimizer::new().with_join_reorder(false);
1206        assert!(
1207            optimizer
1208                .optimize(LogicalPlan::new(LogicalOperator::Empty))
1209                .is_ok()
1210        );
1211    }
1212
1213    #[test]
1214    fn test_optimizer_with_cost_model() {
1215        let cost_model = CostModel::new();
1216        let optimizer = Optimizer::new().with_cost_model(cost_model);
1217        assert!(
1218            optimizer
1219                .cost_model()
1220                .estimate(&LogicalOperator::Empty, 0.0)
1221                .total()
1222                < 0.001
1223        );
1224    }
1225
1226    #[test]
1227    fn test_optimizer_with_cardinality_estimator() {
1228        let mut estimator = CardinalityEstimator::new();
1229        estimator.add_table_stats("Test", TableStats::new(500));
1230        let optimizer = Optimizer::new().with_cardinality_estimator(estimator);
1231
1232        let scan = LogicalOperator::NodeScan(NodeScanOp {
1233            variable: "n".to_string(),
1234            label: Some("Test".to_string()),
1235            input: None,
1236        });
1237        let plan = LogicalPlan::new(scan);
1238
1239        let cardinality = optimizer.estimate_cardinality(&plan);
1240        assert!((cardinality - 500.0).abs() < 0.001);
1241    }
1242
1243    #[test]
1244    fn test_optimizer_estimate_cost() {
1245        let optimizer = Optimizer::new();
1246        let plan = LogicalPlan::new(LogicalOperator::NodeScan(NodeScanOp {
1247            variable: "n".to_string(),
1248            label: None,
1249            input: None,
1250        }));
1251
1252        let cost = optimizer.estimate_cost(&plan);
1253        assert!(cost.total() > 0.0);
1254    }
1255
1256    // Filter pushdown through various operators
1257
1258    #[test]
1259    fn test_filter_pushdown_through_project() {
1260        let optimizer = Optimizer::new();
1261
1262        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1263            predicate: LogicalExpression::Binary {
1264                left: Box::new(LogicalExpression::Property {
1265                    variable: "n".to_string(),
1266                    property: "age".to_string(),
1267                }),
1268                op: BinaryOp::Gt,
1269                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1270            },
1271            input: Box::new(LogicalOperator::Project(ProjectOp {
1272                projections: vec![Projection {
1273                    expression: LogicalExpression::Variable("n".to_string()),
1274                    alias: None,
1275                }],
1276                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1277                    variable: "n".to_string(),
1278                    label: None,
1279                    input: None,
1280                })),
1281            })),
1282        }));
1283
1284        let optimized = optimizer.optimize(plan).unwrap();
1285
1286        // Filter should be pushed through Project
1287        if let LogicalOperator::Project(proj) = &optimized.root
1288            && let LogicalOperator::Filter(_) = proj.input.as_ref()
1289        {
1290            return;
1291        }
1292        panic!("Expected Project -> Filter structure");
1293    }
1294
1295    #[test]
1296    fn test_filter_not_pushed_through_project_with_alias() {
1297        let optimizer = Optimizer::new();
1298
1299        // Filter on computed column 'x' should not be pushed through project that creates 'x'
1300        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1301            predicate: LogicalExpression::Binary {
1302                left: Box::new(LogicalExpression::Variable("x".to_string())),
1303                op: BinaryOp::Gt,
1304                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1305            },
1306            input: Box::new(LogicalOperator::Project(ProjectOp {
1307                projections: vec![Projection {
1308                    expression: LogicalExpression::Property {
1309                        variable: "n".to_string(),
1310                        property: "age".to_string(),
1311                    },
1312                    alias: Some("x".to_string()),
1313                }],
1314                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1315                    variable: "n".to_string(),
1316                    label: None,
1317                    input: None,
1318                })),
1319            })),
1320        }));
1321
1322        let optimized = optimizer.optimize(plan).unwrap();
1323
1324        // Filter should stay above Project
1325        if let LogicalOperator::Filter(filter) = &optimized.root
1326            && let LogicalOperator::Project(_) = filter.input.as_ref()
1327        {
1328            return;
1329        }
1330        panic!("Expected Filter -> Project structure");
1331    }
1332
1333    #[test]
1334    fn test_filter_pushdown_through_limit() {
1335        let optimizer = Optimizer::new();
1336
1337        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1338            predicate: LogicalExpression::Literal(Value::Bool(true)),
1339            input: Box::new(LogicalOperator::Limit(LimitOp {
1340                count: 10,
1341                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1342                    variable: "n".to_string(),
1343                    label: None,
1344                    input: None,
1345                })),
1346            })),
1347        }));
1348
1349        let optimized = optimizer.optimize(plan).unwrap();
1350
1351        // Filter stays above Limit (cannot be pushed through)
1352        if let LogicalOperator::Filter(filter) = &optimized.root
1353            && let LogicalOperator::Limit(_) = filter.input.as_ref()
1354        {
1355            return;
1356        }
1357        panic!("Expected Filter -> Limit structure");
1358    }
1359
1360    #[test]
1361    fn test_filter_pushdown_through_sort() {
1362        let optimizer = Optimizer::new();
1363
1364        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1365            predicate: LogicalExpression::Literal(Value::Bool(true)),
1366            input: Box::new(LogicalOperator::Sort(SortOp {
1367                keys: vec![SortKey {
1368                    expression: LogicalExpression::Variable("n".to_string()),
1369                    order: SortOrder::Ascending,
1370                }],
1371                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1372                    variable: "n".to_string(),
1373                    label: None,
1374                    input: None,
1375                })),
1376            })),
1377        }));
1378
1379        let optimized = optimizer.optimize(plan).unwrap();
1380
1381        // Filter stays above Sort
1382        if let LogicalOperator::Filter(filter) = &optimized.root
1383            && let LogicalOperator::Sort(_) = filter.input.as_ref()
1384        {
1385            return;
1386        }
1387        panic!("Expected Filter -> Sort structure");
1388    }
1389
1390    #[test]
1391    fn test_filter_pushdown_through_distinct() {
1392        let optimizer = Optimizer::new();
1393
1394        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1395            predicate: LogicalExpression::Literal(Value::Bool(true)),
1396            input: Box::new(LogicalOperator::Distinct(DistinctOp {
1397                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1398                    variable: "n".to_string(),
1399                    label: None,
1400                    input: None,
1401                })),
1402                columns: None,
1403            })),
1404        }));
1405
1406        let optimized = optimizer.optimize(plan).unwrap();
1407
1408        // Filter stays above Distinct
1409        if let LogicalOperator::Filter(filter) = &optimized.root
1410            && let LogicalOperator::Distinct(_) = filter.input.as_ref()
1411        {
1412            return;
1413        }
1414        panic!("Expected Filter -> Distinct structure");
1415    }
1416
1417    #[test]
1418    fn test_filter_not_pushed_through_aggregate() {
1419        let optimizer = Optimizer::new();
1420
1421        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1422            predicate: LogicalExpression::Binary {
1423                left: Box::new(LogicalExpression::Variable("cnt".to_string())),
1424                op: BinaryOp::Gt,
1425                right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1426            },
1427            input: Box::new(LogicalOperator::Aggregate(AggregateOp {
1428                group_by: vec![],
1429                aggregates: vec![AggregateExpr {
1430                    function: AggregateFunction::Count,
1431                    expression: None,
1432                    distinct: false,
1433                    alias: Some("cnt".to_string()),
1434                    percentile: None,
1435                }],
1436                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1437                    variable: "n".to_string(),
1438                    label: None,
1439                    input: None,
1440                })),
1441                having: None,
1442            })),
1443        }));
1444
1445        let optimized = optimizer.optimize(plan).unwrap();
1446
1447        // Filter should stay above Aggregate
1448        if let LogicalOperator::Filter(filter) = &optimized.root
1449            && let LogicalOperator::Aggregate(_) = filter.input.as_ref()
1450        {
1451            return;
1452        }
1453        panic!("Expected Filter -> Aggregate structure");
1454    }
1455
1456    #[test]
1457    fn test_filter_pushdown_to_left_join_side() {
1458        let optimizer = Optimizer::new();
1459
1460        // Filter on left variable should be pushed to left side
1461        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1462            predicate: LogicalExpression::Binary {
1463                left: Box::new(LogicalExpression::Property {
1464                    variable: "a".to_string(),
1465                    property: "age".to_string(),
1466                }),
1467                op: BinaryOp::Gt,
1468                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1469            },
1470            input: Box::new(LogicalOperator::Join(JoinOp {
1471                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1472                    variable: "a".to_string(),
1473                    label: Some("Person".to_string()),
1474                    input: None,
1475                })),
1476                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1477                    variable: "b".to_string(),
1478                    label: Some("Company".to_string()),
1479                    input: None,
1480                })),
1481                join_type: JoinType::Inner,
1482                conditions: vec![],
1483            })),
1484        }));
1485
1486        let optimized = optimizer.optimize(plan).unwrap();
1487
1488        // Filter should be pushed to left side of join
1489        if let LogicalOperator::Join(join) = &optimized.root
1490            && let LogicalOperator::Filter(_) = join.left.as_ref()
1491        {
1492            return;
1493        }
1494        panic!("Expected Join with Filter on left side");
1495    }
1496
1497    #[test]
1498    fn test_filter_pushdown_to_right_join_side() {
1499        let optimizer = Optimizer::new();
1500
1501        // Filter on right variable should be pushed to right side
1502        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1503            predicate: LogicalExpression::Binary {
1504                left: Box::new(LogicalExpression::Property {
1505                    variable: "b".to_string(),
1506                    property: "name".to_string(),
1507                }),
1508                op: BinaryOp::Eq,
1509                right: Box::new(LogicalExpression::Literal(Value::String("Acme".into()))),
1510            },
1511            input: Box::new(LogicalOperator::Join(JoinOp {
1512                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1513                    variable: "a".to_string(),
1514                    label: Some("Person".to_string()),
1515                    input: None,
1516                })),
1517                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1518                    variable: "b".to_string(),
1519                    label: Some("Company".to_string()),
1520                    input: None,
1521                })),
1522                join_type: JoinType::Inner,
1523                conditions: vec![],
1524            })),
1525        }));
1526
1527        let optimized = optimizer.optimize(plan).unwrap();
1528
1529        // Filter should be pushed to right side of join
1530        if let LogicalOperator::Join(join) = &optimized.root
1531            && let LogicalOperator::Filter(_) = join.right.as_ref()
1532        {
1533            return;
1534        }
1535        panic!("Expected Join with Filter on right side");
1536    }
1537
1538    #[test]
1539    fn test_filter_not_pushed_when_uses_both_join_sides() {
1540        let optimizer = Optimizer::new();
1541
1542        // Filter using both variables should stay above join
1543        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1544            predicate: LogicalExpression::Binary {
1545                left: Box::new(LogicalExpression::Property {
1546                    variable: "a".to_string(),
1547                    property: "id".to_string(),
1548                }),
1549                op: BinaryOp::Eq,
1550                right: Box::new(LogicalExpression::Property {
1551                    variable: "b".to_string(),
1552                    property: "a_id".to_string(),
1553                }),
1554            },
1555            input: Box::new(LogicalOperator::Join(JoinOp {
1556                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1557                    variable: "a".to_string(),
1558                    label: None,
1559                    input: None,
1560                })),
1561                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1562                    variable: "b".to_string(),
1563                    label: None,
1564                    input: None,
1565                })),
1566                join_type: JoinType::Inner,
1567                conditions: vec![],
1568            })),
1569        }));
1570
1571        let optimized = optimizer.optimize(plan).unwrap();
1572
1573        // Filter should stay above join
1574        if let LogicalOperator::Filter(filter) = &optimized.root
1575            && let LogicalOperator::Join(_) = filter.input.as_ref()
1576        {
1577            return;
1578        }
1579        panic!("Expected Filter -> Join structure");
1580    }
1581
1582    // Variable extraction tests
1583
1584    #[test]
1585    fn test_extract_variables_from_variable() {
1586        let optimizer = Optimizer::new();
1587        let expr = LogicalExpression::Variable("x".to_string());
1588        let vars = optimizer.extract_variables(&expr);
1589        assert_eq!(vars.len(), 1);
1590        assert!(vars.contains("x"));
1591    }
1592
1593    #[test]
1594    fn test_extract_variables_from_unary() {
1595        let optimizer = Optimizer::new();
1596        let expr = LogicalExpression::Unary {
1597            op: UnaryOp::Not,
1598            operand: Box::new(LogicalExpression::Variable("x".to_string())),
1599        };
1600        let vars = optimizer.extract_variables(&expr);
1601        assert_eq!(vars.len(), 1);
1602        assert!(vars.contains("x"));
1603    }
1604
1605    #[test]
1606    fn test_extract_variables_from_function_call() {
1607        let optimizer = Optimizer::new();
1608        let expr = LogicalExpression::FunctionCall {
1609            name: "length".to_string(),
1610            args: vec![
1611                LogicalExpression::Variable("a".to_string()),
1612                LogicalExpression::Variable("b".to_string()),
1613            ],
1614            distinct: false,
1615        };
1616        let vars = optimizer.extract_variables(&expr);
1617        assert_eq!(vars.len(), 2);
1618        assert!(vars.contains("a"));
1619        assert!(vars.contains("b"));
1620    }
1621
1622    #[test]
1623    fn test_extract_variables_from_list() {
1624        let optimizer = Optimizer::new();
1625        let expr = LogicalExpression::List(vec![
1626            LogicalExpression::Variable("a".to_string()),
1627            LogicalExpression::Literal(Value::Int64(1)),
1628            LogicalExpression::Variable("b".to_string()),
1629        ]);
1630        let vars = optimizer.extract_variables(&expr);
1631        assert_eq!(vars.len(), 2);
1632        assert!(vars.contains("a"));
1633        assert!(vars.contains("b"));
1634    }
1635
1636    #[test]
1637    fn test_extract_variables_from_map() {
1638        let optimizer = Optimizer::new();
1639        let expr = LogicalExpression::Map(vec![
1640            (
1641                "key1".to_string(),
1642                LogicalExpression::Variable("a".to_string()),
1643            ),
1644            (
1645                "key2".to_string(),
1646                LogicalExpression::Variable("b".to_string()),
1647            ),
1648        ]);
1649        let vars = optimizer.extract_variables(&expr);
1650        assert_eq!(vars.len(), 2);
1651        assert!(vars.contains("a"));
1652        assert!(vars.contains("b"));
1653    }
1654
1655    #[test]
1656    fn test_extract_variables_from_index_access() {
1657        let optimizer = Optimizer::new();
1658        let expr = LogicalExpression::IndexAccess {
1659            base: Box::new(LogicalExpression::Variable("list".to_string())),
1660            index: Box::new(LogicalExpression::Variable("idx".to_string())),
1661        };
1662        let vars = optimizer.extract_variables(&expr);
1663        assert_eq!(vars.len(), 2);
1664        assert!(vars.contains("list"));
1665        assert!(vars.contains("idx"));
1666    }
1667
1668    #[test]
1669    fn test_extract_variables_from_slice_access() {
1670        let optimizer = Optimizer::new();
1671        let expr = LogicalExpression::SliceAccess {
1672            base: Box::new(LogicalExpression::Variable("list".to_string())),
1673            start: Some(Box::new(LogicalExpression::Variable("s".to_string()))),
1674            end: Some(Box::new(LogicalExpression::Variable("e".to_string()))),
1675        };
1676        let vars = optimizer.extract_variables(&expr);
1677        assert_eq!(vars.len(), 3);
1678        assert!(vars.contains("list"));
1679        assert!(vars.contains("s"));
1680        assert!(vars.contains("e"));
1681    }
1682
1683    #[test]
1684    fn test_extract_variables_from_case() {
1685        let optimizer = Optimizer::new();
1686        let expr = LogicalExpression::Case {
1687            operand: Some(Box::new(LogicalExpression::Variable("x".to_string()))),
1688            when_clauses: vec![(
1689                LogicalExpression::Literal(Value::Int64(1)),
1690                LogicalExpression::Variable("a".to_string()),
1691            )],
1692            else_clause: Some(Box::new(LogicalExpression::Variable("b".to_string()))),
1693        };
1694        let vars = optimizer.extract_variables(&expr);
1695        assert_eq!(vars.len(), 3);
1696        assert!(vars.contains("x"));
1697        assert!(vars.contains("a"));
1698        assert!(vars.contains("b"));
1699    }
1700
1701    #[test]
1702    fn test_extract_variables_from_labels() {
1703        let optimizer = Optimizer::new();
1704        let expr = LogicalExpression::Labels("n".to_string());
1705        let vars = optimizer.extract_variables(&expr);
1706        assert_eq!(vars.len(), 1);
1707        assert!(vars.contains("n"));
1708    }
1709
1710    #[test]
1711    fn test_extract_variables_from_type() {
1712        let optimizer = Optimizer::new();
1713        let expr = LogicalExpression::Type("e".to_string());
1714        let vars = optimizer.extract_variables(&expr);
1715        assert_eq!(vars.len(), 1);
1716        assert!(vars.contains("e"));
1717    }
1718
1719    #[test]
1720    fn test_extract_variables_from_id() {
1721        let optimizer = Optimizer::new();
1722        let expr = LogicalExpression::Id("n".to_string());
1723        let vars = optimizer.extract_variables(&expr);
1724        assert_eq!(vars.len(), 1);
1725        assert!(vars.contains("n"));
1726    }
1727
1728    #[test]
1729    fn test_extract_variables_from_list_comprehension() {
1730        let optimizer = Optimizer::new();
1731        let expr = LogicalExpression::ListComprehension {
1732            variable: "x".to_string(),
1733            list_expr: Box::new(LogicalExpression::Variable("items".to_string())),
1734            filter_expr: Some(Box::new(LogicalExpression::Variable("pred".to_string()))),
1735            map_expr: Box::new(LogicalExpression::Variable("result".to_string())),
1736        };
1737        let vars = optimizer.extract_variables(&expr);
1738        assert!(vars.contains("items"));
1739        assert!(vars.contains("pred"));
1740        assert!(vars.contains("result"));
1741    }
1742
1743    #[test]
1744    fn test_extract_variables_from_literal_and_parameter() {
1745        let optimizer = Optimizer::new();
1746
1747        let literal = LogicalExpression::Literal(Value::Int64(42));
1748        assert!(optimizer.extract_variables(&literal).is_empty());
1749
1750        let param = LogicalExpression::Parameter("p".to_string());
1751        assert!(optimizer.extract_variables(&param).is_empty());
1752    }
1753
1754    // Recursive filter pushdown tests
1755
1756    #[test]
1757    fn test_recursive_filter_pushdown_through_skip() {
1758        let optimizer = Optimizer::new();
1759
1760        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1761            items: vec![ReturnItem {
1762                expression: LogicalExpression::Variable("n".to_string()),
1763                alias: None,
1764            }],
1765            distinct: false,
1766            input: Box::new(LogicalOperator::Filter(FilterOp {
1767                predicate: LogicalExpression::Literal(Value::Bool(true)),
1768                input: Box::new(LogicalOperator::Skip(SkipOp {
1769                    count: 5,
1770                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1771                        variable: "n".to_string(),
1772                        label: None,
1773                        input: None,
1774                    })),
1775                })),
1776            })),
1777        }));
1778
1779        let optimized = optimizer.optimize(plan).unwrap();
1780
1781        // Verify optimization succeeded
1782        assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1783    }
1784
1785    #[test]
1786    fn test_nested_filter_pushdown() {
1787        let optimizer = Optimizer::new();
1788
1789        // Multiple nested filters
1790        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1791            items: vec![ReturnItem {
1792                expression: LogicalExpression::Variable("n".to_string()),
1793                alias: None,
1794            }],
1795            distinct: false,
1796            input: Box::new(LogicalOperator::Filter(FilterOp {
1797                predicate: LogicalExpression::Binary {
1798                    left: Box::new(LogicalExpression::Property {
1799                        variable: "n".to_string(),
1800                        property: "x".to_string(),
1801                    }),
1802                    op: BinaryOp::Gt,
1803                    right: Box::new(LogicalExpression::Literal(Value::Int64(1))),
1804                },
1805                input: Box::new(LogicalOperator::Filter(FilterOp {
1806                    predicate: LogicalExpression::Binary {
1807                        left: Box::new(LogicalExpression::Property {
1808                            variable: "n".to_string(),
1809                            property: "y".to_string(),
1810                        }),
1811                        op: BinaryOp::Lt,
1812                        right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1813                    },
1814                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1815                        variable: "n".to_string(),
1816                        label: None,
1817                        input: None,
1818                    })),
1819                })),
1820            })),
1821        }));
1822
1823        let optimized = optimizer.optimize(plan).unwrap();
1824        assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1825    }
1826}