pub mod cardinality;
pub mod cost;
pub mod join_order;
pub use cardinality::{
CardinalityEstimator, ColumnStats, EstimationLog, SelectivityConfig, TableStats,
};
pub use cost::{Cost, CostModel};
pub use join_order::{BitSet, DPccp, JoinGraph, JoinGraphBuilder, JoinPlan};
use crate::query::plan::{
FilterOp, JoinCondition, LogicalExpression, LogicalOperator, LogicalPlan, MultiWayJoinOp,
};
use grafeo_common::grafeo_debug_span;
use grafeo_common::utils::error::Result;
use std::collections::HashSet;
#[derive(Debug, Clone)]
struct JoinInfo {
left_var: String,
right_var: String,
left_expr: LogicalExpression,
right_expr: LogicalExpression,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
enum RequiredColumn {
Variable(String),
Property(String, String),
}
pub struct Optimizer {
enable_filter_pushdown: bool,
enable_join_reorder: bool,
enable_projection_pushdown: bool,
cost_model: CostModel,
card_estimator: CardinalityEstimator,
}
impl Optimizer {
#[must_use]
pub fn new() -> Self {
Self {
enable_filter_pushdown: true,
enable_join_reorder: true,
enable_projection_pushdown: true,
cost_model: CostModel::new(),
card_estimator: CardinalityEstimator::new(),
}
}
#[must_use]
pub fn from_store(store: &grafeo_core::graph::lpg::LpgStore) -> Self {
store.ensure_statistics_fresh();
let stats = store.statistics();
Self::from_statistics(&stats)
}
#[must_use]
pub fn from_graph_store(store: &dyn grafeo_core::graph::GraphStore) -> Self {
let stats = store.statistics();
Self::from_statistics(&stats)
}
#[cfg(feature = "rdf")]
#[must_use]
pub fn from_rdf_statistics(rdf_stats: grafeo_core::statistics::RdfStatistics) -> Self {
let total = rdf_stats.total_triples;
let estimator = CardinalityEstimator::from_rdf_statistics(rdf_stats);
Self {
enable_filter_pushdown: true,
enable_join_reorder: true,
enable_projection_pushdown: true,
cost_model: CostModel::new().with_graph_totals(total, total),
card_estimator: estimator,
}
}
#[must_use]
fn from_statistics(stats: &grafeo_core::statistics::Statistics) -> Self {
let estimator = CardinalityEstimator::from_statistics(stats);
let avg_fanout = if stats.total_nodes > 0 {
(stats.total_edges as f64 / stats.total_nodes as f64).max(1.0)
} else {
10.0
};
let edge_type_degrees: std::collections::HashMap<String, (f64, f64)> = stats
.edge_types
.iter()
.map(|(name, et)| (name.clone(), (et.avg_out_degree, et.avg_in_degree)))
.collect();
let label_cardinalities: std::collections::HashMap<String, u64> = stats
.labels
.iter()
.map(|(name, ls)| (name.clone(), ls.node_count))
.collect();
Self {
enable_filter_pushdown: true,
enable_join_reorder: true,
enable_projection_pushdown: true,
cost_model: CostModel::new()
.with_avg_fanout(avg_fanout)
.with_edge_type_degrees(edge_type_degrees)
.with_label_cardinalities(label_cardinalities)
.with_graph_totals(stats.total_nodes, stats.total_edges),
card_estimator: estimator,
}
}
pub fn with_filter_pushdown(mut self, enabled: bool) -> Self {
self.enable_filter_pushdown = enabled;
self
}
pub fn with_join_reorder(mut self, enabled: bool) -> Self {
self.enable_join_reorder = enabled;
self
}
pub fn with_projection_pushdown(mut self, enabled: bool) -> Self {
self.enable_projection_pushdown = enabled;
self
}
pub fn with_cost_model(mut self, cost_model: CostModel) -> Self {
self.cost_model = cost_model;
self
}
pub fn with_cardinality_estimator(mut self, estimator: CardinalityEstimator) -> Self {
self.card_estimator = estimator;
self
}
pub fn with_selectivity_config(mut self, config: SelectivityConfig) -> Self {
self.card_estimator = CardinalityEstimator::with_selectivity_config(config);
self
}
pub fn cost_model(&self) -> &CostModel {
&self.cost_model
}
pub fn cardinality_estimator(&self) -> &CardinalityEstimator {
&self.card_estimator
}
pub fn estimate_cost(&self, plan: &LogicalPlan) -> Cost {
self.cost_model
.estimate_tree(&plan.root, &self.card_estimator)
}
pub fn estimate_cardinality(&self, plan: &LogicalPlan) -> f64 {
self.card_estimator.estimate(&plan.root)
}
pub fn optimize(&self, plan: LogicalPlan) -> Result<LogicalPlan> {
let _span = grafeo_debug_span!("grafeo::query::optimize");
let mut root = plan.root;
if self.enable_filter_pushdown {
root = self.push_filters_down(root);
}
if self.enable_join_reorder {
root = self.reorder_joins(root);
}
if self.enable_projection_pushdown {
root = self.push_projections_down(root);
}
Ok(LogicalPlan {
root,
explain: plan.explain,
profile: plan.profile,
default_params: plan.default_params,
})
}
fn push_projections_down(&self, op: LogicalOperator) -> LogicalOperator {
let required = self.collect_required_columns(&op);
self.push_projections_recursive(op, &required)
}
fn collect_required_columns(&self, op: &LogicalOperator) -> HashSet<RequiredColumn> {
let mut required = HashSet::new();
Self::collect_required_recursive(op, &mut required);
required
}
fn collect_required_recursive(op: &LogicalOperator, required: &mut HashSet<RequiredColumn>) {
match op {
LogicalOperator::Return(ret) => {
for item in &ret.items {
Self::collect_from_expression(&item.expression, required);
}
Self::collect_required_recursive(&ret.input, required);
}
LogicalOperator::Project(proj) => {
for p in &proj.projections {
Self::collect_from_expression(&p.expression, required);
}
Self::collect_required_recursive(&proj.input, required);
}
LogicalOperator::Filter(filter) => {
Self::collect_from_expression(&filter.predicate, required);
Self::collect_required_recursive(&filter.input, required);
}
LogicalOperator::Sort(sort) => {
for key in &sort.keys {
Self::collect_from_expression(&key.expression, required);
}
Self::collect_required_recursive(&sort.input, required);
}
LogicalOperator::Aggregate(agg) => {
for expr in &agg.group_by {
Self::collect_from_expression(expr, required);
}
for agg_expr in &agg.aggregates {
if let Some(ref expr) = agg_expr.expression {
Self::collect_from_expression(expr, required);
}
}
if let Some(ref having) = agg.having {
Self::collect_from_expression(having, required);
}
Self::collect_required_recursive(&agg.input, required);
}
LogicalOperator::Join(join) => {
for cond in &join.conditions {
Self::collect_from_expression(&cond.left, required);
Self::collect_from_expression(&cond.right, required);
}
Self::collect_required_recursive(&join.left, required);
Self::collect_required_recursive(&join.right, required);
}
LogicalOperator::Expand(expand) => {
required.insert(RequiredColumn::Variable(expand.from_variable.clone()));
required.insert(RequiredColumn::Variable(expand.to_variable.clone()));
if let Some(ref edge_var) = expand.edge_variable {
required.insert(RequiredColumn::Variable(edge_var.clone()));
}
Self::collect_required_recursive(&expand.input, required);
}
LogicalOperator::Limit(limit) => {
Self::collect_required_recursive(&limit.input, required);
}
LogicalOperator::Skip(skip) => {
Self::collect_required_recursive(&skip.input, required);
}
LogicalOperator::Distinct(distinct) => {
Self::collect_required_recursive(&distinct.input, required);
}
LogicalOperator::NodeScan(scan) => {
required.insert(RequiredColumn::Variable(scan.variable.clone()));
}
LogicalOperator::EdgeScan(scan) => {
required.insert(RequiredColumn::Variable(scan.variable.clone()));
}
LogicalOperator::MultiWayJoin(mwj) => {
for cond in &mwj.conditions {
Self::collect_from_expression(&cond.left, required);
Self::collect_from_expression(&cond.right, required);
}
for input in &mwj.inputs {
Self::collect_required_recursive(input, required);
}
}
_ => {}
}
}
fn collect_from_expression(expr: &LogicalExpression, required: &mut HashSet<RequiredColumn>) {
match expr {
LogicalExpression::Variable(var) => {
required.insert(RequiredColumn::Variable(var.clone()));
}
LogicalExpression::Property { variable, property } => {
required.insert(RequiredColumn::Property(variable.clone(), property.clone()));
required.insert(RequiredColumn::Variable(variable.clone()));
}
LogicalExpression::Binary { left, right, .. } => {
Self::collect_from_expression(left, required);
Self::collect_from_expression(right, required);
}
LogicalExpression::Unary { operand, .. } => {
Self::collect_from_expression(operand, required);
}
LogicalExpression::FunctionCall { args, .. } => {
for arg in args {
Self::collect_from_expression(arg, required);
}
}
LogicalExpression::List(items) => {
for item in items {
Self::collect_from_expression(item, required);
}
}
LogicalExpression::Map(pairs) => {
for (_, value) in pairs {
Self::collect_from_expression(value, required);
}
}
LogicalExpression::IndexAccess { base, index } => {
Self::collect_from_expression(base, required);
Self::collect_from_expression(index, required);
}
LogicalExpression::SliceAccess { base, start, end } => {
Self::collect_from_expression(base, required);
if let Some(s) = start {
Self::collect_from_expression(s, required);
}
if let Some(e) = end {
Self::collect_from_expression(e, required);
}
}
LogicalExpression::Case {
operand,
when_clauses,
else_clause,
} => {
if let Some(op) = operand {
Self::collect_from_expression(op, required);
}
for (cond, result) in when_clauses {
Self::collect_from_expression(cond, required);
Self::collect_from_expression(result, required);
}
if let Some(else_expr) = else_clause {
Self::collect_from_expression(else_expr, required);
}
}
LogicalExpression::Labels(var)
| LogicalExpression::Type(var)
| LogicalExpression::Id(var) => {
required.insert(RequiredColumn::Variable(var.clone()));
}
LogicalExpression::ListComprehension {
list_expr,
filter_expr,
map_expr,
..
} => {
Self::collect_from_expression(list_expr, required);
if let Some(filter) = filter_expr {
Self::collect_from_expression(filter, required);
}
Self::collect_from_expression(map_expr, required);
}
_ => {}
}
}
fn push_projections_recursive(
&self,
op: LogicalOperator,
required: &HashSet<RequiredColumn>,
) -> LogicalOperator {
match op {
LogicalOperator::Return(mut ret) => {
ret.input = Box::new(self.push_projections_recursive(*ret.input, required));
LogicalOperator::Return(ret)
}
LogicalOperator::Project(mut proj) => {
proj.input = Box::new(self.push_projections_recursive(*proj.input, required));
LogicalOperator::Project(proj)
}
LogicalOperator::Filter(mut filter) => {
filter.input = Box::new(self.push_projections_recursive(*filter.input, required));
LogicalOperator::Filter(filter)
}
LogicalOperator::Sort(mut sort) => {
sort.input = Box::new(self.push_projections_recursive(*sort.input, required));
LogicalOperator::Sort(sort)
}
LogicalOperator::Aggregate(mut agg) => {
agg.input = Box::new(self.push_projections_recursive(*agg.input, required));
LogicalOperator::Aggregate(agg)
}
LogicalOperator::Join(mut join) => {
let left_vars = self.collect_output_variables(&join.left);
let right_vars = self.collect_output_variables(&join.right);
let left_required: HashSet<_> = required
.iter()
.filter(|c| match c {
RequiredColumn::Variable(v) => left_vars.contains(v),
RequiredColumn::Property(v, _) => left_vars.contains(v),
})
.cloned()
.collect();
let right_required: HashSet<_> = required
.iter()
.filter(|c| match c {
RequiredColumn::Variable(v) => right_vars.contains(v),
RequiredColumn::Property(v, _) => right_vars.contains(v),
})
.cloned()
.collect();
join.left = Box::new(self.push_projections_recursive(*join.left, &left_required));
join.right =
Box::new(self.push_projections_recursive(*join.right, &right_required));
LogicalOperator::Join(join)
}
LogicalOperator::Expand(mut expand) => {
expand.input = Box::new(self.push_projections_recursive(*expand.input, required));
LogicalOperator::Expand(expand)
}
LogicalOperator::Limit(mut limit) => {
limit.input = Box::new(self.push_projections_recursive(*limit.input, required));
LogicalOperator::Limit(limit)
}
LogicalOperator::Skip(mut skip) => {
skip.input = Box::new(self.push_projections_recursive(*skip.input, required));
LogicalOperator::Skip(skip)
}
LogicalOperator::Distinct(mut distinct) => {
distinct.input =
Box::new(self.push_projections_recursive(*distinct.input, required));
LogicalOperator::Distinct(distinct)
}
LogicalOperator::MapCollect(mut mc) => {
mc.input = Box::new(self.push_projections_recursive(*mc.input, required));
LogicalOperator::MapCollect(mc)
}
LogicalOperator::MultiWayJoin(mut mwj) => {
mwj.inputs = mwj
.inputs
.into_iter()
.map(|input| self.push_projections_recursive(input, required))
.collect();
LogicalOperator::MultiWayJoin(mwj)
}
other => other,
}
}
fn reorder_joins(&self, op: LogicalOperator) -> LogicalOperator {
let op = self.reorder_joins_recursive(op);
if let Some((relations, conditions)) = self.extract_join_tree(&op)
&& relations.len() >= 2
&& let Some(optimized) = self.optimize_join_order(&relations, &conditions)
{
return optimized;
}
op
}
fn reorder_joins_recursive(&self, op: LogicalOperator) -> LogicalOperator {
match op {
LogicalOperator::Return(mut ret) => {
ret.input = Box::new(self.reorder_joins(*ret.input));
LogicalOperator::Return(ret)
}
LogicalOperator::Project(mut proj) => {
proj.input = Box::new(self.reorder_joins(*proj.input));
LogicalOperator::Project(proj)
}
LogicalOperator::Filter(mut filter) => {
filter.input = Box::new(self.reorder_joins(*filter.input));
LogicalOperator::Filter(filter)
}
LogicalOperator::Limit(mut limit) => {
limit.input = Box::new(self.reorder_joins(*limit.input));
LogicalOperator::Limit(limit)
}
LogicalOperator::Skip(mut skip) => {
skip.input = Box::new(self.reorder_joins(*skip.input));
LogicalOperator::Skip(skip)
}
LogicalOperator::Sort(mut sort) => {
sort.input = Box::new(self.reorder_joins(*sort.input));
LogicalOperator::Sort(sort)
}
LogicalOperator::Distinct(mut distinct) => {
distinct.input = Box::new(self.reorder_joins(*distinct.input));
LogicalOperator::Distinct(distinct)
}
LogicalOperator::Aggregate(mut agg) => {
agg.input = Box::new(self.reorder_joins(*agg.input));
LogicalOperator::Aggregate(agg)
}
LogicalOperator::Expand(mut expand) => {
expand.input = Box::new(self.reorder_joins(*expand.input));
LogicalOperator::Expand(expand)
}
LogicalOperator::MapCollect(mut mc) => {
mc.input = Box::new(self.reorder_joins(*mc.input));
LogicalOperator::MapCollect(mc)
}
LogicalOperator::MultiWayJoin(mut mwj) => {
mwj.inputs = mwj
.inputs
.into_iter()
.map(|input| self.reorder_joins(input))
.collect();
LogicalOperator::MultiWayJoin(mwj)
}
other => other,
}
}
fn extract_join_tree(
&self,
op: &LogicalOperator,
) -> Option<(Vec<(String, LogicalOperator)>, Vec<JoinInfo>)> {
let mut relations = Vec::new();
let mut join_conditions = Vec::new();
if !self.collect_join_tree(op, &mut relations, &mut join_conditions) {
return None;
}
if relations.len() < 2 {
return None;
}
Some((relations, join_conditions))
}
fn collect_join_tree(
&self,
op: &LogicalOperator,
relations: &mut Vec<(String, LogicalOperator)>,
conditions: &mut Vec<JoinInfo>,
) -> bool {
match op {
LogicalOperator::Join(join) => {
let left_ok = self.collect_join_tree(&join.left, relations, conditions);
let right_ok = self.collect_join_tree(&join.right, relations, conditions);
for cond in &join.conditions {
if let (Some(left_var), Some(right_var)) = (
self.extract_variable_from_expr(&cond.left),
self.extract_variable_from_expr(&cond.right),
) {
conditions.push(JoinInfo {
left_var,
right_var,
left_expr: cond.left.clone(),
right_expr: cond.right.clone(),
});
}
}
left_ok && right_ok
}
LogicalOperator::NodeScan(scan) => {
relations.push((scan.variable.clone(), op.clone()));
true
}
LogicalOperator::EdgeScan(scan) => {
relations.push((scan.variable.clone(), op.clone()));
true
}
LogicalOperator::Filter(filter) => {
self.collect_join_tree(&filter.input, relations, conditions)
}
LogicalOperator::Expand(expand) => {
relations.push((expand.to_variable.clone(), op.clone()));
true
}
_ => false,
}
}
fn extract_variable_from_expr(&self, expr: &LogicalExpression) -> Option<String> {
match expr {
LogicalExpression::Variable(v) => Some(v.clone()),
LogicalExpression::Property { variable, .. } => Some(variable.clone()),
_ => None,
}
}
fn optimize_join_order(
&self,
relations: &[(String, LogicalOperator)],
conditions: &[JoinInfo],
) -> Option<LogicalOperator> {
use join_order::{DPccp, JoinGraphBuilder};
let mut builder = JoinGraphBuilder::new();
for (var, relation) in relations {
builder.add_relation(var, relation.clone());
}
for cond in conditions {
builder.add_join_condition(
&cond.left_var,
&cond.right_var,
cond.left_expr.clone(),
cond.right_expr.clone(),
);
}
let graph = builder.build();
if graph.is_cyclic() && relations.len() >= 3 {
let mut var_counts: std::collections::HashMap<&str, usize> =
std::collections::HashMap::new();
for cond in conditions {
*var_counts.entry(&cond.left_var).or_default() += 1;
*var_counts.entry(&cond.right_var).or_default() += 1;
}
let shared_variables: Vec<String> = var_counts
.into_iter()
.filter(|(_, count)| *count >= 2)
.map(|(var, _)| var.to_string())
.collect();
let join_conditions: Vec<JoinCondition> = conditions
.iter()
.map(|c| JoinCondition {
left: c.left_expr.clone(),
right: c.right_expr.clone(),
})
.collect();
return Some(LogicalOperator::MultiWayJoin(MultiWayJoinOp {
inputs: relations.iter().map(|(_, rel)| rel.clone()).collect(),
conditions: join_conditions,
shared_variables,
}));
}
let mut dpccp = DPccp::new(&graph, &self.cost_model, &self.card_estimator);
let plan = dpccp.optimize()?;
Some(plan.operator)
}
fn push_filters_down(&self, op: LogicalOperator) -> LogicalOperator {
match op {
LogicalOperator::Filter(filter) => {
let optimized_input = self.push_filters_down(*filter.input);
self.try_push_filter_into(filter.predicate, optimized_input)
}
LogicalOperator::Return(mut ret) => {
ret.input = Box::new(self.push_filters_down(*ret.input));
LogicalOperator::Return(ret)
}
LogicalOperator::Project(mut proj) => {
proj.input = Box::new(self.push_filters_down(*proj.input));
LogicalOperator::Project(proj)
}
LogicalOperator::Limit(mut limit) => {
limit.input = Box::new(self.push_filters_down(*limit.input));
LogicalOperator::Limit(limit)
}
LogicalOperator::Skip(mut skip) => {
skip.input = Box::new(self.push_filters_down(*skip.input));
LogicalOperator::Skip(skip)
}
LogicalOperator::Sort(mut sort) => {
sort.input = Box::new(self.push_filters_down(*sort.input));
LogicalOperator::Sort(sort)
}
LogicalOperator::Distinct(mut distinct) => {
distinct.input = Box::new(self.push_filters_down(*distinct.input));
LogicalOperator::Distinct(distinct)
}
LogicalOperator::Expand(mut expand) => {
expand.input = Box::new(self.push_filters_down(*expand.input));
LogicalOperator::Expand(expand)
}
LogicalOperator::Join(mut join) => {
join.left = Box::new(self.push_filters_down(*join.left));
join.right = Box::new(self.push_filters_down(*join.right));
LogicalOperator::Join(join)
}
LogicalOperator::Aggregate(mut agg) => {
agg.input = Box::new(self.push_filters_down(*agg.input));
LogicalOperator::Aggregate(agg)
}
LogicalOperator::MapCollect(mut mc) => {
mc.input = Box::new(self.push_filters_down(*mc.input));
LogicalOperator::MapCollect(mc)
}
LogicalOperator::MultiWayJoin(mut mwj) => {
mwj.inputs = mwj
.inputs
.into_iter()
.map(|input| self.push_filters_down(input))
.collect();
LogicalOperator::MultiWayJoin(mwj)
}
other => other,
}
}
fn try_push_filter_into(
&self,
predicate: LogicalExpression,
op: LogicalOperator,
) -> LogicalOperator {
match op {
LogicalOperator::Project(mut proj) => {
let predicate_vars = self.extract_variables(&predicate);
let computed_vars = self.extract_projection_aliases(&proj.projections);
if predicate_vars.is_disjoint(&computed_vars) {
proj.input = Box::new(self.try_push_filter_into(predicate, *proj.input));
LogicalOperator::Project(proj)
} else {
LogicalOperator::Filter(FilterOp {
predicate,
pushdown_hint: None,
input: Box::new(LogicalOperator::Project(proj)),
})
}
}
LogicalOperator::Return(mut ret) => {
ret.input = Box::new(self.try_push_filter_into(predicate, *ret.input));
LogicalOperator::Return(ret)
}
LogicalOperator::Expand(mut expand) => {
let predicate_vars = self.extract_variables(&predicate);
let mut introduced_vars = vec![&expand.to_variable];
if let Some(ref edge_var) = expand.edge_variable {
introduced_vars.push(edge_var);
}
if let Some(ref path_alias) = expand.path_alias {
introduced_vars.push(path_alias);
}
let uses_introduced_vars =
predicate_vars.iter().any(|v| introduced_vars.contains(&v));
if !uses_introduced_vars {
expand.input = Box::new(self.try_push_filter_into(predicate, *expand.input));
LogicalOperator::Expand(expand)
} else {
LogicalOperator::Filter(FilterOp {
predicate,
pushdown_hint: None,
input: Box::new(LogicalOperator::Expand(expand)),
})
}
}
LogicalOperator::Join(mut join) => {
let predicate_vars = self.extract_variables(&predicate);
let left_vars = self.collect_output_variables(&join.left);
let right_vars = self.collect_output_variables(&join.right);
let uses_left = predicate_vars.iter().any(|v| left_vars.contains(v));
let uses_right = predicate_vars.iter().any(|v| right_vars.contains(v));
if uses_left && !uses_right {
join.left = Box::new(self.try_push_filter_into(predicate, *join.left));
LogicalOperator::Join(join)
} else if uses_right && !uses_left {
join.right = Box::new(self.try_push_filter_into(predicate, *join.right));
LogicalOperator::Join(join)
} else {
LogicalOperator::Filter(FilterOp {
predicate,
pushdown_hint: None,
input: Box::new(LogicalOperator::Join(join)),
})
}
}
LogicalOperator::Aggregate(agg) => LogicalOperator::Filter(FilterOp {
predicate,
pushdown_hint: None,
input: Box::new(LogicalOperator::Aggregate(agg)),
}),
LogicalOperator::NodeScan(scan) => LogicalOperator::Filter(FilterOp {
predicate,
pushdown_hint: None,
input: Box::new(LogicalOperator::NodeScan(scan)),
}),
other => LogicalOperator::Filter(FilterOp {
predicate,
pushdown_hint: None,
input: Box::new(other),
}),
}
}
fn collect_output_variables(&self, op: &LogicalOperator) -> HashSet<String> {
let mut vars = HashSet::new();
Self::collect_output_variables_recursive(op, &mut vars);
vars
}
fn collect_output_variables_recursive(op: &LogicalOperator, vars: &mut HashSet<String>) {
match op {
LogicalOperator::NodeScan(scan) => {
vars.insert(scan.variable.clone());
}
LogicalOperator::EdgeScan(scan) => {
vars.insert(scan.variable.clone());
}
LogicalOperator::Expand(expand) => {
vars.insert(expand.to_variable.clone());
if let Some(edge_var) = &expand.edge_variable {
vars.insert(edge_var.clone());
}
Self::collect_output_variables_recursive(&expand.input, vars);
}
LogicalOperator::Filter(filter) => {
Self::collect_output_variables_recursive(&filter.input, vars);
}
LogicalOperator::Project(proj) => {
for p in &proj.projections {
if let Some(alias) = &p.alias {
vars.insert(alias.clone());
}
}
Self::collect_output_variables_recursive(&proj.input, vars);
}
LogicalOperator::Join(join) => {
Self::collect_output_variables_recursive(&join.left, vars);
Self::collect_output_variables_recursive(&join.right, vars);
}
LogicalOperator::Aggregate(agg) => {
for expr in &agg.group_by {
Self::collect_variables(expr, vars);
}
for agg_expr in &agg.aggregates {
if let Some(alias) = &agg_expr.alias {
vars.insert(alias.clone());
}
}
}
LogicalOperator::Return(ret) => {
Self::collect_output_variables_recursive(&ret.input, vars);
}
LogicalOperator::Limit(limit) => {
Self::collect_output_variables_recursive(&limit.input, vars);
}
LogicalOperator::Skip(skip) => {
Self::collect_output_variables_recursive(&skip.input, vars);
}
LogicalOperator::Sort(sort) => {
Self::collect_output_variables_recursive(&sort.input, vars);
}
LogicalOperator::Distinct(distinct) => {
Self::collect_output_variables_recursive(&distinct.input, vars);
}
_ => {}
}
}
fn extract_variables(&self, expr: &LogicalExpression) -> HashSet<String> {
let mut vars = HashSet::new();
Self::collect_variables(expr, &mut vars);
vars
}
fn collect_variables(expr: &LogicalExpression, vars: &mut HashSet<String>) {
match expr {
LogicalExpression::Variable(name) => {
vars.insert(name.clone());
}
LogicalExpression::Property { variable, .. } => {
vars.insert(variable.clone());
}
LogicalExpression::Binary { left, right, .. } => {
Self::collect_variables(left, vars);
Self::collect_variables(right, vars);
}
LogicalExpression::Unary { operand, .. } => {
Self::collect_variables(operand, vars);
}
LogicalExpression::FunctionCall { args, .. } => {
for arg in args {
Self::collect_variables(arg, vars);
}
}
LogicalExpression::List(items) => {
for item in items {
Self::collect_variables(item, vars);
}
}
LogicalExpression::Map(pairs) => {
for (_, value) in pairs {
Self::collect_variables(value, vars);
}
}
LogicalExpression::IndexAccess { base, index } => {
Self::collect_variables(base, vars);
Self::collect_variables(index, vars);
}
LogicalExpression::SliceAccess { base, start, end } => {
Self::collect_variables(base, vars);
if let Some(s) = start {
Self::collect_variables(s, vars);
}
if let Some(e) = end {
Self::collect_variables(e, vars);
}
}
LogicalExpression::Case {
operand,
when_clauses,
else_clause,
} => {
if let Some(op) = operand {
Self::collect_variables(op, vars);
}
for (cond, result) in when_clauses {
Self::collect_variables(cond, vars);
Self::collect_variables(result, vars);
}
if let Some(else_expr) = else_clause {
Self::collect_variables(else_expr, vars);
}
}
LogicalExpression::Labels(var)
| LogicalExpression::Type(var)
| LogicalExpression::Id(var) => {
vars.insert(var.clone());
}
LogicalExpression::Literal(_) | LogicalExpression::Parameter(_) => {}
LogicalExpression::ListComprehension {
list_expr,
filter_expr,
map_expr,
..
} => {
Self::collect_variables(list_expr, vars);
if let Some(filter) = filter_expr {
Self::collect_variables(filter, vars);
}
Self::collect_variables(map_expr, vars);
}
LogicalExpression::ListPredicate {
list_expr,
predicate,
..
} => {
Self::collect_variables(list_expr, vars);
Self::collect_variables(predicate, vars);
}
LogicalExpression::ExistsSubquery(_)
| LogicalExpression::CountSubquery(_)
| LogicalExpression::ValueSubquery(_) => {
}
LogicalExpression::PatternComprehension { projection, .. } => {
Self::collect_variables(projection, vars);
}
LogicalExpression::MapProjection { base, entries } => {
vars.insert(base.clone());
for entry in entries {
if let crate::query::plan::MapProjectionEntry::LiteralEntry(_, expr) = entry {
Self::collect_variables(expr, vars);
}
}
}
LogicalExpression::Reduce {
initial,
list,
expression,
..
} => {
Self::collect_variables(initial, vars);
Self::collect_variables(list, vars);
Self::collect_variables(expression, vars);
}
}
}
fn extract_projection_aliases(
&self,
projections: &[crate::query::plan::Projection],
) -> HashSet<String> {
projections.iter().filter_map(|p| p.alias.clone()).collect()
}
}
impl Default for Optimizer {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::query::plan::{
AggregateExpr, AggregateFunction, AggregateOp, BinaryOp, DistinctOp, ExpandDirection,
ExpandOp, JoinOp, JoinType, LimitOp, NodeScanOp, PathMode, ProjectOp, Projection,
ReturnItem, ReturnOp, SkipOp, SortKey, SortOp, SortOrder, UnaryOp,
};
use grafeo_common::types::Value;
#[test]
fn test_optimizer_filter_pushdown_simple() {
let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
items: vec![ReturnItem {
expression: LogicalExpression::Variable("n".to_string()),
alias: None,
}],
distinct: false,
input: Box::new(LogicalOperator::Filter(FilterOp {
predicate: LogicalExpression::Binary {
left: Box::new(LogicalExpression::Property {
variable: "n".to_string(),
property: "age".to_string(),
}),
op: BinaryOp::Gt,
right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
},
input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
variable: "n".to_string(),
label: Some("Person".to_string()),
input: None,
})),
pushdown_hint: None,
})),
}));
let optimizer = Optimizer::new();
let optimized = optimizer.optimize(plan).unwrap();
if let LogicalOperator::Return(ret) = &optimized.root
&& let LogicalOperator::Filter(filter) = ret.input.as_ref()
&& let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
{
assert_eq!(scan.variable, "n");
return;
}
panic!("Expected Return -> Filter -> NodeScan structure");
}
#[test]
fn test_optimizer_filter_pushdown_through_expand() {
let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
items: vec![ReturnItem {
expression: LogicalExpression::Variable("b".to_string()),
alias: None,
}],
distinct: false,
input: Box::new(LogicalOperator::Filter(FilterOp {
predicate: LogicalExpression::Binary {
left: Box::new(LogicalExpression::Property {
variable: "a".to_string(),
property: "age".to_string(),
}),
op: BinaryOp::Gt,
right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
},
pushdown_hint: None,
input: Box::new(LogicalOperator::Expand(ExpandOp {
from_variable: "a".to_string(),
to_variable: "b".to_string(),
edge_variable: None,
direction: ExpandDirection::Outgoing,
edge_types: vec!["KNOWS".to_string()],
min_hops: 1,
max_hops: Some(1),
input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
variable: "a".to_string(),
label: Some("Person".to_string()),
input: None,
})),
path_alias: None,
path_mode: PathMode::Walk,
})),
})),
}));
let optimizer = Optimizer::new();
let optimized = optimizer.optimize(plan).unwrap();
if let LogicalOperator::Return(ret) = &optimized.root
&& let LogicalOperator::Expand(expand) = ret.input.as_ref()
&& let LogicalOperator::Filter(filter) = expand.input.as_ref()
&& let LogicalOperator::NodeScan(scan) = filter.input.as_ref()
{
assert_eq!(scan.variable, "a");
assert_eq!(expand.from_variable, "a");
assert_eq!(expand.to_variable, "b");
return;
}
panic!("Expected Return -> Expand -> Filter -> NodeScan structure");
}
#[test]
fn test_optimizer_filter_not_pushed_through_expand_for_target_var() {
let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
items: vec![ReturnItem {
expression: LogicalExpression::Variable("a".to_string()),
alias: None,
}],
distinct: false,
input: Box::new(LogicalOperator::Filter(FilterOp {
predicate: LogicalExpression::Binary {
left: Box::new(LogicalExpression::Property {
variable: "b".to_string(),
property: "age".to_string(),
}),
op: BinaryOp::Gt,
right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
},
pushdown_hint: None,
input: Box::new(LogicalOperator::Expand(ExpandOp {
from_variable: "a".to_string(),
to_variable: "b".to_string(),
edge_variable: None,
direction: ExpandDirection::Outgoing,
edge_types: vec!["KNOWS".to_string()],
min_hops: 1,
max_hops: Some(1),
input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
variable: "a".to_string(),
label: Some("Person".to_string()),
input: None,
})),
path_alias: None,
path_mode: PathMode::Walk,
})),
})),
}));
let optimizer = Optimizer::new();
let optimized = optimizer.optimize(plan).unwrap();
if let LogicalOperator::Return(ret) = &optimized.root
&& let LogicalOperator::Filter(filter) = ret.input.as_ref()
{
if let LogicalExpression::Binary { left, .. } = &filter.predicate
&& let LogicalExpression::Property { variable, .. } = left.as_ref()
{
assert_eq!(variable, "b");
}
if let LogicalOperator::Expand(expand) = filter.input.as_ref()
&& let LogicalOperator::NodeScan(_) = expand.input.as_ref()
{
return;
}
}
panic!("Expected Return -> Filter -> Expand -> NodeScan structure");
}
#[test]
fn test_optimizer_extract_variables() {
let optimizer = Optimizer::new();
let expr = LogicalExpression::Binary {
left: Box::new(LogicalExpression::Property {
variable: "n".to_string(),
property: "age".to_string(),
}),
op: BinaryOp::Gt,
right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
};
let vars = optimizer.extract_variables(&expr);
assert_eq!(vars.len(), 1);
assert!(vars.contains("n"));
}
#[test]
fn test_optimizer_default() {
let optimizer = Optimizer::default();
let plan = LogicalPlan::new(LogicalOperator::Empty);
let result = optimizer.optimize(plan);
assert!(result.is_ok());
}
#[test]
fn test_optimizer_with_filter_pushdown_disabled() {
let optimizer = Optimizer::new().with_filter_pushdown(false);
let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
items: vec![ReturnItem {
expression: LogicalExpression::Variable("n".to_string()),
alias: None,
}],
distinct: false,
input: Box::new(LogicalOperator::Filter(FilterOp {
predicate: LogicalExpression::Literal(Value::Bool(true)),
input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
variable: "n".to_string(),
label: None,
input: None,
})),
pushdown_hint: None,
})),
}));
let optimized = optimizer.optimize(plan).unwrap();
if let LogicalOperator::Return(ret) = &optimized.root
&& let LogicalOperator::Filter(_) = ret.input.as_ref()
{
return;
}
panic!("Expected unchanged structure");
}
#[test]
fn test_optimizer_with_join_reorder_disabled() {
let optimizer = Optimizer::new().with_join_reorder(false);
assert!(
optimizer
.optimize(LogicalPlan::new(LogicalOperator::Empty))
.is_ok()
);
}
#[test]
fn test_optimizer_with_cost_model() {
let cost_model = CostModel::new();
let optimizer = Optimizer::new().with_cost_model(cost_model);
assert!(
optimizer
.cost_model()
.estimate(&LogicalOperator::Empty, 0.0)
.total()
< 0.001
);
}
#[test]
fn test_optimizer_with_cardinality_estimator() {
let mut estimator = CardinalityEstimator::new();
estimator.add_table_stats("Test", TableStats::new(500));
let optimizer = Optimizer::new().with_cardinality_estimator(estimator);
let scan = LogicalOperator::NodeScan(NodeScanOp {
variable: "n".to_string(),
label: Some("Test".to_string()),
input: None,
});
let plan = LogicalPlan::new(scan);
let cardinality = optimizer.estimate_cardinality(&plan);
assert!((cardinality - 500.0).abs() < 0.001);
}
#[test]
fn test_optimizer_estimate_cost() {
let optimizer = Optimizer::new();
let plan = LogicalPlan::new(LogicalOperator::NodeScan(NodeScanOp {
variable: "n".to_string(),
label: None,
input: None,
}));
let cost = optimizer.estimate_cost(&plan);
assert!(cost.total() > 0.0);
}
#[test]
fn test_filter_pushdown_through_project() {
let optimizer = Optimizer::new();
let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
predicate: LogicalExpression::Binary {
left: Box::new(LogicalExpression::Property {
variable: "n".to_string(),
property: "age".to_string(),
}),
op: BinaryOp::Gt,
right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
},
pushdown_hint: None,
input: Box::new(LogicalOperator::Project(ProjectOp {
projections: vec![Projection {
expression: LogicalExpression::Variable("n".to_string()),
alias: None,
}],
input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
variable: "n".to_string(),
label: None,
input: None,
})),
pass_through_input: false,
})),
}));
let optimized = optimizer.optimize(plan).unwrap();
if let LogicalOperator::Project(proj) = &optimized.root
&& let LogicalOperator::Filter(_) = proj.input.as_ref()
{
return;
}
panic!("Expected Project -> Filter structure");
}
#[test]
fn test_filter_not_pushed_through_project_with_alias() {
let optimizer = Optimizer::new();
let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
predicate: LogicalExpression::Binary {
left: Box::new(LogicalExpression::Variable("x".to_string())),
op: BinaryOp::Gt,
right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
},
pushdown_hint: None,
input: Box::new(LogicalOperator::Project(ProjectOp {
projections: vec![Projection {
expression: LogicalExpression::Property {
variable: "n".to_string(),
property: "age".to_string(),
},
alias: Some("x".to_string()),
}],
input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
variable: "n".to_string(),
label: None,
input: None,
})),
pass_through_input: false,
})),
}));
let optimized = optimizer.optimize(plan).unwrap();
if let LogicalOperator::Filter(filter) = &optimized.root
&& let LogicalOperator::Project(_) = filter.input.as_ref()
{
return;
}
panic!("Expected Filter -> Project structure");
}
#[test]
fn test_filter_pushdown_through_limit() {
let optimizer = Optimizer::new();
let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
predicate: LogicalExpression::Literal(Value::Bool(true)),
pushdown_hint: None,
input: Box::new(LogicalOperator::Limit(LimitOp {
count: 10.into(),
input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
variable: "n".to_string(),
label: None,
input: None,
})),
})),
}));
let optimized = optimizer.optimize(plan).unwrap();
if let LogicalOperator::Filter(filter) = &optimized.root
&& let LogicalOperator::Limit(_) = filter.input.as_ref()
{
return;
}
panic!("Expected Filter -> Limit structure");
}
#[test]
fn test_filter_pushdown_through_sort() {
let optimizer = Optimizer::new();
let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
predicate: LogicalExpression::Literal(Value::Bool(true)),
pushdown_hint: None,
input: Box::new(LogicalOperator::Sort(SortOp {
keys: vec![SortKey {
expression: LogicalExpression::Variable("n".to_string()),
order: SortOrder::Ascending,
nulls: None,
}],
input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
variable: "n".to_string(),
label: None,
input: None,
})),
})),
}));
let optimized = optimizer.optimize(plan).unwrap();
if let LogicalOperator::Filter(filter) = &optimized.root
&& let LogicalOperator::Sort(_) = filter.input.as_ref()
{
return;
}
panic!("Expected Filter -> Sort structure");
}
#[test]
fn test_filter_pushdown_through_distinct() {
let optimizer = Optimizer::new();
let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
predicate: LogicalExpression::Literal(Value::Bool(true)),
pushdown_hint: None,
input: Box::new(LogicalOperator::Distinct(DistinctOp {
input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
variable: "n".to_string(),
label: None,
input: None,
})),
columns: None,
})),
}));
let optimized = optimizer.optimize(plan).unwrap();
if let LogicalOperator::Filter(filter) = &optimized.root
&& let LogicalOperator::Distinct(_) = filter.input.as_ref()
{
return;
}
panic!("Expected Filter -> Distinct structure");
}
#[test]
fn test_filter_not_pushed_through_aggregate() {
let optimizer = Optimizer::new();
let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
predicate: LogicalExpression::Binary {
left: Box::new(LogicalExpression::Variable("cnt".to_string())),
op: BinaryOp::Gt,
right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
},
pushdown_hint: None,
input: Box::new(LogicalOperator::Aggregate(AggregateOp {
group_by: vec![],
aggregates: vec![AggregateExpr {
function: AggregateFunction::Count,
expression: None,
expression2: None,
distinct: false,
alias: Some("cnt".to_string()),
percentile: None,
separator: None,
}],
input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
variable: "n".to_string(),
label: None,
input: None,
})),
having: None,
})),
}));
let optimized = optimizer.optimize(plan).unwrap();
if let LogicalOperator::Filter(filter) = &optimized.root
&& let LogicalOperator::Aggregate(_) = filter.input.as_ref()
{
return;
}
panic!("Expected Filter -> Aggregate structure");
}
#[test]
fn test_filter_pushdown_to_left_join_side() {
let optimizer = Optimizer::new();
let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
predicate: LogicalExpression::Binary {
left: Box::new(LogicalExpression::Property {
variable: "a".to_string(),
property: "age".to_string(),
}),
op: BinaryOp::Gt,
right: Box::new(LogicalExpression::Literal(Value::Int64(30))),
},
pushdown_hint: None,
input: Box::new(LogicalOperator::Join(JoinOp {
left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
variable: "a".to_string(),
label: Some("Person".to_string()),
input: None,
})),
right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
variable: "b".to_string(),
label: Some("Company".to_string()),
input: None,
})),
join_type: JoinType::Inner,
conditions: vec![],
})),
}));
let optimized = optimizer.optimize(plan).unwrap();
if let LogicalOperator::Join(join) = &optimized.root
&& let LogicalOperator::Filter(_) = join.left.as_ref()
{
return;
}
panic!("Expected Join with Filter on left side");
}
#[test]
fn test_filter_pushdown_to_right_join_side() {
let optimizer = Optimizer::new();
let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
predicate: LogicalExpression::Binary {
left: Box::new(LogicalExpression::Property {
variable: "b".to_string(),
property: "name".to_string(),
}),
op: BinaryOp::Eq,
right: Box::new(LogicalExpression::Literal(Value::String("Acme".into()))),
},
pushdown_hint: None,
input: Box::new(LogicalOperator::Join(JoinOp {
left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
variable: "a".to_string(),
label: Some("Person".to_string()),
input: None,
})),
right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
variable: "b".to_string(),
label: Some("Company".to_string()),
input: None,
})),
join_type: JoinType::Inner,
conditions: vec![],
})),
}));
let optimized = optimizer.optimize(plan).unwrap();
if let LogicalOperator::Join(join) = &optimized.root
&& let LogicalOperator::Filter(_) = join.right.as_ref()
{
return;
}
panic!("Expected Join with Filter on right side");
}
#[test]
fn test_filter_not_pushed_when_uses_both_join_sides() {
let optimizer = Optimizer::new();
let plan = LogicalPlan::new(LogicalOperator::Filter(FilterOp {
predicate: LogicalExpression::Binary {
left: Box::new(LogicalExpression::Property {
variable: "a".to_string(),
property: "id".to_string(),
}),
op: BinaryOp::Eq,
right: Box::new(LogicalExpression::Property {
variable: "b".to_string(),
property: "a_id".to_string(),
}),
},
pushdown_hint: None,
input: Box::new(LogicalOperator::Join(JoinOp {
left: Box::new(LogicalOperator::NodeScan(NodeScanOp {
variable: "a".to_string(),
label: None,
input: None,
})),
right: Box::new(LogicalOperator::NodeScan(NodeScanOp {
variable: "b".to_string(),
label: None,
input: None,
})),
join_type: JoinType::Inner,
conditions: vec![],
})),
}));
let optimized = optimizer.optimize(plan).unwrap();
if let LogicalOperator::Filter(filter) = &optimized.root
&& let LogicalOperator::Join(_) = filter.input.as_ref()
{
return;
}
panic!("Expected Filter -> Join structure");
}
#[test]
fn test_extract_variables_from_variable() {
let optimizer = Optimizer::new();
let expr = LogicalExpression::Variable("x".to_string());
let vars = optimizer.extract_variables(&expr);
assert_eq!(vars.len(), 1);
assert!(vars.contains("x"));
}
#[test]
fn test_extract_variables_from_unary() {
let optimizer = Optimizer::new();
let expr = LogicalExpression::Unary {
op: UnaryOp::Not,
operand: Box::new(LogicalExpression::Variable("x".to_string())),
};
let vars = optimizer.extract_variables(&expr);
assert_eq!(vars.len(), 1);
assert!(vars.contains("x"));
}
#[test]
fn test_extract_variables_from_function_call() {
let optimizer = Optimizer::new();
let expr = LogicalExpression::FunctionCall {
name: "length".to_string(),
args: vec![
LogicalExpression::Variable("a".to_string()),
LogicalExpression::Variable("b".to_string()),
],
distinct: false,
};
let vars = optimizer.extract_variables(&expr);
assert_eq!(vars.len(), 2);
assert!(vars.contains("a"));
assert!(vars.contains("b"));
}
#[test]
fn test_extract_variables_from_list() {
let optimizer = Optimizer::new();
let expr = LogicalExpression::List(vec![
LogicalExpression::Variable("a".to_string()),
LogicalExpression::Literal(Value::Int64(1)),
LogicalExpression::Variable("b".to_string()),
]);
let vars = optimizer.extract_variables(&expr);
assert_eq!(vars.len(), 2);
assert!(vars.contains("a"));
assert!(vars.contains("b"));
}
#[test]
fn test_extract_variables_from_map() {
let optimizer = Optimizer::new();
let expr = LogicalExpression::Map(vec![
(
"key1".to_string(),
LogicalExpression::Variable("a".to_string()),
),
(
"key2".to_string(),
LogicalExpression::Variable("b".to_string()),
),
]);
let vars = optimizer.extract_variables(&expr);
assert_eq!(vars.len(), 2);
assert!(vars.contains("a"));
assert!(vars.contains("b"));
}
#[test]
fn test_extract_variables_from_index_access() {
let optimizer = Optimizer::new();
let expr = LogicalExpression::IndexAccess {
base: Box::new(LogicalExpression::Variable("list".to_string())),
index: Box::new(LogicalExpression::Variable("idx".to_string())),
};
let vars = optimizer.extract_variables(&expr);
assert_eq!(vars.len(), 2);
assert!(vars.contains("list"));
assert!(vars.contains("idx"));
}
#[test]
fn test_extract_variables_from_slice_access() {
let optimizer = Optimizer::new();
let expr = LogicalExpression::SliceAccess {
base: Box::new(LogicalExpression::Variable("list".to_string())),
start: Some(Box::new(LogicalExpression::Variable("s".to_string()))),
end: Some(Box::new(LogicalExpression::Variable("e".to_string()))),
};
let vars = optimizer.extract_variables(&expr);
assert_eq!(vars.len(), 3);
assert!(vars.contains("list"));
assert!(vars.contains("s"));
assert!(vars.contains("e"));
}
#[test]
fn test_extract_variables_from_case() {
let optimizer = Optimizer::new();
let expr = LogicalExpression::Case {
operand: Some(Box::new(LogicalExpression::Variable("x".to_string()))),
when_clauses: vec![(
LogicalExpression::Literal(Value::Int64(1)),
LogicalExpression::Variable("a".to_string()),
)],
else_clause: Some(Box::new(LogicalExpression::Variable("b".to_string()))),
};
let vars = optimizer.extract_variables(&expr);
assert_eq!(vars.len(), 3);
assert!(vars.contains("x"));
assert!(vars.contains("a"));
assert!(vars.contains("b"));
}
#[test]
fn test_extract_variables_from_labels() {
let optimizer = Optimizer::new();
let expr = LogicalExpression::Labels("n".to_string());
let vars = optimizer.extract_variables(&expr);
assert_eq!(vars.len(), 1);
assert!(vars.contains("n"));
}
#[test]
fn test_extract_variables_from_type() {
let optimizer = Optimizer::new();
let expr = LogicalExpression::Type("e".to_string());
let vars = optimizer.extract_variables(&expr);
assert_eq!(vars.len(), 1);
assert!(vars.contains("e"));
}
#[test]
fn test_extract_variables_from_id() {
let optimizer = Optimizer::new();
let expr = LogicalExpression::Id("n".to_string());
let vars = optimizer.extract_variables(&expr);
assert_eq!(vars.len(), 1);
assert!(vars.contains("n"));
}
#[test]
fn test_extract_variables_from_list_comprehension() {
let optimizer = Optimizer::new();
let expr = LogicalExpression::ListComprehension {
variable: "x".to_string(),
list_expr: Box::new(LogicalExpression::Variable("items".to_string())),
filter_expr: Some(Box::new(LogicalExpression::Variable("pred".to_string()))),
map_expr: Box::new(LogicalExpression::Variable("result".to_string())),
};
let vars = optimizer.extract_variables(&expr);
assert!(vars.contains("items"));
assert!(vars.contains("pred"));
assert!(vars.contains("result"));
}
#[test]
fn test_extract_variables_from_literal_and_parameter() {
let optimizer = Optimizer::new();
let literal = LogicalExpression::Literal(Value::Int64(42));
assert!(optimizer.extract_variables(&literal).is_empty());
let param = LogicalExpression::Parameter("p".to_string());
assert!(optimizer.extract_variables(¶m).is_empty());
}
#[test]
fn test_recursive_filter_pushdown_through_skip() {
let optimizer = Optimizer::new();
let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
items: vec![ReturnItem {
expression: LogicalExpression::Variable("n".to_string()),
alias: None,
}],
distinct: false,
input: Box::new(LogicalOperator::Filter(FilterOp {
predicate: LogicalExpression::Literal(Value::Bool(true)),
pushdown_hint: None,
input: Box::new(LogicalOperator::Skip(SkipOp {
count: 5.into(),
input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
variable: "n".to_string(),
label: None,
input: None,
})),
})),
})),
}));
let optimized = optimizer.optimize(plan).unwrap();
assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
}
#[test]
fn test_nested_filter_pushdown() {
let optimizer = Optimizer::new();
let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
items: vec![ReturnItem {
expression: LogicalExpression::Variable("n".to_string()),
alias: None,
}],
distinct: false,
input: Box::new(LogicalOperator::Filter(FilterOp {
predicate: LogicalExpression::Binary {
left: Box::new(LogicalExpression::Property {
variable: "n".to_string(),
property: "x".to_string(),
}),
op: BinaryOp::Gt,
right: Box::new(LogicalExpression::Literal(Value::Int64(1))),
},
pushdown_hint: None,
input: Box::new(LogicalOperator::Filter(FilterOp {
predicate: LogicalExpression::Binary {
left: Box::new(LogicalExpression::Property {
variable: "n".to_string(),
property: "y".to_string(),
}),
op: BinaryOp::Lt,
right: Box::new(LogicalExpression::Literal(Value::Int64(10))),
},
pushdown_hint: None,
input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
variable: "n".to_string(),
label: None,
input: None,
})),
})),
})),
}));
let optimized = optimizer.optimize(plan).unwrap();
assert!(matches!(&optimized.root, LogicalOperator::Return(_)));
}
#[test]
fn test_cyclic_join_produces_multi_way_join() {
use crate::query::plan::JoinCondition;
let scan_a = LogicalOperator::NodeScan(NodeScanOp {
variable: "a".to_string(),
label: Some("Person".to_string()),
input: None,
});
let scan_b = LogicalOperator::NodeScan(NodeScanOp {
variable: "b".to_string(),
label: Some("Person".to_string()),
input: None,
});
let scan_c = LogicalOperator::NodeScan(NodeScanOp {
variable: "c".to_string(),
label: Some("Person".to_string()),
input: None,
});
let join_ab = LogicalOperator::Join(JoinOp {
left: Box::new(scan_a),
right: Box::new(scan_b),
join_type: JoinType::Inner,
conditions: vec![JoinCondition {
left: LogicalExpression::Variable("a".to_string()),
right: LogicalExpression::Variable("b".to_string()),
}],
});
let join_abc = LogicalOperator::Join(JoinOp {
left: Box::new(join_ab),
right: Box::new(scan_c),
join_type: JoinType::Inner,
conditions: vec![
JoinCondition {
left: LogicalExpression::Variable("b".to_string()),
right: LogicalExpression::Variable("c".to_string()),
},
JoinCondition {
left: LogicalExpression::Variable("c".to_string()),
right: LogicalExpression::Variable("a".to_string()),
},
],
});
let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
items: vec![ReturnItem {
expression: LogicalExpression::Variable("a".to_string()),
alias: None,
}],
distinct: false,
input: Box::new(join_abc),
}));
let mut optimizer = Optimizer::new();
optimizer
.card_estimator
.add_table_stats("Person", cardinality::TableStats::new(1000));
let optimized = optimizer.optimize(plan).unwrap();
fn has_multi_way_join(op: &LogicalOperator) -> bool {
match op {
LogicalOperator::MultiWayJoin(_) => true,
LogicalOperator::Return(ret) => has_multi_way_join(&ret.input),
LogicalOperator::Filter(f) => has_multi_way_join(&f.input),
LogicalOperator::Project(p) => has_multi_way_join(&p.input),
_ => false,
}
}
assert!(
has_multi_way_join(&optimized.root),
"Expected MultiWayJoin for cyclic triangle pattern"
);
}
#[test]
fn test_acyclic_join_uses_binary_joins() {
use crate::query::plan::JoinCondition;
let scan_a = LogicalOperator::NodeScan(NodeScanOp {
variable: "a".to_string(),
label: Some("Person".to_string()),
input: None,
});
let scan_b = LogicalOperator::NodeScan(NodeScanOp {
variable: "b".to_string(),
label: Some("Person".to_string()),
input: None,
});
let scan_c = LogicalOperator::NodeScan(NodeScanOp {
variable: "c".to_string(),
label: Some("Company".to_string()),
input: None,
});
let join_ab = LogicalOperator::Join(JoinOp {
left: Box::new(scan_a),
right: Box::new(scan_b),
join_type: JoinType::Inner,
conditions: vec![JoinCondition {
left: LogicalExpression::Variable("a".to_string()),
right: LogicalExpression::Variable("b".to_string()),
}],
});
let join_abc = LogicalOperator::Join(JoinOp {
left: Box::new(join_ab),
right: Box::new(scan_c),
join_type: JoinType::Inner,
conditions: vec![JoinCondition {
left: LogicalExpression::Variable("b".to_string()),
right: LogicalExpression::Variable("c".to_string()),
}],
});
let plan = LogicalPlan::new(LogicalOperator::Return(ReturnOp {
items: vec![ReturnItem {
expression: LogicalExpression::Variable("a".to_string()),
alias: None,
}],
distinct: false,
input: Box::new(join_abc),
}));
let mut optimizer = Optimizer::new();
optimizer
.card_estimator
.add_table_stats("Person", cardinality::TableStats::new(1000));
optimizer
.card_estimator
.add_table_stats("Company", cardinality::TableStats::new(100));
let optimized = optimizer.optimize(plan).unwrap();
fn has_multi_way_join(op: &LogicalOperator) -> bool {
match op {
LogicalOperator::MultiWayJoin(_) => true,
LogicalOperator::Return(ret) => has_multi_way_join(&ret.input),
LogicalOperator::Filter(f) => has_multi_way_join(&f.input),
LogicalOperator::Project(p) => has_multi_way_join(&p.input),
LogicalOperator::Join(j) => {
has_multi_way_join(&j.left) || has_multi_way_join(&j.right)
}
_ => false,
}
}
assert!(
!has_multi_way_join(&optimized.root),
"Acyclic join should NOT produce MultiWayJoin"
);
}
}