Skip to main content

grafeo_engine/query/optimizer/
mod.rs

1//! Makes your queries faster without changing their meaning.
2//!
3//! The optimizer transforms logical plans to run more efficiently:
4//!
5//! | Optimization | What it does |
6//! | ------------ | ------------ |
7//! | Filter Pushdown | Moves `WHERE` clauses closer to scans - filter early, process less |
8//! | Join Reordering | Picks the best order to join tables using the DPccp algorithm |
9//! | Predicate Simplification | Folds constants like `1 + 1` into `2` |
10//!
11//! The optimizer uses [`CostModel`] and [`CardinalityEstimator`] to predict
12//! how expensive different plans are, then picks the cheapest.
13
14pub mod cardinality;
15pub mod cost;
16pub mod join_order;
17
18pub use cardinality::{
19    CardinalityEstimator, ColumnStats, EstimationLog, SelectivityConfig, TableStats,
20};
21pub use cost::{Cost, CostModel};
22pub use join_order::{BitSet, DPccp, JoinGraph, JoinGraphBuilder, JoinPlan};
23
24use crate::query::plan::{
25    FilterOp, JoinCondition, LogicalExpression, LogicalOperator, LogicalPlan, MultiWayJoinOp,
26};
27use grafeo_common::utils::error::Result;
28use std::collections::HashSet;
29
30/// Information about a join condition for join reordering.
31#[derive(Debug, Clone)]
32struct JoinInfo {
33    left_var: String,
34    right_var: String,
35    left_expr: LogicalExpression,
36    right_expr: LogicalExpression,
37}
38
39/// A column required by the query, used for projection pushdown.
40#[derive(Debug, Clone, PartialEq, Eq, Hash)]
41enum RequiredColumn {
42    /// A variable (node, edge, or path binding)
43    Variable(String),
44    /// A specific property of a variable
45    Property(String, String),
46}
47
48/// Transforms logical plans for faster execution.
49///
50/// Create with [`new()`](Self::new), then call [`optimize()`](Self::optimize).
51/// Use the builder methods to enable/disable specific optimizations.
52pub struct Optimizer {
53    /// Whether to enable filter pushdown.
54    enable_filter_pushdown: bool,
55    /// Whether to enable join reordering.
56    enable_join_reorder: bool,
57    /// Whether to enable projection pushdown.
58    enable_projection_pushdown: bool,
59    /// Cost model for estimation.
60    cost_model: CostModel,
61    /// Cardinality estimator.
62    card_estimator: CardinalityEstimator,
63}
64
65impl Optimizer {
66    /// Creates a new optimizer with default settings.
67    #[must_use]
68    pub fn new() -> Self {
69        Self {
70            enable_filter_pushdown: true,
71            enable_join_reorder: true,
72            enable_projection_pushdown: true,
73            cost_model: CostModel::new(),
74            card_estimator: CardinalityEstimator::new(),
75        }
76    }
77
78    /// Creates an optimizer with cardinality estimates from the store's statistics.
79    ///
80    /// Pre-populates the cardinality estimator with per-label row counts and
81    /// edge type fanout. Feeds per-edge-type degree stats, label cardinalities,
82    /// and graph totals into the cost model for accurate estimation.
83    #[must_use]
84    pub fn from_store(store: &grafeo_core::graph::lpg::LpgStore) -> Self {
85        store.ensure_statistics_fresh();
86        let stats = store.statistics();
87        Self::from_statistics(&stats)
88    }
89
90    /// Creates an optimizer from any GraphStore implementation.
91    ///
92    /// Unlike [`from_store`](Self::from_store), this does not call
93    /// `ensure_statistics_fresh()` since external stores manage their own
94    /// statistics. The store's [`statistics()`](grafeo_core::graph::GraphStore::statistics) method
95    /// is called directly.
96    #[must_use]
97    pub fn from_graph_store(store: &dyn grafeo_core::graph::GraphStore) -> Self {
98        let stats = store.statistics();
99        Self::from_statistics(&stats)
100    }
101
102    /// Creates an optimizer from RDF statistics.
103    ///
104    /// Uses triple pattern cardinality estimates for cost-based optimization
105    /// of SPARQL queries. Maps total triples to graph totals for the cost model.
106    #[cfg(feature = "rdf")]
107    #[must_use]
108    pub fn from_rdf_statistics(rdf_stats: grafeo_core::statistics::RdfStatistics) -> Self {
109        let total = rdf_stats.total_triples;
110        let estimator = CardinalityEstimator::from_rdf_statistics(rdf_stats);
111        Self {
112            enable_filter_pushdown: true,
113            enable_join_reorder: true,
114            enable_projection_pushdown: true,
115            cost_model: CostModel::new().with_graph_totals(total, total),
116            card_estimator: estimator,
117        }
118    }
119
120    /// Creates an optimizer from a Statistics snapshot.
121    ///
122    /// Extracts label cardinalities, edge type degrees, and graph totals
123    /// for both the cardinality estimator and cost model.
124    #[must_use]
125    fn from_statistics(stats: &grafeo_core::statistics::Statistics) -> Self {
126        let estimator = CardinalityEstimator::from_statistics(stats);
127
128        let avg_fanout = if stats.total_nodes > 0 {
129            (stats.total_edges as f64 / stats.total_nodes as f64).max(1.0)
130        } else {
131            10.0
132        };
133
134        let edge_type_degrees: std::collections::HashMap<String, (f64, f64)> = stats
135            .edge_types
136            .iter()
137            .map(|(name, et)| (name.clone(), (et.avg_out_degree, et.avg_in_degree)))
138            .collect();
139
140        let label_cardinalities: std::collections::HashMap<String, u64> = stats
141            .labels
142            .iter()
143            .map(|(name, ls)| (name.clone(), ls.node_count))
144            .collect();
145
146        Self {
147            enable_filter_pushdown: true,
148            enable_join_reorder: true,
149            enable_projection_pushdown: true,
150            cost_model: CostModel::new()
151                .with_avg_fanout(avg_fanout)
152                .with_edge_type_degrees(edge_type_degrees)
153                .with_label_cardinalities(label_cardinalities)
154                .with_graph_totals(stats.total_nodes, stats.total_edges),
155            card_estimator: estimator,
156        }
157    }
158
159    /// Enables or disables filter pushdown.
160    pub fn with_filter_pushdown(mut self, enabled: bool) -> Self {
161        self.enable_filter_pushdown = enabled;
162        self
163    }
164
165    /// Enables or disables join reordering.
166    pub fn with_join_reorder(mut self, enabled: bool) -> Self {
167        self.enable_join_reorder = enabled;
168        self
169    }
170
171    /// Enables or disables projection pushdown.
172    pub fn with_projection_pushdown(mut self, enabled: bool) -> Self {
173        self.enable_projection_pushdown = enabled;
174        self
175    }
176
177    /// Sets the cost model.
178    pub fn with_cost_model(mut self, cost_model: CostModel) -> Self {
179        self.cost_model = cost_model;
180        self
181    }
182
183    /// Sets the cardinality estimator.
184    pub fn with_cardinality_estimator(mut self, estimator: CardinalityEstimator) -> Self {
185        self.card_estimator = estimator;
186        self
187    }
188
189    /// Sets the selectivity configuration for the cardinality estimator.
190    pub fn with_selectivity_config(mut self, config: SelectivityConfig) -> Self {
191        self.card_estimator = CardinalityEstimator::with_selectivity_config(config);
192        self
193    }
194
195    /// Returns a reference to the cost model.
196    pub fn cost_model(&self) -> &CostModel {
197        &self.cost_model
198    }
199
200    /// Returns a reference to the cardinality estimator.
201    pub fn cardinality_estimator(&self) -> &CardinalityEstimator {
202        &self.card_estimator
203    }
204
205    /// Estimates the total cost of a plan by recursively costing the entire tree.
206    ///
207    /// Walks every operator in the plan, computing cardinality at each level
208    /// and summing the per-operator costs. Uses actual child cardinalities
209    /// for join cost estimation rather than approximations.
210    pub fn estimate_cost(&self, plan: &LogicalPlan) -> Cost {
211        self.cost_model
212            .estimate_tree(&plan.root, &self.card_estimator)
213    }
214
215    /// Estimates the cardinality of a plan.
216    pub fn estimate_cardinality(&self, plan: &LogicalPlan) -> f64 {
217        self.card_estimator.estimate(&plan.root)
218    }
219
220    /// Optimizes a logical plan.
221    ///
222    /// # Errors
223    ///
224    /// Returns an error if optimization fails.
225    pub fn optimize(&self, plan: LogicalPlan) -> Result<LogicalPlan> {
226        let _span = tracing::debug_span!("grafeo::query::optimize").entered();
227        let mut root = plan.root;
228
229        // Apply optimization rules
230        if self.enable_filter_pushdown {
231            root = self.push_filters_down(root);
232        }
233
234        if self.enable_join_reorder {
235            root = self.reorder_joins(root);
236        }
237
238        if self.enable_projection_pushdown {
239            root = self.push_projections_down(root);
240        }
241
242        Ok(LogicalPlan {
243            root,
244            explain: plan.explain,
245            profile: plan.profile,
246        })
247    }
248
249    /// Pushes projections down the operator tree to eliminate unused columns early.
250    ///
251    /// This optimization:
252    /// 1. Collects required variables/properties from the root
253    /// 2. Propagates requirements down through the tree
254    /// 3. Inserts projections to eliminate unneeded columns before expensive operations
255    fn push_projections_down(&self, op: LogicalOperator) -> LogicalOperator {
256        // Collect required columns from the top of the plan
257        let required = self.collect_required_columns(&op);
258
259        // Push projections down
260        self.push_projections_recursive(op, &required)
261    }
262
263    /// Collects all variables and properties required by an operator and its ancestors.
264    fn collect_required_columns(&self, op: &LogicalOperator) -> HashSet<RequiredColumn> {
265        let mut required = HashSet::new();
266        Self::collect_required_recursive(op, &mut required);
267        required
268    }
269
270    /// Recursively collects required columns.
271    fn collect_required_recursive(op: &LogicalOperator, required: &mut HashSet<RequiredColumn>) {
272        match op {
273            LogicalOperator::Return(ret) => {
274                for item in &ret.items {
275                    Self::collect_from_expression(&item.expression, required);
276                }
277                Self::collect_required_recursive(&ret.input, required);
278            }
279            LogicalOperator::Project(proj) => {
280                for p in &proj.projections {
281                    Self::collect_from_expression(&p.expression, required);
282                }
283                Self::collect_required_recursive(&proj.input, required);
284            }
285            LogicalOperator::Filter(filter) => {
286                Self::collect_from_expression(&filter.predicate, required);
287                Self::collect_required_recursive(&filter.input, required);
288            }
289            LogicalOperator::Sort(sort) => {
290                for key in &sort.keys {
291                    Self::collect_from_expression(&key.expression, required);
292                }
293                Self::collect_required_recursive(&sort.input, required);
294            }
295            LogicalOperator::Aggregate(agg) => {
296                for expr in &agg.group_by {
297                    Self::collect_from_expression(expr, required);
298                }
299                for agg_expr in &agg.aggregates {
300                    if let Some(ref expr) = agg_expr.expression {
301                        Self::collect_from_expression(expr, required);
302                    }
303                }
304                if let Some(ref having) = agg.having {
305                    Self::collect_from_expression(having, required);
306                }
307                Self::collect_required_recursive(&agg.input, required);
308            }
309            LogicalOperator::Join(join) => {
310                for cond in &join.conditions {
311                    Self::collect_from_expression(&cond.left, required);
312                    Self::collect_from_expression(&cond.right, required);
313                }
314                Self::collect_required_recursive(&join.left, required);
315                Self::collect_required_recursive(&join.right, required);
316            }
317            LogicalOperator::Expand(expand) => {
318                // The source and target variables are needed
319                required.insert(RequiredColumn::Variable(expand.from_variable.clone()));
320                required.insert(RequiredColumn::Variable(expand.to_variable.clone()));
321                if let Some(ref edge_var) = expand.edge_variable {
322                    required.insert(RequiredColumn::Variable(edge_var.clone()));
323                }
324                Self::collect_required_recursive(&expand.input, required);
325            }
326            LogicalOperator::Limit(limit) => {
327                Self::collect_required_recursive(&limit.input, required);
328            }
329            LogicalOperator::Skip(skip) => {
330                Self::collect_required_recursive(&skip.input, required);
331            }
332            LogicalOperator::Distinct(distinct) => {
333                Self::collect_required_recursive(&distinct.input, required);
334            }
335            LogicalOperator::NodeScan(scan) => {
336                required.insert(RequiredColumn::Variable(scan.variable.clone()));
337            }
338            LogicalOperator::EdgeScan(scan) => {
339                required.insert(RequiredColumn::Variable(scan.variable.clone()));
340            }
341            LogicalOperator::MultiWayJoin(mwj) => {
342                for cond in &mwj.conditions {
343                    Self::collect_from_expression(&cond.left, required);
344                    Self::collect_from_expression(&cond.right, required);
345                }
346                for input in &mwj.inputs {
347                    Self::collect_required_recursive(input, required);
348                }
349            }
350            _ => {}
351        }
352    }
353
354    /// Collects required columns from an expression.
355    fn collect_from_expression(expr: &LogicalExpression, required: &mut HashSet<RequiredColumn>) {
356        match expr {
357            LogicalExpression::Variable(var) => {
358                required.insert(RequiredColumn::Variable(var.clone()));
359            }
360            LogicalExpression::Property { variable, property } => {
361                required.insert(RequiredColumn::Property(variable.clone(), property.clone()));
362                required.insert(RequiredColumn::Variable(variable.clone()));
363            }
364            LogicalExpression::Binary { left, right, .. } => {
365                Self::collect_from_expression(left, required);
366                Self::collect_from_expression(right, required);
367            }
368            LogicalExpression::Unary { operand, .. } => {
369                Self::collect_from_expression(operand, required);
370            }
371            LogicalExpression::FunctionCall { args, .. } => {
372                for arg in args {
373                    Self::collect_from_expression(arg, required);
374                }
375            }
376            LogicalExpression::List(items) => {
377                for item in items {
378                    Self::collect_from_expression(item, required);
379                }
380            }
381            LogicalExpression::Map(pairs) => {
382                for (_, value) in pairs {
383                    Self::collect_from_expression(value, required);
384                }
385            }
386            LogicalExpression::IndexAccess { base, index } => {
387                Self::collect_from_expression(base, required);
388                Self::collect_from_expression(index, required);
389            }
390            LogicalExpression::SliceAccess { base, start, end } => {
391                Self::collect_from_expression(base, required);
392                if let Some(s) = start {
393                    Self::collect_from_expression(s, required);
394                }
395                if let Some(e) = end {
396                    Self::collect_from_expression(e, required);
397                }
398            }
399            LogicalExpression::Case {
400                operand,
401                when_clauses,
402                else_clause,
403            } => {
404                if let Some(op) = operand {
405                    Self::collect_from_expression(op, required);
406                }
407                for (cond, result) in when_clauses {
408                    Self::collect_from_expression(cond, required);
409                    Self::collect_from_expression(result, required);
410                }
411                if let Some(else_expr) = else_clause {
412                    Self::collect_from_expression(else_expr, required);
413                }
414            }
415            LogicalExpression::Labels(var)
416            | LogicalExpression::Type(var)
417            | LogicalExpression::Id(var) => {
418                required.insert(RequiredColumn::Variable(var.clone()));
419            }
420            LogicalExpression::ListComprehension {
421                list_expr,
422                filter_expr,
423                map_expr,
424                ..
425            } => {
426                Self::collect_from_expression(list_expr, required);
427                if let Some(filter) = filter_expr {
428                    Self::collect_from_expression(filter, required);
429                }
430                Self::collect_from_expression(map_expr, required);
431            }
432            _ => {}
433        }
434    }
435
436    /// Recursively pushes projections down, adding them before expensive operations.
437    fn push_projections_recursive(
438        &self,
439        op: LogicalOperator,
440        required: &HashSet<RequiredColumn>,
441    ) -> LogicalOperator {
442        match op {
443            LogicalOperator::Return(mut ret) => {
444                ret.input = Box::new(self.push_projections_recursive(*ret.input, required));
445                LogicalOperator::Return(ret)
446            }
447            LogicalOperator::Project(mut proj) => {
448                proj.input = Box::new(self.push_projections_recursive(*proj.input, required));
449                LogicalOperator::Project(proj)
450            }
451            LogicalOperator::Filter(mut filter) => {
452                filter.input = Box::new(self.push_projections_recursive(*filter.input, required));
453                LogicalOperator::Filter(filter)
454            }
455            LogicalOperator::Sort(mut sort) => {
456                // Sort is expensive - consider adding a projection before it
457                // to reduce tuple width
458                sort.input = Box::new(self.push_projections_recursive(*sort.input, required));
459                LogicalOperator::Sort(sort)
460            }
461            LogicalOperator::Aggregate(mut agg) => {
462                agg.input = Box::new(self.push_projections_recursive(*agg.input, required));
463                LogicalOperator::Aggregate(agg)
464            }
465            LogicalOperator::Join(mut join) => {
466                // Joins are expensive - the required columns help determine
467                // what to project on each side
468                let left_vars = self.collect_output_variables(&join.left);
469                let right_vars = self.collect_output_variables(&join.right);
470
471                // Filter required columns to each side
472                let left_required: HashSet<_> = required
473                    .iter()
474                    .filter(|c| match c {
475                        RequiredColumn::Variable(v) => left_vars.contains(v),
476                        RequiredColumn::Property(v, _) => left_vars.contains(v),
477                    })
478                    .cloned()
479                    .collect();
480
481                let right_required: HashSet<_> = required
482                    .iter()
483                    .filter(|c| match c {
484                        RequiredColumn::Variable(v) => right_vars.contains(v),
485                        RequiredColumn::Property(v, _) => right_vars.contains(v),
486                    })
487                    .cloned()
488                    .collect();
489
490                join.left = Box::new(self.push_projections_recursive(*join.left, &left_required));
491                join.right =
492                    Box::new(self.push_projections_recursive(*join.right, &right_required));
493                LogicalOperator::Join(join)
494            }
495            LogicalOperator::Expand(mut expand) => {
496                expand.input = Box::new(self.push_projections_recursive(*expand.input, required));
497                LogicalOperator::Expand(expand)
498            }
499            LogicalOperator::Limit(mut limit) => {
500                limit.input = Box::new(self.push_projections_recursive(*limit.input, required));
501                LogicalOperator::Limit(limit)
502            }
503            LogicalOperator::Skip(mut skip) => {
504                skip.input = Box::new(self.push_projections_recursive(*skip.input, required));
505                LogicalOperator::Skip(skip)
506            }
507            LogicalOperator::Distinct(mut distinct) => {
508                distinct.input =
509                    Box::new(self.push_projections_recursive(*distinct.input, required));
510                LogicalOperator::Distinct(distinct)
511            }
512            LogicalOperator::MapCollect(mut mc) => {
513                mc.input = Box::new(self.push_projections_recursive(*mc.input, required));
514                LogicalOperator::MapCollect(mc)
515            }
516            LogicalOperator::MultiWayJoin(mut mwj) => {
517                mwj.inputs = mwj
518                    .inputs
519                    .into_iter()
520                    .map(|input| self.push_projections_recursive(input, required))
521                    .collect();
522                LogicalOperator::MultiWayJoin(mwj)
523            }
524            other => other,
525        }
526    }
527
528    /// Reorders joins in the operator tree using the DPccp algorithm.
529    ///
530    /// This optimization finds the optimal join order by:
531    /// 1. Extracting all base relations (scans) and join conditions
532    /// 2. Building a join graph
533    /// 3. Using dynamic programming to find the cheapest join order
534    fn reorder_joins(&self, op: LogicalOperator) -> LogicalOperator {
535        // First, recursively optimize children
536        let op = self.reorder_joins_recursive(op);
537
538        // Then, if this is a join tree, try to optimize it
539        if let Some((relations, conditions)) = self.extract_join_tree(&op)
540            && relations.len() >= 2
541            && let Some(optimized) = self.optimize_join_order(&relations, &conditions)
542        {
543            return optimized;
544        }
545
546        op
547    }
548
549    /// Recursively applies join reordering to child operators.
550    fn reorder_joins_recursive(&self, op: LogicalOperator) -> LogicalOperator {
551        match op {
552            LogicalOperator::Return(mut ret) => {
553                ret.input = Box::new(self.reorder_joins(*ret.input));
554                LogicalOperator::Return(ret)
555            }
556            LogicalOperator::Project(mut proj) => {
557                proj.input = Box::new(self.reorder_joins(*proj.input));
558                LogicalOperator::Project(proj)
559            }
560            LogicalOperator::Filter(mut filter) => {
561                filter.input = Box::new(self.reorder_joins(*filter.input));
562                LogicalOperator::Filter(filter)
563            }
564            LogicalOperator::Limit(mut limit) => {
565                limit.input = Box::new(self.reorder_joins(*limit.input));
566                LogicalOperator::Limit(limit)
567            }
568            LogicalOperator::Skip(mut skip) => {
569                skip.input = Box::new(self.reorder_joins(*skip.input));
570                LogicalOperator::Skip(skip)
571            }
572            LogicalOperator::Sort(mut sort) => {
573                sort.input = Box::new(self.reorder_joins(*sort.input));
574                LogicalOperator::Sort(sort)
575            }
576            LogicalOperator::Distinct(mut distinct) => {
577                distinct.input = Box::new(self.reorder_joins(*distinct.input));
578                LogicalOperator::Distinct(distinct)
579            }
580            LogicalOperator::Aggregate(mut agg) => {
581                agg.input = Box::new(self.reorder_joins(*agg.input));
582                LogicalOperator::Aggregate(agg)
583            }
584            LogicalOperator::Expand(mut expand) => {
585                expand.input = Box::new(self.reorder_joins(*expand.input));
586                LogicalOperator::Expand(expand)
587            }
588            LogicalOperator::MapCollect(mut mc) => {
589                mc.input = Box::new(self.reorder_joins(*mc.input));
590                LogicalOperator::MapCollect(mc)
591            }
592            LogicalOperator::MultiWayJoin(mut mwj) => {
593                mwj.inputs = mwj
594                    .inputs
595                    .into_iter()
596                    .map(|input| self.reorder_joins(input))
597                    .collect();
598                LogicalOperator::MultiWayJoin(mwj)
599            }
600            // Join operators are handled by the parent reorder_joins call
601            other => other,
602        }
603    }
604
605    /// Extracts base relations and join conditions from a join tree.
606    ///
607    /// Returns None if the operator is not a join tree.
608    fn extract_join_tree(
609        &self,
610        op: &LogicalOperator,
611    ) -> Option<(Vec<(String, LogicalOperator)>, Vec<JoinInfo>)> {
612        let mut relations = Vec::new();
613        let mut join_conditions = Vec::new();
614
615        if !self.collect_join_tree(op, &mut relations, &mut join_conditions) {
616            return None;
617        }
618
619        if relations.len() < 2 {
620            return None;
621        }
622
623        Some((relations, join_conditions))
624    }
625
626    /// Recursively collects base relations and join conditions.
627    ///
628    /// Returns true if this subtree is part of a join tree.
629    fn collect_join_tree(
630        &self,
631        op: &LogicalOperator,
632        relations: &mut Vec<(String, LogicalOperator)>,
633        conditions: &mut Vec<JoinInfo>,
634    ) -> bool {
635        match op {
636            LogicalOperator::Join(join) => {
637                // Collect from both sides
638                let left_ok = self.collect_join_tree(&join.left, relations, conditions);
639                let right_ok = self.collect_join_tree(&join.right, relations, conditions);
640
641                // Add conditions from this join
642                for cond in &join.conditions {
643                    if let (Some(left_var), Some(right_var)) = (
644                        self.extract_variable_from_expr(&cond.left),
645                        self.extract_variable_from_expr(&cond.right),
646                    ) {
647                        conditions.push(JoinInfo {
648                            left_var,
649                            right_var,
650                            left_expr: cond.left.clone(),
651                            right_expr: cond.right.clone(),
652                        });
653                    }
654                }
655
656                left_ok && right_ok
657            }
658            LogicalOperator::NodeScan(scan) => {
659                relations.push((scan.variable.clone(), op.clone()));
660                true
661            }
662            LogicalOperator::EdgeScan(scan) => {
663                relations.push((scan.variable.clone(), op.clone()));
664                true
665            }
666            LogicalOperator::Filter(filter) => {
667                // A filter on a base relation is still part of the join tree
668                self.collect_join_tree(&filter.input, relations, conditions)
669            }
670            LogicalOperator::Expand(expand) => {
671                // Expand is a special case - it's like a join with the adjacency
672                // For now, treat the whole Expand subtree as a single relation
673                relations.push((expand.to_variable.clone(), op.clone()));
674                true
675            }
676            _ => false,
677        }
678    }
679
680    /// Extracts the primary variable from an expression.
681    fn extract_variable_from_expr(&self, expr: &LogicalExpression) -> Option<String> {
682        match expr {
683            LogicalExpression::Variable(v) => Some(v.clone()),
684            LogicalExpression::Property { variable, .. } => Some(variable.clone()),
685            _ => None,
686        }
687    }
688
689    /// Optimizes the join order using DPccp, or produces a multi-way
690    /// leapfrog join for cyclic patterns when the cost model prefers it.
691    fn optimize_join_order(
692        &self,
693        relations: &[(String, LogicalOperator)],
694        conditions: &[JoinInfo],
695    ) -> Option<LogicalOperator> {
696        use join_order::{DPccp, JoinGraphBuilder};
697
698        // Build the join graph
699        let mut builder = JoinGraphBuilder::new();
700
701        for (var, relation) in relations {
702            builder.add_relation(var, relation.clone());
703        }
704
705        for cond in conditions {
706            builder.add_join_condition(
707                &cond.left_var,
708                &cond.right_var,
709                cond.left_expr.clone(),
710                cond.right_expr.clone(),
711            );
712        }
713
714        let graph = builder.build();
715
716        // For cyclic graphs with 3+ relations, use leapfrog (WCOJ) join.
717        // Cyclic joins (e.g. triangle patterns) benefit from worst-case optimal
718        // multi-way intersection rather than binary hash join cascades that can
719        // produce intermediate blowup.
720        if graph.is_cyclic() && relations.len() >= 3 {
721            // Collect shared variables (variables appearing in 2+ conditions)
722            let mut var_counts: std::collections::HashMap<&str, usize> =
723                std::collections::HashMap::new();
724            for cond in conditions {
725                *var_counts.entry(&cond.left_var).or_default() += 1;
726                *var_counts.entry(&cond.right_var).or_default() += 1;
727            }
728            let shared_variables: Vec<String> = var_counts
729                .into_iter()
730                .filter(|(_, count)| *count >= 2)
731                .map(|(var, _)| var.to_string())
732                .collect();
733
734            let join_conditions: Vec<JoinCondition> = conditions
735                .iter()
736                .map(|c| JoinCondition {
737                    left: c.left_expr.clone(),
738                    right: c.right_expr.clone(),
739                })
740                .collect();
741
742            return Some(LogicalOperator::MultiWayJoin(MultiWayJoinOp {
743                inputs: relations.iter().map(|(_, rel)| rel.clone()).collect(),
744                conditions: join_conditions,
745                shared_variables,
746            }));
747        }
748
749        // Fall through to DPccp for binary join ordering
750        let mut dpccp = DPccp::new(&graph, &self.cost_model, &self.card_estimator);
751        let plan = dpccp.optimize()?;
752
753        Some(plan.operator)
754    }
755
756    /// Pushes filters down the operator tree.
757    ///
758    /// This optimization moves filter predicates as close to the data source
759    /// as possible to reduce the amount of data processed by upper operators.
760    fn push_filters_down(&self, op: LogicalOperator) -> LogicalOperator {
761        match op {
762            // For Filter operators, try to push the predicate into the child
763            LogicalOperator::Filter(filter) => {
764                let optimized_input = self.push_filters_down(*filter.input);
765                self.try_push_filter_into(filter.predicate, optimized_input)
766            }
767            // Recursively optimize children for other operators
768            LogicalOperator::Return(mut ret) => {
769                ret.input = Box::new(self.push_filters_down(*ret.input));
770                LogicalOperator::Return(ret)
771            }
772            LogicalOperator::Project(mut proj) => {
773                proj.input = Box::new(self.push_filters_down(*proj.input));
774                LogicalOperator::Project(proj)
775            }
776            LogicalOperator::Limit(mut limit) => {
777                limit.input = Box::new(self.push_filters_down(*limit.input));
778                LogicalOperator::Limit(limit)
779            }
780            LogicalOperator::Skip(mut skip) => {
781                skip.input = Box::new(self.push_filters_down(*skip.input));
782                LogicalOperator::Skip(skip)
783            }
784            LogicalOperator::Sort(mut sort) => {
785                sort.input = Box::new(self.push_filters_down(*sort.input));
786                LogicalOperator::Sort(sort)
787            }
788            LogicalOperator::Distinct(mut distinct) => {
789                distinct.input = Box::new(self.push_filters_down(*distinct.input));
790                LogicalOperator::Distinct(distinct)
791            }
792            LogicalOperator::Expand(mut expand) => {
793                expand.input = Box::new(self.push_filters_down(*expand.input));
794                LogicalOperator::Expand(expand)
795            }
796            LogicalOperator::Join(mut join) => {
797                join.left = Box::new(self.push_filters_down(*join.left));
798                join.right = Box::new(self.push_filters_down(*join.right));
799                LogicalOperator::Join(join)
800            }
801            LogicalOperator::Aggregate(mut agg) => {
802                agg.input = Box::new(self.push_filters_down(*agg.input));
803                LogicalOperator::Aggregate(agg)
804            }
805            LogicalOperator::MapCollect(mut mc) => {
806                mc.input = Box::new(self.push_filters_down(*mc.input));
807                LogicalOperator::MapCollect(mc)
808            }
809            LogicalOperator::MultiWayJoin(mut mwj) => {
810                mwj.inputs = mwj
811                    .inputs
812                    .into_iter()
813                    .map(|input| self.push_filters_down(input))
814                    .collect();
815                LogicalOperator::MultiWayJoin(mwj)
816            }
817            // Leaf operators and unsupported operators are returned as-is
818            other => other,
819        }
820    }
821
822    /// Tries to push a filter predicate into the given operator.
823    ///
824    /// Returns either the predicate pushed into the operator, or a new
825    /// Filter operator on top if the predicate cannot be pushed further.
826    fn try_push_filter_into(
827        &self,
828        predicate: LogicalExpression,
829        op: LogicalOperator,
830    ) -> LogicalOperator {
831        match op {
832            // Can push through Project if predicate doesn't depend on computed columns
833            LogicalOperator::Project(mut proj) => {
834                let predicate_vars = self.extract_variables(&predicate);
835                let computed_vars = self.extract_projection_aliases(&proj.projections);
836
837                // If predicate doesn't use any computed columns, push through
838                if predicate_vars.is_disjoint(&computed_vars) {
839                    proj.input = Box::new(self.try_push_filter_into(predicate, *proj.input));
840                    LogicalOperator::Project(proj)
841                } else {
842                    // Can't push through, keep filter on top
843                    LogicalOperator::Filter(FilterOp {
844                        predicate,
845                        pushdown_hint: None,
846                        input: Box::new(LogicalOperator::Project(proj)),
847                    })
848                }
849            }
850
851            // Can push through Return (which is like a projection)
852            LogicalOperator::Return(mut ret) => {
853                ret.input = Box::new(self.try_push_filter_into(predicate, *ret.input));
854                LogicalOperator::Return(ret)
855            }
856
857            // Can push through Expand if predicate doesn't use variables introduced by this expand
858            LogicalOperator::Expand(mut expand) => {
859                let predicate_vars = self.extract_variables(&predicate);
860
861                // Variables introduced by this expand are:
862                // - The target variable (to_variable)
863                // - The edge variable (if any)
864                // - The path alias (if any)
865                let mut introduced_vars = vec![&expand.to_variable];
866                if let Some(ref edge_var) = expand.edge_variable {
867                    introduced_vars.push(edge_var);
868                }
869                if let Some(ref path_alias) = expand.path_alias {
870                    introduced_vars.push(path_alias);
871                }
872
873                // Check if predicate uses any variables introduced by this expand
874                let uses_introduced_vars =
875                    predicate_vars.iter().any(|v| introduced_vars.contains(&v));
876
877                if !uses_introduced_vars {
878                    // Predicate doesn't use vars from this expand, so push through
879                    expand.input = Box::new(self.try_push_filter_into(predicate, *expand.input));
880                    LogicalOperator::Expand(expand)
881                } else {
882                    // Keep filter after expand
883                    LogicalOperator::Filter(FilterOp {
884                        predicate,
885                        pushdown_hint: None,
886                        input: Box::new(LogicalOperator::Expand(expand)),
887                    })
888                }
889            }
890
891            // Can push through Join to left/right side based on variables used
892            LogicalOperator::Join(mut join) => {
893                let predicate_vars = self.extract_variables(&predicate);
894                let left_vars = self.collect_output_variables(&join.left);
895                let right_vars = self.collect_output_variables(&join.right);
896
897                let uses_left = predicate_vars.iter().any(|v| left_vars.contains(v));
898                let uses_right = predicate_vars.iter().any(|v| right_vars.contains(v));
899
900                if uses_left && !uses_right {
901                    // Push to left side
902                    join.left = Box::new(self.try_push_filter_into(predicate, *join.left));
903                    LogicalOperator::Join(join)
904                } else if uses_right && !uses_left {
905                    // Push to right side
906                    join.right = Box::new(self.try_push_filter_into(predicate, *join.right));
907                    LogicalOperator::Join(join)
908                } else {
909                    // Uses both sides - keep above join
910                    LogicalOperator::Filter(FilterOp {
911                        predicate,
912                        pushdown_hint: None,
913                        input: Box::new(LogicalOperator::Join(join)),
914                    })
915                }
916            }
917
918            // Cannot push through Aggregate (predicate refers to aggregated values)
919            LogicalOperator::Aggregate(agg) => LogicalOperator::Filter(FilterOp {
920                predicate,
921                pushdown_hint: None,
922                input: Box::new(LogicalOperator::Aggregate(agg)),
923            }),
924
925            // For NodeScan, we've reached the bottom - keep filter on top
926            LogicalOperator::NodeScan(scan) => LogicalOperator::Filter(FilterOp {
927                predicate,
928                pushdown_hint: None,
929                input: Box::new(LogicalOperator::NodeScan(scan)),
930            }),
931
932            // For other operators, keep filter on top
933            other => LogicalOperator::Filter(FilterOp {
934                predicate,
935                pushdown_hint: None,
936                input: Box::new(other),
937            }),
938        }
939    }
940
941    /// Collects all output variable names from an operator.
942    fn collect_output_variables(&self, op: &LogicalOperator) -> HashSet<String> {
943        let mut vars = HashSet::new();
944        Self::collect_output_variables_recursive(op, &mut vars);
945        vars
946    }
947
948    /// Recursively collects output variables from an operator.
949    fn collect_output_variables_recursive(op: &LogicalOperator, vars: &mut HashSet<String>) {
950        match op {
951            LogicalOperator::NodeScan(scan) => {
952                vars.insert(scan.variable.clone());
953            }
954            LogicalOperator::EdgeScan(scan) => {
955                vars.insert(scan.variable.clone());
956            }
957            LogicalOperator::Expand(expand) => {
958                vars.insert(expand.to_variable.clone());
959                if let Some(edge_var) = &expand.edge_variable {
960                    vars.insert(edge_var.clone());
961                }
962                Self::collect_output_variables_recursive(&expand.input, vars);
963            }
964            LogicalOperator::Filter(filter) => {
965                Self::collect_output_variables_recursive(&filter.input, vars);
966            }
967            LogicalOperator::Project(proj) => {
968                for p in &proj.projections {
969                    if let Some(alias) = &p.alias {
970                        vars.insert(alias.clone());
971                    }
972                }
973                Self::collect_output_variables_recursive(&proj.input, vars);
974            }
975            LogicalOperator::Join(join) => {
976                Self::collect_output_variables_recursive(&join.left, vars);
977                Self::collect_output_variables_recursive(&join.right, vars);
978            }
979            LogicalOperator::Aggregate(agg) => {
980                for expr in &agg.group_by {
981                    Self::collect_variables(expr, vars);
982                }
983                for agg_expr in &agg.aggregates {
984                    if let Some(alias) = &agg_expr.alias {
985                        vars.insert(alias.clone());
986                    }
987                }
988            }
989            LogicalOperator::Return(ret) => {
990                Self::collect_output_variables_recursive(&ret.input, vars);
991            }
992            LogicalOperator::Limit(limit) => {
993                Self::collect_output_variables_recursive(&limit.input, vars);
994            }
995            LogicalOperator::Skip(skip) => {
996                Self::collect_output_variables_recursive(&skip.input, vars);
997            }
998            LogicalOperator::Sort(sort) => {
999                Self::collect_output_variables_recursive(&sort.input, vars);
1000            }
1001            LogicalOperator::Distinct(distinct) => {
1002                Self::collect_output_variables_recursive(&distinct.input, vars);
1003            }
1004            _ => {}
1005        }
1006    }
1007
1008    /// Extracts all variable names referenced in an expression.
1009    fn extract_variables(&self, expr: &LogicalExpression) -> HashSet<String> {
1010        let mut vars = HashSet::new();
1011        Self::collect_variables(expr, &mut vars);
1012        vars
1013    }
1014
1015    /// Recursively collects variable names from an expression.
1016    fn collect_variables(expr: &LogicalExpression, vars: &mut HashSet<String>) {
1017        match expr {
1018            LogicalExpression::Variable(name) => {
1019                vars.insert(name.clone());
1020            }
1021            LogicalExpression::Property { variable, .. } => {
1022                vars.insert(variable.clone());
1023            }
1024            LogicalExpression::Binary { left, right, .. } => {
1025                Self::collect_variables(left, vars);
1026                Self::collect_variables(right, vars);
1027            }
1028            LogicalExpression::Unary { operand, .. } => {
1029                Self::collect_variables(operand, vars);
1030            }
1031            LogicalExpression::FunctionCall { args, .. } => {
1032                for arg in args {
1033                    Self::collect_variables(arg, vars);
1034                }
1035            }
1036            LogicalExpression::List(items) => {
1037                for item in items {
1038                    Self::collect_variables(item, vars);
1039                }
1040            }
1041            LogicalExpression::Map(pairs) => {
1042                for (_, value) in pairs {
1043                    Self::collect_variables(value, vars);
1044                }
1045            }
1046            LogicalExpression::IndexAccess { base, index } => {
1047                Self::collect_variables(base, vars);
1048                Self::collect_variables(index, vars);
1049            }
1050            LogicalExpression::SliceAccess { base, start, end } => {
1051                Self::collect_variables(base, vars);
1052                if let Some(s) = start {
1053                    Self::collect_variables(s, vars);
1054                }
1055                if let Some(e) = end {
1056                    Self::collect_variables(e, vars);
1057                }
1058            }
1059            LogicalExpression::Case {
1060                operand,
1061                when_clauses,
1062                else_clause,
1063            } => {
1064                if let Some(op) = operand {
1065                    Self::collect_variables(op, vars);
1066                }
1067                for (cond, result) in when_clauses {
1068                    Self::collect_variables(cond, vars);
1069                    Self::collect_variables(result, vars);
1070                }
1071                if let Some(else_expr) = else_clause {
1072                    Self::collect_variables(else_expr, vars);
1073                }
1074            }
1075            LogicalExpression::Labels(var)
1076            | LogicalExpression::Type(var)
1077            | LogicalExpression::Id(var) => {
1078                vars.insert(var.clone());
1079            }
1080            LogicalExpression::Literal(_) | LogicalExpression::Parameter(_) => {}
1081            LogicalExpression::ListComprehension {
1082                list_expr,
1083                filter_expr,
1084                map_expr,
1085                ..
1086            } => {
1087                Self::collect_variables(list_expr, vars);
1088                if let Some(filter) = filter_expr {
1089                    Self::collect_variables(filter, vars);
1090                }
1091                Self::collect_variables(map_expr, vars);
1092            }
1093            LogicalExpression::ListPredicate {
1094                list_expr,
1095                predicate,
1096                ..
1097            } => {
1098                Self::collect_variables(list_expr, vars);
1099                Self::collect_variables(predicate, vars);
1100            }
1101            LogicalExpression::ExistsSubquery(_)
1102            | LogicalExpression::CountSubquery(_)
1103            | LogicalExpression::ValueSubquery(_) => {
1104                // Subqueries have their own variable scope
1105            }
1106            LogicalExpression::PatternComprehension { projection, .. } => {
1107                Self::collect_variables(projection, vars);
1108            }
1109            LogicalExpression::MapProjection { base, entries } => {
1110                vars.insert(base.clone());
1111                for entry in entries {
1112                    if let crate::query::plan::MapProjectionEntry::LiteralEntry(_, expr) = entry {
1113                        Self::collect_variables(expr, vars);
1114                    }
1115                }
1116            }
1117            LogicalExpression::Reduce {
1118                initial,
1119                list,
1120                expression,
1121                ..
1122            } => {
1123                Self::collect_variables(initial, vars);
1124                Self::collect_variables(list, vars);
1125                Self::collect_variables(expression, vars);
1126            }
1127        }
1128    }
1129
1130    /// Extracts aliases from projection expressions.
1131    fn extract_projection_aliases(
1132        &self,
1133        projections: &[crate::query::plan::Projection],
1134    ) -> HashSet<String> {
1135        projections.iter().filter_map(|p| p.alias.clone()).collect()
1136    }
1137}
1138
1139impl Default for Optimizer {
1140    fn default() -> Self {
1141        Self::new()
1142    }
1143}
1144
1145#[cfg(test)]
1146mod tests {
1147    use super::*;
1148    use crate::query::plan::{
1149        AggregateExpr, AggregateFunction, AggregateOp, BinaryOp, DistinctOp, ExpandDirection,
1150        ExpandOp, JoinOp, JoinType, LimitOp, NodeScanOp, PathMode, ProjectOp, Projection,
1151        ReturnItem, ReturnOp, SkipOp, SortKey, SortOp, SortOrder, UnaryOp,
1152    };
1153    use grafeo_common::types::Value;
1154
1155    #[test]
1156    fn test_optimizer_filter_pushdown_simple() {
1157        // Query: MATCH (n:Person) WHERE n.age > 30 RETURN n
1158        // Before: Return -> Filter -> NodeScan
1159        // After:  Return -> Filter -> NodeScan (filter stays at bottom)
1160
1161        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1162            items: vec![ReturnItem {
1163                expression: LogicalExpression::Variable("n".to_string()),
1164                alias: None,
1165            }],
1166            distinct: false,
1167            input: Box::new(LogicalOperator::Filter(FilterOp {
1168                predicate: LogicalExpression::Binary {
1169                    left: Box::new(LogicalExpression::Property {
1170                        variable: "n".to_string(),
1171                        property: "age".to_string(),
1172                    }),
1173                    op: BinaryOp::Gt,
1174                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1175                },
1176                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1177                    variable: "n".to_string(),
1178                    label: Some("Person".to_string()),
1179                    input: None,
1180                })),
1181                pushdown_hint: None,
1182            })),
1183        }));
1184
1185        let optimizer = Optimizer::new();
1186        let optimized = optimizer.optimize(plan).unwrap();
1187
1188        // The structure should remain similar (filter stays near scan)
1189        if let LogicalOperator::Return(ret) = &optimized.root
1190            && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1191            && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1192        {
1193            assert_eq!(scan.variable, "n");
1194            return;
1195        }
1196        panic!("Expected Return -> Filter -> NodeScan structure");
1197    }
1198
1199    #[test]
1200    fn test_optimizer_filter_pushdown_through_expand() {
1201        // Query: MATCH (a:Person)-[:KNOWS]->(b) WHERE a.age > 30 RETURN b
1202        // The filter on 'a' should be pushed before the expand
1203
1204        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1205            items: vec![ReturnItem {
1206                expression: LogicalExpression::Variable("b".to_string()),
1207                alias: None,
1208            }],
1209            distinct: false,
1210            input: Box::new(LogicalOperator::Filter(FilterOp {
1211                predicate: LogicalExpression::Binary {
1212                    left: Box::new(LogicalExpression::Property {
1213                        variable: "a".to_string(),
1214                        property: "age".to_string(),
1215                    }),
1216                    op: BinaryOp::Gt,
1217                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1218                },
1219                pushdown_hint: None,
1220                input: Box::new(LogicalOperator::Expand(ExpandOp {
1221                    from_variable: "a".to_string(),
1222                    to_variable: "b".to_string(),
1223                    edge_variable: None,
1224                    direction: ExpandDirection::Outgoing,
1225                    edge_types: vec!["KNOWS".to_string()],
1226                    min_hops: 1,
1227                    max_hops: Some(1),
1228                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1229                        variable: "a".to_string(),
1230                        label: Some("Person".to_string()),
1231                        input: None,
1232                    })),
1233                    path_alias: None,
1234                    path_mode: PathMode::Walk,
1235                })),
1236            })),
1237        }));
1238
1239        let optimizer = Optimizer::new();
1240        let optimized = optimizer.optimize(plan).unwrap();
1241
1242        // Filter on 'a' should be pushed before the expand
1243        // Expected: Return -> Expand -> Filter -> NodeScan
1244        if let LogicalOperator::Return(ret) = &optimized.root
1245            && let LogicalOperator::Expand(expand) = ret.input.as_ref()
1246            && let LogicalOperator::Filter(filter) = expand.input.as_ref()
1247            && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1248        {
1249            assert_eq!(scan.variable, "a");
1250            assert_eq!(expand.from_variable, "a");
1251            assert_eq!(expand.to_variable, "b");
1252            return;
1253        }
1254        panic!("Expected Return -> Expand -> Filter -> NodeScan structure");
1255    }
1256
1257    #[test]
1258    fn test_optimizer_filter_not_pushed_through_expand_for_target_var() {
1259        // Query: MATCH (a:Person)-[:KNOWS]->(b) WHERE b.age > 30 RETURN a
1260        // The filter on 'b' should NOT be pushed before the expand
1261
1262        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1263            items: vec![ReturnItem {
1264                expression: LogicalExpression::Variable("a".to_string()),
1265                alias: None,
1266            }],
1267            distinct: false,
1268            input: Box::new(LogicalOperator::Filter(FilterOp {
1269                predicate: LogicalExpression::Binary {
1270                    left: Box::new(LogicalExpression::Property {
1271                        variable: "b".to_string(),
1272                        property: "age".to_string(),
1273                    }),
1274                    op: BinaryOp::Gt,
1275                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1276                },
1277                pushdown_hint: None,
1278                input: Box::new(LogicalOperator::Expand(ExpandOp {
1279                    from_variable: "a".to_string(),
1280                    to_variable: "b".to_string(),
1281                    edge_variable: None,
1282                    direction: ExpandDirection::Outgoing,
1283                    edge_types: vec!["KNOWS".to_string()],
1284                    min_hops: 1,
1285                    max_hops: Some(1),
1286                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1287                        variable: "a".to_string(),
1288                        label: Some("Person".to_string()),
1289                        input: None,
1290                    })),
1291                    path_alias: None,
1292                    path_mode: PathMode::Walk,
1293                })),
1294            })),
1295        }));
1296
1297        let optimizer = Optimizer::new();
1298        let optimized = optimizer.optimize(plan).unwrap();
1299
1300        // Filter on 'b' should stay after the expand
1301        // Expected: Return -> Filter -> Expand -> NodeScan
1302        if let LogicalOperator::Return(ret) = &optimized.root
1303            && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1304        {
1305            // Check that the filter is on 'b'
1306            if let LogicalExpression::Binary { left, .. } = &filter.predicate
1307                && let LogicalExpression::Property { variable, .. } = left.as_ref()
1308            {
1309                assert_eq!(variable, "b");
1310            }
1311
1312            if let LogicalOperator::Expand(expand) = filter.input.as_ref()
1313                && let LogicalOperator::NodeScan(_) = expand.input.as_ref()
1314            {
1315                return;
1316            }
1317        }
1318        panic!("Expected Return -> Filter -> Expand -> NodeScan structure");
1319    }
1320
1321    #[test]
1322    fn test_optimizer_extract_variables() {
1323        let optimizer = Optimizer::new();
1324
1325        let expr = LogicalExpression::Binary {
1326            left: Box::new(LogicalExpression::Property {
1327                variable: "n".to_string(),
1328                property: "age".to_string(),
1329            }),
1330            op: BinaryOp::Gt,
1331            right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1332        };
1333
1334        let vars = optimizer.extract_variables(&expr);
1335        assert_eq!(vars.len(), 1);
1336        assert!(vars.contains("n"));
1337    }
1338
1339    // Additional tests for optimizer configuration
1340
1341    #[test]
1342    fn test_optimizer_default() {
1343        let optimizer = Optimizer::default();
1344        // Should be able to optimize an empty plan
1345        let plan = LogicalPlan::new(LogicalOperator::Empty);
1346        let result = optimizer.optimize(plan);
1347        assert!(result.is_ok());
1348    }
1349
1350    #[test]
1351    fn test_optimizer_with_filter_pushdown_disabled() {
1352        let optimizer = Optimizer::new().with_filter_pushdown(false);
1353
1354        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1355            items: vec![ReturnItem {
1356                expression: LogicalExpression::Variable("n".to_string()),
1357                alias: None,
1358            }],
1359            distinct: false,
1360            input: Box::new(LogicalOperator::Filter(FilterOp {
1361                predicate: LogicalExpression::Literal(Value::Bool(true)),
1362                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1363                    variable: "n".to_string(),
1364                    label: None,
1365                    input: None,
1366                })),
1367                pushdown_hint: None,
1368            })),
1369        }));
1370
1371        let optimized = optimizer.optimize(plan).unwrap();
1372        // Structure should be unchanged
1373        if let LogicalOperator::Return(ret) = &optimized.root
1374            && let LogicalOperator::Filter(_) = ret.input.as_ref()
1375        {
1376            return;
1377        }
1378        panic!("Expected unchanged structure");
1379    }
1380
1381    #[test]
1382    fn test_optimizer_with_join_reorder_disabled() {
1383        let optimizer = Optimizer::new().with_join_reorder(false);
1384        assert!(
1385            optimizer
1386                .optimize(LogicalPlan::new(LogicalOperator::Empty))
1387                .is_ok()
1388        );
1389    }
1390
1391    #[test]
1392    fn test_optimizer_with_cost_model() {
1393        let cost_model = CostModel::new();
1394        let optimizer = Optimizer::new().with_cost_model(cost_model);
1395        assert!(
1396            optimizer
1397                .cost_model()
1398                .estimate(&LogicalOperator::Empty, 0.0)
1399                .total()
1400                < 0.001
1401        );
1402    }
1403
1404    #[test]
1405    fn test_optimizer_with_cardinality_estimator() {
1406        let mut estimator = CardinalityEstimator::new();
1407        estimator.add_table_stats("Test", TableStats::new(500));
1408        let optimizer = Optimizer::new().with_cardinality_estimator(estimator);
1409
1410        let scan = LogicalOperator::NodeScan(NodeScanOp {
1411            variable: "n".to_string(),
1412            label: Some("Test".to_string()),
1413            input: None,
1414        });
1415        let plan = LogicalPlan::new(scan);
1416
1417        let cardinality = optimizer.estimate_cardinality(&plan);
1418        assert!((cardinality - 500.0).abs() < 0.001);
1419    }
1420
1421    #[test]
1422    fn test_optimizer_estimate_cost() {
1423        let optimizer = Optimizer::new();
1424        let plan = LogicalPlan::new(LogicalOperator::NodeScan(NodeScanOp {
1425            variable: "n".to_string(),
1426            label: None,
1427            input: None,
1428        }));
1429
1430        let cost = optimizer.estimate_cost(&plan);
1431        assert!(cost.total() > 0.0);
1432    }
1433
1434    // Filter pushdown through various operators
1435
1436    #[test]
1437    fn test_filter_pushdown_through_project() {
1438        let optimizer = Optimizer::new();
1439
1440        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1441            predicate: LogicalExpression::Binary {
1442                left: Box::new(LogicalExpression::Property {
1443                    variable: "n".to_string(),
1444                    property: "age".to_string(),
1445                }),
1446                op: BinaryOp::Gt,
1447                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1448            },
1449            pushdown_hint: None,
1450            input: Box::new(LogicalOperator::Project(ProjectOp {
1451                projections: vec![Projection {
1452                    expression: LogicalExpression::Variable("n".to_string()),
1453                    alias: None,
1454                }],
1455                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1456                    variable: "n".to_string(),
1457                    label: None,
1458                    input: None,
1459                })),
1460                pass_through_input: false,
1461            })),
1462        }));
1463
1464        let optimized = optimizer.optimize(plan).unwrap();
1465
1466        // Filter should be pushed through Project
1467        if let LogicalOperator::Project(proj) = &optimized.root
1468            && let LogicalOperator::Filter(_) = proj.input.as_ref()
1469        {
1470            return;
1471        }
1472        panic!("Expected Project -> Filter structure");
1473    }
1474
1475    #[test]
1476    fn test_filter_not_pushed_through_project_with_alias() {
1477        let optimizer = Optimizer::new();
1478
1479        // Filter on computed column 'x' should not be pushed through project that creates 'x'
1480        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1481            predicate: LogicalExpression::Binary {
1482                left: Box::new(LogicalExpression::Variable("x".to_string())),
1483                op: BinaryOp::Gt,
1484                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1485            },
1486            pushdown_hint: None,
1487            input: Box::new(LogicalOperator::Project(ProjectOp {
1488                projections: vec![Projection {
1489                    expression: LogicalExpression::Property {
1490                        variable: "n".to_string(),
1491                        property: "age".to_string(),
1492                    },
1493                    alias: Some("x".to_string()),
1494                }],
1495                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1496                    variable: "n".to_string(),
1497                    label: None,
1498                    input: None,
1499                })),
1500                pass_through_input: false,
1501            })),
1502        }));
1503
1504        let optimized = optimizer.optimize(plan).unwrap();
1505
1506        // Filter should stay above Project
1507        if let LogicalOperator::Filter(filter) = &optimized.root
1508            && let LogicalOperator::Project(_) = filter.input.as_ref()
1509        {
1510            return;
1511        }
1512        panic!("Expected Filter -> Project structure");
1513    }
1514
1515    #[test]
1516    fn test_filter_pushdown_through_limit() {
1517        let optimizer = Optimizer::new();
1518
1519        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1520            predicate: LogicalExpression::Literal(Value::Bool(true)),
1521            pushdown_hint: None,
1522            input: Box::new(LogicalOperator::Limit(LimitOp {
1523                count: 10.into(),
1524                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1525                    variable: "n".to_string(),
1526                    label: None,
1527                    input: None,
1528                })),
1529            })),
1530        }));
1531
1532        let optimized = optimizer.optimize(plan).unwrap();
1533
1534        // Filter stays above Limit (cannot be pushed through)
1535        if let LogicalOperator::Filter(filter) = &optimized.root
1536            && let LogicalOperator::Limit(_) = filter.input.as_ref()
1537        {
1538            return;
1539        }
1540        panic!("Expected Filter -> Limit structure");
1541    }
1542
1543    #[test]
1544    fn test_filter_pushdown_through_sort() {
1545        let optimizer = Optimizer::new();
1546
1547        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1548            predicate: LogicalExpression::Literal(Value::Bool(true)),
1549            pushdown_hint: None,
1550            input: Box::new(LogicalOperator::Sort(SortOp {
1551                keys: vec![SortKey {
1552                    expression: LogicalExpression::Variable("n".to_string()),
1553                    order: SortOrder::Ascending,
1554                    nulls: None,
1555                }],
1556                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1557                    variable: "n".to_string(),
1558                    label: None,
1559                    input: None,
1560                })),
1561            })),
1562        }));
1563
1564        let optimized = optimizer.optimize(plan).unwrap();
1565
1566        // Filter stays above Sort
1567        if let LogicalOperator::Filter(filter) = &optimized.root
1568            && let LogicalOperator::Sort(_) = filter.input.as_ref()
1569        {
1570            return;
1571        }
1572        panic!("Expected Filter -> Sort structure");
1573    }
1574
1575    #[test]
1576    fn test_filter_pushdown_through_distinct() {
1577        let optimizer = Optimizer::new();
1578
1579        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1580            predicate: LogicalExpression::Literal(Value::Bool(true)),
1581            pushdown_hint: None,
1582            input: Box::new(LogicalOperator::Distinct(DistinctOp {
1583                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1584                    variable: "n".to_string(),
1585                    label: None,
1586                    input: None,
1587                })),
1588                columns: None,
1589            })),
1590        }));
1591
1592        let optimized = optimizer.optimize(plan).unwrap();
1593
1594        // Filter stays above Distinct
1595        if let LogicalOperator::Filter(filter) = &optimized.root
1596            && let LogicalOperator::Distinct(_) = filter.input.as_ref()
1597        {
1598            return;
1599        }
1600        panic!("Expected Filter -> Distinct structure");
1601    }
1602
1603    #[test]
1604    fn test_filter_not_pushed_through_aggregate() {
1605        let optimizer = Optimizer::new();
1606
1607        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1608            predicate: LogicalExpression::Binary {
1609                left: Box::new(LogicalExpression::Variable("cnt".to_string())),
1610                op: BinaryOp::Gt,
1611                right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1612            },
1613            pushdown_hint: None,
1614            input: Box::new(LogicalOperator::Aggregate(AggregateOp {
1615                group_by: vec![],
1616                aggregates: vec![AggregateExpr {
1617                    function: AggregateFunction::Count,
1618                    expression: None,
1619                    expression2: None,
1620                    distinct: false,
1621                    alias: Some("cnt".to_string()),
1622                    percentile: None,
1623                    separator: None,
1624                }],
1625                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1626                    variable: "n".to_string(),
1627                    label: None,
1628                    input: None,
1629                })),
1630                having: None,
1631            })),
1632        }));
1633
1634        let optimized = optimizer.optimize(plan).unwrap();
1635
1636        // Filter should stay above Aggregate
1637        if let LogicalOperator::Filter(filter) = &optimized.root
1638            && let LogicalOperator::Aggregate(_) = filter.input.as_ref()
1639        {
1640            return;
1641        }
1642        panic!("Expected Filter -> Aggregate structure");
1643    }
1644
1645    #[test]
1646    fn test_filter_pushdown_to_left_join_side() {
1647        let optimizer = Optimizer::new();
1648
1649        // Filter on left variable should be pushed to left side
1650        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1651            predicate: LogicalExpression::Binary {
1652                left: Box::new(LogicalExpression::Property {
1653                    variable: "a".to_string(),
1654                    property: "age".to_string(),
1655                }),
1656                op: BinaryOp::Gt,
1657                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1658            },
1659            pushdown_hint: None,
1660            input: Box::new(LogicalOperator::Join(JoinOp {
1661                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1662                    variable: "a".to_string(),
1663                    label: Some("Person".to_string()),
1664                    input: None,
1665                })),
1666                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1667                    variable: "b".to_string(),
1668                    label: Some("Company".to_string()),
1669                    input: None,
1670                })),
1671                join_type: JoinType::Inner,
1672                conditions: vec![],
1673            })),
1674        }));
1675
1676        let optimized = optimizer.optimize(plan).unwrap();
1677
1678        // Filter should be pushed to left side of join
1679        if let LogicalOperator::Join(join) = &optimized.root
1680            && let LogicalOperator::Filter(_) = join.left.as_ref()
1681        {
1682            return;
1683        }
1684        panic!("Expected Join with Filter on left side");
1685    }
1686
1687    #[test]
1688    fn test_filter_pushdown_to_right_join_side() {
1689        let optimizer = Optimizer::new();
1690
1691        // Filter on right variable should be pushed to right side
1692        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1693            predicate: LogicalExpression::Binary {
1694                left: Box::new(LogicalExpression::Property {
1695                    variable: "b".to_string(),
1696                    property: "name".to_string(),
1697                }),
1698                op: BinaryOp::Eq,
1699                right: Box::new(LogicalExpression::Literal(Value::String("Acme".into()))),
1700            },
1701            pushdown_hint: None,
1702            input: Box::new(LogicalOperator::Join(JoinOp {
1703                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1704                    variable: "a".to_string(),
1705                    label: Some("Person".to_string()),
1706                    input: None,
1707                })),
1708                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1709                    variable: "b".to_string(),
1710                    label: Some("Company".to_string()),
1711                    input: None,
1712                })),
1713                join_type: JoinType::Inner,
1714                conditions: vec![],
1715            })),
1716        }));
1717
1718        let optimized = optimizer.optimize(plan).unwrap();
1719
1720        // Filter should be pushed to right side of join
1721        if let LogicalOperator::Join(join) = &optimized.root
1722            && let LogicalOperator::Filter(_) = join.right.as_ref()
1723        {
1724            return;
1725        }
1726        panic!("Expected Join with Filter on right side");
1727    }
1728
1729    #[test]
1730    fn test_filter_not_pushed_when_uses_both_join_sides() {
1731        let optimizer = Optimizer::new();
1732
1733        // Filter using both variables should stay above join
1734        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1735            predicate: LogicalExpression::Binary {
1736                left: Box::new(LogicalExpression::Property {
1737                    variable: "a".to_string(),
1738                    property: "id".to_string(),
1739                }),
1740                op: BinaryOp::Eq,
1741                right: Box::new(LogicalExpression::Property {
1742                    variable: "b".to_string(),
1743                    property: "a_id".to_string(),
1744                }),
1745            },
1746            pushdown_hint: None,
1747            input: Box::new(LogicalOperator::Join(JoinOp {
1748                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1749                    variable: "a".to_string(),
1750                    label: None,
1751                    input: None,
1752                })),
1753                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1754                    variable: "b".to_string(),
1755                    label: None,
1756                    input: None,
1757                })),
1758                join_type: JoinType::Inner,
1759                conditions: vec![],
1760            })),
1761        }));
1762
1763        let optimized = optimizer.optimize(plan).unwrap();
1764
1765        // Filter should stay above join
1766        if let LogicalOperator::Filter(filter) = &optimized.root
1767            && let LogicalOperator::Join(_) = filter.input.as_ref()
1768        {
1769            return;
1770        }
1771        panic!("Expected Filter -> Join structure");
1772    }
1773
1774    // Variable extraction tests
1775
1776    #[test]
1777    fn test_extract_variables_from_variable() {
1778        let optimizer = Optimizer::new();
1779        let expr = LogicalExpression::Variable("x".to_string());
1780        let vars = optimizer.extract_variables(&expr);
1781        assert_eq!(vars.len(), 1);
1782        assert!(vars.contains("x"));
1783    }
1784
1785    #[test]
1786    fn test_extract_variables_from_unary() {
1787        let optimizer = Optimizer::new();
1788        let expr = LogicalExpression::Unary {
1789            op: UnaryOp::Not,
1790            operand: Box::new(LogicalExpression::Variable("x".to_string())),
1791        };
1792        let vars = optimizer.extract_variables(&expr);
1793        assert_eq!(vars.len(), 1);
1794        assert!(vars.contains("x"));
1795    }
1796
1797    #[test]
1798    fn test_extract_variables_from_function_call() {
1799        let optimizer = Optimizer::new();
1800        let expr = LogicalExpression::FunctionCall {
1801            name: "length".to_string(),
1802            args: vec![
1803                LogicalExpression::Variable("a".to_string()),
1804                LogicalExpression::Variable("b".to_string()),
1805            ],
1806            distinct: false,
1807        };
1808        let vars = optimizer.extract_variables(&expr);
1809        assert_eq!(vars.len(), 2);
1810        assert!(vars.contains("a"));
1811        assert!(vars.contains("b"));
1812    }
1813
1814    #[test]
1815    fn test_extract_variables_from_list() {
1816        let optimizer = Optimizer::new();
1817        let expr = LogicalExpression::List(vec![
1818            LogicalExpression::Variable("a".to_string()),
1819            LogicalExpression::Literal(Value::Int64(1)),
1820            LogicalExpression::Variable("b".to_string()),
1821        ]);
1822        let vars = optimizer.extract_variables(&expr);
1823        assert_eq!(vars.len(), 2);
1824        assert!(vars.contains("a"));
1825        assert!(vars.contains("b"));
1826    }
1827
1828    #[test]
1829    fn test_extract_variables_from_map() {
1830        let optimizer = Optimizer::new();
1831        let expr = LogicalExpression::Map(vec![
1832            (
1833                "key1".to_string(),
1834                LogicalExpression::Variable("a".to_string()),
1835            ),
1836            (
1837                "key2".to_string(),
1838                LogicalExpression::Variable("b".to_string()),
1839            ),
1840        ]);
1841        let vars = optimizer.extract_variables(&expr);
1842        assert_eq!(vars.len(), 2);
1843        assert!(vars.contains("a"));
1844        assert!(vars.contains("b"));
1845    }
1846
1847    #[test]
1848    fn test_extract_variables_from_index_access() {
1849        let optimizer = Optimizer::new();
1850        let expr = LogicalExpression::IndexAccess {
1851            base: Box::new(LogicalExpression::Variable("list".to_string())),
1852            index: Box::new(LogicalExpression::Variable("idx".to_string())),
1853        };
1854        let vars = optimizer.extract_variables(&expr);
1855        assert_eq!(vars.len(), 2);
1856        assert!(vars.contains("list"));
1857        assert!(vars.contains("idx"));
1858    }
1859
1860    #[test]
1861    fn test_extract_variables_from_slice_access() {
1862        let optimizer = Optimizer::new();
1863        let expr = LogicalExpression::SliceAccess {
1864            base: Box::new(LogicalExpression::Variable("list".to_string())),
1865            start: Some(Box::new(LogicalExpression::Variable("s".to_string()))),
1866            end: Some(Box::new(LogicalExpression::Variable("e".to_string()))),
1867        };
1868        let vars = optimizer.extract_variables(&expr);
1869        assert_eq!(vars.len(), 3);
1870        assert!(vars.contains("list"));
1871        assert!(vars.contains("s"));
1872        assert!(vars.contains("e"));
1873    }
1874
1875    #[test]
1876    fn test_extract_variables_from_case() {
1877        let optimizer = Optimizer::new();
1878        let expr = LogicalExpression::Case {
1879            operand: Some(Box::new(LogicalExpression::Variable("x".to_string()))),
1880            when_clauses: vec![(
1881                LogicalExpression::Literal(Value::Int64(1)),
1882                LogicalExpression::Variable("a".to_string()),
1883            )],
1884            else_clause: Some(Box::new(LogicalExpression::Variable("b".to_string()))),
1885        };
1886        let vars = optimizer.extract_variables(&expr);
1887        assert_eq!(vars.len(), 3);
1888        assert!(vars.contains("x"));
1889        assert!(vars.contains("a"));
1890        assert!(vars.contains("b"));
1891    }
1892
1893    #[test]
1894    fn test_extract_variables_from_labels() {
1895        let optimizer = Optimizer::new();
1896        let expr = LogicalExpression::Labels("n".to_string());
1897        let vars = optimizer.extract_variables(&expr);
1898        assert_eq!(vars.len(), 1);
1899        assert!(vars.contains("n"));
1900    }
1901
1902    #[test]
1903    fn test_extract_variables_from_type() {
1904        let optimizer = Optimizer::new();
1905        let expr = LogicalExpression::Type("e".to_string());
1906        let vars = optimizer.extract_variables(&expr);
1907        assert_eq!(vars.len(), 1);
1908        assert!(vars.contains("e"));
1909    }
1910
1911    #[test]
1912    fn test_extract_variables_from_id() {
1913        let optimizer = Optimizer::new();
1914        let expr = LogicalExpression::Id("n".to_string());
1915        let vars = optimizer.extract_variables(&expr);
1916        assert_eq!(vars.len(), 1);
1917        assert!(vars.contains("n"));
1918    }
1919
1920    #[test]
1921    fn test_extract_variables_from_list_comprehension() {
1922        let optimizer = Optimizer::new();
1923        let expr = LogicalExpression::ListComprehension {
1924            variable: "x".to_string(),
1925            list_expr: Box::new(LogicalExpression::Variable("items".to_string())),
1926            filter_expr: Some(Box::new(LogicalExpression::Variable("pred".to_string()))),
1927            map_expr: Box::new(LogicalExpression::Variable("result".to_string())),
1928        };
1929        let vars = optimizer.extract_variables(&expr);
1930        assert!(vars.contains("items"));
1931        assert!(vars.contains("pred"));
1932        assert!(vars.contains("result"));
1933    }
1934
1935    #[test]
1936    fn test_extract_variables_from_literal_and_parameter() {
1937        let optimizer = Optimizer::new();
1938
1939        let literal = LogicalExpression::Literal(Value::Int64(42));
1940        assert!(optimizer.extract_variables(&literal).is_empty());
1941
1942        let param = LogicalExpression::Parameter("p".to_string());
1943        assert!(optimizer.extract_variables(&param).is_empty());
1944    }
1945
1946    // Recursive filter pushdown tests
1947
1948    #[test]
1949    fn test_recursive_filter_pushdown_through_skip() {
1950        let optimizer = Optimizer::new();
1951
1952        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1953            items: vec![ReturnItem {
1954                expression: LogicalExpression::Variable("n".to_string()),
1955                alias: None,
1956            }],
1957            distinct: false,
1958            input: Box::new(LogicalOperator::Filter(FilterOp {
1959                predicate: LogicalExpression::Literal(Value::Bool(true)),
1960                pushdown_hint: None,
1961                input: Box::new(LogicalOperator::Skip(SkipOp {
1962                    count: 5.into(),
1963                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1964                        variable: "n".to_string(),
1965                        label: None,
1966                        input: None,
1967                    })),
1968                })),
1969            })),
1970        }));
1971
1972        let optimized = optimizer.optimize(plan).unwrap();
1973
1974        // Verify optimization succeeded
1975        assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1976    }
1977
1978    #[test]
1979    fn test_nested_filter_pushdown() {
1980        let optimizer = Optimizer::new();
1981
1982        // Multiple nested filters
1983        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1984            items: vec![ReturnItem {
1985                expression: LogicalExpression::Variable("n".to_string()),
1986                alias: None,
1987            }],
1988            distinct: false,
1989            input: Box::new(LogicalOperator::Filter(FilterOp {
1990                predicate: LogicalExpression::Binary {
1991                    left: Box::new(LogicalExpression::Property {
1992                        variable: "n".to_string(),
1993                        property: "x".to_string(),
1994                    }),
1995                    op: BinaryOp::Gt,
1996                    right: Box::new(LogicalExpression::Literal(Value::Int64(1))),
1997                },
1998                pushdown_hint: None,
1999                input: Box::new(LogicalOperator::Filter(FilterOp {
2000                    predicate: LogicalExpression::Binary {
2001                        left: Box::new(LogicalExpression::Property {
2002                            variable: "n".to_string(),
2003                            property: "y".to_string(),
2004                        }),
2005                        op: BinaryOp::Lt,
2006                        right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
2007                    },
2008                    pushdown_hint: None,
2009                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2010                        variable: "n".to_string(),
2011                        label: None,
2012                        input: None,
2013                    })),
2014                })),
2015            })),
2016        }));
2017
2018        let optimized = optimizer.optimize(plan).unwrap();
2019        assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
2020    }
2021
2022    #[test]
2023    fn test_cyclic_join_produces_multi_way_join() {
2024        use crate::query::plan::JoinCondition;
2025
2026        // Triangle pattern: a ⋈ b ⋈ c ⋈ a (cyclic)
2027        let scan_a = LogicalOperator::NodeScan(NodeScanOp {
2028            variable: "a".to_string(),
2029            label: Some("Person".to_string()),
2030            input: None,
2031        });
2032        let scan_b = LogicalOperator::NodeScan(NodeScanOp {
2033            variable: "b".to_string(),
2034            label: Some("Person".to_string()),
2035            input: None,
2036        });
2037        let scan_c = LogicalOperator::NodeScan(NodeScanOp {
2038            variable: "c".to_string(),
2039            label: Some("Person".to_string()),
2040            input: None,
2041        });
2042
2043        // Build: Join(Join(a, b, a=b), c, b=c) with extra condition c=a
2044        let join_ab = LogicalOperator::Join(JoinOp {
2045            left: Box::new(scan_a),
2046            right: Box::new(scan_b),
2047            join_type: JoinType::Inner,
2048            conditions: vec![JoinCondition {
2049                left: LogicalExpression::Variable("a".to_string()),
2050                right: LogicalExpression::Variable("b".to_string()),
2051            }],
2052        });
2053
2054        let join_abc = LogicalOperator::Join(JoinOp {
2055            left: Box::new(join_ab),
2056            right: Box::new(scan_c),
2057            join_type: JoinType::Inner,
2058            conditions: vec![
2059                JoinCondition {
2060                    left: LogicalExpression::Variable("b".to_string()),
2061                    right: LogicalExpression::Variable("c".to_string()),
2062                },
2063                JoinCondition {
2064                    left: LogicalExpression::Variable("c".to_string()),
2065                    right: LogicalExpression::Variable("a".to_string()),
2066                },
2067            ],
2068        });
2069
2070        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
2071            items: vec![ReturnItem {
2072                expression: LogicalExpression::Variable("a".to_string()),
2073                alias: None,
2074            }],
2075            distinct: false,
2076            input: Box::new(join_abc),
2077        }));
2078
2079        let mut optimizer = Optimizer::new();
2080        optimizer
2081            .card_estimator
2082            .add_table_stats("Person", cardinality::TableStats::new(1000));
2083
2084        let optimized = optimizer.optimize(plan).unwrap();
2085
2086        // Walk the tree to find a MultiWayJoin
2087        fn has_multi_way_join(op: &LogicalOperator) -> bool {
2088            match op {
2089                LogicalOperator::MultiWayJoin(_) => true,
2090                LogicalOperator::Return(ret) => has_multi_way_join(&ret.input),
2091                LogicalOperator::Filter(f) => has_multi_way_join(&f.input),
2092                LogicalOperator::Project(p) => has_multi_way_join(&p.input),
2093                _ => false,
2094            }
2095        }
2096
2097        assert!(
2098            has_multi_way_join(&optimized.root),
2099            "Expected MultiWayJoin for cyclic triangle pattern"
2100        );
2101    }
2102
2103    #[test]
2104    fn test_acyclic_join_uses_binary_joins() {
2105        use crate::query::plan::JoinCondition;
2106
2107        // Chain: a ⋈ b ⋈ c (acyclic)
2108        let scan_a = LogicalOperator::NodeScan(NodeScanOp {
2109            variable: "a".to_string(),
2110            label: Some("Person".to_string()),
2111            input: None,
2112        });
2113        let scan_b = LogicalOperator::NodeScan(NodeScanOp {
2114            variable: "b".to_string(),
2115            label: Some("Person".to_string()),
2116            input: None,
2117        });
2118        let scan_c = LogicalOperator::NodeScan(NodeScanOp {
2119            variable: "c".to_string(),
2120            label: Some("Company".to_string()),
2121            input: None,
2122        });
2123
2124        let join_ab = LogicalOperator::Join(JoinOp {
2125            left: Box::new(scan_a),
2126            right: Box::new(scan_b),
2127            join_type: JoinType::Inner,
2128            conditions: vec![JoinCondition {
2129                left: LogicalExpression::Variable("a".to_string()),
2130                right: LogicalExpression::Variable("b".to_string()),
2131            }],
2132        });
2133
2134        let join_abc = LogicalOperator::Join(JoinOp {
2135            left: Box::new(join_ab),
2136            right: Box::new(scan_c),
2137            join_type: JoinType::Inner,
2138            conditions: vec![JoinCondition {
2139                left: LogicalExpression::Variable("b".to_string()),
2140                right: LogicalExpression::Variable("c".to_string()),
2141            }],
2142        });
2143
2144        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
2145            items: vec![ReturnItem {
2146                expression: LogicalExpression::Variable("a".to_string()),
2147                alias: None,
2148            }],
2149            distinct: false,
2150            input: Box::new(join_abc),
2151        }));
2152
2153        let mut optimizer = Optimizer::new();
2154        optimizer
2155            .card_estimator
2156            .add_table_stats("Person", cardinality::TableStats::new(1000));
2157        optimizer
2158            .card_estimator
2159            .add_table_stats("Company", cardinality::TableStats::new(100));
2160
2161        let optimized = optimizer.optimize(plan).unwrap();
2162
2163        // Should NOT contain MultiWayJoin for acyclic pattern
2164        fn has_multi_way_join(op: &LogicalOperator) -> bool {
2165            match op {
2166                LogicalOperator::MultiWayJoin(_) => true,
2167                LogicalOperator::Return(ret) => has_multi_way_join(&ret.input),
2168                LogicalOperator::Filter(f) => has_multi_way_join(&f.input),
2169                LogicalOperator::Project(p) => has_multi_way_join(&p.input),
2170                LogicalOperator::Join(j) => {
2171                    has_multi_way_join(&j.left) || has_multi_way_join(&j.right)
2172                }
2173                _ => false,
2174            }
2175        }
2176
2177        assert!(
2178            !has_multi_way_join(&optimized.root),
2179            "Acyclic join should NOT produce MultiWayJoin"
2180        );
2181    }
2182}