fee/expr/
irpn.rs

1use std::borrow::Cow;
2
3use crate::{
4    Error, EvalError, IndexedResolver, UContext,
5    expr::{ExprCompiler, Op, ParseableToken},
6    parsing,
7    prelude::*,
8    resolver::ResolverState,
9};
10
11#[derive(Debug, PartialEq, Copy, Clone)]
12pub enum IRpn
13{
14    Num(f64),
15    Var(usize, usize),
16    Fn(usize, usize, usize),
17    Op(Op),
18}
19
20impl<'a, 'c, S, V, F, LV, LF> ParseableToken<'a, 'c, S, V, F, LV, LF> for IRpn
21where
22    S: ResolverState,
23    V: Resolver<S, f64>,
24    F: Resolver<S, ExprFn>,
25{
26    #[inline]
27    fn f64(num: f64) -> Self
28    {
29        IRpn::Num(num)
30    }
31
32    #[inline]
33    fn i64(num: i64) -> Self
34    {
35        IRpn::Num(num as f64)
36    }
37
38    #[inline]
39    fn bool(val: bool) -> Self
40    {
41        IRpn::Num(if val { 1.0 } else { 0.0 })
42    }
43
44    #[inline]
45    fn op(op: Op) -> Self
46    {
47        IRpn::Op(op)
48    }
49
50    #[inline]
51    fn var(name: &'a str, _ctx: &'c Context<S, V, F, LV, LF>) -> Self
52    {
53        let name_bytes = name.as_bytes();
54        let letter = name_bytes[0] - b'a';
55        let idx = parsing::parse_usize(&name_bytes[1..]);
56        IRpn::Var(letter as usize, idx)
57    }
58
59    #[inline]
60    fn fun(name: &'a str, argc: usize, _ctx: &'c Context<S, V, F, LV, LF>) -> Self
61    {
62        let name_bytes = name.as_bytes();
63        let letter = name_bytes[0] - b'a';
64        let idx = parsing::parse_usize(&name_bytes[1..]);
65        IRpn::Fn(letter as usize, idx, argc)
66    }
67}
68
69impl<'e, 'c>
70    ExprCompiler<
71        'e,
72        'c,
73        Unlocked,
74        IndexedResolver<Unlocked, f64>,
75        IndexedResolver<Unlocked, ExprFn>,
76        IndexedResolver<Locked, f64>,
77        IndexedResolver<Locked, ExprFn>,
78        IRpn,
79    > for Expr<IRpn>
80{
81    fn compile(
82        expr: &'e str,
83        ctx: &'c UContext<
84            IndexedResolver<Unlocked, f64>,
85            IndexedResolver<Unlocked, ExprFn>,
86            IndexedResolver<Locked, f64>,
87            IndexedResolver<Locked, ExprFn>,
88        >,
89    ) -> Result<Expr<IRpn>, Error<'e>>
90    {
91        Expr::try_from((expr, ctx))
92    }
93}
94
95impl<'e>
96    ExprEvaluator<
97        'e,
98        Unlocked,
99        IndexedResolver<Unlocked, f64>,
100        IndexedResolver<Unlocked, ExprFn>,
101        IndexedResolver<Locked, f64>,
102        IndexedResolver<Locked, ExprFn>,
103    > for Expr<IRpn>
104{
105    fn eval(
106        &self,
107        ctx: &UContext<
108            IndexedResolver<Unlocked, f64>,
109            IndexedResolver<Unlocked, ExprFn>,
110            IndexedResolver<Locked, f64>,
111            IndexedResolver<Locked, ExprFn>,
112        >,
113        stack: &mut Vec<f64>,
114    ) -> Result<f64, Error<'e>>
115    {
116        if self.tokens.len() == 1 {
117            if let IRpn::Num(num) = &self.tokens[0] {
118                return Ok(*num);
119            }
120        }
121
122        for tok in self.tokens.iter() {
123            match tok {
124                IRpn::Num(num) => stack.push(*num),
125                IRpn::Var(id, idx) => {
126                    stack.push(*ctx.get_var_by_index(*id, *idx).ok_or_else(|| {
127                        Error::UnknownVar(Cow::Owned(format!(
128                            "{}{}",
129                            (*id as u8 + b'a') as char,
130                            idx
131                        )))
132                    })?)
133                }
134                IRpn::Fn(id, idx, argc) => {
135                    if *argc > stack.len() {
136                        return Err(Error::EvalError(EvalError::RPNStackUnderflow));
137                    }
138
139                    let start = stack.len() - argc;
140                    let args = unsafe { stack.get_unchecked(start..) };
141                    let val = ctx.call_fn_by_index(*id, *idx, args).ok_or_else(|| {
142                        Error::UnknownFn(Cow::Owned(format!(
143                            "{}{}",
144                            (*id as u8 + b'a') as char,
145                            idx
146                        )))
147                    })?;
148                    stack.truncate(start);
149                    stack.push(val);
150                }
151                IRpn::Op(op) => {
152                    if op.num_operands() > stack.len() {
153                        return Err(Error::EvalError(EvalError::RPNStackUnderflow));
154                    }
155
156                    let start = stack.len() - op.num_operands();
157                    let args = unsafe { stack.get_unchecked(start..) };
158                    let res = op.apply(args);
159                    stack.truncate(start);
160                    stack.push(res);
161                }
162            }
163        }
164
165        match stack.pop() {
166            Some(result) if stack.is_empty() => Ok(result),
167            _ => Err(Error::EvalError(EvalError::MalformedExpression)),
168        }
169    }
170}
171
172#[cfg(test)]
173mod tests
174{
175    use super::*;
176
177    #[test]
178    fn test_new()
179    {
180        let ctx = Context::empty();
181
182        let expr = "2 - (4 + (p19 - 2) * (p19 + 2))";
183        let rpn_expr = Expr::<IRpn>::try_from((expr, &ctx)).unwrap();
184        assert_eq!(
185            rpn_expr.tokens,
186            vec![
187                IRpn::Num(2.0),
188                IRpn::Num(4.0),
189                IRpn::Var((b'p' - b'a') as usize, 19),
190                IRpn::Num(2.0),
191                IRpn::Op(Op::Sub),
192                IRpn::Var((b'p' - b'a') as usize, 19),
193                IRpn::Num(2.0),
194                IRpn::Op(Op::Add),
195                IRpn::Op(Op::Mul),
196                IRpn::Op(Op::Add),
197                IRpn::Op(Op::Sub)
198            ]
199        );
200
201        let expr = "f0((2 + 3) * 4, f1(5))";
202        let rpn_expr = Expr::<IRpn>::try_from((expr, &ctx)).unwrap();
203        assert_eq!(
204            rpn_expr.tokens,
205            vec![
206                IRpn::Num(20.0),
207                IRpn::Num(5.0),
208                IRpn::Fn((b'f' - b'a') as usize, 1, 1),
209                IRpn::Fn((b'f' - b'a') as usize, 0, 2),
210            ]
211        );
212
213        let expr = "(2 * 21) + 3 + -35 - ((5 * 80) + 5) + 10 + -p0";
214        let rpn_expr = Expr::<IRpn>::try_from((expr, &ctx)).unwrap();
215        assert_eq!(
216            rpn_expr.tokens,
217            vec![
218                IRpn::Num(-385.0),
219                IRpn::Var((b'p' - b'a') as usize, 0),
220                IRpn::Op(Op::Neg),
221                IRpn::Op(Op::Add),
222            ]
223        );
224
225        let expr = "-y1 * (p2 - p3*y0)";
226        let rpn_expr = Expr::<IRpn>::try_from((expr, &ctx)).unwrap();
227        assert_eq!(
228            rpn_expr.tokens,
229            vec![
230                IRpn::Var((b'y' - b'a') as usize, 1),
231                IRpn::Op(Op::Neg),
232                IRpn::Var((b'p' - b'a') as usize, 2),
233                IRpn::Var((b'p' - b'a') as usize, 3),
234                IRpn::Var((b'y' - b'a') as usize, 0),
235                IRpn::Op(Op::Mul),
236                IRpn::Op(Op::Sub),
237                IRpn::Op(Op::Mul),
238            ]
239        );
240    }
241}