Skip to main content

grafeo_engine/query/optimizer/
mod.rs

1//! Makes your queries faster without changing their meaning.
2//!
3//! The optimizer transforms logical plans to run more efficiently:
4//!
5//! | Optimization | What it does |
6//! | ------------ | ------------ |
7//! | Filter Pushdown | Moves `WHERE` clauses closer to scans - filter early, process less |
8//! | Join Reordering | Picks the best order to join tables using the DPccp algorithm |
9//! | Predicate Simplification | Folds constants like `1 + 1` into `2` |
10//!
11//! The optimizer uses [`CostModel`] and [`CardinalityEstimator`] to predict
12//! how expensive different plans are, then picks the cheapest.
13
14pub mod cardinality;
15pub mod cost;
16pub mod join_order;
17
18pub use cardinality::{
19    CardinalityEstimator, ColumnStats, EstimationLog, SelectivityConfig, TableStats,
20};
21pub use cost::{Cost, CostModel};
22pub use join_order::{BitSet, DPccp, JoinGraph, JoinGraphBuilder, JoinPlan};
23
24use crate::query::plan::{
25    FilterOp, JoinCondition, LogicalExpression, LogicalOperator, LogicalPlan, MultiWayJoinOp,
26};
27use grafeo_common::grafeo_debug_span;
28use grafeo_common::utils::error::Result;
29use std::collections::HashSet;
30
31/// Information about a join condition for join reordering.
32#[derive(Debug, Clone)]
33struct JoinInfo {
34    left_var: String,
35    right_var: String,
36    left_expr: LogicalExpression,
37    right_expr: LogicalExpression,
38}
39
40/// A column required by the query, used for projection pushdown.
41#[derive(Debug, Clone, PartialEq, Eq, Hash)]
42enum RequiredColumn {
43    /// A variable (node, edge, or path binding)
44    Variable(String),
45    /// A specific property of a variable
46    Property(String, String),
47}
48
49/// Transforms logical plans for faster execution.
50///
51/// Create with [`new()`](Self::new), then call [`optimize()`](Self::optimize).
52/// Use the builder methods to enable/disable specific optimizations.
53pub struct Optimizer {
54    /// Whether to enable filter pushdown.
55    enable_filter_pushdown: bool,
56    /// Whether to enable join reordering.
57    enable_join_reorder: bool,
58    /// Whether to enable projection pushdown.
59    enable_projection_pushdown: bool,
60    /// Cost model for estimation.
61    cost_model: CostModel,
62    /// Cardinality estimator.
63    card_estimator: CardinalityEstimator,
64}
65
66impl Optimizer {
67    /// Creates a new optimizer with default settings.
68    #[must_use]
69    pub fn new() -> Self {
70        Self {
71            enable_filter_pushdown: true,
72            enable_join_reorder: true,
73            enable_projection_pushdown: true,
74            cost_model: CostModel::new(),
75            card_estimator: CardinalityEstimator::new(),
76        }
77    }
78
79    /// Creates an optimizer with cardinality estimates from the store's statistics.
80    ///
81    /// Pre-populates the cardinality estimator with per-label row counts and
82    /// edge type fanout. Feeds per-edge-type degree stats, label cardinalities,
83    /// and graph totals into the cost model for accurate estimation.
84    #[cfg(feature = "lpg")]
85    #[must_use]
86    pub fn from_store(store: &grafeo_core::graph::lpg::LpgStore) -> Self {
87        store.ensure_statistics_fresh();
88        let stats = store.statistics();
89        Self::from_statistics(&stats)
90    }
91
92    /// Creates an optimizer from any GraphStore implementation.
93    ///
94    /// Unlike [`from_store`](Self::from_store), this does not call
95    /// `ensure_statistics_fresh()` since external stores manage their own
96    /// statistics. The store's [`statistics()`](grafeo_core::graph::GraphStore::statistics) method
97    /// is called directly.
98    #[must_use]
99    pub fn from_graph_store(store: &dyn grafeo_core::graph::GraphStore) -> Self {
100        let stats = store.statistics();
101        Self::from_statistics(&stats)
102    }
103
104    /// Creates an optimizer from RDF statistics.
105    ///
106    /// Uses triple pattern cardinality estimates for cost-based optimization
107    /// of SPARQL queries. Maps total triples to graph totals for the cost model.
108    #[cfg(feature = "triple-store")]
109    #[must_use]
110    pub fn from_rdf_statistics(rdf_stats: grafeo_core::statistics::RdfStatistics) -> Self {
111        let total = rdf_stats.total_triples;
112        let estimator = CardinalityEstimator::from_rdf_statistics(rdf_stats);
113        Self {
114            enable_filter_pushdown: true,
115            enable_join_reorder: true,
116            enable_projection_pushdown: true,
117            cost_model: CostModel::new().with_graph_totals(total, total),
118            card_estimator: estimator,
119        }
120    }
121
122    /// Creates an optimizer from a Statistics snapshot.
123    ///
124    /// Extracts label cardinalities, edge type degrees, and graph totals
125    /// for both the cardinality estimator and cost model.
126    #[must_use]
127    fn from_statistics(stats: &grafeo_core::statistics::Statistics) -> Self {
128        let estimator = CardinalityEstimator::from_statistics(stats);
129
130        let avg_fanout = if stats.total_nodes > 0 {
131            (stats.total_edges as f64 / stats.total_nodes as f64).max(1.0)
132        } else {
133            10.0
134        };
135
136        let edge_type_degrees: std::collections::HashMap<String, (f64, f64)> = stats
137            .edge_types
138            .iter()
139            .map(|(name, et)| (name.clone(), (et.avg_out_degree, et.avg_in_degree)))
140            .collect();
141
142        let label_cardinalities: std::collections::HashMap<String, u64> = stats
143            .labels
144            .iter()
145            .map(|(name, ls)| (name.clone(), ls.node_count))
146            .collect();
147
148        Self {
149            enable_filter_pushdown: true,
150            enable_join_reorder: true,
151            enable_projection_pushdown: true,
152            cost_model: CostModel::new()
153                .with_avg_fanout(avg_fanout)
154                .with_edge_type_degrees(edge_type_degrees)
155                .with_label_cardinalities(label_cardinalities)
156                .with_graph_totals(stats.total_nodes, stats.total_edges),
157            card_estimator: estimator,
158        }
159    }
160
161    /// Enables or disables filter pushdown.
162    pub fn with_filter_pushdown(mut self, enabled: bool) -> Self {
163        self.enable_filter_pushdown = enabled;
164        self
165    }
166
167    /// Enables or disables join reordering.
168    pub fn with_join_reorder(mut self, enabled: bool) -> Self {
169        self.enable_join_reorder = enabled;
170        self
171    }
172
173    /// Enables or disables projection pushdown.
174    pub fn with_projection_pushdown(mut self, enabled: bool) -> Self {
175        self.enable_projection_pushdown = enabled;
176        self
177    }
178
179    /// Sets the cost model.
180    pub fn with_cost_model(mut self, cost_model: CostModel) -> Self {
181        self.cost_model = cost_model;
182        self
183    }
184
185    /// Sets the cardinality estimator.
186    pub fn with_cardinality_estimator(mut self, estimator: CardinalityEstimator) -> Self {
187        self.card_estimator = estimator;
188        self
189    }
190
191    /// Sets the selectivity configuration for the cardinality estimator.
192    pub fn with_selectivity_config(mut self, config: SelectivityConfig) -> Self {
193        self.card_estimator = CardinalityEstimator::with_selectivity_config(config);
194        self
195    }
196
197    /// Returns a reference to the cost model.
198    pub fn cost_model(&self) -> &CostModel {
199        &self.cost_model
200    }
201
202    /// Returns a reference to the cardinality estimator.
203    pub fn cardinality_estimator(&self) -> &CardinalityEstimator {
204        &self.card_estimator
205    }
206
207    /// Estimates the total cost of a plan by recursively costing the entire tree.
208    ///
209    /// Walks every operator in the plan, computing cardinality at each level
210    /// and summing the per-operator costs. Uses actual child cardinalities
211    /// for join cost estimation rather than approximations.
212    pub fn estimate_cost(&self, plan: &LogicalPlan) -> Cost {
213        self.cost_model
214            .estimate_tree(&plan.root, &self.card_estimator)
215    }
216
217    /// Estimates the cardinality of a plan.
218    pub fn estimate_cardinality(&self, plan: &LogicalPlan) -> f64 {
219        self.card_estimator.estimate(&plan.root)
220    }
221
222    /// Optimizes a logical plan.
223    ///
224    /// # Errors
225    ///
226    /// Returns an error if optimization fails.
227    pub fn optimize(&self, plan: LogicalPlan) -> Result<LogicalPlan> {
228        let _span = grafeo_debug_span!("grafeo::query::optimize");
229        let mut root = plan.root;
230
231        // Apply optimization rules
232        if self.enable_filter_pushdown {
233            // Propagate filters across LeftJoin shared variables BEFORE
234            // pushdown, so OPTIONAL MATCH right-side subtrees pick up
235            // the same WHERE constraints the main MATCH applies.
236            // Otherwise pushdown rewrites the left subtree and the
237            // chance is gone.
238            root = self.propagate_join_predicates(root);
239            root = self.push_filters_down(root);
240        }
241
242        if self.enable_join_reorder {
243            root = self.reorder_joins(root);
244        }
245
246        if self.enable_projection_pushdown {
247            root = self.push_projections_down(root);
248        }
249
250        Ok(LogicalPlan {
251            root,
252            explain: plan.explain,
253            profile: plan.profile,
254            default_params: plan.default_params,
255        })
256    }
257
258    /// Pushes projections down the operator tree to eliminate unused columns early.
259    ///
260    /// This optimization:
261    /// 1. Collects required variables/properties from the root
262    /// 2. Propagates requirements down through the tree
263    /// 3. Inserts projections to eliminate unneeded columns before expensive operations
264    fn push_projections_down(&self, op: LogicalOperator) -> LogicalOperator {
265        // Collect required columns from the top of the plan
266        let required = self.collect_required_columns(&op);
267
268        // Push projections down
269        self.push_projections_recursive(op, &required)
270    }
271
272    /// Collects all variables and properties required by an operator and its ancestors.
273    fn collect_required_columns(&self, op: &LogicalOperator) -> HashSet<RequiredColumn> {
274        let mut required = HashSet::new();
275        Self::collect_required_recursive(op, &mut required);
276        required
277    }
278
279    /// Recursively collects required columns.
280    fn collect_required_recursive(op: &LogicalOperator, required: &mut HashSet<RequiredColumn>) {
281        match op {
282            LogicalOperator::Return(ret) => {
283                for item in &ret.items {
284                    Self::collect_from_expression(&item.expression, required);
285                }
286                Self::collect_required_recursive(&ret.input, required);
287            }
288            LogicalOperator::Project(proj) => {
289                for p in &proj.projections {
290                    Self::collect_from_expression(&p.expression, required);
291                }
292                Self::collect_required_recursive(&proj.input, required);
293            }
294            LogicalOperator::Filter(filter) => {
295                Self::collect_from_expression(&filter.predicate, required);
296                Self::collect_required_recursive(&filter.input, required);
297            }
298            LogicalOperator::Sort(sort) => {
299                for key in &sort.keys {
300                    Self::collect_from_expression(&key.expression, required);
301                }
302                Self::collect_required_recursive(&sort.input, required);
303            }
304            LogicalOperator::Aggregate(agg) => {
305                for expr in &agg.group_by {
306                    Self::collect_from_expression(expr, required);
307                }
308                for agg_expr in &agg.aggregates {
309                    if let Some(ref expr) = agg_expr.expression {
310                        Self::collect_from_expression(expr, required);
311                    }
312                }
313                if let Some(ref having) = agg.having {
314                    Self::collect_from_expression(having, required);
315                }
316                Self::collect_required_recursive(&agg.input, required);
317            }
318            LogicalOperator::Join(join) => {
319                for cond in &join.conditions {
320                    Self::collect_from_expression(&cond.left, required);
321                    Self::collect_from_expression(&cond.right, required);
322                }
323                Self::collect_required_recursive(&join.left, required);
324                Self::collect_required_recursive(&join.right, required);
325            }
326            LogicalOperator::Expand(expand) => {
327                // The source and target variables are needed
328                required.insert(RequiredColumn::Variable(expand.from_variable.clone()));
329                required.insert(RequiredColumn::Variable(expand.to_variable.clone()));
330                if let Some(ref edge_var) = expand.edge_variable {
331                    required.insert(RequiredColumn::Variable(edge_var.clone()));
332                }
333                Self::collect_required_recursive(&expand.input, required);
334            }
335            LogicalOperator::Limit(limit) => {
336                Self::collect_required_recursive(&limit.input, required);
337            }
338            LogicalOperator::Skip(skip) => {
339                Self::collect_required_recursive(&skip.input, required);
340            }
341            LogicalOperator::Distinct(distinct) => {
342                Self::collect_required_recursive(&distinct.input, required);
343            }
344            LogicalOperator::NodeScan(scan) => {
345                required.insert(RequiredColumn::Variable(scan.variable.clone()));
346            }
347            LogicalOperator::EdgeScan(scan) => {
348                required.insert(RequiredColumn::Variable(scan.variable.clone()));
349            }
350            LogicalOperator::MultiWayJoin(mwj) => {
351                for cond in &mwj.conditions {
352                    Self::collect_from_expression(&cond.left, required);
353                    Self::collect_from_expression(&cond.right, required);
354                }
355                for input in &mwj.inputs {
356                    Self::collect_required_recursive(input, required);
357                }
358            }
359            _ => {}
360        }
361    }
362
363    /// Collects required columns from an expression.
364    fn collect_from_expression(expr: &LogicalExpression, required: &mut HashSet<RequiredColumn>) {
365        match expr {
366            LogicalExpression::Variable(var) => {
367                required.insert(RequiredColumn::Variable(var.clone()));
368            }
369            LogicalExpression::Property { variable, property } => {
370                required.insert(RequiredColumn::Property(variable.clone(), property.clone()));
371                required.insert(RequiredColumn::Variable(variable.clone()));
372            }
373            LogicalExpression::Binary { left, right, .. } => {
374                Self::collect_from_expression(left, required);
375                Self::collect_from_expression(right, required);
376            }
377            LogicalExpression::Unary { operand, .. } => {
378                Self::collect_from_expression(operand, required);
379            }
380            LogicalExpression::FunctionCall { args, .. } => {
381                for arg in args {
382                    Self::collect_from_expression(arg, required);
383                }
384            }
385            LogicalExpression::List(items) => {
386                for item in items {
387                    Self::collect_from_expression(item, required);
388                }
389            }
390            LogicalExpression::Map(pairs) => {
391                for (_, value) in pairs {
392                    Self::collect_from_expression(value, required);
393                }
394            }
395            LogicalExpression::IndexAccess { base, index } => {
396                Self::collect_from_expression(base, required);
397                Self::collect_from_expression(index, required);
398            }
399            LogicalExpression::SliceAccess { base, start, end } => {
400                Self::collect_from_expression(base, required);
401                if let Some(s) = start {
402                    Self::collect_from_expression(s, required);
403                }
404                if let Some(e) = end {
405                    Self::collect_from_expression(e, required);
406                }
407            }
408            LogicalExpression::Case {
409                operand,
410                when_clauses,
411                else_clause,
412            } => {
413                if let Some(op) = operand {
414                    Self::collect_from_expression(op, required);
415                }
416                for (cond, result) in when_clauses {
417                    Self::collect_from_expression(cond, required);
418                    Self::collect_from_expression(result, required);
419                }
420                if let Some(else_expr) = else_clause {
421                    Self::collect_from_expression(else_expr, required);
422                }
423            }
424            LogicalExpression::Labels(var)
425            | LogicalExpression::Type(var)
426            | LogicalExpression::Id(var) => {
427                required.insert(RequiredColumn::Variable(var.clone()));
428            }
429            LogicalExpression::ListComprehension {
430                list_expr,
431                filter_expr,
432                map_expr,
433                ..
434            } => {
435                Self::collect_from_expression(list_expr, required);
436                if let Some(filter) = filter_expr {
437                    Self::collect_from_expression(filter, required);
438                }
439                Self::collect_from_expression(map_expr, required);
440            }
441            _ => {}
442        }
443    }
444
445    /// Recursively pushes projections down, adding them before expensive operations.
446    fn push_projections_recursive(
447        &self,
448        op: LogicalOperator,
449        required: &HashSet<RequiredColumn>,
450    ) -> LogicalOperator {
451        match op {
452            LogicalOperator::Return(mut ret) => {
453                ret.input = Box::new(self.push_projections_recursive(*ret.input, required));
454                LogicalOperator::Return(ret)
455            }
456            LogicalOperator::Project(mut proj) => {
457                proj.input = Box::new(self.push_projections_recursive(*proj.input, required));
458                LogicalOperator::Project(proj)
459            }
460            LogicalOperator::Filter(mut filter) => {
461                filter.input = Box::new(self.push_projections_recursive(*filter.input, required));
462                LogicalOperator::Filter(filter)
463            }
464            LogicalOperator::Sort(mut sort) => {
465                // Sort is expensive - consider adding a projection before it
466                // to reduce tuple width
467                sort.input = Box::new(self.push_projections_recursive(*sort.input, required));
468                LogicalOperator::Sort(sort)
469            }
470            LogicalOperator::Aggregate(mut agg) => {
471                agg.input = Box::new(self.push_projections_recursive(*agg.input, required));
472                LogicalOperator::Aggregate(agg)
473            }
474            LogicalOperator::Join(mut join) => {
475                // Joins are expensive - the required columns help determine
476                // what to project on each side
477                let left_vars = self.collect_output_variables(&join.left);
478                let right_vars = self.collect_output_variables(&join.right);
479
480                // Filter required columns to each side
481                let left_required: HashSet<_> = required
482                    .iter()
483                    .filter(|c| match c {
484                        RequiredColumn::Variable(v) => left_vars.contains(v),
485                        RequiredColumn::Property(v, _) => left_vars.contains(v),
486                    })
487                    .cloned()
488                    .collect();
489
490                let right_required: HashSet<_> = required
491                    .iter()
492                    .filter(|c| match c {
493                        RequiredColumn::Variable(v) => right_vars.contains(v),
494                        RequiredColumn::Property(v, _) => right_vars.contains(v),
495                    })
496                    .cloned()
497                    .collect();
498
499                join.left = Box::new(self.push_projections_recursive(*join.left, &left_required));
500                join.right =
501                    Box::new(self.push_projections_recursive(*join.right, &right_required));
502                LogicalOperator::Join(join)
503            }
504            LogicalOperator::Expand(mut expand) => {
505                expand.input = Box::new(self.push_projections_recursive(*expand.input, required));
506                LogicalOperator::Expand(expand)
507            }
508            LogicalOperator::Limit(mut limit) => {
509                limit.input = Box::new(self.push_projections_recursive(*limit.input, required));
510                LogicalOperator::Limit(limit)
511            }
512            LogicalOperator::Skip(mut skip) => {
513                skip.input = Box::new(self.push_projections_recursive(*skip.input, required));
514                LogicalOperator::Skip(skip)
515            }
516            LogicalOperator::Distinct(mut distinct) => {
517                distinct.input =
518                    Box::new(self.push_projections_recursive(*distinct.input, required));
519                LogicalOperator::Distinct(distinct)
520            }
521            LogicalOperator::MapCollect(mut mc) => {
522                mc.input = Box::new(self.push_projections_recursive(*mc.input, required));
523                LogicalOperator::MapCollect(mc)
524            }
525            LogicalOperator::MultiWayJoin(mut mwj) => {
526                mwj.inputs = mwj
527                    .inputs
528                    .into_iter()
529                    .map(|input| self.push_projections_recursive(input, required))
530                    .collect();
531                LogicalOperator::MultiWayJoin(mwj)
532            }
533            other => other,
534        }
535    }
536
537    /// Reorders joins in the operator tree using the DPccp algorithm.
538    ///
539    /// This optimization finds the optimal join order by:
540    /// 1. Extracting all base relations (scans) and join conditions
541    /// 2. Building a join graph
542    /// 3. Using dynamic programming to find the cheapest join order
543    fn reorder_joins(&self, op: LogicalOperator) -> LogicalOperator {
544        // First, recursively optimize children
545        let op = self.reorder_joins_recursive(op);
546
547        // Then, if this is a join tree, try to optimize it
548        if let Some((relations, conditions)) = self.extract_join_tree(&op)
549            && relations.len() >= 2
550            && let Some(optimized) = self.optimize_join_order(&relations, &conditions)
551        {
552            return optimized;
553        }
554
555        op
556    }
557
558    /// Recursively applies join reordering to child operators.
559    fn reorder_joins_recursive(&self, op: LogicalOperator) -> LogicalOperator {
560        match op {
561            LogicalOperator::Return(mut ret) => {
562                ret.input = Box::new(self.reorder_joins(*ret.input));
563                LogicalOperator::Return(ret)
564            }
565            LogicalOperator::Project(mut proj) => {
566                proj.input = Box::new(self.reorder_joins(*proj.input));
567                LogicalOperator::Project(proj)
568            }
569            LogicalOperator::Filter(mut filter) => {
570                filter.input = Box::new(self.reorder_joins(*filter.input));
571                LogicalOperator::Filter(filter)
572            }
573            LogicalOperator::Limit(mut limit) => {
574                limit.input = Box::new(self.reorder_joins(*limit.input));
575                LogicalOperator::Limit(limit)
576            }
577            LogicalOperator::Skip(mut skip) => {
578                skip.input = Box::new(self.reorder_joins(*skip.input));
579                LogicalOperator::Skip(skip)
580            }
581            LogicalOperator::Sort(mut sort) => {
582                sort.input = Box::new(self.reorder_joins(*sort.input));
583                LogicalOperator::Sort(sort)
584            }
585            LogicalOperator::Distinct(mut distinct) => {
586                distinct.input = Box::new(self.reorder_joins(*distinct.input));
587                LogicalOperator::Distinct(distinct)
588            }
589            LogicalOperator::Aggregate(mut agg) => {
590                agg.input = Box::new(self.reorder_joins(*agg.input));
591                LogicalOperator::Aggregate(agg)
592            }
593            LogicalOperator::Expand(mut expand) => {
594                expand.input = Box::new(self.reorder_joins(*expand.input));
595                LogicalOperator::Expand(expand)
596            }
597            LogicalOperator::MapCollect(mut mc) => {
598                mc.input = Box::new(self.reorder_joins(*mc.input));
599                LogicalOperator::MapCollect(mc)
600            }
601            LogicalOperator::MultiWayJoin(mut mwj) => {
602                mwj.inputs = mwj
603                    .inputs
604                    .into_iter()
605                    .map(|input| self.reorder_joins(input))
606                    .collect();
607                LogicalOperator::MultiWayJoin(mwj)
608            }
609            // Join operators are handled by the parent reorder_joins call
610            other => other,
611        }
612    }
613
614    /// Extracts base relations and join conditions from a join tree.
615    ///
616    /// Returns None if the operator is not a join tree.
617    fn extract_join_tree(
618        &self,
619        op: &LogicalOperator,
620    ) -> Option<(Vec<(String, LogicalOperator)>, Vec<JoinInfo>)> {
621        let mut relations = Vec::new();
622        let mut join_conditions = Vec::new();
623
624        if !self.collect_join_tree(op, &mut relations, &mut join_conditions) {
625            return None;
626        }
627
628        if relations.len() < 2 {
629            return None;
630        }
631
632        Some((relations, join_conditions))
633    }
634
635    /// Recursively collects base relations and join conditions.
636    ///
637    /// Returns true if this subtree is part of a join tree.
638    fn collect_join_tree(
639        &self,
640        op: &LogicalOperator,
641        relations: &mut Vec<(String, LogicalOperator)>,
642        conditions: &mut Vec<JoinInfo>,
643    ) -> bool {
644        match op {
645            LogicalOperator::Join(join) => {
646                // Collect from both sides
647                let left_ok = self.collect_join_tree(&join.left, relations, conditions);
648                let right_ok = self.collect_join_tree(&join.right, relations, conditions);
649
650                // Add conditions from this join
651                for cond in &join.conditions {
652                    if let (Some(left_var), Some(right_var)) = (
653                        self.extract_variable_from_expr(&cond.left),
654                        self.extract_variable_from_expr(&cond.right),
655                    ) {
656                        conditions.push(JoinInfo {
657                            left_var,
658                            right_var,
659                            left_expr: cond.left.clone(),
660                            right_expr: cond.right.clone(),
661                        });
662                    }
663                }
664
665                left_ok && right_ok
666            }
667            LogicalOperator::NodeScan(scan) => {
668                relations.push((scan.variable.clone(), op.clone()));
669                true
670            }
671            LogicalOperator::EdgeScan(scan) => {
672                relations.push((scan.variable.clone(), op.clone()));
673                true
674            }
675            LogicalOperator::Filter(filter) => {
676                // A filter on a base relation is still part of the join tree
677                self.collect_join_tree(&filter.input, relations, conditions)
678            }
679            LogicalOperator::Expand(expand) => {
680                // Expand is a special case - it's like a join with the adjacency
681                // For now, treat the whole Expand subtree as a single relation
682                relations.push((expand.to_variable.clone(), op.clone()));
683                true
684            }
685            #[cfg(feature = "triple-store")]
686            LogicalOperator::TripleScan(scan) => {
687                // Use the first variable found as the relation name.
688                // For all-constant patterns, generate a unique fallback to
689                // avoid HashMap key collisions in `add_relation`.
690                let name = scan
691                    .subject
692                    .as_variable()
693                    .or_else(|| scan.predicate.as_variable())
694                    .or_else(|| scan.object.as_variable())
695                    .map_or_else(|| format!("__tp_{}", relations.len()), String::from);
696                relations.push((name, op.clone()));
697                true
698            }
699            _ => false,
700        }
701    }
702
703    /// Extracts the primary variable from an expression.
704    fn extract_variable_from_expr(&self, expr: &LogicalExpression) -> Option<String> {
705        match expr {
706            LogicalExpression::Variable(v) => Some(v.clone()),
707            LogicalExpression::Property { variable, .. } => Some(variable.clone()),
708            _ => None,
709        }
710    }
711
712    /// Optimizes the join order using DPccp, or produces a multi-way
713    /// leapfrog join for cyclic patterns when the cost model prefers it.
714    fn optimize_join_order(
715        &self,
716        relations: &[(String, LogicalOperator)],
717        conditions: &[JoinInfo],
718    ) -> Option<LogicalOperator> {
719        use join_order::{DPccp, JoinGraphBuilder};
720
721        // Build the join graph
722        let mut builder = JoinGraphBuilder::new();
723
724        for (var, relation) in relations {
725            builder.add_relation(var, relation.clone());
726        }
727
728        for cond in conditions {
729            builder.add_join_condition(
730                &cond.left_var,
731                &cond.right_var,
732                cond.left_expr.clone(),
733                cond.right_expr.clone(),
734            );
735        }
736
737        let graph = builder.build();
738
739        // For cyclic graphs with 3+ relations, use leapfrog (WCOJ) join.
740        // Cyclic joins (e.g. triangle patterns) benefit from worst-case optimal
741        // multi-way intersection rather than binary hash join cascades that can
742        // produce intermediate blowup.
743        if graph.is_cyclic() && relations.len() >= 3 {
744            // Collect shared variables (variables appearing in 2+ conditions)
745            let mut var_counts: std::collections::HashMap<&str, usize> =
746                std::collections::HashMap::new();
747            for cond in conditions {
748                *var_counts.entry(&cond.left_var).or_default() += 1;
749                *var_counts.entry(&cond.right_var).or_default() += 1;
750            }
751            let shared_variables: Vec<String> = var_counts
752                .into_iter()
753                .filter(|(_, count)| *count >= 2)
754                .map(|(var, _)| var.to_string())
755                .collect();
756
757            let join_conditions: Vec<JoinCondition> = conditions
758                .iter()
759                .map(|c| JoinCondition {
760                    left: c.left_expr.clone(),
761                    right: c.right_expr.clone(),
762                })
763                .collect();
764
765            return Some(LogicalOperator::MultiWayJoin(MultiWayJoinOp {
766                inputs: relations.iter().map(|(_, rel)| rel.clone()).collect(),
767                conditions: join_conditions,
768                shared_variables,
769            }));
770        }
771
772        // Fall through to DPccp for binary join ordering
773        let mut dpccp = DPccp::new(&graph, &self.cost_model, &self.card_estimator);
774        let plan = dpccp.optimize()?;
775
776        Some(plan.operator)
777    }
778
779    /// Propagates filter predicates across LeftJoin boundaries when the
780    /// predicate references variables bound on both sides of the join.
781    ///
782    /// OPTIONAL MATCH compiles to `LeftJoin(left, right)` where the right
783    /// subtree binds the same variable as the left (typically via its own
784    /// NodeScan), so the right side independently scans the full table even
785    /// when the left side is bound to a tiny set. After regular pushdown
786    /// rewrites `Filter(P, MainMatch)` into something like
787    /// `Expand(NodeList(P))`, the original filter is gone and we can no
788    /// longer detect the constraint to mirror to the right. This pass runs
789    /// BEFORE pushdown: for every `LeftJoin(Filter(P, X), Y)` where P's
790    /// variables are bound in both X and Y, wrap Y with the same Filter.
791    ///
792    /// Safe by LeftJoin semantics: matched pairs satisfy left.v = right.v
793    /// for shared variable v, so a right row that fails P also fails P
794    /// when joined to its matching left row, which already failed P. And
795    /// OPTIONAL left rows with no right match are unaffected, since we
796    /// only constrain the right side's enumeration.
797    fn propagate_join_predicates(&self, op: LogicalOperator) -> LogicalOperator {
798        // Recurse into children first (post-order). After children are
799        // processed, propagate at this level if it's a LeftJoin.
800        let op = op.map_children(|child| self.propagate_join_predicates(child));
801
802        let LogicalOperator::LeftJoin(mut left_join) = op else {
803            return op;
804        };
805
806        // Variables bound on both sides act as the implicit join keys
807        // (typically a single shared variable like `c`). Predicates that
808        // reference only these variables are safe to mirror to the right.
809        let left_vars = self.collect_output_variables(&left_join.left);
810        let right_vars = self.collect_output_variables(&left_join.right);
811        let shared_vars: HashSet<String> = left_vars.intersection(&right_vars).cloned().collect();
812
813        if shared_vars.is_empty() {
814            return LogicalOperator::LeftJoin(left_join);
815        }
816
817        let mut shared_filters = Vec::new();
818        self.collect_shared_var_filters(&left_join.left, &shared_vars, &mut shared_filters);
819        for predicate in shared_filters {
820            left_join.right = Box::new(LogicalOperator::Filter(FilterOp {
821                predicate,
822                pushdown_hint: None,
823                input: left_join.right,
824            }));
825        }
826
827        LogicalOperator::LeftJoin(left_join)
828    }
829
830    /// Walks a (sub)tree collecting filter predicates that reference only
831    /// the given shared variables. These are safe to mirror across a
832    /// LeftJoin to the right subtree.
833    fn collect_shared_var_filters(
834        &self,
835        op: &LogicalOperator,
836        shared_vars: &HashSet<String>,
837        out: &mut Vec<LogicalExpression>,
838    ) {
839        match op {
840            LogicalOperator::Filter(f) => {
841                let predicate_vars = self.extract_variables(&f.predicate);
842                if !predicate_vars.is_empty()
843                    && predicate_vars.iter().all(|v| shared_vars.contains(v))
844                {
845                    out.push(f.predicate.clone());
846                }
847                self.collect_shared_var_filters(&f.input, shared_vars, out);
848            }
849            // Don't descend through operators that change row semantics
850            // (Aggregate, Limit, Skip, Sort, Distinct): predicates above
851            // them aren't safe to extract because they apply to a
852            // post-aggregation/limit world. Project, Return, and Expand
853            // are row-preserving so we can keep walking.
854            LogicalOperator::Project(p) => {
855                self.collect_shared_var_filters(&p.input, shared_vars, out);
856            }
857            LogicalOperator::Return(r) => {
858                self.collect_shared_var_filters(&r.input, shared_vars, out);
859            }
860            LogicalOperator::Expand(e) => {
861                self.collect_shared_var_filters(&e.input, shared_vars, out);
862            }
863            // For nested LeftJoins (chained OPTIONAL MATCH), the LEFT side
864            // is the row-driving subtree: its surviving rows feed the outer
865            // join. Filters anywhere in that left subtree are implicitly
866            // applied to every output row, so they're safe to mirror across
867            // the outer join. The right side is optional and may contribute
868            // NULL-padded rows that don't satisfy those filters, so we
869            // deliberately don't walk into it.
870            LogicalOperator::LeftJoin(j) => {
871                self.collect_shared_var_filters(&j.left, shared_vars, out);
872            }
873            LogicalOperator::Join(j) => {
874                self.collect_shared_var_filters(&j.left, shared_vars, out);
875                self.collect_shared_var_filters(&j.right, shared_vars, out);
876            }
877            // Stop at Apply, Aggregate, Limit, Skip, etc.: too risky to
878            // pull predicates across those boundaries.
879            _ => {}
880        }
881    }
882
883    /// Pushes filters down the operator tree.
884    ///
885    /// This optimization moves filter predicates as close to the data source
886    /// as possible to reduce the amount of data processed by upper operators.
887    fn push_filters_down(&self, op: LogicalOperator) -> LogicalOperator {
888        match op {
889            // For Filter operators, try to push the predicate into the child
890            LogicalOperator::Filter(filter) => {
891                let optimized_input = self.push_filters_down(*filter.input);
892                self.try_push_filter_into(filter.predicate, optimized_input)
893            }
894            // Recursively optimize children for other operators
895            LogicalOperator::Return(mut ret) => {
896                ret.input = Box::new(self.push_filters_down(*ret.input));
897                LogicalOperator::Return(ret)
898            }
899            LogicalOperator::Project(mut proj) => {
900                proj.input = Box::new(self.push_filters_down(*proj.input));
901                LogicalOperator::Project(proj)
902            }
903            LogicalOperator::Limit(mut limit) => {
904                limit.input = Box::new(self.push_filters_down(*limit.input));
905                LogicalOperator::Limit(limit)
906            }
907            LogicalOperator::Skip(mut skip) => {
908                skip.input = Box::new(self.push_filters_down(*skip.input));
909                LogicalOperator::Skip(skip)
910            }
911            LogicalOperator::Sort(mut sort) => {
912                sort.input = Box::new(self.push_filters_down(*sort.input));
913                LogicalOperator::Sort(sort)
914            }
915            LogicalOperator::Distinct(mut distinct) => {
916                distinct.input = Box::new(self.push_filters_down(*distinct.input));
917                LogicalOperator::Distinct(distinct)
918            }
919            LogicalOperator::Expand(mut expand) => {
920                expand.input = Box::new(self.push_filters_down(*expand.input));
921                LogicalOperator::Expand(expand)
922            }
923            LogicalOperator::Join(mut join) => {
924                join.left = Box::new(self.push_filters_down(*join.left));
925                join.right = Box::new(self.push_filters_down(*join.right));
926                LogicalOperator::Join(join)
927            }
928            LogicalOperator::LeftJoin(mut left_join) => {
929                left_join.left = Box::new(self.push_filters_down(*left_join.left));
930                left_join.right = Box::new(self.push_filters_down(*left_join.right));
931                LogicalOperator::LeftJoin(left_join)
932            }
933            LogicalOperator::AntiJoin(mut anti_join) => {
934                anti_join.left = Box::new(self.push_filters_down(*anti_join.left));
935                anti_join.right = Box::new(self.push_filters_down(*anti_join.right));
936                LogicalOperator::AntiJoin(anti_join)
937            }
938            LogicalOperator::Apply(mut apply) => {
939                apply.input = Box::new(self.push_filters_down(*apply.input));
940                apply.subplan = Box::new(self.push_filters_down(*apply.subplan));
941                LogicalOperator::Apply(apply)
942            }
943            LogicalOperator::Union(mut union) => {
944                union.inputs = union
945                    .inputs
946                    .into_iter()
947                    .map(|input| self.push_filters_down(input))
948                    .collect();
949                LogicalOperator::Union(union)
950            }
951            LogicalOperator::Unwind(mut unwind) => {
952                unwind.input = Box::new(self.push_filters_down(*unwind.input));
953                LogicalOperator::Unwind(unwind)
954            }
955            LogicalOperator::Aggregate(mut agg) => {
956                agg.input = Box::new(self.push_filters_down(*agg.input));
957                LogicalOperator::Aggregate(agg)
958            }
959            LogicalOperator::MapCollect(mut mc) => {
960                mc.input = Box::new(self.push_filters_down(*mc.input));
961                LogicalOperator::MapCollect(mc)
962            }
963            LogicalOperator::MultiWayJoin(mut mwj) => {
964                mwj.inputs = mwj
965                    .inputs
966                    .into_iter()
967                    .map(|input| self.push_filters_down(input))
968                    .collect();
969                LogicalOperator::MultiWayJoin(mwj)
970            }
971            // Leaf operators and unsupported operators are returned as-is
972            other => other,
973        }
974    }
975
976    /// Tries to push a filter predicate into the given operator.
977    ///
978    /// Returns either the predicate pushed into the operator, or a new
979    /// Filter operator on top if the predicate cannot be pushed further.
980    fn try_push_filter_into(
981        &self,
982        predicate: LogicalExpression,
983        op: LogicalOperator,
984    ) -> LogicalOperator {
985        match op {
986            // Can push through Project if predicate doesn't depend on computed columns
987            LogicalOperator::Project(mut proj) => {
988                let predicate_vars = self.extract_variables(&predicate);
989                let computed_vars = self.extract_projection_aliases(&proj.projections);
990
991                // If predicate doesn't use any computed columns, push through
992                if predicate_vars.is_disjoint(&computed_vars) {
993                    proj.input = Box::new(self.try_push_filter_into(predicate, *proj.input));
994                    LogicalOperator::Project(proj)
995                } else {
996                    // Can't push through, keep filter on top
997                    LogicalOperator::Filter(FilterOp {
998                        predicate,
999                        pushdown_hint: None,
1000                        input: Box::new(LogicalOperator::Project(proj)),
1001                    })
1002                }
1003            }
1004
1005            // Can push through Return (which is like a projection)
1006            LogicalOperator::Return(mut ret) => {
1007                ret.input = Box::new(self.try_push_filter_into(predicate, *ret.input));
1008                LogicalOperator::Return(ret)
1009            }
1010
1011            // Can push through Expand if predicate doesn't use variables introduced by this expand
1012            LogicalOperator::Expand(mut expand) => {
1013                let predicate_vars = self.extract_variables(&predicate);
1014
1015                // Variables introduced by this expand are:
1016                // - The target variable (to_variable)
1017                // - The edge variable (if any)
1018                // - The path alias (if any)
1019                let mut introduced_vars = vec![&expand.to_variable];
1020                if let Some(ref edge_var) = expand.edge_variable {
1021                    introduced_vars.push(edge_var);
1022                }
1023                if let Some(ref path_alias) = expand.path_alias {
1024                    introduced_vars.push(path_alias);
1025                }
1026
1027                // Check if predicate uses any variables introduced by this expand
1028                let uses_introduced_vars =
1029                    predicate_vars.iter().any(|v| introduced_vars.contains(&v));
1030
1031                if !uses_introduced_vars {
1032                    // Predicate doesn't use vars from this expand, so push through
1033                    expand.input = Box::new(self.try_push_filter_into(predicate, *expand.input));
1034                    LogicalOperator::Expand(expand)
1035                } else {
1036                    // Keep filter after expand
1037                    LogicalOperator::Filter(FilterOp {
1038                        predicate,
1039                        pushdown_hint: None,
1040                        input: Box::new(LogicalOperator::Expand(expand)),
1041                    })
1042                }
1043            }
1044
1045            // Can push through Join to left/right side based on variables used
1046            LogicalOperator::Join(mut join) => {
1047                let predicate_vars = self.extract_variables(&predicate);
1048                let left_vars = self.collect_output_variables(&join.left);
1049                let right_vars = self.collect_output_variables(&join.right);
1050
1051                let uses_left = predicate_vars.iter().any(|v| left_vars.contains(v));
1052                let uses_right = predicate_vars.iter().any(|v| right_vars.contains(v));
1053
1054                if uses_left && !uses_right {
1055                    // Push to left side
1056                    join.left = Box::new(self.try_push_filter_into(predicate, *join.left));
1057                    LogicalOperator::Join(join)
1058                } else if uses_right && !uses_left {
1059                    // Push to right side
1060                    join.right = Box::new(self.try_push_filter_into(predicate, *join.right));
1061                    LogicalOperator::Join(join)
1062                } else {
1063                    // Uses both sides - keep above join
1064                    LogicalOperator::Filter(FilterOp {
1065                        predicate,
1066                        pushdown_hint: None,
1067                        input: Box::new(LogicalOperator::Join(join)),
1068                    })
1069                }
1070            }
1071
1072            // LeftJoin pushdown is semantics-preserving only on the LEFT
1073            // side: anything that filters out a left row also filters out
1074            // every (left, NULL) pair the join would have emitted. Pushing
1075            // to the right side is unsafe because OPTIONAL MATCH must keep
1076            // left rows that have no right match.
1077            LogicalOperator::LeftJoin(mut left_join) => {
1078                let predicate_vars = self.extract_variables(&predicate);
1079                let left_vars = self.collect_output_variables(&left_join.left);
1080                let right_vars = self.collect_output_variables(&left_join.right);
1081
1082                let uses_left = predicate_vars.iter().any(|v| left_vars.contains(v));
1083                let uses_right = predicate_vars.iter().any(|v| right_vars.contains(v));
1084
1085                if uses_left && !uses_right {
1086                    left_join.left =
1087                        Box::new(self.try_push_filter_into(predicate, *left_join.left));
1088                    LogicalOperator::LeftJoin(left_join)
1089                } else if uses_left
1090                    && uses_right
1091                    && predicate_vars
1092                        .iter()
1093                        .all(|v| left_vars.contains(v) && right_vars.contains(v))
1094                {
1095                    // Predicate references only variables that are bound on
1096                    // both sides (i.e. join-key variables). The OPTIONAL
1097                    // MATCH compiles to a LeftJoin where the right subtree
1098                    // independently re-binds the shared variable (typically
1099                    // via its own NodeScan), so the right side can balloon
1100                    // to the full table even when the left side is bound
1101                    // to a tiny set. Duplicating the predicate to both
1102                    // sides is safe: matched pairs satisfy left.x == right.x,
1103                    // so a right row that fails the predicate either has
1104                    // no left match or pairs with a left row that also
1105                    // fails. Unmatched (OPTIONAL) left rows are unaffected.
1106                    // This is what collapses hydrate's three independent
1107                    // 30k-row scans into 6-id index lookups.
1108                    left_join.left =
1109                        Box::new(self.try_push_filter_into(predicate.clone(), *left_join.left));
1110                    left_join.right =
1111                        Box::new(self.try_push_filter_into(predicate, *left_join.right));
1112                    LogicalOperator::LeftJoin(left_join)
1113                } else {
1114                    LogicalOperator::Filter(FilterOp {
1115                        predicate,
1116                        pushdown_hint: None,
1117                        input: Box::new(LogicalOperator::LeftJoin(left_join)),
1118                    })
1119                }
1120            }
1121
1122            // Apply (correlated subquery): the input is the outer plan, the
1123            // subplan re-evaluates per outer row. Predicates that only use
1124            // outer-side variables can safely push into the input.
1125            LogicalOperator::Apply(mut apply) => {
1126                let predicate_vars = self.extract_variables(&predicate);
1127                let input_vars = self.collect_output_variables(&apply.input);
1128                let subplan_vars = self.collect_output_variables(&apply.subplan);
1129
1130                let uses_input = predicate_vars.iter().any(|v| input_vars.contains(v));
1131                let uses_subplan = predicate_vars.iter().any(|v| subplan_vars.contains(v));
1132
1133                if uses_input && !uses_subplan {
1134                    apply.input = Box::new(self.try_push_filter_into(predicate, *apply.input));
1135                    LogicalOperator::Apply(apply)
1136                } else {
1137                    LogicalOperator::Filter(FilterOp {
1138                        predicate,
1139                        pushdown_hint: None,
1140                        input: Box::new(LogicalOperator::Apply(apply)),
1141                    })
1142                }
1143            }
1144
1145            // Cannot push through Aggregate (predicate refers to aggregated values)
1146            LogicalOperator::Aggregate(agg) => LogicalOperator::Filter(FilterOp {
1147                predicate,
1148                pushdown_hint: None,
1149                input: Box::new(LogicalOperator::Aggregate(agg)),
1150            }),
1151
1152            // For NodeScan, we've reached the bottom - keep filter on top
1153            LogicalOperator::NodeScan(scan) => LogicalOperator::Filter(FilterOp {
1154                predicate,
1155                pushdown_hint: None,
1156                input: Box::new(LogicalOperator::NodeScan(scan)),
1157            }),
1158
1159            // Filters commute (boolean conjunction is associative/commutative),
1160            // so we can push the outer predicate past an existing inner
1161            // Filter and continue trying to push further down. Without this
1162            // case, a pattern like
1163            // `Filter(r.id IN [...], Filter(hasLabel(d), Expand))` (emitted
1164            // by Cypher when the WHERE clause is conjoined with automatic
1165            // label filters) gets stuck at the top, missing the chance to
1166            // anchor the filter on r's NodeScan and trigger the property-
1167            // index fast path.
1168            //
1169            // Only commute when the outer predicate references variables
1170            // that are bound at or below the inner filter's input. If the
1171            // inner filter is a final scope-narrowing step that introduces
1172            // synthetic columns the outer predicate depends on (e.g. path
1173            // variables produced by a variable-length expand wrapped in a
1174            // label filter), pushing past it would evaluate the outer
1175            // predicate against rows where those columns aren't bound and
1176            // every row would be filtered out.
1177            LogicalOperator::Filter(inner_filter) => {
1178                let predicate_vars = self.extract_variables(&predicate);
1179                let inner_input_vars = self.collect_output_variables(&inner_filter.input);
1180                let safe_to_commute = predicate_vars.iter().all(|v| inner_input_vars.contains(v));
1181                if safe_to_commute {
1182                    let mut inner_filter = inner_filter;
1183                    inner_filter.input =
1184                        Box::new(self.try_push_filter_into(predicate, *inner_filter.input));
1185                    LogicalOperator::Filter(inner_filter)
1186                } else {
1187                    LogicalOperator::Filter(FilterOp {
1188                        predicate,
1189                        pushdown_hint: None,
1190                        input: Box::new(LogicalOperator::Filter(inner_filter)),
1191                    })
1192                }
1193            }
1194
1195            // For other operators, keep filter on top
1196            other => LogicalOperator::Filter(FilterOp {
1197                predicate,
1198                pushdown_hint: None,
1199                input: Box::new(other),
1200            }),
1201        }
1202    }
1203
1204    // NOTE: Filter-into-TripleScan pushdown is intentionally not implemented.
1205    // Binding a variable position in the scan removes it from output columns,
1206    // which breaks downstream operators (SELECT, ORDER BY) that reference it.
1207    // A correct implementation needs a separate `pushed_bindings` field on
1208    // TripleScanOp so the variable stays in output while the scan is constrained.
1209
1210    /// Collects all output variable names from an operator.
1211    fn collect_output_variables(&self, op: &LogicalOperator) -> HashSet<String> {
1212        let mut vars = HashSet::new();
1213        Self::collect_output_variables_recursive(op, &mut vars);
1214        vars
1215    }
1216
1217    /// Recursively collects output variables from an operator.
1218    fn collect_output_variables_recursive(op: &LogicalOperator, vars: &mut HashSet<String>) {
1219        match op {
1220            LogicalOperator::NodeScan(scan) => {
1221                vars.insert(scan.variable.clone());
1222            }
1223            LogicalOperator::EdgeScan(scan) => {
1224                vars.insert(scan.variable.clone());
1225            }
1226            LogicalOperator::Expand(expand) => {
1227                vars.insert(expand.to_variable.clone());
1228                if let Some(edge_var) = &expand.edge_variable {
1229                    vars.insert(edge_var.clone());
1230                }
1231                Self::collect_output_variables_recursive(&expand.input, vars);
1232            }
1233            LogicalOperator::Filter(filter) => {
1234                Self::collect_output_variables_recursive(&filter.input, vars);
1235            }
1236            LogicalOperator::Project(proj) => {
1237                for p in &proj.projections {
1238                    if let Some(alias) = &p.alias {
1239                        vars.insert(alias.clone());
1240                    }
1241                }
1242                Self::collect_output_variables_recursive(&proj.input, vars);
1243            }
1244            LogicalOperator::Join(join) => {
1245                Self::collect_output_variables_recursive(&join.left, vars);
1246                Self::collect_output_variables_recursive(&join.right, vars);
1247            }
1248            LogicalOperator::Aggregate(agg) => {
1249                for expr in &agg.group_by {
1250                    Self::collect_variables(expr, vars);
1251                }
1252                for agg_expr in &agg.aggregates {
1253                    if let Some(alias) = &agg_expr.alias {
1254                        vars.insert(alias.clone());
1255                    }
1256                }
1257            }
1258            LogicalOperator::Return(ret) => {
1259                Self::collect_output_variables_recursive(&ret.input, vars);
1260            }
1261            LogicalOperator::Limit(limit) => {
1262                Self::collect_output_variables_recursive(&limit.input, vars);
1263            }
1264            LogicalOperator::Skip(skip) => {
1265                Self::collect_output_variables_recursive(&skip.input, vars);
1266            }
1267            LogicalOperator::Sort(sort) => {
1268                Self::collect_output_variables_recursive(&sort.input, vars);
1269            }
1270            LogicalOperator::Distinct(distinct) => {
1271                Self::collect_output_variables_recursive(&distinct.input, vars);
1272            }
1273            #[cfg(feature = "triple-store")]
1274            LogicalOperator::TripleScan(scan) => {
1275                if let Some(v) = scan.subject.as_variable() {
1276                    vars.insert(v.to_string());
1277                }
1278                if let Some(v) = scan.predicate.as_variable() {
1279                    vars.insert(v.to_string());
1280                }
1281                if let Some(v) = scan.object.as_variable() {
1282                    vars.insert(v.to_string());
1283                }
1284                if let Some(ref g) = scan.graph
1285                    && let Some(v) = g.as_variable()
1286                {
1287                    vars.insert(v.to_string());
1288                }
1289            }
1290            _ => {}
1291        }
1292    }
1293
1294    /// Extracts all variable names referenced in an expression.
1295    fn extract_variables(&self, expr: &LogicalExpression) -> HashSet<String> {
1296        let mut vars = HashSet::new();
1297        Self::collect_variables(expr, &mut vars);
1298        vars
1299    }
1300
1301    /// Recursively collects variable names from an expression.
1302    fn collect_variables(expr: &LogicalExpression, vars: &mut HashSet<String>) {
1303        match expr {
1304            LogicalExpression::Variable(name) => {
1305                vars.insert(name.clone());
1306            }
1307            LogicalExpression::Property { variable, .. } => {
1308                vars.insert(variable.clone());
1309            }
1310            LogicalExpression::Binary { left, right, .. } => {
1311                Self::collect_variables(left, vars);
1312                Self::collect_variables(right, vars);
1313            }
1314            LogicalExpression::Unary { operand, .. } => {
1315                Self::collect_variables(operand, vars);
1316            }
1317            LogicalExpression::FunctionCall { args, .. } => {
1318                for arg in args {
1319                    Self::collect_variables(arg, vars);
1320                }
1321            }
1322            LogicalExpression::List(items) => {
1323                for item in items {
1324                    Self::collect_variables(item, vars);
1325                }
1326            }
1327            LogicalExpression::Map(pairs) => {
1328                for (_, value) in pairs {
1329                    Self::collect_variables(value, vars);
1330                }
1331            }
1332            LogicalExpression::IndexAccess { base, index } => {
1333                Self::collect_variables(base, vars);
1334                Self::collect_variables(index, vars);
1335            }
1336            LogicalExpression::SliceAccess { base, start, end } => {
1337                Self::collect_variables(base, vars);
1338                if let Some(s) = start {
1339                    Self::collect_variables(s, vars);
1340                }
1341                if let Some(e) = end {
1342                    Self::collect_variables(e, vars);
1343                }
1344            }
1345            LogicalExpression::Case {
1346                operand,
1347                when_clauses,
1348                else_clause,
1349            } => {
1350                if let Some(op) = operand {
1351                    Self::collect_variables(op, vars);
1352                }
1353                for (cond, result) in when_clauses {
1354                    Self::collect_variables(cond, vars);
1355                    Self::collect_variables(result, vars);
1356                }
1357                if let Some(else_expr) = else_clause {
1358                    Self::collect_variables(else_expr, vars);
1359                }
1360            }
1361            LogicalExpression::Labels(var)
1362            | LogicalExpression::Type(var)
1363            | LogicalExpression::Id(var) => {
1364                vars.insert(var.clone());
1365            }
1366            LogicalExpression::Literal(_) | LogicalExpression::Parameter(_) => {}
1367            LogicalExpression::ListComprehension {
1368                list_expr,
1369                filter_expr,
1370                map_expr,
1371                ..
1372            } => {
1373                Self::collect_variables(list_expr, vars);
1374                if let Some(filter) = filter_expr {
1375                    Self::collect_variables(filter, vars);
1376                }
1377                Self::collect_variables(map_expr, vars);
1378            }
1379            LogicalExpression::ListPredicate {
1380                list_expr,
1381                predicate,
1382                ..
1383            } => {
1384                Self::collect_variables(list_expr, vars);
1385                Self::collect_variables(predicate, vars);
1386            }
1387            LogicalExpression::ExistsSubquery(_)
1388            | LogicalExpression::CountSubquery(_)
1389            | LogicalExpression::ValueSubquery(_) => {
1390                // Subqueries have their own variable scope
1391            }
1392            LogicalExpression::PatternComprehension { projection, .. } => {
1393                Self::collect_variables(projection, vars);
1394            }
1395            LogicalExpression::MapProjection { base, entries } => {
1396                vars.insert(base.clone());
1397                for entry in entries {
1398                    if let crate::query::plan::MapProjectionEntry::LiteralEntry(_, expr) = entry {
1399                        Self::collect_variables(expr, vars);
1400                    }
1401                }
1402            }
1403            LogicalExpression::Reduce {
1404                initial,
1405                list,
1406                expression,
1407                ..
1408            } => {
1409                Self::collect_variables(initial, vars);
1410                Self::collect_variables(list, vars);
1411                Self::collect_variables(expression, vars);
1412            }
1413        }
1414    }
1415
1416    /// Extracts aliases from projection expressions.
1417    fn extract_projection_aliases(
1418        &self,
1419        projections: &[crate::query::plan::Projection],
1420    ) -> HashSet<String> {
1421        projections.iter().filter_map(|p| p.alias.clone()).collect()
1422    }
1423}
1424
1425impl Default for Optimizer {
1426    fn default() -> Self {
1427        Self::new()
1428    }
1429}
1430
1431#[cfg(test)]
1432mod tests {
1433    use super::*;
1434    use crate::query::plan::{
1435        AggregateExpr, AggregateFunction, AggregateOp, BinaryOp, DistinctOp, ExpandDirection,
1436        ExpandOp, JoinOp, JoinType, LimitOp, NodeScanOp, PathMode, ProjectOp, Projection,
1437        ReturnItem, ReturnOp, SkipOp, SortKey, SortOp, SortOrder, UnaryOp,
1438    };
1439    use grafeo_common::types::Value;
1440
1441    #[test]
1442    fn test_optimizer_filter_pushdown_simple() {
1443        // Query: MATCH (n:Person) WHERE n.age > 30 RETURN n
1444        // Before: Return -> Filter -> NodeScan
1445        // After:  Return -> Filter -> NodeScan (filter stays at bottom)
1446
1447        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1448            items: vec![ReturnItem {
1449                expression: LogicalExpression::Variable("n".to_string()),
1450                alias: None,
1451            }],
1452            distinct: false,
1453            input: Box::new(LogicalOperator::Filter(FilterOp {
1454                predicate: LogicalExpression::Binary {
1455                    left: Box::new(LogicalExpression::Property {
1456                        variable: "n".to_string(),
1457                        property: "age".to_string(),
1458                    }),
1459                    op: BinaryOp::Gt,
1460                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1461                },
1462                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1463                    variable: "n".to_string(),
1464                    label: Some("Person".to_string()),
1465                    input: None,
1466                })),
1467                pushdown_hint: None,
1468            })),
1469        }));
1470
1471        let optimizer = Optimizer::new();
1472        let optimized = optimizer.optimize(plan).unwrap();
1473
1474        // The structure should remain similar (filter stays near scan)
1475        if let LogicalOperator::Return(ret) = &optimized.root
1476            && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1477            && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1478        {
1479            assert_eq!(scan.variable, "n");
1480            return;
1481        }
1482        panic!("Expected Return -> Filter -> NodeScan structure");
1483    }
1484
1485    #[test]
1486    fn test_optimizer_filter_pushdown_through_expand() {
1487        // Query: MATCH (a:Person)-[:KNOWS]->(b) WHERE a.age > 30 RETURN b
1488        // The filter on 'a' should be pushed before the expand
1489
1490        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1491            items: vec![ReturnItem {
1492                expression: LogicalExpression::Variable("b".to_string()),
1493                alias: None,
1494            }],
1495            distinct: false,
1496            input: Box::new(LogicalOperator::Filter(FilterOp {
1497                predicate: LogicalExpression::Binary {
1498                    left: Box::new(LogicalExpression::Property {
1499                        variable: "a".to_string(),
1500                        property: "age".to_string(),
1501                    }),
1502                    op: BinaryOp::Gt,
1503                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1504                },
1505                pushdown_hint: None,
1506                input: Box::new(LogicalOperator::Expand(ExpandOp {
1507                    from_variable: "a".to_string(),
1508                    to_variable: "b".to_string(),
1509                    edge_variable: None,
1510                    direction: ExpandDirection::Outgoing,
1511                    edge_types: vec!["KNOWS".to_string()],
1512                    min_hops: 1,
1513                    max_hops: Some(1),
1514                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1515                        variable: "a".to_string(),
1516                        label: Some("Person".to_string()),
1517                        input: None,
1518                    })),
1519                    path_alias: None,
1520                    path_mode: PathMode::Walk,
1521                })),
1522            })),
1523        }));
1524
1525        let optimizer = Optimizer::new();
1526        let optimized = optimizer.optimize(plan).unwrap();
1527
1528        // Filter on 'a' should be pushed before the expand
1529        // Expected: Return -> Expand -> Filter -> NodeScan
1530        if let LogicalOperator::Return(ret) = &optimized.root
1531            && let LogicalOperator::Expand(expand) = ret.input.as_ref()
1532            && let LogicalOperator::Filter(filter) = expand.input.as_ref()
1533            && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1534        {
1535            assert_eq!(scan.variable, "a");
1536            assert_eq!(expand.from_variable, "a");
1537            assert_eq!(expand.to_variable, "b");
1538            return;
1539        }
1540        panic!("Expected Return -> Expand -> Filter -> NodeScan structure");
1541    }
1542
1543    #[test]
1544    fn test_optimizer_filter_not_pushed_through_expand_for_target_var() {
1545        // Query: MATCH (a:Person)-[:KNOWS]->(b) WHERE b.age > 30 RETURN a
1546        // The filter on 'b' should NOT be pushed before the expand
1547
1548        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1549            items: vec![ReturnItem {
1550                expression: LogicalExpression::Variable("a".to_string()),
1551                alias: None,
1552            }],
1553            distinct: false,
1554            input: Box::new(LogicalOperator::Filter(FilterOp {
1555                predicate: LogicalExpression::Binary {
1556                    left: Box::new(LogicalExpression::Property {
1557                        variable: "b".to_string(),
1558                        property: "age".to_string(),
1559                    }),
1560                    op: BinaryOp::Gt,
1561                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1562                },
1563                pushdown_hint: None,
1564                input: Box::new(LogicalOperator::Expand(ExpandOp {
1565                    from_variable: "a".to_string(),
1566                    to_variable: "b".to_string(),
1567                    edge_variable: None,
1568                    direction: ExpandDirection::Outgoing,
1569                    edge_types: vec!["KNOWS".to_string()],
1570                    min_hops: 1,
1571                    max_hops: Some(1),
1572                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1573                        variable: "a".to_string(),
1574                        label: Some("Person".to_string()),
1575                        input: None,
1576                    })),
1577                    path_alias: None,
1578                    path_mode: PathMode::Walk,
1579                })),
1580            })),
1581        }));
1582
1583        let optimizer = Optimizer::new();
1584        let optimized = optimizer.optimize(plan).unwrap();
1585
1586        // Filter on 'b' should stay after the expand
1587        // Expected: Return -> Filter -> Expand -> NodeScan
1588        if let LogicalOperator::Return(ret) = &optimized.root
1589            && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1590        {
1591            // Check that the filter is on 'b'
1592            if let LogicalExpression::Binary { left, .. } = &filter.predicate
1593                && let LogicalExpression::Property { variable, .. } = left.as_ref()
1594            {
1595                assert_eq!(variable, "b");
1596            }
1597
1598            if let LogicalOperator::Expand(expand) = filter.input.as_ref()
1599                && let LogicalOperator::NodeScan(_) = expand.input.as_ref()
1600            {
1601                return;
1602            }
1603        }
1604        panic!("Expected Return -> Filter -> Expand -> NodeScan structure");
1605    }
1606
1607    #[test]
1608    fn test_optimizer_extract_variables() {
1609        let optimizer = Optimizer::new();
1610
1611        let expr = LogicalExpression::Binary {
1612            left: Box::new(LogicalExpression::Property {
1613                variable: "n".to_string(),
1614                property: "age".to_string(),
1615            }),
1616            op: BinaryOp::Gt,
1617            right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1618        };
1619
1620        let vars = optimizer.extract_variables(&expr);
1621        assert_eq!(vars.len(), 1);
1622        assert!(vars.contains("n"));
1623    }
1624
1625    // Additional tests for optimizer configuration
1626
1627    #[test]
1628    fn test_optimizer_default() {
1629        let optimizer = Optimizer::default();
1630        // Should be able to optimize an empty plan
1631        let plan = LogicalPlan::new(LogicalOperator::Empty);
1632        let result = optimizer.optimize(plan);
1633        assert!(result.is_ok());
1634    }
1635
1636    #[test]
1637    fn test_optimizer_with_filter_pushdown_disabled() {
1638        let optimizer = Optimizer::new().with_filter_pushdown(false);
1639
1640        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1641            items: vec![ReturnItem {
1642                expression: LogicalExpression::Variable("n".to_string()),
1643                alias: None,
1644            }],
1645            distinct: false,
1646            input: Box::new(LogicalOperator::Filter(FilterOp {
1647                predicate: LogicalExpression::Literal(Value::Bool(true)),
1648                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1649                    variable: "n".to_string(),
1650                    label: None,
1651                    input: None,
1652                })),
1653                pushdown_hint: None,
1654            })),
1655        }));
1656
1657        let optimized = optimizer.optimize(plan).unwrap();
1658        // Structure should be unchanged
1659        if let LogicalOperator::Return(ret) = &optimized.root
1660            && let LogicalOperator::Filter(_) = ret.input.as_ref()
1661        {
1662            return;
1663        }
1664        panic!("Expected unchanged structure");
1665    }
1666
1667    #[test]
1668    fn test_optimizer_with_join_reorder_disabled() {
1669        let optimizer = Optimizer::new().with_join_reorder(false);
1670        assert!(
1671            optimizer
1672                .optimize(LogicalPlan::new(LogicalOperator::Empty))
1673                .is_ok()
1674        );
1675    }
1676
1677    #[test]
1678    fn test_optimizer_with_cost_model() {
1679        let cost_model = CostModel::new();
1680        let optimizer = Optimizer::new().with_cost_model(cost_model);
1681        assert!(
1682            optimizer
1683                .cost_model()
1684                .estimate(&LogicalOperator::Empty, 0.0)
1685                .total()
1686                < 0.001
1687        );
1688    }
1689
1690    #[test]
1691    fn test_optimizer_with_cardinality_estimator() {
1692        let mut estimator = CardinalityEstimator::new();
1693        estimator.add_table_stats("Test", TableStats::new(500));
1694        let optimizer = Optimizer::new().with_cardinality_estimator(estimator);
1695
1696        let scan = LogicalOperator::NodeScan(NodeScanOp {
1697            variable: "n".to_string(),
1698            label: Some("Test".to_string()),
1699            input: None,
1700        });
1701        let plan = LogicalPlan::new(scan);
1702
1703        let cardinality = optimizer.estimate_cardinality(&plan);
1704        assert!((cardinality - 500.0).abs() < 0.001);
1705    }
1706
1707    #[test]
1708    fn test_optimizer_estimate_cost() {
1709        let optimizer = Optimizer::new();
1710        let plan = LogicalPlan::new(LogicalOperator::NodeScan(NodeScanOp {
1711            variable: "n".to_string(),
1712            label: None,
1713            input: None,
1714        }));
1715
1716        let cost = optimizer.estimate_cost(&plan);
1717        assert!(cost.total() > 0.0);
1718    }
1719
1720    // Filter pushdown through various operators
1721
1722    #[test]
1723    fn test_filter_pushdown_through_project() {
1724        let optimizer = Optimizer::new();
1725
1726        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1727            predicate: LogicalExpression::Binary {
1728                left: Box::new(LogicalExpression::Property {
1729                    variable: "n".to_string(),
1730                    property: "age".to_string(),
1731                }),
1732                op: BinaryOp::Gt,
1733                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1734            },
1735            pushdown_hint: None,
1736            input: Box::new(LogicalOperator::Project(ProjectOp {
1737                projections: vec![Projection {
1738                    expression: LogicalExpression::Variable("n".to_string()),
1739                    alias: None,
1740                }],
1741                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1742                    variable: "n".to_string(),
1743                    label: None,
1744                    input: None,
1745                })),
1746                pass_through_input: false,
1747            })),
1748        }));
1749
1750        let optimized = optimizer.optimize(plan).unwrap();
1751
1752        // Filter should be pushed through Project
1753        if let LogicalOperator::Project(proj) = &optimized.root
1754            && let LogicalOperator::Filter(_) = proj.input.as_ref()
1755        {
1756            return;
1757        }
1758        panic!("Expected Project -> Filter structure");
1759    }
1760
1761    #[test]
1762    fn test_filter_not_pushed_through_project_with_alias() {
1763        let optimizer = Optimizer::new();
1764
1765        // Filter on computed column 'x' should not be pushed through project that creates 'x'
1766        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1767            predicate: LogicalExpression::Binary {
1768                left: Box::new(LogicalExpression::Variable("x".to_string())),
1769                op: BinaryOp::Gt,
1770                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1771            },
1772            pushdown_hint: None,
1773            input: Box::new(LogicalOperator::Project(ProjectOp {
1774                projections: vec![Projection {
1775                    expression: LogicalExpression::Property {
1776                        variable: "n".to_string(),
1777                        property: "age".to_string(),
1778                    },
1779                    alias: Some("x".to_string()),
1780                }],
1781                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1782                    variable: "n".to_string(),
1783                    label: None,
1784                    input: None,
1785                })),
1786                pass_through_input: false,
1787            })),
1788        }));
1789
1790        let optimized = optimizer.optimize(plan).unwrap();
1791
1792        // Filter should stay above Project
1793        if let LogicalOperator::Filter(filter) = &optimized.root
1794            && let LogicalOperator::Project(_) = filter.input.as_ref()
1795        {
1796            return;
1797        }
1798        panic!("Expected Filter -> Project structure");
1799    }
1800
1801    #[test]
1802    fn test_filter_pushdown_through_limit() {
1803        let optimizer = Optimizer::new();
1804
1805        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1806            predicate: LogicalExpression::Literal(Value::Bool(true)),
1807            pushdown_hint: None,
1808            input: Box::new(LogicalOperator::Limit(LimitOp {
1809                count: 10.into(),
1810                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1811                    variable: "n".to_string(),
1812                    label: None,
1813                    input: None,
1814                })),
1815            })),
1816        }));
1817
1818        let optimized = optimizer.optimize(plan).unwrap();
1819
1820        // Filter stays above Limit (cannot be pushed through)
1821        if let LogicalOperator::Filter(filter) = &optimized.root
1822            && let LogicalOperator::Limit(_) = filter.input.as_ref()
1823        {
1824            return;
1825        }
1826        panic!("Expected Filter -> Limit structure");
1827    }
1828
1829    #[test]
1830    fn test_filter_pushdown_through_sort() {
1831        let optimizer = Optimizer::new();
1832
1833        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1834            predicate: LogicalExpression::Literal(Value::Bool(true)),
1835            pushdown_hint: None,
1836            input: Box::new(LogicalOperator::Sort(SortOp {
1837                keys: vec![SortKey {
1838                    expression: LogicalExpression::Variable("n".to_string()),
1839                    order: SortOrder::Ascending,
1840                    nulls: None,
1841                }],
1842                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1843                    variable: "n".to_string(),
1844                    label: None,
1845                    input: None,
1846                })),
1847            })),
1848        }));
1849
1850        let optimized = optimizer.optimize(plan).unwrap();
1851
1852        // Filter stays above Sort
1853        if let LogicalOperator::Filter(filter) = &optimized.root
1854            && let LogicalOperator::Sort(_) = filter.input.as_ref()
1855        {
1856            return;
1857        }
1858        panic!("Expected Filter -> Sort structure");
1859    }
1860
1861    #[test]
1862    fn test_filter_pushdown_through_distinct() {
1863        let optimizer = Optimizer::new();
1864
1865        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1866            predicate: LogicalExpression::Literal(Value::Bool(true)),
1867            pushdown_hint: None,
1868            input: Box::new(LogicalOperator::Distinct(DistinctOp {
1869                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1870                    variable: "n".to_string(),
1871                    label: None,
1872                    input: None,
1873                })),
1874                columns: None,
1875            })),
1876        }));
1877
1878        let optimized = optimizer.optimize(plan).unwrap();
1879
1880        // Filter stays above Distinct
1881        if let LogicalOperator::Filter(filter) = &optimized.root
1882            && let LogicalOperator::Distinct(_) = filter.input.as_ref()
1883        {
1884            return;
1885        }
1886        panic!("Expected Filter -> Distinct structure");
1887    }
1888
1889    #[test]
1890    fn test_filter_not_pushed_through_aggregate() {
1891        let optimizer = Optimizer::new();
1892
1893        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1894            predicate: LogicalExpression::Binary {
1895                left: Box::new(LogicalExpression::Variable("cnt".to_string())),
1896                op: BinaryOp::Gt,
1897                right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1898            },
1899            pushdown_hint: None,
1900            input: Box::new(LogicalOperator::Aggregate(AggregateOp {
1901                group_by: vec![],
1902                aggregates: vec![AggregateExpr {
1903                    function: AggregateFunction::Count,
1904                    expression: None,
1905                    expression2: None,
1906                    distinct: false,
1907                    alias: Some("cnt".to_string()),
1908                    percentile: None,
1909                    separator: None,
1910                }],
1911                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1912                    variable: "n".to_string(),
1913                    label: None,
1914                    input: None,
1915                })),
1916                having: None,
1917            })),
1918        }));
1919
1920        let optimized = optimizer.optimize(plan).unwrap();
1921
1922        // Filter should stay above Aggregate
1923        if let LogicalOperator::Filter(filter) = &optimized.root
1924            && let LogicalOperator::Aggregate(_) = filter.input.as_ref()
1925        {
1926            return;
1927        }
1928        panic!("Expected Filter -> Aggregate structure");
1929    }
1930
1931    #[test]
1932    fn test_filter_pushdown_to_left_join_side() {
1933        let optimizer = Optimizer::new();
1934
1935        // Filter on left variable should be pushed to left side
1936        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1937            predicate: LogicalExpression::Binary {
1938                left: Box::new(LogicalExpression::Property {
1939                    variable: "a".to_string(),
1940                    property: "age".to_string(),
1941                }),
1942                op: BinaryOp::Gt,
1943                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1944            },
1945            pushdown_hint: None,
1946            input: Box::new(LogicalOperator::Join(JoinOp {
1947                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1948                    variable: "a".to_string(),
1949                    label: Some("Person".to_string()),
1950                    input: None,
1951                })),
1952                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1953                    variable: "b".to_string(),
1954                    label: Some("Company".to_string()),
1955                    input: None,
1956                })),
1957                join_type: JoinType::Inner,
1958                conditions: vec![],
1959            })),
1960        }));
1961
1962        let optimized = optimizer.optimize(plan).unwrap();
1963
1964        // Filter should be pushed to left side of join
1965        if let LogicalOperator::Join(join) = &optimized.root
1966            && let LogicalOperator::Filter(_) = join.left.as_ref()
1967        {
1968            return;
1969        }
1970        panic!("Expected Join with Filter on left side");
1971    }
1972
1973    #[test]
1974    fn test_filter_pushdown_to_right_join_side() {
1975        let optimizer = Optimizer::new();
1976
1977        // Filter on right variable should be pushed to right side
1978        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1979            predicate: LogicalExpression::Binary {
1980                left: Box::new(LogicalExpression::Property {
1981                    variable: "b".to_string(),
1982                    property: "name".to_string(),
1983                }),
1984                op: BinaryOp::Eq,
1985                right: Box::new(LogicalExpression::Literal(Value::String("Acme".into()))),
1986            },
1987            pushdown_hint: None,
1988            input: Box::new(LogicalOperator::Join(JoinOp {
1989                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1990                    variable: "a".to_string(),
1991                    label: Some("Person".to_string()),
1992                    input: None,
1993                })),
1994                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1995                    variable: "b".to_string(),
1996                    label: Some("Company".to_string()),
1997                    input: None,
1998                })),
1999                join_type: JoinType::Inner,
2000                conditions: vec![],
2001            })),
2002        }));
2003
2004        let optimized = optimizer.optimize(plan).unwrap();
2005
2006        // Filter should be pushed to right side of join
2007        if let LogicalOperator::Join(join) = &optimized.root
2008            && let LogicalOperator::Filter(_) = join.right.as_ref()
2009        {
2010            return;
2011        }
2012        panic!("Expected Join with Filter on right side");
2013    }
2014
2015    #[test]
2016    fn test_filter_not_pushed_when_uses_both_join_sides() {
2017        let optimizer = Optimizer::new();
2018
2019        // Filter using both variables should stay above join
2020        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
2021            predicate: LogicalExpression::Binary {
2022                left: Box::new(LogicalExpression::Property {
2023                    variable: "a".to_string(),
2024                    property: "id".to_string(),
2025                }),
2026                op: BinaryOp::Eq,
2027                right: Box::new(LogicalExpression::Property {
2028                    variable: "b".to_string(),
2029                    property: "a_id".to_string(),
2030                }),
2031            },
2032            pushdown_hint: None,
2033            input: Box::new(LogicalOperator::Join(JoinOp {
2034                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2035                    variable: "a".to_string(),
2036                    label: None,
2037                    input: None,
2038                })),
2039                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2040                    variable: "b".to_string(),
2041                    label: None,
2042                    input: None,
2043                })),
2044                join_type: JoinType::Inner,
2045                conditions: vec![],
2046            })),
2047        }));
2048
2049        let optimized = optimizer.optimize(plan).unwrap();
2050
2051        // Filter should stay above join
2052        if let LogicalOperator::Filter(filter) = &optimized.root
2053            && let LogicalOperator::Join(_) = filter.input.as_ref()
2054        {
2055            return;
2056        }
2057        panic!("Expected Filter -> Join structure");
2058    }
2059
2060    // Variable extraction tests
2061
2062    #[test]
2063    fn test_extract_variables_from_variable() {
2064        let optimizer = Optimizer::new();
2065        let expr = LogicalExpression::Variable("x".to_string());
2066        let vars = optimizer.extract_variables(&expr);
2067        assert_eq!(vars.len(), 1);
2068        assert!(vars.contains("x"));
2069    }
2070
2071    #[test]
2072    fn test_extract_variables_from_unary() {
2073        let optimizer = Optimizer::new();
2074        let expr = LogicalExpression::Unary {
2075            op: UnaryOp::Not,
2076            operand: Box::new(LogicalExpression::Variable("x".to_string())),
2077        };
2078        let vars = optimizer.extract_variables(&expr);
2079        assert_eq!(vars.len(), 1);
2080        assert!(vars.contains("x"));
2081    }
2082
2083    #[test]
2084    fn test_extract_variables_from_function_call() {
2085        let optimizer = Optimizer::new();
2086        let expr = LogicalExpression::FunctionCall {
2087            name: "length".to_string(),
2088            args: vec![
2089                LogicalExpression::Variable("a".to_string()),
2090                LogicalExpression::Variable("b".to_string()),
2091            ],
2092            distinct: false,
2093        };
2094        let vars = optimizer.extract_variables(&expr);
2095        assert_eq!(vars.len(), 2);
2096        assert!(vars.contains("a"));
2097        assert!(vars.contains("b"));
2098    }
2099
2100    #[test]
2101    fn test_extract_variables_from_list() {
2102        let optimizer = Optimizer::new();
2103        let expr = LogicalExpression::List(vec![
2104            LogicalExpression::Variable("a".to_string()),
2105            LogicalExpression::Literal(Value::Int64(1)),
2106            LogicalExpression::Variable("b".to_string()),
2107        ]);
2108        let vars = optimizer.extract_variables(&expr);
2109        assert_eq!(vars.len(), 2);
2110        assert!(vars.contains("a"));
2111        assert!(vars.contains("b"));
2112    }
2113
2114    #[test]
2115    fn test_extract_variables_from_map() {
2116        let optimizer = Optimizer::new();
2117        let expr = LogicalExpression::Map(vec![
2118            (
2119                "key1".to_string(),
2120                LogicalExpression::Variable("a".to_string()),
2121            ),
2122            (
2123                "key2".to_string(),
2124                LogicalExpression::Variable("b".to_string()),
2125            ),
2126        ]);
2127        let vars = optimizer.extract_variables(&expr);
2128        assert_eq!(vars.len(), 2);
2129        assert!(vars.contains("a"));
2130        assert!(vars.contains("b"));
2131    }
2132
2133    #[test]
2134    fn test_extract_variables_from_index_access() {
2135        let optimizer = Optimizer::new();
2136        let expr = LogicalExpression::IndexAccess {
2137            base: Box::new(LogicalExpression::Variable("list".to_string())),
2138            index: Box::new(LogicalExpression::Variable("idx".to_string())),
2139        };
2140        let vars = optimizer.extract_variables(&expr);
2141        assert_eq!(vars.len(), 2);
2142        assert!(vars.contains("list"));
2143        assert!(vars.contains("idx"));
2144    }
2145
2146    #[test]
2147    fn test_extract_variables_from_slice_access() {
2148        let optimizer = Optimizer::new();
2149        let expr = LogicalExpression::SliceAccess {
2150            base: Box::new(LogicalExpression::Variable("list".to_string())),
2151            start: Some(Box::new(LogicalExpression::Variable("s".to_string()))),
2152            end: Some(Box::new(LogicalExpression::Variable("e".to_string()))),
2153        };
2154        let vars = optimizer.extract_variables(&expr);
2155        assert_eq!(vars.len(), 3);
2156        assert!(vars.contains("list"));
2157        assert!(vars.contains("s"));
2158        assert!(vars.contains("e"));
2159    }
2160
2161    #[test]
2162    fn test_extract_variables_from_case() {
2163        let optimizer = Optimizer::new();
2164        let expr = LogicalExpression::Case {
2165            operand: Some(Box::new(LogicalExpression::Variable("x".to_string()))),
2166            when_clauses: vec![(
2167                LogicalExpression::Literal(Value::Int64(1)),
2168                LogicalExpression::Variable("a".to_string()),
2169            )],
2170            else_clause: Some(Box::new(LogicalExpression::Variable("b".to_string()))),
2171        };
2172        let vars = optimizer.extract_variables(&expr);
2173        assert_eq!(vars.len(), 3);
2174        assert!(vars.contains("x"));
2175        assert!(vars.contains("a"));
2176        assert!(vars.contains("b"));
2177    }
2178
2179    #[test]
2180    fn test_extract_variables_from_labels() {
2181        let optimizer = Optimizer::new();
2182        let expr = LogicalExpression::Labels("n".to_string());
2183        let vars = optimizer.extract_variables(&expr);
2184        assert_eq!(vars.len(), 1);
2185        assert!(vars.contains("n"));
2186    }
2187
2188    #[test]
2189    fn test_extract_variables_from_type() {
2190        let optimizer = Optimizer::new();
2191        let expr = LogicalExpression::Type("e".to_string());
2192        let vars = optimizer.extract_variables(&expr);
2193        assert_eq!(vars.len(), 1);
2194        assert!(vars.contains("e"));
2195    }
2196
2197    #[test]
2198    fn test_extract_variables_from_id() {
2199        let optimizer = Optimizer::new();
2200        let expr = LogicalExpression::Id("n".to_string());
2201        let vars = optimizer.extract_variables(&expr);
2202        assert_eq!(vars.len(), 1);
2203        assert!(vars.contains("n"));
2204    }
2205
2206    #[test]
2207    fn test_extract_variables_from_list_comprehension() {
2208        let optimizer = Optimizer::new();
2209        let expr = LogicalExpression::ListComprehension {
2210            variable: "x".to_string(),
2211            list_expr: Box::new(LogicalExpression::Variable("items".to_string())),
2212            filter_expr: Some(Box::new(LogicalExpression::Variable("pred".to_string()))),
2213            map_expr: Box::new(LogicalExpression::Variable("result".to_string())),
2214        };
2215        let vars = optimizer.extract_variables(&expr);
2216        assert!(vars.contains("items"));
2217        assert!(vars.contains("pred"));
2218        assert!(vars.contains("result"));
2219    }
2220
2221    #[test]
2222    fn test_extract_variables_from_literal_and_parameter() {
2223        let optimizer = Optimizer::new();
2224
2225        let literal = LogicalExpression::Literal(Value::Int64(42));
2226        assert!(optimizer.extract_variables(&literal).is_empty());
2227
2228        let param = LogicalExpression::Parameter("p".to_string());
2229        assert!(optimizer.extract_variables(&param).is_empty());
2230    }
2231
2232    // Recursive filter pushdown tests
2233
2234    #[test]
2235    fn test_recursive_filter_pushdown_through_skip() {
2236        let optimizer = Optimizer::new();
2237
2238        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
2239            items: vec![ReturnItem {
2240                expression: LogicalExpression::Variable("n".to_string()),
2241                alias: None,
2242            }],
2243            distinct: false,
2244            input: Box::new(LogicalOperator::Filter(FilterOp {
2245                predicate: LogicalExpression::Literal(Value::Bool(true)),
2246                pushdown_hint: None,
2247                input: Box::new(LogicalOperator::Skip(SkipOp {
2248                    count: 5.into(),
2249                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2250                        variable: "n".to_string(),
2251                        label: None,
2252                        input: None,
2253                    })),
2254                })),
2255            })),
2256        }));
2257
2258        let optimized = optimizer.optimize(plan).unwrap();
2259
2260        // Verify optimization succeeded
2261        assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
2262    }
2263
2264    #[test]
2265    fn test_nested_filter_pushdown() {
2266        let optimizer = Optimizer::new();
2267
2268        // Multiple nested filters
2269        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
2270            items: vec![ReturnItem {
2271                expression: LogicalExpression::Variable("n".to_string()),
2272                alias: None,
2273            }],
2274            distinct: false,
2275            input: Box::new(LogicalOperator::Filter(FilterOp {
2276                predicate: LogicalExpression::Binary {
2277                    left: Box::new(LogicalExpression::Property {
2278                        variable: "n".to_string(),
2279                        property: "x".to_string(),
2280                    }),
2281                    op: BinaryOp::Gt,
2282                    right: Box::new(LogicalExpression::Literal(Value::Int64(1))),
2283                },
2284                pushdown_hint: None,
2285                input: Box::new(LogicalOperator::Filter(FilterOp {
2286                    predicate: LogicalExpression::Binary {
2287                        left: Box::new(LogicalExpression::Property {
2288                            variable: "n".to_string(),
2289                            property: "y".to_string(),
2290                        }),
2291                        op: BinaryOp::Lt,
2292                        right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
2293                    },
2294                    pushdown_hint: None,
2295                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2296                        variable: "n".to_string(),
2297                        label: None,
2298                        input: None,
2299                    })),
2300                })),
2301            })),
2302        }));
2303
2304        let optimized = optimizer.optimize(plan).unwrap();
2305        assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
2306    }
2307
2308    #[test]
2309    fn test_cyclic_join_produces_multi_way_join() {
2310        use crate::query::plan::JoinCondition;
2311
2312        // Triangle pattern: a ⋈ b ⋈ c ⋈ a (cyclic)
2313        let scan_a = LogicalOperator::NodeScan(NodeScanOp {
2314            variable: "a".to_string(),
2315            label: Some("Person".to_string()),
2316            input: None,
2317        });
2318        let scan_b = LogicalOperator::NodeScan(NodeScanOp {
2319            variable: "b".to_string(),
2320            label: Some("Person".to_string()),
2321            input: None,
2322        });
2323        let scan_c = LogicalOperator::NodeScan(NodeScanOp {
2324            variable: "c".to_string(),
2325            label: Some("Person".to_string()),
2326            input: None,
2327        });
2328
2329        // Build: Join(Join(a, b, a=b), c, b=c) with extra condition c=a
2330        let join_ab = LogicalOperator::Join(JoinOp {
2331            left: Box::new(scan_a),
2332            right: Box::new(scan_b),
2333            join_type: JoinType::Inner,
2334            conditions: vec![JoinCondition {
2335                left: LogicalExpression::Variable("a".to_string()),
2336                right: LogicalExpression::Variable("b".to_string()),
2337            }],
2338        });
2339
2340        let join_abc = LogicalOperator::Join(JoinOp {
2341            left: Box::new(join_ab),
2342            right: Box::new(scan_c),
2343            join_type: JoinType::Inner,
2344            conditions: vec![
2345                JoinCondition {
2346                    left: LogicalExpression::Variable("b".to_string()),
2347                    right: LogicalExpression::Variable("c".to_string()),
2348                },
2349                JoinCondition {
2350                    left: LogicalExpression::Variable("c".to_string()),
2351                    right: LogicalExpression::Variable("a".to_string()),
2352                },
2353            ],
2354        });
2355
2356        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
2357            items: vec![ReturnItem {
2358                expression: LogicalExpression::Variable("a".to_string()),
2359                alias: None,
2360            }],
2361            distinct: false,
2362            input: Box::new(join_abc),
2363        }));
2364
2365        let mut optimizer = Optimizer::new();
2366        optimizer
2367            .card_estimator
2368            .add_table_stats("Person", cardinality::TableStats::new(1000));
2369
2370        let optimized = optimizer.optimize(plan).unwrap();
2371
2372        // Walk the tree to find a MultiWayJoin
2373        fn has_multi_way_join(op: &LogicalOperator) -> bool {
2374            match op {
2375                LogicalOperator::MultiWayJoin(_) => true,
2376                LogicalOperator::Return(ret) => has_multi_way_join(&ret.input),
2377                LogicalOperator::Filter(f) => has_multi_way_join(&f.input),
2378                LogicalOperator::Project(p) => has_multi_way_join(&p.input),
2379                _ => false,
2380            }
2381        }
2382
2383        assert!(
2384            has_multi_way_join(&optimized.root),
2385            "Expected MultiWayJoin for cyclic triangle pattern"
2386        );
2387    }
2388
2389    #[test]
2390    fn test_acyclic_join_uses_binary_joins() {
2391        use crate::query::plan::JoinCondition;
2392
2393        // Chain: a ⋈ b ⋈ c (acyclic)
2394        let scan_a = LogicalOperator::NodeScan(NodeScanOp {
2395            variable: "a".to_string(),
2396            label: Some("Person".to_string()),
2397            input: None,
2398        });
2399        let scan_b = LogicalOperator::NodeScan(NodeScanOp {
2400            variable: "b".to_string(),
2401            label: Some("Person".to_string()),
2402            input: None,
2403        });
2404        let scan_c = LogicalOperator::NodeScan(NodeScanOp {
2405            variable: "c".to_string(),
2406            label: Some("Company".to_string()),
2407            input: None,
2408        });
2409
2410        let join_ab = LogicalOperator::Join(JoinOp {
2411            left: Box::new(scan_a),
2412            right: Box::new(scan_b),
2413            join_type: JoinType::Inner,
2414            conditions: vec![JoinCondition {
2415                left: LogicalExpression::Variable("a".to_string()),
2416                right: LogicalExpression::Variable("b".to_string()),
2417            }],
2418        });
2419
2420        let join_abc = LogicalOperator::Join(JoinOp {
2421            left: Box::new(join_ab),
2422            right: Box::new(scan_c),
2423            join_type: JoinType::Inner,
2424            conditions: vec![JoinCondition {
2425                left: LogicalExpression::Variable("b".to_string()),
2426                right: LogicalExpression::Variable("c".to_string()),
2427            }],
2428        });
2429
2430        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
2431            items: vec![ReturnItem {
2432                expression: LogicalExpression::Variable("a".to_string()),
2433                alias: None,
2434            }],
2435            distinct: false,
2436            input: Box::new(join_abc),
2437        }));
2438
2439        let mut optimizer = Optimizer::new();
2440        optimizer
2441            .card_estimator
2442            .add_table_stats("Person", cardinality::TableStats::new(1000));
2443        optimizer
2444            .card_estimator
2445            .add_table_stats("Company", cardinality::TableStats::new(100));
2446
2447        let optimized = optimizer.optimize(plan).unwrap();
2448
2449        // Should NOT contain MultiWayJoin for acyclic pattern
2450        fn has_multi_way_join(op: &LogicalOperator) -> bool {
2451            match op {
2452                LogicalOperator::MultiWayJoin(_) => true,
2453                LogicalOperator::Return(ret) => has_multi_way_join(&ret.input),
2454                LogicalOperator::Filter(f) => has_multi_way_join(&f.input),
2455                LogicalOperator::Project(p) => has_multi_way_join(&p.input),
2456                LogicalOperator::Join(j) => {
2457                    has_multi_way_join(&j.left) || has_multi_way_join(&j.right)
2458                }
2459                _ => false,
2460            }
2461        }
2462
2463        assert!(
2464            !has_multi_way_join(&optimized.root),
2465            "Acyclic join should NOT produce MultiWayJoin"
2466        );
2467    }
2468}