Skip to main content

grafeo_engine/query/optimizer/
mod.rs

1//! Makes your queries faster without changing their meaning.
2//!
3//! The optimizer transforms logical plans to run more efficiently:
4//!
5//! | Optimization | What it does |
6//! | ------------ | ------------ |
7//! | Filter Pushdown | Moves `WHERE` clauses closer to scans - filter early, process less |
8//! | Join Reordering | Picks the best order to join tables using the DPccp algorithm |
9//! | Predicate Simplification | Folds constants like `1 + 1` into `2` |
10//!
11//! The optimizer uses [`CostModel`] and [`CardinalityEstimator`] to predict
12//! how expensive different plans are, then picks the cheapest.
13
14pub mod cardinality;
15pub mod cost;
16pub mod join_order;
17
18pub use cardinality::{
19    CardinalityEstimator, ColumnStats, EstimationLog, SelectivityConfig, TableStats,
20};
21pub use cost::{Cost, CostModel};
22pub use join_order::{BitSet, DPccp, JoinGraph, JoinGraphBuilder, JoinPlan};
23
24use crate::query::plan::{
25    FilterOp, JoinCondition, LogicalExpression, LogicalOperator, LogicalPlan, MultiWayJoinOp,
26};
27use grafeo_common::grafeo_debug_span;
28use grafeo_common::utils::error::Result;
29use std::collections::HashSet;
30
31/// Information about a join condition for join reordering.
32#[derive(Debug, Clone)]
33struct JoinInfo {
34    left_var: String,
35    right_var: String,
36    left_expr: LogicalExpression,
37    right_expr: LogicalExpression,
38}
39
40/// A column required by the query, used for projection pushdown.
41#[derive(Debug, Clone, PartialEq, Eq, Hash)]
42enum RequiredColumn {
43    /// A variable (node, edge, or path binding)
44    Variable(String),
45    /// A specific property of a variable
46    Property(String, String),
47}
48
49/// Transforms logical plans for faster execution.
50///
51/// Create with [`new()`](Self::new), then call [`optimize()`](Self::optimize).
52/// Use the builder methods to enable/disable specific optimizations.
53pub struct Optimizer {
54    /// Whether to enable filter pushdown.
55    enable_filter_pushdown: bool,
56    /// Whether to enable join reordering.
57    enable_join_reorder: bool,
58    /// Whether to enable projection pushdown.
59    enable_projection_pushdown: bool,
60    /// Cost model for estimation.
61    cost_model: CostModel,
62    /// Cardinality estimator.
63    card_estimator: CardinalityEstimator,
64}
65
66impl Optimizer {
67    /// Creates a new optimizer with default settings.
68    #[must_use]
69    pub fn new() -> Self {
70        Self {
71            enable_filter_pushdown: true,
72            enable_join_reorder: true,
73            enable_projection_pushdown: true,
74            cost_model: CostModel::new(),
75            card_estimator: CardinalityEstimator::new(),
76        }
77    }
78
79    /// Creates an optimizer with cardinality estimates from the store's statistics.
80    ///
81    /// Pre-populates the cardinality estimator with per-label row counts and
82    /// edge type fanout. Feeds per-edge-type degree stats, label cardinalities,
83    /// and graph totals into the cost model for accurate estimation.
84    #[cfg(feature = "lpg")]
85    #[must_use]
86    pub fn from_store(store: &grafeo_core::graph::lpg::LpgStore) -> Self {
87        store.ensure_statistics_fresh();
88        let stats = store.statistics();
89        Self::from_statistics(&stats)
90    }
91
92    /// Creates an optimizer from any GraphStore implementation.
93    ///
94    /// Unlike [`from_store`](Self::from_store), this does not call
95    /// `ensure_statistics_fresh()` since external stores manage their own
96    /// statistics. The store's [`statistics()`](grafeo_core::graph::GraphStore::statistics) method
97    /// is called directly.
98    #[must_use]
99    pub fn from_graph_store(store: &dyn grafeo_core::graph::GraphStore) -> Self {
100        let stats = store.statistics();
101        Self::from_statistics(&stats)
102    }
103
104    /// Creates an optimizer from RDF statistics.
105    ///
106    /// Uses triple pattern cardinality estimates for cost-based optimization
107    /// of SPARQL queries. Maps total triples to graph totals for the cost model.
108    #[cfg(feature = "triple-store")]
109    #[must_use]
110    pub fn from_rdf_statistics(rdf_stats: grafeo_core::statistics::RdfStatistics) -> Self {
111        let total = rdf_stats.total_triples;
112        let estimator = CardinalityEstimator::from_rdf_statistics(rdf_stats);
113        Self {
114            enable_filter_pushdown: true,
115            enable_join_reorder: true,
116            enable_projection_pushdown: true,
117            cost_model: CostModel::new().with_graph_totals(total, total),
118            card_estimator: estimator,
119        }
120    }
121
122    /// Creates an optimizer from a Statistics snapshot.
123    ///
124    /// Extracts label cardinalities, edge type degrees, and graph totals
125    /// for both the cardinality estimator and cost model.
126    #[must_use]
127    fn from_statistics(stats: &grafeo_core::statistics::Statistics) -> Self {
128        let estimator = CardinalityEstimator::from_statistics(stats);
129
130        let avg_fanout = if stats.total_nodes > 0 {
131            (stats.total_edges as f64 / stats.total_nodes as f64).max(1.0)
132        } else {
133            10.0
134        };
135
136        let edge_type_degrees: std::collections::HashMap<String, (f64, f64)> = stats
137            .edge_types
138            .iter()
139            .map(|(name, et)| (name.clone(), (et.avg_out_degree, et.avg_in_degree)))
140            .collect();
141
142        let label_cardinalities: std::collections::HashMap<String, u64> = stats
143            .labels
144            .iter()
145            .map(|(name, ls)| (name.clone(), ls.node_count))
146            .collect();
147
148        Self {
149            enable_filter_pushdown: true,
150            enable_join_reorder: true,
151            enable_projection_pushdown: true,
152            cost_model: CostModel::new()
153                .with_avg_fanout(avg_fanout)
154                .with_edge_type_degrees(edge_type_degrees)
155                .with_label_cardinalities(label_cardinalities)
156                .with_graph_totals(stats.total_nodes, stats.total_edges),
157            card_estimator: estimator,
158        }
159    }
160
161    /// Enables or disables filter pushdown.
162    pub fn with_filter_pushdown(mut self, enabled: bool) -> Self {
163        self.enable_filter_pushdown = enabled;
164        self
165    }
166
167    /// Enables or disables join reordering.
168    pub fn with_join_reorder(mut self, enabled: bool) -> Self {
169        self.enable_join_reorder = enabled;
170        self
171    }
172
173    /// Enables or disables projection pushdown.
174    pub fn with_projection_pushdown(mut self, enabled: bool) -> Self {
175        self.enable_projection_pushdown = enabled;
176        self
177    }
178
179    /// Sets the cost model.
180    pub fn with_cost_model(mut self, cost_model: CostModel) -> Self {
181        self.cost_model = cost_model;
182        self
183    }
184
185    /// Sets the cardinality estimator.
186    pub fn with_cardinality_estimator(mut self, estimator: CardinalityEstimator) -> Self {
187        self.card_estimator = estimator;
188        self
189    }
190
191    /// Sets the selectivity configuration for the cardinality estimator.
192    pub fn with_selectivity_config(mut self, config: SelectivityConfig) -> Self {
193        self.card_estimator = CardinalityEstimator::with_selectivity_config(config);
194        self
195    }
196
197    /// Returns a reference to the cost model.
198    pub fn cost_model(&self) -> &CostModel {
199        &self.cost_model
200    }
201
202    /// Returns a reference to the cardinality estimator.
203    pub fn cardinality_estimator(&self) -> &CardinalityEstimator {
204        &self.card_estimator
205    }
206
207    /// Estimates the total cost of a plan by recursively costing the entire tree.
208    ///
209    /// Walks every operator in the plan, computing cardinality at each level
210    /// and summing the per-operator costs. Uses actual child cardinalities
211    /// for join cost estimation rather than approximations.
212    pub fn estimate_cost(&self, plan: &LogicalPlan) -> Cost {
213        self.cost_model
214            .estimate_tree(&plan.root, &self.card_estimator)
215    }
216
217    /// Estimates the cardinality of a plan.
218    pub fn estimate_cardinality(&self, plan: &LogicalPlan) -> f64 {
219        self.card_estimator.estimate(&plan.root)
220    }
221
222    /// Optimizes a logical plan.
223    ///
224    /// # Errors
225    ///
226    /// Returns an error if optimization fails.
227    pub fn optimize(&self, plan: LogicalPlan) -> Result<LogicalPlan> {
228        let _span = grafeo_debug_span!("grafeo::query::optimize");
229        let mut root = plan.root;
230
231        // Apply optimization rules
232        if self.enable_filter_pushdown {
233            root = self.push_filters_down(root);
234        }
235
236        if self.enable_join_reorder {
237            root = self.reorder_joins(root);
238        }
239
240        if self.enable_projection_pushdown {
241            root = self.push_projections_down(root);
242        }
243
244        Ok(LogicalPlan {
245            root,
246            explain: plan.explain,
247            profile: plan.profile,
248            default_params: plan.default_params,
249        })
250    }
251
252    /// Pushes projections down the operator tree to eliminate unused columns early.
253    ///
254    /// This optimization:
255    /// 1. Collects required variables/properties from the root
256    /// 2. Propagates requirements down through the tree
257    /// 3. Inserts projections to eliminate unneeded columns before expensive operations
258    fn push_projections_down(&self, op: LogicalOperator) -> LogicalOperator {
259        // Collect required columns from the top of the plan
260        let required = self.collect_required_columns(&op);
261
262        // Push projections down
263        self.push_projections_recursive(op, &required)
264    }
265
266    /// Collects all variables and properties required by an operator and its ancestors.
267    fn collect_required_columns(&self, op: &LogicalOperator) -> HashSet<RequiredColumn> {
268        let mut required = HashSet::new();
269        Self::collect_required_recursive(op, &mut required);
270        required
271    }
272
273    /// Recursively collects required columns.
274    fn collect_required_recursive(op: &LogicalOperator, required: &mut HashSet<RequiredColumn>) {
275        match op {
276            LogicalOperator::Return(ret) => {
277                for item in &ret.items {
278                    Self::collect_from_expression(&item.expression, required);
279                }
280                Self::collect_required_recursive(&ret.input, required);
281            }
282            LogicalOperator::Project(proj) => {
283                for p in &proj.projections {
284                    Self::collect_from_expression(&p.expression, required);
285                }
286                Self::collect_required_recursive(&proj.input, required);
287            }
288            LogicalOperator::Filter(filter) => {
289                Self::collect_from_expression(&filter.predicate, required);
290                Self::collect_required_recursive(&filter.input, required);
291            }
292            LogicalOperator::Sort(sort) => {
293                for key in &sort.keys {
294                    Self::collect_from_expression(&key.expression, required);
295                }
296                Self::collect_required_recursive(&sort.input, required);
297            }
298            LogicalOperator::Aggregate(agg) => {
299                for expr in &agg.group_by {
300                    Self::collect_from_expression(expr, required);
301                }
302                for agg_expr in &agg.aggregates {
303                    if let Some(ref expr) = agg_expr.expression {
304                        Self::collect_from_expression(expr, required);
305                    }
306                }
307                if let Some(ref having) = agg.having {
308                    Self::collect_from_expression(having, required);
309                }
310                Self::collect_required_recursive(&agg.input, required);
311            }
312            LogicalOperator::Join(join) => {
313                for cond in &join.conditions {
314                    Self::collect_from_expression(&cond.left, required);
315                    Self::collect_from_expression(&cond.right, required);
316                }
317                Self::collect_required_recursive(&join.left, required);
318                Self::collect_required_recursive(&join.right, required);
319            }
320            LogicalOperator::Expand(expand) => {
321                // The source and target variables are needed
322                required.insert(RequiredColumn::Variable(expand.from_variable.clone()));
323                required.insert(RequiredColumn::Variable(expand.to_variable.clone()));
324                if let Some(ref edge_var) = expand.edge_variable {
325                    required.insert(RequiredColumn::Variable(edge_var.clone()));
326                }
327                Self::collect_required_recursive(&expand.input, required);
328            }
329            LogicalOperator::Limit(limit) => {
330                Self::collect_required_recursive(&limit.input, required);
331            }
332            LogicalOperator::Skip(skip) => {
333                Self::collect_required_recursive(&skip.input, required);
334            }
335            LogicalOperator::Distinct(distinct) => {
336                Self::collect_required_recursive(&distinct.input, required);
337            }
338            LogicalOperator::NodeScan(scan) => {
339                required.insert(RequiredColumn::Variable(scan.variable.clone()));
340            }
341            LogicalOperator::EdgeScan(scan) => {
342                required.insert(RequiredColumn::Variable(scan.variable.clone()));
343            }
344            LogicalOperator::MultiWayJoin(mwj) => {
345                for cond in &mwj.conditions {
346                    Self::collect_from_expression(&cond.left, required);
347                    Self::collect_from_expression(&cond.right, required);
348                }
349                for input in &mwj.inputs {
350                    Self::collect_required_recursive(input, required);
351                }
352            }
353            _ => {}
354        }
355    }
356
357    /// Collects required columns from an expression.
358    fn collect_from_expression(expr: &LogicalExpression, required: &mut HashSet<RequiredColumn>) {
359        match expr {
360            LogicalExpression::Variable(var) => {
361                required.insert(RequiredColumn::Variable(var.clone()));
362            }
363            LogicalExpression::Property { variable, property } => {
364                required.insert(RequiredColumn::Property(variable.clone(), property.clone()));
365                required.insert(RequiredColumn::Variable(variable.clone()));
366            }
367            LogicalExpression::Binary { left, right, .. } => {
368                Self::collect_from_expression(left, required);
369                Self::collect_from_expression(right, required);
370            }
371            LogicalExpression::Unary { operand, .. } => {
372                Self::collect_from_expression(operand, required);
373            }
374            LogicalExpression::FunctionCall { args, .. } => {
375                for arg in args {
376                    Self::collect_from_expression(arg, required);
377                }
378            }
379            LogicalExpression::List(items) => {
380                for item in items {
381                    Self::collect_from_expression(item, required);
382                }
383            }
384            LogicalExpression::Map(pairs) => {
385                for (_, value) in pairs {
386                    Self::collect_from_expression(value, required);
387                }
388            }
389            LogicalExpression::IndexAccess { base, index } => {
390                Self::collect_from_expression(base, required);
391                Self::collect_from_expression(index, required);
392            }
393            LogicalExpression::SliceAccess { base, start, end } => {
394                Self::collect_from_expression(base, required);
395                if let Some(s) = start {
396                    Self::collect_from_expression(s, required);
397                }
398                if let Some(e) = end {
399                    Self::collect_from_expression(e, required);
400                }
401            }
402            LogicalExpression::Case {
403                operand,
404                when_clauses,
405                else_clause,
406            } => {
407                if let Some(op) = operand {
408                    Self::collect_from_expression(op, required);
409                }
410                for (cond, result) in when_clauses {
411                    Self::collect_from_expression(cond, required);
412                    Self::collect_from_expression(result, required);
413                }
414                if let Some(else_expr) = else_clause {
415                    Self::collect_from_expression(else_expr, required);
416                }
417            }
418            LogicalExpression::Labels(var)
419            | LogicalExpression::Type(var)
420            | LogicalExpression::Id(var) => {
421                required.insert(RequiredColumn::Variable(var.clone()));
422            }
423            LogicalExpression::ListComprehension {
424                list_expr,
425                filter_expr,
426                map_expr,
427                ..
428            } => {
429                Self::collect_from_expression(list_expr, required);
430                if let Some(filter) = filter_expr {
431                    Self::collect_from_expression(filter, required);
432                }
433                Self::collect_from_expression(map_expr, required);
434            }
435            _ => {}
436        }
437    }
438
439    /// Recursively pushes projections down, adding them before expensive operations.
440    fn push_projections_recursive(
441        &self,
442        op: LogicalOperator,
443        required: &HashSet<RequiredColumn>,
444    ) -> LogicalOperator {
445        match op {
446            LogicalOperator::Return(mut ret) => {
447                ret.input = Box::new(self.push_projections_recursive(*ret.input, required));
448                LogicalOperator::Return(ret)
449            }
450            LogicalOperator::Project(mut proj) => {
451                proj.input = Box::new(self.push_projections_recursive(*proj.input, required));
452                LogicalOperator::Project(proj)
453            }
454            LogicalOperator::Filter(mut filter) => {
455                filter.input = Box::new(self.push_projections_recursive(*filter.input, required));
456                LogicalOperator::Filter(filter)
457            }
458            LogicalOperator::Sort(mut sort) => {
459                // Sort is expensive - consider adding a projection before it
460                // to reduce tuple width
461                sort.input = Box::new(self.push_projections_recursive(*sort.input, required));
462                LogicalOperator::Sort(sort)
463            }
464            LogicalOperator::Aggregate(mut agg) => {
465                agg.input = Box::new(self.push_projections_recursive(*agg.input, required));
466                LogicalOperator::Aggregate(agg)
467            }
468            LogicalOperator::Join(mut join) => {
469                // Joins are expensive - the required columns help determine
470                // what to project on each side
471                let left_vars = self.collect_output_variables(&join.left);
472                let right_vars = self.collect_output_variables(&join.right);
473
474                // Filter required columns to each side
475                let left_required: HashSet<_> = required
476                    .iter()
477                    .filter(|c| match c {
478                        RequiredColumn::Variable(v) => left_vars.contains(v),
479                        RequiredColumn::Property(v, _) => left_vars.contains(v),
480                    })
481                    .cloned()
482                    .collect();
483
484                let right_required: HashSet<_> = required
485                    .iter()
486                    .filter(|c| match c {
487                        RequiredColumn::Variable(v) => right_vars.contains(v),
488                        RequiredColumn::Property(v, _) => right_vars.contains(v),
489                    })
490                    .cloned()
491                    .collect();
492
493                join.left = Box::new(self.push_projections_recursive(*join.left, &left_required));
494                join.right =
495                    Box::new(self.push_projections_recursive(*join.right, &right_required));
496                LogicalOperator::Join(join)
497            }
498            LogicalOperator::Expand(mut expand) => {
499                expand.input = Box::new(self.push_projections_recursive(*expand.input, required));
500                LogicalOperator::Expand(expand)
501            }
502            LogicalOperator::Limit(mut limit) => {
503                limit.input = Box::new(self.push_projections_recursive(*limit.input, required));
504                LogicalOperator::Limit(limit)
505            }
506            LogicalOperator::Skip(mut skip) => {
507                skip.input = Box::new(self.push_projections_recursive(*skip.input, required));
508                LogicalOperator::Skip(skip)
509            }
510            LogicalOperator::Distinct(mut distinct) => {
511                distinct.input =
512                    Box::new(self.push_projections_recursive(*distinct.input, required));
513                LogicalOperator::Distinct(distinct)
514            }
515            LogicalOperator::MapCollect(mut mc) => {
516                mc.input = Box::new(self.push_projections_recursive(*mc.input, required));
517                LogicalOperator::MapCollect(mc)
518            }
519            LogicalOperator::MultiWayJoin(mut mwj) => {
520                mwj.inputs = mwj
521                    .inputs
522                    .into_iter()
523                    .map(|input| self.push_projections_recursive(input, required))
524                    .collect();
525                LogicalOperator::MultiWayJoin(mwj)
526            }
527            other => other,
528        }
529    }
530
531    /// Reorders joins in the operator tree using the DPccp algorithm.
532    ///
533    /// This optimization finds the optimal join order by:
534    /// 1. Extracting all base relations (scans) and join conditions
535    /// 2. Building a join graph
536    /// 3. Using dynamic programming to find the cheapest join order
537    fn reorder_joins(&self, op: LogicalOperator) -> LogicalOperator {
538        // First, recursively optimize children
539        let op = self.reorder_joins_recursive(op);
540
541        // Then, if this is a join tree, try to optimize it
542        if let Some((relations, conditions)) = self.extract_join_tree(&op)
543            && relations.len() >= 2
544            && let Some(optimized) = self.optimize_join_order(&relations, &conditions)
545        {
546            return optimized;
547        }
548
549        op
550    }
551
552    /// Recursively applies join reordering to child operators.
553    fn reorder_joins_recursive(&self, op: LogicalOperator) -> LogicalOperator {
554        match op {
555            LogicalOperator::Return(mut ret) => {
556                ret.input = Box::new(self.reorder_joins(*ret.input));
557                LogicalOperator::Return(ret)
558            }
559            LogicalOperator::Project(mut proj) => {
560                proj.input = Box::new(self.reorder_joins(*proj.input));
561                LogicalOperator::Project(proj)
562            }
563            LogicalOperator::Filter(mut filter) => {
564                filter.input = Box::new(self.reorder_joins(*filter.input));
565                LogicalOperator::Filter(filter)
566            }
567            LogicalOperator::Limit(mut limit) => {
568                limit.input = Box::new(self.reorder_joins(*limit.input));
569                LogicalOperator::Limit(limit)
570            }
571            LogicalOperator::Skip(mut skip) => {
572                skip.input = Box::new(self.reorder_joins(*skip.input));
573                LogicalOperator::Skip(skip)
574            }
575            LogicalOperator::Sort(mut sort) => {
576                sort.input = Box::new(self.reorder_joins(*sort.input));
577                LogicalOperator::Sort(sort)
578            }
579            LogicalOperator::Distinct(mut distinct) => {
580                distinct.input = Box::new(self.reorder_joins(*distinct.input));
581                LogicalOperator::Distinct(distinct)
582            }
583            LogicalOperator::Aggregate(mut agg) => {
584                agg.input = Box::new(self.reorder_joins(*agg.input));
585                LogicalOperator::Aggregate(agg)
586            }
587            LogicalOperator::Expand(mut expand) => {
588                expand.input = Box::new(self.reorder_joins(*expand.input));
589                LogicalOperator::Expand(expand)
590            }
591            LogicalOperator::MapCollect(mut mc) => {
592                mc.input = Box::new(self.reorder_joins(*mc.input));
593                LogicalOperator::MapCollect(mc)
594            }
595            LogicalOperator::MultiWayJoin(mut mwj) => {
596                mwj.inputs = mwj
597                    .inputs
598                    .into_iter()
599                    .map(|input| self.reorder_joins(input))
600                    .collect();
601                LogicalOperator::MultiWayJoin(mwj)
602            }
603            // Join operators are handled by the parent reorder_joins call
604            other => other,
605        }
606    }
607
608    /// Extracts base relations and join conditions from a join tree.
609    ///
610    /// Returns None if the operator is not a join tree.
611    fn extract_join_tree(
612        &self,
613        op: &LogicalOperator,
614    ) -> Option<(Vec<(String, LogicalOperator)>, Vec<JoinInfo>)> {
615        let mut relations = Vec::new();
616        let mut join_conditions = Vec::new();
617
618        if !self.collect_join_tree(op, &mut relations, &mut join_conditions) {
619            return None;
620        }
621
622        if relations.len() < 2 {
623            return None;
624        }
625
626        Some((relations, join_conditions))
627    }
628
629    /// Recursively collects base relations and join conditions.
630    ///
631    /// Returns true if this subtree is part of a join tree.
632    fn collect_join_tree(
633        &self,
634        op: &LogicalOperator,
635        relations: &mut Vec<(String, LogicalOperator)>,
636        conditions: &mut Vec<JoinInfo>,
637    ) -> bool {
638        match op {
639            LogicalOperator::Join(join) => {
640                // Collect from both sides
641                let left_ok = self.collect_join_tree(&join.left, relations, conditions);
642                let right_ok = self.collect_join_tree(&join.right, relations, conditions);
643
644                // Add conditions from this join
645                for cond in &join.conditions {
646                    if let (Some(left_var), Some(right_var)) = (
647                        self.extract_variable_from_expr(&cond.left),
648                        self.extract_variable_from_expr(&cond.right),
649                    ) {
650                        conditions.push(JoinInfo {
651                            left_var,
652                            right_var,
653                            left_expr: cond.left.clone(),
654                            right_expr: cond.right.clone(),
655                        });
656                    }
657                }
658
659                left_ok && right_ok
660            }
661            LogicalOperator::NodeScan(scan) => {
662                relations.push((scan.variable.clone(), op.clone()));
663                true
664            }
665            LogicalOperator::EdgeScan(scan) => {
666                relations.push((scan.variable.clone(), op.clone()));
667                true
668            }
669            LogicalOperator::Filter(filter) => {
670                // A filter on a base relation is still part of the join tree
671                self.collect_join_tree(&filter.input, relations, conditions)
672            }
673            LogicalOperator::Expand(expand) => {
674                // Expand is a special case - it's like a join with the adjacency
675                // For now, treat the whole Expand subtree as a single relation
676                relations.push((expand.to_variable.clone(), op.clone()));
677                true
678            }
679            _ => false,
680        }
681    }
682
683    /// Extracts the primary variable from an expression.
684    fn extract_variable_from_expr(&self, expr: &LogicalExpression) -> Option<String> {
685        match expr {
686            LogicalExpression::Variable(v) => Some(v.clone()),
687            LogicalExpression::Property { variable, .. } => Some(variable.clone()),
688            _ => None,
689        }
690    }
691
692    /// Optimizes the join order using DPccp, or produces a multi-way
693    /// leapfrog join for cyclic patterns when the cost model prefers it.
694    fn optimize_join_order(
695        &self,
696        relations: &[(String, LogicalOperator)],
697        conditions: &[JoinInfo],
698    ) -> Option<LogicalOperator> {
699        use join_order::{DPccp, JoinGraphBuilder};
700
701        // Build the join graph
702        let mut builder = JoinGraphBuilder::new();
703
704        for (var, relation) in relations {
705            builder.add_relation(var, relation.clone());
706        }
707
708        for cond in conditions {
709            builder.add_join_condition(
710                &cond.left_var,
711                &cond.right_var,
712                cond.left_expr.clone(),
713                cond.right_expr.clone(),
714            );
715        }
716
717        let graph = builder.build();
718
719        // For cyclic graphs with 3+ relations, use leapfrog (WCOJ) join.
720        // Cyclic joins (e.g. triangle patterns) benefit from worst-case optimal
721        // multi-way intersection rather than binary hash join cascades that can
722        // produce intermediate blowup.
723        if graph.is_cyclic() && relations.len() >= 3 {
724            // Collect shared variables (variables appearing in 2+ conditions)
725            let mut var_counts: std::collections::HashMap<&str, usize> =
726                std::collections::HashMap::new();
727            for cond in conditions {
728                *var_counts.entry(&cond.left_var).or_default() += 1;
729                *var_counts.entry(&cond.right_var).or_default() += 1;
730            }
731            let shared_variables: Vec<String> = var_counts
732                .into_iter()
733                .filter(|(_, count)| *count >= 2)
734                .map(|(var, _)| var.to_string())
735                .collect();
736
737            let join_conditions: Vec<JoinCondition> = conditions
738                .iter()
739                .map(|c| JoinCondition {
740                    left: c.left_expr.clone(),
741                    right: c.right_expr.clone(),
742                })
743                .collect();
744
745            return Some(LogicalOperator::MultiWayJoin(MultiWayJoinOp {
746                inputs: relations.iter().map(|(_, rel)| rel.clone()).collect(),
747                conditions: join_conditions,
748                shared_variables,
749            }));
750        }
751
752        // Fall through to DPccp for binary join ordering
753        let mut dpccp = DPccp::new(&graph, &self.cost_model, &self.card_estimator);
754        let plan = dpccp.optimize()?;
755
756        Some(plan.operator)
757    }
758
759    /// Pushes filters down the operator tree.
760    ///
761    /// This optimization moves filter predicates as close to the data source
762    /// as possible to reduce the amount of data processed by upper operators.
763    fn push_filters_down(&self, op: LogicalOperator) -> LogicalOperator {
764        match op {
765            // For Filter operators, try to push the predicate into the child
766            LogicalOperator::Filter(filter) => {
767                let optimized_input = self.push_filters_down(*filter.input);
768                self.try_push_filter_into(filter.predicate, optimized_input)
769            }
770            // Recursively optimize children for other operators
771            LogicalOperator::Return(mut ret) => {
772                ret.input = Box::new(self.push_filters_down(*ret.input));
773                LogicalOperator::Return(ret)
774            }
775            LogicalOperator::Project(mut proj) => {
776                proj.input = Box::new(self.push_filters_down(*proj.input));
777                LogicalOperator::Project(proj)
778            }
779            LogicalOperator::Limit(mut limit) => {
780                limit.input = Box::new(self.push_filters_down(*limit.input));
781                LogicalOperator::Limit(limit)
782            }
783            LogicalOperator::Skip(mut skip) => {
784                skip.input = Box::new(self.push_filters_down(*skip.input));
785                LogicalOperator::Skip(skip)
786            }
787            LogicalOperator::Sort(mut sort) => {
788                sort.input = Box::new(self.push_filters_down(*sort.input));
789                LogicalOperator::Sort(sort)
790            }
791            LogicalOperator::Distinct(mut distinct) => {
792                distinct.input = Box::new(self.push_filters_down(*distinct.input));
793                LogicalOperator::Distinct(distinct)
794            }
795            LogicalOperator::Expand(mut expand) => {
796                expand.input = Box::new(self.push_filters_down(*expand.input));
797                LogicalOperator::Expand(expand)
798            }
799            LogicalOperator::Join(mut join) => {
800                join.left = Box::new(self.push_filters_down(*join.left));
801                join.right = Box::new(self.push_filters_down(*join.right));
802                LogicalOperator::Join(join)
803            }
804            LogicalOperator::Aggregate(mut agg) => {
805                agg.input = Box::new(self.push_filters_down(*agg.input));
806                LogicalOperator::Aggregate(agg)
807            }
808            LogicalOperator::MapCollect(mut mc) => {
809                mc.input = Box::new(self.push_filters_down(*mc.input));
810                LogicalOperator::MapCollect(mc)
811            }
812            LogicalOperator::MultiWayJoin(mut mwj) => {
813                mwj.inputs = mwj
814                    .inputs
815                    .into_iter()
816                    .map(|input| self.push_filters_down(input))
817                    .collect();
818                LogicalOperator::MultiWayJoin(mwj)
819            }
820            // Leaf operators and unsupported operators are returned as-is
821            other => other,
822        }
823    }
824
825    /// Tries to push a filter predicate into the given operator.
826    ///
827    /// Returns either the predicate pushed into the operator, or a new
828    /// Filter operator on top if the predicate cannot be pushed further.
829    fn try_push_filter_into(
830        &self,
831        predicate: LogicalExpression,
832        op: LogicalOperator,
833    ) -> LogicalOperator {
834        match op {
835            // Can push through Project if predicate doesn't depend on computed columns
836            LogicalOperator::Project(mut proj) => {
837                let predicate_vars = self.extract_variables(&predicate);
838                let computed_vars = self.extract_projection_aliases(&proj.projections);
839
840                // If predicate doesn't use any computed columns, push through
841                if predicate_vars.is_disjoint(&computed_vars) {
842                    proj.input = Box::new(self.try_push_filter_into(predicate, *proj.input));
843                    LogicalOperator::Project(proj)
844                } else {
845                    // Can't push through, keep filter on top
846                    LogicalOperator::Filter(FilterOp {
847                        predicate,
848                        pushdown_hint: None,
849                        input: Box::new(LogicalOperator::Project(proj)),
850                    })
851                }
852            }
853
854            // Can push through Return (which is like a projection)
855            LogicalOperator::Return(mut ret) => {
856                ret.input = Box::new(self.try_push_filter_into(predicate, *ret.input));
857                LogicalOperator::Return(ret)
858            }
859
860            // Can push through Expand if predicate doesn't use variables introduced by this expand
861            LogicalOperator::Expand(mut expand) => {
862                let predicate_vars = self.extract_variables(&predicate);
863
864                // Variables introduced by this expand are:
865                // - The target variable (to_variable)
866                // - The edge variable (if any)
867                // - The path alias (if any)
868                let mut introduced_vars = vec![&expand.to_variable];
869                if let Some(ref edge_var) = expand.edge_variable {
870                    introduced_vars.push(edge_var);
871                }
872                if let Some(ref path_alias) = expand.path_alias {
873                    introduced_vars.push(path_alias);
874                }
875
876                // Check if predicate uses any variables introduced by this expand
877                let uses_introduced_vars =
878                    predicate_vars.iter().any(|v| introduced_vars.contains(&v));
879
880                if !uses_introduced_vars {
881                    // Predicate doesn't use vars from this expand, so push through
882                    expand.input = Box::new(self.try_push_filter_into(predicate, *expand.input));
883                    LogicalOperator::Expand(expand)
884                } else {
885                    // Keep filter after expand
886                    LogicalOperator::Filter(FilterOp {
887                        predicate,
888                        pushdown_hint: None,
889                        input: Box::new(LogicalOperator::Expand(expand)),
890                    })
891                }
892            }
893
894            // Can push through Join to left/right side based on variables used
895            LogicalOperator::Join(mut join) => {
896                let predicate_vars = self.extract_variables(&predicate);
897                let left_vars = self.collect_output_variables(&join.left);
898                let right_vars = self.collect_output_variables(&join.right);
899
900                let uses_left = predicate_vars.iter().any(|v| left_vars.contains(v));
901                let uses_right = predicate_vars.iter().any(|v| right_vars.contains(v));
902
903                if uses_left && !uses_right {
904                    // Push to left side
905                    join.left = Box::new(self.try_push_filter_into(predicate, *join.left));
906                    LogicalOperator::Join(join)
907                } else if uses_right && !uses_left {
908                    // Push to right side
909                    join.right = Box::new(self.try_push_filter_into(predicate, *join.right));
910                    LogicalOperator::Join(join)
911                } else {
912                    // Uses both sides - keep above join
913                    LogicalOperator::Filter(FilterOp {
914                        predicate,
915                        pushdown_hint: None,
916                        input: Box::new(LogicalOperator::Join(join)),
917                    })
918                }
919            }
920
921            // Cannot push through Aggregate (predicate refers to aggregated values)
922            LogicalOperator::Aggregate(agg) => LogicalOperator::Filter(FilterOp {
923                predicate,
924                pushdown_hint: None,
925                input: Box::new(LogicalOperator::Aggregate(agg)),
926            }),
927
928            // For NodeScan, we've reached the bottom - keep filter on top
929            LogicalOperator::NodeScan(scan) => LogicalOperator::Filter(FilterOp {
930                predicate,
931                pushdown_hint: None,
932                input: Box::new(LogicalOperator::NodeScan(scan)),
933            }),
934
935            // For other operators, keep filter on top
936            other => LogicalOperator::Filter(FilterOp {
937                predicate,
938                pushdown_hint: None,
939                input: Box::new(other),
940            }),
941        }
942    }
943
944    /// Collects all output variable names from an operator.
945    fn collect_output_variables(&self, op: &LogicalOperator) -> HashSet<String> {
946        let mut vars = HashSet::new();
947        Self::collect_output_variables_recursive(op, &mut vars);
948        vars
949    }
950
951    /// Recursively collects output variables from an operator.
952    fn collect_output_variables_recursive(op: &LogicalOperator, vars: &mut HashSet<String>) {
953        match op {
954            LogicalOperator::NodeScan(scan) => {
955                vars.insert(scan.variable.clone());
956            }
957            LogicalOperator::EdgeScan(scan) => {
958                vars.insert(scan.variable.clone());
959            }
960            LogicalOperator::Expand(expand) => {
961                vars.insert(expand.to_variable.clone());
962                if let Some(edge_var) = &expand.edge_variable {
963                    vars.insert(edge_var.clone());
964                }
965                Self::collect_output_variables_recursive(&expand.input, vars);
966            }
967            LogicalOperator::Filter(filter) => {
968                Self::collect_output_variables_recursive(&filter.input, vars);
969            }
970            LogicalOperator::Project(proj) => {
971                for p in &proj.projections {
972                    if let Some(alias) = &p.alias {
973                        vars.insert(alias.clone());
974                    }
975                }
976                Self::collect_output_variables_recursive(&proj.input, vars);
977            }
978            LogicalOperator::Join(join) => {
979                Self::collect_output_variables_recursive(&join.left, vars);
980                Self::collect_output_variables_recursive(&join.right, vars);
981            }
982            LogicalOperator::Aggregate(agg) => {
983                for expr in &agg.group_by {
984                    Self::collect_variables(expr, vars);
985                }
986                for agg_expr in &agg.aggregates {
987                    if let Some(alias) = &agg_expr.alias {
988                        vars.insert(alias.clone());
989                    }
990                }
991            }
992            LogicalOperator::Return(ret) => {
993                Self::collect_output_variables_recursive(&ret.input, vars);
994            }
995            LogicalOperator::Limit(limit) => {
996                Self::collect_output_variables_recursive(&limit.input, vars);
997            }
998            LogicalOperator::Skip(skip) => {
999                Self::collect_output_variables_recursive(&skip.input, vars);
1000            }
1001            LogicalOperator::Sort(sort) => {
1002                Self::collect_output_variables_recursive(&sort.input, vars);
1003            }
1004            LogicalOperator::Distinct(distinct) => {
1005                Self::collect_output_variables_recursive(&distinct.input, vars);
1006            }
1007            _ => {}
1008        }
1009    }
1010
1011    /// Extracts all variable names referenced in an expression.
1012    fn extract_variables(&self, expr: &LogicalExpression) -> HashSet<String> {
1013        let mut vars = HashSet::new();
1014        Self::collect_variables(expr, &mut vars);
1015        vars
1016    }
1017
1018    /// Recursively collects variable names from an expression.
1019    fn collect_variables(expr: &LogicalExpression, vars: &mut HashSet<String>) {
1020        match expr {
1021            LogicalExpression::Variable(name) => {
1022                vars.insert(name.clone());
1023            }
1024            LogicalExpression::Property { variable, .. } => {
1025                vars.insert(variable.clone());
1026            }
1027            LogicalExpression::Binary { left, right, .. } => {
1028                Self::collect_variables(left, vars);
1029                Self::collect_variables(right, vars);
1030            }
1031            LogicalExpression::Unary { operand, .. } => {
1032                Self::collect_variables(operand, vars);
1033            }
1034            LogicalExpression::FunctionCall { args, .. } => {
1035                for arg in args {
1036                    Self::collect_variables(arg, vars);
1037                }
1038            }
1039            LogicalExpression::List(items) => {
1040                for item in items {
1041                    Self::collect_variables(item, vars);
1042                }
1043            }
1044            LogicalExpression::Map(pairs) => {
1045                for (_, value) in pairs {
1046                    Self::collect_variables(value, vars);
1047                }
1048            }
1049            LogicalExpression::IndexAccess { base, index } => {
1050                Self::collect_variables(base, vars);
1051                Self::collect_variables(index, vars);
1052            }
1053            LogicalExpression::SliceAccess { base, start, end } => {
1054                Self::collect_variables(base, vars);
1055                if let Some(s) = start {
1056                    Self::collect_variables(s, vars);
1057                }
1058                if let Some(e) = end {
1059                    Self::collect_variables(e, vars);
1060                }
1061            }
1062            LogicalExpression::Case {
1063                operand,
1064                when_clauses,
1065                else_clause,
1066            } => {
1067                if let Some(op) = operand {
1068                    Self::collect_variables(op, vars);
1069                }
1070                for (cond, result) in when_clauses {
1071                    Self::collect_variables(cond, vars);
1072                    Self::collect_variables(result, vars);
1073                }
1074                if let Some(else_expr) = else_clause {
1075                    Self::collect_variables(else_expr, vars);
1076                }
1077            }
1078            LogicalExpression::Labels(var)
1079            | LogicalExpression::Type(var)
1080            | LogicalExpression::Id(var) => {
1081                vars.insert(var.clone());
1082            }
1083            LogicalExpression::Literal(_) | LogicalExpression::Parameter(_) => {}
1084            LogicalExpression::ListComprehension {
1085                list_expr,
1086                filter_expr,
1087                map_expr,
1088                ..
1089            } => {
1090                Self::collect_variables(list_expr, vars);
1091                if let Some(filter) = filter_expr {
1092                    Self::collect_variables(filter, vars);
1093                }
1094                Self::collect_variables(map_expr, vars);
1095            }
1096            LogicalExpression::ListPredicate {
1097                list_expr,
1098                predicate,
1099                ..
1100            } => {
1101                Self::collect_variables(list_expr, vars);
1102                Self::collect_variables(predicate, vars);
1103            }
1104            LogicalExpression::ExistsSubquery(_)
1105            | LogicalExpression::CountSubquery(_)
1106            | LogicalExpression::ValueSubquery(_) => {
1107                // Subqueries have their own variable scope
1108            }
1109            LogicalExpression::PatternComprehension { projection, .. } => {
1110                Self::collect_variables(projection, vars);
1111            }
1112            LogicalExpression::MapProjection { base, entries } => {
1113                vars.insert(base.clone());
1114                for entry in entries {
1115                    if let crate::query::plan::MapProjectionEntry::LiteralEntry(_, expr) = entry {
1116                        Self::collect_variables(expr, vars);
1117                    }
1118                }
1119            }
1120            LogicalExpression::Reduce {
1121                initial,
1122                list,
1123                expression,
1124                ..
1125            } => {
1126                Self::collect_variables(initial, vars);
1127                Self::collect_variables(list, vars);
1128                Self::collect_variables(expression, vars);
1129            }
1130        }
1131    }
1132
1133    /// Extracts aliases from projection expressions.
1134    fn extract_projection_aliases(
1135        &self,
1136        projections: &[crate::query::plan::Projection],
1137    ) -> HashSet<String> {
1138        projections.iter().filter_map(|p| p.alias.clone()).collect()
1139    }
1140}
1141
1142impl Default for Optimizer {
1143    fn default() -> Self {
1144        Self::new()
1145    }
1146}
1147
1148#[cfg(test)]
1149mod tests {
1150    use super::*;
1151    use crate::query::plan::{
1152        AggregateExpr, AggregateFunction, AggregateOp, BinaryOp, DistinctOp, ExpandDirection,
1153        ExpandOp, JoinOp, JoinType, LimitOp, NodeScanOp, PathMode, ProjectOp, Projection,
1154        ReturnItem, ReturnOp, SkipOp, SortKey, SortOp, SortOrder, UnaryOp,
1155    };
1156    use grafeo_common::types::Value;
1157
1158    #[test]
1159    fn test_optimizer_filter_pushdown_simple() {
1160        // Query: MATCH (n:Person) WHERE n.age > 30 RETURN n
1161        // Before: Return -> Filter -> NodeScan
1162        // After:  Return -> Filter -> NodeScan (filter stays at bottom)
1163
1164        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1165            items: vec![ReturnItem {
1166                expression: LogicalExpression::Variable("n".to_string()),
1167                alias: None,
1168            }],
1169            distinct: false,
1170            input: Box::new(LogicalOperator::Filter(FilterOp {
1171                predicate: LogicalExpression::Binary {
1172                    left: Box::new(LogicalExpression::Property {
1173                        variable: "n".to_string(),
1174                        property: "age".to_string(),
1175                    }),
1176                    op: BinaryOp::Gt,
1177                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1178                },
1179                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1180                    variable: "n".to_string(),
1181                    label: Some("Person".to_string()),
1182                    input: None,
1183                })),
1184                pushdown_hint: None,
1185            })),
1186        }));
1187
1188        let optimizer = Optimizer::new();
1189        let optimized = optimizer.optimize(plan).unwrap();
1190
1191        // The structure should remain similar (filter stays near scan)
1192        if let LogicalOperator::Return(ret) = &optimized.root
1193            && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1194            && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1195        {
1196            assert_eq!(scan.variable, "n");
1197            return;
1198        }
1199        panic!("Expected Return -> Filter -> NodeScan structure");
1200    }
1201
1202    #[test]
1203    fn test_optimizer_filter_pushdown_through_expand() {
1204        // Query: MATCH (a:Person)-[:KNOWS]->(b) WHERE a.age > 30 RETURN b
1205        // The filter on 'a' should be pushed before the expand
1206
1207        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1208            items: vec![ReturnItem {
1209                expression: LogicalExpression::Variable("b".to_string()),
1210                alias: None,
1211            }],
1212            distinct: false,
1213            input: Box::new(LogicalOperator::Filter(FilterOp {
1214                predicate: LogicalExpression::Binary {
1215                    left: Box::new(LogicalExpression::Property {
1216                        variable: "a".to_string(),
1217                        property: "age".to_string(),
1218                    }),
1219                    op: BinaryOp::Gt,
1220                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1221                },
1222                pushdown_hint: None,
1223                input: Box::new(LogicalOperator::Expand(ExpandOp {
1224                    from_variable: "a".to_string(),
1225                    to_variable: "b".to_string(),
1226                    edge_variable: None,
1227                    direction: ExpandDirection::Outgoing,
1228                    edge_types: vec!["KNOWS".to_string()],
1229                    min_hops: 1,
1230                    max_hops: Some(1),
1231                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1232                        variable: "a".to_string(),
1233                        label: Some("Person".to_string()),
1234                        input: None,
1235                    })),
1236                    path_alias: None,
1237                    path_mode: PathMode::Walk,
1238                })),
1239            })),
1240        }));
1241
1242        let optimizer = Optimizer::new();
1243        let optimized = optimizer.optimize(plan).unwrap();
1244
1245        // Filter on 'a' should be pushed before the expand
1246        // Expected: Return -> Expand -> Filter -> NodeScan
1247        if let LogicalOperator::Return(ret) = &optimized.root
1248            && let LogicalOperator::Expand(expand) = ret.input.as_ref()
1249            && let LogicalOperator::Filter(filter) = expand.input.as_ref()
1250            && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1251        {
1252            assert_eq!(scan.variable, "a");
1253            assert_eq!(expand.from_variable, "a");
1254            assert_eq!(expand.to_variable, "b");
1255            return;
1256        }
1257        panic!("Expected Return -> Expand -> Filter -> NodeScan structure");
1258    }
1259
1260    #[test]
1261    fn test_optimizer_filter_not_pushed_through_expand_for_target_var() {
1262        // Query: MATCH (a:Person)-[:KNOWS]->(b) WHERE b.age > 30 RETURN a
1263        // The filter on 'b' should NOT be pushed before the expand
1264
1265        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1266            items: vec![ReturnItem {
1267                expression: LogicalExpression::Variable("a".to_string()),
1268                alias: None,
1269            }],
1270            distinct: false,
1271            input: Box::new(LogicalOperator::Filter(FilterOp {
1272                predicate: LogicalExpression::Binary {
1273                    left: Box::new(LogicalExpression::Property {
1274                        variable: "b".to_string(),
1275                        property: "age".to_string(),
1276                    }),
1277                    op: BinaryOp::Gt,
1278                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1279                },
1280                pushdown_hint: None,
1281                input: Box::new(LogicalOperator::Expand(ExpandOp {
1282                    from_variable: "a".to_string(),
1283                    to_variable: "b".to_string(),
1284                    edge_variable: None,
1285                    direction: ExpandDirection::Outgoing,
1286                    edge_types: vec!["KNOWS".to_string()],
1287                    min_hops: 1,
1288                    max_hops: Some(1),
1289                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1290                        variable: "a".to_string(),
1291                        label: Some("Person".to_string()),
1292                        input: None,
1293                    })),
1294                    path_alias: None,
1295                    path_mode: PathMode::Walk,
1296                })),
1297            })),
1298        }));
1299
1300        let optimizer = Optimizer::new();
1301        let optimized = optimizer.optimize(plan).unwrap();
1302
1303        // Filter on 'b' should stay after the expand
1304        // Expected: Return -> Filter -> Expand -> NodeScan
1305        if let LogicalOperator::Return(ret) = &optimized.root
1306            && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1307        {
1308            // Check that the filter is on 'b'
1309            if let LogicalExpression::Binary { left, .. } = &filter.predicate
1310                && let LogicalExpression::Property { variable, .. } = left.as_ref()
1311            {
1312                assert_eq!(variable, "b");
1313            }
1314
1315            if let LogicalOperator::Expand(expand) = filter.input.as_ref()
1316                && let LogicalOperator::NodeScan(_) = expand.input.as_ref()
1317            {
1318                return;
1319            }
1320        }
1321        panic!("Expected Return -> Filter -> Expand -> NodeScan structure");
1322    }
1323
1324    #[test]
1325    fn test_optimizer_extract_variables() {
1326        let optimizer = Optimizer::new();
1327
1328        let expr = LogicalExpression::Binary {
1329            left: Box::new(LogicalExpression::Property {
1330                variable: "n".to_string(),
1331                property: "age".to_string(),
1332            }),
1333            op: BinaryOp::Gt,
1334            right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1335        };
1336
1337        let vars = optimizer.extract_variables(&expr);
1338        assert_eq!(vars.len(), 1);
1339        assert!(vars.contains("n"));
1340    }
1341
1342    // Additional tests for optimizer configuration
1343
1344    #[test]
1345    fn test_optimizer_default() {
1346        let optimizer = Optimizer::default();
1347        // Should be able to optimize an empty plan
1348        let plan = LogicalPlan::new(LogicalOperator::Empty);
1349        let result = optimizer.optimize(plan);
1350        assert!(result.is_ok());
1351    }
1352
1353    #[test]
1354    fn test_optimizer_with_filter_pushdown_disabled() {
1355        let optimizer = Optimizer::new().with_filter_pushdown(false);
1356
1357        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1358            items: vec![ReturnItem {
1359                expression: LogicalExpression::Variable("n".to_string()),
1360                alias: None,
1361            }],
1362            distinct: false,
1363            input: Box::new(LogicalOperator::Filter(FilterOp {
1364                predicate: LogicalExpression::Literal(Value::Bool(true)),
1365                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1366                    variable: "n".to_string(),
1367                    label: None,
1368                    input: None,
1369                })),
1370                pushdown_hint: None,
1371            })),
1372        }));
1373
1374        let optimized = optimizer.optimize(plan).unwrap();
1375        // Structure should be unchanged
1376        if let LogicalOperator::Return(ret) = &optimized.root
1377            && let LogicalOperator::Filter(_) = ret.input.as_ref()
1378        {
1379            return;
1380        }
1381        panic!("Expected unchanged structure");
1382    }
1383
1384    #[test]
1385    fn test_optimizer_with_join_reorder_disabled() {
1386        let optimizer = Optimizer::new().with_join_reorder(false);
1387        assert!(
1388            optimizer
1389                .optimize(LogicalPlan::new(LogicalOperator::Empty))
1390                .is_ok()
1391        );
1392    }
1393
1394    #[test]
1395    fn test_optimizer_with_cost_model() {
1396        let cost_model = CostModel::new();
1397        let optimizer = Optimizer::new().with_cost_model(cost_model);
1398        assert!(
1399            optimizer
1400                .cost_model()
1401                .estimate(&LogicalOperator::Empty, 0.0)
1402                .total()
1403                < 0.001
1404        );
1405    }
1406
1407    #[test]
1408    fn test_optimizer_with_cardinality_estimator() {
1409        let mut estimator = CardinalityEstimator::new();
1410        estimator.add_table_stats("Test", TableStats::new(500));
1411        let optimizer = Optimizer::new().with_cardinality_estimator(estimator);
1412
1413        let scan = LogicalOperator::NodeScan(NodeScanOp {
1414            variable: "n".to_string(),
1415            label: Some("Test".to_string()),
1416            input: None,
1417        });
1418        let plan = LogicalPlan::new(scan);
1419
1420        let cardinality = optimizer.estimate_cardinality(&plan);
1421        assert!((cardinality - 500.0).abs() < 0.001);
1422    }
1423
1424    #[test]
1425    fn test_optimizer_estimate_cost() {
1426        let optimizer = Optimizer::new();
1427        let plan = LogicalPlan::new(LogicalOperator::NodeScan(NodeScanOp {
1428            variable: "n".to_string(),
1429            label: None,
1430            input: None,
1431        }));
1432
1433        let cost = optimizer.estimate_cost(&plan);
1434        assert!(cost.total() > 0.0);
1435    }
1436
1437    // Filter pushdown through various operators
1438
1439    #[test]
1440    fn test_filter_pushdown_through_project() {
1441        let optimizer = Optimizer::new();
1442
1443        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1444            predicate: LogicalExpression::Binary {
1445                left: Box::new(LogicalExpression::Property {
1446                    variable: "n".to_string(),
1447                    property: "age".to_string(),
1448                }),
1449                op: BinaryOp::Gt,
1450                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1451            },
1452            pushdown_hint: None,
1453            input: Box::new(LogicalOperator::Project(ProjectOp {
1454                projections: vec![Projection {
1455                    expression: LogicalExpression::Variable("n".to_string()),
1456                    alias: None,
1457                }],
1458                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1459                    variable: "n".to_string(),
1460                    label: None,
1461                    input: None,
1462                })),
1463                pass_through_input: false,
1464            })),
1465        }));
1466
1467        let optimized = optimizer.optimize(plan).unwrap();
1468
1469        // Filter should be pushed through Project
1470        if let LogicalOperator::Project(proj) = &optimized.root
1471            && let LogicalOperator::Filter(_) = proj.input.as_ref()
1472        {
1473            return;
1474        }
1475        panic!("Expected Project -> Filter structure");
1476    }
1477
1478    #[test]
1479    fn test_filter_not_pushed_through_project_with_alias() {
1480        let optimizer = Optimizer::new();
1481
1482        // Filter on computed column 'x' should not be pushed through project that creates 'x'
1483        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1484            predicate: LogicalExpression::Binary {
1485                left: Box::new(LogicalExpression::Variable("x".to_string())),
1486                op: BinaryOp::Gt,
1487                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1488            },
1489            pushdown_hint: None,
1490            input: Box::new(LogicalOperator::Project(ProjectOp {
1491                projections: vec![Projection {
1492                    expression: LogicalExpression::Property {
1493                        variable: "n".to_string(),
1494                        property: "age".to_string(),
1495                    },
1496                    alias: Some("x".to_string()),
1497                }],
1498                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1499                    variable: "n".to_string(),
1500                    label: None,
1501                    input: None,
1502                })),
1503                pass_through_input: false,
1504            })),
1505        }));
1506
1507        let optimized = optimizer.optimize(plan).unwrap();
1508
1509        // Filter should stay above Project
1510        if let LogicalOperator::Filter(filter) = &optimized.root
1511            && let LogicalOperator::Project(_) = filter.input.as_ref()
1512        {
1513            return;
1514        }
1515        panic!("Expected Filter -> Project structure");
1516    }
1517
1518    #[test]
1519    fn test_filter_pushdown_through_limit() {
1520        let optimizer = Optimizer::new();
1521
1522        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1523            predicate: LogicalExpression::Literal(Value::Bool(true)),
1524            pushdown_hint: None,
1525            input: Box::new(LogicalOperator::Limit(LimitOp {
1526                count: 10.into(),
1527                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1528                    variable: "n".to_string(),
1529                    label: None,
1530                    input: None,
1531                })),
1532            })),
1533        }));
1534
1535        let optimized = optimizer.optimize(plan).unwrap();
1536
1537        // Filter stays above Limit (cannot be pushed through)
1538        if let LogicalOperator::Filter(filter) = &optimized.root
1539            && let LogicalOperator::Limit(_) = filter.input.as_ref()
1540        {
1541            return;
1542        }
1543        panic!("Expected Filter -> Limit structure");
1544    }
1545
1546    #[test]
1547    fn test_filter_pushdown_through_sort() {
1548        let optimizer = Optimizer::new();
1549
1550        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1551            predicate: LogicalExpression::Literal(Value::Bool(true)),
1552            pushdown_hint: None,
1553            input: Box::new(LogicalOperator::Sort(SortOp {
1554                keys: vec![SortKey {
1555                    expression: LogicalExpression::Variable("n".to_string()),
1556                    order: SortOrder::Ascending,
1557                    nulls: None,
1558                }],
1559                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1560                    variable: "n".to_string(),
1561                    label: None,
1562                    input: None,
1563                })),
1564            })),
1565        }));
1566
1567        let optimized = optimizer.optimize(plan).unwrap();
1568
1569        // Filter stays above Sort
1570        if let LogicalOperator::Filter(filter) = &optimized.root
1571            && let LogicalOperator::Sort(_) = filter.input.as_ref()
1572        {
1573            return;
1574        }
1575        panic!("Expected Filter -> Sort structure");
1576    }
1577
1578    #[test]
1579    fn test_filter_pushdown_through_distinct() {
1580        let optimizer = Optimizer::new();
1581
1582        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1583            predicate: LogicalExpression::Literal(Value::Bool(true)),
1584            pushdown_hint: None,
1585            input: Box::new(LogicalOperator::Distinct(DistinctOp {
1586                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1587                    variable: "n".to_string(),
1588                    label: None,
1589                    input: None,
1590                })),
1591                columns: None,
1592            })),
1593        }));
1594
1595        let optimized = optimizer.optimize(plan).unwrap();
1596
1597        // Filter stays above Distinct
1598        if let LogicalOperator::Filter(filter) = &optimized.root
1599            && let LogicalOperator::Distinct(_) = filter.input.as_ref()
1600        {
1601            return;
1602        }
1603        panic!("Expected Filter -> Distinct structure");
1604    }
1605
1606    #[test]
1607    fn test_filter_not_pushed_through_aggregate() {
1608        let optimizer = Optimizer::new();
1609
1610        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1611            predicate: LogicalExpression::Binary {
1612                left: Box::new(LogicalExpression::Variable("cnt".to_string())),
1613                op: BinaryOp::Gt,
1614                right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1615            },
1616            pushdown_hint: None,
1617            input: Box::new(LogicalOperator::Aggregate(AggregateOp {
1618                group_by: vec![],
1619                aggregates: vec![AggregateExpr {
1620                    function: AggregateFunction::Count,
1621                    expression: None,
1622                    expression2: None,
1623                    distinct: false,
1624                    alias: Some("cnt".to_string()),
1625                    percentile: None,
1626                    separator: None,
1627                }],
1628                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1629                    variable: "n".to_string(),
1630                    label: None,
1631                    input: None,
1632                })),
1633                having: None,
1634            })),
1635        }));
1636
1637        let optimized = optimizer.optimize(plan).unwrap();
1638
1639        // Filter should stay above Aggregate
1640        if let LogicalOperator::Filter(filter) = &optimized.root
1641            && let LogicalOperator::Aggregate(_) = filter.input.as_ref()
1642        {
1643            return;
1644        }
1645        panic!("Expected Filter -> Aggregate structure");
1646    }
1647
1648    #[test]
1649    fn test_filter_pushdown_to_left_join_side() {
1650        let optimizer = Optimizer::new();
1651
1652        // Filter on left variable should be pushed to left side
1653        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1654            predicate: LogicalExpression::Binary {
1655                left: Box::new(LogicalExpression::Property {
1656                    variable: "a".to_string(),
1657                    property: "age".to_string(),
1658                }),
1659                op: BinaryOp::Gt,
1660                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1661            },
1662            pushdown_hint: None,
1663            input: Box::new(LogicalOperator::Join(JoinOp {
1664                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1665                    variable: "a".to_string(),
1666                    label: Some("Person".to_string()),
1667                    input: None,
1668                })),
1669                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1670                    variable: "b".to_string(),
1671                    label: Some("Company".to_string()),
1672                    input: None,
1673                })),
1674                join_type: JoinType::Inner,
1675                conditions: vec![],
1676            })),
1677        }));
1678
1679        let optimized = optimizer.optimize(plan).unwrap();
1680
1681        // Filter should be pushed to left side of join
1682        if let LogicalOperator::Join(join) = &optimized.root
1683            && let LogicalOperator::Filter(_) = join.left.as_ref()
1684        {
1685            return;
1686        }
1687        panic!("Expected Join with Filter on left side");
1688    }
1689
1690    #[test]
1691    fn test_filter_pushdown_to_right_join_side() {
1692        let optimizer = Optimizer::new();
1693
1694        // Filter on right variable should be pushed to right side
1695        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1696            predicate: LogicalExpression::Binary {
1697                left: Box::new(LogicalExpression::Property {
1698                    variable: "b".to_string(),
1699                    property: "name".to_string(),
1700                }),
1701                op: BinaryOp::Eq,
1702                right: Box::new(LogicalExpression::Literal(Value::String("Acme".into()))),
1703            },
1704            pushdown_hint: None,
1705            input: Box::new(LogicalOperator::Join(JoinOp {
1706                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1707                    variable: "a".to_string(),
1708                    label: Some("Person".to_string()),
1709                    input: None,
1710                })),
1711                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1712                    variable: "b".to_string(),
1713                    label: Some("Company".to_string()),
1714                    input: None,
1715                })),
1716                join_type: JoinType::Inner,
1717                conditions: vec![],
1718            })),
1719        }));
1720
1721        let optimized = optimizer.optimize(plan).unwrap();
1722
1723        // Filter should be pushed to right side of join
1724        if let LogicalOperator::Join(join) = &optimized.root
1725            && let LogicalOperator::Filter(_) = join.right.as_ref()
1726        {
1727            return;
1728        }
1729        panic!("Expected Join with Filter on right side");
1730    }
1731
1732    #[test]
1733    fn test_filter_not_pushed_when_uses_both_join_sides() {
1734        let optimizer = Optimizer::new();
1735
1736        // Filter using both variables should stay above join
1737        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1738            predicate: LogicalExpression::Binary {
1739                left: Box::new(LogicalExpression::Property {
1740                    variable: "a".to_string(),
1741                    property: "id".to_string(),
1742                }),
1743                op: BinaryOp::Eq,
1744                right: Box::new(LogicalExpression::Property {
1745                    variable: "b".to_string(),
1746                    property: "a_id".to_string(),
1747                }),
1748            },
1749            pushdown_hint: None,
1750            input: Box::new(LogicalOperator::Join(JoinOp {
1751                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1752                    variable: "a".to_string(),
1753                    label: None,
1754                    input: None,
1755                })),
1756                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1757                    variable: "b".to_string(),
1758                    label: None,
1759                    input: None,
1760                })),
1761                join_type: JoinType::Inner,
1762                conditions: vec![],
1763            })),
1764        }));
1765
1766        let optimized = optimizer.optimize(plan).unwrap();
1767
1768        // Filter should stay above join
1769        if let LogicalOperator::Filter(filter) = &optimized.root
1770            && let LogicalOperator::Join(_) = filter.input.as_ref()
1771        {
1772            return;
1773        }
1774        panic!("Expected Filter -> Join structure");
1775    }
1776
1777    // Variable extraction tests
1778
1779    #[test]
1780    fn test_extract_variables_from_variable() {
1781        let optimizer = Optimizer::new();
1782        let expr = LogicalExpression::Variable("x".to_string());
1783        let vars = optimizer.extract_variables(&expr);
1784        assert_eq!(vars.len(), 1);
1785        assert!(vars.contains("x"));
1786    }
1787
1788    #[test]
1789    fn test_extract_variables_from_unary() {
1790        let optimizer = Optimizer::new();
1791        let expr = LogicalExpression::Unary {
1792            op: UnaryOp::Not,
1793            operand: Box::new(LogicalExpression::Variable("x".to_string())),
1794        };
1795        let vars = optimizer.extract_variables(&expr);
1796        assert_eq!(vars.len(), 1);
1797        assert!(vars.contains("x"));
1798    }
1799
1800    #[test]
1801    fn test_extract_variables_from_function_call() {
1802        let optimizer = Optimizer::new();
1803        let expr = LogicalExpression::FunctionCall {
1804            name: "length".to_string(),
1805            args: vec![
1806                LogicalExpression::Variable("a".to_string()),
1807                LogicalExpression::Variable("b".to_string()),
1808            ],
1809            distinct: false,
1810        };
1811        let vars = optimizer.extract_variables(&expr);
1812        assert_eq!(vars.len(), 2);
1813        assert!(vars.contains("a"));
1814        assert!(vars.contains("b"));
1815    }
1816
1817    #[test]
1818    fn test_extract_variables_from_list() {
1819        let optimizer = Optimizer::new();
1820        let expr = LogicalExpression::List(vec![
1821            LogicalExpression::Variable("a".to_string()),
1822            LogicalExpression::Literal(Value::Int64(1)),
1823            LogicalExpression::Variable("b".to_string()),
1824        ]);
1825        let vars = optimizer.extract_variables(&expr);
1826        assert_eq!(vars.len(), 2);
1827        assert!(vars.contains("a"));
1828        assert!(vars.contains("b"));
1829    }
1830
1831    #[test]
1832    fn test_extract_variables_from_map() {
1833        let optimizer = Optimizer::new();
1834        let expr = LogicalExpression::Map(vec![
1835            (
1836                "key1".to_string(),
1837                LogicalExpression::Variable("a".to_string()),
1838            ),
1839            (
1840                "key2".to_string(),
1841                LogicalExpression::Variable("b".to_string()),
1842            ),
1843        ]);
1844        let vars = optimizer.extract_variables(&expr);
1845        assert_eq!(vars.len(), 2);
1846        assert!(vars.contains("a"));
1847        assert!(vars.contains("b"));
1848    }
1849
1850    #[test]
1851    fn test_extract_variables_from_index_access() {
1852        let optimizer = Optimizer::new();
1853        let expr = LogicalExpression::IndexAccess {
1854            base: Box::new(LogicalExpression::Variable("list".to_string())),
1855            index: Box::new(LogicalExpression::Variable("idx".to_string())),
1856        };
1857        let vars = optimizer.extract_variables(&expr);
1858        assert_eq!(vars.len(), 2);
1859        assert!(vars.contains("list"));
1860        assert!(vars.contains("idx"));
1861    }
1862
1863    #[test]
1864    fn test_extract_variables_from_slice_access() {
1865        let optimizer = Optimizer::new();
1866        let expr = LogicalExpression::SliceAccess {
1867            base: Box::new(LogicalExpression::Variable("list".to_string())),
1868            start: Some(Box::new(LogicalExpression::Variable("s".to_string()))),
1869            end: Some(Box::new(LogicalExpression::Variable("e".to_string()))),
1870        };
1871        let vars = optimizer.extract_variables(&expr);
1872        assert_eq!(vars.len(), 3);
1873        assert!(vars.contains("list"));
1874        assert!(vars.contains("s"));
1875        assert!(vars.contains("e"));
1876    }
1877
1878    #[test]
1879    fn test_extract_variables_from_case() {
1880        let optimizer = Optimizer::new();
1881        let expr = LogicalExpression::Case {
1882            operand: Some(Box::new(LogicalExpression::Variable("x".to_string()))),
1883            when_clauses: vec![(
1884                LogicalExpression::Literal(Value::Int64(1)),
1885                LogicalExpression::Variable("a".to_string()),
1886            )],
1887            else_clause: Some(Box::new(LogicalExpression::Variable("b".to_string()))),
1888        };
1889        let vars = optimizer.extract_variables(&expr);
1890        assert_eq!(vars.len(), 3);
1891        assert!(vars.contains("x"));
1892        assert!(vars.contains("a"));
1893        assert!(vars.contains("b"));
1894    }
1895
1896    #[test]
1897    fn test_extract_variables_from_labels() {
1898        let optimizer = Optimizer::new();
1899        let expr = LogicalExpression::Labels("n".to_string());
1900        let vars = optimizer.extract_variables(&expr);
1901        assert_eq!(vars.len(), 1);
1902        assert!(vars.contains("n"));
1903    }
1904
1905    #[test]
1906    fn test_extract_variables_from_type() {
1907        let optimizer = Optimizer::new();
1908        let expr = LogicalExpression::Type("e".to_string());
1909        let vars = optimizer.extract_variables(&expr);
1910        assert_eq!(vars.len(), 1);
1911        assert!(vars.contains("e"));
1912    }
1913
1914    #[test]
1915    fn test_extract_variables_from_id() {
1916        let optimizer = Optimizer::new();
1917        let expr = LogicalExpression::Id("n".to_string());
1918        let vars = optimizer.extract_variables(&expr);
1919        assert_eq!(vars.len(), 1);
1920        assert!(vars.contains("n"));
1921    }
1922
1923    #[test]
1924    fn test_extract_variables_from_list_comprehension() {
1925        let optimizer = Optimizer::new();
1926        let expr = LogicalExpression::ListComprehension {
1927            variable: "x".to_string(),
1928            list_expr: Box::new(LogicalExpression::Variable("items".to_string())),
1929            filter_expr: Some(Box::new(LogicalExpression::Variable("pred".to_string()))),
1930            map_expr: Box::new(LogicalExpression::Variable("result".to_string())),
1931        };
1932        let vars = optimizer.extract_variables(&expr);
1933        assert!(vars.contains("items"));
1934        assert!(vars.contains("pred"));
1935        assert!(vars.contains("result"));
1936    }
1937
1938    #[test]
1939    fn test_extract_variables_from_literal_and_parameter() {
1940        let optimizer = Optimizer::new();
1941
1942        let literal = LogicalExpression::Literal(Value::Int64(42));
1943        assert!(optimizer.extract_variables(&literal).is_empty());
1944
1945        let param = LogicalExpression::Parameter("p".to_string());
1946        assert!(optimizer.extract_variables(&param).is_empty());
1947    }
1948
1949    // Recursive filter pushdown tests
1950
1951    #[test]
1952    fn test_recursive_filter_pushdown_through_skip() {
1953        let optimizer = Optimizer::new();
1954
1955        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1956            items: vec![ReturnItem {
1957                expression: LogicalExpression::Variable("n".to_string()),
1958                alias: None,
1959            }],
1960            distinct: false,
1961            input: Box::new(LogicalOperator::Filter(FilterOp {
1962                predicate: LogicalExpression::Literal(Value::Bool(true)),
1963                pushdown_hint: None,
1964                input: Box::new(LogicalOperator::Skip(SkipOp {
1965                    count: 5.into(),
1966                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1967                        variable: "n".to_string(),
1968                        label: None,
1969                        input: None,
1970                    })),
1971                })),
1972            })),
1973        }));
1974
1975        let optimized = optimizer.optimize(plan).unwrap();
1976
1977        // Verify optimization succeeded
1978        assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1979    }
1980
1981    #[test]
1982    fn test_nested_filter_pushdown() {
1983        let optimizer = Optimizer::new();
1984
1985        // Multiple nested filters
1986        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1987            items: vec![ReturnItem {
1988                expression: LogicalExpression::Variable("n".to_string()),
1989                alias: None,
1990            }],
1991            distinct: false,
1992            input: Box::new(LogicalOperator::Filter(FilterOp {
1993                predicate: LogicalExpression::Binary {
1994                    left: Box::new(LogicalExpression::Property {
1995                        variable: "n".to_string(),
1996                        property: "x".to_string(),
1997                    }),
1998                    op: BinaryOp::Gt,
1999                    right: Box::new(LogicalExpression::Literal(Value::Int64(1))),
2000                },
2001                pushdown_hint: None,
2002                input: Box::new(LogicalOperator::Filter(FilterOp {
2003                    predicate: LogicalExpression::Binary {
2004                        left: Box::new(LogicalExpression::Property {
2005                            variable: "n".to_string(),
2006                            property: "y".to_string(),
2007                        }),
2008                        op: BinaryOp::Lt,
2009                        right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
2010                    },
2011                    pushdown_hint: None,
2012                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2013                        variable: "n".to_string(),
2014                        label: None,
2015                        input: None,
2016                    })),
2017                })),
2018            })),
2019        }));
2020
2021        let optimized = optimizer.optimize(plan).unwrap();
2022        assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
2023    }
2024
2025    #[test]
2026    fn test_cyclic_join_produces_multi_way_join() {
2027        use crate::query::plan::JoinCondition;
2028
2029        // Triangle pattern: a ⋈ b ⋈ c ⋈ a (cyclic)
2030        let scan_a = LogicalOperator::NodeScan(NodeScanOp {
2031            variable: "a".to_string(),
2032            label: Some("Person".to_string()),
2033            input: None,
2034        });
2035        let scan_b = LogicalOperator::NodeScan(NodeScanOp {
2036            variable: "b".to_string(),
2037            label: Some("Person".to_string()),
2038            input: None,
2039        });
2040        let scan_c = LogicalOperator::NodeScan(NodeScanOp {
2041            variable: "c".to_string(),
2042            label: Some("Person".to_string()),
2043            input: None,
2044        });
2045
2046        // Build: Join(Join(a, b, a=b), c, b=c) with extra condition c=a
2047        let join_ab = LogicalOperator::Join(JoinOp {
2048            left: Box::new(scan_a),
2049            right: Box::new(scan_b),
2050            join_type: JoinType::Inner,
2051            conditions: vec![JoinCondition {
2052                left: LogicalExpression::Variable("a".to_string()),
2053                right: LogicalExpression::Variable("b".to_string()),
2054            }],
2055        });
2056
2057        let join_abc = LogicalOperator::Join(JoinOp {
2058            left: Box::new(join_ab),
2059            right: Box::new(scan_c),
2060            join_type: JoinType::Inner,
2061            conditions: vec![
2062                JoinCondition {
2063                    left: LogicalExpression::Variable("b".to_string()),
2064                    right: LogicalExpression::Variable("c".to_string()),
2065                },
2066                JoinCondition {
2067                    left: LogicalExpression::Variable("c".to_string()),
2068                    right: LogicalExpression::Variable("a".to_string()),
2069                },
2070            ],
2071        });
2072
2073        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
2074            items: vec![ReturnItem {
2075                expression: LogicalExpression::Variable("a".to_string()),
2076                alias: None,
2077            }],
2078            distinct: false,
2079            input: Box::new(join_abc),
2080        }));
2081
2082        let mut optimizer = Optimizer::new();
2083        optimizer
2084            .card_estimator
2085            .add_table_stats("Person", cardinality::TableStats::new(1000));
2086
2087        let optimized = optimizer.optimize(plan).unwrap();
2088
2089        // Walk the tree to find a MultiWayJoin
2090        fn has_multi_way_join(op: &LogicalOperator) -> bool {
2091            match op {
2092                LogicalOperator::MultiWayJoin(_) => true,
2093                LogicalOperator::Return(ret) => has_multi_way_join(&ret.input),
2094                LogicalOperator::Filter(f) => has_multi_way_join(&f.input),
2095                LogicalOperator::Project(p) => has_multi_way_join(&p.input),
2096                _ => false,
2097            }
2098        }
2099
2100        assert!(
2101            has_multi_way_join(&optimized.root),
2102            "Expected MultiWayJoin for cyclic triangle pattern"
2103        );
2104    }
2105
2106    #[test]
2107    fn test_acyclic_join_uses_binary_joins() {
2108        use crate::query::plan::JoinCondition;
2109
2110        // Chain: a ⋈ b ⋈ c (acyclic)
2111        let scan_a = LogicalOperator::NodeScan(NodeScanOp {
2112            variable: "a".to_string(),
2113            label: Some("Person".to_string()),
2114            input: None,
2115        });
2116        let scan_b = LogicalOperator::NodeScan(NodeScanOp {
2117            variable: "b".to_string(),
2118            label: Some("Person".to_string()),
2119            input: None,
2120        });
2121        let scan_c = LogicalOperator::NodeScan(NodeScanOp {
2122            variable: "c".to_string(),
2123            label: Some("Company".to_string()),
2124            input: None,
2125        });
2126
2127        let join_ab = LogicalOperator::Join(JoinOp {
2128            left: Box::new(scan_a),
2129            right: Box::new(scan_b),
2130            join_type: JoinType::Inner,
2131            conditions: vec![JoinCondition {
2132                left: LogicalExpression::Variable("a".to_string()),
2133                right: LogicalExpression::Variable("b".to_string()),
2134            }],
2135        });
2136
2137        let join_abc = LogicalOperator::Join(JoinOp {
2138            left: Box::new(join_ab),
2139            right: Box::new(scan_c),
2140            join_type: JoinType::Inner,
2141            conditions: vec![JoinCondition {
2142                left: LogicalExpression::Variable("b".to_string()),
2143                right: LogicalExpression::Variable("c".to_string()),
2144            }],
2145        });
2146
2147        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
2148            items: vec![ReturnItem {
2149                expression: LogicalExpression::Variable("a".to_string()),
2150                alias: None,
2151            }],
2152            distinct: false,
2153            input: Box::new(join_abc),
2154        }));
2155
2156        let mut optimizer = Optimizer::new();
2157        optimizer
2158            .card_estimator
2159            .add_table_stats("Person", cardinality::TableStats::new(1000));
2160        optimizer
2161            .card_estimator
2162            .add_table_stats("Company", cardinality::TableStats::new(100));
2163
2164        let optimized = optimizer.optimize(plan).unwrap();
2165
2166        // Should NOT contain MultiWayJoin for acyclic pattern
2167        fn has_multi_way_join(op: &LogicalOperator) -> bool {
2168            match op {
2169                LogicalOperator::MultiWayJoin(_) => true,
2170                LogicalOperator::Return(ret) => has_multi_way_join(&ret.input),
2171                LogicalOperator::Filter(f) => has_multi_way_join(&f.input),
2172                LogicalOperator::Project(p) => has_multi_way_join(&p.input),
2173                LogicalOperator::Join(j) => {
2174                    has_multi_way_join(&j.left) || has_multi_way_join(&j.right)
2175                }
2176                _ => false,
2177            }
2178        }
2179
2180        assert!(
2181            !has_multi_way_join(&optimized.root),
2182            "Acyclic join should NOT produce MultiWayJoin"
2183        );
2184    }
2185}