use std::collections::{HashMap, HashSet};
use super::ast::{Condition, GraphMatchPredicate};
use super::graph_pattern::{Direction, GraphPattern};
use super::validation_types::{ValidationError, ValidationErrorKind};
pub(super) fn walk_graph_match_anchors(
condition: &Condition,
from_aliases: &[String],
) -> Result<(), ValidationError> {
let shared_aliases = collect_shared_pattern_aliases(condition);
walk_anchors(condition, from_aliases, &shared_aliases)
}
fn collect_shared_pattern_aliases(condition: &Condition) -> HashSet<String> {
let mut counts = HashMap::new();
accumulate_pattern_aliases(condition, &mut counts);
counts
.into_iter()
.filter_map(|(alias, predicates)| (predicates >= 2).then_some(alias))
.collect()
}
fn accumulate_pattern_aliases(condition: &Condition, counts: &mut HashMap<String, usize>) {
match condition {
Condition::GraphMatch(predicate) => {
let aliases: HashSet<&str> = predicate
.pattern
.nodes
.iter()
.filter_map(|node| node.alias.as_deref())
.collect();
for alias in aliases {
*counts.entry(alias.to_string()).or_insert(0) += 1;
}
}
Condition::And(l, r) | Condition::Or(l, r) => {
accumulate_pattern_aliases(l, counts);
accumulate_pattern_aliases(r, counts);
}
Condition::Not(inner) | Condition::Group(inner) => {
accumulate_pattern_aliases(inner, counts);
}
_ => {}
}
}
fn walk_anchors(
condition: &Condition,
from_aliases: &[String],
shared_aliases: &HashSet<String>,
) -> Result<(), ValidationError> {
match condition {
Condition::GraphMatch(predicate) => {
check_graph_match_anchor(predicate, from_aliases, shared_aliases)
}
Condition::And(l, r) | Condition::Or(l, r) => {
walk_anchors(l, from_aliases, shared_aliases)?;
walk_anchors(r, from_aliases, shared_aliases)
}
Condition::Not(inner) | Condition::Group(inner) => {
walk_anchors(inner, from_aliases, shared_aliases)
}
_ => Ok(()),
}
}
fn check_graph_match_anchor(
predicate: &GraphMatchPredicate,
from_aliases: &[String],
shared_aliases: &HashSet<String>,
) -> Result<(), ValidationError> {
let nodes = &predicate.pattern.nodes;
let Some(anchor) = nodes.first().and_then(|node| node.alias.as_deref()) else {
return Err(v011(
"MATCH (...)",
"MATCH in SELECT WHERE requires an alias on the first node, \
e.g. MATCH (d:Doc)-[:REL]->(x)",
));
};
if from_aliases.is_empty() || from_aliases.iter().any(|a| a == anchor) {
return Ok(());
}
check_implicit_anchor(predicate, anchor, from_aliases, shared_aliases)
}
fn check_implicit_anchor(
predicate: &GraphMatchPredicate,
anchor: &str,
from_aliases: &[String],
shared_aliases: &HashSet<String>,
) -> Result<(), ValidationError> {
check_g1_inverted_anchor(predicate, anchor, from_aliases)?;
let nodes = &predicate.pattern.nodes;
if nodes.first().is_some_and(|node| node.collection.is_some()) {
return Err(v011(
format!("MATCH ({anchor}@...)"),
format!(
"MATCH anchor alias '{anchor}' carries a @collection override and \
cannot bind implicitly to the FROM rows. Anchor the pattern on a \
declared alias, e.g. {rewritten}",
rewritten =
render_pattern_with_anchor(&predicate.pattern, &from_aliases[0], anchor)
),
));
}
if shared_aliases.contains(anchor) {
return Err(v011(
format!("MATCH ({anchor})"),
format!(
"implicit MATCH anchor '{anchor}' also appears in another MATCH \
predicate of this WHERE clause; chain into a single pattern \
instead, e.g. MATCH (m)-[:R]->(f)-[:S]->(g)"
),
));
}
Ok(())
}
fn check_g1_inverted_anchor(
predicate: &GraphMatchPredicate,
anchor: &str,
from_aliases: &[String],
) -> Result<(), ValidationError> {
let declared = predicate
.pattern
.nodes
.iter()
.skip(1)
.filter_map(|node| node.alias.as_deref())
.find(|alias| from_aliases.iter().any(|f| f == alias));
let Some(declared) = declared else {
return Ok(());
};
Err(v011(
format!("MATCH ({anchor})"),
format!(
"MATCH anchor alias '{anchor}' is not declared in FROM/JOIN while \
'{declared}' is. Anchor the pattern on '{declared}', \
e.g. {rewritten}",
rewritten = render_pattern_with_anchor(&predicate.pattern, declared, anchor)
),
))
}
fn v011(fragment: impl Into<String>, suggestion: impl Into<String>) -> ValidationError {
ValidationError::new(
ValidationErrorKind::GraphMatchAnchorMismatch,
None,
fragment,
suggestion,
)
}
fn render_pattern_with_anchor(pattern: &GraphPattern, anchor: &str, displaced: &str) -> String {
use std::fmt::Write;
let mut out = format!("MATCH ({anchor})");
for (rel, node) in pattern
.relationships
.iter()
.zip(pattern.nodes.iter().skip(1))
{
let types = if rel.types.is_empty() {
String::new()
} else {
format!(":{}", rel.types.join("|"))
};
let range = rel
.range
.map(|(lo, hi)| format!("*{lo}..{hi}"))
.unwrap_or_default();
let body = format!("[{types}{range}]");
let arrow = match rel.direction {
Direction::Outgoing => format!("-{body}->"),
Direction::Incoming => format!("<-{body}-"),
Direction::Both => format!("-{body}-"),
};
let alias = match node.alias.as_deref() {
Some(a) if a == anchor => displaced,
Some(a) => a,
None => "",
};
let _ = write!(out, "{arrow}({alias})");
}
out
}