Skip to main content

grafeo_engine/query/optimizer/
mod.rs

1//! Makes your queries faster without changing their meaning.
2//!
3//! The optimizer transforms logical plans to run more efficiently:
4//!
5//! | Optimization | What it does |
6//! | ------------ | ------------ |
7//! | Filter Pushdown | Moves `WHERE` clauses closer to scans - filter early, process less |
8//! | Join Reordering | Picks the best order to join tables using the DPccp algorithm |
9//! | Predicate Simplification | Folds constants like `1 + 1` into `2` |
10//!
11//! The optimizer uses [`CostModel`] and [`CardinalityEstimator`] to predict
12//! how expensive different plans are, then picks the cheapest.
13
14pub mod cardinality;
15pub mod cost;
16pub mod join_order;
17
18pub use cardinality::{
19    CardinalityEstimator, ColumnStats, EstimationLog, SelectivityConfig, TableStats,
20};
21pub use cost::{Cost, CostModel};
22pub use join_order::{BitSet, DPccp, JoinGraph, JoinGraphBuilder, JoinPlan};
23
24use crate::query::plan::{
25    FilterOp, JoinCondition, LogicalExpression, LogicalOperator, LogicalPlan, MultiWayJoinOp,
26};
27use grafeo_common::utils::error::Result;
28use std::collections::HashSet;
29
30/// Information about a join condition for join reordering.
31#[derive(Debug, Clone)]
32struct JoinInfo {
33    left_var: String,
34    right_var: String,
35    left_expr: LogicalExpression,
36    right_expr: LogicalExpression,
37}
38
39/// A column required by the query, used for projection pushdown.
40#[derive(Debug, Clone, PartialEq, Eq, Hash)]
41enum RequiredColumn {
42    /// A variable (node, edge, or path binding)
43    Variable(String),
44    /// A specific property of a variable
45    Property(String, String),
46}
47
48/// Transforms logical plans for faster execution.
49///
50/// Create with [`new()`](Self::new), then call [`optimize()`](Self::optimize).
51/// Use the builder methods to enable/disable specific optimizations.
52pub struct Optimizer {
53    /// Whether to enable filter pushdown.
54    enable_filter_pushdown: bool,
55    /// Whether to enable join reordering.
56    enable_join_reorder: bool,
57    /// Whether to enable projection pushdown.
58    enable_projection_pushdown: bool,
59    /// Cost model for estimation.
60    cost_model: CostModel,
61    /// Cardinality estimator.
62    card_estimator: CardinalityEstimator,
63}
64
65impl Optimizer {
66    /// Creates a new optimizer with default settings.
67    #[must_use]
68    pub fn new() -> Self {
69        Self {
70            enable_filter_pushdown: true,
71            enable_join_reorder: true,
72            enable_projection_pushdown: true,
73            cost_model: CostModel::new(),
74            card_estimator: CardinalityEstimator::new(),
75        }
76    }
77
78    /// Creates an optimizer with cardinality estimates from the store's statistics.
79    ///
80    /// Pre-populates the cardinality estimator with per-label row counts and
81    /// edge type fanout. Feeds per-edge-type degree stats, label cardinalities,
82    /// and graph totals into the cost model for accurate estimation.
83    #[must_use]
84    pub fn from_store(store: &grafeo_core::graph::lpg::LpgStore) -> Self {
85        store.ensure_statistics_fresh();
86        let stats = store.statistics();
87        Self::from_statistics(&stats)
88    }
89
90    /// Creates an optimizer from any GraphStore implementation.
91    ///
92    /// Unlike [`from_store`](Self::from_store), this does not call
93    /// `ensure_statistics_fresh()` since external stores manage their own
94    /// statistics. The store's [`statistics()`](grafeo_core::graph::GraphStore::statistics) method
95    /// is called directly.
96    #[must_use]
97    pub fn from_graph_store(store: &dyn grafeo_core::graph::GraphStore) -> Self {
98        let stats = store.statistics();
99        Self::from_statistics(&stats)
100    }
101
102    /// Creates an optimizer from RDF statistics.
103    ///
104    /// Uses triple pattern cardinality estimates for cost-based optimization
105    /// of SPARQL queries. Maps total triples to graph totals for the cost model.
106    #[cfg(feature = "rdf")]
107    #[must_use]
108    pub fn from_rdf_statistics(rdf_stats: grafeo_core::statistics::RdfStatistics) -> Self {
109        let total = rdf_stats.total_triples;
110        let estimator = CardinalityEstimator::from_rdf_statistics(rdf_stats);
111        Self {
112            enable_filter_pushdown: true,
113            enable_join_reorder: true,
114            enable_projection_pushdown: true,
115            cost_model: CostModel::new().with_graph_totals(total, total),
116            card_estimator: estimator,
117        }
118    }
119
120    /// Creates an optimizer from a Statistics snapshot.
121    ///
122    /// Extracts label cardinalities, edge type degrees, and graph totals
123    /// for both the cardinality estimator and cost model.
124    #[must_use]
125    fn from_statistics(stats: &grafeo_core::statistics::Statistics) -> Self {
126        let estimator = CardinalityEstimator::from_statistics(stats);
127
128        let avg_fanout = if stats.total_nodes > 0 {
129            (stats.total_edges as f64 / stats.total_nodes as f64).max(1.0)
130        } else {
131            10.0
132        };
133
134        let edge_type_degrees: std::collections::HashMap<String, (f64, f64)> = stats
135            .edge_types
136            .iter()
137            .map(|(name, et)| (name.clone(), (et.avg_out_degree, et.avg_in_degree)))
138            .collect();
139
140        let label_cardinalities: std::collections::HashMap<String, u64> = stats
141            .labels
142            .iter()
143            .map(|(name, ls)| (name.clone(), ls.node_count))
144            .collect();
145
146        Self {
147            enable_filter_pushdown: true,
148            enable_join_reorder: true,
149            enable_projection_pushdown: true,
150            cost_model: CostModel::new()
151                .with_avg_fanout(avg_fanout)
152                .with_edge_type_degrees(edge_type_degrees)
153                .with_label_cardinalities(label_cardinalities)
154                .with_graph_totals(stats.total_nodes, stats.total_edges),
155            card_estimator: estimator,
156        }
157    }
158
159    /// Enables or disables filter pushdown.
160    pub fn with_filter_pushdown(mut self, enabled: bool) -> Self {
161        self.enable_filter_pushdown = enabled;
162        self
163    }
164
165    /// Enables or disables join reordering.
166    pub fn with_join_reorder(mut self, enabled: bool) -> Self {
167        self.enable_join_reorder = enabled;
168        self
169    }
170
171    /// Enables or disables projection pushdown.
172    pub fn with_projection_pushdown(mut self, enabled: bool) -> Self {
173        self.enable_projection_pushdown = enabled;
174        self
175    }
176
177    /// Sets the cost model.
178    pub fn with_cost_model(mut self, cost_model: CostModel) -> Self {
179        self.cost_model = cost_model;
180        self
181    }
182
183    /// Sets the cardinality estimator.
184    pub fn with_cardinality_estimator(mut self, estimator: CardinalityEstimator) -> Self {
185        self.card_estimator = estimator;
186        self
187    }
188
189    /// Sets the selectivity configuration for the cardinality estimator.
190    pub fn with_selectivity_config(mut self, config: SelectivityConfig) -> Self {
191        self.card_estimator = CardinalityEstimator::with_selectivity_config(config);
192        self
193    }
194
195    /// Returns a reference to the cost model.
196    pub fn cost_model(&self) -> &CostModel {
197        &self.cost_model
198    }
199
200    /// Returns a reference to the cardinality estimator.
201    pub fn cardinality_estimator(&self) -> &CardinalityEstimator {
202        &self.card_estimator
203    }
204
205    /// Estimates the total cost of a plan by recursively costing the entire tree.
206    ///
207    /// Walks every operator in the plan, computing cardinality at each level
208    /// and summing the per-operator costs. Uses actual child cardinalities
209    /// for join cost estimation rather than approximations.
210    pub fn estimate_cost(&self, plan: &LogicalPlan) -> Cost {
211        self.cost_model
212            .estimate_tree(&plan.root, &self.card_estimator)
213    }
214
215    /// Estimates the cardinality of a plan.
216    pub fn estimate_cardinality(&self, plan: &LogicalPlan) -> f64 {
217        self.card_estimator.estimate(&plan.root)
218    }
219
220    /// Optimizes a logical plan.
221    ///
222    /// # Errors
223    ///
224    /// Returns an error if optimization fails.
225    pub fn optimize(&self, plan: LogicalPlan) -> Result<LogicalPlan> {
226        let mut root = plan.root;
227
228        // Apply optimization rules
229        if self.enable_filter_pushdown {
230            root = self.push_filters_down(root);
231        }
232
233        if self.enable_join_reorder {
234            root = self.reorder_joins(root);
235        }
236
237        if self.enable_projection_pushdown {
238            root = self.push_projections_down(root);
239        }
240
241        Ok(LogicalPlan {
242            root,
243            explain: plan.explain,
244            profile: plan.profile,
245        })
246    }
247
248    /// Pushes projections down the operator tree to eliminate unused columns early.
249    ///
250    /// This optimization:
251    /// 1. Collects required variables/properties from the root
252    /// 2. Propagates requirements down through the tree
253    /// 3. Inserts projections to eliminate unneeded columns before expensive operations
254    fn push_projections_down(&self, op: LogicalOperator) -> LogicalOperator {
255        // Collect required columns from the top of the plan
256        let required = self.collect_required_columns(&op);
257
258        // Push projections down
259        self.push_projections_recursive(op, &required)
260    }
261
262    /// Collects all variables and properties required by an operator and its ancestors.
263    fn collect_required_columns(&self, op: &LogicalOperator) -> HashSet<RequiredColumn> {
264        let mut required = HashSet::new();
265        Self::collect_required_recursive(op, &mut required);
266        required
267    }
268
269    /// Recursively collects required columns.
270    fn collect_required_recursive(op: &LogicalOperator, required: &mut HashSet<RequiredColumn>) {
271        match op {
272            LogicalOperator::Return(ret) => {
273                for item in &ret.items {
274                    Self::collect_from_expression(&item.expression, required);
275                }
276                Self::collect_required_recursive(&ret.input, required);
277            }
278            LogicalOperator::Project(proj) => {
279                for p in &proj.projections {
280                    Self::collect_from_expression(&p.expression, required);
281                }
282                Self::collect_required_recursive(&proj.input, required);
283            }
284            LogicalOperator::Filter(filter) => {
285                Self::collect_from_expression(&filter.predicate, required);
286                Self::collect_required_recursive(&filter.input, required);
287            }
288            LogicalOperator::Sort(sort) => {
289                for key in &sort.keys {
290                    Self::collect_from_expression(&key.expression, required);
291                }
292                Self::collect_required_recursive(&sort.input, required);
293            }
294            LogicalOperator::Aggregate(agg) => {
295                for expr in &agg.group_by {
296                    Self::collect_from_expression(expr, required);
297                }
298                for agg_expr in &agg.aggregates {
299                    if let Some(ref expr) = agg_expr.expression {
300                        Self::collect_from_expression(expr, required);
301                    }
302                }
303                if let Some(ref having) = agg.having {
304                    Self::collect_from_expression(having, required);
305                }
306                Self::collect_required_recursive(&agg.input, required);
307            }
308            LogicalOperator::Join(join) => {
309                for cond in &join.conditions {
310                    Self::collect_from_expression(&cond.left, required);
311                    Self::collect_from_expression(&cond.right, required);
312                }
313                Self::collect_required_recursive(&join.left, required);
314                Self::collect_required_recursive(&join.right, required);
315            }
316            LogicalOperator::Expand(expand) => {
317                // The source and target variables are needed
318                required.insert(RequiredColumn::Variable(expand.from_variable.clone()));
319                required.insert(RequiredColumn::Variable(expand.to_variable.clone()));
320                if let Some(ref edge_var) = expand.edge_variable {
321                    required.insert(RequiredColumn::Variable(edge_var.clone()));
322                }
323                Self::collect_required_recursive(&expand.input, required);
324            }
325            LogicalOperator::Limit(limit) => {
326                Self::collect_required_recursive(&limit.input, required);
327            }
328            LogicalOperator::Skip(skip) => {
329                Self::collect_required_recursive(&skip.input, required);
330            }
331            LogicalOperator::Distinct(distinct) => {
332                Self::collect_required_recursive(&distinct.input, required);
333            }
334            LogicalOperator::NodeScan(scan) => {
335                required.insert(RequiredColumn::Variable(scan.variable.clone()));
336            }
337            LogicalOperator::EdgeScan(scan) => {
338                required.insert(RequiredColumn::Variable(scan.variable.clone()));
339            }
340            LogicalOperator::MultiWayJoin(mwj) => {
341                for cond in &mwj.conditions {
342                    Self::collect_from_expression(&cond.left, required);
343                    Self::collect_from_expression(&cond.right, required);
344                }
345                for input in &mwj.inputs {
346                    Self::collect_required_recursive(input, required);
347                }
348            }
349            _ => {}
350        }
351    }
352
353    /// Collects required columns from an expression.
354    fn collect_from_expression(expr: &LogicalExpression, required: &mut HashSet<RequiredColumn>) {
355        match expr {
356            LogicalExpression::Variable(var) => {
357                required.insert(RequiredColumn::Variable(var.clone()));
358            }
359            LogicalExpression::Property { variable, property } => {
360                required.insert(RequiredColumn::Property(variable.clone(), property.clone()));
361                required.insert(RequiredColumn::Variable(variable.clone()));
362            }
363            LogicalExpression::Binary { left, right, .. } => {
364                Self::collect_from_expression(left, required);
365                Self::collect_from_expression(right, required);
366            }
367            LogicalExpression::Unary { operand, .. } => {
368                Self::collect_from_expression(operand, required);
369            }
370            LogicalExpression::FunctionCall { args, .. } => {
371                for arg in args {
372                    Self::collect_from_expression(arg, required);
373                }
374            }
375            LogicalExpression::List(items) => {
376                for item in items {
377                    Self::collect_from_expression(item, required);
378                }
379            }
380            LogicalExpression::Map(pairs) => {
381                for (_, value) in pairs {
382                    Self::collect_from_expression(value, required);
383                }
384            }
385            LogicalExpression::IndexAccess { base, index } => {
386                Self::collect_from_expression(base, required);
387                Self::collect_from_expression(index, required);
388            }
389            LogicalExpression::SliceAccess { base, start, end } => {
390                Self::collect_from_expression(base, required);
391                if let Some(s) = start {
392                    Self::collect_from_expression(s, required);
393                }
394                if let Some(e) = end {
395                    Self::collect_from_expression(e, required);
396                }
397            }
398            LogicalExpression::Case {
399                operand,
400                when_clauses,
401                else_clause,
402            } => {
403                if let Some(op) = operand {
404                    Self::collect_from_expression(op, required);
405                }
406                for (cond, result) in when_clauses {
407                    Self::collect_from_expression(cond, required);
408                    Self::collect_from_expression(result, required);
409                }
410                if let Some(else_expr) = else_clause {
411                    Self::collect_from_expression(else_expr, required);
412                }
413            }
414            LogicalExpression::Labels(var)
415            | LogicalExpression::Type(var)
416            | LogicalExpression::Id(var) => {
417                required.insert(RequiredColumn::Variable(var.clone()));
418            }
419            LogicalExpression::ListComprehension {
420                list_expr,
421                filter_expr,
422                map_expr,
423                ..
424            } => {
425                Self::collect_from_expression(list_expr, required);
426                if let Some(filter) = filter_expr {
427                    Self::collect_from_expression(filter, required);
428                }
429                Self::collect_from_expression(map_expr, required);
430            }
431            _ => {}
432        }
433    }
434
435    /// Recursively pushes projections down, adding them before expensive operations.
436    fn push_projections_recursive(
437        &self,
438        op: LogicalOperator,
439        required: &HashSet<RequiredColumn>,
440    ) -> LogicalOperator {
441        match op {
442            LogicalOperator::Return(mut ret) => {
443                ret.input = Box::new(self.push_projections_recursive(*ret.input, required));
444                LogicalOperator::Return(ret)
445            }
446            LogicalOperator::Project(mut proj) => {
447                proj.input = Box::new(self.push_projections_recursive(*proj.input, required));
448                LogicalOperator::Project(proj)
449            }
450            LogicalOperator::Filter(mut filter) => {
451                filter.input = Box::new(self.push_projections_recursive(*filter.input, required));
452                LogicalOperator::Filter(filter)
453            }
454            LogicalOperator::Sort(mut sort) => {
455                // Sort is expensive - consider adding a projection before it
456                // to reduce tuple width
457                sort.input = Box::new(self.push_projections_recursive(*sort.input, required));
458                LogicalOperator::Sort(sort)
459            }
460            LogicalOperator::Aggregate(mut agg) => {
461                agg.input = Box::new(self.push_projections_recursive(*agg.input, required));
462                LogicalOperator::Aggregate(agg)
463            }
464            LogicalOperator::Join(mut join) => {
465                // Joins are expensive - the required columns help determine
466                // what to project on each side
467                let left_vars = self.collect_output_variables(&join.left);
468                let right_vars = self.collect_output_variables(&join.right);
469
470                // Filter required columns to each side
471                let left_required: HashSet<_> = required
472                    .iter()
473                    .filter(|c| match c {
474                        RequiredColumn::Variable(v) => left_vars.contains(v),
475                        RequiredColumn::Property(v, _) => left_vars.contains(v),
476                    })
477                    .cloned()
478                    .collect();
479
480                let right_required: HashSet<_> = required
481                    .iter()
482                    .filter(|c| match c {
483                        RequiredColumn::Variable(v) => right_vars.contains(v),
484                        RequiredColumn::Property(v, _) => right_vars.contains(v),
485                    })
486                    .cloned()
487                    .collect();
488
489                join.left = Box::new(self.push_projections_recursive(*join.left, &left_required));
490                join.right =
491                    Box::new(self.push_projections_recursive(*join.right, &right_required));
492                LogicalOperator::Join(join)
493            }
494            LogicalOperator::Expand(mut expand) => {
495                expand.input = Box::new(self.push_projections_recursive(*expand.input, required));
496                LogicalOperator::Expand(expand)
497            }
498            LogicalOperator::Limit(mut limit) => {
499                limit.input = Box::new(self.push_projections_recursive(*limit.input, required));
500                LogicalOperator::Limit(limit)
501            }
502            LogicalOperator::Skip(mut skip) => {
503                skip.input = Box::new(self.push_projections_recursive(*skip.input, required));
504                LogicalOperator::Skip(skip)
505            }
506            LogicalOperator::Distinct(mut distinct) => {
507                distinct.input =
508                    Box::new(self.push_projections_recursive(*distinct.input, required));
509                LogicalOperator::Distinct(distinct)
510            }
511            LogicalOperator::MapCollect(mut mc) => {
512                mc.input = Box::new(self.push_projections_recursive(*mc.input, required));
513                LogicalOperator::MapCollect(mc)
514            }
515            LogicalOperator::MultiWayJoin(mut mwj) => {
516                mwj.inputs = mwj
517                    .inputs
518                    .into_iter()
519                    .map(|input| self.push_projections_recursive(input, required))
520                    .collect();
521                LogicalOperator::MultiWayJoin(mwj)
522            }
523            other => other,
524        }
525    }
526
527    /// Reorders joins in the operator tree using the DPccp algorithm.
528    ///
529    /// This optimization finds the optimal join order by:
530    /// 1. Extracting all base relations (scans) and join conditions
531    /// 2. Building a join graph
532    /// 3. Using dynamic programming to find the cheapest join order
533    fn reorder_joins(&self, op: LogicalOperator) -> LogicalOperator {
534        // First, recursively optimize children
535        let op = self.reorder_joins_recursive(op);
536
537        // Then, if this is a join tree, try to optimize it
538        if let Some((relations, conditions)) = self.extract_join_tree(&op)
539            && relations.len() >= 2
540            && let Some(optimized) = self.optimize_join_order(&relations, &conditions)
541        {
542            return optimized;
543        }
544
545        op
546    }
547
548    /// Recursively applies join reordering to child operators.
549    fn reorder_joins_recursive(&self, op: LogicalOperator) -> LogicalOperator {
550        match op {
551            LogicalOperator::Return(mut ret) => {
552                ret.input = Box::new(self.reorder_joins(*ret.input));
553                LogicalOperator::Return(ret)
554            }
555            LogicalOperator::Project(mut proj) => {
556                proj.input = Box::new(self.reorder_joins(*proj.input));
557                LogicalOperator::Project(proj)
558            }
559            LogicalOperator::Filter(mut filter) => {
560                filter.input = Box::new(self.reorder_joins(*filter.input));
561                LogicalOperator::Filter(filter)
562            }
563            LogicalOperator::Limit(mut limit) => {
564                limit.input = Box::new(self.reorder_joins(*limit.input));
565                LogicalOperator::Limit(limit)
566            }
567            LogicalOperator::Skip(mut skip) => {
568                skip.input = Box::new(self.reorder_joins(*skip.input));
569                LogicalOperator::Skip(skip)
570            }
571            LogicalOperator::Sort(mut sort) => {
572                sort.input = Box::new(self.reorder_joins(*sort.input));
573                LogicalOperator::Sort(sort)
574            }
575            LogicalOperator::Distinct(mut distinct) => {
576                distinct.input = Box::new(self.reorder_joins(*distinct.input));
577                LogicalOperator::Distinct(distinct)
578            }
579            LogicalOperator::Aggregate(mut agg) => {
580                agg.input = Box::new(self.reorder_joins(*agg.input));
581                LogicalOperator::Aggregate(agg)
582            }
583            LogicalOperator::Expand(mut expand) => {
584                expand.input = Box::new(self.reorder_joins(*expand.input));
585                LogicalOperator::Expand(expand)
586            }
587            LogicalOperator::MapCollect(mut mc) => {
588                mc.input = Box::new(self.reorder_joins(*mc.input));
589                LogicalOperator::MapCollect(mc)
590            }
591            LogicalOperator::MultiWayJoin(mut mwj) => {
592                mwj.inputs = mwj
593                    .inputs
594                    .into_iter()
595                    .map(|input| self.reorder_joins(input))
596                    .collect();
597                LogicalOperator::MultiWayJoin(mwj)
598            }
599            // Join operators are handled by the parent reorder_joins call
600            other => other,
601        }
602    }
603
604    /// Extracts base relations and join conditions from a join tree.
605    ///
606    /// Returns None if the operator is not a join tree.
607    fn extract_join_tree(
608        &self,
609        op: &LogicalOperator,
610    ) -> Option<(Vec<(String, LogicalOperator)>, Vec<JoinInfo>)> {
611        let mut relations = Vec::new();
612        let mut join_conditions = Vec::new();
613
614        if !self.collect_join_tree(op, &mut relations, &mut join_conditions) {
615            return None;
616        }
617
618        if relations.len() < 2 {
619            return None;
620        }
621
622        Some((relations, join_conditions))
623    }
624
625    /// Recursively collects base relations and join conditions.
626    ///
627    /// Returns true if this subtree is part of a join tree.
628    fn collect_join_tree(
629        &self,
630        op: &LogicalOperator,
631        relations: &mut Vec<(String, LogicalOperator)>,
632        conditions: &mut Vec<JoinInfo>,
633    ) -> bool {
634        match op {
635            LogicalOperator::Join(join) => {
636                // Collect from both sides
637                let left_ok = self.collect_join_tree(&join.left, relations, conditions);
638                let right_ok = self.collect_join_tree(&join.right, relations, conditions);
639
640                // Add conditions from this join
641                for cond in &join.conditions {
642                    if let (Some(left_var), Some(right_var)) = (
643                        self.extract_variable_from_expr(&cond.left),
644                        self.extract_variable_from_expr(&cond.right),
645                    ) {
646                        conditions.push(JoinInfo {
647                            left_var,
648                            right_var,
649                            left_expr: cond.left.clone(),
650                            right_expr: cond.right.clone(),
651                        });
652                    }
653                }
654
655                left_ok && right_ok
656            }
657            LogicalOperator::NodeScan(scan) => {
658                relations.push((scan.variable.clone(), op.clone()));
659                true
660            }
661            LogicalOperator::EdgeScan(scan) => {
662                relations.push((scan.variable.clone(), op.clone()));
663                true
664            }
665            LogicalOperator::Filter(filter) => {
666                // A filter on a base relation is still part of the join tree
667                self.collect_join_tree(&filter.input, relations, conditions)
668            }
669            LogicalOperator::Expand(expand) => {
670                // Expand is a special case - it's like a join with the adjacency
671                // For now, treat the whole Expand subtree as a single relation
672                relations.push((expand.to_variable.clone(), op.clone()));
673                true
674            }
675            _ => false,
676        }
677    }
678
679    /// Extracts the primary variable from an expression.
680    fn extract_variable_from_expr(&self, expr: &LogicalExpression) -> Option<String> {
681        match expr {
682            LogicalExpression::Variable(v) => Some(v.clone()),
683            LogicalExpression::Property { variable, .. } => Some(variable.clone()),
684            _ => None,
685        }
686    }
687
688    /// Optimizes the join order using DPccp, or produces a multi-way
689    /// leapfrog join for cyclic patterns when the cost model prefers it.
690    fn optimize_join_order(
691        &self,
692        relations: &[(String, LogicalOperator)],
693        conditions: &[JoinInfo],
694    ) -> Option<LogicalOperator> {
695        use join_order::{DPccp, JoinGraphBuilder};
696
697        // Build the join graph
698        let mut builder = JoinGraphBuilder::new();
699
700        for (var, relation) in relations {
701            builder.add_relation(var, relation.clone());
702        }
703
704        for cond in conditions {
705            builder.add_join_condition(
706                &cond.left_var,
707                &cond.right_var,
708                cond.left_expr.clone(),
709                cond.right_expr.clone(),
710            );
711        }
712
713        let graph = builder.build();
714
715        // For cyclic graphs with 3+ relations, use leapfrog (WCOJ) join.
716        // Cyclic joins (e.g. triangle patterns) benefit from worst-case optimal
717        // multi-way intersection rather than binary hash join cascades that can
718        // produce intermediate blowup.
719        if graph.is_cyclic() && relations.len() >= 3 {
720            // Collect shared variables (variables appearing in 2+ conditions)
721            let mut var_counts: std::collections::HashMap<&str, usize> =
722                std::collections::HashMap::new();
723            for cond in conditions {
724                *var_counts.entry(&cond.left_var).or_default() += 1;
725                *var_counts.entry(&cond.right_var).or_default() += 1;
726            }
727            let shared_variables: Vec<String> = var_counts
728                .into_iter()
729                .filter(|(_, count)| *count >= 2)
730                .map(|(var, _)| var.to_string())
731                .collect();
732
733            let join_conditions: Vec<JoinCondition> = conditions
734                .iter()
735                .map(|c| JoinCondition {
736                    left: c.left_expr.clone(),
737                    right: c.right_expr.clone(),
738                })
739                .collect();
740
741            return Some(LogicalOperator::MultiWayJoin(MultiWayJoinOp {
742                inputs: relations.iter().map(|(_, rel)| rel.clone()).collect(),
743                conditions: join_conditions,
744                shared_variables,
745            }));
746        }
747
748        // Fall through to DPccp for binary join ordering
749        let mut dpccp = DPccp::new(&graph, &self.cost_model, &self.card_estimator);
750        let plan = dpccp.optimize()?;
751
752        Some(plan.operator)
753    }
754
755    /// Pushes filters down the operator tree.
756    ///
757    /// This optimization moves filter predicates as close to the data source
758    /// as possible to reduce the amount of data processed by upper operators.
759    fn push_filters_down(&self, op: LogicalOperator) -> LogicalOperator {
760        match op {
761            // For Filter operators, try to push the predicate into the child
762            LogicalOperator::Filter(filter) => {
763                let optimized_input = self.push_filters_down(*filter.input);
764                self.try_push_filter_into(filter.predicate, optimized_input)
765            }
766            // Recursively optimize children for other operators
767            LogicalOperator::Return(mut ret) => {
768                ret.input = Box::new(self.push_filters_down(*ret.input));
769                LogicalOperator::Return(ret)
770            }
771            LogicalOperator::Project(mut proj) => {
772                proj.input = Box::new(self.push_filters_down(*proj.input));
773                LogicalOperator::Project(proj)
774            }
775            LogicalOperator::Limit(mut limit) => {
776                limit.input = Box::new(self.push_filters_down(*limit.input));
777                LogicalOperator::Limit(limit)
778            }
779            LogicalOperator::Skip(mut skip) => {
780                skip.input = Box::new(self.push_filters_down(*skip.input));
781                LogicalOperator::Skip(skip)
782            }
783            LogicalOperator::Sort(mut sort) => {
784                sort.input = Box::new(self.push_filters_down(*sort.input));
785                LogicalOperator::Sort(sort)
786            }
787            LogicalOperator::Distinct(mut distinct) => {
788                distinct.input = Box::new(self.push_filters_down(*distinct.input));
789                LogicalOperator::Distinct(distinct)
790            }
791            LogicalOperator::Expand(mut expand) => {
792                expand.input = Box::new(self.push_filters_down(*expand.input));
793                LogicalOperator::Expand(expand)
794            }
795            LogicalOperator::Join(mut join) => {
796                join.left = Box::new(self.push_filters_down(*join.left));
797                join.right = Box::new(self.push_filters_down(*join.right));
798                LogicalOperator::Join(join)
799            }
800            LogicalOperator::Aggregate(mut agg) => {
801                agg.input = Box::new(self.push_filters_down(*agg.input));
802                LogicalOperator::Aggregate(agg)
803            }
804            LogicalOperator::MapCollect(mut mc) => {
805                mc.input = Box::new(self.push_filters_down(*mc.input));
806                LogicalOperator::MapCollect(mc)
807            }
808            LogicalOperator::MultiWayJoin(mut mwj) => {
809                mwj.inputs = mwj
810                    .inputs
811                    .into_iter()
812                    .map(|input| self.push_filters_down(input))
813                    .collect();
814                LogicalOperator::MultiWayJoin(mwj)
815            }
816            // Leaf operators and unsupported operators are returned as-is
817            other => other,
818        }
819    }
820
821    /// Tries to push a filter predicate into the given operator.
822    ///
823    /// Returns either the predicate pushed into the operator, or a new
824    /// Filter operator on top if the predicate cannot be pushed further.
825    fn try_push_filter_into(
826        &self,
827        predicate: LogicalExpression,
828        op: LogicalOperator,
829    ) -> LogicalOperator {
830        match op {
831            // Can push through Project if predicate doesn't depend on computed columns
832            LogicalOperator::Project(mut proj) => {
833                let predicate_vars = self.extract_variables(&predicate);
834                let computed_vars = self.extract_projection_aliases(&proj.projections);
835
836                // If predicate doesn't use any computed columns, push through
837                if predicate_vars.is_disjoint(&computed_vars) {
838                    proj.input = Box::new(self.try_push_filter_into(predicate, *proj.input));
839                    LogicalOperator::Project(proj)
840                } else {
841                    // Can't push through, keep filter on top
842                    LogicalOperator::Filter(FilterOp {
843                        predicate,
844                        pushdown_hint: None,
845                        input: Box::new(LogicalOperator::Project(proj)),
846                    })
847                }
848            }
849
850            // Can push through Return (which is like a projection)
851            LogicalOperator::Return(mut ret) => {
852                ret.input = Box::new(self.try_push_filter_into(predicate, *ret.input));
853                LogicalOperator::Return(ret)
854            }
855
856            // Can push through Expand if predicate doesn't use variables introduced by this expand
857            LogicalOperator::Expand(mut expand) => {
858                let predicate_vars = self.extract_variables(&predicate);
859
860                // Variables introduced by this expand are:
861                // - The target variable (to_variable)
862                // - The edge variable (if any)
863                // - The path alias (if any)
864                let mut introduced_vars = vec![&expand.to_variable];
865                if let Some(ref edge_var) = expand.edge_variable {
866                    introduced_vars.push(edge_var);
867                }
868                if let Some(ref path_alias) = expand.path_alias {
869                    introduced_vars.push(path_alias);
870                }
871
872                // Check if predicate uses any variables introduced by this expand
873                let uses_introduced_vars =
874                    predicate_vars.iter().any(|v| introduced_vars.contains(&v));
875
876                if !uses_introduced_vars {
877                    // Predicate doesn't use vars from this expand, so push through
878                    expand.input = Box::new(self.try_push_filter_into(predicate, *expand.input));
879                    LogicalOperator::Expand(expand)
880                } else {
881                    // Keep filter after expand
882                    LogicalOperator::Filter(FilterOp {
883                        predicate,
884                        pushdown_hint: None,
885                        input: Box::new(LogicalOperator::Expand(expand)),
886                    })
887                }
888            }
889
890            // Can push through Join to left/right side based on variables used
891            LogicalOperator::Join(mut join) => {
892                let predicate_vars = self.extract_variables(&predicate);
893                let left_vars = self.collect_output_variables(&join.left);
894                let right_vars = self.collect_output_variables(&join.right);
895
896                let uses_left = predicate_vars.iter().any(|v| left_vars.contains(v));
897                let uses_right = predicate_vars.iter().any(|v| right_vars.contains(v));
898
899                if uses_left && !uses_right {
900                    // Push to left side
901                    join.left = Box::new(self.try_push_filter_into(predicate, *join.left));
902                    LogicalOperator::Join(join)
903                } else if uses_right && !uses_left {
904                    // Push to right side
905                    join.right = Box::new(self.try_push_filter_into(predicate, *join.right));
906                    LogicalOperator::Join(join)
907                } else {
908                    // Uses both sides - keep above join
909                    LogicalOperator::Filter(FilterOp {
910                        predicate,
911                        pushdown_hint: None,
912                        input: Box::new(LogicalOperator::Join(join)),
913                    })
914                }
915            }
916
917            // Cannot push through Aggregate (predicate refers to aggregated values)
918            LogicalOperator::Aggregate(agg) => LogicalOperator::Filter(FilterOp {
919                predicate,
920                pushdown_hint: None,
921                input: Box::new(LogicalOperator::Aggregate(agg)),
922            }),
923
924            // For NodeScan, we've reached the bottom - keep filter on top
925            LogicalOperator::NodeScan(scan) => LogicalOperator::Filter(FilterOp {
926                predicate,
927                pushdown_hint: None,
928                input: Box::new(LogicalOperator::NodeScan(scan)),
929            }),
930
931            // For other operators, keep filter on top
932            other => LogicalOperator::Filter(FilterOp {
933                predicate,
934                pushdown_hint: None,
935                input: Box::new(other),
936            }),
937        }
938    }
939
940    /// Collects all output variable names from an operator.
941    fn collect_output_variables(&self, op: &LogicalOperator) -> HashSet<String> {
942        let mut vars = HashSet::new();
943        Self::collect_output_variables_recursive(op, &mut vars);
944        vars
945    }
946
947    /// Recursively collects output variables from an operator.
948    fn collect_output_variables_recursive(op: &LogicalOperator, vars: &mut HashSet<String>) {
949        match op {
950            LogicalOperator::NodeScan(scan) => {
951                vars.insert(scan.variable.clone());
952            }
953            LogicalOperator::EdgeScan(scan) => {
954                vars.insert(scan.variable.clone());
955            }
956            LogicalOperator::Expand(expand) => {
957                vars.insert(expand.to_variable.clone());
958                if let Some(edge_var) = &expand.edge_variable {
959                    vars.insert(edge_var.clone());
960                }
961                Self::collect_output_variables_recursive(&expand.input, vars);
962            }
963            LogicalOperator::Filter(filter) => {
964                Self::collect_output_variables_recursive(&filter.input, vars);
965            }
966            LogicalOperator::Project(proj) => {
967                for p in &proj.projections {
968                    if let Some(alias) = &p.alias {
969                        vars.insert(alias.clone());
970                    }
971                }
972                Self::collect_output_variables_recursive(&proj.input, vars);
973            }
974            LogicalOperator::Join(join) => {
975                Self::collect_output_variables_recursive(&join.left, vars);
976                Self::collect_output_variables_recursive(&join.right, vars);
977            }
978            LogicalOperator::Aggregate(agg) => {
979                for expr in &agg.group_by {
980                    Self::collect_variables(expr, vars);
981                }
982                for agg_expr in &agg.aggregates {
983                    if let Some(alias) = &agg_expr.alias {
984                        vars.insert(alias.clone());
985                    }
986                }
987            }
988            LogicalOperator::Return(ret) => {
989                Self::collect_output_variables_recursive(&ret.input, vars);
990            }
991            LogicalOperator::Limit(limit) => {
992                Self::collect_output_variables_recursive(&limit.input, vars);
993            }
994            LogicalOperator::Skip(skip) => {
995                Self::collect_output_variables_recursive(&skip.input, vars);
996            }
997            LogicalOperator::Sort(sort) => {
998                Self::collect_output_variables_recursive(&sort.input, vars);
999            }
1000            LogicalOperator::Distinct(distinct) => {
1001                Self::collect_output_variables_recursive(&distinct.input, vars);
1002            }
1003            _ => {}
1004        }
1005    }
1006
1007    /// Extracts all variable names referenced in an expression.
1008    fn extract_variables(&self, expr: &LogicalExpression) -> HashSet<String> {
1009        let mut vars = HashSet::new();
1010        Self::collect_variables(expr, &mut vars);
1011        vars
1012    }
1013
1014    /// Recursively collects variable names from an expression.
1015    fn collect_variables(expr: &LogicalExpression, vars: &mut HashSet<String>) {
1016        match expr {
1017            LogicalExpression::Variable(name) => {
1018                vars.insert(name.clone());
1019            }
1020            LogicalExpression::Property { variable, .. } => {
1021                vars.insert(variable.clone());
1022            }
1023            LogicalExpression::Binary { left, right, .. } => {
1024                Self::collect_variables(left, vars);
1025                Self::collect_variables(right, vars);
1026            }
1027            LogicalExpression::Unary { operand, .. } => {
1028                Self::collect_variables(operand, vars);
1029            }
1030            LogicalExpression::FunctionCall { args, .. } => {
1031                for arg in args {
1032                    Self::collect_variables(arg, vars);
1033                }
1034            }
1035            LogicalExpression::List(items) => {
1036                for item in items {
1037                    Self::collect_variables(item, vars);
1038                }
1039            }
1040            LogicalExpression::Map(pairs) => {
1041                for (_, value) in pairs {
1042                    Self::collect_variables(value, vars);
1043                }
1044            }
1045            LogicalExpression::IndexAccess { base, index } => {
1046                Self::collect_variables(base, vars);
1047                Self::collect_variables(index, vars);
1048            }
1049            LogicalExpression::SliceAccess { base, start, end } => {
1050                Self::collect_variables(base, vars);
1051                if let Some(s) = start {
1052                    Self::collect_variables(s, vars);
1053                }
1054                if let Some(e) = end {
1055                    Self::collect_variables(e, vars);
1056                }
1057            }
1058            LogicalExpression::Case {
1059                operand,
1060                when_clauses,
1061                else_clause,
1062            } => {
1063                if let Some(op) = operand {
1064                    Self::collect_variables(op, vars);
1065                }
1066                for (cond, result) in when_clauses {
1067                    Self::collect_variables(cond, vars);
1068                    Self::collect_variables(result, vars);
1069                }
1070                if let Some(else_expr) = else_clause {
1071                    Self::collect_variables(else_expr, vars);
1072                }
1073            }
1074            LogicalExpression::Labels(var)
1075            | LogicalExpression::Type(var)
1076            | LogicalExpression::Id(var) => {
1077                vars.insert(var.clone());
1078            }
1079            LogicalExpression::Literal(_) | LogicalExpression::Parameter(_) => {}
1080            LogicalExpression::ListComprehension {
1081                list_expr,
1082                filter_expr,
1083                map_expr,
1084                ..
1085            } => {
1086                Self::collect_variables(list_expr, vars);
1087                if let Some(filter) = filter_expr {
1088                    Self::collect_variables(filter, vars);
1089                }
1090                Self::collect_variables(map_expr, vars);
1091            }
1092            LogicalExpression::ListPredicate {
1093                list_expr,
1094                predicate,
1095                ..
1096            } => {
1097                Self::collect_variables(list_expr, vars);
1098                Self::collect_variables(predicate, vars);
1099            }
1100            LogicalExpression::ExistsSubquery(_)
1101            | LogicalExpression::CountSubquery(_)
1102            | LogicalExpression::ValueSubquery(_) => {
1103                // Subqueries have their own variable scope
1104            }
1105            LogicalExpression::PatternComprehension { projection, .. } => {
1106                Self::collect_variables(projection, vars);
1107            }
1108            LogicalExpression::MapProjection { base, entries } => {
1109                vars.insert(base.clone());
1110                for entry in entries {
1111                    if let crate::query::plan::MapProjectionEntry::LiteralEntry(_, expr) = entry {
1112                        Self::collect_variables(expr, vars);
1113                    }
1114                }
1115            }
1116            LogicalExpression::Reduce {
1117                initial,
1118                list,
1119                expression,
1120                ..
1121            } => {
1122                Self::collect_variables(initial, vars);
1123                Self::collect_variables(list, vars);
1124                Self::collect_variables(expression, vars);
1125            }
1126        }
1127    }
1128
1129    /// Extracts aliases from projection expressions.
1130    fn extract_projection_aliases(
1131        &self,
1132        projections: &[crate::query::plan::Projection],
1133    ) -> HashSet<String> {
1134        projections.iter().filter_map(|p| p.alias.clone()).collect()
1135    }
1136}
1137
1138impl Default for Optimizer {
1139    fn default() -> Self {
1140        Self::new()
1141    }
1142}
1143
1144#[cfg(test)]
1145mod tests {
1146    use super::*;
1147    use crate::query::plan::{
1148        AggregateExpr, AggregateFunction, AggregateOp, BinaryOp, DistinctOp, ExpandDirection,
1149        ExpandOp, JoinOp, JoinType, LimitOp, NodeScanOp, PathMode, ProjectOp, Projection,
1150        ReturnItem, ReturnOp, SkipOp, SortKey, SortOp, SortOrder, UnaryOp,
1151    };
1152    use grafeo_common::types::Value;
1153
1154    #[test]
1155    fn test_optimizer_filter_pushdown_simple() {
1156        // Query: MATCH (n:Person) WHERE n.age > 30 RETURN n
1157        // Before: Return -> Filter -> NodeScan
1158        // After:  Return -> Filter -> NodeScan (filter stays at bottom)
1159
1160        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1161            items: vec![ReturnItem {
1162                expression: LogicalExpression::Variable("n".to_string()),
1163                alias: None,
1164            }],
1165            distinct: false,
1166            input: Box::new(LogicalOperator::Filter(FilterOp {
1167                predicate: LogicalExpression::Binary {
1168                    left: Box::new(LogicalExpression::Property {
1169                        variable: "n".to_string(),
1170                        property: "age".to_string(),
1171                    }),
1172                    op: BinaryOp::Gt,
1173                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1174                },
1175                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1176                    variable: "n".to_string(),
1177                    label: Some("Person".to_string()),
1178                    input: None,
1179                })),
1180                pushdown_hint: None,
1181            })),
1182        }));
1183
1184        let optimizer = Optimizer::new();
1185        let optimized = optimizer.optimize(plan).unwrap();
1186
1187        // The structure should remain similar (filter stays near scan)
1188        if let LogicalOperator::Return(ret) = &optimized.root
1189            && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1190            && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1191        {
1192            assert_eq!(scan.variable, "n");
1193            return;
1194        }
1195        panic!("Expected Return -> Filter -> NodeScan structure");
1196    }
1197
1198    #[test]
1199    fn test_optimizer_filter_pushdown_through_expand() {
1200        // Query: MATCH (a:Person)-[:KNOWS]->(b) WHERE a.age > 30 RETURN b
1201        // The filter on 'a' should be pushed before the expand
1202
1203        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1204            items: vec![ReturnItem {
1205                expression: LogicalExpression::Variable("b".to_string()),
1206                alias: None,
1207            }],
1208            distinct: false,
1209            input: Box::new(LogicalOperator::Filter(FilterOp {
1210                predicate: LogicalExpression::Binary {
1211                    left: Box::new(LogicalExpression::Property {
1212                        variable: "a".to_string(),
1213                        property: "age".to_string(),
1214                    }),
1215                    op: BinaryOp::Gt,
1216                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1217                },
1218                pushdown_hint: None,
1219                input: Box::new(LogicalOperator::Expand(ExpandOp {
1220                    from_variable: "a".to_string(),
1221                    to_variable: "b".to_string(),
1222                    edge_variable: None,
1223                    direction: ExpandDirection::Outgoing,
1224                    edge_types: vec!["KNOWS".to_string()],
1225                    min_hops: 1,
1226                    max_hops: Some(1),
1227                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1228                        variable: "a".to_string(),
1229                        label: Some("Person".to_string()),
1230                        input: None,
1231                    })),
1232                    path_alias: None,
1233                    path_mode: PathMode::Walk,
1234                })),
1235            })),
1236        }));
1237
1238        let optimizer = Optimizer::new();
1239        let optimized = optimizer.optimize(plan).unwrap();
1240
1241        // Filter on 'a' should be pushed before the expand
1242        // Expected: Return -> Expand -> Filter -> NodeScan
1243        if let LogicalOperator::Return(ret) = &optimized.root
1244            && let LogicalOperator::Expand(expand) = ret.input.as_ref()
1245            && let LogicalOperator::Filter(filter) = expand.input.as_ref()
1246            && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1247        {
1248            assert_eq!(scan.variable, "a");
1249            assert_eq!(expand.from_variable, "a");
1250            assert_eq!(expand.to_variable, "b");
1251            return;
1252        }
1253        panic!("Expected Return -> Expand -> Filter -> NodeScan structure");
1254    }
1255
1256    #[test]
1257    fn test_optimizer_filter_not_pushed_through_expand_for_target_var() {
1258        // Query: MATCH (a:Person)-[:KNOWS]->(b) WHERE b.age > 30 RETURN a
1259        // The filter on 'b' should NOT be pushed before the expand
1260
1261        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1262            items: vec![ReturnItem {
1263                expression: LogicalExpression::Variable("a".to_string()),
1264                alias: None,
1265            }],
1266            distinct: false,
1267            input: Box::new(LogicalOperator::Filter(FilterOp {
1268                predicate: LogicalExpression::Binary {
1269                    left: Box::new(LogicalExpression::Property {
1270                        variable: "b".to_string(),
1271                        property: "age".to_string(),
1272                    }),
1273                    op: BinaryOp::Gt,
1274                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1275                },
1276                pushdown_hint: None,
1277                input: Box::new(LogicalOperator::Expand(ExpandOp {
1278                    from_variable: "a".to_string(),
1279                    to_variable: "b".to_string(),
1280                    edge_variable: None,
1281                    direction: ExpandDirection::Outgoing,
1282                    edge_types: vec!["KNOWS".to_string()],
1283                    min_hops: 1,
1284                    max_hops: Some(1),
1285                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1286                        variable: "a".to_string(),
1287                        label: Some("Person".to_string()),
1288                        input: None,
1289                    })),
1290                    path_alias: None,
1291                    path_mode: PathMode::Walk,
1292                })),
1293            })),
1294        }));
1295
1296        let optimizer = Optimizer::new();
1297        let optimized = optimizer.optimize(plan).unwrap();
1298
1299        // Filter on 'b' should stay after the expand
1300        // Expected: Return -> Filter -> Expand -> NodeScan
1301        if let LogicalOperator::Return(ret) = &optimized.root
1302            && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1303        {
1304            // Check that the filter is on 'b'
1305            if let LogicalExpression::Binary { left, .. } = &filter.predicate
1306                && let LogicalExpression::Property { variable, .. } = left.as_ref()
1307            {
1308                assert_eq!(variable, "b");
1309            }
1310
1311            if let LogicalOperator::Expand(expand) = filter.input.as_ref()
1312                && let LogicalOperator::NodeScan(_) = expand.input.as_ref()
1313            {
1314                return;
1315            }
1316        }
1317        panic!("Expected Return -> Filter -> Expand -> NodeScan structure");
1318    }
1319
1320    #[test]
1321    fn test_optimizer_extract_variables() {
1322        let optimizer = Optimizer::new();
1323
1324        let expr = LogicalExpression::Binary {
1325            left: Box::new(LogicalExpression::Property {
1326                variable: "n".to_string(),
1327                property: "age".to_string(),
1328            }),
1329            op: BinaryOp::Gt,
1330            right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1331        };
1332
1333        let vars = optimizer.extract_variables(&expr);
1334        assert_eq!(vars.len(), 1);
1335        assert!(vars.contains("n"));
1336    }
1337
1338    // Additional tests for optimizer configuration
1339
1340    #[test]
1341    fn test_optimizer_default() {
1342        let optimizer = Optimizer::default();
1343        // Should be able to optimize an empty plan
1344        let plan = LogicalPlan::new(LogicalOperator::Empty);
1345        let result = optimizer.optimize(plan);
1346        assert!(result.is_ok());
1347    }
1348
1349    #[test]
1350    fn test_optimizer_with_filter_pushdown_disabled() {
1351        let optimizer = Optimizer::new().with_filter_pushdown(false);
1352
1353        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1354            items: vec![ReturnItem {
1355                expression: LogicalExpression::Variable("n".to_string()),
1356                alias: None,
1357            }],
1358            distinct: false,
1359            input: Box::new(LogicalOperator::Filter(FilterOp {
1360                predicate: LogicalExpression::Literal(Value::Bool(true)),
1361                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1362                    variable: "n".to_string(),
1363                    label: None,
1364                    input: None,
1365                })),
1366                pushdown_hint: None,
1367            })),
1368        }));
1369
1370        let optimized = optimizer.optimize(plan).unwrap();
1371        // Structure should be unchanged
1372        if let LogicalOperator::Return(ret) = &optimized.root
1373            && let LogicalOperator::Filter(_) = ret.input.as_ref()
1374        {
1375            return;
1376        }
1377        panic!("Expected unchanged structure");
1378    }
1379
1380    #[test]
1381    fn test_optimizer_with_join_reorder_disabled() {
1382        let optimizer = Optimizer::new().with_join_reorder(false);
1383        assert!(
1384            optimizer
1385                .optimize(LogicalPlan::new(LogicalOperator::Empty))
1386                .is_ok()
1387        );
1388    }
1389
1390    #[test]
1391    fn test_optimizer_with_cost_model() {
1392        let cost_model = CostModel::new();
1393        let optimizer = Optimizer::new().with_cost_model(cost_model);
1394        assert!(
1395            optimizer
1396                .cost_model()
1397                .estimate(&LogicalOperator::Empty, 0.0)
1398                .total()
1399                < 0.001
1400        );
1401    }
1402
1403    #[test]
1404    fn test_optimizer_with_cardinality_estimator() {
1405        let mut estimator = CardinalityEstimator::new();
1406        estimator.add_table_stats("Test", TableStats::new(500));
1407        let optimizer = Optimizer::new().with_cardinality_estimator(estimator);
1408
1409        let scan = LogicalOperator::NodeScan(NodeScanOp {
1410            variable: "n".to_string(),
1411            label: Some("Test".to_string()),
1412            input: None,
1413        });
1414        let plan = LogicalPlan::new(scan);
1415
1416        let cardinality = optimizer.estimate_cardinality(&plan);
1417        assert!((cardinality - 500.0).abs() < 0.001);
1418    }
1419
1420    #[test]
1421    fn test_optimizer_estimate_cost() {
1422        let optimizer = Optimizer::new();
1423        let plan = LogicalPlan::new(LogicalOperator::NodeScan(NodeScanOp {
1424            variable: "n".to_string(),
1425            label: None,
1426            input: None,
1427        }));
1428
1429        let cost = optimizer.estimate_cost(&plan);
1430        assert!(cost.total() > 0.0);
1431    }
1432
1433    // Filter pushdown through various operators
1434
1435    #[test]
1436    fn test_filter_pushdown_through_project() {
1437        let optimizer = Optimizer::new();
1438
1439        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1440            predicate: LogicalExpression::Binary {
1441                left: Box::new(LogicalExpression::Property {
1442                    variable: "n".to_string(),
1443                    property: "age".to_string(),
1444                }),
1445                op: BinaryOp::Gt,
1446                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1447            },
1448            pushdown_hint: None,
1449            input: Box::new(LogicalOperator::Project(ProjectOp {
1450                projections: vec![Projection {
1451                    expression: LogicalExpression::Variable("n".to_string()),
1452                    alias: None,
1453                }],
1454                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1455                    variable: "n".to_string(),
1456                    label: None,
1457                    input: None,
1458                })),
1459                pass_through_input: false,
1460            })),
1461        }));
1462
1463        let optimized = optimizer.optimize(plan).unwrap();
1464
1465        // Filter should be pushed through Project
1466        if let LogicalOperator::Project(proj) = &optimized.root
1467            && let LogicalOperator::Filter(_) = proj.input.as_ref()
1468        {
1469            return;
1470        }
1471        panic!("Expected Project -> Filter structure");
1472    }
1473
1474    #[test]
1475    fn test_filter_not_pushed_through_project_with_alias() {
1476        let optimizer = Optimizer::new();
1477
1478        // Filter on computed column 'x' should not be pushed through project that creates 'x'
1479        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1480            predicate: LogicalExpression::Binary {
1481                left: Box::new(LogicalExpression::Variable("x".to_string())),
1482                op: BinaryOp::Gt,
1483                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1484            },
1485            pushdown_hint: None,
1486            input: Box::new(LogicalOperator::Project(ProjectOp {
1487                projections: vec![Projection {
1488                    expression: LogicalExpression::Property {
1489                        variable: "n".to_string(),
1490                        property: "age".to_string(),
1491                    },
1492                    alias: Some("x".to_string()),
1493                }],
1494                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1495                    variable: "n".to_string(),
1496                    label: None,
1497                    input: None,
1498                })),
1499                pass_through_input: false,
1500            })),
1501        }));
1502
1503        let optimized = optimizer.optimize(plan).unwrap();
1504
1505        // Filter should stay above Project
1506        if let LogicalOperator::Filter(filter) = &optimized.root
1507            && let LogicalOperator::Project(_) = filter.input.as_ref()
1508        {
1509            return;
1510        }
1511        panic!("Expected Filter -> Project structure");
1512    }
1513
1514    #[test]
1515    fn test_filter_pushdown_through_limit() {
1516        let optimizer = Optimizer::new();
1517
1518        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1519            predicate: LogicalExpression::Literal(Value::Bool(true)),
1520            pushdown_hint: None,
1521            input: Box::new(LogicalOperator::Limit(LimitOp {
1522                count: 10.into(),
1523                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1524                    variable: "n".to_string(),
1525                    label: None,
1526                    input: None,
1527                })),
1528            })),
1529        }));
1530
1531        let optimized = optimizer.optimize(plan).unwrap();
1532
1533        // Filter stays above Limit (cannot be pushed through)
1534        if let LogicalOperator::Filter(filter) = &optimized.root
1535            && let LogicalOperator::Limit(_) = filter.input.as_ref()
1536        {
1537            return;
1538        }
1539        panic!("Expected Filter -> Limit structure");
1540    }
1541
1542    #[test]
1543    fn test_filter_pushdown_through_sort() {
1544        let optimizer = Optimizer::new();
1545
1546        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1547            predicate: LogicalExpression::Literal(Value::Bool(true)),
1548            pushdown_hint: None,
1549            input: Box::new(LogicalOperator::Sort(SortOp {
1550                keys: vec![SortKey {
1551                    expression: LogicalExpression::Variable("n".to_string()),
1552                    order: SortOrder::Ascending,
1553                    nulls: None,
1554                }],
1555                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1556                    variable: "n".to_string(),
1557                    label: None,
1558                    input: None,
1559                })),
1560            })),
1561        }));
1562
1563        let optimized = optimizer.optimize(plan).unwrap();
1564
1565        // Filter stays above Sort
1566        if let LogicalOperator::Filter(filter) = &optimized.root
1567            && let LogicalOperator::Sort(_) = filter.input.as_ref()
1568        {
1569            return;
1570        }
1571        panic!("Expected Filter -> Sort structure");
1572    }
1573
1574    #[test]
1575    fn test_filter_pushdown_through_distinct() {
1576        let optimizer = Optimizer::new();
1577
1578        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1579            predicate: LogicalExpression::Literal(Value::Bool(true)),
1580            pushdown_hint: None,
1581            input: Box::new(LogicalOperator::Distinct(DistinctOp {
1582                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1583                    variable: "n".to_string(),
1584                    label: None,
1585                    input: None,
1586                })),
1587                columns: None,
1588            })),
1589        }));
1590
1591        let optimized = optimizer.optimize(plan).unwrap();
1592
1593        // Filter stays above Distinct
1594        if let LogicalOperator::Filter(filter) = &optimized.root
1595            && let LogicalOperator::Distinct(_) = filter.input.as_ref()
1596        {
1597            return;
1598        }
1599        panic!("Expected Filter -> Distinct structure");
1600    }
1601
1602    #[test]
1603    fn test_filter_not_pushed_through_aggregate() {
1604        let optimizer = Optimizer::new();
1605
1606        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1607            predicate: LogicalExpression::Binary {
1608                left: Box::new(LogicalExpression::Variable("cnt".to_string())),
1609                op: BinaryOp::Gt,
1610                right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1611            },
1612            pushdown_hint: None,
1613            input: Box::new(LogicalOperator::Aggregate(AggregateOp {
1614                group_by: vec![],
1615                aggregates: vec![AggregateExpr {
1616                    function: AggregateFunction::Count,
1617                    expression: None,
1618                    expression2: None,
1619                    distinct: false,
1620                    alias: Some("cnt".to_string()),
1621                    percentile: None,
1622                    separator: None,
1623                }],
1624                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1625                    variable: "n".to_string(),
1626                    label: None,
1627                    input: None,
1628                })),
1629                having: None,
1630            })),
1631        }));
1632
1633        let optimized = optimizer.optimize(plan).unwrap();
1634
1635        // Filter should stay above Aggregate
1636        if let LogicalOperator::Filter(filter) = &optimized.root
1637            && let LogicalOperator::Aggregate(_) = filter.input.as_ref()
1638        {
1639            return;
1640        }
1641        panic!("Expected Filter -> Aggregate structure");
1642    }
1643
1644    #[test]
1645    fn test_filter_pushdown_to_left_join_side() {
1646        let optimizer = Optimizer::new();
1647
1648        // Filter on left variable should be pushed to left side
1649        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1650            predicate: LogicalExpression::Binary {
1651                left: Box::new(LogicalExpression::Property {
1652                    variable: "a".to_string(),
1653                    property: "age".to_string(),
1654                }),
1655                op: BinaryOp::Gt,
1656                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1657            },
1658            pushdown_hint: None,
1659            input: Box::new(LogicalOperator::Join(JoinOp {
1660                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1661                    variable: "a".to_string(),
1662                    label: Some("Person".to_string()),
1663                    input: None,
1664                })),
1665                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1666                    variable: "b".to_string(),
1667                    label: Some("Company".to_string()),
1668                    input: None,
1669                })),
1670                join_type: JoinType::Inner,
1671                conditions: vec![],
1672            })),
1673        }));
1674
1675        let optimized = optimizer.optimize(plan).unwrap();
1676
1677        // Filter should be pushed to left side of join
1678        if let LogicalOperator::Join(join) = &optimized.root
1679            && let LogicalOperator::Filter(_) = join.left.as_ref()
1680        {
1681            return;
1682        }
1683        panic!("Expected Join with Filter on left side");
1684    }
1685
1686    #[test]
1687    fn test_filter_pushdown_to_right_join_side() {
1688        let optimizer = Optimizer::new();
1689
1690        // Filter on right variable should be pushed to right side
1691        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1692            predicate: LogicalExpression::Binary {
1693                left: Box::new(LogicalExpression::Property {
1694                    variable: "b".to_string(),
1695                    property: "name".to_string(),
1696                }),
1697                op: BinaryOp::Eq,
1698                right: Box::new(LogicalExpression::Literal(Value::String("Acme".into()))),
1699            },
1700            pushdown_hint: None,
1701            input: Box::new(LogicalOperator::Join(JoinOp {
1702                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1703                    variable: "a".to_string(),
1704                    label: Some("Person".to_string()),
1705                    input: None,
1706                })),
1707                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1708                    variable: "b".to_string(),
1709                    label: Some("Company".to_string()),
1710                    input: None,
1711                })),
1712                join_type: JoinType::Inner,
1713                conditions: vec![],
1714            })),
1715        }));
1716
1717        let optimized = optimizer.optimize(plan).unwrap();
1718
1719        // Filter should be pushed to right side of join
1720        if let LogicalOperator::Join(join) = &optimized.root
1721            && let LogicalOperator::Filter(_) = join.right.as_ref()
1722        {
1723            return;
1724        }
1725        panic!("Expected Join with Filter on right side");
1726    }
1727
1728    #[test]
1729    fn test_filter_not_pushed_when_uses_both_join_sides() {
1730        let optimizer = Optimizer::new();
1731
1732        // Filter using both variables should stay above join
1733        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1734            predicate: LogicalExpression::Binary {
1735                left: Box::new(LogicalExpression::Property {
1736                    variable: "a".to_string(),
1737                    property: "id".to_string(),
1738                }),
1739                op: BinaryOp::Eq,
1740                right: Box::new(LogicalExpression::Property {
1741                    variable: "b".to_string(),
1742                    property: "a_id".to_string(),
1743                }),
1744            },
1745            pushdown_hint: None,
1746            input: Box::new(LogicalOperator::Join(JoinOp {
1747                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1748                    variable: "a".to_string(),
1749                    label: None,
1750                    input: None,
1751                })),
1752                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1753                    variable: "b".to_string(),
1754                    label: None,
1755                    input: None,
1756                })),
1757                join_type: JoinType::Inner,
1758                conditions: vec![],
1759            })),
1760        }));
1761
1762        let optimized = optimizer.optimize(plan).unwrap();
1763
1764        // Filter should stay above join
1765        if let LogicalOperator::Filter(filter) = &optimized.root
1766            && let LogicalOperator::Join(_) = filter.input.as_ref()
1767        {
1768            return;
1769        }
1770        panic!("Expected Filter -> Join structure");
1771    }
1772
1773    // Variable extraction tests
1774
1775    #[test]
1776    fn test_extract_variables_from_variable() {
1777        let optimizer = Optimizer::new();
1778        let expr = LogicalExpression::Variable("x".to_string());
1779        let vars = optimizer.extract_variables(&expr);
1780        assert_eq!(vars.len(), 1);
1781        assert!(vars.contains("x"));
1782    }
1783
1784    #[test]
1785    fn test_extract_variables_from_unary() {
1786        let optimizer = Optimizer::new();
1787        let expr = LogicalExpression::Unary {
1788            op: UnaryOp::Not,
1789            operand: Box::new(LogicalExpression::Variable("x".to_string())),
1790        };
1791        let vars = optimizer.extract_variables(&expr);
1792        assert_eq!(vars.len(), 1);
1793        assert!(vars.contains("x"));
1794    }
1795
1796    #[test]
1797    fn test_extract_variables_from_function_call() {
1798        let optimizer = Optimizer::new();
1799        let expr = LogicalExpression::FunctionCall {
1800            name: "length".to_string(),
1801            args: vec![
1802                LogicalExpression::Variable("a".to_string()),
1803                LogicalExpression::Variable("b".to_string()),
1804            ],
1805            distinct: false,
1806        };
1807        let vars = optimizer.extract_variables(&expr);
1808        assert_eq!(vars.len(), 2);
1809        assert!(vars.contains("a"));
1810        assert!(vars.contains("b"));
1811    }
1812
1813    #[test]
1814    fn test_extract_variables_from_list() {
1815        let optimizer = Optimizer::new();
1816        let expr = LogicalExpression::List(vec![
1817            LogicalExpression::Variable("a".to_string()),
1818            LogicalExpression::Literal(Value::Int64(1)),
1819            LogicalExpression::Variable("b".to_string()),
1820        ]);
1821        let vars = optimizer.extract_variables(&expr);
1822        assert_eq!(vars.len(), 2);
1823        assert!(vars.contains("a"));
1824        assert!(vars.contains("b"));
1825    }
1826
1827    #[test]
1828    fn test_extract_variables_from_map() {
1829        let optimizer = Optimizer::new();
1830        let expr = LogicalExpression::Map(vec![
1831            (
1832                "key1".to_string(),
1833                LogicalExpression::Variable("a".to_string()),
1834            ),
1835            (
1836                "key2".to_string(),
1837                LogicalExpression::Variable("b".to_string()),
1838            ),
1839        ]);
1840        let vars = optimizer.extract_variables(&expr);
1841        assert_eq!(vars.len(), 2);
1842        assert!(vars.contains("a"));
1843        assert!(vars.contains("b"));
1844    }
1845
1846    #[test]
1847    fn test_extract_variables_from_index_access() {
1848        let optimizer = Optimizer::new();
1849        let expr = LogicalExpression::IndexAccess {
1850            base: Box::new(LogicalExpression::Variable("list".to_string())),
1851            index: Box::new(LogicalExpression::Variable("idx".to_string())),
1852        };
1853        let vars = optimizer.extract_variables(&expr);
1854        assert_eq!(vars.len(), 2);
1855        assert!(vars.contains("list"));
1856        assert!(vars.contains("idx"));
1857    }
1858
1859    #[test]
1860    fn test_extract_variables_from_slice_access() {
1861        let optimizer = Optimizer::new();
1862        let expr = LogicalExpression::SliceAccess {
1863            base: Box::new(LogicalExpression::Variable("list".to_string())),
1864            start: Some(Box::new(LogicalExpression::Variable("s".to_string()))),
1865            end: Some(Box::new(LogicalExpression::Variable("e".to_string()))),
1866        };
1867        let vars = optimizer.extract_variables(&expr);
1868        assert_eq!(vars.len(), 3);
1869        assert!(vars.contains("list"));
1870        assert!(vars.contains("s"));
1871        assert!(vars.contains("e"));
1872    }
1873
1874    #[test]
1875    fn test_extract_variables_from_case() {
1876        let optimizer = Optimizer::new();
1877        let expr = LogicalExpression::Case {
1878            operand: Some(Box::new(LogicalExpression::Variable("x".to_string()))),
1879            when_clauses: vec![(
1880                LogicalExpression::Literal(Value::Int64(1)),
1881                LogicalExpression::Variable("a".to_string()),
1882            )],
1883            else_clause: Some(Box::new(LogicalExpression::Variable("b".to_string()))),
1884        };
1885        let vars = optimizer.extract_variables(&expr);
1886        assert_eq!(vars.len(), 3);
1887        assert!(vars.contains("x"));
1888        assert!(vars.contains("a"));
1889        assert!(vars.contains("b"));
1890    }
1891
1892    #[test]
1893    fn test_extract_variables_from_labels() {
1894        let optimizer = Optimizer::new();
1895        let expr = LogicalExpression::Labels("n".to_string());
1896        let vars = optimizer.extract_variables(&expr);
1897        assert_eq!(vars.len(), 1);
1898        assert!(vars.contains("n"));
1899    }
1900
1901    #[test]
1902    fn test_extract_variables_from_type() {
1903        let optimizer = Optimizer::new();
1904        let expr = LogicalExpression::Type("e".to_string());
1905        let vars = optimizer.extract_variables(&expr);
1906        assert_eq!(vars.len(), 1);
1907        assert!(vars.contains("e"));
1908    }
1909
1910    #[test]
1911    fn test_extract_variables_from_id() {
1912        let optimizer = Optimizer::new();
1913        let expr = LogicalExpression::Id("n".to_string());
1914        let vars = optimizer.extract_variables(&expr);
1915        assert_eq!(vars.len(), 1);
1916        assert!(vars.contains("n"));
1917    }
1918
1919    #[test]
1920    fn test_extract_variables_from_list_comprehension() {
1921        let optimizer = Optimizer::new();
1922        let expr = LogicalExpression::ListComprehension {
1923            variable: "x".to_string(),
1924            list_expr: Box::new(LogicalExpression::Variable("items".to_string())),
1925            filter_expr: Some(Box::new(LogicalExpression::Variable("pred".to_string()))),
1926            map_expr: Box::new(LogicalExpression::Variable("result".to_string())),
1927        };
1928        let vars = optimizer.extract_variables(&expr);
1929        assert!(vars.contains("items"));
1930        assert!(vars.contains("pred"));
1931        assert!(vars.contains("result"));
1932    }
1933
1934    #[test]
1935    fn test_extract_variables_from_literal_and_parameter() {
1936        let optimizer = Optimizer::new();
1937
1938        let literal = LogicalExpression::Literal(Value::Int64(42));
1939        assert!(optimizer.extract_variables(&literal).is_empty());
1940
1941        let param = LogicalExpression::Parameter("p".to_string());
1942        assert!(optimizer.extract_variables(&param).is_empty());
1943    }
1944
1945    // Recursive filter pushdown tests
1946
1947    #[test]
1948    fn test_recursive_filter_pushdown_through_skip() {
1949        let optimizer = Optimizer::new();
1950
1951        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1952            items: vec![ReturnItem {
1953                expression: LogicalExpression::Variable("n".to_string()),
1954                alias: None,
1955            }],
1956            distinct: false,
1957            input: Box::new(LogicalOperator::Filter(FilterOp {
1958                predicate: LogicalExpression::Literal(Value::Bool(true)),
1959                pushdown_hint: None,
1960                input: Box::new(LogicalOperator::Skip(SkipOp {
1961                    count: 5.into(),
1962                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1963                        variable: "n".to_string(),
1964                        label: None,
1965                        input: None,
1966                    })),
1967                })),
1968            })),
1969        }));
1970
1971        let optimized = optimizer.optimize(plan).unwrap();
1972
1973        // Verify optimization succeeded
1974        assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
1975    }
1976
1977    #[test]
1978    fn test_nested_filter_pushdown() {
1979        let optimizer = Optimizer::new();
1980
1981        // Multiple nested filters
1982        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1983            items: vec![ReturnItem {
1984                expression: LogicalExpression::Variable("n".to_string()),
1985                alias: None,
1986            }],
1987            distinct: false,
1988            input: Box::new(LogicalOperator::Filter(FilterOp {
1989                predicate: LogicalExpression::Binary {
1990                    left: Box::new(LogicalExpression::Property {
1991                        variable: "n".to_string(),
1992                        property: "x".to_string(),
1993                    }),
1994                    op: BinaryOp::Gt,
1995                    right: Box::new(LogicalExpression::Literal(Value::Int64(1))),
1996                },
1997                pushdown_hint: None,
1998                input: Box::new(LogicalOperator::Filter(FilterOp {
1999                    predicate: LogicalExpression::Binary {
2000                        left: Box::new(LogicalExpression::Property {
2001                            variable: "n".to_string(),
2002                            property: "y".to_string(),
2003                        }),
2004                        op: BinaryOp::Lt,
2005                        right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
2006                    },
2007                    pushdown_hint: None,
2008                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2009                        variable: "n".to_string(),
2010                        label: None,
2011                        input: None,
2012                    })),
2013                })),
2014            })),
2015        }));
2016
2017        let optimized = optimizer.optimize(plan).unwrap();
2018        assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
2019    }
2020
2021    #[test]
2022    fn test_cyclic_join_produces_multi_way_join() {
2023        use crate::query::plan::JoinCondition;
2024
2025        // Triangle pattern: a ⋈ b ⋈ c ⋈ a (cyclic)
2026        let scan_a = LogicalOperator::NodeScan(NodeScanOp {
2027            variable: "a".to_string(),
2028            label: Some("Person".to_string()),
2029            input: None,
2030        });
2031        let scan_b = LogicalOperator::NodeScan(NodeScanOp {
2032            variable: "b".to_string(),
2033            label: Some("Person".to_string()),
2034            input: None,
2035        });
2036        let scan_c = LogicalOperator::NodeScan(NodeScanOp {
2037            variable: "c".to_string(),
2038            label: Some("Person".to_string()),
2039            input: None,
2040        });
2041
2042        // Build: Join(Join(a, b, a=b), c, b=c) with extra condition c=a
2043        let join_ab = LogicalOperator::Join(JoinOp {
2044            left: Box::new(scan_a),
2045            right: Box::new(scan_b),
2046            join_type: JoinType::Inner,
2047            conditions: vec![JoinCondition {
2048                left: LogicalExpression::Variable("a".to_string()),
2049                right: LogicalExpression::Variable("b".to_string()),
2050            }],
2051        });
2052
2053        let join_abc = LogicalOperator::Join(JoinOp {
2054            left: Box::new(join_ab),
2055            right: Box::new(scan_c),
2056            join_type: JoinType::Inner,
2057            conditions: vec![
2058                JoinCondition {
2059                    left: LogicalExpression::Variable("b".to_string()),
2060                    right: LogicalExpression::Variable("c".to_string()),
2061                },
2062                JoinCondition {
2063                    left: LogicalExpression::Variable("c".to_string()),
2064                    right: LogicalExpression::Variable("a".to_string()),
2065                },
2066            ],
2067        });
2068
2069        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
2070            items: vec![ReturnItem {
2071                expression: LogicalExpression::Variable("a".to_string()),
2072                alias: None,
2073            }],
2074            distinct: false,
2075            input: Box::new(join_abc),
2076        }));
2077
2078        let mut optimizer = Optimizer::new();
2079        optimizer
2080            .card_estimator
2081            .add_table_stats("Person", cardinality::TableStats::new(1000));
2082
2083        let optimized = optimizer.optimize(plan).unwrap();
2084
2085        // Walk the tree to find a MultiWayJoin
2086        fn has_multi_way_join(op: &LogicalOperator) -> bool {
2087            match op {
2088                LogicalOperator::MultiWayJoin(_) => true,
2089                LogicalOperator::Return(ret) => has_multi_way_join(&ret.input),
2090                LogicalOperator::Filter(f) => has_multi_way_join(&f.input),
2091                LogicalOperator::Project(p) => has_multi_way_join(&p.input),
2092                _ => false,
2093            }
2094        }
2095
2096        assert!(
2097            has_multi_way_join(&optimized.root),
2098            "Expected MultiWayJoin for cyclic triangle pattern"
2099        );
2100    }
2101
2102    #[test]
2103    fn test_acyclic_join_uses_binary_joins() {
2104        use crate::query::plan::JoinCondition;
2105
2106        // Chain: a ⋈ b ⋈ c (acyclic)
2107        let scan_a = LogicalOperator::NodeScan(NodeScanOp {
2108            variable: "a".to_string(),
2109            label: Some("Person".to_string()),
2110            input: None,
2111        });
2112        let scan_b = LogicalOperator::NodeScan(NodeScanOp {
2113            variable: "b".to_string(),
2114            label: Some("Person".to_string()),
2115            input: None,
2116        });
2117        let scan_c = LogicalOperator::NodeScan(NodeScanOp {
2118            variable: "c".to_string(),
2119            label: Some("Company".to_string()),
2120            input: None,
2121        });
2122
2123        let join_ab = LogicalOperator::Join(JoinOp {
2124            left: Box::new(scan_a),
2125            right: Box::new(scan_b),
2126            join_type: JoinType::Inner,
2127            conditions: vec![JoinCondition {
2128                left: LogicalExpression::Variable("a".to_string()),
2129                right: LogicalExpression::Variable("b".to_string()),
2130            }],
2131        });
2132
2133        let join_abc = LogicalOperator::Join(JoinOp {
2134            left: Box::new(join_ab),
2135            right: Box::new(scan_c),
2136            join_type: JoinType::Inner,
2137            conditions: vec![JoinCondition {
2138                left: LogicalExpression::Variable("b".to_string()),
2139                right: LogicalExpression::Variable("c".to_string()),
2140            }],
2141        });
2142
2143        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
2144            items: vec![ReturnItem {
2145                expression: LogicalExpression::Variable("a".to_string()),
2146                alias: None,
2147            }],
2148            distinct: false,
2149            input: Box::new(join_abc),
2150        }));
2151
2152        let mut optimizer = Optimizer::new();
2153        optimizer
2154            .card_estimator
2155            .add_table_stats("Person", cardinality::TableStats::new(1000));
2156        optimizer
2157            .card_estimator
2158            .add_table_stats("Company", cardinality::TableStats::new(100));
2159
2160        let optimized = optimizer.optimize(plan).unwrap();
2161
2162        // Should NOT contain MultiWayJoin for acyclic pattern
2163        fn has_multi_way_join(op: &LogicalOperator) -> bool {
2164            match op {
2165                LogicalOperator::MultiWayJoin(_) => true,
2166                LogicalOperator::Return(ret) => has_multi_way_join(&ret.input),
2167                LogicalOperator::Filter(f) => has_multi_way_join(&f.input),
2168                LogicalOperator::Project(p) => has_multi_way_join(&p.input),
2169                LogicalOperator::Join(j) => {
2170                    has_multi_way_join(&j.left) || has_multi_way_join(&j.right)
2171                }
2172                _ => false,
2173            }
2174        }
2175
2176        assert!(
2177            !has_multi_way_join(&optimized.root),
2178            "Acyclic join should NOT produce MultiWayJoin"
2179        );
2180    }
2181}