use jexl_parser::{
ast::{Expression, OpCode},
Parser,
};
use serde_json::{json as value, Value};
pub mod error;
use error::*;
use std::collections::HashMap;
const EPSILON: f64 = 0.000001f64;
trait Truthy {
fn is_truthy(&self) -> bool;
fn is_falsey(&self) -> bool {
!self.is_truthy()
}
}
impl Truthy for Value {
fn is_truthy(&self) -> bool {
match self {
Value::Bool(b) => *b,
Value::Null => true,
Value::Number(f) => f.as_f64().unwrap() != 0.0,
Value::String(s) => !s.is_empty(),
Value::Array(_) => true,
Value::Object(_) => true,
}
}
}
type Context = Value;
pub type TransformFn<'a> = Box<dyn Fn(&[Value]) -> Result<Value, anyhow::Error> + 'a>;
#[derive(Default)]
pub struct Evaluator<'a> {
transforms: HashMap<String, TransformFn<'a>>,
}
impl<'a> Evaluator<'a> {
pub fn new() -> Self {
Self::default()
}
pub fn with_transform<F>(mut self, name: &str, transform: F) -> Self
where
F: Fn(&[Value]) -> Result<Value, anyhow::Error> + 'a,
{
self.transforms
.insert(name.to_string(), Box::new(transform));
self
}
pub fn eval<'b>(&self, input: &'b str) -> Result<'b, Value> {
let context = value!({});
self.eval_in_context(input, &context)
}
pub fn eval_in_context<'b, T: serde::Serialize>(
&self,
input: &'b str,
context: T,
) -> Result<'b, Value> {
let tree = Parser::parse(input)?;
let context = serde_json::to_value(context)?;
if !context.is_object() {
return Err(EvaluationError::InvalidContext);
}
self.eval_ast(tree, &context)
}
fn eval_ast<'b>(&self, ast: Expression, context: &Context) -> Result<'b, Value> {
match ast {
Expression::Number(n) => Ok(value!(n)),
Expression::Boolean(b) => Ok(value!(b)),
Expression::String(s) => Ok(value!(s)),
Expression::Array(xs) => xs.into_iter().map(|x| self.eval_ast(*x, context)).collect(),
Expression::Object(items) => {
let mut map = serde_json::Map::with_capacity(items.len());
for (key, expr) in items.into_iter() {
if map.contains_key(&key) {
return Err(EvaluationError::DuplicateObjectKey(key));
}
let value = self.eval_ast(*expr, context)?;
map.insert(key, value);
}
Ok(Value::Object(map))
}
Expression::Identifier(inner) => {
Ok(context.get(&inner).unwrap_or(&value!(null)).clone())
}
Expression::DotOperation { subject, ident } => {
let subject = self.eval_ast(*subject, context)?;
Ok(subject.get(&ident).unwrap_or(&value!(null)).clone())
}
Expression::IndexOperation { subject, index } => {
let subject = self.eval_ast(*subject, context)?;
if let Expression::Filter { ident, op, right } = *index {
let subject_arr = subject.as_array().ok_or(EvaluationError::InvalidFilter)?;
let right = self.eval_ast(*right, context)?;
let filtered = subject_arr
.iter()
.filter(|e| {
let left = e.get(&ident).unwrap_or(&value!(null));
Self::apply_op(op, left.clone(), right.clone())
.unwrap_or(value!(false))
.is_truthy()
})
.collect::<Vec<_>>();
return Ok(value!(filtered));
}
let index = self.eval_ast(*index, context)?;
match index {
Value::String(inner) => {
Ok(subject.get(&inner).unwrap_or(&value!(null)).clone())
}
Value::Number(inner) => Ok(subject
.get(inner.as_f64().unwrap().floor() as usize)
.unwrap_or(&value!(null))
.clone()),
_ => Err(EvaluationError::InvalidIndexType),
}
}
Expression::BinaryOperation {
left,
right,
operation,
} => {
let left = self.eval_ast(*left, context)?;
let right = self.eval_ast(*right, context)?;
Self::apply_op(operation, left, right)
}
Expression::Transform {
name,
subject,
args,
} => {
let subject = self.eval_ast(*subject, context)?;
let mut args_arr = Vec::new();
args_arr.push(subject);
if let Some(args) = args {
for arg in args {
args_arr.push(self.eval_ast(*arg, context)?);
}
}
let f = self
.transforms
.get(&name)
.ok_or(EvaluationError::UnknownTransform(name))?;
f(&args_arr).map_err(|e| e.into())
}
Expression::Conditional {
left,
truthy,
falsy,
} => {
if self.eval_ast(*left, context)?.is_truthy() {
self.eval_ast(*truthy, context)
} else {
self.eval_ast(*falsy, context)
}
}
Expression::Filter {
ident: _,
op: _,
right: _,
} => {
return Err(EvaluationError::InvalidFilter);
}
}
}
fn apply_op<'b>(operation: OpCode, left: Value, right: Value) -> Result<'b, Value> {
match (operation, left, right) {
(OpCode::And, a, b) => Ok(if a.is_truthy() { b } else { a }),
(OpCode::Or, a, b) => Ok(if a.is_truthy() { a } else { b }),
(op, Value::Number(a), Value::Number(b)) => {
let left = a.as_f64().unwrap();
let right = b.as_f64().unwrap();
Ok(match op {
OpCode::Add => value!(left + right),
OpCode::Subtract => value!(left - right),
OpCode::Multiply => value!(left * right),
OpCode::Divide => value!(left / right),
OpCode::FloorDivide => value!((left / right).floor()),
OpCode::Modulus => value!(left % right),
OpCode::Exponent => value!(left.powf(right)),
OpCode::Less => value!(left < right),
OpCode::Greater => value!(left > right),
OpCode::LessEqual => value!(left <= right),
OpCode::GreaterEqual => value!(left >= right),
OpCode::Equal => value!((left - right).abs() < EPSILON),
OpCode::NotEqual => value!((left - right).abs() > EPSILON),
OpCode::In => value!(false),
OpCode::And | OpCode::Or => {
unreachable!("Covered by previous case in parent match")
}
})
}
(OpCode::Add, Value::String(a), Value::String(b)) => Ok(value!(format!("{}{}", a, b))),
(OpCode::In, Value::String(a), Value::String(b)) => Ok(value!(b.contains(&a))),
(OpCode::In, left, Value::Array(v)) => Ok(value!(v.contains(&left))),
(OpCode::Equal, Value::String(a), Value::String(b)) => Ok(value!(a == b)),
(operation, left, right) => Err(EvaluationError::InvalidBinaryOp {
operation,
left,
right,
}),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json as value;
#[test]
fn test_literal() {
assert_eq!(Evaluator::new().eval("1").unwrap(), value!(1.0));
}
#[test]
fn test_binary_expression_addition() {
assert_eq!(Evaluator::new().eval("1 + 2").unwrap(), value!(3.0));
}
#[test]
fn test_binary_expression_multiplication() {
assert_eq!(Evaluator::new().eval("2 * 3").unwrap(), value!(6.0));
}
#[test]
fn test_precedence() {
assert_eq!(Evaluator::new().eval("2 + 3 * 4").unwrap(), value!(14.0));
}
#[test]
fn test_parenthesis() {
assert_eq!(Evaluator::new().eval("(2 + 3) * 4").unwrap(), value!(20.0));
}
#[test]
fn test_string_concat() {
assert_eq!(
Evaluator::new().eval("'Hello ' + 'World'").unwrap(),
value!("Hello World")
);
}
#[test]
fn test_true_comparison() {
assert_eq!(Evaluator::new().eval("2 > 1").unwrap(), value!(true));
}
#[test]
fn test_false_comparison() {
assert_eq!(Evaluator::new().eval("2 <= 1").unwrap(), value!(false));
}
#[test]
fn test_boolean_logic() {
assert_eq!(
Evaluator::new()
.eval("'foo' && 6 >= 6 && 0 + 1 && true")
.unwrap(),
value!(true)
);
}
#[test]
fn test_identifier() {
let context = value!({"a": 1.0});
assert_eq!(
Evaluator::new().eval_in_context("a", context).unwrap(),
value!(1.0)
);
}
#[test]
fn test_identifier_chain() {
let context = value!({"a": {"b": 2.0}});
assert_eq!(
Evaluator::new().eval_in_context("a.b", context).unwrap(),
value!(2.0)
);
}
#[test]
fn test_context_filter_arrays() {
let context = value!({
"foo": {
"bar": [
{"tek": "hello"},
{"tek": "baz"},
{"tok": "baz"},
]
}
});
assert_eq!(
Evaluator::new()
.eval_in_context("foo.bar[.tek == 'baz']", &context)
.unwrap(),
value!([{"tek": "baz"}])
);
}
#[test]
fn test_context_array_index() {
let context = value!({
"foo": {
"bar": [
{"tek": "hello"},
{"tek": "baz"},
{"tok": "baz"},
]
}
});
assert_eq!(
Evaluator::new()
.eval_in_context("foo.bar[1].tek", context)
.unwrap(),
value!("baz")
);
}
#[test]
fn test_object_expression_properties() {
let context = value!({"foo": {"baz": {"bar": "tek"}}});
assert_eq!(
Evaluator::new()
.eval_in_context("foo['ba' + 'z'].bar", &context)
.unwrap(),
value!("tek")
);
}
#[test]
fn test_divfloor() {
assert_eq!(Evaluator::new().eval("7 // 2").unwrap(), value!(3.0));
}
#[test]
fn test_empty_object_literal() {
assert_eq!(Evaluator::new().eval("{}").unwrap(), value!({}));
}
#[test]
fn test_object_literal_strings() {
assert_eq!(
Evaluator::new().eval("{'foo': {'bar': 'tek'}}").unwrap(),
value!({"foo": {"bar": "tek"}})
);
}
#[test]
fn test_object_literal_identifiers() {
assert_eq!(
Evaluator::new().eval("{foo: {bar: 'tek'}}").unwrap(),
value!({"foo": {"bar": "tek"}})
);
}
#[test]
fn test_object_literal_properties() {
assert_eq!(
Evaluator::new().eval("{foo: 'bar'}.foo").unwrap(),
value!("bar")
);
}
#[test]
fn test_array_literal() {
assert_eq!(
Evaluator::new().eval("['foo', 1+2]").unwrap(),
value!(["foo", 3.0])
);
}
#[test]
fn test_array_literal_indexing() {
assert_eq!(Evaluator::new().eval("[1, 2, 3][1]").unwrap(), value!(2.0));
}
#[test]
fn test_in_operator_string() {
assert_eq!(
Evaluator::new().eval("'bar' in 'foobartek'").unwrap(),
value!(true)
);
assert_eq!(
Evaluator::new().eval("'baz' in 'foobartek'").unwrap(),
value!(false)
);
}
#[test]
fn test_in_operator_array() {
assert_eq!(
Evaluator::new()
.eval("'bar' in ['foo', 'bar', 'tek']")
.unwrap(),
value!(true)
);
assert_eq!(
Evaluator::new()
.eval("'baz' in ['foo', 'bar', 'tek']")
.unwrap(),
value!(false)
);
}
#[test]
fn test_conditional_expression() {
assert_eq!(
Evaluator::new().eval("'foo' ? 1 : 2").unwrap(),
value!(1f64)
);
assert_eq!(Evaluator::new().eval("'' ? 1 : 2").unwrap(), value!(2f64));
}
#[test]
fn test_arbitrary_whitespace() {
assert_eq!(
Evaluator::new().eval("(\t2\n+\n3) *\n4\n\r\n").unwrap(),
value!(20.0)
);
}
#[test]
fn test_non_integer() {
assert_eq!(Evaluator::new().eval("1.5 * 3.0").unwrap(), value!(4.5));
}
#[test]
fn test_string_literal() {
assert_eq!(
Evaluator::new().eval("'hello world'").unwrap(),
value!("hello world")
);
assert_eq!(
Evaluator::new().eval("\"hello world\"").unwrap(),
value!("hello world")
);
}
#[test]
fn test_string_escapes() {
assert_eq!(Evaluator::new().eval("'a\\'b'").unwrap(), value!("a'b"));
assert_eq!(Evaluator::new().eval("\"a\\\"b\"").unwrap(), value!("a\"b"));
}
#[test]
fn test_simple_transform() {
let evaluator = Evaluator::new().with_transform("lower", |v: &[Value]| {
let s = v
.get(0)
.expect("There should be one argument!")
.as_str()
.expect("Should be a string!");
Ok(value!(s.to_lowercase()))
});
assert_eq!(evaluator.eval("'T_T'|lower").unwrap(), value!("t_t"));
}
#[test]
fn test_missing_transform() {
let err = Evaluator::new().eval("'hello'|world").unwrap_err();
if let EvaluationError::UnknownTransform(transform) = err {
assert_eq!(transform, "world")
} else {
panic!("Should have thrown an unknown transform error")
}
}
#[test]
fn test_add_multiple_transforms() {
let evaluator = Evaluator::new()
.with_transform("sqrt", |v: &[Value]| {
let num = v
.first()
.expect("There should be one argument!")
.as_f64()
.expect("Should be a valid number!");
Ok(value!(num.sqrt() as u64))
})
.with_transform("square", |v: &[Value]| {
let num = v
.first()
.expect("There should be one argument!")
.as_f64()
.expect("Should be a valid number!");
Ok(value!((num as u64).pow(2)))
});
assert_eq!(evaluator.eval("4|square").unwrap(), value!(16));
assert_eq!(evaluator.eval("4|sqrt").unwrap(), value!(2));
assert_eq!(evaluator.eval("4|square|sqrt").unwrap(), value!(4));
}
#[test]
fn test_transform_with_argument() {
let evaluator = Evaluator::new().with_transform("split", |args: &[Value]| {
let s = args
.first()
.expect("Should be a first argument!")
.as_str()
.expect("Should be a string!");
let c = args
.get(1)
.expect("There should be a second argument!")
.as_str()
.expect("Should be a string");
let res: Vec<&str> = s.split_terminator(c).collect();
Ok(value!(res))
});
assert_eq!(
evaluator.eval("'John Doe'|split(' ')").unwrap(),
value!(vec!["John", "Doe"])
);
}
#[derive(Debug, thiserror::Error)]
enum CustomError {
#[error("Invalid argument in transform!")]
InvalidArgument,
}
#[test]
fn test_custom_error_message() {
let evaluator = Evaluator::new().with_transform("error", |_: &[Value]| {
Err(CustomError::InvalidArgument.into())
});
let res = evaluator.eval("1234|error");
assert!(res.is_err());
if let EvaluationError::CustomError(e) = res.unwrap_err() {
assert_eq!(e.to_string(), "Invalid argument in transform!")
} else {
panic!("Should have returned a Custom error!")
}
}
#[test]
fn test_filter_collections_many_returned() {
let evaluator = Evaluator::new();
let context = value!({
"foo": [
{"bobo": 50, "fofo": 100},
{"bobo": 60, "baz": 90},
{"bobo": 10, "bar": 83},
{"bobo": 20, "yam": 12},
]
});
let exp = "foo[.bobo >= 50]";
assert_eq!(
evaluator.eval_in_context(exp, context).unwrap(),
value!([{"bobo": 50, "fofo": 100}, {"bobo": 60, "baz": 90}])
);
}
}