Skip to main content

grafeo_engine/query/optimizer/
mod.rs

1//! Makes your queries faster without changing their meaning.
2//!
3//! The optimizer transforms logical plans to run more efficiently:
4//!
5//! | Optimization | What it does |
6//! | ------------ | ------------ |
7//! | Filter Pushdown | Moves `WHERE` clauses closer to scans - filter early, process less |
8//! | Join Reordering | Picks the best order to join tables using the DPccp algorithm |
9//! | Predicate Simplification | Folds constants like `1 + 1` into `2` |
10//!
11//! The optimizer uses [`CostModel`] and [`CardinalityEstimator`] to predict
12//! how expensive different plans are, then picks the cheapest.
13
14pub mod cardinality;
15pub mod cost;
16pub mod join_order;
17
18pub use cardinality::{
19    CardinalityEstimator, ColumnStats, EstimationLog, SelectivityConfig, TableStats,
20};
21pub use cost::{Cost, CostModel};
22pub use join_order::{BitSet, DPccp, JoinGraph, JoinGraphBuilder, JoinPlan};
23
24use crate::query::plan::{
25    FilterOp, JoinCondition, LogicalExpression, LogicalOperator, LogicalPlan, MultiWayJoinOp,
26};
27use grafeo_common::grafeo_debug_span;
28use grafeo_common::utils::error::Result;
29use std::collections::HashSet;
30
31/// Information about a join condition for join reordering.
32#[derive(Debug, Clone)]
33struct JoinInfo {
34    left_var: String,
35    right_var: String,
36    left_expr: LogicalExpression,
37    right_expr: LogicalExpression,
38}
39
40/// A column required by the query, used for projection pushdown.
41#[derive(Debug, Clone, PartialEq, Eq, Hash)]
42enum RequiredColumn {
43    /// A variable (node, edge, or path binding)
44    Variable(String),
45    /// A specific property of a variable
46    Property(String, String),
47}
48
49/// Transforms logical plans for faster execution.
50///
51/// Create with [`new()`](Self::new), then call [`optimize()`](Self::optimize).
52/// Use the builder methods to enable/disable specific optimizations.
53pub struct Optimizer {
54    /// Whether to enable filter pushdown.
55    enable_filter_pushdown: bool,
56    /// Whether to enable join reordering.
57    enable_join_reorder: bool,
58    /// Whether to enable projection pushdown.
59    enable_projection_pushdown: bool,
60    /// Cost model for estimation.
61    cost_model: CostModel,
62    /// Cardinality estimator.
63    card_estimator: CardinalityEstimator,
64}
65
66impl Optimizer {
67    /// Creates a new optimizer with default settings.
68    #[must_use]
69    pub fn new() -> Self {
70        Self {
71            enable_filter_pushdown: true,
72            enable_join_reorder: true,
73            enable_projection_pushdown: true,
74            cost_model: CostModel::new(),
75            card_estimator: CardinalityEstimator::new(),
76        }
77    }
78
79    /// Creates an optimizer with cardinality estimates from the store's statistics.
80    ///
81    /// Pre-populates the cardinality estimator with per-label row counts and
82    /// edge type fanout. Feeds per-edge-type degree stats, label cardinalities,
83    /// and graph totals into the cost model for accurate estimation.
84    #[cfg(feature = "lpg")]
85    #[must_use]
86    pub fn from_store(store: &grafeo_core::graph::lpg::LpgStore) -> Self {
87        store.ensure_statistics_fresh();
88        let stats = store.statistics();
89        Self::from_statistics(&stats)
90    }
91
92    /// Creates an optimizer from any GraphStore implementation.
93    ///
94    /// Unlike [`from_store`](Self::from_store), this does not call
95    /// `ensure_statistics_fresh()` since external stores manage their own
96    /// statistics. The store's [`statistics()`](grafeo_core::graph::GraphStore::statistics) method
97    /// is called directly.
98    #[must_use]
99    pub fn from_graph_store(store: &dyn grafeo_core::graph::GraphStore) -> Self {
100        let stats = store.statistics();
101        Self::from_statistics(&stats)
102    }
103
104    /// Creates an optimizer from RDF statistics.
105    ///
106    /// Uses triple pattern cardinality estimates for cost-based optimization
107    /// of SPARQL queries. Maps total triples to graph totals for the cost model.
108    #[cfg(feature = "triple-store")]
109    #[must_use]
110    pub fn from_rdf_statistics(rdf_stats: grafeo_core::statistics::RdfStatistics) -> Self {
111        let total = rdf_stats.total_triples;
112        let estimator = CardinalityEstimator::from_rdf_statistics(rdf_stats);
113        Self {
114            enable_filter_pushdown: true,
115            enable_join_reorder: true,
116            enable_projection_pushdown: true,
117            cost_model: CostModel::new().with_graph_totals(total, total),
118            card_estimator: estimator,
119        }
120    }
121
122    /// Creates an optimizer from a Statistics snapshot.
123    ///
124    /// Extracts label cardinalities, edge type degrees, and graph totals
125    /// for both the cardinality estimator and cost model.
126    #[must_use]
127    fn from_statistics(stats: &grafeo_core::statistics::Statistics) -> Self {
128        let estimator = CardinalityEstimator::from_statistics(stats);
129
130        let avg_fanout = if stats.total_nodes > 0 {
131            (stats.total_edges as f64 / stats.total_nodes as f64).max(1.0)
132        } else {
133            10.0
134        };
135
136        let edge_type_degrees: std::collections::HashMap<String, (f64, f64)> = stats
137            .edge_types
138            .iter()
139            .map(|(name, et)| (name.clone(), (et.avg_out_degree, et.avg_in_degree)))
140            .collect();
141
142        let label_cardinalities: std::collections::HashMap<String, u64> = stats
143            .labels
144            .iter()
145            .map(|(name, ls)| (name.clone(), ls.node_count))
146            .collect();
147
148        Self {
149            enable_filter_pushdown: true,
150            enable_join_reorder: true,
151            enable_projection_pushdown: true,
152            cost_model: CostModel::new()
153                .with_avg_fanout(avg_fanout)
154                .with_edge_type_degrees(edge_type_degrees)
155                .with_label_cardinalities(label_cardinalities)
156                .with_graph_totals(stats.total_nodes, stats.total_edges),
157            card_estimator: estimator,
158        }
159    }
160
161    /// Enables or disables filter pushdown.
162    pub fn with_filter_pushdown(mut self, enabled: bool) -> Self {
163        self.enable_filter_pushdown = enabled;
164        self
165    }
166
167    /// Enables or disables join reordering.
168    pub fn with_join_reorder(mut self, enabled: bool) -> Self {
169        self.enable_join_reorder = enabled;
170        self
171    }
172
173    /// Enables or disables projection pushdown.
174    pub fn with_projection_pushdown(mut self, enabled: bool) -> Self {
175        self.enable_projection_pushdown = enabled;
176        self
177    }
178
179    /// Sets the cost model.
180    pub fn with_cost_model(mut self, cost_model: CostModel) -> Self {
181        self.cost_model = cost_model;
182        self
183    }
184
185    /// Sets the cardinality estimator.
186    pub fn with_cardinality_estimator(mut self, estimator: CardinalityEstimator) -> Self {
187        self.card_estimator = estimator;
188        self
189    }
190
191    /// Sets the selectivity configuration for the cardinality estimator.
192    pub fn with_selectivity_config(mut self, config: SelectivityConfig) -> Self {
193        self.card_estimator = CardinalityEstimator::with_selectivity_config(config);
194        self
195    }
196
197    /// Returns a reference to the cost model.
198    pub fn cost_model(&self) -> &CostModel {
199        &self.cost_model
200    }
201
202    /// Returns a reference to the cardinality estimator.
203    pub fn cardinality_estimator(&self) -> &CardinalityEstimator {
204        &self.card_estimator
205    }
206
207    /// Estimates the total cost of a plan by recursively costing the entire tree.
208    ///
209    /// Walks every operator in the plan, computing cardinality at each level
210    /// and summing the per-operator costs. Uses actual child cardinalities
211    /// for join cost estimation rather than approximations.
212    pub fn estimate_cost(&self, plan: &LogicalPlan) -> Cost {
213        self.cost_model
214            .estimate_tree(&plan.root, &self.card_estimator)
215    }
216
217    /// Estimates the cardinality of a plan.
218    pub fn estimate_cardinality(&self, plan: &LogicalPlan) -> f64 {
219        self.card_estimator.estimate(&plan.root)
220    }
221
222    /// Optimizes a logical plan.
223    ///
224    /// # Errors
225    ///
226    /// Returns an error if optimization fails.
227    pub fn optimize(&self, plan: LogicalPlan) -> Result<LogicalPlan> {
228        let _span = grafeo_debug_span!("grafeo::query::optimize");
229        let mut root = plan.root;
230
231        // Apply optimization rules
232        if self.enable_filter_pushdown {
233            root = self.push_filters_down(root);
234        }
235
236        if self.enable_join_reorder {
237            root = self.reorder_joins(root);
238        }
239
240        if self.enable_projection_pushdown {
241            root = self.push_projections_down(root);
242        }
243
244        Ok(LogicalPlan {
245            root,
246            explain: plan.explain,
247            profile: plan.profile,
248            default_params: plan.default_params,
249        })
250    }
251
252    /// Pushes projections down the operator tree to eliminate unused columns early.
253    ///
254    /// This optimization:
255    /// 1. Collects required variables/properties from the root
256    /// 2. Propagates requirements down through the tree
257    /// 3. Inserts projections to eliminate unneeded columns before expensive operations
258    fn push_projections_down(&self, op: LogicalOperator) -> LogicalOperator {
259        // Collect required columns from the top of the plan
260        let required = self.collect_required_columns(&op);
261
262        // Push projections down
263        self.push_projections_recursive(op, &required)
264    }
265
266    /// Collects all variables and properties required by an operator and its ancestors.
267    fn collect_required_columns(&self, op: &LogicalOperator) -> HashSet<RequiredColumn> {
268        let mut required = HashSet::new();
269        Self::collect_required_recursive(op, &mut required);
270        required
271    }
272
273    /// Recursively collects required columns.
274    fn collect_required_recursive(op: &LogicalOperator, required: &mut HashSet<RequiredColumn>) {
275        match op {
276            LogicalOperator::Return(ret) => {
277                for item in &ret.items {
278                    Self::collect_from_expression(&item.expression, required);
279                }
280                Self::collect_required_recursive(&ret.input, required);
281            }
282            LogicalOperator::Project(proj) => {
283                for p in &proj.projections {
284                    Self::collect_from_expression(&p.expression, required);
285                }
286                Self::collect_required_recursive(&proj.input, required);
287            }
288            LogicalOperator::Filter(filter) => {
289                Self::collect_from_expression(&filter.predicate, required);
290                Self::collect_required_recursive(&filter.input, required);
291            }
292            LogicalOperator::Sort(sort) => {
293                for key in &sort.keys {
294                    Self::collect_from_expression(&key.expression, required);
295                }
296                Self::collect_required_recursive(&sort.input, required);
297            }
298            LogicalOperator::Aggregate(agg) => {
299                for expr in &agg.group_by {
300                    Self::collect_from_expression(expr, required);
301                }
302                for agg_expr in &agg.aggregates {
303                    if let Some(ref expr) = agg_expr.expression {
304                        Self::collect_from_expression(expr, required);
305                    }
306                }
307                if let Some(ref having) = agg.having {
308                    Self::collect_from_expression(having, required);
309                }
310                Self::collect_required_recursive(&agg.input, required);
311            }
312            LogicalOperator::Join(join) => {
313                for cond in &join.conditions {
314                    Self::collect_from_expression(&cond.left, required);
315                    Self::collect_from_expression(&cond.right, required);
316                }
317                Self::collect_required_recursive(&join.left, required);
318                Self::collect_required_recursive(&join.right, required);
319            }
320            LogicalOperator::Expand(expand) => {
321                // The source and target variables are needed
322                required.insert(RequiredColumn::Variable(expand.from_variable.clone()));
323                required.insert(RequiredColumn::Variable(expand.to_variable.clone()));
324                if let Some(ref edge_var) = expand.edge_variable {
325                    required.insert(RequiredColumn::Variable(edge_var.clone()));
326                }
327                Self::collect_required_recursive(&expand.input, required);
328            }
329            LogicalOperator::Limit(limit) => {
330                Self::collect_required_recursive(&limit.input, required);
331            }
332            LogicalOperator::Skip(skip) => {
333                Self::collect_required_recursive(&skip.input, required);
334            }
335            LogicalOperator::Distinct(distinct) => {
336                Self::collect_required_recursive(&distinct.input, required);
337            }
338            LogicalOperator::NodeScan(scan) => {
339                required.insert(RequiredColumn::Variable(scan.variable.clone()));
340            }
341            LogicalOperator::EdgeScan(scan) => {
342                required.insert(RequiredColumn::Variable(scan.variable.clone()));
343            }
344            LogicalOperator::MultiWayJoin(mwj) => {
345                for cond in &mwj.conditions {
346                    Self::collect_from_expression(&cond.left, required);
347                    Self::collect_from_expression(&cond.right, required);
348                }
349                for input in &mwj.inputs {
350                    Self::collect_required_recursive(input, required);
351                }
352            }
353            _ => {}
354        }
355    }
356
357    /// Collects required columns from an expression.
358    fn collect_from_expression(expr: &LogicalExpression, required: &mut HashSet<RequiredColumn>) {
359        match expr {
360            LogicalExpression::Variable(var) => {
361                required.insert(RequiredColumn::Variable(var.clone()));
362            }
363            LogicalExpression::Property { variable, property } => {
364                required.insert(RequiredColumn::Property(variable.clone(), property.clone()));
365                required.insert(RequiredColumn::Variable(variable.clone()));
366            }
367            LogicalExpression::Binary { left, right, .. } => {
368                Self::collect_from_expression(left, required);
369                Self::collect_from_expression(right, required);
370            }
371            LogicalExpression::Unary { operand, .. } => {
372                Self::collect_from_expression(operand, required);
373            }
374            LogicalExpression::FunctionCall { args, .. } => {
375                for arg in args {
376                    Self::collect_from_expression(arg, required);
377                }
378            }
379            LogicalExpression::List(items) => {
380                for item in items {
381                    Self::collect_from_expression(item, required);
382                }
383            }
384            LogicalExpression::Map(pairs) => {
385                for (_, value) in pairs {
386                    Self::collect_from_expression(value, required);
387                }
388            }
389            LogicalExpression::IndexAccess { base, index } => {
390                Self::collect_from_expression(base, required);
391                Self::collect_from_expression(index, required);
392            }
393            LogicalExpression::SliceAccess { base, start, end } => {
394                Self::collect_from_expression(base, required);
395                if let Some(s) = start {
396                    Self::collect_from_expression(s, required);
397                }
398                if let Some(e) = end {
399                    Self::collect_from_expression(e, required);
400                }
401            }
402            LogicalExpression::Case {
403                operand,
404                when_clauses,
405                else_clause,
406            } => {
407                if let Some(op) = operand {
408                    Self::collect_from_expression(op, required);
409                }
410                for (cond, result) in when_clauses {
411                    Self::collect_from_expression(cond, required);
412                    Self::collect_from_expression(result, required);
413                }
414                if let Some(else_expr) = else_clause {
415                    Self::collect_from_expression(else_expr, required);
416                }
417            }
418            LogicalExpression::Labels(var)
419            | LogicalExpression::Type(var)
420            | LogicalExpression::Id(var) => {
421                required.insert(RequiredColumn::Variable(var.clone()));
422            }
423            LogicalExpression::ListComprehension {
424                list_expr,
425                filter_expr,
426                map_expr,
427                ..
428            } => {
429                Self::collect_from_expression(list_expr, required);
430                if let Some(filter) = filter_expr {
431                    Self::collect_from_expression(filter, required);
432                }
433                Self::collect_from_expression(map_expr, required);
434            }
435            _ => {}
436        }
437    }
438
439    /// Recursively pushes projections down, adding them before expensive operations.
440    fn push_projections_recursive(
441        &self,
442        op: LogicalOperator,
443        required: &HashSet<RequiredColumn>,
444    ) -> LogicalOperator {
445        match op {
446            LogicalOperator::Return(mut ret) => {
447                ret.input = Box::new(self.push_projections_recursive(*ret.input, required));
448                LogicalOperator::Return(ret)
449            }
450            LogicalOperator::Project(mut proj) => {
451                proj.input = Box::new(self.push_projections_recursive(*proj.input, required));
452                LogicalOperator::Project(proj)
453            }
454            LogicalOperator::Filter(mut filter) => {
455                filter.input = Box::new(self.push_projections_recursive(*filter.input, required));
456                LogicalOperator::Filter(filter)
457            }
458            LogicalOperator::Sort(mut sort) => {
459                // Sort is expensive - consider adding a projection before it
460                // to reduce tuple width
461                sort.input = Box::new(self.push_projections_recursive(*sort.input, required));
462                LogicalOperator::Sort(sort)
463            }
464            LogicalOperator::Aggregate(mut agg) => {
465                agg.input = Box::new(self.push_projections_recursive(*agg.input, required));
466                LogicalOperator::Aggregate(agg)
467            }
468            LogicalOperator::Join(mut join) => {
469                // Joins are expensive - the required columns help determine
470                // what to project on each side
471                let left_vars = self.collect_output_variables(&join.left);
472                let right_vars = self.collect_output_variables(&join.right);
473
474                // Filter required columns to each side
475                let left_required: HashSet<_> = required
476                    .iter()
477                    .filter(|c| match c {
478                        RequiredColumn::Variable(v) => left_vars.contains(v),
479                        RequiredColumn::Property(v, _) => left_vars.contains(v),
480                    })
481                    .cloned()
482                    .collect();
483
484                let right_required: HashSet<_> = required
485                    .iter()
486                    .filter(|c| match c {
487                        RequiredColumn::Variable(v) => right_vars.contains(v),
488                        RequiredColumn::Property(v, _) => right_vars.contains(v),
489                    })
490                    .cloned()
491                    .collect();
492
493                join.left = Box::new(self.push_projections_recursive(*join.left, &left_required));
494                join.right =
495                    Box::new(self.push_projections_recursive(*join.right, &right_required));
496                LogicalOperator::Join(join)
497            }
498            LogicalOperator::Expand(mut expand) => {
499                expand.input = Box::new(self.push_projections_recursive(*expand.input, required));
500                LogicalOperator::Expand(expand)
501            }
502            LogicalOperator::Limit(mut limit) => {
503                limit.input = Box::new(self.push_projections_recursive(*limit.input, required));
504                LogicalOperator::Limit(limit)
505            }
506            LogicalOperator::Skip(mut skip) => {
507                skip.input = Box::new(self.push_projections_recursive(*skip.input, required));
508                LogicalOperator::Skip(skip)
509            }
510            LogicalOperator::Distinct(mut distinct) => {
511                distinct.input =
512                    Box::new(self.push_projections_recursive(*distinct.input, required));
513                LogicalOperator::Distinct(distinct)
514            }
515            LogicalOperator::MapCollect(mut mc) => {
516                mc.input = Box::new(self.push_projections_recursive(*mc.input, required));
517                LogicalOperator::MapCollect(mc)
518            }
519            LogicalOperator::MultiWayJoin(mut mwj) => {
520                mwj.inputs = mwj
521                    .inputs
522                    .into_iter()
523                    .map(|input| self.push_projections_recursive(input, required))
524                    .collect();
525                LogicalOperator::MultiWayJoin(mwj)
526            }
527            other => other,
528        }
529    }
530
531    /// Reorders joins in the operator tree using the DPccp algorithm.
532    ///
533    /// This optimization finds the optimal join order by:
534    /// 1. Extracting all base relations (scans) and join conditions
535    /// 2. Building a join graph
536    /// 3. Using dynamic programming to find the cheapest join order
537    fn reorder_joins(&self, op: LogicalOperator) -> LogicalOperator {
538        // First, recursively optimize children
539        let op = self.reorder_joins_recursive(op);
540
541        // Then, if this is a join tree, try to optimize it
542        if let Some((relations, conditions)) = self.extract_join_tree(&op)
543            && relations.len() >= 2
544            && let Some(optimized) = self.optimize_join_order(&relations, &conditions)
545        {
546            return optimized;
547        }
548
549        op
550    }
551
552    /// Recursively applies join reordering to child operators.
553    fn reorder_joins_recursive(&self, op: LogicalOperator) -> LogicalOperator {
554        match op {
555            LogicalOperator::Return(mut ret) => {
556                ret.input = Box::new(self.reorder_joins(*ret.input));
557                LogicalOperator::Return(ret)
558            }
559            LogicalOperator::Project(mut proj) => {
560                proj.input = Box::new(self.reorder_joins(*proj.input));
561                LogicalOperator::Project(proj)
562            }
563            LogicalOperator::Filter(mut filter) => {
564                filter.input = Box::new(self.reorder_joins(*filter.input));
565                LogicalOperator::Filter(filter)
566            }
567            LogicalOperator::Limit(mut limit) => {
568                limit.input = Box::new(self.reorder_joins(*limit.input));
569                LogicalOperator::Limit(limit)
570            }
571            LogicalOperator::Skip(mut skip) => {
572                skip.input = Box::new(self.reorder_joins(*skip.input));
573                LogicalOperator::Skip(skip)
574            }
575            LogicalOperator::Sort(mut sort) => {
576                sort.input = Box::new(self.reorder_joins(*sort.input));
577                LogicalOperator::Sort(sort)
578            }
579            LogicalOperator::Distinct(mut distinct) => {
580                distinct.input = Box::new(self.reorder_joins(*distinct.input));
581                LogicalOperator::Distinct(distinct)
582            }
583            LogicalOperator::Aggregate(mut agg) => {
584                agg.input = Box::new(self.reorder_joins(*agg.input));
585                LogicalOperator::Aggregate(agg)
586            }
587            LogicalOperator::Expand(mut expand) => {
588                expand.input = Box::new(self.reorder_joins(*expand.input));
589                LogicalOperator::Expand(expand)
590            }
591            LogicalOperator::MapCollect(mut mc) => {
592                mc.input = Box::new(self.reorder_joins(*mc.input));
593                LogicalOperator::MapCollect(mc)
594            }
595            LogicalOperator::MultiWayJoin(mut mwj) => {
596                mwj.inputs = mwj
597                    .inputs
598                    .into_iter()
599                    .map(|input| self.reorder_joins(input))
600                    .collect();
601                LogicalOperator::MultiWayJoin(mwj)
602            }
603            // Join operators are handled by the parent reorder_joins call
604            other => other,
605        }
606    }
607
608    /// Extracts base relations and join conditions from a join tree.
609    ///
610    /// Returns None if the operator is not a join tree.
611    fn extract_join_tree(
612        &self,
613        op: &LogicalOperator,
614    ) -> Option<(Vec<(String, LogicalOperator)>, Vec<JoinInfo>)> {
615        let mut relations = Vec::new();
616        let mut join_conditions = Vec::new();
617
618        if !self.collect_join_tree(op, &mut relations, &mut join_conditions) {
619            return None;
620        }
621
622        if relations.len() < 2 {
623            return None;
624        }
625
626        Some((relations, join_conditions))
627    }
628
629    /// Recursively collects base relations and join conditions.
630    ///
631    /// Returns true if this subtree is part of a join tree.
632    fn collect_join_tree(
633        &self,
634        op: &LogicalOperator,
635        relations: &mut Vec<(String, LogicalOperator)>,
636        conditions: &mut Vec<JoinInfo>,
637    ) -> bool {
638        match op {
639            LogicalOperator::Join(join) => {
640                // Collect from both sides
641                let left_ok = self.collect_join_tree(&join.left, relations, conditions);
642                let right_ok = self.collect_join_tree(&join.right, relations, conditions);
643
644                // Add conditions from this join
645                for cond in &join.conditions {
646                    if let (Some(left_var), Some(right_var)) = (
647                        self.extract_variable_from_expr(&cond.left),
648                        self.extract_variable_from_expr(&cond.right),
649                    ) {
650                        conditions.push(JoinInfo {
651                            left_var,
652                            right_var,
653                            left_expr: cond.left.clone(),
654                            right_expr: cond.right.clone(),
655                        });
656                    }
657                }
658
659                left_ok && right_ok
660            }
661            LogicalOperator::NodeScan(scan) => {
662                relations.push((scan.variable.clone(), op.clone()));
663                true
664            }
665            LogicalOperator::EdgeScan(scan) => {
666                relations.push((scan.variable.clone(), op.clone()));
667                true
668            }
669            LogicalOperator::Filter(filter) => {
670                // A filter on a base relation is still part of the join tree
671                self.collect_join_tree(&filter.input, relations, conditions)
672            }
673            LogicalOperator::Expand(expand) => {
674                // Expand is a special case - it's like a join with the adjacency
675                // For now, treat the whole Expand subtree as a single relation
676                relations.push((expand.to_variable.clone(), op.clone()));
677                true
678            }
679            #[cfg(feature = "triple-store")]
680            LogicalOperator::TripleScan(scan) => {
681                // Use the first variable found as the relation name.
682                // For all-constant patterns, generate a unique fallback to
683                // avoid HashMap key collisions in `add_relation`.
684                let name = scan
685                    .subject
686                    .as_variable()
687                    .or_else(|| scan.predicate.as_variable())
688                    .or_else(|| scan.object.as_variable())
689                    .map_or_else(|| format!("__tp_{}", relations.len()), String::from);
690                relations.push((name, op.clone()));
691                true
692            }
693            _ => false,
694        }
695    }
696
697    /// Extracts the primary variable from an expression.
698    fn extract_variable_from_expr(&self, expr: &LogicalExpression) -> Option<String> {
699        match expr {
700            LogicalExpression::Variable(v) => Some(v.clone()),
701            LogicalExpression::Property { variable, .. } => Some(variable.clone()),
702            _ => None,
703        }
704    }
705
706    /// Optimizes the join order using DPccp, or produces a multi-way
707    /// leapfrog join for cyclic patterns when the cost model prefers it.
708    fn optimize_join_order(
709        &self,
710        relations: &[(String, LogicalOperator)],
711        conditions: &[JoinInfo],
712    ) -> Option<LogicalOperator> {
713        use join_order::{DPccp, JoinGraphBuilder};
714
715        // Build the join graph
716        let mut builder = JoinGraphBuilder::new();
717
718        for (var, relation) in relations {
719            builder.add_relation(var, relation.clone());
720        }
721
722        for cond in conditions {
723            builder.add_join_condition(
724                &cond.left_var,
725                &cond.right_var,
726                cond.left_expr.clone(),
727                cond.right_expr.clone(),
728            );
729        }
730
731        let graph = builder.build();
732
733        // For cyclic graphs with 3+ relations, use leapfrog (WCOJ) join.
734        // Cyclic joins (e.g. triangle patterns) benefit from worst-case optimal
735        // multi-way intersection rather than binary hash join cascades that can
736        // produce intermediate blowup.
737        if graph.is_cyclic() && relations.len() >= 3 {
738            // Collect shared variables (variables appearing in 2+ conditions)
739            let mut var_counts: std::collections::HashMap<&str, usize> =
740                std::collections::HashMap::new();
741            for cond in conditions {
742                *var_counts.entry(&cond.left_var).or_default() += 1;
743                *var_counts.entry(&cond.right_var).or_default() += 1;
744            }
745            let shared_variables: Vec<String> = var_counts
746                .into_iter()
747                .filter(|(_, count)| *count >= 2)
748                .map(|(var, _)| var.to_string())
749                .collect();
750
751            let join_conditions: Vec<JoinCondition> = conditions
752                .iter()
753                .map(|c| JoinCondition {
754                    left: c.left_expr.clone(),
755                    right: c.right_expr.clone(),
756                })
757                .collect();
758
759            return Some(LogicalOperator::MultiWayJoin(MultiWayJoinOp {
760                inputs: relations.iter().map(|(_, rel)| rel.clone()).collect(),
761                conditions: join_conditions,
762                shared_variables,
763            }));
764        }
765
766        // Fall through to DPccp for binary join ordering
767        let mut dpccp = DPccp::new(&graph, &self.cost_model, &self.card_estimator);
768        let plan = dpccp.optimize()?;
769
770        Some(plan.operator)
771    }
772
773    /// Pushes filters down the operator tree.
774    ///
775    /// This optimization moves filter predicates as close to the data source
776    /// as possible to reduce the amount of data processed by upper operators.
777    fn push_filters_down(&self, op: LogicalOperator) -> LogicalOperator {
778        match op {
779            // For Filter operators, try to push the predicate into the child
780            LogicalOperator::Filter(filter) => {
781                let optimized_input = self.push_filters_down(*filter.input);
782                self.try_push_filter_into(filter.predicate, optimized_input)
783            }
784            // Recursively optimize children for other operators
785            LogicalOperator::Return(mut ret) => {
786                ret.input = Box::new(self.push_filters_down(*ret.input));
787                LogicalOperator::Return(ret)
788            }
789            LogicalOperator::Project(mut proj) => {
790                proj.input = Box::new(self.push_filters_down(*proj.input));
791                LogicalOperator::Project(proj)
792            }
793            LogicalOperator::Limit(mut limit) => {
794                limit.input = Box::new(self.push_filters_down(*limit.input));
795                LogicalOperator::Limit(limit)
796            }
797            LogicalOperator::Skip(mut skip) => {
798                skip.input = Box::new(self.push_filters_down(*skip.input));
799                LogicalOperator::Skip(skip)
800            }
801            LogicalOperator::Sort(mut sort) => {
802                sort.input = Box::new(self.push_filters_down(*sort.input));
803                LogicalOperator::Sort(sort)
804            }
805            LogicalOperator::Distinct(mut distinct) => {
806                distinct.input = Box::new(self.push_filters_down(*distinct.input));
807                LogicalOperator::Distinct(distinct)
808            }
809            LogicalOperator::Expand(mut expand) => {
810                expand.input = Box::new(self.push_filters_down(*expand.input));
811                LogicalOperator::Expand(expand)
812            }
813            LogicalOperator::Join(mut join) => {
814                join.left = Box::new(self.push_filters_down(*join.left));
815                join.right = Box::new(self.push_filters_down(*join.right));
816                LogicalOperator::Join(join)
817            }
818            LogicalOperator::Aggregate(mut agg) => {
819                agg.input = Box::new(self.push_filters_down(*agg.input));
820                LogicalOperator::Aggregate(agg)
821            }
822            LogicalOperator::MapCollect(mut mc) => {
823                mc.input = Box::new(self.push_filters_down(*mc.input));
824                LogicalOperator::MapCollect(mc)
825            }
826            LogicalOperator::MultiWayJoin(mut mwj) => {
827                mwj.inputs = mwj
828                    .inputs
829                    .into_iter()
830                    .map(|input| self.push_filters_down(input))
831                    .collect();
832                LogicalOperator::MultiWayJoin(mwj)
833            }
834            // Leaf operators and unsupported operators are returned as-is
835            other => other,
836        }
837    }
838
839    /// Tries to push a filter predicate into the given operator.
840    ///
841    /// Returns either the predicate pushed into the operator, or a new
842    /// Filter operator on top if the predicate cannot be pushed further.
843    fn try_push_filter_into(
844        &self,
845        predicate: LogicalExpression,
846        op: LogicalOperator,
847    ) -> LogicalOperator {
848        match op {
849            // Can push through Project if predicate doesn't depend on computed columns
850            LogicalOperator::Project(mut proj) => {
851                let predicate_vars = self.extract_variables(&predicate);
852                let computed_vars = self.extract_projection_aliases(&proj.projections);
853
854                // If predicate doesn't use any computed columns, push through
855                if predicate_vars.is_disjoint(&computed_vars) {
856                    proj.input = Box::new(self.try_push_filter_into(predicate, *proj.input));
857                    LogicalOperator::Project(proj)
858                } else {
859                    // Can't push through, keep filter on top
860                    LogicalOperator::Filter(FilterOp {
861                        predicate,
862                        pushdown_hint: None,
863                        input: Box::new(LogicalOperator::Project(proj)),
864                    })
865                }
866            }
867
868            // Can push through Return (which is like a projection)
869            LogicalOperator::Return(mut ret) => {
870                ret.input = Box::new(self.try_push_filter_into(predicate, *ret.input));
871                LogicalOperator::Return(ret)
872            }
873
874            // Can push through Expand if predicate doesn't use variables introduced by this expand
875            LogicalOperator::Expand(mut expand) => {
876                let predicate_vars = self.extract_variables(&predicate);
877
878                // Variables introduced by this expand are:
879                // - The target variable (to_variable)
880                // - The edge variable (if any)
881                // - The path alias (if any)
882                let mut introduced_vars = vec![&expand.to_variable];
883                if let Some(ref edge_var) = expand.edge_variable {
884                    introduced_vars.push(edge_var);
885                }
886                if let Some(ref path_alias) = expand.path_alias {
887                    introduced_vars.push(path_alias);
888                }
889
890                // Check if predicate uses any variables introduced by this expand
891                let uses_introduced_vars =
892                    predicate_vars.iter().any(|v| introduced_vars.contains(&v));
893
894                if !uses_introduced_vars {
895                    // Predicate doesn't use vars from this expand, so push through
896                    expand.input = Box::new(self.try_push_filter_into(predicate, *expand.input));
897                    LogicalOperator::Expand(expand)
898                } else {
899                    // Keep filter after expand
900                    LogicalOperator::Filter(FilterOp {
901                        predicate,
902                        pushdown_hint: None,
903                        input: Box::new(LogicalOperator::Expand(expand)),
904                    })
905                }
906            }
907
908            // Can push through Join to left/right side based on variables used
909            LogicalOperator::Join(mut join) => {
910                let predicate_vars = self.extract_variables(&predicate);
911                let left_vars = self.collect_output_variables(&join.left);
912                let right_vars = self.collect_output_variables(&join.right);
913
914                let uses_left = predicate_vars.iter().any(|v| left_vars.contains(v));
915                let uses_right = predicate_vars.iter().any(|v| right_vars.contains(v));
916
917                if uses_left && !uses_right {
918                    // Push to left side
919                    join.left = Box::new(self.try_push_filter_into(predicate, *join.left));
920                    LogicalOperator::Join(join)
921                } else if uses_right && !uses_left {
922                    // Push to right side
923                    join.right = Box::new(self.try_push_filter_into(predicate, *join.right));
924                    LogicalOperator::Join(join)
925                } else {
926                    // Uses both sides - keep above join
927                    LogicalOperator::Filter(FilterOp {
928                        predicate,
929                        pushdown_hint: None,
930                        input: Box::new(LogicalOperator::Join(join)),
931                    })
932                }
933            }
934
935            // Cannot push through Aggregate (predicate refers to aggregated values)
936            LogicalOperator::Aggregate(agg) => LogicalOperator::Filter(FilterOp {
937                predicate,
938                pushdown_hint: None,
939                input: Box::new(LogicalOperator::Aggregate(agg)),
940            }),
941
942            // For NodeScan, we've reached the bottom - keep filter on top
943            LogicalOperator::NodeScan(scan) => LogicalOperator::Filter(FilterOp {
944                predicate,
945                pushdown_hint: None,
946                input: Box::new(LogicalOperator::NodeScan(scan)),
947            }),
948
949            // For other operators, keep filter on top
950            other => LogicalOperator::Filter(FilterOp {
951                predicate,
952                pushdown_hint: None,
953                input: Box::new(other),
954            }),
955        }
956    }
957
958    // NOTE: Filter-into-TripleScan pushdown is intentionally not implemented.
959    // Binding a variable position in the scan removes it from output columns,
960    // which breaks downstream operators (SELECT, ORDER BY) that reference it.
961    // A correct implementation needs a separate `pushed_bindings` field on
962    // TripleScanOp so the variable stays in output while the scan is constrained.
963
964    /// Collects all output variable names from an operator.
965    fn collect_output_variables(&self, op: &LogicalOperator) -> HashSet<String> {
966        let mut vars = HashSet::new();
967        Self::collect_output_variables_recursive(op, &mut vars);
968        vars
969    }
970
971    /// Recursively collects output variables from an operator.
972    fn collect_output_variables_recursive(op: &LogicalOperator, vars: &mut HashSet<String>) {
973        match op {
974            LogicalOperator::NodeScan(scan) => {
975                vars.insert(scan.variable.clone());
976            }
977            LogicalOperator::EdgeScan(scan) => {
978                vars.insert(scan.variable.clone());
979            }
980            LogicalOperator::Expand(expand) => {
981                vars.insert(expand.to_variable.clone());
982                if let Some(edge_var) = &expand.edge_variable {
983                    vars.insert(edge_var.clone());
984                }
985                Self::collect_output_variables_recursive(&expand.input, vars);
986            }
987            LogicalOperator::Filter(filter) => {
988                Self::collect_output_variables_recursive(&filter.input, vars);
989            }
990            LogicalOperator::Project(proj) => {
991                for p in &proj.projections {
992                    if let Some(alias) = &p.alias {
993                        vars.insert(alias.clone());
994                    }
995                }
996                Self::collect_output_variables_recursive(&proj.input, vars);
997            }
998            LogicalOperator::Join(join) => {
999                Self::collect_output_variables_recursive(&join.left, vars);
1000                Self::collect_output_variables_recursive(&join.right, vars);
1001            }
1002            LogicalOperator::Aggregate(agg) => {
1003                for expr in &agg.group_by {
1004                    Self::collect_variables(expr, vars);
1005                }
1006                for agg_expr in &agg.aggregates {
1007                    if let Some(alias) = &agg_expr.alias {
1008                        vars.insert(alias.clone());
1009                    }
1010                }
1011            }
1012            LogicalOperator::Return(ret) => {
1013                Self::collect_output_variables_recursive(&ret.input, vars);
1014            }
1015            LogicalOperator::Limit(limit) => {
1016                Self::collect_output_variables_recursive(&limit.input, vars);
1017            }
1018            LogicalOperator::Skip(skip) => {
1019                Self::collect_output_variables_recursive(&skip.input, vars);
1020            }
1021            LogicalOperator::Sort(sort) => {
1022                Self::collect_output_variables_recursive(&sort.input, vars);
1023            }
1024            LogicalOperator::Distinct(distinct) => {
1025                Self::collect_output_variables_recursive(&distinct.input, vars);
1026            }
1027            #[cfg(feature = "triple-store")]
1028            LogicalOperator::TripleScan(scan) => {
1029                if let Some(v) = scan.subject.as_variable() {
1030                    vars.insert(v.to_string());
1031                }
1032                if let Some(v) = scan.predicate.as_variable() {
1033                    vars.insert(v.to_string());
1034                }
1035                if let Some(v) = scan.object.as_variable() {
1036                    vars.insert(v.to_string());
1037                }
1038                if let Some(ref g) = scan.graph
1039                    && let Some(v) = g.as_variable()
1040                {
1041                    vars.insert(v.to_string());
1042                }
1043            }
1044            _ => {}
1045        }
1046    }
1047
1048    /// Extracts all variable names referenced in an expression.
1049    fn extract_variables(&self, expr: &LogicalExpression) -> HashSet<String> {
1050        let mut vars = HashSet::new();
1051        Self::collect_variables(expr, &mut vars);
1052        vars
1053    }
1054
1055    /// Recursively collects variable names from an expression.
1056    fn collect_variables(expr: &LogicalExpression, vars: &mut HashSet<String>) {
1057        match expr {
1058            LogicalExpression::Variable(name) => {
1059                vars.insert(name.clone());
1060            }
1061            LogicalExpression::Property { variable, .. } => {
1062                vars.insert(variable.clone());
1063            }
1064            LogicalExpression::Binary { left, right, .. } => {
1065                Self::collect_variables(left, vars);
1066                Self::collect_variables(right, vars);
1067            }
1068            LogicalExpression::Unary { operand, .. } => {
1069                Self::collect_variables(operand, vars);
1070            }
1071            LogicalExpression::FunctionCall { args, .. } => {
1072                for arg in args {
1073                    Self::collect_variables(arg, vars);
1074                }
1075            }
1076            LogicalExpression::List(items) => {
1077                for item in items {
1078                    Self::collect_variables(item, vars);
1079                }
1080            }
1081            LogicalExpression::Map(pairs) => {
1082                for (_, value) in pairs {
1083                    Self::collect_variables(value, vars);
1084                }
1085            }
1086            LogicalExpression::IndexAccess { base, index } => {
1087                Self::collect_variables(base, vars);
1088                Self::collect_variables(index, vars);
1089            }
1090            LogicalExpression::SliceAccess { base, start, end } => {
1091                Self::collect_variables(base, vars);
1092                if let Some(s) = start {
1093                    Self::collect_variables(s, vars);
1094                }
1095                if let Some(e) = end {
1096                    Self::collect_variables(e, vars);
1097                }
1098            }
1099            LogicalExpression::Case {
1100                operand,
1101                when_clauses,
1102                else_clause,
1103            } => {
1104                if let Some(op) = operand {
1105                    Self::collect_variables(op, vars);
1106                }
1107                for (cond, result) in when_clauses {
1108                    Self::collect_variables(cond, vars);
1109                    Self::collect_variables(result, vars);
1110                }
1111                if let Some(else_expr) = else_clause {
1112                    Self::collect_variables(else_expr, vars);
1113                }
1114            }
1115            LogicalExpression::Labels(var)
1116            | LogicalExpression::Type(var)
1117            | LogicalExpression::Id(var) => {
1118                vars.insert(var.clone());
1119            }
1120            LogicalExpression::Literal(_) | LogicalExpression::Parameter(_) => {}
1121            LogicalExpression::ListComprehension {
1122                list_expr,
1123                filter_expr,
1124                map_expr,
1125                ..
1126            } => {
1127                Self::collect_variables(list_expr, vars);
1128                if let Some(filter) = filter_expr {
1129                    Self::collect_variables(filter, vars);
1130                }
1131                Self::collect_variables(map_expr, vars);
1132            }
1133            LogicalExpression::ListPredicate {
1134                list_expr,
1135                predicate,
1136                ..
1137            } => {
1138                Self::collect_variables(list_expr, vars);
1139                Self::collect_variables(predicate, vars);
1140            }
1141            LogicalExpression::ExistsSubquery(_)
1142            | LogicalExpression::CountSubquery(_)
1143            | LogicalExpression::ValueSubquery(_) => {
1144                // Subqueries have their own variable scope
1145            }
1146            LogicalExpression::PatternComprehension { projection, .. } => {
1147                Self::collect_variables(projection, vars);
1148            }
1149            LogicalExpression::MapProjection { base, entries } => {
1150                vars.insert(base.clone());
1151                for entry in entries {
1152                    if let crate::query::plan::MapProjectionEntry::LiteralEntry(_, expr) = entry {
1153                        Self::collect_variables(expr, vars);
1154                    }
1155                }
1156            }
1157            LogicalExpression::Reduce {
1158                initial,
1159                list,
1160                expression,
1161                ..
1162            } => {
1163                Self::collect_variables(initial, vars);
1164                Self::collect_variables(list, vars);
1165                Self::collect_variables(expression, vars);
1166            }
1167        }
1168    }
1169
1170    /// Extracts aliases from projection expressions.
1171    fn extract_projection_aliases(
1172        &self,
1173        projections: &[crate::query::plan::Projection],
1174    ) -> HashSet<String> {
1175        projections.iter().filter_map(|p| p.alias.clone()).collect()
1176    }
1177}
1178
1179impl Default for Optimizer {
1180    fn default() -> Self {
1181        Self::new()
1182    }
1183}
1184
1185#[cfg(test)]
1186mod tests {
1187    use super::*;
1188    use crate::query::plan::{
1189        AggregateExpr, AggregateFunction, AggregateOp, BinaryOp, DistinctOp, ExpandDirection,
1190        ExpandOp, JoinOp, JoinType, LimitOp, NodeScanOp, PathMode, ProjectOp, Projection,
1191        ReturnItem, ReturnOp, SkipOp, SortKey, SortOp, SortOrder, UnaryOp,
1192    };
1193    use grafeo_common::types::Value;
1194
1195    #[test]
1196    fn test_optimizer_filter_pushdown_simple() {
1197        // Query: MATCH (n:Person) WHERE n.age > 30 RETURN n
1198        // Before: Return -> Filter -> NodeScan
1199        // After:  Return -> Filter -> NodeScan (filter stays at bottom)
1200
1201        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1202            items: vec![ReturnItem {
1203                expression: LogicalExpression::Variable("n".to_string()),
1204                alias: None,
1205            }],
1206            distinct: false,
1207            input: Box::new(LogicalOperator::Filter(FilterOp {
1208                predicate: LogicalExpression::Binary {
1209                    left: Box::new(LogicalExpression::Property {
1210                        variable: "n".to_string(),
1211                        property: "age".to_string(),
1212                    }),
1213                    op: BinaryOp::Gt,
1214                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1215                },
1216                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1217                    variable: "n".to_string(),
1218                    label: Some("Person".to_string()),
1219                    input: None,
1220                })),
1221                pushdown_hint: None,
1222            })),
1223        }));
1224
1225        let optimizer = Optimizer::new();
1226        let optimized = optimizer.optimize(plan).unwrap();
1227
1228        // The structure should remain similar (filter stays near scan)
1229        if let LogicalOperator::Return(ret) = &optimized.root
1230            && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1231            && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1232        {
1233            assert_eq!(scan.variable, "n");
1234            return;
1235        }
1236        panic!("Expected Return -> Filter -> NodeScan structure");
1237    }
1238
1239    #[test]
1240    fn test_optimizer_filter_pushdown_through_expand() {
1241        // Query: MATCH (a:Person)-[:KNOWS]->(b) WHERE a.age > 30 RETURN b
1242        // The filter on 'a' should be pushed before the expand
1243
1244        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1245            items: vec![ReturnItem {
1246                expression: LogicalExpression::Variable("b".to_string()),
1247                alias: None,
1248            }],
1249            distinct: false,
1250            input: Box::new(LogicalOperator::Filter(FilterOp {
1251                predicate: LogicalExpression::Binary {
1252                    left: Box::new(LogicalExpression::Property {
1253                        variable: "a".to_string(),
1254                        property: "age".to_string(),
1255                    }),
1256                    op: BinaryOp::Gt,
1257                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1258                },
1259                pushdown_hint: None,
1260                input: Box::new(LogicalOperator::Expand(ExpandOp {
1261                    from_variable: "a".to_string(),
1262                    to_variable: "b".to_string(),
1263                    edge_variable: None,
1264                    direction: ExpandDirection::Outgoing,
1265                    edge_types: vec!["KNOWS".to_string()],
1266                    min_hops: 1,
1267                    max_hops: Some(1),
1268                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1269                        variable: "a".to_string(),
1270                        label: Some("Person".to_string()),
1271                        input: None,
1272                    })),
1273                    path_alias: None,
1274                    path_mode: PathMode::Walk,
1275                })),
1276            })),
1277        }));
1278
1279        let optimizer = Optimizer::new();
1280        let optimized = optimizer.optimize(plan).unwrap();
1281
1282        // Filter on 'a' should be pushed before the expand
1283        // Expected: Return -> Expand -> Filter -> NodeScan
1284        if let LogicalOperator::Return(ret) = &optimized.root
1285            && let LogicalOperator::Expand(expand) = ret.input.as_ref()
1286            && let LogicalOperator::Filter(filter) = expand.input.as_ref()
1287            && let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
1288        {
1289            assert_eq!(scan.variable, "a");
1290            assert_eq!(expand.from_variable, "a");
1291            assert_eq!(expand.to_variable, "b");
1292            return;
1293        }
1294        panic!("Expected Return -> Expand -> Filter -> NodeScan structure");
1295    }
1296
1297    #[test]
1298    fn test_optimizer_filter_not_pushed_through_expand_for_target_var() {
1299        // Query: MATCH (a:Person)-[:KNOWS]->(b) WHERE b.age > 30 RETURN a
1300        // The filter on 'b' should NOT be pushed before the expand
1301
1302        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1303            items: vec![ReturnItem {
1304                expression: LogicalExpression::Variable("a".to_string()),
1305                alias: None,
1306            }],
1307            distinct: false,
1308            input: Box::new(LogicalOperator::Filter(FilterOp {
1309                predicate: LogicalExpression::Binary {
1310                    left: Box::new(LogicalExpression::Property {
1311                        variable: "b".to_string(),
1312                        property: "age".to_string(),
1313                    }),
1314                    op: BinaryOp::Gt,
1315                    right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1316                },
1317                pushdown_hint: None,
1318                input: Box::new(LogicalOperator::Expand(ExpandOp {
1319                    from_variable: "a".to_string(),
1320                    to_variable: "b".to_string(),
1321                    edge_variable: None,
1322                    direction: ExpandDirection::Outgoing,
1323                    edge_types: vec!["KNOWS".to_string()],
1324                    min_hops: 1,
1325                    max_hops: Some(1),
1326                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1327                        variable: "a".to_string(),
1328                        label: Some("Person".to_string()),
1329                        input: None,
1330                    })),
1331                    path_alias: None,
1332                    path_mode: PathMode::Walk,
1333                })),
1334            })),
1335        }));
1336
1337        let optimizer = Optimizer::new();
1338        let optimized = optimizer.optimize(plan).unwrap();
1339
1340        // Filter on 'b' should stay after the expand
1341        // Expected: Return -> Filter -> Expand -> NodeScan
1342        if let LogicalOperator::Return(ret) = &optimized.root
1343            && let LogicalOperator::Filter(filter) = ret.input.as_ref()
1344        {
1345            // Check that the filter is on 'b'
1346            if let LogicalExpression::Binary { left, .. } = &filter.predicate
1347                && let LogicalExpression::Property { variable, .. } = left.as_ref()
1348            {
1349                assert_eq!(variable, "b");
1350            }
1351
1352            if let LogicalOperator::Expand(expand) = filter.input.as_ref()
1353                && let LogicalOperator::NodeScan(_) = expand.input.as_ref()
1354            {
1355                return;
1356            }
1357        }
1358        panic!("Expected Return -> Filter -> Expand -> NodeScan structure");
1359    }
1360
1361    #[test]
1362    fn test_optimizer_extract_variables() {
1363        let optimizer = Optimizer::new();
1364
1365        let expr = LogicalExpression::Binary {
1366            left: Box::new(LogicalExpression::Property {
1367                variable: "n".to_string(),
1368                property: "age".to_string(),
1369            }),
1370            op: BinaryOp::Gt,
1371            right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1372        };
1373
1374        let vars = optimizer.extract_variables(&expr);
1375        assert_eq!(vars.len(), 1);
1376        assert!(vars.contains("n"));
1377    }
1378
1379    // Additional tests for optimizer configuration
1380
1381    #[test]
1382    fn test_optimizer_default() {
1383        let optimizer = Optimizer::default();
1384        // Should be able to optimize an empty plan
1385        let plan = LogicalPlan::new(LogicalOperator::Empty);
1386        let result = optimizer.optimize(plan);
1387        assert!(result.is_ok());
1388    }
1389
1390    #[test]
1391    fn test_optimizer_with_filter_pushdown_disabled() {
1392        let optimizer = Optimizer::new().with_filter_pushdown(false);
1393
1394        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1395            items: vec![ReturnItem {
1396                expression: LogicalExpression::Variable("n".to_string()),
1397                alias: None,
1398            }],
1399            distinct: false,
1400            input: Box::new(LogicalOperator::Filter(FilterOp {
1401                predicate: LogicalExpression::Literal(Value::Bool(true)),
1402                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1403                    variable: "n".to_string(),
1404                    label: None,
1405                    input: None,
1406                })),
1407                pushdown_hint: None,
1408            })),
1409        }));
1410
1411        let optimized = optimizer.optimize(plan).unwrap();
1412        // Structure should be unchanged
1413        if let LogicalOperator::Return(ret) = &optimized.root
1414            && let LogicalOperator::Filter(_) = ret.input.as_ref()
1415        {
1416            return;
1417        }
1418        panic!("Expected unchanged structure");
1419    }
1420
1421    #[test]
1422    fn test_optimizer_with_join_reorder_disabled() {
1423        let optimizer = Optimizer::new().with_join_reorder(false);
1424        assert!(
1425            optimizer
1426                .optimize(LogicalPlan::new(LogicalOperator::Empty))
1427                .is_ok()
1428        );
1429    }
1430
1431    #[test]
1432    fn test_optimizer_with_cost_model() {
1433        let cost_model = CostModel::new();
1434        let optimizer = Optimizer::new().with_cost_model(cost_model);
1435        assert!(
1436            optimizer
1437                .cost_model()
1438                .estimate(&LogicalOperator::Empty, 0.0)
1439                .total()
1440                < 0.001
1441        );
1442    }
1443
1444    #[test]
1445    fn test_optimizer_with_cardinality_estimator() {
1446        let mut estimator = CardinalityEstimator::new();
1447        estimator.add_table_stats("Test", TableStats::new(500));
1448        let optimizer = Optimizer::new().with_cardinality_estimator(estimator);
1449
1450        let scan = LogicalOperator::NodeScan(NodeScanOp {
1451            variable: "n".to_string(),
1452            label: Some("Test".to_string()),
1453            input: None,
1454        });
1455        let plan = LogicalPlan::new(scan);
1456
1457        let cardinality = optimizer.estimate_cardinality(&plan);
1458        assert!((cardinality - 500.0).abs() < 0.001);
1459    }
1460
1461    #[test]
1462    fn test_optimizer_estimate_cost() {
1463        let optimizer = Optimizer::new();
1464        let plan = LogicalPlan::new(LogicalOperator::NodeScan(NodeScanOp {
1465            variable: "n".to_string(),
1466            label: None,
1467            input: None,
1468        }));
1469
1470        let cost = optimizer.estimate_cost(&plan);
1471        assert!(cost.total() > 0.0);
1472    }
1473
1474    // Filter pushdown through various operators
1475
1476    #[test]
1477    fn test_filter_pushdown_through_project() {
1478        let optimizer = Optimizer::new();
1479
1480        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1481            predicate: LogicalExpression::Binary {
1482                left: Box::new(LogicalExpression::Property {
1483                    variable: "n".to_string(),
1484                    property: "age".to_string(),
1485                }),
1486                op: BinaryOp::Gt,
1487                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1488            },
1489            pushdown_hint: None,
1490            input: Box::new(LogicalOperator::Project(ProjectOp {
1491                projections: vec![Projection {
1492                    expression: LogicalExpression::Variable("n".to_string()),
1493                    alias: None,
1494                }],
1495                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1496                    variable: "n".to_string(),
1497                    label: None,
1498                    input: None,
1499                })),
1500                pass_through_input: false,
1501            })),
1502        }));
1503
1504        let optimized = optimizer.optimize(plan).unwrap();
1505
1506        // Filter should be pushed through Project
1507        if let LogicalOperator::Project(proj) = &optimized.root
1508            && let LogicalOperator::Filter(_) = proj.input.as_ref()
1509        {
1510            return;
1511        }
1512        panic!("Expected Project -> Filter structure");
1513    }
1514
1515    #[test]
1516    fn test_filter_not_pushed_through_project_with_alias() {
1517        let optimizer = Optimizer::new();
1518
1519        // Filter on computed column 'x' should not be pushed through project that creates 'x'
1520        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1521            predicate: LogicalExpression::Binary {
1522                left: Box::new(LogicalExpression::Variable("x".to_string())),
1523                op: BinaryOp::Gt,
1524                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1525            },
1526            pushdown_hint: None,
1527            input: Box::new(LogicalOperator::Project(ProjectOp {
1528                projections: vec![Projection {
1529                    expression: LogicalExpression::Property {
1530                        variable: "n".to_string(),
1531                        property: "age".to_string(),
1532                    },
1533                    alias: Some("x".to_string()),
1534                }],
1535                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1536                    variable: "n".to_string(),
1537                    label: None,
1538                    input: None,
1539                })),
1540                pass_through_input: false,
1541            })),
1542        }));
1543
1544        let optimized = optimizer.optimize(plan).unwrap();
1545
1546        // Filter should stay above Project
1547        if let LogicalOperator::Filter(filter) = &optimized.root
1548            && let LogicalOperator::Project(_) = filter.input.as_ref()
1549        {
1550            return;
1551        }
1552        panic!("Expected Filter -> Project structure");
1553    }
1554
1555    #[test]
1556    fn test_filter_pushdown_through_limit() {
1557        let optimizer = Optimizer::new();
1558
1559        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1560            predicate: LogicalExpression::Literal(Value::Bool(true)),
1561            pushdown_hint: None,
1562            input: Box::new(LogicalOperator::Limit(LimitOp {
1563                count: 10.into(),
1564                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1565                    variable: "n".to_string(),
1566                    label: None,
1567                    input: None,
1568                })),
1569            })),
1570        }));
1571
1572        let optimized = optimizer.optimize(plan).unwrap();
1573
1574        // Filter stays above Limit (cannot be pushed through)
1575        if let LogicalOperator::Filter(filter) = &optimized.root
1576            && let LogicalOperator::Limit(_) = filter.input.as_ref()
1577        {
1578            return;
1579        }
1580        panic!("Expected Filter -> Limit structure");
1581    }
1582
1583    #[test]
1584    fn test_filter_pushdown_through_sort() {
1585        let optimizer = Optimizer::new();
1586
1587        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1588            predicate: LogicalExpression::Literal(Value::Bool(true)),
1589            pushdown_hint: None,
1590            input: Box::new(LogicalOperator::Sort(SortOp {
1591                keys: vec![SortKey {
1592                    expression: LogicalExpression::Variable("n".to_string()),
1593                    order: SortOrder::Ascending,
1594                    nulls: None,
1595                }],
1596                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1597                    variable: "n".to_string(),
1598                    label: None,
1599                    input: None,
1600                })),
1601            })),
1602        }));
1603
1604        let optimized = optimizer.optimize(plan).unwrap();
1605
1606        // Filter stays above Sort
1607        if let LogicalOperator::Filter(filter) = &optimized.root
1608            && let LogicalOperator::Sort(_) = filter.input.as_ref()
1609        {
1610            return;
1611        }
1612        panic!("Expected Filter -> Sort structure");
1613    }
1614
1615    #[test]
1616    fn test_filter_pushdown_through_distinct() {
1617        let optimizer = Optimizer::new();
1618
1619        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1620            predicate: LogicalExpression::Literal(Value::Bool(true)),
1621            pushdown_hint: None,
1622            input: Box::new(LogicalOperator::Distinct(DistinctOp {
1623                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1624                    variable: "n".to_string(),
1625                    label: None,
1626                    input: None,
1627                })),
1628                columns: None,
1629            })),
1630        }));
1631
1632        let optimized = optimizer.optimize(plan).unwrap();
1633
1634        // Filter stays above Distinct
1635        if let LogicalOperator::Filter(filter) = &optimized.root
1636            && let LogicalOperator::Distinct(_) = filter.input.as_ref()
1637        {
1638            return;
1639        }
1640        panic!("Expected Filter -> Distinct structure");
1641    }
1642
1643    #[test]
1644    fn test_filter_not_pushed_through_aggregate() {
1645        let optimizer = Optimizer::new();
1646
1647        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1648            predicate: LogicalExpression::Binary {
1649                left: Box::new(LogicalExpression::Variable("cnt".to_string())),
1650                op: BinaryOp::Gt,
1651                right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
1652            },
1653            pushdown_hint: None,
1654            input: Box::new(LogicalOperator::Aggregate(AggregateOp {
1655                group_by: vec![],
1656                aggregates: vec![AggregateExpr {
1657                    function: AggregateFunction::Count,
1658                    expression: None,
1659                    expression2: None,
1660                    distinct: false,
1661                    alias: Some("cnt".to_string()),
1662                    percentile: None,
1663                    separator: None,
1664                }],
1665                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1666                    variable: "n".to_string(),
1667                    label: None,
1668                    input: None,
1669                })),
1670                having: None,
1671            })),
1672        }));
1673
1674        let optimized = optimizer.optimize(plan).unwrap();
1675
1676        // Filter should stay above Aggregate
1677        if let LogicalOperator::Filter(filter) = &optimized.root
1678            && let LogicalOperator::Aggregate(_) = filter.input.as_ref()
1679        {
1680            return;
1681        }
1682        panic!("Expected Filter -> Aggregate structure");
1683    }
1684
1685    #[test]
1686    fn test_filter_pushdown_to_left_join_side() {
1687        let optimizer = Optimizer::new();
1688
1689        // Filter on left variable should be pushed to left side
1690        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1691            predicate: LogicalExpression::Binary {
1692                left: Box::new(LogicalExpression::Property {
1693                    variable: "a".to_string(),
1694                    property: "age".to_string(),
1695                }),
1696                op: BinaryOp::Gt,
1697                right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
1698            },
1699            pushdown_hint: None,
1700            input: Box::new(LogicalOperator::Join(JoinOp {
1701                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1702                    variable: "a".to_string(),
1703                    label: Some("Person".to_string()),
1704                    input: None,
1705                })),
1706                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1707                    variable: "b".to_string(),
1708                    label: Some("Company".to_string()),
1709                    input: None,
1710                })),
1711                join_type: JoinType::Inner,
1712                conditions: vec![],
1713            })),
1714        }));
1715
1716        let optimized = optimizer.optimize(plan).unwrap();
1717
1718        // Filter should be pushed to left side of join
1719        if let LogicalOperator::Join(join) = &optimized.root
1720            && let LogicalOperator::Filter(_) = join.left.as_ref()
1721        {
1722            return;
1723        }
1724        panic!("Expected Join with Filter on left side");
1725    }
1726
1727    #[test]
1728    fn test_filter_pushdown_to_right_join_side() {
1729        let optimizer = Optimizer::new();
1730
1731        // Filter on right variable should be pushed to right side
1732        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1733            predicate: LogicalExpression::Binary {
1734                left: Box::new(LogicalExpression::Property {
1735                    variable: "b".to_string(),
1736                    property: "name".to_string(),
1737                }),
1738                op: BinaryOp::Eq,
1739                right: Box::new(LogicalExpression::Literal(Value::String("Acme".into()))),
1740            },
1741            pushdown_hint: None,
1742            input: Box::new(LogicalOperator::Join(JoinOp {
1743                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1744                    variable: "a".to_string(),
1745                    label: Some("Person".to_string()),
1746                    input: None,
1747                })),
1748                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1749                    variable: "b".to_string(),
1750                    label: Some("Company".to_string()),
1751                    input: None,
1752                })),
1753                join_type: JoinType::Inner,
1754                conditions: vec![],
1755            })),
1756        }));
1757
1758        let optimized = optimizer.optimize(plan).unwrap();
1759
1760        // Filter should be pushed to right side of join
1761        if let LogicalOperator::Join(join) = &optimized.root
1762            && let LogicalOperator::Filter(_) = join.right.as_ref()
1763        {
1764            return;
1765        }
1766        panic!("Expected Join with Filter on right side");
1767    }
1768
1769    #[test]
1770    fn test_filter_not_pushed_when_uses_both_join_sides() {
1771        let optimizer = Optimizer::new();
1772
1773        // Filter using both variables should stay above join
1774        let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
1775            predicate: LogicalExpression::Binary {
1776                left: Box::new(LogicalExpression::Property {
1777                    variable: "a".to_string(),
1778                    property: "id".to_string(),
1779                }),
1780                op: BinaryOp::Eq,
1781                right: Box::new(LogicalExpression::Property {
1782                    variable: "b".to_string(),
1783                    property: "a_id".to_string(),
1784                }),
1785            },
1786            pushdown_hint: None,
1787            input: Box::new(LogicalOperator::Join(JoinOp {
1788                left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1789                    variable: "a".to_string(),
1790                    label: None,
1791                    input: None,
1792                })),
1793                right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1794                    variable: "b".to_string(),
1795                    label: None,
1796                    input: None,
1797                })),
1798                join_type: JoinType::Inner,
1799                conditions: vec![],
1800            })),
1801        }));
1802
1803        let optimized = optimizer.optimize(plan).unwrap();
1804
1805        // Filter should stay above join
1806        if let LogicalOperator::Filter(filter) = &optimized.root
1807            && let LogicalOperator::Join(_) = filter.input.as_ref()
1808        {
1809            return;
1810        }
1811        panic!("Expected Filter -> Join structure");
1812    }
1813
1814    // Variable extraction tests
1815
1816    #[test]
1817    fn test_extract_variables_from_variable() {
1818        let optimizer = Optimizer::new();
1819        let expr = LogicalExpression::Variable("x".to_string());
1820        let vars = optimizer.extract_variables(&expr);
1821        assert_eq!(vars.len(), 1);
1822        assert!(vars.contains("x"));
1823    }
1824
1825    #[test]
1826    fn test_extract_variables_from_unary() {
1827        let optimizer = Optimizer::new();
1828        let expr = LogicalExpression::Unary {
1829            op: UnaryOp::Not,
1830            operand: Box::new(LogicalExpression::Variable("x".to_string())),
1831        };
1832        let vars = optimizer.extract_variables(&expr);
1833        assert_eq!(vars.len(), 1);
1834        assert!(vars.contains("x"));
1835    }
1836
1837    #[test]
1838    fn test_extract_variables_from_function_call() {
1839        let optimizer = Optimizer::new();
1840        let expr = LogicalExpression::FunctionCall {
1841            name: "length".to_string(),
1842            args: vec![
1843                LogicalExpression::Variable("a".to_string()),
1844                LogicalExpression::Variable("b".to_string()),
1845            ],
1846            distinct: false,
1847        };
1848        let vars = optimizer.extract_variables(&expr);
1849        assert_eq!(vars.len(), 2);
1850        assert!(vars.contains("a"));
1851        assert!(vars.contains("b"));
1852    }
1853
1854    #[test]
1855    fn test_extract_variables_from_list() {
1856        let optimizer = Optimizer::new();
1857        let expr = LogicalExpression::List(vec![
1858            LogicalExpression::Variable("a".to_string()),
1859            LogicalExpression::Literal(Value::Int64(1)),
1860            LogicalExpression::Variable("b".to_string()),
1861        ]);
1862        let vars = optimizer.extract_variables(&expr);
1863        assert_eq!(vars.len(), 2);
1864        assert!(vars.contains("a"));
1865        assert!(vars.contains("b"));
1866    }
1867
1868    #[test]
1869    fn test_extract_variables_from_map() {
1870        let optimizer = Optimizer::new();
1871        let expr = LogicalExpression::Map(vec![
1872            (
1873                "key1".to_string(),
1874                LogicalExpression::Variable("a".to_string()),
1875            ),
1876            (
1877                "key2".to_string(),
1878                LogicalExpression::Variable("b".to_string()),
1879            ),
1880        ]);
1881        let vars = optimizer.extract_variables(&expr);
1882        assert_eq!(vars.len(), 2);
1883        assert!(vars.contains("a"));
1884        assert!(vars.contains("b"));
1885    }
1886
1887    #[test]
1888    fn test_extract_variables_from_index_access() {
1889        let optimizer = Optimizer::new();
1890        let expr = LogicalExpression::IndexAccess {
1891            base: Box::new(LogicalExpression::Variable("list".to_string())),
1892            index: Box::new(LogicalExpression::Variable("idx".to_string())),
1893        };
1894        let vars = optimizer.extract_variables(&expr);
1895        assert_eq!(vars.len(), 2);
1896        assert!(vars.contains("list"));
1897        assert!(vars.contains("idx"));
1898    }
1899
1900    #[test]
1901    fn test_extract_variables_from_slice_access() {
1902        let optimizer = Optimizer::new();
1903        let expr = LogicalExpression::SliceAccess {
1904            base: Box::new(LogicalExpression::Variable("list".to_string())),
1905            start: Some(Box::new(LogicalExpression::Variable("s".to_string()))),
1906            end: Some(Box::new(LogicalExpression::Variable("e".to_string()))),
1907        };
1908        let vars = optimizer.extract_variables(&expr);
1909        assert_eq!(vars.len(), 3);
1910        assert!(vars.contains("list"));
1911        assert!(vars.contains("s"));
1912        assert!(vars.contains("e"));
1913    }
1914
1915    #[test]
1916    fn test_extract_variables_from_case() {
1917        let optimizer = Optimizer::new();
1918        let expr = LogicalExpression::Case {
1919            operand: Some(Box::new(LogicalExpression::Variable("x".to_string()))),
1920            when_clauses: vec![(
1921                LogicalExpression::Literal(Value::Int64(1)),
1922                LogicalExpression::Variable("a".to_string()),
1923            )],
1924            else_clause: Some(Box::new(LogicalExpression::Variable("b".to_string()))),
1925        };
1926        let vars = optimizer.extract_variables(&expr);
1927        assert_eq!(vars.len(), 3);
1928        assert!(vars.contains("x"));
1929        assert!(vars.contains("a"));
1930        assert!(vars.contains("b"));
1931    }
1932
1933    #[test]
1934    fn test_extract_variables_from_labels() {
1935        let optimizer = Optimizer::new();
1936        let expr = LogicalExpression::Labels("n".to_string());
1937        let vars = optimizer.extract_variables(&expr);
1938        assert_eq!(vars.len(), 1);
1939        assert!(vars.contains("n"));
1940    }
1941
1942    #[test]
1943    fn test_extract_variables_from_type() {
1944        let optimizer = Optimizer::new();
1945        let expr = LogicalExpression::Type("e".to_string());
1946        let vars = optimizer.extract_variables(&expr);
1947        assert_eq!(vars.len(), 1);
1948        assert!(vars.contains("e"));
1949    }
1950
1951    #[test]
1952    fn test_extract_variables_from_id() {
1953        let optimizer = Optimizer::new();
1954        let expr = LogicalExpression::Id("n".to_string());
1955        let vars = optimizer.extract_variables(&expr);
1956        assert_eq!(vars.len(), 1);
1957        assert!(vars.contains("n"));
1958    }
1959
1960    #[test]
1961    fn test_extract_variables_from_list_comprehension() {
1962        let optimizer = Optimizer::new();
1963        let expr = LogicalExpression::ListComprehension {
1964            variable: "x".to_string(),
1965            list_expr: Box::new(LogicalExpression::Variable("items".to_string())),
1966            filter_expr: Some(Box::new(LogicalExpression::Variable("pred".to_string()))),
1967            map_expr: Box::new(LogicalExpression::Variable("result".to_string())),
1968        };
1969        let vars = optimizer.extract_variables(&expr);
1970        assert!(vars.contains("items"));
1971        assert!(vars.contains("pred"));
1972        assert!(vars.contains("result"));
1973    }
1974
1975    #[test]
1976    fn test_extract_variables_from_literal_and_parameter() {
1977        let optimizer = Optimizer::new();
1978
1979        let literal = LogicalExpression::Literal(Value::Int64(42));
1980        assert!(optimizer.extract_variables(&literal).is_empty());
1981
1982        let param = LogicalExpression::Parameter("p".to_string());
1983        assert!(optimizer.extract_variables(&param).is_empty());
1984    }
1985
1986    // Recursive filter pushdown tests
1987
1988    #[test]
1989    fn test_recursive_filter_pushdown_through_skip() {
1990        let optimizer = Optimizer::new();
1991
1992        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
1993            items: vec![ReturnItem {
1994                expression: LogicalExpression::Variable("n".to_string()),
1995                alias: None,
1996            }],
1997            distinct: false,
1998            input: Box::new(LogicalOperator::Filter(FilterOp {
1999                predicate: LogicalExpression::Literal(Value::Bool(true)),
2000                pushdown_hint: None,
2001                input: Box::new(LogicalOperator::Skip(SkipOp {
2002                    count: 5.into(),
2003                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2004                        variable: "n".to_string(),
2005                        label: None,
2006                        input: None,
2007                    })),
2008                })),
2009            })),
2010        }));
2011
2012        let optimized = optimizer.optimize(plan).unwrap();
2013
2014        // Verify optimization succeeded
2015        assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
2016    }
2017
2018    #[test]
2019    fn test_nested_filter_pushdown() {
2020        let optimizer = Optimizer::new();
2021
2022        // Multiple nested filters
2023        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
2024            items: vec![ReturnItem {
2025                expression: LogicalExpression::Variable("n".to_string()),
2026                alias: None,
2027            }],
2028            distinct: false,
2029            input: Box::new(LogicalOperator::Filter(FilterOp {
2030                predicate: LogicalExpression::Binary {
2031                    left: Box::new(LogicalExpression::Property {
2032                        variable: "n".to_string(),
2033                        property: "x".to_string(),
2034                    }),
2035                    op: BinaryOp::Gt,
2036                    right: Box::new(LogicalExpression::Literal(Value::Int64(1))),
2037                },
2038                pushdown_hint: None,
2039                input: Box::new(LogicalOperator::Filter(FilterOp {
2040                    predicate: LogicalExpression::Binary {
2041                        left: Box::new(LogicalExpression::Property {
2042                            variable: "n".to_string(),
2043                            property: "y".to_string(),
2044                        }),
2045                        op: BinaryOp::Lt,
2046                        right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
2047                    },
2048                    pushdown_hint: None,
2049                    input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
2050                        variable: "n".to_string(),
2051                        label: None,
2052                        input: None,
2053                    })),
2054                })),
2055            })),
2056        }));
2057
2058        let optimized = optimizer.optimize(plan).unwrap();
2059        assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
2060    }
2061
2062    #[test]
2063    fn test_cyclic_join_produces_multi_way_join() {
2064        use crate::query::plan::JoinCondition;
2065
2066        // Triangle pattern: a ⋈ b ⋈ c ⋈ a (cyclic)
2067        let scan_a = LogicalOperator::NodeScan(NodeScanOp {
2068            variable: "a".to_string(),
2069            label: Some("Person".to_string()),
2070            input: None,
2071        });
2072        let scan_b = LogicalOperator::NodeScan(NodeScanOp {
2073            variable: "b".to_string(),
2074            label: Some("Person".to_string()),
2075            input: None,
2076        });
2077        let scan_c = LogicalOperator::NodeScan(NodeScanOp {
2078            variable: "c".to_string(),
2079            label: Some("Person".to_string()),
2080            input: None,
2081        });
2082
2083        // Build: Join(Join(a, b, a=b), c, b=c) with extra condition c=a
2084        let join_ab = LogicalOperator::Join(JoinOp {
2085            left: Box::new(scan_a),
2086            right: Box::new(scan_b),
2087            join_type: JoinType::Inner,
2088            conditions: vec![JoinCondition {
2089                left: LogicalExpression::Variable("a".to_string()),
2090                right: LogicalExpression::Variable("b".to_string()),
2091            }],
2092        });
2093
2094        let join_abc = LogicalOperator::Join(JoinOp {
2095            left: Box::new(join_ab),
2096            right: Box::new(scan_c),
2097            join_type: JoinType::Inner,
2098            conditions: vec![
2099                JoinCondition {
2100                    left: LogicalExpression::Variable("b".to_string()),
2101                    right: LogicalExpression::Variable("c".to_string()),
2102                },
2103                JoinCondition {
2104                    left: LogicalExpression::Variable("c".to_string()),
2105                    right: LogicalExpression::Variable("a".to_string()),
2106                },
2107            ],
2108        });
2109
2110        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
2111            items: vec![ReturnItem {
2112                expression: LogicalExpression::Variable("a".to_string()),
2113                alias: None,
2114            }],
2115            distinct: false,
2116            input: Box::new(join_abc),
2117        }));
2118
2119        let mut optimizer = Optimizer::new();
2120        optimizer
2121            .card_estimator
2122            .add_table_stats("Person", cardinality::TableStats::new(1000));
2123
2124        let optimized = optimizer.optimize(plan).unwrap();
2125
2126        // Walk the tree to find a MultiWayJoin
2127        fn has_multi_way_join(op: &LogicalOperator) -> bool {
2128            match op {
2129                LogicalOperator::MultiWayJoin(_) => true,
2130                LogicalOperator::Return(ret) => has_multi_way_join(&ret.input),
2131                LogicalOperator::Filter(f) => has_multi_way_join(&f.input),
2132                LogicalOperator::Project(p) => has_multi_way_join(&p.input),
2133                _ => false,
2134            }
2135        }
2136
2137        assert!(
2138            has_multi_way_join(&optimized.root),
2139            "Expected MultiWayJoin for cyclic triangle pattern"
2140        );
2141    }
2142
2143    #[test]
2144    fn test_acyclic_join_uses_binary_joins() {
2145        use crate::query::plan::JoinCondition;
2146
2147        // Chain: a ⋈ b ⋈ c (acyclic)
2148        let scan_a = LogicalOperator::NodeScan(NodeScanOp {
2149            variable: "a".to_string(),
2150            label: Some("Person".to_string()),
2151            input: None,
2152        });
2153        let scan_b = LogicalOperator::NodeScan(NodeScanOp {
2154            variable: "b".to_string(),
2155            label: Some("Person".to_string()),
2156            input: None,
2157        });
2158        let scan_c = LogicalOperator::NodeScan(NodeScanOp {
2159            variable: "c".to_string(),
2160            label: Some("Company".to_string()),
2161            input: None,
2162        });
2163
2164        let join_ab = LogicalOperator::Join(JoinOp {
2165            left: Box::new(scan_a),
2166            right: Box::new(scan_b),
2167            join_type: JoinType::Inner,
2168            conditions: vec![JoinCondition {
2169                left: LogicalExpression::Variable("a".to_string()),
2170                right: LogicalExpression::Variable("b".to_string()),
2171            }],
2172        });
2173
2174        let join_abc = LogicalOperator::Join(JoinOp {
2175            left: Box::new(join_ab),
2176            right: Box::new(scan_c),
2177            join_type: JoinType::Inner,
2178            conditions: vec![JoinCondition {
2179                left: LogicalExpression::Variable("b".to_string()),
2180                right: LogicalExpression::Variable("c".to_string()),
2181            }],
2182        });
2183
2184        let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
2185            items: vec![ReturnItem {
2186                expression: LogicalExpression::Variable("a".to_string()),
2187                alias: None,
2188            }],
2189            distinct: false,
2190            input: Box::new(join_abc),
2191        }));
2192
2193        let mut optimizer = Optimizer::new();
2194        optimizer
2195            .card_estimator
2196            .add_table_stats("Person", cardinality::TableStats::new(1000));
2197        optimizer
2198            .card_estimator
2199            .add_table_stats("Company", cardinality::TableStats::new(100));
2200
2201        let optimized = optimizer.optimize(plan).unwrap();
2202
2203        // Should NOT contain MultiWayJoin for acyclic pattern
2204        fn has_multi_way_join(op: &LogicalOperator) -> bool {
2205            match op {
2206                LogicalOperator::MultiWayJoin(_) => true,
2207                LogicalOperator::Return(ret) => has_multi_way_join(&ret.input),
2208                LogicalOperator::Filter(f) => has_multi_way_join(&f.input),
2209                LogicalOperator::Project(p) => has_multi_way_join(&p.input),
2210                LogicalOperator::Join(j) => {
2211                    has_multi_way_join(&j.left) || has_multi_way_join(&j.right)
2212                }
2213                _ => false,
2214            }
2215        }
2216
2217        assert!(
2218            !has_multi_way_join(&optimized.root),
2219            "Acyclic join should NOT produce MultiWayJoin"
2220        );
2221    }
2222}