use indexmap::IndexMap;
use smol_str::SmolStr;
use cyrs_hir::{
Clause, Direction as HirDir, Expr as HirExpr, HirSpan, ListPredKind as HirListPredKind,
Pattern, PatternElement, PatternPart, Projection, RelLength as HirRelLen, RemoveItem, SetItem,
Statement, VarId as HirVarId,
};
use crate::{
AggExpr, BinOp, Direction, Expr, LabelSet, ListPredKind, NodeSpec, OpId, OrderKey,
PlanLowerError, Projection as PlanProj, ReadOp, RelLength, RelSpec, UnaryOp, UnionKind, VarId,
WriteOp,
};
#[derive(Debug, Clone)]
pub struct PlanStatement {
pub ops: Vec<ReadOp>,
pub write_ops: Vec<WriteOp>,
pub var_map: IndexMap<VarId, HirVarId>,
}
impl PlanStatement {
fn new() -> Self {
Self::empty()
}
#[must_use]
pub fn empty() -> Self {
Self {
ops: Vec::new(),
write_ops: Vec::new(),
var_map: IndexMap::new(),
}
}
fn push(&mut self, op: ReadOp) -> OpId {
#[allow(clippy::cast_possible_truncation)]
let id = OpId(self.ops.len() as u32);
self.ops.push(op);
id
}
}
pub fn lower_statement(stmt: &Statement) -> Result<PlanStatement, PlanLowerError> {
precheck_statement(stmt)?;
let mut ctx = LowerCtx::new(stmt);
ctx.lower(stmt);
Ok(ctx.into_plan())
}
fn precheck_statement(stmt: &Statement) -> Result<(), PlanLowerError> {
for clause in &stmt.clauses {
let span = clause.span();
match clause {
Clause::Match { pattern, .. } | Clause::Create { pattern, .. } => {
check_pattern(pattern, span)?;
}
Clause::Where { predicate, .. } => check_expr(predicate, span)?,
Clause::With {
projections,
filter,
..
} => {
for p in projections {
check_expr(&p.expr, span)?;
}
if let Some(f) = filter {
check_expr(f, span)?;
}
}
Clause::Return { projections, .. } => {
for p in projections {
check_expr(&p.expr, span)?;
}
}
Clause::Unwind { list, .. } => check_expr(list, span)?,
Clause::Merge {
pattern,
on_create,
on_match,
..
} => {
check_pattern(pattern, span)?;
for item in on_create.iter().chain(on_match.iter()) {
check_set_item(item, span)?;
}
}
Clause::Set { items, .. } => {
for item in items {
check_set_item(item, span)?;
}
}
Clause::Remove { items, .. } => {
for item in items {
check_remove_item(item, span)?;
}
}
Clause::Delete { targets, .. } => {
for t in targets {
check_expr(t, span)?;
}
}
Clause::Call { args, .. } => {
for a in args {
check_expr(a, span)?;
}
}
}
}
Ok(())
}
fn check_pattern(pattern: &Pattern, clause_span: HirSpan) -> Result<(), PlanLowerError> {
for part in &pattern.parts {
match part.elements.first() {
None => return Err(PlanLowerError::EmptyPatternPart { span: clause_span }),
Some(PatternElement::Rel { .. }) => {
return Err(PlanLowerError::EmptyPatternPart { span: clause_span });
}
Some(PatternElement::Node { .. }) => {}
}
for elem in &part.elements {
let props = match elem {
PatternElement::Node { props, .. } | PatternElement::Rel { props, .. } => {
props.as_ref()
}
};
if let Some(p) = props {
check_expr(p, elem.span())?;
}
}
}
Ok(())
}
fn check_set_item(item: &SetItem, span: HirSpan) -> Result<(), PlanLowerError> {
match item {
SetItem::Property { target, value, .. } => {
check_expr(target, span)?;
check_expr(value, span)?;
}
SetItem::Labels { .. } => {}
SetItem::AssignMap { map, .. } => check_expr(map, span)?,
}
Ok(())
}
fn check_remove_item(item: &RemoveItem, span: HirSpan) -> Result<(), PlanLowerError> {
match item {
RemoveItem::Property { target, .. } => check_expr(target, span)?,
RemoveItem::Labels { .. } => {}
}
Ok(())
}
fn check_expr(expr: &HirExpr, span: HirSpan) -> Result<(), PlanLowerError> {
match expr {
HirExpr::Null
| HirExpr::Bool(_)
| HirExpr::Int(_)
| HirExpr::Float(_)
| HirExpr::String(_)
| HirExpr::Var(_)
| HirExpr::Param(_) => Ok(()),
HirExpr::PatternPredicate(pattern) => check_pattern(pattern, span),
HirExpr::Unresolved(name) => Err(PlanLowerError::UnresolvedName {
name: name.clone(),
span,
}),
HirExpr::ListComprehension { .. } => Err(PlanLowerError::UndesugaredExpr {
kind: "ListComprehension",
span,
}),
HirExpr::MapProjection { .. } => Err(PlanLowerError::UndesugaredExpr {
kind: "MapProjection",
span,
}),
HirExpr::Prop { target, .. } => check_expr(target, span),
HirExpr::Index { target, index } => {
check_expr(target, span)?;
check_expr(index, span)
}
HirExpr::Slice { target, start, end } => {
check_expr(target, span)?;
if let Some(s) = start {
check_expr(s, span)?;
}
if let Some(e) = end {
check_expr(e, span)?;
}
Ok(())
}
HirExpr::List(items) => {
for item in items {
check_expr(item, span)?;
}
Ok(())
}
HirExpr::Map(pairs) => {
for (_, v) in pairs {
check_expr(v, span)?;
}
Ok(())
}
HirExpr::Call { args, .. } => {
for a in args {
check_expr(a, span)?;
}
Ok(())
}
HirExpr::BinOp { lhs, rhs, .. } => {
check_expr(lhs, span)?;
check_expr(rhs, span)
}
HirExpr::UnaryOp { operand, .. } | HirExpr::IsNull { operand, .. } => {
check_expr(operand, span)
}
HirExpr::Case {
scrutinee,
arms,
otherwise,
} => {
if let Some(s) = scrutinee {
check_expr(s, span)?;
}
for (w, t) in arms {
check_expr(w, span)?;
check_expr(t, span)?;
}
if let Some(o) = otherwise {
check_expr(o, span)?;
}
Ok(())
}
HirExpr::InList { operand, list } => {
check_expr(operand, span)?;
check_expr(list, span)
}
HirExpr::ListPredicate {
iterable,
predicate,
..
} => {
check_expr(iterable, span)?;
if let Some(p) = predicate {
check_expr(p, span)?;
}
Ok(())
}
}
}
struct LowerCtx<'s> {
plan: PlanStatement,
hir_to_plan: IndexMap<HirVarId, VarId>,
next_var: u32,
_stmt: &'s Statement,
}
impl<'s> LowerCtx<'s> {
fn new(stmt: &'s Statement) -> Self {
Self {
plan: PlanStatement::new(),
hir_to_plan: IndexMap::new(),
next_var: 0,
_stmt: stmt,
}
}
fn into_plan(self) -> PlanStatement {
self.plan
}
fn map_var(&mut self, hir_var: HirVarId) -> VarId {
if let Some(&plan_var) = self.hir_to_plan.get(&hir_var) {
return plan_var;
}
let plan_var = VarId(self.next_var);
self.next_var += 1;
self.hir_to_plan.insert(hir_var, plan_var);
self.plan.var_map.insert(plan_var, hir_var);
plan_var
}
fn lower(&mut self, stmt: &Statement) {
let mut current_op: Option<OpId> = None;
let mut i = 0;
while i < stmt.clauses.len() {
let clause = &stmt.clauses[i];
match clause {
Clause::Match {
pattern, optional, ..
} => {
let (new_op, _) = self.lower_match_pattern(pattern, current_op, *optional);
current_op = Some(new_op);
}
Clause::Where { predicate, .. } => {
let pred = self.lower_expr(predicate);
let input = current_op.unwrap_or_else(|| self.push_source_all());
let op = self.plan.push(ReadOp::Filter {
input,
predicate: pred,
});
current_op = Some(op);
}
Clause::With {
projections,
filter,
..
} => {
let input = current_op.unwrap_or_else(|| self.push_source_all());
let items = self.lower_projections(projections);
let filter_expr = filter.as_ref().map(|f| self.lower_expr(f));
let op = self.plan.push(ReadOp::With {
input,
items,
filter: filter_expr,
});
current_op = Some(op);
}
Clause::Return {
projections,
distinct,
..
} => {
let input = current_op.unwrap_or_else(|| self.push_source_all());
let (items, agg_items) = self.split_projections_agg(projections);
let op = if agg_items.is_empty() {
let proj_items = self.lower_projections(projections);
self.plan.push(ReadOp::Project {
input,
items: proj_items,
})
} else {
let keys: Vec<Expr> = items.iter().map(|p| p.expr.clone()).collect();
let agg_op = self.plan.push(ReadOp::Aggregate {
input,
keys,
aggs: agg_items,
});
let all_items = self.lower_projections(projections);
self.plan.push(ReadOp::Project {
input: agg_op,
items: all_items,
})
};
let op = if *distinct {
self.plan.push(ReadOp::Distinct { input: op })
} else {
op
};
current_op = Some(op);
}
Clause::Unwind { list, bind, .. } => {
let input = current_op.unwrap_or_else(|| self.push_source_all());
let list_expr = self.lower_expr(list);
let bind_var = self.map_var(*bind);
let op = self.plan.push(ReadOp::Unwind {
input,
list: list_expr,
bind: bind_var,
});
current_op = Some(op);
}
Clause::Create { pattern, .. } => {
let write_ops = self.lower_create_pattern(pattern);
self.plan.write_ops.extend(write_ops);
}
Clause::Merge {
pattern,
on_create,
on_match,
..
} => {
let write_ops = self.lower_merge_pattern(pattern, on_create, on_match);
self.plan.write_ops.extend(write_ops);
}
Clause::Set { items, .. } => {
let write_ops = self.lower_set_items(items);
self.plan.write_ops.extend(write_ops);
}
Clause::Remove { items, .. } => {
let write_ops = self.lower_remove_items(items);
self.plan.write_ops.extend(write_ops);
}
Clause::Delete {
targets, detach, ..
} => {
let exprs: Vec<Expr> = targets.iter().map(|e| self.lower_expr(e)).collect();
self.plan.write_ops.push(WriteOp::Delete {
targets: exprs,
detach: *detach,
});
}
Clause::Call { .. } => {
}
}
i += 1;
}
}
fn push_source_all(&mut self) -> OpId {
self.plan.push(ReadOp::Source {
label: None,
bind: VarId(self.next_var),
})
}
fn lower_match_pattern(
&mut self,
pattern: &Pattern,
current_op: Option<OpId>,
optional: bool,
) -> (OpId, Vec<VarId>) {
let mut vars = Vec::new();
let mut op: Option<OpId> = None;
for part in &pattern.parts {
let part_op = self.lower_pattern_part(part, &mut vars);
op = Some(match op {
None => part_op,
Some(left) => {
let _ = left;
part_op
}
});
}
let inner_op = op.unwrap_or_else(|| {
let bind = VarId(self.next_var);
self.next_var += 1;
self.plan.push(ReadOp::Source { label: None, bind })
});
let final_op = if optional {
if let Some(outer) = current_op {
let inner_root = self.plan.ops[inner_op.0 as usize].clone();
self.plan.push(ReadOp::OptionalJoin {
input: outer,
pattern: Box::new(inner_root),
})
} else {
inner_op
}
} else {
inner_op
};
(final_op, vars)
}
fn lower_pattern_part(&mut self, part: &PatternPart, vars: &mut Vec<VarId>) -> OpId {
let mut last_op: Option<OpId> = None;
let mut last_node_var: Option<VarId> = None;
let mut last_rel: Option<&PatternElement> = None;
for elem in &part.elements {
match elem {
PatternElement::Node {
bind,
labels,
props,
..
} => {
let bind_var = bind.map(|v| {
let pv = self.map_var(v);
vars.push(pv);
pv
});
if let (Some(rel_elem), Some(from), Some(input)) =
(last_rel.take(), last_node_var, last_op)
{
let bind_var = bind_var.unwrap_or_else(|| {
let v = VarId(self.next_var);
self.next_var += 1;
v
});
let bind_to = bind_var;
let (rel_spec, bind_rel) = self.lower_rel_element(rel_elem, vars);
let node_spec = NodeSpec {
labels: LabelSet(labels.clone()),
properties: props.as_ref().map(|e| self.lower_expr(e)),
};
let op = self.plan.push(ReadOp::Expand {
input,
from,
rel: rel_spec,
to: node_spec,
bind_rel,
bind_to,
});
last_node_var = Some(bind_to);
last_op = Some(op);
} else {
let label_set = if labels.is_empty() {
None
} else {
Some(LabelSet(labels.clone()))
};
let bind_var = bind_var.unwrap_or_else(|| {
let v = VarId(self.next_var);
self.next_var += 1;
v
});
let op = self.plan.push(ReadOp::Source {
label: label_set,
bind: bind_var,
});
let op = if let Some(prop_expr) = props.as_ref() {
let predicate = self.lower_expr(prop_expr);
self.plan.push(ReadOp::Filter {
input: op,
predicate,
})
} else {
op
};
last_node_var = Some(bind_var);
last_op = Some(op);
}
}
PatternElement::Rel { .. } => {
last_rel = Some(elem);
}
}
}
last_op.unwrap_or_else(|| self.push_source_all())
}
fn lower_rel_element(
&mut self,
elem: &PatternElement,
vars: &mut Vec<VarId>,
) -> (RelSpec, VarId) {
match elem {
PatternElement::Rel {
bind,
types,
direction,
length,
props,
..
} => {
let bind_rel = bind
.map(|v| {
let pv = self.map_var(v);
vars.push(pv);
pv
})
.unwrap_or_else(|| {
let v = VarId(self.next_var);
self.next_var += 1;
v
});
let dir = match direction {
HirDir::Outgoing => Direction::Outgoing,
HirDir::Incoming => Direction::Incoming,
HirDir::Undirected => Direction::Undirected,
_ => unreachable!("cyrs-plan::lower: unhandled Direction variant"),
};
let rel_len = match length {
HirRelLen::Single => RelLength::Single,
HirRelLen::Variable { min, max } => RelLength::Variable {
min: *min,
max: *max,
},
_ => unreachable!("cyrs-plan::lower: unhandled RelLength variant"),
};
let rel_spec = RelSpec {
types: types.clone(),
direction: dir,
length: rel_len,
properties: props.as_ref().map(|e| self.lower_expr(e)),
};
(rel_spec, bind_rel)
}
PatternElement::Node { .. } => panic!("lower_rel_element called on a Node element"),
}
}
fn lower_projections(&mut self, projs: &[Projection]) -> Vec<PlanProj> {
projs
.iter()
.map(|p| {
let expr = self.lower_expr(&p.expr);
let alias = p.alias.clone().unwrap_or_else(|| synthesise_alias(&p.expr));
PlanProj { expr, alias }
})
.collect()
}
fn split_projections_agg(&mut self, projs: &[Projection]) -> (Vec<PlanProj>, Vec<AggExpr>) {
let mut non_agg = Vec::new();
let mut agg = Vec::new();
for p in projs {
if let HirExpr::Call {
name,
args,
distinct,
} = &p.expr
&& is_aggregate_func(name)
{
let plan_args: Vec<Expr> = args.iter().map(|a| self.lower_expr(a)).collect();
agg.push(AggExpr {
func: name.clone(),
args: plan_args,
distinct: *distinct,
});
continue;
}
let expr = self.lower_expr(&p.expr);
let alias = p.alias.clone().unwrap_or_else(|| synthesise_alias(&p.expr));
non_agg.push(PlanProj { expr, alias });
}
(non_agg, agg)
}
fn lower_create_pattern(&mut self, pattern: &Pattern) -> Vec<WriteOp> {
let mut ops = Vec::new();
for part in &pattern.parts {
let paired = create_pattern_pairs(part);
for pair in paired {
match pair {
CreatePair::Node {
labels,
props,
bind,
} => {
let bind_var = bind.map(|v| self.map_var(v));
let props_expr = if let Some(e) = props.as_ref() {
self.lower_expr(e)
} else {
Expr::Map(vec![])
};
ops.push(WriteOp::CreateNode {
labels,
props: props_expr,
bind: bind_var,
});
}
CreatePair::Rel {
from_bind,
to_bind,
rel_type,
props,
bind,
} => {
let from = self.map_var(from_bind);
let to = self.map_var(to_bind);
let bind_rel = bind.map(|v| self.map_var(v));
let props_expr = if let Some(e) = props.as_ref() {
self.lower_expr(e)
} else {
Expr::Map(vec![])
};
ops.push(WriteOp::CreateRel {
from,
to,
rel_type,
props: props_expr,
bind: bind_rel,
});
}
}
}
}
ops
}
fn lower_merge_pattern(
&mut self,
pattern: &Pattern,
on_create: &[SetItem],
on_match: &[SetItem],
) -> Vec<WriteOp> {
let mut ops = Vec::new();
let create_ops = self.lower_set_items(on_create);
let match_ops = self.lower_set_items(on_match);
for part in &pattern.parts {
let paired = create_pattern_pairs(part);
for pair in paired {
match pair {
CreatePair::Node {
labels,
props,
bind,
} => {
let bind_var = bind.map(|v| self.map_var(v));
let props_expr = if let Some(e) = props.as_ref() {
self.lower_expr(e)
} else {
Expr::Map(vec![])
};
ops.push(WriteOp::MergeNode {
labels,
props: props_expr,
on_create: create_ops.clone(),
on_match: match_ops.clone(),
bind: bind_var,
});
}
CreatePair::Rel {
from_bind,
to_bind,
rel_type,
props,
bind,
} => {
let from = self.map_var(from_bind);
let to = self.map_var(to_bind);
let bind_rel = bind.map(|v| self.map_var(v));
let props_expr = if let Some(e) = props.as_ref() {
self.lower_expr(e)
} else {
Expr::Map(vec![])
};
ops.push(WriteOp::MergeRel {
from,
to,
rel_type,
props: props_expr,
on_create: create_ops.clone(),
on_match: match_ops.clone(),
bind: bind_rel,
});
}
}
}
}
ops
}
fn lower_set_items(&mut self, items: &[SetItem]) -> Vec<WriteOp> {
items
.iter()
.flat_map(|item| self.lower_set_item(item))
.collect()
}
fn lower_set_item(&mut self, item: &SetItem) -> Vec<WriteOp> {
match item {
SetItem::Property {
target,
prop,
value,
} => {
let target_var = if let Some(hir_var) = expr_to_var_id(target) {
self.map_var(hir_var)
} else {
let v = VarId(self.next_var);
self.next_var += 1;
v
};
vec![WriteOp::SetProperty {
target: target_var,
prop: prop.clone(),
value: self.lower_expr(value),
}]
}
SetItem::Labels { target, labels } => {
let target_var = self.map_var(*target);
vec![WriteOp::SetLabels {
target: target_var,
labels: labels.clone(),
}]
}
SetItem::AssignMap {
target,
map: _,
replace: _,
} => {
let target_var = self.map_var(*target);
vec![WriteOp::SetLabels {
target: target_var,
labels: vec![],
}]
}
}
}
fn lower_remove_items(&mut self, items: &[RemoveItem]) -> Vec<WriteOp> {
items
.iter()
.map(|item| match item {
RemoveItem::Property { target, prop } => {
let target_var = if let Some(hir_var) = expr_to_var_id(target) {
self.map_var(hir_var)
} else {
let v = VarId(self.next_var);
self.next_var += 1;
v
};
WriteOp::RemoveProperty {
target: target_var,
prop: prop.clone(),
}
}
RemoveItem::Labels { target, labels } => {
let target_var = self.map_var(*target);
WriteOp::RemoveLabels {
target: target_var,
labels: labels.clone(),
}
}
})
.collect()
}
fn lower_expr(&mut self, expr: &HirExpr) -> Expr {
match expr {
HirExpr::Null => Expr::Null,
HirExpr::Bool(b) => Expr::Bool(*b),
HirExpr::Int(i) => Expr::Int(*i),
HirExpr::Float(f) => Expr::Float(*f),
HirExpr::String(s) => Expr::String(s.clone()),
HirExpr::Var(v) => Expr::Var(self.map_var(*v)),
HirExpr::Param(name) => Expr::Param { name: name.clone() },
HirExpr::Prop { target, prop } => Expr::Prop {
target: Box::new(self.lower_expr(target)),
prop: prop.clone(),
},
HirExpr::Index { target, index } => Expr::Index {
target: Box::new(self.lower_expr(target)),
index: Box::new(self.lower_expr(index)),
},
HirExpr::Slice { target, start, end } => Expr::Slice {
target: Box::new(self.lower_expr(target)),
start: start.as_ref().map(|s| Box::new(self.lower_expr(s))),
end: end.as_ref().map(|e| Box::new(self.lower_expr(e))),
},
HirExpr::List(items) => Expr::List(items.iter().map(|e| self.lower_expr(e)).collect()),
HirExpr::Map(pairs) => Expr::Map(
pairs
.iter()
.map(|(k, v)| (k.clone(), self.lower_expr(v)))
.collect(),
),
HirExpr::Call {
name,
args,
distinct: _,
} => Expr::Call {
func: name.clone(),
args: args.iter().map(|a| self.lower_expr(a)).collect(),
},
HirExpr::BinOp { op, lhs, rhs } => Expr::BinOp {
op: lower_bin_op(*op),
lhs: Box::new(self.lower_expr(lhs)),
rhs: Box::new(self.lower_expr(rhs)),
},
HirExpr::UnaryOp { op, operand } => Expr::UnaryOp {
op: match op {
cyrs_hir::UnaryOp::Neg => UnaryOp::Neg,
cyrs_hir::UnaryOp::Not => UnaryOp::Not,
},
operand: Box::new(self.lower_expr(operand)),
},
HirExpr::Case {
scrutinee,
arms,
otherwise,
} => Expr::Case {
scrutinee: scrutinee.as_ref().map(|s| Box::new(self.lower_expr(s))),
arms: arms
.iter()
.map(|(w, t)| (self.lower_expr(w), self.lower_expr(t)))
.collect(),
otherwise: otherwise.as_ref().map(|o| Box::new(self.lower_expr(o))),
},
HirExpr::IsNull { operand, negated } => Expr::IsNull {
operand: Box::new(self.lower_expr(operand)),
negated: *negated,
},
HirExpr::InList { operand, list } => Expr::InList {
operand: Box::new(self.lower_expr(operand)),
list: Box::new(self.lower_expr(list)),
},
HirExpr::Unresolved(name) => {
debug_assert!(
false,
"Unresolved variable `{name}` encountered in HIR→Plan lowering; \
run name resolution (cy-b4b) before calling lower_statement"
);
Expr::Null
}
HirExpr::PatternPredicate(pattern) => {
let (sub_op, _sub_vars) =
self.lower_match_pattern(pattern, None, false);
let inner_root = self.plan.ops[sub_op.0 as usize].clone();
Expr::Exists {
pattern: Box::new(inner_root),
}
}
HirExpr::ListComprehension { .. } => {
debug_assert!(
false,
"ListComprehension encountered in HIR→Plan lowering; \
run cyrs_hir::desugar::desugar_statement (cy-mla) first"
);
Expr::Null
}
HirExpr::ListPredicate {
kind,
var,
iterable,
predicate,
} => Expr::ListPredicate {
kind: lower_list_pred_kind(*kind),
var: self.map_var(*var),
iterable: Box::new(self.lower_expr(iterable)),
predicate: predicate.as_ref().map(|p| Box::new(self.lower_expr(p))),
},
HirExpr::MapProjection { .. } => {
debug_assert!(
false,
"MapProjection encountered in HIR→Plan lowering; \
run cyrs_hir::desugar::desugar_statement (cy-mla) first"
);
Expr::Null
}
}
}
}
fn synthesise_alias(expr: &HirExpr) -> SmolStr {
match expr {
HirExpr::Var(v) => SmolStr::new(format!("_v{}", v.0)),
HirExpr::Prop { prop, .. } => prop.clone(),
HirExpr::Call { name, .. } => name.clone(),
_ => SmolStr::new("_"),
}
}
fn expr_to_var_id(expr: &HirExpr) -> Option<HirVarId> {
match expr {
HirExpr::Var(v) => Some(*v),
_ => None,
}
}
#[allow(clippy::match_same_arms)]
fn lower_list_pred_kind(kind: HirListPredKind) -> ListPredKind {
match kind {
HirListPredKind::Any => ListPredKind::Any,
HirListPredKind::All => ListPredKind::All,
HirListPredKind::None => ListPredKind::None,
HirListPredKind::Single => ListPredKind::Single,
_ => ListPredKind::All,
}
}
fn lower_bin_op(op: cyrs_hir::BinOp) -> BinOp {
match op {
cyrs_hir::BinOp::Add => BinOp::Add,
cyrs_hir::BinOp::Sub => BinOp::Sub,
cyrs_hir::BinOp::Mul => BinOp::Mul,
cyrs_hir::BinOp::Div => BinOp::Div,
cyrs_hir::BinOp::Mod => BinOp::Mod,
cyrs_hir::BinOp::Pow => BinOp::Pow,
cyrs_hir::BinOp::Eq => BinOp::Eq,
cyrs_hir::BinOp::Neq => BinOp::Neq,
cyrs_hir::BinOp::Lt => BinOp::Lt,
cyrs_hir::BinOp::Le => BinOp::Le,
cyrs_hir::BinOp::Gt => BinOp::Gt,
cyrs_hir::BinOp::Ge => BinOp::Ge,
cyrs_hir::BinOp::And => BinOp::And,
cyrs_hir::BinOp::Or => BinOp::Or,
cyrs_hir::BinOp::Xor => BinOp::Xor,
cyrs_hir::BinOp::StartsWith => BinOp::StartsWith,
cyrs_hir::BinOp::EndsWith => BinOp::EndsWith,
cyrs_hir::BinOp::Contains => BinOp::Contains,
cyrs_hir::BinOp::RegexMatch => BinOp::RegexMatch,
cyrs_hir::BinOp::Concat => BinOp::Concat,
}
}
fn is_aggregate_func(name: &str) -> bool {
matches!(
name.to_ascii_lowercase().as_str(),
"count"
| "sum"
| "avg"
| "min"
| "max"
| "collect"
| "stdev"
| "stdevp"
| "percentilecont"
| "percentiledisc"
)
}
enum CreatePair<'a> {
Node {
labels: Vec<SmolStr>,
props: Option<&'a HirExpr>,
bind: Option<HirVarId>,
},
Rel {
from_bind: HirVarId,
to_bind: HirVarId,
rel_type: SmolStr,
props: Option<&'a HirExpr>,
bind: Option<HirVarId>,
},
}
fn create_pattern_pairs(part: &PatternPart) -> Vec<CreatePair<'_>> {
let mut result = Vec::new();
let mut node_vars: Vec<Option<HirVarId>> = Vec::new();
let mut elements = part.elements.iter().peekable();
while let Some(elem) = elements.next() {
match elem {
PatternElement::Node {
bind,
labels,
props,
..
} => {
node_vars.push(*bind);
result.push(CreatePair::Node {
labels: labels.clone(),
props: props.as_ref(),
bind: *bind,
});
}
PatternElement::Rel {
bind, types, props, ..
} => {
let Some(from_bind) = node_vars.last().copied().flatten() else {
continue; };
let to_bind = match elements.peek() {
Some(PatternElement::Node { bind: Some(v), .. }) => {
let v = *v;
node_vars.push(Some(v));
let next = elements.next().unwrap();
if let PatternElement::Node {
labels,
props,
bind,
..
} = next
{
result.push(CreatePair::Node {
labels: labels.clone(),
props: props.as_ref(),
bind: *bind,
});
}
v
}
_ => continue,
};
let rel_type = types.first().cloned().unwrap_or_default();
result.push(CreatePair::Rel {
from_bind,
to_bind,
rel_type,
props: props.as_ref(),
bind: *bind,
});
}
}
}
result
}
pub fn lower_union_pair(
left: &Statement,
right: &Statement,
kind: UnionKind,
) -> Result<PlanStatement, PlanLowerError> {
let mut left_plan = lower_statement(left)?;
let right_plan = lower_statement(right)?;
#[allow(clippy::cast_possible_truncation)]
let offset = left_plan.ops.len() as u32;
#[allow(clippy::cast_possible_truncation)]
let right_root = OpId(right_plan.ops.len() as u32 - 1 + offset);
left_plan.ops.extend(right_plan.ops);
left_plan.write_ops.extend(right_plan.write_ops);
for (plan_var, hir_var) in right_plan.var_map {
left_plan
.var_map
.insert(VarId(plan_var.0 + offset), hir_var);
}
let left_root = OpId(offset - 1);
left_plan.ops.push(ReadOp::Union {
left: left_root,
right: right_root,
kind,
});
Ok(left_plan)
}
pub fn apply_order_skip_limit(
plan: &mut PlanStatement,
order_keys: Vec<OrderKey>,
skip: Option<Expr>,
limit: Option<Expr>,
) {
if plan.ops.is_empty() {
return;
}
#[allow(clippy::cast_possible_truncation)]
let mut root = OpId(plan.ops.len() as u32 - 1);
if !order_keys.is_empty() {
let op = ReadOp::OrderBy {
input: root,
keys: order_keys,
};
root = plan.push(op);
}
if let Some(count) = skip {
let op = ReadOp::Skip { input: root, count };
root = plan.push(op);
}
if let Some(count) = limit {
let op = ReadOp::Limit { input: root, count };
root = plan.push(op);
}
let _ = root;
}
#[cfg(test)]
mod tests {
use super::*;
use crate::SortDir;
use cyrs_hir::desugar::desugar_statement;
use cyrs_hir::lower::lower_statement as hir_lower;
fn plan_from(src: &str) -> PlanStatement {
let hir = hir_lower(src);
let hir = desugar_statement(hir);
lower_statement(&hir).expect("plan_from: input HIR must be resolved and desugared")
}
fn render(plan: &PlanStatement) -> String {
use std::fmt::Write;
let mut out = String::new();
writeln!(out, "read_ops: {}", plan.ops.len()).unwrap();
writeln!(out, "write_ops: {}", plan.write_ops.len()).unwrap();
writeln!(out, "var_map_entries: {}", plan.var_map.len()).unwrap();
for (i, op) in plan.ops.iter().enumerate() {
writeln!(out, "op[{i}]: {}", op_tag(op)).unwrap();
}
for (i, wop) in plan.write_ops.iter().enumerate() {
writeln!(out, "write[{i}]: {}", write_op_tag(wop)).unwrap();
}
out
}
fn op_tag(op: &ReadOp) -> String {
match op {
ReadOp::Source { label, bind } => format!(
"Source(label={}, bind={})",
label
.as_ref()
.map_or("None".into(), |l| format!("{:?}", l.0)),
bind.0
),
ReadOp::Expand {
from,
bind_rel,
bind_to,
..
} => {
format!(
"Expand(from={}, bind_rel={}, bind_to={})",
from.0, bind_rel.0, bind_to.0
)
}
ReadOp::Filter { input, .. } => format!("Filter(input={})", input.0),
ReadOp::Project { input, items } => {
format!("Project(input={}, cols={})", input.0, items.len())
}
ReadOp::Aggregate { input, keys, aggs } => {
format!(
"Aggregate(input={}, keys={}, aggs={})",
input.0,
keys.len(),
aggs.len()
)
}
ReadOp::OrderBy { input, keys } => {
format!("OrderBy(input={}, keys={})", input.0, keys.len())
}
ReadOp::Skip { input, .. } => format!("Skip(input={})", input.0),
ReadOp::Limit { input, .. } => format!("Limit(input={})", input.0),
ReadOp::Distinct { input } => format!("Distinct(input={})", input.0),
ReadOp::Unwind { input, bind, .. } => {
format!("Unwind(input={}, bind={})", input.0, bind.0)
}
ReadOp::Union { left, right, kind } => {
format!("Union(left={}, right={}, kind={:?})", left.0, right.0, kind)
}
ReadOp::With {
input,
items,
filter,
} => {
format!(
"With(input={}, cols={}, has_filter={})",
input.0,
items.len(),
filter.is_some()
)
}
ReadOp::OptionalJoin { input, .. } => format!("OptionalJoin(input={})", input.0),
}
}
fn write_op_tag(op: &WriteOp) -> String {
match op {
WriteOp::CreateNode { labels, bind, .. } => {
format!(
"CreateNode(labels={:?}, bind={:?})",
labels,
bind.map(|v| v.0)
)
}
WriteOp::CreateRel { rel_type, bind, .. } => {
format!("CreateRel(type={rel_type}, bind={:?})", bind.map(|v| v.0))
}
WriteOp::MergeNode { labels, bind, .. } => {
format!(
"MergeNode(labels={:?}, bind={:?})",
labels,
bind.map(|v| v.0)
)
}
WriteOp::MergeRel { rel_type, bind, .. } => {
format!("MergeRel(type={rel_type}, bind={:?})", bind.map(|v| v.0))
}
WriteOp::SetProperty { target, prop, .. } => {
format!("SetProperty(target={}, prop={prop})", target.0)
}
WriteOp::SetLabels { target, labels } => {
format!("SetLabels(target={}, labels={:?})", target.0, labels)
}
WriteOp::RemoveProperty { target, prop } => {
format!("RemoveProperty(target={}, prop={prop})", target.0)
}
WriteOp::RemoveLabels { target, labels } => {
format!("RemoveLabels(target={}, labels={:?})", target.0, labels)
}
WriteOp::Delete { detach, targets } => {
format!("Delete(detach={detach}, targets={})", targets.len())
}
}
}
#[test]
fn snap_single_match() {
let plan = plan_from("MATCH (n) RETURN n");
insta::assert_snapshot!("plan_single_match", render(&plan));
}
#[test]
fn snap_match_with_label() {
let plan = plan_from("MATCH (n:Person) RETURN n");
insta::assert_snapshot!("plan_match_with_label", render(&plan));
}
#[test]
fn snap_match_where() {
let plan = plan_from("MATCH (n) WHERE n.age > 18 RETURN n");
insta::assert_snapshot!("plan_match_where", render(&plan));
}
#[test]
fn snap_match_where_pretty_tree() {
use crate::pretty::pretty;
let plan = plan_from("MATCH (a) WHERE a.x = 1 RETURN a");
insta::assert_snapshot!("plan_match_where_pretty_tree", pretty(&plan));
}
#[test]
fn snap_match_with() {
let plan = plan_from("MATCH (n) WITH n RETURN n");
insta::assert_snapshot!("plan_match_with", render(&plan));
}
#[test]
fn snap_match_return_projection() {
let plan = plan_from("MATCH (n:Person) RETURN n.name, n.age");
insta::assert_snapshot!("plan_match_return_projection", render(&plan));
}
#[test]
fn snap_return_distinct() {
let plan = plan_from("MATCH (n) RETURN DISTINCT n.name");
insta::assert_snapshot!("plan_return_distinct", render(&plan));
}
#[test]
fn snap_unwind() {
use cyrs_hir::{
Binding, Clause, Expr as HirExpr, HirSpan, Statement, VarId as HirVarId, VarKind,
};
let span = HirSpan::default();
let mut stmt = Statement::new(span);
let x_var = HirVarId(0);
stmt.bindings.insert(
x_var,
Binding {
id: x_var,
name: "x".into(),
kind: VarKind::Value,
defined_at: span,
},
);
stmt.clauses.push(Clause::Unwind {
id: cyrs_hir::HirId::DUMMY,
list: HirExpr::List(vec![HirExpr::Int(1), HirExpr::Int(2), HirExpr::Int(3)]),
bind: x_var,
span,
});
stmt.clauses.push(Clause::Return {
id: cyrs_hir::HirId::DUMMY,
projections: vec![cyrs_hir::Projection {
expr: HirExpr::Var(x_var),
alias: Some("x".into()),
span,
}],
distinct: false,
span,
});
let plan = lower_statement(&stmt).expect("manually-built HIR must be resolved");
insta::assert_snapshot!("plan_unwind", render(&plan));
}
#[test]
fn snap_create_node() {
let plan = plan_from("CREATE (n:Person)");
insta::assert_snapshot!("plan_create_node", render(&plan));
}
#[test]
fn snap_create_rel() {
let plan = plan_from("MATCH (a:Person), (b:Person) CREATE (a)-[:KNOWS]->(b)");
insta::assert_snapshot!("plan_create_rel", render(&plan));
}
#[test]
fn snap_merge_node() {
let plan = plan_from("MERGE (n:Person {name: 'Alice'})");
insta::assert_snapshot!("plan_merge_node", render(&plan));
}
#[test]
fn snap_set_property() {
let plan = plan_from("MATCH (n:Person) SET n.age = 30");
insta::assert_snapshot!("plan_set_property", render(&plan));
}
#[test]
fn snap_remove_label() {
let plan = plan_from("MATCH (n:Person) REMOVE n:Person");
insta::assert_snapshot!("plan_remove_label", render(&plan));
}
#[test]
fn snap_delete() {
let plan = plan_from("MATCH (n) DELETE n");
insta::assert_snapshot!("plan_delete", render(&plan));
}
#[test]
fn snap_detach_delete() {
let plan = plan_from("MATCH (n) DETACH DELETE n");
insta::assert_snapshot!("plan_detach_delete", render(&plan));
}
#[test]
fn snap_aggregation_count() {
let plan = plan_from("MATCH (n) RETURN count(n)");
insta::assert_snapshot!("plan_aggregation_count", render(&plan));
}
#[test]
fn snap_aggregation_sum() {
let plan = plan_from("MATCH (n) RETURN sum(n.age)");
insta::assert_snapshot!("plan_aggregation_sum", render(&plan));
}
#[test]
fn snap_union_all() {
let left_hir = desugar_statement(hir_lower("MATCH (n:Person) RETURN n"));
let right_hir = desugar_statement(hir_lower("MATCH (n:Animal) RETURN n"));
let plan = lower_union_pair(&left_hir, &right_hir, UnionKind::All)
.expect("UNION arms must be resolved/desugared");
insta::assert_snapshot!("plan_union_all", render(&plan));
}
#[test]
fn snap_union_distinct() {
let left_hir = desugar_statement(hir_lower("MATCH (n:Person) RETURN n"));
let right_hir = desugar_statement(hir_lower("MATCH (n:Animal) RETURN n"));
let plan = lower_union_pair(&left_hir, &right_hir, UnionKind::Distinct)
.expect("UNION arms must be resolved/desugared");
insta::assert_snapshot!("plan_union_distinct", render(&plan));
}
#[test]
fn snap_optional_match() {
let plan = plan_from("MATCH (n) OPTIONAL MATCH (n)-[:KNOWS]->(m) RETURN n, m");
insta::assert_snapshot!("plan_optional_match", render(&plan));
}
#[test]
fn snap_match_rel_chain() {
let plan = plan_from("MATCH (a)-[:KNOWS]->(b) RETURN a, b");
insta::assert_snapshot!("plan_match_rel_chain", render(&plan));
}
#[test]
fn snap_order_skip_limit() {
let mut plan = plan_from("MATCH (n) RETURN n");
apply_order_skip_limit(
&mut plan,
vec![OrderKey {
expr: Expr::Var(VarId(0)),
dir: SortDir::Desc,
}],
Some(Expr::Int(10)),
Some(Expr::Int(5)),
);
insta::assert_snapshot!("plan_order_skip_limit", render(&plan));
}
#[test]
fn plan_lowering_is_deterministic() {
let plan1 = plan_from("MATCH (n:Person) WHERE n.age > 18 RETURN n.name, n.age");
let plan2 = plan_from("MATCH (n:Person) WHERE n.age > 18 RETURN n.name, n.age");
assert_eq!(render(&plan1), render(&plan2));
}
#[test]
fn single_match_returns_source_and_project() {
let plan = plan_from("MATCH (n) RETURN n");
assert!(plan.ops.len() >= 2);
assert!(matches!(plan.ops[0], ReadOp::Source { .. }));
assert!(matches!(plan.ops.last(), Some(ReadOp::Project { .. })));
}
#[test]
fn match_where_inserts_filter() {
let plan = plan_from("MATCH (n) WHERE n.age > 18 RETURN n");
let has_filter = plan
.ops
.iter()
.any(|op| matches!(op, ReadOp::Filter { .. }));
assert!(has_filter, "expected Filter op in plan");
}
#[test]
fn create_node_emits_write_op() {
use cyrs_hir::{
Binding, Clause, HirSpan, Pattern, PatternElement, PatternPart, Statement,
VarId as HirVarId, VarKind,
};
let span = HirSpan::default();
let mut stmt = Statement::new(span);
let n_var = HirVarId(0);
stmt.bindings.insert(
n_var,
Binding {
id: n_var,
name: "n".into(),
kind: VarKind::Node,
defined_at: span,
},
);
stmt.clauses.push(Clause::Create {
id: cyrs_hir::HirId::DUMMY,
pattern: Pattern {
parts: vec![PatternPart {
named_as: None,
elements: vec![PatternElement::Node {
id: cyrs_hir::HirId::DUMMY,
bind: Some(n_var),
labels: vec!["Person".into()],
props: None,
span,
}],
}],
},
span,
});
let plan = lower_statement(&stmt).expect("manually-built HIR must be resolved");
assert!(
plan.write_ops
.iter()
.any(|w| matches!(w, WriteOp::CreateNode { .. })),
"expected CreateNode write op; write_ops={:?}",
plan.write_ops.iter().map(write_op_tag).collect::<Vec<_>>()
);
}
#[test]
fn delete_emits_write_op() {
use cyrs_hir::{
Binding, Clause, Expr as HirExpr, HirSpan, Pattern, PatternElement, PatternPart,
Statement, VarId as HirVarId, VarKind,
};
let span = HirSpan::default();
let mut stmt = Statement::new(span);
let n_var = HirVarId(0);
stmt.bindings.insert(
n_var,
Binding {
id: n_var,
name: "n".into(),
kind: VarKind::Node,
defined_at: span,
},
);
stmt.clauses.push(Clause::Match {
id: cyrs_hir::HirId::DUMMY,
optional: false,
pattern: Pattern {
parts: vec![PatternPart {
named_as: None,
elements: vec![PatternElement::Node {
id: cyrs_hir::HirId::DUMMY,
bind: Some(n_var),
labels: vec![],
props: None,
span,
}],
}],
},
span,
});
stmt.clauses.push(Clause::Delete {
id: cyrs_hir::HirId::DUMMY,
targets: vec![HirExpr::Var(n_var)],
detach: false,
span,
});
let plan = lower_statement(&stmt).expect("manually-built HIR must be resolved");
assert!(
plan.write_ops
.iter()
.any(|w| matches!(w, WriteOp::Delete { detach: false, .. })),
"expected Delete(detach=false) write op"
);
}
#[test]
fn detach_delete_emits_write_op() {
use cyrs_hir::{
Binding, Clause, Expr as HirExpr, HirSpan, Pattern, PatternElement, PatternPart,
Statement, VarId as HirVarId, VarKind,
};
let span = HirSpan::default();
let mut stmt = Statement::new(span);
let n_var = HirVarId(0);
stmt.bindings.insert(
n_var,
Binding {
id: n_var,
name: "n".into(),
kind: VarKind::Node,
defined_at: span,
},
);
stmt.clauses.push(Clause::Match {
id: cyrs_hir::HirId::DUMMY,
optional: false,
pattern: Pattern {
parts: vec![PatternPart {
named_as: None,
elements: vec![PatternElement::Node {
id: cyrs_hir::HirId::DUMMY,
bind: Some(n_var),
labels: vec![],
props: None,
span,
}],
}],
},
span,
});
stmt.clauses.push(Clause::Delete {
id: cyrs_hir::HirId::DUMMY,
targets: vec![HirExpr::Var(n_var)],
detach: true,
span,
});
let plan = lower_statement(&stmt).expect("manually-built HIR must be resolved");
assert!(
plan.write_ops
.iter()
.any(|w| matches!(w, WriteOp::Delete { detach: true, .. })),
"expected Delete(detach=true) write op"
);
}
#[test]
fn var_map_populated_for_bound_variables() {
let plan = plan_from("MATCH (n) RETURN n");
assert!(
!plan.var_map.is_empty(),
"var_map should be populated for bound variables"
);
}
#[test]
fn with_where_threads_filter_into_plan() {
let plan = plan_from(
"MATCH (a) UNWIND a.aliases AS alias \
WITH a, alias WHERE alias CONTAINS 'Fancy' \
RETURN DISTINCT a.canonical_name",
);
let has_with_filter = plan.ops.iter().any(|op| {
matches!(
op,
ReadOp::With {
filter: Some(_),
..
}
)
});
assert!(
has_with_filter,
"expected ReadOp::With with a Some(filter); plan ops = {:#?}",
plan.ops
);
}
fn stmt_with_return_expr(expr: HirExpr) -> Statement {
use cyrs_hir::HirSpan;
let span = HirSpan::default();
let mut stmt = Statement::new(span);
stmt.clauses.push(Clause::Return {
id: cyrs_hir::HirId::DUMMY,
projections: vec![Projection {
expr,
alias: Some("x".into()),
span,
}],
distinct: false,
span,
});
stmt
}
#[test]
fn lower_statement_returns_err_on_unresolved_name() {
let stmt = stmt_with_return_expr(HirExpr::Unresolved("foo".into()));
let err = lower_statement(&stmt).expect_err("unresolved name must be rejected");
match err {
PlanLowerError::UnresolvedName { name, .. } => assert_eq!(name, "foo"),
other => panic!("expected UnresolvedName, got {other:?}"),
}
}
#[test]
fn lower_statement_returns_err_on_listcomp() {
let expr = HirExpr::ListComprehension {
filter_var: HirVarId(0),
iterable: Box::new(HirExpr::List(vec![HirExpr::Int(1)])),
filter: None,
map_expr: Box::new(HirExpr::Var(HirVarId(0))),
};
let stmt = stmt_with_return_expr(expr);
let err = lower_statement(&stmt).expect_err("list comprehension must be rejected");
match err {
PlanLowerError::UndesugaredExpr { kind, .. } => assert_eq!(kind, "ListComprehension"),
other => panic!("expected UndesugaredExpr(ListComprehension), got {other:?}"),
}
}
#[test]
fn lower_statement_returns_err_on_mapprojection() {
let expr = HirExpr::MapProjection {
base: Box::new(HirExpr::Var(HirVarId(0))),
items: vec![],
};
let stmt = stmt_with_return_expr(expr);
let err = lower_statement(&stmt).expect_err("map projection must be rejected");
match err {
PlanLowerError::UndesugaredExpr { kind, .. } => assert_eq!(kind, "MapProjection"),
other => panic!("expected UndesugaredExpr(MapProjection), got {other:?}"),
}
}
#[test]
fn lower_statement_returns_err_on_unresolved_inside_patternpredicate() {
let element = PatternElement::Node {
id: cyrs_hir::HirId::DUMMY,
bind: None,
labels: vec![],
props: Some(HirExpr::Map(vec![(
"k".into(),
HirExpr::Unresolved("vaext".into()),
)])),
span: HirSpan::default(),
};
let pattern = cyrs_hir::Pattern {
parts: vec![PatternPart {
named_as: None,
elements: vec![element],
}],
};
let stmt = stmt_with_return_expr(HirExpr::PatternPredicate(pattern));
let err = lower_statement(&stmt)
.expect_err("unresolved name inside PatternPredicate must be rejected");
match err {
PlanLowerError::UnresolvedName { name, .. } => assert_eq!(name, "vaext"),
other => panic!("expected UnresolvedName, got {other:?}"),
}
}
#[test]
fn lower_statement_no_panic_on_unresolved_inside_patternpredicate_text() {
let s = "MATCH (n) WHERE (n {k: vaext})-->() RETURN n\n";
let stmt = hir_lower(s);
let stmt = desugar_statement(stmt);
let _ = lower_statement(&stmt);
}
#[test]
fn lower_statement_accepts_patternpredicate_as_exists() {
let expr = HirExpr::PatternPredicate(cyrs_hir::Pattern { parts: vec![] });
let stmt = stmt_with_return_expr(expr);
let plan = lower_statement(&stmt).expect("pattern predicate must lower to Exists");
let mut saw_exists = false;
for op in &plan.ops {
if let ReadOp::Project { items, .. } = op {
for item in items {
if matches!(item.expr, Expr::Exists { .. }) {
saw_exists = true;
}
}
}
}
assert!(
saw_exists,
"expected plan to carry Expr::Exists after PatternPredicate lowering, got {plan:?}"
);
}
}