use super::super::ast::*;
use crate::datatypes::values::Value;
use crate::graph::core::pattern_matching::PatternElement;
use crate::graph::schema::DirGraph;
pub(super) fn fuse_anchored_edge_count(query: &mut CypherQuery, graph: &DirGraph) {
use crate::graph::core::pattern_matching::{EdgeDirection, PropertyMatcher};
if query.clauses.len() < 2 {
return;
}
let is_match_return = matches!(
(&query.clauses[0], &query.clauses[1]),
(Clause::Match(_), Clause::Return(_))
);
if !is_match_return {
return;
}
let match_clause = if let Clause::Match(m) = &query.clauses[0] {
m
} else {
return;
};
let return_clause = if let Clause::Return(r) = &query.clauses[1] {
r
} else {
return;
};
if return_clause.distinct || return_clause.having.is_some() {
return;
}
if match_clause.patterns.len() != 1 || !match_clause.path_assignments.is_empty() {
return;
}
let pat = &match_clause.patterns[0];
if pat.elements.len() != 3 {
return;
}
let src_node = match &pat.elements[0] {
PatternElement::Node(np) => np,
_ => return,
};
let edge = match &pat.elements[1] {
PatternElement::Edge(ep) => ep,
_ => return,
};
let tgt_node = match &pat.elements[2] {
PatternElement::Node(np) => np,
_ => return,
};
if edge.properties.is_some() || edge.var_length.is_some() {
return;
}
if edge.direction == EdgeDirection::Both {
return;
}
let as_anchor_id = |np: &crate::graph::core::pattern_matching::NodePattern| -> Option<Value> {
if np.node_type.is_some() || np.variable.is_some() {
return None;
}
let props = np.properties.as_ref()?;
if props.len() != 1 {
return None;
}
if let Some(PropertyMatcher::Equals(val)) = props.get("id") {
Some(val.clone())
} else {
None
}
};
fn as_pure_var(np: &crate::graph::core::pattern_matching::NodePattern) -> Option<&String> {
if np.node_type.is_some() || np.properties.is_some() {
return None;
}
np.variable.as_ref()
}
let (var_name, anchor_val, anchor_dir) = match (as_pure_var(src_node), as_anchor_id(tgt_node)) {
(Some(v), Some(id)) => {
let dir = match edge.direction {
EdgeDirection::Outgoing => petgraph::Direction::Incoming,
EdgeDirection::Incoming => petgraph::Direction::Outgoing,
EdgeDirection::Both => return,
};
(v, id, dir)
}
_ => match (as_anchor_id(src_node), as_pure_var(tgt_node)) {
(Some(id), Some(v)) => {
let dir = match edge.direction {
EdgeDirection::Outgoing => petgraph::Direction::Outgoing,
EdgeDirection::Incoming => petgraph::Direction::Incoming,
EdgeDirection::Both => return,
};
(v, id, dir)
}
_ => return,
},
};
if return_clause.items.len() != 1 {
return;
}
if !is_count_of_var_or_star(&return_clause.items[0].expression, Some(var_name)) {
return;
}
let mut resolved: Option<petgraph::graph::NodeIndex> = None;
for node_type in graph.type_indices.keys() {
if let Some(idx) = graph.lookup_by_id_readonly(node_type, &anchor_val) {
resolved = Some(idx);
break;
}
}
let anchor_idx = match resolved {
Some(idx) => idx.index() as u32,
None => return, };
let alias = return_item_column_name(&return_clause.items[0]);
let edge_type = edge.connection_type.clone();
query.clauses.drain(0..2);
query.clauses.insert(
0,
Clause::FusedCountAnchoredEdges {
anchor_idx,
anchor_direction: anchor_dir,
edge_type,
alias,
},
);
}
pub(super) fn fuse_count_short_circuits(query: &mut CypherQuery, has_secondary_labels: bool) {
use crate::graph::core::pattern_matching::EdgeDirection;
if query.clauses.len() < 2 {
return;
}
let is_match_return = matches!(
(&query.clauses[0], &query.clauses[1]),
(Clause::Match(_), Clause::Return(_))
);
if !is_match_return {
return;
}
let match_clause = if let Clause::Match(m) = &query.clauses[0] {
m
} else {
return;
};
let return_clause = if let Clause::Return(r) = &query.clauses[1] {
r
} else {
return;
};
if return_clause.distinct {
return;
}
if match_clause.patterns.len() != 1 {
return;
}
let pat = &match_clause.patterns[0];
if pat.elements.len() == 1 {
let node = match &pat.elements[0] {
PatternElement::Node(np) => np,
_ => return,
};
if node.properties.is_some() {
return;
}
if !node.extra_labels.is_empty() {
return;
}
let node_var = node.variable.as_deref();
if let Some(ref node_type) = node.node_type {
if return_clause.items.len() == 1
&& is_count_of_var_or_star(&return_clause.items[0].expression, node_var)
{
let alias = return_item_column_name(&return_clause.items[0]);
let nt = node_type.clone();
query.clauses.drain(0..2);
query.clauses.insert(
0,
Clause::FusedCountTypedNode {
node_type: nt,
alias,
},
);
}
return;
}
if return_clause.items.len() == 1 {
let item = &return_clause.items[0];
if !is_count_of_var_or_star(&item.expression, node_var) {
return;
}
let alias = return_item_column_name(item);
query.clauses.drain(0..2);
query.clauses.insert(0, Clause::FusedCountAll { alias });
return;
}
if return_clause.items.len() == 2 {
let (type_idx, count_idx) =
identify_type_count_pair(&return_clause.items, node_var, has_secondary_labels);
if let Some((ti, ci)) = type_idx.zip(count_idx) {
let type_alias = return_item_column_name(&return_clause.items[ti]);
let count_alias = return_item_column_name(&return_clause.items[ci]);
query.clauses.drain(0..2);
query.clauses.insert(
0,
Clause::FusedCountByType {
type_alias,
count_alias,
},
);
return;
}
}
return;
}
if pat.elements.len() == 3 {
let src_node = match &pat.elements[0] {
PatternElement::Node(np) => np,
_ => return,
};
let edge = match &pat.elements[1] {
PatternElement::Edge(ep) => ep,
_ => return,
};
let tgt_node = match &pat.elements[2] {
PatternElement::Node(np) => np,
_ => return,
};
if src_node.node_type.is_some()
|| src_node.properties.is_some()
|| tgt_node.node_type.is_some()
|| tgt_node.properties.is_some()
{
return;
}
if edge.properties.is_some()
|| edge.var_length.is_some()
|| edge.direction == EdgeDirection::Both
{
return;
}
let edge_var = edge.variable.as_deref();
if let Some(ref edge_type) = edge.connection_type {
if return_clause.items.len() == 1
&& is_count_of_var_or_star(&return_clause.items[0].expression, edge_var)
{
let alias = return_item_column_name(&return_clause.items[0]);
let et = edge_type.clone();
query.clauses.drain(0..2);
query.clauses.insert(
0,
Clause::FusedCountTypedEdge {
edge_type: et,
alias,
},
);
}
return;
}
if return_clause.items.len() != 2 {
return;
}
let (type_idx, count_idx) = identify_edge_type_count_pair(&return_clause.items, edge_var);
if let Some((ti, ci)) = type_idx.zip(count_idx) {
let type_alias = return_item_column_name(&return_clause.items[ti]);
let count_alias = return_item_column_name(&return_clause.items[ci]);
query.clauses.drain(0..2);
query.clauses.insert(
0,
Clause::FusedCountEdgesByType {
type_alias,
count_alias,
},
);
}
}
}
pub(super) fn is_count_of_var_or_star(expr: &Expression, node_var: Option<&str>) -> bool {
if let Expression::FunctionCall {
name,
args,
distinct,
} = expr
{
if name != "count" || *distinct {
return false;
}
if args.len() == 1 {
return match &args[0] {
Expression::Star => true,
Expression::Variable(v) => node_var.is_some_and(|nv| v == nv),
_ => false,
};
}
}
false
}
pub(super) fn identify_type_count_pair(
items: &[ReturnItem],
node_var: Option<&str>,
has_secondary_labels: bool,
) -> (Option<usize>, Option<usize>) {
let mut type_idx = None;
let mut count_idx = None;
for (i, item) in items.iter().enumerate() {
if is_count_of_var_or_star(&item.expression, node_var) {
count_idx = Some(i);
} else if is_primary_type_accessor(&item.expression, node_var)
|| (!has_secondary_labels && is_labels_call(&item.expression, node_var))
{
type_idx = Some(i);
}
}
(type_idx, count_idx)
}
pub(super) fn is_primary_type_accessor(expr: &Expression, node_var: Option<&str>) -> bool {
match expr {
Expression::PropertyAccess { variable, property } => {
let is_type_prop = matches!(property.as_str(), "type" | "node_type" | "label");
is_type_prop && node_var.is_some_and(|nv| variable == nv)
}
_ => false,
}
}
pub(super) fn is_labels_call(expr: &Expression, node_var: Option<&str>) -> bool {
if let Expression::FunctionCall { name, args, .. } = expr {
if name == "labels" && args.len() == 1 {
if let Expression::Variable(v) = &args[0] {
return node_var.is_some_and(|nv| v == nv);
}
}
}
false
}
pub(super) fn identify_edge_type_count_pair(
items: &[ReturnItem],
edge_var: Option<&str>,
) -> (Option<usize>, Option<usize>) {
let mut type_idx = None;
let mut count_idx = None;
for (i, item) in items.iter().enumerate() {
if is_count_of_var_or_star(&item.expression, edge_var) {
count_idx = Some(i);
} else if is_edge_type_function(&item.expression, edge_var) {
type_idx = Some(i);
}
}
(type_idx, count_idx)
}
pub(super) fn is_edge_type_function(expr: &Expression, edge_var: Option<&str>) -> bool {
if let Expression::FunctionCall { name, args, .. } = expr {
if name == "type" && args.len() == 1 {
if let Expression::Variable(v) = &args[0] {
return edge_var.is_some_and(|ev| v == ev);
}
}
}
false
}
pub(super) fn fuse_optional_match_aggregate(query: &mut CypherQuery) {
let mut i = 0;
while i + 1 < query.clauses.len() {
let can_fuse = matches!(
(&query.clauses[i], &query.clauses[i + 1]),
(Clause::OptionalMatch(_), Clause::With(_))
| (Clause::OptionalMatch(_), Clause::Return(_))
);
if !can_fuse {
i += 1;
continue;
}
let collect_all_pattern_vars =
|patterns: &[crate::graph::core::pattern_matching::Pattern]| -> Vec<String> {
let mut vars = Vec::new();
for pattern in patterns {
for element in &pattern.elements {
match element {
PatternElement::Node(np) => {
if let Some(ref v) = np.variable {
vars.push(v.clone());
}
}
PatternElement::Edge(ep) => {
if let Some(ref v) = ep.variable {
vars.push(v.clone());
}
}
}
}
}
vars
};
let pre_bound_vars: std::collections::HashSet<String> = query.clauses[..i]
.iter()
.flat_map(|c| match c {
Clause::Match(m) | Clause::OptionalMatch(m) => {
collect_all_pattern_vars(&m.patterns)
}
Clause::With(w) => w
.items
.iter()
.filter_map(|it| {
it.alias.clone().or_else(|| match &it.expression {
Expression::Variable(v) => Some(v.clone()),
_ => None,
})
})
.collect(),
Clause::Unwind(u) => vec![u.alias.clone()],
_ => Vec::new(),
})
.collect();
let opt_match_vars: std::collections::HashSet<String> =
if let Clause::OptionalMatch(m) = &query.clauses[i] {
collect_all_pattern_vars(&m.patterns)
.into_iter()
.filter(|v| !pre_bound_vars.contains(v))
.collect()
} else {
i += 1;
continue;
};
let fusable = match &query.clauses[i + 1] {
Clause::With(w) => is_fusable_with_clause(w),
Clause::Return(r) => is_fusable_return_clause(r, &opt_match_vars),
_ => false,
};
if !fusable {
i += 1;
continue;
}
let items = match &query.clauses[i + 1] {
Clause::With(w) => &w.items,
Clause::Return(r) => &r.items,
_ => {
i += 1;
continue;
}
};
let all_counts_local = items
.iter()
.all(|item| count_args_local_to_opt(&item.expression, &opt_match_vars));
if !all_counts_local {
i += 1;
continue;
}
let with_clause = match query.clauses.remove(i + 1) {
Clause::With(w) => w,
Clause::Return(r) => WithClause {
items: r.items,
distinct: r.distinct,
where_clause: r.having.map(|pred| WhereClause { predicate: pred }),
group_limit_hint: r.group_limit_hint,
},
_ => unreachable!(),
};
let match_clause = if let Clause::OptionalMatch(m) = query.clauses.remove(i) {
m
} else {
unreachable!()
};
query.clauses.insert(
i,
Clause::FusedOptionalMatchAggregate {
match_clause,
with_clause,
},
);
i += 1;
}
}
pub(super) fn is_fusable_with_clause(with: &WithClause) -> bool {
use super::super::ast::is_aggregate_expression;
let mut has_count = false;
for item in &with.items {
if is_aggregate_expression(&item.expression) {
match &item.expression {
Expression::FunctionCall { name, .. } if name == "count" => {
has_count = true;
}
expr if aggregates_only_count(expr) => {
has_count = true;
}
_ => return false,
}
} else {
if !matches!(&item.expression, Expression::Variable(_)) {
return false;
}
}
}
has_count
}
fn aggregates_only_count(expr: &Expression) -> bool {
use super::super::ast::is_aggregate_expression;
match expr {
Expression::FunctionCall {
name,
args,
distinct: _,
} => {
if is_aggregate_expression(expr) && name != "count" {
return false;
}
args.iter().all(aggregates_only_count)
}
Expression::Add(l, r)
| Expression::Subtract(l, r)
| Expression::Multiply(l, r)
| Expression::Divide(l, r)
| Expression::Modulo(l, r)
| Expression::Concat(l, r) => aggregates_only_count(l) && aggregates_only_count(r),
Expression::Negate(inner) => aggregates_only_count(inner),
Expression::IndexAccess { expr, index } => {
aggregates_only_count(expr) && aggregates_only_count(index)
}
Expression::ListSlice { expr, start, end } => {
aggregates_only_count(expr)
&& start.as_deref().is_none_or(aggregates_only_count)
&& end.as_deref().is_none_or(aggregates_only_count)
}
Expression::ListComprehension {
list_expr,
map_expr,
..
} => {
aggregates_only_count(list_expr)
&& map_expr.as_deref().is_none_or(aggregates_only_count)
}
Expression::Case {
when_clauses,
else_expr,
..
} => {
when_clauses
.iter()
.all(|(_, result)| aggregates_only_count(result))
&& else_expr.as_deref().is_none_or(aggregates_only_count)
}
Expression::ExprPropertyAccess { expr, .. } => aggregates_only_count(expr),
Expression::MapLiteral(entries) => entries.iter().all(|(_, e)| aggregates_only_count(e)),
_ => true,
}
}
pub(super) fn is_fusable_return_clause(
ret: &ReturnClause,
opt_match_vars: &std::collections::HashSet<String>,
) -> bool {
use super::super::ast::is_aggregate_expression;
let mut has_count = false;
for item in &ret.items {
if is_aggregate_expression(&item.expression) {
match &item.expression {
Expression::FunctionCall { name, .. } if name == "count" => {
has_count = true;
}
expr if aggregates_only_count(expr) => {
if expression_touches_vars(expr, opt_match_vars) {
return false;
}
has_count = true;
}
_ => return false,
}
} else {
match &item.expression {
Expression::Variable(_) => {}
Expression::PropertyAccess { variable, .. } => {
if opt_match_vars.contains(variable) {
return false;
}
}
_ => return false,
}
}
}
has_count
}
fn count_args_local_to_opt(
expr: &Expression,
opt_match_vars: &std::collections::HashSet<String>,
) -> bool {
match expr {
Expression::FunctionCall {
name,
args,
distinct,
} => {
if name == "count" {
if *distinct {
return false;
}
if args.len() != 1 {
return false;
}
match &args[0] {
Expression::Star => true,
Expression::Variable(v) => opt_match_vars.contains(v),
_ => false,
}
} else {
if super::super::ast::is_aggregate_expression(expr) {
return false;
}
args.iter()
.all(|a| count_args_local_to_opt(a, opt_match_vars))
}
}
Expression::Add(l, r)
| Expression::Subtract(l, r)
| Expression::Multiply(l, r)
| Expression::Divide(l, r)
| Expression::Modulo(l, r)
| Expression::Concat(l, r) => {
count_args_local_to_opt(l, opt_match_vars) && count_args_local_to_opt(r, opt_match_vars)
}
Expression::Negate(inner) => count_args_local_to_opt(inner, opt_match_vars),
_ => true,
}
}
fn expression_touches_vars(expr: &Expression, vars: &std::collections::HashSet<String>) -> bool {
match expr {
Expression::Variable(v) => vars.contains(v),
Expression::PropertyAccess { variable, .. } => vars.contains(variable),
Expression::FunctionCall { name, args, .. } => {
if name == "count" {
false
} else {
args.iter().any(|a| expression_touches_vars(a, vars))
}
}
Expression::Add(l, r)
| Expression::Subtract(l, r)
| Expression::Multiply(l, r)
| Expression::Divide(l, r)
| Expression::Modulo(l, r)
| Expression::Concat(l, r) => {
expression_touches_vars(l, vars) || expression_touches_vars(r, vars)
}
Expression::Negate(inner) => expression_touches_vars(inner, vars),
_ => false,
}
}
fn distinct_fusable_3elem_with_constrained_group(
match_clause: &Clause,
next_clause: &Clause,
) -> bool {
use super::super::ast::is_aggregate_expression;
let m = match match_clause {
Clause::Match(m) => m,
_ => return false,
};
if m.patterns.len() != 1 || m.patterns[0].elements.len() != 3 {
return false;
}
let first = match &m.patterns[0].elements[0] {
PatternElement::Node(np) => np,
_ => return false,
};
let last = match &m.patterns[0].elements[2] {
PatternElement::Node(np) => np,
_ => return false,
};
let group_var: Option<&str> = match next_clause {
Clause::Return(r) => r.items.iter().find_map(|item| {
if is_aggregate_expression(&item.expression) {
None
} else {
match &item.expression {
Expression::Variable(v) => Some(v.as_str()),
Expression::PropertyAccess { variable, .. } => Some(variable.as_str()),
_ => None,
}
}
}),
Clause::With(w) => w.items.iter().find_map(|item| {
if is_aggregate_expression(&item.expression) {
None
} else {
match &item.expression {
Expression::Variable(v) => Some(v.as_str()),
Expression::PropertyAccess { variable, .. } => Some(variable.as_str()),
_ => None,
}
}
}),
_ => None,
};
let Some(gv) = group_var else { return false };
let group_node = if first.variable.as_deref() == Some(gv) {
first
} else if last.variable.as_deref() == Some(gv) {
last
} else {
return false;
};
let has_type = group_node.node_type.is_some();
let has_props = group_node
.properties
.as_ref()
.is_some_and(|p| !p.is_empty());
has_type || has_props
}
pub(super) fn fuse_match_return_aggregate(query: &mut CypherQuery, has_secondary_labels: bool) {
use super::super::ast::is_aggregate_expression;
if has_secondary_labels {
return;
}
let mut i = 0;
while i + 1 < query.clauses.len() {
if i > 0 {
i += 1;
continue;
}
let can_fuse = matches!(
(&query.clauses[i], &query.clauses[i + 1]),
(Clause::Match(_), Clause::Return(_))
);
if !can_fuse {
i += 1;
continue;
}
let (first_var, second_var, edge_has_props, edge_var) = if let Clause::Match(m) =
&query.clauses[i]
{
let n_elems = m.patterns[0].elements.len();
if m.patterns.len() != 1 || (n_elems != 3 && n_elems != 5) {
i += 1;
continue;
}
let pat = &m.patterns[0];
let first_var = match &pat.elements[0] {
PatternElement::Node(np) => np.variable.clone(),
_ => {
i += 1;
continue;
}
};
let (edge_has_props, edge_var) = match &pat.elements[1] {
PatternElement::Edge(ep) => (
ep.properties.is_some() || ep.var_length.is_some(),
ep.variable.clone(),
),
_ => {
i += 1;
continue;
}
};
if n_elems == 5 {
let mid_has_props = match &pat.elements[2] {
PatternElement::Node(np) => np.properties.is_some(),
_ => {
i += 1;
continue;
}
};
let edge2_has_props = match &pat.elements[3] {
PatternElement::Edge(ep) => ep.properties.is_some() || ep.var_length.is_some(),
_ => {
i += 1;
continue;
}
};
let (last_var, last_has_props) = match &pat.elements[4] {
PatternElement::Node(np) => (np.variable.clone(), np.properties.is_some()),
_ => {
i += 1;
continue;
}
};
if mid_has_props || edge2_has_props || last_has_props {
i += 1;
continue;
}
(first_var, last_var, edge_has_props, edge_var)
} else {
let second_var = match &pat.elements[2] {
PatternElement::Node(np) => np.variable.clone(),
_ => {
i += 1;
continue;
}
};
(first_var, second_var, edge_has_props, edge_var)
}
} else {
i += 1;
continue;
};
if edge_has_props {
i += 1;
continue;
}
if first_var.is_none() && second_var.is_none() {
i += 1;
continue;
}
let (fusable, distinct_count) = if let Clause::Return(r) = &query.clauses[i + 1] {
if r.distinct {
(false, false)
} else {
let mut has_count = false;
let mut all_valid = true;
let mut group_var: Option<&str> = None;
let mut count_var_ok = true;
let mut saw_distinct = false;
for item in &r.items {
if !is_aggregate_expression(&item.expression) {
let refs_var = match &item.expression {
Expression::PropertyAccess { variable, .. } => Some(variable.as_str()),
Expression::Variable(v) => Some(v.as_str()),
_ => None,
};
match refs_var {
Some(v) => {
if group_var.is_none() {
group_var = Some(v);
} else if group_var != Some(v) {
all_valid = false;
break;
}
}
None => {
all_valid = false;
break;
}
}
}
}
if all_valid {
if let Some(gv) = group_var {
let is_first = first_var.as_deref() == Some(gv);
let is_second = second_var.as_deref() == Some(gv);
if !is_first && !is_second {
all_valid = false;
}
} else {
all_valid = false; }
}
if all_valid {
let other_var = if group_var == first_var.as_deref() {
&second_var
} else {
&first_var
};
for item in &r.items {
if is_aggregate_expression(&item.expression) {
match &item.expression {
Expression::FunctionCall {
name,
args,
distinct,
} if name == "count" => {
if args.len() == 1 && matches!(args[0], Expression::Star) {
if *distinct {
count_var_ok = false;
break;
}
has_count = true;
continue;
}
if let Some(Expression::Variable(var)) = args.first() {
let matches_other =
other_var.as_deref() == Some(var.as_str());
let matches_edge =
edge_var.as_deref() == Some(var.as_str());
if matches_other || matches_edge {
has_count = true;
if *distinct {
saw_distinct = true;
}
continue;
}
}
count_var_ok = false;
break;
}
_ => {
count_var_ok = false;
break;
}
}
}
}
}
(has_count && all_valid && count_var_ok, saw_distinct)
}
} else {
(false, false)
};
if !fusable {
i += 1;
continue;
}
if distinct_count
&& !distinct_fusable_3elem_with_constrained_group(
&query.clauses[i],
&query.clauses[i + 1],
)
{
i += 1;
continue;
}
let return_clause = if let Clause::Return(r) = query.clauses.remove(i + 1) {
r
} else {
unreachable!()
};
let match_clause = if let Clause::Match(m) = query.clauses.remove(i) {
m
} else {
unreachable!()
};
query.clauses.insert(
i,
Clause::FusedMatchReturnAggregate {
match_clause,
return_clause,
top_k: None,
candidate_emit: None,
distinct_count,
},
);
i += 1;
}
fuse_aggregate_order_limit(query);
}
pub(super) fn fuse_aggregate_order_limit(query: &mut CypherQuery) {
use super::super::ast::is_aggregate_expression;
let mut i = 0;
while i + 2 < query.clauses.len() {
let is_pattern = matches!(
(
&query.clauses[i],
&query.clauses[i + 1],
&query.clauses[i + 2]
),
(
Clause::FusedMatchReturnAggregate { .. },
Clause::OrderBy(_),
Clause::Limit(_)
)
);
if !is_pattern {
i += 1;
continue;
}
if let Clause::FusedMatchReturnAggregate { return_clause, .. } = &query.clauses[i] {
if return_clause.having.is_some() {
i += 1;
continue;
}
}
let (sort_expr_idx, descending, multi_key) = if let Clause::OrderBy(ob) =
&query.clauses[i + 1]
{
if ob.items.is_empty() {
i += 1;
continue;
}
let sort_item = &ob.items[0];
if let Clause::FusedMatchReturnAggregate { return_clause, .. } = &query.clauses[i] {
let sort_alias = match &sort_item.expression {
Expression::Variable(v) => Some(v.clone()),
_ => None,
};
let sort_expr_str = expression_to_column_name(&sort_item.expression);
let mut found_idx = None;
for (ri, item) in return_clause.items.iter().enumerate() {
if !is_aggregate_expression(&item.expression) {
continue;
}
let matches_alias = sort_alias
.as_deref()
.zip(item.alias.as_deref())
.is_some_and(|(s, a)| s == a);
let matches_expr = expression_to_column_name(&item.expression) == sort_expr_str;
if matches_alias || matches_expr {
found_idx = Some(ri);
break;
}
}
match found_idx {
Some(idx) => (idx, !sort_item.ascending, ob.items.len() > 1),
None => {
i += 1;
continue;
}
}
} else {
i += 1;
continue;
}
} else {
i += 1;
continue;
};
let limit = if let Clause::Limit(l) = &query.clauses[i + 2] {
match &l.count {
Expression::Literal(Value::Int64(n)) if *n > 0 => *n as usize,
_ => {
i += 1;
continue;
}
}
} else {
i += 1;
continue;
};
if multi_key {
if let Clause::FusedMatchReturnAggregate { candidate_emit, .. } = &mut query.clauses[i]
{
*candidate_emit = Some((sort_expr_idx, descending, limit));
}
} else {
query.clauses.remove(i + 2); query.clauses.remove(i + 1); if let Clause::FusedMatchReturnAggregate { top_k, .. } = &mut query.clauses[i] {
*top_k = Some((sort_expr_idx, descending, limit));
}
}
i += 1;
}
}
pub(super) fn fuse_node_scan_aggregate(query: &mut CypherQuery) {
use super::super::ast::is_aggregate_expression;
let mut i = 0;
while i + 1 < query.clauses.len() {
if i > 0 {
i += 1;
continue;
}
let match_idx = i;
if !matches!(&query.clauses[match_idx], Clause::Match(_)) {
i += 1;
continue;
}
let (where_idx, return_idx) = if i + 2 < query.clauses.len()
&& matches!(&query.clauses[i + 1], Clause::Where(_))
&& matches!(&query.clauses[i + 2], Clause::Return(_))
{
(Some(i + 1), i + 2)
} else if matches!(&query.clauses[i + 1], Clause::Return(_)) {
(None, i + 1)
} else {
i += 1;
continue;
};
let is_single_node = if let Clause::Match(mc) = &query.clauses[match_idx] {
mc.patterns.len() == 1
&& mc.patterns[0].elements.len() == 1
&& matches!(mc.patterns[0].elements[0], PatternElement::Node(_))
&& mc.path_assignments.is_empty()
} else {
false
};
if !is_single_node {
i += 1;
continue;
}
let has_supported_agg = if let Clause::Return(r) = &query.clauses[return_idx] {
let has_any_agg = r
.items
.iter()
.any(|item| is_aggregate_expression(&item.expression));
let all_supported = r.items.iter().all(|item| {
if !is_aggregate_expression(&item.expression) {
return true; }
match &item.expression {
Expression::FunctionCall { name, distinct, .. } => {
if *distinct {
return false; }
matches!(
name.to_lowercase().as_str(),
"count" | "sum" | "avg" | "mean" | "average" | "min" | "max"
)
}
_ => false,
}
});
has_any_agg && all_supported
} else {
false
};
if !has_supported_agg {
i += 1;
continue;
}
let where_predicate = if let Some(wi) = where_idx {
if let Clause::Where(w) = query.clauses.remove(wi) {
Some(w.predicate)
} else {
None
}
} else {
None
};
let ret_idx = if where_idx.is_some() {
return_idx - 1
} else {
return_idx
};
let return_clause = if let Clause::Return(r) = query.clauses.remove(ret_idx) {
r
} else {
unreachable!()
};
let match_clause = if let Clause::Match(mc) = query.clauses.remove(match_idx) {
mc
} else {
unreachable!()
};
query.clauses.insert(
match_idx,
Clause::FusedNodeScanAggregate {
match_clause,
where_predicate,
return_clause,
},
);
i += 1;
}
}
fn try_fuse_two_match_with_aggregate(query: &mut CypherQuery, i: usize) -> bool {
use super::super::ast::is_aggregate_expression;
if i + 2 >= query.clauses.len() {
return false;
}
if !matches!(
(
&query.clauses[i],
&query.clauses[i + 1],
&query.clauses[i + 2]
),
(Clause::Match(_), Clause::Match(_), Clause::With(_))
) {
return false;
}
let (m1_first_var, m1_second_var) = {
let m1 = if let Clause::Match(m) = &query.clauses[i] {
m
} else {
return false;
};
if m1.patterns.len() != 1 || m1.patterns[0].elements.len() != 3 {
return false;
}
let pat = &m1.patterns[0];
let edge_blocking = matches!(&pat.elements[1], PatternElement::Edge(ep) if ep.properties.is_some() || ep.var_length.is_some());
if edge_blocking {
return false;
}
let first_var = match &pat.elements[0] {
PatternElement::Node(np) => np.variable.clone(),
_ => return false,
};
let second_var = match &pat.elements[2] {
PatternElement::Node(np) => np.variable.clone(),
_ => return false,
};
(first_var, second_var)
};
let (m2_shared_var, m2_edge_var) = {
let m2 = if let Clause::Match(m) = &query.clauses[i + 1] {
m
} else {
return false;
};
if m2.patterns.len() != 1 || m2.patterns[0].elements.len() != 3 {
return false;
}
let pat = &m2.patterns[0];
let m2_first_var = match &pat.elements[0] {
PatternElement::Node(np) => np.variable.clone(),
_ => return false,
};
let edge = match &pat.elements[1] {
PatternElement::Edge(ep) => ep,
_ => return false,
};
if edge.properties.is_some() || edge.var_length.is_some() {
return false;
}
let edge_var = match &edge.variable {
Some(v) => v.clone(),
None => return false,
};
let shared = m2_first_var.as_ref().filter(|v| {
m1_first_var.as_deref() == Some(v.as_str())
|| m1_second_var.as_deref() == Some(v.as_str())
});
let shared = match shared {
Some(v) => v.clone(),
None => return false,
};
(shared, edge_var)
};
let w = if let Clause::With(w) = &query.clauses[i + 2] {
w
} else {
return false;
};
if w.distinct {
return false;
}
let mut has_count_of_edge = false;
let mut group_var: Option<String> = None;
for item in &w.items {
if is_aggregate_expression(&item.expression) {
match &item.expression {
Expression::FunctionCall {
name,
args,
distinct,
} if name == "count" => {
if *distinct {
return false;
}
if args.len() == 1 && matches!(args[0], Expression::Star) {
has_count_of_edge = true;
continue;
}
if let Some(Expression::Variable(v)) = args.first() {
if v == &m2_edge_var {
has_count_of_edge = true;
continue;
}
}
return false;
}
_ => return false,
}
} else {
let referenced = match &item.expression {
Expression::Variable(v) => Some(v.clone()),
Expression::PropertyAccess { variable, .. } => Some(variable.clone()),
_ => None,
};
let v = match referenced {
Some(v) => v,
None => return false,
};
if v == m2_edge_var {
return false;
}
let m1_bound = m1_first_var.as_deref() == Some(v.as_str())
|| m1_second_var.as_deref() == Some(v.as_str());
if !m1_bound {
return false;
}
match &group_var {
None => group_var = Some(v),
Some(existing) if existing == &v => {}
_ => return false, }
}
}
if !has_count_of_edge {
return false;
}
let group_var = match group_var {
Some(v) => v,
None => return false,
};
if group_var != m2_shared_var {
return false;
}
let with_clause = if let Clause::With(w) = query.clauses.remove(i + 2) {
w
} else {
unreachable!()
};
let secondary = if let Clause::Match(m) = query.clauses.remove(i + 1) {
m
} else {
unreachable!()
};
let primary = if let Clause::Match(m) = query.clauses.remove(i) {
m
} else {
unreachable!()
};
query.clauses.insert(
i,
Clause::FusedMatchWithAggregate {
match_clause: primary,
with_clause,
secondary_match: Some(secondary),
top_k: None,
distinct_count: false,
},
);
true
}
pub(super) fn fuse_match_with_aggregate(query: &mut CypherQuery, has_secondary_labels: bool) {
use super::super::ast::is_aggregate_expression;
if has_secondary_labels {
return;
}
let mut i = 0;
while i + 1 < query.clauses.len() {
if i > 0 {
i += 1;
continue;
}
if try_fuse_two_match_with_aggregate(query, i) {
i += 1;
continue;
}
let can_fuse = matches!(
(&query.clauses[i], &query.clauses[i + 1]),
(Clause::Match(_), Clause::With(_))
);
if !can_fuse {
i += 1;
continue;
}
let (first_var, second_var, edge_has_props, second_has_props, edge_var) =
if let Clause::Match(m) = &query.clauses[i] {
if m.patterns.len() != 1 || m.patterns[0].elements.len() != 3 {
i += 1;
continue;
}
let pat = &m.patterns[0];
let first_var = match &pat.elements[0] {
PatternElement::Node(np) => np.variable.clone(),
_ => {
i += 1;
continue;
}
};
let (edge_has_props, edge_var) = match &pat.elements[1] {
PatternElement::Edge(ep) => (
ep.properties.is_some() || ep.var_length.is_some(),
ep.variable.clone(),
),
_ => {
i += 1;
continue;
}
};
let (second_var, second_has_props) = match &pat.elements[2] {
PatternElement::Node(np) => (np.variable.clone(), np.properties.is_some()),
_ => {
i += 1;
continue;
}
};
(
first_var,
second_var,
edge_has_props,
second_has_props,
edge_var,
)
} else {
i += 1;
continue;
};
if edge_has_props || second_has_props {
i += 1;
continue;
}
if first_var.is_none() && second_var.is_none() {
i += 1;
continue;
}
let (fusable, distinct_count) = if let Clause::With(w) = &query.clauses[i + 1] {
if w.distinct {
(false, false)
} else {
let mut has_count = false;
let mut all_valid = true;
let mut group_var: Option<&str> = None;
let mut count_var_ok = true;
let mut saw_distinct = false;
for item in &w.items {
if !is_aggregate_expression(&item.expression) {
let refs_var = match &item.expression {
Expression::Variable(v) => Some(v.as_str()),
_ => None,
};
match refs_var {
Some(v) => {
if group_var.is_none() {
group_var = Some(v);
} else if group_var != Some(v) {
all_valid = false;
break;
}
}
None => {
all_valid = false;
break;
}
}
}
}
if all_valid {
if let Some(gv) = group_var {
let is_first = first_var.as_deref() == Some(gv);
let is_second = second_var.as_deref() == Some(gv);
if !is_first && !is_second {
all_valid = false;
}
} else {
all_valid = false;
}
}
if all_valid {
let other_var = if group_var == first_var.as_deref() {
&second_var
} else {
&first_var
};
for item in &w.items {
if is_aggregate_expression(&item.expression) {
match &item.expression {
Expression::FunctionCall {
name,
args,
distinct,
} if name == "count" => {
if args.len() == 1 && matches!(args[0], Expression::Star) {
if *distinct {
count_var_ok = false;
break;
}
has_count = true;
continue;
}
if let Some(Expression::Variable(var)) = args.first() {
let matches_other =
other_var.as_deref() == Some(var.as_str());
let matches_edge =
edge_var.as_deref() == Some(var.as_str());
if matches_other || matches_edge {
has_count = true;
if *distinct {
saw_distinct = true;
}
continue;
}
}
count_var_ok = false;
break;
}
_ => {
count_var_ok = false;
break;
}
}
}
}
}
(has_count && all_valid && count_var_ok, saw_distinct)
}
} else {
(false, false)
};
if !fusable {
i += 1;
continue;
}
if distinct_count
&& !distinct_fusable_3elem_with_constrained_group(
&query.clauses[i],
&query.clauses[i + 1],
)
{
i += 1;
continue;
}
let with_clause = if let Clause::With(w) = query.clauses.remove(i + 1) {
w
} else {
unreachable!()
};
let match_clause = if let Clause::Match(m) = query.clauses.remove(i) {
m
} else {
unreachable!()
};
query.clauses.insert(
i,
Clause::FusedMatchWithAggregate {
match_clause,
with_clause,
secondary_match: None,
top_k: None,
distinct_count,
},
);
i += 1;
}
}
pub(super) fn mark_return_lazy_eligible(query: &mut CypherQuery) {
let n = query.clauses.len();
if n == 0 {
return;
}
let mut return_idx: Option<usize> = None;
for (i, c) in query.clauses.iter().enumerate() {
match c {
Clause::Match(_) | Clause::OptionalMatch(_) => {}
Clause::Return(_) => {
if return_idx.is_some() {
return; }
return_idx = Some(i);
}
Clause::Skip(_) | Clause::Limit(_) => {}
_ => return,
}
}
let Some(idx) = return_idx else {
return;
};
for c in &query.clauses[idx + 1..] {
match c {
Clause::Skip(_) | Clause::Limit(_) => {}
_ => return,
}
}
let r = match &query.clauses[idx] {
Clause::Return(r) => r,
_ => return,
};
if r.distinct || r.having.is_some() {
return;
}
let all_simple = r
.items
.iter()
.all(|item| matches!(item.expression, Expression::PropertyAccess { .. }));
if !all_simple {
return;
}
if let Clause::Return(r) = &mut query.clauses[idx] {
r.lazy_eligible = true;
}
}
pub(super) fn fuse_match_with_aggregate_top_k(query: &mut CypherQuery) {
use super::super::ast::is_aggregate_expression;
let mut i = 0;
while i + 3 < query.clauses.len() {
if !matches!(
(
&query.clauses[i],
&query.clauses[i + 1],
&query.clauses[i + 2],
&query.clauses[i + 3],
),
(
Clause::FusedMatchWithAggregate { .. },
Clause::Return(_),
Clause::OrderBy(_),
Clause::Limit(_)
)
) {
i += 1;
continue;
}
let with_items = match &query.clauses[i] {
Clause::FusedMatchWithAggregate { with_clause, .. } => with_clause.items.clone(),
_ => unreachable!(),
};
let already_has_top_k = matches!(
&query.clauses[i],
Clause::FusedMatchWithAggregate { top_k: Some(_), .. }
);
if already_has_top_k {
i += 1;
continue;
}
let mut count_alias: Option<String> = None;
let mut count_count = 0usize;
let mut aliases: std::collections::HashSet<String> = std::collections::HashSet::new();
for item in &with_items {
let alias = item
.alias
.clone()
.unwrap_or_else(|| match &item.expression {
Expression::Variable(v) => v.clone(),
Expression::PropertyAccess { variable, property } => {
format!("{variable}.{property}")
}
_ => format!("{:?}", item.expression),
});
if is_aggregate_expression(&item.expression) {
count_count += 1;
count_alias = Some(alias.clone());
}
aliases.insert(alias);
}
if count_count != 1 {
i += 1;
continue;
}
let count_alias = match count_alias {
Some(s) => s,
None => {
i += 1;
continue;
}
};
let mut group_vars: std::collections::HashSet<String> = std::collections::HashSet::new();
for item in &with_items {
if !is_aggregate_expression(&item.expression) {
match &item.expression {
Expression::Variable(v) => {
group_vars.insert(v.clone());
}
Expression::PropertyAccess { variable, .. } => {
group_vars.insert(variable.clone());
}
_ => {}
}
}
}
let return_ok = if let Clause::Return(r) = &query.clauses[i + 1] {
!r.distinct
&& r.items.iter().all(|item| match &item.expression {
Expression::Variable(v) => aliases.contains(v),
Expression::PropertyAccess { variable, .. } => group_vars.contains(variable),
_ => false,
})
} else {
false
};
if !return_ok {
i += 1;
continue;
}
let (target_count, descending) = if let Clause::OrderBy(o) = &query.clauses[i + 2] {
if o.items.len() != 1 {
(false, false)
} else {
let target = match &o.items[0].expression {
Expression::Variable(v) => v == &count_alias,
_ => false,
};
(target, !o.items[0].ascending)
}
} else {
(false, false)
};
if !target_count {
i += 1;
continue;
}
let limit = if let Clause::Limit(l) = &query.clauses[i + 3] {
match &l.count {
Expression::Literal(Value::Int64(n)) if *n > 0 => *n as usize,
_ => {
i += 1;
continue;
}
}
} else {
i += 1;
continue;
};
if let Clause::FusedMatchWithAggregate { top_k, .. } = &mut query.clauses[i] {
*top_k = Some(AggregateTopK { limit, descending });
}
i += 1;
}
}
pub(super) fn fuse_node_scan_top_k(query: &mut CypherQuery) {
use super::super::ast::is_aggregate_expression;
if query.clauses.len() < 4 {
return;
}
let mut i = 0;
while i + 3 < query.clauses.len() {
if i > 0 {
i += 1;
continue;
}
let (match_idx, where_idx, return_idx, orderby_idx, limit_idx) =
if matches!(&query.clauses[i], Clause::Match(_))
&& matches!(&query.clauses[i + 1], Clause::Where(_))
&& i + 4 < query.clauses.len()
&& matches!(&query.clauses[i + 2], Clause::Return(_))
&& matches!(&query.clauses[i + 3], Clause::OrderBy(_))
&& matches!(&query.clauses[i + 4], Clause::Limit(_))
{
(i, Some(i + 1), i + 2, i + 3, i + 4)
} else if matches!(&query.clauses[i], Clause::Match(_))
&& matches!(&query.clauses[i + 1], Clause::Return(_))
&& matches!(&query.clauses[i + 2], Clause::OrderBy(_))
&& matches!(&query.clauses[i + 3], Clause::Limit(_))
{
(i, None, i + 1, i + 2, i + 3)
} else {
i += 1;
continue;
};
let is_single_node = if let Clause::Match(mc) = &query.clauses[match_idx] {
mc.patterns.len() == 1
&& mc.patterns[0].elements.len() == 1
&& matches!(
mc.patterns[0].elements[0],
crate::graph::core::pattern_matching::PatternElement::Node(_)
)
&& mc.path_assignments.is_empty()
} else {
false
};
if !is_single_node {
i += 1;
continue;
}
let return_ok = if let Clause::Return(r) = &query.clauses[return_idx] {
!r.distinct
&& !r
.items
.iter()
.any(|item| is_aggregate_expression(&item.expression))
&& !r
.items
.iter()
.any(|item| matches!(item.expression, Expression::FunctionCall { .. }))
} else {
false
};
if !return_ok {
i += 1;
continue;
}
let sort_info = if let Clause::OrderBy(o) = &query.clauses[orderby_idx] {
if o.items.len() == 1 {
Some((o.items[0].expression.clone(), !o.items[0].ascending))
} else {
None
}
} else {
None
};
let Some((sort_expr, descending)) = sort_info else {
i += 1;
continue;
};
if let Clause::Return(r) = &query.clauses[return_idx] {
let return_aliases: std::collections::HashSet<String> = r
.items
.iter()
.filter_map(|item| item.alias.clone())
.collect();
if expression_touches_vars(&sort_expr, &return_aliases) {
i += 1;
continue;
}
}
let limit_val = if let Clause::Limit(l) = &query.clauses[limit_idx] {
match &l.count {
Expression::Literal(Value::Int64(n)) if *n > 0 => Some(*n as usize),
_ => None,
}
} else {
None
};
let Some(limit) = limit_val else {
i += 1;
continue;
};
query.clauses.remove(limit_idx);
query.clauses.remove(orderby_idx);
let return_clause = if let Clause::Return(r) = query.clauses.remove(return_idx) {
r
} else {
unreachable!()
};
let where_predicate = if let Some(wi) = where_idx {
if let Clause::Where(w) = query.clauses.remove(wi) {
Some(w.predicate)
} else {
None
}
} else {
None
};
let match_clause = if let Clause::Match(mc) = query.clauses.remove(match_idx) {
mc
} else {
unreachable!()
};
query.clauses.insert(
match_idx,
Clause::FusedNodeScanTopK {
match_clause,
where_predicate,
return_clause,
sort_expression: sort_expr,
descending,
limit,
},
);
i += 1;
}
}
pub(super) fn fuse_vector_score_order_limit(query: &mut CypherQuery) {
use super::super::ast::is_aggregate_expression;
if query.clauses.len() < 3 {
return;
}
let mut i = 0;
while i + 2 < query.clauses.len() {
let is_pattern = matches!(
(
&query.clauses[i],
&query.clauses[i + 1],
&query.clauses[i + 2]
),
(Clause::Return(_), Clause::OrderBy(_), Clause::Limit(_))
);
if !is_pattern {
i += 1;
continue;
}
let (score_idx, alias) = if let Clause::Return(r) = &query.clauses[i] {
if r.distinct
|| r.items
.iter()
.any(|item| is_aggregate_expression(&item.expression))
{
i += 1;
continue;
}
let found = r.items.iter().enumerate().find(|(_, item)| {
matches!(
&item.expression,
Expression::FunctionCall { name, .. }
if name == "vector_score"
)
});
match found {
Some((idx, item)) => {
let col = return_item_column_name(item);
(idx, col)
}
None => {
i += 1;
continue;
}
}
} else {
i += 1;
continue;
};
let descending = if let Clause::OrderBy(o) = &query.clauses[i + 1] {
if o.items.len() != 1 {
i += 1;
continue;
}
let sort_name = match &o.items[0].expression {
Expression::Variable(v) => v.clone(),
other => expression_to_column_name(other),
};
if sort_name != alias {
i += 1;
continue;
}
!o.items[0].ascending
} else {
i += 1;
continue;
};
let limit = if let Clause::Limit(l) = &query.clauses[i + 2] {
match &l.count {
Expression::Literal(Value::Int64(n)) if *n > 0 => *n as usize,
_ => {
i += 1;
continue;
}
}
} else {
i += 1;
continue;
};
query.clauses.remove(i + 2); query.clauses.remove(i + 1); let return_clause = if let Clause::Return(r) = query.clauses.remove(i) {
r
} else {
unreachable!()
};
query.clauses.insert(
i,
Clause::FusedVectorScoreTopK {
return_clause,
score_item_index: score_idx,
descending,
limit,
},
);
i += 1;
}
}
pub(super) fn return_item_column_name(item: &ReturnItem) -> String {
if let Some(ref alias) = item.alias {
alias.clone()
} else {
expression_to_column_name(&item.expression)
}
}
pub(super) fn expression_to_column_name(expr: &Expression) -> String {
match expr {
Expression::Variable(name) => name.clone(),
Expression::PropertyAccess { variable, property } => format!("{}.{}", variable, property),
Expression::FunctionCall { name, args, .. } => {
let args_str: Vec<String> = args.iter().map(expression_to_column_name).collect();
format!("{}({})", name, args_str.join(", "))
}
_ => format!("{:?}", expr),
}
}
pub(super) fn fuse_order_by_top_k(query: &mut CypherQuery) {
if query.clauses.len() < 3 {
return;
}
let mut i = 0;
while i + 2 < query.clauses.len() {
let is_pattern = matches!(
(
&query.clauses[i],
&query.clauses[i + 1],
&query.clauses[i + 2]
),
(Clause::Return(_), Clause::OrderBy(_), Clause::Limit(_))
);
if !is_pattern {
i += 1;
continue;
}
let (score_idx, sort_expression) = if let Clause::Return(r) = &query.clauses[i] {
if r.distinct {
i += 1;
continue;
}
if r.items
.iter()
.any(|item| super::super::ast::is_aggregate_expression(&item.expression))
{
i += 1;
continue;
}
if r.items
.iter()
.any(|item| matches!(item.expression, Expression::WindowFunction { .. }))
{
i += 1;
continue;
}
let order_info = if let Clause::OrderBy(o) = &query.clauses[i + 1] {
if o.items.len() != 1 {
i += 1;
continue;
}
let order_alias = match &o.items[0].expression {
Expression::Variable(v) => v.clone(),
other => expression_to_column_name(other),
};
let found = r
.items
.iter()
.enumerate()
.find(|(_, item)| return_item_column_name(item) == order_alias);
match found {
Some((idx, _)) => (idx, None), None => {
(0, Some(o.items[0].expression.clone()))
}
}
} else {
i += 1;
continue;
};
order_info
} else {
i += 1;
continue;
};
let descending = if let Clause::OrderBy(o) = &query.clauses[i + 1] {
!o.items[0].ascending
} else {
i += 1;
continue;
};
let limit = if let Clause::Limit(l) = &query.clauses[i + 2] {
match &l.count {
Expression::Literal(Value::Int64(n)) if *n > 0 => *n as usize,
_ => {
i += 1;
continue;
}
}
} else {
i += 1;
continue;
};
query.clauses.remove(i + 2); query.clauses.remove(i + 1); let return_clause = if let Clause::Return(r) = query.clauses.remove(i) {
r
} else {
unreachable!()
};
query.clauses.insert(
i,
Clause::FusedOrderByTopK {
return_clause,
score_item_index: score_idx,
descending,
limit,
sort_expression,
},
);
i += 1;
}
}
fn extract_spatial_join_contains(
pred: &Predicate,
) -> Option<(
String,
String,
super::super::ast::SpatialProbeKind,
Option<Predicate>,
)> {
match pred {
Predicate::Comparison {
left,
operator: ComparisonOp::NotEquals,
right: Expression::Literal(Value::Boolean(false)),
} => {
let (c, p, k) = extract_contains_call_vars(left)?;
Some((c, p, k, None))
}
Predicate::And(l, r) => {
if let Some((c, p, k, None)) = extract_spatial_join_contains(l) {
return Some((c, p, k, Some((**r).clone())));
}
if let Some((c, p, k, None)) = extract_spatial_join_contains(r) {
return Some((c, p, k, Some((**l).clone())));
}
None
}
_ => None,
}
}
fn extract_contains_call_vars(
expr: &Expression,
) -> Option<(String, String, super::super::ast::SpatialProbeKind)> {
use super::super::ast::SpatialProbeKind;
if let Expression::FunctionCall { name, args, .. } = expr {
if name != "contains" || args.len() != 2 {
return None;
}
let c = match &args[0] {
Expression::Variable(n) => n.clone(),
_ => return None,
};
let (p, kind) = match &args[1] {
Expression::Variable(n) => (n.clone(), SpatialProbeKind::Location),
Expression::FunctionCall {
name: inner_name,
args: inner_args,
..
} if inner_name == "centroid" && inner_args.len() == 1 => match &inner_args[0] {
Expression::Variable(n) => (n.clone(), SpatialProbeKind::Centroid),
_ => return None,
},
_ => return None,
};
if c == p {
return None;
}
Some((c, p, kind))
} else {
None
}
}
pub(super) fn fuse_spatial_join(query: &mut CypherQuery, graph: &DirGraph) {
if graph.has_secondary_labels {
return;
}
let mut i = 0;
while i < query.clauses.len() {
if try_fuse_spatial_single_match(query, graph, i)
|| try_fuse_spatial_multi_match(query, graph, i)
{
}
i += 1;
}
}
fn try_fuse_spatial_single_match(query: &mut CypherQuery, graph: &DirGraph, i: usize) -> bool {
if i + 1 >= query.clauses.len() {
return false;
}
let eligible = matches!(
(&query.clauses[i], &query.clauses[i + 1]),
(Clause::Match(_), Clause::Where(_))
);
if !eligible {
return false;
}
let (p0_var, p0_type, p1_var, p1_type) = {
let mc = match &query.clauses[i] {
Clause::Match(m) => m,
_ => return false,
};
if mc.patterns.len() != 2
|| !mc.path_assignments.is_empty()
|| mc.limit_hint.is_some()
|| mc.distinct_node_hint.is_some()
{
return false;
}
let (v0, t0) = match extract_single_typed_node(&mc.patterns[0]) {
Some(x) => x,
None => return false,
};
let (v1, t1) = match extract_single_typed_node(&mc.patterns[1]) {
Some(x) => x,
None => return false,
};
(v0, t0, v1, t1)
};
let (container_var, probe_var, probe_kind, remainder) = {
let w = match &query.clauses[i + 1] {
Clause::Where(w) => w,
_ => return false,
};
match extract_spatial_join_contains(&w.predicate) {
Some(x) => x,
None => return false,
}
};
let (container_type, probe_type) = if container_var == p0_var && probe_var == p1_var {
(p0_type.clone(), p1_type.clone())
} else if container_var == p1_var && probe_var == p0_var {
(p1_type.clone(), p0_type.clone())
} else {
return false;
};
if !spatial_schema_ok(graph, &container_type, &probe_type, probe_kind) {
return false;
}
query.clauses.remove(i + 1);
query.clauses[i] = Clause::SpatialJoin {
container_var,
probe_var,
container_type,
probe_type,
probe_kind,
remainder,
};
true
}
fn try_fuse_spatial_multi_match(query: &mut CypherQuery, graph: &DirGraph, i: usize) -> bool {
use super::super::ast::SpatialProbeKind;
if i + 2 >= query.clauses.len() {
return false;
}
let probe_pre_where_idx: Option<usize> = match (
matches!(query.clauses.get(i + 1), Some(Clause::Where(_))),
matches!(query.clauses.get(i + 2), Some(Clause::Match(_))),
) {
(true, true) => Some(i + 1),
(false, _) => None,
_ => return false,
};
let m1_idx = probe_pre_where_idx.map_or(i + 1, |_| i + 2);
let w_idx = m1_idx + 1;
if !matches!(query.clauses.get(m1_idx), Some(Clause::Match(_))) {
return false;
}
if !matches!(query.clauses.get(w_idx), Some(Clause::Where(_))) {
return false;
}
let extract_single = |c: &Clause| -> Option<(String, String)> {
let mc = match c {
Clause::Match(m) => m,
_ => return None,
};
if mc.patterns.len() != 1
|| !mc.path_assignments.is_empty()
|| mc.limit_hint.is_some()
|| mc.distinct_node_hint.is_some()
{
return None;
}
extract_single_typed_node(&mc.patterns[0])
};
let (m0_var, m0_type) = match extract_single(&query.clauses[i]) {
Some(x) => x,
None => return false,
};
let (m1_var, m1_type) = match extract_single(&query.clauses[m1_idx]) {
Some(x) => x,
None => return false,
};
if m0_var == m1_var {
return false;
}
let (container_var, probe_var, probe_kind, remainder) = {
let w = match &query.clauses[w_idx] {
Clause::Where(w) => w,
_ => return false,
};
match extract_spatial_join_contains(&w.predicate) {
Some(x) => x,
None => return false,
}
};
if probe_kind != SpatialProbeKind::Centroid {
return false;
}
let (cont_pat_type, probe_pat_type, probe_pat_is_first) =
if container_var == m0_var && probe_var == m1_var {
(m0_type.clone(), m1_type.clone(), false)
} else if container_var == m1_var && probe_var == m0_var {
(m1_type.clone(), m0_type.clone(), true)
} else {
return false;
};
let _ = probe_pat_is_first;
if !spatial_schema_ok(graph, &cont_pat_type, &probe_pat_type, probe_kind) {
return false;
}
let merged_remainder = match (probe_pre_where_idx, remainder) {
(None, r) => r,
(Some(idx), r) => {
let pre = match &query.clauses[idx] {
Clause::Where(w) => w.predicate.clone(),
_ => return false,
};
Some(match r {
Some(rest) => Predicate::And(Box::new(pre), Box::new(rest)),
None => pre,
})
}
};
query.clauses.remove(w_idx);
query.clauses.remove(m1_idx);
if let Some(pre_idx) = probe_pre_where_idx {
query.clauses.remove(pre_idx);
}
query.clauses[i] = Clause::SpatialJoin {
container_var,
probe_var,
container_type: cont_pat_type,
probe_type: probe_pat_type,
probe_kind,
remainder: merged_remainder,
};
true
}
fn extract_single_typed_node(
pat: &crate::graph::core::pattern_matching::Pattern,
) -> Option<(String, String)> {
if pat.elements.len() != 1 {
return None;
}
match &pat.elements[0] {
PatternElement::Node(np) => {
let v = np.variable.as_ref()?.clone();
let t = np.node_type.as_ref()?.clone();
Some((v, t))
}
_ => None,
}
}
fn spatial_schema_ok(
graph: &DirGraph,
container_type: &str,
probe_type: &str,
probe_kind: super::super::ast::SpatialProbeKind,
) -> bool {
use super::super::ast::SpatialProbeKind;
let container_ok = graph
.get_spatial_config(container_type)
.is_some_and(|c| c.geometry.is_some());
let probe_ok = match probe_kind {
SpatialProbeKind::Location => graph
.get_spatial_config(probe_type)
.is_some_and(|c| c.location.is_some()),
SpatialProbeKind::Centroid => graph
.get_spatial_config(probe_type)
.is_some_and(|c| c.geometry.is_some()),
};
container_ok && probe_ok
}
#[cfg(test)]
#[path = "fusion_spatial_tests.rs"]
mod spatial_join_tests;