use super::ast::*;
use crate::datatypes::values::Value;
use crate::graph::core::pattern_matching::PatternElement;
use crate::graph::schema::DirGraph;
use std::collections::{HashMap, HashSet};
pub mod cost_model;
pub mod fusion;
pub mod index_selection;
pub mod join_order;
pub mod rel_predicate_pushdown;
pub mod schema_check;
pub mod simplification;
use cost_model::reorder_predicates_by_cost;
use fusion::{
fuse_anchored_edge_count, fuse_count_short_circuits, fuse_match_return_aggregate,
fuse_match_with_aggregate, fuse_match_with_aggregate_top_k, fuse_node_scan_aggregate,
fuse_node_scan_top_k, fuse_optional_match_aggregate, fuse_order_by_top_k, fuse_spatial_join,
fuse_vector_score_order_limit, mark_return_lazy_eligible,
};
use index_selection::push_where_into_match;
use join_order::{
optimize_pattern_start_node, reorder_cyclic_pattern_edges, reorder_match_clauses,
reorder_match_patterns,
};
use rel_predicate_pushdown::extract_pushable_rel_predicates;
use simplification::{
desugar_multi_match_return_aggregate, fold_or_to_in, fold_pass_through_with,
push_distinct_into_match, push_limit_into_aggregate, push_limit_into_match,
rewrite_count_bound_var_to_star,
};
pub struct PassCtx<'a> {
pub graph: &'a DirGraph,
pub params: &'a HashMap<String, Value>,
pub disabled: &'a HashSet<String>,
}
type PassFn = fn(&mut CypherQuery, &PassCtx);
pub const PASSES: &[(&str, PassFn)] = &[
("optimize_nested_queries", pass_optimize_nested_queries),
(
"rewrite_count_bound_var_to_star",
pass_rewrite_count_bound_var_to_star,
),
("push_where_into_match.1", pass_push_where_into_match),
("fold_or_to_in", pass_fold_or_to_in),
("push_where_into_match.2", pass_push_where_into_match),
(
"extract_pushable_rel_predicates",
pass_extract_pushable_rel_predicates,
),
("fold_pass_through_with", pass_fold_pass_through_with),
(
"desugar_multi_match_return_aggregate",
pass_desugar_multi_match_return_aggregate,
),
("fuse_spatial_join", pass_fuse_spatial_join),
("reorder_match_clauses", pass_reorder_match_clauses),
(
"reorder_cyclic_pattern_edges",
pass_reorder_cyclic_pattern_edges,
),
(
"optimize_pattern_start_node",
pass_optimize_pattern_start_node,
),
("reorder_match_patterns", pass_reorder_match_patterns),
("push_limit_into_match", pass_push_limit_into_match),
("push_limit_into_aggregate", pass_push_limit_into_aggregate),
("push_distinct_into_match", pass_push_distinct_into_match),
("fuse_anchored_edge_count", pass_fuse_anchored_edge_count),
("fuse_count_short_circuits", pass_fuse_count_short_circuits),
(
"fuse_optional_match_aggregate",
pass_fuse_optional_match_aggregate,
),
(
"fuse_match_return_aggregate",
pass_fuse_match_return_aggregate,
),
("fuse_match_with_aggregate", pass_fuse_match_with_aggregate),
(
"fuse_match_with_aggregate_top_k",
pass_fuse_match_with_aggregate_top_k,
),
("fuse_node_scan_aggregate", pass_fuse_node_scan_aggregate),
("fuse_node_scan_top_k", pass_fuse_node_scan_top_k),
(
"fuse_vector_score_order_limit",
pass_fuse_vector_score_order_limit,
),
("fuse_order_by_top_k", pass_fuse_order_by_top_k),
(
"reorder_predicates_by_cost",
pass_reorder_predicates_by_cost,
),
(
"mark_fast_var_length_paths",
pass_mark_fast_var_length_paths,
),
(
"mark_skip_target_type_check",
pass_mark_skip_target_type_check,
),
];
pub fn is_known_pass(name: &str) -> bool {
PASSES.iter().any(|(n, _)| *n == name)
}
pub fn all_pass_names() -> Vec<String> {
PASSES.iter().map(|(n, _)| n.to_string()).collect()
}
pub fn mark_lazy_eligibility(query: &mut CypherQuery) {
if query.clauses.iter().any(|c| matches!(c, Clause::Union(_))) {
return;
}
if query.clauses.iter().any(|c| {
matches!(
c,
Clause::Create(_)
| Clause::Set(_)
| Clause::Delete(_)
| Clause::Remove(_)
| Clause::Merge(_)
)
}) {
return;
}
mark_return_lazy_eligible(query);
}
pub fn optimize(query: &mut CypherQuery, graph: &DirGraph, params: &HashMap<String, Value>) {
optimize_with_disabled(query, graph, params, empty_disabled_set());
}
pub fn empty_disabled_set() -> &'static HashSet<String> {
static EMPTY: std::sync::OnceLock<HashSet<String>> = std::sync::OnceLock::new();
EMPTY.get_or_init(HashSet::new)
}
pub fn optimize_with_disabled(
query: &mut CypherQuery,
graph: &DirGraph,
params: &HashMap<String, Value>,
disabled: &HashSet<String>,
) {
let ctx = PassCtx {
graph,
params,
disabled,
};
for (name, pass_fn) in PASSES {
if disabled.contains(*name) {
continue;
}
pass_fn(query, &ctx);
#[cfg(debug_assertions)]
debug_check_invariants(query, name);
}
}
#[cfg(debug_assertions)]
fn debug_check_invariants(query: &CypherQuery, after_pass_name: &str) {
if let Err(msg) = check_match_patterns_non_empty(query) {
panic!("Pass `{after_pass_name}` produced invalid IR: {msg}");
}
if let Err(msg) = check_return_with_items_non_empty(query) {
panic!("Pass `{after_pass_name}` produced invalid IR: {msg}");
}
if let Err(msg) = check_limit_skip_nonnegative(query) {
panic!("Pass `{after_pass_name}` produced invalid IR: {msg}");
}
}
#[cfg(debug_assertions)]
fn check_match_patterns_non_empty(query: &CypherQuery) -> Result<(), String> {
for (idx, clause) in query.clauses.iter().enumerate() {
let mc = match clause {
Clause::Match(m) | Clause::OptionalMatch(m) => m,
_ => continue,
};
if mc.patterns.is_empty() {
return Err(format!("Match clause at index {idx} has no patterns"));
}
for (pi, p) in mc.patterns.iter().enumerate() {
if p.elements.is_empty() {
return Err(format!(
"Match clause at index {idx}, pattern {pi} has no elements"
));
}
}
}
Ok(())
}
#[cfg(debug_assertions)]
fn check_return_with_items_non_empty(query: &CypherQuery) -> Result<(), String> {
for (idx, clause) in query.clauses.iter().enumerate() {
match clause {
Clause::Return(r) if r.items.is_empty() => {
return Err(format!("Return clause at index {idx} has no items"));
}
Clause::With(w) if w.items.is_empty() => {
return Err(format!("With clause at index {idx} has no items"));
}
_ => {}
}
}
Ok(())
}
#[cfg(debug_assertions)]
fn check_limit_skip_nonnegative(query: &CypherQuery) -> Result<(), String> {
for (idx, clause) in query.clauses.iter().enumerate() {
match clause {
Clause::Limit(l) => {
if let Expression::Literal(Value::Int64(n)) = &l.count {
if *n < 0 {
return Err(format!(
"Limit clause at index {idx} has negative literal {n}"
));
}
}
}
Clause::Skip(s) => {
if let Expression::Literal(Value::Int64(n)) = &s.count {
if *n < 0 {
return Err(format!(
"Skip clause at index {idx} has negative literal {n}"
));
}
}
}
_ => {}
}
}
Ok(())
}
fn pass_optimize_nested_queries(query: &mut CypherQuery, ctx: &PassCtx) {
for clause in &mut query.clauses {
match clause {
Clause::Union(ref mut u) => {
optimize_with_disabled(&mut u.query, ctx.graph, ctx.params, ctx.disabled);
}
Clause::CallSubquery {
ref import,
ref mut body,
} => {
let anchors = import_pattern_anchors(body, import);
if anchors.is_empty() {
optimize_with_disabled(body, ctx.graph, ctx.params, ctx.disabled);
} else {
let mut merged = ctx.disabled.clone();
merged.extend(seed_ignoring_fusion_passes().iter().cloned());
optimize_with_disabled(body, ctx.graph, ctx.params, &merged);
}
}
_ => {}
}
}
}
pub(crate) fn import_pattern_anchors(body: &CypherQuery, import: &[String]) -> Vec<String> {
let mut anchors: Vec<String> = Vec::new();
for clause in &body.clauses {
let patterns = match clause {
Clause::Match(m) | Clause::OptionalMatch(m) => &m.patterns,
_ => continue,
};
for pattern in patterns {
for elem in &pattern.elements {
let var = match elem {
PatternElement::Node(np) => np.variable.as_ref(),
PatternElement::Edge(ep) => ep.variable.as_ref(),
};
if let Some(v) = var {
if import.iter().any(|name| name == v) && !anchors.iter().any(|a| a == v) {
anchors.push(v.clone());
}
}
}
}
}
anchors
}
pub(crate) fn seed_ignoring_fusion_passes() -> &'static HashSet<String> {
static PASSES_SET: std::sync::OnceLock<HashSet<String>> = std::sync::OnceLock::new();
PASSES_SET.get_or_init(|| {
[
"fuse_anchored_edge_count",
"fuse_count_short_circuits",
"fuse_optional_match_aggregate",
"fuse_match_return_aggregate",
"fuse_match_with_aggregate",
"fuse_match_with_aggregate_top_k",
"fuse_node_scan_aggregate",
"fuse_node_scan_top_k",
]
.iter()
.map(|s| s.to_string())
.collect()
})
}
fn pass_push_where_into_match(query: &mut CypherQuery, ctx: &PassCtx) {
push_where_into_match(query, ctx.params)
}
fn pass_fold_or_to_in(query: &mut CypherQuery, _ctx: &PassCtx) {
fold_or_to_in(query)
}
fn pass_rewrite_count_bound_var_to_star(query: &mut CypherQuery, _ctx: &PassCtx) {
rewrite_count_bound_var_to_star(query)
}
fn pass_extract_pushable_rel_predicates(query: &mut CypherQuery, _ctx: &PassCtx) {
extract_pushable_rel_predicates(query)
}
fn pass_fold_pass_through_with(query: &mut CypherQuery, _ctx: &PassCtx) {
fold_pass_through_with(query)
}
fn pass_desugar_multi_match_return_aggregate(query: &mut CypherQuery, _ctx: &PassCtx) {
desugar_multi_match_return_aggregate(query)
}
fn pass_fuse_spatial_join(query: &mut CypherQuery, ctx: &PassCtx) {
fuse_spatial_join(query, ctx.graph)
}
fn pass_reorder_match_clauses(query: &mut CypherQuery, ctx: &PassCtx) {
reorder_match_clauses(query, ctx.graph)
}
fn pass_reorder_cyclic_pattern_edges(query: &mut CypherQuery, ctx: &PassCtx) {
reorder_cyclic_pattern_edges(query, ctx.graph)
}
fn pass_optimize_pattern_start_node(query: &mut CypherQuery, ctx: &PassCtx) {
optimize_pattern_start_node(query, ctx.graph)
}
fn pass_reorder_match_patterns(query: &mut CypherQuery, ctx: &PassCtx) {
reorder_match_patterns(query, ctx.graph)
}
fn pass_push_limit_into_match(query: &mut CypherQuery, ctx: &PassCtx) {
push_limit_into_match(query, ctx.graph)
}
fn pass_push_limit_into_aggregate(query: &mut CypherQuery, ctx: &PassCtx) {
push_limit_into_aggregate(query, ctx.graph)
}
fn pass_push_distinct_into_match(query: &mut CypherQuery, _ctx: &PassCtx) {
push_distinct_into_match(query)
}
fn pass_fuse_anchored_edge_count(query: &mut CypherQuery, ctx: &PassCtx) {
fuse_anchored_edge_count(query, ctx.graph)
}
fn pass_fuse_count_short_circuits(query: &mut CypherQuery, ctx: &PassCtx) {
fuse_count_short_circuits(
query,
ctx.graph.has_secondary_labels,
ctx.graph.has_type_shadowing_property(),
)
}
fn pass_fuse_optional_match_aggregate(query: &mut CypherQuery, _ctx: &PassCtx) {
fuse_optional_match_aggregate(query)
}
fn pass_fuse_match_return_aggregate(query: &mut CypherQuery, ctx: &PassCtx) {
fuse_match_return_aggregate(query, ctx.graph.has_secondary_labels)
}
fn pass_fuse_match_with_aggregate(query: &mut CypherQuery, ctx: &PassCtx) {
fuse_match_with_aggregate(query, ctx.graph.has_secondary_labels)
}
fn pass_fuse_match_with_aggregate_top_k(query: &mut CypherQuery, _ctx: &PassCtx) {
fuse_match_with_aggregate_top_k(query)
}
fn pass_fuse_node_scan_aggregate(query: &mut CypherQuery, _ctx: &PassCtx) {
fuse_node_scan_aggregate(query)
}
fn pass_fuse_node_scan_top_k(query: &mut CypherQuery, _ctx: &PassCtx) {
fuse_node_scan_top_k(query)
}
fn pass_fuse_vector_score_order_limit(query: &mut CypherQuery, _ctx: &PassCtx) {
fuse_vector_score_order_limit(query)
}
fn pass_fuse_order_by_top_k(query: &mut CypherQuery, _ctx: &PassCtx) {
fuse_order_by_top_k(query)
}
fn pass_reorder_predicates_by_cost(query: &mut CypherQuery, _ctx: &PassCtx) {
reorder_predicates_by_cost(query)
}
fn pass_mark_fast_var_length_paths(query: &mut CypherQuery, _ctx: &PassCtx) {
mark_fast_var_length_paths(query)
}
fn pass_mark_skip_target_type_check(query: &mut CypherQuery, ctx: &PassCtx) {
mark_skip_target_type_check(query, ctx.graph)
}
fn mark_fast_var_length_paths(query: &mut CypherQuery) {
if !downstream_is_dedup_safe(query) {
return;
}
for clause in &mut query.clauses {
let mc = match clause {
Clause::Match(mc) | Clause::OptionalMatch(mc) => mc,
_ => continue,
};
if !mc.path_assignments.is_empty() {
continue;
}
for pattern in &mut mc.patterns {
for element in &mut pattern.elements {
if let PatternElement::Edge(ep) = element {
if ep.var_length.is_some() && ep.variable.is_none() {
ep.needs_path_info = false;
}
}
}
}
}
}
fn downstream_is_dedup_safe(query: &CypherQuery) -> bool {
for clause in &query.clauses {
match clause {
Clause::Return(r) => {
if r.distinct {
return true;
}
let all_agg_distinct = !r.items.is_empty()
&& r.items
.iter()
.all(|item| is_distinct_safe_aggregate(&item.expression));
return all_agg_distinct;
}
Clause::With(w) => {
if w.distinct {
return true;
}
let all_agg_distinct = !w.items.is_empty()
&& w.items
.iter()
.all(|item| is_distinct_safe_aggregate(&item.expression));
return all_agg_distinct;
}
_ => continue,
}
}
false
}
fn is_distinct_safe_aggregate(expr: &Expression) -> bool {
if let Expression::FunctionCall {
name,
args: _,
distinct,
} = expr
{
let nm = name.to_lowercase();
if matches!(nm.as_str(), "min" | "max") {
return true;
}
if *distinct && matches!(nm.as_str(), "count" | "collect") {
return true;
}
}
false
}
fn mark_skip_target_type_check(query: &mut CypherQuery, graph: &DirGraph) {
use crate::graph::core::pattern_matching::EdgeDirection;
for clause in &mut query.clauses {
let mc = match clause {
Clause::Match(mc) | Clause::OptionalMatch(mc) => mc,
_ => continue,
};
for pattern in &mut mc.patterns {
let elements = &mut pattern.elements;
let len = elements.len();
for i in 0..len {
if i + 2 >= len {
break;
}
let (conn_type, direction, target_node_type) = {
let edge = match &elements[i + 1] {
PatternElement::Edge(ep) => ep,
_ => continue,
};
let target = match &elements[i + 2] {
PatternElement::Node(np) => np,
_ => continue,
};
if !target.extra_labels.is_empty() {
continue;
}
match (&edge.connection_type, edge.direction, &target.node_type) {
(Some(ct), dir, Some(nt)) => (ct.clone(), dir, nt.clone()),
_ => continue,
}
};
if let Some(info) = graph.connection_type_metadata.get(&conn_type) {
let guaranteed = match direction {
EdgeDirection::Outgoing => {
info.target_types.len() == 1
&& info.target_types.contains(&target_node_type)
}
EdgeDirection::Incoming => {
info.source_types.len() == 1
&& info.source_types.contains(&target_node_type)
}
EdgeDirection::Both => false, };
if guaranteed {
if let PatternElement::Edge(ep) = &mut elements[i + 1] {
ep.skip_target_type_check = true;
}
}
}
}
}
}
}
#[cfg(test)]
#[path = "planner_tests.rs"]
mod tests;