Skip to main content

grafeo_engine/query/optimizer/
mod.rs

1//! Makes your queries faster without changing their meaning.
2//!
3//! The optimizer transforms logical plans to run more efficiently:
4//!
5//! | Optimization | What it does |
6//! | ------------ | ------------ |
7//! | Filter Pushdown | Moves `WHERE` clauses closer to scans - filter early, process less |
8//! | Join Reordering | Picks the best order to join tables using the DPccp algorithm |
9//! | Predicate Simplification | Folds constants like `1 + 1` into `2` |
10//!
11//! The optimizer uses [`CostModel`] and [`CardinalityEstimator`] to predict
12//! how expensive different plans are, then picks the cheapest.
13
14pub mod cardinality;
15pub mod cost;
16pub mod join_order;
17
18pub use cardinality::{
19    CardinalityEstimator, ColumnStats, EstimationLog, SelectivityConfig, TableStats,
20};
21pub use cost::{Cost, CostModel};
22pub use join_order::{BitSet, DPccp, JoinGraph, JoinGraphBuilder, JoinPlan};
23
24use crate::query::plan::{
25    FilterOp, JoinCondition, LogicalExpression, LogicalOperator, LogicalPlan, MultiWayJoinOp,
26};
27use grafeo_common::utils::error::Result;
28use std::collections::HashSet;
29
30/// Information about a join condition for join reordering.
31#[derive(Debug, Clone)]
32struct JoinInfo {
33    left_var: String,
34    right_var: String,
35    left_expr: LogicalExpression,
36    right_expr: LogicalExpression,
37}
38
39/// A column required by the query, used for projection pushdown.
40#[derive(Debug, Clone, PartialEq, Eq, Hash)]
41enum RequiredColumn {
42    /// A variable (node, edge, or path binding)
43    Variable(String),
44    /// A specific property of a variable
45    Property(String, String),
46}
47
48/// Transforms logical plans for faster execution.
49///
50/// Create with [`new()`](Self::new), then call [`optimize()`](Self::optimize).
51/// Use the builder methods to enable/disable specific optimizations.
52pub struct Optimizer {
53    /// Whether to enable filter pushdown.
54    enable_filter_pushdown: bool,
55    /// Whether to enable join reordering.
56    enable_join_reorder: bool,
57    /// Whether to enable projection pushdown.
58    enable_projection_pushdown: bool,
59    /// Cost model for estimation.
60    cost_model: CostModel,
61    /// Cardinality estimator.
62    card_estimator: CardinalityEstimator,
63}
64
65impl Optimizer {
66    /// Creates a new optimizer with default settings.
67    #[must_use]
68    pub fn new() -> Self {
69        Self {
70            enable_filter_pushdown: true,
71            enable_join_reorder: true,
72            enable_projection_pushdown: true,
73            cost_model: CostModel::new(),
74            card_estimator: CardinalityEstimator::new(),
75        }
76    }
77
78    /// Creates an optimizer with cardinality estimates from the store's statistics.
79    ///
80    /// Pre-populates the cardinality estimator with per-label row counts and
81    /// edge type fanout. Feeds per-edge-type degree stats into the cost model
82    /// for accurate expand cost estimation.
83    #[must_use]
84    pub fn from_store(store: &grafeo_core::graph::lpg::LpgStore) -> Self {
85        store.ensure_statistics_fresh();
86        let stats = store.statistics();
87        let estimator = CardinalityEstimator::from_statistics(&stats);
88
89        // Derive average fanout from statistics for the cost model
90        let avg_fanout = if stats.total_nodes > 0 {
91            (stats.total_edges as f64 / stats.total_nodes as f64).max(1.0)
92        } else {
93            10.0
94        };
95
96        // Collect per-edge-type degree stats for accurate expand costing
97        let edge_type_degrees: std::collections::HashMap<String, (f64, f64)> = stats
98            .edge_types
99            .iter()
100            .map(|(name, et)| (name.clone(), (et.avg_out_degree, et.avg_in_degree)))
101            .collect();
102
103        Self {
104            enable_filter_pushdown: true,
105            enable_join_reorder: true,
106            enable_projection_pushdown: true,
107            cost_model: CostModel::new()
108                .with_avg_fanout(avg_fanout)
109                .with_edge_type_degrees(edge_type_degrees),
110            card_estimator: estimator,
111        }
112    }
113
114    /// Creates an optimizer from any GraphStore implementation.
115    ///
116    /// Unlike [`from_store`](Self::from_store), this does not call
117    /// `ensure_statistics_fresh()` since external stores manage their own
118    /// statistics. The store's [`statistics()`](grafeo_core::graph::GraphStore::statistics) method
119    /// is called directly.
120    #[must_use]
121    pub fn from_graph_store(store: &dyn grafeo_core::graph::GraphStore) -> Self {
122        let stats = store.statistics();
123        let estimator = CardinalityEstimator::from_statistics(&stats);
124
125        let avg_fanout = if stats.total_nodes > 0 {
126            (stats.total_edges as f64 / stats.total_nodes as f64).max(1.0)
127        } else {
128            10.0
129        };
130
131        let edge_type_degrees: std::collections::HashMap<String, (f64, f64)> = stats
132            .edge_types
133            .iter()
134            .map(|(name, et)| (name.clone(), (et.avg_out_degree, et.avg_in_degree)))
135            .collect();
136
137        Self {
138            enable_filter_pushdown: true,
139            enable_join_reorder: true,
140            enable_projection_pushdown: true,
141            cost_model: CostModel::new()
142                .with_avg_fanout(avg_fanout)
143                .with_edge_type_degrees(edge_type_degrees),
144            card_estimator: estimator,
145        }
146    }
147
148    /// Enables or disables filter pushdown.
149    pub fn with_filter_pushdown(mut self, enabled: bool) -> Self {
150        self.enable_filter_pushdown = enabled;
151        self
152    }
153
154    /// Enables or disables join reordering.
155    pub fn with_join_reorder(mut self, enabled: bool) -> Self {
156        self.enable_join_reorder = enabled;
157        self
158    }
159
160    /// Enables or disables projection pushdown.
161    pub fn with_projection_pushdown(mut self, enabled: bool) -> Self {
162        self.enable_projection_pushdown = enabled;
163        self
164    }
165
166    /// Sets the cost model.
167    pub fn with_cost_model(mut self, cost_model: CostModel) -> Self {
168        self.cost_model = cost_model;
169        self
170    }
171
172    /// Sets the cardinality estimator.
173    pub fn with_cardinality_estimator(mut self, estimator: CardinalityEstimator) -> Self {
174        self.card_estimator = estimator;
175        self
176    }
177
178    /// Sets the selectivity configuration for the cardinality estimator.
179    pub fn with_selectivity_config(mut self, config: SelectivityConfig) -> Self {
180        self.card_estimator = CardinalityEstimator::with_selectivity_config(config);
181        self
182    }
183
184    /// Returns a reference to the cost model.
185    pub fn cost_model(&self) -> &CostModel {
186        &self.cost_model
187    }
188
189    /// Returns a reference to the cardinality estimator.
190    pub fn cardinality_estimator(&self) -> &CardinalityEstimator {
191        &self.card_estimator
192    }
193
194    /// Estimates the cost of a plan.
195    pub fn estimate_cost(&self, plan: &LogicalPlan) -> Cost {
196        let cardinality = self.card_estimator.estimate(&plan.root);
197        self.cost_model.estimate(&plan.root, cardinality)
198    }
199
200    /// Estimates the cardinality of a plan.
201    pub fn estimate_cardinality(&self, plan: &LogicalPlan) -> f64 {
202        self.card_estimator.estimate(&plan.root)
203    }
204
205    /// Optimizes a logical plan.
206    ///
207    /// # Errors
208    ///
209    /// Returns an error if optimization fails.
210    pub fn optimize(&self, plan: LogicalPlan) -> Result<LogicalPlan> {
211        let mut root = plan.root;
212
213        // Apply optimization rules
214        if self.enable_filter_pushdown {
215            root = self.push_filters_down(root);
216        }
217
218        if self.enable_join_reorder {
219            root = self.reorder_joins(root);
220        }
221
222        if self.enable_projection_pushdown {
223            root = self.push_projections_down(root);
224        }
225
226        Ok(LogicalPlan::new(root))
227    }
228
229    /// Pushes projections down the operator tree to eliminate unused columns early.
230    ///
231    /// This optimization:
232    /// 1. Collects required variables/properties from the root
233    /// 2. Propagates requirements down through the tree
234    /// 3. Inserts projections to eliminate unneeded columns before expensive operations
235    fn push_projections_down(&self, op: LogicalOperator) -> LogicalOperator {
236        // Collect required columns from the top of the plan
237        let required = self.collect_required_columns(&op);
238
239        // Push projections down
240        self.push_projections_recursive(op, &required)
241    }
242
243    /// Collects all variables and properties required by an operator and its ancestors.
244    fn collect_required_columns(&self, op: &LogicalOperator) -> HashSet<RequiredColumn> {
245        let mut required = HashSet::new();
246        Self::collect_required_recursive(op, &mut required);
247        required
248    }
249
250    /// Recursively collects required columns.
251    fn collect_required_recursive(op: &LogicalOperator, required: &mut HashSet<RequiredColumn>) {
252        match op {
253            LogicalOperator::Return(ret) => {
254                for item in &ret.items {
255                    Self::collect_from_expression(&item.expression, required);
256                }
257                Self::collect_required_recursive(&ret.input, required);
258            }
259            LogicalOperator::Project(proj) => {
260                for p in &proj.projections {
261                    Self::collect_from_expression(&p.expression, required);
262                }
263                Self::collect_required_recursive(&proj.input, required);
264            }
265            LogicalOperator::Filter(filter) => {
266                Self::collect_from_expression(&filter.predicate, required);
267                Self::collect_required_recursive(&filter.input, required);
268            }
269            LogicalOperator::Sort(sort) => {
270                for key in &sort.keys {
271                    Self::collect_from_expression(&key.expression, required);
272                }
273                Self::collect_required_recursive(&sort.input, required);
274            }
275            LogicalOperator::Aggregate(agg) => {
276                for expr in &agg.group_by {
277                    Self::collect_from_expression(expr, required);
278                }
279                for agg_expr in &agg.aggregates {
280                    if let Some(ref expr) = agg_expr.expression {
281                        Self::collect_from_expression(expr, required);
282                    }
283                }
284                if let Some(ref having) = agg.having {
285                    Self::collect_from_expression(having, required);
286                }
287                Self::collect_required_recursive(&agg.input, required);
288            }
289            LogicalOperator::Join(join) => {
290                for cond in &join.conditions {
291                    Self::collect_from_expression(&cond.left, required);
292                    Self::collect_from_expression(&cond.right, required);
293                }
294                Self::collect_required_recursive(&join.left, required);
295                Self::collect_required_recursive(&join.right, required);
296            }
297            LogicalOperator::Expand(expand) => {
298                // The source and target variables are needed
299                required.insert(RequiredColumn::Variable(expand.from_variable.clone()));
300                required.insert(RequiredColumn::Variable(expand.to_variable.clone()));
301                if let Some(ref edge_var) = expand.edge_variable {
302                    required.insert(RequiredColumn::Variable(edge_var.clone()));
303                }
304                Self::collect_required_recursive(&expand.input, required);
305            }
306            LogicalOperator::Limit(limit) => {
307                Self::collect_required_recursive(&limit.input, required);
308            }
309            LogicalOperator::Skip(skip) => {
310                Self::collect_required_recursive(&skip.input, required);
311            }
312            LogicalOperator::Distinct(distinct) => {
313                Self::collect_required_recursive(&distinct.input, required);
314            }
315            LogicalOperator::NodeScan(scan) => {
316                required.insert(RequiredColumn::Variable(scan.variable.clone()));
317            }
318            LogicalOperator::EdgeScan(scan) => {
319                required.insert(RequiredColumn::Variable(scan.variable.clone()));
320            }
321            LogicalOperator::MultiWayJoin(mwj) => {
322                for cond in &mwj.conditions {
323                    Self::collect_from_expression(&cond.left, required);
324                    Self::collect_from_expression(&cond.right, required);
325                }
326                for input in &mwj.inputs {
327                    Self::collect_required_recursive(input, required);
328                }
329            }
330            _ => {}
331        }
332    }
333
334    /// Collects required columns from an expression.
335    fn collect_from_expression(expr: &LogicalExpression, required: &mut HashSet<RequiredColumn>) {
336        match expr {
337            LogicalExpression::Variable(var) => {
338                required.insert(RequiredColumn::Variable(var.clone()));
339            }
340            LogicalExpression::Property { variable, property } => {
341                required.insert(RequiredColumn::Property(variable.clone(), property.clone()));
342                required.insert(RequiredColumn::Variable(variable.clone()));
343            }
344            LogicalExpression::Binary { left, right, .. } => {
345                Self::collect_from_expression(left, required);
346                Self::collect_from_expression(right, required);
347            }
348            LogicalExpression::Unary { operand, .. } => {
349                Self::collect_from_expression(operand, required);
350            }
351            LogicalExpression::FunctionCall { args, .. } => {
352                for arg in args {
353                    Self::collect_from_expression(arg, required);
354                }
355            }
356            LogicalExpression::List(items) => {
357                for item in items {
358                    Self::collect_from_expression(item, required);
359                }
360            }
361            LogicalExpression::Map(pairs) => {
362                for (_, value) in pairs {
363                    Self::collect_from_expression(value, required);
364                }
365            }
366            LogicalExpression::IndexAccess { base, index } => {
367                Self::collect_from_expression(base, required);
368                Self::collect_from_expression(index, required);
369            }
370            LogicalExpression::SliceAccess { base, start, end } => {
371                Self::collect_from_expression(base, required);
372                if let Some(s) = start {
373                    Self::collect_from_expression(s, required);
374                }
375                if let Some(e) = end {
376                    Self::collect_from_expression(e, required);
377                }
378            }
379            LogicalExpression::Case {
380                operand,
381                when_clauses,
382                else_clause,
383            } => {
384                if let Some(op) = operand {
385                    Self::collect_from_expression(op, required);
386                }
387                for (cond, result) in when_clauses {
388                    Self::collect_from_expression(cond, required);
389                    Self::collect_from_expression(result, required);
390                }
391                if let Some(else_expr) = else_clause {
392                    Self::collect_from_expression(else_expr, required);
393                }
394            }
395            LogicalExpression::Labels(var)
396            | LogicalExpression::Type(var)
397            | LogicalExpression::Id(var) => {
398                required.insert(RequiredColumn::Variable(var.clone()));
399            }
400            LogicalExpression::ListComprehension {
401                list_expr,
402                filter_expr,
403                map_expr,
404                ..
405            } => {
406                Self::collect_from_expression(list_expr, required);
407                if let Some(filter) = filter_expr {
408                    Self::collect_from_expression(filter, required);
409                }
410                Self::collect_from_expression(map_expr, required);
411            }
412            _ => {}
413        }
414    }
415
416    /// Recursively pushes projections down, adding them before expensive operations.
417    fn push_projections_recursive(
418        &self,
419        op: LogicalOperator,
420        required: &HashSet<RequiredColumn>,
421    ) -> LogicalOperator {
422        match op {
423            LogicalOperator::Return(mut ret) => {
424                ret.input = Box::new(self.push_projections_recursive(*ret.input, required));
425                LogicalOperator::Return(ret)
426            }
427            LogicalOperator::Project(mut proj) => {
428                proj.input = Box::new(self.push_projections_recursive(*proj.input, required));
429                LogicalOperator::Project(proj)
430            }
431            LogicalOperator::Filter(mut filter) => {
432                filter.input = Box::new(self.push_projections_recursive(*filter.input, required));
433                LogicalOperator::Filter(filter)
434            }
435            LogicalOperator::Sort(mut sort) => {
436                // Sort is expensive - consider adding a projection before it
437                // to reduce tuple width
438                sort.input = Box::new(self.push_projections_recursive(*sort.input, required));
439                LogicalOperator::Sort(sort)
440            }
441            LogicalOperator::Aggregate(mut agg) => {
442                agg.input = Box::new(self.push_projections_recursive(*agg.input, required));
443                LogicalOperator::Aggregate(agg)
444            }
445            LogicalOperator::Join(mut join) => {
446                // Joins are expensive - the required columns help determine
447                // what to project on each side
448                let left_vars = self.collect_output_variables(&join.left);
449                let right_vars = self.collect_output_variables(&join.right);
450
451                // Filter required columns to each side
452                let left_required: HashSet<_> = required
453                    .iter()
454                    .filter(|c| match c {
455                        RequiredColumn::Variable(v) => left_vars.contains(v),
456                        RequiredColumn::Property(v, _) => left_vars.contains(v),
457                    })
458                    .cloned()
459                    .collect();
460
461                let right_required: HashSet<_> = required
462                    .iter()
463                    .filter(|c| match c {
464                        RequiredColumn::Variable(v) => right_vars.contains(v),
465                        RequiredColumn::Property(v, _) => right_vars.contains(v),
466                    })
467                    .cloned()
468                    .collect();
469
470                join.left = Box::new(self.push_projections_recursive(*join.left, &left_required));
471                join.right =
472                    Box::new(self.push_projections_recursive(*join.right, &right_required));
473                LogicalOperator::Join(join)
474            }
475            LogicalOperator::Expand(mut expand) => {
476                expand.input = Box::new(self.push_projections_recursive(*expand.input, required));
477                LogicalOperator::Expand(expand)
478            }
479            LogicalOperator::Limit(mut limit) => {
480                limit.input = Box::new(self.push_projections_recursive(*limit.input, required));
481                LogicalOperator::Limit(limit)
482            }
483            LogicalOperator::Skip(mut skip) => {
484                skip.input = Box::new(self.push_projections_recursive(*skip.input, required));
485                LogicalOperator::Skip(skip)
486            }
487            LogicalOperator::Distinct(mut distinct) => {
488                distinct.input =
489                    Box::new(self.push_projections_recursive(*distinct.input, required));
490                LogicalOperator::Distinct(distinct)
491            }
492            LogicalOperator::MapCollect(mut mc) => {
493                mc.input = Box::new(self.push_projections_recursive(*mc.input, required));
494                LogicalOperator::MapCollect(mc)
495            }
496            LogicalOperator::MultiWayJoin(mut mwj) => {
497                mwj.inputs = mwj
498                    .inputs
499                    .into_iter()
500                    .map(|input| self.push_projections_recursive(input, required))
501                    .collect();
502                LogicalOperator::MultiWayJoin(mwj)
503            }
504            other => other,
505        }
506    }
507
508    /// Reorders joins in the operator tree using the DPccp algorithm.
509    ///
510    /// This optimization finds the optimal join order by:
511    /// 1. Extracting all base relations (scans) and join conditions
512    /// 2. Building a join graph
513    /// 3. Using dynamic programming to find the cheapest join order
514    fn reorder_joins(&self, op: LogicalOperator) -> LogicalOperator {
515        // First, recursively optimize children
516        let op = self.reorder_joins_recursive(op);
517
518        // Then, if this is a join tree, try to optimize it
519        if let Some((relations, conditions)) = self.extract_join_tree(&op)
520            && relations.len() >= 2
521            && let Some(optimized) = self.optimize_join_order(&relations, &conditions)
522        {
523            return optimized;
524        }
525
526        op
527    }
528
529    /// Recursively applies join reordering to child operators.
530    fn reorder_joins_recursive(&self, op: LogicalOperator) -> LogicalOperator {
531        match op {
532            LogicalOperator::Return(mut ret) => {
533                ret.input = Box::new(self.reorder_joins(*ret.input));
534                LogicalOperator::Return(ret)
535            }
536            LogicalOperator::Project(mut proj) => {
537                proj.input = Box::new(self.reorder_joins(*proj.input));
538                LogicalOperator::Project(proj)
539            }
540            LogicalOperator::Filter(mut filter) => {
541                filter.input = Box::new(self.reorder_joins(*filter.input));
542                LogicalOperator::Filter(filter)
543            }
544            LogicalOperator::Limit(mut limit) => {
545                limit.input = Box::new(self.reorder_joins(*limit.input));
546                LogicalOperator::Limit(limit)
547            }
548            LogicalOperator::Skip(mut skip) => {
549                skip.input = Box::new(self.reorder_joins(*skip.input));
550                LogicalOperator::Skip(skip)
551            }
552            LogicalOperator::Sort(mut sort) => {
553                sort.input = Box::new(self.reorder_joins(*sort.input));
554                LogicalOperator::Sort(sort)
555            }
556            LogicalOperator::Distinct(mut distinct) => {
557                distinct.input = Box::new(self.reorder_joins(*distinct.input));
558                LogicalOperator::Distinct(distinct)
559            }
560            LogicalOperator::Aggregate(mut agg) => {
561                agg.input = Box::new(self.reorder_joins(*agg.input));
562                LogicalOperator::Aggregate(agg)
563            }
564            LogicalOperator::Expand(mut expand) => {
565                expand.input = Box::new(self.reorder_joins(*expand.input));
566                LogicalOperator::Expand(expand)
567            }
568            LogicalOperator::MapCollect(mut mc) => {
569                mc.input = Box::new(self.reorder_joins(*mc.input));
570                LogicalOperator::MapCollect(mc)
571            }
572            LogicalOperator::MultiWayJoin(mut mwj) => {
573                mwj.inputs = mwj
574                    .inputs
575                    .into_iter()
576                    .map(|input| self.reorder_joins(input))
577                    .collect();
578                LogicalOperator::MultiWayJoin(mwj)
579            }
580            // Join operators are handled by the parent reorder_joins call
581            other => other,
582        }
583    }
584
585    /// Extracts base relations and join conditions from a join tree.
586    ///
587    /// Returns None if the operator is not a join tree.
588    fn extract_join_tree(
589        &self,
590        op: &LogicalOperator,
591    ) -> Option<(Vec<(String, LogicalOperator)>, Vec<JoinInfo>)> {
592        let mut relations = Vec::new();
593        let mut join_conditions = Vec::new();
594
595        if !self.collect_join_tree(op, &mut relations, &mut join_conditions) {
596            return None;
597        }
598
599        if relations.len() < 2 {
600            return None;
601        }
602
603        Some((relations, join_conditions))
604    }
605
606    /// Recursively collects base relations and join conditions.
607    ///
608    /// Returns true if this subtree is part of a join tree.
609    fn collect_join_tree(
610        &self,
611        op: &LogicalOperator,
612        relations: &mut Vec<(String, LogicalOperator)>,
613        conditions: &mut Vec<JoinInfo>,
614    ) -> bool {
615        match op {
616            LogicalOperator::Join(join) => {
617                // Collect from both sides
618                let left_ok = self.collect_join_tree(&join.left, relations, conditions);
619                let right_ok = self.collect_join_tree(&join.right, relations, conditions);
620
621                // Add conditions from this join
622                for cond in &join.conditions {
623                    if let (Some(left_var), Some(right_var)) = (
624                        self.extract_variable_from_expr(&cond.left),
625                        self.extract_variable_from_expr(&cond.right),
626                    ) {
627                        conditions.push(JoinInfo {
628                            left_var,
629                            right_var,
630                            left_expr: cond.left.clone(),
631                            right_expr: cond.right.clone(),
632                        });
633                    }
634                }
635
636                left_ok && right_ok
637            }
638            LogicalOperator::NodeScan(scan) => {
639                relations.push((scan.variable.clone(), op.clone()));
640                true
641            }
642            LogicalOperator::EdgeScan(scan) => {
643                relations.push((scan.variable.clone(), op.clone()));
644                true
645            }
646            LogicalOperator::Filter(filter) => {
647                // A filter on a base relation is still part of the join tree
648                self.collect_join_tree(&filter.input, relations, conditions)
649            }
650            LogicalOperator::Expand(expand) => {
651                // Expand is a special case - it's like a join with the adjacency
652                // For now, treat the whole Expand subtree as a single relation
653                relations.push((expand.to_variable.clone(), op.clone()));
654                true
655            }
656            _ => false,
657        }
658    }
659
660    /// Extracts the primary variable from an expression.
661    fn extract_variable_from_expr(&self, expr: &LogicalExpression) -> Option<String> {
662        match expr {
663            LogicalExpression::Variable(v) => Some(v.clone()),
664            LogicalExpression::Property { variable, .. } => Some(variable.clone()),
665            _ => None,
666        }
667    }
668
669    /// Optimizes the join order using DPccp, or produces a multi-way
670    /// leapfrog join for cyclic patterns when the cost model prefers it.
671    fn optimize_join_order(
672        &self,
673        relations: &[(String, LogicalOperator)],
674        conditions: &[JoinInfo],
675    ) -> Option<LogicalOperator> {
676        use join_order::{DPccp, JoinGraphBuilder};
677
678        // Build the join graph
679        let mut builder = JoinGraphBuilder::new();
680
681        for (var, relation) in relations {
682            builder.add_relation(var, relation.clone());
683        }
684
685        for cond in conditions {
686            builder.add_join_condition(
687                &cond.left_var,
688                &cond.right_var,
689                cond.left_expr.clone(),
690                cond.right_expr.clone(),
691            );
692        }
693
694        let graph = builder.build();
695
696        // For cyclic graphs with 3+ relations, use leapfrog (WCOJ) join.
697        // Cyclic joins (e.g. triangle patterns) benefit from worst-case optimal
698        // multi-way intersection rather than binary hash join cascades that can
699        // produce intermediate blowup.
700        if graph.is_cyclic() && relations.len() >= 3 {
701            // Collect shared variables (variables appearing in 2+ conditions)
702            let mut var_counts: std::collections::HashMap<&str, usize> =
703                std::collections::HashMap::new();
704            for cond in conditions {
705                *var_counts.entry(&cond.left_var).or_default() += 1;
706                *var_counts.entry(&cond.right_var).or_default() += 1;
707            }
708            let shared_variables: Vec<String> = var_counts
709                .into_iter()
710                .filter(|(_, count)| *count >= 2)
711                .map(|(var, _)| var.to_string())
712                .collect();
713
714            let join_conditions: Vec<JoinCondition> = conditions
715                .iter()
716                .map(|c| JoinCondition {
717                    left: c.left_expr.clone(),
718                    right: c.right_expr.clone(),
719                })
720                .collect();
721
722            return Some(LogicalOperator::MultiWayJoin(MultiWayJoinOp {
723                inputs: relations.iter().map(|(_, rel)| rel.clone()).collect(),
724                conditions: join_conditions,
725                shared_variables,
726            }));
727        }
728
729        // Fall through to DPccp for binary join ordering
730        let mut dpccp = DPccp::new(&graph, &self.cost_model, &self.card_estimator);
731        let plan = dpccp.optimize()?;
732
733        Some(plan.operator)
734    }
735
736    /// Pushes filters down the operator tree.
737    ///
738    /// This optimization moves filter predicates as close to the data source
739    /// as possible to reduce the amount of data processed by upper operators.
740    fn push_filters_down(&self, op: LogicalOperator) -> LogicalOperator {
741        match op {
742            // For Filter operators, try to push the predicate into the child
743            LogicalOperator::Filter(filter) => {
744                let optimized_input = self.push_filters_down(*filter.input);
745                self.try_push_filter_into(filter.predicate, optimized_input)
746            }
747            // Recursively optimize children for other operators
748            LogicalOperator::Return(mut ret) => {
749                ret.input = Box::new(self.push_filters_down(*ret.input));
750                LogicalOperator::Return(ret)
751            }
752            LogicalOperator::Project(mut proj) => {
753                proj.input = Box::new(self.push_filters_down(*proj.input));
754                LogicalOperator::Project(proj)
755            }
756            LogicalOperator::Limit(mut limit) => {
757                limit.input = Box::new(self.push_filters_down(*limit.input));
758                LogicalOperator::Limit(limit)
759            }
760            LogicalOperator::Skip(mut skip) => {
761                skip.input = Box::new(self.push_filters_down(*skip.input));
762                LogicalOperator::Skip(skip)
763            }
764            LogicalOperator::Sort(mut sort) => {
765                sort.input = Box::new(self.push_filters_down(*sort.input));
766                LogicalOperator::Sort(sort)
767            }
768            LogicalOperator::Distinct(mut distinct) => {
769                distinct.input = Box::new(self.push_filters_down(*distinct.input));
770                LogicalOperator::Distinct(distinct)
771            }
772            LogicalOperator::Expand(mut expand) => {
773                expand.input = Box::new(self.push_filters_down(*expand.input));
774                LogicalOperator::Expand(expand)
775            }
776            LogicalOperator::Join(mut join) => {
777                join.left = Box::new(self.push_filters_down(*join.left));
778                join.right = Box::new(self.push_filters_down(*join.right));
779                LogicalOperator::Join(join)
780            }
781            LogicalOperator::Aggregate(mut agg) => {
782                agg.input = Box::new(self.push_filters_down(*agg.input));
783                LogicalOperator::Aggregate(agg)
784            }
785            LogicalOperator::MapCollect(mut mc) => {
786                mc.input = Box::new(self.push_filters_down(*mc.input));
787                LogicalOperator::MapCollect(mc)
788            }
789            LogicalOperator::MultiWayJoin(mut mwj) => {
790                mwj.inputs = mwj
791                    .inputs
792                    .into_iter()
793                    .map(|input| self.push_filters_down(input))
794                    .collect();
795                LogicalOperator::MultiWayJoin(mwj)
796            }
797            // Leaf operators and unsupported operators are returned as-is
798            other => other,
799        }
800    }
801
802    /// Tries to push a filter predicate into the given operator.
803    ///
804    /// Returns either the predicate pushed into the operator, or a new
805    /// Filter operator on top if the predicate cannot be pushed further.
806    fn try_push_filter_into(
807        &self,
808        predicate: LogicalExpression,
809        op: LogicalOperator,
810    ) -> LogicalOperator {
811        match op {
812            // Can push through Project if predicate doesn't depend on computed columns
813            LogicalOperator::Project(mut proj) => {
814                let predicate_vars = self.extract_variables(&predicate);
815                let computed_vars = self.extract_projection_aliases(&proj.projections);
816
817                // If predicate doesn't use any computed columns, push through
818                if predicate_vars.is_disjoint(&computed_vars) {
819                    proj.input = Box::new(self.try_push_filter_into(predicate, *proj.input));
820                    LogicalOperator::Project(proj)
821                } else {
822                    // Can't push through, keep filter on top
823                    LogicalOperator::Filter(FilterOp {
824                        predicate,
825                        pushdown_hint: None,
826                        input: Box::new(LogicalOperator::Project(proj)),
827                    })
828                }
829            }
830
831            // Can push through Return (which is like a projection)
832            LogicalOperator::Return(mut ret) => {
833                ret.input = Box::new(self.try_push_filter_into(predicate, *ret.input));
834                LogicalOperator::Return(ret)
835            }
836
837            // Can push through Expand if predicate doesn't use variables introduced by this expand
838            LogicalOperator::Expand(mut expand) => {
839                let predicate_vars = self.extract_variables(&predicate);
840
841                // Variables introduced by this expand are:
842                // - The target variable (to_variable)
843                // - The edge variable (if any)
844                // - The path alias (if any)
845                let mut introduced_vars = vec![&expand.to_variable];
846                if let Some(ref edge_var) = expand.edge_variable {
847                    introduced_vars.push(edge_var);
848                }
849                if let Some(ref path_alias) = expand.path_alias {
850                    introduced_vars.push(path_alias);
851                }
852
853                // Check if predicate uses any variables introduced by this expand
854                let uses_introduced_vars =
855                    predicate_vars.iter().any(|v| introduced_vars.contains(&v));
856
857                if !uses_introduced_vars {
858                    // Predicate doesn't use vars from this expand, so push through
859                    expand.input = Box::new(self.try_push_filter_into(predicate, *expand.input));
860                    LogicalOperator::Expand(expand)
861                } else {
862                    // Keep filter after expand
863                    LogicalOperator::Filter(FilterOp {
864                        predicate,
865                        pushdown_hint: None,
866                        input: Box::new(LogicalOperator::Expand(expand)),
867                    })
868                }
869            }
870
871            // Can push through Join to left/right side based on variables used
872            LogicalOperator::Join(mut join) => {
873                let predicate_vars = self.extract_variables(&predicate);
874                let left_vars = self.collect_output_variables(&join.left);
875                let right_vars = self.collect_output_variables(&join.right);
876
877                let uses_left = predicate_vars.iter().any(|v| left_vars.contains(v));
878                let uses_right = predicate_vars.iter().any(|v| right_vars.contains(v));
879
880                if uses_left && !uses_right {
881                    // Push to left side
882                    join.left = Box::new(self.try_push_filter_into(predicate, *join.left));
883                    LogicalOperator::Join(join)
884                } else if uses_right && !uses_left {
885                    // Push to right side
886                    join.right = Box::new(self.try_push_filter_into(predicate, *join.right));
887                    LogicalOperator::Join(join)
888                } else {
889                    // Uses both sides - keep above join
890                    LogicalOperator::Filter(FilterOp {
891                        predicate,
892                        pushdown_hint: None,
893                        input: Box::new(LogicalOperator::Join(join)),
894                    })
895                }
896            }
897
898            // Cannot push through Aggregate (predicate refers to aggregated values)
899            LogicalOperator::Aggregate(agg) => LogicalOperator::Filter(FilterOp {
900                predicate,
901                pushdown_hint: None,
902                input: Box::new(LogicalOperator::Aggregate(agg)),
903            }),
904
905            // For NodeScan, we've reached the bottom - keep filter on top
906            LogicalOperator::NodeScan(scan) => LogicalOperator::Filter(FilterOp {
907                predicate,
908                pushdown_hint: None,
909                input: Box::new(LogicalOperator::NodeScan(scan)),
910            }),
911
912            // For other operators, keep filter on top
913            other => LogicalOperator::Filter(FilterOp {
914                predicate,
915                pushdown_hint: None,
916                input: Box::new(other),
917            }),
918        }
919    }
920
921    /// Collects all output variable names from an operator.
922    fn collect_output_variables(&self, op: &LogicalOperator) -> HashSet<String> {
923        let mut vars = HashSet::new();
924        Self::collect_output_variables_recursive(op, &mut vars);
925        vars
926    }
927
928    /// Recursively collects output variables from an operator.
929    fn collect_output_variables_recursive(op: &LogicalOperator, vars: &mut HashSet<String>) {
930        match op {
931            LogicalOperator::NodeScan(scan) => {
932                vars.insert(scan.variable.clone());
933            }
934            LogicalOperator::EdgeScan(scan) => {
935                vars.insert(scan.variable.clone());
936            }
937            LogicalOperator::Expand(expand) => {
938                vars.insert(expand.to_variable.clone());
939                if let Some(edge_var) = &expand.edge_variable {
940                    vars.insert(edge_var.clone());
941                }
942                Self::collect_output_variables_recursive(&expand.input, vars);
943            }
944            LogicalOperator::Filter(filter) => {
945                Self::collect_output_variables_recursive(&filter.input, vars);
946            }
947            LogicalOperator::Project(proj) => {
948                for p in &proj.projections {
949                    if let Some(alias) = &p.alias {
950                        vars.insert(alias.clone());
951                    }
952                }
953                Self::collect_output_variables_recursive(&proj.input, vars);
954            }
955            LogicalOperator::Join(join) => {
956                Self::collect_output_variables_recursive(&join.left, vars);
957                Self::collect_output_variables_recursive(&join.right, vars);
958            }
959            LogicalOperator::Aggregate(agg) => {
960                for expr in &agg.group_by {
961                    Self::collect_variables(expr, vars);
962                }
963                for agg_expr in &agg.aggregates {
964                    if let Some(alias) = &agg_expr.alias {
965                        vars.insert(alias.clone());
966                    }
967                }
968            }
969            LogicalOperator::Return(ret) => {
970                Self::collect_output_variables_recursive(&ret.input, vars);
971            }
972            LogicalOperator::Limit(limit) => {
973                Self::collect_output_variables_recursive(&limit.input, vars);
974            }
975            LogicalOperator::Skip(skip) => {
976                Self::collect_output_variables_recursive(&skip.input, vars);
977            }
978            LogicalOperator::Sort(sort) => {
979                Self::collect_output_variables_recursive(&sort.input, vars);
980            }
981            LogicalOperator::Distinct(distinct) => {
982                Self::collect_output_variables_recursive(&distinct.input, vars);
983            }
984            _ => {}
985        }
986    }
987
988    /// Extracts all variable names referenced in an expression.
989    fn extract_variables(&self, expr: &LogicalExpression) -> HashSet<String> {
990        let mut vars = HashSet::new();
991        Self::collect_variables(expr, &mut vars);
992        vars
993    }
994
995    /// Recursively collects variable names from an expression.
996    fn collect_variables(expr: &LogicalExpression, vars: &mut HashSet<String>) {
997        match expr {
998            LogicalExpression::Variable(name) => {
999                vars.insert(name.clone());
1000            }
1001            LogicalExpression::Property { variable, .. } => {
1002                vars.insert(variable.clone());
1003            }
1004            LogicalExpression::Binary { left, right, .. } => {
1005                Self::collect_variables(left, vars);
1006                Self::collect_variables(right, vars);
1007            }
1008            LogicalExpression::Unary { operand, .. } => {
1009                Self::collect_variables(operand, vars);
1010            }
1011            LogicalExpression::FunctionCall { args, .. } => {
1012                for arg in args {
1013                    Self::collect_variables(arg, vars);
1014                }
1015            }
1016            LogicalExpression::List(items) => {
1017                for item in items {
1018                    Self::collect_variables(item, vars);
1019                }
1020            }
1021            LogicalExpression::Map(pairs) => {
1022                for (_, value) in pairs {
1023                    Self::collect_variables(value, vars);
1024                }
1025            }
1026            LogicalExpression::IndexAccess { base, index } => {
1027                Self::collect_variables(base, vars);
1028                Self::collect_variables(index, vars);
1029            }
1030            LogicalExpression::SliceAccess { base, start, end } => {
1031                Self::collect_variables(base, vars);
1032                if let Some(s) = start {
1033                    Self::collect_variables(s, vars);
1034                }
1035                if let Some(e) = end {
1036                    Self::collect_variables(e, vars);
1037                }
1038            }
1039            LogicalExpression::Case {
1040                operand,
1041                when_clauses,
1042                else_clause,
1043            } => {
1044                if let Some(op) = operand {
1045                    Self::collect_variables(op, vars);
1046                }
1047                for (cond, result) in when_clauses {
1048                    Self::collect_variables(cond, vars);
1049                    Self::collect_variables(result, vars);
1050                }
1051                if let Some(else_expr) = else_clause {
1052                    Self::collect_variables(else_expr, vars);
1053                }
1054            }
1055            LogicalExpression::Labels(var)
1056            | LogicalExpression::Type(var)
1057            | LogicalExpression::Id(var) => {
1058                vars.insert(var.clone());
1059            }
1060            LogicalExpression::Literal(_) | LogicalExpression::Parameter(_) => {}
1061            LogicalExpression::ListComprehension {
1062                list_expr,
1063                filter_expr,
1064                map_expr,
1065                ..
1066            } => {
1067                Self::collect_variables(list_expr, vars);
1068                if let Some(filter) = filter_expr {
1069                    Self::collect_variables(filter, vars);
1070                }
1071                Self::collect_variables(map_expr, vars);
1072            }
1073            LogicalExpression::ListPredicate {
1074                list_expr,
1075                predicate,
1076                ..
1077            } => {
1078                Self::collect_variables(list_expr, vars);
1079                Self::collect_variables(predicate, vars);
1080            }
1081            LogicalExpression::ExistsSubquery(_) | LogicalExpression::CountSubquery(_) => {
1082                // Subqueries have their own variable scope
1083            }
1084            LogicalExpression::PatternComprehension { projection, .. } => {
1085                Self::collect_variables(projection, vars);
1086            }
1087            LogicalExpression::MapProjection { base, entries } => {
1088                vars.insert(base.clone());
1089                for entry in entries {
1090                    if let crate::query::plan::MapProjectionEntry::LiteralEntry(_, expr) = entry {
1091                        Self::collect_variables(expr, vars);
1092                    }
1093                }
1094            }
1095            LogicalExpression::Reduce {
1096                initial,
1097                list,
1098                expression,
1099                ..
1100            } => {
1101                Self::collect_variables(initial, vars);
1102                Self::collect_variables(list, vars);
1103                Self::collect_variables(expression, vars);
1104            }
1105        }
1106    }
1107
1108    /// Extracts aliases from projection expressions.
1109    fn extract_projection_aliases(
1110        &self,
1111        projections: &[crate::query::plan::Projection],
1112    ) -> HashSet<String> {
1113        projections.iter().filter_map(|p| p.alias.clone()).collect()
1114    }
1115}
1116
1117impl Default for Optimizer {
1118    fn default() -> Self {
1119        Self::new()
1120    }
1121}
1122
1123#[cfg(test)]
1124mod tests {
1125    use super::*;
1126    use crate::query::plan::{
1127        AggregateExpr, AggregateFunction, AggregateOp, BinaryOp, DistinctOp, ExpandDirection,
1128        ExpandOp, JoinOp, JoinType, LimitOp, NodeScanOp, PathMode, ProjectOp, Projection,
1129        ReturnItem, ReturnOp, SkipOp, SortKey, SortOp, SortOrder, UnaryOp,
1130    };
1131    use grafeo_common::types::Value;
1132
1133    #[test]
1134    fn test_optimizer_filter_pushdown_simple() {
1135        // Query: MATCH (n:Person) WHERE n.age > 30 RETURN n
1136        // Before: Return -> Filter -> NodeScan
1137        // After:  Return -> Filter -> NodeScan (filter stays at bottom)
1138
1139        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1140            items: vec![ReturnItem {
1141                expression: LogicalExpression::Variable("n".to_string()),
1142                alias: None,
1143            }],
1144            distinct: false,
1145            input: Box::new(LogicalOperator::Filter(FilterOp {
1146                predicate: LogicalExpression::Binary {
1147                    left: Box::new(LogicalExpression::Property {
1148                        variable: "n".to_string(),
1149                        property: "age".to_string(),
1150                    }),
1151                    op: BinaryOp::Gt,
1152                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1153                },
1154                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1155                    variable: "n".to_string(),
1156                    label: Some("Person".to_string()),
1157                    input: None,
1158                })),
1159                pushdown_hint: None,
1160            })),
1161        }));
1162
1163        let optimizer = Optimizer::new();
1164        let optimized = optimizer.optimize(plan).unwrap();
1165
1166        // The structure should remain similar (filter stays near scan)
1167        if let LogicalOperator::Return(ret) = &optimized.root
1168            && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1169            && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1170        {
1171            assert_eq!(scan.variable, "n");
1172            return;
1173        }
1174        panic!("Expected Return -> Filter -> NodeScan structure");
1175    }
1176
1177    #[test]
1178    fn test_optimizer_filter_pushdown_through_expand() {
1179        // Query: MATCH (a:Person)-[:KNOWS]->(b) WHERE a.age > 30 RETURN b
1180        // The filter on 'a' should be pushed before the expand
1181
1182        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1183            items: vec![ReturnItem {
1184                expression: LogicalExpression::Variable("b".to_string()),
1185                alias: None,
1186            }],
1187            distinct: false,
1188            input: Box::new(LogicalOperator::Filter(FilterOp {
1189                predicate: LogicalExpression::Binary {
1190                    left: Box::new(LogicalExpression::Property {
1191                        variable: "a".to_string(),
1192                        property: "age".to_string(),
1193                    }),
1194                    op: BinaryOp::Gt,
1195                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1196                },
1197                pushdown_hint: None,
1198                input: Box::new(LogicalOperator::Expand(ExpandOp {
1199                    from_variable: "a".to_string(),
1200                    to_variable: "b".to_string(),
1201                    edge_variable: None,
1202                    direction: ExpandDirection::Outgoing,
1203                    edge_types: vec!["KNOWS".to_string()],
1204                    min_hops: 1,
1205                    max_hops: Some(1),
1206                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1207                        variable: "a".to_string(),
1208                        label: Some("Person".to_string()),
1209                        input: None,
1210                    })),
1211                    path_alias: None,
1212                    path_mode: PathMode::Walk,
1213                })),
1214            })),
1215        }));
1216
1217        let optimizer = Optimizer::new();
1218        let optimized = optimizer.optimize(plan).unwrap();
1219
1220        // Filter on 'a' should be pushed before the expand
1221        // Expected: Return -> Expand -> Filter -> NodeScan
1222        if let LogicalOperator::Return(ret) = &optimized.root
1223            && let LogicalOperator::Expand(expand) = ret.input.as_ref()
1224            && let LogicalOperator::Filter(filter) = expand.input.as_ref()
1225            && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1226        {
1227            assert_eq!(scan.variable, "a");
1228            assert_eq!(expand.from_variable, "a");
1229            assert_eq!(expand.to_variable, "b");
1230            return;
1231        }
1232        panic!("Expected Return -> Expand -> Filter -> NodeScan structure");
1233    }
1234
1235    #[test]
1236    fn test_optimizer_filter_not_pushed_through_expand_for_target_var() {
1237        // Query: MATCH (a:Person)-[:KNOWS]->(b) WHERE b.age > 30 RETURN a
1238        // The filter on 'b' should NOT be pushed before the expand
1239
1240        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1241            items: vec![ReturnItem {
1242                expression: LogicalExpression::Variable("a".to_string()),
1243                alias: None,
1244            }],
1245            distinct: false,
1246            input: Box::new(LogicalOperator::Filter(FilterOp {
1247                predicate: LogicalExpression::Binary {
1248                    left: Box::new(LogicalExpression::Property {
1249                        variable: "b".to_string(),
1250                        property: "age".to_string(),
1251                    }),
1252                    op: BinaryOp::Gt,
1253                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1254                },
1255                pushdown_hint: None,
1256                input: Box::new(LogicalOperator::Expand(ExpandOp {
1257                    from_variable: "a".to_string(),
1258                    to_variable: "b".to_string(),
1259                    edge_variable: None,
1260                    direction: ExpandDirection::Outgoing,
1261                    edge_types: vec!["KNOWS".to_string()],
1262                    min_hops: 1,
1263                    max_hops: Some(1),
1264                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1265                        variable: "a".to_string(),
1266                        label: Some("Person".to_string()),
1267                        input: None,
1268                    })),
1269                    path_alias: None,
1270                    path_mode: PathMode::Walk,
1271                })),
1272            })),
1273        }));
1274
1275        let optimizer = Optimizer::new();
1276        let optimized = optimizer.optimize(plan).unwrap();
1277
1278        // Filter on 'b' should stay after the expand
1279        // Expected: Return -> Filter -> Expand -> NodeScan
1280        if let LogicalOperator::Return(ret) = &optimized.root
1281            && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1282        {
1283            // Check that the filter is on 'b'
1284            if let LogicalExpression::Binary { left, .. } = &filter.predicate
1285                && let LogicalExpression::Property { variable, .. } = left.as_ref()
1286            {
1287                assert_eq!(variable, "b");
1288            }
1289
1290            if let LogicalOperator::Expand(expand) = filter.input.as_ref()
1291                && let LogicalOperator::NodeScan(_) = expand.input.as_ref()
1292            {
1293                return;
1294            }
1295        }
1296        panic!("Expected Return -> Filter -> Expand -> NodeScan structure");
1297    }
1298
1299    #[test]
1300    fn test_optimizer_extract_variables() {
1301        let optimizer = Optimizer::new();
1302
1303        let expr = LogicalExpression::Binary {
1304            left: Box::new(LogicalExpression::Property {
1305                variable: "n".to_string(),
1306                property: "age".to_string(),
1307            }),
1308            op: BinaryOp::Gt,
1309            right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1310        };
1311
1312        let vars = optimizer.extract_variables(&expr);
1313        assert_eq!(vars.len(), 1);
1314        assert!(vars.contains("n"));
1315    }
1316
1317    // Additional tests for optimizer configuration
1318
1319    #[test]
1320    fn test_optimizer_default() {
1321        let optimizer = Optimizer::default();
1322        // Should be able to optimize an empty plan
1323        let plan = LogicalPlan::new(LogicalOperator::Empty);
1324        let result = optimizer.optimize(plan);
1325        assert!(result.is_ok());
1326    }
1327
1328    #[test]
1329    fn test_optimizer_with_filter_pushdown_disabled() {
1330        let optimizer = Optimizer::new().with_filter_pushdown(false);
1331
1332        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1333            items: vec![ReturnItem {
1334                expression: LogicalExpression::Variable("n".to_string()),
1335                alias: None,
1336            }],
1337            distinct: false,
1338            input: Box::new(LogicalOperator::Filter(FilterOp {
1339                predicate: LogicalExpression::Literal(Value::Bool(true)),
1340                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1341                    variable: "n".to_string(),
1342                    label: None,
1343                    input: None,
1344                })),
1345                pushdown_hint: None,
1346            })),
1347        }));
1348
1349        let optimized = optimizer.optimize(plan).unwrap();
1350        // Structure should be unchanged
1351        if let LogicalOperator::Return(ret) = &optimized.root
1352            && let LogicalOperator::Filter(_) = ret.input.as_ref()
1353        {
1354            return;
1355        }
1356        panic!("Expected unchanged structure");
1357    }
1358
1359    #[test]
1360    fn test_optimizer_with_join_reorder_disabled() {
1361        let optimizer = Optimizer::new().with_join_reorder(false);
1362        assert!(
1363            optimizer
1364                .optimize(LogicalPlan::new(LogicalOperator::Empty))
1365                .is_ok()
1366        );
1367    }
1368
1369    #[test]
1370    fn test_optimizer_with_cost_model() {
1371        let cost_model = CostModel::new();
1372        let optimizer = Optimizer::new().with_cost_model(cost_model);
1373        assert!(
1374            optimizer
1375                .cost_model()
1376                .estimate(&LogicalOperator::Empty, 0.0)
1377                .total()
1378                < 0.001
1379        );
1380    }
1381
1382    #[test]
1383    fn test_optimizer_with_cardinality_estimator() {
1384        let mut estimator = CardinalityEstimator::new();
1385        estimator.add_table_stats("Test", TableStats::new(500));
1386        let optimizer = Optimizer::new().with_cardinality_estimator(estimator);
1387
1388        let scan = LogicalOperator::NodeScan(NodeScanOp {
1389            variable: "n".to_string(),
1390            label: Some("Test".to_string()),
1391            input: None,
1392        });
1393        let plan = LogicalPlan::new(scan);
1394
1395        let cardinality = optimizer.estimate_cardinality(&plan);
1396        assert!((cardinality - 500.0).abs() < 0.001);
1397    }
1398
1399    #[test]
1400    fn test_optimizer_estimate_cost() {
1401        let optimizer = Optimizer::new();
1402        let plan = LogicalPlan::new(LogicalOperator::NodeScan(NodeScanOp {
1403            variable: "n".to_string(),
1404            label: None,
1405            input: None,
1406        }));
1407
1408        let cost = optimizer.estimate_cost(&plan);
1409        assert!(cost.total() > 0.0);
1410    }
1411
1412    // Filter pushdown through various operators
1413
1414    #[test]
1415    fn test_filter_pushdown_through_project() {
1416        let optimizer = Optimizer::new();
1417
1418        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1419            predicate: LogicalExpression::Binary {
1420                left: Box::new(LogicalExpression::Property {
1421                    variable: "n".to_string(),
1422                    property: "age".to_string(),
1423                }),
1424                op: BinaryOp::Gt,
1425                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1426            },
1427            pushdown_hint: None,
1428            input: Box::new(LogicalOperator::Project(ProjectOp {
1429                projections: vec![Projection {
1430                    expression: LogicalExpression::Variable("n".to_string()),
1431                    alias: None,
1432                }],
1433                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1434                    variable: "n".to_string(),
1435                    label: None,
1436                    input: None,
1437                })),
1438            })),
1439        }));
1440
1441        let optimized = optimizer.optimize(plan).unwrap();
1442
1443        // Filter should be pushed through Project
1444        if let LogicalOperator::Project(proj) = &optimized.root
1445            && let LogicalOperator::Filter(_) = proj.input.as_ref()
1446        {
1447            return;
1448        }
1449        panic!("Expected Project -> Filter structure");
1450    }
1451
1452    #[test]
1453    fn test_filter_not_pushed_through_project_with_alias() {
1454        let optimizer = Optimizer::new();
1455
1456        // Filter on computed column 'x' should not be pushed through project that creates 'x'
1457        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1458            predicate: LogicalExpression::Binary {
1459                left: Box::new(LogicalExpression::Variable("x".to_string())),
1460                op: BinaryOp::Gt,
1461                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1462            },
1463            pushdown_hint: None,
1464            input: Box::new(LogicalOperator::Project(ProjectOp {
1465                projections: vec![Projection {
1466                    expression: LogicalExpression::Property {
1467                        variable: "n".to_string(),
1468                        property: "age".to_string(),
1469                    },
1470                    alias: Some("x".to_string()),
1471                }],
1472                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1473                    variable: "n".to_string(),
1474                    label: None,
1475                    input: None,
1476                })),
1477            })),
1478        }));
1479
1480        let optimized = optimizer.optimize(plan).unwrap();
1481
1482        // Filter should stay above Project
1483        if let LogicalOperator::Filter(filter) = &optimized.root
1484            && let LogicalOperator::Project(_) = filter.input.as_ref()
1485        {
1486            return;
1487        }
1488        panic!("Expected Filter -> Project structure");
1489    }
1490
1491    #[test]
1492    fn test_filter_pushdown_through_limit() {
1493        let optimizer = Optimizer::new();
1494
1495        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1496            predicate: LogicalExpression::Literal(Value::Bool(true)),
1497            pushdown_hint: None,
1498            input: Box::new(LogicalOperator::Limit(LimitOp {
1499                count: 10,
1500                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1501                    variable: "n".to_string(),
1502                    label: None,
1503                    input: None,
1504                })),
1505            })),
1506        }));
1507
1508        let optimized = optimizer.optimize(plan).unwrap();
1509
1510        // Filter stays above Limit (cannot be pushed through)
1511        if let LogicalOperator::Filter(filter) = &optimized.root
1512            && let LogicalOperator::Limit(_) = filter.input.as_ref()
1513        {
1514            return;
1515        }
1516        panic!("Expected Filter -> Limit structure");
1517    }
1518
1519    #[test]
1520    fn test_filter_pushdown_through_sort() {
1521        let optimizer = Optimizer::new();
1522
1523        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1524            predicate: LogicalExpression::Literal(Value::Bool(true)),
1525            pushdown_hint: None,
1526            input: Box::new(LogicalOperator::Sort(SortOp {
1527                keys: vec![SortKey {
1528                    expression: LogicalExpression::Variable("n".to_string()),
1529                    order: SortOrder::Ascending,
1530                    nulls: None,
1531                }],
1532                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1533                    variable: "n".to_string(),
1534                    label: None,
1535                    input: None,
1536                })),
1537            })),
1538        }));
1539
1540        let optimized = optimizer.optimize(plan).unwrap();
1541
1542        // Filter stays above Sort
1543        if let LogicalOperator::Filter(filter) = &optimized.root
1544            && let LogicalOperator::Sort(_) = filter.input.as_ref()
1545        {
1546            return;
1547        }
1548        panic!("Expected Filter -> Sort structure");
1549    }
1550
1551    #[test]
1552    fn test_filter_pushdown_through_distinct() {
1553        let optimizer = Optimizer::new();
1554
1555        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1556            predicate: LogicalExpression::Literal(Value::Bool(true)),
1557            pushdown_hint: None,
1558            input: Box::new(LogicalOperator::Distinct(DistinctOp {
1559                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1560                    variable: "n".to_string(),
1561                    label: None,
1562                    input: None,
1563                })),
1564                columns: None,
1565            })),
1566        }));
1567
1568        let optimized = optimizer.optimize(plan).unwrap();
1569
1570        // Filter stays above Distinct
1571        if let LogicalOperator::Filter(filter) = &optimized.root
1572            && let LogicalOperator::Distinct(_) = filter.input.as_ref()
1573        {
1574            return;
1575        }
1576        panic!("Expected Filter -> Distinct structure");
1577    }
1578
1579    #[test]
1580    fn test_filter_not_pushed_through_aggregate() {
1581        let optimizer = Optimizer::new();
1582
1583        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1584            predicate: LogicalExpression::Binary {
1585                left: Box::new(LogicalExpression::Variable("cnt".to_string())),
1586                op: BinaryOp::Gt,
1587                right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1588            },
1589            pushdown_hint: None,
1590            input: Box::new(LogicalOperator::Aggregate(AggregateOp {
1591                group_by: vec![],
1592                aggregates: vec![AggregateExpr {
1593                    function: AggregateFunction::Count,
1594                    expression: None,
1595                    expression2: None,
1596                    distinct: false,
1597                    alias: Some("cnt".to_string()),
1598                    percentile: None,
1599                }],
1600                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1601                    variable: "n".to_string(),
1602                    label: None,
1603                    input: None,
1604                })),
1605                having: None,
1606            })),
1607        }));
1608
1609        let optimized = optimizer.optimize(plan).unwrap();
1610
1611        // Filter should stay above Aggregate
1612        if let LogicalOperator::Filter(filter) = &optimized.root
1613            && let LogicalOperator::Aggregate(_) = filter.input.as_ref()
1614        {
1615            return;
1616        }
1617        panic!("Expected Filter -> Aggregate structure");
1618    }
1619
1620    #[test]
1621    fn test_filter_pushdown_to_left_join_side() {
1622        let optimizer = Optimizer::new();
1623
1624        // Filter on left variable should be pushed to left side
1625        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1626            predicate: LogicalExpression::Binary {
1627                left: Box::new(LogicalExpression::Property {
1628                    variable: "a".to_string(),
1629                    property: "age".to_string(),
1630                }),
1631                op: BinaryOp::Gt,
1632                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1633            },
1634            pushdown_hint: None,
1635            input: Box::new(LogicalOperator::Join(JoinOp {
1636                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1637                    variable: "a".to_string(),
1638                    label: Some("Person".to_string()),
1639                    input: None,
1640                })),
1641                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1642                    variable: "b".to_string(),
1643                    label: Some("Company".to_string()),
1644                    input: None,
1645                })),
1646                join_type: JoinType::Inner,
1647                conditions: vec![],
1648            })),
1649        }));
1650
1651        let optimized = optimizer.optimize(plan).unwrap();
1652
1653        // Filter should be pushed to left side of join
1654        if let LogicalOperator::Join(join) = &optimized.root
1655            && let LogicalOperator::Filter(_) = join.left.as_ref()
1656        {
1657            return;
1658        }
1659        panic!("Expected Join with Filter on left side");
1660    }
1661
1662    #[test]
1663    fn test_filter_pushdown_to_right_join_side() {
1664        let optimizer = Optimizer::new();
1665
1666        // Filter on right variable should be pushed to right side
1667        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1668            predicate: LogicalExpression::Binary {
1669                left: Box::new(LogicalExpression::Property {
1670                    variable: "b".to_string(),
1671                    property: "name".to_string(),
1672                }),
1673                op: BinaryOp::Eq,
1674                right: Box::new(LogicalExpression::Literal(Value::String("Acme".into()))),
1675            },
1676            pushdown_hint: None,
1677            input: Box::new(LogicalOperator::Join(JoinOp {
1678                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1679                    variable: "a".to_string(),
1680                    label: Some("Person".to_string()),
1681                    input: None,
1682                })),
1683                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1684                    variable: "b".to_string(),
1685                    label: Some("Company".to_string()),
1686                    input: None,
1687                })),
1688                join_type: JoinType::Inner,
1689                conditions: vec![],
1690            })),
1691        }));
1692
1693        let optimized = optimizer.optimize(plan).unwrap();
1694
1695        // Filter should be pushed to right side of join
1696        if let LogicalOperator::Join(join) = &optimized.root
1697            && let LogicalOperator::Filter(_) = join.right.as_ref()
1698        {
1699            return;
1700        }
1701        panic!("Expected Join with Filter on right side");
1702    }
1703
1704    #[test]
1705    fn test_filter_not_pushed_when_uses_both_join_sides() {
1706        let optimizer = Optimizer::new();
1707
1708        // Filter using both variables should stay above join
1709        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1710            predicate: LogicalExpression::Binary {
1711                left: Box::new(LogicalExpression::Property {
1712                    variable: "a".to_string(),
1713                    property: "id".to_string(),
1714                }),
1715                op: BinaryOp::Eq,
1716                right: Box::new(LogicalExpression::Property {
1717                    variable: "b".to_string(),
1718                    property: "a_id".to_string(),
1719                }),
1720            },
1721            pushdown_hint: None,
1722            input: Box::new(LogicalOperator::Join(JoinOp {
1723                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1724                    variable: "a".to_string(),
1725                    label: None,
1726                    input: None,
1727                })),
1728                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1729                    variable: "b".to_string(),
1730                    label: None,
1731                    input: None,
1732                })),
1733                join_type: JoinType::Inner,
1734                conditions: vec![],
1735            })),
1736        }));
1737
1738        let optimized = optimizer.optimize(plan).unwrap();
1739
1740        // Filter should stay above join
1741        if let LogicalOperator::Filter(filter) = &optimized.root
1742            && let LogicalOperator::Join(_) = filter.input.as_ref()
1743        {
1744            return;
1745        }
1746        panic!("Expected Filter -> Join structure");
1747    }
1748
1749    // Variable extraction tests
1750
1751    #[test]
1752    fn test_extract_variables_from_variable() {
1753        let optimizer = Optimizer::new();
1754        let expr = LogicalExpression::Variable("x".to_string());
1755        let vars = optimizer.extract_variables(&expr);
1756        assert_eq!(vars.len(), 1);
1757        assert!(vars.contains("x"));
1758    }
1759
1760    #[test]
1761    fn test_extract_variables_from_unary() {
1762        let optimizer = Optimizer::new();
1763        let expr = LogicalExpression::Unary {
1764            op: UnaryOp::Not,
1765            operand: Box::new(LogicalExpression::Variable("x".to_string())),
1766        };
1767        let vars = optimizer.extract_variables(&expr);
1768        assert_eq!(vars.len(), 1);
1769        assert!(vars.contains("x"));
1770    }
1771
1772    #[test]
1773    fn test_extract_variables_from_function_call() {
1774        let optimizer = Optimizer::new();
1775        let expr = LogicalExpression::FunctionCall {
1776            name: "length".to_string(),
1777            args: vec![
1778                LogicalExpression::Variable("a".to_string()),
1779                LogicalExpression::Variable("b".to_string()),
1780            ],
1781            distinct: false,
1782        };
1783        let vars = optimizer.extract_variables(&expr);
1784        assert_eq!(vars.len(), 2);
1785        assert!(vars.contains("a"));
1786        assert!(vars.contains("b"));
1787    }
1788
1789    #[test]
1790    fn test_extract_variables_from_list() {
1791        let optimizer = Optimizer::new();
1792        let expr = LogicalExpression::List(vec![
1793            LogicalExpression::Variable("a".to_string()),
1794            LogicalExpression::Literal(Value::Int64(1)),
1795            LogicalExpression::Variable("b".to_string()),
1796        ]);
1797        let vars = optimizer.extract_variables(&expr);
1798        assert_eq!(vars.len(), 2);
1799        assert!(vars.contains("a"));
1800        assert!(vars.contains("b"));
1801    }
1802
1803    #[test]
1804    fn test_extract_variables_from_map() {
1805        let optimizer = Optimizer::new();
1806        let expr = LogicalExpression::Map(vec![
1807            (
1808                "key1".to_string(),
1809                LogicalExpression::Variable("a".to_string()),
1810            ),
1811            (
1812                "key2".to_string(),
1813                LogicalExpression::Variable("b".to_string()),
1814            ),
1815        ]);
1816        let vars = optimizer.extract_variables(&expr);
1817        assert_eq!(vars.len(), 2);
1818        assert!(vars.contains("a"));
1819        assert!(vars.contains("b"));
1820    }
1821
1822    #[test]
1823    fn test_extract_variables_from_index_access() {
1824        let optimizer = Optimizer::new();
1825        let expr = LogicalExpression::IndexAccess {
1826            base: Box::new(LogicalExpression::Variable("list".to_string())),
1827            index: Box::new(LogicalExpression::Variable("idx".to_string())),
1828        };
1829        let vars = optimizer.extract_variables(&expr);
1830        assert_eq!(vars.len(), 2);
1831        assert!(vars.contains("list"));
1832        assert!(vars.contains("idx"));
1833    }
1834
1835    #[test]
1836    fn test_extract_variables_from_slice_access() {
1837        let optimizer = Optimizer::new();
1838        let expr = LogicalExpression::SliceAccess {
1839            base: Box::new(LogicalExpression::Variable("list".to_string())),
1840            start: Some(Box::new(LogicalExpression::Variable("s".to_string()))),
1841            end: Some(Box::new(LogicalExpression::Variable("e".to_string()))),
1842        };
1843        let vars = optimizer.extract_variables(&expr);
1844        assert_eq!(vars.len(), 3);
1845        assert!(vars.contains("list"));
1846        assert!(vars.contains("s"));
1847        assert!(vars.contains("e"));
1848    }
1849
1850    #[test]
1851    fn test_extract_variables_from_case() {
1852        let optimizer = Optimizer::new();
1853        let expr = LogicalExpression::Case {
1854            operand: Some(Box::new(LogicalExpression::Variable("x".to_string()))),
1855            when_clauses: vec![(
1856                LogicalExpression::Literal(Value::Int64(1)),
1857                LogicalExpression::Variable("a".to_string()),
1858            )],
1859            else_clause: Some(Box::new(LogicalExpression::Variable("b".to_string()))),
1860        };
1861        let vars = optimizer.extract_variables(&expr);
1862        assert_eq!(vars.len(), 3);
1863        assert!(vars.contains("x"));
1864        assert!(vars.contains("a"));
1865        assert!(vars.contains("b"));
1866    }
1867
1868    #[test]
1869    fn test_extract_variables_from_labels() {
1870        let optimizer = Optimizer::new();
1871        let expr = LogicalExpression::Labels("n".to_string());
1872        let vars = optimizer.extract_variables(&expr);
1873        assert_eq!(vars.len(), 1);
1874        assert!(vars.contains("n"));
1875    }
1876
1877    #[test]
1878    fn test_extract_variables_from_type() {
1879        let optimizer = Optimizer::new();
1880        let expr = LogicalExpression::Type("e".to_string());
1881        let vars = optimizer.extract_variables(&expr);
1882        assert_eq!(vars.len(), 1);
1883        assert!(vars.contains("e"));
1884    }
1885
1886    #[test]
1887    fn test_extract_variables_from_id() {
1888        let optimizer = Optimizer::new();
1889        let expr = LogicalExpression::Id("n".to_string());
1890        let vars = optimizer.extract_variables(&expr);
1891        assert_eq!(vars.len(), 1);
1892        assert!(vars.contains("n"));
1893    }
1894
1895    #[test]
1896    fn test_extract_variables_from_list_comprehension() {
1897        let optimizer = Optimizer::new();
1898        let expr = LogicalExpression::ListComprehension {
1899            variable: "x".to_string(),
1900            list_expr: Box::new(LogicalExpression::Variable("items".to_string())),
1901            filter_expr: Some(Box::new(LogicalExpression::Variable("pred".to_string()))),
1902            map_expr: Box::new(LogicalExpression::Variable("result".to_string())),
1903        };
1904        let vars = optimizer.extract_variables(&expr);
1905        assert!(vars.contains("items"));
1906        assert!(vars.contains("pred"));
1907        assert!(vars.contains("result"));
1908    }
1909
1910    #[test]
1911    fn test_extract_variables_from_literal_and_parameter() {
1912        let optimizer = Optimizer::new();
1913
1914        let literal = LogicalExpression::Literal(Value::Int64(42));
1915        assert!(optimizer.extract_variables(&literal).is_empty());
1916
1917        let param = LogicalExpression::Parameter("p".to_string());
1918        assert!(optimizer.extract_variables(&param).is_empty());
1919    }
1920
1921    // Recursive filter pushdown tests
1922
1923    #[test]
1924    fn test_recursive_filter_pushdown_through_skip() {
1925        let optimizer = Optimizer::new();
1926
1927        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1928            items: vec![ReturnItem {
1929                expression: LogicalExpression::Variable("n".to_string()),
1930                alias: None,
1931            }],
1932            distinct: false,
1933            input: Box::new(LogicalOperator::Filter(FilterOp {
1934                predicate: LogicalExpression::Literal(Value::Bool(true)),
1935                pushdown_hint: None,
1936                input: Box::new(LogicalOperator::Skip(SkipOp {
1937                    count: 5,
1938                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1939                        variable: "n".to_string(),
1940                        label: None,
1941                        input: None,
1942                    })),
1943                })),
1944            })),
1945        }));
1946
1947        let optimized = optimizer.optimize(plan).unwrap();
1948
1949        // Verify optimization succeeded
1950        assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1951    }
1952
1953    #[test]
1954    fn test_nested_filter_pushdown() {
1955        let optimizer = Optimizer::new();
1956
1957        // Multiple nested filters
1958        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1959            items: vec![ReturnItem {
1960                expression: LogicalExpression::Variable("n".to_string()),
1961                alias: None,
1962            }],
1963            distinct: false,
1964            input: Box::new(LogicalOperator::Filter(FilterOp {
1965                predicate: LogicalExpression::Binary {
1966                    left: Box::new(LogicalExpression::Property {
1967                        variable: "n".to_string(),
1968                        property: "x".to_string(),
1969                    }),
1970                    op: BinaryOp::Gt,
1971                    right: Box::new(LogicalExpression::Literal(Value::Int64(1))),
1972                },
1973                pushdown_hint: None,
1974                input: Box::new(LogicalOperator::Filter(FilterOp {
1975                    predicate: LogicalExpression::Binary {
1976                        left: Box::new(LogicalExpression::Property {
1977                            variable: "n".to_string(),
1978                            property: "y".to_string(),
1979                        }),
1980                        op: BinaryOp::Lt,
1981                        right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1982                    },
1983                    pushdown_hint: None,
1984                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1985                        variable: "n".to_string(),
1986                        label: None,
1987                        input: None,
1988                    })),
1989                })),
1990            })),
1991        }));
1992
1993        let optimized = optimizer.optimize(plan).unwrap();
1994        assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1995    }
1996
1997    #[test]
1998    fn test_cyclic_join_produces_multi_way_join() {
1999        use crate::query::plan::JoinCondition;
2000
2001        // Triangle pattern: a ⋈ b ⋈ c ⋈ a (cyclic)
2002        let scan_a = LogicalOperator::NodeScan(NodeScanOp {
2003            variable: "a".to_string(),
2004            label: Some("Person".to_string()),
2005            input: None,
2006        });
2007        let scan_b = LogicalOperator::NodeScan(NodeScanOp {
2008            variable: "b".to_string(),
2009            label: Some("Person".to_string()),
2010            input: None,
2011        });
2012        let scan_c = LogicalOperator::NodeScan(NodeScanOp {
2013            variable: "c".to_string(),
2014            label: Some("Person".to_string()),
2015            input: None,
2016        });
2017
2018        // Build: Join(Join(a, b, a=b), c, b=c) with extra condition c=a
2019        let join_ab = LogicalOperator::Join(JoinOp {
2020            left: Box::new(scan_a),
2021            right: Box::new(scan_b),
2022            join_type: JoinType::Inner,
2023            conditions: vec![JoinCondition {
2024                left: LogicalExpression::Variable("a".to_string()),
2025                right: LogicalExpression::Variable("b".to_string()),
2026            }],
2027        });
2028
2029        let join_abc = LogicalOperator::Join(JoinOp {
2030            left: Box::new(join_ab),
2031            right: Box::new(scan_c),
2032            join_type: JoinType::Inner,
2033            conditions: vec![
2034                JoinCondition {
2035                    left: LogicalExpression::Variable("b".to_string()),
2036                    right: LogicalExpression::Variable("c".to_string()),
2037                },
2038                JoinCondition {
2039                    left: LogicalExpression::Variable("c".to_string()),
2040                    right: LogicalExpression::Variable("a".to_string()),
2041                },
2042            ],
2043        });
2044
2045        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
2046            items: vec![ReturnItem {
2047                expression: LogicalExpression::Variable("a".to_string()),
2048                alias: None,
2049            }],
2050            distinct: false,
2051            input: Box::new(join_abc),
2052        }));
2053
2054        let mut optimizer = Optimizer::new();
2055        optimizer
2056            .card_estimator
2057            .add_table_stats("Person", cardinality::TableStats::new(1000));
2058
2059        let optimized = optimizer.optimize(plan).unwrap();
2060
2061        // Walk the tree to find a MultiWayJoin
2062        fn has_multi_way_join(op: &LogicalOperator) -> bool {
2063            match op {
2064                LogicalOperator::MultiWayJoin(_) => true,
2065                LogicalOperator::Return(ret) => has_multi_way_join(&ret.input),
2066                LogicalOperator::Filter(f) => has_multi_way_join(&f.input),
2067                LogicalOperator::Project(p) => has_multi_way_join(&p.input),
2068                _ => false,
2069            }
2070        }
2071
2072        assert!(
2073            has_multi_way_join(&optimized.root),
2074            "Expected MultiWayJoin for cyclic triangle pattern"
2075        );
2076    }
2077
2078    #[test]
2079    fn test_acyclic_join_uses_binary_joins() {
2080        use crate::query::plan::JoinCondition;
2081
2082        // Chain: a ⋈ b ⋈ c (acyclic)
2083        let scan_a = LogicalOperator::NodeScan(NodeScanOp {
2084            variable: "a".to_string(),
2085            label: Some("Person".to_string()),
2086            input: None,
2087        });
2088        let scan_b = LogicalOperator::NodeScan(NodeScanOp {
2089            variable: "b".to_string(),
2090            label: Some("Person".to_string()),
2091            input: None,
2092        });
2093        let scan_c = LogicalOperator::NodeScan(NodeScanOp {
2094            variable: "c".to_string(),
2095            label: Some("Company".to_string()),
2096            input: None,
2097        });
2098
2099        let join_ab = LogicalOperator::Join(JoinOp {
2100            left: Box::new(scan_a),
2101            right: Box::new(scan_b),
2102            join_type: JoinType::Inner,
2103            conditions: vec![JoinCondition {
2104                left: LogicalExpression::Variable("a".to_string()),
2105                right: LogicalExpression::Variable("b".to_string()),
2106            }],
2107        });
2108
2109        let join_abc = LogicalOperator::Join(JoinOp {
2110            left: Box::new(join_ab),
2111            right: Box::new(scan_c),
2112            join_type: JoinType::Inner,
2113            conditions: vec![JoinCondition {
2114                left: LogicalExpression::Variable("b".to_string()),
2115                right: LogicalExpression::Variable("c".to_string()),
2116            }],
2117        });
2118
2119        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
2120            items: vec![ReturnItem {
2121                expression: LogicalExpression::Variable("a".to_string()),
2122                alias: None,
2123            }],
2124            distinct: false,
2125            input: Box::new(join_abc),
2126        }));
2127
2128        let mut optimizer = Optimizer::new();
2129        optimizer
2130            .card_estimator
2131            .add_table_stats("Person", cardinality::TableStats::new(1000));
2132        optimizer
2133            .card_estimator
2134            .add_table_stats("Company", cardinality::TableStats::new(100));
2135
2136        let optimized = optimizer.optimize(plan).unwrap();
2137
2138        // Should NOT contain MultiWayJoin for acyclic pattern
2139        fn has_multi_way_join(op: &LogicalOperator) -> bool {
2140            match op {
2141                LogicalOperator::MultiWayJoin(_) => true,
2142                LogicalOperator::Return(ret) => has_multi_way_join(&ret.input),
2143                LogicalOperator::Filter(f) => has_multi_way_join(&f.input),
2144                LogicalOperator::Project(p) => has_multi_way_join(&p.input),
2145                LogicalOperator::Join(j) => {
2146                    has_multi_way_join(&j.left) || has_multi_way_join(&j.right)
2147                }
2148                _ => false,
2149            }
2150        }
2151
2152        assert!(
2153            !has_multi_way_join(&optimized.root),
2154            "Acyclic join should NOT produce MultiWayJoin"
2155        );
2156    }
2157}