Skip to main content

grafeo_engine/query/optimizer/
mod.rs

1//! Makes your queries faster without changing their meaning.
2//!
3//! The optimizer transforms logical plans to run more efficiently:
4//!
5//! | Optimization | What it does |
6//! | ------------ | ------------ |
7//! | Filter Pushdown | Moves `WHERE` clauses closer to scans - filter early, process less |
8//! | Join Reordering | Picks the best order to join tables using the DPccp algorithm |
9//! | Predicate Simplification | Folds constants like `1 + 1` into `2` |
10//!
11//! The optimizer uses [`CostModel`] and [`CardinalityEstimator`] to predict
12//! how expensive different plans are, then picks the cheapest.
13
14pub mod cardinality;
15pub mod cost;
16pub mod join_order;
17
18pub use cardinality::{
19    CardinalityEstimator, ColumnStats, EstimationLog, SelectivityConfig, TableStats,
20};
21pub use cost::{Cost, CostModel};
22pub use join_order::{BitSet, DPccp, JoinGraph, JoinGraphBuilder, JoinPlan};
23
24use crate::query::plan::{
25    FilterOp, JoinCondition, LogicalExpression, LogicalOperator, LogicalPlan, MultiWayJoinOp,
26};
27use grafeo_common::grafeo_debug_span;
28use grafeo_common::utils::error::Result;
29use std::collections::HashSet;
30
31/// Information about a join condition for join reordering.
32#[derive(Debug, Clone)]
33struct JoinInfo {
34    left_var: String,
35    right_var: String,
36    left_expr: LogicalExpression,
37    right_expr: LogicalExpression,
38}
39
40/// A column required by the query, used for projection pushdown.
41#[derive(Debug, Clone, PartialEq, Eq, Hash)]
42enum RequiredColumn {
43    /// A variable (node, edge, or path binding)
44    Variable(String),
45    /// A specific property of a variable
46    Property(String, String),
47}
48
49/// Transforms logical plans for faster execution.
50///
51/// Create with [`new()`](Self::new), then call [`optimize()`](Self::optimize).
52/// Use the builder methods to enable/disable specific optimizations.
53pub struct Optimizer {
54    /// Whether to enable filter pushdown.
55    enable_filter_pushdown: bool,
56    /// Whether to enable join reordering.
57    enable_join_reorder: bool,
58    /// Whether to enable projection pushdown.
59    enable_projection_pushdown: bool,
60    /// Cost model for estimation.
61    cost_model: CostModel,
62    /// Cardinality estimator.
63    card_estimator: CardinalityEstimator,
64}
65
66impl Optimizer {
67    /// Creates a new optimizer with default settings.
68    #[must_use]
69    pub fn new() -> Self {
70        Self {
71            enable_filter_pushdown: true,
72            enable_join_reorder: true,
73            enable_projection_pushdown: true,
74            cost_model: CostModel::new(),
75            card_estimator: CardinalityEstimator::new(),
76        }
77    }
78
79    /// Creates an optimizer with cardinality estimates from the store's statistics.
80    ///
81    /// Pre-populates the cardinality estimator with per-label row counts and
82    /// edge type fanout. Feeds per-edge-type degree stats, label cardinalities,
83    /// and graph totals into the cost model for accurate estimation.
84    #[must_use]
85    pub fn from_store(store: &grafeo_core::graph::lpg::LpgStore) -> Self {
86        store.ensure_statistics_fresh();
87        let stats = store.statistics();
88        Self::from_statistics(&stats)
89    }
90
91    /// Creates an optimizer from any GraphStore implementation.
92    ///
93    /// Unlike [`from_store`](Self::from_store), this does not call
94    /// `ensure_statistics_fresh()` since external stores manage their own
95    /// statistics. The store's [`statistics()`](grafeo_core::graph::GraphStore::statistics) method
96    /// is called directly.
97    #[must_use]
98    pub fn from_graph_store(store: &dyn grafeo_core::graph::GraphStore) -> Self {
99        let stats = store.statistics();
100        Self::from_statistics(&stats)
101    }
102
103    /// Creates an optimizer from RDF statistics.
104    ///
105    /// Uses triple pattern cardinality estimates for cost-based optimization
106    /// of SPARQL queries. Maps total triples to graph totals for the cost model.
107    #[cfg(feature = "rdf")]
108    #[must_use]
109    pub fn from_rdf_statistics(rdf_stats: grafeo_core::statistics::RdfStatistics) -> Self {
110        let total = rdf_stats.total_triples;
111        let estimator = CardinalityEstimator::from_rdf_statistics(rdf_stats);
112        Self {
113            enable_filter_pushdown: true,
114            enable_join_reorder: true,
115            enable_projection_pushdown: true,
116            cost_model: CostModel::new().with_graph_totals(total, total),
117            card_estimator: estimator,
118        }
119    }
120
121    /// Creates an optimizer from a Statistics snapshot.
122    ///
123    /// Extracts label cardinalities, edge type degrees, and graph totals
124    /// for both the cardinality estimator and cost model.
125    #[must_use]
126    fn from_statistics(stats: &grafeo_core::statistics::Statistics) -> Self {
127        let estimator = CardinalityEstimator::from_statistics(stats);
128
129        let avg_fanout = if stats.total_nodes > 0 {
130            (stats.total_edges as f64 / stats.total_nodes as f64).max(1.0)
131        } else {
132            10.0
133        };
134
135        let edge_type_degrees: std::collections::HashMap<String, (f64, f64)> = stats
136            .edge_types
137            .iter()
138            .map(|(name, et)| (name.clone(), (et.avg_out_degree, et.avg_in_degree)))
139            .collect();
140
141        let label_cardinalities: std::collections::HashMap<String, u64> = stats
142            .labels
143            .iter()
144            .map(|(name, ls)| (name.clone(), ls.node_count))
145            .collect();
146
147        Self {
148            enable_filter_pushdown: true,
149            enable_join_reorder: true,
150            enable_projection_pushdown: true,
151            cost_model: CostModel::new()
152                .with_avg_fanout(avg_fanout)
153                .with_edge_type_degrees(edge_type_degrees)
154                .with_label_cardinalities(label_cardinalities)
155                .with_graph_totals(stats.total_nodes, stats.total_edges),
156            card_estimator: estimator,
157        }
158    }
159
160    /// Enables or disables filter pushdown.
161    pub fn with_filter_pushdown(mut self, enabled: bool) -> Self {
162        self.enable_filter_pushdown = enabled;
163        self
164    }
165
166    /// Enables or disables join reordering.
167    pub fn with_join_reorder(mut self, enabled: bool) -> Self {
168        self.enable_join_reorder = enabled;
169        self
170    }
171
172    /// Enables or disables projection pushdown.
173    pub fn with_projection_pushdown(mut self, enabled: bool) -> Self {
174        self.enable_projection_pushdown = enabled;
175        self
176    }
177
178    /// Sets the cost model.
179    pub fn with_cost_model(mut self, cost_model: CostModel) -> Self {
180        self.cost_model = cost_model;
181        self
182    }
183
184    /// Sets the cardinality estimator.
185    pub fn with_cardinality_estimator(mut self, estimator: CardinalityEstimator) -> Self {
186        self.card_estimator = estimator;
187        self
188    }
189
190    /// Sets the selectivity configuration for the cardinality estimator.
191    pub fn with_selectivity_config(mut self, config: SelectivityConfig) -> Self {
192        self.card_estimator = CardinalityEstimator::with_selectivity_config(config);
193        self
194    }
195
196    /// Returns a reference to the cost model.
197    pub fn cost_model(&self) -> &CostModel {
198        &self.cost_model
199    }
200
201    /// Returns a reference to the cardinality estimator.
202    pub fn cardinality_estimator(&self) -> &CardinalityEstimator {
203        &self.card_estimator
204    }
205
206    /// Estimates the total cost of a plan by recursively costing the entire tree.
207    ///
208    /// Walks every operator in the plan, computing cardinality at each level
209    /// and summing the per-operator costs. Uses actual child cardinalities
210    /// for join cost estimation rather than approximations.
211    pub fn estimate_cost(&self, plan: &LogicalPlan) -> Cost {
212        self.cost_model
213            .estimate_tree(&plan.root, &self.card_estimator)
214    }
215
216    /// Estimates the cardinality of a plan.
217    pub fn estimate_cardinality(&self, plan: &LogicalPlan) -> f64 {
218        self.card_estimator.estimate(&plan.root)
219    }
220
221    /// Optimizes a logical plan.
222    ///
223    /// # Errors
224    ///
225    /// Returns an error if optimization fails.
226    pub fn optimize(&self, plan: LogicalPlan) -> Result<LogicalPlan> {
227        let _span = grafeo_debug_span!("grafeo::query::optimize");
228        let mut root = plan.root;
229
230        // Apply optimization rules
231        if self.enable_filter_pushdown {
232            root = self.push_filters_down(root);
233        }
234
235        if self.enable_join_reorder {
236            root = self.reorder_joins(root);
237        }
238
239        if self.enable_projection_pushdown {
240            root = self.push_projections_down(root);
241        }
242
243        Ok(LogicalPlan {
244            root,
245            explain: plan.explain,
246            profile: plan.profile,
247            default_params: plan.default_params,
248        })
249    }
250
251    /// Pushes projections down the operator tree to eliminate unused columns early.
252    ///
253    /// This optimization:
254    /// 1. Collects required variables/properties from the root
255    /// 2. Propagates requirements down through the tree
256    /// 3. Inserts projections to eliminate unneeded columns before expensive operations
257    fn push_projections_down(&self, op: LogicalOperator) -> LogicalOperator {
258        // Collect required columns from the top of the plan
259        let required = self.collect_required_columns(&op);
260
261        // Push projections down
262        self.push_projections_recursive(op, &required)
263    }
264
265    /// Collects all variables and properties required by an operator and its ancestors.
266    fn collect_required_columns(&self, op: &LogicalOperator) -> HashSet<RequiredColumn> {
267        let mut required = HashSet::new();
268        Self::collect_required_recursive(op, &mut required);
269        required
270    }
271
272    /// Recursively collects required columns.
273    fn collect_required_recursive(op: &LogicalOperator, required: &mut HashSet<RequiredColumn>) {
274        match op {
275            LogicalOperator::Return(ret) => {
276                for item in &ret.items {
277                    Self::collect_from_expression(&item.expression, required);
278                }
279                Self::collect_required_recursive(&ret.input, required);
280            }
281            LogicalOperator::Project(proj) => {
282                for p in &proj.projections {
283                    Self::collect_from_expression(&p.expression, required);
284                }
285                Self::collect_required_recursive(&proj.input, required);
286            }
287            LogicalOperator::Filter(filter) => {
288                Self::collect_from_expression(&filter.predicate, required);
289                Self::collect_required_recursive(&filter.input, required);
290            }
291            LogicalOperator::Sort(sort) => {
292                for key in &sort.keys {
293                    Self::collect_from_expression(&key.expression, required);
294                }
295                Self::collect_required_recursive(&sort.input, required);
296            }
297            LogicalOperator::Aggregate(agg) => {
298                for expr in &agg.group_by {
299                    Self::collect_from_expression(expr, required);
300                }
301                for agg_expr in &agg.aggregates {
302                    if let Some(ref expr) = agg_expr.expression {
303                        Self::collect_from_expression(expr, required);
304                    }
305                }
306                if let Some(ref having) = agg.having {
307                    Self::collect_from_expression(having, required);
308                }
309                Self::collect_required_recursive(&agg.input, required);
310            }
311            LogicalOperator::Join(join) => {
312                for cond in &join.conditions {
313                    Self::collect_from_expression(&cond.left, required);
314                    Self::collect_from_expression(&cond.right, required);
315                }
316                Self::collect_required_recursive(&join.left, required);
317                Self::collect_required_recursive(&join.right, required);
318            }
319            LogicalOperator::Expand(expand) => {
320                // The source and target variables are needed
321                required.insert(RequiredColumn::Variable(expand.from_variable.clone()));
322                required.insert(RequiredColumn::Variable(expand.to_variable.clone()));
323                if let Some(ref edge_var) = expand.edge_variable {
324                    required.insert(RequiredColumn::Variable(edge_var.clone()));
325                }
326                Self::collect_required_recursive(&expand.input, required);
327            }
328            LogicalOperator::Limit(limit) => {
329                Self::collect_required_recursive(&limit.input, required);
330            }
331            LogicalOperator::Skip(skip) => {
332                Self::collect_required_recursive(&skip.input, required);
333            }
334            LogicalOperator::Distinct(distinct) => {
335                Self::collect_required_recursive(&distinct.input, required);
336            }
337            LogicalOperator::NodeScan(scan) => {
338                required.insert(RequiredColumn::Variable(scan.variable.clone()));
339            }
340            LogicalOperator::EdgeScan(scan) => {
341                required.insert(RequiredColumn::Variable(scan.variable.clone()));
342            }
343            LogicalOperator::MultiWayJoin(mwj) => {
344                for cond in &mwj.conditions {
345                    Self::collect_from_expression(&cond.left, required);
346                    Self::collect_from_expression(&cond.right, required);
347                }
348                for input in &mwj.inputs {
349                    Self::collect_required_recursive(input, required);
350                }
351            }
352            _ => {}
353        }
354    }
355
356    /// Collects required columns from an expression.
357    fn collect_from_expression(expr: &LogicalExpression, required: &mut HashSet<RequiredColumn>) {
358        match expr {
359            LogicalExpression::Variable(var) => {
360                required.insert(RequiredColumn::Variable(var.clone()));
361            }
362            LogicalExpression::Property { variable, property } => {
363                required.insert(RequiredColumn::Property(variable.clone(), property.clone()));
364                required.insert(RequiredColumn::Variable(variable.clone()));
365            }
366            LogicalExpression::Binary { left, right, .. } => {
367                Self::collect_from_expression(left, required);
368                Self::collect_from_expression(right, required);
369            }
370            LogicalExpression::Unary { operand, .. } => {
371                Self::collect_from_expression(operand, required);
372            }
373            LogicalExpression::FunctionCall { args, .. } => {
374                for arg in args {
375                    Self::collect_from_expression(arg, required);
376                }
377            }
378            LogicalExpression::List(items) => {
379                for item in items {
380                    Self::collect_from_expression(item, required);
381                }
382            }
383            LogicalExpression::Map(pairs) => {
384                for (_, value) in pairs {
385                    Self::collect_from_expression(value, required);
386                }
387            }
388            LogicalExpression::IndexAccess { base, index } => {
389                Self::collect_from_expression(base, required);
390                Self::collect_from_expression(index, required);
391            }
392            LogicalExpression::SliceAccess { base, start, end } => {
393                Self::collect_from_expression(base, required);
394                if let Some(s) = start {
395                    Self::collect_from_expression(s, required);
396                }
397                if let Some(e) = end {
398                    Self::collect_from_expression(e, required);
399                }
400            }
401            LogicalExpression::Case {
402                operand,
403                when_clauses,
404                else_clause,
405            } => {
406                if let Some(op) = operand {
407                    Self::collect_from_expression(op, required);
408                }
409                for (cond, result) in when_clauses {
410                    Self::collect_from_expression(cond, required);
411                    Self::collect_from_expression(result, required);
412                }
413                if let Some(else_expr) = else_clause {
414                    Self::collect_from_expression(else_expr, required);
415                }
416            }
417            LogicalExpression::Labels(var)
418            | LogicalExpression::Type(var)
419            | LogicalExpression::Id(var) => {
420                required.insert(RequiredColumn::Variable(var.clone()));
421            }
422            LogicalExpression::ListComprehension {
423                list_expr,
424                filter_expr,
425                map_expr,
426                ..
427            } => {
428                Self::collect_from_expression(list_expr, required);
429                if let Some(filter) = filter_expr {
430                    Self::collect_from_expression(filter, required);
431                }
432                Self::collect_from_expression(map_expr, required);
433            }
434            _ => {}
435        }
436    }
437
438    /// Recursively pushes projections down, adding them before expensive operations.
439    fn push_projections_recursive(
440        &self,
441        op: LogicalOperator,
442        required: &HashSet<RequiredColumn>,
443    ) -> LogicalOperator {
444        match op {
445            LogicalOperator::Return(mut ret) => {
446                ret.input = Box::new(self.push_projections_recursive(*ret.input, required));
447                LogicalOperator::Return(ret)
448            }
449            LogicalOperator::Project(mut proj) => {
450                proj.input = Box::new(self.push_projections_recursive(*proj.input, required));
451                LogicalOperator::Project(proj)
452            }
453            LogicalOperator::Filter(mut filter) => {
454                filter.input = Box::new(self.push_projections_recursive(*filter.input, required));
455                LogicalOperator::Filter(filter)
456            }
457            LogicalOperator::Sort(mut sort) => {
458                // Sort is expensive - consider adding a projection before it
459                // to reduce tuple width
460                sort.input = Box::new(self.push_projections_recursive(*sort.input, required));
461                LogicalOperator::Sort(sort)
462            }
463            LogicalOperator::Aggregate(mut agg) => {
464                agg.input = Box::new(self.push_projections_recursive(*agg.input, required));
465                LogicalOperator::Aggregate(agg)
466            }
467            LogicalOperator::Join(mut join) => {
468                // Joins are expensive - the required columns help determine
469                // what to project on each side
470                let left_vars = self.collect_output_variables(&join.left);
471                let right_vars = self.collect_output_variables(&join.right);
472
473                // Filter required columns to each side
474                let left_required: HashSet<_> = required
475                    .iter()
476                    .filter(|c| match c {
477                        RequiredColumn::Variable(v) => left_vars.contains(v),
478                        RequiredColumn::Property(v, _) => left_vars.contains(v),
479                    })
480                    .cloned()
481                    .collect();
482
483                let right_required: HashSet<_> = required
484                    .iter()
485                    .filter(|c| match c {
486                        RequiredColumn::Variable(v) => right_vars.contains(v),
487                        RequiredColumn::Property(v, _) => right_vars.contains(v),
488                    })
489                    .cloned()
490                    .collect();
491
492                join.left = Box::new(self.push_projections_recursive(*join.left, &left_required));
493                join.right =
494                    Box::new(self.push_projections_recursive(*join.right, &right_required));
495                LogicalOperator::Join(join)
496            }
497            LogicalOperator::Expand(mut expand) => {
498                expand.input = Box::new(self.push_projections_recursive(*expand.input, required));
499                LogicalOperator::Expand(expand)
500            }
501            LogicalOperator::Limit(mut limit) => {
502                limit.input = Box::new(self.push_projections_recursive(*limit.input, required));
503                LogicalOperator::Limit(limit)
504            }
505            LogicalOperator::Skip(mut skip) => {
506                skip.input = Box::new(self.push_projections_recursive(*skip.input, required));
507                LogicalOperator::Skip(skip)
508            }
509            LogicalOperator::Distinct(mut distinct) => {
510                distinct.input =
511                    Box::new(self.push_projections_recursive(*distinct.input, required));
512                LogicalOperator::Distinct(distinct)
513            }
514            LogicalOperator::MapCollect(mut mc) => {
515                mc.input = Box::new(self.push_projections_recursive(*mc.input, required));
516                LogicalOperator::MapCollect(mc)
517            }
518            LogicalOperator::MultiWayJoin(mut mwj) => {
519                mwj.inputs = mwj
520                    .inputs
521                    .into_iter()
522                    .map(|input| self.push_projections_recursive(input, required))
523                    .collect();
524                LogicalOperator::MultiWayJoin(mwj)
525            }
526            other => other,
527        }
528    }
529
530    /// Reorders joins in the operator tree using the DPccp algorithm.
531    ///
532    /// This optimization finds the optimal join order by:
533    /// 1. Extracting all base relations (scans) and join conditions
534    /// 2. Building a join graph
535    /// 3. Using dynamic programming to find the cheapest join order
536    fn reorder_joins(&self, op: LogicalOperator) -> LogicalOperator {
537        // First, recursively optimize children
538        let op = self.reorder_joins_recursive(op);
539
540        // Then, if this is a join tree, try to optimize it
541        if let Some((relations, conditions)) = self.extract_join_tree(&op)
542            && relations.len() >= 2
543            && let Some(optimized) = self.optimize_join_order(&relations, &conditions)
544        {
545            return optimized;
546        }
547
548        op
549    }
550
551    /// Recursively applies join reordering to child operators.
552    fn reorder_joins_recursive(&self, op: LogicalOperator) -> LogicalOperator {
553        match op {
554            LogicalOperator::Return(mut ret) => {
555                ret.input = Box::new(self.reorder_joins(*ret.input));
556                LogicalOperator::Return(ret)
557            }
558            LogicalOperator::Project(mut proj) => {
559                proj.input = Box::new(self.reorder_joins(*proj.input));
560                LogicalOperator::Project(proj)
561            }
562            LogicalOperator::Filter(mut filter) => {
563                filter.input = Box::new(self.reorder_joins(*filter.input));
564                LogicalOperator::Filter(filter)
565            }
566            LogicalOperator::Limit(mut limit) => {
567                limit.input = Box::new(self.reorder_joins(*limit.input));
568                LogicalOperator::Limit(limit)
569            }
570            LogicalOperator::Skip(mut skip) => {
571                skip.input = Box::new(self.reorder_joins(*skip.input));
572                LogicalOperator::Skip(skip)
573            }
574            LogicalOperator::Sort(mut sort) => {
575                sort.input = Box::new(self.reorder_joins(*sort.input));
576                LogicalOperator::Sort(sort)
577            }
578            LogicalOperator::Distinct(mut distinct) => {
579                distinct.input = Box::new(self.reorder_joins(*distinct.input));
580                LogicalOperator::Distinct(distinct)
581            }
582            LogicalOperator::Aggregate(mut agg) => {
583                agg.input = Box::new(self.reorder_joins(*agg.input));
584                LogicalOperator::Aggregate(agg)
585            }
586            LogicalOperator::Expand(mut expand) => {
587                expand.input = Box::new(self.reorder_joins(*expand.input));
588                LogicalOperator::Expand(expand)
589            }
590            LogicalOperator::MapCollect(mut mc) => {
591                mc.input = Box::new(self.reorder_joins(*mc.input));
592                LogicalOperator::MapCollect(mc)
593            }
594            LogicalOperator::MultiWayJoin(mut mwj) => {
595                mwj.inputs = mwj
596                    .inputs
597                    .into_iter()
598                    .map(|input| self.reorder_joins(input))
599                    .collect();
600                LogicalOperator::MultiWayJoin(mwj)
601            }
602            // Join operators are handled by the parent reorder_joins call
603            other => other,
604        }
605    }
606
607    /// Extracts base relations and join conditions from a join tree.
608    ///
609    /// Returns None if the operator is not a join tree.
610    fn extract_join_tree(
611        &self,
612        op: &LogicalOperator,
613    ) -> Option<(Vec<(String, LogicalOperator)>, Vec<JoinInfo>)> {
614        let mut relations = Vec::new();
615        let mut join_conditions = Vec::new();
616
617        if !self.collect_join_tree(op, &mut relations, &mut join_conditions) {
618            return None;
619        }
620
621        if relations.len() < 2 {
622            return None;
623        }
624
625        Some((relations, join_conditions))
626    }
627
628    /// Recursively collects base relations and join conditions.
629    ///
630    /// Returns true if this subtree is part of a join tree.
631    fn collect_join_tree(
632        &self,
633        op: &LogicalOperator,
634        relations: &mut Vec<(String, LogicalOperator)>,
635        conditions: &mut Vec<JoinInfo>,
636    ) -> bool {
637        match op {
638            LogicalOperator::Join(join) => {
639                // Collect from both sides
640                let left_ok = self.collect_join_tree(&join.left, relations, conditions);
641                let right_ok = self.collect_join_tree(&join.right, relations, conditions);
642
643                // Add conditions from this join
644                for cond in &join.conditions {
645                    if let (Some(left_var), Some(right_var)) = (
646                        self.extract_variable_from_expr(&cond.left),
647                        self.extract_variable_from_expr(&cond.right),
648                    ) {
649                        conditions.push(JoinInfo {
650                            left_var,
651                            right_var,
652                            left_expr: cond.left.clone(),
653                            right_expr: cond.right.clone(),
654                        });
655                    }
656                }
657
658                left_ok && right_ok
659            }
660            LogicalOperator::NodeScan(scan) => {
661                relations.push((scan.variable.clone(), op.clone()));
662                true
663            }
664            LogicalOperator::EdgeScan(scan) => {
665                relations.push((scan.variable.clone(), op.clone()));
666                true
667            }
668            LogicalOperator::Filter(filter) => {
669                // A filter on a base relation is still part of the join tree
670                self.collect_join_tree(&filter.input, relations, conditions)
671            }
672            LogicalOperator::Expand(expand) => {
673                // Expand is a special case - it's like a join with the adjacency
674                // For now, treat the whole Expand subtree as a single relation
675                relations.push((expand.to_variable.clone(), op.clone()));
676                true
677            }
678            _ => false,
679        }
680    }
681
682    /// Extracts the primary variable from an expression.
683    fn extract_variable_from_expr(&self, expr: &LogicalExpression) -> Option<String> {
684        match expr {
685            LogicalExpression::Variable(v) => Some(v.clone()),
686            LogicalExpression::Property { variable, .. } => Some(variable.clone()),
687            _ => None,
688        }
689    }
690
691    /// Optimizes the join order using DPccp, or produces a multi-way
692    /// leapfrog join for cyclic patterns when the cost model prefers it.
693    fn optimize_join_order(
694        &self,
695        relations: &[(String, LogicalOperator)],
696        conditions: &[JoinInfo],
697    ) -> Option<LogicalOperator> {
698        use join_order::{DPccp, JoinGraphBuilder};
699
700        // Build the join graph
701        let mut builder = JoinGraphBuilder::new();
702
703        for (var, relation) in relations {
704            builder.add_relation(var, relation.clone());
705        }
706
707        for cond in conditions {
708            builder.add_join_condition(
709                &cond.left_var,
710                &cond.right_var,
711                cond.left_expr.clone(),
712                cond.right_expr.clone(),
713            );
714        }
715
716        let graph = builder.build();
717
718        // For cyclic graphs with 3+ relations, use leapfrog (WCOJ) join.
719        // Cyclic joins (e.g. triangle patterns) benefit from worst-case optimal
720        // multi-way intersection rather than binary hash join cascades that can
721        // produce intermediate blowup.
722        if graph.is_cyclic() && relations.len() >= 3 {
723            // Collect shared variables (variables appearing in 2+ conditions)
724            let mut var_counts: std::collections::HashMap<&str, usize> =
725                std::collections::HashMap::new();
726            for cond in conditions {
727                *var_counts.entry(&cond.left_var).or_default() += 1;
728                *var_counts.entry(&cond.right_var).or_default() += 1;
729            }
730            let shared_variables: Vec<String> = var_counts
731                .into_iter()
732                .filter(|(_, count)| *count >= 2)
733                .map(|(var, _)| var.to_string())
734                .collect();
735
736            let join_conditions: Vec<JoinCondition> = conditions
737                .iter()
738                .map(|c| JoinCondition {
739                    left: c.left_expr.clone(),
740                    right: c.right_expr.clone(),
741                })
742                .collect();
743
744            return Some(LogicalOperator::MultiWayJoin(MultiWayJoinOp {
745                inputs: relations.iter().map(|(_, rel)| rel.clone()).collect(),
746                conditions: join_conditions,
747                shared_variables,
748            }));
749        }
750
751        // Fall through to DPccp for binary join ordering
752        let mut dpccp = DPccp::new(&graph, &self.cost_model, &self.card_estimator);
753        let plan = dpccp.optimize()?;
754
755        Some(plan.operator)
756    }
757
758    /// Pushes filters down the operator tree.
759    ///
760    /// This optimization moves filter predicates as close to the data source
761    /// as possible to reduce the amount of data processed by upper operators.
762    fn push_filters_down(&self, op: LogicalOperator) -> LogicalOperator {
763        match op {
764            // For Filter operators, try to push the predicate into the child
765            LogicalOperator::Filter(filter) => {
766                let optimized_input = self.push_filters_down(*filter.input);
767                self.try_push_filter_into(filter.predicate, optimized_input)
768            }
769            // Recursively optimize children for other operators
770            LogicalOperator::Return(mut ret) => {
771                ret.input = Box::new(self.push_filters_down(*ret.input));
772                LogicalOperator::Return(ret)
773            }
774            LogicalOperator::Project(mut proj) => {
775                proj.input = Box::new(self.push_filters_down(*proj.input));
776                LogicalOperator::Project(proj)
777            }
778            LogicalOperator::Limit(mut limit) => {
779                limit.input = Box::new(self.push_filters_down(*limit.input));
780                LogicalOperator::Limit(limit)
781            }
782            LogicalOperator::Skip(mut skip) => {
783                skip.input = Box::new(self.push_filters_down(*skip.input));
784                LogicalOperator::Skip(skip)
785            }
786            LogicalOperator::Sort(mut sort) => {
787                sort.input = Box::new(self.push_filters_down(*sort.input));
788                LogicalOperator::Sort(sort)
789            }
790            LogicalOperator::Distinct(mut distinct) => {
791                distinct.input = Box::new(self.push_filters_down(*distinct.input));
792                LogicalOperator::Distinct(distinct)
793            }
794            LogicalOperator::Expand(mut expand) => {
795                expand.input = Box::new(self.push_filters_down(*expand.input));
796                LogicalOperator::Expand(expand)
797            }
798            LogicalOperator::Join(mut join) => {
799                join.left = Box::new(self.push_filters_down(*join.left));
800                join.right = Box::new(self.push_filters_down(*join.right));
801                LogicalOperator::Join(join)
802            }
803            LogicalOperator::Aggregate(mut agg) => {
804                agg.input = Box::new(self.push_filters_down(*agg.input));
805                LogicalOperator::Aggregate(agg)
806            }
807            LogicalOperator::MapCollect(mut mc) => {
808                mc.input = Box::new(self.push_filters_down(*mc.input));
809                LogicalOperator::MapCollect(mc)
810            }
811            LogicalOperator::MultiWayJoin(mut mwj) => {
812                mwj.inputs = mwj
813                    .inputs
814                    .into_iter()
815                    .map(|input| self.push_filters_down(input))
816                    .collect();
817                LogicalOperator::MultiWayJoin(mwj)
818            }
819            // Leaf operators and unsupported operators are returned as-is
820            other => other,
821        }
822    }
823
824    /// Tries to push a filter predicate into the given operator.
825    ///
826    /// Returns either the predicate pushed into the operator, or a new
827    /// Filter operator on top if the predicate cannot be pushed further.
828    fn try_push_filter_into(
829        &self,
830        predicate: LogicalExpression,
831        op: LogicalOperator,
832    ) -> LogicalOperator {
833        match op {
834            // Can push through Project if predicate doesn't depend on computed columns
835            LogicalOperator::Project(mut proj) => {
836                let predicate_vars = self.extract_variables(&predicate);
837                let computed_vars = self.extract_projection_aliases(&proj.projections);
838
839                // If predicate doesn't use any computed columns, push through
840                if predicate_vars.is_disjoint(&computed_vars) {
841                    proj.input = Box::new(self.try_push_filter_into(predicate, *proj.input));
842                    LogicalOperator::Project(proj)
843                } else {
844                    // Can't push through, keep filter on top
845                    LogicalOperator::Filter(FilterOp {
846                        predicate,
847                        pushdown_hint: None,
848                        input: Box::new(LogicalOperator::Project(proj)),
849                    })
850                }
851            }
852
853            // Can push through Return (which is like a projection)
854            LogicalOperator::Return(mut ret) => {
855                ret.input = Box::new(self.try_push_filter_into(predicate, *ret.input));
856                LogicalOperator::Return(ret)
857            }
858
859            // Can push through Expand if predicate doesn't use variables introduced by this expand
860            LogicalOperator::Expand(mut expand) => {
861                let predicate_vars = self.extract_variables(&predicate);
862
863                // Variables introduced by this expand are:
864                // - The target variable (to_variable)
865                // - The edge variable (if any)
866                // - The path alias (if any)
867                let mut introduced_vars = vec![&expand.to_variable];
868                if let Some(ref edge_var) = expand.edge_variable {
869                    introduced_vars.push(edge_var);
870                }
871                if let Some(ref path_alias) = expand.path_alias {
872                    introduced_vars.push(path_alias);
873                }
874
875                // Check if predicate uses any variables introduced by this expand
876                let uses_introduced_vars =
877                    predicate_vars.iter().any(|v| introduced_vars.contains(&v));
878
879                if !uses_introduced_vars {
880                    // Predicate doesn't use vars from this expand, so push through
881                    expand.input = Box::new(self.try_push_filter_into(predicate, *expand.input));
882                    LogicalOperator::Expand(expand)
883                } else {
884                    // Keep filter after expand
885                    LogicalOperator::Filter(FilterOp {
886                        predicate,
887                        pushdown_hint: None,
888                        input: Box::new(LogicalOperator::Expand(expand)),
889                    })
890                }
891            }
892
893            // Can push through Join to left/right side based on variables used
894            LogicalOperator::Join(mut join) => {
895                let predicate_vars = self.extract_variables(&predicate);
896                let left_vars = self.collect_output_variables(&join.left);
897                let right_vars = self.collect_output_variables(&join.right);
898
899                let uses_left = predicate_vars.iter().any(|v| left_vars.contains(v));
900                let uses_right = predicate_vars.iter().any(|v| right_vars.contains(v));
901
902                if uses_left && !uses_right {
903                    // Push to left side
904                    join.left = Box::new(self.try_push_filter_into(predicate, *join.left));
905                    LogicalOperator::Join(join)
906                } else if uses_right && !uses_left {
907                    // Push to right side
908                    join.right = Box::new(self.try_push_filter_into(predicate, *join.right));
909                    LogicalOperator::Join(join)
910                } else {
911                    // Uses both sides - keep above join
912                    LogicalOperator::Filter(FilterOp {
913                        predicate,
914                        pushdown_hint: None,
915                        input: Box::new(LogicalOperator::Join(join)),
916                    })
917                }
918            }
919
920            // Cannot push through Aggregate (predicate refers to aggregated values)
921            LogicalOperator::Aggregate(agg) => LogicalOperator::Filter(FilterOp {
922                predicate,
923                pushdown_hint: None,
924                input: Box::new(LogicalOperator::Aggregate(agg)),
925            }),
926
927            // For NodeScan, we've reached the bottom - keep filter on top
928            LogicalOperator::NodeScan(scan) => LogicalOperator::Filter(FilterOp {
929                predicate,
930                pushdown_hint: None,
931                input: Box::new(LogicalOperator::NodeScan(scan)),
932            }),
933
934            // For other operators, keep filter on top
935            other => LogicalOperator::Filter(FilterOp {
936                predicate,
937                pushdown_hint: None,
938                input: Box::new(other),
939            }),
940        }
941    }
942
943    /// Collects all output variable names from an operator.
944    fn collect_output_variables(&self, op: &LogicalOperator) -> HashSet<String> {
945        let mut vars = HashSet::new();
946        Self::collect_output_variables_recursive(op, &mut vars);
947        vars
948    }
949
950    /// Recursively collects output variables from an operator.
951    fn collect_output_variables_recursive(op: &LogicalOperator, vars: &mut HashSet<String>) {
952        match op {
953            LogicalOperator::NodeScan(scan) => {
954                vars.insert(scan.variable.clone());
955            }
956            LogicalOperator::EdgeScan(scan) => {
957                vars.insert(scan.variable.clone());
958            }
959            LogicalOperator::Expand(expand) => {
960                vars.insert(expand.to_variable.clone());
961                if let Some(edge_var) = &expand.edge_variable {
962                    vars.insert(edge_var.clone());
963                }
964                Self::collect_output_variables_recursive(&expand.input, vars);
965            }
966            LogicalOperator::Filter(filter) => {
967                Self::collect_output_variables_recursive(&filter.input, vars);
968            }
969            LogicalOperator::Project(proj) => {
970                for p in &proj.projections {
971                    if let Some(alias) = &p.alias {
972                        vars.insert(alias.clone());
973                    }
974                }
975                Self::collect_output_variables_recursive(&proj.input, vars);
976            }
977            LogicalOperator::Join(join) => {
978                Self::collect_output_variables_recursive(&join.left, vars);
979                Self::collect_output_variables_recursive(&join.right, vars);
980            }
981            LogicalOperator::Aggregate(agg) => {
982                for expr in &agg.group_by {
983                    Self::collect_variables(expr, vars);
984                }
985                for agg_expr in &agg.aggregates {
986                    if let Some(alias) = &agg_expr.alias {
987                        vars.insert(alias.clone());
988                    }
989                }
990            }
991            LogicalOperator::Return(ret) => {
992                Self::collect_output_variables_recursive(&ret.input, vars);
993            }
994            LogicalOperator::Limit(limit) => {
995                Self::collect_output_variables_recursive(&limit.input, vars);
996            }
997            LogicalOperator::Skip(skip) => {
998                Self::collect_output_variables_recursive(&skip.input, vars);
999            }
1000            LogicalOperator::Sort(sort) => {
1001                Self::collect_output_variables_recursive(&sort.input, vars);
1002            }
1003            LogicalOperator::Distinct(distinct) => {
1004                Self::collect_output_variables_recursive(&distinct.input, vars);
1005            }
1006            _ => {}
1007        }
1008    }
1009
1010    /// Extracts all variable names referenced in an expression.
1011    fn extract_variables(&self, expr: &LogicalExpression) -> HashSet<String> {
1012        let mut vars = HashSet::new();
1013        Self::collect_variables(expr, &mut vars);
1014        vars
1015    }
1016
1017    /// Recursively collects variable names from an expression.
1018    fn collect_variables(expr: &LogicalExpression, vars: &mut HashSet<String>) {
1019        match expr {
1020            LogicalExpression::Variable(name) => {
1021                vars.insert(name.clone());
1022            }
1023            LogicalExpression::Property { variable, .. } => {
1024                vars.insert(variable.clone());
1025            }
1026            LogicalExpression::Binary { left, right, .. } => {
1027                Self::collect_variables(left, vars);
1028                Self::collect_variables(right, vars);
1029            }
1030            LogicalExpression::Unary { operand, .. } => {
1031                Self::collect_variables(operand, vars);
1032            }
1033            LogicalExpression::FunctionCall { args, .. } => {
1034                for arg in args {
1035                    Self::collect_variables(arg, vars);
1036                }
1037            }
1038            LogicalExpression::List(items) => {
1039                for item in items {
1040                    Self::collect_variables(item, vars);
1041                }
1042            }
1043            LogicalExpression::Map(pairs) => {
1044                for (_, value) in pairs {
1045                    Self::collect_variables(value, vars);
1046                }
1047            }
1048            LogicalExpression::IndexAccess { base, index } => {
1049                Self::collect_variables(base, vars);
1050                Self::collect_variables(index, vars);
1051            }
1052            LogicalExpression::SliceAccess { base, start, end } => {
1053                Self::collect_variables(base, vars);
1054                if let Some(s) = start {
1055                    Self::collect_variables(s, vars);
1056                }
1057                if let Some(e) = end {
1058                    Self::collect_variables(e, vars);
1059                }
1060            }
1061            LogicalExpression::Case {
1062                operand,
1063                when_clauses,
1064                else_clause,
1065            } => {
1066                if let Some(op) = operand {
1067                    Self::collect_variables(op, vars);
1068                }
1069                for (cond, result) in when_clauses {
1070                    Self::collect_variables(cond, vars);
1071                    Self::collect_variables(result, vars);
1072                }
1073                if let Some(else_expr) = else_clause {
1074                    Self::collect_variables(else_expr, vars);
1075                }
1076            }
1077            LogicalExpression::Labels(var)
1078            | LogicalExpression::Type(var)
1079            | LogicalExpression::Id(var) => {
1080                vars.insert(var.clone());
1081            }
1082            LogicalExpression::Literal(_) | LogicalExpression::Parameter(_) => {}
1083            LogicalExpression::ListComprehension {
1084                list_expr,
1085                filter_expr,
1086                map_expr,
1087                ..
1088            } => {
1089                Self::collect_variables(list_expr, vars);
1090                if let Some(filter) = filter_expr {
1091                    Self::collect_variables(filter, vars);
1092                }
1093                Self::collect_variables(map_expr, vars);
1094            }
1095            LogicalExpression::ListPredicate {
1096                list_expr,
1097                predicate,
1098                ..
1099            } => {
1100                Self::collect_variables(list_expr, vars);
1101                Self::collect_variables(predicate, vars);
1102            }
1103            LogicalExpression::ExistsSubquery(_)
1104            | LogicalExpression::CountSubquery(_)
1105            | LogicalExpression::ValueSubquery(_) => {
1106                // Subqueries have their own variable scope
1107            }
1108            LogicalExpression::PatternComprehension { projection, .. } => {
1109                Self::collect_variables(projection, vars);
1110            }
1111            LogicalExpression::MapProjection { base, entries } => {
1112                vars.insert(base.clone());
1113                for entry in entries {
1114                    if let crate::query::plan::MapProjectionEntry::LiteralEntry(_, expr) = entry {
1115                        Self::collect_variables(expr, vars);
1116                    }
1117                }
1118            }
1119            LogicalExpression::Reduce {
1120                initial,
1121                list,
1122                expression,
1123                ..
1124            } => {
1125                Self::collect_variables(initial, vars);
1126                Self::collect_variables(list, vars);
1127                Self::collect_variables(expression, vars);
1128            }
1129        }
1130    }
1131
1132    /// Extracts aliases from projection expressions.
1133    fn extract_projection_aliases(
1134        &self,
1135        projections: &[crate::query::plan::Projection],
1136    ) -> HashSet<String> {
1137        projections.iter().filter_map(|p| p.alias.clone()).collect()
1138    }
1139}
1140
1141impl Default for Optimizer {
1142    fn default() -> Self {
1143        Self::new()
1144    }
1145}
1146
1147#[cfg(test)]
1148mod tests {
1149    use super::*;
1150    use crate::query::plan::{
1151        AggregateExpr, AggregateFunction, AggregateOp, BinaryOp, DistinctOp, ExpandDirection,
1152        ExpandOp, JoinOp, JoinType, LimitOp, NodeScanOp, PathMode, ProjectOp, Projection,
1153        ReturnItem, ReturnOp, SkipOp, SortKey, SortOp, SortOrder, UnaryOp,
1154    };
1155    use grafeo_common::types::Value;
1156
1157    #[test]
1158    fn test_optimizer_filter_pushdown_simple() {
1159        // Query: MATCH (n:Person) WHERE n.age > 30 RETURN n
1160        // Before: Return -> Filter -> NodeScan
1161        // After:  Return -> Filter -> NodeScan (filter stays at bottom)
1162
1163        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1164            items: vec![ReturnItem {
1165                expression: LogicalExpression::Variable("n".to_string()),
1166                alias: None,
1167            }],
1168            distinct: false,
1169            input: Box::new(LogicalOperator::Filter(FilterOp {
1170                predicate: LogicalExpression::Binary {
1171                    left: Box::new(LogicalExpression::Property {
1172                        variable: "n".to_string(),
1173                        property: "age".to_string(),
1174                    }),
1175                    op: BinaryOp::Gt,
1176                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1177                },
1178                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1179                    variable: "n".to_string(),
1180                    label: Some("Person".to_string()),
1181                    input: None,
1182                })),
1183                pushdown_hint: None,
1184            })),
1185        }));
1186
1187        let optimizer = Optimizer::new();
1188        let optimized = optimizer.optimize(plan).unwrap();
1189
1190        // The structure should remain similar (filter stays near scan)
1191        if let LogicalOperator::Return(ret) = &optimized.root
1192            && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1193            && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1194        {
1195            assert_eq!(scan.variable, "n");
1196            return;
1197        }
1198        panic!("Expected Return -> Filter -> NodeScan structure");
1199    }
1200
1201    #[test]
1202    fn test_optimizer_filter_pushdown_through_expand() {
1203        // Query: MATCH (a:Person)-[:KNOWS]->(b) WHERE a.age > 30 RETURN b
1204        // The filter on 'a' should be pushed before the expand
1205
1206        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1207            items: vec![ReturnItem {
1208                expression: LogicalExpression::Variable("b".to_string()),
1209                alias: None,
1210            }],
1211            distinct: false,
1212            input: Box::new(LogicalOperator::Filter(FilterOp {
1213                predicate: LogicalExpression::Binary {
1214                    left: Box::new(LogicalExpression::Property {
1215                        variable: "a".to_string(),
1216                        property: "age".to_string(),
1217                    }),
1218                    op: BinaryOp::Gt,
1219                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1220                },
1221                pushdown_hint: None,
1222                input: Box::new(LogicalOperator::Expand(ExpandOp {
1223                    from_variable: "a".to_string(),
1224                    to_variable: "b".to_string(),
1225                    edge_variable: None,
1226                    direction: ExpandDirection::Outgoing,
1227                    edge_types: vec!["KNOWS".to_string()],
1228                    min_hops: 1,
1229                    max_hops: Some(1),
1230                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1231                        variable: "a".to_string(),
1232                        label: Some("Person".to_string()),
1233                        input: None,
1234                    })),
1235                    path_alias: None,
1236                    path_mode: PathMode::Walk,
1237                })),
1238            })),
1239        }));
1240
1241        let optimizer = Optimizer::new();
1242        let optimized = optimizer.optimize(plan).unwrap();
1243
1244        // Filter on 'a' should be pushed before the expand
1245        // Expected: Return -> Expand -> Filter -> NodeScan
1246        if let LogicalOperator::Return(ret) = &optimized.root
1247            && let LogicalOperator::Expand(expand) = ret.input.as_ref()
1248            && let LogicalOperator::Filter(filter) = expand.input.as_ref()
1249            && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1250        {
1251            assert_eq!(scan.variable, "a");
1252            assert_eq!(expand.from_variable, "a");
1253            assert_eq!(expand.to_variable, "b");
1254            return;
1255        }
1256        panic!("Expected Return -> Expand -> Filter -> NodeScan structure");
1257    }
1258
1259    #[test]
1260    fn test_optimizer_filter_not_pushed_through_expand_for_target_var() {
1261        // Query: MATCH (a:Person)-[:KNOWS]->(b) WHERE b.age > 30 RETURN a
1262        // The filter on 'b' should NOT be pushed before the expand
1263
1264        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1265            items: vec![ReturnItem {
1266                expression: LogicalExpression::Variable("a".to_string()),
1267                alias: None,
1268            }],
1269            distinct: false,
1270            input: Box::new(LogicalOperator::Filter(FilterOp {
1271                predicate: LogicalExpression::Binary {
1272                    left: Box::new(LogicalExpression::Property {
1273                        variable: "b".to_string(),
1274                        property: "age".to_string(),
1275                    }),
1276                    op: BinaryOp::Gt,
1277                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1278                },
1279                pushdown_hint: None,
1280                input: Box::new(LogicalOperator::Expand(ExpandOp {
1281                    from_variable: "a".to_string(),
1282                    to_variable: "b".to_string(),
1283                    edge_variable: None,
1284                    direction: ExpandDirection::Outgoing,
1285                    edge_types: vec!["KNOWS".to_string()],
1286                    min_hops: 1,
1287                    max_hops: Some(1),
1288                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1289                        variable: "a".to_string(),
1290                        label: Some("Person".to_string()),
1291                        input: None,
1292                    })),
1293                    path_alias: None,
1294                    path_mode: PathMode::Walk,
1295                })),
1296            })),
1297        }));
1298
1299        let optimizer = Optimizer::new();
1300        let optimized = optimizer.optimize(plan).unwrap();
1301
1302        // Filter on 'b' should stay after the expand
1303        // Expected: Return -> Filter -> Expand -> NodeScan
1304        if let LogicalOperator::Return(ret) = &optimized.root
1305            && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1306        {
1307            // Check that the filter is on 'b'
1308            if let LogicalExpression::Binary { left, .. } = &filter.predicate
1309                && let LogicalExpression::Property { variable, .. } = left.as_ref()
1310            {
1311                assert_eq!(variable, "b");
1312            }
1313
1314            if let LogicalOperator::Expand(expand) = filter.input.as_ref()
1315                && let LogicalOperator::NodeScan(_) = expand.input.as_ref()
1316            {
1317                return;
1318            }
1319        }
1320        panic!("Expected Return -> Filter -> Expand -> NodeScan structure");
1321    }
1322
1323    #[test]
1324    fn test_optimizer_extract_variables() {
1325        let optimizer = Optimizer::new();
1326
1327        let expr = LogicalExpression::Binary {
1328            left: Box::new(LogicalExpression::Property {
1329                variable: "n".to_string(),
1330                property: "age".to_string(),
1331            }),
1332            op: BinaryOp::Gt,
1333            right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1334        };
1335
1336        let vars = optimizer.extract_variables(&expr);
1337        assert_eq!(vars.len(), 1);
1338        assert!(vars.contains("n"));
1339    }
1340
1341    // Additional tests for optimizer configuration
1342
1343    #[test]
1344    fn test_optimizer_default() {
1345        let optimizer = Optimizer::default();
1346        // Should be able to optimize an empty plan
1347        let plan = LogicalPlan::new(LogicalOperator::Empty);
1348        let result = optimizer.optimize(plan);
1349        assert!(result.is_ok());
1350    }
1351
1352    #[test]
1353    fn test_optimizer_with_filter_pushdown_disabled() {
1354        let optimizer = Optimizer::new().with_filter_pushdown(false);
1355
1356        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1357            items: vec![ReturnItem {
1358                expression: LogicalExpression::Variable("n".to_string()),
1359                alias: None,
1360            }],
1361            distinct: false,
1362            input: Box::new(LogicalOperator::Filter(FilterOp {
1363                predicate: LogicalExpression::Literal(Value::Bool(true)),
1364                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1365                    variable: "n".to_string(),
1366                    label: None,
1367                    input: None,
1368                })),
1369                pushdown_hint: None,
1370            })),
1371        }));
1372
1373        let optimized = optimizer.optimize(plan).unwrap();
1374        // Structure should be unchanged
1375        if let LogicalOperator::Return(ret) = &optimized.root
1376            && let LogicalOperator::Filter(_) = ret.input.as_ref()
1377        {
1378            return;
1379        }
1380        panic!("Expected unchanged structure");
1381    }
1382
1383    #[test]
1384    fn test_optimizer_with_join_reorder_disabled() {
1385        let optimizer = Optimizer::new().with_join_reorder(false);
1386        assert!(
1387            optimizer
1388                .optimize(LogicalPlan::new(LogicalOperator::Empty))
1389                .is_ok()
1390        );
1391    }
1392
1393    #[test]
1394    fn test_optimizer_with_cost_model() {
1395        let cost_model = CostModel::new();
1396        let optimizer = Optimizer::new().with_cost_model(cost_model);
1397        assert!(
1398            optimizer
1399                .cost_model()
1400                .estimate(&LogicalOperator::Empty, 0.0)
1401                .total()
1402                < 0.001
1403        );
1404    }
1405
1406    #[test]
1407    fn test_optimizer_with_cardinality_estimator() {
1408        let mut estimator = CardinalityEstimator::new();
1409        estimator.add_table_stats("Test", TableStats::new(500));
1410        let optimizer = Optimizer::new().with_cardinality_estimator(estimator);
1411
1412        let scan = LogicalOperator::NodeScan(NodeScanOp {
1413            variable: "n".to_string(),
1414            label: Some("Test".to_string()),
1415            input: None,
1416        });
1417        let plan = LogicalPlan::new(scan);
1418
1419        let cardinality = optimizer.estimate_cardinality(&plan);
1420        assert!((cardinality - 500.0).abs() < 0.001);
1421    }
1422
1423    #[test]
1424    fn test_optimizer_estimate_cost() {
1425        let optimizer = Optimizer::new();
1426        let plan = LogicalPlan::new(LogicalOperator::NodeScan(NodeScanOp {
1427            variable: "n".to_string(),
1428            label: None,
1429            input: None,
1430        }));
1431
1432        let cost = optimizer.estimate_cost(&plan);
1433        assert!(cost.total() > 0.0);
1434    }
1435
1436    // Filter pushdown through various operators
1437
1438    #[test]
1439    fn test_filter_pushdown_through_project() {
1440        let optimizer = Optimizer::new();
1441
1442        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1443            predicate: LogicalExpression::Binary {
1444                left: Box::new(LogicalExpression::Property {
1445                    variable: "n".to_string(),
1446                    property: "age".to_string(),
1447                }),
1448                op: BinaryOp::Gt,
1449                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1450            },
1451            pushdown_hint: None,
1452            input: Box::new(LogicalOperator::Project(ProjectOp {
1453                projections: vec![Projection {
1454                    expression: LogicalExpression::Variable("n".to_string()),
1455                    alias: None,
1456                }],
1457                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1458                    variable: "n".to_string(),
1459                    label: None,
1460                    input: None,
1461                })),
1462                pass_through_input: false,
1463            })),
1464        }));
1465
1466        let optimized = optimizer.optimize(plan).unwrap();
1467
1468        // Filter should be pushed through Project
1469        if let LogicalOperator::Project(proj) = &optimized.root
1470            && let LogicalOperator::Filter(_) = proj.input.as_ref()
1471        {
1472            return;
1473        }
1474        panic!("Expected Project -> Filter structure");
1475    }
1476
1477    #[test]
1478    fn test_filter_not_pushed_through_project_with_alias() {
1479        let optimizer = Optimizer::new();
1480
1481        // Filter on computed column 'x' should not be pushed through project that creates 'x'
1482        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1483            predicate: LogicalExpression::Binary {
1484                left: Box::new(LogicalExpression::Variable("x".to_string())),
1485                op: BinaryOp::Gt,
1486                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1487            },
1488            pushdown_hint: None,
1489            input: Box::new(LogicalOperator::Project(ProjectOp {
1490                projections: vec![Projection {
1491                    expression: LogicalExpression::Property {
1492                        variable: "n".to_string(),
1493                        property: "age".to_string(),
1494                    },
1495                    alias: Some("x".to_string()),
1496                }],
1497                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1498                    variable: "n".to_string(),
1499                    label: None,
1500                    input: None,
1501                })),
1502                pass_through_input: false,
1503            })),
1504        }));
1505
1506        let optimized = optimizer.optimize(plan).unwrap();
1507
1508        // Filter should stay above Project
1509        if let LogicalOperator::Filter(filter) = &optimized.root
1510            && let LogicalOperator::Project(_) = filter.input.as_ref()
1511        {
1512            return;
1513        }
1514        panic!("Expected Filter -> Project structure");
1515    }
1516
1517    #[test]
1518    fn test_filter_pushdown_through_limit() {
1519        let optimizer = Optimizer::new();
1520
1521        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1522            predicate: LogicalExpression::Literal(Value::Bool(true)),
1523            pushdown_hint: None,
1524            input: Box::new(LogicalOperator::Limit(LimitOp {
1525                count: 10.into(),
1526                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1527                    variable: "n".to_string(),
1528                    label: None,
1529                    input: None,
1530                })),
1531            })),
1532        }));
1533
1534        let optimized = optimizer.optimize(plan).unwrap();
1535
1536        // Filter stays above Limit (cannot be pushed through)
1537        if let LogicalOperator::Filter(filter) = &optimized.root
1538            && let LogicalOperator::Limit(_) = filter.input.as_ref()
1539        {
1540            return;
1541        }
1542        panic!("Expected Filter -> Limit structure");
1543    }
1544
1545    #[test]
1546    fn test_filter_pushdown_through_sort() {
1547        let optimizer = Optimizer::new();
1548
1549        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1550            predicate: LogicalExpression::Literal(Value::Bool(true)),
1551            pushdown_hint: None,
1552            input: Box::new(LogicalOperator::Sort(SortOp {
1553                keys: vec![SortKey {
1554                    expression: LogicalExpression::Variable("n".to_string()),
1555                    order: SortOrder::Ascending,
1556                    nulls: None,
1557                }],
1558                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1559                    variable: "n".to_string(),
1560                    label: None,
1561                    input: None,
1562                })),
1563            })),
1564        }));
1565
1566        let optimized = optimizer.optimize(plan).unwrap();
1567
1568        // Filter stays above Sort
1569        if let LogicalOperator::Filter(filter) = &optimized.root
1570            && let LogicalOperator::Sort(_) = filter.input.as_ref()
1571        {
1572            return;
1573        }
1574        panic!("Expected Filter -> Sort structure");
1575    }
1576
1577    #[test]
1578    fn test_filter_pushdown_through_distinct() {
1579        let optimizer = Optimizer::new();
1580
1581        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1582            predicate: LogicalExpression::Literal(Value::Bool(true)),
1583            pushdown_hint: None,
1584            input: Box::new(LogicalOperator::Distinct(DistinctOp {
1585                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1586                    variable: "n".to_string(),
1587                    label: None,
1588                    input: None,
1589                })),
1590                columns: None,
1591            })),
1592        }));
1593
1594        let optimized = optimizer.optimize(plan).unwrap();
1595
1596        // Filter stays above Distinct
1597        if let LogicalOperator::Filter(filter) = &optimized.root
1598            && let LogicalOperator::Distinct(_) = filter.input.as_ref()
1599        {
1600            return;
1601        }
1602        panic!("Expected Filter -> Distinct structure");
1603    }
1604
1605    #[test]
1606    fn test_filter_not_pushed_through_aggregate() {
1607        let optimizer = Optimizer::new();
1608
1609        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1610            predicate: LogicalExpression::Binary {
1611                left: Box::new(LogicalExpression::Variable("cnt".to_string())),
1612                op: BinaryOp::Gt,
1613                right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1614            },
1615            pushdown_hint: None,
1616            input: Box::new(LogicalOperator::Aggregate(AggregateOp {
1617                group_by: vec![],
1618                aggregates: vec![AggregateExpr {
1619                    function: AggregateFunction::Count,
1620                    expression: None,
1621                    expression2: None,
1622                    distinct: false,
1623                    alias: Some("cnt".to_string()),
1624                    percentile: None,
1625                    separator: None,
1626                }],
1627                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1628                    variable: "n".to_string(),
1629                    label: None,
1630                    input: None,
1631                })),
1632                having: None,
1633            })),
1634        }));
1635
1636        let optimized = optimizer.optimize(plan).unwrap();
1637
1638        // Filter should stay above Aggregate
1639        if let LogicalOperator::Filter(filter) = &optimized.root
1640            && let LogicalOperator::Aggregate(_) = filter.input.as_ref()
1641        {
1642            return;
1643        }
1644        panic!("Expected Filter -> Aggregate structure");
1645    }
1646
1647    #[test]
1648    fn test_filter_pushdown_to_left_join_side() {
1649        let optimizer = Optimizer::new();
1650
1651        // Filter on left variable should be pushed to left side
1652        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1653            predicate: LogicalExpression::Binary {
1654                left: Box::new(LogicalExpression::Property {
1655                    variable: "a".to_string(),
1656                    property: "age".to_string(),
1657                }),
1658                op: BinaryOp::Gt,
1659                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1660            },
1661            pushdown_hint: None,
1662            input: Box::new(LogicalOperator::Join(JoinOp {
1663                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1664                    variable: "a".to_string(),
1665                    label: Some("Person".to_string()),
1666                    input: None,
1667                })),
1668                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1669                    variable: "b".to_string(),
1670                    label: Some("Company".to_string()),
1671                    input: None,
1672                })),
1673                join_type: JoinType::Inner,
1674                conditions: vec![],
1675            })),
1676        }));
1677
1678        let optimized = optimizer.optimize(plan).unwrap();
1679
1680        // Filter should be pushed to left side of join
1681        if let LogicalOperator::Join(join) = &optimized.root
1682            && let LogicalOperator::Filter(_) = join.left.as_ref()
1683        {
1684            return;
1685        }
1686        panic!("Expected Join with Filter on left side");
1687    }
1688
1689    #[test]
1690    fn test_filter_pushdown_to_right_join_side() {
1691        let optimizer = Optimizer::new();
1692
1693        // Filter on right variable should be pushed to right side
1694        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1695            predicate: LogicalExpression::Binary {
1696                left: Box::new(LogicalExpression::Property {
1697                    variable: "b".to_string(),
1698                    property: "name".to_string(),
1699                }),
1700                op: BinaryOp::Eq,
1701                right: Box::new(LogicalExpression::Literal(Value::String("Acme".into()))),
1702            },
1703            pushdown_hint: None,
1704            input: Box::new(LogicalOperator::Join(JoinOp {
1705                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1706                    variable: "a".to_string(),
1707                    label: Some("Person".to_string()),
1708                    input: None,
1709                })),
1710                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1711                    variable: "b".to_string(),
1712                    label: Some("Company".to_string()),
1713                    input: None,
1714                })),
1715                join_type: JoinType::Inner,
1716                conditions: vec![],
1717            })),
1718        }));
1719
1720        let optimized = optimizer.optimize(plan).unwrap();
1721
1722        // Filter should be pushed to right side of join
1723        if let LogicalOperator::Join(join) = &optimized.root
1724            && let LogicalOperator::Filter(_) = join.right.as_ref()
1725        {
1726            return;
1727        }
1728        panic!("Expected Join with Filter on right side");
1729    }
1730
1731    #[test]
1732    fn test_filter_not_pushed_when_uses_both_join_sides() {
1733        let optimizer = Optimizer::new();
1734
1735        // Filter using both variables should stay above join
1736        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1737            predicate: LogicalExpression::Binary {
1738                left: Box::new(LogicalExpression::Property {
1739                    variable: "a".to_string(),
1740                    property: "id".to_string(),
1741                }),
1742                op: BinaryOp::Eq,
1743                right: Box::new(LogicalExpression::Property {
1744                    variable: "b".to_string(),
1745                    property: "a_id".to_string(),
1746                }),
1747            },
1748            pushdown_hint: None,
1749            input: Box::new(LogicalOperator::Join(JoinOp {
1750                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1751                    variable: "a".to_string(),
1752                    label: None,
1753                    input: None,
1754                })),
1755                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1756                    variable: "b".to_string(),
1757                    label: None,
1758                    input: None,
1759                })),
1760                join_type: JoinType::Inner,
1761                conditions: vec![],
1762            })),
1763        }));
1764
1765        let optimized = optimizer.optimize(plan).unwrap();
1766
1767        // Filter should stay above join
1768        if let LogicalOperator::Filter(filter) = &optimized.root
1769            && let LogicalOperator::Join(_) = filter.input.as_ref()
1770        {
1771            return;
1772        }
1773        panic!("Expected Filter -> Join structure");
1774    }
1775
1776    // Variable extraction tests
1777
1778    #[test]
1779    fn test_extract_variables_from_variable() {
1780        let optimizer = Optimizer::new();
1781        let expr = LogicalExpression::Variable("x".to_string());
1782        let vars = optimizer.extract_variables(&expr);
1783        assert_eq!(vars.len(), 1);
1784        assert!(vars.contains("x"));
1785    }
1786
1787    #[test]
1788    fn test_extract_variables_from_unary() {
1789        let optimizer = Optimizer::new();
1790        let expr = LogicalExpression::Unary {
1791            op: UnaryOp::Not,
1792            operand: Box::new(LogicalExpression::Variable("x".to_string())),
1793        };
1794        let vars = optimizer.extract_variables(&expr);
1795        assert_eq!(vars.len(), 1);
1796        assert!(vars.contains("x"));
1797    }
1798
1799    #[test]
1800    fn test_extract_variables_from_function_call() {
1801        let optimizer = Optimizer::new();
1802        let expr = LogicalExpression::FunctionCall {
1803            name: "length".to_string(),
1804            args: vec![
1805                LogicalExpression::Variable("a".to_string()),
1806                LogicalExpression::Variable("b".to_string()),
1807            ],
1808            distinct: false,
1809        };
1810        let vars = optimizer.extract_variables(&expr);
1811        assert_eq!(vars.len(), 2);
1812        assert!(vars.contains("a"));
1813        assert!(vars.contains("b"));
1814    }
1815
1816    #[test]
1817    fn test_extract_variables_from_list() {
1818        let optimizer = Optimizer::new();
1819        let expr = LogicalExpression::List(vec![
1820            LogicalExpression::Variable("a".to_string()),
1821            LogicalExpression::Literal(Value::Int64(1)),
1822            LogicalExpression::Variable("b".to_string()),
1823        ]);
1824        let vars = optimizer.extract_variables(&expr);
1825        assert_eq!(vars.len(), 2);
1826        assert!(vars.contains("a"));
1827        assert!(vars.contains("b"));
1828    }
1829
1830    #[test]
1831    fn test_extract_variables_from_map() {
1832        let optimizer = Optimizer::new();
1833        let expr = LogicalExpression::Map(vec![
1834            (
1835                "key1".to_string(),
1836                LogicalExpression::Variable("a".to_string()),
1837            ),
1838            (
1839                "key2".to_string(),
1840                LogicalExpression::Variable("b".to_string()),
1841            ),
1842        ]);
1843        let vars = optimizer.extract_variables(&expr);
1844        assert_eq!(vars.len(), 2);
1845        assert!(vars.contains("a"));
1846        assert!(vars.contains("b"));
1847    }
1848
1849    #[test]
1850    fn test_extract_variables_from_index_access() {
1851        let optimizer = Optimizer::new();
1852        let expr = LogicalExpression::IndexAccess {
1853            base: Box::new(LogicalExpression::Variable("list".to_string())),
1854            index: Box::new(LogicalExpression::Variable("idx".to_string())),
1855        };
1856        let vars = optimizer.extract_variables(&expr);
1857        assert_eq!(vars.len(), 2);
1858        assert!(vars.contains("list"));
1859        assert!(vars.contains("idx"));
1860    }
1861
1862    #[test]
1863    fn test_extract_variables_from_slice_access() {
1864        let optimizer = Optimizer::new();
1865        let expr = LogicalExpression::SliceAccess {
1866            base: Box::new(LogicalExpression::Variable("list".to_string())),
1867            start: Some(Box::new(LogicalExpression::Variable("s".to_string()))),
1868            end: Some(Box::new(LogicalExpression::Variable("e".to_string()))),
1869        };
1870        let vars = optimizer.extract_variables(&expr);
1871        assert_eq!(vars.len(), 3);
1872        assert!(vars.contains("list"));
1873        assert!(vars.contains("s"));
1874        assert!(vars.contains("e"));
1875    }
1876
1877    #[test]
1878    fn test_extract_variables_from_case() {
1879        let optimizer = Optimizer::new();
1880        let expr = LogicalExpression::Case {
1881            operand: Some(Box::new(LogicalExpression::Variable("x".to_string()))),
1882            when_clauses: vec![(
1883                LogicalExpression::Literal(Value::Int64(1)),
1884                LogicalExpression::Variable("a".to_string()),
1885            )],
1886            else_clause: Some(Box::new(LogicalExpression::Variable("b".to_string()))),
1887        };
1888        let vars = optimizer.extract_variables(&expr);
1889        assert_eq!(vars.len(), 3);
1890        assert!(vars.contains("x"));
1891        assert!(vars.contains("a"));
1892        assert!(vars.contains("b"));
1893    }
1894
1895    #[test]
1896    fn test_extract_variables_from_labels() {
1897        let optimizer = Optimizer::new();
1898        let expr = LogicalExpression::Labels("n".to_string());
1899        let vars = optimizer.extract_variables(&expr);
1900        assert_eq!(vars.len(), 1);
1901        assert!(vars.contains("n"));
1902    }
1903
1904    #[test]
1905    fn test_extract_variables_from_type() {
1906        let optimizer = Optimizer::new();
1907        let expr = LogicalExpression::Type("e".to_string());
1908        let vars = optimizer.extract_variables(&expr);
1909        assert_eq!(vars.len(), 1);
1910        assert!(vars.contains("e"));
1911    }
1912
1913    #[test]
1914    fn test_extract_variables_from_id() {
1915        let optimizer = Optimizer::new();
1916        let expr = LogicalExpression::Id("n".to_string());
1917        let vars = optimizer.extract_variables(&expr);
1918        assert_eq!(vars.len(), 1);
1919        assert!(vars.contains("n"));
1920    }
1921
1922    #[test]
1923    fn test_extract_variables_from_list_comprehension() {
1924        let optimizer = Optimizer::new();
1925        let expr = LogicalExpression::ListComprehension {
1926            variable: "x".to_string(),
1927            list_expr: Box::new(LogicalExpression::Variable("items".to_string())),
1928            filter_expr: Some(Box::new(LogicalExpression::Variable("pred".to_string()))),
1929            map_expr: Box::new(LogicalExpression::Variable("result".to_string())),
1930        };
1931        let vars = optimizer.extract_variables(&expr);
1932        assert!(vars.contains("items"));
1933        assert!(vars.contains("pred"));
1934        assert!(vars.contains("result"));
1935    }
1936
1937    #[test]
1938    fn test_extract_variables_from_literal_and_parameter() {
1939        let optimizer = Optimizer::new();
1940
1941        let literal = LogicalExpression::Literal(Value::Int64(42));
1942        assert!(optimizer.extract_variables(&literal).is_empty());
1943
1944        let param = LogicalExpression::Parameter("p".to_string());
1945        assert!(optimizer.extract_variables(&param).is_empty());
1946    }
1947
1948    // Recursive filter pushdown tests
1949
1950    #[test]
1951    fn test_recursive_filter_pushdown_through_skip() {
1952        let optimizer = Optimizer::new();
1953
1954        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1955            items: vec![ReturnItem {
1956                expression: LogicalExpression::Variable("n".to_string()),
1957                alias: None,
1958            }],
1959            distinct: false,
1960            input: Box::new(LogicalOperator::Filter(FilterOp {
1961                predicate: LogicalExpression::Literal(Value::Bool(true)),
1962                pushdown_hint: None,
1963                input: Box::new(LogicalOperator::Skip(SkipOp {
1964                    count: 5.into(),
1965                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1966                        variable: "n".to_string(),
1967                        label: None,
1968                        input: None,
1969                    })),
1970                })),
1971            })),
1972        }));
1973
1974        let optimized = optimizer.optimize(plan).unwrap();
1975
1976        // Verify optimization succeeded
1977        assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1978    }
1979
1980    #[test]
1981    fn test_nested_filter_pushdown() {
1982        let optimizer = Optimizer::new();
1983
1984        // Multiple nested filters
1985        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1986            items: vec![ReturnItem {
1987                expression: LogicalExpression::Variable("n".to_string()),
1988                alias: None,
1989            }],
1990            distinct: false,
1991            input: Box::new(LogicalOperator::Filter(FilterOp {
1992                predicate: LogicalExpression::Binary {
1993                    left: Box::new(LogicalExpression::Property {
1994                        variable: "n".to_string(),
1995                        property: "x".to_string(),
1996                    }),
1997                    op: BinaryOp::Gt,
1998                    right: Box::new(LogicalExpression::Literal(Value::Int64(1))),
1999                },
2000                pushdown_hint: None,
2001                input: Box::new(LogicalOperator::Filter(FilterOp {
2002                    predicate: LogicalExpression::Binary {
2003                        left: Box::new(LogicalExpression::Property {
2004                            variable: "n".to_string(),
2005                            property: "y".to_string(),
2006                        }),
2007                        op: BinaryOp::Lt,
2008                        right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
2009                    },
2010                    pushdown_hint: None,
2011                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2012                        variable: "n".to_string(),
2013                        label: None,
2014                        input: None,
2015                    })),
2016                })),
2017            })),
2018        }));
2019
2020        let optimized = optimizer.optimize(plan).unwrap();
2021        assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
2022    }
2023
2024    #[test]
2025    fn test_cyclic_join_produces_multi_way_join() {
2026        use crate::query::plan::JoinCondition;
2027
2028        // Triangle pattern: a ⋈ b ⋈ c ⋈ a (cyclic)
2029        let scan_a = LogicalOperator::NodeScan(NodeScanOp {
2030            variable: "a".to_string(),
2031            label: Some("Person".to_string()),
2032            input: None,
2033        });
2034        let scan_b = LogicalOperator::NodeScan(NodeScanOp {
2035            variable: "b".to_string(),
2036            label: Some("Person".to_string()),
2037            input: None,
2038        });
2039        let scan_c = LogicalOperator::NodeScan(NodeScanOp {
2040            variable: "c".to_string(),
2041            label: Some("Person".to_string()),
2042            input: None,
2043        });
2044
2045        // Build: Join(Join(a, b, a=b), c, b=c) with extra condition c=a
2046        let join_ab = LogicalOperator::Join(JoinOp {
2047            left: Box::new(scan_a),
2048            right: Box::new(scan_b),
2049            join_type: JoinType::Inner,
2050            conditions: vec![JoinCondition {
2051                left: LogicalExpression::Variable("a".to_string()),
2052                right: LogicalExpression::Variable("b".to_string()),
2053            }],
2054        });
2055
2056        let join_abc = LogicalOperator::Join(JoinOp {
2057            left: Box::new(join_ab),
2058            right: Box::new(scan_c),
2059            join_type: JoinType::Inner,
2060            conditions: vec![
2061                JoinCondition {
2062                    left: LogicalExpression::Variable("b".to_string()),
2063                    right: LogicalExpression::Variable("c".to_string()),
2064                },
2065                JoinCondition {
2066                    left: LogicalExpression::Variable("c".to_string()),
2067                    right: LogicalExpression::Variable("a".to_string()),
2068                },
2069            ],
2070        });
2071
2072        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
2073            items: vec![ReturnItem {
2074                expression: LogicalExpression::Variable("a".to_string()),
2075                alias: None,
2076            }],
2077            distinct: false,
2078            input: Box::new(join_abc),
2079        }));
2080
2081        let mut optimizer = Optimizer::new();
2082        optimizer
2083            .card_estimator
2084            .add_table_stats("Person", cardinality::TableStats::new(1000));
2085
2086        let optimized = optimizer.optimize(plan).unwrap();
2087
2088        // Walk the tree to find a MultiWayJoin
2089        fn has_multi_way_join(op: &LogicalOperator) -> bool {
2090            match op {
2091                LogicalOperator::MultiWayJoin(_) => true,
2092                LogicalOperator::Return(ret) => has_multi_way_join(&ret.input),
2093                LogicalOperator::Filter(f) => has_multi_way_join(&f.input),
2094                LogicalOperator::Project(p) => has_multi_way_join(&p.input),
2095                _ => false,
2096            }
2097        }
2098
2099        assert!(
2100            has_multi_way_join(&optimized.root),
2101            "Expected MultiWayJoin for cyclic triangle pattern"
2102        );
2103    }
2104
2105    #[test]
2106    fn test_acyclic_join_uses_binary_joins() {
2107        use crate::query::plan::JoinCondition;
2108
2109        // Chain: a ⋈ b ⋈ c (acyclic)
2110        let scan_a = LogicalOperator::NodeScan(NodeScanOp {
2111            variable: "a".to_string(),
2112            label: Some("Person".to_string()),
2113            input: None,
2114        });
2115        let scan_b = LogicalOperator::NodeScan(NodeScanOp {
2116            variable: "b".to_string(),
2117            label: Some("Person".to_string()),
2118            input: None,
2119        });
2120        let scan_c = LogicalOperator::NodeScan(NodeScanOp {
2121            variable: "c".to_string(),
2122            label: Some("Company".to_string()),
2123            input: None,
2124        });
2125
2126        let join_ab = LogicalOperator::Join(JoinOp {
2127            left: Box::new(scan_a),
2128            right: Box::new(scan_b),
2129            join_type: JoinType::Inner,
2130            conditions: vec![JoinCondition {
2131                left: LogicalExpression::Variable("a".to_string()),
2132                right: LogicalExpression::Variable("b".to_string()),
2133            }],
2134        });
2135
2136        let join_abc = LogicalOperator::Join(JoinOp {
2137            left: Box::new(join_ab),
2138            right: Box::new(scan_c),
2139            join_type: JoinType::Inner,
2140            conditions: vec![JoinCondition {
2141                left: LogicalExpression::Variable("b".to_string()),
2142                right: LogicalExpression::Variable("c".to_string()),
2143            }],
2144        });
2145
2146        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
2147            items: vec![ReturnItem {
2148                expression: LogicalExpression::Variable("a".to_string()),
2149                alias: None,
2150            }],
2151            distinct: false,
2152            input: Box::new(join_abc),
2153        }));
2154
2155        let mut optimizer = Optimizer::new();
2156        optimizer
2157            .card_estimator
2158            .add_table_stats("Person", cardinality::TableStats::new(1000));
2159        optimizer
2160            .card_estimator
2161            .add_table_stats("Company", cardinality::TableStats::new(100));
2162
2163        let optimized = optimizer.optimize(plan).unwrap();
2164
2165        // Should NOT contain MultiWayJoin for acyclic pattern
2166        fn has_multi_way_join(op: &LogicalOperator) -> bool {
2167            match op {
2168                LogicalOperator::MultiWayJoin(_) => true,
2169                LogicalOperator::Return(ret) => has_multi_way_join(&ret.input),
2170                LogicalOperator::Filter(f) => has_multi_way_join(&f.input),
2171                LogicalOperator::Project(p) => has_multi_way_join(&p.input),
2172                LogicalOperator::Join(j) => {
2173                    has_multi_way_join(&j.left) || has_multi_way_join(&j.right)
2174                }
2175                _ => false,
2176            }
2177        }
2178
2179        assert!(
2180            !has_multi_way_join(&optimized.root),
2181            "Acyclic join should NOT produce MultiWayJoin"
2182        );
2183    }
2184}