use std::collections::HashMap;
use thiserror::Error;
use uni_common::Value;
use uni_cypher::ast::{BinaryOp, CypherLiteral, Expr, UnaryOp};
use crate::decode::stringify;
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum EvalError {
#[error("unbound parameter `${0}` in declared expression")]
UnboundParameter(String),
#[error("declared expression uses unsupported Cypher feature: {0}")]
Unsupported(String),
#[error("type mismatch: cannot apply `{op}` to {lhs} and {rhs}")]
TypeMismatch {
op: String,
lhs: &'static str,
rhs: &'static str,
},
#[error("arithmetic error in declared expression: {0}")]
Arithmetic(String),
}
pub fn eval_expr(expr: &Expr, params: &HashMap<String, Value>) -> Result<Value, EvalError> {
match expr {
Expr::Literal(lit) => Ok(lit_to_value(lit)),
Expr::Parameter(name) => params
.get(name)
.cloned()
.ok_or_else(|| EvalError::UnboundParameter(name.clone())),
Expr::BinaryOp { left, op, right } => {
let l = eval_expr(left, params)?;
let r = eval_expr(right, params)?;
apply_binary(*op, l, r)
}
Expr::UnaryOp { op, expr } => {
let v = eval_expr(expr, params)?;
apply_unary(*op, v)
}
Expr::List(items) => items
.iter()
.map(|e| eval_expr(e, params))
.collect::<Result<Vec<_>, _>>()
.map(Value::List),
Expr::Map(entries) => entries
.iter()
.map(|(k, v)| eval_expr(v, params).map(|v| (k.clone(), v)))
.collect::<Result<HashMap<_, _>, _>>()
.map(Value::Map),
Expr::IsNull(inner) => {
let v = eval_expr(inner, params)?;
Ok(Value::Bool(matches!(v, Value::Null)))
}
Expr::IsNotNull(inner) => {
let v = eval_expr(inner, params)?;
Ok(Value::Bool(!matches!(v, Value::Null)))
}
Expr::FunctionCall { name, args, .. } => {
let args: Vec<Value> = args
.iter()
.map(|e| eval_expr(e, params))
.collect::<Result<_, _>>()?;
apply_function(name, &args)
}
Expr::Case {
expr: scrutinee,
when_then,
else_expr,
} => eval_case(
scrutinee.as_deref(),
when_then,
else_expr.as_deref(),
params,
),
other => Err(EvalError::Unsupported(format!("{other:?}"))),
}
}
fn lit_to_value(lit: &CypherLiteral) -> Value {
match lit {
CypherLiteral::Null => Value::Null,
CypherLiteral::Bool(b) => Value::Bool(*b),
CypherLiteral::Integer(i) => Value::Int(*i),
CypherLiteral::Float(f) => Value::Float(*f),
CypherLiteral::String(s) => Value::String(s.clone()),
CypherLiteral::Bytes(b) => Value::Bytes(b.clone()),
}
}
fn type_name(v: &Value) -> &'static str {
match v {
Value::Null => "Null",
Value::Bool(_) => "Bool",
Value::Int(_) => "Int",
Value::Float(_) => "Float",
Value::String(_) => "String",
Value::Bytes(_) => "Bytes",
Value::List(_) => "List",
Value::Map(_) => "Map",
Value::Node(_) => "Node",
Value::Edge(_) => "Edge",
Value::Path(_) => "Path",
Value::Vector(_) => "Vector",
Value::Temporal(_) => "Temporal",
_ => "Other",
}
}
fn apply_binary(op: BinaryOp, l: Value, r: Value) -> Result<Value, EvalError> {
use BinaryOp::*;
if matches!(l, Value::Null) || matches!(r, Value::Null) {
return Ok(Value::Null);
}
match op {
Add => add_values(l, r),
Sub | Mul | Div | Mod | Pow => arith(op, l, r),
Eq | NotEq | Lt | LtEq | Gt | GtEq => compare(op, l, r),
And | Or | Xor => boolean_op(op, l, r),
Contains | StartsWith | EndsWith => string_match(op, l, r),
Regex | ApproxEq => Err(EvalError::Unsupported(format!("{op}"))),
}
}
fn add_values(l: Value, r: Value) -> Result<Value, EvalError> {
match (l, r) {
(Value::Int(a), Value::Int(b)) => Ok(Value::Int(a.saturating_add(b))),
(Value::Float(a), Value::Float(b)) => Ok(Value::Float(a + b)),
(Value::Int(a), Value::Float(b)) => Ok(Value::Float(a as f64 + b)),
(Value::Float(a), Value::Int(b)) => Ok(Value::Float(a + b as f64)),
(Value::String(a), Value::String(b)) => Ok(Value::String(a + &b)),
(Value::String(a), b) => Ok(Value::String(a + &stringify(&b))),
(a, Value::String(b)) => Ok(Value::String(stringify(&a) + &b)),
(Value::List(mut a), Value::List(b)) => {
a.extend(b);
Ok(Value::List(a))
}
(l, r) => Err(EvalError::TypeMismatch {
op: "+".to_owned(),
lhs: type_name(&l),
rhs: type_name(&r),
}),
}
}
fn arith(op: BinaryOp, l: Value, r: Value) -> Result<Value, EvalError> {
let (lf, rf, both_int) = match (&l, &r) {
(Value::Int(a), Value::Int(b)) => (*a as f64, *b as f64, true),
(Value::Float(a), Value::Float(b)) => (*a, *b, false),
(Value::Int(a), Value::Float(b)) => (*a as f64, *b, false),
(Value::Float(a), Value::Int(b)) => (*a, *b as f64, false),
_ => {
return Err(EvalError::TypeMismatch {
op: format!("{op}"),
lhs: type_name(&l),
rhs: type_name(&r),
});
}
};
let out = match op {
BinaryOp::Sub => lf - rf,
BinaryOp::Mul => lf * rf,
BinaryOp::Div => {
if rf == 0.0 {
return Err(EvalError::Arithmetic("divide by zero".to_owned()));
}
lf / rf
}
BinaryOp::Mod => {
if rf == 0.0 {
return Err(EvalError::Arithmetic("mod by zero".to_owned()));
}
lf % rf
}
BinaryOp::Pow => lf.powf(rf),
_ => unreachable!("arith dispatched non-arith op"),
};
if both_int && out.fract() == 0.0 && out.is_finite() {
Ok(Value::Int(out as i64))
} else {
Ok(Value::Float(out))
}
}
fn compare(op: BinaryOp, l: Value, r: Value) -> Result<Value, EvalError> {
let ord = match (&l, &r) {
(Value::Int(a), Value::Int(b)) => a.cmp(b),
(Value::Float(a), Value::Float(b)) => a
.partial_cmp(b)
.ok_or_else(|| EvalError::Arithmetic("NaN comparison".to_owned()))?,
(Value::Int(a), Value::Float(b)) => (*a as f64)
.partial_cmp(b)
.ok_or_else(|| EvalError::Arithmetic("NaN comparison".to_owned()))?,
(Value::Float(a), Value::Int(b)) => a
.partial_cmp(&(*b as f64))
.ok_or_else(|| EvalError::Arithmetic("NaN comparison".to_owned()))?,
(Value::String(a), Value::String(b)) => a.cmp(b),
(Value::Bool(a), Value::Bool(b)) => a.cmp(b),
(l, r) => {
return Err(EvalError::TypeMismatch {
op: format!("{op}"),
lhs: type_name(l),
rhs: type_name(r),
});
}
};
use std::cmp::Ordering::*;
let v = match op {
BinaryOp::Eq => ord == Equal,
BinaryOp::NotEq => ord != Equal,
BinaryOp::Lt => ord == Less,
BinaryOp::LtEq => ord != Greater,
BinaryOp::Gt => ord == Greater,
BinaryOp::GtEq => ord != Less,
_ => unreachable!(),
};
Ok(Value::Bool(v))
}
fn boolean_op(op: BinaryOp, l: Value, r: Value) -> Result<Value, EvalError> {
match (l, r) {
(Value::Bool(a), Value::Bool(b)) => Ok(Value::Bool(match op {
BinaryOp::And => a && b,
BinaryOp::Or => a || b,
BinaryOp::Xor => a ^ b,
_ => unreachable!(),
})),
(l, r) => Err(EvalError::TypeMismatch {
op: format!("{op}"),
lhs: type_name(&l),
rhs: type_name(&r),
}),
}
}
fn string_match(op: BinaryOp, l: Value, r: Value) -> Result<Value, EvalError> {
match (l, r) {
(Value::String(a), Value::String(b)) => Ok(Value::Bool(match op {
BinaryOp::Contains => a.contains(&b),
BinaryOp::StartsWith => a.starts_with(&b),
BinaryOp::EndsWith => a.ends_with(&b),
_ => unreachable!(),
})),
(l, r) => Err(EvalError::TypeMismatch {
op: format!("{op}"),
lhs: type_name(&l),
rhs: type_name(&r),
}),
}
}
fn apply_unary(op: UnaryOp, v: Value) -> Result<Value, EvalError> {
match (op, v) {
(UnaryOp::Neg, Value::Int(i)) => Ok(Value::Int(-i)),
(UnaryOp::Neg, Value::Float(f)) => Ok(Value::Float(-f)),
(UnaryOp::Not, Value::Bool(b)) => Ok(Value::Bool(!b)),
(UnaryOp::Not, Value::Null) => Ok(Value::Null),
(op, v) => Err(EvalError::TypeMismatch {
op: format!("{op}"),
lhs: type_name(&v),
rhs: "<unary>",
}),
}
}
fn apply_function(name: &str, args: &[Value]) -> Result<Value, EvalError> {
match (name, args) {
("toString", [v]) => Ok(Value::String(stringify(v))),
("upper" | "toUpper", [Value::String(s)]) => Ok(Value::String(s.to_uppercase())),
("lower" | "toLower", [Value::String(s)]) => Ok(Value::String(s.to_lowercase())),
("trim", [Value::String(s)]) => Ok(Value::String(s.trim().to_owned())),
("length" | "size", [Value::String(s)]) => Ok(Value::Int(s.chars().count() as i64)),
("length" | "size", [Value::List(l)]) => Ok(Value::Int(l.len() as i64)),
("abs", [Value::Int(i)]) => Ok(Value::Int(i.unsigned_abs() as i64)),
("abs", [Value::Float(f)]) => Ok(Value::Float(f.abs())),
(name, _) => Err(EvalError::Unsupported(format!("function `{name}`"))),
}
}
fn eval_case(
scrutinee: Option<&Expr>,
when_then: &[(Expr, Expr)],
else_expr: Option<&Expr>,
params: &HashMap<String, Value>,
) -> Result<Value, EvalError> {
let scrutinee_val = match scrutinee {
Some(e) => Some(eval_expr(e, params)?),
None => None,
};
for (w, t) in when_then {
let w_val = eval_expr(w, params)?;
let matched = match &scrutinee_val {
Some(s) => values_equal(s, &w_val),
None => matches!(w_val, Value::Bool(true)),
};
if matched {
return eval_expr(t, params);
}
}
if let Some(e) = else_expr {
eval_expr(e, params)
} else {
Ok(Value::Null)
}
}
fn values_equal(a: &Value, b: &Value) -> bool {
match (a, b) {
(Value::Null, Value::Null) => true,
(Value::Bool(a), Value::Bool(b)) => a == b,
(Value::Int(a), Value::Int(b)) => a == b,
(Value::Float(a), Value::Float(b)) => a == b,
(Value::Int(a), Value::Float(b)) | (Value::Float(b), Value::Int(a)) => *a as f64 == *b,
(Value::String(a), Value::String(b)) => a == b,
_ => false,
}
}
#[cfg(test)]
mod tests {
use super::*;
use uni_cypher::parse_expression;
fn ev(src: &str, params: &[(&str, Value)]) -> Value {
let expr = parse_expression(src).expect("parse");
let p: HashMap<String, Value> = params
.iter()
.map(|(k, v)| ((*k).to_owned(), v.clone()))
.collect();
eval_expr(&expr, &p).expect("eval")
}
#[test]
fn string_concat_with_params() {
let v = ev(
"$first + ' ' + $last",
&[
("first", Value::String("Ada".to_owned())),
("last", Value::String("Lovelace".to_owned())),
],
);
assert_eq!(v, Value::String("Ada Lovelace".to_owned()));
}
#[test]
fn integer_arithmetic() {
let v = ev("$a * $b + 1", &[("a", Value::Int(3)), ("b", Value::Int(4))]);
assert_eq!(v, Value::Int(13));
}
#[test]
fn boolean_short_circuit_via_eval() {
let v = ev("$x > 0 AND $x < 10", &[("x", Value::Int(5))]);
assert_eq!(v, Value::Bool(true));
}
#[test]
fn case_when_branch() {
let v = ev(
"CASE WHEN $x > 0 THEN 'pos' WHEN $x < 0 THEN 'neg' ELSE 'zero' END",
&[("x", Value::Int(-3))],
);
assert_eq!(v, Value::String("neg".to_owned()));
}
#[test]
fn unbound_parameter_errors() {
let expr = parse_expression("$missing + 1").unwrap();
let err = eval_expr(&expr, &HashMap::new()).unwrap_err();
assert!(matches!(err, EvalError::UnboundParameter(ref n) if n == "missing"));
}
#[test]
fn null_propagates_through_arithmetic() {
let v = ev("$x + 1", &[("x", Value::Null)]);
assert_eq!(v, Value::Null);
}
#[test]
fn toupper_function() {
let v = ev("toUpper($s)", &[("s", Value::String("hello".to_owned()))]);
assert_eq!(v, Value::String("HELLO".to_owned()));
}
}