Skip to main content

grafeo_engine/query/optimizer/
mod.rs

1//! Makes your queries faster without changing their meaning.
2//!
3//! The optimizer transforms logical plans to run more efficiently:
4//!
5//! | Optimization | What it does |
6//! | ------------ | ------------ |
7//! | Filter Pushdown | Moves `WHERE` clauses closer to scans - filter early, process less |
8//! | Join Reordering | Picks the best order to join tables using the DPccp algorithm |
9//! | Predicate Simplification | Folds constants like `1 + 1` into `2` |
10//!
11//! The optimizer uses [`CostModel`] and [`CardinalityEstimator`] to predict
12//! how expensive different plans are, then picks the cheapest.
13
14pub mod cardinality;
15pub mod cost;
16pub mod join_order;
17
18pub use cardinality::{
19    CardinalityEstimator, ColumnStats, EstimationLog, SelectivityConfig, TableStats,
20};
21pub use cost::{Cost, CostModel};
22pub use join_order::{BitSet, DPccp, JoinGraph, JoinGraphBuilder, JoinPlan};
23
24use crate::query::plan::{
25    FilterOp, JoinCondition, LogicalExpression, LogicalOperator, LogicalPlan, MultiWayJoinOp,
26};
27use grafeo_common::utils::error::Result;
28use std::collections::HashSet;
29
30/// Information about a join condition for join reordering.
31#[derive(Debug, Clone)]
32struct JoinInfo {
33    left_var: String,
34    right_var: String,
35    left_expr: LogicalExpression,
36    right_expr: LogicalExpression,
37}
38
39/// A column required by the query, used for projection pushdown.
40#[derive(Debug, Clone, PartialEq, Eq, Hash)]
41enum RequiredColumn {
42    /// A variable (node, edge, or path binding)
43    Variable(String),
44    /// A specific property of a variable
45    Property(String, String),
46}
47
48/// Transforms logical plans for faster execution.
49///
50/// Create with [`new()`](Self::new), then call [`optimize()`](Self::optimize).
51/// Use the builder methods to enable/disable specific optimizations.
52pub struct Optimizer {
53    /// Whether to enable filter pushdown.
54    enable_filter_pushdown: bool,
55    /// Whether to enable join reordering.
56    enable_join_reorder: bool,
57    /// Whether to enable projection pushdown.
58    enable_projection_pushdown: bool,
59    /// Cost model for estimation.
60    cost_model: CostModel,
61    /// Cardinality estimator.
62    card_estimator: CardinalityEstimator,
63}
64
65impl Optimizer {
66    /// Creates a new optimizer with default settings.
67    #[must_use]
68    pub fn new() -> Self {
69        Self {
70            enable_filter_pushdown: true,
71            enable_join_reorder: true,
72            enable_projection_pushdown: true,
73            cost_model: CostModel::new(),
74            card_estimator: CardinalityEstimator::new(),
75        }
76    }
77
78    /// Creates an optimizer with cardinality estimates from the store's statistics.
79    ///
80    /// Pre-populates the cardinality estimator with per-label row counts and
81    /// edge type fanout. Feeds per-edge-type degree stats, label cardinalities,
82    /// and graph totals into the cost model for accurate estimation.
83    #[must_use]
84    pub fn from_store(store: &grafeo_core::graph::lpg::LpgStore) -> Self {
85        store.ensure_statistics_fresh();
86        let stats = store.statistics();
87        Self::from_statistics(&stats)
88    }
89
90    /// Creates an optimizer from any GraphStore implementation.
91    ///
92    /// Unlike [`from_store`](Self::from_store), this does not call
93    /// `ensure_statistics_fresh()` since external stores manage their own
94    /// statistics. The store's [`statistics()`](grafeo_core::graph::GraphStore::statistics) method
95    /// is called directly.
96    #[must_use]
97    pub fn from_graph_store(store: &dyn grafeo_core::graph::GraphStore) -> Self {
98        let stats = store.statistics();
99        Self::from_statistics(&stats)
100    }
101
102    /// Creates an optimizer from a Statistics snapshot.
103    ///
104    /// Extracts label cardinalities, edge type degrees, and graph totals
105    /// for both the cardinality estimator and cost model.
106    #[must_use]
107    fn from_statistics(stats: &grafeo_core::statistics::Statistics) -> Self {
108        let estimator = CardinalityEstimator::from_statistics(stats);
109
110        let avg_fanout = if stats.total_nodes > 0 {
111            (stats.total_edges as f64 / stats.total_nodes as f64).max(1.0)
112        } else {
113            10.0
114        };
115
116        let edge_type_degrees: std::collections::HashMap<String, (f64, f64)> = stats
117            .edge_types
118            .iter()
119            .map(|(name, et)| (name.clone(), (et.avg_out_degree, et.avg_in_degree)))
120            .collect();
121
122        let label_cardinalities: std::collections::HashMap<String, u64> = stats
123            .labels
124            .iter()
125            .map(|(name, ls)| (name.clone(), ls.node_count))
126            .collect();
127
128        Self {
129            enable_filter_pushdown: true,
130            enable_join_reorder: true,
131            enable_projection_pushdown: true,
132            cost_model: CostModel::new()
133                .with_avg_fanout(avg_fanout)
134                .with_edge_type_degrees(edge_type_degrees)
135                .with_label_cardinalities(label_cardinalities)
136                .with_graph_totals(stats.total_nodes, stats.total_edges),
137            card_estimator: estimator,
138        }
139    }
140
141    /// Enables or disables filter pushdown.
142    pub fn with_filter_pushdown(mut self, enabled: bool) -> Self {
143        self.enable_filter_pushdown = enabled;
144        self
145    }
146
147    /// Enables or disables join reordering.
148    pub fn with_join_reorder(mut self, enabled: bool) -> Self {
149        self.enable_join_reorder = enabled;
150        self
151    }
152
153    /// Enables or disables projection pushdown.
154    pub fn with_projection_pushdown(mut self, enabled: bool) -> Self {
155        self.enable_projection_pushdown = enabled;
156        self
157    }
158
159    /// Sets the cost model.
160    pub fn with_cost_model(mut self, cost_model: CostModel) -> Self {
161        self.cost_model = cost_model;
162        self
163    }
164
165    /// Sets the cardinality estimator.
166    pub fn with_cardinality_estimator(mut self, estimator: CardinalityEstimator) -> Self {
167        self.card_estimator = estimator;
168        self
169    }
170
171    /// Sets the selectivity configuration for the cardinality estimator.
172    pub fn with_selectivity_config(mut self, config: SelectivityConfig) -> Self {
173        self.card_estimator = CardinalityEstimator::with_selectivity_config(config);
174        self
175    }
176
177    /// Returns a reference to the cost model.
178    pub fn cost_model(&self) -> &CostModel {
179        &self.cost_model
180    }
181
182    /// Returns a reference to the cardinality estimator.
183    pub fn cardinality_estimator(&self) -> &CardinalityEstimator {
184        &self.card_estimator
185    }
186
187    /// Estimates the total cost of a plan by recursively costing the entire tree.
188    ///
189    /// Walks every operator in the plan, computing cardinality at each level
190    /// and summing the per-operator costs. Uses actual child cardinalities
191    /// for join cost estimation rather than approximations.
192    pub fn estimate_cost(&self, plan: &LogicalPlan) -> Cost {
193        self.cost_model
194            .estimate_tree(&plan.root, &self.card_estimator)
195    }
196
197    /// Estimates the cardinality of a plan.
198    pub fn estimate_cardinality(&self, plan: &LogicalPlan) -> f64 {
199        self.card_estimator.estimate(&plan.root)
200    }
201
202    /// Optimizes a logical plan.
203    ///
204    /// # Errors
205    ///
206    /// Returns an error if optimization fails.
207    pub fn optimize(&self, plan: LogicalPlan) -> Result<LogicalPlan> {
208        let mut root = plan.root;
209
210        // Apply optimization rules
211        if self.enable_filter_pushdown {
212            root = self.push_filters_down(root);
213        }
214
215        if self.enable_join_reorder {
216            root = self.reorder_joins(root);
217        }
218
219        if self.enable_projection_pushdown {
220            root = self.push_projections_down(root);
221        }
222
223        Ok(LogicalPlan {
224            root,
225            explain: plan.explain,
226            profile: plan.profile,
227        })
228    }
229
230    /// Pushes projections down the operator tree to eliminate unused columns early.
231    ///
232    /// This optimization:
233    /// 1. Collects required variables/properties from the root
234    /// 2. Propagates requirements down through the tree
235    /// 3. Inserts projections to eliminate unneeded columns before expensive operations
236    fn push_projections_down(&self, op: LogicalOperator) -> LogicalOperator {
237        // Collect required columns from the top of the plan
238        let required = self.collect_required_columns(&op);
239
240        // Push projections down
241        self.push_projections_recursive(op, &required)
242    }
243
244    /// Collects all variables and properties required by an operator and its ancestors.
245    fn collect_required_columns(&self, op: &LogicalOperator) -> HashSet<RequiredColumn> {
246        let mut required = HashSet::new();
247        Self::collect_required_recursive(op, &mut required);
248        required
249    }
250
251    /// Recursively collects required columns.
252    fn collect_required_recursive(op: &LogicalOperator, required: &mut HashSet<RequiredColumn>) {
253        match op {
254            LogicalOperator::Return(ret) => {
255                for item in &ret.items {
256                    Self::collect_from_expression(&item.expression, required);
257                }
258                Self::collect_required_recursive(&ret.input, required);
259            }
260            LogicalOperator::Project(proj) => {
261                for p in &proj.projections {
262                    Self::collect_from_expression(&p.expression, required);
263                }
264                Self::collect_required_recursive(&proj.input, required);
265            }
266            LogicalOperator::Filter(filter) => {
267                Self::collect_from_expression(&filter.predicate, required);
268                Self::collect_required_recursive(&filter.input, required);
269            }
270            LogicalOperator::Sort(sort) => {
271                for key in &sort.keys {
272                    Self::collect_from_expression(&key.expression, required);
273                }
274                Self::collect_required_recursive(&sort.input, required);
275            }
276            LogicalOperator::Aggregate(agg) => {
277                for expr in &agg.group_by {
278                    Self::collect_from_expression(expr, required);
279                }
280                for agg_expr in &agg.aggregates {
281                    if let Some(ref expr) = agg_expr.expression {
282                        Self::collect_from_expression(expr, required);
283                    }
284                }
285                if let Some(ref having) = agg.having {
286                    Self::collect_from_expression(having, required);
287                }
288                Self::collect_required_recursive(&agg.input, required);
289            }
290            LogicalOperator::Join(join) => {
291                for cond in &join.conditions {
292                    Self::collect_from_expression(&cond.left, required);
293                    Self::collect_from_expression(&cond.right, required);
294                }
295                Self::collect_required_recursive(&join.left, required);
296                Self::collect_required_recursive(&join.right, required);
297            }
298            LogicalOperator::Expand(expand) => {
299                // The source and target variables are needed
300                required.insert(RequiredColumn::Variable(expand.from_variable.clone()));
301                required.insert(RequiredColumn::Variable(expand.to_variable.clone()));
302                if let Some(ref edge_var) = expand.edge_variable {
303                    required.insert(RequiredColumn::Variable(edge_var.clone()));
304                }
305                Self::collect_required_recursive(&expand.input, required);
306            }
307            LogicalOperator::Limit(limit) => {
308                Self::collect_required_recursive(&limit.input, required);
309            }
310            LogicalOperator::Skip(skip) => {
311                Self::collect_required_recursive(&skip.input, required);
312            }
313            LogicalOperator::Distinct(distinct) => {
314                Self::collect_required_recursive(&distinct.input, required);
315            }
316            LogicalOperator::NodeScan(scan) => {
317                required.insert(RequiredColumn::Variable(scan.variable.clone()));
318            }
319            LogicalOperator::EdgeScan(scan) => {
320                required.insert(RequiredColumn::Variable(scan.variable.clone()));
321            }
322            LogicalOperator::MultiWayJoin(mwj) => {
323                for cond in &mwj.conditions {
324                    Self::collect_from_expression(&cond.left, required);
325                    Self::collect_from_expression(&cond.right, required);
326                }
327                for input in &mwj.inputs {
328                    Self::collect_required_recursive(input, required);
329                }
330            }
331            _ => {}
332        }
333    }
334
335    /// Collects required columns from an expression.
336    fn collect_from_expression(expr: &LogicalExpression, required: &mut HashSet<RequiredColumn>) {
337        match expr {
338            LogicalExpression::Variable(var) => {
339                required.insert(RequiredColumn::Variable(var.clone()));
340            }
341            LogicalExpression::Property { variable, property } => {
342                required.insert(RequiredColumn::Property(variable.clone(), property.clone()));
343                required.insert(RequiredColumn::Variable(variable.clone()));
344            }
345            LogicalExpression::Binary { left, right, .. } => {
346                Self::collect_from_expression(left, required);
347                Self::collect_from_expression(right, required);
348            }
349            LogicalExpression::Unary { operand, .. } => {
350                Self::collect_from_expression(operand, required);
351            }
352            LogicalExpression::FunctionCall { args, .. } => {
353                for arg in args {
354                    Self::collect_from_expression(arg, required);
355                }
356            }
357            LogicalExpression::List(items) => {
358                for item in items {
359                    Self::collect_from_expression(item, required);
360                }
361            }
362            LogicalExpression::Map(pairs) => {
363                for (_, value) in pairs {
364                    Self::collect_from_expression(value, required);
365                }
366            }
367            LogicalExpression::IndexAccess { base, index } => {
368                Self::collect_from_expression(base, required);
369                Self::collect_from_expression(index, required);
370            }
371            LogicalExpression::SliceAccess { base, start, end } => {
372                Self::collect_from_expression(base, required);
373                if let Some(s) = start {
374                    Self::collect_from_expression(s, required);
375                }
376                if let Some(e) = end {
377                    Self::collect_from_expression(e, required);
378                }
379            }
380            LogicalExpression::Case {
381                operand,
382                when_clauses,
383                else_clause,
384            } => {
385                if let Some(op) = operand {
386                    Self::collect_from_expression(op, required);
387                }
388                for (cond, result) in when_clauses {
389                    Self::collect_from_expression(cond, required);
390                    Self::collect_from_expression(result, required);
391                }
392                if let Some(else_expr) = else_clause {
393                    Self::collect_from_expression(else_expr, required);
394                }
395            }
396            LogicalExpression::Labels(var)
397            | LogicalExpression::Type(var)
398            | LogicalExpression::Id(var) => {
399                required.insert(RequiredColumn::Variable(var.clone()));
400            }
401            LogicalExpression::ListComprehension {
402                list_expr,
403                filter_expr,
404                map_expr,
405                ..
406            } => {
407                Self::collect_from_expression(list_expr, required);
408                if let Some(filter) = filter_expr {
409                    Self::collect_from_expression(filter, required);
410                }
411                Self::collect_from_expression(map_expr, required);
412            }
413            _ => {}
414        }
415    }
416
417    /// Recursively pushes projections down, adding them before expensive operations.
418    fn push_projections_recursive(
419        &self,
420        op: LogicalOperator,
421        required: &HashSet<RequiredColumn>,
422    ) -> LogicalOperator {
423        match op {
424            LogicalOperator::Return(mut ret) => {
425                ret.input = Box::new(self.push_projections_recursive(*ret.input, required));
426                LogicalOperator::Return(ret)
427            }
428            LogicalOperator::Project(mut proj) => {
429                proj.input = Box::new(self.push_projections_recursive(*proj.input, required));
430                LogicalOperator::Project(proj)
431            }
432            LogicalOperator::Filter(mut filter) => {
433                filter.input = Box::new(self.push_projections_recursive(*filter.input, required));
434                LogicalOperator::Filter(filter)
435            }
436            LogicalOperator::Sort(mut sort) => {
437                // Sort is expensive - consider adding a projection before it
438                // to reduce tuple width
439                sort.input = Box::new(self.push_projections_recursive(*sort.input, required));
440                LogicalOperator::Sort(sort)
441            }
442            LogicalOperator::Aggregate(mut agg) => {
443                agg.input = Box::new(self.push_projections_recursive(*agg.input, required));
444                LogicalOperator::Aggregate(agg)
445            }
446            LogicalOperator::Join(mut join) => {
447                // Joins are expensive - the required columns help determine
448                // what to project on each side
449                let left_vars = self.collect_output_variables(&join.left);
450                let right_vars = self.collect_output_variables(&join.right);
451
452                // Filter required columns to each side
453                let left_required: HashSet<_> = required
454                    .iter()
455                    .filter(|c| match c {
456                        RequiredColumn::Variable(v) => left_vars.contains(v),
457                        RequiredColumn::Property(v, _) => left_vars.contains(v),
458                    })
459                    .cloned()
460                    .collect();
461
462                let right_required: HashSet<_> = required
463                    .iter()
464                    .filter(|c| match c {
465                        RequiredColumn::Variable(v) => right_vars.contains(v),
466                        RequiredColumn::Property(v, _) => right_vars.contains(v),
467                    })
468                    .cloned()
469                    .collect();
470
471                join.left = Box::new(self.push_projections_recursive(*join.left, &left_required));
472                join.right =
473                    Box::new(self.push_projections_recursive(*join.right, &right_required));
474                LogicalOperator::Join(join)
475            }
476            LogicalOperator::Expand(mut expand) => {
477                expand.input = Box::new(self.push_projections_recursive(*expand.input, required));
478                LogicalOperator::Expand(expand)
479            }
480            LogicalOperator::Limit(mut limit) => {
481                limit.input = Box::new(self.push_projections_recursive(*limit.input, required));
482                LogicalOperator::Limit(limit)
483            }
484            LogicalOperator::Skip(mut skip) => {
485                skip.input = Box::new(self.push_projections_recursive(*skip.input, required));
486                LogicalOperator::Skip(skip)
487            }
488            LogicalOperator::Distinct(mut distinct) => {
489                distinct.input =
490                    Box::new(self.push_projections_recursive(*distinct.input, required));
491                LogicalOperator::Distinct(distinct)
492            }
493            LogicalOperator::MapCollect(mut mc) => {
494                mc.input = Box::new(self.push_projections_recursive(*mc.input, required));
495                LogicalOperator::MapCollect(mc)
496            }
497            LogicalOperator::MultiWayJoin(mut mwj) => {
498                mwj.inputs = mwj
499                    .inputs
500                    .into_iter()
501                    .map(|input| self.push_projections_recursive(input, required))
502                    .collect();
503                LogicalOperator::MultiWayJoin(mwj)
504            }
505            other => other,
506        }
507    }
508
509    /// Reorders joins in the operator tree using the DPccp algorithm.
510    ///
511    /// This optimization finds the optimal join order by:
512    /// 1. Extracting all base relations (scans) and join conditions
513    /// 2. Building a join graph
514    /// 3. Using dynamic programming to find the cheapest join order
515    fn reorder_joins(&self, op: LogicalOperator) -> LogicalOperator {
516        // First, recursively optimize children
517        let op = self.reorder_joins_recursive(op);
518
519        // Then, if this is a join tree, try to optimize it
520        if let Some((relations, conditions)) = self.extract_join_tree(&op)
521            && relations.len() >= 2
522            && let Some(optimized) = self.optimize_join_order(&relations, &conditions)
523        {
524            return optimized;
525        }
526
527        op
528    }
529
530    /// Recursively applies join reordering to child operators.
531    fn reorder_joins_recursive(&self, op: LogicalOperator) -> LogicalOperator {
532        match op {
533            LogicalOperator::Return(mut ret) => {
534                ret.input = Box::new(self.reorder_joins(*ret.input));
535                LogicalOperator::Return(ret)
536            }
537            LogicalOperator::Project(mut proj) => {
538                proj.input = Box::new(self.reorder_joins(*proj.input));
539                LogicalOperator::Project(proj)
540            }
541            LogicalOperator::Filter(mut filter) => {
542                filter.input = Box::new(self.reorder_joins(*filter.input));
543                LogicalOperator::Filter(filter)
544            }
545            LogicalOperator::Limit(mut limit) => {
546                limit.input = Box::new(self.reorder_joins(*limit.input));
547                LogicalOperator::Limit(limit)
548            }
549            LogicalOperator::Skip(mut skip) => {
550                skip.input = Box::new(self.reorder_joins(*skip.input));
551                LogicalOperator::Skip(skip)
552            }
553            LogicalOperator::Sort(mut sort) => {
554                sort.input = Box::new(self.reorder_joins(*sort.input));
555                LogicalOperator::Sort(sort)
556            }
557            LogicalOperator::Distinct(mut distinct) => {
558                distinct.input = Box::new(self.reorder_joins(*distinct.input));
559                LogicalOperator::Distinct(distinct)
560            }
561            LogicalOperator::Aggregate(mut agg) => {
562                agg.input = Box::new(self.reorder_joins(*agg.input));
563                LogicalOperator::Aggregate(agg)
564            }
565            LogicalOperator::Expand(mut expand) => {
566                expand.input = Box::new(self.reorder_joins(*expand.input));
567                LogicalOperator::Expand(expand)
568            }
569            LogicalOperator::MapCollect(mut mc) => {
570                mc.input = Box::new(self.reorder_joins(*mc.input));
571                LogicalOperator::MapCollect(mc)
572            }
573            LogicalOperator::MultiWayJoin(mut mwj) => {
574                mwj.inputs = mwj
575                    .inputs
576                    .into_iter()
577                    .map(|input| self.reorder_joins(input))
578                    .collect();
579                LogicalOperator::MultiWayJoin(mwj)
580            }
581            // Join operators are handled by the parent reorder_joins call
582            other => other,
583        }
584    }
585
586    /// Extracts base relations and join conditions from a join tree.
587    ///
588    /// Returns None if the operator is not a join tree.
589    fn extract_join_tree(
590        &self,
591        op: &LogicalOperator,
592    ) -> Option<(Vec<(String, LogicalOperator)>, Vec<JoinInfo>)> {
593        let mut relations = Vec::new();
594        let mut join_conditions = Vec::new();
595
596        if !self.collect_join_tree(op, &mut relations, &mut join_conditions) {
597            return None;
598        }
599
600        if relations.len() < 2 {
601            return None;
602        }
603
604        Some((relations, join_conditions))
605    }
606
607    /// Recursively collects base relations and join conditions.
608    ///
609    /// Returns true if this subtree is part of a join tree.
610    fn collect_join_tree(
611        &self,
612        op: &LogicalOperator,
613        relations: &mut Vec<(String, LogicalOperator)>,
614        conditions: &mut Vec<JoinInfo>,
615    ) -> bool {
616        match op {
617            LogicalOperator::Join(join) => {
618                // Collect from both sides
619                let left_ok = self.collect_join_tree(&join.left, relations, conditions);
620                let right_ok = self.collect_join_tree(&join.right, relations, conditions);
621
622                // Add conditions from this join
623                for cond in &join.conditions {
624                    if let (Some(left_var), Some(right_var)) = (
625                        self.extract_variable_from_expr(&cond.left),
626                        self.extract_variable_from_expr(&cond.right),
627                    ) {
628                        conditions.push(JoinInfo {
629                            left_var,
630                            right_var,
631                            left_expr: cond.left.clone(),
632                            right_expr: cond.right.clone(),
633                        });
634                    }
635                }
636
637                left_ok && right_ok
638            }
639            LogicalOperator::NodeScan(scan) => {
640                relations.push((scan.variable.clone(), op.clone()));
641                true
642            }
643            LogicalOperator::EdgeScan(scan) => {
644                relations.push((scan.variable.clone(), op.clone()));
645                true
646            }
647            LogicalOperator::Filter(filter) => {
648                // A filter on a base relation is still part of the join tree
649                self.collect_join_tree(&filter.input, relations, conditions)
650            }
651            LogicalOperator::Expand(expand) => {
652                // Expand is a special case - it's like a join with the adjacency
653                // For now, treat the whole Expand subtree as a single relation
654                relations.push((expand.to_variable.clone(), op.clone()));
655                true
656            }
657            _ => false,
658        }
659    }
660
661    /// Extracts the primary variable from an expression.
662    fn extract_variable_from_expr(&self, expr: &LogicalExpression) -> Option<String> {
663        match expr {
664            LogicalExpression::Variable(v) => Some(v.clone()),
665            LogicalExpression::Property { variable, .. } => Some(variable.clone()),
666            _ => None,
667        }
668    }
669
670    /// Optimizes the join order using DPccp, or produces a multi-way
671    /// leapfrog join for cyclic patterns when the cost model prefers it.
672    fn optimize_join_order(
673        &self,
674        relations: &[(String, LogicalOperator)],
675        conditions: &[JoinInfo],
676    ) -> Option<LogicalOperator> {
677        use join_order::{DPccp, JoinGraphBuilder};
678
679        // Build the join graph
680        let mut builder = JoinGraphBuilder::new();
681
682        for (var, relation) in relations {
683            builder.add_relation(var, relation.clone());
684        }
685
686        for cond in conditions {
687            builder.add_join_condition(
688                &cond.left_var,
689                &cond.right_var,
690                cond.left_expr.clone(),
691                cond.right_expr.clone(),
692            );
693        }
694
695        let graph = builder.build();
696
697        // For cyclic graphs with 3+ relations, use leapfrog (WCOJ) join.
698        // Cyclic joins (e.g. triangle patterns) benefit from worst-case optimal
699        // multi-way intersection rather than binary hash join cascades that can
700        // produce intermediate blowup.
701        if graph.is_cyclic() && relations.len() >= 3 {
702            // Collect shared variables (variables appearing in 2+ conditions)
703            let mut var_counts: std::collections::HashMap<&str, usize> =
704                std::collections::HashMap::new();
705            for cond in conditions {
706                *var_counts.entry(&cond.left_var).or_default() += 1;
707                *var_counts.entry(&cond.right_var).or_default() += 1;
708            }
709            let shared_variables: Vec<String> = var_counts
710                .into_iter()
711                .filter(|(_, count)| *count >= 2)
712                .map(|(var, _)| var.to_string())
713                .collect();
714
715            let join_conditions: Vec<JoinCondition> = conditions
716                .iter()
717                .map(|c| JoinCondition {
718                    left: c.left_expr.clone(),
719                    right: c.right_expr.clone(),
720                })
721                .collect();
722
723            return Some(LogicalOperator::MultiWayJoin(MultiWayJoinOp {
724                inputs: relations.iter().map(|(_, rel)| rel.clone()).collect(),
725                conditions: join_conditions,
726                shared_variables,
727            }));
728        }
729
730        // Fall through to DPccp for binary join ordering
731        let mut dpccp = DPccp::new(&graph, &self.cost_model, &self.card_estimator);
732        let plan = dpccp.optimize()?;
733
734        Some(plan.operator)
735    }
736
737    /// Pushes filters down the operator tree.
738    ///
739    /// This optimization moves filter predicates as close to the data source
740    /// as possible to reduce the amount of data processed by upper operators.
741    fn push_filters_down(&self, op: LogicalOperator) -> LogicalOperator {
742        match op {
743            // For Filter operators, try to push the predicate into the child
744            LogicalOperator::Filter(filter) => {
745                let optimized_input = self.push_filters_down(*filter.input);
746                self.try_push_filter_into(filter.predicate, optimized_input)
747            }
748            // Recursively optimize children for other operators
749            LogicalOperator::Return(mut ret) => {
750                ret.input = Box::new(self.push_filters_down(*ret.input));
751                LogicalOperator::Return(ret)
752            }
753            LogicalOperator::Project(mut proj) => {
754                proj.input = Box::new(self.push_filters_down(*proj.input));
755                LogicalOperator::Project(proj)
756            }
757            LogicalOperator::Limit(mut limit) => {
758                limit.input = Box::new(self.push_filters_down(*limit.input));
759                LogicalOperator::Limit(limit)
760            }
761            LogicalOperator::Skip(mut skip) => {
762                skip.input = Box::new(self.push_filters_down(*skip.input));
763                LogicalOperator::Skip(skip)
764            }
765            LogicalOperator::Sort(mut sort) => {
766                sort.input = Box::new(self.push_filters_down(*sort.input));
767                LogicalOperator::Sort(sort)
768            }
769            LogicalOperator::Distinct(mut distinct) => {
770                distinct.input = Box::new(self.push_filters_down(*distinct.input));
771                LogicalOperator::Distinct(distinct)
772            }
773            LogicalOperator::Expand(mut expand) => {
774                expand.input = Box::new(self.push_filters_down(*expand.input));
775                LogicalOperator::Expand(expand)
776            }
777            LogicalOperator::Join(mut join) => {
778                join.left = Box::new(self.push_filters_down(*join.left));
779                join.right = Box::new(self.push_filters_down(*join.right));
780                LogicalOperator::Join(join)
781            }
782            LogicalOperator::Aggregate(mut agg) => {
783                agg.input = Box::new(self.push_filters_down(*agg.input));
784                LogicalOperator::Aggregate(agg)
785            }
786            LogicalOperator::MapCollect(mut mc) => {
787                mc.input = Box::new(self.push_filters_down(*mc.input));
788                LogicalOperator::MapCollect(mc)
789            }
790            LogicalOperator::MultiWayJoin(mut mwj) => {
791                mwj.inputs = mwj
792                    .inputs
793                    .into_iter()
794                    .map(|input| self.push_filters_down(input))
795                    .collect();
796                LogicalOperator::MultiWayJoin(mwj)
797            }
798            // Leaf operators and unsupported operators are returned as-is
799            other => other,
800        }
801    }
802
803    /// Tries to push a filter predicate into the given operator.
804    ///
805    /// Returns either the predicate pushed into the operator, or a new
806    /// Filter operator on top if the predicate cannot be pushed further.
807    fn try_push_filter_into(
808        &self,
809        predicate: LogicalExpression,
810        op: LogicalOperator,
811    ) -> LogicalOperator {
812        match op {
813            // Can push through Project if predicate doesn't depend on computed columns
814            LogicalOperator::Project(mut proj) => {
815                let predicate_vars = self.extract_variables(&predicate);
816                let computed_vars = self.extract_projection_aliases(&proj.projections);
817
818                // If predicate doesn't use any computed columns, push through
819                if predicate_vars.is_disjoint(&computed_vars) {
820                    proj.input = Box::new(self.try_push_filter_into(predicate, *proj.input));
821                    LogicalOperator::Project(proj)
822                } else {
823                    // Can't push through, keep filter on top
824                    LogicalOperator::Filter(FilterOp {
825                        predicate,
826                        pushdown_hint: None,
827                        input: Box::new(LogicalOperator::Project(proj)),
828                    })
829                }
830            }
831
832            // Can push through Return (which is like a projection)
833            LogicalOperator::Return(mut ret) => {
834                ret.input = Box::new(self.try_push_filter_into(predicate, *ret.input));
835                LogicalOperator::Return(ret)
836            }
837
838            // Can push through Expand if predicate doesn't use variables introduced by this expand
839            LogicalOperator::Expand(mut expand) => {
840                let predicate_vars = self.extract_variables(&predicate);
841
842                // Variables introduced by this expand are:
843                // - The target variable (to_variable)
844                // - The edge variable (if any)
845                // - The path alias (if any)
846                let mut introduced_vars = vec![&expand.to_variable];
847                if let Some(ref edge_var) = expand.edge_variable {
848                    introduced_vars.push(edge_var);
849                }
850                if let Some(ref path_alias) = expand.path_alias {
851                    introduced_vars.push(path_alias);
852                }
853
854                // Check if predicate uses any variables introduced by this expand
855                let uses_introduced_vars =
856                    predicate_vars.iter().any(|v| introduced_vars.contains(&v));
857
858                if !uses_introduced_vars {
859                    // Predicate doesn't use vars from this expand, so push through
860                    expand.input = Box::new(self.try_push_filter_into(predicate, *expand.input));
861                    LogicalOperator::Expand(expand)
862                } else {
863                    // Keep filter after expand
864                    LogicalOperator::Filter(FilterOp {
865                        predicate,
866                        pushdown_hint: None,
867                        input: Box::new(LogicalOperator::Expand(expand)),
868                    })
869                }
870            }
871
872            // Can push through Join to left/right side based on variables used
873            LogicalOperator::Join(mut join) => {
874                let predicate_vars = self.extract_variables(&predicate);
875                let left_vars = self.collect_output_variables(&join.left);
876                let right_vars = self.collect_output_variables(&join.right);
877
878                let uses_left = predicate_vars.iter().any(|v| left_vars.contains(v));
879                let uses_right = predicate_vars.iter().any(|v| right_vars.contains(v));
880
881                if uses_left && !uses_right {
882                    // Push to left side
883                    join.left = Box::new(self.try_push_filter_into(predicate, *join.left));
884                    LogicalOperator::Join(join)
885                } else if uses_right && !uses_left {
886                    // Push to right side
887                    join.right = Box::new(self.try_push_filter_into(predicate, *join.right));
888                    LogicalOperator::Join(join)
889                } else {
890                    // Uses both sides - keep above join
891                    LogicalOperator::Filter(FilterOp {
892                        predicate,
893                        pushdown_hint: None,
894                        input: Box::new(LogicalOperator::Join(join)),
895                    })
896                }
897            }
898
899            // Cannot push through Aggregate (predicate refers to aggregated values)
900            LogicalOperator::Aggregate(agg) => LogicalOperator::Filter(FilterOp {
901                predicate,
902                pushdown_hint: None,
903                input: Box::new(LogicalOperator::Aggregate(agg)),
904            }),
905
906            // For NodeScan, we've reached the bottom - keep filter on top
907            LogicalOperator::NodeScan(scan) => LogicalOperator::Filter(FilterOp {
908                predicate,
909                pushdown_hint: None,
910                input: Box::new(LogicalOperator::NodeScan(scan)),
911            }),
912
913            // For other operators, keep filter on top
914            other => LogicalOperator::Filter(FilterOp {
915                predicate,
916                pushdown_hint: None,
917                input: Box::new(other),
918            }),
919        }
920    }
921
922    /// Collects all output variable names from an operator.
923    fn collect_output_variables(&self, op: &LogicalOperator) -> HashSet<String> {
924        let mut vars = HashSet::new();
925        Self::collect_output_variables_recursive(op, &mut vars);
926        vars
927    }
928
929    /// Recursively collects output variables from an operator.
930    fn collect_output_variables_recursive(op: &LogicalOperator, vars: &mut HashSet<String>) {
931        match op {
932            LogicalOperator::NodeScan(scan) => {
933                vars.insert(scan.variable.clone());
934            }
935            LogicalOperator::EdgeScan(scan) => {
936                vars.insert(scan.variable.clone());
937            }
938            LogicalOperator::Expand(expand) => {
939                vars.insert(expand.to_variable.clone());
940                if let Some(edge_var) = &expand.edge_variable {
941                    vars.insert(edge_var.clone());
942                }
943                Self::collect_output_variables_recursive(&expand.input, vars);
944            }
945            LogicalOperator::Filter(filter) => {
946                Self::collect_output_variables_recursive(&filter.input, vars);
947            }
948            LogicalOperator::Project(proj) => {
949                for p in &proj.projections {
950                    if let Some(alias) = &p.alias {
951                        vars.insert(alias.clone());
952                    }
953                }
954                Self::collect_output_variables_recursive(&proj.input, vars);
955            }
956            LogicalOperator::Join(join) => {
957                Self::collect_output_variables_recursive(&join.left, vars);
958                Self::collect_output_variables_recursive(&join.right, vars);
959            }
960            LogicalOperator::Aggregate(agg) => {
961                for expr in &agg.group_by {
962                    Self::collect_variables(expr, vars);
963                }
964                for agg_expr in &agg.aggregates {
965                    if let Some(alias) = &agg_expr.alias {
966                        vars.insert(alias.clone());
967                    }
968                }
969            }
970            LogicalOperator::Return(ret) => {
971                Self::collect_output_variables_recursive(&ret.input, vars);
972            }
973            LogicalOperator::Limit(limit) => {
974                Self::collect_output_variables_recursive(&limit.input, vars);
975            }
976            LogicalOperator::Skip(skip) => {
977                Self::collect_output_variables_recursive(&skip.input, vars);
978            }
979            LogicalOperator::Sort(sort) => {
980                Self::collect_output_variables_recursive(&sort.input, vars);
981            }
982            LogicalOperator::Distinct(distinct) => {
983                Self::collect_output_variables_recursive(&distinct.input, vars);
984            }
985            _ => {}
986        }
987    }
988
989    /// Extracts all variable names referenced in an expression.
990    fn extract_variables(&self, expr: &LogicalExpression) -> HashSet<String> {
991        let mut vars = HashSet::new();
992        Self::collect_variables(expr, &mut vars);
993        vars
994    }
995
996    /// Recursively collects variable names from an expression.
997    fn collect_variables(expr: &LogicalExpression, vars: &mut HashSet<String>) {
998        match expr {
999            LogicalExpression::Variable(name) => {
1000                vars.insert(name.clone());
1001            }
1002            LogicalExpression::Property { variable, .. } => {
1003                vars.insert(variable.clone());
1004            }
1005            LogicalExpression::Binary { left, right, .. } => {
1006                Self::collect_variables(left, vars);
1007                Self::collect_variables(right, vars);
1008            }
1009            LogicalExpression::Unary { operand, .. } => {
1010                Self::collect_variables(operand, vars);
1011            }
1012            LogicalExpression::FunctionCall { args, .. } => {
1013                for arg in args {
1014                    Self::collect_variables(arg, vars);
1015                }
1016            }
1017            LogicalExpression::List(items) => {
1018                for item in items {
1019                    Self::collect_variables(item, vars);
1020                }
1021            }
1022            LogicalExpression::Map(pairs) => {
1023                for (_, value) in pairs {
1024                    Self::collect_variables(value, vars);
1025                }
1026            }
1027            LogicalExpression::IndexAccess { base, index } => {
1028                Self::collect_variables(base, vars);
1029                Self::collect_variables(index, vars);
1030            }
1031            LogicalExpression::SliceAccess { base, start, end } => {
1032                Self::collect_variables(base, vars);
1033                if let Some(s) = start {
1034                    Self::collect_variables(s, vars);
1035                }
1036                if let Some(e) = end {
1037                    Self::collect_variables(e, vars);
1038                }
1039            }
1040            LogicalExpression::Case {
1041                operand,
1042                when_clauses,
1043                else_clause,
1044            } => {
1045                if let Some(op) = operand {
1046                    Self::collect_variables(op, vars);
1047                }
1048                for (cond, result) in when_clauses {
1049                    Self::collect_variables(cond, vars);
1050                    Self::collect_variables(result, vars);
1051                }
1052                if let Some(else_expr) = else_clause {
1053                    Self::collect_variables(else_expr, vars);
1054                }
1055            }
1056            LogicalExpression::Labels(var)
1057            | LogicalExpression::Type(var)
1058            | LogicalExpression::Id(var) => {
1059                vars.insert(var.clone());
1060            }
1061            LogicalExpression::Literal(_) | LogicalExpression::Parameter(_) => {}
1062            LogicalExpression::ListComprehension {
1063                list_expr,
1064                filter_expr,
1065                map_expr,
1066                ..
1067            } => {
1068                Self::collect_variables(list_expr, vars);
1069                if let Some(filter) = filter_expr {
1070                    Self::collect_variables(filter, vars);
1071                }
1072                Self::collect_variables(map_expr, vars);
1073            }
1074            LogicalExpression::ListPredicate {
1075                list_expr,
1076                predicate,
1077                ..
1078            } => {
1079                Self::collect_variables(list_expr, vars);
1080                Self::collect_variables(predicate, vars);
1081            }
1082            LogicalExpression::ExistsSubquery(_)
1083            | LogicalExpression::CountSubquery(_)
1084            | LogicalExpression::ValueSubquery(_) => {
1085                // Subqueries have their own variable scope
1086            }
1087            LogicalExpression::PatternComprehension { projection, .. } => {
1088                Self::collect_variables(projection, vars);
1089            }
1090            LogicalExpression::MapProjection { base, entries } => {
1091                vars.insert(base.clone());
1092                for entry in entries {
1093                    if let crate::query::plan::MapProjectionEntry::LiteralEntry(_, expr) = entry {
1094                        Self::collect_variables(expr, vars);
1095                    }
1096                }
1097            }
1098            LogicalExpression::Reduce {
1099                initial,
1100                list,
1101                expression,
1102                ..
1103            } => {
1104                Self::collect_variables(initial, vars);
1105                Self::collect_variables(list, vars);
1106                Self::collect_variables(expression, vars);
1107            }
1108        }
1109    }
1110
1111    /// Extracts aliases from projection expressions.
1112    fn extract_projection_aliases(
1113        &self,
1114        projections: &[crate::query::plan::Projection],
1115    ) -> HashSet<String> {
1116        projections.iter().filter_map(|p| p.alias.clone()).collect()
1117    }
1118}
1119
1120impl Default for Optimizer {
1121    fn default() -> Self {
1122        Self::new()
1123    }
1124}
1125
1126#[cfg(test)]
1127mod tests {
1128    use super::*;
1129    use crate::query::plan::{
1130        AggregateExpr, AggregateFunction, AggregateOp, BinaryOp, DistinctOp, ExpandDirection,
1131        ExpandOp, JoinOp, JoinType, LimitOp, NodeScanOp, PathMode, ProjectOp, Projection,
1132        ReturnItem, ReturnOp, SkipOp, SortKey, SortOp, SortOrder, UnaryOp,
1133    };
1134    use grafeo_common::types::Value;
1135
1136    #[test]
1137    fn test_optimizer_filter_pushdown_simple() {
1138        // Query: MATCH (n:Person) WHERE n.age > 30 RETURN n
1139        // Before: Return -> Filter -> NodeScan
1140        // After:  Return -> Filter -> NodeScan (filter stays at bottom)
1141
1142        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1143            items: vec![ReturnItem {
1144                expression: LogicalExpression::Variable("n".to_string()),
1145                alias: None,
1146            }],
1147            distinct: false,
1148            input: Box::new(LogicalOperator::Filter(FilterOp {
1149                predicate: LogicalExpression::Binary {
1150                    left: Box::new(LogicalExpression::Property {
1151                        variable: "n".to_string(),
1152                        property: "age".to_string(),
1153                    }),
1154                    op: BinaryOp::Gt,
1155                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1156                },
1157                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1158                    variable: "n".to_string(),
1159                    label: Some("Person".to_string()),
1160                    input: None,
1161                })),
1162                pushdown_hint: None,
1163            })),
1164        }));
1165
1166        let optimizer = Optimizer::new();
1167        let optimized = optimizer.optimize(plan).unwrap();
1168
1169        // The structure should remain similar (filter stays near scan)
1170        if let LogicalOperator::Return(ret) = &optimized.root
1171            && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1172            && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1173        {
1174            assert_eq!(scan.variable, "n");
1175            return;
1176        }
1177        panic!("Expected Return -> Filter -> NodeScan structure");
1178    }
1179
1180    #[test]
1181    fn test_optimizer_filter_pushdown_through_expand() {
1182        // Query: MATCH (a:Person)-[:KNOWS]->(b) WHERE a.age > 30 RETURN b
1183        // The filter on 'a' should be pushed before the expand
1184
1185        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1186            items: vec![ReturnItem {
1187                expression: LogicalExpression::Variable("b".to_string()),
1188                alias: None,
1189            }],
1190            distinct: false,
1191            input: Box::new(LogicalOperator::Filter(FilterOp {
1192                predicate: LogicalExpression::Binary {
1193                    left: Box::new(LogicalExpression::Property {
1194                        variable: "a".to_string(),
1195                        property: "age".to_string(),
1196                    }),
1197                    op: BinaryOp::Gt,
1198                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1199                },
1200                pushdown_hint: None,
1201                input: Box::new(LogicalOperator::Expand(ExpandOp {
1202                    from_variable: "a".to_string(),
1203                    to_variable: "b".to_string(),
1204                    edge_variable: None,
1205                    direction: ExpandDirection::Outgoing,
1206                    edge_types: vec!["KNOWS".to_string()],
1207                    min_hops: 1,
1208                    max_hops: Some(1),
1209                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1210                        variable: "a".to_string(),
1211                        label: Some("Person".to_string()),
1212                        input: None,
1213                    })),
1214                    path_alias: None,
1215                    path_mode: PathMode::Walk,
1216                })),
1217            })),
1218        }));
1219
1220        let optimizer = Optimizer::new();
1221        let optimized = optimizer.optimize(plan).unwrap();
1222
1223        // Filter on 'a' should be pushed before the expand
1224        // Expected: Return -> Expand -> Filter -> NodeScan
1225        if let LogicalOperator::Return(ret) = &optimized.root
1226            && let LogicalOperator::Expand(expand) = ret.input.as_ref()
1227            && let LogicalOperator::Filter(filter) = expand.input.as_ref()
1228            && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1229        {
1230            assert_eq!(scan.variable, "a");
1231            assert_eq!(expand.from_variable, "a");
1232            assert_eq!(expand.to_variable, "b");
1233            return;
1234        }
1235        panic!("Expected Return -> Expand -> Filter -> NodeScan structure");
1236    }
1237
1238    #[test]
1239    fn test_optimizer_filter_not_pushed_through_expand_for_target_var() {
1240        // Query: MATCH (a:Person)-[:KNOWS]->(b) WHERE b.age > 30 RETURN a
1241        // The filter on 'b' should NOT be pushed before the expand
1242
1243        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1244            items: vec![ReturnItem {
1245                expression: LogicalExpression::Variable("a".to_string()),
1246                alias: None,
1247            }],
1248            distinct: false,
1249            input: Box::new(LogicalOperator::Filter(FilterOp {
1250                predicate: LogicalExpression::Binary {
1251                    left: Box::new(LogicalExpression::Property {
1252                        variable: "b".to_string(),
1253                        property: "age".to_string(),
1254                    }),
1255                    op: BinaryOp::Gt,
1256                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1257                },
1258                pushdown_hint: None,
1259                input: Box::new(LogicalOperator::Expand(ExpandOp {
1260                    from_variable: "a".to_string(),
1261                    to_variable: "b".to_string(),
1262                    edge_variable: None,
1263                    direction: ExpandDirection::Outgoing,
1264                    edge_types: vec!["KNOWS".to_string()],
1265                    min_hops: 1,
1266                    max_hops: Some(1),
1267                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1268                        variable: "a".to_string(),
1269                        label: Some("Person".to_string()),
1270                        input: None,
1271                    })),
1272                    path_alias: None,
1273                    path_mode: PathMode::Walk,
1274                })),
1275            })),
1276        }));
1277
1278        let optimizer = Optimizer::new();
1279        let optimized = optimizer.optimize(plan).unwrap();
1280
1281        // Filter on 'b' should stay after the expand
1282        // Expected: Return -> Filter -> Expand -> NodeScan
1283        if let LogicalOperator::Return(ret) = &optimized.root
1284            && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1285        {
1286            // Check that the filter is on 'b'
1287            if let LogicalExpression::Binary { left, .. } = &filter.predicate
1288                && let LogicalExpression::Property { variable, .. } = left.as_ref()
1289            {
1290                assert_eq!(variable, "b");
1291            }
1292
1293            if let LogicalOperator::Expand(expand) = filter.input.as_ref()
1294                && let LogicalOperator::NodeScan(_) = expand.input.as_ref()
1295            {
1296                return;
1297            }
1298        }
1299        panic!("Expected Return -> Filter -> Expand -> NodeScan structure");
1300    }
1301
1302    #[test]
1303    fn test_optimizer_extract_variables() {
1304        let optimizer = Optimizer::new();
1305
1306        let expr = LogicalExpression::Binary {
1307            left: Box::new(LogicalExpression::Property {
1308                variable: "n".to_string(),
1309                property: "age".to_string(),
1310            }),
1311            op: BinaryOp::Gt,
1312            right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1313        };
1314
1315        let vars = optimizer.extract_variables(&expr);
1316        assert_eq!(vars.len(), 1);
1317        assert!(vars.contains("n"));
1318    }
1319
1320    // Additional tests for optimizer configuration
1321
1322    #[test]
1323    fn test_optimizer_default() {
1324        let optimizer = Optimizer::default();
1325        // Should be able to optimize an empty plan
1326        let plan = LogicalPlan::new(LogicalOperator::Empty);
1327        let result = optimizer.optimize(plan);
1328        assert!(result.is_ok());
1329    }
1330
1331    #[test]
1332    fn test_optimizer_with_filter_pushdown_disabled() {
1333        let optimizer = Optimizer::new().with_filter_pushdown(false);
1334
1335        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1336            items: vec![ReturnItem {
1337                expression: LogicalExpression::Variable("n".to_string()),
1338                alias: None,
1339            }],
1340            distinct: false,
1341            input: Box::new(LogicalOperator::Filter(FilterOp {
1342                predicate: LogicalExpression::Literal(Value::Bool(true)),
1343                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1344                    variable: "n".to_string(),
1345                    label: None,
1346                    input: None,
1347                })),
1348                pushdown_hint: None,
1349            })),
1350        }));
1351
1352        let optimized = optimizer.optimize(plan).unwrap();
1353        // Structure should be unchanged
1354        if let LogicalOperator::Return(ret) = &optimized.root
1355            && let LogicalOperator::Filter(_) = ret.input.as_ref()
1356        {
1357            return;
1358        }
1359        panic!("Expected unchanged structure");
1360    }
1361
1362    #[test]
1363    fn test_optimizer_with_join_reorder_disabled() {
1364        let optimizer = Optimizer::new().with_join_reorder(false);
1365        assert!(
1366            optimizer
1367                .optimize(LogicalPlan::new(LogicalOperator::Empty))
1368                .is_ok()
1369        );
1370    }
1371
1372    #[test]
1373    fn test_optimizer_with_cost_model() {
1374        let cost_model = CostModel::new();
1375        let optimizer = Optimizer::new().with_cost_model(cost_model);
1376        assert!(
1377            optimizer
1378                .cost_model()
1379                .estimate(&LogicalOperator::Empty, 0.0)
1380                .total()
1381                < 0.001
1382        );
1383    }
1384
1385    #[test]
1386    fn test_optimizer_with_cardinality_estimator() {
1387        let mut estimator = CardinalityEstimator::new();
1388        estimator.add_table_stats("Test", TableStats::new(500));
1389        let optimizer = Optimizer::new().with_cardinality_estimator(estimator);
1390
1391        let scan = LogicalOperator::NodeScan(NodeScanOp {
1392            variable: "n".to_string(),
1393            label: Some("Test".to_string()),
1394            input: None,
1395        });
1396        let plan = LogicalPlan::new(scan);
1397
1398        let cardinality = optimizer.estimate_cardinality(&plan);
1399        assert!((cardinality - 500.0).abs() < 0.001);
1400    }
1401
1402    #[test]
1403    fn test_optimizer_estimate_cost() {
1404        let optimizer = Optimizer::new();
1405        let plan = LogicalPlan::new(LogicalOperator::NodeScan(NodeScanOp {
1406            variable: "n".to_string(),
1407            label: None,
1408            input: None,
1409        }));
1410
1411        let cost = optimizer.estimate_cost(&plan);
1412        assert!(cost.total() > 0.0);
1413    }
1414
1415    // Filter pushdown through various operators
1416
1417    #[test]
1418    fn test_filter_pushdown_through_project() {
1419        let optimizer = Optimizer::new();
1420
1421        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1422            predicate: LogicalExpression::Binary {
1423                left: Box::new(LogicalExpression::Property {
1424                    variable: "n".to_string(),
1425                    property: "age".to_string(),
1426                }),
1427                op: BinaryOp::Gt,
1428                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1429            },
1430            pushdown_hint: None,
1431            input: Box::new(LogicalOperator::Project(ProjectOp {
1432                projections: vec![Projection {
1433                    expression: LogicalExpression::Variable("n".to_string()),
1434                    alias: None,
1435                }],
1436                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1437                    variable: "n".to_string(),
1438                    label: None,
1439                    input: None,
1440                })),
1441                pass_through_input: false,
1442            })),
1443        }));
1444
1445        let optimized = optimizer.optimize(plan).unwrap();
1446
1447        // Filter should be pushed through Project
1448        if let LogicalOperator::Project(proj) = &optimized.root
1449            && let LogicalOperator::Filter(_) = proj.input.as_ref()
1450        {
1451            return;
1452        }
1453        panic!("Expected Project -> Filter structure");
1454    }
1455
1456    #[test]
1457    fn test_filter_not_pushed_through_project_with_alias() {
1458        let optimizer = Optimizer::new();
1459
1460        // Filter on computed column 'x' should not be pushed through project that creates 'x'
1461        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1462            predicate: LogicalExpression::Binary {
1463                left: Box::new(LogicalExpression::Variable("x".to_string())),
1464                op: BinaryOp::Gt,
1465                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1466            },
1467            pushdown_hint: None,
1468            input: Box::new(LogicalOperator::Project(ProjectOp {
1469                projections: vec![Projection {
1470                    expression: LogicalExpression::Property {
1471                        variable: "n".to_string(),
1472                        property: "age".to_string(),
1473                    },
1474                    alias: Some("x".to_string()),
1475                }],
1476                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1477                    variable: "n".to_string(),
1478                    label: None,
1479                    input: None,
1480                })),
1481                pass_through_input: false,
1482            })),
1483        }));
1484
1485        let optimized = optimizer.optimize(plan).unwrap();
1486
1487        // Filter should stay above Project
1488        if let LogicalOperator::Filter(filter) = &optimized.root
1489            && let LogicalOperator::Project(_) = filter.input.as_ref()
1490        {
1491            return;
1492        }
1493        panic!("Expected Filter -> Project structure");
1494    }
1495
1496    #[test]
1497    fn test_filter_pushdown_through_limit() {
1498        let optimizer = Optimizer::new();
1499
1500        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1501            predicate: LogicalExpression::Literal(Value::Bool(true)),
1502            pushdown_hint: None,
1503            input: Box::new(LogicalOperator::Limit(LimitOp {
1504                count: 10.into(),
1505                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1506                    variable: "n".to_string(),
1507                    label: None,
1508                    input: None,
1509                })),
1510            })),
1511        }));
1512
1513        let optimized = optimizer.optimize(plan).unwrap();
1514
1515        // Filter stays above Limit (cannot be pushed through)
1516        if let LogicalOperator::Filter(filter) = &optimized.root
1517            && let LogicalOperator::Limit(_) = filter.input.as_ref()
1518        {
1519            return;
1520        }
1521        panic!("Expected Filter -> Limit structure");
1522    }
1523
1524    #[test]
1525    fn test_filter_pushdown_through_sort() {
1526        let optimizer = Optimizer::new();
1527
1528        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1529            predicate: LogicalExpression::Literal(Value::Bool(true)),
1530            pushdown_hint: None,
1531            input: Box::new(LogicalOperator::Sort(SortOp {
1532                keys: vec![SortKey {
1533                    expression: LogicalExpression::Variable("n".to_string()),
1534                    order: SortOrder::Ascending,
1535                    nulls: None,
1536                }],
1537                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1538                    variable: "n".to_string(),
1539                    label: None,
1540                    input: None,
1541                })),
1542            })),
1543        }));
1544
1545        let optimized = optimizer.optimize(plan).unwrap();
1546
1547        // Filter stays above Sort
1548        if let LogicalOperator::Filter(filter) = &optimized.root
1549            && let LogicalOperator::Sort(_) = filter.input.as_ref()
1550        {
1551            return;
1552        }
1553        panic!("Expected Filter -> Sort structure");
1554    }
1555
1556    #[test]
1557    fn test_filter_pushdown_through_distinct() {
1558        let optimizer = Optimizer::new();
1559
1560        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1561            predicate: LogicalExpression::Literal(Value::Bool(true)),
1562            pushdown_hint: None,
1563            input: Box::new(LogicalOperator::Distinct(DistinctOp {
1564                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1565                    variable: "n".to_string(),
1566                    label: None,
1567                    input: None,
1568                })),
1569                columns: None,
1570            })),
1571        }));
1572
1573        let optimized = optimizer.optimize(plan).unwrap();
1574
1575        // Filter stays above Distinct
1576        if let LogicalOperator::Filter(filter) = &optimized.root
1577            && let LogicalOperator::Distinct(_) = filter.input.as_ref()
1578        {
1579            return;
1580        }
1581        panic!("Expected Filter -> Distinct structure");
1582    }
1583
1584    #[test]
1585    fn test_filter_not_pushed_through_aggregate() {
1586        let optimizer = Optimizer::new();
1587
1588        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1589            predicate: LogicalExpression::Binary {
1590                left: Box::new(LogicalExpression::Variable("cnt".to_string())),
1591                op: BinaryOp::Gt,
1592                right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1593            },
1594            pushdown_hint: None,
1595            input: Box::new(LogicalOperator::Aggregate(AggregateOp {
1596                group_by: vec![],
1597                aggregates: vec![AggregateExpr {
1598                    function: AggregateFunction::Count,
1599                    expression: None,
1600                    expression2: None,
1601                    distinct: false,
1602                    alias: Some("cnt".to_string()),
1603                    percentile: None,
1604                    separator: None,
1605                }],
1606                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1607                    variable: "n".to_string(),
1608                    label: None,
1609                    input: None,
1610                })),
1611                having: None,
1612            })),
1613        }));
1614
1615        let optimized = optimizer.optimize(plan).unwrap();
1616
1617        // Filter should stay above Aggregate
1618        if let LogicalOperator::Filter(filter) = &optimized.root
1619            && let LogicalOperator::Aggregate(_) = filter.input.as_ref()
1620        {
1621            return;
1622        }
1623        panic!("Expected Filter -> Aggregate structure");
1624    }
1625
1626    #[test]
1627    fn test_filter_pushdown_to_left_join_side() {
1628        let optimizer = Optimizer::new();
1629
1630        // Filter on left variable should be pushed to left side
1631        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1632            predicate: LogicalExpression::Binary {
1633                left: Box::new(LogicalExpression::Property {
1634                    variable: "a".to_string(),
1635                    property: "age".to_string(),
1636                }),
1637                op: BinaryOp::Gt,
1638                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1639            },
1640            pushdown_hint: None,
1641            input: Box::new(LogicalOperator::Join(JoinOp {
1642                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1643                    variable: "a".to_string(),
1644                    label: Some("Person".to_string()),
1645                    input: None,
1646                })),
1647                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1648                    variable: "b".to_string(),
1649                    label: Some("Company".to_string()),
1650                    input: None,
1651                })),
1652                join_type: JoinType::Inner,
1653                conditions: vec![],
1654            })),
1655        }));
1656
1657        let optimized = optimizer.optimize(plan).unwrap();
1658
1659        // Filter should be pushed to left side of join
1660        if let LogicalOperator::Join(join) = &optimized.root
1661            && let LogicalOperator::Filter(_) = join.left.as_ref()
1662        {
1663            return;
1664        }
1665        panic!("Expected Join with Filter on left side");
1666    }
1667
1668    #[test]
1669    fn test_filter_pushdown_to_right_join_side() {
1670        let optimizer = Optimizer::new();
1671
1672        // Filter on right variable should be pushed to right side
1673        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1674            predicate: LogicalExpression::Binary {
1675                left: Box::new(LogicalExpression::Property {
1676                    variable: "b".to_string(),
1677                    property: "name".to_string(),
1678                }),
1679                op: BinaryOp::Eq,
1680                right: Box::new(LogicalExpression::Literal(Value::String("Acme".into()))),
1681            },
1682            pushdown_hint: None,
1683            input: Box::new(LogicalOperator::Join(JoinOp {
1684                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1685                    variable: "a".to_string(),
1686                    label: Some("Person".to_string()),
1687                    input: None,
1688                })),
1689                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1690                    variable: "b".to_string(),
1691                    label: Some("Company".to_string()),
1692                    input: None,
1693                })),
1694                join_type: JoinType::Inner,
1695                conditions: vec![],
1696            })),
1697        }));
1698
1699        let optimized = optimizer.optimize(plan).unwrap();
1700
1701        // Filter should be pushed to right side of join
1702        if let LogicalOperator::Join(join) = &optimized.root
1703            && let LogicalOperator::Filter(_) = join.right.as_ref()
1704        {
1705            return;
1706        }
1707        panic!("Expected Join with Filter on right side");
1708    }
1709
1710    #[test]
1711    fn test_filter_not_pushed_when_uses_both_join_sides() {
1712        let optimizer = Optimizer::new();
1713
1714        // Filter using both variables should stay above join
1715        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1716            predicate: LogicalExpression::Binary {
1717                left: Box::new(LogicalExpression::Property {
1718                    variable: "a".to_string(),
1719                    property: "id".to_string(),
1720                }),
1721                op: BinaryOp::Eq,
1722                right: Box::new(LogicalExpression::Property {
1723                    variable: "b".to_string(),
1724                    property: "a_id".to_string(),
1725                }),
1726            },
1727            pushdown_hint: None,
1728            input: Box::new(LogicalOperator::Join(JoinOp {
1729                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1730                    variable: "a".to_string(),
1731                    label: None,
1732                    input: None,
1733                })),
1734                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1735                    variable: "b".to_string(),
1736                    label: None,
1737                    input: None,
1738                })),
1739                join_type: JoinType::Inner,
1740                conditions: vec![],
1741            })),
1742        }));
1743
1744        let optimized = optimizer.optimize(plan).unwrap();
1745
1746        // Filter should stay above join
1747        if let LogicalOperator::Filter(filter) = &optimized.root
1748            && let LogicalOperator::Join(_) = filter.input.as_ref()
1749        {
1750            return;
1751        }
1752        panic!("Expected Filter -> Join structure");
1753    }
1754
1755    // Variable extraction tests
1756
1757    #[test]
1758    fn test_extract_variables_from_variable() {
1759        let optimizer = Optimizer::new();
1760        let expr = LogicalExpression::Variable("x".to_string());
1761        let vars = optimizer.extract_variables(&expr);
1762        assert_eq!(vars.len(), 1);
1763        assert!(vars.contains("x"));
1764    }
1765
1766    #[test]
1767    fn test_extract_variables_from_unary() {
1768        let optimizer = Optimizer::new();
1769        let expr = LogicalExpression::Unary {
1770            op: UnaryOp::Not,
1771            operand: Box::new(LogicalExpression::Variable("x".to_string())),
1772        };
1773        let vars = optimizer.extract_variables(&expr);
1774        assert_eq!(vars.len(), 1);
1775        assert!(vars.contains("x"));
1776    }
1777
1778    #[test]
1779    fn test_extract_variables_from_function_call() {
1780        let optimizer = Optimizer::new();
1781        let expr = LogicalExpression::FunctionCall {
1782            name: "length".to_string(),
1783            args: vec![
1784                LogicalExpression::Variable("a".to_string()),
1785                LogicalExpression::Variable("b".to_string()),
1786            ],
1787            distinct: false,
1788        };
1789        let vars = optimizer.extract_variables(&expr);
1790        assert_eq!(vars.len(), 2);
1791        assert!(vars.contains("a"));
1792        assert!(vars.contains("b"));
1793    }
1794
1795    #[test]
1796    fn test_extract_variables_from_list() {
1797        let optimizer = Optimizer::new();
1798        let expr = LogicalExpression::List(vec![
1799            LogicalExpression::Variable("a".to_string()),
1800            LogicalExpression::Literal(Value::Int64(1)),
1801            LogicalExpression::Variable("b".to_string()),
1802        ]);
1803        let vars = optimizer.extract_variables(&expr);
1804        assert_eq!(vars.len(), 2);
1805        assert!(vars.contains("a"));
1806        assert!(vars.contains("b"));
1807    }
1808
1809    #[test]
1810    fn test_extract_variables_from_map() {
1811        let optimizer = Optimizer::new();
1812        let expr = LogicalExpression::Map(vec![
1813            (
1814                "key1".to_string(),
1815                LogicalExpression::Variable("a".to_string()),
1816            ),
1817            (
1818                "key2".to_string(),
1819                LogicalExpression::Variable("b".to_string()),
1820            ),
1821        ]);
1822        let vars = optimizer.extract_variables(&expr);
1823        assert_eq!(vars.len(), 2);
1824        assert!(vars.contains("a"));
1825        assert!(vars.contains("b"));
1826    }
1827
1828    #[test]
1829    fn test_extract_variables_from_index_access() {
1830        let optimizer = Optimizer::new();
1831        let expr = LogicalExpression::IndexAccess {
1832            base: Box::new(LogicalExpression::Variable("list".to_string())),
1833            index: Box::new(LogicalExpression::Variable("idx".to_string())),
1834        };
1835        let vars = optimizer.extract_variables(&expr);
1836        assert_eq!(vars.len(), 2);
1837        assert!(vars.contains("list"));
1838        assert!(vars.contains("idx"));
1839    }
1840
1841    #[test]
1842    fn test_extract_variables_from_slice_access() {
1843        let optimizer = Optimizer::new();
1844        let expr = LogicalExpression::SliceAccess {
1845            base: Box::new(LogicalExpression::Variable("list".to_string())),
1846            start: Some(Box::new(LogicalExpression::Variable("s".to_string()))),
1847            end: Some(Box::new(LogicalExpression::Variable("e".to_string()))),
1848        };
1849        let vars = optimizer.extract_variables(&expr);
1850        assert_eq!(vars.len(), 3);
1851        assert!(vars.contains("list"));
1852        assert!(vars.contains("s"));
1853        assert!(vars.contains("e"));
1854    }
1855
1856    #[test]
1857    fn test_extract_variables_from_case() {
1858        let optimizer = Optimizer::new();
1859        let expr = LogicalExpression::Case {
1860            operand: Some(Box::new(LogicalExpression::Variable("x".to_string()))),
1861            when_clauses: vec![(
1862                LogicalExpression::Literal(Value::Int64(1)),
1863                LogicalExpression::Variable("a".to_string()),
1864            )],
1865            else_clause: Some(Box::new(LogicalExpression::Variable("b".to_string()))),
1866        };
1867        let vars = optimizer.extract_variables(&expr);
1868        assert_eq!(vars.len(), 3);
1869        assert!(vars.contains("x"));
1870        assert!(vars.contains("a"));
1871        assert!(vars.contains("b"));
1872    }
1873
1874    #[test]
1875    fn test_extract_variables_from_labels() {
1876        let optimizer = Optimizer::new();
1877        let expr = LogicalExpression::Labels("n".to_string());
1878        let vars = optimizer.extract_variables(&expr);
1879        assert_eq!(vars.len(), 1);
1880        assert!(vars.contains("n"));
1881    }
1882
1883    #[test]
1884    fn test_extract_variables_from_type() {
1885        let optimizer = Optimizer::new();
1886        let expr = LogicalExpression::Type("e".to_string());
1887        let vars = optimizer.extract_variables(&expr);
1888        assert_eq!(vars.len(), 1);
1889        assert!(vars.contains("e"));
1890    }
1891
1892    #[test]
1893    fn test_extract_variables_from_id() {
1894        let optimizer = Optimizer::new();
1895        let expr = LogicalExpression::Id("n".to_string());
1896        let vars = optimizer.extract_variables(&expr);
1897        assert_eq!(vars.len(), 1);
1898        assert!(vars.contains("n"));
1899    }
1900
1901    #[test]
1902    fn test_extract_variables_from_list_comprehension() {
1903        let optimizer = Optimizer::new();
1904        let expr = LogicalExpression::ListComprehension {
1905            variable: "x".to_string(),
1906            list_expr: Box::new(LogicalExpression::Variable("items".to_string())),
1907            filter_expr: Some(Box::new(LogicalExpression::Variable("pred".to_string()))),
1908            map_expr: Box::new(LogicalExpression::Variable("result".to_string())),
1909        };
1910        let vars = optimizer.extract_variables(&expr);
1911        assert!(vars.contains("items"));
1912        assert!(vars.contains("pred"));
1913        assert!(vars.contains("result"));
1914    }
1915
1916    #[test]
1917    fn test_extract_variables_from_literal_and_parameter() {
1918        let optimizer = Optimizer::new();
1919
1920        let literal = LogicalExpression::Literal(Value::Int64(42));
1921        assert!(optimizer.extract_variables(&literal).is_empty());
1922
1923        let param = LogicalExpression::Parameter("p".to_string());
1924        assert!(optimizer.extract_variables(&param).is_empty());
1925    }
1926
1927    // Recursive filter pushdown tests
1928
1929    #[test]
1930    fn test_recursive_filter_pushdown_through_skip() {
1931        let optimizer = Optimizer::new();
1932
1933        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1934            items: vec![ReturnItem {
1935                expression: LogicalExpression::Variable("n".to_string()),
1936                alias: None,
1937            }],
1938            distinct: false,
1939            input: Box::new(LogicalOperator::Filter(FilterOp {
1940                predicate: LogicalExpression::Literal(Value::Bool(true)),
1941                pushdown_hint: None,
1942                input: Box::new(LogicalOperator::Skip(SkipOp {
1943                    count: 5.into(),
1944                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1945                        variable: "n".to_string(),
1946                        label: None,
1947                        input: None,
1948                    })),
1949                })),
1950            })),
1951        }));
1952
1953        let optimized = optimizer.optimize(plan).unwrap();
1954
1955        // Verify optimization succeeded
1956        assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1957    }
1958
1959    #[test]
1960    fn test_nested_filter_pushdown() {
1961        let optimizer = Optimizer::new();
1962
1963        // Multiple nested filters
1964        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1965            items: vec![ReturnItem {
1966                expression: LogicalExpression::Variable("n".to_string()),
1967                alias: None,
1968            }],
1969            distinct: false,
1970            input: Box::new(LogicalOperator::Filter(FilterOp {
1971                predicate: LogicalExpression::Binary {
1972                    left: Box::new(LogicalExpression::Property {
1973                        variable: "n".to_string(),
1974                        property: "x".to_string(),
1975                    }),
1976                    op: BinaryOp::Gt,
1977                    right: Box::new(LogicalExpression::Literal(Value::Int64(1))),
1978                },
1979                pushdown_hint: None,
1980                input: Box::new(LogicalOperator::Filter(FilterOp {
1981                    predicate: LogicalExpression::Binary {
1982                        left: Box::new(LogicalExpression::Property {
1983                            variable: "n".to_string(),
1984                            property: "y".to_string(),
1985                        }),
1986                        op: BinaryOp::Lt,
1987                        right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1988                    },
1989                    pushdown_hint: None,
1990                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1991                        variable: "n".to_string(),
1992                        label: None,
1993                        input: None,
1994                    })),
1995                })),
1996            })),
1997        }));
1998
1999        let optimized = optimizer.optimize(plan).unwrap();
2000        assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
2001    }
2002
2003    #[test]
2004    fn test_cyclic_join_produces_multi_way_join() {
2005        use crate::query::plan::JoinCondition;
2006
2007        // Triangle pattern: a ⋈ b ⋈ c ⋈ a (cyclic)
2008        let scan_a = LogicalOperator::NodeScan(NodeScanOp {
2009            variable: "a".to_string(),
2010            label: Some("Person".to_string()),
2011            input: None,
2012        });
2013        let scan_b = LogicalOperator::NodeScan(NodeScanOp {
2014            variable: "b".to_string(),
2015            label: Some("Person".to_string()),
2016            input: None,
2017        });
2018        let scan_c = LogicalOperator::NodeScan(NodeScanOp {
2019            variable: "c".to_string(),
2020            label: Some("Person".to_string()),
2021            input: None,
2022        });
2023
2024        // Build: Join(Join(a, b, a=b), c, b=c) with extra condition c=a
2025        let join_ab = LogicalOperator::Join(JoinOp {
2026            left: Box::new(scan_a),
2027            right: Box::new(scan_b),
2028            join_type: JoinType::Inner,
2029            conditions: vec![JoinCondition {
2030                left: LogicalExpression::Variable("a".to_string()),
2031                right: LogicalExpression::Variable("b".to_string()),
2032            }],
2033        });
2034
2035        let join_abc = LogicalOperator::Join(JoinOp {
2036            left: Box::new(join_ab),
2037            right: Box::new(scan_c),
2038            join_type: JoinType::Inner,
2039            conditions: vec![
2040                JoinCondition {
2041                    left: LogicalExpression::Variable("b".to_string()),
2042                    right: LogicalExpression::Variable("c".to_string()),
2043                },
2044                JoinCondition {
2045                    left: LogicalExpression::Variable("c".to_string()),
2046                    right: LogicalExpression::Variable("a".to_string()),
2047                },
2048            ],
2049        });
2050
2051        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
2052            items: vec![ReturnItem {
2053                expression: LogicalExpression::Variable("a".to_string()),
2054                alias: None,
2055            }],
2056            distinct: false,
2057            input: Box::new(join_abc),
2058        }));
2059
2060        let mut optimizer = Optimizer::new();
2061        optimizer
2062            .card_estimator
2063            .add_table_stats("Person", cardinality::TableStats::new(1000));
2064
2065        let optimized = optimizer.optimize(plan).unwrap();
2066
2067        // Walk the tree to find a MultiWayJoin
2068        fn has_multi_way_join(op: &LogicalOperator) -> bool {
2069            match op {
2070                LogicalOperator::MultiWayJoin(_) => true,
2071                LogicalOperator::Return(ret) => has_multi_way_join(&ret.input),
2072                LogicalOperator::Filter(f) => has_multi_way_join(&f.input),
2073                LogicalOperator::Project(p) => has_multi_way_join(&p.input),
2074                _ => false,
2075            }
2076        }
2077
2078        assert!(
2079            has_multi_way_join(&optimized.root),
2080            "Expected MultiWayJoin for cyclic triangle pattern"
2081        );
2082    }
2083
2084    #[test]
2085    fn test_acyclic_join_uses_binary_joins() {
2086        use crate::query::plan::JoinCondition;
2087
2088        // Chain: a ⋈ b ⋈ c (acyclic)
2089        let scan_a = LogicalOperator::NodeScan(NodeScanOp {
2090            variable: "a".to_string(),
2091            label: Some("Person".to_string()),
2092            input: None,
2093        });
2094        let scan_b = LogicalOperator::NodeScan(NodeScanOp {
2095            variable: "b".to_string(),
2096            label: Some("Person".to_string()),
2097            input: None,
2098        });
2099        let scan_c = LogicalOperator::NodeScan(NodeScanOp {
2100            variable: "c".to_string(),
2101            label: Some("Company".to_string()),
2102            input: None,
2103        });
2104
2105        let join_ab = LogicalOperator::Join(JoinOp {
2106            left: Box::new(scan_a),
2107            right: Box::new(scan_b),
2108            join_type: JoinType::Inner,
2109            conditions: vec![JoinCondition {
2110                left: LogicalExpression::Variable("a".to_string()),
2111                right: LogicalExpression::Variable("b".to_string()),
2112            }],
2113        });
2114
2115        let join_abc = LogicalOperator::Join(JoinOp {
2116            left: Box::new(join_ab),
2117            right: Box::new(scan_c),
2118            join_type: JoinType::Inner,
2119            conditions: vec![JoinCondition {
2120                left: LogicalExpression::Variable("b".to_string()),
2121                right: LogicalExpression::Variable("c".to_string()),
2122            }],
2123        });
2124
2125        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
2126            items: vec![ReturnItem {
2127                expression: LogicalExpression::Variable("a".to_string()),
2128                alias: None,
2129            }],
2130            distinct: false,
2131            input: Box::new(join_abc),
2132        }));
2133
2134        let mut optimizer = Optimizer::new();
2135        optimizer
2136            .card_estimator
2137            .add_table_stats("Person", cardinality::TableStats::new(1000));
2138        optimizer
2139            .card_estimator
2140            .add_table_stats("Company", cardinality::TableStats::new(100));
2141
2142        let optimized = optimizer.optimize(plan).unwrap();
2143
2144        // Should NOT contain MultiWayJoin for acyclic pattern
2145        fn has_multi_way_join(op: &LogicalOperator) -> bool {
2146            match op {
2147                LogicalOperator::MultiWayJoin(_) => true,
2148                LogicalOperator::Return(ret) => has_multi_way_join(&ret.input),
2149                LogicalOperator::Filter(f) => has_multi_way_join(&f.input),
2150                LogicalOperator::Project(p) => has_multi_way_join(&p.input),
2151                LogicalOperator::Join(j) => {
2152                    has_multi_way_join(&j.left) || has_multi_way_join(&j.right)
2153                }
2154                _ => false,
2155            }
2156        }
2157
2158        assert!(
2159            !has_multi_way_join(&optimized.root),
2160            "Acyclic join should NOT produce MultiWayJoin"
2161        );
2162    }
2163}