Skip to main content

grafeo_engine/query/optimizer/
mod.rs

1//! Makes your queries faster without changing their meaning.
2//!
3//! The optimizer transforms logical plans to run more efficiently:
4//!
5//! | Optimization | What it does |
6//! | ------------ | ------------ |
7//! | Filter Pushdown | Moves `WHERE` clauses closer to scans - filter early, process less |
8//! | Join Reordering | Picks the best order to join tables using the DPccp algorithm |
9//! | Predicate Simplification | Folds constants like `1 + 1` into `2` |
10//!
11//! The optimizer uses [`CostModel`] and [`CardinalityEstimator`] to predict
12//! how expensive different plans are, then picks the cheapest.
13
14pub mod cardinality;
15pub mod cost;
16pub mod join_order;
17
18pub use cardinality::{
19    CardinalityEstimator, ColumnStats, EstimationLog, SelectivityConfig, TableStats,
20};
21pub use cost::{Cost, CostModel};
22pub use join_order::{BitSet, DPccp, JoinGraph, JoinGraphBuilder, JoinPlan};
23
24use crate::query::plan::{
25    FilterOp, JoinCondition, LogicalExpression, LogicalOperator, LogicalPlan, MultiWayJoinOp,
26};
27use grafeo_common::grafeo_debug_span;
28use grafeo_common::utils::error::Result;
29use std::collections::HashSet;
30
31/// Information about a join condition for join reordering.
32#[derive(Debug, Clone)]
33struct JoinInfo {
34    left_var: String,
35    right_var: String,
36    left_expr: LogicalExpression,
37    right_expr: LogicalExpression,
38}
39
40/// A column required by the query, used for projection pushdown.
41#[derive(Debug, Clone, PartialEq, Eq, Hash)]
42enum RequiredColumn {
43    /// A variable (node, edge, or path binding)
44    Variable(String),
45    /// A specific property of a variable
46    Property(String, String),
47}
48
49/// Transforms logical plans for faster execution.
50///
51/// Create with [`new()`](Self::new), then call [`optimize()`](Self::optimize).
52/// Use the builder methods to enable/disable specific optimizations.
53pub struct Optimizer {
54    /// Whether to enable filter pushdown.
55    enable_filter_pushdown: bool,
56    /// Whether to enable join reordering.
57    enable_join_reorder: bool,
58    /// Whether to enable projection pushdown.
59    enable_projection_pushdown: bool,
60    /// Cost model for estimation.
61    cost_model: CostModel,
62    /// Cardinality estimator.
63    card_estimator: CardinalityEstimator,
64}
65
66impl Optimizer {
67    /// Creates a new optimizer with default settings.
68    #[must_use]
69    pub fn new() -> Self {
70        Self {
71            enable_filter_pushdown: true,
72            enable_join_reorder: true,
73            enable_projection_pushdown: true,
74            cost_model: CostModel::new(),
75            card_estimator: CardinalityEstimator::new(),
76        }
77    }
78
79    /// Creates an optimizer with cardinality estimates from the store's statistics.
80    ///
81    /// Pre-populates the cardinality estimator with per-label row counts and
82    /// edge type fanout. Feeds per-edge-type degree stats, label cardinalities,
83    /// and graph totals into the cost model for accurate estimation.
84    #[must_use]
85    pub fn from_store(store: &grafeo_core::graph::lpg::LpgStore) -> Self {
86        store.ensure_statistics_fresh();
87        let stats = store.statistics();
88        Self::from_statistics(&stats)
89    }
90
91    /// Creates an optimizer from any GraphStore implementation.
92    ///
93    /// Unlike [`from_store`](Self::from_store), this does not call
94    /// `ensure_statistics_fresh()` since external stores manage their own
95    /// statistics. The store's [`statistics()`](grafeo_core::graph::GraphStore::statistics) method
96    /// is called directly.
97    #[must_use]
98    pub fn from_graph_store(store: &dyn grafeo_core::graph::GraphStore) -> Self {
99        let stats = store.statistics();
100        Self::from_statistics(&stats)
101    }
102
103    /// Creates an optimizer from RDF statistics.
104    ///
105    /// Uses triple pattern cardinality estimates for cost-based optimization
106    /// of SPARQL queries. Maps total triples to graph totals for the cost model.
107    #[cfg(feature = "rdf")]
108    #[must_use]
109    pub fn from_rdf_statistics(rdf_stats: grafeo_core::statistics::RdfStatistics) -> Self {
110        let total = rdf_stats.total_triples;
111        let estimator = CardinalityEstimator::from_rdf_statistics(rdf_stats);
112        Self {
113            enable_filter_pushdown: true,
114            enable_join_reorder: true,
115            enable_projection_pushdown: true,
116            cost_model: CostModel::new().with_graph_totals(total, total),
117            card_estimator: estimator,
118        }
119    }
120
121    /// Creates an optimizer from a Statistics snapshot.
122    ///
123    /// Extracts label cardinalities, edge type degrees, and graph totals
124    /// for both the cardinality estimator and cost model.
125    #[must_use]
126    fn from_statistics(stats: &grafeo_core::statistics::Statistics) -> Self {
127        let estimator = CardinalityEstimator::from_statistics(stats);
128
129        let avg_fanout = if stats.total_nodes > 0 {
130            (stats.total_edges as f64 / stats.total_nodes as f64).max(1.0)
131        } else {
132            10.0
133        };
134
135        let edge_type_degrees: std::collections::HashMap<String, (f64, f64)> = stats
136            .edge_types
137            .iter()
138            .map(|(name, et)| (name.clone(), (et.avg_out_degree, et.avg_in_degree)))
139            .collect();
140
141        let label_cardinalities: std::collections::HashMap<String, u64> = stats
142            .labels
143            .iter()
144            .map(|(name, ls)| (name.clone(), ls.node_count))
145            .collect();
146
147        Self {
148            enable_filter_pushdown: true,
149            enable_join_reorder: true,
150            enable_projection_pushdown: true,
151            cost_model: CostModel::new()
152                .with_avg_fanout(avg_fanout)
153                .with_edge_type_degrees(edge_type_degrees)
154                .with_label_cardinalities(label_cardinalities)
155                .with_graph_totals(stats.total_nodes, stats.total_edges),
156            card_estimator: estimator,
157        }
158    }
159
160    /// Enables or disables filter pushdown.
161    pub fn with_filter_pushdown(mut self, enabled: bool) -> Self {
162        self.enable_filter_pushdown = enabled;
163        self
164    }
165
166    /// Enables or disables join reordering.
167    pub fn with_join_reorder(mut self, enabled: bool) -> Self {
168        self.enable_join_reorder = enabled;
169        self
170    }
171
172    /// Enables or disables projection pushdown.
173    pub fn with_projection_pushdown(mut self, enabled: bool) -> Self {
174        self.enable_projection_pushdown = enabled;
175        self
176    }
177
178    /// Sets the cost model.
179    pub fn with_cost_model(mut self, cost_model: CostModel) -> Self {
180        self.cost_model = cost_model;
181        self
182    }
183
184    /// Sets the cardinality estimator.
185    pub fn with_cardinality_estimator(mut self, estimator: CardinalityEstimator) -> Self {
186        self.card_estimator = estimator;
187        self
188    }
189
190    /// Sets the selectivity configuration for the cardinality estimator.
191    pub fn with_selectivity_config(mut self, config: SelectivityConfig) -> Self {
192        self.card_estimator = CardinalityEstimator::with_selectivity_config(config);
193        self
194    }
195
196    /// Returns a reference to the cost model.
197    pub fn cost_model(&self) -> &CostModel {
198        &self.cost_model
199    }
200
201    /// Returns a reference to the cardinality estimator.
202    pub fn cardinality_estimator(&self) -> &CardinalityEstimator {
203        &self.card_estimator
204    }
205
206    /// Estimates the total cost of a plan by recursively costing the entire tree.
207    ///
208    /// Walks every operator in the plan, computing cardinality at each level
209    /// and summing the per-operator costs. Uses actual child cardinalities
210    /// for join cost estimation rather than approximations.
211    pub fn estimate_cost(&self, plan: &LogicalPlan) -> Cost {
212        self.cost_model
213            .estimate_tree(&plan.root, &self.card_estimator)
214    }
215
216    /// Estimates the cardinality of a plan.
217    pub fn estimate_cardinality(&self, plan: &LogicalPlan) -> f64 {
218        self.card_estimator.estimate(&plan.root)
219    }
220
221    /// Optimizes a logical plan.
222    ///
223    /// # Errors
224    ///
225    /// Returns an error if optimization fails.
226    pub fn optimize(&self, plan: LogicalPlan) -> Result<LogicalPlan> {
227        let _span = grafeo_debug_span!("grafeo::query::optimize");
228        let mut root = plan.root;
229
230        // Apply optimization rules
231        if self.enable_filter_pushdown {
232            root = self.push_filters_down(root);
233        }
234
235        if self.enable_join_reorder {
236            root = self.reorder_joins(root);
237        }
238
239        if self.enable_projection_pushdown {
240            root = self.push_projections_down(root);
241        }
242
243        Ok(LogicalPlan {
244            root,
245            explain: plan.explain,
246            profile: plan.profile,
247        })
248    }
249
250    /// Pushes projections down the operator tree to eliminate unused columns early.
251    ///
252    /// This optimization:
253    /// 1. Collects required variables/properties from the root
254    /// 2. Propagates requirements down through the tree
255    /// 3. Inserts projections to eliminate unneeded columns before expensive operations
256    fn push_projections_down(&self, op: LogicalOperator) -> LogicalOperator {
257        // Collect required columns from the top of the plan
258        let required = self.collect_required_columns(&op);
259
260        // Push projections down
261        self.push_projections_recursive(op, &required)
262    }
263
264    /// Collects all variables and properties required by an operator and its ancestors.
265    fn collect_required_columns(&self, op: &LogicalOperator) -> HashSet<RequiredColumn> {
266        let mut required = HashSet::new();
267        Self::collect_required_recursive(op, &mut required);
268        required
269    }
270
271    /// Recursively collects required columns.
272    fn collect_required_recursive(op: &LogicalOperator, required: &mut HashSet<RequiredColumn>) {
273        match op {
274            LogicalOperator::Return(ret) => {
275                for item in &ret.items {
276                    Self::collect_from_expression(&item.expression, required);
277                }
278                Self::collect_required_recursive(&ret.input, required);
279            }
280            LogicalOperator::Project(proj) => {
281                for p in &proj.projections {
282                    Self::collect_from_expression(&p.expression, required);
283                }
284                Self::collect_required_recursive(&proj.input, required);
285            }
286            LogicalOperator::Filter(filter) => {
287                Self::collect_from_expression(&filter.predicate, required);
288                Self::collect_required_recursive(&filter.input, required);
289            }
290            LogicalOperator::Sort(sort) => {
291                for key in &sort.keys {
292                    Self::collect_from_expression(&key.expression, required);
293                }
294                Self::collect_required_recursive(&sort.input, required);
295            }
296            LogicalOperator::Aggregate(agg) => {
297                for expr in &agg.group_by {
298                    Self::collect_from_expression(expr, required);
299                }
300                for agg_expr in &agg.aggregates {
301                    if let Some(ref expr) = agg_expr.expression {
302                        Self::collect_from_expression(expr, required);
303                    }
304                }
305                if let Some(ref having) = agg.having {
306                    Self::collect_from_expression(having, required);
307                }
308                Self::collect_required_recursive(&agg.input, required);
309            }
310            LogicalOperator::Join(join) => {
311                for cond in &join.conditions {
312                    Self::collect_from_expression(&cond.left, required);
313                    Self::collect_from_expression(&cond.right, required);
314                }
315                Self::collect_required_recursive(&join.left, required);
316                Self::collect_required_recursive(&join.right, required);
317            }
318            LogicalOperator::Expand(expand) => {
319                // The source and target variables are needed
320                required.insert(RequiredColumn::Variable(expand.from_variable.clone()));
321                required.insert(RequiredColumn::Variable(expand.to_variable.clone()));
322                if let Some(ref edge_var) = expand.edge_variable {
323                    required.insert(RequiredColumn::Variable(edge_var.clone()));
324                }
325                Self::collect_required_recursive(&expand.input, required);
326            }
327            LogicalOperator::Limit(limit) => {
328                Self::collect_required_recursive(&limit.input, required);
329            }
330            LogicalOperator::Skip(skip) => {
331                Self::collect_required_recursive(&skip.input, required);
332            }
333            LogicalOperator::Distinct(distinct) => {
334                Self::collect_required_recursive(&distinct.input, required);
335            }
336            LogicalOperator::NodeScan(scan) => {
337                required.insert(RequiredColumn::Variable(scan.variable.clone()));
338            }
339            LogicalOperator::EdgeScan(scan) => {
340                required.insert(RequiredColumn::Variable(scan.variable.clone()));
341            }
342            LogicalOperator::MultiWayJoin(mwj) => {
343                for cond in &mwj.conditions {
344                    Self::collect_from_expression(&cond.left, required);
345                    Self::collect_from_expression(&cond.right, required);
346                }
347                for input in &mwj.inputs {
348                    Self::collect_required_recursive(input, required);
349                }
350            }
351            _ => {}
352        }
353    }
354
355    /// Collects required columns from an expression.
356    fn collect_from_expression(expr: &LogicalExpression, required: &mut HashSet<RequiredColumn>) {
357        match expr {
358            LogicalExpression::Variable(var) => {
359                required.insert(RequiredColumn::Variable(var.clone()));
360            }
361            LogicalExpression::Property { variable, property } => {
362                required.insert(RequiredColumn::Property(variable.clone(), property.clone()));
363                required.insert(RequiredColumn::Variable(variable.clone()));
364            }
365            LogicalExpression::Binary { left, right, .. } => {
366                Self::collect_from_expression(left, required);
367                Self::collect_from_expression(right, required);
368            }
369            LogicalExpression::Unary { operand, .. } => {
370                Self::collect_from_expression(operand, required);
371            }
372            LogicalExpression::FunctionCall { args, .. } => {
373                for arg in args {
374                    Self::collect_from_expression(arg, required);
375                }
376            }
377            LogicalExpression::List(items) => {
378                for item in items {
379                    Self::collect_from_expression(item, required);
380                }
381            }
382            LogicalExpression::Map(pairs) => {
383                for (_, value) in pairs {
384                    Self::collect_from_expression(value, required);
385                }
386            }
387            LogicalExpression::IndexAccess { base, index } => {
388                Self::collect_from_expression(base, required);
389                Self::collect_from_expression(index, required);
390            }
391            LogicalExpression::SliceAccess { base, start, end } => {
392                Self::collect_from_expression(base, required);
393                if let Some(s) = start {
394                    Self::collect_from_expression(s, required);
395                }
396                if let Some(e) = end {
397                    Self::collect_from_expression(e, required);
398                }
399            }
400            LogicalExpression::Case {
401                operand,
402                when_clauses,
403                else_clause,
404            } => {
405                if let Some(op) = operand {
406                    Self::collect_from_expression(op, required);
407                }
408                for (cond, result) in when_clauses {
409                    Self::collect_from_expression(cond, required);
410                    Self::collect_from_expression(result, required);
411                }
412                if let Some(else_expr) = else_clause {
413                    Self::collect_from_expression(else_expr, required);
414                }
415            }
416            LogicalExpression::Labels(var)
417            | LogicalExpression::Type(var)
418            | LogicalExpression::Id(var) => {
419                required.insert(RequiredColumn::Variable(var.clone()));
420            }
421            LogicalExpression::ListComprehension {
422                list_expr,
423                filter_expr,
424                map_expr,
425                ..
426            } => {
427                Self::collect_from_expression(list_expr, required);
428                if let Some(filter) = filter_expr {
429                    Self::collect_from_expression(filter, required);
430                }
431                Self::collect_from_expression(map_expr, required);
432            }
433            _ => {}
434        }
435    }
436
437    /// Recursively pushes projections down, adding them before expensive operations.
438    fn push_projections_recursive(
439        &self,
440        op: LogicalOperator,
441        required: &HashSet<RequiredColumn>,
442    ) -> LogicalOperator {
443        match op {
444            LogicalOperator::Return(mut ret) => {
445                ret.input = Box::new(self.push_projections_recursive(*ret.input, required));
446                LogicalOperator::Return(ret)
447            }
448            LogicalOperator::Project(mut proj) => {
449                proj.input = Box::new(self.push_projections_recursive(*proj.input, required));
450                LogicalOperator::Project(proj)
451            }
452            LogicalOperator::Filter(mut filter) => {
453                filter.input = Box::new(self.push_projections_recursive(*filter.input, required));
454                LogicalOperator::Filter(filter)
455            }
456            LogicalOperator::Sort(mut sort) => {
457                // Sort is expensive - consider adding a projection before it
458                // to reduce tuple width
459                sort.input = Box::new(self.push_projections_recursive(*sort.input, required));
460                LogicalOperator::Sort(sort)
461            }
462            LogicalOperator::Aggregate(mut agg) => {
463                agg.input = Box::new(self.push_projections_recursive(*agg.input, required));
464                LogicalOperator::Aggregate(agg)
465            }
466            LogicalOperator::Join(mut join) => {
467                // Joins are expensive - the required columns help determine
468                // what to project on each side
469                let left_vars = self.collect_output_variables(&join.left);
470                let right_vars = self.collect_output_variables(&join.right);
471
472                // Filter required columns to each side
473                let left_required: HashSet<_> = required
474                    .iter()
475                    .filter(|c| match c {
476                        RequiredColumn::Variable(v) => left_vars.contains(v),
477                        RequiredColumn::Property(v, _) => left_vars.contains(v),
478                    })
479                    .cloned()
480                    .collect();
481
482                let right_required: HashSet<_> = required
483                    .iter()
484                    .filter(|c| match c {
485                        RequiredColumn::Variable(v) => right_vars.contains(v),
486                        RequiredColumn::Property(v, _) => right_vars.contains(v),
487                    })
488                    .cloned()
489                    .collect();
490
491                join.left = Box::new(self.push_projections_recursive(*join.left, &left_required));
492                join.right =
493                    Box::new(self.push_projections_recursive(*join.right, &right_required));
494                LogicalOperator::Join(join)
495            }
496            LogicalOperator::Expand(mut expand) => {
497                expand.input = Box::new(self.push_projections_recursive(*expand.input, required));
498                LogicalOperator::Expand(expand)
499            }
500            LogicalOperator::Limit(mut limit) => {
501                limit.input = Box::new(self.push_projections_recursive(*limit.input, required));
502                LogicalOperator::Limit(limit)
503            }
504            LogicalOperator::Skip(mut skip) => {
505                skip.input = Box::new(self.push_projections_recursive(*skip.input, required));
506                LogicalOperator::Skip(skip)
507            }
508            LogicalOperator::Distinct(mut distinct) => {
509                distinct.input =
510                    Box::new(self.push_projections_recursive(*distinct.input, required));
511                LogicalOperator::Distinct(distinct)
512            }
513            LogicalOperator::MapCollect(mut mc) => {
514                mc.input = Box::new(self.push_projections_recursive(*mc.input, required));
515                LogicalOperator::MapCollect(mc)
516            }
517            LogicalOperator::MultiWayJoin(mut mwj) => {
518                mwj.inputs = mwj
519                    .inputs
520                    .into_iter()
521                    .map(|input| self.push_projections_recursive(input, required))
522                    .collect();
523                LogicalOperator::MultiWayJoin(mwj)
524            }
525            other => other,
526        }
527    }
528
529    /// Reorders joins in the operator tree using the DPccp algorithm.
530    ///
531    /// This optimization finds the optimal join order by:
532    /// 1. Extracting all base relations (scans) and join conditions
533    /// 2. Building a join graph
534    /// 3. Using dynamic programming to find the cheapest join order
535    fn reorder_joins(&self, op: LogicalOperator) -> LogicalOperator {
536        // First, recursively optimize children
537        let op = self.reorder_joins_recursive(op);
538
539        // Then, if this is a join tree, try to optimize it
540        if let Some((relations, conditions)) = self.extract_join_tree(&op)
541            && relations.len() >= 2
542            && let Some(optimized) = self.optimize_join_order(&relations, &conditions)
543        {
544            return optimized;
545        }
546
547        op
548    }
549
550    /// Recursively applies join reordering to child operators.
551    fn reorder_joins_recursive(&self, op: LogicalOperator) -> LogicalOperator {
552        match op {
553            LogicalOperator::Return(mut ret) => {
554                ret.input = Box::new(self.reorder_joins(*ret.input));
555                LogicalOperator::Return(ret)
556            }
557            LogicalOperator::Project(mut proj) => {
558                proj.input = Box::new(self.reorder_joins(*proj.input));
559                LogicalOperator::Project(proj)
560            }
561            LogicalOperator::Filter(mut filter) => {
562                filter.input = Box::new(self.reorder_joins(*filter.input));
563                LogicalOperator::Filter(filter)
564            }
565            LogicalOperator::Limit(mut limit) => {
566                limit.input = Box::new(self.reorder_joins(*limit.input));
567                LogicalOperator::Limit(limit)
568            }
569            LogicalOperator::Skip(mut skip) => {
570                skip.input = Box::new(self.reorder_joins(*skip.input));
571                LogicalOperator::Skip(skip)
572            }
573            LogicalOperator::Sort(mut sort) => {
574                sort.input = Box::new(self.reorder_joins(*sort.input));
575                LogicalOperator::Sort(sort)
576            }
577            LogicalOperator::Distinct(mut distinct) => {
578                distinct.input = Box::new(self.reorder_joins(*distinct.input));
579                LogicalOperator::Distinct(distinct)
580            }
581            LogicalOperator::Aggregate(mut agg) => {
582                agg.input = Box::new(self.reorder_joins(*agg.input));
583                LogicalOperator::Aggregate(agg)
584            }
585            LogicalOperator::Expand(mut expand) => {
586                expand.input = Box::new(self.reorder_joins(*expand.input));
587                LogicalOperator::Expand(expand)
588            }
589            LogicalOperator::MapCollect(mut mc) => {
590                mc.input = Box::new(self.reorder_joins(*mc.input));
591                LogicalOperator::MapCollect(mc)
592            }
593            LogicalOperator::MultiWayJoin(mut mwj) => {
594                mwj.inputs = mwj
595                    .inputs
596                    .into_iter()
597                    .map(|input| self.reorder_joins(input))
598                    .collect();
599                LogicalOperator::MultiWayJoin(mwj)
600            }
601            // Join operators are handled by the parent reorder_joins call
602            other => other,
603        }
604    }
605
606    /// Extracts base relations and join conditions from a join tree.
607    ///
608    /// Returns None if the operator is not a join tree.
609    fn extract_join_tree(
610        &self,
611        op: &LogicalOperator,
612    ) -> Option<(Vec<(String, LogicalOperator)>, Vec<JoinInfo>)> {
613        let mut relations = Vec::new();
614        let mut join_conditions = Vec::new();
615
616        if !self.collect_join_tree(op, &mut relations, &mut join_conditions) {
617            return None;
618        }
619
620        if relations.len() < 2 {
621            return None;
622        }
623
624        Some((relations, join_conditions))
625    }
626
627    /// Recursively collects base relations and join conditions.
628    ///
629    /// Returns true if this subtree is part of a join tree.
630    fn collect_join_tree(
631        &self,
632        op: &LogicalOperator,
633        relations: &mut Vec<(String, LogicalOperator)>,
634        conditions: &mut Vec<JoinInfo>,
635    ) -> bool {
636        match op {
637            LogicalOperator::Join(join) => {
638                // Collect from both sides
639                let left_ok = self.collect_join_tree(&join.left, relations, conditions);
640                let right_ok = self.collect_join_tree(&join.right, relations, conditions);
641
642                // Add conditions from this join
643                for cond in &join.conditions {
644                    if let (Some(left_var), Some(right_var)) = (
645                        self.extract_variable_from_expr(&cond.left),
646                        self.extract_variable_from_expr(&cond.right),
647                    ) {
648                        conditions.push(JoinInfo {
649                            left_var,
650                            right_var,
651                            left_expr: cond.left.clone(),
652                            right_expr: cond.right.clone(),
653                        });
654                    }
655                }
656
657                left_ok && right_ok
658            }
659            LogicalOperator::NodeScan(scan) => {
660                relations.push((scan.variable.clone(), op.clone()));
661                true
662            }
663            LogicalOperator::EdgeScan(scan) => {
664                relations.push((scan.variable.clone(), op.clone()));
665                true
666            }
667            LogicalOperator::Filter(filter) => {
668                // A filter on a base relation is still part of the join tree
669                self.collect_join_tree(&filter.input, relations, conditions)
670            }
671            LogicalOperator::Expand(expand) => {
672                // Expand is a special case - it's like a join with the adjacency
673                // For now, treat the whole Expand subtree as a single relation
674                relations.push((expand.to_variable.clone(), op.clone()));
675                true
676            }
677            _ => false,
678        }
679    }
680
681    /// Extracts the primary variable from an expression.
682    fn extract_variable_from_expr(&self, expr: &LogicalExpression) -> Option<String> {
683        match expr {
684            LogicalExpression::Variable(v) => Some(v.clone()),
685            LogicalExpression::Property { variable, .. } => Some(variable.clone()),
686            _ => None,
687        }
688    }
689
690    /// Optimizes the join order using DPccp, or produces a multi-way
691    /// leapfrog join for cyclic patterns when the cost model prefers it.
692    fn optimize_join_order(
693        &self,
694        relations: &[(String, LogicalOperator)],
695        conditions: &[JoinInfo],
696    ) -> Option<LogicalOperator> {
697        use join_order::{DPccp, JoinGraphBuilder};
698
699        // Build the join graph
700        let mut builder = JoinGraphBuilder::new();
701
702        for (var, relation) in relations {
703            builder.add_relation(var, relation.clone());
704        }
705
706        for cond in conditions {
707            builder.add_join_condition(
708                &cond.left_var,
709                &cond.right_var,
710                cond.left_expr.clone(),
711                cond.right_expr.clone(),
712            );
713        }
714
715        let graph = builder.build();
716
717        // For cyclic graphs with 3+ relations, use leapfrog (WCOJ) join.
718        // Cyclic joins (e.g. triangle patterns) benefit from worst-case optimal
719        // multi-way intersection rather than binary hash join cascades that can
720        // produce intermediate blowup.
721        if graph.is_cyclic() && relations.len() >= 3 {
722            // Collect shared variables (variables appearing in 2+ conditions)
723            let mut var_counts: std::collections::HashMap<&str, usize> =
724                std::collections::HashMap::new();
725            for cond in conditions {
726                *var_counts.entry(&cond.left_var).or_default() += 1;
727                *var_counts.entry(&cond.right_var).or_default() += 1;
728            }
729            let shared_variables: Vec<String> = var_counts
730                .into_iter()
731                .filter(|(_, count)| *count >= 2)
732                .map(|(var, _)| var.to_string())
733                .collect();
734
735            let join_conditions: Vec<JoinCondition> = conditions
736                .iter()
737                .map(|c| JoinCondition {
738                    left: c.left_expr.clone(),
739                    right: c.right_expr.clone(),
740                })
741                .collect();
742
743            return Some(LogicalOperator::MultiWayJoin(MultiWayJoinOp {
744                inputs: relations.iter().map(|(_, rel)| rel.clone()).collect(),
745                conditions: join_conditions,
746                shared_variables,
747            }));
748        }
749
750        // Fall through to DPccp for binary join ordering
751        let mut dpccp = DPccp::new(&graph, &self.cost_model, &self.card_estimator);
752        let plan = dpccp.optimize()?;
753
754        Some(plan.operator)
755    }
756
757    /// Pushes filters down the operator tree.
758    ///
759    /// This optimization moves filter predicates as close to the data source
760    /// as possible to reduce the amount of data processed by upper operators.
761    fn push_filters_down(&self, op: LogicalOperator) -> LogicalOperator {
762        match op {
763            // For Filter operators, try to push the predicate into the child
764            LogicalOperator::Filter(filter) => {
765                let optimized_input = self.push_filters_down(*filter.input);
766                self.try_push_filter_into(filter.predicate, optimized_input)
767            }
768            // Recursively optimize children for other operators
769            LogicalOperator::Return(mut ret) => {
770                ret.input = Box::new(self.push_filters_down(*ret.input));
771                LogicalOperator::Return(ret)
772            }
773            LogicalOperator::Project(mut proj) => {
774                proj.input = Box::new(self.push_filters_down(*proj.input));
775                LogicalOperator::Project(proj)
776            }
777            LogicalOperator::Limit(mut limit) => {
778                limit.input = Box::new(self.push_filters_down(*limit.input));
779                LogicalOperator::Limit(limit)
780            }
781            LogicalOperator::Skip(mut skip) => {
782                skip.input = Box::new(self.push_filters_down(*skip.input));
783                LogicalOperator::Skip(skip)
784            }
785            LogicalOperator::Sort(mut sort) => {
786                sort.input = Box::new(self.push_filters_down(*sort.input));
787                LogicalOperator::Sort(sort)
788            }
789            LogicalOperator::Distinct(mut distinct) => {
790                distinct.input = Box::new(self.push_filters_down(*distinct.input));
791                LogicalOperator::Distinct(distinct)
792            }
793            LogicalOperator::Expand(mut expand) => {
794                expand.input = Box::new(self.push_filters_down(*expand.input));
795                LogicalOperator::Expand(expand)
796            }
797            LogicalOperator::Join(mut join) => {
798                join.left = Box::new(self.push_filters_down(*join.left));
799                join.right = Box::new(self.push_filters_down(*join.right));
800                LogicalOperator::Join(join)
801            }
802            LogicalOperator::Aggregate(mut agg) => {
803                agg.input = Box::new(self.push_filters_down(*agg.input));
804                LogicalOperator::Aggregate(agg)
805            }
806            LogicalOperator::MapCollect(mut mc) => {
807                mc.input = Box::new(self.push_filters_down(*mc.input));
808                LogicalOperator::MapCollect(mc)
809            }
810            LogicalOperator::MultiWayJoin(mut mwj) => {
811                mwj.inputs = mwj
812                    .inputs
813                    .into_iter()
814                    .map(|input| self.push_filters_down(input))
815                    .collect();
816                LogicalOperator::MultiWayJoin(mwj)
817            }
818            // Leaf operators and unsupported operators are returned as-is
819            other => other,
820        }
821    }
822
823    /// Tries to push a filter predicate into the given operator.
824    ///
825    /// Returns either the predicate pushed into the operator, or a new
826    /// Filter operator on top if the predicate cannot be pushed further.
827    fn try_push_filter_into(
828        &self,
829        predicate: LogicalExpression,
830        op: LogicalOperator,
831    ) -> LogicalOperator {
832        match op {
833            // Can push through Project if predicate doesn't depend on computed columns
834            LogicalOperator::Project(mut proj) => {
835                let predicate_vars = self.extract_variables(&predicate);
836                let computed_vars = self.extract_projection_aliases(&proj.projections);
837
838                // If predicate doesn't use any computed columns, push through
839                if predicate_vars.is_disjoint(&computed_vars) {
840                    proj.input = Box::new(self.try_push_filter_into(predicate, *proj.input));
841                    LogicalOperator::Project(proj)
842                } else {
843                    // Can't push through, keep filter on top
844                    LogicalOperator::Filter(FilterOp {
845                        predicate,
846                        pushdown_hint: None,
847                        input: Box::new(LogicalOperator::Project(proj)),
848                    })
849                }
850            }
851
852            // Can push through Return (which is like a projection)
853            LogicalOperator::Return(mut ret) => {
854                ret.input = Box::new(self.try_push_filter_into(predicate, *ret.input));
855                LogicalOperator::Return(ret)
856            }
857
858            // Can push through Expand if predicate doesn't use variables introduced by this expand
859            LogicalOperator::Expand(mut expand) => {
860                let predicate_vars = self.extract_variables(&predicate);
861
862                // Variables introduced by this expand are:
863                // - The target variable (to_variable)
864                // - The edge variable (if any)
865                // - The path alias (if any)
866                let mut introduced_vars = vec![&expand.to_variable];
867                if let Some(ref edge_var) = expand.edge_variable {
868                    introduced_vars.push(edge_var);
869                }
870                if let Some(ref path_alias) = expand.path_alias {
871                    introduced_vars.push(path_alias);
872                }
873
874                // Check if predicate uses any variables introduced by this expand
875                let uses_introduced_vars =
876                    predicate_vars.iter().any(|v| introduced_vars.contains(&v));
877
878                if !uses_introduced_vars {
879                    // Predicate doesn't use vars from this expand, so push through
880                    expand.input = Box::new(self.try_push_filter_into(predicate, *expand.input));
881                    LogicalOperator::Expand(expand)
882                } else {
883                    // Keep filter after expand
884                    LogicalOperator::Filter(FilterOp {
885                        predicate,
886                        pushdown_hint: None,
887                        input: Box::new(LogicalOperator::Expand(expand)),
888                    })
889                }
890            }
891
892            // Can push through Join to left/right side based on variables used
893            LogicalOperator::Join(mut join) => {
894                let predicate_vars = self.extract_variables(&predicate);
895                let left_vars = self.collect_output_variables(&join.left);
896                let right_vars = self.collect_output_variables(&join.right);
897
898                let uses_left = predicate_vars.iter().any(|v| left_vars.contains(v));
899                let uses_right = predicate_vars.iter().any(|v| right_vars.contains(v));
900
901                if uses_left && !uses_right {
902                    // Push to left side
903                    join.left = Box::new(self.try_push_filter_into(predicate, *join.left));
904                    LogicalOperator::Join(join)
905                } else if uses_right && !uses_left {
906                    // Push to right side
907                    join.right = Box::new(self.try_push_filter_into(predicate, *join.right));
908                    LogicalOperator::Join(join)
909                } else {
910                    // Uses both sides - keep above join
911                    LogicalOperator::Filter(FilterOp {
912                        predicate,
913                        pushdown_hint: None,
914                        input: Box::new(LogicalOperator::Join(join)),
915                    })
916                }
917            }
918
919            // Cannot push through Aggregate (predicate refers to aggregated values)
920            LogicalOperator::Aggregate(agg) => LogicalOperator::Filter(FilterOp {
921                predicate,
922                pushdown_hint: None,
923                input: Box::new(LogicalOperator::Aggregate(agg)),
924            }),
925
926            // For NodeScan, we've reached the bottom - keep filter on top
927            LogicalOperator::NodeScan(scan) => LogicalOperator::Filter(FilterOp {
928                predicate,
929                pushdown_hint: None,
930                input: Box::new(LogicalOperator::NodeScan(scan)),
931            }),
932
933            // For other operators, keep filter on top
934            other => LogicalOperator::Filter(FilterOp {
935                predicate,
936                pushdown_hint: None,
937                input: Box::new(other),
938            }),
939        }
940    }
941
942    /// Collects all output variable names from an operator.
943    fn collect_output_variables(&self, op: &LogicalOperator) -> HashSet<String> {
944        let mut vars = HashSet::new();
945        Self::collect_output_variables_recursive(op, &mut vars);
946        vars
947    }
948
949    /// Recursively collects output variables from an operator.
950    fn collect_output_variables_recursive(op: &LogicalOperator, vars: &mut HashSet<String>) {
951        match op {
952            LogicalOperator::NodeScan(scan) => {
953                vars.insert(scan.variable.clone());
954            }
955            LogicalOperator::EdgeScan(scan) => {
956                vars.insert(scan.variable.clone());
957            }
958            LogicalOperator::Expand(expand) => {
959                vars.insert(expand.to_variable.clone());
960                if let Some(edge_var) = &expand.edge_variable {
961                    vars.insert(edge_var.clone());
962                }
963                Self::collect_output_variables_recursive(&expand.input, vars);
964            }
965            LogicalOperator::Filter(filter) => {
966                Self::collect_output_variables_recursive(&filter.input, vars);
967            }
968            LogicalOperator::Project(proj) => {
969                for p in &proj.projections {
970                    if let Some(alias) = &p.alias {
971                        vars.insert(alias.clone());
972                    }
973                }
974                Self::collect_output_variables_recursive(&proj.input, vars);
975            }
976            LogicalOperator::Join(join) => {
977                Self::collect_output_variables_recursive(&join.left, vars);
978                Self::collect_output_variables_recursive(&join.right, vars);
979            }
980            LogicalOperator::Aggregate(agg) => {
981                for expr in &agg.group_by {
982                    Self::collect_variables(expr, vars);
983                }
984                for agg_expr in &agg.aggregates {
985                    if let Some(alias) = &agg_expr.alias {
986                        vars.insert(alias.clone());
987                    }
988                }
989            }
990            LogicalOperator::Return(ret) => {
991                Self::collect_output_variables_recursive(&ret.input, vars);
992            }
993            LogicalOperator::Limit(limit) => {
994                Self::collect_output_variables_recursive(&limit.input, vars);
995            }
996            LogicalOperator::Skip(skip) => {
997                Self::collect_output_variables_recursive(&skip.input, vars);
998            }
999            LogicalOperator::Sort(sort) => {
1000                Self::collect_output_variables_recursive(&sort.input, vars);
1001            }
1002            LogicalOperator::Distinct(distinct) => {
1003                Self::collect_output_variables_recursive(&distinct.input, vars);
1004            }
1005            _ => {}
1006        }
1007    }
1008
1009    /// Extracts all variable names referenced in an expression.
1010    fn extract_variables(&self, expr: &LogicalExpression) -> HashSet<String> {
1011        let mut vars = HashSet::new();
1012        Self::collect_variables(expr, &mut vars);
1013        vars
1014    }
1015
1016    /// Recursively collects variable names from an expression.
1017    fn collect_variables(expr: &LogicalExpression, vars: &mut HashSet<String>) {
1018        match expr {
1019            LogicalExpression::Variable(name) => {
1020                vars.insert(name.clone());
1021            }
1022            LogicalExpression::Property { variable, .. } => {
1023                vars.insert(variable.clone());
1024            }
1025            LogicalExpression::Binary { left, right, .. } => {
1026                Self::collect_variables(left, vars);
1027                Self::collect_variables(right, vars);
1028            }
1029            LogicalExpression::Unary { operand, .. } => {
1030                Self::collect_variables(operand, vars);
1031            }
1032            LogicalExpression::FunctionCall { args, .. } => {
1033                for arg in args {
1034                    Self::collect_variables(arg, vars);
1035                }
1036            }
1037            LogicalExpression::List(items) => {
1038                for item in items {
1039                    Self::collect_variables(item, vars);
1040                }
1041            }
1042            LogicalExpression::Map(pairs) => {
1043                for (_, value) in pairs {
1044                    Self::collect_variables(value, vars);
1045                }
1046            }
1047            LogicalExpression::IndexAccess { base, index } => {
1048                Self::collect_variables(base, vars);
1049                Self::collect_variables(index, vars);
1050            }
1051            LogicalExpression::SliceAccess { base, start, end } => {
1052                Self::collect_variables(base, vars);
1053                if let Some(s) = start {
1054                    Self::collect_variables(s, vars);
1055                }
1056                if let Some(e) = end {
1057                    Self::collect_variables(e, vars);
1058                }
1059            }
1060            LogicalExpression::Case {
1061                operand,
1062                when_clauses,
1063                else_clause,
1064            } => {
1065                if let Some(op) = operand {
1066                    Self::collect_variables(op, vars);
1067                }
1068                for (cond, result) in when_clauses {
1069                    Self::collect_variables(cond, vars);
1070                    Self::collect_variables(result, vars);
1071                }
1072                if let Some(else_expr) = else_clause {
1073                    Self::collect_variables(else_expr, vars);
1074                }
1075            }
1076            LogicalExpression::Labels(var)
1077            | LogicalExpression::Type(var)
1078            | LogicalExpression::Id(var) => {
1079                vars.insert(var.clone());
1080            }
1081            LogicalExpression::Literal(_) | LogicalExpression::Parameter(_) => {}
1082            LogicalExpression::ListComprehension {
1083                list_expr,
1084                filter_expr,
1085                map_expr,
1086                ..
1087            } => {
1088                Self::collect_variables(list_expr, vars);
1089                if let Some(filter) = filter_expr {
1090                    Self::collect_variables(filter, vars);
1091                }
1092                Self::collect_variables(map_expr, vars);
1093            }
1094            LogicalExpression::ListPredicate {
1095                list_expr,
1096                predicate,
1097                ..
1098            } => {
1099                Self::collect_variables(list_expr, vars);
1100                Self::collect_variables(predicate, vars);
1101            }
1102            LogicalExpression::ExistsSubquery(_)
1103            | LogicalExpression::CountSubquery(_)
1104            | LogicalExpression::ValueSubquery(_) => {
1105                // Subqueries have their own variable scope
1106            }
1107            LogicalExpression::PatternComprehension { projection, .. } => {
1108                Self::collect_variables(projection, vars);
1109            }
1110            LogicalExpression::MapProjection { base, entries } => {
1111                vars.insert(base.clone());
1112                for entry in entries {
1113                    if let crate::query::plan::MapProjectionEntry::LiteralEntry(_, expr) = entry {
1114                        Self::collect_variables(expr, vars);
1115                    }
1116                }
1117            }
1118            LogicalExpression::Reduce {
1119                initial,
1120                list,
1121                expression,
1122                ..
1123            } => {
1124                Self::collect_variables(initial, vars);
1125                Self::collect_variables(list, vars);
1126                Self::collect_variables(expression, vars);
1127            }
1128        }
1129    }
1130
1131    /// Extracts aliases from projection expressions.
1132    fn extract_projection_aliases(
1133        &self,
1134        projections: &[crate::query::plan::Projection],
1135    ) -> HashSet<String> {
1136        projections.iter().filter_map(|p| p.alias.clone()).collect()
1137    }
1138}
1139
1140impl Default for Optimizer {
1141    fn default() -> Self {
1142        Self::new()
1143    }
1144}
1145
1146#[cfg(test)]
1147mod tests {
1148    use super::*;
1149    use crate::query::plan::{
1150        AggregateExpr, AggregateFunction, AggregateOp, BinaryOp, DistinctOp, ExpandDirection,
1151        ExpandOp, JoinOp, JoinType, LimitOp, NodeScanOp, PathMode, ProjectOp, Projection,
1152        ReturnItem, ReturnOp, SkipOp, SortKey, SortOp, SortOrder, UnaryOp,
1153    };
1154    use grafeo_common::types::Value;
1155
1156    #[test]
1157    fn test_optimizer_filter_pushdown_simple() {
1158        // Query: MATCH (n:Person) WHERE n.age > 30 RETURN n
1159        // Before: Return -> Filter -> NodeScan
1160        // After:  Return -> Filter -> NodeScan (filter stays at bottom)
1161
1162        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1163            items: vec![ReturnItem {
1164                expression: LogicalExpression::Variable("n".to_string()),
1165                alias: None,
1166            }],
1167            distinct: false,
1168            input: Box::new(LogicalOperator::Filter(FilterOp {
1169                predicate: LogicalExpression::Binary {
1170                    left: Box::new(LogicalExpression::Property {
1171                        variable: "n".to_string(),
1172                        property: "age".to_string(),
1173                    }),
1174                    op: BinaryOp::Gt,
1175                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1176                },
1177                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1178                    variable: "n".to_string(),
1179                    label: Some("Person".to_string()),
1180                    input: None,
1181                })),
1182                pushdown_hint: None,
1183            })),
1184        }));
1185
1186        let optimizer = Optimizer::new();
1187        let optimized = optimizer.optimize(plan).unwrap();
1188
1189        // The structure should remain similar (filter stays near scan)
1190        if let LogicalOperator::Return(ret) = &optimized.root
1191            && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1192            && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1193        {
1194            assert_eq!(scan.variable, "n");
1195            return;
1196        }
1197        panic!("Expected Return -> Filter -> NodeScan structure");
1198    }
1199
1200    #[test]
1201    fn test_optimizer_filter_pushdown_through_expand() {
1202        // Query: MATCH (a:Person)-[:KNOWS]->(b) WHERE a.age > 30 RETURN b
1203        // The filter on 'a' should be pushed before the expand
1204
1205        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1206            items: vec![ReturnItem {
1207                expression: LogicalExpression::Variable("b".to_string()),
1208                alias: None,
1209            }],
1210            distinct: false,
1211            input: Box::new(LogicalOperator::Filter(FilterOp {
1212                predicate: LogicalExpression::Binary {
1213                    left: Box::new(LogicalExpression::Property {
1214                        variable: "a".to_string(),
1215                        property: "age".to_string(),
1216                    }),
1217                    op: BinaryOp::Gt,
1218                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1219                },
1220                pushdown_hint: None,
1221                input: Box::new(LogicalOperator::Expand(ExpandOp {
1222                    from_variable: "a".to_string(),
1223                    to_variable: "b".to_string(),
1224                    edge_variable: None,
1225                    direction: ExpandDirection::Outgoing,
1226                    edge_types: vec!["KNOWS".to_string()],
1227                    min_hops: 1,
1228                    max_hops: Some(1),
1229                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1230                        variable: "a".to_string(),
1231                        label: Some("Person".to_string()),
1232                        input: None,
1233                    })),
1234                    path_alias: None,
1235                    path_mode: PathMode::Walk,
1236                })),
1237            })),
1238        }));
1239
1240        let optimizer = Optimizer::new();
1241        let optimized = optimizer.optimize(plan).unwrap();
1242
1243        // Filter on 'a' should be pushed before the expand
1244        // Expected: Return -> Expand -> Filter -> NodeScan
1245        if let LogicalOperator::Return(ret) = &optimized.root
1246            && let LogicalOperator::Expand(expand) = ret.input.as_ref()
1247            && let LogicalOperator::Filter(filter) = expand.input.as_ref()
1248            && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1249        {
1250            assert_eq!(scan.variable, "a");
1251            assert_eq!(expand.from_variable, "a");
1252            assert_eq!(expand.to_variable, "b");
1253            return;
1254        }
1255        panic!("Expected Return -> Expand -> Filter -> NodeScan structure");
1256    }
1257
1258    #[test]
1259    fn test_optimizer_filter_not_pushed_through_expand_for_target_var() {
1260        // Query: MATCH (a:Person)-[:KNOWS]->(b) WHERE b.age > 30 RETURN a
1261        // The filter on 'b' should NOT be pushed before the expand
1262
1263        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1264            items: vec![ReturnItem {
1265                expression: LogicalExpression::Variable("a".to_string()),
1266                alias: None,
1267            }],
1268            distinct: false,
1269            input: Box::new(LogicalOperator::Filter(FilterOp {
1270                predicate: LogicalExpression::Binary {
1271                    left: Box::new(LogicalExpression::Property {
1272                        variable: "b".to_string(),
1273                        property: "age".to_string(),
1274                    }),
1275                    op: BinaryOp::Gt,
1276                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1277                },
1278                pushdown_hint: None,
1279                input: Box::new(LogicalOperator::Expand(ExpandOp {
1280                    from_variable: "a".to_string(),
1281                    to_variable: "b".to_string(),
1282                    edge_variable: None,
1283                    direction: ExpandDirection::Outgoing,
1284                    edge_types: vec!["KNOWS".to_string()],
1285                    min_hops: 1,
1286                    max_hops: Some(1),
1287                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1288                        variable: "a".to_string(),
1289                        label: Some("Person".to_string()),
1290                        input: None,
1291                    })),
1292                    path_alias: None,
1293                    path_mode: PathMode::Walk,
1294                })),
1295            })),
1296        }));
1297
1298        let optimizer = Optimizer::new();
1299        let optimized = optimizer.optimize(plan).unwrap();
1300
1301        // Filter on 'b' should stay after the expand
1302        // Expected: Return -> Filter -> Expand -> NodeScan
1303        if let LogicalOperator::Return(ret) = &optimized.root
1304            && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1305        {
1306            // Check that the filter is on 'b'
1307            if let LogicalExpression::Binary { left, .. } = &filter.predicate
1308                && let LogicalExpression::Property { variable, .. } = left.as_ref()
1309            {
1310                assert_eq!(variable, "b");
1311            }
1312
1313            if let LogicalOperator::Expand(expand) = filter.input.as_ref()
1314                && let LogicalOperator::NodeScan(_) = expand.input.as_ref()
1315            {
1316                return;
1317            }
1318        }
1319        panic!("Expected Return -> Filter -> Expand -> NodeScan structure");
1320    }
1321
1322    #[test]
1323    fn test_optimizer_extract_variables() {
1324        let optimizer = Optimizer::new();
1325
1326        let expr = LogicalExpression::Binary {
1327            left: Box::new(LogicalExpression::Property {
1328                variable: "n".to_string(),
1329                property: "age".to_string(),
1330            }),
1331            op: BinaryOp::Gt,
1332            right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1333        };
1334
1335        let vars = optimizer.extract_variables(&expr);
1336        assert_eq!(vars.len(), 1);
1337        assert!(vars.contains("n"));
1338    }
1339
1340    // Additional tests for optimizer configuration
1341
1342    #[test]
1343    fn test_optimizer_default() {
1344        let optimizer = Optimizer::default();
1345        // Should be able to optimize an empty plan
1346        let plan = LogicalPlan::new(LogicalOperator::Empty);
1347        let result = optimizer.optimize(plan);
1348        assert!(result.is_ok());
1349    }
1350
1351    #[test]
1352    fn test_optimizer_with_filter_pushdown_disabled() {
1353        let optimizer = Optimizer::new().with_filter_pushdown(false);
1354
1355        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1356            items: vec![ReturnItem {
1357                expression: LogicalExpression::Variable("n".to_string()),
1358                alias: None,
1359            }],
1360            distinct: false,
1361            input: Box::new(LogicalOperator::Filter(FilterOp {
1362                predicate: LogicalExpression::Literal(Value::Bool(true)),
1363                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1364                    variable: "n".to_string(),
1365                    label: None,
1366                    input: None,
1367                })),
1368                pushdown_hint: None,
1369            })),
1370        }));
1371
1372        let optimized = optimizer.optimize(plan).unwrap();
1373        // Structure should be unchanged
1374        if let LogicalOperator::Return(ret) = &optimized.root
1375            && let LogicalOperator::Filter(_) = ret.input.as_ref()
1376        {
1377            return;
1378        }
1379        panic!("Expected unchanged structure");
1380    }
1381
1382    #[test]
1383    fn test_optimizer_with_join_reorder_disabled() {
1384        let optimizer = Optimizer::new().with_join_reorder(false);
1385        assert!(
1386            optimizer
1387                .optimize(LogicalPlan::new(LogicalOperator::Empty))
1388                .is_ok()
1389        );
1390    }
1391
1392    #[test]
1393    fn test_optimizer_with_cost_model() {
1394        let cost_model = CostModel::new();
1395        let optimizer = Optimizer::new().with_cost_model(cost_model);
1396        assert!(
1397            optimizer
1398                .cost_model()
1399                .estimate(&LogicalOperator::Empty, 0.0)
1400                .total()
1401                < 0.001
1402        );
1403    }
1404
1405    #[test]
1406    fn test_optimizer_with_cardinality_estimator() {
1407        let mut estimator = CardinalityEstimator::new();
1408        estimator.add_table_stats("Test", TableStats::new(500));
1409        let optimizer = Optimizer::new().with_cardinality_estimator(estimator);
1410
1411        let scan = LogicalOperator::NodeScan(NodeScanOp {
1412            variable: "n".to_string(),
1413            label: Some("Test".to_string()),
1414            input: None,
1415        });
1416        let plan = LogicalPlan::new(scan);
1417
1418        let cardinality = optimizer.estimate_cardinality(&plan);
1419        assert!((cardinality - 500.0).abs() < 0.001);
1420    }
1421
1422    #[test]
1423    fn test_optimizer_estimate_cost() {
1424        let optimizer = Optimizer::new();
1425        let plan = LogicalPlan::new(LogicalOperator::NodeScan(NodeScanOp {
1426            variable: "n".to_string(),
1427            label: None,
1428            input: None,
1429        }));
1430
1431        let cost = optimizer.estimate_cost(&plan);
1432        assert!(cost.total() > 0.0);
1433    }
1434
1435    // Filter pushdown through various operators
1436
1437    #[test]
1438    fn test_filter_pushdown_through_project() {
1439        let optimizer = Optimizer::new();
1440
1441        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1442            predicate: LogicalExpression::Binary {
1443                left: Box::new(LogicalExpression::Property {
1444                    variable: "n".to_string(),
1445                    property: "age".to_string(),
1446                }),
1447                op: BinaryOp::Gt,
1448                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1449            },
1450            pushdown_hint: None,
1451            input: Box::new(LogicalOperator::Project(ProjectOp {
1452                projections: vec![Projection {
1453                    expression: LogicalExpression::Variable("n".to_string()),
1454                    alias: None,
1455                }],
1456                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1457                    variable: "n".to_string(),
1458                    label: None,
1459                    input: None,
1460                })),
1461                pass_through_input: false,
1462            })),
1463        }));
1464
1465        let optimized = optimizer.optimize(plan).unwrap();
1466
1467        // Filter should be pushed through Project
1468        if let LogicalOperator::Project(proj) = &optimized.root
1469            && let LogicalOperator::Filter(_) = proj.input.as_ref()
1470        {
1471            return;
1472        }
1473        panic!("Expected Project -> Filter structure");
1474    }
1475
1476    #[test]
1477    fn test_filter_not_pushed_through_project_with_alias() {
1478        let optimizer = Optimizer::new();
1479
1480        // Filter on computed column 'x' should not be pushed through project that creates 'x'
1481        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1482            predicate: LogicalExpression::Binary {
1483                left: Box::new(LogicalExpression::Variable("x".to_string())),
1484                op: BinaryOp::Gt,
1485                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1486            },
1487            pushdown_hint: None,
1488            input: Box::new(LogicalOperator::Project(ProjectOp {
1489                projections: vec![Projection {
1490                    expression: LogicalExpression::Property {
1491                        variable: "n".to_string(),
1492                        property: "age".to_string(),
1493                    },
1494                    alias: Some("x".to_string()),
1495                }],
1496                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1497                    variable: "n".to_string(),
1498                    label: None,
1499                    input: None,
1500                })),
1501                pass_through_input: false,
1502            })),
1503        }));
1504
1505        let optimized = optimizer.optimize(plan).unwrap();
1506
1507        // Filter should stay above Project
1508        if let LogicalOperator::Filter(filter) = &optimized.root
1509            && let LogicalOperator::Project(_) = filter.input.as_ref()
1510        {
1511            return;
1512        }
1513        panic!("Expected Filter -> Project structure");
1514    }
1515
1516    #[test]
1517    fn test_filter_pushdown_through_limit() {
1518        let optimizer = Optimizer::new();
1519
1520        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1521            predicate: LogicalExpression::Literal(Value::Bool(true)),
1522            pushdown_hint: None,
1523            input: Box::new(LogicalOperator::Limit(LimitOp {
1524                count: 10.into(),
1525                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1526                    variable: "n".to_string(),
1527                    label: None,
1528                    input: None,
1529                })),
1530            })),
1531        }));
1532
1533        let optimized = optimizer.optimize(plan).unwrap();
1534
1535        // Filter stays above Limit (cannot be pushed through)
1536        if let LogicalOperator::Filter(filter) = &optimized.root
1537            && let LogicalOperator::Limit(_) = filter.input.as_ref()
1538        {
1539            return;
1540        }
1541        panic!("Expected Filter -> Limit structure");
1542    }
1543
1544    #[test]
1545    fn test_filter_pushdown_through_sort() {
1546        let optimizer = Optimizer::new();
1547
1548        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1549            predicate: LogicalExpression::Literal(Value::Bool(true)),
1550            pushdown_hint: None,
1551            input: Box::new(LogicalOperator::Sort(SortOp {
1552                keys: vec![SortKey {
1553                    expression: LogicalExpression::Variable("n".to_string()),
1554                    order: SortOrder::Ascending,
1555                    nulls: None,
1556                }],
1557                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1558                    variable: "n".to_string(),
1559                    label: None,
1560                    input: None,
1561                })),
1562            })),
1563        }));
1564
1565        let optimized = optimizer.optimize(plan).unwrap();
1566
1567        // Filter stays above Sort
1568        if let LogicalOperator::Filter(filter) = &optimized.root
1569            && let LogicalOperator::Sort(_) = filter.input.as_ref()
1570        {
1571            return;
1572        }
1573        panic!("Expected Filter -> Sort structure");
1574    }
1575
1576    #[test]
1577    fn test_filter_pushdown_through_distinct() {
1578        let optimizer = Optimizer::new();
1579
1580        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1581            predicate: LogicalExpression::Literal(Value::Bool(true)),
1582            pushdown_hint: None,
1583            input: Box::new(LogicalOperator::Distinct(DistinctOp {
1584                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1585                    variable: "n".to_string(),
1586                    label: None,
1587                    input: None,
1588                })),
1589                columns: None,
1590            })),
1591        }));
1592
1593        let optimized = optimizer.optimize(plan).unwrap();
1594
1595        // Filter stays above Distinct
1596        if let LogicalOperator::Filter(filter) = &optimized.root
1597            && let LogicalOperator::Distinct(_) = filter.input.as_ref()
1598        {
1599            return;
1600        }
1601        panic!("Expected Filter -> Distinct structure");
1602    }
1603
1604    #[test]
1605    fn test_filter_not_pushed_through_aggregate() {
1606        let optimizer = Optimizer::new();
1607
1608        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1609            predicate: LogicalExpression::Binary {
1610                left: Box::new(LogicalExpression::Variable("cnt".to_string())),
1611                op: BinaryOp::Gt,
1612                right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1613            },
1614            pushdown_hint: None,
1615            input: Box::new(LogicalOperator::Aggregate(AggregateOp {
1616                group_by: vec![],
1617                aggregates: vec![AggregateExpr {
1618                    function: AggregateFunction::Count,
1619                    expression: None,
1620                    expression2: None,
1621                    distinct: false,
1622                    alias: Some("cnt".to_string()),
1623                    percentile: None,
1624                    separator: None,
1625                }],
1626                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1627                    variable: "n".to_string(),
1628                    label: None,
1629                    input: None,
1630                })),
1631                having: None,
1632            })),
1633        }));
1634
1635        let optimized = optimizer.optimize(plan).unwrap();
1636
1637        // Filter should stay above Aggregate
1638        if let LogicalOperator::Filter(filter) = &optimized.root
1639            && let LogicalOperator::Aggregate(_) = filter.input.as_ref()
1640        {
1641            return;
1642        }
1643        panic!("Expected Filter -> Aggregate structure");
1644    }
1645
1646    #[test]
1647    fn test_filter_pushdown_to_left_join_side() {
1648        let optimizer = Optimizer::new();
1649
1650        // Filter on left variable should be pushed to left side
1651        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1652            predicate: LogicalExpression::Binary {
1653                left: Box::new(LogicalExpression::Property {
1654                    variable: "a".to_string(),
1655                    property: "age".to_string(),
1656                }),
1657                op: BinaryOp::Gt,
1658                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1659            },
1660            pushdown_hint: None,
1661            input: Box::new(LogicalOperator::Join(JoinOp {
1662                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1663                    variable: "a".to_string(),
1664                    label: Some("Person".to_string()),
1665                    input: None,
1666                })),
1667                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1668                    variable: "b".to_string(),
1669                    label: Some("Company".to_string()),
1670                    input: None,
1671                })),
1672                join_type: JoinType::Inner,
1673                conditions: vec![],
1674            })),
1675        }));
1676
1677        let optimized = optimizer.optimize(plan).unwrap();
1678
1679        // Filter should be pushed to left side of join
1680        if let LogicalOperator::Join(join) = &optimized.root
1681            && let LogicalOperator::Filter(_) = join.left.as_ref()
1682        {
1683            return;
1684        }
1685        panic!("Expected Join with Filter on left side");
1686    }
1687
1688    #[test]
1689    fn test_filter_pushdown_to_right_join_side() {
1690        let optimizer = Optimizer::new();
1691
1692        // Filter on right variable should be pushed to right side
1693        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1694            predicate: LogicalExpression::Binary {
1695                left: Box::new(LogicalExpression::Property {
1696                    variable: "b".to_string(),
1697                    property: "name".to_string(),
1698                }),
1699                op: BinaryOp::Eq,
1700                right: Box::new(LogicalExpression::Literal(Value::String("Acme".into()))),
1701            },
1702            pushdown_hint: None,
1703            input: Box::new(LogicalOperator::Join(JoinOp {
1704                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1705                    variable: "a".to_string(),
1706                    label: Some("Person".to_string()),
1707                    input: None,
1708                })),
1709                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1710                    variable: "b".to_string(),
1711                    label: Some("Company".to_string()),
1712                    input: None,
1713                })),
1714                join_type: JoinType::Inner,
1715                conditions: vec![],
1716            })),
1717        }));
1718
1719        let optimized = optimizer.optimize(plan).unwrap();
1720
1721        // Filter should be pushed to right side of join
1722        if let LogicalOperator::Join(join) = &optimized.root
1723            && let LogicalOperator::Filter(_) = join.right.as_ref()
1724        {
1725            return;
1726        }
1727        panic!("Expected Join with Filter on right side");
1728    }
1729
1730    #[test]
1731    fn test_filter_not_pushed_when_uses_both_join_sides() {
1732        let optimizer = Optimizer::new();
1733
1734        // Filter using both variables should stay above join
1735        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1736            predicate: LogicalExpression::Binary {
1737                left: Box::new(LogicalExpression::Property {
1738                    variable: "a".to_string(),
1739                    property: "id".to_string(),
1740                }),
1741                op: BinaryOp::Eq,
1742                right: Box::new(LogicalExpression::Property {
1743                    variable: "b".to_string(),
1744                    property: "a_id".to_string(),
1745                }),
1746            },
1747            pushdown_hint: None,
1748            input: Box::new(LogicalOperator::Join(JoinOp {
1749                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1750                    variable: "a".to_string(),
1751                    label: None,
1752                    input: None,
1753                })),
1754                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1755                    variable: "b".to_string(),
1756                    label: None,
1757                    input: None,
1758                })),
1759                join_type: JoinType::Inner,
1760                conditions: vec![],
1761            })),
1762        }));
1763
1764        let optimized = optimizer.optimize(plan).unwrap();
1765
1766        // Filter should stay above join
1767        if let LogicalOperator::Filter(filter) = &optimized.root
1768            && let LogicalOperator::Join(_) = filter.input.as_ref()
1769        {
1770            return;
1771        }
1772        panic!("Expected Filter -> Join structure");
1773    }
1774
1775    // Variable extraction tests
1776
1777    #[test]
1778    fn test_extract_variables_from_variable() {
1779        let optimizer = Optimizer::new();
1780        let expr = LogicalExpression::Variable("x".to_string());
1781        let vars = optimizer.extract_variables(&expr);
1782        assert_eq!(vars.len(), 1);
1783        assert!(vars.contains("x"));
1784    }
1785
1786    #[test]
1787    fn test_extract_variables_from_unary() {
1788        let optimizer = Optimizer::new();
1789        let expr = LogicalExpression::Unary {
1790            op: UnaryOp::Not,
1791            operand: Box::new(LogicalExpression::Variable("x".to_string())),
1792        };
1793        let vars = optimizer.extract_variables(&expr);
1794        assert_eq!(vars.len(), 1);
1795        assert!(vars.contains("x"));
1796    }
1797
1798    #[test]
1799    fn test_extract_variables_from_function_call() {
1800        let optimizer = Optimizer::new();
1801        let expr = LogicalExpression::FunctionCall {
1802            name: "length".to_string(),
1803            args: vec![
1804                LogicalExpression::Variable("a".to_string()),
1805                LogicalExpression::Variable("b".to_string()),
1806            ],
1807            distinct: false,
1808        };
1809        let vars = optimizer.extract_variables(&expr);
1810        assert_eq!(vars.len(), 2);
1811        assert!(vars.contains("a"));
1812        assert!(vars.contains("b"));
1813    }
1814
1815    #[test]
1816    fn test_extract_variables_from_list() {
1817        let optimizer = Optimizer::new();
1818        let expr = LogicalExpression::List(vec![
1819            LogicalExpression::Variable("a".to_string()),
1820            LogicalExpression::Literal(Value::Int64(1)),
1821            LogicalExpression::Variable("b".to_string()),
1822        ]);
1823        let vars = optimizer.extract_variables(&expr);
1824        assert_eq!(vars.len(), 2);
1825        assert!(vars.contains("a"));
1826        assert!(vars.contains("b"));
1827    }
1828
1829    #[test]
1830    fn test_extract_variables_from_map() {
1831        let optimizer = Optimizer::new();
1832        let expr = LogicalExpression::Map(vec![
1833            (
1834                "key1".to_string(),
1835                LogicalExpression::Variable("a".to_string()),
1836            ),
1837            (
1838                "key2".to_string(),
1839                LogicalExpression::Variable("b".to_string()),
1840            ),
1841        ]);
1842        let vars = optimizer.extract_variables(&expr);
1843        assert_eq!(vars.len(), 2);
1844        assert!(vars.contains("a"));
1845        assert!(vars.contains("b"));
1846    }
1847
1848    #[test]
1849    fn test_extract_variables_from_index_access() {
1850        let optimizer = Optimizer::new();
1851        let expr = LogicalExpression::IndexAccess {
1852            base: Box::new(LogicalExpression::Variable("list".to_string())),
1853            index: Box::new(LogicalExpression::Variable("idx".to_string())),
1854        };
1855        let vars = optimizer.extract_variables(&expr);
1856        assert_eq!(vars.len(), 2);
1857        assert!(vars.contains("list"));
1858        assert!(vars.contains("idx"));
1859    }
1860
1861    #[test]
1862    fn test_extract_variables_from_slice_access() {
1863        let optimizer = Optimizer::new();
1864        let expr = LogicalExpression::SliceAccess {
1865            base: Box::new(LogicalExpression::Variable("list".to_string())),
1866            start: Some(Box::new(LogicalExpression::Variable("s".to_string()))),
1867            end: Some(Box::new(LogicalExpression::Variable("e".to_string()))),
1868        };
1869        let vars = optimizer.extract_variables(&expr);
1870        assert_eq!(vars.len(), 3);
1871        assert!(vars.contains("list"));
1872        assert!(vars.contains("s"));
1873        assert!(vars.contains("e"));
1874    }
1875
1876    #[test]
1877    fn test_extract_variables_from_case() {
1878        let optimizer = Optimizer::new();
1879        let expr = LogicalExpression::Case {
1880            operand: Some(Box::new(LogicalExpression::Variable("x".to_string()))),
1881            when_clauses: vec![(
1882                LogicalExpression::Literal(Value::Int64(1)),
1883                LogicalExpression::Variable("a".to_string()),
1884            )],
1885            else_clause: Some(Box::new(LogicalExpression::Variable("b".to_string()))),
1886        };
1887        let vars = optimizer.extract_variables(&expr);
1888        assert_eq!(vars.len(), 3);
1889        assert!(vars.contains("x"));
1890        assert!(vars.contains("a"));
1891        assert!(vars.contains("b"));
1892    }
1893
1894    #[test]
1895    fn test_extract_variables_from_labels() {
1896        let optimizer = Optimizer::new();
1897        let expr = LogicalExpression::Labels("n".to_string());
1898        let vars = optimizer.extract_variables(&expr);
1899        assert_eq!(vars.len(), 1);
1900        assert!(vars.contains("n"));
1901    }
1902
1903    #[test]
1904    fn test_extract_variables_from_type() {
1905        let optimizer = Optimizer::new();
1906        let expr = LogicalExpression::Type("e".to_string());
1907        let vars = optimizer.extract_variables(&expr);
1908        assert_eq!(vars.len(), 1);
1909        assert!(vars.contains("e"));
1910    }
1911
1912    #[test]
1913    fn test_extract_variables_from_id() {
1914        let optimizer = Optimizer::new();
1915        let expr = LogicalExpression::Id("n".to_string());
1916        let vars = optimizer.extract_variables(&expr);
1917        assert_eq!(vars.len(), 1);
1918        assert!(vars.contains("n"));
1919    }
1920
1921    #[test]
1922    fn test_extract_variables_from_list_comprehension() {
1923        let optimizer = Optimizer::new();
1924        let expr = LogicalExpression::ListComprehension {
1925            variable: "x".to_string(),
1926            list_expr: Box::new(LogicalExpression::Variable("items".to_string())),
1927            filter_expr: Some(Box::new(LogicalExpression::Variable("pred".to_string()))),
1928            map_expr: Box::new(LogicalExpression::Variable("result".to_string())),
1929        };
1930        let vars = optimizer.extract_variables(&expr);
1931        assert!(vars.contains("items"));
1932        assert!(vars.contains("pred"));
1933        assert!(vars.contains("result"));
1934    }
1935
1936    #[test]
1937    fn test_extract_variables_from_literal_and_parameter() {
1938        let optimizer = Optimizer::new();
1939
1940        let literal = LogicalExpression::Literal(Value::Int64(42));
1941        assert!(optimizer.extract_variables(&literal).is_empty());
1942
1943        let param = LogicalExpression::Parameter("p".to_string());
1944        assert!(optimizer.extract_variables(&param).is_empty());
1945    }
1946
1947    // Recursive filter pushdown tests
1948
1949    #[test]
1950    fn test_recursive_filter_pushdown_through_skip() {
1951        let optimizer = Optimizer::new();
1952
1953        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1954            items: vec![ReturnItem {
1955                expression: LogicalExpression::Variable("n".to_string()),
1956                alias: None,
1957            }],
1958            distinct: false,
1959            input: Box::new(LogicalOperator::Filter(FilterOp {
1960                predicate: LogicalExpression::Literal(Value::Bool(true)),
1961                pushdown_hint: None,
1962                input: Box::new(LogicalOperator::Skip(SkipOp {
1963                    count: 5.into(),
1964                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1965                        variable: "n".to_string(),
1966                        label: None,
1967                        input: None,
1968                    })),
1969                })),
1970            })),
1971        }));
1972
1973        let optimized = optimizer.optimize(plan).unwrap();
1974
1975        // Verify optimization succeeded
1976        assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1977    }
1978
1979    #[test]
1980    fn test_nested_filter_pushdown() {
1981        let optimizer = Optimizer::new();
1982
1983        // Multiple nested filters
1984        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1985            items: vec![ReturnItem {
1986                expression: LogicalExpression::Variable("n".to_string()),
1987                alias: None,
1988            }],
1989            distinct: false,
1990            input: Box::new(LogicalOperator::Filter(FilterOp {
1991                predicate: LogicalExpression::Binary {
1992                    left: Box::new(LogicalExpression::Property {
1993                        variable: "n".to_string(),
1994                        property: "x".to_string(),
1995                    }),
1996                    op: BinaryOp::Gt,
1997                    right: Box::new(LogicalExpression::Literal(Value::Int64(1))),
1998                },
1999                pushdown_hint: None,
2000                input: Box::new(LogicalOperator::Filter(FilterOp {
2001                    predicate: LogicalExpression::Binary {
2002                        left: Box::new(LogicalExpression::Property {
2003                            variable: "n".to_string(),
2004                            property: "y".to_string(),
2005                        }),
2006                        op: BinaryOp::Lt,
2007                        right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
2008                    },
2009                    pushdown_hint: None,
2010                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2011                        variable: "n".to_string(),
2012                        label: None,
2013                        input: None,
2014                    })),
2015                })),
2016            })),
2017        }));
2018
2019        let optimized = optimizer.optimize(plan).unwrap();
2020        assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
2021    }
2022
2023    #[test]
2024    fn test_cyclic_join_produces_multi_way_join() {
2025        use crate::query::plan::JoinCondition;
2026
2027        // Triangle pattern: a ⋈ b ⋈ c ⋈ a (cyclic)
2028        let scan_a = LogicalOperator::NodeScan(NodeScanOp {
2029            variable: "a".to_string(),
2030            label: Some("Person".to_string()),
2031            input: None,
2032        });
2033        let scan_b = LogicalOperator::NodeScan(NodeScanOp {
2034            variable: "b".to_string(),
2035            label: Some("Person".to_string()),
2036            input: None,
2037        });
2038        let scan_c = LogicalOperator::NodeScan(NodeScanOp {
2039            variable: "c".to_string(),
2040            label: Some("Person".to_string()),
2041            input: None,
2042        });
2043
2044        // Build: Join(Join(a, b, a=b), c, b=c) with extra condition c=a
2045        let join_ab = LogicalOperator::Join(JoinOp {
2046            left: Box::new(scan_a),
2047            right: Box::new(scan_b),
2048            join_type: JoinType::Inner,
2049            conditions: vec![JoinCondition {
2050                left: LogicalExpression::Variable("a".to_string()),
2051                right: LogicalExpression::Variable("b".to_string()),
2052            }],
2053        });
2054
2055        let join_abc = LogicalOperator::Join(JoinOp {
2056            left: Box::new(join_ab),
2057            right: Box::new(scan_c),
2058            join_type: JoinType::Inner,
2059            conditions: vec![
2060                JoinCondition {
2061                    left: LogicalExpression::Variable("b".to_string()),
2062                    right: LogicalExpression::Variable("c".to_string()),
2063                },
2064                JoinCondition {
2065                    left: LogicalExpression::Variable("c".to_string()),
2066                    right: LogicalExpression::Variable("a".to_string()),
2067                },
2068            ],
2069        });
2070
2071        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
2072            items: vec![ReturnItem {
2073                expression: LogicalExpression::Variable("a".to_string()),
2074                alias: None,
2075            }],
2076            distinct: false,
2077            input: Box::new(join_abc),
2078        }));
2079
2080        let mut optimizer = Optimizer::new();
2081        optimizer
2082            .card_estimator
2083            .add_table_stats("Person", cardinality::TableStats::new(1000));
2084
2085        let optimized = optimizer.optimize(plan).unwrap();
2086
2087        // Walk the tree to find a MultiWayJoin
2088        fn has_multi_way_join(op: &LogicalOperator) -> bool {
2089            match op {
2090                LogicalOperator::MultiWayJoin(_) => true,
2091                LogicalOperator::Return(ret) => has_multi_way_join(&ret.input),
2092                LogicalOperator::Filter(f) => has_multi_way_join(&f.input),
2093                LogicalOperator::Project(p) => has_multi_way_join(&p.input),
2094                _ => false,
2095            }
2096        }
2097
2098        assert!(
2099            has_multi_way_join(&optimized.root),
2100            "Expected MultiWayJoin for cyclic triangle pattern"
2101        );
2102    }
2103
2104    #[test]
2105    fn test_acyclic_join_uses_binary_joins() {
2106        use crate::query::plan::JoinCondition;
2107
2108        // Chain: a ⋈ b ⋈ c (acyclic)
2109        let scan_a = LogicalOperator::NodeScan(NodeScanOp {
2110            variable: "a".to_string(),
2111            label: Some("Person".to_string()),
2112            input: None,
2113        });
2114        let scan_b = LogicalOperator::NodeScan(NodeScanOp {
2115            variable: "b".to_string(),
2116            label: Some("Person".to_string()),
2117            input: None,
2118        });
2119        let scan_c = LogicalOperator::NodeScan(NodeScanOp {
2120            variable: "c".to_string(),
2121            label: Some("Company".to_string()),
2122            input: None,
2123        });
2124
2125        let join_ab = LogicalOperator::Join(JoinOp {
2126            left: Box::new(scan_a),
2127            right: Box::new(scan_b),
2128            join_type: JoinType::Inner,
2129            conditions: vec![JoinCondition {
2130                left: LogicalExpression::Variable("a".to_string()),
2131                right: LogicalExpression::Variable("b".to_string()),
2132            }],
2133        });
2134
2135        let join_abc = LogicalOperator::Join(JoinOp {
2136            left: Box::new(join_ab),
2137            right: Box::new(scan_c),
2138            join_type: JoinType::Inner,
2139            conditions: vec![JoinCondition {
2140                left: LogicalExpression::Variable("b".to_string()),
2141                right: LogicalExpression::Variable("c".to_string()),
2142            }],
2143        });
2144
2145        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
2146            items: vec![ReturnItem {
2147                expression: LogicalExpression::Variable("a".to_string()),
2148                alias: None,
2149            }],
2150            distinct: false,
2151            input: Box::new(join_abc),
2152        }));
2153
2154        let mut optimizer = Optimizer::new();
2155        optimizer
2156            .card_estimator
2157            .add_table_stats("Person", cardinality::TableStats::new(1000));
2158        optimizer
2159            .card_estimator
2160            .add_table_stats("Company", cardinality::TableStats::new(100));
2161
2162        let optimized = optimizer.optimize(plan).unwrap();
2163
2164        // Should NOT contain MultiWayJoin for acyclic pattern
2165        fn has_multi_way_join(op: &LogicalOperator) -> bool {
2166            match op {
2167                LogicalOperator::MultiWayJoin(_) => true,
2168                LogicalOperator::Return(ret) => has_multi_way_join(&ret.input),
2169                LogicalOperator::Filter(f) => has_multi_way_join(&f.input),
2170                LogicalOperator::Project(p) => has_multi_way_join(&p.input),
2171                LogicalOperator::Join(j) => {
2172                    has_multi_way_join(&j.left) || has_multi_way_join(&j.right)
2173                }
2174                _ => false,
2175            }
2176        }
2177
2178        assert!(
2179            !has_multi_way_join(&optimized.root),
2180            "Acyclic join should NOT produce MultiWayJoin"
2181        );
2182    }
2183}