Skip to main content

grafeo_engine/query/optimizer/
mod.rs

1//! Makes your queries faster without changing their meaning.
2//!
3//! The optimizer transforms logical plans to run more efficiently:
4//!
5//! | Optimization | What it does |
6//! | ------------ | ------------ |
7//! | Filter Pushdown | Moves `WHERE` clauses closer to scans - filter early, process less |
8//! | Join Reordering | Picks the best order to join tables using the DPccp algorithm |
9//! | Predicate Simplification | Folds constants like `1 + 1` into `2` |
10//!
11//! The optimizer uses [`CostModel`] and [`CardinalityEstimator`] to predict
12//! how expensive different plans are, then picks the cheapest.
13
14pub mod cardinality;
15pub mod cost;
16pub mod join_order;
17
18pub use cardinality::{
19    CardinalityEstimator, ColumnStats, EstimationLog, SelectivityConfig, TableStats,
20};
21pub use cost::{Cost, CostModel};
22pub use join_order::{BitSet, DPccp, JoinGraph, JoinGraphBuilder, JoinPlan};
23
24use crate::query::plan::{
25    FilterOp, JoinCondition, LogicalExpression, LogicalOperator, LogicalPlan, MultiWayJoinOp,
26};
27use grafeo_common::utils::error::Result;
28use std::collections::HashSet;
29
30/// Information about a join condition for join reordering.
31#[derive(Debug, Clone)]
32struct JoinInfo {
33    left_var: String,
34    right_var: String,
35    left_expr: LogicalExpression,
36    right_expr: LogicalExpression,
37}
38
39/// A column required by the query, used for projection pushdown.
40#[derive(Debug, Clone, PartialEq, Eq, Hash)]
41enum RequiredColumn {
42    /// A variable (node, edge, or path binding)
43    Variable(String),
44    /// A specific property of a variable
45    Property(String, String),
46}
47
48/// Transforms logical plans for faster execution.
49///
50/// Create with [`new()`](Self::new), then call [`optimize()`](Self::optimize).
51/// Use the builder methods to enable/disable specific optimizations.
52pub struct Optimizer {
53    /// Whether to enable filter pushdown.
54    enable_filter_pushdown: bool,
55    /// Whether to enable join reordering.
56    enable_join_reorder: bool,
57    /// Whether to enable projection pushdown.
58    enable_projection_pushdown: bool,
59    /// Cost model for estimation.
60    cost_model: CostModel,
61    /// Cardinality estimator.
62    card_estimator: CardinalityEstimator,
63}
64
65impl Optimizer {
66    /// Creates a new optimizer with default settings.
67    #[must_use]
68    pub fn new() -> Self {
69        Self {
70            enable_filter_pushdown: true,
71            enable_join_reorder: true,
72            enable_projection_pushdown: true,
73            cost_model: CostModel::new(),
74            card_estimator: CardinalityEstimator::new(),
75        }
76    }
77
78    /// Creates an optimizer with cardinality estimates from the store's statistics.
79    ///
80    /// Pre-populates the cardinality estimator with per-label row counts and
81    /// edge type fanout. Feeds per-edge-type degree stats, label cardinalities,
82    /// and graph totals into the cost model for accurate estimation.
83    #[must_use]
84    pub fn from_store(store: &grafeo_core::graph::lpg::LpgStore) -> Self {
85        store.ensure_statistics_fresh();
86        let stats = store.statistics();
87        Self::from_statistics(&stats)
88    }
89
90    /// Creates an optimizer from any GraphStore implementation.
91    ///
92    /// Unlike [`from_store`](Self::from_store), this does not call
93    /// `ensure_statistics_fresh()` since external stores manage their own
94    /// statistics. The store's [`statistics()`](grafeo_core::graph::GraphStore::statistics) method
95    /// is called directly.
96    #[must_use]
97    pub fn from_graph_store(store: &dyn grafeo_core::graph::GraphStore) -> Self {
98        let stats = store.statistics();
99        Self::from_statistics(&stats)
100    }
101
102    /// Creates an optimizer from a Statistics snapshot.
103    ///
104    /// Extracts label cardinalities, edge type degrees, and graph totals
105    /// for both the cardinality estimator and cost model.
106    #[must_use]
107    fn from_statistics(stats: &grafeo_core::statistics::Statistics) -> Self {
108        let estimator = CardinalityEstimator::from_statistics(stats);
109
110        let avg_fanout = if stats.total_nodes > 0 {
111            (stats.total_edges as f64 / stats.total_nodes as f64).max(1.0)
112        } else {
113            10.0
114        };
115
116        let edge_type_degrees: std::collections::HashMap<String, (f64, f64)> = stats
117            .edge_types
118            .iter()
119            .map(|(name, et)| (name.clone(), (et.avg_out_degree, et.avg_in_degree)))
120            .collect();
121
122        let label_cardinalities: std::collections::HashMap<String, u64> = stats
123            .labels
124            .iter()
125            .map(|(name, ls)| (name.clone(), ls.node_count))
126            .collect();
127
128        Self {
129            enable_filter_pushdown: true,
130            enable_join_reorder: true,
131            enable_projection_pushdown: true,
132            cost_model: CostModel::new()
133                .with_avg_fanout(avg_fanout)
134                .with_edge_type_degrees(edge_type_degrees)
135                .with_label_cardinalities(label_cardinalities)
136                .with_graph_totals(stats.total_nodes, stats.total_edges),
137            card_estimator: estimator,
138        }
139    }
140
141    /// Enables or disables filter pushdown.
142    pub fn with_filter_pushdown(mut self, enabled: bool) -> Self {
143        self.enable_filter_pushdown = enabled;
144        self
145    }
146
147    /// Enables or disables join reordering.
148    pub fn with_join_reorder(mut self, enabled: bool) -> Self {
149        self.enable_join_reorder = enabled;
150        self
151    }
152
153    /// Enables or disables projection pushdown.
154    pub fn with_projection_pushdown(mut self, enabled: bool) -> Self {
155        self.enable_projection_pushdown = enabled;
156        self
157    }
158
159    /// Sets the cost model.
160    pub fn with_cost_model(mut self, cost_model: CostModel) -> Self {
161        self.cost_model = cost_model;
162        self
163    }
164
165    /// Sets the cardinality estimator.
166    pub fn with_cardinality_estimator(mut self, estimator: CardinalityEstimator) -> Self {
167        self.card_estimator = estimator;
168        self
169    }
170
171    /// Sets the selectivity configuration for the cardinality estimator.
172    pub fn with_selectivity_config(mut self, config: SelectivityConfig) -> Self {
173        self.card_estimator = CardinalityEstimator::with_selectivity_config(config);
174        self
175    }
176
177    /// Returns a reference to the cost model.
178    pub fn cost_model(&self) -> &CostModel {
179        &self.cost_model
180    }
181
182    /// Returns a reference to the cardinality estimator.
183    pub fn cardinality_estimator(&self) -> &CardinalityEstimator {
184        &self.card_estimator
185    }
186
187    /// Estimates the total cost of a plan by recursively costing the entire tree.
188    ///
189    /// Walks every operator in the plan, computing cardinality at each level
190    /// and summing the per-operator costs. Uses actual child cardinalities
191    /// for join cost estimation rather than approximations.
192    pub fn estimate_cost(&self, plan: &LogicalPlan) -> Cost {
193        self.cost_model
194            .estimate_tree(&plan.root, &self.card_estimator)
195    }
196
197    /// Estimates the cardinality of a plan.
198    pub fn estimate_cardinality(&self, plan: &LogicalPlan) -> f64 {
199        self.card_estimator.estimate(&plan.root)
200    }
201
202    /// Optimizes a logical plan.
203    ///
204    /// # Errors
205    ///
206    /// Returns an error if optimization fails.
207    pub fn optimize(&self, plan: LogicalPlan) -> Result<LogicalPlan> {
208        let mut root = plan.root;
209
210        // Apply optimization rules
211        if self.enable_filter_pushdown {
212            root = self.push_filters_down(root);
213        }
214
215        if self.enable_join_reorder {
216            root = self.reorder_joins(root);
217        }
218
219        if self.enable_projection_pushdown {
220            root = self.push_projections_down(root);
221        }
222
223        Ok(LogicalPlan {
224            root,
225            explain: plan.explain,
226            profile: plan.profile,
227        })
228    }
229
230    /// Pushes projections down the operator tree to eliminate unused columns early.
231    ///
232    /// This optimization:
233    /// 1. Collects required variables/properties from the root
234    /// 2. Propagates requirements down through the tree
235    /// 3. Inserts projections to eliminate unneeded columns before expensive operations
236    fn push_projections_down(&self, op: LogicalOperator) -> LogicalOperator {
237        // Collect required columns from the top of the plan
238        let required = self.collect_required_columns(&op);
239
240        // Push projections down
241        self.push_projections_recursive(op, &required)
242    }
243
244    /// Collects all variables and properties required by an operator and its ancestors.
245    fn collect_required_columns(&self, op: &LogicalOperator) -> HashSet<RequiredColumn> {
246        let mut required = HashSet::new();
247        Self::collect_required_recursive(op, &mut required);
248        required
249    }
250
251    /// Recursively collects required columns.
252    fn collect_required_recursive(op: &LogicalOperator, required: &mut HashSet<RequiredColumn>) {
253        match op {
254            LogicalOperator::Return(ret) => {
255                for item in &ret.items {
256                    Self::collect_from_expression(&item.expression, required);
257                }
258                Self::collect_required_recursive(&ret.input, required);
259            }
260            LogicalOperator::Project(proj) => {
261                for p in &proj.projections {
262                    Self::collect_from_expression(&p.expression, required);
263                }
264                Self::collect_required_recursive(&proj.input, required);
265            }
266            LogicalOperator::Filter(filter) => {
267                Self::collect_from_expression(&filter.predicate, required);
268                Self::collect_required_recursive(&filter.input, required);
269            }
270            LogicalOperator::Sort(sort) => {
271                for key in &sort.keys {
272                    Self::collect_from_expression(&key.expression, required);
273                }
274                Self::collect_required_recursive(&sort.input, required);
275            }
276            LogicalOperator::Aggregate(agg) => {
277                for expr in &agg.group_by {
278                    Self::collect_from_expression(expr, required);
279                }
280                for agg_expr in &agg.aggregates {
281                    if let Some(ref expr) = agg_expr.expression {
282                        Self::collect_from_expression(expr, required);
283                    }
284                }
285                if let Some(ref having) = agg.having {
286                    Self::collect_from_expression(having, required);
287                }
288                Self::collect_required_recursive(&agg.input, required);
289            }
290            LogicalOperator::Join(join) => {
291                for cond in &join.conditions {
292                    Self::collect_from_expression(&cond.left, required);
293                    Self::collect_from_expression(&cond.right, required);
294                }
295                Self::collect_required_recursive(&join.left, required);
296                Self::collect_required_recursive(&join.right, required);
297            }
298            LogicalOperator::Expand(expand) => {
299                // The source and target variables are needed
300                required.insert(RequiredColumn::Variable(expand.from_variable.clone()));
301                required.insert(RequiredColumn::Variable(expand.to_variable.clone()));
302                if let Some(ref edge_var) = expand.edge_variable {
303                    required.insert(RequiredColumn::Variable(edge_var.clone()));
304                }
305                Self::collect_required_recursive(&expand.input, required);
306            }
307            LogicalOperator::Limit(limit) => {
308                Self::collect_required_recursive(&limit.input, required);
309            }
310            LogicalOperator::Skip(skip) => {
311                Self::collect_required_recursive(&skip.input, required);
312            }
313            LogicalOperator::Distinct(distinct) => {
314                Self::collect_required_recursive(&distinct.input, required);
315            }
316            LogicalOperator::NodeScan(scan) => {
317                required.insert(RequiredColumn::Variable(scan.variable.clone()));
318            }
319            LogicalOperator::EdgeScan(scan) => {
320                required.insert(RequiredColumn::Variable(scan.variable.clone()));
321            }
322            LogicalOperator::MultiWayJoin(mwj) => {
323                for cond in &mwj.conditions {
324                    Self::collect_from_expression(&cond.left, required);
325                    Self::collect_from_expression(&cond.right, required);
326                }
327                for input in &mwj.inputs {
328                    Self::collect_required_recursive(input, required);
329                }
330            }
331            _ => {}
332        }
333    }
334
335    /// Collects required columns from an expression.
336    fn collect_from_expression(expr: &LogicalExpression, required: &mut HashSet<RequiredColumn>) {
337        match expr {
338            LogicalExpression::Variable(var) => {
339                required.insert(RequiredColumn::Variable(var.clone()));
340            }
341            LogicalExpression::Property { variable, property } => {
342                required.insert(RequiredColumn::Property(variable.clone(), property.clone()));
343                required.insert(RequiredColumn::Variable(variable.clone()));
344            }
345            LogicalExpression::Binary { left, right, .. } => {
346                Self::collect_from_expression(left, required);
347                Self::collect_from_expression(right, required);
348            }
349            LogicalExpression::Unary { operand, .. } => {
350                Self::collect_from_expression(operand, required);
351            }
352            LogicalExpression::FunctionCall { args, .. } => {
353                for arg in args {
354                    Self::collect_from_expression(arg, required);
355                }
356            }
357            LogicalExpression::List(items) => {
358                for item in items {
359                    Self::collect_from_expression(item, required);
360                }
361            }
362            LogicalExpression::Map(pairs) => {
363                for (_, value) in pairs {
364                    Self::collect_from_expression(value, required);
365                }
366            }
367            LogicalExpression::IndexAccess { base, index } => {
368                Self::collect_from_expression(base, required);
369                Self::collect_from_expression(index, required);
370            }
371            LogicalExpression::SliceAccess { base, start, end } => {
372                Self::collect_from_expression(base, required);
373                if let Some(s) = start {
374                    Self::collect_from_expression(s, required);
375                }
376                if let Some(e) = end {
377                    Self::collect_from_expression(e, required);
378                }
379            }
380            LogicalExpression::Case {
381                operand,
382                when_clauses,
383                else_clause,
384            } => {
385                if let Some(op) = operand {
386                    Self::collect_from_expression(op, required);
387                }
388                for (cond, result) in when_clauses {
389                    Self::collect_from_expression(cond, required);
390                    Self::collect_from_expression(result, required);
391                }
392                if let Some(else_expr) = else_clause {
393                    Self::collect_from_expression(else_expr, required);
394                }
395            }
396            LogicalExpression::Labels(var)
397            | LogicalExpression::Type(var)
398            | LogicalExpression::Id(var) => {
399                required.insert(RequiredColumn::Variable(var.clone()));
400            }
401            LogicalExpression::ListComprehension {
402                list_expr,
403                filter_expr,
404                map_expr,
405                ..
406            } => {
407                Self::collect_from_expression(list_expr, required);
408                if let Some(filter) = filter_expr {
409                    Self::collect_from_expression(filter, required);
410                }
411                Self::collect_from_expression(map_expr, required);
412            }
413            _ => {}
414        }
415    }
416
417    /// Recursively pushes projections down, adding them before expensive operations.
418    fn push_projections_recursive(
419        &self,
420        op: LogicalOperator,
421        required: &HashSet<RequiredColumn>,
422    ) -> LogicalOperator {
423        match op {
424            LogicalOperator::Return(mut ret) => {
425                ret.input = Box::new(self.push_projections_recursive(*ret.input, required));
426                LogicalOperator::Return(ret)
427            }
428            LogicalOperator::Project(mut proj) => {
429                proj.input = Box::new(self.push_projections_recursive(*proj.input, required));
430                LogicalOperator::Project(proj)
431            }
432            LogicalOperator::Filter(mut filter) => {
433                filter.input = Box::new(self.push_projections_recursive(*filter.input, required));
434                LogicalOperator::Filter(filter)
435            }
436            LogicalOperator::Sort(mut sort) => {
437                // Sort is expensive - consider adding a projection before it
438                // to reduce tuple width
439                sort.input = Box::new(self.push_projections_recursive(*sort.input, required));
440                LogicalOperator::Sort(sort)
441            }
442            LogicalOperator::Aggregate(mut agg) => {
443                agg.input = Box::new(self.push_projections_recursive(*agg.input, required));
444                LogicalOperator::Aggregate(agg)
445            }
446            LogicalOperator::Join(mut join) => {
447                // Joins are expensive - the required columns help determine
448                // what to project on each side
449                let left_vars = self.collect_output_variables(&join.left);
450                let right_vars = self.collect_output_variables(&join.right);
451
452                // Filter required columns to each side
453                let left_required: HashSet<_> = required
454                    .iter()
455                    .filter(|c| match c {
456                        RequiredColumn::Variable(v) => left_vars.contains(v),
457                        RequiredColumn::Property(v, _) => left_vars.contains(v),
458                    })
459                    .cloned()
460                    .collect();
461
462                let right_required: HashSet<_> = required
463                    .iter()
464                    .filter(|c| match c {
465                        RequiredColumn::Variable(v) => right_vars.contains(v),
466                        RequiredColumn::Property(v, _) => right_vars.contains(v),
467                    })
468                    .cloned()
469                    .collect();
470
471                join.left = Box::new(self.push_projections_recursive(*join.left, &left_required));
472                join.right =
473                    Box::new(self.push_projections_recursive(*join.right, &right_required));
474                LogicalOperator::Join(join)
475            }
476            LogicalOperator::Expand(mut expand) => {
477                expand.input = Box::new(self.push_projections_recursive(*expand.input, required));
478                LogicalOperator::Expand(expand)
479            }
480            LogicalOperator::Limit(mut limit) => {
481                limit.input = Box::new(self.push_projections_recursive(*limit.input, required));
482                LogicalOperator::Limit(limit)
483            }
484            LogicalOperator::Skip(mut skip) => {
485                skip.input = Box::new(self.push_projections_recursive(*skip.input, required));
486                LogicalOperator::Skip(skip)
487            }
488            LogicalOperator::Distinct(mut distinct) => {
489                distinct.input =
490                    Box::new(self.push_projections_recursive(*distinct.input, required));
491                LogicalOperator::Distinct(distinct)
492            }
493            LogicalOperator::MapCollect(mut mc) => {
494                mc.input = Box::new(self.push_projections_recursive(*mc.input, required));
495                LogicalOperator::MapCollect(mc)
496            }
497            LogicalOperator::MultiWayJoin(mut mwj) => {
498                mwj.inputs = mwj
499                    .inputs
500                    .into_iter()
501                    .map(|input| self.push_projections_recursive(input, required))
502                    .collect();
503                LogicalOperator::MultiWayJoin(mwj)
504            }
505            other => other,
506        }
507    }
508
509    /// Reorders joins in the operator tree using the DPccp algorithm.
510    ///
511    /// This optimization finds the optimal join order by:
512    /// 1. Extracting all base relations (scans) and join conditions
513    /// 2. Building a join graph
514    /// 3. Using dynamic programming to find the cheapest join order
515    fn reorder_joins(&self, op: LogicalOperator) -> LogicalOperator {
516        // First, recursively optimize children
517        let op = self.reorder_joins_recursive(op);
518
519        // Then, if this is a join tree, try to optimize it
520        if let Some((relations, conditions)) = self.extract_join_tree(&op)
521            && relations.len() >= 2
522            && let Some(optimized) = self.optimize_join_order(&relations, &conditions)
523        {
524            return optimized;
525        }
526
527        op
528    }
529
530    /// Recursively applies join reordering to child operators.
531    fn reorder_joins_recursive(&self, op: LogicalOperator) -> LogicalOperator {
532        match op {
533            LogicalOperator::Return(mut ret) => {
534                ret.input = Box::new(self.reorder_joins(*ret.input));
535                LogicalOperator::Return(ret)
536            }
537            LogicalOperator::Project(mut proj) => {
538                proj.input = Box::new(self.reorder_joins(*proj.input));
539                LogicalOperator::Project(proj)
540            }
541            LogicalOperator::Filter(mut filter) => {
542                filter.input = Box::new(self.reorder_joins(*filter.input));
543                LogicalOperator::Filter(filter)
544            }
545            LogicalOperator::Limit(mut limit) => {
546                limit.input = Box::new(self.reorder_joins(*limit.input));
547                LogicalOperator::Limit(limit)
548            }
549            LogicalOperator::Skip(mut skip) => {
550                skip.input = Box::new(self.reorder_joins(*skip.input));
551                LogicalOperator::Skip(skip)
552            }
553            LogicalOperator::Sort(mut sort) => {
554                sort.input = Box::new(self.reorder_joins(*sort.input));
555                LogicalOperator::Sort(sort)
556            }
557            LogicalOperator::Distinct(mut distinct) => {
558                distinct.input = Box::new(self.reorder_joins(*distinct.input));
559                LogicalOperator::Distinct(distinct)
560            }
561            LogicalOperator::Aggregate(mut agg) => {
562                agg.input = Box::new(self.reorder_joins(*agg.input));
563                LogicalOperator::Aggregate(agg)
564            }
565            LogicalOperator::Expand(mut expand) => {
566                expand.input = Box::new(self.reorder_joins(*expand.input));
567                LogicalOperator::Expand(expand)
568            }
569            LogicalOperator::MapCollect(mut mc) => {
570                mc.input = Box::new(self.reorder_joins(*mc.input));
571                LogicalOperator::MapCollect(mc)
572            }
573            LogicalOperator::MultiWayJoin(mut mwj) => {
574                mwj.inputs = mwj
575                    .inputs
576                    .into_iter()
577                    .map(|input| self.reorder_joins(input))
578                    .collect();
579                LogicalOperator::MultiWayJoin(mwj)
580            }
581            // Join operators are handled by the parent reorder_joins call
582            other => other,
583        }
584    }
585
586    /// Extracts base relations and join conditions from a join tree.
587    ///
588    /// Returns None if the operator is not a join tree.
589    fn extract_join_tree(
590        &self,
591        op: &LogicalOperator,
592    ) -> Option<(Vec<(String, LogicalOperator)>, Vec<JoinInfo>)> {
593        let mut relations = Vec::new();
594        let mut join_conditions = Vec::new();
595
596        if !self.collect_join_tree(op, &mut relations, &mut join_conditions) {
597            return None;
598        }
599
600        if relations.len() < 2 {
601            return None;
602        }
603
604        Some((relations, join_conditions))
605    }
606
607    /// Recursively collects base relations and join conditions.
608    ///
609    /// Returns true if this subtree is part of a join tree.
610    fn collect_join_tree(
611        &self,
612        op: &LogicalOperator,
613        relations: &mut Vec<(String, LogicalOperator)>,
614        conditions: &mut Vec<JoinInfo>,
615    ) -> bool {
616        match op {
617            LogicalOperator::Join(join) => {
618                // Collect from both sides
619                let left_ok = self.collect_join_tree(&join.left, relations, conditions);
620                let right_ok = self.collect_join_tree(&join.right, relations, conditions);
621
622                // Add conditions from this join
623                for cond in &join.conditions {
624                    if let (Some(left_var), Some(right_var)) = (
625                        self.extract_variable_from_expr(&cond.left),
626                        self.extract_variable_from_expr(&cond.right),
627                    ) {
628                        conditions.push(JoinInfo {
629                            left_var,
630                            right_var,
631                            left_expr: cond.left.clone(),
632                            right_expr: cond.right.clone(),
633                        });
634                    }
635                }
636
637                left_ok && right_ok
638            }
639            LogicalOperator::NodeScan(scan) => {
640                relations.push((scan.variable.clone(), op.clone()));
641                true
642            }
643            LogicalOperator::EdgeScan(scan) => {
644                relations.push((scan.variable.clone(), op.clone()));
645                true
646            }
647            LogicalOperator::Filter(filter) => {
648                // A filter on a base relation is still part of the join tree
649                self.collect_join_tree(&filter.input, relations, conditions)
650            }
651            LogicalOperator::Expand(expand) => {
652                // Expand is a special case - it's like a join with the adjacency
653                // For now, treat the whole Expand subtree as a single relation
654                relations.push((expand.to_variable.clone(), op.clone()));
655                true
656            }
657            _ => false,
658        }
659    }
660
661    /// Extracts the primary variable from an expression.
662    fn extract_variable_from_expr(&self, expr: &LogicalExpression) -> Option<String> {
663        match expr {
664            LogicalExpression::Variable(v) => Some(v.clone()),
665            LogicalExpression::Property { variable, .. } => Some(variable.clone()),
666            _ => None,
667        }
668    }
669
670    /// Optimizes the join order using DPccp, or produces a multi-way
671    /// leapfrog join for cyclic patterns when the cost model prefers it.
672    fn optimize_join_order(
673        &self,
674        relations: &[(String, LogicalOperator)],
675        conditions: &[JoinInfo],
676    ) -> Option<LogicalOperator> {
677        use join_order::{DPccp, JoinGraphBuilder};
678
679        // Build the join graph
680        let mut builder = JoinGraphBuilder::new();
681
682        for (var, relation) in relations {
683            builder.add_relation(var, relation.clone());
684        }
685
686        for cond in conditions {
687            builder.add_join_condition(
688                &cond.left_var,
689                &cond.right_var,
690                cond.left_expr.clone(),
691                cond.right_expr.clone(),
692            );
693        }
694
695        let graph = builder.build();
696
697        // For cyclic graphs with 3+ relations, use leapfrog (WCOJ) join.
698        // Cyclic joins (e.g. triangle patterns) benefit from worst-case optimal
699        // multi-way intersection rather than binary hash join cascades that can
700        // produce intermediate blowup.
701        if graph.is_cyclic() && relations.len() >= 3 {
702            // Collect shared variables (variables appearing in 2+ conditions)
703            let mut var_counts: std::collections::HashMap<&str, usize> =
704                std::collections::HashMap::new();
705            for cond in conditions {
706                *var_counts.entry(&cond.left_var).or_default() += 1;
707                *var_counts.entry(&cond.right_var).or_default() += 1;
708            }
709            let shared_variables: Vec<String> = var_counts
710                .into_iter()
711                .filter(|(_, count)| *count >= 2)
712                .map(|(var, _)| var.to_string())
713                .collect();
714
715            let join_conditions: Vec<JoinCondition> = conditions
716                .iter()
717                .map(|c| JoinCondition {
718                    left: c.left_expr.clone(),
719                    right: c.right_expr.clone(),
720                })
721                .collect();
722
723            return Some(LogicalOperator::MultiWayJoin(MultiWayJoinOp {
724                inputs: relations.iter().map(|(_, rel)| rel.clone()).collect(),
725                conditions: join_conditions,
726                shared_variables,
727            }));
728        }
729
730        // Fall through to DPccp for binary join ordering
731        let mut dpccp = DPccp::new(&graph, &self.cost_model, &self.card_estimator);
732        let plan = dpccp.optimize()?;
733
734        Some(plan.operator)
735    }
736
737    /// Pushes filters down the operator tree.
738    ///
739    /// This optimization moves filter predicates as close to the data source
740    /// as possible to reduce the amount of data processed by upper operators.
741    fn push_filters_down(&self, op: LogicalOperator) -> LogicalOperator {
742        match op {
743            // For Filter operators, try to push the predicate into the child
744            LogicalOperator::Filter(filter) => {
745                let optimized_input = self.push_filters_down(*filter.input);
746                self.try_push_filter_into(filter.predicate, optimized_input)
747            }
748            // Recursively optimize children for other operators
749            LogicalOperator::Return(mut ret) => {
750                ret.input = Box::new(self.push_filters_down(*ret.input));
751                LogicalOperator::Return(ret)
752            }
753            LogicalOperator::Project(mut proj) => {
754                proj.input = Box::new(self.push_filters_down(*proj.input));
755                LogicalOperator::Project(proj)
756            }
757            LogicalOperator::Limit(mut limit) => {
758                limit.input = Box::new(self.push_filters_down(*limit.input));
759                LogicalOperator::Limit(limit)
760            }
761            LogicalOperator::Skip(mut skip) => {
762                skip.input = Box::new(self.push_filters_down(*skip.input));
763                LogicalOperator::Skip(skip)
764            }
765            LogicalOperator::Sort(mut sort) => {
766                sort.input = Box::new(self.push_filters_down(*sort.input));
767                LogicalOperator::Sort(sort)
768            }
769            LogicalOperator::Distinct(mut distinct) => {
770                distinct.input = Box::new(self.push_filters_down(*distinct.input));
771                LogicalOperator::Distinct(distinct)
772            }
773            LogicalOperator::Expand(mut expand) => {
774                expand.input = Box::new(self.push_filters_down(*expand.input));
775                LogicalOperator::Expand(expand)
776            }
777            LogicalOperator::Join(mut join) => {
778                join.left = Box::new(self.push_filters_down(*join.left));
779                join.right = Box::new(self.push_filters_down(*join.right));
780                LogicalOperator::Join(join)
781            }
782            LogicalOperator::Aggregate(mut agg) => {
783                agg.input = Box::new(self.push_filters_down(*agg.input));
784                LogicalOperator::Aggregate(agg)
785            }
786            LogicalOperator::MapCollect(mut mc) => {
787                mc.input = Box::new(self.push_filters_down(*mc.input));
788                LogicalOperator::MapCollect(mc)
789            }
790            LogicalOperator::MultiWayJoin(mut mwj) => {
791                mwj.inputs = mwj
792                    .inputs
793                    .into_iter()
794                    .map(|input| self.push_filters_down(input))
795                    .collect();
796                LogicalOperator::MultiWayJoin(mwj)
797            }
798            // Leaf operators and unsupported operators are returned as-is
799            other => other,
800        }
801    }
802
803    /// Tries to push a filter predicate into the given operator.
804    ///
805    /// Returns either the predicate pushed into the operator, or a new
806    /// Filter operator on top if the predicate cannot be pushed further.
807    fn try_push_filter_into(
808        &self,
809        predicate: LogicalExpression,
810        op: LogicalOperator,
811    ) -> LogicalOperator {
812        match op {
813            // Can push through Project if predicate doesn't depend on computed columns
814            LogicalOperator::Project(mut proj) => {
815                let predicate_vars = self.extract_variables(&predicate);
816                let computed_vars = self.extract_projection_aliases(&proj.projections);
817
818                // If predicate doesn't use any computed columns, push through
819                if predicate_vars.is_disjoint(&computed_vars) {
820                    proj.input = Box::new(self.try_push_filter_into(predicate, *proj.input));
821                    LogicalOperator::Project(proj)
822                } else {
823                    // Can't push through, keep filter on top
824                    LogicalOperator::Filter(FilterOp {
825                        predicate,
826                        pushdown_hint: None,
827                        input: Box::new(LogicalOperator::Project(proj)),
828                    })
829                }
830            }
831
832            // Can push through Return (which is like a projection)
833            LogicalOperator::Return(mut ret) => {
834                ret.input = Box::new(self.try_push_filter_into(predicate, *ret.input));
835                LogicalOperator::Return(ret)
836            }
837
838            // Can push through Expand if predicate doesn't use variables introduced by this expand
839            LogicalOperator::Expand(mut expand) => {
840                let predicate_vars = self.extract_variables(&predicate);
841
842                // Variables introduced by this expand are:
843                // - The target variable (to_variable)
844                // - The edge variable (if any)
845                // - The path alias (if any)
846                let mut introduced_vars = vec![&expand.to_variable];
847                if let Some(ref edge_var) = expand.edge_variable {
848                    introduced_vars.push(edge_var);
849                }
850                if let Some(ref path_alias) = expand.path_alias {
851                    introduced_vars.push(path_alias);
852                }
853
854                // Check if predicate uses any variables introduced by this expand
855                let uses_introduced_vars =
856                    predicate_vars.iter().any(|v| introduced_vars.contains(&v));
857
858                if !uses_introduced_vars {
859                    // Predicate doesn't use vars from this expand, so push through
860                    expand.input = Box::new(self.try_push_filter_into(predicate, *expand.input));
861                    LogicalOperator::Expand(expand)
862                } else {
863                    // Keep filter after expand
864                    LogicalOperator::Filter(FilterOp {
865                        predicate,
866                        pushdown_hint: None,
867                        input: Box::new(LogicalOperator::Expand(expand)),
868                    })
869                }
870            }
871
872            // Can push through Join to left/right side based on variables used
873            LogicalOperator::Join(mut join) => {
874                let predicate_vars = self.extract_variables(&predicate);
875                let left_vars = self.collect_output_variables(&join.left);
876                let right_vars = self.collect_output_variables(&join.right);
877
878                let uses_left = predicate_vars.iter().any(|v| left_vars.contains(v));
879                let uses_right = predicate_vars.iter().any(|v| right_vars.contains(v));
880
881                if uses_left && !uses_right {
882                    // Push to left side
883                    join.left = Box::new(self.try_push_filter_into(predicate, *join.left));
884                    LogicalOperator::Join(join)
885                } else if uses_right && !uses_left {
886                    // Push to right side
887                    join.right = Box::new(self.try_push_filter_into(predicate, *join.right));
888                    LogicalOperator::Join(join)
889                } else {
890                    // Uses both sides - keep above join
891                    LogicalOperator::Filter(FilterOp {
892                        predicate,
893                        pushdown_hint: None,
894                        input: Box::new(LogicalOperator::Join(join)),
895                    })
896                }
897            }
898
899            // Cannot push through Aggregate (predicate refers to aggregated values)
900            LogicalOperator::Aggregate(agg) => LogicalOperator::Filter(FilterOp {
901                predicate,
902                pushdown_hint: None,
903                input: Box::new(LogicalOperator::Aggregate(agg)),
904            }),
905
906            // For NodeScan, we've reached the bottom - keep filter on top
907            LogicalOperator::NodeScan(scan) => LogicalOperator::Filter(FilterOp {
908                predicate,
909                pushdown_hint: None,
910                input: Box::new(LogicalOperator::NodeScan(scan)),
911            }),
912
913            // For other operators, keep filter on top
914            other => LogicalOperator::Filter(FilterOp {
915                predicate,
916                pushdown_hint: None,
917                input: Box::new(other),
918            }),
919        }
920    }
921
922    /// Collects all output variable names from an operator.
923    fn collect_output_variables(&self, op: &LogicalOperator) -> HashSet<String> {
924        let mut vars = HashSet::new();
925        Self::collect_output_variables_recursive(op, &mut vars);
926        vars
927    }
928
929    /// Recursively collects output variables from an operator.
930    fn collect_output_variables_recursive(op: &LogicalOperator, vars: &mut HashSet<String>) {
931        match op {
932            LogicalOperator::NodeScan(scan) => {
933                vars.insert(scan.variable.clone());
934            }
935            LogicalOperator::EdgeScan(scan) => {
936                vars.insert(scan.variable.clone());
937            }
938            LogicalOperator::Expand(expand) => {
939                vars.insert(expand.to_variable.clone());
940                if let Some(edge_var) = &expand.edge_variable {
941                    vars.insert(edge_var.clone());
942                }
943                Self::collect_output_variables_recursive(&expand.input, vars);
944            }
945            LogicalOperator::Filter(filter) => {
946                Self::collect_output_variables_recursive(&filter.input, vars);
947            }
948            LogicalOperator::Project(proj) => {
949                for p in &proj.projections {
950                    if let Some(alias) = &p.alias {
951                        vars.insert(alias.clone());
952                    }
953                }
954                Self::collect_output_variables_recursive(&proj.input, vars);
955            }
956            LogicalOperator::Join(join) => {
957                Self::collect_output_variables_recursive(&join.left, vars);
958                Self::collect_output_variables_recursive(&join.right, vars);
959            }
960            LogicalOperator::Aggregate(agg) => {
961                for expr in &agg.group_by {
962                    Self::collect_variables(expr, vars);
963                }
964                for agg_expr in &agg.aggregates {
965                    if let Some(alias) = &agg_expr.alias {
966                        vars.insert(alias.clone());
967                    }
968                }
969            }
970            LogicalOperator::Return(ret) => {
971                Self::collect_output_variables_recursive(&ret.input, vars);
972            }
973            LogicalOperator::Limit(limit) => {
974                Self::collect_output_variables_recursive(&limit.input, vars);
975            }
976            LogicalOperator::Skip(skip) => {
977                Self::collect_output_variables_recursive(&skip.input, vars);
978            }
979            LogicalOperator::Sort(sort) => {
980                Self::collect_output_variables_recursive(&sort.input, vars);
981            }
982            LogicalOperator::Distinct(distinct) => {
983                Self::collect_output_variables_recursive(&distinct.input, vars);
984            }
985            _ => {}
986        }
987    }
988
989    /// Extracts all variable names referenced in an expression.
990    fn extract_variables(&self, expr: &LogicalExpression) -> HashSet<String> {
991        let mut vars = HashSet::new();
992        Self::collect_variables(expr, &mut vars);
993        vars
994    }
995
996    /// Recursively collects variable names from an expression.
997    fn collect_variables(expr: &LogicalExpression, vars: &mut HashSet<String>) {
998        match expr {
999            LogicalExpression::Variable(name) => {
1000                vars.insert(name.clone());
1001            }
1002            LogicalExpression::Property { variable, .. } => {
1003                vars.insert(variable.clone());
1004            }
1005            LogicalExpression::Binary { left, right, .. } => {
1006                Self::collect_variables(left, vars);
1007                Self::collect_variables(right, vars);
1008            }
1009            LogicalExpression::Unary { operand, .. } => {
1010                Self::collect_variables(operand, vars);
1011            }
1012            LogicalExpression::FunctionCall { args, .. } => {
1013                for arg in args {
1014                    Self::collect_variables(arg, vars);
1015                }
1016            }
1017            LogicalExpression::List(items) => {
1018                for item in items {
1019                    Self::collect_variables(item, vars);
1020                }
1021            }
1022            LogicalExpression::Map(pairs) => {
1023                for (_, value) in pairs {
1024                    Self::collect_variables(value, vars);
1025                }
1026            }
1027            LogicalExpression::IndexAccess { base, index } => {
1028                Self::collect_variables(base, vars);
1029                Self::collect_variables(index, vars);
1030            }
1031            LogicalExpression::SliceAccess { base, start, end } => {
1032                Self::collect_variables(base, vars);
1033                if let Some(s) = start {
1034                    Self::collect_variables(s, vars);
1035                }
1036                if let Some(e) = end {
1037                    Self::collect_variables(e, vars);
1038                }
1039            }
1040            LogicalExpression::Case {
1041                operand,
1042                when_clauses,
1043                else_clause,
1044            } => {
1045                if let Some(op) = operand {
1046                    Self::collect_variables(op, vars);
1047                }
1048                for (cond, result) in when_clauses {
1049                    Self::collect_variables(cond, vars);
1050                    Self::collect_variables(result, vars);
1051                }
1052                if let Some(else_expr) = else_clause {
1053                    Self::collect_variables(else_expr, vars);
1054                }
1055            }
1056            LogicalExpression::Labels(var)
1057            | LogicalExpression::Type(var)
1058            | LogicalExpression::Id(var) => {
1059                vars.insert(var.clone());
1060            }
1061            LogicalExpression::Literal(_) | LogicalExpression::Parameter(_) => {}
1062            LogicalExpression::ListComprehension {
1063                list_expr,
1064                filter_expr,
1065                map_expr,
1066                ..
1067            } => {
1068                Self::collect_variables(list_expr, vars);
1069                if let Some(filter) = filter_expr {
1070                    Self::collect_variables(filter, vars);
1071                }
1072                Self::collect_variables(map_expr, vars);
1073            }
1074            LogicalExpression::ListPredicate {
1075                list_expr,
1076                predicate,
1077                ..
1078            } => {
1079                Self::collect_variables(list_expr, vars);
1080                Self::collect_variables(predicate, vars);
1081            }
1082            LogicalExpression::ExistsSubquery(_) | LogicalExpression::CountSubquery(_) => {
1083                // Subqueries have their own variable scope
1084            }
1085            LogicalExpression::PatternComprehension { projection, .. } => {
1086                Self::collect_variables(projection, vars);
1087            }
1088            LogicalExpression::MapProjection { base, entries } => {
1089                vars.insert(base.clone());
1090                for entry in entries {
1091                    if let crate::query::plan::MapProjectionEntry::LiteralEntry(_, expr) = entry {
1092                        Self::collect_variables(expr, vars);
1093                    }
1094                }
1095            }
1096            LogicalExpression::Reduce {
1097                initial,
1098                list,
1099                expression,
1100                ..
1101            } => {
1102                Self::collect_variables(initial, vars);
1103                Self::collect_variables(list, vars);
1104                Self::collect_variables(expression, vars);
1105            }
1106        }
1107    }
1108
1109    /// Extracts aliases from projection expressions.
1110    fn extract_projection_aliases(
1111        &self,
1112        projections: &[crate::query::plan::Projection],
1113    ) -> HashSet<String> {
1114        projections.iter().filter_map(|p| p.alias.clone()).collect()
1115    }
1116}
1117
1118impl Default for Optimizer {
1119    fn default() -> Self {
1120        Self::new()
1121    }
1122}
1123
1124#[cfg(test)]
1125mod tests {
1126    use super::*;
1127    use crate::query::plan::{
1128        AggregateExpr, AggregateFunction, AggregateOp, BinaryOp, DistinctOp, ExpandDirection,
1129        ExpandOp, JoinOp, JoinType, LimitOp, NodeScanOp, PathMode, ProjectOp, Projection,
1130        ReturnItem, ReturnOp, SkipOp, SortKey, SortOp, SortOrder, UnaryOp,
1131    };
1132    use grafeo_common::types::Value;
1133
1134    #[test]
1135    fn test_optimizer_filter_pushdown_simple() {
1136        // Query: MATCH (n:Person) WHERE n.age > 30 RETURN n
1137        // Before: Return -> Filter -> NodeScan
1138        // After:  Return -> Filter -> NodeScan (filter stays at bottom)
1139
1140        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1141            items: vec![ReturnItem {
1142                expression: LogicalExpression::Variable("n".to_string()),
1143                alias: None,
1144            }],
1145            distinct: false,
1146            input: Box::new(LogicalOperator::Filter(FilterOp {
1147                predicate: LogicalExpression::Binary {
1148                    left: Box::new(LogicalExpression::Property {
1149                        variable: "n".to_string(),
1150                        property: "age".to_string(),
1151                    }),
1152                    op: BinaryOp::Gt,
1153                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1154                },
1155                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1156                    variable: "n".to_string(),
1157                    label: Some("Person".to_string()),
1158                    input: None,
1159                })),
1160                pushdown_hint: None,
1161            })),
1162        }));
1163
1164        let optimizer = Optimizer::new();
1165        let optimized = optimizer.optimize(plan).unwrap();
1166
1167        // The structure should remain similar (filter stays near scan)
1168        if let LogicalOperator::Return(ret) = &optimized.root
1169            && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1170            && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1171        {
1172            assert_eq!(scan.variable, "n");
1173            return;
1174        }
1175        panic!("Expected Return -> Filter -> NodeScan structure");
1176    }
1177
1178    #[test]
1179    fn test_optimizer_filter_pushdown_through_expand() {
1180        // Query: MATCH (a:Person)-[:KNOWS]->(b) WHERE a.age > 30 RETURN b
1181        // The filter on 'a' should be pushed before the expand
1182
1183        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1184            items: vec![ReturnItem {
1185                expression: LogicalExpression::Variable("b".to_string()),
1186                alias: None,
1187            }],
1188            distinct: false,
1189            input: Box::new(LogicalOperator::Filter(FilterOp {
1190                predicate: LogicalExpression::Binary {
1191                    left: Box::new(LogicalExpression::Property {
1192                        variable: "a".to_string(),
1193                        property: "age".to_string(),
1194                    }),
1195                    op: BinaryOp::Gt,
1196                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1197                },
1198                pushdown_hint: None,
1199                input: Box::new(LogicalOperator::Expand(ExpandOp {
1200                    from_variable: "a".to_string(),
1201                    to_variable: "b".to_string(),
1202                    edge_variable: None,
1203                    direction: ExpandDirection::Outgoing,
1204                    edge_types: vec!["KNOWS".to_string()],
1205                    min_hops: 1,
1206                    max_hops: Some(1),
1207                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1208                        variable: "a".to_string(),
1209                        label: Some("Person".to_string()),
1210                        input: None,
1211                    })),
1212                    path_alias: None,
1213                    path_mode: PathMode::Walk,
1214                })),
1215            })),
1216        }));
1217
1218        let optimizer = Optimizer::new();
1219        let optimized = optimizer.optimize(plan).unwrap();
1220
1221        // Filter on 'a' should be pushed before the expand
1222        // Expected: Return -> Expand -> Filter -> NodeScan
1223        if let LogicalOperator::Return(ret) = &optimized.root
1224            && let LogicalOperator::Expand(expand) = ret.input.as_ref()
1225            && let LogicalOperator::Filter(filter) = expand.input.as_ref()
1226            && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1227        {
1228            assert_eq!(scan.variable, "a");
1229            assert_eq!(expand.from_variable, "a");
1230            assert_eq!(expand.to_variable, "b");
1231            return;
1232        }
1233        panic!("Expected Return -> Expand -> Filter -> NodeScan structure");
1234    }
1235
1236    #[test]
1237    fn test_optimizer_filter_not_pushed_through_expand_for_target_var() {
1238        // Query: MATCH (a:Person)-[:KNOWS]->(b) WHERE b.age > 30 RETURN a
1239        // The filter on 'b' should NOT be pushed before the expand
1240
1241        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1242            items: vec![ReturnItem {
1243                expression: LogicalExpression::Variable("a".to_string()),
1244                alias: None,
1245            }],
1246            distinct: false,
1247            input: Box::new(LogicalOperator::Filter(FilterOp {
1248                predicate: LogicalExpression::Binary {
1249                    left: Box::new(LogicalExpression::Property {
1250                        variable: "b".to_string(),
1251                        property: "age".to_string(),
1252                    }),
1253                    op: BinaryOp::Gt,
1254                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1255                },
1256                pushdown_hint: None,
1257                input: Box::new(LogicalOperator::Expand(ExpandOp {
1258                    from_variable: "a".to_string(),
1259                    to_variable: "b".to_string(),
1260                    edge_variable: None,
1261                    direction: ExpandDirection::Outgoing,
1262                    edge_types: vec!["KNOWS".to_string()],
1263                    min_hops: 1,
1264                    max_hops: Some(1),
1265                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1266                        variable: "a".to_string(),
1267                        label: Some("Person".to_string()),
1268                        input: None,
1269                    })),
1270                    path_alias: None,
1271                    path_mode: PathMode::Walk,
1272                })),
1273            })),
1274        }));
1275
1276        let optimizer = Optimizer::new();
1277        let optimized = optimizer.optimize(plan).unwrap();
1278
1279        // Filter on 'b' should stay after the expand
1280        // Expected: Return -> Filter -> Expand -> NodeScan
1281        if let LogicalOperator::Return(ret) = &optimized.root
1282            && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1283        {
1284            // Check that the filter is on 'b'
1285            if let LogicalExpression::Binary { left, .. } = &filter.predicate
1286                && let LogicalExpression::Property { variable, .. } = left.as_ref()
1287            {
1288                assert_eq!(variable, "b");
1289            }
1290
1291            if let LogicalOperator::Expand(expand) = filter.input.as_ref()
1292                && let LogicalOperator::NodeScan(_) = expand.input.as_ref()
1293            {
1294                return;
1295            }
1296        }
1297        panic!("Expected Return -> Filter -> Expand -> NodeScan structure");
1298    }
1299
1300    #[test]
1301    fn test_optimizer_extract_variables() {
1302        let optimizer = Optimizer::new();
1303
1304        let expr = LogicalExpression::Binary {
1305            left: Box::new(LogicalExpression::Property {
1306                variable: "n".to_string(),
1307                property: "age".to_string(),
1308            }),
1309            op: BinaryOp::Gt,
1310            right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1311        };
1312
1313        let vars = optimizer.extract_variables(&expr);
1314        assert_eq!(vars.len(), 1);
1315        assert!(vars.contains("n"));
1316    }
1317
1318    // Additional tests for optimizer configuration
1319
1320    #[test]
1321    fn test_optimizer_default() {
1322        let optimizer = Optimizer::default();
1323        // Should be able to optimize an empty plan
1324        let plan = LogicalPlan::new(LogicalOperator::Empty);
1325        let result = optimizer.optimize(plan);
1326        assert!(result.is_ok());
1327    }
1328
1329    #[test]
1330    fn test_optimizer_with_filter_pushdown_disabled() {
1331        let optimizer = Optimizer::new().with_filter_pushdown(false);
1332
1333        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1334            items: vec![ReturnItem {
1335                expression: LogicalExpression::Variable("n".to_string()),
1336                alias: None,
1337            }],
1338            distinct: false,
1339            input: Box::new(LogicalOperator::Filter(FilterOp {
1340                predicate: LogicalExpression::Literal(Value::Bool(true)),
1341                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1342                    variable: "n".to_string(),
1343                    label: None,
1344                    input: None,
1345                })),
1346                pushdown_hint: None,
1347            })),
1348        }));
1349
1350        let optimized = optimizer.optimize(plan).unwrap();
1351        // Structure should be unchanged
1352        if let LogicalOperator::Return(ret) = &optimized.root
1353            && let LogicalOperator::Filter(_) = ret.input.as_ref()
1354        {
1355            return;
1356        }
1357        panic!("Expected unchanged structure");
1358    }
1359
1360    #[test]
1361    fn test_optimizer_with_join_reorder_disabled() {
1362        let optimizer = Optimizer::new().with_join_reorder(false);
1363        assert!(
1364            optimizer
1365                .optimize(LogicalPlan::new(LogicalOperator::Empty))
1366                .is_ok()
1367        );
1368    }
1369
1370    #[test]
1371    fn test_optimizer_with_cost_model() {
1372        let cost_model = CostModel::new();
1373        let optimizer = Optimizer::new().with_cost_model(cost_model);
1374        assert!(
1375            optimizer
1376                .cost_model()
1377                .estimate(&LogicalOperator::Empty, 0.0)
1378                .total()
1379                < 0.001
1380        );
1381    }
1382
1383    #[test]
1384    fn test_optimizer_with_cardinality_estimator() {
1385        let mut estimator = CardinalityEstimator::new();
1386        estimator.add_table_stats("Test", TableStats::new(500));
1387        let optimizer = Optimizer::new().with_cardinality_estimator(estimator);
1388
1389        let scan = LogicalOperator::NodeScan(NodeScanOp {
1390            variable: "n".to_string(),
1391            label: Some("Test".to_string()),
1392            input: None,
1393        });
1394        let plan = LogicalPlan::new(scan);
1395
1396        let cardinality = optimizer.estimate_cardinality(&plan);
1397        assert!((cardinality - 500.0).abs() < 0.001);
1398    }
1399
1400    #[test]
1401    fn test_optimizer_estimate_cost() {
1402        let optimizer = Optimizer::new();
1403        let plan = LogicalPlan::new(LogicalOperator::NodeScan(NodeScanOp {
1404            variable: "n".to_string(),
1405            label: None,
1406            input: None,
1407        }));
1408
1409        let cost = optimizer.estimate_cost(&plan);
1410        assert!(cost.total() > 0.0);
1411    }
1412
1413    // Filter pushdown through various operators
1414
1415    #[test]
1416    fn test_filter_pushdown_through_project() {
1417        let optimizer = Optimizer::new();
1418
1419        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1420            predicate: LogicalExpression::Binary {
1421                left: Box::new(LogicalExpression::Property {
1422                    variable: "n".to_string(),
1423                    property: "age".to_string(),
1424                }),
1425                op: BinaryOp::Gt,
1426                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1427            },
1428            pushdown_hint: None,
1429            input: Box::new(LogicalOperator::Project(ProjectOp {
1430                projections: vec![Projection {
1431                    expression: LogicalExpression::Variable("n".to_string()),
1432                    alias: None,
1433                }],
1434                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1435                    variable: "n".to_string(),
1436                    label: None,
1437                    input: None,
1438                })),
1439            })),
1440        }));
1441
1442        let optimized = optimizer.optimize(plan).unwrap();
1443
1444        // Filter should be pushed through Project
1445        if let LogicalOperator::Project(proj) = &optimized.root
1446            && let LogicalOperator::Filter(_) = proj.input.as_ref()
1447        {
1448            return;
1449        }
1450        panic!("Expected Project -> Filter structure");
1451    }
1452
1453    #[test]
1454    fn test_filter_not_pushed_through_project_with_alias() {
1455        let optimizer = Optimizer::new();
1456
1457        // Filter on computed column 'x' should not be pushed through project that creates 'x'
1458        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1459            predicate: LogicalExpression::Binary {
1460                left: Box::new(LogicalExpression::Variable("x".to_string())),
1461                op: BinaryOp::Gt,
1462                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1463            },
1464            pushdown_hint: None,
1465            input: Box::new(LogicalOperator::Project(ProjectOp {
1466                projections: vec![Projection {
1467                    expression: LogicalExpression::Property {
1468                        variable: "n".to_string(),
1469                        property: "age".to_string(),
1470                    },
1471                    alias: Some("x".to_string()),
1472                }],
1473                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1474                    variable: "n".to_string(),
1475                    label: None,
1476                    input: None,
1477                })),
1478            })),
1479        }));
1480
1481        let optimized = optimizer.optimize(plan).unwrap();
1482
1483        // Filter should stay above Project
1484        if let LogicalOperator::Filter(filter) = &optimized.root
1485            && let LogicalOperator::Project(_) = filter.input.as_ref()
1486        {
1487            return;
1488        }
1489        panic!("Expected Filter -> Project structure");
1490    }
1491
1492    #[test]
1493    fn test_filter_pushdown_through_limit() {
1494        let optimizer = Optimizer::new();
1495
1496        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1497            predicate: LogicalExpression::Literal(Value::Bool(true)),
1498            pushdown_hint: None,
1499            input: Box::new(LogicalOperator::Limit(LimitOp {
1500                count: 10.into(),
1501                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1502                    variable: "n".to_string(),
1503                    label: None,
1504                    input: None,
1505                })),
1506            })),
1507        }));
1508
1509        let optimized = optimizer.optimize(plan).unwrap();
1510
1511        // Filter stays above Limit (cannot be pushed through)
1512        if let LogicalOperator::Filter(filter) = &optimized.root
1513            && let LogicalOperator::Limit(_) = filter.input.as_ref()
1514        {
1515            return;
1516        }
1517        panic!("Expected Filter -> Limit structure");
1518    }
1519
1520    #[test]
1521    fn test_filter_pushdown_through_sort() {
1522        let optimizer = Optimizer::new();
1523
1524        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1525            predicate: LogicalExpression::Literal(Value::Bool(true)),
1526            pushdown_hint: None,
1527            input: Box::new(LogicalOperator::Sort(SortOp {
1528                keys: vec![SortKey {
1529                    expression: LogicalExpression::Variable("n".to_string()),
1530                    order: SortOrder::Ascending,
1531                    nulls: None,
1532                }],
1533                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1534                    variable: "n".to_string(),
1535                    label: None,
1536                    input: None,
1537                })),
1538            })),
1539        }));
1540
1541        let optimized = optimizer.optimize(plan).unwrap();
1542
1543        // Filter stays above Sort
1544        if let LogicalOperator::Filter(filter) = &optimized.root
1545            && let LogicalOperator::Sort(_) = filter.input.as_ref()
1546        {
1547            return;
1548        }
1549        panic!("Expected Filter -> Sort structure");
1550    }
1551
1552    #[test]
1553    fn test_filter_pushdown_through_distinct() {
1554        let optimizer = Optimizer::new();
1555
1556        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1557            predicate: LogicalExpression::Literal(Value::Bool(true)),
1558            pushdown_hint: None,
1559            input: Box::new(LogicalOperator::Distinct(DistinctOp {
1560                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1561                    variable: "n".to_string(),
1562                    label: None,
1563                    input: None,
1564                })),
1565                columns: None,
1566            })),
1567        }));
1568
1569        let optimized = optimizer.optimize(plan).unwrap();
1570
1571        // Filter stays above Distinct
1572        if let LogicalOperator::Filter(filter) = &optimized.root
1573            && let LogicalOperator::Distinct(_) = filter.input.as_ref()
1574        {
1575            return;
1576        }
1577        panic!("Expected Filter -> Distinct structure");
1578    }
1579
1580    #[test]
1581    fn test_filter_not_pushed_through_aggregate() {
1582        let optimizer = Optimizer::new();
1583
1584        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1585            predicate: LogicalExpression::Binary {
1586                left: Box::new(LogicalExpression::Variable("cnt".to_string())),
1587                op: BinaryOp::Gt,
1588                right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1589            },
1590            pushdown_hint: None,
1591            input: Box::new(LogicalOperator::Aggregate(AggregateOp {
1592                group_by: vec![],
1593                aggregates: vec![AggregateExpr {
1594                    function: AggregateFunction::Count,
1595                    expression: None,
1596                    expression2: None,
1597                    distinct: false,
1598                    alias: Some("cnt".to_string()),
1599                    percentile: None,
1600                    separator: None,
1601                }],
1602                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1603                    variable: "n".to_string(),
1604                    label: None,
1605                    input: None,
1606                })),
1607                having: None,
1608            })),
1609        }));
1610
1611        let optimized = optimizer.optimize(plan).unwrap();
1612
1613        // Filter should stay above Aggregate
1614        if let LogicalOperator::Filter(filter) = &optimized.root
1615            && let LogicalOperator::Aggregate(_) = filter.input.as_ref()
1616        {
1617            return;
1618        }
1619        panic!("Expected Filter -> Aggregate structure");
1620    }
1621
1622    #[test]
1623    fn test_filter_pushdown_to_left_join_side() {
1624        let optimizer = Optimizer::new();
1625
1626        // Filter on left variable should be pushed to left side
1627        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1628            predicate: LogicalExpression::Binary {
1629                left: Box::new(LogicalExpression::Property {
1630                    variable: "a".to_string(),
1631                    property: "age".to_string(),
1632                }),
1633                op: BinaryOp::Gt,
1634                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1635            },
1636            pushdown_hint: None,
1637            input: Box::new(LogicalOperator::Join(JoinOp {
1638                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1639                    variable: "a".to_string(),
1640                    label: Some("Person".to_string()),
1641                    input: None,
1642                })),
1643                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1644                    variable: "b".to_string(),
1645                    label: Some("Company".to_string()),
1646                    input: None,
1647                })),
1648                join_type: JoinType::Inner,
1649                conditions: vec![],
1650            })),
1651        }));
1652
1653        let optimized = optimizer.optimize(plan).unwrap();
1654
1655        // Filter should be pushed to left side of join
1656        if let LogicalOperator::Join(join) = &optimized.root
1657            && let LogicalOperator::Filter(_) = join.left.as_ref()
1658        {
1659            return;
1660        }
1661        panic!("Expected Join with Filter on left side");
1662    }
1663
1664    #[test]
1665    fn test_filter_pushdown_to_right_join_side() {
1666        let optimizer = Optimizer::new();
1667
1668        // Filter on right variable should be pushed to right side
1669        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1670            predicate: LogicalExpression::Binary {
1671                left: Box::new(LogicalExpression::Property {
1672                    variable: "b".to_string(),
1673                    property: "name".to_string(),
1674                }),
1675                op: BinaryOp::Eq,
1676                right: Box::new(LogicalExpression::Literal(Value::String("Acme".into()))),
1677            },
1678            pushdown_hint: None,
1679            input: Box::new(LogicalOperator::Join(JoinOp {
1680                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1681                    variable: "a".to_string(),
1682                    label: Some("Person".to_string()),
1683                    input: None,
1684                })),
1685                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1686                    variable: "b".to_string(),
1687                    label: Some("Company".to_string()),
1688                    input: None,
1689                })),
1690                join_type: JoinType::Inner,
1691                conditions: vec![],
1692            })),
1693        }));
1694
1695        let optimized = optimizer.optimize(plan).unwrap();
1696
1697        // Filter should be pushed to right side of join
1698        if let LogicalOperator::Join(join) = &optimized.root
1699            && let LogicalOperator::Filter(_) = join.right.as_ref()
1700        {
1701            return;
1702        }
1703        panic!("Expected Join with Filter on right side");
1704    }
1705
1706    #[test]
1707    fn test_filter_not_pushed_when_uses_both_join_sides() {
1708        let optimizer = Optimizer::new();
1709
1710        // Filter using both variables should stay above join
1711        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1712            predicate: LogicalExpression::Binary {
1713                left: Box::new(LogicalExpression::Property {
1714                    variable: "a".to_string(),
1715                    property: "id".to_string(),
1716                }),
1717                op: BinaryOp::Eq,
1718                right: Box::new(LogicalExpression::Property {
1719                    variable: "b".to_string(),
1720                    property: "a_id".to_string(),
1721                }),
1722            },
1723            pushdown_hint: None,
1724            input: Box::new(LogicalOperator::Join(JoinOp {
1725                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1726                    variable: "a".to_string(),
1727                    label: None,
1728                    input: None,
1729                })),
1730                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1731                    variable: "b".to_string(),
1732                    label: None,
1733                    input: None,
1734                })),
1735                join_type: JoinType::Inner,
1736                conditions: vec![],
1737            })),
1738        }));
1739
1740        let optimized = optimizer.optimize(plan).unwrap();
1741
1742        // Filter should stay above join
1743        if let LogicalOperator::Filter(filter) = &optimized.root
1744            && let LogicalOperator::Join(_) = filter.input.as_ref()
1745        {
1746            return;
1747        }
1748        panic!("Expected Filter -> Join structure");
1749    }
1750
1751    // Variable extraction tests
1752
1753    #[test]
1754    fn test_extract_variables_from_variable() {
1755        let optimizer = Optimizer::new();
1756        let expr = LogicalExpression::Variable("x".to_string());
1757        let vars = optimizer.extract_variables(&expr);
1758        assert_eq!(vars.len(), 1);
1759        assert!(vars.contains("x"));
1760    }
1761
1762    #[test]
1763    fn test_extract_variables_from_unary() {
1764        let optimizer = Optimizer::new();
1765        let expr = LogicalExpression::Unary {
1766            op: UnaryOp::Not,
1767            operand: Box::new(LogicalExpression::Variable("x".to_string())),
1768        };
1769        let vars = optimizer.extract_variables(&expr);
1770        assert_eq!(vars.len(), 1);
1771        assert!(vars.contains("x"));
1772    }
1773
1774    #[test]
1775    fn test_extract_variables_from_function_call() {
1776        let optimizer = Optimizer::new();
1777        let expr = LogicalExpression::FunctionCall {
1778            name: "length".to_string(),
1779            args: vec![
1780                LogicalExpression::Variable("a".to_string()),
1781                LogicalExpression::Variable("b".to_string()),
1782            ],
1783            distinct: false,
1784        };
1785        let vars = optimizer.extract_variables(&expr);
1786        assert_eq!(vars.len(), 2);
1787        assert!(vars.contains("a"));
1788        assert!(vars.contains("b"));
1789    }
1790
1791    #[test]
1792    fn test_extract_variables_from_list() {
1793        let optimizer = Optimizer::new();
1794        let expr = LogicalExpression::List(vec![
1795            LogicalExpression::Variable("a".to_string()),
1796            LogicalExpression::Literal(Value::Int64(1)),
1797            LogicalExpression::Variable("b".to_string()),
1798        ]);
1799        let vars = optimizer.extract_variables(&expr);
1800        assert_eq!(vars.len(), 2);
1801        assert!(vars.contains("a"));
1802        assert!(vars.contains("b"));
1803    }
1804
1805    #[test]
1806    fn test_extract_variables_from_map() {
1807        let optimizer = Optimizer::new();
1808        let expr = LogicalExpression::Map(vec![
1809            (
1810                "key1".to_string(),
1811                LogicalExpression::Variable("a".to_string()),
1812            ),
1813            (
1814                "key2".to_string(),
1815                LogicalExpression::Variable("b".to_string()),
1816            ),
1817        ]);
1818        let vars = optimizer.extract_variables(&expr);
1819        assert_eq!(vars.len(), 2);
1820        assert!(vars.contains("a"));
1821        assert!(vars.contains("b"));
1822    }
1823
1824    #[test]
1825    fn test_extract_variables_from_index_access() {
1826        let optimizer = Optimizer::new();
1827        let expr = LogicalExpression::IndexAccess {
1828            base: Box::new(LogicalExpression::Variable("list".to_string())),
1829            index: Box::new(LogicalExpression::Variable("idx".to_string())),
1830        };
1831        let vars = optimizer.extract_variables(&expr);
1832        assert_eq!(vars.len(), 2);
1833        assert!(vars.contains("list"));
1834        assert!(vars.contains("idx"));
1835    }
1836
1837    #[test]
1838    fn test_extract_variables_from_slice_access() {
1839        let optimizer = Optimizer::new();
1840        let expr = LogicalExpression::SliceAccess {
1841            base: Box::new(LogicalExpression::Variable("list".to_string())),
1842            start: Some(Box::new(LogicalExpression::Variable("s".to_string()))),
1843            end: Some(Box::new(LogicalExpression::Variable("e".to_string()))),
1844        };
1845        let vars = optimizer.extract_variables(&expr);
1846        assert_eq!(vars.len(), 3);
1847        assert!(vars.contains("list"));
1848        assert!(vars.contains("s"));
1849        assert!(vars.contains("e"));
1850    }
1851
1852    #[test]
1853    fn test_extract_variables_from_case() {
1854        let optimizer = Optimizer::new();
1855        let expr = LogicalExpression::Case {
1856            operand: Some(Box::new(LogicalExpression::Variable("x".to_string()))),
1857            when_clauses: vec![(
1858                LogicalExpression::Literal(Value::Int64(1)),
1859                LogicalExpression::Variable("a".to_string()),
1860            )],
1861            else_clause: Some(Box::new(LogicalExpression::Variable("b".to_string()))),
1862        };
1863        let vars = optimizer.extract_variables(&expr);
1864        assert_eq!(vars.len(), 3);
1865        assert!(vars.contains("x"));
1866        assert!(vars.contains("a"));
1867        assert!(vars.contains("b"));
1868    }
1869
1870    #[test]
1871    fn test_extract_variables_from_labels() {
1872        let optimizer = Optimizer::new();
1873        let expr = LogicalExpression::Labels("n".to_string());
1874        let vars = optimizer.extract_variables(&expr);
1875        assert_eq!(vars.len(), 1);
1876        assert!(vars.contains("n"));
1877    }
1878
1879    #[test]
1880    fn test_extract_variables_from_type() {
1881        let optimizer = Optimizer::new();
1882        let expr = LogicalExpression::Type("e".to_string());
1883        let vars = optimizer.extract_variables(&expr);
1884        assert_eq!(vars.len(), 1);
1885        assert!(vars.contains("e"));
1886    }
1887
1888    #[test]
1889    fn test_extract_variables_from_id() {
1890        let optimizer = Optimizer::new();
1891        let expr = LogicalExpression::Id("n".to_string());
1892        let vars = optimizer.extract_variables(&expr);
1893        assert_eq!(vars.len(), 1);
1894        assert!(vars.contains("n"));
1895    }
1896
1897    #[test]
1898    fn test_extract_variables_from_list_comprehension() {
1899        let optimizer = Optimizer::new();
1900        let expr = LogicalExpression::ListComprehension {
1901            variable: "x".to_string(),
1902            list_expr: Box::new(LogicalExpression::Variable("items".to_string())),
1903            filter_expr: Some(Box::new(LogicalExpression::Variable("pred".to_string()))),
1904            map_expr: Box::new(LogicalExpression::Variable("result".to_string())),
1905        };
1906        let vars = optimizer.extract_variables(&expr);
1907        assert!(vars.contains("items"));
1908        assert!(vars.contains("pred"));
1909        assert!(vars.contains("result"));
1910    }
1911
1912    #[test]
1913    fn test_extract_variables_from_literal_and_parameter() {
1914        let optimizer = Optimizer::new();
1915
1916        let literal = LogicalExpression::Literal(Value::Int64(42));
1917        assert!(optimizer.extract_variables(&literal).is_empty());
1918
1919        let param = LogicalExpression::Parameter("p".to_string());
1920        assert!(optimizer.extract_variables(&param).is_empty());
1921    }
1922
1923    // Recursive filter pushdown tests
1924
1925    #[test]
1926    fn test_recursive_filter_pushdown_through_skip() {
1927        let optimizer = Optimizer::new();
1928
1929        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1930            items: vec![ReturnItem {
1931                expression: LogicalExpression::Variable("n".to_string()),
1932                alias: None,
1933            }],
1934            distinct: false,
1935            input: Box::new(LogicalOperator::Filter(FilterOp {
1936                predicate: LogicalExpression::Literal(Value::Bool(true)),
1937                pushdown_hint: None,
1938                input: Box::new(LogicalOperator::Skip(SkipOp {
1939                    count: 5.into(),
1940                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1941                        variable: "n".to_string(),
1942                        label: None,
1943                        input: None,
1944                    })),
1945                })),
1946            })),
1947        }));
1948
1949        let optimized = optimizer.optimize(plan).unwrap();
1950
1951        // Verify optimization succeeded
1952        assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1953    }
1954
1955    #[test]
1956    fn test_nested_filter_pushdown() {
1957        let optimizer = Optimizer::new();
1958
1959        // Multiple nested filters
1960        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1961            items: vec![ReturnItem {
1962                expression: LogicalExpression::Variable("n".to_string()),
1963                alias: None,
1964            }],
1965            distinct: false,
1966            input: Box::new(LogicalOperator::Filter(FilterOp {
1967                predicate: LogicalExpression::Binary {
1968                    left: Box::new(LogicalExpression::Property {
1969                        variable: "n".to_string(),
1970                        property: "x".to_string(),
1971                    }),
1972                    op: BinaryOp::Gt,
1973                    right: Box::new(LogicalExpression::Literal(Value::Int64(1))),
1974                },
1975                pushdown_hint: None,
1976                input: Box::new(LogicalOperator::Filter(FilterOp {
1977                    predicate: LogicalExpression::Binary {
1978                        left: Box::new(LogicalExpression::Property {
1979                            variable: "n".to_string(),
1980                            property: "y".to_string(),
1981                        }),
1982                        op: BinaryOp::Lt,
1983                        right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1984                    },
1985                    pushdown_hint: None,
1986                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1987                        variable: "n".to_string(),
1988                        label: None,
1989                        input: None,
1990                    })),
1991                })),
1992            })),
1993        }));
1994
1995        let optimized = optimizer.optimize(plan).unwrap();
1996        assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1997    }
1998
1999    #[test]
2000    fn test_cyclic_join_produces_multi_way_join() {
2001        use crate::query::plan::JoinCondition;
2002
2003        // Triangle pattern: a ⋈ b ⋈ c ⋈ a (cyclic)
2004        let scan_a = LogicalOperator::NodeScan(NodeScanOp {
2005            variable: "a".to_string(),
2006            label: Some("Person".to_string()),
2007            input: None,
2008        });
2009        let scan_b = LogicalOperator::NodeScan(NodeScanOp {
2010            variable: "b".to_string(),
2011            label: Some("Person".to_string()),
2012            input: None,
2013        });
2014        let scan_c = LogicalOperator::NodeScan(NodeScanOp {
2015            variable: "c".to_string(),
2016            label: Some("Person".to_string()),
2017            input: None,
2018        });
2019
2020        // Build: Join(Join(a, b, a=b), c, b=c) with extra condition c=a
2021        let join_ab = LogicalOperator::Join(JoinOp {
2022            left: Box::new(scan_a),
2023            right: Box::new(scan_b),
2024            join_type: JoinType::Inner,
2025            conditions: vec![JoinCondition {
2026                left: LogicalExpression::Variable("a".to_string()),
2027                right: LogicalExpression::Variable("b".to_string()),
2028            }],
2029        });
2030
2031        let join_abc = LogicalOperator::Join(JoinOp {
2032            left: Box::new(join_ab),
2033            right: Box::new(scan_c),
2034            join_type: JoinType::Inner,
2035            conditions: vec![
2036                JoinCondition {
2037                    left: LogicalExpression::Variable("b".to_string()),
2038                    right: LogicalExpression::Variable("c".to_string()),
2039                },
2040                JoinCondition {
2041                    left: LogicalExpression::Variable("c".to_string()),
2042                    right: LogicalExpression::Variable("a".to_string()),
2043                },
2044            ],
2045        });
2046
2047        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
2048            items: vec![ReturnItem {
2049                expression: LogicalExpression::Variable("a".to_string()),
2050                alias: None,
2051            }],
2052            distinct: false,
2053            input: Box::new(join_abc),
2054        }));
2055
2056        let mut optimizer = Optimizer::new();
2057        optimizer
2058            .card_estimator
2059            .add_table_stats("Person", cardinality::TableStats::new(1000));
2060
2061        let optimized = optimizer.optimize(plan).unwrap();
2062
2063        // Walk the tree to find a MultiWayJoin
2064        fn has_multi_way_join(op: &LogicalOperator) -> bool {
2065            match op {
2066                LogicalOperator::MultiWayJoin(_) => true,
2067                LogicalOperator::Return(ret) => has_multi_way_join(&ret.input),
2068                LogicalOperator::Filter(f) => has_multi_way_join(&f.input),
2069                LogicalOperator::Project(p) => has_multi_way_join(&p.input),
2070                _ => false,
2071            }
2072        }
2073
2074        assert!(
2075            has_multi_way_join(&optimized.root),
2076            "Expected MultiWayJoin for cyclic triangle pattern"
2077        );
2078    }
2079
2080    #[test]
2081    fn test_acyclic_join_uses_binary_joins() {
2082        use crate::query::plan::JoinCondition;
2083
2084        // Chain: a ⋈ b ⋈ c (acyclic)
2085        let scan_a = LogicalOperator::NodeScan(NodeScanOp {
2086            variable: "a".to_string(),
2087            label: Some("Person".to_string()),
2088            input: None,
2089        });
2090        let scan_b = LogicalOperator::NodeScan(NodeScanOp {
2091            variable: "b".to_string(),
2092            label: Some("Person".to_string()),
2093            input: None,
2094        });
2095        let scan_c = LogicalOperator::NodeScan(NodeScanOp {
2096            variable: "c".to_string(),
2097            label: Some("Company".to_string()),
2098            input: None,
2099        });
2100
2101        let join_ab = LogicalOperator::Join(JoinOp {
2102            left: Box::new(scan_a),
2103            right: Box::new(scan_b),
2104            join_type: JoinType::Inner,
2105            conditions: vec![JoinCondition {
2106                left: LogicalExpression::Variable("a".to_string()),
2107                right: LogicalExpression::Variable("b".to_string()),
2108            }],
2109        });
2110
2111        let join_abc = LogicalOperator::Join(JoinOp {
2112            left: Box::new(join_ab),
2113            right: Box::new(scan_c),
2114            join_type: JoinType::Inner,
2115            conditions: vec![JoinCondition {
2116                left: LogicalExpression::Variable("b".to_string()),
2117                right: LogicalExpression::Variable("c".to_string()),
2118            }],
2119        });
2120
2121        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
2122            items: vec![ReturnItem {
2123                expression: LogicalExpression::Variable("a".to_string()),
2124                alias: None,
2125            }],
2126            distinct: false,
2127            input: Box::new(join_abc),
2128        }));
2129
2130        let mut optimizer = Optimizer::new();
2131        optimizer
2132            .card_estimator
2133            .add_table_stats("Person", cardinality::TableStats::new(1000));
2134        optimizer
2135            .card_estimator
2136            .add_table_stats("Company", cardinality::TableStats::new(100));
2137
2138        let optimized = optimizer.optimize(plan).unwrap();
2139
2140        // Should NOT contain MultiWayJoin for acyclic pattern
2141        fn has_multi_way_join(op: &LogicalOperator) -> bool {
2142            match op {
2143                LogicalOperator::MultiWayJoin(_) => true,
2144                LogicalOperator::Return(ret) => has_multi_way_join(&ret.input),
2145                LogicalOperator::Filter(f) => has_multi_way_join(&f.input),
2146                LogicalOperator::Project(p) => has_multi_way_join(&p.input),
2147                LogicalOperator::Join(j) => {
2148                    has_multi_way_join(&j.left) || has_multi_way_join(&j.right)
2149                }
2150                _ => false,
2151            }
2152        }
2153
2154        assert!(
2155            !has_multi_way_join(&optimized.root),
2156            "Acyclic join should NOT produce MultiWayJoin"
2157        );
2158    }
2159}