1use crate::{error::JitError, Library};
4
5#[derive(Clone, Debug, PartialEq, PartialOrd)]
7pub enum Token {
8 Push(Value),
10 PushVar(Var),
12 Write(Out),
14 Binop(Binop),
19 Unop(Unop),
23 Function(Function),
28 Noop,
30}
31
32#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
34pub enum Value {
35 Literal(f32),
37 Pi,
39 E,
41}
42
43impl Value {
44 pub fn value(self) -> f32 {
46 match self {
47 Value::Literal(f) => f,
48 Value::Pi => std::f32::consts::PI,
49 Value::E => std::f32::consts::E,
50 }
51 }
52}
53
54#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
56pub enum Var {
57 X,
58 Y,
59 A,
60 B,
61 C,
62 D,
63 Sig1,
64 Sig2,
65}
66
67#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
69pub enum Out {
70 Sig1,
71 Sig2,
72}
73
74#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
76pub enum Binop {
77 Add,
79 Sub,
81 Mul,
83 Div,
85}
86
87#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
89pub enum Unop {
90 Neg,
92}
93
94#[derive(Clone, Debug, PartialEq, PartialOrd)]
96pub struct Function {
97 pub name: String,
99 pub args: usize,
101}
102
103#[derive(Debug, PartialEq, PartialOrd)]
108pub struct Program(pub Vec<Token>);
109
110impl Program {
111 pub fn new(tokens: Vec<Token>) -> Self {
113 Program(tokens)
114 }
115
116 pub fn parse_from_infix(expr: &str) -> Result<Self, JitError> {
118 let tokens = meval::tokenizer::tokenize(expr)?;
119 let meval_rpn = meval::shunting_yard::to_rpn(&tokens)?;
120
121 let mut prog = Vec::new();
122 for meval_token in meval_rpn {
123 use meval::tokenizer::Operation as MevalOp;
124 use meval::tokenizer::Token as MevalToken;
125 let token = match meval_token {
126 MevalToken::Var(name) => match name.as_str() {
127 "x" => Token::PushVar(Var::X),
128 "y" => Token::PushVar(Var::Y),
129 "a" => Token::PushVar(Var::A),
130 "b" => Token::PushVar(Var::B),
131 "c" => Token::PushVar(Var::C),
132 "d" => Token::PushVar(Var::D),
133 "_1" => Token::PushVar(Var::Sig1),
134 "_2" => Token::PushVar(Var::Sig2),
135 "pi" => Token::Push(Value::Pi),
136 "e" => Token::Push(Value::E),
137 _ => return Err(JitError::ParseUnknownVariable(name.to_string())),
138 },
139 MevalToken::Number(f) => Token::Push(Value::Literal(f as f32)),
140 MevalToken::Binary(op) => match op {
141 MevalOp::Plus => Token::Binop(Binop::Add),
142 MevalOp::Minus => Token::Binop(Binop::Sub),
143 MevalOp::Times => Token::Binop(Binop::Mul),
144 MevalOp::Div => Token::Binop(Binop::Div),
145 MevalOp::Pow => Token::Function(Function {
146 name: "pow".to_string(),
147 args: 2,
148 }),
149 _ => return Err(JitError::ParseUnknownBinop(format!("{op:?}"))),
150 },
151 MevalToken::Unary(op) => match op {
152 MevalOp::Plus => Token::Noop,
153 MevalOp::Minus => Token::Unop(Unop::Neg),
154 _ => return Err(JitError::ParseUnknownUnop(format!("{op:?}"))),
155 },
156 MevalToken::Func(name, Some(1)) if name == "_1" => Token::Write(Out::Sig1),
157 MevalToken::Func(name, Some(1)) if name == "_2" => Token::Write(Out::Sig2),
158 MevalToken::Func(name, args) => Token::Function(Function {
159 name,
160 args: args.unwrap_or_default(),
161 }),
162
163 other => return Err(JitError::ParseUnknownToken(format!("{other:?}"))),
164 };
165
166 prog.push(token);
167 }
168
169 Ok(Program(prog))
170 }
171
172 pub fn reorder_ops_deepen(&mut self) {
181 for n in 2..self.0.len() {
182 let (tok0, tok1, tok2) = (
183 self.0[n - 2].clone(),
184 self.0[n - 1].clone(),
185 self.0[n].clone(),
186 );
187
188 let (ntok0, ntok1, ntok2) = match (tok0, tok1, tok2) {
189 (
190 op1 @ Token::Binop(Binop::Add | Binop::Sub),
191 push @ (Token::Push(_) | Token::PushVar(_)),
192 op2 @ Token::Binop(Binop::Add | Binop::Sub),
193 ) => (push, op2, op1),
194 (
195 op1 @ Token::Binop(Binop::Mul | Binop::Div),
196 push @ (Token::Push(_) | Token::PushVar(_)),
197 op2 @ Token::Binop(Binop::Mul | Binop::Div),
198 ) => (push, op2, op1),
199 _ => continue,
200 };
201
202 self.0[n - 2] = ntok0;
203 self.0[n - 1] = ntok1;
204 self.0[n] = ntok2;
205 }
206 }
207
208 pub fn reorder_ops_flatten(&mut self) {
217 let mut work_done = true;
218 while work_done {
219 work_done = false;
220
221 for n in 2..self.0.len() {
222 let (tok0, tok1, tok2) = (
223 self.0[n - 2].clone(),
224 self.0[n - 1].clone(),
225 self.0[n].clone(),
226 );
227
228 let (ntok0, ntok1, ntok2) = match (tok0, tok1, tok2) {
229 (
230 push @ (Token::Push(_) | Token::PushVar(_)),
231 op2 @ Token::Binop(Binop::Add | Binop::Sub | Binop::Mul | Binop::Div),
232 op1 @ Token::Binop(Binop::Add | Binop::Sub | Binop::Mul | Binop::Div),
233 ) => (op1, push, op2),
234 _ => continue,
235 };
236
237 self.0[n - 2] = ntok0;
238 self.0[n - 1] = ntok1;
239 self.0[n] = ntok2;
240 work_done = true;
241 }
242 }
243 }
244
245 pub fn fold_constants_step(&mut self, library: &Library) -> bool {
258 let mut work_done = false;
259
260 for n in 2..self.0.len() {
261 match self.0[n].clone() {
262 Token::Unop(unop) => {
263 let Token::Push(a) = self.0[n - 1] else {
264 continue;
265 };
266 let result = match unop {
267 Unop::Neg => -a.value(),
268 };
269
270 self.0[n - 1] = Token::Noop;
271 self.0[n] = Token::Push(Value::Literal(result));
272 work_done = true;
273 }
274 Token::Binop(binop) => {
275 let Token::Push(a) = self.0[n - 2] else {
276 continue;
277 };
278 let Token::Push(b) = self.0[n - 1] else {
279 continue;
280 };
281
282 let (a, b) = (a.value(), b.value());
283 let result = match binop {
284 Binop::Add => a + b,
285 Binop::Sub => a - b,
286 Binop::Mul => a * b,
287 Binop::Div => a / b,
288 };
289
290 self.0[n - 2] = Token::Noop;
291 self.0[n - 1] = Token::Noop;
292 self.0[n] = Token::Push(Value::Literal(result));
293 work_done = true;
294 }
295 Token::Function(Function { name, args }) => {
296 let Some(extern_fun) = library.iter().find(|f| f.name == name) else {
297 log::warn!("No function {name} in library, compilation will fail");
298 continue;
299 };
300
301 let result = match args {
302 1 => {
303 let Token::Push(a) = self.0[n - 1] else {
304 continue;
305 };
306 extern_fun.call_1(a.value())
307 }
308 2 => {
309 let Token::Push(a) = self.0[n - 2] else {
310 continue;
311 };
312 let Token::Push(b) = self.0[n - 1] else {
313 continue;
314 };
315 extern_fun.call_2(a.value(), b.value())
316 }
317 _ => continue,
318 };
319
320 let Some(value) = result else {
321 log::warn!("Function {name} called with invalid number of arguments, compilation will fail");
322 continue;
323 };
324
325 self.0[n - args..n].fill_with(|| Token::Noop);
326 self.0[n] = Token::Push(Value::Literal(value));
327 }
328 _ => continue,
329 }
330 }
331
332 self.0.retain(|tok| *tok != Token::Noop);
333
334 work_done
335 }
336
337 pub fn optimize(&mut self, library: &Library) {
346 let mut work_done = true;
347 while work_done {
348 self.reorder_ops_deepen();
349 work_done = self.fold_constants_step(library);
350 }
351
352 self.reorder_ops_flatten();
353 }
354}
355
356#[cfg(test)]
357mod tests {
358 use std::f32::consts::PI;
359
360 use crate::{
361 rpn::{Token, Value},
362 Library, Program,
363 };
364
365 use super::{Binop, Function, Out, Unop, Var};
366
367 #[test]
368 fn test_parse() {
369 let two = || Token::Push(Value::Literal(2.0));
370
371 let cases = [
372 ("2", vec![two()]),
373 ("2 + 2", vec![two(), two(), Token::Binop(Binop::Add)]),
374 ("2 - 2", vec![two(), two(), Token::Binop(Binop::Sub)]),
375 ("2 * 2", vec![two(), two(), Token::Binop(Binop::Mul)]),
376 ("2 / 2", vec![two(), two(), Token::Binop(Binop::Div)]),
377 (
378 "2 ^ 2",
379 vec![
380 two(),
381 two(),
382 Token::Function(Function {
383 name: "pow".into(),
384 args: 2,
385 }),
386 ],
387 ),
388 ("-2", vec![two(), Token::Unop(Unop::Neg)]),
389 (
390 "sin(cos(tan(_2(_1(2)))))",
391 vec![
392 two(),
393 Token::Write(Out::Sig1),
394 Token::Write(Out::Sig2),
395 Token::Function(Function {
396 name: "tan".into(),
397 args: 1,
398 }),
399 Token::Function(Function {
400 name: "cos".into(),
401 args: 1,
402 }),
403 Token::Function(Function {
404 name: "sin".into(),
405 args: 1,
406 }),
407 ],
408 ),
409 ("x", vec![Token::PushVar(Var::X)]),
410 ("y", vec![Token::PushVar(Var::Y)]),
411 ("a", vec![Token::PushVar(Var::A)]),
412 ("b", vec![Token::PushVar(Var::B)]),
413 ("c", vec![Token::PushVar(Var::C)]),
414 ("d", vec![Token::PushVar(Var::D)]),
415 ("pi", vec![Token::Push(Value::Pi)]),
416 ("e", vec![Token::Push(Value::E)]),
417 ];
418
419 for (expr, tokens) in cases {
420 assert_eq!(Program::parse_from_infix(expr).unwrap(), Program(tokens));
421 }
422 }
423
424 #[test]
425 fn test_optimize() {
426 let x = |x| Token::Push(Value::Literal(x));
427
428 fn rough_compare(prog0: &Program, prog1: &Program) -> bool {
429 if prog0.0.len() != prog1.0.len() {
430 return false;
431 }
432
433 for (tok0, tok1) in prog0.0.iter().zip(prog1.0.iter()) {
434 const EPS: f32 = 0.00001;
435 match (tok0, tok1) {
436 (Token::Push(Value::Literal(l)), Token::Push(Value::Literal(r))) => {
437 if (l - r).abs() > EPS {
438 return false;
439 }
440 }
441 (left, right) => {
442 if left != right {
443 return false;
444 }
445 }
446 }
447 }
448
449 true
450 }
451
452 let cases = [
453 ("2", vec![x(2.0)]),
454 ("2 + 2", vec![x(4.0)]),
455 ("2 + -2", vec![x(0.0)]),
456 ("sin(pi/2 + pi/2)", vec![x(0.0)]),
457 (
458 "sin(pi/2 + pi/2) + x",
459 vec![x(0.0), Token::PushVar(Var::X), Token::Binop(Binop::Add)],
460 ),
461 (
462 "x + 1 + 1",
463 vec![Token::PushVar(Var::X), x(2.0), Token::Binop(Binop::Add)],
464 ),
465 (
466 "x * pi/4/3",
467 vec![
468 Token::PushVar(Var::X),
469 x(PI / 12.0),
470 Token::Binop(Binop::Mul),
471 ],
472 ),
473 (
474 "a + b + c",
475 vec![
476 Token::PushVar(Var::A),
477 Token::PushVar(Var::B),
478 Token::Binop(Binop::Add),
479 Token::PushVar(Var::C),
480 Token::Binop(Binop::Add),
481 ],
482 ),
483 (
484 "x * (a / b)",
485 vec![
486 Token::PushVar(Var::X),
487 Token::PushVar(Var::A),
488 Token::Binop(Binop::Mul),
489 Token::PushVar(Var::B),
490 Token::Binop(Binop::Div),
491 ],
492 ),
493 ];
494
495 for (expr, tokens) in cases {
496 let mut program = Program::parse_from_infix(expr).unwrap();
497 program.optimize(&Library::default());
498 let expected = Program(tokens);
499 assert!(
500 rough_compare(&program, &expected),
501 "{program:?} != {expected:?}"
502 );
503 }
504 }
505}