use syn::Expr;
use std::collections::BTreeMap;
mod function;
mod operator;
mod visit;
use self::{function::Fun, operator::Operator};
#[cfg(feature = "double")]
pub type Value = f64;
#[cfg(not(feature = "double"))]
pub type Value = f32;
pub fn eval(ctx: &BTreeMap<&str, &syn::Expr>, expr: &Expr) -> Option<Value> {
Reflect::new(ctx).eval(expr)
}
#[derive(Debug)]
pub(self) enum Output {
Op(Operator),
V(Value),
Fn(Fun),
}
pub(self) struct Reflect<'a> {
pub(self) ctx: &'a BTreeMap<&'a str, &'a syn::Expr>,
pub(self) on_err: bool,
pub(self) output: Vec<Output>,
operators: Vec<Operator>,
}
impl<'a> Reflect<'a> {
fn new<'n>(ctx: &'n BTreeMap<&'n str, &'n syn::Expr>) -> Reflect<'n> {
Reflect {
ctx,
operators: vec![],
output: vec![],
on_err: false,
}
}
#[inline]
fn eval(mut self, e: &'a Expr) -> Option<Value> {
self.visit_expr(e);
if self.on_err {
None
} else {
self.output.extend(
self.operators
.drain(..)
.rev()
.map(|o| Output::Op(o))
.collect::<Vec<Output>>(),
);
evaluate(self.output).ok()
}
}
pub(self) fn push_op(&mut self, op: Operator) {
if Operator::ParenLeft.eq_preference(&op) {
if op == Operator::ParenRight {
loop {
if let Some(last) = self.operators.last() {
if *last == Operator::ParenLeft {
self.operators.pop();
break;
}
self.output.push(Output::Op(self.operators.pop().unwrap()));
} else {
break self.on_err = true;
}
}
} else {
self.operators.push(op);
}
} else {
while let Some(last) = self.operators.last() {
if *last != Operator::ParenLeft
&& (*last == Operator::Fn || last.ge_preference(&op))
{
self.output.push(Output::Op(self.operators.pop().unwrap()));
} else {
break;
}
}
self.operators.push(op);
}
}
}
#[inline]
fn evaluate(output: Vec<Output>) -> Result<Value, ()> {
let mut stack = Vec::new();
for o in output {
match o {
Output::V(v) => stack.push(v),
Output::Fn(method) => {
macro_rules! fun_arg {
($m:ident) => {{
let op2 = stack.pop().ok_or(())?;
let op1 = stack.pop().ok_or(())?;
op1.$m(op2)
}};
}
macro_rules! fun {
($m:ident) => {{
let op1 = stack.pop().ok_or(())?;
op1.$m()
}};
}
use self::function::Fun::*;
let e = match method {
Atan2 => fun_arg!(atan2),
Hypot => fun_arg!(hypot),
Log => fun_arg!(log),
Max => fun_arg!(max),
Min => fun_arg!(min),
PowF => fun_arg!(powf),
PowI => {
let op2 = stack.pop().ok_or(())?;
let op1 = stack.pop().ok_or(())?;
op1.powi(op2 as i32)
}
Abs => fun!(abs),
Acos => fun!(acos),
Acosh => fun!(acosh),
Asin => fun!(asin),
Asinh => fun!(asinh),
Atan => fun!(atan),
Atanh => fun!(atanh),
Cbrt => fun!(cbrt),
Ceil => fun!(ceil),
Cos => fun!(cos),
Cosh => fun!(cosh),
Exp => fun!(exp),
Exp2 => fun!(exp2),
ExpM1 => fun!(exp_m1),
Floor => fun!(floor),
Fract => fun!(fract),
Ln => fun!(ln),
Ln1p => fun!(ln_1p),
Log10 => fun!(log10),
Log2 => fun!(log2),
Recip => fun!(recip),
Round => fun!(round),
Signum => fun!(signum),
Sin => fun!(sin),
Sinh => fun!(sinh),
Sqrt => fun!(sqrt),
Tan => fun!(tan),
Tanh => fun!(tanh),
ToDegrees => fun!(to_degrees),
ToRadians => fun!(to_radians),
Trunc => fun!(trunc),
};
stack.push(e);
}
Output::Op(ref op) => {
use Operator::*;
macro_rules! two {
($op:tt) => {{
let op2 = stack.pop().ok_or(())?;
let op1 = stack.pop().ok_or(())?;
op1 $op op2
}};
}
let e = match op {
Add => two!(+),
Sub => two!(-),
Mul => two!(*),
Div => two!(/),
Rem => two!(%),
Neg => {
let op1 = stack.pop().ok_or(())?;
-op1
}
_ => unreachable!(),
};
stack.push(e);
}
}
}
if stack.len() == 1 {
stack.pop().ok_or(())
} else {
Err(())
}
}
#[cfg(test)]
mod test {
use syn::parse_str;
use super::operator::Operator::*;
use super::Output::*;
use super::*;
#[test]
fn test_evaluate_add() {
let f = vec![V(1.0), V(1.0), Op(Add)];
assert_eq!(evaluate(f).unwrap(), 2.0);
}
#[test]
fn test_evaluate_sub() {
let f = vec![V(1.0), V(1.0), Op(Sub)];
assert_eq!(evaluate(f).unwrap(), 0.0);
}
#[test]
fn test_evaluate_mul() {
let f = vec![V(1.0), V(1.0), Op(Mul)];
assert_eq!(evaluate(f).unwrap(), 1.0);
}
#[test]
fn test_evaluate_div() {
let f = vec![V(1.0), V(1.0), Op(Div)];
assert_eq!(evaluate(f).unwrap(), 1.0);
}
#[test]
fn test_evaluate_rem() {
let f = vec![V(4.0), V(2.0), Op(Rem)];
assert_eq!(evaluate(f).unwrap(), 0.0);
}
#[test]
fn test_eval_literal() {
let src = "-1";
let e = parse_str::<syn::Expr>(src).unwrap();
let ctx = BTreeMap::new();
assert_eq!(eval(&ctx, &e).unwrap(), -1.0);
let src = "-1.0";
let e = parse_str::<syn::Expr>(src).unwrap();
let ctx = BTreeMap::new();
assert_eq!(eval(&ctx, &e).unwrap(), -1.0);
}
#[test]
fn test_eval_one() {
let src = "1 + 1";
let e = parse_str::<syn::Expr>(src).unwrap();
let ctx = BTreeMap::new();
assert_eq!(eval(&ctx, &e).unwrap(), 2.0);
}
#[test]
fn test_eval() {
let src = "1 + 1 - 6 % 5";
let e = parse_str::<syn::Expr>(src).unwrap();
let ctx = BTreeMap::new();
assert_eq!(eval(&ctx, &e).unwrap(), 1.0);
let src = "1 + 1 - 10 / 5";
let e = parse_str::<syn::Expr>(src).unwrap();
let ctx = BTreeMap::new();
assert_eq!(eval(&ctx, &e).unwrap(), 0.0);
let src = "foo + (1 * 1 - 10 / 5)";
let e = parse_str::<syn::Expr>(src).unwrap();
let mut ctx = BTreeMap::new();
let arg = parse_str::<syn::Expr>("-1").unwrap();
ctx.insert("foo", &arg);
assert_eq!(eval(&ctx, &e).unwrap(), -2.0);
let src = "(foo * 2) + 1";
let e = parse_str::<syn::Expr>(src).unwrap();
let mut ctx = BTreeMap::new();
let arg = parse_str::<syn::Expr>("1 + -1 + -1 + 1").unwrap();
ctx.insert("foo", &arg);
assert_eq!(eval(&ctx, &e).unwrap(), 1.0);
}
#[test]
fn test_eval_fn() {
let src = "4.sqrt()";
let e = parse_str::<syn::Expr>(src).unwrap();
let ctx = BTreeMap::new();
assert_eq!(eval(&ctx, &e).unwrap(), 2.0);
let src = "2.powi(2)";
let e = parse_str::<syn::Expr>(src).unwrap();
let ctx = BTreeMap::new();
assert_eq!(eval(&ctx, &e).unwrap(), 4.0);
let src = "2.5.powi(2)";
let e = parse_str::<syn::Expr>(src).unwrap();
let ctx = BTreeMap::new();
assert_eq!(eval(&ctx, &e).unwrap(), 6.25);
}
}