math_jit/
rpn.rs

1//! Parsing and operations on the program
2
3use crate::{error::JitError, Library};
4
5/// RPN Token
6#[derive(Clone, Debug, PartialEq, PartialOrd)]
7pub enum Token {
8    /// Push a value onto the stack
9    Push(Value),
10    /// Push variable value onto the stack
11    PushVar(Var),
12    /// Write top of stack to in-out variable
13    Write(Out),
14    /// Binary operation
15    ///
16    /// Pops 2 values from the stack, performs the operation, and pushes the
17    /// result back onto the stack
18    Binop(Binop),
19    /// Unary operation
20    ///
21    /// Replaces the top value on the stack with the result of the operation
22    Unop(Unop),
23    /// Function call
24    ///
25    /// Pops a number of arguments from the stack, evaluates the function, and
26    /// pushes the result back onto the stack.
27    Function(Function),
28    /// No operation
29    Noop,
30}
31
32/// Constant value
33#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
34pub enum Value {
35    /// Arbotrary value
36    Literal(f32),
37    /// Pi
38    Pi,
39    /// Euler's constant
40    E,
41}
42
43impl Value {
44    /// Obtains the corresponding value
45    pub fn value(self) -> f32 {
46        match self {
47            Value::Literal(f) => f,
48            Value::Pi => std::f32::consts::PI,
49            Value::E => std::f32::consts::E,
50        }
51    }
52}
53
54/// Readable variables
55#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
56pub enum Var {
57    X,
58    Y,
59    A,
60    B,
61    C,
62    D,
63    Sig1,
64    Sig2,
65}
66
67/// Writeable variables
68#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
69pub enum Out {
70    Sig1,
71    Sig2,
72}
73
74/// Binary operation
75#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
76pub enum Binop {
77    /// Addition
78    Add,
79    /// Subtraction
80    Sub,
81    /// Multiplication
82    Mul,
83    /// Division
84    Div,
85}
86
87/// Unary operation
88#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
89pub enum Unop {
90    /// Negation
91    Neg,
92}
93
94/// Function call
95#[derive(Clone, Debug, PartialEq, PartialOrd)]
96pub struct Function {
97    /// Name of the function
98    pub name: String,
99    /// Number of arguments
100    pub args: usize,
101}
102
103/// Parsed program representation
104///
105/// The program is represented using Reverse Polish Notation, which is lends
106/// to easy iterative translation into CLIF as well as to simple optimizations.
107#[derive(Debug, PartialEq, PartialOrd)]
108pub struct Program(pub Vec<Token>);
109
110impl Program {
111    /// Constructs program directly from RPN
112    pub fn new(tokens: Vec<Token>) -> Self {
113        Program(tokens)
114    }
115
116    /// Parses an infix notation into RPN
117    pub fn parse_from_infix(expr: &str) -> Result<Self, JitError> {
118        let tokens = meval::tokenizer::tokenize(expr)?;
119        let meval_rpn = meval::shunting_yard::to_rpn(&tokens)?;
120
121        let mut prog = Vec::new();
122        for meval_token in meval_rpn {
123            use meval::tokenizer::Operation as MevalOp;
124            use meval::tokenizer::Token as MevalToken;
125            let token = match meval_token {
126                MevalToken::Var(name) => match name.as_str() {
127                    "x" => Token::PushVar(Var::X),
128                    "y" => Token::PushVar(Var::Y),
129                    "a" => Token::PushVar(Var::A),
130                    "b" => Token::PushVar(Var::B),
131                    "c" => Token::PushVar(Var::C),
132                    "d" => Token::PushVar(Var::D),
133                    "_1" => Token::PushVar(Var::Sig1),
134                    "_2" => Token::PushVar(Var::Sig2),
135                    "pi" => Token::Push(Value::Pi),
136                    "e" => Token::Push(Value::E),
137                    _ => return Err(JitError::ParseUnknownVariable(name.to_string())),
138                },
139                MevalToken::Number(f) => Token::Push(Value::Literal(f as f32)),
140                MevalToken::Binary(op) => match op {
141                    MevalOp::Plus => Token::Binop(Binop::Add),
142                    MevalOp::Minus => Token::Binop(Binop::Sub),
143                    MevalOp::Times => Token::Binop(Binop::Mul),
144                    MevalOp::Div => Token::Binop(Binop::Div),
145                    MevalOp::Pow => Token::Function(Function {
146                        name: "pow".to_string(),
147                        args: 2,
148                    }),
149                    _ => return Err(JitError::ParseUnknownBinop(format!("{op:?}"))),
150                },
151                MevalToken::Unary(op) => match op {
152                    MevalOp::Plus => Token::Noop,
153                    MevalOp::Minus => Token::Unop(Unop::Neg),
154                    _ => return Err(JitError::ParseUnknownUnop(format!("{op:?}"))),
155                },
156                MevalToken::Func(name, Some(1)) if name == "_1" => Token::Write(Out::Sig1),
157                MevalToken::Func(name, Some(1)) if name == "_2" => Token::Write(Out::Sig2),
158                MevalToken::Func(name, args) => Token::Function(Function {
159                    name,
160                    args: args.unwrap_or_default(),
161                }),
162
163                other => return Err(JitError::ParseUnknownToken(format!("{other:?}"))),
164            };
165
166            prog.push(token);
167        }
168
169        Ok(Program(prog))
170    }
171
172    /// Rewrites RPN into a deeper form that's more optimizable
173    ///
174    /// The optimizer isn't able to optimize RPN like `[.. 1 + 1 +]`. This
175    /// function will replace it with `[.. 1 1 + +]`, which the optimizer
176    /// will rewrite as `[.. 2 +]`.
177    ///
178    /// The resultant form has a deeper stack, meaning more variables need to
179    /// be kept alive at the same time.
180    pub fn reorder_ops_deepen(&mut self) {
181        for n in 2..self.0.len() {
182            let (tok0, tok1, tok2) = (
183                self.0[n - 2].clone(),
184                self.0[n - 1].clone(),
185                self.0[n].clone(),
186            );
187
188            let (ntok0, ntok1, ntok2) = match (tok0, tok1, tok2) {
189                (
190                    op1 @ Token::Binop(Binop::Add | Binop::Sub),
191                    push @ (Token::Push(_) | Token::PushVar(_)),
192                    op2 @ Token::Binop(Binop::Add | Binop::Sub),
193                ) => (push, op2, op1),
194                (
195                    op1 @ Token::Binop(Binop::Mul | Binop::Div),
196                    push @ (Token::Push(_) | Token::PushVar(_)),
197                    op2 @ Token::Binop(Binop::Mul | Binop::Div),
198                ) => (push, op2, op1),
199                _ => continue,
200            };
201
202            self.0[n - 2] = ntok0;
203            self.0[n - 1] = ntok1;
204            self.0[n] = ntok2;
205        }
206    }
207
208    /// Rewrites RPN into a form that requires a lower stack
209    ///
210    /// `a * (b / c)` will produce RPN `a b c / *`, which keeps up to 3 variables
211    /// alive at once. This optimization will rewrite it into RPN `a b * c /`,
212    /// which does the same work despite using less memory.
213    ///
214    /// Notably the constant folding algorithm in this library will fail to
215    /// optimize this form.
216    pub fn reorder_ops_flatten(&mut self) {
217        let mut work_done = true;
218        while work_done {
219            work_done = false;
220
221            for n in 2..self.0.len() {
222                let (tok0, tok1, tok2) = (
223                    self.0[n - 2].clone(),
224                    self.0[n - 1].clone(),
225                    self.0[n].clone(),
226                );
227
228                let (ntok0, ntok1, ntok2) = match (tok0, tok1, tok2) {
229                    (
230                        push @ (Token::Push(_) | Token::PushVar(_)),
231                        op2 @ Token::Binop(Binop::Add | Binop::Sub | Binop::Mul | Binop::Div),
232                        op1 @ Token::Binop(Binop::Add | Binop::Sub | Binop::Mul | Binop::Div),
233                    ) => (op1, push, op2),
234                    _ => continue,
235                };
236
237                self.0[n - 2] = ntok0;
238                self.0[n - 1] = ntok1;
239                self.0[n] = ntok2;
240                work_done = true;
241            }
242        }
243    }
244
245    /// Evaluate some constant expressions
246    ///
247    /// Optimizes binary and unary operations:
248    /// - replace `[const0, const1, op]` with `[op(const0, const1)]`
249    /// - replace `[const, op]` with `[op(const)]`
250    ///
251    /// [`Token::Noop`] is removed in the process. Only one pass over the code
252    /// is made. Returns `false` if no further progress can be made.
253    ///
254    /// Doesn't support reordering of associative operations, so
255    /// `[var, const0, add, const1, add]` is *not* replaced with
256    /// `[var, add(const0, const1), add]` and so on.
257    pub fn fold_constants_step(&mut self, library: &Library) -> bool {
258        let mut work_done = false;
259
260        for n in 2..self.0.len() {
261            match self.0[n].clone() {
262                Token::Unop(unop) => {
263                    let Token::Push(a) = self.0[n - 1] else {
264                        continue;
265                    };
266                    let result = match unop {
267                        Unop::Neg => -a.value(),
268                    };
269
270                    self.0[n - 1] = Token::Noop;
271                    self.0[n] = Token::Push(Value::Literal(result));
272                    work_done = true;
273                }
274                Token::Binop(binop) => {
275                    let Token::Push(a) = self.0[n - 2] else {
276                        continue;
277                    };
278                    let Token::Push(b) = self.0[n - 1] else {
279                        continue;
280                    };
281
282                    let (a, b) = (a.value(), b.value());
283                    let result = match binop {
284                        Binop::Add => a + b,
285                        Binop::Sub => a - b,
286                        Binop::Mul => a * b,
287                        Binop::Div => a / b,
288                    };
289
290                    self.0[n - 2] = Token::Noop;
291                    self.0[n - 1] = Token::Noop;
292                    self.0[n] = Token::Push(Value::Literal(result));
293                    work_done = true;
294                }
295                Token::Function(Function { name, args }) => {
296                    let Some(extern_fun) = library.iter().find(|f| f.name == name) else {
297                        log::warn!("No function {name} in library, compilation will fail");
298                        continue;
299                    };
300
301                    let result = match args {
302                        1 => {
303                            let Token::Push(a) = self.0[n - 1] else {
304                                continue;
305                            };
306                            extern_fun.call_1(a.value())
307                        }
308                        2 => {
309                            let Token::Push(a) = self.0[n - 2] else {
310                                continue;
311                            };
312                            let Token::Push(b) = self.0[n - 1] else {
313                                continue;
314                            };
315                            extern_fun.call_2(a.value(), b.value())
316                        }
317                        _ => continue,
318                    };
319
320                    let Some(value) = result else {
321                        log::warn!("Function {name} called with invalid number of arguments, compilation will fail");
322                        continue;
323                    };
324
325                    self.0[n - args..n].fill_with(|| Token::Noop);
326                    self.0[n] = Token::Push(Value::Literal(value));
327                }
328                _ => continue,
329            }
330        }
331
332        self.0.retain(|tok| *tok != Token::Noop);
333
334        work_done
335    }
336
337    /// Rewrites RPN into a form most suitable for codegen
338    ///
339    /// Performs constant folding and minimizes stack usage of the resultant RPN.
340    ///
341    /// For details, see:
342    /// - [`Self::reorder_ops_deepen`]
343    /// - [`Self::reorder_ops_flatten`]
344    /// - [`Self::fold_constants_step`]
345    pub fn optimize(&mut self, library: &Library) {
346        let mut work_done = true;
347        while work_done {
348            self.reorder_ops_deepen();
349            work_done = self.fold_constants_step(library);
350        }
351
352        self.reorder_ops_flatten();
353    }
354}
355
356#[cfg(test)]
357mod tests {
358    use std::f32::consts::PI;
359
360    use crate::{
361        rpn::{Token, Value},
362        Library, Program,
363    };
364
365    use super::{Binop, Function, Out, Unop, Var};
366
367    #[test]
368    fn test_parse() {
369        let two = || Token::Push(Value::Literal(2.0));
370
371        let cases = [
372            ("2", vec![two()]),
373            ("2 + 2", vec![two(), two(), Token::Binop(Binop::Add)]),
374            ("2 - 2", vec![two(), two(), Token::Binop(Binop::Sub)]),
375            ("2 * 2", vec![two(), two(), Token::Binop(Binop::Mul)]),
376            ("2 / 2", vec![two(), two(), Token::Binop(Binop::Div)]),
377            (
378                "2 ^ 2",
379                vec![
380                    two(),
381                    two(),
382                    Token::Function(Function {
383                        name: "pow".into(),
384                        args: 2,
385                    }),
386                ],
387            ),
388            ("-2", vec![two(), Token::Unop(Unop::Neg)]),
389            (
390                "sin(cos(tan(_2(_1(2)))))",
391                vec![
392                    two(),
393                    Token::Write(Out::Sig1),
394                    Token::Write(Out::Sig2),
395                    Token::Function(Function {
396                        name: "tan".into(),
397                        args: 1,
398                    }),
399                    Token::Function(Function {
400                        name: "cos".into(),
401                        args: 1,
402                    }),
403                    Token::Function(Function {
404                        name: "sin".into(),
405                        args: 1,
406                    }),
407                ],
408            ),
409            ("x", vec![Token::PushVar(Var::X)]),
410            ("y", vec![Token::PushVar(Var::Y)]),
411            ("a", vec![Token::PushVar(Var::A)]),
412            ("b", vec![Token::PushVar(Var::B)]),
413            ("c", vec![Token::PushVar(Var::C)]),
414            ("d", vec![Token::PushVar(Var::D)]),
415            ("pi", vec![Token::Push(Value::Pi)]),
416            ("e", vec![Token::Push(Value::E)]),
417        ];
418
419        for (expr, tokens) in cases {
420            assert_eq!(Program::parse_from_infix(expr).unwrap(), Program(tokens));
421        }
422    }
423
424    #[test]
425    fn test_optimize() {
426        let x = |x| Token::Push(Value::Literal(x));
427
428        fn rough_compare(prog0: &Program, prog1: &Program) -> bool {
429            if prog0.0.len() != prog1.0.len() {
430                return false;
431            }
432
433            for (tok0, tok1) in prog0.0.iter().zip(prog1.0.iter()) {
434                const EPS: f32 = 0.00001;
435                match (tok0, tok1) {
436                    (Token::Push(Value::Literal(l)), Token::Push(Value::Literal(r))) => {
437                        if (l - r).abs() > EPS {
438                            return false;
439                        }
440                    }
441                    (left, right) => {
442                        if left != right {
443                            return false;
444                        }
445                    }
446                }
447            }
448
449            true
450        }
451
452        let cases = [
453            ("2", vec![x(2.0)]),
454            ("2 + 2", vec![x(4.0)]),
455            ("2 + -2", vec![x(0.0)]),
456            ("sin(pi/2 + pi/2)", vec![x(0.0)]),
457            (
458                "sin(pi/2 + pi/2) + x",
459                vec![x(0.0), Token::PushVar(Var::X), Token::Binop(Binop::Add)],
460            ),
461            (
462                "x + 1 + 1",
463                vec![Token::PushVar(Var::X), x(2.0), Token::Binop(Binop::Add)],
464            ),
465            (
466                "x * pi/4/3",
467                vec![
468                    Token::PushVar(Var::X),
469                    x(PI / 12.0),
470                    Token::Binop(Binop::Mul),
471                ],
472            ),
473            (
474                "a + b + c",
475                vec![
476                    Token::PushVar(Var::A),
477                    Token::PushVar(Var::B),
478                    Token::Binop(Binop::Add),
479                    Token::PushVar(Var::C),
480                    Token::Binop(Binop::Add),
481                ],
482            ),
483            (
484                "x * (a / b)",
485                vec![
486                    Token::PushVar(Var::X),
487                    Token::PushVar(Var::A),
488                    Token::Binop(Binop::Mul),
489                    Token::PushVar(Var::B),
490                    Token::Binop(Binop::Div),
491                ],
492            ),
493        ];
494
495        for (expr, tokens) in cases {
496            let mut program = Program::parse_from_infix(expr).unwrap();
497            program.optimize(&Library::default());
498            let expected = Program(tokens);
499            assert!(
500                rough_compare(&program, &expected),
501                "{program:?} != {expected:?}"
502            );
503        }
504    }
505}