use crate::{
ast::{
call::{InlineProcedureCall, ProcedureCall},
ddl::{DdlStatement, TypePropertyConstraint},
expr::{ExistsBody, IsCheckKind, ValueExpr},
mutation::{MutationPipeline, MutationStatement, MutationTerminator, SetItem},
pattern::{GraphPattern, MatchClause, PatternElement},
statement::{PipelineStatement, QueryPipeline, ReturnClause, Statement, WithClause},
},
error::ParserError,
};
pub(crate) const MAX_EXPR_DEPTH: u32 = 256;
enum Node<'a> {
Statement(&'a Statement),
Pipeline(&'a QueryPipeline),
Mutation(&'a MutationPipeline),
MatchClause(&'a MatchClause),
Pattern(&'a GraphPattern),
Return(&'a ReturnClause),
With(&'a WithClause),
Call(&'a ProcedureCall),
Ddl(&'a DdlStatement),
Expr(&'a ValueExpr, u32),
}
pub(super) fn reject_excessive_expr_depth(statement: &Statement) -> Result<(), ParserError> {
let mut work = vec![Node::Statement(statement)];
while let Some(node) = work.pop() {
match node {
Node::Statement(statement) => push_statement(statement, &mut work),
Node::Pipeline(pipeline) => push_pipeline(pipeline, &mut work),
Node::Mutation(pipeline) => push_mutation(pipeline, &mut work),
Node::MatchClause(clause) => push_match_clause(clause, &mut work),
Node::Pattern(pattern) => push_pattern(pattern, &mut work),
Node::Return(clause) => push_return(clause, &mut work),
Node::With(clause) => push_with(clause, &mut work),
Node::Call(call) => push_call(call, &mut work),
Node::Ddl(statement) => push_ddl(statement, &mut work),
Node::Expr(expr, depth) => push_expr(expr, depth, &mut work)?,
}
}
Ok(())
}
fn push_expr<'a>(
expr: &'a ValueExpr,
depth: u32,
work: &mut Vec<Node<'a>>,
) -> Result<(), ParserError> {
if depth > MAX_EXPR_DEPTH {
return Err(ParserError::ComplexityLimitExceeded {
limit: MAX_EXPR_DEPTH,
span: expr.span(),
});
}
let next = depth.saturating_add(1);
match expr {
ValueExpr::Literal(_) | ValueExpr::Variable { .. } | ValueExpr::Parameter { .. } => {}
ValueExpr::PropertyAccess { target, .. } | ValueExpr::PropertyExists { target, .. } => {
work.push(Node::Expr(target, next));
}
ValueExpr::ListLiteral { items, .. }
| ValueExpr::PathConstructor {
elements: items, ..
}
| ValueExpr::AllDifferent { items, .. }
| ValueExpr::Same { items, .. } => {
for item in items {
work.push(Node::Expr(item, next));
}
}
ValueExpr::RecordLiteral { fields, .. } => {
for (_, value) in fields {
work.push(Node::Expr(value, next));
}
}
ValueExpr::BinaryOp { lhs, rhs, .. } => {
work.push(Node::Expr(lhs, next));
work.push(Node::Expr(rhs, next));
}
ValueExpr::UnaryOp { operand, .. } => work.push(Node::Expr(operand, next)),
ValueExpr::FunctionCall { args, .. } => {
for arg in args {
work.push(Node::Expr(arg, next));
}
}
ValueExpr::DurationBetween { start, end, .. } => {
work.push(Node::Expr(start, next));
work.push(Node::Expr(end, next));
}
ValueExpr::IsCheck { operand, kind, .. } => {
work.push(Node::Expr(operand, next));
match kind {
IsCheckKind::SourceOf(value) | IsCheckKind::DestinationOf(value) => {
work.push(Node::Expr(value, next));
}
IsCheckKind::Null
| IsCheckKind::Directed
| IsCheckKind::Labeled(_)
| IsCheckKind::TruthValue(_)
| IsCheckKind::Typed(_)
| IsCheckKind::Normalized(_) => {}
}
}
ValueExpr::InList { operand, list, .. } => {
work.push(Node::Expr(operand, next));
for item in list {
work.push(Node::Expr(item, next));
}
}
ValueExpr::InListExpression { operand, list, .. } => {
work.push(Node::Expr(operand, next));
work.push(Node::Expr(list, next));
}
ValueExpr::Case {
branches,
else_branch,
..
} => {
for (condition, value) in branches {
work.push(Node::Expr(condition, next));
work.push(Node::Expr(value, next));
}
if let Some(value) = else_branch {
work.push(Node::Expr(value, next));
}
}
ValueExpr::Cast { value, .. } | ValueExpr::Normalize { source: value, .. } => {
work.push(Node::Expr(value, next));
}
ValueExpr::Trim {
character, source, ..
} => {
if let Some(character) = character {
work.push(Node::Expr(character, next));
}
work.push(Node::Expr(source, next));
}
ValueExpr::Exists { body, .. } => match body {
ExistsBody::Match(pattern) => work.push(Node::MatchClause(pattern)),
ExistsBody::Query(pipeline) => work.push(Node::Pipeline(pipeline)),
},
ValueExpr::ValueSubquery { body, .. } => work.push(Node::Pipeline(body)),
}
Ok(())
}
fn push_statement<'a>(statement: &'a Statement, work: &mut Vec<Node<'a>>) {
match statement {
Statement::Query(pipeline) => work.push(Node::Pipeline(pipeline)),
Statement::Composite { first, rest, .. } => {
work.push(Node::Pipeline(first));
for (_, pipeline) in rest.iter() {
work.push(Node::Pipeline(pipeline));
}
}
Statement::Chained { blocks, .. } => {
for pipeline in blocks {
work.push(Node::Pipeline(pipeline));
}
}
Statement::Mutate(pipeline) => work.push(Node::Mutation(pipeline)),
Statement::Ddl(statement) => work.push(Node::Ddl(statement)),
Statement::Call(call) => work.push(Node::Call(call)),
Statement::Explain { inner, .. } => work.push(Node::Statement(inner)),
Statement::SessionSetValue { value, .. } => work.push(Node::Expr(value, 1)),
Statement::StartTransaction { .. }
| Statement::Commit { .. }
| Statement::Rollback { .. }
| Statement::SessionSetTimeZone { .. }
| Statement::SessionSetGraph { .. }
| Statement::SessionReset { .. }
| Statement::SessionClose { .. } => {}
}
}
fn push_pipeline<'a>(pipeline: &'a QueryPipeline, work: &mut Vec<Node<'a>>) {
for statement in &pipeline.statements {
match statement {
PipelineStatement::Match(clause) => work.push(Node::MatchClause(clause)),
PipelineStatement::Filter(expr) => work.push(Node::Expr(expr, 1)),
PipelineStatement::Let(bindings) => {
for binding in bindings {
work.push(Node::Expr(&binding.value, 1));
}
}
PipelineStatement::For(statement) => work.push(Node::Expr(&statement.source, 1)),
PipelineStatement::Sorting(terms) => {
for term in terms {
work.push(Node::Expr(&term.expr, 1));
}
}
PipelineStatement::Limit(_) | PipelineStatement::Offset(_) => {}
PipelineStatement::Return(clause) => work.push(Node::Return(clause)),
PipelineStatement::With(clause) => work.push(Node::With(clause)),
PipelineStatement::Call(call) => work.push(Node::Call(call)),
PipelineStatement::CallSubquery(call) => push_inline_call(call, work),
}
}
}
fn push_inline_call<'a>(call: &'a InlineProcedureCall, work: &mut Vec<Node<'a>>) {
work.push(Node::Pipeline(&call.body));
}
fn push_mutation<'a>(pipeline: &'a MutationPipeline, work: &mut Vec<Node<'a>>) {
for statement in pipeline.statements.iter() {
match statement {
MutationStatement::Match(clause) => work.push(Node::MatchClause(clause)),
MutationStatement::Filter(expr) => work.push(Node::Expr(expr, 1)),
MutationStatement::Insert(insert) => {
for pattern in &insert.patterns {
work.push(Node::Pattern(pattern));
}
}
MutationStatement::Set(items) => {
for item in items {
push_set_item(item, work);
}
}
MutationStatement::Remove(_) | MutationStatement::Delete(_) => {}
}
}
if let Some(MutationTerminator::Return(clause)) = &pipeline.terminator {
work.push(Node::Return(clause));
}
}
fn push_set_item<'a>(item: &'a SetItem, work: &mut Vec<Node<'a>>) {
match item {
SetItem::Property { value, .. } => work.push(Node::Expr(value, 1)),
SetItem::PropertyMerge { properties, .. } => {
for (_, value) in properties {
work.push(Node::Expr(value, 1));
}
}
SetItem::Label { .. } => {}
}
}
fn push_match_clause<'a>(clause: &'a MatchClause, work: &mut Vec<Node<'a>>) {
for pattern in &clause.patterns {
work.push(Node::Pattern(pattern));
}
if let Some(where_clause) = &clause.where_clause {
work.push(Node::Expr(where_clause, 1));
}
}
fn push_pattern<'a>(pattern: &'a GraphPattern, work: &mut Vec<Node<'a>>) {
for element in &pattern.elements {
let (properties, inline_where) = match element {
PatternElement::Node(node) => (&node.properties, &node.inline_where),
PatternElement::Edge(edge) => (&edge.properties, &edge.inline_where),
};
for (_, value) in properties {
work.push(Node::Expr(value, 1));
}
if let Some(inline_where) = inline_where {
work.push(Node::Expr(inline_where, 1));
}
}
}
fn push_return<'a>(clause: &'a ReturnClause, work: &mut Vec<Node<'a>>) {
for item in &clause.items {
work.push(Node::Expr(&item.expr, 1));
}
if let Some(group_by) = &clause.group_by {
for expr in group_by {
work.push(Node::Expr(expr, 1));
}
}
if let Some(having) = &clause.having {
work.push(Node::Expr(having, 1));
}
}
fn push_with<'a>(clause: &'a WithClause, work: &mut Vec<Node<'a>>) {
for item in &clause.items {
work.push(Node::Expr(&item.expr, 1));
}
if let Some(group_by) = &clause.group_by {
for expr in group_by {
work.push(Node::Expr(expr, 1));
}
}
if let Some(having) = &clause.having {
work.push(Node::Expr(having, 1));
}
if let Some(where_clause) = &clause.where_clause {
work.push(Node::Expr(where_clause, 1));
}
}
fn push_call<'a>(call: &'a ProcedureCall, work: &mut Vec<Node<'a>>) {
for arg in &call.args {
work.push(Node::Expr(arg, 1));
}
}
fn push_ddl<'a>(statement: &'a DdlStatement, work: &mut Vec<Node<'a>>) {
let properties = match statement {
DdlStatement::CreateNodeType { properties, .. }
| DdlStatement::CreateEdgeType { properties, .. } => properties,
DdlStatement::CreateGraph { .. }
| DdlStatement::DropGraph { .. }
| DdlStatement::DropNodeType { .. }
| DdlStatement::DropEdgeType { .. }
| DdlStatement::TruncateNodeType { .. }
| DdlStatement::TruncateEdgeType { .. }
| DdlStatement::CreateIndex { .. }
| DdlStatement::DropIndex { .. }
| DdlStatement::ShowNodeTypes(_)
| DdlStatement::ShowEdgeTypes(_)
| DdlStatement::ShowIndexes(_)
| DdlStatement::ShowProcedures(_) => return,
};
for property in properties {
for constraint in &property.constraints {
if let TypePropertyConstraint::Default(value, _) = constraint {
work.push(Node::Expr(value, 1));
}
}
}
}