Skip to main content

grafeo_engine/query/optimizer/
mod.rs

1//! Makes your queries faster without changing their meaning.
2//!
3//! The optimizer transforms logical plans to run more efficiently:
4//!
5//! | Optimization | What it does |
6//! | ------------ | ------------ |
7//! | Filter Pushdown | Moves `WHERE` clauses closer to scans - filter early, process less |
8//! | Join Reordering | Picks the best order to join tables using the DPccp algorithm |
9//! | Predicate Simplification | Folds constants like `1 + 1` into `2` |
10//!
11//! The optimizer uses [`CostModel`] and [`CardinalityEstimator`] to predict
12//! how expensive different plans are, then picks the cheapest.
13
14pub mod cardinality;
15pub mod cost;
16pub mod join_order;
17
18pub use cardinality::{
19    CardinalityEstimator, ColumnStats, EstimationLog, SelectivityConfig, TableStats,
20};
21pub use cost::{Cost, CostModel};
22pub use join_order::{BitSet, DPccp, JoinGraph, JoinGraphBuilder, JoinPlan};
23
24use crate::query::plan::{FilterOp, LogicalExpression, LogicalOperator, LogicalPlan};
25use grafeo_common::utils::error::Result;
26use std::collections::HashSet;
27
28/// Information about a join condition for join reordering.
29#[derive(Debug, Clone)]
30struct JoinInfo {
31    left_var: String,
32    right_var: String,
33    left_expr: LogicalExpression,
34    right_expr: LogicalExpression,
35}
36
37/// A column required by the query, used for projection pushdown.
38#[derive(Debug, Clone, PartialEq, Eq, Hash)]
39enum RequiredColumn {
40    /// A variable (node, edge, or path binding)
41    Variable(String),
42    /// A specific property of a variable
43    Property(String, String),
44}
45
46/// Transforms logical plans for faster execution.
47///
48/// Create with [`new()`](Self::new), then call [`optimize()`](Self::optimize).
49/// Use the builder methods to enable/disable specific optimizations.
50pub struct Optimizer {
51    /// Whether to enable filter pushdown.
52    enable_filter_pushdown: bool,
53    /// Whether to enable join reordering.
54    enable_join_reorder: bool,
55    /// Whether to enable projection pushdown.
56    enable_projection_pushdown: bool,
57    /// Cost model for estimation.
58    cost_model: CostModel,
59    /// Cardinality estimator.
60    card_estimator: CardinalityEstimator,
61}
62
63impl Optimizer {
64    /// Creates a new optimizer with default settings.
65    #[must_use]
66    pub fn new() -> Self {
67        Self {
68            enable_filter_pushdown: true,
69            enable_join_reorder: true,
70            enable_projection_pushdown: true,
71            cost_model: CostModel::new(),
72            card_estimator: CardinalityEstimator::new(),
73        }
74    }
75
76    /// Creates an optimizer with cardinality estimates from the store's statistics.
77    ///
78    /// Pre-populates the cardinality estimator with per-label row counts and
79    /// edge type fanout. Ensures statistics are fresh before reading.
80    #[must_use]
81    pub fn from_store(store: &grafeo_core::graph::lpg::LpgStore) -> Self {
82        store.ensure_statistics_fresh();
83        let stats = store.statistics();
84        let estimator = CardinalityEstimator::from_statistics(&stats);
85        Self {
86            enable_filter_pushdown: true,
87            enable_join_reorder: true,
88            enable_projection_pushdown: true,
89            cost_model: CostModel::new(),
90            card_estimator: estimator,
91        }
92    }
93
94    /// Enables or disables filter pushdown.
95    pub fn with_filter_pushdown(mut self, enabled: bool) -> Self {
96        self.enable_filter_pushdown = enabled;
97        self
98    }
99
100    /// Enables or disables join reordering.
101    pub fn with_join_reorder(mut self, enabled: bool) -> Self {
102        self.enable_join_reorder = enabled;
103        self
104    }
105
106    /// Enables or disables projection pushdown.
107    pub fn with_projection_pushdown(mut self, enabled: bool) -> Self {
108        self.enable_projection_pushdown = enabled;
109        self
110    }
111
112    /// Sets the cost model.
113    pub fn with_cost_model(mut self, cost_model: CostModel) -> Self {
114        self.cost_model = cost_model;
115        self
116    }
117
118    /// Sets the cardinality estimator.
119    pub fn with_cardinality_estimator(mut self, estimator: CardinalityEstimator) -> Self {
120        self.card_estimator = estimator;
121        self
122    }
123
124    /// Sets the selectivity configuration for the cardinality estimator.
125    pub fn with_selectivity_config(mut self, config: SelectivityConfig) -> Self {
126        self.card_estimator = CardinalityEstimator::with_selectivity_config(config);
127        self
128    }
129
130    /// Returns a reference to the cost model.
131    pub fn cost_model(&self) -> &CostModel {
132        &self.cost_model
133    }
134
135    /// Returns a reference to the cardinality estimator.
136    pub fn cardinality_estimator(&self) -> &CardinalityEstimator {
137        &self.card_estimator
138    }
139
140    /// Estimates the cost of a plan.
141    pub fn estimate_cost(&self, plan: &LogicalPlan) -> Cost {
142        let cardinality = self.card_estimator.estimate(&plan.root);
143        self.cost_model.estimate(&plan.root, cardinality)
144    }
145
146    /// Estimates the cardinality of a plan.
147    pub fn estimate_cardinality(&self, plan: &LogicalPlan) -> f64 {
148        self.card_estimator.estimate(&plan.root)
149    }
150
151    /// Optimizes a logical plan.
152    ///
153    /// # Errors
154    ///
155    /// Returns an error if optimization fails.
156    pub fn optimize(&self, plan: LogicalPlan) -> Result<LogicalPlan> {
157        let mut root = plan.root;
158
159        // Apply optimization rules
160        if self.enable_filter_pushdown {
161            root = self.push_filters_down(root);
162        }
163
164        if self.enable_join_reorder {
165            root = self.reorder_joins(root);
166        }
167
168        if self.enable_projection_pushdown {
169            root = self.push_projections_down(root);
170        }
171
172        Ok(LogicalPlan::new(root))
173    }
174
175    /// Pushes projections down the operator tree to eliminate unused columns early.
176    ///
177    /// This optimization:
178    /// 1. Collects required variables/properties from the root
179    /// 2. Propagates requirements down through the tree
180    /// 3. Inserts projections to eliminate unneeded columns before expensive operations
181    fn push_projections_down(&self, op: LogicalOperator) -> LogicalOperator {
182        // Collect required columns from the top of the plan
183        let required = self.collect_required_columns(&op);
184
185        // Push projections down
186        self.push_projections_recursive(op, &required)
187    }
188
189    /// Collects all variables and properties required by an operator and its ancestors.
190    fn collect_required_columns(&self, op: &LogicalOperator) -> HashSet<RequiredColumn> {
191        let mut required = HashSet::new();
192        Self::collect_required_recursive(op, &mut required);
193        required
194    }
195
196    /// Recursively collects required columns.
197    fn collect_required_recursive(op: &LogicalOperator, required: &mut HashSet<RequiredColumn>) {
198        match op {
199            LogicalOperator::Return(ret) => {
200                for item in &ret.items {
201                    Self::collect_from_expression(&item.expression, required);
202                }
203                Self::collect_required_recursive(&ret.input, required);
204            }
205            LogicalOperator::Project(proj) => {
206                for p in &proj.projections {
207                    Self::collect_from_expression(&p.expression, required);
208                }
209                Self::collect_required_recursive(&proj.input, required);
210            }
211            LogicalOperator::Filter(filter) => {
212                Self::collect_from_expression(&filter.predicate, required);
213                Self::collect_required_recursive(&filter.input, required);
214            }
215            LogicalOperator::Sort(sort) => {
216                for key in &sort.keys {
217                    Self::collect_from_expression(&key.expression, required);
218                }
219                Self::collect_required_recursive(&sort.input, required);
220            }
221            LogicalOperator::Aggregate(agg) => {
222                for expr in &agg.group_by {
223                    Self::collect_from_expression(expr, required);
224                }
225                for agg_expr in &agg.aggregates {
226                    if let Some(ref expr) = agg_expr.expression {
227                        Self::collect_from_expression(expr, required);
228                    }
229                }
230                if let Some(ref having) = agg.having {
231                    Self::collect_from_expression(having, required);
232                }
233                Self::collect_required_recursive(&agg.input, required);
234            }
235            LogicalOperator::Join(join) => {
236                for cond in &join.conditions {
237                    Self::collect_from_expression(&cond.left, required);
238                    Self::collect_from_expression(&cond.right, required);
239                }
240                Self::collect_required_recursive(&join.left, required);
241                Self::collect_required_recursive(&join.right, required);
242            }
243            LogicalOperator::Expand(expand) => {
244                // The source and target variables are needed
245                required.insert(RequiredColumn::Variable(expand.from_variable.clone()));
246                required.insert(RequiredColumn::Variable(expand.to_variable.clone()));
247                if let Some(ref edge_var) = expand.edge_variable {
248                    required.insert(RequiredColumn::Variable(edge_var.clone()));
249                }
250                Self::collect_required_recursive(&expand.input, required);
251            }
252            LogicalOperator::Limit(limit) => {
253                Self::collect_required_recursive(&limit.input, required);
254            }
255            LogicalOperator::Skip(skip) => {
256                Self::collect_required_recursive(&skip.input, required);
257            }
258            LogicalOperator::Distinct(distinct) => {
259                Self::collect_required_recursive(&distinct.input, required);
260            }
261            LogicalOperator::NodeScan(scan) => {
262                required.insert(RequiredColumn::Variable(scan.variable.clone()));
263            }
264            LogicalOperator::EdgeScan(scan) => {
265                required.insert(RequiredColumn::Variable(scan.variable.clone()));
266            }
267            _ => {}
268        }
269    }
270
271    /// Collects required columns from an expression.
272    fn collect_from_expression(expr: &LogicalExpression, required: &mut HashSet<RequiredColumn>) {
273        match expr {
274            LogicalExpression::Variable(var) => {
275                required.insert(RequiredColumn::Variable(var.clone()));
276            }
277            LogicalExpression::Property { variable, property } => {
278                required.insert(RequiredColumn::Property(variable.clone(), property.clone()));
279                required.insert(RequiredColumn::Variable(variable.clone()));
280            }
281            LogicalExpression::Binary { left, right, .. } => {
282                Self::collect_from_expression(left, required);
283                Self::collect_from_expression(right, required);
284            }
285            LogicalExpression::Unary { operand, .. } => {
286                Self::collect_from_expression(operand, required);
287            }
288            LogicalExpression::FunctionCall { args, .. } => {
289                for arg in args {
290                    Self::collect_from_expression(arg, required);
291                }
292            }
293            LogicalExpression::List(items) => {
294                for item in items {
295                    Self::collect_from_expression(item, required);
296                }
297            }
298            LogicalExpression::Map(pairs) => {
299                for (_, value) in pairs {
300                    Self::collect_from_expression(value, required);
301                }
302            }
303            LogicalExpression::IndexAccess { base, index } => {
304                Self::collect_from_expression(base, required);
305                Self::collect_from_expression(index, required);
306            }
307            LogicalExpression::SliceAccess { base, start, end } => {
308                Self::collect_from_expression(base, required);
309                if let Some(s) = start {
310                    Self::collect_from_expression(s, required);
311                }
312                if let Some(e) = end {
313                    Self::collect_from_expression(e, required);
314                }
315            }
316            LogicalExpression::Case {
317                operand,
318                when_clauses,
319                else_clause,
320            } => {
321                if let Some(op) = operand {
322                    Self::collect_from_expression(op, required);
323                }
324                for (cond, result) in when_clauses {
325                    Self::collect_from_expression(cond, required);
326                    Self::collect_from_expression(result, required);
327                }
328                if let Some(else_expr) = else_clause {
329                    Self::collect_from_expression(else_expr, required);
330                }
331            }
332            LogicalExpression::Labels(var)
333            | LogicalExpression::Type(var)
334            | LogicalExpression::Id(var) => {
335                required.insert(RequiredColumn::Variable(var.clone()));
336            }
337            LogicalExpression::ListComprehension {
338                list_expr,
339                filter_expr,
340                map_expr,
341                ..
342            } => {
343                Self::collect_from_expression(list_expr, required);
344                if let Some(filter) = filter_expr {
345                    Self::collect_from_expression(filter, required);
346                }
347                Self::collect_from_expression(map_expr, required);
348            }
349            _ => {}
350        }
351    }
352
353    /// Recursively pushes projections down, adding them before expensive operations.
354    fn push_projections_recursive(
355        &self,
356        op: LogicalOperator,
357        required: &HashSet<RequiredColumn>,
358    ) -> LogicalOperator {
359        match op {
360            LogicalOperator::Return(mut ret) => {
361                ret.input = Box::new(self.push_projections_recursive(*ret.input, required));
362                LogicalOperator::Return(ret)
363            }
364            LogicalOperator::Project(mut proj) => {
365                proj.input = Box::new(self.push_projections_recursive(*proj.input, required));
366                LogicalOperator::Project(proj)
367            }
368            LogicalOperator::Filter(mut filter) => {
369                filter.input = Box::new(self.push_projections_recursive(*filter.input, required));
370                LogicalOperator::Filter(filter)
371            }
372            LogicalOperator::Sort(mut sort) => {
373                // Sort is expensive - consider adding a projection before it
374                // to reduce tuple width
375                sort.input = Box::new(self.push_projections_recursive(*sort.input, required));
376                LogicalOperator::Sort(sort)
377            }
378            LogicalOperator::Aggregate(mut agg) => {
379                agg.input = Box::new(self.push_projections_recursive(*agg.input, required));
380                LogicalOperator::Aggregate(agg)
381            }
382            LogicalOperator::Join(mut join) => {
383                // Joins are expensive - the required columns help determine
384                // what to project on each side
385                let left_vars = self.collect_output_variables(&join.left);
386                let right_vars = self.collect_output_variables(&join.right);
387
388                // Filter required columns to each side
389                let left_required: HashSet<_> = required
390                    .iter()
391                    .filter(|c| match c {
392                        RequiredColumn::Variable(v) => left_vars.contains(v),
393                        RequiredColumn::Property(v, _) => left_vars.contains(v),
394                    })
395                    .cloned()
396                    .collect();
397
398                let right_required: HashSet<_> = required
399                    .iter()
400                    .filter(|c| match c {
401                        RequiredColumn::Variable(v) => right_vars.contains(v),
402                        RequiredColumn::Property(v, _) => right_vars.contains(v),
403                    })
404                    .cloned()
405                    .collect();
406
407                join.left = Box::new(self.push_projections_recursive(*join.left, &left_required));
408                join.right =
409                    Box::new(self.push_projections_recursive(*join.right, &right_required));
410                LogicalOperator::Join(join)
411            }
412            LogicalOperator::Expand(mut expand) => {
413                expand.input = Box::new(self.push_projections_recursive(*expand.input, required));
414                LogicalOperator::Expand(expand)
415            }
416            LogicalOperator::Limit(mut limit) => {
417                limit.input = Box::new(self.push_projections_recursive(*limit.input, required));
418                LogicalOperator::Limit(limit)
419            }
420            LogicalOperator::Skip(mut skip) => {
421                skip.input = Box::new(self.push_projections_recursive(*skip.input, required));
422                LogicalOperator::Skip(skip)
423            }
424            LogicalOperator::Distinct(mut distinct) => {
425                distinct.input =
426                    Box::new(self.push_projections_recursive(*distinct.input, required));
427                LogicalOperator::Distinct(distinct)
428            }
429            other => other,
430        }
431    }
432
433    /// Reorders joins in the operator tree using the DPccp algorithm.
434    ///
435    /// This optimization finds the optimal join order by:
436    /// 1. Extracting all base relations (scans) and join conditions
437    /// 2. Building a join graph
438    /// 3. Using dynamic programming to find the cheapest join order
439    fn reorder_joins(&self, op: LogicalOperator) -> LogicalOperator {
440        // First, recursively optimize children
441        let op = self.reorder_joins_recursive(op);
442
443        // Then, if this is a join tree, try to optimize it
444        if let Some((relations, conditions)) = self.extract_join_tree(&op)
445            && relations.len() >= 2
446            && let Some(optimized) = self.optimize_join_order(&relations, &conditions)
447        {
448            return optimized;
449        }
450
451        op
452    }
453
454    /// Recursively applies join reordering to child operators.
455    fn reorder_joins_recursive(&self, op: LogicalOperator) -> LogicalOperator {
456        match op {
457            LogicalOperator::Return(mut ret) => {
458                ret.input = Box::new(self.reorder_joins(*ret.input));
459                LogicalOperator::Return(ret)
460            }
461            LogicalOperator::Project(mut proj) => {
462                proj.input = Box::new(self.reorder_joins(*proj.input));
463                LogicalOperator::Project(proj)
464            }
465            LogicalOperator::Filter(mut filter) => {
466                filter.input = Box::new(self.reorder_joins(*filter.input));
467                LogicalOperator::Filter(filter)
468            }
469            LogicalOperator::Limit(mut limit) => {
470                limit.input = Box::new(self.reorder_joins(*limit.input));
471                LogicalOperator::Limit(limit)
472            }
473            LogicalOperator::Skip(mut skip) => {
474                skip.input = Box::new(self.reorder_joins(*skip.input));
475                LogicalOperator::Skip(skip)
476            }
477            LogicalOperator::Sort(mut sort) => {
478                sort.input = Box::new(self.reorder_joins(*sort.input));
479                LogicalOperator::Sort(sort)
480            }
481            LogicalOperator::Distinct(mut distinct) => {
482                distinct.input = Box::new(self.reorder_joins(*distinct.input));
483                LogicalOperator::Distinct(distinct)
484            }
485            LogicalOperator::Aggregate(mut agg) => {
486                agg.input = Box::new(self.reorder_joins(*agg.input));
487                LogicalOperator::Aggregate(agg)
488            }
489            LogicalOperator::Expand(mut expand) => {
490                expand.input = Box::new(self.reorder_joins(*expand.input));
491                LogicalOperator::Expand(expand)
492            }
493            // Join operators are handled by the parent reorder_joins call
494            other => other,
495        }
496    }
497
498    /// Extracts base relations and join conditions from a join tree.
499    ///
500    /// Returns None if the operator is not a join tree.
501    fn extract_join_tree(
502        &self,
503        op: &LogicalOperator,
504    ) -> Option<(Vec<(String, LogicalOperator)>, Vec<JoinInfo>)> {
505        let mut relations = Vec::new();
506        let mut join_conditions = Vec::new();
507
508        if !self.collect_join_tree(op, &mut relations, &mut join_conditions) {
509            return None;
510        }
511
512        if relations.len() < 2 {
513            return None;
514        }
515
516        Some((relations, join_conditions))
517    }
518
519    /// Recursively collects base relations and join conditions.
520    ///
521    /// Returns true if this subtree is part of a join tree.
522    fn collect_join_tree(
523        &self,
524        op: &LogicalOperator,
525        relations: &mut Vec<(String, LogicalOperator)>,
526        conditions: &mut Vec<JoinInfo>,
527    ) -> bool {
528        match op {
529            LogicalOperator::Join(join) => {
530                // Collect from both sides
531                let left_ok = self.collect_join_tree(&join.left, relations, conditions);
532                let right_ok = self.collect_join_tree(&join.right, relations, conditions);
533
534                // Add conditions from this join
535                for cond in &join.conditions {
536                    if let (Some(left_var), Some(right_var)) = (
537                        self.extract_variable_from_expr(&cond.left),
538                        self.extract_variable_from_expr(&cond.right),
539                    ) {
540                        conditions.push(JoinInfo {
541                            left_var,
542                            right_var,
543                            left_expr: cond.left.clone(),
544                            right_expr: cond.right.clone(),
545                        });
546                    }
547                }
548
549                left_ok && right_ok
550            }
551            LogicalOperator::NodeScan(scan) => {
552                relations.push((scan.variable.clone(), op.clone()));
553                true
554            }
555            LogicalOperator::EdgeScan(scan) => {
556                relations.push((scan.variable.clone(), op.clone()));
557                true
558            }
559            LogicalOperator::Filter(filter) => {
560                // A filter on a base relation is still part of the join tree
561                self.collect_join_tree(&filter.input, relations, conditions)
562            }
563            LogicalOperator::Expand(expand) => {
564                // Expand is a special case - it's like a join with the adjacency
565                // For now, treat the whole Expand subtree as a single relation
566                relations.push((expand.to_variable.clone(), op.clone()));
567                true
568            }
569            _ => false,
570        }
571    }
572
573    /// Extracts the primary variable from an expression.
574    fn extract_variable_from_expr(&self, expr: &LogicalExpression) -> Option<String> {
575        match expr {
576            LogicalExpression::Variable(v) => Some(v.clone()),
577            LogicalExpression::Property { variable, .. } => Some(variable.clone()),
578            _ => None,
579        }
580    }
581
582    /// Optimizes the join order using DPccp.
583    fn optimize_join_order(
584        &self,
585        relations: &[(String, LogicalOperator)],
586        conditions: &[JoinInfo],
587    ) -> Option<LogicalOperator> {
588        use join_order::{DPccp, JoinGraphBuilder};
589
590        // Build the join graph
591        let mut builder = JoinGraphBuilder::new();
592
593        for (var, relation) in relations {
594            builder.add_relation(var, relation.clone());
595        }
596
597        for cond in conditions {
598            builder.add_join_condition(
599                &cond.left_var,
600                &cond.right_var,
601                cond.left_expr.clone(),
602                cond.right_expr.clone(),
603            );
604        }
605
606        let graph = builder.build();
607
608        // Run DPccp
609        let mut dpccp = DPccp::new(&graph, &self.cost_model, &self.card_estimator);
610        let plan = dpccp.optimize()?;
611
612        Some(plan.operator)
613    }
614
615    /// Pushes filters down the operator tree.
616    ///
617    /// This optimization moves filter predicates as close to the data source
618    /// as possible to reduce the amount of data processed by upper operators.
619    fn push_filters_down(&self, op: LogicalOperator) -> LogicalOperator {
620        match op {
621            // For Filter operators, try to push the predicate into the child
622            LogicalOperator::Filter(filter) => {
623                let optimized_input = self.push_filters_down(*filter.input);
624                self.try_push_filter_into(filter.predicate, optimized_input)
625            }
626            // Recursively optimize children for other operators
627            LogicalOperator::Return(mut ret) => {
628                ret.input = Box::new(self.push_filters_down(*ret.input));
629                LogicalOperator::Return(ret)
630            }
631            LogicalOperator::Project(mut proj) => {
632                proj.input = Box::new(self.push_filters_down(*proj.input));
633                LogicalOperator::Project(proj)
634            }
635            LogicalOperator::Limit(mut limit) => {
636                limit.input = Box::new(self.push_filters_down(*limit.input));
637                LogicalOperator::Limit(limit)
638            }
639            LogicalOperator::Skip(mut skip) => {
640                skip.input = Box::new(self.push_filters_down(*skip.input));
641                LogicalOperator::Skip(skip)
642            }
643            LogicalOperator::Sort(mut sort) => {
644                sort.input = Box::new(self.push_filters_down(*sort.input));
645                LogicalOperator::Sort(sort)
646            }
647            LogicalOperator::Distinct(mut distinct) => {
648                distinct.input = Box::new(self.push_filters_down(*distinct.input));
649                LogicalOperator::Distinct(distinct)
650            }
651            LogicalOperator::Expand(mut expand) => {
652                expand.input = Box::new(self.push_filters_down(*expand.input));
653                LogicalOperator::Expand(expand)
654            }
655            LogicalOperator::Join(mut join) => {
656                join.left = Box::new(self.push_filters_down(*join.left));
657                join.right = Box::new(self.push_filters_down(*join.right));
658                LogicalOperator::Join(join)
659            }
660            LogicalOperator::Aggregate(mut agg) => {
661                agg.input = Box::new(self.push_filters_down(*agg.input));
662                LogicalOperator::Aggregate(agg)
663            }
664            // Leaf operators and unsupported operators are returned as-is
665            other => other,
666        }
667    }
668
669    /// Tries to push a filter predicate into the given operator.
670    ///
671    /// Returns either the predicate pushed into the operator, or a new
672    /// Filter operator on top if the predicate cannot be pushed further.
673    fn try_push_filter_into(
674        &self,
675        predicate: LogicalExpression,
676        op: LogicalOperator,
677    ) -> LogicalOperator {
678        match op {
679            // Can push through Project if predicate doesn't depend on computed columns
680            LogicalOperator::Project(mut proj) => {
681                let predicate_vars = self.extract_variables(&predicate);
682                let computed_vars = self.extract_projection_aliases(&proj.projections);
683
684                // If predicate doesn't use any computed columns, push through
685                if predicate_vars.is_disjoint(&computed_vars) {
686                    proj.input = Box::new(self.try_push_filter_into(predicate, *proj.input));
687                    LogicalOperator::Project(proj)
688                } else {
689                    // Can't push through, keep filter on top
690                    LogicalOperator::Filter(FilterOp {
691                        predicate,
692                        input: Box::new(LogicalOperator::Project(proj)),
693                    })
694                }
695            }
696
697            // Can push through Return (which is like a projection)
698            LogicalOperator::Return(mut ret) => {
699                ret.input = Box::new(self.try_push_filter_into(predicate, *ret.input));
700                LogicalOperator::Return(ret)
701            }
702
703            // Can push through Expand if predicate doesn't use variables introduced by this expand
704            LogicalOperator::Expand(mut expand) => {
705                let predicate_vars = self.extract_variables(&predicate);
706
707                // Variables introduced by this expand are:
708                // - The target variable (to_variable)
709                // - The edge variable (if any)
710                // - The path alias (if any)
711                let mut introduced_vars = vec![&expand.to_variable];
712                if let Some(ref edge_var) = expand.edge_variable {
713                    introduced_vars.push(edge_var);
714                }
715                if let Some(ref path_alias) = expand.path_alias {
716                    introduced_vars.push(path_alias);
717                }
718
719                // Check if predicate uses any variables introduced by this expand
720                let uses_introduced_vars =
721                    predicate_vars.iter().any(|v| introduced_vars.contains(&v));
722
723                if !uses_introduced_vars {
724                    // Predicate doesn't use vars from this expand, so push through
725                    expand.input = Box::new(self.try_push_filter_into(predicate, *expand.input));
726                    LogicalOperator::Expand(expand)
727                } else {
728                    // Keep filter after expand
729                    LogicalOperator::Filter(FilterOp {
730                        predicate,
731                        input: Box::new(LogicalOperator::Expand(expand)),
732                    })
733                }
734            }
735
736            // Can push through Join to left/right side based on variables used
737            LogicalOperator::Join(mut join) => {
738                let predicate_vars = self.extract_variables(&predicate);
739                let left_vars = self.collect_output_variables(&join.left);
740                let right_vars = self.collect_output_variables(&join.right);
741
742                let uses_left = predicate_vars.iter().any(|v| left_vars.contains(v));
743                let uses_right = predicate_vars.iter().any(|v| right_vars.contains(v));
744
745                if uses_left && !uses_right {
746                    // Push to left side
747                    join.left = Box::new(self.try_push_filter_into(predicate, *join.left));
748                    LogicalOperator::Join(join)
749                } else if uses_right && !uses_left {
750                    // Push to right side
751                    join.right = Box::new(self.try_push_filter_into(predicate, *join.right));
752                    LogicalOperator::Join(join)
753                } else {
754                    // Uses both sides - keep above join
755                    LogicalOperator::Filter(FilterOp {
756                        predicate,
757                        input: Box::new(LogicalOperator::Join(join)),
758                    })
759                }
760            }
761
762            // Cannot push through Aggregate (predicate refers to aggregated values)
763            LogicalOperator::Aggregate(agg) => LogicalOperator::Filter(FilterOp {
764                predicate,
765                input: Box::new(LogicalOperator::Aggregate(agg)),
766            }),
767
768            // For NodeScan, we've reached the bottom - keep filter on top
769            LogicalOperator::NodeScan(scan) => LogicalOperator::Filter(FilterOp {
770                predicate,
771                input: Box::new(LogicalOperator::NodeScan(scan)),
772            }),
773
774            // For other operators, keep filter on top
775            other => LogicalOperator::Filter(FilterOp {
776                predicate,
777                input: Box::new(other),
778            }),
779        }
780    }
781
782    /// Collects all output variable names from an operator.
783    fn collect_output_variables(&self, op: &LogicalOperator) -> HashSet<String> {
784        let mut vars = HashSet::new();
785        Self::collect_output_variables_recursive(op, &mut vars);
786        vars
787    }
788
789    /// Recursively collects output variables from an operator.
790    fn collect_output_variables_recursive(op: &LogicalOperator, vars: &mut HashSet<String>) {
791        match op {
792            LogicalOperator::NodeScan(scan) => {
793                vars.insert(scan.variable.clone());
794            }
795            LogicalOperator::EdgeScan(scan) => {
796                vars.insert(scan.variable.clone());
797            }
798            LogicalOperator::Expand(expand) => {
799                vars.insert(expand.to_variable.clone());
800                if let Some(edge_var) = &expand.edge_variable {
801                    vars.insert(edge_var.clone());
802                }
803                Self::collect_output_variables_recursive(&expand.input, vars);
804            }
805            LogicalOperator::Filter(filter) => {
806                Self::collect_output_variables_recursive(&filter.input, vars);
807            }
808            LogicalOperator::Project(proj) => {
809                for p in &proj.projections {
810                    if let Some(alias) = &p.alias {
811                        vars.insert(alias.clone());
812                    }
813                }
814                Self::collect_output_variables_recursive(&proj.input, vars);
815            }
816            LogicalOperator::Join(join) => {
817                Self::collect_output_variables_recursive(&join.left, vars);
818                Self::collect_output_variables_recursive(&join.right, vars);
819            }
820            LogicalOperator::Aggregate(agg) => {
821                for expr in &agg.group_by {
822                    Self::collect_variables(expr, vars);
823                }
824                for agg_expr in &agg.aggregates {
825                    if let Some(alias) = &agg_expr.alias {
826                        vars.insert(alias.clone());
827                    }
828                }
829            }
830            LogicalOperator::Return(ret) => {
831                Self::collect_output_variables_recursive(&ret.input, vars);
832            }
833            LogicalOperator::Limit(limit) => {
834                Self::collect_output_variables_recursive(&limit.input, vars);
835            }
836            LogicalOperator::Skip(skip) => {
837                Self::collect_output_variables_recursive(&skip.input, vars);
838            }
839            LogicalOperator::Sort(sort) => {
840                Self::collect_output_variables_recursive(&sort.input, vars);
841            }
842            LogicalOperator::Distinct(distinct) => {
843                Self::collect_output_variables_recursive(&distinct.input, vars);
844            }
845            _ => {}
846        }
847    }
848
849    /// Extracts all variable names referenced in an expression.
850    fn extract_variables(&self, expr: &LogicalExpression) -> HashSet<String> {
851        let mut vars = HashSet::new();
852        Self::collect_variables(expr, &mut vars);
853        vars
854    }
855
856    /// Recursively collects variable names from an expression.
857    fn collect_variables(expr: &LogicalExpression, vars: &mut HashSet<String>) {
858        match expr {
859            LogicalExpression::Variable(name) => {
860                vars.insert(name.clone());
861            }
862            LogicalExpression::Property { variable, .. } => {
863                vars.insert(variable.clone());
864            }
865            LogicalExpression::Binary { left, right, .. } => {
866                Self::collect_variables(left, vars);
867                Self::collect_variables(right, vars);
868            }
869            LogicalExpression::Unary { operand, .. } => {
870                Self::collect_variables(operand, vars);
871            }
872            LogicalExpression::FunctionCall { args, .. } => {
873                for arg in args {
874                    Self::collect_variables(arg, vars);
875                }
876            }
877            LogicalExpression::List(items) => {
878                for item in items {
879                    Self::collect_variables(item, vars);
880                }
881            }
882            LogicalExpression::Map(pairs) => {
883                for (_, value) in pairs {
884                    Self::collect_variables(value, vars);
885                }
886            }
887            LogicalExpression::IndexAccess { base, index } => {
888                Self::collect_variables(base, vars);
889                Self::collect_variables(index, vars);
890            }
891            LogicalExpression::SliceAccess { base, start, end } => {
892                Self::collect_variables(base, vars);
893                if let Some(s) = start {
894                    Self::collect_variables(s, vars);
895                }
896                if let Some(e) = end {
897                    Self::collect_variables(e, vars);
898                }
899            }
900            LogicalExpression::Case {
901                operand,
902                when_clauses,
903                else_clause,
904            } => {
905                if let Some(op) = operand {
906                    Self::collect_variables(op, vars);
907                }
908                for (cond, result) in when_clauses {
909                    Self::collect_variables(cond, vars);
910                    Self::collect_variables(result, vars);
911                }
912                if let Some(else_expr) = else_clause {
913                    Self::collect_variables(else_expr, vars);
914                }
915            }
916            LogicalExpression::Labels(var)
917            | LogicalExpression::Type(var)
918            | LogicalExpression::Id(var) => {
919                vars.insert(var.clone());
920            }
921            LogicalExpression::Literal(_) | LogicalExpression::Parameter(_) => {}
922            LogicalExpression::ListComprehension {
923                list_expr,
924                filter_expr,
925                map_expr,
926                ..
927            } => {
928                Self::collect_variables(list_expr, vars);
929                if let Some(filter) = filter_expr {
930                    Self::collect_variables(filter, vars);
931                }
932                Self::collect_variables(map_expr, vars);
933            }
934            LogicalExpression::ExistsSubquery(_) | LogicalExpression::CountSubquery(_) => {
935                // Subqueries have their own variable scope
936            }
937        }
938    }
939
940    /// Extracts aliases from projection expressions.
941    fn extract_projection_aliases(
942        &self,
943        projections: &[crate::query::plan::Projection],
944    ) -> HashSet<String> {
945        projections.iter().filter_map(|p| p.alias.clone()).collect()
946    }
947}
948
949impl Default for Optimizer {
950    fn default() -> Self {
951        Self::new()
952    }
953}
954
955#[cfg(test)]
956mod tests {
957    use super::*;
958    use crate::query::plan::{
959        AggregateExpr, AggregateFunction, AggregateOp, BinaryOp, DistinctOp, ExpandDirection,
960        ExpandOp, JoinOp, JoinType, LimitOp, NodeScanOp, ProjectOp, Projection, ReturnItem,
961        ReturnOp, SkipOp, SortKey, SortOp, SortOrder, UnaryOp,
962    };
963    use grafeo_common::types::Value;
964
965    #[test]
966    fn test_optimizer_filter_pushdown_simple() {
967        // Query: MATCH (n:Person) WHERE n.age > 30 RETURN n
968        // Before: Return -> Filter -> NodeScan
969        // After:  Return -> Filter -> NodeScan (filter stays at bottom)
970
971        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
972            items: vec![ReturnItem {
973                expression: LogicalExpression::Variable("n".to_string()),
974                alias: None,
975            }],
976            distinct: false,
977            input: Box::new(LogicalOperator::Filter(FilterOp {
978                predicate: LogicalExpression::Binary {
979                    left: Box::new(LogicalExpression::Property {
980                        variable: "n".to_string(),
981                        property: "age".to_string(),
982                    }),
983                    op: BinaryOp::Gt,
984                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
985                },
986                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
987                    variable: "n".to_string(),
988                    label: Some("Person".to_string()),
989                    input: None,
990                })),
991            })),
992        }));
993
994        let optimizer = Optimizer::new();
995        let optimized = optimizer.optimize(plan).unwrap();
996
997        // The structure should remain similar (filter stays near scan)
998        if let LogicalOperator::Return(ret) = &optimized.root
999            && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1000            && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1001        {
1002            assert_eq!(scan.variable, "n");
1003            return;
1004        }
1005        panic!("Expected Return -> Filter -> NodeScan structure");
1006    }
1007
1008    #[test]
1009    fn test_optimizer_filter_pushdown_through_expand() {
1010        // Query: MATCH (a:Person)-[:KNOWS]->(b) WHERE a.age > 30 RETURN b
1011        // The filter on 'a' should be pushed before the expand
1012
1013        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1014            items: vec![ReturnItem {
1015                expression: LogicalExpression::Variable("b".to_string()),
1016                alias: None,
1017            }],
1018            distinct: false,
1019            input: Box::new(LogicalOperator::Filter(FilterOp {
1020                predicate: LogicalExpression::Binary {
1021                    left: Box::new(LogicalExpression::Property {
1022                        variable: "a".to_string(),
1023                        property: "age".to_string(),
1024                    }),
1025                    op: BinaryOp::Gt,
1026                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1027                },
1028                input: Box::new(LogicalOperator::Expand(ExpandOp {
1029                    from_variable: "a".to_string(),
1030                    to_variable: "b".to_string(),
1031                    edge_variable: None,
1032                    direction: ExpandDirection::Outgoing,
1033                    edge_type: Some("KNOWS".to_string()),
1034                    min_hops: 1,
1035                    max_hops: Some(1),
1036                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1037                        variable: "a".to_string(),
1038                        label: Some("Person".to_string()),
1039                        input: None,
1040                    })),
1041                    path_alias: None,
1042                })),
1043            })),
1044        }));
1045
1046        let optimizer = Optimizer::new();
1047        let optimized = optimizer.optimize(plan).unwrap();
1048
1049        // Filter on 'a' should be pushed before the expand
1050        // Expected: Return -> Expand -> Filter -> NodeScan
1051        if let LogicalOperator::Return(ret) = &optimized.root
1052            && let LogicalOperator::Expand(expand) = ret.input.as_ref()
1053            && let LogicalOperator::Filter(filter) = expand.input.as_ref()
1054            && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1055        {
1056            assert_eq!(scan.variable, "a");
1057            assert_eq!(expand.from_variable, "a");
1058            assert_eq!(expand.to_variable, "b");
1059            return;
1060        }
1061        panic!("Expected Return -> Expand -> Filter -> NodeScan structure");
1062    }
1063
1064    #[test]
1065    fn test_optimizer_filter_not_pushed_through_expand_for_target_var() {
1066        // Query: MATCH (a:Person)-[:KNOWS]->(b) WHERE b.age > 30 RETURN a
1067        // The filter on 'b' should NOT be pushed before the expand
1068
1069        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1070            items: vec![ReturnItem {
1071                expression: LogicalExpression::Variable("a".to_string()),
1072                alias: None,
1073            }],
1074            distinct: false,
1075            input: Box::new(LogicalOperator::Filter(FilterOp {
1076                predicate: LogicalExpression::Binary {
1077                    left: Box::new(LogicalExpression::Property {
1078                        variable: "b".to_string(),
1079                        property: "age".to_string(),
1080                    }),
1081                    op: BinaryOp::Gt,
1082                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1083                },
1084                input: Box::new(LogicalOperator::Expand(ExpandOp {
1085                    from_variable: "a".to_string(),
1086                    to_variable: "b".to_string(),
1087                    edge_variable: None,
1088                    direction: ExpandDirection::Outgoing,
1089                    edge_type: Some("KNOWS".to_string()),
1090                    min_hops: 1,
1091                    max_hops: Some(1),
1092                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1093                        variable: "a".to_string(),
1094                        label: Some("Person".to_string()),
1095                        input: None,
1096                    })),
1097                    path_alias: None,
1098                })),
1099            })),
1100        }));
1101
1102        let optimizer = Optimizer::new();
1103        let optimized = optimizer.optimize(plan).unwrap();
1104
1105        // Filter on 'b' should stay after the expand
1106        // Expected: Return -> Filter -> Expand -> NodeScan
1107        if let LogicalOperator::Return(ret) = &optimized.root
1108            && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1109        {
1110            // Check that the filter is on 'b'
1111            if let LogicalExpression::Binary { left, .. } = &filter.predicate
1112                && let LogicalExpression::Property { variable, .. } = left.as_ref()
1113            {
1114                assert_eq!(variable, "b");
1115            }
1116
1117            if let LogicalOperator::Expand(expand) = filter.input.as_ref()
1118                && let LogicalOperator::NodeScan(_) = expand.input.as_ref()
1119            {
1120                return;
1121            }
1122        }
1123        panic!("Expected Return -> Filter -> Expand -> NodeScan structure");
1124    }
1125
1126    #[test]
1127    fn test_optimizer_extract_variables() {
1128        let optimizer = Optimizer::new();
1129
1130        let expr = LogicalExpression::Binary {
1131            left: Box::new(LogicalExpression::Property {
1132                variable: "n".to_string(),
1133                property: "age".to_string(),
1134            }),
1135            op: BinaryOp::Gt,
1136            right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1137        };
1138
1139        let vars = optimizer.extract_variables(&expr);
1140        assert_eq!(vars.len(), 1);
1141        assert!(vars.contains("n"));
1142    }
1143
1144    // Additional tests for optimizer configuration
1145
1146    #[test]
1147    fn test_optimizer_default() {
1148        let optimizer = Optimizer::default();
1149        // Should be able to optimize an empty plan
1150        let plan = LogicalPlan::new(LogicalOperator::Empty);
1151        let result = optimizer.optimize(plan);
1152        assert!(result.is_ok());
1153    }
1154
1155    #[test]
1156    fn test_optimizer_with_filter_pushdown_disabled() {
1157        let optimizer = Optimizer::new().with_filter_pushdown(false);
1158
1159        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1160            items: vec![ReturnItem {
1161                expression: LogicalExpression::Variable("n".to_string()),
1162                alias: None,
1163            }],
1164            distinct: false,
1165            input: Box::new(LogicalOperator::Filter(FilterOp {
1166                predicate: LogicalExpression::Literal(Value::Bool(true)),
1167                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1168                    variable: "n".to_string(),
1169                    label: None,
1170                    input: None,
1171                })),
1172            })),
1173        }));
1174
1175        let optimized = optimizer.optimize(plan).unwrap();
1176        // Structure should be unchanged
1177        if let LogicalOperator::Return(ret) = &optimized.root
1178            && let LogicalOperator::Filter(_) = ret.input.as_ref()
1179        {
1180            return;
1181        }
1182        panic!("Expected unchanged structure");
1183    }
1184
1185    #[test]
1186    fn test_optimizer_with_join_reorder_disabled() {
1187        let optimizer = Optimizer::new().with_join_reorder(false);
1188        assert!(
1189            optimizer
1190                .optimize(LogicalPlan::new(LogicalOperator::Empty))
1191                .is_ok()
1192        );
1193    }
1194
1195    #[test]
1196    fn test_optimizer_with_cost_model() {
1197        let cost_model = CostModel::new();
1198        let optimizer = Optimizer::new().with_cost_model(cost_model);
1199        assert!(
1200            optimizer
1201                .cost_model()
1202                .estimate(&LogicalOperator::Empty, 0.0)
1203                .total()
1204                < 0.001
1205        );
1206    }
1207
1208    #[test]
1209    fn test_optimizer_with_cardinality_estimator() {
1210        let mut estimator = CardinalityEstimator::new();
1211        estimator.add_table_stats("Test", TableStats::new(500));
1212        let optimizer = Optimizer::new().with_cardinality_estimator(estimator);
1213
1214        let scan = LogicalOperator::NodeScan(NodeScanOp {
1215            variable: "n".to_string(),
1216            label: Some("Test".to_string()),
1217            input: None,
1218        });
1219        let plan = LogicalPlan::new(scan);
1220
1221        let cardinality = optimizer.estimate_cardinality(&plan);
1222        assert!((cardinality - 500.0).abs() < 0.001);
1223    }
1224
1225    #[test]
1226    fn test_optimizer_estimate_cost() {
1227        let optimizer = Optimizer::new();
1228        let plan = LogicalPlan::new(LogicalOperator::NodeScan(NodeScanOp {
1229            variable: "n".to_string(),
1230            label: None,
1231            input: None,
1232        }));
1233
1234        let cost = optimizer.estimate_cost(&plan);
1235        assert!(cost.total() > 0.0);
1236    }
1237
1238    // Filter pushdown through various operators
1239
1240    #[test]
1241    fn test_filter_pushdown_through_project() {
1242        let optimizer = Optimizer::new();
1243
1244        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1245            predicate: LogicalExpression::Binary {
1246                left: Box::new(LogicalExpression::Property {
1247                    variable: "n".to_string(),
1248                    property: "age".to_string(),
1249                }),
1250                op: BinaryOp::Gt,
1251                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1252            },
1253            input: Box::new(LogicalOperator::Project(ProjectOp {
1254                projections: vec![Projection {
1255                    expression: LogicalExpression::Variable("n".to_string()),
1256                    alias: None,
1257                }],
1258                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1259                    variable: "n".to_string(),
1260                    label: None,
1261                    input: None,
1262                })),
1263            })),
1264        }));
1265
1266        let optimized = optimizer.optimize(plan).unwrap();
1267
1268        // Filter should be pushed through Project
1269        if let LogicalOperator::Project(proj) = &optimized.root
1270            && let LogicalOperator::Filter(_) = proj.input.as_ref()
1271        {
1272            return;
1273        }
1274        panic!("Expected Project -> Filter structure");
1275    }
1276
1277    #[test]
1278    fn test_filter_not_pushed_through_project_with_alias() {
1279        let optimizer = Optimizer::new();
1280
1281        // Filter on computed column 'x' should not be pushed through project that creates 'x'
1282        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1283            predicate: LogicalExpression::Binary {
1284                left: Box::new(LogicalExpression::Variable("x".to_string())),
1285                op: BinaryOp::Gt,
1286                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1287            },
1288            input: Box::new(LogicalOperator::Project(ProjectOp {
1289                projections: vec![Projection {
1290                    expression: LogicalExpression::Property {
1291                        variable: "n".to_string(),
1292                        property: "age".to_string(),
1293                    },
1294                    alias: Some("x".to_string()),
1295                }],
1296                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1297                    variable: "n".to_string(),
1298                    label: None,
1299                    input: None,
1300                })),
1301            })),
1302        }));
1303
1304        let optimized = optimizer.optimize(plan).unwrap();
1305
1306        // Filter should stay above Project
1307        if let LogicalOperator::Filter(filter) = &optimized.root
1308            && let LogicalOperator::Project(_) = filter.input.as_ref()
1309        {
1310            return;
1311        }
1312        panic!("Expected Filter -> Project structure");
1313    }
1314
1315    #[test]
1316    fn test_filter_pushdown_through_limit() {
1317        let optimizer = Optimizer::new();
1318
1319        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1320            predicate: LogicalExpression::Literal(Value::Bool(true)),
1321            input: Box::new(LogicalOperator::Limit(LimitOp {
1322                count: 10,
1323                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1324                    variable: "n".to_string(),
1325                    label: None,
1326                    input: None,
1327                })),
1328            })),
1329        }));
1330
1331        let optimized = optimizer.optimize(plan).unwrap();
1332
1333        // Filter stays above Limit (cannot be pushed through)
1334        if let LogicalOperator::Filter(filter) = &optimized.root
1335            && let LogicalOperator::Limit(_) = filter.input.as_ref()
1336        {
1337            return;
1338        }
1339        panic!("Expected Filter -> Limit structure");
1340    }
1341
1342    #[test]
1343    fn test_filter_pushdown_through_sort() {
1344        let optimizer = Optimizer::new();
1345
1346        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1347            predicate: LogicalExpression::Literal(Value::Bool(true)),
1348            input: Box::new(LogicalOperator::Sort(SortOp {
1349                keys: vec![SortKey {
1350                    expression: LogicalExpression::Variable("n".to_string()),
1351                    order: SortOrder::Ascending,
1352                }],
1353                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1354                    variable: "n".to_string(),
1355                    label: None,
1356                    input: None,
1357                })),
1358            })),
1359        }));
1360
1361        let optimized = optimizer.optimize(plan).unwrap();
1362
1363        // Filter stays above Sort
1364        if let LogicalOperator::Filter(filter) = &optimized.root
1365            && let LogicalOperator::Sort(_) = filter.input.as_ref()
1366        {
1367            return;
1368        }
1369        panic!("Expected Filter -> Sort structure");
1370    }
1371
1372    #[test]
1373    fn test_filter_pushdown_through_distinct() {
1374        let optimizer = Optimizer::new();
1375
1376        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1377            predicate: LogicalExpression::Literal(Value::Bool(true)),
1378            input: Box::new(LogicalOperator::Distinct(DistinctOp {
1379                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1380                    variable: "n".to_string(),
1381                    label: None,
1382                    input: None,
1383                })),
1384                columns: None,
1385            })),
1386        }));
1387
1388        let optimized = optimizer.optimize(plan).unwrap();
1389
1390        // Filter stays above Distinct
1391        if let LogicalOperator::Filter(filter) = &optimized.root
1392            && let LogicalOperator::Distinct(_) = filter.input.as_ref()
1393        {
1394            return;
1395        }
1396        panic!("Expected Filter -> Distinct structure");
1397    }
1398
1399    #[test]
1400    fn test_filter_not_pushed_through_aggregate() {
1401        let optimizer = Optimizer::new();
1402
1403        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1404            predicate: LogicalExpression::Binary {
1405                left: Box::new(LogicalExpression::Variable("cnt".to_string())),
1406                op: BinaryOp::Gt,
1407                right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1408            },
1409            input: Box::new(LogicalOperator::Aggregate(AggregateOp {
1410                group_by: vec![],
1411                aggregates: vec![AggregateExpr {
1412                    function: AggregateFunction::Count,
1413                    expression: None,
1414                    distinct: false,
1415                    alias: Some("cnt".to_string()),
1416                    percentile: None,
1417                }],
1418                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1419                    variable: "n".to_string(),
1420                    label: None,
1421                    input: None,
1422                })),
1423                having: None,
1424            })),
1425        }));
1426
1427        let optimized = optimizer.optimize(plan).unwrap();
1428
1429        // Filter should stay above Aggregate
1430        if let LogicalOperator::Filter(filter) = &optimized.root
1431            && let LogicalOperator::Aggregate(_) = filter.input.as_ref()
1432        {
1433            return;
1434        }
1435        panic!("Expected Filter -> Aggregate structure");
1436    }
1437
1438    #[test]
1439    fn test_filter_pushdown_to_left_join_side() {
1440        let optimizer = Optimizer::new();
1441
1442        // Filter on left variable should be pushed to left side
1443        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1444            predicate: LogicalExpression::Binary {
1445                left: Box::new(LogicalExpression::Property {
1446                    variable: "a".to_string(),
1447                    property: "age".to_string(),
1448                }),
1449                op: BinaryOp::Gt,
1450                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1451            },
1452            input: Box::new(LogicalOperator::Join(JoinOp {
1453                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1454                    variable: "a".to_string(),
1455                    label: Some("Person".to_string()),
1456                    input: None,
1457                })),
1458                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1459                    variable: "b".to_string(),
1460                    label: Some("Company".to_string()),
1461                    input: None,
1462                })),
1463                join_type: JoinType::Inner,
1464                conditions: vec![],
1465            })),
1466        }));
1467
1468        let optimized = optimizer.optimize(plan).unwrap();
1469
1470        // Filter should be pushed to left side of join
1471        if let LogicalOperator::Join(join) = &optimized.root
1472            && let LogicalOperator::Filter(_) = join.left.as_ref()
1473        {
1474            return;
1475        }
1476        panic!("Expected Join with Filter on left side");
1477    }
1478
1479    #[test]
1480    fn test_filter_pushdown_to_right_join_side() {
1481        let optimizer = Optimizer::new();
1482
1483        // Filter on right variable should be pushed to right side
1484        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1485            predicate: LogicalExpression::Binary {
1486                left: Box::new(LogicalExpression::Property {
1487                    variable: "b".to_string(),
1488                    property: "name".to_string(),
1489                }),
1490                op: BinaryOp::Eq,
1491                right: Box::new(LogicalExpression::Literal(Value::String("Acme".into()))),
1492            },
1493            input: Box::new(LogicalOperator::Join(JoinOp {
1494                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1495                    variable: "a".to_string(),
1496                    label: Some("Person".to_string()),
1497                    input: None,
1498                })),
1499                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1500                    variable: "b".to_string(),
1501                    label: Some("Company".to_string()),
1502                    input: None,
1503                })),
1504                join_type: JoinType::Inner,
1505                conditions: vec![],
1506            })),
1507        }));
1508
1509        let optimized = optimizer.optimize(plan).unwrap();
1510
1511        // Filter should be pushed to right side of join
1512        if let LogicalOperator::Join(join) = &optimized.root
1513            && let LogicalOperator::Filter(_) = join.right.as_ref()
1514        {
1515            return;
1516        }
1517        panic!("Expected Join with Filter on right side");
1518    }
1519
1520    #[test]
1521    fn test_filter_not_pushed_when_uses_both_join_sides() {
1522        let optimizer = Optimizer::new();
1523
1524        // Filter using both variables should stay above join
1525        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1526            predicate: LogicalExpression::Binary {
1527                left: Box::new(LogicalExpression::Property {
1528                    variable: "a".to_string(),
1529                    property: "id".to_string(),
1530                }),
1531                op: BinaryOp::Eq,
1532                right: Box::new(LogicalExpression::Property {
1533                    variable: "b".to_string(),
1534                    property: "a_id".to_string(),
1535                }),
1536            },
1537            input: Box::new(LogicalOperator::Join(JoinOp {
1538                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1539                    variable: "a".to_string(),
1540                    label: None,
1541                    input: None,
1542                })),
1543                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1544                    variable: "b".to_string(),
1545                    label: None,
1546                    input: None,
1547                })),
1548                join_type: JoinType::Inner,
1549                conditions: vec![],
1550            })),
1551        }));
1552
1553        let optimized = optimizer.optimize(plan).unwrap();
1554
1555        // Filter should stay above join
1556        if let LogicalOperator::Filter(filter) = &optimized.root
1557            && let LogicalOperator::Join(_) = filter.input.as_ref()
1558        {
1559            return;
1560        }
1561        panic!("Expected Filter -> Join structure");
1562    }
1563
1564    // Variable extraction tests
1565
1566    #[test]
1567    fn test_extract_variables_from_variable() {
1568        let optimizer = Optimizer::new();
1569        let expr = LogicalExpression::Variable("x".to_string());
1570        let vars = optimizer.extract_variables(&expr);
1571        assert_eq!(vars.len(), 1);
1572        assert!(vars.contains("x"));
1573    }
1574
1575    #[test]
1576    fn test_extract_variables_from_unary() {
1577        let optimizer = Optimizer::new();
1578        let expr = LogicalExpression::Unary {
1579            op: UnaryOp::Not,
1580            operand: Box::new(LogicalExpression::Variable("x".to_string())),
1581        };
1582        let vars = optimizer.extract_variables(&expr);
1583        assert_eq!(vars.len(), 1);
1584        assert!(vars.contains("x"));
1585    }
1586
1587    #[test]
1588    fn test_extract_variables_from_function_call() {
1589        let optimizer = Optimizer::new();
1590        let expr = LogicalExpression::FunctionCall {
1591            name: "length".to_string(),
1592            args: vec![
1593                LogicalExpression::Variable("a".to_string()),
1594                LogicalExpression::Variable("b".to_string()),
1595            ],
1596            distinct: false,
1597        };
1598        let vars = optimizer.extract_variables(&expr);
1599        assert_eq!(vars.len(), 2);
1600        assert!(vars.contains("a"));
1601        assert!(vars.contains("b"));
1602    }
1603
1604    #[test]
1605    fn test_extract_variables_from_list() {
1606        let optimizer = Optimizer::new();
1607        let expr = LogicalExpression::List(vec![
1608            LogicalExpression::Variable("a".to_string()),
1609            LogicalExpression::Literal(Value::Int64(1)),
1610            LogicalExpression::Variable("b".to_string()),
1611        ]);
1612        let vars = optimizer.extract_variables(&expr);
1613        assert_eq!(vars.len(), 2);
1614        assert!(vars.contains("a"));
1615        assert!(vars.contains("b"));
1616    }
1617
1618    #[test]
1619    fn test_extract_variables_from_map() {
1620        let optimizer = Optimizer::new();
1621        let expr = LogicalExpression::Map(vec![
1622            (
1623                "key1".to_string(),
1624                LogicalExpression::Variable("a".to_string()),
1625            ),
1626            (
1627                "key2".to_string(),
1628                LogicalExpression::Variable("b".to_string()),
1629            ),
1630        ]);
1631        let vars = optimizer.extract_variables(&expr);
1632        assert_eq!(vars.len(), 2);
1633        assert!(vars.contains("a"));
1634        assert!(vars.contains("b"));
1635    }
1636
1637    #[test]
1638    fn test_extract_variables_from_index_access() {
1639        let optimizer = Optimizer::new();
1640        let expr = LogicalExpression::IndexAccess {
1641            base: Box::new(LogicalExpression::Variable("list".to_string())),
1642            index: Box::new(LogicalExpression::Variable("idx".to_string())),
1643        };
1644        let vars = optimizer.extract_variables(&expr);
1645        assert_eq!(vars.len(), 2);
1646        assert!(vars.contains("list"));
1647        assert!(vars.contains("idx"));
1648    }
1649
1650    #[test]
1651    fn test_extract_variables_from_slice_access() {
1652        let optimizer = Optimizer::new();
1653        let expr = LogicalExpression::SliceAccess {
1654            base: Box::new(LogicalExpression::Variable("list".to_string())),
1655            start: Some(Box::new(LogicalExpression::Variable("s".to_string()))),
1656            end: Some(Box::new(LogicalExpression::Variable("e".to_string()))),
1657        };
1658        let vars = optimizer.extract_variables(&expr);
1659        assert_eq!(vars.len(), 3);
1660        assert!(vars.contains("list"));
1661        assert!(vars.contains("s"));
1662        assert!(vars.contains("e"));
1663    }
1664
1665    #[test]
1666    fn test_extract_variables_from_case() {
1667        let optimizer = Optimizer::new();
1668        let expr = LogicalExpression::Case {
1669            operand: Some(Box::new(LogicalExpression::Variable("x".to_string()))),
1670            when_clauses: vec![(
1671                LogicalExpression::Literal(Value::Int64(1)),
1672                LogicalExpression::Variable("a".to_string()),
1673            )],
1674            else_clause: Some(Box::new(LogicalExpression::Variable("b".to_string()))),
1675        };
1676        let vars = optimizer.extract_variables(&expr);
1677        assert_eq!(vars.len(), 3);
1678        assert!(vars.contains("x"));
1679        assert!(vars.contains("a"));
1680        assert!(vars.contains("b"));
1681    }
1682
1683    #[test]
1684    fn test_extract_variables_from_labels() {
1685        let optimizer = Optimizer::new();
1686        let expr = LogicalExpression::Labels("n".to_string());
1687        let vars = optimizer.extract_variables(&expr);
1688        assert_eq!(vars.len(), 1);
1689        assert!(vars.contains("n"));
1690    }
1691
1692    #[test]
1693    fn test_extract_variables_from_type() {
1694        let optimizer = Optimizer::new();
1695        let expr = LogicalExpression::Type("e".to_string());
1696        let vars = optimizer.extract_variables(&expr);
1697        assert_eq!(vars.len(), 1);
1698        assert!(vars.contains("e"));
1699    }
1700
1701    #[test]
1702    fn test_extract_variables_from_id() {
1703        let optimizer = Optimizer::new();
1704        let expr = LogicalExpression::Id("n".to_string());
1705        let vars = optimizer.extract_variables(&expr);
1706        assert_eq!(vars.len(), 1);
1707        assert!(vars.contains("n"));
1708    }
1709
1710    #[test]
1711    fn test_extract_variables_from_list_comprehension() {
1712        let optimizer = Optimizer::new();
1713        let expr = LogicalExpression::ListComprehension {
1714            variable: "x".to_string(),
1715            list_expr: Box::new(LogicalExpression::Variable("items".to_string())),
1716            filter_expr: Some(Box::new(LogicalExpression::Variable("pred".to_string()))),
1717            map_expr: Box::new(LogicalExpression::Variable("result".to_string())),
1718        };
1719        let vars = optimizer.extract_variables(&expr);
1720        assert!(vars.contains("items"));
1721        assert!(vars.contains("pred"));
1722        assert!(vars.contains("result"));
1723    }
1724
1725    #[test]
1726    fn test_extract_variables_from_literal_and_parameter() {
1727        let optimizer = Optimizer::new();
1728
1729        let literal = LogicalExpression::Literal(Value::Int64(42));
1730        assert!(optimizer.extract_variables(&literal).is_empty());
1731
1732        let param = LogicalExpression::Parameter("p".to_string());
1733        assert!(optimizer.extract_variables(&param).is_empty());
1734    }
1735
1736    // Recursive filter pushdown tests
1737
1738    #[test]
1739    fn test_recursive_filter_pushdown_through_skip() {
1740        let optimizer = Optimizer::new();
1741
1742        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1743            items: vec![ReturnItem {
1744                expression: LogicalExpression::Variable("n".to_string()),
1745                alias: None,
1746            }],
1747            distinct: false,
1748            input: Box::new(LogicalOperator::Filter(FilterOp {
1749                predicate: LogicalExpression::Literal(Value::Bool(true)),
1750                input: Box::new(LogicalOperator::Skip(SkipOp {
1751                    count: 5,
1752                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1753                        variable: "n".to_string(),
1754                        label: None,
1755                        input: None,
1756                    })),
1757                })),
1758            })),
1759        }));
1760
1761        let optimized = optimizer.optimize(plan).unwrap();
1762
1763        // Verify optimization succeeded
1764        assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1765    }
1766
1767    #[test]
1768    fn test_nested_filter_pushdown() {
1769        let optimizer = Optimizer::new();
1770
1771        // Multiple nested filters
1772        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1773            items: vec![ReturnItem {
1774                expression: LogicalExpression::Variable("n".to_string()),
1775                alias: None,
1776            }],
1777            distinct: false,
1778            input: Box::new(LogicalOperator::Filter(FilterOp {
1779                predicate: LogicalExpression::Binary {
1780                    left: Box::new(LogicalExpression::Property {
1781                        variable: "n".to_string(),
1782                        property: "x".to_string(),
1783                    }),
1784                    op: BinaryOp::Gt,
1785                    right: Box::new(LogicalExpression::Literal(Value::Int64(1))),
1786                },
1787                input: Box::new(LogicalOperator::Filter(FilterOp {
1788                    predicate: LogicalExpression::Binary {
1789                        left: Box::new(LogicalExpression::Property {
1790                            variable: "n".to_string(),
1791                            property: "y".to_string(),
1792                        }),
1793                        op: BinaryOp::Lt,
1794                        right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1795                    },
1796                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1797                        variable: "n".to_string(),
1798                        label: None,
1799                        input: None,
1800                    })),
1801                })),
1802            })),
1803        }));
1804
1805        let optimized = optimizer.optimize(plan).unwrap();
1806        assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1807    }
1808}