use crate::query::rewrite::context::RewriteContext;
use crate::query::rewrite::error::RewriteError;
use uni_cypher::ast::Expr;
pub trait RewriteRule: Send + Sync {
fn function_name(&self) -> &str;
fn validate_args(&self, args: &[Expr]) -> Result<(), RewriteError>;
fn rewrite(&self, args: Vec<Expr>, ctx: &RewriteContext) -> Result<Expr, RewriteError>;
fn is_applicable(&self, _ctx: &RewriteContext) -> bool {
true
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Arity {
Exact(usize),
Range(usize, usize),
VarArgs(usize),
}
impl Arity {
pub fn check(&self, count: usize) -> Result<(), RewriteError> {
let (min, max) = match self {
Arity::Exact(n) => (*n, *n),
Arity::Range(min, max) => (*min, *max),
Arity::VarArgs(min) => (*min, usize::MAX),
};
if count >= min && count <= max {
return Ok(());
}
if min == max {
Err(RewriteError::ArityMismatch {
expected: min,
got: count,
})
} else {
Err(RewriteError::ArityOutOfRange {
min,
max,
got: count,
})
}
}
}
#[derive(Debug, Clone)]
pub struct ArgConstraints {
pub arity: Arity,
pub literal_args: Vec<usize>,
pub entity_arg: Option<usize>,
}
impl ArgConstraints {
pub fn validate(&self, args: &[Expr]) -> Result<(), RewriteError> {
self.arity.check(args.len())?;
for &idx in &self.literal_args {
if idx >= args.len() {
continue; }
if !matches!(args[idx], Expr::Literal(_)) {
return Err(RewriteError::ExpectedStringLiteral { arg_index: idx });
}
}
if let Some(idx) = self.entity_arg {
if idx >= args.len() {
return Ok(()); }
if !matches!(args[idx], Expr::Variable(_) | Expr::Property(_, _)) {
return Err(RewriteError::ExpectedEntityReference { arg_index: idx });
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_arity_exact() {
let arity = Arity::Exact(3);
assert!(arity.check(3).is_ok());
assert!(arity.check(2).is_err());
assert!(arity.check(4).is_err());
}
#[test]
fn test_arity_range() {
let arity = Arity::Range(2, 4);
assert!(arity.check(1).is_err());
assert!(arity.check(2).is_ok());
assert!(arity.check(3).is_ok());
assert!(arity.check(4).is_ok());
assert!(arity.check(5).is_err());
}
#[test]
fn test_arity_varargs() {
let arity = Arity::VarArgs(2);
assert!(arity.check(1).is_err());
assert!(arity.check(2).is_ok());
assert!(arity.check(10).is_ok());
}
#[test]
fn test_arg_constraints_validate() {
use uni_cypher::ast::CypherLiteral;
let constraints = ArgConstraints {
arity: Arity::Exact(3),
literal_args: vec![1],
entity_arg: Some(0),
};
let valid_args = vec![
Expr::Variable("e".into()),
Expr::Literal(CypherLiteral::String("prop".into())),
Expr::Variable("x".into()),
];
assert!(constraints.validate(&valid_args).is_ok());
let wrong_arity = vec![Expr::Variable("e".into())];
assert!(constraints.validate(&wrong_arity).is_err());
let non_literal = vec![
Expr::Variable("e".into()),
Expr::Variable("prop".into()), Expr::Variable("x".into()),
];
assert!(constraints.validate(&non_literal).is_err());
}
}