fee 0.2.4

Expression evaluator supporting numeric, logical and bitwise operators
Documentation
use std::borrow::Cow;

use crate::{
    Error, EvalError, UContext,
    expr::{ExprCompiler, NotIndexedResolver, Op, ParseableToken},
    prelude::*,
    resolver::{LockedResolver, ResolverState, UnlockedResolver},
};

#[derive(Debug, PartialEq, Copy, Clone)]
pub enum Rpn<'e>
{
    Num(f64),
    Var(&'e str),
    Fn(&'e str, usize),
    Op(Op),
}

impl<'a, 'c, S, V, F, LV, LF> ParseableToken<'a, 'c, S, V, F, LV, LF> for Rpn<'a>
where
    S: ResolverState,
    V: Resolver<S, f64>,
    F: Resolver<S, ExprFn>,
{
    #[inline]
    fn f64(num: f64) -> Self
    {
        Rpn::Num(num)
    }

    #[inline]
    fn i64(num: i64) -> Self
    {
        Rpn::Num(num as f64)
    }

    #[inline]
    fn bool(val: bool) -> Self
    {
        Rpn::Num(if val { 1.0 } else { 0.0 })
    }

    #[inline]
    fn op(op: Op) -> Self
    {
        Rpn::Op(op)
    }

    #[inline]
    fn var(name: &'a str, _ctx: &'c Context<S, V, F, LV, LF>) -> Self
    {
        Rpn::Var(name)
    }

    #[inline]
    fn fun(name: &'a str, argc: usize, _ctx: &'c Context<S, V, F, LV, LF>) -> Self
    {
        Rpn::Fn(name, argc)
    }
}

impl<'e, 'c, V, F, LV, LF> ExprCompiler<'e, 'c, Unlocked, V, F, LV, LF, Rpn<'e>> for Expr<Rpn<'e>>
where
    V: NotIndexedResolver + UnlockedResolver<f64, LV>,
    F: NotIndexedResolver + UnlockedResolver<ExprFn, LF>,
    LV: LockedResolver<f64>,
    LF: LockedResolver<ExprFn>,
{
    fn compile(expr: &'e str, ctx: &'c UContext<V, F, LV, LF>) -> Result<Expr<Rpn<'e>>, Error<'e>>
    {
        Expr::try_from((expr, ctx))
    }
}

impl<'e, V, F, LV, LF> ExprEvaluator<'e, Unlocked, V, F, LV, LF> for Expr<Rpn<'e>>
where
    V: NotIndexedResolver + UnlockedResolver<f64, LV>,
    F: NotIndexedResolver + UnlockedResolver<ExprFn, LF>,
    LV: LockedResolver<f64>,
    LF: LockedResolver<ExprFn>,
{
    fn eval(&self, ctx: &UContext<V, F, LV, LF>, stack: &mut Vec<f64>) -> Result<f64, Error<'e>>
    {
        if self.tokens.len() == 1 {
            if let Rpn::Num(num) = &self.tokens[0] {
                return Ok(*num);
            }
        }

        for tok in self.tokens.iter() {
            match tok {
                Rpn::Num(num) => stack.push(*num),
                Rpn::Var(name) => stack.push(
                    *ctx.get_var(name)
                        .ok_or_else(|| Error::UnknownVar(Cow::Borrowed(name)))?,
                ),
                Rpn::Fn(name, argc) => {
                    if *argc > stack.len() {
                        return Err(Error::EvalError(EvalError::RPNStackUnderflow));
                    }

                    let start = stack.len() - argc;
                    let args = unsafe { stack.get_unchecked(start..) };
                    let val = ctx
                        .get_fn(name)
                        .ok_or_else(|| Error::UnknownFn(Cow::Borrowed(name)))?(
                        args
                    );

                    stack.truncate(start);
                    stack.push(val);
                }
                Rpn::Op(op) => {
                    if op.num_operands() > stack.len() {
                        return Err(Error::EvalError(EvalError::RPNStackUnderflow));
                    }

                    let start = stack.len() - op.num_operands();
                    let args = unsafe { stack.get_unchecked(start..) };
                    let res = op.apply(args);
                    stack.truncate(start);
                    stack.push(res);
                }
            }
        }

        match stack.pop() {
            Some(result) if stack.is_empty() => Ok(result),
            _ => Err(Error::EvalError(EvalError::MalformedExpression)),
        }
    }
}

#[cfg(test)]
mod tests
{
    use super::*;

    #[test]
    fn test_new()
    {
        let ctx = Context::empty();

        let expr = "2 - (4 + (p19 - 2) * (p19 + 2))";
        let rpn_expr = Expr::<Rpn>::try_from((expr, &ctx)).unwrap();
        assert_eq!(
            rpn_expr.tokens,
            vec![
                Rpn::Num(2.0),
                Rpn::Num(4.0),
                Rpn::Var("p19"),
                Rpn::Num(2.0),
                Rpn::Op(Op::Sub),
                Rpn::Var("p19"),
                Rpn::Num(2.0),
                Rpn::Op(Op::Add),
                Rpn::Op(Op::Mul),
                Rpn::Op(Op::Add),
                Rpn::Op(Op::Sub)
            ]
        );

        let expr = "sqrt(5)";
        let rpn_expr = Expr::<Rpn>::try_from((expr, &ctx)).unwrap();
        assert_eq!(rpn_expr.tokens, vec![Rpn::Num(5.0), Rpn::Fn("sqrt", 1),]);

        let expr = "abs(4, sqrt(5))";
        let rpn_expr = Expr::<Rpn>::try_from((expr, &ctx)).unwrap();
        assert_eq!(
            rpn_expr.tokens,
            vec![
                Rpn::Num(4.0),
                Rpn::Num(5.0),
                Rpn::Fn("sqrt", 1),
                Rpn::Fn("abs", 2),
            ]
        );

        let expr = "(2 * 21) + 3 + -35 - ((5 * 80) + 5) + 10 + -p0";
        let rpn_expr = Expr::<Rpn>::try_from((expr, &ctx)).unwrap();
        assert_eq!(
            rpn_expr.tokens,
            vec![
                Rpn::Num(-385.0),
                Rpn::Var("p0"),
                Rpn::Op(Op::Neg),
                Rpn::Op(Op::Add),
            ]
        );

        let expr = "-y1 * (p2 - p3*y0)";
        let rpn_expr = Expr::<Rpn>::try_from((expr, &ctx)).unwrap();
        assert_eq!(
            rpn_expr.tokens,
            vec![
                Rpn::Var("y1"),
                Rpn::Op(Op::Neg),
                Rpn::Var("p2"),
                Rpn::Var("p3"),
                Rpn::Var("y0"),
                Rpn::Op(Op::Mul),
                Rpn::Op(Op::Sub),
                Rpn::Op(Op::Mul),
            ]
        );
    }
}