mathew 0.0.2

Mathematical expression evaluator with context
Documentation
use syn::Expr;

use std::collections::BTreeMap;

mod function;
mod operator;
mod visit;

use self::{function::Fun, operator::Operator};

#[cfg(feature = "double")]
pub type Value = f64;

#[cfg(not(feature = "double"))]
pub type Value = f32;

pub fn eval(ctx: &BTreeMap<&str, &syn::Expr>, expr: &Expr) -> Option<Value> {
    Reflect::new(ctx).eval(expr)
}

#[derive(Debug)]
pub(self) enum Output {
    Op(Operator),
    V(Value),
    Fn(Fun),
}

pub(self) struct Reflect<'a> {
    pub(self) ctx: &'a BTreeMap<&'a str, &'a syn::Expr>,
    pub(self) on_err: bool,
    pub(self) output: Vec<Output>,
    operators: Vec<Operator>,
}

impl<'a> Reflect<'a> {
    fn new<'n>(ctx: &'n BTreeMap<&'n str, &'n syn::Expr>) -> Reflect<'n> {
        Reflect {
            ctx,
            operators: vec![],
            output: vec![],
            on_err: false,
        }
    }

    #[inline]
    fn eval(mut self, e: &'a Expr) -> Option<Value> {
        self.visit_expr(e);

        if self.on_err {
            None
        } else {
            self.output.extend(
                self.operators
                    .drain(..)
                    .rev()
                    .map(|o| Output::Op(o))
                    .collect::<Vec<Output>>(),
            );
            evaluate(self.output).ok()
        }
    }

    pub(self) fn push_op(&mut self, op: Operator) {
        if Operator::ParenLeft.eq_preference(&op) {
            if op == Operator::ParenRight {
                loop {
                    if let Some(last) = self.operators.last() {
                        if *last == Operator::ParenLeft {
                            self.operators.pop();
                            break;
                        }
                        self.output.push(Output::Op(self.operators.pop().unwrap()));
                    } else {
                        break self.on_err = true;
                    }
                }
            } else {
                // Fn and '('
                self.operators.push(op);
            }
        } else {
            while let Some(last) = self.operators.last() {
                if *last != Operator::ParenLeft
                    && (*last == Operator::Fn || last.ge_preference(&op))
                {
                    self.output.push(Output::Op(self.operators.pop().unwrap()));
                } else {
                    break;
                }
            }
            self.operators.push(op);
        }
    }
}

#[inline]
fn evaluate(output: Vec<Output>) -> Result<Value, ()> {
    let mut stack = Vec::new();
    for o in output {
        match o {
            Output::V(v) => stack.push(v),
            Output::Fn(method) => {
                macro_rules! fun_arg {
                    ($m:ident) => {{
                        let op2 = stack.pop().ok_or(())?;
                        let op1 = stack.pop().ok_or(())?;
                        op1.$m(op2)
                    }};
                }

                macro_rules! fun {
                    ($m:ident) => {{
                        let op1 = stack.pop().ok_or(())?;
                        op1.$m()
                    }};
                }

                use self::function::Fun::*;
                let e = match method {
                    Atan2 => fun_arg!(atan2),
                    Hypot => fun_arg!(hypot),
                    Log => fun_arg!(log),
                    Max => fun_arg!(max),
                    Min => fun_arg!(min),
                    PowF => fun_arg!(powf),
                    PowI => {
                        let op2 = stack.pop().ok_or(())?;
                        let op1 = stack.pop().ok_or(())?;
                        op1.powi(op2 as i32)
                    }
                    Abs => fun!(abs),
                    Acos => fun!(acos),
                    Acosh => fun!(acosh),
                    Asin => fun!(asin),
                    Asinh => fun!(asinh),
                    Atan => fun!(atan),
                    Atanh => fun!(atanh),
                    Cbrt => fun!(cbrt),
                    Ceil => fun!(ceil),
                    Cos => fun!(cos),
                    Cosh => fun!(cosh),
                    Exp => fun!(exp),
                    Exp2 => fun!(exp2),
                    ExpM1 => fun!(exp_m1),
                    Floor => fun!(floor),
                    Fract => fun!(fract),
                    Ln => fun!(ln),
                    Ln1p => fun!(ln_1p),
                    Log10 => fun!(log10),
                    Log2 => fun!(log2),
                    Recip => fun!(recip),
                    Round => fun!(round),
                    Signum => fun!(signum),
                    Sin => fun!(sin),
                    Sinh => fun!(sinh),
                    Sqrt => fun!(sqrt),
                    Tan => fun!(tan),
                    Tanh => fun!(tanh),
                    ToDegrees => fun!(to_degrees),
                    ToRadians => fun!(to_radians),
                    Trunc => fun!(trunc),
                };

                stack.push(e);
            }
            Output::Op(ref op) => {
                use Operator::*;

                macro_rules! two {
                    ($op:tt) => {{
                        let op2 = stack.pop().ok_or(())?;
                        let op1 = stack.pop().ok_or(())?;
                        op1 $op op2
                    }};
                }

                let e = match op {
                    Add => two!(+),
                    Sub => two!(-),
                    Mul => two!(*),
                    Div => two!(/),
                    Rem => two!(%),
                    Neg => {
                        let op1 = stack.pop().ok_or(())?;
                        -op1
                    }
                    _ => unreachable!(),
                };
                stack.push(e);
            }
        }
    }

    if stack.len() == 1 {
        stack.pop().ok_or(())
    } else {
        Err(())
    }
}

#[cfg(test)]
mod test {
    use syn::parse_str;

    use super::operator::Operator::*;
    use super::Output::*;
    use super::*;

    #[test]
    fn test_evaluate_add() {
        let f = vec![V(1.0), V(1.0), Op(Add)];
        assert_eq!(evaluate(f).unwrap(), 2.0);
    }

    #[test]
    fn test_evaluate_sub() {
        let f = vec![V(1.0), V(1.0), Op(Sub)];
        assert_eq!(evaluate(f).unwrap(), 0.0);
    }

    #[test]
    fn test_evaluate_mul() {
        let f = vec![V(1.0), V(1.0), Op(Mul)];
        assert_eq!(evaluate(f).unwrap(), 1.0);
    }

    #[test]
    fn test_evaluate_div() {
        let f = vec![V(1.0), V(1.0), Op(Div)];
        assert_eq!(evaluate(f).unwrap(), 1.0);
    }

    #[test]
    fn test_evaluate_rem() {
        let f = vec![V(4.0), V(2.0), Op(Rem)];
        assert_eq!(evaluate(f).unwrap(), 0.0);
    }

    #[test]
    fn test_eval_literal() {
        let src = "-1";
        let e = parse_str::<syn::Expr>(src).unwrap();
        let ctx = BTreeMap::new();

        assert_eq!(eval(&ctx, &e).unwrap(), -1.0);

        let src = "-1.0";
        let e = parse_str::<syn::Expr>(src).unwrap();
        let ctx = BTreeMap::new();

        assert_eq!(eval(&ctx, &e).unwrap(), -1.0);
    }

    #[test]
    fn test_eval_one() {
        let src = "1 + 1";
        let e = parse_str::<syn::Expr>(src).unwrap();
        let ctx = BTreeMap::new();

        assert_eq!(eval(&ctx, &e).unwrap(), 2.0);
    }

    #[test]
    fn test_eval() {
        let src = "1 + 1 - 6 % 5";
        let e = parse_str::<syn::Expr>(src).unwrap();
        let ctx = BTreeMap::new();

        assert_eq!(eval(&ctx, &e).unwrap(), 1.0);

        let src = "1 + 1 - 10 / 5";
        let e = parse_str::<syn::Expr>(src).unwrap();
        let ctx = BTreeMap::new();

        assert_eq!(eval(&ctx, &e).unwrap(), 0.0);

        let src = "foo + (1 * 1 - 10 / 5)";
        let e = parse_str::<syn::Expr>(src).unwrap();
        let mut ctx = BTreeMap::new();
        let arg = parse_str::<syn::Expr>("-1").unwrap();

        ctx.insert("foo", &arg);

        assert_eq!(eval(&ctx, &e).unwrap(), -2.0);

        let src = "(foo * 2) + 1";
        let e = parse_str::<syn::Expr>(src).unwrap();
        let mut ctx = BTreeMap::new();
        let arg = parse_str::<syn::Expr>("1 + -1 + -1 + 1").unwrap();

        ctx.insert("foo", &arg);

        assert_eq!(eval(&ctx, &e).unwrap(), 1.0);
    }

    #[test]
    fn test_eval_fn() {
        let src = "4.sqrt()";
        let e = parse_str::<syn::Expr>(src).unwrap();
        let ctx = BTreeMap::new();

        assert_eq!(eval(&ctx, &e).unwrap(), 2.0);

        let src = "2.powi(2)";
        let e = parse_str::<syn::Expr>(src).unwrap();
        let ctx = BTreeMap::new();

        assert_eq!(eval(&ctx, &e).unwrap(), 4.0);

        let src = "2.5.powi(2)";
        let e = parse_str::<syn::Expr>(src).unwrap();
        let ctx = BTreeMap::new();

        assert_eq!(eval(&ctx, &e).unwrap(), 6.25);
    }
}